<a href="https://colab.research.google.com/github/google-deepmind/disentangled_rnns/blob/main/disentangled_rnns/notebooks/train_single_gru.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install disentangled_rnns repo from github
!git clone https://github.com/google-deepmind/disentangled_rnns
%cd disentangled_rnns
!pip install .
!pip install -r requirements.txt
%cd ..


import optax
import matplotlib.pyplot as plt
import matplotlib as mpl
import haiku as hk

from disentangled_rnns.library import rnn_utils
from disentangled_rnns.library import get_datasets

# Define a dataset

In [None]:
dataset = get_datasets.get_q_learning_dataset(n_sessions=500, n_trials=200)
dataset_train, dataset_eval = rnn_utils.split_dataset(dataset, eval_every_n=2)

# Define and train RNN

In [None]:
# Define the architecture of the network we'd like to train
n_hidden = 16
output_size = 2

def make_network():
  model = hk.DeepRNN(
      [hk.GRU(n_hidden), hk.Linear(output_size=output_size)]
  )
  return model

In [None]:
# INITIALIZE THE NETWORK
# Running rnn_utils.train_network with n_steps=0 does no training but sets up the
# parameters and optimizer state.
optimizer = optax.adam(learning_rate=1e-3)

params, opt_state, losses = rnn_utils.train_network(
    make_network = make_network,
    training_dataset=dataset_train,
    validation_dataset=dataset_eval,
    opt = optimizer,
    loss="categorical",
    n_steps=0)


In [None]:
# TRAIN THE NETWORK
# Running this cell repeatedly continues to train the same network.
# The cell below gives insight into what's going on in your network.
# If you'd like to reinitialize the network and start over, re-run the above cell

n_steps = 1000
optimizer = optax.adam(learning_rate=1e-3)

params, opt_state, losses = rnn_utils.train_network(
    make_network = make_network,
    training_dataset=dataset_train,
    validation_dataset=dataset_eval,
    loss="categorical",
    params=params,
    opt_state=opt_state,
    opt = optimizer,
    loss_param = 1,
    n_steps=n_steps,
    do_plot = True)


In [None]:
# Run forward pass on the unseen data
xs_eval, ys_eval = dataset_eval.get_all()
network_output, network_states = rnn_utils.eval_network(make_network, params, xs_eval)

# Compute normalized likelihood
score = rnn_utils.normalized_likelihood(ys_eval, network_output)
print(f'Normalized Likelihood: {100*score:.1f}%')

# Plot network activations on an example session
example_session = 0
plt.plot(network_states[:,example_session,:])
plt.xlabel('Trial Number')
plt.ylabel('Network Activations')