Notebook Settings
=================

``` ipython
%load_ext autoreload
%autoreload 2
%reload_ext autoreload
%run ../../../notebooks/setup.py
%matplotlib inline
%config InlineBackend.figure_format = 'png'
```

Imports
=======

``` ipython
import torch
import torch.nn as nn
import torch.optim as optim
import torchmetrics
from torch.utils.data import Dataset, TensorDataset, DataLoader

REPO_ROOT = "/home/leon/models/NeuroFlame"

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

pal = sns.color_palette("tab10")
DEVICE = 'cuda:1'
```

``` ipython
import sys
sys.path.insert(0, '../../../')

from notebooks.setup import *

import pandas as pd
import torch.nn as nn
from time import perf_counter
from scipy.stats import circmean

from src.network import Network
from src.plot_utils import plot_con
from src.decode import decode_bump, circcvl
from src.lr_utils import masked_normalize, clamp_tensor, normalize_tensor
```

Test
====

utils
-----

``` ipython
import pickle as pkl

def pkl_save(obj, name, path="."):
      pkl.dump(obj, open(path + "/" + name + ".pkl", "wb"))


def pkl_load(name, path="."):
      return pkl.load(open(path + "/" + name + '.pkl', "rb"))

```

``` ipython
def add_vlines(ax=None, mouse=""):
    t_BL = [0, 1]
    t_STIM = [1 , 2]
    t_ED = [2, 3]
    t_DIST = [3 , 4]
    t_MD = [4 , 5]
    t_CUE = [5 , 5.5]
    t_RWD = [5.5, 6.0]
    t_LD = [6.0 , 7.0]
    t_TEST = [7.0, 8.0]
    t_RWD2 = [11 , 12]

    # time_periods = [t_STIM, t_DIST, t_TEST, t_CUE, t_RWD, t_RWD2]
    # colors = ["b", "b", "b", "g", "y", "y"]

    time_periods = [t_STIM, t_DIST, t_TEST, t_CUE]
    colors = ["b", "b", "b", "g"]

    if ax is None:
        for period, color in zip(time_periods, colors):
            plt.axvspan(period[0], period[1], alpha=0.1, color=color)
    else:
        for period, color in zip(time_periods, colors):
            ax.axvspan(period[0], period[1], alpha=0.1, color=color)

```

``` ipython
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import sem, t

def mean_ci(data):
  # Calculate mean and SEM
  mean = np.nanmean(data, axis=0)
  serr = sem(data, axis=0, nan_policy='omit')

  # Calculate the t critical value for 95% CI
  n = np.sum(~np.isnan(data), axis=0)
  t_val = t.ppf(0.975, df=n - 1)  # 0.975 for two-tailed 95% CI

  # Calculate 95% confidence intervals
  ci = t_val * serr

  return mean, ci
```

``` ipython
def plot_overlap_label(readout, y, task=0, label=['pair', 'unpair'], figname='fig.svg', title='first'):
    '''
    y[0] is pairs, y[1] is samples, y[2] is task if not None
    '''

    size = y.shape[0]
    if size ==2:
        ones_slice = np.zeros(y.shape)
        y_ = np.vstack((y.copy(), ones_slice))
        task = 0
    else:
        y_ = y.copy()
        tasks = [0, 1, -1]


    fig, ax = plt.subplots(1, 3, figsize=[3*width, height], sharey=True)

    time = np.linspace(0, 10, readout.shape[1])
    colors = ['r', 'b', 'g']
    ls = ['--', '-', '-.', ':']
    label = ['AD', 'AC', 'BD', 'BC']
    mean_overlaps = []
    for k in range(2): # readout
        for j in range(2): # sample
            for i in range(2): # pair
                data = readout[(y_[0]==i) & (y_[1]==j) & (y_[2]==task), :, k]
                mean, ci = mean_ci(data)
                mean_overlaps.append(mean)
                ax[k].plot(time, mean, ls=ls[i+2*j], label=label[i+2*j], color=colors[task], alpha=1-j/4)
                ax[k].fill_between(time, mean - ci, mean + ci, color=colors[task], alpha=0.1)

        add_vlines(ax[k])
        ax[k].set_xlabel('Time (s)')

        if k==0:
            ax[k].set_ylabel('A/B Overlap (Hz)')
        elif k==1:
            ax[k].set_ylabel('GNG Overlap (Hz)')
        else:
            ax[k].set_ylabel('Readout (Hz)')

        ax[k].axhline(0, color='k', ls='--')

    mean_overlaps = np.array(mean_overlaps).reshape((2, 2, 2, -1))

    for j in range(2): # sample
        for i in range(2): # pair
            ax[-1].plot(mean_overlaps[0, j, i], mean_overlaps[1, j, i], color=colors[task], ls=ls[i+2*j], label=label[i+2*j])

    ax[-1].set_xlabel('A/B Overlap (Hz)')
    ax[-1].set_ylabel('Choice Overlap (Hz)')

    plt.legend(fontsize=10)
    plt.savefig('../figures/dual/%s' % figname, dpi=300)
    plt.show()
```

run
---

``` ipython
REPO_ROOT = "/home/leon/models/NeuroFlame"
conf_name = "train_dual.yml"
DEVICE = 'cuda:1'
thresh= 5
seed = 1
```

``` ipython
sys.path.insert(0, '../../../src')
from src.train.dual.train_dual import test_dual
```

``` ipython
accuracies = []
readouts = []
covariances = []
labels = []

for seed in range(0, 100):
    acc_ = []
    cov_ = []
    readout_ = []
    labels_ = []
    for IF_OPTO in range(2):
        state = 'train'
        print(seed, state)
        readout, y_labels, cov, accuracy = test_dual(REPO_ROOT, conf_name, seed, state, thresh, DEVICE, IF_OPTO)
        acc_.append(accuracy)
        cov_.append(cov)
        readout_.append(readout)
        labels_.append(y_labels)

    accuracies.append(acc_)
    readouts.append(readout_)
    covariances.append(cov_)
    labels.append(labels_)
```

``` ipython
print(np.array(accuracies).shape)
acc = np.moveaxis(np.array(accuracies), 0, -1)
print(acc.shape)
```

Perf
----

``` ipython
fig, ax = plt.subplots(1, 2, figsize=[1.5*width, height])

rd = np.random.normal(size=(acc.shape[-1])) / 10

pal = ['r', 'b', 'g']
for j in range(2):
    for i in range(3):
        acc_mean = np.mean(acc[j][0][i], -1)
        acc_sem = np.std(acc[j][0][i], axis=-1, ddof=1) / np.sqrt(len(acc[j][0][i]))

        ax[0].errorbar(i+4*j, acc_mean, yerr=acc_sem, fmt='o', color=pal[i], ecolor=pal[i], elinewidth=3, capsize=5)
        ax[0].plot(i+rd + 4*j, acc[j][0][i], 'o', alpha=0.25)

# ax[0].set_xlim(-1, 4)
# ax[0].set_ylim(0.4, 1.1)

ax[0].set_ylabel('DPA Perf.')
ax[0].set_xticks([1, 5], ['OFF', 'ON'])
ax[0].axhline(y=0.5, color='k', linestyle='--')

# ax[1].errorbar(rd, acc[0][-1], yerr=acc[1][-1], fmt='o', label='Naive',
#              color='k', ecolor='k', elinewidth=3, capsize=5)

for i in range(2):
    acc_mean = np.mean(acc[i][0][-1], -1)
    acc_sem = np.std(acc[i][0][-1], axis=-1, ddof=1) / np.sqrt(len(acc[0][-1]))

    ax[1].errorbar(i, acc_mean, yerr=acc_sem, fmt='o', color='k', ecolor='k', elinewidth=3, capsize=5)
    ax[1].plot(rd+i, acc[i][0][-1], 'ko', alpha=.25)

ax[1].set_xlim(-1, 2)
ax[1].set_ylim(0.4, 1.1)

ax[1].set_ylabel('Go/NoGo Perf.')
ax[1].set_xticks([0, 1], ['OFF', 'ON'])
ax[1].axhline(y=0.5, color='k', linestyle='--')

plt.savefig('../figures/dual/dual_perf_multi_opto_%d.svg' % seed, dpi=300)

plt.show()
```

overlaps
--------

``` ipython
print(readout.shape, y_labels.shape)
```

``` ipython
plot_overlap_label(readout, y_labels, task=0, figname='overlaps_naive_dpa.svg')
```

``` ipython
plot_overlap_label(readout, y_labels, task=1, figname='overlaps_naive_go.svg')
```

``` ipython
plot_overlap_label(readout, y_labels, task=-1, figname='overlaps_naive_nogo.svg')
```

``` ipython
```

Perf vs overlap
---------------

``` ipython
delta_perf = acc[1] - acc[0]
print(delta_perf.shape)
```

``` ipython
readouts = np.array(readouts)
print(readouts.shape)
reads = np.moveaxis(readouts, 0, -1)
print(reads.shape)
```

``` ipython
delta_read = reads[1] - reads[0]
print(delta_read.shape)
```

``` ipython
REPO_ROOT = "/home/leon/models/NeuroFlame"
conf_name = "train_dual.yml"
DEVICE = 'cuda:1'

seed = 9
print(seed)
```

``` ipython
model = Network(conf_name, REPO_ROOT, VERBOSE=0, DEVICE=DEVICE, SEED=seed, N_BATCH=1)
device = torch.device(DEVICE if torch.cuda.is_available() else 'cpu')
model.to(device)
```

``` ipython
steps = np.arange(0, model.N_STEPS - model.N_STEADY, model.N_WINDOW)
mask = (steps >= (model.N_STIM_OFF[3].cpu().numpy() - model.N_STEADY)) & (steps < (model.N_STIM_ON[4].cpu().numpy() - model.N_STEADY))
LD_idx = np.where(mask)[0]
print('LD', LD_idx)
```

``` ipython
delta_read_LD = np.nanmean(delta_read[:, LD_idx], (0,1))
print(delta_read_LD.shape)
```

``` ipython
from scipy.stats import pearsonr

fig, ax = plt.subplots(1, 2, figsize=[2*width, height])

dperf = np.nanmean(delta_perf[0, :3], 0)
ax[0].scatter(delta_read_LD[1], dperf)
ax[0].set_ylabel('$\\Delta$ DPA Perf On-Off')
ax[0].set_xlabel('$\\Delta$ Choice Overlap On-Off')

corr, p_value = pearsonr(delta_read_LD[1], dperf)
ax[0].set_title("Corr: %.2f, p-value: %.3f" % (corr, p_value))
ax[0].set_ylim([-.25, .1])

dperf = delta_perf[0, -1]
ax[1].scatter(delta_read_LD[1], dperf)
ax[1].set_ylabel('$\\Delta$ GoNoGo Perf On-Off')
ax[1].set_xlabel('$\\Delta$ Choice Overlap On-Off')
ax[1].set_ylim([-.2, .1])

corr, p_value = pearsonr(dperf, delta_read_LD[1])
ax[1].set_title("Corr: %.2f, p-value: %.3f" % (corr, p_value))
plt.savefig('model_corr.svg')
plt.show()
```

``` ipython
```