# Disentangled RNNs for Mouse Switching Dataset
The dataset below is from [Harvard Dataverse](https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/7E0NM5). Each row corresponds to a trial, and the columns correspond to the trial number, block position, target direction, choice direction, and reward outcome, as well as the session and mouse identifiers and task conditions.

| Trial | blockTrial | Decision | Switch | Reward | Condition | Target | blockLength | Session | Mouse |
|-------|------------|----------|--------|--------|-----------|--------|-------------|---------|-------|
| 11.0  | 11.0       | 1.0      | 0.0    | 1.0    | 90-10     | 1.0    | 58.0        | m1_77   | m1    |
| 12.0  | 12.0       | 1.0      | 0.0    | 1.0    | 90-10     | 1.0    | 58.0        | m1_77   | m1    |
| 13.0  | 13.0       | 1.0      | 0.0    | 1.0    | 90-10     | 1.0    | 58.0        | m1_77   | m1    |

In [1]:
from disentangled_rnns.library import rnn_utils
from disentangled_rnns.library import disrnn
from disentangled_rnns import switch_utils
import optax
from tqdm.auto import tqdm
from datetime import datetime
import os
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import math
import seaborn as sns

addr = "/Users/michaelcondon/workspaces/pbm_group2/2ABT_behavior_models/bandit_data.csv"
batch_size = 30
test_prop = 0.7
df = pd.read_csv(addr)

# shuffle the sessions
eps = df['Session'].value_counts().sample(frac=1)

# create training and validation datasets with a 70% train 30% validation split
tr_eps = eps.iloc[:math.floor(test_prop*len(eps))]
tr_eps.name = "training_sessions"
ds_tr = switch_utils.get_dataset(df[df['Session'].isin(tr_eps.index)], batch_size)

va_eps = eps.iloc[math.floor(test_prop*len(eps)):]
va_eps.name = "validation_sessions"
ds_va = switch_utils.get_dataset(df[df['Session'].isin(va_eps.index)], batch_size)

In [2]:
update_mlp_shape = (5,5,5)
choice_mlp_shape = (2,2)
latent_size = 5

def make_network():
  return disrnn.HkDisRNN(update_mlp_shape=update_mlp_shape,
                        choice_mlp_shape=choice_mlp_shape,
                        latent_size=latent_size,
                        obs_size=2, target_size=2)

learning_rate = 1e-3
opt = optax.adam(learning_rate)

In [None]:
"""
Iterate through the mice, and through the beta values, saving the trained
params and loss for each in a json to disk.
"""
betas = [1e-3]
# betas = [1e-3]
n_steps = 1e4

n_calls = len(betas)
dt = datetime.now().strftime("%Y-%m-%d_%H-%M")
print(f"start time: {dt}")
switch_utils.split_saver(tr_eps, va_eps, dt, test_prop)
with tqdm(total=n_calls, desc='Overall Progress', position=1) as outer_bar:
  for beta_j in betas:
    outer_bar.set_postfix(beta=f"{beta_j:.0e}")
    params, opt_state, losses = rnn_utils.train_network(
    make_network,
        ds_tr,
        ds_va,
        ltype_tr="penalized_categorical",
        opt = optax.adam(learning_rate),
        penalty_scale = beta_j,
        n_steps=n_steps,
        do_plot = False)
    switch_utils.model_saver(params, beta_j, dt=dt, loss=losses, test_prop=test_prop)
    outer_bar.update(1)


## Analysis
From here on, you can load models from disk for each mouse as trained above.

In [None]:

directory = "/Users/michaelcondon/workspaces/pbm_group2/disentangled_rnns/models/"

# choose mouse, beta and run time
test_prop = 0.7
cv = f'{test_prop*100:.0f}-{(1-test_prop)*100:.0f}'
beta = 0.001
dt = "2025-04-10_12-51"

params_file = os.path.join(directory, f"params_{beta:.0e}_{cv}_{dt}.json")
loss_file = os.path.join(directory, f"loss_{beta:.0e}_{cv}_{dt}.json")

params, loss = switch_utils.model_loader(params_file=params_file, loss_file=loss_file)
training_loss = loss['training_loss']
validation_loss = loss['validation_loss']

plt.figure()
plt.semilogy(training_loss, color='black')
plt.semilogy(np.linspace(0, len(training_loss), len(validation_loss)), validation_loss, color='tab:red', linestyle='dashed')
plt.xlabel('Training Step')
plt.ylabel('Mean Loss')
plt.legend(('Training Set', 'Validation Set'))
plt.title('Loss over Training')

In [None]:
# Eval mode runs the network with no noise
def make_network_eval():
  return disrnn.HkDisRNN(update_mlp_shape=update_mlp_shape,
                        choice_mlp_shape=choice_mlp_shape,
                        latent_size=latent_size,
                        obs_size=2, target_size=2,
                        eval_mode=True)


disrnn.plot_bottlenecks(params, sort_latents=True)
plt.show()
disrnn.plot_update_rules(params, make_network_eval)
plt.show()

## Switching Analysis
Here I will check how the RNN models from above behave from a switching perspective. This is based on the comparisons from the paper:

    Beron, C. C., Neufeld, S. Q., Linderman, S. W., & Sabatini, B. L. (2022). Mice exhibit stochastic and efficient action switching during probabilistic decision making. Proceedings of the National Academy of Sciences, 119(15), e2113961119. https://doi.org/10.1073/pnas.2113961119


In [None]:

p_dict = switch_utils.switch_bars(ds_va._xs, ds_va._xs[:,:,0], symm=True, prob=True)

sorted_items = sorted(p_dict.items(), key=lambda item: item[1])
sorted_labels = [item[0] for item in sorted_items]
sorted_heights = [item[1] for item in sorted_items]

sns.set(style='ticks', font_scale=1.7, rc={'axes.labelsize':20, 'axes.titlesize':20})
sns.set_palette('deep')


fig, ax = plt.subplots(figsize=(14,4.2))

sns.barplot(x=sorted_labels, y=sorted_heights, color='k', alpha=0.5, ax=ax, edgecolor='gray')
ax.errorbar(x=sorted_labels, y=sorted_heights, fmt=' ', color='k', label=None)

ax.set(xlim=(-1,len(sorted_heights)), ylim=(0,1), ylabel='P(switch)')
plt.xticks(rotation=90)
sns.despine()
plt.title('Empirical Switch Probabilities')
plt.tight_layout()

In [5]:
import jax
import haiku as hk
import jax.numpy as jnp

# make_network_eval, step_hk and unroll_network are functions needed to use
# the neural network.
def make_network_eval():
  return disrnn.HkDisRNN(update_mlp_shape=update_mlp_shape,
                        choice_mlp_shape=choice_mlp_shape,
                        latent_size=latent_size,
                        obs_size=2, target_size=2,
                        eval_mode=True)

def unroll_network(xs):
  core = make_network()
  batch_size = jnp.shape(xs)[1]
  state = core.initial_state(batch_size)
  ys, _ = hk.dynamic_unroll(core, xs, state)
  return ys


_, step_hk = hk.transform(unroll_network)
step_hk = jax.jit(step_hk)

random_key = jax.random.PRNGKey(0)

# first two columns give the probability of left and right (but need to be put through
# softmax for normalising)
output = step_hk(params, random_key, ds_va._xs)[:,:,:2]
# sample from the output either greedily or with thompson sampling
y_sampled = switch_utils.sampler(output, 'thompson')

In [None]:
# simulated conditional probability dictionary for each 3 letter history
p_dict = switch_utils.switch_bars(ds_va._xs, ds_va._xs[:,:,0], symm=True, prob=True)
sim_p_dict = switch_utils.switch_bars(ds_va._xs[1:], y_sampled[:-1,:], symm=True, prob=True)

sorted_items = sorted(p_dict.items(), key=lambda item: item[1])
sorted_keys = [item[0] for item in sorted_items] 
sorted_labels = [item[0] for item in sorted_items]
sorted_heights = [item[1] for item in sorted_items]
sim_sorted_heights = [sim_p_dict[key] for key in sorted_keys]

sns.set(style='ticks', font_scale=1.7, rc={'axes.labelsize':20, 'axes.titlesize':20})
sns.set_palette('deep')


fig, ax = plt.subplots(figsize=(14,4.2))

sns.barplot(x=sorted_labels, y=sim_sorted_heights, color='g', alpha=1, ax=ax, label='DisRNN Switch Prob')
sns.barplot(x=sorted_labels, y=sorted_heights, color='k', alpha=0.5, ax=ax, edgecolor='gray', label='Mouse Switch Prob')

ax.set(xlim=(-1,len(sorted_heights)), ylim=(0,1), ylabel='P(switch)')
plt.xticks(rotation=90)
sns.despine()
plt.title('Empirical Switch Probabilities')
plt.tight_layout()
plt.legend()
plt.savefig('/Users/michaelcondon/workspaces/pbm_group2/disentangled_rnns/figs/switch_probs.pdf')
plt.show()

## Simulation Analysis
Here, I simulate sessions where the reward contingincies are first calculated independently of choices, then the choices are simulated. The trial length for each mouse is normally distributed with mean and std below

In [None]:
# mean and std of trial length
print(f"trial length mean: {eps.mean():.2f}, std: {eps.std():.2f}")