# Disentangled RNNs for Mouse Switching Dataset
The dataset below is from [Harvard Dataverse](https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/7E0NM5). Each row corresponds to a trial, and the columns correspond to the trial number, block position, target direction, choice direction, and reward outcome, as well as the session and mouse identifiers and task conditions.

| Trial | blockTrial | Decision | Switch | Reward | Condition | Target | blockLength | Session | Mouse |
|-------|------------|----------|--------|--------|-----------|--------|-------------|---------|-------|
| 11.0  | 11.0       | 1.0      | 0.0    | 1.0    | 90-10     | 1.0    | 58.0        | m1_77   | m1    |
| 12.0  | 12.0       | 1.0      | 0.0    | 1.0    | 90-10     | 1.0    | 58.0        | m1_77   | m1    |
| 13.0  | 13.0       | 1.0      | 0.0    | 1.0    | 90-10     | 1.0    | 58.0        | m1_77   | m1    |

In [1]:
from disentangled_rnns.library import rnn_utils
from disentangled_rnns.library import disrnn
from disentangled_rnns import switch_utils
import optax
from tqdm.auto import tqdm
from datetime import datetime
import os
from matplotlib import pyplot as plt
import numpy as np


addr = "/Users/michaelcondon/workspaces/pbm_group2/2ABT_behavior_models/bandit_data.csv"
# list of tuples containing datasets for train, val, test.
ds_list = switch_utils.get_dataset(addr, batch_size=30, tr_prop=0.25, va_prop=0.75, te_prop=0.0)


In [2]:
update_mlp_shape = (5,5,5)
choice_mlp_shape = (2,2)
latent_size = 5

def make_network():
  return disrnn.HkDisRNN(update_mlp_shape=update_mlp_shape,
                        choice_mlp_shape=choice_mlp_shape,
                        latent_size=latent_size,
                        obs_size=2, target_size=2)

learning_rate = 1e-3
opt = optax.adam(learning_rate)

In [None]:
"""
Iterate through the mice, and through the beta values, saving the trained
params and loss for each in a json to disk.
"""
betas = [1e-3, 3e-3, 1e-2, 3e-2]
n_steps = 8e4

n_calls = len(ds_list) * len(betas)
dt = datetime.now().strftime("%Y-%m-%d_%H-%M")
print(f"start time: {dt}")
with tqdm(total=n_calls, desc='Overall Progress', position=1) as outer_bar:
  for m_i, dataset_tr, dataset_va, dataset_te in ds_list:
    # Train additional steps
    for beta_j in betas:
      outer_bar.set_postfix(mouse=f"{m_i}", beta=f"{beta_j:.0e}")
      params, opt_state, losses = rnn_utils.train_network(
      make_network,
          dataset_tr,
          dataset_va,
          ltype_tr="penalized_categorical",
          opt = optax.adam(learning_rate),
          penalty_scale = beta_j,
          n_steps=n_steps,
          do_plot = False)
      switch_utils.model_saver(params, m_i, beta_j, dt=dt, loss=losses)
      outer_bar.update(1)


## Analysis
From here on, you can load models from disk for each mouse as trained above.

In [None]:

directory = "/Users/michaelcondon/workspaces/pbm_group2/disentangled_rnns/models/"

# choose mouse, beta and run time
mouse = "m1"
beta = 0.01
cv = 0
dt = "2025-04-02_22-19"


params_file = os.path.join(directory, f"params_{mouse}_{beta:.0e}_0_{dt}.json")
loss_file = os.path.join(directory, f"loss_{mouse}_{beta:.0e}_0_{dt}.json")

params, loss = switch_utils.model_loader(params_file=params_file, loss_file=loss_file)
training_loss = loss['training_loss']
validation_loss = loss['validation_loss']

plt.figure()
plt.semilogy(training_loss, color='black')
plt.semilogy(np.linspace(0, len(training_loss), len(validation_loss)), validation_loss, color='tab:red', linestyle='dashed')
plt.xlabel('Training Step')
plt.ylabel('Mean Loss')
plt.legend(('Training Set', 'Validation Set'))
plt.title('Loss over Training')

In [None]:
# Eval mode runs the network with no noise
def make_network_eval():
  return disrnn.HkDisRNN(update_mlp_shape=update_mlp_shape,
                        choice_mlp_shape=choice_mlp_shape,
                        latent_size=latent_size,
                        obs_size=2, target_size=2,
                        eval_mode=True)


disrnn.plot_bottlenecks(params, sort_latents=True)
plt.show()
disrnn.plot_update_rules(params, make_network_eval)
plt.show()