In [None]:
import numpy as np
from torch.utils.data import DataLoader
import os, sys
sys.path.append(os.getcwd()+"/..")
from rnn_scripts.train import *
from rnn_scripts.utils import *
from tasks.seqDS import *

pltcolors = red_yellow_colours()
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# Load models

fig_dir=os.getcwd()+"/../figures/"
model_dir = os.getcwd()+"/../models/"

model1="N512_T0221-113711"
model2 = "N512_T0217-151523"

rnn1,params1,task_params1,training_params1 = load_rnn(model_dir+model1)
rnn1.rnn.svd_orth()
rnn2,params1,task_params2,training_params2 = load_rnn(model_dir+model2)
rnn2.rnn.svd_orth()

In [None]:
# Setup task and run models
ds = seqDS(task_params1)

dataloader = DataLoader(
    ds, batch_size=128, shuffle=True
)
test_input, test_target, test_mask = next(iter(dataloader))
labels = extract_labels(test_input)
rates1, _ = predict(rnn1, test_input,mse_loss, test_target, test_mask)
rates2, _ = predict(rnn2, test_input,mse_loss, test_target, test_mask)
ind0=np.arange(128)[labels==0]
ind1=np.arange(128)[labels==1]

In [None]:
# Calculate phase and rate statistics

period = 1000/(task_params1['freq']*task_params1['dt'])
qp = int(period/4)
period=int(period)
sin = np.sin(np.linspace(0,np.pi*2,period))
cos = np.cos(np.linspace(0,np.pi*2,period))

phase= np.arctan2(test_input[:,0,0],test_input[:,qp,0]).cpu().numpy()
phase_int=np.int_(((phase+np.pi)/(np.pi*2))*period)+1

angs_dist1=[]
means_dist1 = []
angs_dist2=[]
means_dist2 = []

for ni in range(rates1.shape[-1]):
    
    #RNN 1
    angs0 = []
    angs1 = []
    means0= []
    means1= []
    for ind in ind0:
        mean =  np.mean(rates1[ind,-period-phase_int[ind]:-phase_int[ind],ni])
        angs0.append(np.angle(1j*np.inner(sin,rates1[ind,-period-phase_int[ind]:-phase_int[ind],ni]-mean)+
                     np.inner(cos,rates1[ind,-period-phase_int[ind]:-phase_int[ind],ni]-mean)))
        means0.append(mean)
    for ind in ind1:
        mean =  np.mean(rates1[ind,-period-phase_int[ind]:-phase_int[ind],ni])
        angs1.append(np.angle(1j*np.inner(sin,rates1[ind,-period-phase_int[ind]:-phase_int[ind],ni]-mean)+
                     np.inner(cos,rates1[ind,-period-phase_int[ind]:-phase_int[ind],ni]-mean)))
        means1.append(mean)
    angs_dist1.append(circ_dif(circ_mean(angs1)[0],circ_mean(angs0)[0]))
    means_dist1.append(np.mean(means0)-np.mean(means1))
    
    #RNN 2
    angs0 = []
    angs1 = []
    means0= []
    means1= []
    for ind in ind0:
        mean =  np.mean(rates2[ind,-period-phase_int[ind]:-phase_int[ind],ni])
        angs0.append(np.angle(1j*np.inner(sin,rates2[ind,-period-phase_int[ind]:-phase_int[ind],ni]-mean)+
                     np.inner(cos,rates2[ind,-period-phase_int[ind]:-phase_int[ind],ni]-mean)))
        means0.append(mean)
    for ind in ind1:
        mean =  np.mean(rates2[ind,-period-phase_int[ind]:-phase_int[ind],ni])
        angs1.append(np.angle(1j*np.inner(sin,rates2[ind,-period-phase_int[ind]:-phase_int[ind],ni]-mean)+
                     np.inner(cos,rates2[ind,-period-phase_int[ind]:-phase_int[ind],ni]-mean)))
        means1.append(mean)
    angs_dist2.append(circ_dif(circ_mean(angs1)[0],circ_mean(angs0)[0]))
    means_dist2.append(np.mean(means0)-np.mean(means1))


In [None]:
def rand_jitter(arr,stdev):
    """For creating the jittered effect in the scatter plots"""
    return arr + np.random.randn(len(arr)) * stdev

In [None]:
# Make scatter plots of the statistics

stdev=0.1
x1=rand_jitter(np.ones_like(angs_dist1),stdev)
x2=rand_jitter(np.ones_like(angs_dist2)*2,stdev)

color1 = 'darkred'
color2 = 'orange'
vbarcolor = 'grey'
alpha = 0.05
fig,ax =plt.subplots(1,2,figsize=(2.5,1.5))

ax[1].scatter(x1,angs_dist1,c=color1 ,alpha=alpha,label='alternative solution')#,angs_dist2])
ax[1].scatter(x2,angs_dist2,c=color2,alpha=alpha,label='main text solution')
ax[1].axhline(0,ls='--',color =vbarcolor)
ax[1].set_yticklabels([r'$-\pi$',r'$0$',r'$\pi$'])
ax[1].set_yticks([-np.pi,0,np.pi])
ax[1].set_ylim(-np.pi*1.1,np.pi*1.1)
ax[1].spines['bottom'].set_visible(False)
ax[1].set_xticks([])

leg = ax[1].legend(bbox_to_anchor=(1,1.18))
for lh in leg.legendHandles: 
    lh.set_alpha(1)

ax[0].scatter(x1,means_dist1,c=color1 ,alpha=alpha)
ax[0].scatter(x2,means_dist2,c=color2,alpha=alpha)
ax[0].axhline(0,ls='--',color =vbarcolor)
ax[0].set_yticks([-12,0,12])
ax[0].set_ylim([-12,12])
ax[0].spines['bottom'].set_visible(False)
ax[0].set_xticks([])
ax[1].set_title("phase difference")
ax[0].set_title("mean rate difference")

plt.savefig(fig_dir + "sol_comp.svg")