<a href="https://colab.research.google.com/github/google-deepmind/disentangled_rnns/blob/main/disentangled_rnns/notebooks/train_neuro_disrnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install disentangled_rnns repo from github
!git clone https://github.com/google-deepmind/disentangled_rnns
%cd disentangled_rnns
!pip install .
%cd ..

import optax
import numpy as np

from disentangled_rnns.library import rnn_utils
from disentangled_rnns.library import neuro_disrnn
from disentangled_rnns.library import checkpoint_utils
from disentangled_rnns.library import two_armed_bandits_w_dopamine

# Define a dataset

In [None]:
# @title Dataset Selection

dataset_configs = {
    "q_learning_w_dopamine":{
        "getter": two_armed_bandits_w_dopamine.get_q_learning_with_dopamine_dataset,
        "kwargs": {"n_trials": 100, "n_sessions": 100},
        "penalties": {
            "latent_penalty": 1e-2,
            "choice_net_latent_penalty": 1e-4,
            "update_net_latent_penalty": 2e-3,
            "neural_activity_net_latent_penalty": 1e-4,
            "update_net_obs_penalty": 1e-5,
        },
    },
    "reward_seeking":{
        "getter": two_armed_bandits_w_dopamine.get_reward_seeking_with_dopamine_dataset,
        "kwargs": {"n_trials":100, "n_sessions": 100},
        "penalties": {
            "latent_penalty": 1e-3,
            "choice_net_latent_penalty": 1e-4,
            "update_net_latent_penalty": 2e-3,
            "neural_activity_net_latent_penalty": 1e-4,
            "update_net_obs_penalty": 1e-5,
        }

    },
}

dataset_name = "q_learning_w_dopamine"  # @param ["q_learning_w_dopamine", "reward_seeking"]
dataset_config = dataset_configs[dataset_name]
dataset = dataset_config["getter"](**dataset_config["kwargs"])

In [None]:
dataset_train, dataset_eval = rnn_utils.split_dataset(dataset, 2)

# Define and train RNN

In [None]:
disrnn_w_neural_activity_config = neuro_disrnn.DisRnnWNeuralActivityConfig(
      # Dataset related
      obs_size=2,  # Choice, reward
      output_size=2,  # Choose left / choose right
      x_names=dataset.x_names,
      y_names=dataset.y_names,
      # Network architecture
      latent_size=7,
      update_net_n_units_per_layer=16,
      update_net_n_layers=4,
      choice_net_n_units_per_layer=2,
      choice_net_n_layers=2,
      neural_activity_net_n_units_per_layer=4,
      neural_activity_net_n_layers=2,
      activation='leaky_relu',
      # Penalties
      noiseless_mode=False,
      latent_penalty=np.nan,
      choice_net_latent_penalty=np.nan,
      update_net_latent_penalty=np.nan,
      neural_activity_net_latent_penalty=np.nan,
  )

for penalty_name, penalty_value in dataset_config["penalties"].items():
    setattr(disrnn_w_neural_activity_config, penalty_name, penalty_value)

In [None]:
# Initial training in noiseless mode

likelihood_weight = 0.5 # @param {type: "slider",min:0, max: 1, step: 0.1}
params, opt_state, losses = rnn_utils.train_network(
   lambda: neuro_disrnn.HkNeuroDisentangledRNN(disrnn_w_neural_activity_config),
    dataset_train,
    dataset_eval,
    opt = optax.adam(1e-3),
    loss="penalized_hybrid",
    loss_param={'likelihood_weight': likelihood_weight, 'penalty_scale': 1.0},
    n_steps=0)

In [None]:
# RUN THIS CELL AND THE ONES BELOW IT MANY TIMES
# Running this cell repeatedly continues to train the same network.
# The cells below make plots documenting what's going on in your network
# If you'd like to reinitialize the network, re-run the above cell
# Try tweaking the bottleneck parameters as you train, to get a feel for how they affect things


# Usually 15,000 steps in total should be sufficient.
n_steps = 15_000
params, opt_state, losses = rnn_utils.train_network(
    lambda: neuro_disrnn.HkNeuroDisentangledRNN(disrnn_w_neural_activity_config),
    dataset_train,
    dataset_eval,
    loss="penalized_hybrid",
    params=params,
    opt_state=opt_state,
    opt = optax.adam(1e-3),
    loss_param = {'likelihood_weight': likelihood_weight, 'penalty_scale': 1.0},
    n_steps=n_steps,
    do_plot = True)

In [None]:
# Plot the open/closed state of the bottlenecks. Ideally neural activity bottlenecks
# should stay closed as we are not training the neural activity readout right now.

_=neuro_disrnn.plot_bottlenecks(params, disrnn_w_neural_activity_config, sort_latents=False)

In [None]:
# Plot the choice rule
neuro_disrnn.plot_choice_rule(params, disrnn_w_neural_activity_config)

In [None]:
# Plot the update rules
neuro_disrnn.plot_update_rules(params, disrnn_w_neural_activity_config)

In [None]:
# Plot neural activity rules
_ = neuro_disrnn.plot_neural_activity_rules(params, disrnn_w_neural_activity_config, axis_lim=0.8)