In [None]:
# change working directory to the project root
import os
os.chdir('../../')

import sys
sys.path.append('models/utils')
sys.path.append('models/brian2')
sys.path.append('models/aln')

In [None]:
# import python packages
from __future__ import print_function
import os
import datetime
import tqdm
import matplotlib.pyplot as plt
% matplotlib inline
import numpy as np
import scipy 
import pandas as pd
import pypet as pp

# import utils libs
import pypet_parameters as pe
import fitparams as fp
import functions as func
import runModels as rm
import paths

In [None]:
# sett dpi for notebooks on server
plt.rcParams['svg.fonttype'] = 'none'
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['image.cmap'] = 'plasma'

In [None]:
params = []
params = fp.loadpoint(params, "A2")

params['dt'] = 0.1
params['duration'] =  6000 
params['sigma_ou'] = 0.0


N_neurons = 10000
params['N'] = N_neurons
params['model'] = 'aln'

# Parameter exploration

In [None]:
# explore the following parameter combinations
parametrizationA2 = pp.cartesian_product({
    'f_sin':[ round(elem, 5) for elem in np.linspace(1, 70, 70)],
    'A_sin':[ 0.2, 0.4, 0.8, 1.0, 1.2, 1.4],
    'model':[ 'aln', 'brian' ],
    'load_point':[ 'A2' ]
})

parametrization = parametrizationA2.copy()

In [None]:
print("Number of parameter configurations: {}".format(len(parametrization[parametrization.keys()[0]])))

## Run simulations

In [None]:
# ---- initialize pypet environment ----
trajectoryName = 'results' + datetime.datetime.now().strftime("-%Y-%m-%d-%HH-%MM-%SS")
HDF_FILE = os.path.join(paths.HDF_DIR, 'frequency-entrainment-adex-aln.hdf')
trajectoryFileName = HDF_FILE

import multiprocessing
ncores = multiprocessing.cpu_count()
print("Number of cores: {}".format(ncores))



env = pp.Environment(trajectory=trajectoryName,filename=trajectoryFileName,
                    file_title='frequency entrainment',
                    large_overview_tables=True,
                    multiproc=True,           
                    ncores=ncores,
                    wrap_mode='QUEUE',
                    log_stdout = False
                    )

# Get the trajectory from the environment
traj = env.v_trajectory
trajectoryName = traj.v_name

pe.add_parameters(traj, params)

In [None]:
traj.f_explore(parametrization)
env.f_run(rm.runModels_stimulus)
env.f_disable_logging()
print("Done.")

# Data processing

## Load results from disk

In [None]:
HDF_FILE = os.path.join(paths.HDF_DIR, 'frequency-entrainment-adex-aln.hdf')
trajectoryFileName = HDF_FILE

In [None]:
# ---- load pypet trajectory "trajectoryFileName" ----
print("Analyzing File \"{}\"".format(trajectoryFileName))
print("All Trajectories:")
print(pe.getTrajectoryNameInsideFile(trajectoryFileName))
trajectoryName = pe.getTrajectoryNameInsideFile(trajectoryFileName)[-1]

print("Analyzing trajectory \"{}\".".format(trajectoryName))
trajLoaded = pp.Trajectory(trajectoryName,add_time=False)
trajLoaded.f_load(trajectoryName,filename=trajectoryFileName, force=True)
trajLoaded.v_auto_load = True
print("{} results found".format(len(trajLoaded.f_get_results())))

## Process data

In [None]:
nResults = len(trajLoaded.f_get_run_names()) # number of results in trajectory 
dt = trajLoaded.f_get_parameters()['parameters.simulation.dt'].f_get() # dt of simulation

# ---- explored parameter list ----
exploredParameters = trajLoaded.f_get_explored_parameters()
niceParKeys = [p.split('.')[-1] for p in exploredParameters.keys()]

# ---- lcreate pandas df with results as keys ----
dfResults = pd.DataFrame(columns=niceParKeys,dtype=object)

# range of parameters
for nicep, p in zip(niceParKeys,exploredParameters.keys()):
    dfResults[nicep] = exploredParameters[p].f_get_range()

### Serial Processing

In [None]:
# ---- make a dictionary with results ----
resultDicts = []

measures = ['spectrum']

for rInd in tqdm.tqdm(range(nResults), total=nResults):
    res = trajLoaded.results[rInd].f_to_dict()
    for measure in measures:
        res['spectrum'] = func.analyse_run(measure, res, dt)
        res['stimulation'] = [dfResults.loc[rInd, 'A_sin'], dfResults.loc[rInd, 'f_sin']]     
        resultDicts.append(res)
            
print("done.")

# Plot data

In [None]:
A_sin_range = trajLoaded.f_get('parameters.localNetwork.A_sin').f_get_range()
load_point_range = trajLoaded.f_get('parameters.localNetwork.load_point').f_get_range()

from matplotlib.colors import LogNorm

for point in np.unique(load_point_range):
    for stim_amp in np.unique(A_sin_range):
        for model in ['aln', 'brian']:
            selector = (dfResults.A_sin==stim_amp) & (dfResults.model==model) & (dfResults.load_point==point)
            selectedResults = dfResults[selector]
            selectIndices = selectedResults.index
            
            f_sin_range = np.unique(selectedResults.f_sin)
            
            # spectrum plot -----------

            thisResults = []
            for ind in selectIndices:
                power = resultDicts[ind]['spectrum'][1]
                frequency = resultDicts[ind]['spectrum'][0]
                power = power[0:len(frequency)]
                
                thisResults.append(power)

            plt.figure(dpi=300, figsize=(3, 2.0))
            
            ax = plt.gca()
            plt.title("Point: {} Model: {} Amp: {} mV/ms".format(point, model, stim_amp))
            plt.imshow(np.array(thisResults).T, origin='lower', aspect='auto', norm=LogNorm(vmin=0.01, vmax=1),
                                                                                           extent=[f_sin_range[0],
                                                                                           f_sin_range[-1],
                                                                                           f_sin_range[0],
                                                                                           f_sin_range[-1]])

            plt.xlabel("Stimulation frequency [Hz]")
            plt.ylabel("Spectrum [Hz]")
            
            # temp: for reading plots precisely, remove for paper plot
            #plt.locator_params(axis='x', nbins=70)
            #plt.xticks(fontsize=4, rotation=90)
            #plt.grid() 
            #cbar = plt.colorbar(ticks=[0.00001, 1])
            #cbar.ax.set_yticklabels(['0', '1']) 
            
            # point label
            bbox_props = dict(boxstyle="circle", fc="w", ec="0.5", pad=0.2, alpha=0.9)
            ax.text(0.08, 0.88, point, ha="center",transform=ax.transAxes, va="center", size=10, bbox=bbox_props)
            
            # amplitude label
            ax.text(0.92,0.85, "{} pA".format(int(stim_amp*200)), fontweight='regular', transform=ax.transAxes, ha='right', #style='italic',
                    bbox={'facecolor':'white', 'alpha':0.85, 'pad':5}, fontsize=8)
            
            plt.show()    

            if (1 == 0):
                # amplitude plots -----------
                thisMaxima = []
                thisMinima = []
                for ind in selectIndices:
                    thisMaxima.append(resultDicts[ind]['amplitudes'][0])
                    thisMinima.append(resultDicts[ind]['amplitudes'][1])

                for ma, mi, i in zip(thisMaxima, thisMinima, selectIndices):
                    #print(i, m)
                    for maximum in ma[::10]:
                        if maximum > 25:
                            plt.scatter(i/3, maximum, zorder=-1, c='C3', alpha=0.2, s=0.8)

                    for minimum in mi[::10]:
                            plt.scatter(i/3, minimum, zorder=-1, c='C0', alpha=0.2, s=0.8)
                plt.xlim(0, 70)
                plt.show()