# Definitions

In [None]:
# change working directory to the project root
import os
os.chdir('../../')

import sys
sys.path.append('models/utils')
sys.path.append('models/brian2')
sys.path.append('models/aln')

In [None]:
import os
import matplotlib.pyplot as plt
% matplotlib inline
import numpy as np
import scipy.signal

import fitparams as fp
import functions as func
import runModels as rm
import paths

In [None]:
# sett dpi for notebooks on server
plt.rcParams['svg.fonttype'] = 'none'
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['image.cmap'] = 'plasma'

In [None]:
params = []
params = fp.loadpoint(params, 'A2')

params['dt'] = 0.1
params['duration'] = 3000.0
params['sigma_ou'] = 0.0

N_neurons = 25000

# Plot data

In [None]:
models = ['aln', 'brian']
points = ['A2', 'B3']

for model in models:
    params['model'] = model
    fig, axs = plt.subplots(2, 1, figsize=(2.75, 2), dpi=300)

    for ip, point in enumerate(points):

        if model == 'aln':
            params = fp.loadpoint(params, point, newIC=True)
        elif model == 'brian':
            params = fp.loadpoint_network(params, point)
        params['dt'] = 0.1
        params['duration'] = 2000.0
        params['sigma_ou'] = 0.0
        params['N'] = N_neurons
        params['load_point'] = point
        
        print(model, point)

        plotLen = int(1.2 * 1000 / params['dt'])

        t, rates_exc, rates_inh, stimulus = rm.runModels(manual_params=params)

        plotData = rates_exc[-plotLen:]
        
        # determining the dominant frequency 
        nperseg = int(1 * 1000 / params['dt'])  # welch parameter below
        maxfr = 75
        f_orig, Pxx_spec_orig = scipy.signal.welch(rates_exc[int(1 * 1000 / params['dt']):], 1 / params['dt'] * 1000,
                                                  'flattop', nperseg, scaling='spectrum')
        orig_peak_frequency = f_orig[Pxx_spec_orig.argmax()]
        maxfreqIndex = np.abs(f_orig - maxfr).argmin()
        print("Maximum Powerspectrum peak without stimulation: {} at {}Hz".format(np.max(np.sqrt(Pxx_spec_orig)),
                                                                                 orig_peak_frequency))
        Pxx_spec_orig = Pxx_spec_orig[:maxfreqIndex]
        f_orig = f_orig[:maxfreqIndex]
        Pxx_spec_orig_max = np.max(np.sqrt(Pxx_spec_orig))

        # Plotting
        
        axs[ip].plot(np.linspace(0, plotLen / 1000 * params['dt'], plotLen), rates_exc[-plotLen:].T, c='C3', label='$r_E$')
        axs[ip].plot(np.linspace(0, plotLen / 1000 * params['dt'], plotLen), rates_inh[-plotLen:].T, c='C0', label='$r_I$')

        axs[ip].set_xlim([0.0, 1.0])

        plotpoint = [0.10, np.max(plotData) * 0.78]

        bbox_props = dict(boxstyle="circle", fc="w", ec="0.5", alpha=0.98)
        axs[ip].text(plotpoint[0], 0.8, point, ha="center", va="center", size=8, bbox=bbox_props, transform = axs[ip].transAxes)
        
        axs[ip].set_xticks([0, 0.5, 1])
        axs[ip].tick_params(size=1)
        axs[ip].tick_params(length=3)
        
        if model == 'aln':
            axs[ip].set_ylabel("Population\nrate [Hz]", fontsize=8)

        if ip < len(points) - 1:
            axs[ip].set_xticks([])
        else:
            axs[ip].set_xlabel("Time [s]", fontsize=8)

    if model == 'brian':
        axs[0].legend(loc=1, prop={'size': 8})

    plt.tight_layout(pad=0.3, w_pad=-0.3)  # , h_pad=1.0)

    if model == 'aln':
        axs[0].set_yticks([0, 30])
        axs[1].set_yticks([0, 50])
    elif model == 'brian':
        axs[0].set_yticks([0, 100])
        axs[1].set_yticks([0, 50])        

    for k in range(2):
        axs[k].tick_params(size=1)
        axs[k].tick_params(length=3)

    for k in range(2):
        axs[k].spines['right'].set_visible(False)
        axs[k].spines['top'].set_visible(False)
        axs[k].spines['bottom'].set_visible(False)
        axs[k].spines['left'].set_visible(False)
        axs[k].tick_params(direction='out', length=4, width=1, colors='k', labelsize=6)

    # TODO create dirs hierarchy dynamically
    fname = "traces-{}".format(model)
    print("Saving {}".format(fname))
    for extension in ['png', 'svg', 'jpg']:
        plt.savefig(os.path.join(paths.FIGURES_DIR, "{}.{}".format(fname, extension)))
    plt.show()