# Disentangled RNNs for Mouse Switching Dataset
The dataset below is from [Harvard Dataverse](https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/7E0NM5). Each row corresponds to a trial, and the columns correspond to the trial number, block position, target direction, choice direction, and reward outcome, as well as the session and mouse identifiers and task conditions.

| Trial | blockTrial | Decision | Switch | Reward | Condition | Target | blockLength | Session | Mouse |
|-------|------------|----------|--------|--------|-----------|--------|-------------|---------|-------|
| 11.0  | 11.0       | 1.0      | 0.0    | 1.0    | 90-10     | 1.0    | 58.0        | m1_77   | m1    |
| 12.0  | 12.0       | 1.0      | 0.0    | 1.0    | 90-10     | 1.0    | 58.0        | m1_77   | m1    |
| 13.0  | 13.0       | 1.0      | 0.0    | 1.0    | 90-10     | 1.0    | 58.0        | m1_77   | m1    |

In [1]:
from disentangled_rnns.library import rnn_utils
from disentangled_rnns.library import disrnn
from disentangled_rnns import switch_utils
import optax
from tqdm.auto import tqdm
from datetime import datetime
import os
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import math

addr = "/Users/michaelcondon/workspaces/pbm_group2/2ABT_behavior_models/bandit_data.csv"
batch_size = 30
test_prop = 0.7
df = pd.read_csv(addr)

# shuffle the sessions
eps = df['Session'].value_counts().sample(frac=1)

# create training and validation datasets with a 70% train 30% validation split
tr_eps = eps.iloc[:math.floor(test_prop*len(eps))]
tr_eps.name = "training_sessions"
ds_tr = switch_utils.get_dataset(df[df['Session'].isin(tr_eps.index)], batch_size)

va_eps = eps.iloc[math.floor(test_prop*len(eps)):]
va_eps.name = "validation_sessions"
ds_va = switch_utils.get_dataset(df[df['Session'].isin(va_eps.index)], batch_size)

In [2]:
update_mlp_shape = (5,5,5)
choice_mlp_shape = (2,2)
latent_size = 5

def make_network():
  return disrnn.HkDisRNN(update_mlp_shape=update_mlp_shape,
                        choice_mlp_shape=choice_mlp_shape,
                        latent_size=latent_size,
                        obs_size=2, target_size=2)

learning_rate = 1e-3
opt = optax.adam(learning_rate)

In [3]:
"""
Iterate through the mice, and through the beta values, saving the trained
params and loss for each in a json to disk.
"""
# betas = [1e-3, 3e-3, 1e-2, 3e-2]
betas = [1e-3]
n_steps = 1e2

n_calls = len(betas)
dt = datetime.now().strftime("%Y-%m-%d_%H-%M")
print(f"start time: {dt}")
switch_utils.split_saver(tr_eps, va_eps, dt, test_prop)
with tqdm(total=n_calls, desc='Overall Progress', position=1) as outer_bar:
  for beta_j in betas:
    outer_bar.set_postfix(beta=f"{beta_j:.0e}")
    params, opt_state, losses = rnn_utils.train_network(
    make_network,
        ds_tr,
        ds_va,
        ltype_tr="penalized_categorical",
        opt = optax.adam(learning_rate),
        penalty_scale = beta_j,
        n_steps=n_steps,
        do_plot = False)
    switch_utils.model_saver(params, beta_j, dt=dt, loss=losses, test_prop=test_prop)
    outer_bar.update(1)


start time: 2025-04-10_11-49


Overall Progress:   0%|          | 0/1 [00:00<?, ?it/s]

Training Progress:   0%|          | 0/100 [00:00<?, ?it/s]

In [6]:
va_eps.index

Index(['m1_27', 'm6_7', 'm6_40', 'm6_55', 'm5_24', 'm1_62', 'm5_27', 'm6_60',
       'm3_81', 'm1_75',
       ...
       'm6_47', 'm1_30', 'm1_25', 'm6_58', 'm2_86', 'm2_45', 'm6_9', 'm5_32',
       'm3_5', 'm4_48'],
      dtype='object', name='Session', length=158)

In [14]:
pd.concat((pd.DataFrame({'Session': tr_eps.values,
                                    'Side': 'train'},
                                    index=tr_eps.index.to_list()),
                     pd.DataFrame({'Session': va_eps.values,
                                   'Side': 'validation'},
                                  index=va_eps.index.to_list())))

Unnamed: 0,Session,Side
m6_29,837,train
m5_88,742,train
m3_16,688,train
m3_34,786,train
m3_14,760,train
...,...,...
m2_45,694,validation
m6_9,482,validation
m5_32,1209,validation
m3_5,626,validation


In [9]:
def split_loader(file_path):
  df = pd.read_csv(file_path)
  tr_eps = df[df['Side']=='train']
  va_eps = df[df['Side']=='validation']
  return tr_eps, va_eps

file_path = '/Users/michaelcondon/workspaces/pbm_group2/disentangled_rnns/models/split_2025-04-10_11-49_70-30.csv'
a = pd.read_csv(file_path)
a

Unnamed: 0,Session,Session.1,Side
0,m6_29,837,train
1,m5_88,742,train
2,m3_16,688,train
3,m3_34,786,train
4,m3_14,760,train
...,...,...,...
520,m2_45,694,validation
521,m6_9,482,validation
522,m5_32,1209,validation
523,m3_5,626,validation


## Analysis
From here on, you can load models from disk for each mouse as trained above.

In [None]:

directory = "/Users/michaelcondon/workspaces/pbm_group2/disentangled_rnns/models/"

# choose mouse, beta and run time
mouse = "m2"
beta = 0.001
cv = 0
dt = "2025-04-09_18-18"


params_file = os.path.join(directory, f"params_{beta:.0e}_0_{dt}.json")
loss_file = os.path.join(directory, f"loss_{beta:.0e}_0_{dt}.json")

params, loss = switch_utils.model_loader(params_file=params_file, loss_file=loss_file)
training_loss = loss['training_loss']
validation_loss = loss['validation_loss']

plt.figure()
plt.semilogy(training_loss, color='black')
plt.semilogy(np.linspace(0, len(training_loss), len(validation_loss)), validation_loss, color='tab:red', linestyle='dashed')
plt.xlabel('Training Step')
plt.ylabel('Mean Loss')
plt.legend(('Training Set', 'Validation Set'))
plt.title('Loss over Training')

In [None]:
# Eval mode runs the network with no noise
def make_network_eval():
  return disrnn.HkDisRNN(update_mlp_shape=update_mlp_shape,
                        choice_mlp_shape=choice_mlp_shape,
                        latent_size=latent_size,
                        obs_size=2, target_size=2,
                        eval_mode=True)


disrnn.plot_bottlenecks(params, sort_latents=True)
plt.show()
disrnn.plot_update_rules(params, make_network_eval)
plt.show()

## Switching Analysis
Here I will check how the RNN models from above behave from a switching perspective. This is based on the comparisons from the paper:

    Beron, C. C., Neufeld, S. Q., Linderman, S. W., & Sabatini, B. L. (2022). Mice exhibit stochastic and efficient action switching during probabilistic decision making. Proceedings of the National Academy of Sciences, 119(15), e2113961119. https://doi.org/10.1073/pnas.2113961119


In [6]:
"""
Iterate through each session of each dataset.
"""
import itertools
import disentangled_rnns.switch_utils as switch_utils
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from numpy.lib.stride_tricks import sliding_window_view

addr = "/Users/michaelcondon/workspaces/pbm_group2/2ABT_behavior_models/bandit_data.csv"
# dataset with no division into training set etc
full_ds = switch_utils.get_dataset(addr, tr_prop=1.0, condition='80-20')
dss = [full_ds[i][1]._xs for i in range(len(full_ds))]
ds = dss[1]

chars = 'lrLR'
h_len = 3
seq_dict = {''.join(seq): [0,0] for seq in itertools.product(chars, repeat=3)}

for session_i in range(np.shape(ds)[1]):
  session = ds[:, session_i]
  for ts_i in range(3, np.shape(session)[0]):
    if session[ts_i, 0] == -1:
      break
    ts = session[ts_i-3: ts_i]
    key = ''.join([chars[int(a+2*b)] for a, b in ts])
    if session[ts_i, 0] == session[ts_i-1, 0]:
      seq_dict[key][0] += 1
    else:
      seq_dict[key][1] += 1


In [None]:
def mean(t1):
  print(t1)
  if t1 == [0, 0]:
    return 0
  return t1[1] / (t1[0]+t1[1])

p_dict = {key: mean(val) for key, val in seq_dict.items()}

eq_chars = 'aAbB'
eqs = list(itertools.product(eq_chars, repeat=h_len))[:len(eq_chars)**h_len//2]
eq_dict = {''.join(seq): 0 for seq in eqs}
for seq in eq_dict:
  tran1 = seq.translate(str.maketrans('abAB', 'lrLR'))
  tran2 = seq.translate(str.maketrans('abAB', 'rlRL'))
  eq_dict[seq] = (p_dict[tran1] + p_dict[tran2]) / 2

In [None]:
import seaborn as sns

sorted_items = sorted(eq_dict.items(), key=lambda item: item[1])
sorted_labels = [item[0] for item in sorted_items]
sorted_heights = [item[1] for item in sorted_items]
total_height = sum(sorted_heights)

sns.set(style='ticks', font_scale=1.7, rc={'axes.labelsize':20, 'axes.titlesize':20})
sns.set_palette('deep')


fig, ax = plt.subplots(figsize=(14,4.2))

sns.barplot(x=sorted_labels, y=sorted_heights, color='k', alpha=0.5, ax=ax, edgecolor='gray')
ax.errorbar(x=sorted_labels, y=sorted_heights, fmt=' ', color='k', label=None)

ax.set(xlim=(-1,len(sorted_heights)), ylim=(0,1), ylabel='P(switch)')
plt.xticks(rotation=90)
sns.despine()
plt.title('Empirical Switch Probabilities')
plt.tight_layout()

In [None]:
np.shape(ds_list[1][2]._xs)

In [None]:
import jax
import haiku as hk
import jax.numpy as jnp

def make_network_eval():
  return disrnn.HkDisRNN(update_mlp_shape=update_mlp_shape,
                        choice_mlp_shape=choice_mlp_shape,
                        latent_size=latent_size,
                        obs_size=2, target_size=2,
                        eval_mode=True)

def unroll_network(xs):
  core = make_network()
  batch_size = jnp.shape(xs)[1]
  state = core.initial_state(batch_size)
  ys, _ = hk.dynamic_unroll(core, xs, state)
  return ys



# Haiku, step two: Transform the network into a pair of functions
# (model.init and model.apply)
_, step_hk = hk.transform(unroll_network)
step_hk = jax.jit(step_hk)

random_key = jax.random.PRNGKey(0)
# If params have not been supplied, start training from scratch
random_key, key1 = jax.random.split(random_key)


xs, ys = full_ds[1][1]._xs, full_ds[1][1]._ys
np.shape(step_hk(params, key1, xs))
