Notebook Settings
=================

``` ipython
%load_ext autoreload
%autoreload 2
%reload_ext autoreload
%run ../notebooks/setup.py
%matplotlib inline
%config InlineBackend.figure_format = 'png'
```

Imports
=======

``` ipython
import torch
import torch.nn as nn
import torch.optim as optim
import torchmetrics
from torch.utils.data import Dataset, TensorDataset, DataLoader

REPO_ROOT = "/home/leon/models/NeuroFlame"

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

pal = sns.color_palette("tab10")
DEVICE = 'cuda:1'
```

``` ipython
import sys
sys.path.insert(0, '../')

from notebooks.setup import *

import pandas as pd
import torch.nn as nn
from time import perf_counter
from scipy.stats import circmean

from src.network import Network
from src.plot_utils import plot_con
from src.decode import decode_bump, circcvl
from src.lr_utils import masked_normalize, clamp_tensor, normalize_tensor
```

Helpers
=======

Data Split
----------

``` ipython
from sklearn.model_selection import train_test_split, StratifiedShuffleSplit

def split_data(X, Y, train_perc=0.8, batch_size=32):

    if Y.ndim==3:
      X_train, X_test, Y_train, Y_test = train_test_split(X, Y,
                                                          train_size=train_perc,
                                                          stratify=Y[:, 0, 0].cpu().numpy(),
                                                          shuffle=True)
    else:
      X_train, X_test, Y_train, Y_test = train_test_split(X, Y,
                                                          train_size=train_perc,
                                                          stratify=Y[:, 0].cpu().numpy(),
                                                          shuffle=True)

    print(X_train.shape, X_test.shape)
    print(Y_train.shape, Y_test.shape)

    train_dataset = TensorDataset(X_train, Y_train)
    val_dataset = TensorDataset(X_test, Y_test)

    # Create data loaders
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader
```

Optimization
------------

``` ipython
def torch_angle_AB(U, V):
      # Calculate the dot product
      dot_product = torch.dot(U, V)

      # Calculate the magnitudes of U and V
      magnitude_U = torch.linalg.norm(U)
      magnitude_V = torch.linalg.norm(V)

      # Compute the cosine of the angle
      cos_theta = dot_product / (magnitude_U * magnitude_V + .00001)

      # Calculate the angle in radians, then convert to degrees
      angle_radians = torch.acos(cos_theta)
      return torch.round(torch.rad2deg(angle_radians))
```

``` ipython
def training_step(dataloader, model, loss_fn, optimizer, penalty=None, lbd=1, clip_grad=0, zero_grad=None):

      model.train()
      total_loss = 0.0
      total_batches = len(dataloader)

      for batch, (X, y) in enumerate(dataloader):
            X, y = X.to(model.device), y.to(model.device)

            optimizer.zero_grad()

            rates = model(X)
            loss = loss_fn(model.readout, y)

            loss.backward()

            if zero_grad is not None:
                  try:
                        if zero_grad == -1:
                              model.low_rank.U.grad[:,:] = 0
                              model.low_rank.V.grad[:, :] = 0
                        else:
                              model.low_rank.U.grad[:, zero_grad] = 0
                              model.low_rank.V.grad[:, zero_grad] = 0
                  except:
                        pass

            if clip_grad:
                  torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
                  #torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=1.0)

            optimizer.step()

            total_loss += loss.item()

      avg_loss = total_loss / total_batches
      return avg_loss
```

``` ipython
def validation_step(dataloader, model, loss_fn):
      num_batches = len(dataloader)
      model.eval()

      val_loss = 0.0
      with torch.no_grad():
          for X, y in dataloader:
              X, y = X.to(model.device), y.to(model.device)

              rates = model(X)
              loss = loss_fn(model.readout, y)
              val_loss += loss.item()

          val_loss /= num_batches

      return val_loss
```

``` ipython
def validation_step(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    device = torch.device(DEVICE if torch.cuda.is_available() else "cpu")

    model.eval()
    val_loss = 0.0

    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)

            rates = model(X)
            batch_loss = loss_fn(model.readout, y)
            val_loss += batch_loss.item() * X.size(0)

    val_loss /= size
    return val_loss
```

``` ipython
def optimization(model, train_loader, val_loader, loss_fn, optimizer, num_epochs=100, penalty=None, lbd=1, thresh=.005, zero_grad=None):
      scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
      # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.1, verbose=True)
      # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

      device = torch.device(DEVICE if torch.cuda.is_available() else 'cpu')
      model.to(device)

      loss_list = []
      val_loss_list = []
      angle_list = []

      for epoch in range(num_epochs):
          loss = training_step(train_loader, model, loss_fn, optimizer, penalty, lbd, zero_grad=zero_grad)
          val_loss = validation_step(val_loader, model, loss_fn)

          scheduler.step(val_loss)
          loss_list.append(loss)
          val_loss_list.append(val_loss)

          memory = model.low_rank.U[model.slices[0], 0]
          readout = model.low_rank.V[model.slices[0], 1]

          angle = torch_angle_AB(memory, readout).item()
          angle_list.append(angle)

          print(f'Epoch {epoch+1}/{num_epochs}, Training Loss: {loss:.4f}, Validation Loss: {val_loss:.4f}, Angle(U, W) : {angle} °')

          if val_loss < thresh and loss < thresh:
              print(f'Stopping training as loss has fallen below the threshold: {loss}, {val_loss}')
              break

          if val_loss > 300:
              print(f'Stopping training as loss is too high: {val_loss}')
              break

          if torch.isnan(torch.tensor(loss)):
              print(f'Stopping training as loss is NaN.')
              break

      return loss_list, val_loss_list
```

Loss
----

``` ipython
def imbalance_func(target, imbalance):
    output = torch.zeros_like(target)

    output[target == 0] = imbalance
    output[target == 1] = 1

    return output
```

``` ipython
import torch
import torch.nn as nn
import torch.nn.functional as F

class SignBCELoss(nn.Module):
      def __init__(self, alpha=1.0, thresh=2.0, imbalance=0):
            super(SignBCELoss, self).__init__()
            self.alpha = alpha
            self.thresh = thresh

            self.imbalance = imbalance
            self.bce_with_logits = nn.BCEWithLogitsLoss()

      def forward(self, readout, targets):
            if self.alpha != 1.0:
                  bce_loss = self.bce_with_logits(readout, targets)
            else:
                  bce_loss = 0.0

            # average readout over bins
            mean_readout = readout.mean(dim=1).unsqueeze(-1)

            # only penalizing not licking when pair
            if self.imbalance == -1:
                  # sign_overlap = torch.abs(torch.sign(2 * targets - 1)) * mean_readout
                  sign_overlap = torch.sign(targets) * mean_readout
                  self.imbalance = 0
            else:
                  sign_overlap = torch.sign(2 * targets - 1) * mean_readout

            if self.imbalance > 1.0:
                  sign_loss = F.relu(torch.sign(targets) * self.thresh - imbalance_func(targets, self.imbalance) * sign_overlap)
            elif self.imbalance == 0:
                  sign_loss = F.relu(imbalance_func(targets, self.imbalance) * self.thresh - sign_overlap)
            else:
                  sign_loss = F.relu(self.thresh - sign_overlap)

            combined_loss = (1-self.alpha) * bce_loss + self.alpha * sign_loss

            return combined_loss.mean()

```

``` ipython
class DualLoss(nn.Module):
      def __init__(self, alpha=1.0, thresh=2.0, cue_idx=[], rwd_idx=-1, zero_idx=[], read_idx=[-1], imbalance=0):
            super(DualLoss, self).__init__()
            self.alpha = alpha
            self.thresh = thresh

            self.imbalance = imbalance

            # BL idx
            self.zero_idx = zero_idx
            # rwd idx for DRT
            self.cue_idx = torch.tensor(cue_idx, dtype=torch.int, device=DEVICE)
            # rwd idx for DPA
            self.rwd_idx = torch.tensor(rwd_idx, dtype=torch.int, device=DEVICE)

            # readout idx
            self.read_idx = read_idx

            self.loss = SignBCELoss(self.alpha, self.thresh, self.imbalance)
            self.l1loss = nn.SmoothL1Loss()
            # self.l1loss = nn.MSELoss()

      def forward(self, readout, targets):

            zeros = torch.zeros_like(readout[:, self.zero_idx, 0])
            # custom zeros for readout
            BL_loss = self.l1loss(readout[:, self.zero_idx, self.read_idx[0]], zeros)
            # zero memory only before stim
            if len(self.read_idx)>1:
                  BL_loss += self.l1loss(readout[:, :9, self.read_idx[1]], zeros[:, :9])

            is_empty = (self.cue_idx.numel() == 0)

            if is_empty:
                  DPA_loss = self.loss(readout[:,  self.rwd_idx, self.read_idx[0]], targets)
                  return DPA_loss + BL_loss
            else:
                  self.loss.imbalance = self.imbalance[0]
                  DPA_loss = self.loss(readout[:,  self.rwd_idx, self.read_idx[0]], targets[:, 0, :self.rwd_idx.shape[0]])

                  self.loss.imbalance = self.imbalance[1]
                  DRT_loss = self.loss(readout[:, self.cue_idx, self.read_idx[1]], targets[:, 1, :self.cue_idx.shape[0]])

                  return DPA_loss + DRT_loss + BL_loss
```

``` ipython
import torch
import torch.nn as nn
import torch.nn.functional as F

class Accuracy(nn.Module):
      def __init__(self, thresh=4.0):
            super(Accuracy, self).__init__()
            self.thresh = thresh

      def forward(self, readout, targets):
            mean_readout = readout.mean(dim=1)
            # print(readout.shape, targets.shape)
            # mean_readout = readout[:, -1]
            sign_loss = (mean_readout >= self.thresh)
            return 1.0 * (sign_loss == targets[:, 0])

```

``` ipython
class DualPerf(nn.Module):
      def __init__(self, alpha=1.0, thresh=2.0, cue_idx=[], rwd_idx=-1, zero_idx=[], read_idx=[-1], imbalance=0):
            super(DualPerf, self).__init__()
            self.alpha = alpha
            self.thresh = thresh

            self.imbalance = imbalance

            # BL idx
            self.zero_idx = zero_idx
            # rwd idx for DRT
            self.cue_idx = torch.tensor(cue_idx, dtype=torch.int, device=DEVICE)
            # rwd idx for DPA
            self.rwd_idx = torch.tensor(rwd_idx, dtype=torch.int, device=DEVICE)

            # readout idx
            self.read_idx = read_idx

            self.loss = Accuracy(thresh=self.thresh)

      def forward(self, readout, targets):
            targets[targets==-1] = 0
            is_empty = (self.cue_idx.numel() == 0)

            if is_empty:
                  DPA_loss = self.loss(readout[:,  self.rwd_idx, self.read_idx[0]], targets)
                  return DPA_loss
            else:
                  self.loss.imbalance = self.imbalance[0]
                  DPA_loss = self.loss(readout[:,  self.rwd_idx, self.read_idx[0]], targets[:, 0, :self.rwd_idx.shape[0]])

                  self.loss.imbalance = self.imbalance[1]
                  DRT_loss = self.loss(readout[:, self.cue_idx, self.read_idx[1]], targets[:, 1, :self.cue_idx.shape[0]])

                  return DPA_loss, DRT_loss
```

Other
-----

``` ipython
def angle_AB(A, B):
      A_norm = A / (np.linalg.norm(A) + 1e-5)
      B_norm = B / (np.linalg.norm(B) + 1e-5)

      return int(np.arccos(A_norm @ B_norm) * 180 / np.pi)
```

``` ipython
def get_theta(a, b, GM=0, IF_NORM=0):

      u, v = a, b

      if GM:
          v = b - np.dot(b, a) / np.dot(a, a) * a

      if IF_NORM:
          u = a / np.linalg.norm(a)
          v = b / np.linalg.norm(b)

      return np.arctan2(v, u) % (2.0 * np.pi)
```

``` ipython
def get_idx(model, rank=1):
      # print(model.low_rank.U.shape)
      # ksi = torch.vstack((model.low_rank.U[:,0], model.low_rank.U[:,1]))
      ksi = torch.hstack((model.low_rank.U, model.low_rank.V)).T
      ksi = ksi[:, :model.Na[0]]

      try:
            readout = model.low_rank.linear.weight.data
            ksi = torch.vstack((ksi, readout))
      except:
            pass

      print('ksi', ksi.shape)

      ksi = ksi.cpu().detach().numpy()
      theta = get_theta(ksi[0], ksi[rank])

      return theta.argsort()
```

``` ipython
def get_overlap(model, rates):
      ksi = model.odors.cpu().detach().numpy()
      return rates @ ksi.T / rates.shape[-1]
```

``` ipython
import scipy.stats as stats

def plot_smooth(data, ax, color):
      mean = data.mean(axis=0)
      ci = smooth.std(axis=0, ddof=1) * 1.96

      # Plot
      ax.plot(mean, color=color)
      ax.fill_between(range(data.shape[1]), mean - ci, mean + ci, alpha=0.25, color=color)

```

``` ipython
def convert_seconds(seconds):
      h = seconds // 3600
      m = (seconds % 3600) // 60
      s = seconds % 60
      return h, m, s
```

plots
-----

``` ipython
import pickle as pkl
import os
def pkl_save(obj, name, path="."):
    os.makedirs(path, exist_ok=True)
    destination = path + "/" + name + ".pkl"
    print("saving to", destination)
    pkl.dump(obj, open(destination, "wb"))


def pkl_load(name, path="."):
    source = path + "/" + name + '.pkl'
    # print('loading from', source)
    return pkl.load(open( source, "rb"))

```

``` ipython
def add_vlines(ax=None, mouse=""):
    t_BL = [0, 1]
    t_STIM = [1 , 2]
    t_ED = [2, 3]
    t_DIST = [3 , 4]
    t_MD = [4 , 5]
    t_CUE = [5 , 5.5]
    t_RWD = [5.5, 6.0]
    t_LD = [6.0 , 7.0]
    t_TEST = [7.0, 8.0]
    t_RWD2 = [11 , 12]

    time_periods = [t_STIM, t_DIST, t_TEST, t_CUE]
    colors = ["b", "b", "b", "g"]

    if ax is None:
        for period, color in zip(time_periods, colors):
            plt.axvspan(period[0], period[1], alpha=0.1, color=color)
    else:
        for period, color in zip(time_periods, colors):
            ax.axvspan(period[0], period[1], alpha=0.1, color=color)

```

``` ipython
def plot_rates_selec(rates, idx, thresh=0.5, figname='fig.svg'):
        ordered = rates[..., idx]
        fig, ax = plt.subplots(1, 2, figsize=[2*width, height])
        r_max = thresh * np.max(rates[0])

        ax[0].imshow(rates[0].T, aspect='auto', cmap='jet', vmin=0, vmax=r_max)
        ax[0].set_ylabel('Neuron #')
        ax[0].set_xlabel('Step')

        ax[1].imshow(ordered[0].T, aspect='auto', cmap='jet', vmin=0, vmax=r_max)
        ax[1].set_yticks(np.linspace(0, model.Na[0].cpu().detach(), 5), np.linspace(0, 360, 5).astype(int))
        ax[1].set_ylabel('Pref. Location (°)')
        ax[1].set_xlabel('Step')
        plt.savefig(figname, dpi=300)
        plt.show()
```

``` ipython
def plot_overlap_label(readout, y, y1, labels=['pair', 'unpair'], figname='fig.svg'):
      fig, ax = plt.subplots(2, 3, figsize=[3*width, 2*height], sharey=True)

      time = np.linspace(0, 8, readout.shape[1])
      trial = [0, 1, -1]
      colors = ['r', 'b', 'g']
      for j in range(3):
            for i in range(readout.shape[-1]):
                  i = -1
                  ax[i][j].plot(time, np.nanmean(readout[(y==1)&(y1==trial[j]), :, i], 0), ls='-', label=labels[0], color=colors[j])
                  ax[i][j].plot(time, np.nanmean(readout[(y==0)&(y1==trial[j]), :, i], 0), ls='--', label=labels[1], color=colors[j])

                  add_vlines(ax[i][j])
                  ax[i][j].set_xlabel('Time (s)')

            ax[0][j].set_ylabel('Sample Overlap (Hz)')
            ax[1][j].set_ylabel('Go/NoGo Overlap (Hz)')
      # ax[2].set_ylabel('Readout (Hz)')
      # ax[1].legend(fontsize=10, frameon=False)
      plt.savefig(figname, dpi=300)
      plt.show()
```

``` ipython
def plot_avg_overlap(readout, n_batch, labels=['A', 'B'], figname='fig.svg'):
      fig, ax = plt.subplots(1, 2, figsize=[2*width, height])

      time = np.linspace(0, 8, readout.shape[1])
      size = readout.shape[0] // 2
      print(readout.shape[0], size)

      readout = readout.reshape((3, ))

      for i in range(readout.shape[-1]):
            if i==0:
                  ax[i].plot(time, (readout[:size, :, i].T - readout[size:,:,i].T), ls='-', label=labels[0])
            else:
                  ax[i].plot(time, readout[size:, :, i].T, ls='--', label='Go')

            add_vlines(ax[i])
            ax[i].set_xlabel('Time (s)')

      ax[0].set_ylabel('Sample Overlap (Hz)')
      ax[1].set_ylabel('Go/NoGo Overlap (Hz)')
      # ax[2].set_ylabel('Readout (Hz)')

      # plt.legend(fontsize=10, frameon=False)
      plt.savefig(figname, dpi=300)
      plt.show()
```

``` ipython
def plot_m0_m1_phi(rates, idx, figname='fig.svg'):

      m0, m1, phi = decode_bump(rates[..., idx], axis=-1)
      fig, ax = plt.subplots(1, 3, figsize=[3*width, height])

      time = np.linspace(0, 8, m0.T.shape[0])

      ax[0].plot(time, m0[:2].T)
      ax[0].plot(time, m0[2:].T, '--')
      #ax[0].set_ylim([0, 360])
      #ax[0].set_yticks([0, 90, 180, 270, 360])
      ax[0].set_ylabel('$\mathcal{F}_0$ (Hz)')
      ax[0].set_ylabel('Activity (Hz)')
      ax[0].set_xlabel('Time (s)')
      add_vlines(ax[0])

      ax[1].plot(time, m1[:2].T)
      ax[1].plot(time, m1[2:].T, '--')
      # ax[1].set_ylim([0, 360])
      # ax[1].set_yticks([0, 90, 180, 270, 360])
      ax[1].set_ylabel('$\mathcal{F}_1$ (Hz)')
      ax[1].set_ylabel('Bump Amplitude (Hz)')
      ax[1].set_xlabel('Time (s)')
      add_vlines(ax[1])

      ax[2].plot(time, phi[:2].T * 180 / np.pi)
      ax[2].plot(time, phi[2:].T * 180 / np.pi, '--')
      ax[2].set_ylim([0, 360])
      ax[2].set_yticks([0, 90, 180, 270, 360])
      ax[2].set_ylabel('Bump Center (°)')
      ax[2].set_xlabel('Time (s)')
      add_vlines(ax[2])

      plt.savefig(figname, dpi=300)
      plt.show()
```

``` ipython
from matplotlib.patches import Circle

def plot_fix_points(rates, ax, title=''):
    m0, m1, phi = decode_bump(rates[:, -1], axis=-1)

    x = np.cos(phi)
    y = np.sin(phi)

    xNoGo = np.cos(3*np.pi /2.)
    yNoGo = np.sin(3*np.pi /2)

    xGo = np.cos(np.pi /2.)
    yGo = np.sin(np.pi /2)

    # rad = np.max(np.sqrt(x**2+y**2))

    ax.plot(x, y, 'o', ms=15)
    ax.plot(xGo, yGo, 'o', ms=15, color='w', markeredgecolor='k')
    ax.plot(xNoGo, yNoGo, 'o', ms=15, color='w', markeredgecolor='k')
    circle = Circle((0., 0.), 1, fill=False, edgecolor='k')
    ax.add_patch(circle)

    # Set the aspect of the plot to equal to make the circle circular
    ax.set_aspect('equal')
    ax.set_title(title)
    ax.axis('off')
    # plt.savefig('fp_dpa.svg', dpi=300)
    # plt.show()
```

``` ipython
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

# Define custom colormap with red at the center
cdict = {
    'red':   [(0.0, 0.0, 0.0),
              (0.5, 1.0, 1.0),
              (1.0, 1.0, 1.0)],
    'green': [(0.0, 0.0, 0.0),
              (0.5, 0.0, 0.0),
              (1.0, 1.0, 1.0)],
    'blue':  [(0.0, 1.0, 1.0),
              (0.5, 0.0, 0.0),
              (1.0, 0.0, 0.0)]
}

custom_cmap = LinearSegmentedColormap('RedCenterMap', cdict)

# Plot to visualize the colormap
gradient = np.linspace(0, 1, 256)
gradient = np.vstack((gradient, gradient))

fig, ax = plt.subplots(figsize=(6, 1))
ax.imshow(gradient, aspect='auto', cmap=custom_cmap)
ax.set_axis_off()
plt.show()
```

``` ipython
def plot_overlap(readout, labels=['pair', 'unpair'], figname='fig.svg'):
      fig, ax = plt.subplots(1, 2, figsize=[2*width, height])

      time = np.linspace(0, 8, readout.shape[1])
      size = readout.shape[0] // 2

      for i in range(readout.shape[-1]):
            ax[i].plot(time, readout[:size, :, i].T, ls='-', label=labels[0])
            if i==0:
                  ax[i].plot(time, -readout[size:, :, i].T, ls='--', label=labels[1])
            else:
                  ax[i].plot(time, readout[size:, :, i].T, ls='--', label=labels[1])

            add_vlines(ax[i])
            ax[i].set_xlabel('Time (s)')

      ax[0].set_ylabel('Sample Overlap (Hz)')
      ax[1].set_ylabel('Go/NoGo Overlap (Hz)')

      # ax[1].legend(fontsize=10, frameon=False)
      plt.savefig(figname, dpi=300)
      plt.show()
```

Model
=====

``` ipython
print(B0)
```

``` ipython
REPO_ROOT = "/home/leon/models/NeuroFlame"
conf_name = "train_dual.yml"
DEVICE = 'cuda:1'

seed = np.random.randint(0, 1e6)

# seed = 971646 # good
# : 104378
# seed = 330502
# seed= 849639

# #+RESULTS:
#  387828
# seed = 305810
# seed = 312784
# seed = 763019
# seed = 713495
# 544891
# : 413416
# : 744944
# seed= 151689
# : 2261
# seed = 295741 # not bad
# seed= 404520
seed= 332246 # china

print(seed)
A0 = 1.0 # sample/dist
B0 = 2.0 # cue
C0 = 0.0 # DRT rwd
```

``` ipython
model = Network(conf_name, REPO_ROOT, VERBOSE=0, DEVICE=DEVICE, SEED=seed, N_BATCH=1)
```

``` ipython
```

Sample Classification
=====================

Training
--------

### Parameters

``` ipython
model.J_STP.requires_grad = True
model.low_rank.lr_kappa.requires_grad = False

if model.LR_READOUT:
    for param in model.low_rank.linear.parameters():
        param.requires_grad = False
    model.low_rank.linear.bias.requires_grad = False
```

``` ipython
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, param.shape)
```

Testing the network on steps from sample odor offset to test odor onset

``` ipython
steps = np.arange(0, model.N_STEPS - model.N_STEADY, model.N_WINDOW)

mask = (steps >= (model.N_STIM_OFF[0].cpu().numpy() - model.N_STEADY)) & (steps <= (model.N_STEPS - model.N_STEADY))
rwd_idx = np.where(mask)[0]
print('rwd', rwd_idx)

model.lr_eval_win = rwd_idx.shape[0]

stim_mask = (steps >= (model.N_STIM_ON[0].cpu().numpy() - model.N_STEADY)) & (steps < (model.N_STIM_OFF[0].cpu().numpy() - model.N_STEADY))

zero_idx = np.where(~mask & ~stim_mask )[0]
print('zero', zero_idx)
```

### Inputs and Labels

``` ipython
model.N_BATCH = 512

model.I0[0] = A0
model.I0[1] = 0
model.I0[2] = 0
model.I0[3] = 0
model.I0[4] = 0

A = model.init_ff_input()

model.I0[0] = -A0
model.I0[1] = 0
model.I0[2] = 0
model.I0[3] = 0
model.I0[4] = 0

B = model.init_ff_input()

ff_input = torch.cat((A, B))
print(ff_input.shape)
```

``` ipython
labels_A = torch.ones((model.N_BATCH, rwd_idx.shape[0]))
labels_B = torch.zeros((model.N_BATCH, rwd_idx.shape[0]))
labels = torch.cat((labels_A, labels_B))

print('labels', labels.shape)
```

### Run

``` ipython
batch_size = 32
train_loader, val_loader = split_data(ff_input, labels, train_perc=0.8, batch_size=batch_size)
```

``` ipython
criterion = DualLoss(alpha=1.0, thresh=4.0, rwd_idx=rwd_idx, zero_idx=zero_idx, imbalance=1, read_idx=[0])
learning_rate = 0.1
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
```

``` ipython
print('Sample Classification')
num_epochs = 15
start = perf_counter()
loss, val_loss = optimization(model, train_loader, val_loader, criterion, optimizer, num_epochs, zero_grad=1)
end = perf_counter()
print("Elapsed (with compilation) = %dh %dm %ds" % convert_seconds(end - start))
```

``` ipython
torch.save(model.state_dict(), 'models/dual/dpa_naive_%d.pth' % seed)
```

Testing
-------

``` ipython
model.eval()
```

``` ipython
model.N_BATCH = 10

model.I0[0] = 2
model.I0[1] = 0
model.I0[2] = 0

A = model.init_ff_input()

model.I0[0] = -2
model.I0[1] = 0
model.I0[2] = 0

B = model.init_ff_input()

ff_input = torch.cat((A, B))
print('ff_input', ff_input.shape)
```

``` ipython
rates = model.forward(ff_input=ff_input).cpu().detach().numpy()
print('rates', rates.shape)
```

``` ipython
readout = model.readout.cpu().detach().numpy()
print('readout', readout.shape)
plot_overlap(readout, labels=['A', 'B'])
```

``` ipython
idx = get_idx(model, 1)
plot_rates_selec(rates, idx)
```

``` ipython
idx = get_idx(model, 1)
plot_m0_m1_phi(rates, idx)
```

DPA
===

``` ipython
model_state_dict = torch.load('models/dual/dpa_naive_%d.pth' % seed)
model.load_state_dict(model_state_dict)
```

Training
--------

### Parameters

``` ipython
model.J_STP.requires_grad = False
model.low_rank.lr_kappa.requires_grad = False
```

Here we only evaluate performance from test onset to test offset

``` ipython
steps = np.arange(0, model.N_STEPS - model.N_STEADY, model.N_WINDOW)
mask = (steps >= (model.N_STIM_ON[4].cpu().numpy() - model.N_STEADY)) & (steps <= (model.N_STEPS - model.N_STEADY))
rwd_idx = np.where(mask)[0]
print('rwd', rwd_idx)

# mask for Go/NoGo memory from dist to cue
cue_mask = (steps >= (model.N_STIM_ON[0].cpu().numpy() - model.N_STEADY)) & (steps < (model.N_STIM_ON[-1].cpu().numpy() - model.N_STEADY))
cue_idx = np.where(cue_mask)[0]
cue_idx = []
print('cue', cue_idx)

if len(cue_idx) !=0:
    model.lr_eval_win = np.max((rwd_idx.shape[0], cue_idx.shape[0]))
else:
    model.lr_eval_win = rwd_idx.shape[0]

stim_mask = (steps >= (model.N_STIM_ON[0].cpu().numpy() - model.N_STEADY)) & (steps < (model.N_STIM_OFF[0].cpu().numpy() - model.N_STEADY))

mask_zero = ~mask  & ~stim_mask
zero_idx = np.where(mask_zero)[0]
print('zero', zero_idx)
```

### Inputs and Labels

``` ipython
model.N_BATCH = 256

model.I0[0] = A0 # sample
model.I0[1] = 0 # distractor
model.I0[2] = 0 # cue
model.I0[3] = 0 # drt rwd
model.I0[4] = A0 # test

AC_pair = model.init_ff_input()

model.I0[0] = A0
model.I0[1] = 0
model.I0[2] = 0
model.I0[3] = 0
model.I0[4] = -A0

AD_pair = model.init_ff_input()

model.I0[0] = -A0
model.I0[1] = 0
model.I0[2] = 0
model.I0[3] = 0
model.I0[4] = A0

BC_pair = model.init_ff_input()

model.I0[0] = -A0
model.I0[1] = 0
model.I0[2] = 0
model.I0[3] = 0
model.I0[4] = -A0

BD_pair = model.init_ff_input()

ff_input = torch.cat((AC_pair, BD_pair, AD_pair, BC_pair))
print('ff_input', ff_input.shape)
```

``` ipython
labels_pair = torch.ones((2 * model.N_BATCH, model.lr_eval_win))
labels_unpair = torch.zeros((2 * model.N_BATCH, model.lr_eval_win))

labels = torch.cat((labels_pair, labels_unpair))

if len(cue_idx)!=0:
    labels =  labels.repeat((2, 1, 1))
    labels = torch.transpose(labels, 0, 1)
    model.J_STP.requires_grad = True

print('labels', labels.shape)
```

### Run

``` ipython
batch_size = 32
train_loader, val_loader = split_data(ff_input, labels, train_perc=0.8, batch_size=batch_size)
```

``` ipython
if len(cue_idx) == 0:
    criterion = DualLoss(alpha=1.0, thresh=4.0, rwd_idx=rwd_idx, zero_idx=zero_idx, imbalance=0, read_idx=[1])
else:
    criterion = DualLoss(alpha=1.0, thresh=4.0, rwd_idx=rwd_idx, zero_idx=zero_idx, cue_idx=cue_idx, imbalance=[0.0, 1.0], read_idx=[1, 0])

learning_rate = 0.1
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
```

``` ipython
print('training DPA')
num_epochs = 30
start = perf_counter()
if len(cue_idx) == 0:
    loss, val_loss = optimization(model, train_loader, val_loader, criterion, optimizer, num_epochs, zero_grad=0)
else:
    loss, val_loss = optimization(model, train_loader, val_loader, criterion, optimizer, num_epochs, zero_grad=None)
end = perf_counter()
print("Elapsed (with compilation) = %dh %dm %ds" % convert_seconds(end - start))
torch.save(model.state_dict(), 'models/dual/dpa_%d.pth' % seed)
```

``` ipython
```

Testing
-------

``` ipython
model_state_dict = torch.load('models/dual/dpa_%d.pth' % seed)
model.load_state_dict(model_state_dict)
```

``` ipython
model.eval()
```

``` ipython
model.N_BATCH = 1

model.I0[0] = A0
model.I0[1] = 0
model.I0[2] = 0
model.I0[3] = 0
model.I0[4] = A0

AC_pair = model.init_ff_input()

model.I0[0] = A0
model.I0[1] = 0
model.I0[2] = 0
model.I0[3] = 0
model.I0[4] = -A0

AD_pair = model.init_ff_input()

model.I0[0] = -A0
model.I0[1] = 0
model.I0[2] = 0
model.I0[3] = 0
model.I0[4] = A0

BC_pair = model.init_ff_input()

model.I0[0] = -A0
model.I0[1] = 0
model.I0[2] = 0
model.I0[3] = 0
model.I0[4] = -A0

BD_pair = model.init_ff_input()

ff_input = torch.cat((AC_pair, BD_pair, AD_pair, BC_pair))
print('ff_input', ff_input.shape)
```

``` ipython
labels_pair = torch.ones((2 * model.N_BATCH, 2))
labels_unpair = torch.zeros((2 * model.N_BATCH, 2))

labels = torch.cat((labels_pair, labels_unpair))
print('labels', labels.shape)
```

``` ipython
rates = model.forward(ff_input=ff_input).detach().cpu().numpy()
rates_dpa = rates
print(rates.shape)
```

``` ipython
plot_overlap(model.readout.cpu().detach().numpy(), labels=['pair', 'unpair'], figname='./figures/dual/dpa_overlap_%d.svg' % seed)
```

``` ipython
idx = get_idx(model, 1)
plot_m0_m1_phi(rates, idx)
```

``` ipython
```

Go/NoGo
=======

``` ipython
model_state_dict = torch.load('models/dual/dpa_%d.pth' % seed)
model.load_state_dict(model_state_dict)
```

Training
--------

``` ipython
model.J_STP.requires_grad = False
model.low_rank.lr_kappa.requires_grad = False
```

``` ipython
for name, param in model.named_parameters():
      if param.requires_grad:
            print(name, param.shape)
```

``` ipython
steps = np.arange(0, model.N_STEPS - model.N_STEADY, model.N_WINDOW)

# mask for lick/nolick  from cue to test
rwd_mask = (steps >= (model.N_STIM_ON[2].cpu().numpy() - model.N_STEADY)) & (steps < (model.N_STIM_ON[4].cpu().numpy() - model.N_STEADY))
rwd_idx = np.where(rwd_mask)[0]
print('rwd', rwd_idx)

# mask for Go/NoGo memory from dist to cue
cue_mask = (steps >= (model.N_STIM_ON[1].cpu().numpy() - model.N_STEADY)) & (steps < (model.N_STIM_ON[2].cpu().numpy() - model.N_STEADY))
cue_idx = np.where(cue_mask)[0]
# cue_idx = []
print('cue', cue_idx)

mask_zero = (steps < (model.N_STIM_ON[1].cpu().numpy() - model.N_STEADY))
zero_idx = np.where(mask_zero)[0]
print('zero', zero_idx)

if len(cue_idx)!=0:
    model.lr_eval_win = np.max( (rwd_idx.shape[0], cue_idx.shape[0]))
else:
    model.lr_eval_win = rwd_idx.shape[0]
```

``` ipython
model.N_BATCH = 512

model.I0[0] = 0
model.I0[1] = A0
model.I0[2] = float(B0)
model.I0[3] = 0
model.I0[4] = 0

Go = model.init_ff_input()

model.I0[0] = 0
model.I0[1] = -A0
model.I0[2] = float(B0)
model.I0[3] = 0
model.I0[4] = 0

NoGo = model.init_ff_input()

ff_input = torch.cat((Go, NoGo))
print(ff_input.shape)
```

``` ipython
labels_Go = torch.ones((model.N_BATCH, model.lr_eval_win))
labels_NoGo = torch.zeros((model.N_BATCH, model.lr_eval_win))
labels = torch.cat((labels_Go, labels_NoGo))
print(labels.shape)
# print(labels)
if len(cue_idx)!=0:
    labels =  labels.repeat((2, 1, 1))
    labels = torch.transpose(labels, 0, 1)
print('labels', labels.shape)
```

### Run

``` ipython
batch_size = 32
train_loader, val_loader = split_data(ff_input, labels, train_perc=0.8, batch_size=batch_size)
```

``` ipython
criterion = DualLoss(alpha=1.0, thresh=4.0, rwd_idx=rwd_idx, zero_idx=zero_idx, cue_idx=cue_idx, imbalance=[0.0, 1.0], read_idx=[1, 1])
learning_rate = 0.1
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
```

``` ipython
print('training DRT')
num_epochs = 30
start = perf_counter()
loss, val_loss = optimization(model, train_loader, val_loader, criterion, optimizer, num_epochs, zero_grad=0)
end = perf_counter()
print("Elapsed (with compilation) = %dh %dm %ds" % convert_seconds(end - start))
```

``` ipython
torch.save(model.state_dict(), 'models/dual/dual_naive_%d.pth' % seed)
```

Test
----

``` ipython
model_state_dict = torch.load('models/dual/dual_naive_%d.pth' % seed)
model.load_state_dict(model_state_dict)
```

``` ipython
model.eval()
```

``` example
Network(
  (low_rank): LowRankWeights()
  (dropout): Dropout(p=0.0, inplace=False)
)
Network(
  (low_rank): LowRankWeights()
  (dropout): Dropout(p=0.0, inplace=False)
)
```

``` ipython
model.N_BATCH = 1

model.I0[0] = 0
model.I0[1] = A0
model.I0[2] = float(B0)
model.I0[3] = 0.0
model.I0[4] = 0.0

A = model.init_ff_input()

model.I0[0] = 0 # NoGo
model.I0[1] = -A0 # cue
model.I0[2] = float(B0) # rwd
model.I0[3] = 0.0
model.I0[4] = 0.0

B = model.init_ff_input()

ff_input = torch.cat((A, B))
print('ff_input', ff_input.shape)
```

:

``` example
ff_input torch.Size([2, 505, 1000])
```

``` ipython
rates = model.forward(ff_input=ff_input).cpu().detach().numpy()
print(rates.shape)
```

``` ipython
plot_overlap(model.readout.cpu().detach().numpy(), labels=['Go', 'NoGo'], figname='./figures/dual/GoNoGo_overlaps_%d.svg' % seed)
```

``` ipython
```

Dual Naive
==========

Testing
-------

``` ipython
model_state_dict = torch.load('models/dual/dual_naive_%d.pth' % seed)
model.load_state_dict(model_state_dict)
```

``` ipython
steps = np.arange(0, model.N_STEPS - model.N_STEADY, model.N_WINDOW)

mask_rwd = (steps >= (model.N_STIM_ON[-1].cpu().numpy() - model.N_STEADY))
rwd_idx = np.where(mask_rwd)[0]
print('rwd', rwd_idx)

mask_cue = (steps >= (model.N_STIM_ON[2].cpu().numpy() - model.N_STEADY)) & (steps <= (model.N_STIM_OFF[3].cpu().numpy() - model.N_STEADY))
cue_idx = np.where(mask_cue)[0]
print('cue', cue_idx)

mask_GnG = (steps >= (model.N_STIM_OFF[1].cpu().numpy() - model.N_STEADY)) & (steps <= (model.N_STIM_ON[2].cpu().numpy() - model.N_STEADY))
GnG_idx = np.where(mask_GnG)[0]
print('GnG', GnG_idx)

stim_mask = (steps >= (model.N_STIM_ON[0].cpu().numpy() - model.N_STEADY))

mask_zero = ~mask_rwd & ~mask_cue & ~stim_mask
zero_idx = np.where(mask_zero)[0]
print('zero', zero_idx)
```

``` ipython
U = model.low_rank.U.cpu().detach().numpy()[model.slices[0], 0]
V = model.low_rank.V.cpu().detach().numpy()[model.slices[0], 0]

odors = model.odors.cpu().numpy()

m = model.low_rank.U.cpu().detach().numpy()[model.slices[0], 1]
n = model.low_rank.V.cpu().detach().numpy()[model.slices[0], 1]

vectors = [U, V, m, n]
labels = ['$m_\\text{AB}$', '$n_\\text{AB}$', '$m_\\text{GnG}$', '$n_\\text{GnG}$']

import numpy as np
import matplotlib.pyplot as plt

# Calculate the covariance matrix
num_vectors = len(vectors)
cov_matrix = np.zeros((num_vectors, num_vectors))

for i in range(num_vectors):
    for j in range(num_vectors):
        cov_matrix[i][j] = angle_AB(vectors[i], vectors[j])

# Mask the upper triangle
mask = np.triu(np.ones_like(cov_matrix, dtype=bool))
masked_cov_matrix = np.ma.masked_array(cov_matrix, mask=mask)

plt.figure(figsize=(8, 6))

# Plot the masked covariance matrix
img = plt.imshow(masked_cov_matrix, cmap=custom_cmap, interpolation='nearest', vmin=30, vmax=150)
cbar = plt.colorbar(label='Angle (°)')
cbar.set_ticks([30, 90, 120])

# Set axis labels on top and left
# plt.gca().xaxis.tick_top()
plt.xticks(ticks=np.arange(num_vectors), labels=labels)
plt.yticks(ticks=np.arange(num_vectors), labels=labels)

# Invert y-axis
plt.gca().invert_yaxis()

for i in range(num_vectors):
    for j in range(i + 1):
        plt.text(j, i, f'{cov_matrix[i, j]:.0f}', ha='center', va='center', color='black')

plt.savefig('./figures/dual/cov_naive_%d.svg' % seed, dpi=300)
plt.show()
```

``` ipython
model.eval()
```

``` ipython
N_BATCH = 32
model.N_BATCH = N_BATCH

model.lr_eval_win = np.max( (rwd_idx.shape[0], cue_idx.shape[0]))

ff_input = []
labels = np.zeros((2, 12, model.N_BATCH, model.lr_eval_win))

l=0
for j in [0, 1, -1]:
    for i in [-1, 1]:
        for k in [-1, 1]:

            model.I0[0] = i # sample
            model.I0[1] = j*1.5 # distractor
            model.I0[4] = k # test

            if i==k: # Pair Trials
                labels[0, l] = np.ones((model.N_BATCH, model.lr_eval_win))

            if j==1: # Go
                model.I0[2] = float(B0) # cue
                model.I0[3] = float(C0) * model.IF_RL # rwd

                labels[1, l] = np.ones((model.N_BATCH, model.lr_eval_win))
            elif j==-1: # NoGo
                model.I0[2] = float(B0) # cue
                model.I0[3] = 0.0 # rwd
                labels[1, l] = -np.ones((model.N_BATCH, model.lr_eval_win))
            else: # DPA
                model.I0[2] = 0 # cue
                model.I0[3] = 0 # rwd

            l+=1

            ff_input.append(model.init_ff_input())

labels = torch.tensor(labels, dtype=torch.float, device=DEVICE).reshape(2, 12 * model.N_BATCH, model.lr_eval_win).transpose(0,1)
# labels = torch.tensor(labels, dtype=torch.float, device=DEVICE).reshape(3, -1, model.lr_eval_win).transpose(0, 1)
ff_input = torch.vstack(ff_input)
print('ff_input', ff_input.shape, 'labels', labels.shape)
```

``` ipython
rates = model.forward(ff_input=ff_input).detach()
rates = rates.cpu().detach().numpy()
print(rates.shape)
```

``` ipython
def calculate_mean_accuracy_and_sem(accuracies):
    mean_accuracy = accuracies.mean()
    std_dev = accuracies.std(unbiased=True).item()
    sem = std_dev / np.sqrt(len(accuracies))
    return mean_accuracy, sem
```

``` ipython
readout = model.readout.cpu().detach().numpy()
print(readout.shape)
```

``` ipython
plot_overlap_label(model.readout.cpu().detach().numpy(), y=labels[:,0,0].cpu().numpy(), y1=labels[:,1,0].cpu().numpy() , labels=['pair', 'unpair'], figname='./figures/dual/overlaps_task_%d.svg' %seed)
```

``` ipython
criterion = DualPerf(alpha=1.0, thresh=4.0, cue_idx=cue_idx, rwd_idx=rwd_idx, zero_idx=zero_idx, imbalance=[0.0, 0.0], read_idx=[1, 1])
dpa_perf, drt_perf = criterion(model.readout, labels)
print(dpa_perf.mean(), drt_perf.mean())
dpa_mean, dpa_sem = calculate_mean_accuracy_and_sem(dpa_perf)
drt_mean, drt_sem = calculate_mean_accuracy_and_sem(drt_perf)
print('perf', dpa_mean, drt_mean)
```

``` ipython
readout = model.readout.cpu().detach().numpy().reshape(3, -1, 81, 2) / 2
print(readout.shape)

time = np.linspace(0, 8, readout.shape[-2])
fig, ax = plt.subplots(1, 2, figsize=[2*width, height])

color = ['r', 'b', 'g']
label = ['DPA', 'DualGo', 'DualNoGo']

for i in range(3):
    sample = (-readout[i, :2*N_BATCH, :, 0].T  + readout[i, 2*N_BATCH:, :, 0].T)
    dist = (readout[i, :2*N_BATCH, :, 1].T  + readout[i, 2*N_BATCH:, :, 1].T)

    ax[0].plot(time, sample.mean(1), color=color[i])
    ax[1].plot(time, dist.mean(1), color=color[i], label=label[i])

add_vlines(ax[0])
add_vlines(ax[1])
ax[0].set_ylabel('Sample Overlap (Hz)')
ax[1].set_ylabel('Go/NoGo Overlap (Hz)')
ax[1].legend(frameon=False, fontsize=10)
ax[0].set_xlabel('Time (s)')
ax[1].set_xlabel('Time (s)')

plt.savefig('./figures/dual/dual_naive_%d_over.svg' % seed, dpi=300)
plt.show()
```

``` ipython
```

Training
--------

``` ipython
model.J_STP.requires_grad = False
model.low_rank.lr_kappa.requires_grad = False
```

``` ipython
model.N_BATCH = 86

model.lr_eval_win = np.max( (rwd_idx.shape[0], cue_idx.shape[0]))

ff_input = []
labels = np.zeros((2, 12, model.N_BATCH, model.lr_eval_win))

l=0
for i in [-1, 1]:
    for j in [-1, 0, 1]:
        for k in [-1, 1]:

            model.I0[0] = i # sample
            model.I0[1] = j # distractor
            model.I0[4] = k # test

            if i==k: # Pair Trials
                labels[0, l] = np.ones((model.N_BATCH, model.lr_eval_win))

            if j==1: # Go
                model.I0[2] = float(B0) # cue
                model.I0[3] = float(C0) * model.IF_RL # rwd

                labels[1, l] = np.ones((model.N_BATCH, model.lr_eval_win))
            elif j==-1: # NoGo
                model.I0[2] = float(B0) # cue
                model.I0[3] = 0.0 # rwd
                labels[1, l] = -np.ones((model.N_BATCH, model.lr_eval_win))
            else: # DPA
                model.I0[2] = 0 # cue
                model.I0[3] = 0 # rwd

            l+=1

            ff_input.append(model.init_ff_input())

labels = torch.tensor(labels, dtype=torch.float, device=DEVICE).reshape(2, -1, model.lr_eval_win).transpose(0, 1)
# labels = torch.tensor(labels, dtype=torch.float, device=DEVICE).reshape(3, -1, model.lr_eval_win).transpose(0, 1)
ff_input = torch.vstack(ff_input)
print('ff_input', ff_input.shape, 'labels', labels.shape)
```

``` ipython
batch_size = 32
train_loader, val_loader = split_data(ff_input, labels, train_perc=0.8, batch_size=batch_size)
```

``` ipython
criterion = DualLoss(alpha=1.0, thresh=4.0, cue_idx=cue_idx, rwd_idx=rwd_idx, zero_idx=zero_idx, imbalance=[0.0, 0.0], read_idx=[1, 1])
learning_rate = 0.1
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
```

``` ipython
print('training Dual')
num_epochs = 30
start = perf_counter()

loss, val_loss = optimization(model, train_loader, val_loader, criterion, optimizer, num_epochs, zero_grad=0)
end = perf_counter()
print("Elapsed (with compilation) = %dh %dm %ds" % convert_seconds(end - start))
```

``` ipython
torch.save(model.state_dict(), 'models/dual/dual_train_%d.pth' % seed)
```

``` ipython
```

Re-Testing
----------

``` ipython
model_state_dict = torch.load('models/dual/dual_train_%d.pth' % seed)
model.load_state_dict(model_state_dict)
```

``` ipython
U = model.low_rank.U.cpu().detach().numpy()[model.slices[0], 0]
V = model.low_rank.V.cpu().detach().numpy()[model.slices[0], 0]

odors = model.odors.cpu().numpy()

m = model.low_rank.U.cpu().detach().numpy()[model.slices[0], 1]
n = model.low_rank.V.cpu().detach().numpy()[model.slices[0], 1]

vectors = [U, V, m, n]
labels = ['$m_\\text{AB}$', '$n_\\text{AB}$', '$m_\\text{GnG}$', '$n_\\text{GnG}$']

import numpy as np
import matplotlib.pyplot as plt

# Calculate the covariance matrix
num_vectors = len(vectors)
cov_matrix = np.zeros((num_vectors, num_vectors))

for i in range(num_vectors):
    for j in range(num_vectors):
        cov_matrix[i][j] = angle_AB(vectors[i], vectors[j])

# Mask the upper triangle
mask = np.triu(np.ones_like(cov_matrix, dtype=bool))
masked_cov_matrix = np.ma.masked_array(cov_matrix, mask=mask)

plt.figure(figsize=(8, 6))

# Plot the masked covariance matrix
img = plt.imshow(masked_cov_matrix, cmap=custom_cmap, interpolation='nearest', vmin=30, vmax=150)
cbar = plt.colorbar(label='Angle (°)')
cbar.set_ticks([30, 90, 120])

# Set axis labels on top and left
# plt.gca().xaxis.tick_top()
plt.xticks(ticks=np.arange(num_vectors), labels=labels)
plt.yticks(ticks=np.arange(num_vectors), labels=labels)

# Invert y-axis
plt.gca().invert_yaxis()

for i in range(num_vectors):
    for j in range(i + 1):
        plt.text(j, i, f'{cov_matrix[i, j]:.0f}', ha='center', va='center', color='black')
plt.savefig('./figures/dual/cov_train_%d.svg' % seed, dpi=300)
plt.show()
```

``` ipython
model.eval()
```

``` ipython
N_BATCH = 32
model.N_BATCH = N_BATCH

model.lr_eval_win = np.max( (rwd_idx.shape[0], cue_idx.shape[0]))

ff_input = []
labels = np.zeros((2, 12, model.N_BATCH, model.lr_eval_win))

l=0
for j in [0, 1, -1]:
    for i in [-1, 1]:
        for k in [-1, 1]:

            model.I0[0] = i # sample
            model.I0[1] = j # distractor
            model.I0[4] = k # test

            if i==k: # Pair Trials
                labels[0, l] = np.ones((model.N_BATCH, model.lr_eval_win))

            if j==1: # Go
                model.I0[2] = float(B0) # cue
                model.I0[3] = float(C0) * model.IF_RL # rwd

                labels[1, l] = np.ones((model.N_BATCH, model.lr_eval_win))
            elif j==-1: # NoGo
                model.I0[2] = float(B0) # cue
                model.I0[3] = 0.0 # rwd
                labels[1, l] = -np.ones((model.N_BATCH, model.lr_eval_win))
            else: # DPA
                model.I0[2] = 0 # cue
                model.I0[3] = 0 # rwd

            l+=1

            ff_input.append(model.init_ff_input())

labels = torch.tensor(labels, dtype=torch.float, device=DEVICE).reshape(2, -1, model.lr_eval_win).transpose(0, 1)
# labels = torch.tensor(labels, dtype=torch.float, device=DEVICE).reshape(3, -1, model.lr_eval_win).transpose(0, 1)
ff_input = torch.vstack(ff_input)
print('ff_input', ff_input.shape, 'labels', labels.shape)
```

``` ipython
rates = model.forward(ff_input=ff_input).detach()
rates = rates.cpu().detach().numpy()
print(rates.shape)
```

``` ipython
plot_overlap_label(model.readout.cpu().detach().numpy(), y=labels[:,0,0].cpu().numpy(), y1=labels[:,1,0].cpu().numpy() , labels=['pair', 'unpair'], figname='./figures/dual/overlaps_task_train_%d.svg' %seed)
```

``` ipython
plot_overlap_label(model.readout.cpu().detach().numpy(), y=labels[:,0,0].cpu().numpy(), y1=labels[:,1,0].cpu().numpy() , labels=['pair', 'unpair'], figname='./figures/dual/overlaps_task_%d.svg' %seed)
```

``` ipython
criterion = DualPerf(alpha=1.0, thresh=4.0, cue_idx=cue_idx, rwd_idx=rwd_idx, zero_idx=zero_idx, imbalance=[0.0, 0.0], read_idx=[1, 1])
dpa_perf2, drt_perf2 = criterion(model.readout, labels)
dpa_mean2, dpa_sem2 = calculate_mean_accuracy_and_sem(dpa_perf2)
drt_mean2, drt_sem2 = calculate_mean_accuracy_and_sem(drt_perf2)
print('perf', dpa_mean2, drt_mean2)
```

``` ipython
fig, ax = plt.subplots(1, 2, figsize=[1.5*width, height], sharex=True)

ax[0].errorbar(0, dpa_mean.item(), yerr=dpa_sem.item(), fmt='o', label='Naive',
             color='k', ecolor='k', elinewidth=3, capsize=5)
ax[0].errorbar(1, dpa_mean2.item(), yerr=dpa_sem2.item(), fmt='o', label='Expert',
             color='k', ecolor='k', elinewidth=3, capsize=5)

ax[0].set_xlim(-1, 2)
ax[0].set_ylim(0.4, 1.1)

ax[0].set_ylabel('DPA Accuracy')
ax[0].set_xticks([0, 1], ['Naive', 'Expert'])
ax[0].axhline(y=0.5, color='k', linestyle='--')

ax[1].errorbar(0, drt_mean.item(), yerr=drt_sem.item(), fmt='o', label='Naive',
             color='k', ecolor='k', elinewidth=3, capsize=5)
ax[1].errorbar(1, drt_mean2.item(), yerr=drt_sem2.item(), fmt='o', label='Expert',
             color='k', ecolor='k', elinewidth=3, capsize=5)

ax[1].set_xlim(-1, 2)
ax[1].set_ylim(0.4, 1.1)

ax[1].set_ylabel('Go/NoGo Accuracy')
ax[1].set_xticks([0, 1], ['Naive', 'Expert'])
ax[1].axhline(y=0.5, color='k', linestyle='--')

plt.savefig('./figures/dual/dual_perf_%d.svg' % seed, dpi=300)

plt.show()
```

``` ipython
readout = model.readout.cpu().detach().numpy().reshape(3, -1, 81, 2) / 2
print(readout.shape)

N_BATCH *= 2

time = np.linspace(0, 8, readout.shape[-2])
fig, ax = plt.subplots(1, 3, figsize=[3*width, height])

color = ['r', 'b', 'g']
label = ['DPA', 'DualGo', 'DualNoGo']

for i in range(3):
    sample = (-readout[i, :N_BATCH, :, 0].T  + readout[i, N_BATCH:, :, 0].T)
    dist = (readout[i, :N_BATCH, :, 1].T  + readout[i, N_BATCH:, :, 1].T)

    ax[0].plot(time, sample.mean(1), color=color[i])
    ax[1].plot(time, dist.mean(1), color=color[i], label=label[i])
    ax[2].plot(sample[:, 0] , dist[:,0], color=color[i], label=label[i])

add_vlines(ax[0])
add_vlines(ax[1])
ax[0].set_ylabel('Sample Overlap (Hz)')
ax[1].set_ylabel('Go/NoGo Overlap (Hz)')
ax[1].legend(frameon=False, fontsize=10)
ax[0].set_xlabel('Time (s)')
ax[1].set_xlabel('Time (s)')
plt.savefig('./figures/dual/dual_train_%d_over.svg' % seed, dpi=300)
plt.show()
```

``` ipython
U = model.low_rank.U.cpu().detach().numpy()
V = model.low_rank.V.cpu().detach().numpy()
print(U.shape)
fig, ax = plt.subplots(1, 2, figsize=[2*width, height])
ax[0].hist(U[:, 0], histtype='step', bins='auto')
ax[0].hist(U[:, 1], histtype='step', bins='auto')
ax[1].hist(V[:, 0], histtype='step', bins='auto')
ax[1].hist(V[:, 1], histtype='step', bins='auto')
plt.show()
```

``` ipython
plt.scatter(U[:, 0], V[:, 0])
plt.show()
```

``` ipython
readout = model.readout.cpu().detach().numpy().reshape(3, -1, 81, 2) / 2
print(readout.shape)

# N_BATCH *= 2

time = np.linspace(0, 8, readout.shape[-2])
fig, ax = plt.subplots(1, 2, figsize=[2*width, height])

color = ['r', 'b', 'g']
label = ['DPA', 'DualGo', 'DualNoGo']

for i in range(3):
    # Calculate sample and dist
    sample = (-readout[i, :N_BATCH, :, 0].T + readout[i, N_BATCH:, :, 0].T)
    dist = (readout[i, :N_BATCH, :, 1].T + readout[i, N_BATCH:, :, 1].T)

    # Calculate mean and SEM for sample and dist
    sample_mean = sample.mean(axis=1)
    sample_sem = sample.std(axis=1, ddof=1) # / np.sqrt(sample.shape[1])

    dist_mean = dist.mean(axis=1)
    dist_sem = dist.std(axis=1, ddof=1) # / np.sqrt(dist.shape[1])

    # Plot mean and SEM for sample
    ax[0].plot(time, sample_mean, color=color[i])
    ax[0].fill_between(time, sample_mean - sample_sem, sample_mean + sample_sem, color=color[i], alpha=0.3)

    # Plot mean and SEM for dist
    ax[1].plot(time, dist_mean, color=color[i], label=label[i])
    ax[1].fill_between(time, dist_mean - dist_sem, dist_mean + dist_sem, color=color[i], alpha=0.3)

# Add vertical lines and labels
add_vlines(ax[0])
add_vlines(ax[1])
ax[0].set_ylabel('Sample Overlap (Hz)')
ax[1].set_ylabel('Go/NoGo Overlap (Hz)')
ax[1].legend(frameon=False, fontsize=10)
ax[0].set_xlabel('Time (s)')
ax[1].set_xlabel('Time (s)')

# Save and show the plot
#plt.savefig('./figures/dual/dual_train_%d_over.svg' % seed, dpi=300)
plt.show()
```

``` ipython
```

Fix
---

``` ipython
def get_fix_points(model, task, seed):
    model_state_dict = torch.load('models/dual/%s_%d.pth' % (task, seed))
    model.load_state_dict(model_state_dict)
    ff_input = get_input(model)
    rates = model.forward(ff_input=ff_input).cpu().detach().numpy()
    idx = get_idx(model, rank=1)
    return rates[..., idx]
```

``` ipython
def get_input(model):
    model.N_BATCH = 10

    model.I0[0] = 1
    model.I0[1] = 0
    model.I0[2] = 0
    model.I0[3] = 0
    model.I0[4] = 0

    A_pair = model.init_ff_input()

    model.I0[0] = -1
    model.I0[1] = 0
    model.I0[2] = 0
    model.I0[3] = 0
    model.I0[4] = 0

    B_pair = model.init_ff_input()

    ff_input = torch.cat((A_pair, B_pair))

    return ff_input
```

``` ipython
dpa = get_fix_points(model, 'dpa', seed)
dual_naive = get_fix_points(model, 'dual_naive', seed)
dual_train = get_fix_points(model, 'dual_train', seed)
```

``` ipython
fig, ax = plt.subplots(1, 3, figsize=[3*height, height])
plot_fix_points(dpa, ax[0], 'DPA')
plot_fix_points(dual_naive, ax[1], 'Dual Naive')
plot_fix_points(dual_train, ax[-1], 'Dual Expert')
plt.savefig('./figures/dual/fixed_points_%d.svg' % seed, dpi=300)
```

``` ipython
rates = np.stack((dpa, dual_naive, dual_train))
pkl_save(rates, './models/dual/rates_%d' % seed)
```

``` ipython
model.DURATION = 8
model.N_STEPS = int(model.DURATION / model.DT) + model.N_STEADY + model.N_WINDOW
```

``` ipython
import re
model_directory = "./models/dual/"
dpa_files = [f for f in os.listdir(model_directory) if f.startswith("rates_") and f.endswith(".pkl")]
seeds = [int(re.search(r'_(\d+)\.pkl$', name).group(1)) for name in dpa_files]
print(seeds)
```

``` ipython
fig, ax = plt.subplots(1, 3, figsize=[3*height, height])
for seed in seeds:
    rates = pkl_load('./models/dual/rates_%d' % seed)
    plot_fix_points(rates[0], ax[0], 'DPA')
    plot_fix_points(rates[1], ax[1], 'Dual Naive')
    plot_fix_points(rates[2], ax[-1], 'Dual Expert')
plt.savefig('./figures/dual/fixed_points_%d.svg' % seed, dpi=300)
```

opto
====

``` ipython
model_state_dict = torch.load('models/dual/dual_train_%d.pth' % seed)
model.load_state_dict(model_state_dict)
W_stp_T = model.W_stp_T.clone()
```

``` ipython
steps = np.arange(0, model.N_STEPS - model.N_STEADY, model.N_WINDOW)

mask_rwd = (steps >= (model.N_STIM_ON[-1].cpu().numpy() - model.N_STEADY))
rwd_idx = np.where(mask_rwd)[0]
print('rwd', rwd_idx)

mask_cue = (steps >= (model.N_STIM_ON[2].cpu().numpy() - model.N_STEADY)) & (steps <= (model.N_STIM_OFF[3].cpu().numpy() - model.N_STEADY))
cue_idx = np.where(mask_cue)[0]
print('cue', cue_idx)

mask_GnG = (steps >= (model.N_STIM_OFF[1].cpu().numpy() - model.N_STEADY)) & (steps <= (model.N_STIM_ON[2].cpu().numpy() - model.N_STEADY))
GnG_idx = np.where(mask_GnG)[0]
print('GnG', GnG_idx)

stim_mask = (steps >= (model.N_STIM_ON[0].cpu().numpy() - model.N_STEADY))

mask_zero = ~mask_rwd & ~mask_cue & ~stim_mask
zero_idx = np.where(mask_zero)[0]
print('zero', zero_idx)
```

``` ipython
k = 50
a, idx = torch.sort(model.low_rank.U[:,1])

# model.W_stp_T[:, idx[:k]] = 0
# model.W_stp_T[idx[:k], :] = 0

k = k//2
# model.W_stp_T[idx[-k:], :] = 0
# model.W_stp_T[:, idx[-k:]] = 0
```

``` ipython
N_BATCH = 32
model.N_BATCH = N_BATCH

model.lr_eval_win = np.max( (rwd_idx.shape[0], cue_idx.shape[0]))

ff_input = []
labels = np.zeros((2, 12, model.N_BATCH, model.lr_eval_win))

l=0
for j in [0, 1, -1]:
    for i in [-1, 1]:
        for k in [-1, 1]:

            model.I0[0] = i # sample
            model.I0[1] = j # distractor
            model.I0[4] = k # test

            if i==k: # Pair Trials
                labels[0, l] = np.ones((model.N_BATCH, model.lr_eval_win))

            if j==1: # Go
                model.I0[2] = float(B0) # cue
                model.I0[3] = float(C0) * model.IF_RL # rwd

                labels[1, l] = np.ones((model.N_BATCH, model.lr_eval_win))
            elif j==-1: # NoGo
                model.I0[2] = float(B0) # cue
                model.I0[3] = 0.0 # rwd
            else: # DPA
                model.I0[2] = 0 # cue
                model.I0[3] = 0 # rwd

            l+=1

            ff_input.append(model.init_ff_input())

labels = torch.tensor(labels, dtype=torch.float, device=DEVICE).reshape(2, -1, model.lr_eval_win).transpose(0, 1)
# labels = torch.tensor(labels, dtype=torch.float, device=DEVICE).reshape(3, -1, model.lr_eval_win).transpose(0, 1)
ff_input = torch.vstack(ff_input)
print('ff_input', ff_input.shape, 'labels', labels.shape)
```

``` ipython
rates = model.forward(ff_input=ff_input).detach()
rates = rates.cpu().detach().numpy()
model.W_stp_T = W_stp_T
print(rates.shape)
```

``` ipython
readout = model.readout.cpu().detach().numpy().reshape(3, -1, 81, 2) / 2
print(readout.shape)

N_BATCH *= 2

time = np.linspace(0, 8, readout.shape[-2])
fig, ax = plt.subplots(1, 2, figsize=[2*width, height])

color = ['r', 'b', 'g']
label = ['DPA', 'DualGo', 'DualNoGo']

for i in range(3):
    sample = (-readout[i, :N_BATCH, :, 0].T  + readout[i, N_BATCH:, :, 0].T)
    dist = (readout[i, :N_BATCH, :, 1].T  + readout[i, N_BATCH:, :, 1].T)

    ax[0].plot(time, sample.mean(1), color=color[i])
    ax[1].plot(time, dist.mean(1), color=color[i], label=label[i])

add_vlines(ax[0])
add_vlines(ax[1])
ax[0].set_ylabel('Sample Overlap (Hz)')
ax[1].set_ylabel('Go/NoGo Overlap (Hz)')
ax[1].legend(frameon=False, fontsize=10)
ax[0].set_xlabel('Time (s)')
ax[1].set_xlabel('Time (s)')
plt.savefig('./figures/dual/dual_train_%d_over_opto.svg' % seed, dpi=300)
plt.show()
```

``` ipython
criterion = DualPerf(alpha=1.0, thresh=4.0, cue_idx=cue_idx, rwd_idx=rwd_idx, zero_idx=zero_idx, imbalance=[0.0, 0.0], read_idx=[1, 1])
dpa_opto_perf, drt_opto_perf = criterion(model.readout, labels)

dpa_mean_opto, dpa_sem_opto = calculate_mean_accuracy_and_sem(dpa_opto_perf)
drt_mean_opto, drt_sem_opto = calculate_mean_accuracy_and_sem(drt_opto_perf)
```

``` ipython
fig, ax = plt.subplots(1, 2, figsize=[1.5*width, height], sharex=True)

ax[0].errorbar(0, dpa_mean2.item(), yerr=dpa_sem2.item(), fmt='o', label='Expert',
             color='k', ecolor='k', elinewidth=3, capsize=5)
ax[0].errorbar(1, dpa_mean_opto.item(), yerr=dpa_sem_opto.item(), fmt='o', label='Suppr',
             color='k', ecolor='k', elinewidth=3, capsize=5)

ax[0].set_xlim(-1, 2)
ax[0].set_ylim(0.4, 1.1)

ax[0].set_ylabel('DPA Accuracy')
ax[0].set_xticks([0, 1], ['Expert', 'Suppr'])
ax[0].axhline(y=0.5, color='k', linestyle='--')

ax[1].errorbar(0, drt_mean2.item(), yerr=drt_sem2.item(), fmt='o', label='Expert',
             color='k', ecolor='k', elinewidth=3, capsize=5)
ax[1].errorbar(1, drt_mean_opto.item(), yerr=drt_sem_opto.item(), fmt='o', label='Suppr',
             color='k', ecolor='k', elinewidth=3, capsize=5)

ax[1].set_xlim(-1, 2)
ax[1].set_ylim(0.4, 1.1)

ax[1].set_ylabel('Go/NoGo Accuracy')
ax[1].set_xticks([0, 1], ['Expert', 'Suppr'])
ax[1].axhline(y=0.5, color='k', linestyle='--')

plt.savefig('./figures/dual/dual_perf_opto_%d.svg' % seed, dpi=300)

plt.show()
```