In [1]:
import os
from pathlib import Path
import json
import numpy as np
import torch
import torch.nn as nn
import gym
import neurogym as ngym
from neurogym.wrappers import ScheduleEnvs
from neurogym.utils import scheduler
import matplotlib.pyplot as plt
import seaborn as sns
import os
import json
from neurogym.wrappers.block import MultiEnvs
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.metrics import silhouette_score
import matplotlib
from matplotlib.colors import ListedColormap
import copy
import torch.nn.functional as F


  logger.warn(


Create the 1 task and multitask arrays by loading model, running function (files too big to transfer easily)

In [5]:
class Net(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Net, self).__init__()
        self.rnn = nn.RNN(input_size, hidden_size)
        self.linear = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        out, hidden = self.rnn(x)
        x = self.linear(out)
        return x, out

In [9]:
def generate_matrix(model, dataset, num_trials=10000, timesteps=10, hidden_size=None):

    inputs, labels = dataset()

    input_size = inputs.shape[2]
    label_size = labels.shape[1]
    output_size = 1
      
    timesteps = labels.shape[0]
    print(f"inputs_size: {input_size}, label_size: {label_size}, output_size: {output_size}, hidden_size: {hidden_size}, timesteps = {timesteps}")

    matrix_size = (num_trials, timesteps, input_size + label_size + output_size + hidden_size)
    output_matrix = np.zeros(matrix_size)
    model.eval()

    for trial in range(num_trials):

        inputs, labels = dataset()  

        inputs = torch.tensor(inputs, dtype=torch.float32)
        labels = torch.tensor(labels, dtype=torch.float32)

        with torch.no_grad():
            hidden_states, _ = model.rnn(inputs)
            output, _ = model(inputs)
            _, predicted = torch.max(output.data, -1) 


            for t in range(timesteps):
                output_matrix[trial, t, :input_size] = inputs[t, 0, :].numpy()
                output_matrix[trial, t, input_size:input_size+label_size] = labels[t].numpy()
                output_matrix[trial, t, input_size+label_size:input_size+label_size+output_size] = predicted[t, 0].numpy()
                output_matrix[trial, t, -hidden_size:] = hidden_states[0, 0, :].numpy()

    return output_matrix


In [6]:
# create dataset 
kwargs = {}
kwargs['timing'] = {
                'fixation': 300,
                'stimulus': 500,
                'decision': 200} # fix timing for discrete trial intervals

task_env = ngym.make('yang19.dm1-v0', **kwargs)
dataset = ngym.Dataset(task_env, batch_size=1, seq_len=10)
env = dataset.env

ob_size = env.observation_space.shape[0]
act_size = env.action_space.n

In [7]:
# small model
net = Net(input_size=env.observation_space.shape[0],
          hidden_size=20,
          output_size=act_size)

net.load_state_dict(torch.load('small_net.pth'))

net.eval()

  net.load_state_dict(torch.load('small_net.pth'))


Net(
  (rnn): RNN(33, 20)
  (linear): Linear(in_features=20, out_features=17, bias=True)
)

In [10]:
small_array = generate_matrix(net, dataset, hidden_size=20)

inputs_size: 33, label_size: 1, output_size: 1, hidden_size: 20, timesteps = 10


In [14]:
# multitask array
# change batch size to 1 for simplicity
task_names = ['yang19.go-v0', 'yang19.rtgo-v0', 'yang19.dlygo-v0', 'yang19.anti-v0', 'yang19.rtanti-v0', 'yang19.dlyanti-v0', 'yang19.dm1-v0', 'yang19.dm2-v0', 'yang19.ctxdm1-v0', 'yang19.ctxdm2-v0', 'yang19.multidm-v0', 'yang19.dlydm1-v0', 'yang19.dlydm2-v0', 'yang19.ctxdlydm1-v0', 'yang19.ctxdlydm2-v0', 'yang19.multidlydm-v0', 'yang19.dms-v0', 'yang19.dnms-v0', 'yang19.dmc-v0', 'yang19.dnmc-v0']
task_list = []
for name in task_names:
    task_list.append(ngym.make(name))


# Create a schedule for switching between tasks
schedule = scheduler.RandomSchedule(n=len(task_list))

# Combine the tasks
combined_env = ScheduleEnvs(task_list, schedule=schedule, env_input=True)

# Create the dataset
dataset = ngym.Dataset(combined_env, batch_size=1, seq_len=100)
env = dataset.env

ob_size = env.observation_space.shape[0]
act_size = env.action_space.n

In [12]:
multitask_net = Net(input_size=env.observation_space.shape[0],
          hidden_size=256,
          output_size=act_size)
multitask_net.load_state_dict(torch.load('multitask_net.pth'))

  multitask_net.load_state_dict(torch.load('multitask_net.pth'))


<All keys matched successfully>

In [15]:
multitask_array = generate_matrix(multitask_net, dataset, hidden_size=256)

inputs_size: 53, label_size: 1, output_size: 1, hidden_size: 256, timesteps = 100


Navigating Array
Array dims: (trial, timestep, data)

In [17]:
print(multitask_array.shape)
print(small_array.shape)

# Optional: save files, uncomment to run
#np.save('multitask_array', multitask_array)
#np.save('small_array', small_array)

(10000, 100, 311)
(10000, 10, 55)


In [10]:
# Access 1 trial:
print(small_array[0].shape) 
# Access 1 timestep of the trial:
print(small_array[0, 0].shape) 


(10, 55)
(55,)


Unpacking trial data:
data = inputs, label, predicted, hidden states
so for the small 1 task rnn with input size 33, label size 1, prediction size 1, and 20 units/hidden states

In [11]:
data = small_array[0, 0]
inputs = data[:33]
label = data[33:34]
prediction = data[34:35]
hidden_states = data[35:]

For large multitask rnn with input size 53 (33 normal inputs + 20 one hot encoded task vectors), label size 1, prediction size 1, and 256 units/hidden states

In [12]:
data = multitask_array[0, 0]
inputs = data[:53] # fixation input then 2 size 16 rings, then 20 one hot encoded task vector
label = data[53:54]
prediction = data[54:55]
hidden_states = data[55:]