In [None]:
# Plotting data in Figure 5a. The box plots show the accumulated errors since the stimulus change

In [1]:
import numpy as np
import seaborn as sns
import pandas as pd
from scipy.stats import ks_2samp
import pickle
import matplotlib.pyplot as plt
from brian2 import *

In [2]:
s = [0.25, 0.75]
#decode_time =  np.linspace(1,500,500)*ms
decode_time = np.asarray([50,100,150,200,300, 500]) - 1 #subtract 1 to adjust for zero indexing in python

n_trials = 30
n_decoding_times = len(decode_time)
# process data from the periodic population for trial number 0 - 29
rmse_Periodic = np.zeros([n_trials,n_decoding_times])

N_readout = 400
objects = []

for k in range(0,n_trials):
    with (open('./data/Periodic_population/data_readout_layer_run_idx_{}.pickle'.format(k), "rb")) as openfile:
        while True:
            try:
                objects.append(pickle.load(openfile))
            except EOFError:
                break
dt = 0.0001 # dt = 0.1 ms
offset = 500 *ms # offset is the time of the stimulus change

# preferred locations of readout layer neurons
readout_theta_0 = np.linspace(0,1,N_readout + 1)
readout_theta_0 = readout_theta_0[:-1]

for k in range(0,n_trials):
    
    spike_monitor = objects[k]
    spike_monitor_i = spike_monitor["i"]
    spike_monitor_t = spike_monitor["t"]
    # Find the rmse for the different time points in "decode time"
    idx_t = np.argwhere(spike_monitor_t[:]> offset) #find all spikes after stimulus step change
    for m in range(0,n_decoding_times):
        idx_t_upper = np.argwhere(spike_monitor_t[:]< decode_time[m] + offset)
        idx = np.intersect1d(idx_t, idx_t_upper)
        if len(idx)>0:
            sq_errs = np.minimum((s[1] - readout_theta_0[spike_monitor_i[idx]])**2, (s[1] - readout_theta_0[spike_monitor_i[idx]] + 1)**2)
            sq_errs = np.minimum((s[1] - readout_theta_0[spike_monitor_i[idx]] -1)**2, sq_errs)
            rmse_Periodic[k,m] = np.sqrt(np.mean(sq_errs))
        else:
            rmse_Periodic[k,m] = np.sqrt(np.mean(0.5)**2)
    
rmse_Gaussian = np.zeros([n_trials, n_decoding_times])

# process data from the Gaussian population for trial number 0 - 29
objects = []
for k in range(0,n_trials):
    with (open('./data/Gaussian_population/data_readout_layer_run_idx_{}.pickle'.format(k), "rb")) as openfile:
        while True:
            try:
                objects.append(pickle.load(openfile))
            except EOFError:
                break
N_readout = 400
for k in range(0,n_trials):
    spike_monitor = objects[k]
    spike_monitor_i = spike_monitor["i"]
    spike_monitor_t = spike_monitor["t"]
    idx_t = np.argwhere(spike_monitor_t[:]> offset) #find all spikes after stimulus step change
    
    # Find the rmse for the different time points in "decode time"
    for m in range(0,n_decoding_times):
        idx_t_upper = np.argwhere(spike_monitor_t[:]< decode_time[m] + offset)
        idx = np.intersect1d(idx_t, idx_t_upper)
        if len(idx)>0:
            #rmse_Gaussian[k,m] = np.sqrt(np.mean((s[1] - readout_theta_0[spike_monitor_i[idx]])**2))
            sq_errs = np.minimum((s[1] - readout_theta_0[spike_monitor_i[idx]])**2, (s[1] - readout_theta_0[spike_monitor_i[idx]] + 1)**2)
            sq_errs = np.minimum((s[1] - readout_theta_0[spike_monitor_i[idx]] -1)**2, sq_errs)
            rmse_Gaussian[k,m] = np.sqrt(np.mean(sq_errs))
        else:
            rmse_Gaussian[k,m] = np.sqrt(np.mean(0.5)**2)

In [8]:
statistics = np.zeros([np.size(decode_time)])
p_val = np.zeros([np.size(decode_time)])

for count, value in enumerate(decode_time):
    plot_data = np.transpose([rmse_Gaussian[:,value], rmse_Periodic[:,value]])
    df = pd.DataFrame(plot_data, columns=["Gaussian", "Periodic"])
    plt.rcParams['font.size'] = '26'
    # Create a barplot
    p = sns.catplot(kind="box", data=df, estimator=np.mean)

    p.set(ylabel = "Root mean squared error (a.u.)")
    plt.ylim(0,0.5)
    plt.savefig('./figures/accumulated_RMSE_readout_boxplot_{}.png'.format(value+1))
    plt.savefig('./figures/accumulated_RMSE_readout_boxplot_{}.eps'.format(value+1), format='eps')
    plt.close()
    statistics[count], p_val[count] = ks_2samp(rmse_Gaussian[:,value], rmse_Periodic[:,value])

In [6]:

print(statistics)
print(p_val)

[0.23333333 0.5        0.53333333 0.6        0.56666667 0.6       ]
[3.92945014e-01 8.99577684e-04 2.93340549e-04 2.36648847e-05
 8.73780359e-05 2.36648847e-05]
