# Imports

In [3]:
import numpy as np
import scipy.io
from scipy.stats import zscore
import os, sys
sys.path.insert(0,'..')

from rnn.model import RNN
from rnn.task import trial_generator
from analysis.tf_utils import *
from analysis.analysis_utils import *

from pycircstat.tests import watson_williams as watson_williams_test, rayleigh
from itertools import permutations

import matplotlib.pyplot as plt
from matplotlib import gridspec
%matplotlib inline

# Initialise model and task
Here we choose and load a trained model

In [4]:
fname = "y7sfja1hSpecRad1.5DaleTrueTC20_100rand0sparse1osc3.73cost0.1"
base_dir = ".."
task_dir = "sweep_main"
model_dir = os.path.join(base_dir, "models", task_dir, fname)
net = RNN()
var = scipy.io.loadmat(model_dir)
net.load_model(model_dir)

## Experiment settings

Settings for running analysis and generating figures

In [None]:
ref_phase = "sine"  # Reference phase for 'spike-phase' histogram, either sine or LFP
onlyGaba = False  # Only use inhibitory neurons for calculating LFP
substr_mean_LFP = False #substract stim triggered mean from LFP
sw_tanh2log = False # substitute logistic for tanh to get pos firing rates (adjust W and B accordingly)

cutoff_p = 10e-3  # Max Wilc p allowed for stimulus triggered neurons (<1)
cutoff_d = 0  # Min Wilc d allowed stimulus selective neurons (>=0)
vex_cuttoff = 0  # Exclude neurons without strong stimulus - phase effect (>=0)

normalize_fr_extract = True  # Normalize extracted firing rates for wilc p
n_trials = 224  # Number trials used in analysis
randomize_onset = 1 # Randomise onset by at most this much
freqs_l = np.logspace(*np.log10([1, 20]), num=50) # Frequencies for scalogram


In [None]:
model_par, settings = reinstate_params(var)

## Instantiate variables

We instantiate some variables for convenience

In [None]:
# model parameters
out_channels = net.out_channels
n_channels = net.n_channels
n_items = int(var["n_items"][0][0])
N = var["N"][0][0]

# instantiate useful dictionaries
model_par, settings = reinstate_params(var)

# set delay
try:
    delay = int(fname[22:25])
except:
    delay_ms = 2500
    delay = int(delay_ms/settings['deltaT'])
settings["delay"] = delay

# set stimulus onset
settings["stim_ons"] = 120 
if randomize_onset:
    settings["rand_ons"] = int(1000/(var['lossF'][0][0]*settings['deltaT']))
else:
    settings["rand_ons"]=0
if settings["rand_ons"]>settings["stim_ons"]:
    print("WARNING, can't use this random onset, defaulting to 50 steps")
    settings["rand_ons"]=50

# batch size
settings["batch_size"] = n_trials

# delay
delay_start = (
    settings["stim_ons"]
    + n_items * settings["stim_dur"]
    + (n_items - 1) * settings["stim_offs"]
)
delay_end = delay_start + settings["delay"]

# time steps
dt_sec = int(1000 / settings["deltaT"])
timestep = settings["deltaT"] / 1000

# colors for plotting
pltcolors, pltcolors_alt = steffiscolours()

# substitute logistic for tanh
if sw_tanh2log == True:
    tanh2log(net)
    settings["x2x"] = 2

## Instantiate Trial Generator

In [None]:
# instantiate a trial generator

val_perc = 0
trial_gen = trial_generator(
    n_items,
    n_channels,
    out_channels,
    val_perc,
)

trial_gen.train_ind = var["train_ind"][0]

## Generate trials

In [None]:
# draw trials with balanced proportion of each stimuli

stims = draw_balanced_trials()
stim_ind = []
for i in range(len(stims[0])):
    ind = np.argmax(np.all(np.equal(trial_gen.all_trials_arr,stims[:,i]),axis = 1))
    stim_ind.append(ind)
trial_ind_match = stim_ind
trial_ind_non_match = stim_ind

stim, label, delays, stim_roll, isi_stim, isi_probe = trial_gen.generate_input(
    settings,
    settings["delay"],
    val=False,
    stim_ind_match=trial_ind_match,
    stim_ind_non_match=trial_ind_non_match,
)


# instantiate some more variables
stim = stim.astype(np.float64)
T = np.shape(stim)[-1]
z, mask = trial_gen.generate_target(settings, label, T, delays, stim_roll, isi_probe)
settings["T"] = T
time = np.arange(T) * settings["deltaT"] / 1000
plt_time = (
    np.arange(-settings["stim_ons"] + settings["rand_ons"], T - settings["stim_ons"])
    * settings["deltaT"]
    / 1000
)

x1, r1, o1 = net.predict(settings, stim[:, :, :])
accuracy(settings, o1, label, delays, isi_probe, stim_roll)
if sw_tanh2log == False and net.activation[0]=='tanh':
    r1+=1
    r1/=2

# Generate plots

## Figure 5d: LFPs

In [None]:
# extract local field potential

trial = np.random.choice(np.arange(n_trials))
LFP = get_LFP(var, r1, stim, onlyGaba=onlyGaba, alpha = 1)
LFP = zscore(LFP, axis=0)

if substr_mean_LFP:
    substr = np.mean(LFP, axis=1)
else:
    substr = 0

amps = []
amps_u = []

for tr in range(n_trials):
    _, amp = scalogram(
        LFP[:, tr] - substr,
        7,
        time,
        settings["deltaT"] / 1000,
        freqs_l,
    )
    amps.append(amp)

amp = np.mean(np.array(amps),axis=0)

main_freq = freqs_l[np.argmax(np.mean(amp[:, delay_start:delay_end], axis=1))]
main_power = np.max(np.mean(amp[:, delay_start:delay_end], axis=1))
baseline_freq = freqs_l[np.argmax(np.mean(amp[:, :settings["stim_ons"]], axis=1))]
baseline_power = np.max(np.mean(amp[:, :settings["stim_ons"]], axis=1))

print("delay freq = " + str(main_freq) + " with power " + str(main_power) +
      "\nbaseline freq = " + str(baseline_freq) + " with power " + str(baseline_power))
f = main_freq

In [None]:
# Plot a spectrogram

plt_start = 0  
with mpl.rc_context(fname="matplotlibrc"):

    fig = plt.figure(figsize=(2.5, 2))  # , dpi = 150)
    plt.tight_layout()
    
    plt.imshow(amp[:, plt_start:], aspect="auto", cmap ='jet', vmin = -0, vmax=1.2,origin='lower')
    plt.yticks(
        np.arange(0, len(freqs_l))[::14],
        labels=["{:.1f}".format(f).strip("0") for f in freqs_l[::13]],
    )
    xt = np.arange(0, len(time[plt_start:]))[settings["stim_ons"] - plt_start :: 200]
    plt.xticks(xt, labels=["{:.0f}".format(f).strip("0") for f in np.arange(0, 2 * len(xt), 2)])
    plt.xlabel("Time (s)")
    plt.ylabel("Frequency (Hz)")
    cbar = plt.colorbar(ticks=[0,1])
    cbar.set_label("Power (Sd)", rotation=270)
    for i in range(4):
        plt.axvline(
            x=(
                settings["stim_ons"]
                - plt_start
                + i * (settings["stim_offs"] + settings["stim_dur"])
            ),
            color="white",
            linestyle="--",
            alpha=1,
            zorder=2,
        )

    plt.axvline(
        x=(delay_end - plt_start), color="white", linestyle="--", alpha=1, zorder=1
    )
    plt.savefig("../figures/LFP_fr"+str(f)+".svg")
    

## Select stimulus selective units

In [None]:
baseline_len = int(1000/(baseline_freq*settings['deltaT']))
baseline_start = max((settings["stim_ons"]-settings['rand_ons'] - baseline_len) // 2, 0)
stim_len = min(settings["stim_dur"] + settings["stim_offs"], baseline_len)

# extract stimulus triggered activity
data, labels = extract_stim_trig_act(
    r1,
    stim[:n_channels],
    stim_roll,
    settings,
    baseline_start=baseline_start,
    baseline_len=baseline_len,
    stim_len=stim_len,
    normalize=normalize_fr_extract,
)

wilc_pvals = get_wilc_pvals(data, onesided=False)
d_primes, responsive, prefered_stim = get_dprime(data, wilc_pvals, cutoff=cutoff_p)
ranked_neurons = np.argsort(d_primes)

print(
    "percentage of cells d prime > cutoff: "
    + str(np.sum(d_primes > cutoff_d) / N * 100)
)

# calculate upper and lower 50th percentile stim. selective cells
median_d = np.median(d_primes[responsive.astype(bool)])
cutoff_d = np.median(d_primes[responsive.astype(bool)])
up50th = np.arange(200)[np.logical_and(d_primes>median_d, responsive)]
low50th = np.arange(200)[np.logical_and(d_primes<median_d, responsive)]

## Figure 5b: input / output

In [None]:
# plot an example trial


trial = np.random.choice(np.arange(n_trials))
plotstim = np.copy(stim[:, trial])

ci = 0
colors0 = ["lightgray" for _ in range(8)]
colors1 = ["lightgray" for _ in range(8)]
t1 = settings["stim_ons"] - settings["rand_ons"]

for i in range(4):
    labels = np.argmax(
        np.sum(stim[:n_channels, trial, t1 : t1 + settings["stim_dur"]], axis=1), axis=0
    )
    colors0[labels] = pltcolors[i]
    colors1[labels] = pltcolors[i]
    t1 += settings["stim_dur"] + settings["stim_offs"]

for i in np.arange(1, n_channels):
    plotstim[i:] += 1.25


with mpl.rc_context(fname="matplotlibrc"):

    fig = plt.figure(figsize=(5, 3))
    gs = fig.add_gridspec(1, 1)  
    axs = fig.add_subplot(gs[0])
    for i in np.arange(n_channels):
        axs.plot(plt_time, plotstim[i].T, color=colors0[i], linewidth=2)
    for i in range(4):
        axs.axvline(
            x=(i * (settings["stim_offs"] + settings["stim_dur"])) / dt_sec,
            color="black",
            linestyle="--",
            alpha=0.5,
            zorder=-1,
        )

    axs.axvline(
        x=(delay_end - settings["stim_ons"]) / dt_sec,
        color="black",
        linestyle="--",
        alpha=0.5,
        zorder=-1,
    )

    axs.spines["right"].set_visible(False)
    axs.spines["top"].set_visible(False)
    axs.set_xlim(plt_time[0], plt_time[-1])
    axs.set_xlabel("Time (s)")
    axs.spines["left"].set_visible(False)
    axs.set_yticks(np.arange(0, 8 * 1.25, 1.25))
    axs.set_yticklabels(np.arange(8))
    axs.set_ylabel("Input neuron index")
    plt.savefig("../figures/Panel1A")
    

In [None]:
# plot example output

training_params = settings
eval_delays = delay
cutoff_T = 0
probe_time = (
    training_params["stim_ons"]
    + stim_roll
    + training_params["n_items"]
    * (training_params["stim_dur"] + training_params["probe_dur"])
    + eval_delays
    + 1
    + training_params["response_ons"]
    + (training_params["n_items"] - 1) * (training_params["stim_offs"])
    + np.sum(isi_probe, axis=1)
    - cutoff_T
)

with mpl.rc_context(fname="matplotlibrc"):

    fig, axs = plt.subplots(1, 1, figsize=(1.5, 3))
    ploto = np.copy(o1[:, :, trial])
    colors = ["grey", "C2", "C3"]
    if net.out_channels == 3:
        for i in np.arange(net.out_channels):
            ploto[:, i] = np.roll(ploto[:, i], -stim_roll[trial])

            axs.plot(
                plt_time,
                ploto[settings["rand_ons"] :, i].T,
                color=colors[i],
                linewidth=2,
            )
            ploto[:, i + 1 :] += 1.25
        axs.set_ylim(-0.1, 3 * 1.25)
        axs.set_yticks(np.arange(0 + 0.6125, 3 * 1.25 + 0.6125, 1.25))
        axs.set_yticklabels(["fixate", "match", "non-\nmatch"])  # np.arange(3))
        axs.set_ylabel("Output neuron class")

    else:
        ploto[:, 0] = np.roll(ploto[:, 0], -stim_roll[trial])
        axs.plot(
            plt_time, ploto[settings["rand_ons"] :, 0].T, color=colors[0], linewidth=2
        )
        axs.axhline(0, color="lightgrey", ls="--")
        axs.set_ylim(-1.5, 1.5)
        axs.set_yticks([-0.7, 0, 0.7])
        axs.set_yticklabels(["non-\nmatch", "0", "match"])  # np.arange(3))

        axs.set_ylabel("Output neuron current")
    axs.axvline(
        x=(delay_end - settings["stim_ons"]) / dt_sec,
        color="black",
        linestyle="--",
        alpha=0.5,
        zorder=-1,
    )

    axs.spines["top"].set_visible(False)

    axs.set_xlim(plt_time[-settings["response_dur"]*4], plt_time[-1])
    axs.set_xlabel("Time (s)")
    axs.spines["left"].set_visible(False)
    axs.yaxis.set_label_position("right")
    axs.yaxis.tick_right()

    plt.savefig("../figures/Panel1C")  # transparent=False)
    plt.savefig("../figures/Example_output.eps")

## Figure 5c: mean currents

In [None]:
# Plot mean currents with Poisson spikes

scale =60 # for line plots
spikesize = 0.5 # spike size in plots
spike_scale = 7 # for scaling rate parameter in poisson

plt_tr_int = 1 # we take one out of every plt_tr_int trials for spike raster plots to avoid cluttering them

plt_neurons = ranked_neurons[-2:]

with mpl.rc_context(fname="matplotlibrc"):

    
    for neui, neuron in enumerate(plt_neurons):
        # Create Figure
        fig = plt.figure(figsize=(5,2))
        gs = gridspec.GridSpec(1, 1)    

        # Extract traces of pref stim
        traces_all = [[] for i in np.arange(4)]
        pref_stim = prefered_stim[neuron]
        traces_pref, _ = extract_traces(r1, stim, neuron, pref_stim, settings)
        for it in range(settings['n_items']):
            for tr in range(traces_pref[it].shape[1]):
                pref = traces_pref[it][:,tr]
                traces_all[it].append(pref)


        # Extract spiking information
        spike_times = []
        counter = 0
        hbars = []
        r_im = np.concatenate([traces_all[i][::plt_tr_int] for i in range(4)], axis = 0)
        all_spikes, r = draw_spikes(r_im.T/spike_scale, 0, r_im.shape[1])
        spike_times = [np.arange(0,r_im.shape[1])[spikes] for spikes in all_spikes]
        spike_colors = []
        for i in range(settings['n_items']):
            for tr in range(np.array(traces_all[i][::plt_tr_int]).shape[0]):
                spike_colors.append('black')

        ax0 = plt.subplot(gs[0])
        # Spike raster plot
        ax0.eventplot(spike_times, color = spike_colors, linelengths=spikesize, linewidths=spikesize);
        
        for i in np.flip(np.arange(4)):
            meantr = np.mean(traces_all[i][::plt_tr_int], axis = 0)
            meantr -= np.min(meantr)
            meantr/=np.max(meantr) #to [0,1]
            meantr*=len(traces_all[i][::plt_tr_int])*0.95 #scale to fit spike size
            ax0.plot(np.arange(r_im.shape[1]), counter + meantr, color = pltcolors[i])

            counter +=np.array(traces_all[i][::plt_tr_int]).shape[0]
            hbars.append(counter-np.array(traces_all[i][::plt_tr_int]).shape[0]/2)
            ax0.axvline(x=settings['stim_ons']-settings['rand_ons']+(i*(settings['stim_offs']+settings['stim_dur'])), \
                        color = 'black', linestyle = '--', alpha = 0.5, zorder=-1)

        ax0.axvline(x=(delay_end - settings['rand_ons']), color = 'black', linestyle = '--', alpha = 0.5, zorder=-1)
        ax0.set_ylim(0,counter)
        ax0.set_xlim(0,r_im.shape[1])
        all_ind = np.arange(r_im.shape[1])
        ind_start = settings['stim_ons']-settings['rand_ons']             
        ticks =all_ind[ind_start::dt_sec]
        ax0.set_xticks(ticks)
        ax0.set_xticklabels(["{:0.0f}".format(lab) for lab in plt_time[ticks]])
        ax0.set_xlabel("Time (s)")


        hbars_pos = [bar for bar in hbars]
        ax0.set_yticks([])
        if neui==0:
            ax0.set_yticks(hbars_pos)
            ax0.set_yticklabels(['Pos ' +str(i) for i in range(4)])
        ax0.spines['right'].set_visible(False)
        ax0.spines['top'].set_visible(False)
        plt.savefig("../figures/single_unit"+ str(neui)+"fr"+str(f)+".svg")

## Calculate variance explained

In [None]:
delay_buffer1 = 25 
delay_buffer2 = 25 
nbins = 20
freqs = [main_freq]

In [None]:
# Calculate variance explained

n_perm = 1

t1 = delay_start + delay_buffer1 - settings["rand_ons"]
t2 = delay_end - delay_buffer2 - settings["rand_ons"]
delay_time = time[t1:t2]
neurons_plot = ranked_neurons[-int(np.sum(d_primes > cutoff_d)):]
numplot = len(neurons_plot)
for f in freqs:
    if dt_sec / f < nbins:
        print("Warning: too much bins for f = " + str(f))

bin_lims = np.linspace(-np.pi, np.pi, nbins + 1)
bin_centers = bin_lims[:-1] + np.pi / nbins
width = 2 * np.pi / (nbins)

vex = np.zeros((len(freqs), N))
shvex = np.zeros((len(freqs), N))
shuffle_ind = np.random.choice(np.arange(n_items), (n_trials, n_perm))




for neui, neuron in enumerate(neurons_plot):

    if neui % 10 == 0:
        print("{:.2f}% done".format(100 * neui / len(neurons_plot)))

    pref_stim = prefered_stim[neuron]
    pref_r, _, LFPs = extract_traces(
        r1, stim, neuron, pref_stim, settings, True, var, onlyGaba
    )
    for fi, f in enumerate(freqs):

        watsdat = []
        watsw = []
        watsw_shuffle = [[] for _ in range(n_perm)]
        spikephasehist_shuffle = np.zeros((4, nbins, n_perm))
        counter = 0
        cwt = complex_wavelet(timestep, f, 7)
        shuffle_ind[:,1:] = np.random.choice(np.arange(n_items), (n_trials, n_perm-1))

        
        for stim_pos in range(settings["n_items"]):
            spikephasehist = np.zeros(nbins)
            for tr in range(np.array(LFPs[stim_pos]).shape[1]):
                if ref_phase == "sine":
                    LFP_phase = wrap(time * 2 * np.pi * f)
                elif ref_phase == "LFP":
                    LFP_phase, _ = inst_phase(
                        LFPs[stim_pos][:, tr], cwt, time, f, ref_phase=False
                    )
                else:
                    print("WARNING: reference phase not recognised!")

                bin_ind = np.digitize(LFP_phase[t1:t2], bin_lims) - 1
                firing_trace = pref_r[stim_pos][t1:t2, tr]
                for b in range(nbins):
                    summed_spikes = np.sum(firing_trace[bin_ind == b])
                    occ = np.count_nonzero(bin_ind == b)
                    if occ > 0:
                        spikephasehist[b] += summed_spikes / occ
                        for permi in range(n_perm):
                            spikephasehist_shuffle[shuffle_ind[counter,permi], b, permi] += (
                                summed_spikes / occ
                            )
                counter += 1
            avg, avglen = circ_mean(bin_centers, spikephasehist)
            watsw.append(np.array(spikephasehist))
            watsdat.append(bin_centers)
        for stim_pos in range(settings["n_items"]):
            for permi in range(n_perm):
                watsw_shuffle[permi].append(np.array(spikephasehist_shuffle[stim_pos, :, permi]))

        anovatable = watson_williams_test(
            bin_centers, bin_centers, bin_centers, bin_centers, w=watsw_shuffle[0]
        )[1]
        shvex[fi, neuron] = anovatable["SS"][0] / anovatable["SS"][2]

        anovatable = watson_williams_test(
            watsdat[0], watsdat[1], watsdat[2], watsdat[3], w=watsw
        )[1]
        vex[fi, neuron] = anovatable["SS"][0] / anovatable["SS"][2]

vex_fr_ind = np.argmax(np.sum(vex, axis=1))
vex_fr = freqs[vex_fr_ind]

## Figure 5e: 'spike' phase histograms

In [None]:
# Plot 'spike' phase histograms


nbins = 20
vex_ind = np.argsort(vex[vex_fr_ind])[-np.sum(d_primes > cutoff_d) :]
ind = [-2, -1]
plt_neurons = vex_ind[ind] 


f = vex_fr  # main_freq
cwt = complex_wavelet(timestep, f, 7)


delay_time = time[t1:t2]
bin_lims = np.linspace(-np.pi, np.pi, nbins + 1)
bin_centers = bin_lims[:-1] + np.pi / nbins
width = 2 * np.pi / (nbins)
circ_ticks = np.arange(-np.pi, np.pi, np.pi / 2)
circ_ticklabs = ["-pi", "-pi/2", "0", "pi/2"]


with mpl.rc_context(fname="matplotlibrc"):

    fig = plt.figure(figsize=(7, 7))
    gs = gridspec.GridSpec(4, 2)
    for neui, neuron in enumerate(plt_neurons):

        # Extract traces of pref stim
        traces_all = [[] for i in np.arange(4)]
        pref_stim = prefered_stim[neuron]
        traces_pref, _, LFPs = extract_traces(
            r1, stim, neuron, pref_stim, settings, True, var, onlyGaba
        )
        for it in range(settings["n_items"]):
            for tr in range(traces_pref[it].shape[1]):
                pref = traces_pref[it][:, tr]
                traces_all[it].append(pref)

        # Extract Histogram information and plot Histogram

        maxh = np.zeros(4)
        avgs = np.zeros(4)
        for stim_pos in np.arange(settings["n_items"]):
            ax1 = plt.subplot(gs[stim_pos, neui])

            spikephasehist = np.zeros(nbins)
            for tr in range(np.array(traces_pref[stim_pos]).shape[1]):
                if ref_phase == "sine":
                    LFP_phase = wrap(time * 2 * np.pi * f)
                elif ref_phase == "LFP":
                    LFP_phase, _ = inst_phase(
                        LFPs[stim_pos][:, tr], cwt, time, f, ref_phase=False
                    )
                else:
                    print("WARNING: reference phase not recognised!")
                bin_ind = np.digitize(LFP_phase[t1:t2], bin_lims) - 1
                firing_trace = traces_pref[stim_pos][t1:t2, tr]
                for b in range(nbins):
                    summed_spikes = np.sum(firing_trace[bin_ind == b])
                    occ = np.count_nonzero(bin_ind == b)
                    if occ > 0:
                        spikephasehist[b] += summed_spikes / occ
            spikephasehist /= np.sum(spikephasehist)
            avg, avglen = circ_mean(bin_centers, spikephasehist)
            avg += np.pi
            avg /= np.pi * 2
            avg /= 1.5
            maxh[stim_pos] = max(spikephasehist)
            avgs[stim_pos] = avg

            ax1.bar(
                np.concatenate([bin_centers, (bin_centers + np.pi * 2)[: -nbins // 3]]),
                np.concatenate([spikephasehist, spikephasehist[: -nbins // 3]]),
                width=width,
                color=pltcolors[stim_pos],
                ec=pltcolors[stim_pos],
                linewidth=3,
                alpha=0.9,
                label="Pos:" + str(stim_pos + 1),
            ) 
            
            ax1.axhline(
                (maxh[stim_pos]) * 1.1,
                avgs[stim_pos] - 0.04,
                avgs[stim_pos] + 0.04,
                color=pltcolors[stim_pos],
                linewidth=3,
            )  
            ax1.set_xticks([])
            ax1.set_xlim(circ_ticks[0], (bin_centers + np.pi * 2)[-nbins // 3])
            ax1.set_xticklabels([])
            if stim_pos == 0:
                ax1.set_title("VEX = {:.2f}".format(vex[vex_fr_ind, neuron]))
            if neui == 0:
                ax1.set_ylabel("density") 

        
            ax1.spines["right"].set_visible(False)
            ax1.spines["top"].set_visible(False)
        ax1.set_xticks(circ_ticks)
        ax1.set_xticklabels(circ_ticklabs)
    plt.savefig("../figures/spike_phase_hist_fr"+str(f)+".eps", transparent=False)