<a href="https://colab.research.google.com/github/google-deepmind/disentangled_rnns/blob/main/disentangled_rnns/notebooks/train_single_disrnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install disentangled_rnns repo
!pip install disentangled_rnns

# Import the things we need
import optax
import copy
import matplotlib.pyplot as plt

from disentangled_rnns.library import rnn_utils
from disentangled_rnns.library import get_datasets
from disentangled_rnns.library import disrnn
from disentangled_rnns.library import plotting
from disentangled_rnns.library import two_armed_bandits


# Define a dataset

In [None]:
dataset = get_datasets.get_q_learning_dataset(n_sessions=500,)

In [None]:
dataset_train, dataset_eval = rnn_utils.split_dataset(dataset, 2)

# Define and train RNN

In [None]:
disrnn_config = disrnn.DisRnnConfig(
      # Dataset related
      obs_size=2,  # Choice, reward
      output_size=2,  # Choose left / choose right
      x_names=dataset.x_names,
      y_names=dataset.y_names,
      # Network architecture
      latent_size=5,
      update_net_n_units_per_layer=16,
      update_net_n_layers=4,
      choice_net_n_units_per_layer=4,
      choice_net_n_layers=2,
      activation='leaky_relu',
      # Penalties
      noiseless_mode=False,
      latent_penalty=1e-5,
      choice_net_latent_penalty=1e-5,
      update_net_obs_penalty=1e-5,
      update_net_latent_penalty=1e-5,
      l2_scale=1e-3,
  )
# Define a config for noiseless, no-penalty training
disrnn_config_noiseless = copy.deepcopy(disrnn_config)
disrnn_config_noiseless.noiseless_mode = True
disrnn_config_noiseless.latent_penalty = 0
disrnn_config_noiseless.choice_net_latent_penalty = 0
disrnn_config_noiseless.update_net_obs_penalty = 0
disrnn_config_noiseless.update_net_latent_penalty = 0
disrnn_config_noiseless.l2_scale = 0

In [None]:
# INITIAL TRAINING IN NOISELESS MODE
# Train network in noiseless mode and with no penalty
n_steps_noiseless = 1_000  # @param {type: "integer"}
learning_rate = 1e-2  # @param {type: "number"}

params, opt_state, _ = rnn_utils.train_network(
   lambda: disrnn.HkDisentangledRNN(disrnn_config_noiseless),
    training_dataset=dataset_train,
    validation_dataset=dataset_eval,
    opt = optax.adam(learning_rate=learning_rate),
    loss="penalized_categorical",
    n_steps=n_steps_noiseless)

opt_state = None  # Reset the optimizer state. In the next cell we'll add the bottleneck penalties, which will change the loss function quite a bit.

In [None]:
# RUN THIS CELL AND THE ONES BELOW IT MANY TIMES
# Running this cell repeatedly continues to train the same network.
# The cells below make plots documenting what's going on in your network
# If you'd like to reinitialize the network, re-run the above cell
# Try tweaking the bottleneck parameters as you train, to get a feel for how they affect things
disrnn_config.choice_net_latent_penalty = 1e-3  # @param {type: "number"}
disrnn_config.update_net_obs_penalty = 1e-3  # @param {type: "number"}
disrnn_config.update_net_latent_penalty = 1e-3  # @param {type: "number"}
disrnn_config.latent_penalty = 1e-2   # @param {type: "number"}
disrnn_config.l2_scale = 1e-3  # @param {type: "number"}

learning_rate = 1e-3  # @param {type: "number"}
n_steps = 1_000  # @param {type: "integer"}

params, opt_state, losses = rnn_utils.train_network(
    lambda: disrnn.HkDisentangledRNN(disrnn_config),
    dataset_train,
    dataset_eval,
    loss="penalized_categorical",
    params=params,
    opt_state=opt_state,
    opt = optax.adam(learning_rate=learning_rate),
    loss_param = 1,
    n_steps=n_steps,
    do_plot = True)


In [None]:
# Plot the open/closed state of the bottlenecks
_ = plotting.plot_bottlenecks(params, disrnn_config)

In [None]:
# Plot the choice rule
_ = plotting.plot_choice_rule(params, disrnn_config)

In [None]:
# Plot the update rules
_ = plotting.plot_update_rules(params, disrnn_config)

In [None]:
# Run forward pass on the unseen data
eval_data = dataset_eval.get_all()
xs_eval, ys_eval = eval_data['xs'], eval_data['ys']
network_output, network_states = rnn_utils.eval_network(
    lambda: disrnn.HkDisentangledRNN(disrnn_config_noiseless), params, xs_eval)

# Compute normalized likelihood
logits = network_output[:,:,:2]  # First n_actions elements of network output are the logits (the final one is the penalty)
normalized_likelihood = rnn_utils.normalized_likelihood(labels = ys_eval, output_logits=logits)

print(f'Normalized likelihood: {100*normalized_likelihood:.2f}%')

# Plot network activations on an example session
example_session = 0  # @param {type: "integer"}

choices = xs_eval[:, example_session, 0]
rewards = xs_eval[:, example_session, 1]
scalars = network_states[:, example_session, :]
two_armed_bandits.plot_2ab_sessdata(choices,
                                    rewards,
                                    scalars=scalars,
                                    scalar_types='agent_states',
                                    show_legend=False)