# Definitions

In [None]:
# change working directory to the project root
import os
os.chdir('../../')

import sys
sys.path.append('models/utils')
sys.path.append('models/brian2')
sys.path.append('models/aln')

In [None]:
import os
import time
import matplotlib.pyplot as plt
% matplotlib inline
import numpy as np
import scipy.signal
import pickle

import fitparams as fp
import functions as func
import runModels as rm
import paths

In [None]:
# sett dpi for notebooks on server
plt.rcParams['svg.fonttype'] = 'none'
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['image.cmap'] = 'plasma'

In [None]:
N_neurons = 50000

params = []
params = fp.loadpoint(params, 'A2')

params['dt'] = 0.1
params['duration'] = 3000.0
params['sigma_ou'] = 0.0
params['model'] = 'brian'
params['N'] = N_neurons
results = {}

In [None]:
def construct_stimulus(stim = 'dc', stim_amp = 0.2,  stim_freq = 1, stim_bias = 0, n_periods = 0, nostim_before = 0, nostim_after = 0):
    
    def sinus_stim(f = 1, sinus_amplitude = 0.2, positive = 0, phase = 0, cycles = 1, t_pause = 0):
        x = np.linspace(np.pi, -np.pi, 1000 / params['dt'] / f)
        sinus_function = np.hstack(((np.sin(x + phase) + positive) / 2 ,np.tile(0, t_pause)))
        sinus_function *= sinus_amplitude
        return np.tile(sinus_function, cycles)

    if stim == 'ac':
        if n_periods == 0:
            n_periods = int(stim_freq) * 1
        stimulus = np.hstack(([stim_bias]*int(nostim_before / params['dt']), np.tile(sinus_stim(stim_freq, stim_amp) + stim_bias, n_periods)))
        stimulus = np.hstack((stimulus, [stim_bias]*int(nostim_after / params['dt'])))
        
    elif stim == 'cc':
        chirp = scipy.signal.chirp(np.linspace(0, 1, 40000), 2, 1, 25 ,phi=270)
        #plt.plot(np.linspace(0, 1, 10000), chirp)
        stimulus = np.hstack(([stim_bias]*int(2000 / params['dt']), np.tile(chirp, 1)))
    elif stim == 'dc':
        stimulus = np.hstack(([stim_bias]*int(nostim_before / params['dt']), [stim_bias+stim_amp]*int(1000 / params['dt'])))
        stimulus = np.hstack((stimulus, [stim_bias]*int(nostim_after / params['dt'])))
        stimulus[stimulus<0] = 0
    elif stim == 'rect':
        one_stim = np.repeat(stim_amp, int(330 / params['dt']))
        time_before_stim = nostim_before
        time_after_stim = 330
        step_length = 330
        
        before_stimulus = [stim_bias]*int(nostim_before / params['dt']) # largely nothing but a small kick to ensure it's in the down state
        impulse_length = 200
        before_stimulus[len(before_stimulus)/2-impulse_length/2:len(before_stimulus)/2+impulse_length/2] = np.repeat(stim_bias-1.0, impulse_length)
        stimulus = np.hstack((before_stimulus, [stim_bias+stim_amp]*int(step_length / params['dt'])))
        
        stim_increase_counter = 0
        stim_step_increase = 0.0035
        for i in range(int(step_length / params['dt'])/2):
            stimulus[-int(step_length / params['dt'])/2+i] = np.exp(-stim_increase_counter)*stim_amp
            stim_increase_counter += stim_step_increase
            
        stimulus = np.hstack((stimulus, [stim_bias]*int(step_length / params['dt'])))
        stimulus = np.hstack((stimulus, [stim_bias-stim_amp]*int(step_length / params['dt'])))
        
        stim_increase_counter = 0
        for i in range(int(step_length / params['dt'])/2):
            stimulus[-int(step_length / params['dt'])/2+i] = -np.exp(-stim_increase_counter)*stim_amp
            stim_increase_counter += stim_step_increase
                    
        
        stimulus = np.hstack((stimulus, [stim_bias]*int(step_length / params['dt'])))
        if n_periods > 0:
            stimulus = np.tile(stimulus, n_periods)
    else:
        print("ERROR, stim protocol {} not found")
    
    # repeat stimulus until full length
    steps = int(params['duration']/params['dt'])
    stimlength = int(len(stimulus))
    #print("steps: {} stimlength {} repeat: {}".format(steps,stimlength,steps/stimlength+1))
    stimulus = np.tile(stimulus, steps/stimlength+2)
    stimulus = stimulus[:steps]
    
    return stimulus

# Load analyzed data

In [None]:
# uncomment to load results from disk
file_path_p = os.path.join(paths.PICKLE_DIR, 'stimulus-all-results.p')
#results = pickle.load(file(file_path_p, "r"))

# Plot data

In [None]:
points = ['A1', 'A2', 'A3', 'B3', 'B4']
protocols = ['dc', 'dc', 'rect', 'ac', 'ac']
amps = [0.3, 0.2, 0.5, 0.4, 0.4]
freqs = [2, 2, 3, 3, 4]

import string
titles = string.ascii_lowercase

nrows = len(points)
ncols = 2



f, axs = plt.subplots(nrows, ncols, figsize=(4.0, 3.5), dpi=600)

counter = 0
for ip, p in enumerate(points):
    
    point = p
    stim = protocols[ip]
    stim_amp = amps[ip]
    stim_freq = freqs[ip]
    stim_bias = 0   
    nostim_time = 1000
    timeshift_B4 = 0
    


    for im, model in enumerate(['aln', 'net']):
        
        k = counter / ncols
        p = counter % ncols    
        #print(k,p)
        
        title = titles[counter]
        print('{}) {} point {}'.format(title, model, point))
        
        start = time.time()
        
        if point == 'B3': # phase entrainment plot
            # Phase entrainment plot is informative only if the stimulus and 
            # the ongoing oscillation have an initial phase difference. 
            # Unfortunately, that value is not easily controllable for so it
            # might need a few runs until a sensible plot emerges.
            if model == 'aln':
                nostim_time = 1280
            else:
                nostim_time = 1280
        n_periods = 4 if point =='B3' else 0 # more periods for phase entrainment plot

        
        if model == 'aln' and (title not in results):
            params = fp.loadpoint(params, point, newIC=False)
            params['sigma_ou'] = 0.0
            params['duration'] = 3000.0
            params['model'] = 'aln'
            stimulus = construct_stimulus(stim, stim_amp = stim_amp, stim_freq=stim_freq, nostim_before=nostim_time, n_periods=n_periods)
            params['ext_exc_current'] = stimulus 
            
            t, rates_exc, rates_inh, stimulus = rm.runModels(manual_params=params)
            t/=1000
                 
        elif model == 'net' and (title not in results):
            params = fp.loadpoint_network(params, point)
            params['sigma_ou'] = 0.0
            params['duration'] = 3000.0            
            params['N'] = N_neurons
            params['model'] = 'brian'
            stimulus = construct_stimulus(stim, stim_amp = stim_amp, stim_freq=stim_freq, nostim_before=nostim_time, n_periods=n_periods)
            params['ext_exc_current'] = stimulus 
            
            t, rates_exc, rates_inh, stimulus = rm.runModels(manual_params=params)
            t/=1000
            
        else:
            print("loading precomputed result ...")
            t, rates_exc, stimulus = results[title]['t'].copy(), results[title]['rates_exc'].copy(), results[title]['stimulus'].copy()
        
        end = time.time()
        print("Took %f seconds"%((end - start)))
        
        #t/=1000.0
        plotrange = [0.8, 2.2]
        if title in ['c', 'd']: plotrange = [0.8, 2.8]
        if title in ['g', 'h']: plotrange = [0.8, 3.6]
            
        plotIndex = (t>plotrange[0])&(t<plotrange[1])
        axs[ip,im].plot(t[plotIndex], rates_exc[plotIndex], lw=1.2, c='k') 
        
        
        # STIMULUS PLOT
        ax_stimulus = axs[ip,im].twinx()
        stimulus = np.tile(stimulus, 2)
        stimulus = stimulus[:len(t)]
        ax_stimulus.plot(t[plotIndex], stimulus[plotIndex]*200, c='C3', lw=1.5, alpha=0.7)  
        
        ax_stimulus.spines['right'].set_visible(False)
        ax_stimulus.spines['top'].set_visible(False)
        ax_stimulus.spines['bottom'].set_visible(False)
        ax_stimulus.spines['left'].set_visible(False)
        ax_stimulus.tick_params(direction='out', length=4, width=1, colors='k', labelsize=4)
        ax_stimulus.tick_params('y', colors='C3')
        
        ax_stimulus.set_ylim(np.round(np.min(stimulus[plotIndex]*200), 2)*1.5-0.01*200, np.round(np.max(stimulus[plotIndex]*200), 2)*1.2)
        
        if im==1:
            stim_axis_amp = stim_amp
            if stim == 'ac': stim_axis_amp /= 2
            ax_stimulus.set_yticks([0, stim_axis_amp*200])
        else: 
            ax_stimulus.set_yticks([])
        
        
        results[title] = {'t' : t, 'rates_exc' : rates_exc, 'stimulus': stimulus}
        
        
        # subfigure titles
        axs[k,p].text(-0.1, 0.98, title, horizontalalignment='center', size=10,
            verticalalignment='center', transform=axs[k,p].transAxes, fontdict={'weight':'regular'})  
        
        # point labels
        if im==0:
            bbox_props = dict(boxstyle="circle", fc="w", ec="0.5", pad=0.2, alpha=0.9)
            axs[k,p].text(0.05, 0.98, point, ha="center",transform=axs[k,p].transAxes, va="center", size=6, bbox=bbox_props)

        # stimulus amps
        #axs[k,p].text(0.85, 0.95, stim_amp, ha="center",transform=axs[k,p].transAxes, va="center", size=6, color='b')

        
        # Hide the right and top spines
        axs[k,p].spines['right'].set_visible(False)
        axs[k,p].spines['top'].set_visible(False)
        axs[k,p].spines['bottom'].set_visible(False)
        axs[k,p].spines['left'].set_visible(False)
        axs[k,p].tick_params(direction='out', length=4, width=1, colors='k', labelsize=4)
        
        axs[k,p].set_xticks([1.0, 2.0])
        #axs[k,p].set_ylim(-25, np.max(rates_exc[(t>1000)&(t<2000)])*1.1)
        
        counter+=1
    


# hacks
axs[0,0].set_yticks([0,  5])
axs[0,1].set_yticks([0,  10])

#axs[1,0].set_yticks([0,  20])
#axs[1,1].set_yticks([0,  100])

axs[1,0].set_yticks([0,  20])
axs[1,1].set_yticks([0,  60])

axs[2,0].set_yticks([0,  30])
axs[2,1].set_yticks([0,  30])

axs[3,0].set_yticks([0,  25])
axs[3,1].set_yticks([0,  25])

axs[4,0].set_yticks([0,  25])
axs[4,1].set_yticks([0,  25])

axs[4,0].set_xlabel("Time [s]", fontsize=6)
axs[4,1].set_xlabel("Time [s]", fontsize=6)
axs[4,0].set_ylabel("Rate [Hz]", fontsize=6)
ax_stimulus.set_ylabel("Stimulus [pA]", fontsize=6, color='C3')



plt.text(0.27, 1.03, 'Mean-Field', transform=f.transFigure, ha='center',weight="bold")
plt.text(0.735, 1.03, 'AdEx Network', transform=f.transFigure, ha='center',weight="bold")

plt.tight_layout(pad=0.1, h_pad=0.5)       

for extension in ['png', 'svg', 'jpg']:
    plt.savefig(os.path.join(paths.FIGURES_DIR, "stimulus-all.{}".format(extension)))

# Save analyzed data

In [None]:
# uncomment to save results to disk
file_path_p = os.path.join(paths.PICKLE_DIR, 'stimulus-all-results.p')
#pickle.dump(results, file(file_path_p, "w"))