<a href="https://colab.research.google.com/github/google-deepmind/disentangled_rnns/blob/main/disentangled_rnns/notebooks/train_single_disrnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install disentangled_rnns repo from github
!git clone https://github.com/google-deepmind/disentangled_rnns
%cd disentangled_rnns
!pip install .
%cd ..

# Import the things we need
import optax

from disentangled_rnns.library import rnn_utils
from disentangled_rnns.library import get_datasets
from disentangled_rnns.library import disrnn
from disentangled_rnns.library import plotting

# Define a dataset

In [None]:
dataset = get_datasets.get_q_learning_dataset(n_sessions=500,)

In [None]:
dataset_train, dataset_eval = rnn_utils.split_dataset(dataset, 2)

# Define and train RNN

In [None]:
disrnn_config = disrnn.DisRnnConfig(
      # Dataset related
      obs_size=2,  # Choice, reward
      output_size=2,  # Choose left / choose right
      x_names=dataset.x_names,
      y_names=dataset.y_names,
      # Network architecture
      latent_size=5,
      update_net_n_units_per_layer=16,
      update_net_n_layers=4,
      choice_net_n_units_per_layer=4,
      choice_net_n_layers=2,
      activation='leaky_relu',
      # Penalties
      noiseless_mode=False,
      latent_penalty=1e-5,
      choice_net_latent_penalty=1e-5,
      update_net_obs_penalty=1e-5,
      update_net_latent_penalty=1e-5,
  )

In [None]:
# INITIALIZE THE NETWORK
# Running rnn_utils.train_network with n_steps=0 does no training but sets up the
# parameters and optimizer state.
params, opt_state, losses = rnn_utils.train_network(
   lambda: disrnn.HkDisentangledRNN(disrnn_config),
    dataset_train,
    dataset_eval,
    opt = optax.adam(1e-2),
    loss="penalized_categorical",
    n_steps=0)


In [None]:
# RUN THIS CELL AND THE ONES BELOW IT MANY TIMES
# Running this cell repeatedly continues to train the same network.
# The cells below make plots documenting what's going on in your network
# If you'd like to reinitialize the network, re-run the above cell
# Try tweaking the bottleneck parameters as you train, to get a feel for how they affect things
disrnn_config.choice_net_latent_penalty = 1e-3
disrnn_config.update_net_obs_penalty = 1e-3
disrnn_config.update_net_latent_penalty = 1e-3
disrnn_config.latent_penalty = 1e-2

n_steps = 1000

params, opt_state, losses = rnn_utils.train_network(
    lambda: disrnn.HkDisentangledRNN(disrnn_config),
    dataset_train,
    dataset_eval,
    loss="penalized_categorical",
    params=params,
    opt_state=opt_state,
    opt = optax.adam(1e-3),
    loss_param = 1,
    n_steps=n_steps,
    do_plot = True)


In [None]:
# Plot the open/closed state of the bottlenecks
_ = plotting.plot_bottlenecks(params, disrnn_config)

In [None]:
# Plot the choice rule
_ = plotting.plot_choice_rule(params, disrnn_config)

In [None]:
# Plot the update rules
_ = plotting.plot_update_rules(params, disrnn_config)