# Model2 - RNN model of CPRO task using PyTorch with Trial dynamics

#### Taku Ito
#### 09/30/2018

In [1]:
import torch
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
np.set_printoptions(suppress=True)
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stats
%matplotlib inline
import os
import model.model4 as model
model = reload(model)
import model.task as task
task = reload(task)
import time
os.sys.path.append('utils/bctpy/')
import bct
retrain = True

# Build training and test batches

In [169]:
model = reload(model)
TrialInfo = model.TrialBatches()
TrialInfo.createAllBatches(nproc=25)

Running batch 0
Running batch 1500
Running batch 3000
Running batch 4500
Running batch 6000
Running batch 7500
Running batch 9000
Running batch 10500
Running batch 12000
Running batch 13500
Running batch 15000
Running batch 16500
Running batch 18000
Running batch 19500
Running batch 21000
Running batch 22500
Running batch 24000
Running batch 25500
Running batch 28500
Running batch 27000
Running batch 30000
Running batch 33000
Running batch 31500
Running batch 34500
Running batch 36000
Running batch 27100
Running batch 28600
Running batch 25600
Running batch 10600
Running batch 18100
Running batch 24100
Running batch 100
Running batch 31600
Running batch 36100
Running batch 21100
Running batch 15100
Running batch 30100
Running batch 33100
Running batch 16600
Running batch 7600
Running batch 3100
Running batch 6100
Running batch 12100
Running batch 19600
Running batch 13600
Running batch 34600
Running batch 1600
Running batch 22600
Running batch 9100
Running batch 4600
Running batch 2720

# Train RNN first on a subset of tasks (half the tasks) using CPU

In [171]:
model = reload(model)
input_batches, output_batches = model.load_training_batches(cuda=False)
test_inputs, test_outputs = model.load_testset(cuda=False)


In [187]:
model = reload(model)
Network = model.RNN(num_rule_inputs=12,
                     num_sensory_inputs=16,
                     num_hidden=128,
                     num_motor_decision_outputs=4,
                     learning_rate=0.0001,
                     thresh=0.9)
# Network.cuda = True
Network = Network.cpu()



In [188]:
model = reload(model) 
timestart = time.time()
model.batch_training(Network, input_batches,output_batches,cuda=False)  
timeend = time.time()
print 'Time elapsed using CPU:', timeend-timestart

Iteration: 0
	loss: 0.253582566977
Time elapsed... 0.0115549564362
	Accuracy: 0.0%
Iteration: 5000
	loss: 0.0899647250772
Time elapsed... 11.0452830791
	Accuracy: 0.0%
Iteration: 10000
	loss: 0.0598498694599
Time elapsed... 11.691521883
	Accuracy: 8.3333%
Iteration: 15000
	loss: 0.0687529146671
Time elapsed... 12.3160851002
	Accuracy: 29.1667%
Iteration: 20000
	loss: 0.0398500859737
Time elapsed... 10.9046399593
	Accuracy: 31.25%
Iteration: 25000
	loss: 0.0239020306617
Time elapsed... 10.999475956
	Accuracy: 58.3333%
Iteration: 30000
	loss: 0.00815108790994
Time elapsed... 10.4091379642
	Accuracy: 68.75%
Iteration: 35000
	loss: 0.00138975447044
Time elapsed... 10.4073472023
	Accuracy: 95.8333%
Iteration: 40000
	loss: 0.00036781351082
Time elapsed... 11.4844689369
	Accuracy: 97.9167%
Iteration: 45000
	loss: 0.000203750460059
Time elapsed... 11.3463718891
	Accuracy: 100.0%
Iteration: 50000
	loss: 5.77190912736e-05
Time elapsed... 11.5179691315
	Accuracy: 100.0%
Last 1000 batches had abov

# Test on held-out tasks

In [196]:
model = reload(model)
outputs, hidden = model.eval(Network,test_inputs,test_outputs,cuda=False)

	loss: 0.100028008223
	Accuracy: 37.5%


# Plot example miniblocks

In [None]:
outputs2 = outputs.detach().numpy()
hidden2 = hidden.detach().numpy()
inputs2 = test_inputs.detach().numpy()

In [None]:
plt.figure(figsize=(10,5))
plt.title('Input sequence',fontsize=24)
plt.ylabel('Input',fontsize=20)
plt.xlabel('Time',fontsize=20)
a = sns.heatmap(inputs2[:,0,:].T,cmap='Blues')
a.invert_yaxis()

plt.figure(figsize=(10,5))
plt.title('Hidden unit activity',fontsize=24)
plt.ylabel('Input',fontsize=20)
plt.xlabel('Time',fontsize=20)
a = sns.heatmap(hidden2[:,0,:].T,cmap='Blues')
a.invert_yaxis()

plt.figure(figsize=(10,5))
plt.title('Output unit activity',fontsize=24)
plt.ylabel('Input',fontsize=20)
plt.xlabel('Time',fontsize=20)
a = sns.heatmap(outputs2[:,0,:].T,cmap='Blues')
a.invert_yaxis()

# Simulate spontaneous activity on network

In [None]:
model = reload(model)
ntps = 1000
runs = 1
input_dim = 28
rest_time = torch.empty(ntps, runs, input_dim)
output_dummy = torch.empty(ntps, runs, 4) # output dummy variable, rest so everything is 0
rest_output, rest_hidden = model.eval(Network,rest_time,output_dummy,cuda=False)

rest_output = np.squeeze(rest_output.detach().numpy())
rest_hidden = np.squeeze(rest_hidden.detach().numpy())

In [None]:
plt.figure(figsize=(10,5))
plt.title('Simulated random activity', fontsize=20)
ax = sns.heatmap(rest_hidden.T,cmap='Blues')
ax.invert_yaxis()
plt.xlabel('Time',fontsize=18)
plt.ylabel('Units',fontsize=18)

# Construct correlation matrix
corrmat = np.corrcoef(rest_hidden.T)
sig = np.multiply(corrmat,corrmat>0)
ci, q = bct.community_louvain(sig)
networkdef = sorted(range(len(ci)), key=lambda k: ci[k])
networkdef = np.asarray(networkdef)
networkdef.shape = (len(networkdef),1)

plt.figure()
ax = sns.heatmap(corrmat[networkdef,networkdef.T],square=True,center=0,cmap='bwr')
ax.invert_yaxis()
plt.title('Correlation matrix',fontsize=24,y=1.04)
plt.xlabel('Regions',fontsize=20)
plt.ylabel('Regions',fontsize=20)
# plt.savefig('NoiseInduced_CorrelationMatrix.pdf')

## Analyze recurrent connectivity weights

In [None]:
mat = Network.weight_hh_l0.detach().numpy()
plt.figure()
ax = sns.heatmap(mat[networkdef,networkdef.T],square=True,center=0,cmap='bwr')
ax.invert_yaxis()
plt.title('Recurrent connectivity weights',fontsize=24,y=1.04)
plt.xlabel('Regions',fontsize=20)
plt.ylabel('Regions',fontsize=20)
plt.tight_layout()
# plt.savefig('GroundTruth_RNN_weights.pdf')

In [None]:
mat = Network.weight_hh_l0.detach().numpy()
plt.figure()
ax = sns.heatmap(mat,square=True,center=0,cmap='bwr')
ax.invert_yaxis()
plt.title('Recurrent connectivity weights',fontsize=24,y=1.04)
plt.xlabel('Regions',fontsize=20)
plt.ylabel('Regions',fontsize=20)
plt.tight_layout()
# plt.savefig('GroundTruth_RNN_weights.pdf')