HePlas1-3 and HeDep1-3 are the comparable ones; all same parameters for training, all trained on difficulty level 2

In [None]:
from networks import DMTSNet #stsp and fixed (same but no x+u) networks
from spatial_task import DMTSDataModule #spatial version, distraction code removed
import torch
import numpy as np
import os
import matplotlib.pyplot as plt
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" #fix for weird issue where matplotlib kills  kernel

import time

In [None]:
def load_models(checkpoints, ckpt_dir):
    '''
    load models from a list of checkpoint files into a dictionary along with parameters;
    nested dictionary format:
        models_dict
            label
                model : model
                params : param_dict
    '''
    
    models_dict = {}

    for ckpt in checkpoints:
        
        model = DMTSNet.load_from_checkpoint(os.path.join(ckpt_dir, ckpt))
        
        ## read parameter values from checkpoint filename
        params = ckpt.split('_') #edit; this will become '-'
        params_dict = {
            'label' : params[0],
            'rnn' : params[1],
            'nonlinearity' : params[2],
            'hidden size' : params[3],
            'gamma' : params[4],
            'learning rate' : params[5],
            'act reg' : params[6],
            'param reg' : params[7],
            'init method' : params[8],
            'noise level' : params[9],
            'difficulty level' : params[10],
            'optimizer' : params[11],
            'batch size' : params[12][:2], #edit...
            'eps to finish' : int(params[12][8:10]), #edit...
            'accuracy' : params[13][4:8],  #edit...
        }
        
        '''
        note to self: a lot of the above info is also (more easily) accessible
        directly from the model; i.e. model.rnn.hidden_size, model.rnn.lr;
        can see all attributes with model.rnn.__dict__
        '''
        
        models_dict[params_dict['label']] = {
            'model' : model,
            'params' : params_dict
        }
        
    return models_dict

In [None]:
plas_nets = ['HePlas'+str(x) for x in range(1,4)]
dep_nets = ['HeDep'+str(x) for x in range(1,4)]

In [None]:
## specify which models to load by label at start of filename
plas_nets = ['HePlas'+str(x) for x in range(1,4)]
dep_nets = ['HeDep'+str(x) for x in range(1,4)]
ckpt_dir = '_lightning_sandbox\checkpoints'

## find models with matching labels in the checkpoint directory
load_checkpoints = []
for label in plas_nets+dep_nets:
    load_checkpoints += [ckpt for ckpt in os.listdir(ckpt_dir) if ckpt.startswith(label)]

print("loading the following models:")
for n in load_checkpoints: print('    '+n+'/n')

## load the selected models into a dictionary of { [label]: {[model],[params]} } nested dictionary format
models_dict = load_models(load_checkpoints, ckpt_dir)

## check that all models were loaded
if len(load_checkpoints) == len(models_dict.keys()): print('loading successful!')
else: print('something is missing...')

In [None]:
## load up the task to test models on
task = DMTSDataModule(dt_ann=15)  
task.setup()
tester = task.test_dataloader()
inp, out_des, y, test_on = next(iter(tester))

## get ouput for each model and store in models_dict as [output], [hidden activity], and [hidden weights]
## note this will take a while to run
start_time = time.time()

for k in models_dict.keys():
    
    # have model do the task
    print(models_dict[k]['params']['label'], 'is running...')
    model = models_dict[k]['model']
    out_readout, out_hidden, w_hidden, _ = model(inp) #ignoring fourth output which is just process noise
    
    # store model output + activity in models_dict
    models_dict[k]['output'] = out_readout
    models_dict[k]['hidden activity'] = out_hidden
    models_dict[k]['hidden weights'] = w_hidden
    
print('elapsed time:', time.time() - start_time)

In [None]:
dt_ann = 15
samp_on = int(1000/dt_ann)
samp_off = samp_on+int(500/dt_ann)
unique_delay_times = torch.unique(test_on)
num_delays = len(unique_delay_times)
num_samps = 2

In [None]:
model = models_dict['HePlas1']['model']

In [None]:
acc_dict = {}
plt.figure(figsize=(5,4))

for delay in unique_delay_times:

    #get trials with same delay length
    delay_inds = torch.where(test_on == delay)[0]
    accs = np.zeros(len(delay_inds))

    for i,trial in enumerate(delay_inds):

        #count of number of times NN made right choice during test window + divide by num timepoints
        curr_max = out_readout[
            trial,
            int(test_on[trial]):int(test_on[trial])+int(500/dt_ann),
            :-1].argmax(dim=1).cpu().detach().numpy()
        accs[i] = (y[trial].item() == curr_max).sum() / len(curr_max)

    acc_dict[(delay.item()*dt_ann)/1000] = accs
    #average single-trial accuracy across entire test set (1024 trials) and plot as function of delay
    print(f"delay: {(delay.item()*dt_ann)/1000}, {len(delay_inds)} trials, accuracy: {accs.mean()}")


In [None]:
## plot mean accuracy of output at each delay length
## DEBUG- PLOTS SAME FOR EVERY MODEL
def plot_acc(out_readout, label):

    plt.figure(figsize=(5,4))

    for delay in unique_delay_times:

        #get trials with same delay length
        delay_inds = torch.where(test_on == delay)[0]
        accs = np.zeros(len(delay_inds))

        for i,trial in enumerate(delay_inds):

            #count of number of times NN made right choice during test window + divide by num timepoints
            curr_max = out_readout[
                trial,
                int(test_on[trial]):int(test_on[trial])+int(500/dt_ann),
                :-1].argmax(dim=1).cpu().detach().numpy()
            accs[i] = (y[trial].item() == curr_max).sum() / len(curr_max)

        #average single-trial accuracy across entire test set (1024 trials) and plot as function of delay
        plt.scatter((delay.item()*15)/1000, accs.mean())

    plt.title(label)

In [None]:
for k in models_dict.keys():
    plot_acc(models_dict[k]['output'], models_dict[k]['params']['label'])

In [None]:
#check inputs, desired outputs, and actual outputs for one trial at a time

def plot_trial_x(out_readout, trial):

    f,ax = plt.subplots(3,1)
    with torch.no_grad():
        for node in range(3): #2 samples + 1 fixation
            ax[0].plot(inp[trial,:,node])
            ax[1].plot(out_des[trial,:,node])
            ax[2].plot(out_readout[trial][:][:,node])
            ax[2].axvline(test_on[trial], linestyle='--', color='gray') #lines around area that loss is calculated on
            ax[2].axvline(test_on[trial]+500/15, linestyle='--', color='gray')

In [None]:
for k in models_dict.keys():
    plot_trial_x(models_dict[k]['output'], 1)

In [None]:
model = 'HePlas3'

W_matrix = models_dict[model]['model'].rnn.W.detach().numpy()

out_hidden = models_dict[model]['hidden activity']
meanHout = out_hidden.mean(dim=(0,1)).detach().numpy()
sorted_inds = np.argsort(meanHout)  #get sorted indices
high_act_inds = [i.item() for i in sorted_inds[-5:]] #last elements are highest

plt.imshow(W_matrix, origin='lower')
ticks = range(0,20)
plt.xticks(ticks);
plt.yticks(ticks);

for node in high_act_inds:
    plt.axhline(y=node-0.5, color='w')
    plt.axhline(y=node+0.5, color='w')
    plt.axvline(x=node-0.5, color='w')
    plt.axvline(x=node+0.5, color='w')

plt.show()

In [None]:
# get most active hidden nodes

meanHout = out_hidden.mean(dim=1)  #averaged across trial time

active_nodes = []
inactive_nodes = []

for trial in range(meanHout.shape[0]):
    hidden_acts = meanHout[trial,:]
    sorted_inds = np.argsort(hidden_acts)  #get indices of 10 highest activations
    high_act_inds = [i.item() for i in sorted_inds[-10:]]
    low_act_inds = [i.item() for i in sorted_inds[:10]]
    active_nodes += high_act_inds
    inactive_nodes += low_act_inds

In [None]:
#plot activity,x,u for each node for trial x

trial = 1

f,ax = plt.subplots(20, figsize=(5,20), sharey=True)

with torch.no_grad():
    for i in range(20):
        ax[i].plot(out_hidden[trial,:,i], color='k')
        ax[i].plot(w_hidden[trial,:,i]*2.5, color='r')
        ax[i].plot(w_hidden[trial,:,i+20]*25, color='b')

In [None]:
#plot hidden activity for trial 0

def plot_hidden(out_hidden)

f,ax = plt.subplots(20, figsize=(5,10), sharey=True)

with torch.no_grad():
    for i in range(20):
        ax[i].plot(out_hidden[0,:,i])

In [None]:
#plot hidden x activity for trial 0

f,ax = plt.subplots(20, figsize=(5,10), sharey=True)

with torch.no_grad():
    for i in range(20):
        ax[i].plot(w_hidden[0,:,i])

In [None]:
w_hidden.size()

In [None]:
#plot hidden u activity for trial 0

f,ax = plt.subplots(20, figsize=(5,10), sharey=True)

with torch.no_grad():
    for i in range(20):
        ax[i].plot(w_hidden[0,:,i+20])

In [None]:
#view output for each sample-delay combo, averaged across all trials

f,ax=plt.subplots(2,5, figsize=(20,5))
with torch.no_grad():
    
    for sample in range(num_samps):
        for i,delay in enumerate(unique_delay_times):
            inds = torch.where((y == sample) & (test_on == delay))[0]
            
            for node in range(out_readout.size(2)):
                ax[sample,i].plot(out_readout[inds].mean(0))
                ax[sample,i].axvline(1000/15, linestyle='--', color='gray') #sample on
                ax[sample,i].axvline(1500/15, linestyle='--', color='gray') #sample on
                ax[sample,i].axvline(delay, linestyle='--', color='gray') #test on
                ax[sample,i].axvline(500/15+delay, linestyle='--', color='gray') #test off

In [None]:
#plots all hidden activity for each trial type (takes a while)

def plot_all_hidden(out_hidden):

    f,ax=plt.subplots(2,5, figsize=(20,5))
    with torch.no_grad():

        for sample in range(num_samps):
            for i,delay in enumerate(unique_delay_times):
                inds = torch.where((y == sample) & (test_on == delay))[0]

                for node in range(out_hidden.size(2)):
                    ax[sample,i].plot(out_hidden[inds].mean(0))

In [None]:
for k in models_dict.keys():
    plot_all_hidden(models_dict[k]['hidden activity'])

In [None]:
out_readout.size()

In [None]:
## concatenate hidden layer activity for networks of same type along nodes;
## result is 1024 trials * 434 timepoints * 60 nodes
plas_hidden = torch.cat([models_dict[k]['hidden activity'] for k in plas_nets], dim=2)
dep_hidden = torch.cat([models_dict[k]['hidden activity'] for k in dep_nets], dim=2)

In [None]:
plot_all_hidden(plas_hidden)

In [None]:
plot_all_hidden(dep_hidden)

In [None]:
## sample activity
print(plas_hidden[:, samp_on:samp_off, :].mean())
dep_hidden[:, samp_on:samp_off, :].mean()

In [None]:
hidden = plas_hidden

In [None]:
mean_delay_act = np.zeros(5)
mean_test_act = np.zeros(5)

for n,i in enumerate(unique_delay_times):
    
    #find trials with this delay length
    inds = torch.where(test_on == i)[0].tolist()
    
    #split up activity by task event
    delay_activity = hidden[inds, samp_off:int(i), :]
    test_activity = hidden[inds, int(i):int(i)+int(500/dt_ann), :]
    
    mean_delay_act[n] = (delay_activity.mean().item())
    mean_test_act[n] = (test_activity.mean().item())
    
print('mean delay activity:', mean_delay_act.mean().item())
print('mean test activity:', mean_test_act.mean().item())

plastic:
- mean delay activity: 0.0628472201526165
- mean test activity: 0.4721239745616913

depressing:
- mean delay activity: 0.06712958887219429
- mean test activity: 0.4595801293849945

In [None]:
#returns activations for nodes and weights as trials * time * hidden-nodes matrix
with torch.no_grad():
    unique_delay_times = torch.unique(test_on)
    act_neur = []
    act_syn = []

    for i in unique_delay_times:
        act_vs_time_neur = []
        act_vs_time_syn = []

        for j in range(num_samps):                
            inds = torch.where((y == j) & (test_on == i))[0]

            mean_act_neur = out_hidden[inds].mean(0) #average for all hidden-layer nodes?
            mean_act_syn = w_hidden[inds].mean(0)

            act_vs_time_neur.append(mean_act_neur- mean_act_neur[int(1000/dt_ann)]) #not sure what the subtraction part means
            act_vs_time_syn.append(mean_act_syn - mean_act_syn[int(1000/dt_ann)])

        act_neur.append(act_vs_time_neur)
        act_syn.append(act_vs_time_syn)

In [None]:
unique_delay_times = [d.item() for d in torch.unique(test_on)]
samps = [0,1]

In [None]:
#create trial type dictionary
trial_d = {}
trial_d[0] = dict.fromkeys(unique_delay_times)
trial_d[1] = dict.fromkeys(unique_delay_times)

with torch.no_grad():
    for delay in unique_delay_times:
        for samp in range(num_samps): 
            inds = torch.where((y == samp) & (test_on == delay))[0]
            trial_d[samp][delay] = {'inds' : inds}
trial_d

In [None]:
# look at top 10 most active hidden nodes for each trial
with torch.no_grad():
    
    for delay in unique_delay_times:
        for samp in range(num_samps):
            
            inds = trial_d[samp][delay]['inds']
            Hout = out_hidden[inds,:,:]
            meanHout = Hout.mean(dim=1)  #averaged across trial time
    
            active_nodes = []
            inactive_nodes = []
        
            for trial in range(meanHout.shape[0]):
                hidden_acts = meanHout[trial,:]
                sorted_inds = np.argsort(hidden_acts)  #get indices of 10 highest activations
                high_act_inds = [i.item() for i in sorted_inds[-10:]]
                low_act_inds = [i.item() for i in sorted_inds[:10]]
                active_nodes += high_act_inds
                inactive_nodes += low_act_inds
            
            trial_d[samp][delay]['active_nodes'] = set(active_nodes)
            trial_d[samp][delay]['inactive_nodes'] = set(inactive_nodes)
            
for samp in range(num_samps):
    for delay in unique_delay_times:
        print('samp:', samp)
        print('delay:', delay)
        print('active_nodes:', trial_d[samp][delay]['active_nodes'])
        #print('inactive_nodes:', trial_d[samp][delay]['inactive_nodes'])
        print('')

In [None]:
w_hidden.size()

In [None]:
# time * nodes
with torch.no_grad():
    plt.figure(figsize=(10,2))
    plt.imshow(out_hidden.mean(dim=0).T, cmap='hot', interpolation='nearest')
    plt.show()

In [None]:
# time * nodes
with torch.no_grad():
    plt.figure(figsize=(10,2))
    plt.imshow(out_hidden.mean(dim=0).T, cmap='hot', interpolation='nearest')
    plt.show()

In [None]:
# trials * nodes
with torch.no_grad():
    plt.figure(figsize=(15,2))
    plt.imshow(out_hidden.mean(dim=1).T, cmap='hot', interpolation='nearest')
    plt.show()

In [None]:
# trials * time
with torch.no_grad():
    plt.figure(figsize=(10,3))
    plt.imshow(out_hidden.mean(dim=2).T, cmap='hot', interpolation='nearest')
    plt.show()

In [None]:
#get trial indices sorted by delay
sorted_trial_inds0 = []
sorted_trial_inds1 = []
for delay in unique_delay_times: #sort by delay length
    sorted_trial_inds0 += [x.item() for x in trial_d[0][delay]['inds']]
    sorted_trial_inds1 += [x.item() for x in trial_d[1][delay]['inds']]
    
sorted_out_hidden = out_hidden.clone().detach()
sorted_out_hidden0 = torch.stack([sorted_out_hidden[trial,:,:] for trial in sorted_trial_inds0])
sorted_out_hidden1 = torch.stack([sorted_out_hidden[trial,:,:] for trial in sorted_trial_inds1])
all_sorted_hidden = torch.cat((sorted_out_hidden0, sorted_out_hidden1))

In [None]:
# trials * time, sorted by sample and delay length
with torch.no_grad():
    plt.figure(figsize=(10,3))
    plt.imshow(all_sorted_hidden.mean(dim=2).T, cmap='hot', interpolation='nearest')
    plt.show()

In [None]:
# time * nodes
with torch.no_grad():

    # sorted_trial_inds = []
    f,ax = plt.subplots(2,5, figsize=(15,5), sharex=True, sharey=True)

    for samp in range(num_samps):
        for x,delay in enumerate(unique_delay_times):
            
            inds = trial_d[samp][delay]['inds']
            act = torch.stack([out_hidden[trial,:,:] for trial in inds])
            active_nodes = list(trial_d[samp][delay]['active_nodes'])
            img = ax[samp,x].imshow(act[:,:,active_nodes].mean(dim=0).T, cmap='hot', interpolation='nearest')
            #img = ax[samp,x].imshow(act.mean(dim=0).T, cmap='hot', interpolation='nearest')
            ax[samp,x].set_aspect('auto') #so x axis doesn't get squished

In [None]:
# time * nodes
with torch.no_grad():

    # sorted_trial_inds = []
    f,ax = plt.subplots(1,5, figsize=(15,3), sharex=True, sharey=True)

    for x,delay in enumerate(unique_delay_times):

        inds0 = trial_d[0][delay]['inds']
        act0 = torch.stack([out_hidden[trial,:,:] for trial in inds0])
        active_nodes0 = list(trial_d[0][delay]['active_nodes'])
        
        inds1 = trial_d[1][delay]['inds']
        act1 = torch.stack([out_hidden[trial,:,:] for trial in inds1])
        active_nodes1 = list(trial_d[1][delay]['active_nodes'])
        
        b = np.array([len(act0), len(act1)])
        min_trials = b.min()-1
        all_active = list(set(active_nodes0+active_nodes1))
        
        diff = act0[:min_trials,:,:] - act1[:min_trials,:,:]
        
        img = ax[x].imshow(diff[:,:,all_active].mean(dim=0).T, cmap='bwr', interpolation='nearest')
        #img = ax[samp,x].imshow(act.mean(dim=0).T, cmap='hot', interpolation='nearest')
        ax[x].set_aspect('auto') #so x axis doesn't get squished
        ax[x].set_title(delay)

In [None]:
# time * nodes
with torch.no_grad():

    # sorted_trial_inds = []
    f,ax = plt.subplots(1,5, figsize=(15,3), sharex=True, sharey=True)

    for x,delay in enumerate(unique_delay_times):

        inds0 = trial_d[0][delay]['inds']
        act0 = torch.stack([out_hidden[trial,:,:] for trial in inds0])
        active_nodes0 = list(trial_d[0][delay]['active_nodes'])
        
        inds1 = trial_d[1][delay]['inds']
        act1 = torch.stack([out_hidden[trial,:,:] for trial in inds1])
        
        b = np.array([len(act0), len(act1)])
        min_trials = b.min()-1
        
        diff = act0[:min_trials,:,:] - act1[:min_trials,:,:]
        
        img = ax[x].imshow(diff[:,:,:].mean(dim=0).T, cmap='bwr', interpolation='nearest')
        #img = ax[samp,x].imshow(act.mean(dim=0).T, cmap='hot', interpolation='nearest')
        ax[x].set_aspect('auto') #so x axis doesn't get squished
        ax[x].set_title(delay)

In [None]:
E_I = model.rnn.D.sum(axis=0) #get E/I as list of +1's and -1's

In [None]:
F_inds = model.rnn.facil_syn_inds

In [None]:
D_inds = model.rnn.depress_syn_inds

In [None]:
ex_bool = E_I==1
in_bool = E_I==-1

In [None]:
ex_bool[:50] = False
ex_fac_bool = ex_bool

In [None]:
ex_bool[50:] = False
ex_dep_bool = ex_bool

In [None]:
in_bool[:50] = False
in_fac_bool = in_bool

In [None]:
in_bool[50:] = False
in_dep_bool = in_bool

In [None]:
#get trial indices sorted by delay
sorted_trial_inds0 = []
sorted_trial_inds1 = []
for delay in unique_delay_times: #sort by delay length
    sorted_trial_inds0 += [x.item() for x in trial_d[0][delay]['inds']]
    sorted_trial_inds1 += [x.item() for x in trial_d[1][delay]['inds']]
    
sorted_out_hidden = out_hidden.clone().detach()
sorted_out_hidden0 = torch.stack([sorted_out_hidden[trial,:,:] for trial in sorted_trial_inds0])
sorted_out_hidden1 = torch.stack([sorted_out_hidden[trial,:,:] for trial in sorted_trial_inds1])
all_sorted_hidden = torch.cat((sorted_out_hidden0, sorted_out_hidden1))

In [None]:
# trials * time, sorted by sample and delay length
with torch.no_grad():
    plt.figure(figsize=(10,3))
    plt.imshow(all_sorted_hidden[:,:,E_I_bool].mean(dim=2).T, cmap='hot', interpolation='nearest')
    plt.show()

In [None]:
model.rnn.weight_ho