<a href="https://colab.research.google.com/github/google-deepmind/disentangled_rnns/blob/main/disentangled_rnns/notebooks/train_multisubject_disrnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install disentangled_rnns repo from github
!git clone https://github.com/google-deepmind/disentangled_rnns
%cd disentangled_rnns
!pip install .
%cd ..

import optax
import matplotlib.pyplot as plt
import copy

from disentangled_rnns.library import rnn_utils
from disentangled_rnns.library import get_datasets
from disentangled_rnns.library import plotting
from disentangled_rnns.library import multisubject_disrnn

# Define a dataset

In [None]:
learning_rates = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]

datasets = []
for learning_rate in learning_rates:
  dataset_single_subj = get_datasets.get_q_learning_dataset(n_trials=200, n_sessions=300, alpha=learning_rate)
  datasets.append(dataset_single_subj)

dataset_multisubj = multisubject_disrnn.dataset_list_to_multisubject(datasets)

In [None]:
dataset_train, dataset_eval = rnn_utils.split_dataset(dataset_multisubj, 2)

# Define and train RNN

In [None]:
disrnn_config = multisubject_disrnn.MultisubjectDisRnnConfig(
      obs_size=2,
      output_size=2,

      latent_size=5,
      update_net_n_units_per_layer=16,
      update_net_n_layers=4,
      choice_net_n_units_per_layer=4,
      choice_net_n_layers=2,
      noiseless_mode=False,

      max_n_subjects = len(learning_rates),
      subject_embedding_size = 2,

      latent_penalty_scale=1e-5,
      choice_net_penalty_scale=1e-5,
      update_net_penalty_scale=1e-5,
      subject_embedding_penalty_scale=1e-5,
      activation='leaky_relu',
  )


In [None]:
# Initial training in noiseless mode
disrnn_config_noiseless = copy.copy(disrnn_config)
disrnn_config_noiseless.noiseless_mode = True
make_network_noiseless = lambda: multisubject_disrnn.MultisubjectDisRnn(disrnn_config_noiseless)

params, opt_state, losses = rnn_utils.train_network(
    make_network_noiseless,
    dataset_train,
    dataset_eval,
    opt = optax.adam(1e-2),
    loss="penalized_categorical",
    n_steps=1_000,
    )


In [None]:
# RUN THIS CELL AND THE ONES BELOW IT MANY TIMES
# Running this cell repeatedly continues to train the same network.
# The cells below make plots documenting what's going on in your network
# If you'd like to reinitialize the network, re-run the above cell
# Try tweaking the bottleneck parameters as you train, to get a feel for how they affect things
disrnn_config.choice_net_penalty_scale = 1e-3
disrnn_config.update_net_penalty_scale = 1e-3
disrnn_config.subject_embedding_penalty_scale = 1e-3
disrnn_config.latent_penalty_scale = 1e-2

make_network = lambda: multisubject_disrnn.MultisubjectDisRnn(disrnn_config)

n_steps = 1_000

params, opt_state, losses = rnn_utils.train_network(
    make_network,
    dataset_train,
    dataset_eval,
    loss="penalized_categorical",
    params=params,
    opt_state=opt_state,
    opt = optax.adam(1e-3),
    loss_param = 1,
    n_steps=n_steps,
    do_plot = True)


In [None]:
# Plot the open/closed state of the bottlenecks
_ = plotting.plot_bottlenecks(params, disrnn_config)

In [None]:
# Plot the subject embeddings
subject_embeddings = params['multisubject_dis_rnn/subject_embedding_weights']['w'] + params['multisubject_dis_rnn/subject_embedding_weights']['b']

plt.figure()
plt.scatter(subject_embeddings[:, 0], subject_embeddings[:, 1], c=learning_rates)
plt.xlim([-1, 1])
plt.ylim([-1, 1])
plt.xlabel('Dimension One', fontsize=18)
plt.ylabel('Dimension Two', fontsize=18)
plt.title('Subject Embeddings', fontsize=24)
cbar = plt.colorbar()
cbar.set_label('Agent Learning Rate', fontsize=18)
plt.show()

In [None]:
# Plot the choice rule
_ = plotting.plot_choice_rule(params, disrnn_config)

In [None]:
# Plot the update rules
_ = plotting.plot_update_rules(params, disrnn_config)