In [None]:
import scipy.io
import os
import sys
sys.path.insert(0,'..')
from rnn.model import RNN
from rnn.task import trial_generator
import numpy as np
from analysis.tf_utils import *
from analysis.analysis_utils import *
from scipy.stats import zscore
from itertools import permutations
import copy
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.colors import ListedColormap, LinearSegmentedColormap
from matplotlib.animation import FuncAnimation
# from parula import Parula
from cycler import cycler
from analysis.summary_parallel import Summary
from matplotlib import gridspec
from matplotlib.ticker import StrMethodFormatter
from matplotlib.colors import colorConverter as cc

import pickle
%matplotlib inline

# Experiment settings

In [None]:
base_dir = ""
task_dir = "datasweep_SOAs.pkl"

summary_settings = {}
pltcolors, pltcolors_alt = steffiscolours()

Sum_obj = Summary()

# Run or load summary over many models

In [None]:

calc_vex=False
model_dir = os.path.join(base_dir, "..", "models/sweep_main")
data_dir = "../data/"+str(task_dir)
data_list, summary_settings = Sum_obj.run_summary(summary_settings, model_dir, data_dir,n_jobs=5, calc_vex=calc_vex)

mod_indices = np.arange(len(data_list))
print("stats from " +str(len(data_list)) + " analyses")

# Generate plots

In [None]:
# Plot oscillation frequency for all trained models
# This is for preselection, we only include models that oscillate at the loss frequency

fr_l = summary_settings["freqs_l"]
plt_indices = []
for i in mod_indices:
    osc_freq = fr_l[np.argmax(np.mean(data_list[i]["post_spectrum"],axis=1))]
    osc_freq= data_list[i]["vex_f"]
    if np.isclose(data_list[i]["loss_f"], osc_freq, atol=.5):
        plt.scatter(data_list[i]["loss_f"],osc_freq, color = 'green', alpha = 0.5)
        plt_indices.append(i)

    else:
        plt.scatter(data_list[i]["loss_f"], osc_freq, color = 'red', alpha = 0.5)
plt.xlabel("Osc_freq")
plt.ylabel("Top_freq")
plt.title("Oscillation frequency");


In [None]:
# Obtain proportions of phase orders exploited by models from each training frequency and ISI
phase_orders = []  # Percentages
phase_orders_v = [] # standard error over percentages
phase_orders_counts = [] # Counts
freq_ISI = [] # Frequency and ISI of each model
freqInds=[] # Indices of frequencies
ISIInds=[] # Indices of ISIs
phase_order_list = [] # Flattened list of phase orders of all models
phase_orders_raw_data = [] # Percentage of phase orders of each model separately

perms= np.array([[3,1,2],[1,3,2],[3,2,1],[2,3,1],[1,2,3],[2,1,3]])

ISIs=summary_settings["ISIs"]
frs = [1.5, 2.04, 2.75, 3.73]
for i_fr,fr in enumerate(frs):
    for i_ISI, ISI in enumerate(ISIs):
        phase_orders_fr = [] 
        phase_orders_fr_counts = []
        incl_freq_ISI = False
        
        #loop through all models
        for i in plt_indices:
            osc_freq = data_list[i]["vex_f"] 

            if np.isclose(osc_freq,fr, atol = 0.2):
                if np.isclose(data_list[i]["ISI"],ISI, atol = .1):
                    incl_freq_ISI=True
                    phase_orders_fr.append(np.array(data_list[i]["phase_order"])/np.sum(data_list[i]["phase_order"]))
                    phase_orders_fr_counts.append(np.array(data_list[i]["phase_order"]))
                    for order_ind in range(6):
                        phase_order_list.extend([[order_ind,fr,ISI*10+200]]*int(data_list[i]["phase_order"][order_ind]))
   
        
        # Summary of particular models with freq and ISI
        if incl_freq_ISI:
            print("Including frequency " + str(fr) + ", \n ISI " + str(ISI) +  ", n = " + str(len(phase_orders_fr)))
            phase_orders.append(np.mean(phase_orders_fr, axis = 0))
            phase_orders_v.append(np.std(phase_orders_fr, axis = 0)/np.sqrt(len(phase_orders_fr)))
            phase_orders_counts.append(np.sum(phase_orders_fr_counts, axis = 0))
            phase_orders_raw_data.append(phase_orders_fr)
            freq_ISI.append("F: " +str(fr) + ", ISI: "+str(ISI*10+200))
            freqInds.append(i_fr)
            ISIInds.append(i_ISI)



In [None]:
# Check how many phase orders we predict correctly

n_cols = len(ISIs)
n_rows = len(frs)

result =np.zeros((n_rows,n_cols))
prediction =np.zeros((n_rows,n_cols))
sign_and_corr = np.zeros((n_rows,n_cols))
RNN_order =np.zeros((n_rows,n_cols))


for i in range(len(phase_orders)):
    pred_i, pred = get_phase_order(frs[freqInds[i]],ISIs[ISIInds[i]]*10+200)
    print("PREDICTION")
    print(pred_i,pred,frs[freqInds[i]],ISIs[ISIInds[i]]*10+200)
    print("STATS")
    print(phase_orders_counts[i],scipy.stats.chisquare(phase_orders_counts[i])[1])
    result[freqInds[i],ISIInds[i]]=phase_orders[i][pred_i]*100
    prediction[freqInds[i],ISIInds[i]]=pred_i
    RNN_order[freqInds[i],ISIInds[i]]=np.argmax(phase_orders[i])
    if np.argmax(phase_orders[i])==pred_i:
        sign_and_corr[freqInds[i],ISIInds[i]]=1
print(np.sum(sign_and_corr))


In [None]:
# Plot phase order as a function of frequency and ISI bars Fig S3

colors =['mediumorchid',
         'orchid',
         'thistle',
         'lightskyblue',
         'slateblue',
         'midnightblue']
shcolors = [[0.8,0.8,0.8],[0.7,0.7,0.7],[0.6,0.6,0.6],[0.5,0.5,0.5]]
n_plots = n_rows*n_cols

def plt_ind(i):
    return i%n_cols+1+(n_rows-1)*n_cols - (i//n_cols)*n_cols

with mpl.rc_context(fname="matplotlibrc"):
    fig=plt.figure(figsize=(n_cols, n_rows))

    for i in range(n_plots):
        pred_i, pred = get_phase_order(frs[freqInds[i]],ISIs[ISIInds[i]]*10+200)
        ax1 = fig.add_subplot(n_rows, n_cols,plt_ind(i))
        ax1.bar(np.arange(len(perms)),phase_orders[i]*100, color=colors,
                yerr=phase_orders_v[i]*100, 
                align='edge',  ecolor=shcolors[-1], capsize=4)
        ax1.axhline(16, linestyle="--", color="grey")
        ax1.axhline(98,pred_i/6+0.04,(pred_i+1)/6-0.04, linestyle="-", color=pltcolors[-1],zorder = 1000)

    
        ax1.set_ylim(0,100)
        ax1.set_yticks([0,50,100])
        ax1.set_xticks(np.arange(len(perms)))
        plt.xticks(rotation=60)
        if plt_ind(i)<=n_cols:
            ax1.set_title(str(ISIs[ISIInds[i]]*10+200)+" ms")
        if plt_ind(i)>(n_rows-1)*n_cols:
            ax1.set_xticklabels([str(perm)[1:-1] for perm in perms])
            ax1.set_xlabel("Order")

        else:
            ax1.set_xticklabels([])
        if not i%n_cols==0:
            ax1.set_yticklabels([])

        else:
            ax1.set_ylabel(str(frs[freqInds[i]])+" Hz")
            ax1.set_yticklabels(["","50%","100%"])

    plt.tight_layout()


plt.savefig("../figures/model_phaseorders_bars.pdf", facecolor='white')

In [None]:
# Plot phase order as a function of frequency and ISI boxplots Fig S3

colors =['mediumorchid',
         'orchid',
         'thistle',
         'lightskyblue',
         'slateblue',
         'midnightblue']
n_plots = n_rows*n_cols

def plt_ind(i):
    return i%n_cols+1+(n_rows-1)*n_cols - (i//n_cols)*n_cols

with mpl.rc_context(fname="matplotlibrc"):
    fig=plt.figure(figsize=(n_cols, n_rows))

    for i in range(n_plots):
        pred_i, pred = get_phase_order(frs[freqInds[i]],ISIs[ISIInds[i]]*10+200)
        ax1 = fig.add_subplot(n_rows, n_cols,plt_ind(i))
        alpha=.4
        for j in range(6):
            c=colors[j]

            ax1.boxplot(np.array(phase_orders_raw_data[i])[:,j]*100,positions=[j], widths =.6, patch_artist=True,
                            boxprops=dict(facecolor=cc.to_rgba(c, alpha=alpha), color=c),
                            capprops=dict(color=c),
                            whiskerprops=dict(color=c),
                            medianprops=dict(color=c),
                            flierprops={'marker': 'o', 'markersize': 1, 'markerfacecolor':c, 'markeredgecolor':c})#, color=colors,

        ax1.axhline(16, linestyle="--", color="grey")
        ax1.axhline(0, linestyle="--", color="black")

        ax1.axhline(100,pred_i/6+0.04,(pred_i+1)/6-0.04, linestyle="-", color=pltcolors[-1],zorder = 1000)
        ax1.set_ylim(-20,100)
        ax1.set_yticks([0,50,100])
        ax1.set_xticks(np.arange(len(perms))+0.5)
        plt.xticks(rotation=60)
        if plt_ind(i)<=n_cols:
            ax1.set_title(str(ISIs[ISIInds[i]]*10+200)+" ms")
        if plt_ind(i)>(n_rows-1)*n_cols:
            ax1.set_xticklabels([str(perm)[1:-1] for perm in perms])
            ax1.set_xlabel("Order")
            ax1.tick_params(axis='x', which='major', length=0)

        else:
            ax1.set_xticklabels([])
            ax1.set_xticks([])

        if not i%n_cols==0:
            ax1.set_yticklabels([])

        else:
            ax1.set_ylabel(str(frs[freqInds[i]])+" Hz")
            ax1.set_yticklabels(["0%","50%","100%"])
        ax1.spines["bottom"].set_visible(False)
        ax1.tick_params(axis='y', which='major', length=2)

    plt.tight_layout()

plt.savefig("../figures/model_phaseorders_bars.pdf", facecolor='white')

In [None]:
# Plot percentage of phase orders as a function of frequency and ISI

def add_numbers(ax, grid, fontsize,float_labels=False,color='black'):
    for (j, i), label in np.ndenumerate(grid):
        if float_labels:
            ax.text(i, j, "{:.2f}".format(label), ha="center", va="center", fontsize=fontsize, color=color)
        else:
            ax.text(i, j, int(label), ha="center", va="center", fontsize=fontsize, color=color)

plt.imshow(result,origin='lower',cmap='RdBu',vmin=-200,vmax=200)
plt.yticks(np.arange(len(frs)),frs)
plt.xticks(np.arange(len(ISIs)),[ISI*10+200 for ISI in ISIs])
plt.xlabel("Inter stimulus interval (ms)")
plt.ylabel("Oscillation frequency (Hz)")
add_numbers(plt.gca(),result/100,12,float_labels=True)
print(plt.gci().get_clim())

In [None]:
# How many correct predictions do we get?
np.mean(result)

In [None]:
# Load preditctions from the reduced model
data = pickle.load(open("../data/order_pred.pkl",'rb'))    
freqs = data['freqs']
isis = data['isis']
result = data['result']

In [None]:
# Plot model phase orders on top of predictions (FIG 5i)

# Color map for the backdrop
col_dict={0:'orchid',
        2:'thistle',
        4:'slateblue',
        5:'midnightblue'}

cm = ListedColormap([col_dict[x] for x in col_dict.keys()])
labels = np.array([perms[0],perms[2],perms[4],perms[5]])
len_lab = len(labels)
norm_bins = np.sort([*col_dict.keys()]) + 0.5
norm_bins = np.insert(norm_bins, 0, np.min(norm_bins) - 1.0)
norm = mpl.colors.BoundaryNorm(norm_bins, len_lab, clip=True)
fmt = mpl.ticker.FuncFormatter(lambda x, pos: labels[norm(x)])

fig=plt.figure(figsize=(3.1,1.6))
res4 = np.copy(result)
im = plt.imshow(res4.T, cmap=cm, norm=norm,aspect='auto')
plt.ylabel("frequency (Hz)")
plt.xlabel("stimulus onset asynchrony (ms)")
diff = norm_bins[1:] - norm_bins[:-1]
tickz = norm_bins[:-1] + diff / 2
cb = plt.colorbar(im, format=fmt, ticks=tickz,fraction=0.024, pad=0.04)
fig.tight_layout()

# Add borders to backdrop
lw = .5
color='darkgrey'
for i in range(20):
    plt.plot(1000*i/(freqs*3)-200,np.arange(len(freqs)),color=color,lw=lw,zorder=50)
for i in range(20):
    plt.plot(1000*i/(freqs*2)-200,np.arange(len(freqs)),color=color,lw=lw,zorder=50)
for i in range(20):
    plt.plot(1000*i/(freqs)-200,np.arange(len(freqs)),color=color,lw=lw,zorder=50)

# Overlay model phase orders
m2="o"
s_in = 10
s_out=30
for fi,fr in enumerate(frs):
    for ISIi, ISI in enumerate(ISIs):

        order = RNN_order[fi,ISIi]
        y = arg_is_close(fr,freqs)
        x = arg_is_close(ISI*10,(isis))
        plt.scatter(x,y,color='white',s=s_out,marker=m2,zorder=90)
        plt.scatter(x,y,color='grey',s=15,marker=m2,zorder=90)

        plt.scatter(x,y,color=col_dict[order],s=s_in,marker=m2,zorder=100)
        
plt.ylim(arg_is_close(1,freqs),arg_is_close(4,freqs))
plt.xlim(0,arg_is_close(650,isis+200))
plt.yticks([arg_is_close(fr,freqs) for fr in frs])
plt.xticks([arg_is_close(isi*10,isis) for isi in ISIs])
plt.gca().set_xticklabels([ISI*10+200 for ISI in ISIs],rotation=-60)
plt.gca().set_yticklabels(frs)
plt.savefig("../figures/model_phaseorders_pred_n.pdf", facecolor='white')
