Notebook Settings
=================

``` ipython
%load_ext autoreload
%autoreload 2
%reload_ext autoreload

%run ../notebooks/setup.py
%matplotlib inline
%config InlineBackend.figure_format = 'png'

REPO_ROOT = "/home/leon/models/NeuroFlame"
pal = sns.color_palette("tab10")
```

Imports
=======

``` ipython
import torch
import torch.nn as nn
import torch.optim as optim
import torchmetrics
from torch.utils.data import Dataset, TensorDataset, DataLoader

DEVICE = 'cuda:1'
```

``` ipython
import sys
sys.path.insert(0, '../')

import pandas as pd
import torch.nn as nn
from time import perf_counter
from scipy.stats import circmean

from src.network import Network
from src.plot_utils import plot_con
from src.decode import decode_bump, circcvl
from src.lr_utils import masked_normalize, clamp_tensor, normalize_tensor
```

Helpers
=======

Data Split
----------

``` ipython
from sklearn.model_selection import train_test_split, StratifiedShuffleSplit

def split_data(X, Y, train_perc=0.8, batch_size=32):

  if Y.ndim==3:
    X_train, X_test, Y_train, Y_test = train_test_split(X, Y,
                                                        train_size=train_perc,
                                                        stratify=Y[:, 0, 0].cpu().numpy(),
                                                        shuffle=True)
  else:
    X_train, X_test, Y_train, Y_test = train_test_split(X, Y,
                                                        train_size=train_perc,
                                                        stratify=Y[:, 0].cpu().numpy(),
                                                        shuffle=True)
  print(X_train.shape, X_test.shape)
  print(Y_train.shape, Y_test.shape)

  train_dataset = TensorDataset(X_train, Y_train)
  val_dataset = TensorDataset(X_test, Y_test)

  # Create data loaders
  train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
  val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)

  return train_loader, val_loader
```

Optimization
------------

``` ipython
def accuracy_score(y_pred, labels):
  probs = torch.sigmoid(y_pred)
  # Assuming 'outputs' are logits from your model (raw scores before sigmoid)
  predicted = (probs > 0.5).float()  # Convert to 0 or 1 based on comparison with 0
  # 'labels' should be your ground truth labels for the binary classification, also in 0 or 1
  correct = (predicted == labels).sum()
  accuracy = correct / labels.size(0) / labels.size(-1)

  return accuracy
```

``` ipython
def torch_angle_AB(U, V):
    # Calculate the dot product
    dot_product = torch.dot(U, V)

    # Calculate the magnitudes of U and V
    magnitude_U = torch.linalg.norm(U)
    magnitude_V = torch.linalg.norm(V)

    # Compute the cosine of the angle
    cos_theta = dot_product / (magnitude_U * magnitude_V)

    # Calculate the angle in radians, then convert to degrees
    angle_radians = torch.acos(cos_theta)
    return torch.round(torch.rad2deg(angle_radians))
```

``` ipython
def training_step(dataloader, model, loss_fn, optimizer, penalty=None, lbd=1, clip_grad=0, zero_grad=0):
    device = torch.device(DEVICE if torch.cuda.is_available() else "cpu")

    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        rates = model(X)

        # y_pred = model.low_rank.linear(model.low_rank.dropout(rates)).squeeze(-1)
        y_pred = rates @ model.low_rank.U[model.slices[0], 1]

        overlap = rates @ model.low_rank.U[model.slices[0], 0] / model.Na[0]


        loss = loss_fn(y_pred, y) + F.relu(overlap[..., :9].abs() - 1.0).mean()

        if penalty is not None:
            reg_loss = 0
            for param in model.parameters():
                if penalty=='l1':
                    reg_loss += torch.sum(torch.abs(param))
                else:
                    reg_loss += torch.sum(torch.square(param))

                loss = loss + lbd * reg_loss

        # Backpropagation
        loss.backward()

        if zero_grad > 0:
            model.low_rank.U.grad[:, zero_grad-1] = 0
            try:
                model.low_rank.V.grad[:, zero_grad-1] = 0
            except:
                pass

        # Clip gradients
        if clip_grad:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
            #torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=1.0)

        optimizer.step()
        optimizer.zero_grad()

    return loss
```

``` ipython
def validation_step(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)

    device = torch.device(DEVICE if torch.cuda.is_available() else "cpu")
    # metric = torchmetrics.classification.Accuracy(task="binary")

    # Validation loop.
    model.eval()
    val_loss = 0.0

    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)

            rates = model(X)
            # y_pred = model.low_rank.linear(model.low_rank.dropout(rates)).squeeze(-1)

            y_pred = rates @ model.low_rank.U[model.slices[0], 1]
            overlap = rates @ model.low_rank.U[model.slices[0], 0] / model.Na[0]

            loss = loss_fn(y_pred, y) + F.relu(overlap[..., :9].abs() - 1.0).mean()
            # acc = metric(y_pred, y)

            val_loss += loss.item() * X.size(0)

        val_loss /= size
        # acc = metric.compute()
        # print(f"Accuracy: {acc}")
        # metric.reset()
    return val_loss
```

``` ipython
def optimization(model, train_loader, val_loader, loss_fn, loss2_fn, optimizer, num_epochs=100, penalty=None, lbd=1, thresh=.005, zero_grad=0):
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
    # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.1, verbose=True)
    # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

    device = torch.device(DEVICE if torch.cuda.is_available() else 'cpu')
    model.to(device)

    loss_list = []
    val_loss_list = []
    accuracies = []
    angle_list = []

    for epoch in range(num_epochs):
        loss = training_step(train_loader, model, loss_fn, optimizer, penalty, lbd, zero_grad=zero_grad)
        val_loss = validation_step(val_loader, model, loss_fn)
        # accuracy_loss = test(val_loader, model, loss2_fn)

        scheduler.step(val_loss)

        loss_list.append(loss.item())
        val_loss_list.append(val_loss)
        # if epoch % int(num_epochs  / 10) == 0:
        print(f'Epoch {epoch+1}/{num_epochs}, Training Loss: {loss.item():.4f}, Validation Loss: {val_loss:.4f}')
        # print(f'Epoch {epoch+1}/{num_epochs}, Training Loss: {loss.item():.4f}, Validation Loss: {val_loss:.4f}, Accuracy: {accuracy_loss:.4f}')

        memory = model.low_rank.U[model.slices[0], 0]
        readout = model.low_rank.linear.weight.data[0]
        angle = torch_angle_AB(memory, readout).item()
        angle_list.append(angle)

        print(f'Angle(U, W) : {angle} °', 'performance')

        if val_loss < thresh and loss < thresh:
            print(f'Stopping training as loss has fallen below the threshold: {loss}, {val_loss}')
            break

        if val_loss > 300:
            print(f'Stopping training as loss is too high: {val_loss}')
            break

        if torch.isnan(loss):
            print(f'Stopping training as loss is NaN.')
            break

    return loss_list, val_loss_list
```

Loss
----

``` ipython
def performance_score(model, rates, labels):
    print(rates.shape)
    y_pred = model.low_rank.linear(rates[:, -2:]).squeeze(-1)
    accuracy = accuracy_score(y_pred, labels)
    return accuracy
```

``` ipython
def imbalance(target):
  output = torch.zeros_like(target)

  # Update values
  output[target == 1] = 1
  output[target == 0] = 1

  return output
```

``` ipython
import torch
import torch.nn as nn
import torch.nn.functional as F

class SignBCELoss(nn.Module):
    def __init__(self, alpha=1.0, thresh=4.0, N=1000):
        super(SignBCELoss, self).__init__()
        self.alpha = alpha
        self.thresh = thresh
        self.N = N

        self.bce_with_logits = nn.BCEWithLogitsLoss()

    def forward(self, readout, targets):
        if self.alpha != 1.0:
            bce_loss = self.bce_with_logits(readout, targets)
        else:
            bce_loss = 0.0
        # sign_overlap = torch.sign(2 * targets² - 1) * readout / (1.0 * self.N)

        mean_activation = readout.mean(dim=1).unsqueeze(-1)
        sign_overlap = torch.sign(2 * targets - 1) * mean_activation / (1.0 * self.N)

        # sign_loss = F.relu(self.thresh - sign_overlap).mean()

        # Let's penalize more the wrong licks
        # sign_loss = F.relu(imbalance(targets) * self.thresh - sign_overlap).mean()
        sign_loss = F.relu(self.thresh - sign_overlap).mean()

        combined_loss = (1-self.alpha) * bce_loss + self.alpha * sign_loss
        return combined_loss
```

``` ipython
class DualLoss(nn.Module):
    def __init__(self, alpha=1.0, thresh=4.0, N=1000, cue_idx=[], rwd_idx=-1, zero_idx=[]):
        super(DualLoss, self).__init__()
        self.alpha = alpha
        self.thresh = thresh
        self.N = N

        self.zero_idx = zero_idx
        self.cue_idx = torch.tensor(cue_idx, dtype=torch.int, device=DEVICE)
        self.rwd_idx = torch.tensor(rwd_idx, dtype=torch.int, device=DEVICE)

        self.loss = SignBCELoss(self.alpha, self.thresh, self.N)

    def forward(self, readout, targets):

        # ensuring zero bl overlap
        bl_loss = F.relu((readout[:, self.zero_idx] / self.N).abs() - 1.0).mean()

        is_empty = self.cue_idx.numel() == 0
        if is_empty:
            self.DPA_loss = self.loss(readout[:, self.rwd_idx], targets)
            return (self.DPA_loss + bl_loss)
        else:
            # self.loss.thresh = self.thresh
            self.DPA_loss = self.loss(readout[:, self.rwd_idx], targets[:, 0, :self.rwd_idx.shape[0]])
            # self.loss.thresh = 4.0
            self.DRT_loss = self.loss(readout[:, self.cue_idx], targets[:, 1, :self.cue_idx.shape[0]])
            return (self.DPA_loss + self.DRT_loss) / 2.0 + bl_loss
```

``` ipython
class AccuracyLoss(nn.Module):
    def __init__(self, N=1000, cue_idx=[], rwd_idx=-1):
        super(AccuracyLoss, self).__init__()
        self.N = N

        # self.loss = nn.BCEWithLogitsLoss()
        self.cue_idx = torch.tensor(cue_idx, dtype=torch.int, device=DEVICE)
        self.rwd_idx = torch.tensor(rwd_idx, dtype=torch.int, device=DEVICE)

    def forward(self, readout, targets):

        is_empty = self.cue_idx.numel() == 0
        if is_empty:
            self.DPA_loss = accuracy_score(readout[:, self.rwd_idx], targets)
            return self.DPA_loss
        else:
            self.DPA_loss = accuracy_score(readout[:, self.rwd_idx], targets[:, 0, :self.rwd_idx.shape[0]])
            self.DRT_loss = accuracy_score(readout[:, self.cue_idx], targets[:, 1, :self.cue_idx.shape[0]])
            return (self.DPA_loss + self.DRT_loss) / 2.0
```

Other
-----

``` ipython
def angle_AB(A, B):
    A_norm = A / (np.linalg.norm(A) + 1e-5)
    B_norm = B / (np.linalg.norm(B) + 1e-5)

    return int(np.arccos(A_norm @ B_norm) * 180 / np.pi)
```

``` ipython
def get_theta(a, b, GM=0, IF_NORM=0):

    u, v = a, b

    if GM:
        v = b - np.dot(b, a) / np.dot(a, a) * a

    if IF_NORM:
        u = a / np.linalg.norm(a)
        v = b / np.linalg.norm(b)

    return np.arctan2(v, u) % (2.0 * np.pi)
```

``` ipython
def get_idx(model, rank=2):
    ksi = torch.hstack((model.low_rank.U, model.low_rank.V)).T
    ksi = ksi[:, :model.Na[0]]

    readout = model.low_rank.linear.weight.data
    ksi = torch.vstack((ksi, readout))

    print('ksi', ksi.shape)

    ksi = ksi.cpu().detach().numpy()
    theta = get_theta(ksi[0], ksi[rank])

    return theta.argsort()
```

``` ipython
def get_overlap(model, rates):
    ksi = model.odors.cpu().detach().numpy()
    return rates @ ksi.T / rates.shape[-1]

```

``` ipython
import scipy.stats as stats

def plot_smooth(data, ax, color):
    mean = data.mean(axis=0)
    ci = smooth.std(axis=0, ddof=1) * 1.96

    # Plot
    ax.plot(mean, color=color)
    ax.fill_between(range(data.shape[1]), mean - ci, mean + ci, alpha=0.25, color=color)

```

``` ipython
def convert_seconds(seconds):
    h = seconds // 3600
    m = (seconds % 3600) // 60
    s = seconds % 60
    return h, m, s
```

plots
-----

``` ipython
def plot_rates_selec(rates, idx, thresh=0.5, figname='fig.svg'):
      ordered = rates[..., idx]
      fig, ax = plt.subplots(1, 2, figsize=[2*width, height])
      r_max = thresh * np.max(rates[0])

      ax[0].imshow(rates[0].T, aspect='auto', cmap='jet', vmin=0, vmax=r_max)
      ax[0].set_ylabel('Neuron #')
      ax[0].set_xlabel('Step')

      ax[1].imshow(ordered[0].T, aspect='auto', cmap='jet', vmin=0, vmax=r_max)
      ax[1].set_yticks(np.linspace(0, model.Na[0].cpu().detach(), 5), np.linspace(0, 360, 5).astype(int))
      ax[1].set_ylabel('Pref. Location (°)')
      ax[1].set_xlabel('Step')
      plt.savefig(figname, dpi=300)
      plt.show()
```

``` ipython
def plot_overlap(rates, memory, readout, labels=['A', 'B'], figname='fig.svg'):
    fig, ax = plt.subplots(1, 2, figsize=[2*width, height])
    overlap =(rates @ memory) / rates.shape[-1]

    if overlap.shape[0]>2:
        ax[0].plot(overlap.T[..., :2], label=labels[0])
        ax[0].plot(overlap.T[..., 2:], '--', label=labels[1])
    else:
        ax[0].plot(overlap.T[..., 0], label=labels[0])
        ax[0].plot(overlap.T[..., 1], '--', label=labels[1])

    ax[0].set_xlabel('Step')
    ax[0].set_ylabel('Overlap')
    ax[0].set_title('Memory')

    overlap =(rates @ readout) / rates.shape[-1]

    if overlap.shape[0]>2:
        ax[1].plot(overlap.T[..., :2], label=labels[0])
        ax[1].plot(overlap.T[..., 2:], '--', label=labels[1])
    else:
        ax[1].plot(overlap.T[..., 0], label=labels[0])
        ax[1].plot(overlap.T[..., 1], '--', label=labels[1])

    ax[1].set_xlabel('Step')
    ax[1].set_ylabel('Overlap')
    ax[1].set_title('Readout')

    # plt.legend(fontsize=10, frameon=False)
    plt.savefig(figname, dpi=300)
    plt.show()
```

``` ipython
def plot_m0_m1_phi(rates, idx, figname='fig.svg'):

    m0, m1, phi = decode_bump(rates[..., idx], axis=-1)
    fig, ax = plt.subplots(1, 3, figsize=[2*width, height])

    ax[0].plot(m0[:2].T)
    ax[0].plot(m0[2:].T, '--')
    #ax[0].set_ylim([0, 360])
    #ax[0].set_yticks([0, 90, 180, 270, 360])
    ax[0].set_ylabel('$\mathcal{F}_0$ (Hz)')
    ax[0].set_xlabel('Step')

    ax[1].plot(m1[:2].T)
    ax[1].plot(m1[2:].T, '--')
    # ax[1].set_ylim([0, 360])
    # ax[1].set_yticks([0, 90, 180, 270, 360])
    ax[1].set_ylabel('$\mathcal{F}_1$ (Hz)')
    ax[1].set_xlabel('Step')

    ax[2].plot(phi[:2].T * 180 / np.pi)
    ax[2].plot(phi[2:].T * 180 / np.pi, '--')
    ax[2].set_ylim([0, 360])
    ax[2].set_yticks([0, 90, 180, 270, 360])
    ax[2].set_ylabel('Phase (°)')
    ax[2].set_xlabel('Step')

    plt.savefig(figname, dpi=300)
    plt.show()
```

Model
=====

``` ipython
REPO_ROOT = "/home/leon/models/NeuroFlame"
conf_name = "config_train.yml"
DEVICE = 'cuda:1'
seed = np.random.randint(0, 1e6)
print(seed)
```

``` ipython
model = Network(conf_name, REPO_ROOT, VERBOSE=0, DEVICE=DEVICE, SEED=seed, N_BATCH=16)
```

Sample Classification
=====================

Training
--------

### Parameters

``` ipython
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, param.shape)
```

``` ipython
model.LR_TRAIN = 1
model.LR_READOUT = 1
model.IF_RL = 0
```

Testing the network on steps from sample odor offset to test odor onset

``` ipython
steps = np.arange(0, model.N_STEPS - model.N_STEADY, model.N_WINDOW)

mask = (steps >= (model.N_STIM_OFF[0] - model.N_STEADY)) & (steps <= (model.N_STEPS - model.N_STEADY))
rwd_idx = np.where(mask)[0]
print('rwd', rwd_idx)

model.lr_eval_win = rwd_idx.shape[0]

stim_mask = (steps >= (model.N_STIM_ON[0] - model.N_STEADY)) & (steps < (model.N_STIM_OFF[0] - model.N_STEADY))

zero_idx = np.where(~mask & ~stim_mask )[0]
print('zero', zero_idx)
```

### Inputs and Labels

``` ipython
model.N_BATCH = 80

model.I0[0] = 2.0
model.I0[1] = 0
model.I0[2] = 0
model.I0[3] = 0

A = model.init_ff_input()

model.I0[0] = -2.0
model.I0[1] = 0
model.I0[2] = 0
model.I0[3] = 0

B = model.init_ff_input()

ff_input = torch.cat((A, B))
print(ff_input.shape)
```

``` ipython
labels_A = torch.ones((model.N_BATCH, rwd_idx.shape[0]))
labels_B = torch.zeros((model.N_BATCH, rwd_idx.shape[0]))
labels = torch.cat((labels_A, labels_B))

print('labels', labels.shape)
```

### Run

``` ipython
batch_size = 16
train_loader, val_loader = split_data(ff_input, labels, train_perc=0.8, batch_size=batch_size)
```

``` ipython
criterion = DualLoss(alpha=1.0, thresh=5.0, N=model.Na[0], rwd_idx=rwd_idx, zero_idx=zero_idx)
criterion2 = AccuracyLoss(N=model.Na[0], rwd_idx=rwd_idx)

# SGD, Adam, Adam
learning_rate = 0.05
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
```

``` ipython
num_epochs = 15
start = perf_counter()
loss, val_loss = optimization(model, train_loader, val_loader, criterion, criterion2, optimizer, num_epochs)
end = perf_counter()
print("Elapsed (with compilation) = %dh %dm %ds" % convert_seconds(end - start))
```

Testing
-------

``` ipython
model.eval()
model.LR_READOUT = 0
```

``` ipython
model.N_BATCH = 10

model.I0[0] = 2
model.I0[1] = 0
model.I0[2] = 0

A = model.init_ff_input()

model.I0[0] = -2
model.I0[1] = 0
model.I0[2] = 0

B = model.init_ff_input()

ff_input = torch.cat((A, B))
print('ff_input', ff_input.shape)
```

``` ipython
rates = model.forward(ff_input=ff_input).cpu().detach().numpy()
print('rates', rates.shape)
```

``` ipython
# memory = model.odors.cpu().detach().numpy()[0]
memory = model.low_rank.U.cpu().detach().numpy()[model.slices[0], 0]
readout = model.low_rank.U.cpu().detach().numpy()[model.slices[0], 1]
# readout = model.low_rank.linear.weight.data.cpu().detach().numpy()[0]
plot_overlap(rates, memory, readout, labels=['A', 'B'])
```

``` ipython
idx = get_idx(model, 1)
plot_rates_selec(rates, idx)
```

``` ipython
plot_m0_m1_phi(rates, idx)
```

DPA
===

Training
--------

### Parameters

``` ipython
model.low_rank.U.data[:, 1] = torch.randn(model.low_rank.U.T.data[1].shape) * 0.01
```

``` ipython
model.LR_TRAIN = 1
model.LR_READOUT = 1
model.IF_RL = 0
```

Here we only evaluate performance from test onset to test offset

``` ipython
steps = np.arange(0, model.N_STEPS - model.N_STEADY, model.N_WINDOW)
# mask = (steps >= (model.N_STIM_OFF[2] - model.N_STEADY)) & (steps <= (model.N_STEPS - model.N_STEADY))
mask = (steps >= (model.N_STIM_ON[4] - model.N_STEADY)) & (steps <= (model.N_STEPS - model.N_STEADY))
rwd_idx = np.where(mask)[0]
print('rwd', rwd_idx)

model.lr_eval_win = rwd_idx.shape[0]

stim_mask = (steps >= (model.N_STIM_ON[0] - model.N_STEADY)) & (steps < (model.N_STIM_OFF[0] - model.N_STEADY))

stim_mask1 = (steps >= (model.N_STIM_ON[4] - model.N_STEADY)) # & (steps < (model.N_STIM_OFF[3] - model.N_STEADY))

mask_zero = ~mask & ~stim_mask & ~stim_mask1
zero_idx = np.where(mask_zero)[0]
print('zero', zero_idx)
```

### Inputs and Labels

``` ipython
model.N_BATCH = 80

A0 = 1

model.I0[0] = A0
model.I0[1] = 0
model.I0[2] = 0
model.I0[3] = 0
model.I0[4] = A0

AC_pair = model.init_ff_input()

model.I0[0] = A0
model.I0[1] = 0
model.I0[2] = 0
model.I0[3] = 0
model.I0[4] = -A0

AD_pair = model.init_ff_input()

model.I0[0] = -A0
model.I0[1] = 0
model.I0[2] = 0
model.I0[3] = 0
model.I0[4] = A0

BC_pair = model.init_ff_input()

model.I0[0] = -A0
model.I0[1] = 0
model.I0[2] = 0
model.I0[3] = 0
model.I0[4] = -A0

BD_pair = model.init_ff_input()

ff_input = torch.cat((AC_pair, BD_pair, AD_pair, BC_pair))
print('ff_input', ff_input.shape)
```

``` ipython
labels_pair = torch.ones((2 * model.N_BATCH, model.lr_eval_win))
labels_unpair = torch.zeros((2 * model.N_BATCH, model.lr_eval_win))

labels = torch.cat((labels_pair, labels_unpair))
print('labels', labels.shape)
```

### Run

``` ipython
batch_size = 16
train_loader, val_loader = split_data(ff_input, labels, train_perc=0.8, batch_size=batch_size)
```

``` ipython
# Loss
criterion = DualLoss(alpha=1.0, thresh=5.0, N=model.Na[0], rwd_idx=rwd_idx, zero_idx=zero_idx)
criterion2 = AccuracyLoss(N=model.Na[0], rwd_idx=rwd_idx)

# Optimizer: SGD, Adam, Adam
learning_rate = 0.05
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
```

``` ipython
num_epochs = 30
start = perf_counter()
loss, val_loss = optimization(model, train_loader, val_loader, criterion, criterion2, optimizer, num_epochs, zero_grad=0)
end = perf_counter()
print("Elapsed (with compilation) = %dh %dm %ds" % convert_seconds(end - start))
```

``` ipython
torch.save(model.state_dict(), 'models/dpa_%d.pth' % seed)
```

``` ipython
plt.plot(loss)
plt.plot(val_loss)
plt.xlabel('epochs')
plt.ylabel('Loss')
plt.show()
```

``` ipython
torch.save(model.state_dict(), 'models/dpa_%d.pth' % seed)
```

``` ipython
odors = model.odors.cpu().numpy()
U = model.low_rank.U.cpu().detach().numpy()[model.slices[0], 0]
V = model.low_rank.V.cpu().detach().numpy()[model.slices[0], 0]
W = model.low_rank.linear.weight.data.cpu().detach().numpy()[0]

print('   U  V  W  S  D')
print('U ', angle_AB(U, U), angle_AB(U, V), angle_AB(U, W), angle_AB(U, odors[0]), angle_AB(U, odors[1]))
print('V ', 'XXX', angle_AB(V, V), angle_AB(V, W), angle_AB(V, odors[0]), angle_AB(V, odors[1]))
print('W ', 'XXX', 'XXX', angle_AB(W, W), angle_AB(W, odors[0]), angle_AB(W, odors[1]))
print('S ', 'XXX', 'XXX', 'XXX', angle_AB(odors[0], odors[0]), angle_AB(odors[0], odors[1]))
print('D ', 'XXX', 'XXX', 'XXX', 'XXX', angle_AB(odors[1], odors[1]))

```

Testing
-------

``` ipython
model.eval()
model.LR_READOUT = 0
```

``` ipython
model.N_BATCH = 1
A0 = 1

model.I0[0] = A0
model.I0[1] = 0
model.I0[2] = 0
model.I0[3] = 0
model.I0[4] = A0

AC_pair = model.init_ff_input()

model.I0[0] = A0
model.I0[1] = 0
model.I0[2] = 0
model.I0[3] = 0
model.I0[4] = -A0

AD_pair = model.init_ff_input()

model.I0[0] = -A0
model.I0[1] = 0
model.I0[2] = 0
model.I0[3] = 0
model.I0[4] = A0

BC_pair = model.init_ff_input()

model.I0[0] = -A0
model.I0[1] = 0
model.I0[2] = 0
model.I0[3] = 0
model.I0[4] = -A0

BD_pair = model.init_ff_input()

ff_input = torch.cat((AC_pair, BD_pair, AD_pair, BC_pair))
print('ff_input', ff_input.shape)
```

``` ipython
labels_pair = torch.ones((2 * model.N_BATCH, 2))
labels_unpair = torch.zeros((2 * model.N_BATCH, 2))

labels = torch.cat((labels_pair, labels_unpair))
print('labels', labels.shape)
```

``` ipython
rates = model.forward(ff_input=ff_input)
print(rates.shape)
```

``` ipython
print(rates.shape)
print(labels.shape)
```

``` ipython
perf = performance_score(model, rates, labels.to('cuda:1'))
```

``` ipython
print(perf.item())
```

``` ipython
# readout = model.low_rank.linear.weight.data.cpu().detach().numpy()[0]
memory = model.low_rank.U.cpu().detach().numpy()[model.slices[0], 0]
readout = model.low_rank.U.cpu().detach().numpy()[model.slices[0], 1]
plot_overlap(rates.detach().cpu().numpy(), memory, readout, labels=['pair', 'unpair'], figname='dpa_overlap.svg')
```

``` ipython
idx = get_idx(model, 1)
plot_rates_selec(rates.detach().cpu().numpy(), idx, figname='dpa_raster.svg')
```

``` ipython
plot_m0_m1_phi(rates.detach().cpu().numpy(), idx, figname='dpa_fourier.svg')
```

``` ipython
print(rates.shape)
```

``` ipython
from matplotlib.patches import Circle
m0, m1, phi = decode_bump(rates[..., idx].detach().cpu().numpy(), axis=-1)

x = m1 / m0 * np.cos(phi)
y = m1 / m0 * np.sin(phi)

xA = x
yA = y

fig, ax = plt.subplots(1, 1, figsize=[height, height])

ax.plot(xA.T[0], yA.T[0], 'x', alpha=.5, ms=10)
ax.plot(xA.T, yA.T, '-', alpha=.5)
ax.plot(xA.T[-1], yA.T[-1], 'o', alpha=.5, ms=10)
# ax.set_xlim([-.9, .9])
# ax.set_ylim([-.9, .9])
circle = Circle((0., 0.), 1, fill=False, edgecolor='k')
ax.add_patch(circle)

# Set the aspect of the plot to equal to make the circle circular
ax.set_aspect('equal')

plt.show()
```

``` ipython
```

Fixed points
------------

``` ipython
model.DURATION = 20
model.N_STEPS = int(model.DURATION / model.DT) + model.N_STEADY + model.N_WINDOW
model.IF_RL = 0
```

``` ipython
model.eval()
model.LR_READOUT = 0
```

``` ipython
model.N_BATCH = 1

model.I0[0] = A0
model.I0[1] = 0
model.I0[2] = 0
model.I0[3] = 0

AC_pair = model.init_ff_input()

model.I0[0] = A0
model.I0[1] = 0
model.I0[2] = 0
model.I0[3] = 0

AD_pair = model.init_ff_input()

model.I0[0] = -A0
model.I0[1] = 0
model.I0[2] = 0
model.I0[3] = 0

BC_pair = model.init_ff_input()

model.I0[0] = -A0
model.I0[1] = 0
model.I0[2] = 0
model.I0[3] = 0

BD_pair = model.init_ff_input()

ff_input = torch.cat((AC_pair, BD_pair, AD_pair, BC_pair))
print('ff_input', ff_input.shape, ff_input[0, 0, :4])
```

``` ipython
rates = model.forward(ff_input=ff_input).cpu().detach().numpy()
print(rates.shape)
```

``` ipython
memory = model.low_rank.U.cpu().detach().numpy()[model.slices[0], 0]
readout = model.low_rank.U.cpu().detach().numpy()[model.slices[0], 1]
# readout = model.low_rank.linear.weight.data[0].cpu().detach().numpy()
plot_overlap(rates, memory, readout, labels=['pair', 'unpair'])
```

``` ipython
idx = get_idx(model, 1)
plot_rates_selec(rates, idx)
```

``` ipython
print(rates.shape)
```

``` ipython
plot_m0_m1_phi(rates, idx)
```

``` ipython
print(rates.shape)
```

``` ipython
plt.plot(rates[:, :,1].T)
# plt.xlim([0, 10])
plt.show()
```

``` ipython
```

``` ipython
from matplotlib.patches import Circle
m0, m1, phi = decode_bump(rates[..., idx], axis=-1)

x = m1 / m0 * np.cos(phi)
y = m1 / m0 * np.sin(phi)

xA = x
yA = y

fig, ax = plt.subplots(1, 1, figsize=[height, height])

# ax.plot(xA.T[0], yA.T[0], 'x', alpha=.5, ms=10)
# ax.plot(xA.T, yA.T, '-', alpha=.5)
ax.plot(xA.T[-1], yA.T[-1], 'o', alpha=.5, ms=20)
# ax.set_xlim([-.9, .9])
# ax.set_ylim([-.9, .9])
circle = Circle((0., 0.), 1.8, fill=False, edgecolor='k')
ax.add_patch(circle)

# Set the aspect of the plot to equal to make the circle circular
ax.set_aspect('equal')
plt.savefig('fp_dpa.svg', dpi=300)
plt.show()
```

``` ipython
```

Go/NoGo
=======

Training
--------

``` ipython
# for param in model.low_rank.linear.parameters():
#     param.requires_grad = False

# model.low_rank.U.requires_grad = False
# model.low_rank.V.requires_grad = False
```

``` ipython
model.DURATION = 4
model.N_STEPS = int(model.DURATION / model.DT) + model.N_STEADY + model.N_WINDOW

model.T_STIM_ON =  [1.0, 3.0, 3.5]
model.T_STIM_OFF =  [2.0, 3.5, 4.0]

model.N_STIM_ON = np.array(
      [int(i / model.DT) + model.N_STEADY for i in model.T_STIM_ON]
  )

model.N_STIM_OFF = [int(i / model.DT) + model.N_STEADY for i in model.T_STIM_OFF]
```

``` ipython
model.LR_TRAIN = 1
model.LR_READOUT = 1
model.IF_RL = 1
model.RWD = 2
```

``` ipython
steps = np.arange(0, model.N_STEPS - model.N_STEADY, model.N_WINDOW)
mask = (steps >= (model.N_STIM_ON[0] - model.N_STEADY)) & (steps <= (model.N_STIM_ON[1] - model.N_STEADY))

rwd_idx = np.where(mask)[0]
print('rwd', rwd_idx)

mask_cue = (steps >= (model.N_STIM_ON[1] - model.N_STEADY))
cue_idx = np.where(mask_cue)[0]
print('cue', cue_idx)

stim_mask = (steps >= (model.N_STIM_ON[0] - model.N_STEADY)) # & (steps < (model.N_STIM_OFF[0] - model.N_STEADY))

mask_zero = ~mask & ~stim_mask
zero_idx = np.where(mask_zero)[0]
print('zero', zero_idx)

model.lr_eval_win = np.max( (rwd_idx.shape[0], cue_idx.shape[0]))
```

``` ipython
# switching sample and distractor odors
odors = model.odors.clone()
model.odors[0] = odors[1] # distractor Go
model.odors[4] = odors[4+1] # distractor NoGo

model.odors[1] = odors[2] # cue
model.odors[2] = odors[3] # rwd

model.N_BATCH = 80

A0 = 1
# float(B0) = 1

model.I0[0] = A0
model.I0[1] = float(B0) # cue
model.I0[2] = 1.0  # reward
model.I0[3] = 0

Go = model.init_ff_input()

model.I0[0] = -A0
# model.I0[1] = 0
model.I0[1] = float(float(B0)) # cue
model.I0[2] = 0
model.I0[3] = 0

NoGo = model.init_ff_input()

ff_input = torch.cat((Go, NoGo))
print(ff_input.shape)
```

``` ipython
labels_Go = torch.ones((model.N_BATCH, model.lr_eval_win))
labels_NoGo = torch.zeros((model.N_BATCH, model.lr_eval_win))
labels = torch.cat((labels_Go, labels_NoGo))
labels =  labels.repeat((2, 1, 1))
labels = torch.transpose(labels, 0, 1)
print('labels', labels.shape)
```

``` ipython
batch_size = 16
train_loader, val_loader = split_data(ff_input, labels, train_perc=0.8, batch_size=batch_size)
```

``` ipython
criterion = DualLoss(alpha=1.0, thresh=5.0, N=model.Na[0], rwd_idx=rwd_idx, zero_idx=zero_idx, cue_idx=cue_idx)
criterion2 = AccuracyLoss(N=model.Na[0], rwd_idx=rwd_idx)

# SGD, Adam, Adam
learning_rate = 0.05
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
```

``` ipython
num_epochs = 15
start = perf_counter()
loss, val_loss = optimization(model, train_loader, val_loader, criterion, criterion2, optimizer, num_epochs, zero_grad=2)
end = perf_counter()
print("Elapsed (with compilation) = %dh %dm %ds" % convert_seconds(end - start))

# switching back sample and distractor odors
model.odors = odors
```

Test
----

``` ipython
model.RWD = 1
model.VERBOSE = 0
model.eval()
model.LR_READOUT = 0
```

``` ipython
odors = model.odors.clone()
model.odors[0] = odors[1] # distractor Go
model.odors[4] = odors[4+1] # distractor NoGo

model.odors[1] = odors[2] # cue
model.odors[2] = odors[3] # rwd
```

``` ipython
model.N_BATCH = 1

model.I0[0] = A0 # Go
model.I0[1] = float(B0) # cue
model.I0[2] = 1.0 # rwd

A = model.init_ff_input()

model.I0[0] = -A0 # NoGo
model.I0[1] = float(B0) # cue
model.I0[2] = 0 # rwd

B = model.init_ff_input()

ff_input = torch.cat((A, B))
print('ff_input', ff_input.shape)
```

``` ipython
rates = model.forward(ff_input=ff_input).cpu().detach().numpy()
model.odors = odors
print(rates.shape)
```

``` ipython
memory = model.low_rank.U.cpu().detach().numpy()[model.slices[0], 0]
readout = model.low_rank.U.cpu().detach().numpy()[model.slices[0], 1]
# readout = model.low_rank.linear.weight.data.cpu().detach().numpy()[0]
plot_overlap(rates, memory, readout, labels=['Go', 'NoGo'])
```

``` ipython
idx = get_idx(model, 1)
plot_rates_selec(rates, idx)
```

``` ipython
plot_m0_m1_phi(rates, idx)
```

Dual
====

``` ipython
model.DURATION = 8
model.N_STEPS = int(model.DURATION / model.DT) + model.N_STEADY + model.N_WINDOW
model.IF_RL = 1
model.RWD = 3
```

``` ipython
model.T_STIM_ON = [1.0, 3.0, 5.0, 5.5, 7.0]
model.T_STIM_OFF = [2.0, 4.0, 5.5, 6.0, 8.0]

model.N_STIM_ON = np.array(
    [int(i / model.DT) + model.N_STEADY for i in model.T_STIM_ON]
)

model.N_STIM_OFF = [int(i / model.DT) + model.N_STEADY for i in model.T_STIM_OFF]
```

Testing
-------

``` ipython
model.eval()
model.LR_READOUT = 0
```

``` ipython
model.N_BATCH = 1

A0 = 1
# float(B0) = 1

model.I0[0] = A0 # sample A
model.I0[1] = A0 # distractor Go
model.I0[2] = float(B0) # cue
model.I0[3] = A0 # rwd
model.I0[4] = A0 # test

AC_pair = model.init_ff_input()

model.I0[0] = A0
model.I0[1] = A0
model.I0[2] = float(B0)
model.I0[3] = A0
model.I0[4] = -A0

AD_pair = model.init_ff_input()

model.I0[0] = -A0
model.I0[1] = A0
model.I0[2] = float(B0)
model.I0[3] = A0
model.I0[4] = A0

BC_pair = model.init_ff_input()

model.I0[0] = -A0
model.I0[1] = A0
model.I0[2] = float(B0)
model.I0[3] = A0
model.I0[4] = -A0

BD_pair = model.init_ff_input()

ff_input = torch.cat((AC_pair, BD_pair, AD_pair, BC_pair))
print('ff_input', ff_input.shape)
```

``` ipython
torch.ones((2 * model.N_BATCH, 2))
labels_unpair = torch.zeros((2 * model.N_BATCH, 2))

labels = torch.cat((labels_pair, labels_unpair))
print('labels', labels.shape)
```

``` ipython
rates = model.forward(ff_input=ff_input).detach()
print(rates.shape)
```

``` ipython
perf = performance_score(model, rates, labels.to('cuda:1'))
print(perf)
```

``` ipython
rates = rates.cpu().numpy()
memory = model.low_rank.U.cpu().detach().numpy()[model.slices[0], 0]
readout = model.low_rank.U.cpu().detach().numpy()[model.slices[0], 1]
# readout = model.low_rank.linear.weight.data.cpu().detach().numpy()[0]
plot_overlap(rates, memory, readout, labels=['pair', 'unpair'], figname='dual_naive_overlap.svg')
```

``` ipython
idx = get_idx(model, 1)
plot_rates_selec(rates, idx, figname='dual_naive_raster.svg')
```

``` ipython
plot_m0_m1_phi(rates, idx, figname='dual_naive_fourier.svg')
```

``` ipython
```

Fixed points
------------

``` ipython
model.DURATION = 20
model.N_STEPS = int(model.DURATION / model.DT) + model.N_STEADY + model.N_WINDOW
model.IF_RL = 0
```

``` ipython
model.eval()
model.LR_READOUT = 0
```

``` ipython
model.N_BATCH = 1

model.I0[0] = A0
model.I0[1] = 0
model.I0[2] = 0
model.I0[3] = 0

AC_pair = model.init_ff_input()

model.I0[0] = A0
model.I0[1] = 0
model.I0[2] = 0
model.I0[3] = 0

AD_pair = model.init_ff_input()

model.I0[0] = -A0
model.I0[1] = 0
model.I0[2] = 0
model.I0[3] = 0

BC_pair = model.init_ff_input()

model.I0[0] = -A0
model.I0[1] = 0
model.I0[2] = 0
model.I0[3] = 0

BD_pair = model.init_ff_input()

ff_input = torch.cat((AC_pair, BD_pair, AD_pair, BC_pair))
print('ff_input', ff_input.shape, ff_input[0, 0, :4])
```

``` ipython
rates = model.forward(ff_input=ff_input).cpu().detach().numpy()
print(rates.shape)
```

``` ipython
memory = model.low_rank.U.cpu().detach().numpy()[model.slices[0], 0]
readout = model.low_rank.U.cpu().detach().numpy()[model.slices[0], 1]
# readout = model.low_rank.linear.weight.data[0].cpu().detach().numpy()
plot_overlap(rates, memory, readout, labels=['pair', 'unpair'])
```

``` ipython
idx = get_idx(model, 1)
plot_rates_selec(rates, idx)
```

``` ipython
print(rates.shape)
```

``` ipython
plot_m0_m1_phi(rates, idx)
```

``` ipython
print(rates.shape)
```

``` ipython
plt.plot(rates[:, :,1].T)
# plt.xlim([0, 10])
plt.show()
```

``` ipython
```

``` ipython
from matplotlib.patches import Circle
m0, m1, phi = decode_bump(rates[..., idx], axis=-1)

x = m1 / m0 * np.cos(phi)
y = m1 / m0 * np.sin(phi)

xA = x
yA = y

fig, ax = plt.subplots(1, 1, figsize=[height, height])

# ax.plot(xA.T[0], yA.T[0], 'x', alpha=.5, ms=10)
# ax.plot(xA.T, yA.T, '-', alpha=.5)
ax.plot(xA.T[-1], yA.T[-1], 'o', alpha=.5, ms=20)
# ax.set_xlim([-.9, .9])
# ax.set_ylim([-.9, .9])
circle = Circle((0., 0.), 1.8, fill=False, edgecolor='k')
ax.add_patch(circle)

# Set the aspect of the plot to equal to make the circle circular
ax.set_aspect('equal')
plt.savefig('fp_dual_naive.svg', dpi=300)
plt.show()
```

``` ipython
```

Training
--------

``` ipython
for param in model.low_rank.linear.parameters():
    param.requires_grad = True

model.low_rank.U.requires_grad = True
model.low_rank.V.requires_grad = True
```

``` ipython
model.DURATION = 8
model.N_STEPS = int(model.DURATION / model.DT) + model.N_STEADY + model.N_WINDOW
model.IF_RL = 1

model.LR_TRAIN = 1
model.LR_READOUT = 1
model.RWD = 3
```

``` ipython
steps = np.arange(0, model.N_STEPS - model.N_STEADY, model.N_WINDOW)

mask_rwd = (steps >= (model.N_STIM_OFF[-1] - model.N_STEADY)) & (steps <= (model.N_STEPS - model.N_STEADY))
rwd_idx = np.where(mask_rwd)[0]
print('rwd', rwd_idx)

mask_cue = (steps >= (model.N_STIM_OFF[1] - model.N_STEADY)) & (steps <= (model.N_STIM_ON[-1] - model.N_STEADY))
cue_idx = np.where(mask_cue)[0]
print('cue', cue_idx)

# stim_mask = (steps >= (model.N_STIM_ON[0] - model.N_STEADY)) & (steps < (model.N_STIM_OFF[0] - model.N_STEADY))
stim_mask = (steps >= (model.N_STIM_ON[0] - model.N_STEADY)) & (steps < (model.N_STIM_OFF[0] - model.N_STEADY))

stim_mask1 = (steps >= (model.N_STIM_ON[1] - model.N_STEADY)) & (steps < (model.N_STIM_OFF[1] - model.N_STEADY))

stim_mask2 = (steps >= (model.N_STIM_ON[2] - model.N_STEADY)) & (steps < (model.N_STIM_OFF[2] - model.N_STEADY))

stim_mask2 = (steps >= (model.N_STIM_ON[3] - model.N_STEADY)) & (steps < (model.N_STIM_OFF[3] - model.N_STEADY))

stim_mask3 = (steps >= (model.N_STIM_ON[-1] - model.N_STEADY)) # & (steps < (model.N_STIM_OFF[-1] - model.N_STEADY))

mask_zero = ~mask_rwd & ~mask_cue & ~stim_mask & ~stim_mask1 & ~stim_mask2 & ~stim_mask3
zero_idx = np.where(mask_zero)[0]
print('zero', zero_idx)
```

``` ipython
model.N_BATCH = 80

model.lr_eval_win = np.max( (rwd_idx.shape[0], cue_idx.shape[0]))

ff_input = []
labels = np.zeros((2, 12, model.N_BATCH, model.lr_eval_win))

l=0
for i in [-1, 1]:
    for j in [-1, 0, 1]:
        for k in [1, -1]:

            model.I0[0] = i # sample
            model.I0[1] = j # distractor
            model.I0[4] = k # test

            if i==k: # Pair Trials
                labels[0, l] = np.ones((model.N_BATCH, model.lr_eval_win))

            if j==1: # Go
                model.I0[2] = float(B0) # cue
                model.I0[3] = 1.0 # rwd

                labels[1, l] = np.ones((model.N_BATCH, model.lr_eval_win))
            elif j==-1: # NoGo
                model.I0[2] = float(B0) # cue
                model.I0[3] = 0.0 # rwd
            else:
                model.I0[2] = 0 # cue
                model.I0[3] = 0 # rwd

            l+=1

            ff_input.append(model.init_ff_input())

labels = torch.tensor(labels, dtype=torch.float, device=DEVICE).reshape(2, -1, model.lr_eval_win).transpose(0, 1)
ff_input = torch.vstack(ff_input)
print('ff_input', ff_input.shape, 'labels', labels.shape)
```

``` ipython
batch_size = 16
train_loader, val_loader = split_data(ff_input, labels, train_perc=0.8, batch_size=batch_size)
```

``` ipython
# criterion = nn.BCEWithLogitsLoss()
criterion = DualLoss(alpha=1.0, thresh=5.0, N=model.Na[0], cue_idx=cue_idx, rwd_idx=rwd_idx, zero_idx=zero_idx)
criterion2 = AccuracyLoss(N=model.Na[0], rwd_idx=rwd_idx, cue_idx=cue_idx)

# SGD, Adam, Adam
learning_rate = 0.05
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
```

``` ipython
num_epochs = 15
start = perf_counter()
loss, val_loss = optimization(model, train_loader, val_loader, criterion, criterion2, optimizer, num_epochs)
end = perf_counter()
print("Elapsed (with compilation) = %dh %dm %ds" % convert_seconds(end - start))
```

``` ipython
torch.save(model.state_dict(), 'models/dual_train_%d.pth' % seed)
```

``` ipython
odors = model.odors.cpu().numpy()
U = model.low_rank.U.cpu().detach().numpy()[model.slices[0], 0]
V = model.low_rank.V.cpu().detach().numpy()[model.slices[0], 0]
W = model.low_rank.linear.weight.data.cpu().detach().numpy()[0]

print('   U  V  W  S  D')
print('U ', angle_AB(U, U), angle_AB(U, V), angle_AB(U, W), angle_AB(U, odors[0]), angle_AB(U, odors[1]))
print('V ', 'XXX', angle_AB(V, V), angle_AB(V, W), angle_AB(V, odors[0]), angle_AB(V, odors[1]))
print('W ', 'XXX', 'XXX', angle_AB(W, W), angle_AB(W, odors[0]), angle_AB(W, odors[1]))
print('S ', 'XXX', 'XXX', 'XXX', angle_AB(odors[0], odors[0]), angle_AB(odors[0], odors[1]))
print('D ', 'XXX', 'XXX', 'XXX', 'XXX', angle_AB(odors[1], odors[1]))

```

### Re-Testing

``` ipython
model.DURATION = 8
model.N_STEPS = int(model.DURATION / model.DT) + model.N_STEADY + model.N_WINDOW
```

``` ipython
model.eval()
model.LR_READOUT = 0
```

``` ipython
model.N_BATCH = 1

model.I0[0] = A0
model.I0[1] = A0
model.I0[2] = float(B0)
model.I0[3] = A0
model.I0[4] = A0

AC_pair = model.init_ff_input()

model.I0[0] = A0
model.I0[1] = A0
model.I0[2] = float(B0)
model.I0[3] = A0
model.I0[4] = -A0

AD_pair = model.init_ff_input()

model.I0[0] = -A0
model.I0[1] = A0
model.I0[2] = float(B0)
model.I0[3] = A0
model.I0[4] = A0

BC_pair = model.init_ff_input()

model.I0[0] = -A0
model.I0[1] = A0
model.I0[2] = float(B0)
model.I0[3] = A0
model.I0[4] = -A0

BD_pair = model.init_ff_input()

ff_input = torch.cat((AC_pair, BD_pair, AD_pair, BC_pair))
print('ff_input', ff_input.shape)
```

``` ipython
labels_A = torch.ones((2*model.N_BATCH, 2))
labels_B = torch.zeros((2*model.N_BATCH, 2))
labels = torch.cat((labels_A, labels_B))

print('labels', labels.shape)
```

``` ipython
rates = model.forward(ff_input=ff_input).detach()
print(rates.shape)
```

``` ipython
perf = performance_score(model, rates, labels.to(DEVICE))
```

``` ipython
print(perf)
```

``` ipython
rates = rates.cpu().detach().numpy()
memory = model.low_rank.U.cpu().detach().numpy()[model.slices[0], 0]
readout = model.low_rank.U.cpu().detach().numpy()[model.slices[0], 1]
# readout = model.low_rank.linear.weight.data[0].cpu().detach().numpy()
plot_overlap(rates, memory, readout, labels=['pair', 'unpair'], figname='dual_train_overlap.svg')
```

``` ipython
idx = get_idx(model, 1)
plot_rates_selec(rates, idx, figname='dual_train_raster.svg')
```

``` ipython
plot_m0_m1_phi(rates, idx, figname='dual_train_fourier.svg')
```

``` ipython
from matplotlib.patches import Circle
m0, m1, phi = decode_bump(rates[..., idx], axis=-1)

x = m1 / m0 * np.cos(phi)
y = m1 / m0 * np.sin(phi)

xA = x
yA = y

fig, ax = plt.subplots(1, 1, figsize=[height, height])

ax.plot(xA.T[0], yA.T[0], 'x', alpha=.5, ms=10)
ax.plot(xA.T, yA.T, '-', alpha=.5)
ax.plot(xA.T[-1], yA.T[-1], 'o', alpha=.5, ms=10)
# ax.set_xlim([-.9, .9])
# ax.set_ylim([-.9, .9])
circle = Circle((0., 0.), 1, fill=False, edgecolor='k')
ax.add_patch(circle)

# Set the aspect of the plot to equal to make the circle circular
ax.set_aspect('equal')

plt.show()
```

``` ipython
```

Fixed points
------------

``` ipython
model.DURATION = 20
model.N_STEPS = int(model.DURATION / model.DT) + model.N_STEADY + model.N_WINDOW
model.IF_RL = 0
```

``` ipython
model.eval()
model.LR_READOUT = 0
```

``` ipython
model.N_BATCH = 1

model.I0[0] = A0
model.I0[1] = 0
model.I0[2] = 0
model.I0[3] = 0

AC_pair = model.init_ff_input()

model.I0[0] = A0
model.I0[1] = 0
model.I0[2] = 0
model.I0[3] = 0

AD_pair = model.init_ff_input()

model.I0[0] = -A0
model.I0[1] = 0
model.I0[2] = 0
model.I0[3] = 0

BC_pair = model.init_ff_input()

model.I0[0] = -A0
model.I0[1] = 0
model.I0[2] = 0
model.I0[3] = 0

BD_pair = model.init_ff_input()

ff_input = torch.cat((AC_pair, BD_pair, AD_pair, BC_pair))
print('ff_input', ff_input.shape, ff_input[0, 0, :4])
```

``` ipython
rates = model.forward(ff_input=ff_input).cpu().detach().numpy()
print(rates.shape)
```

``` ipython
memory = model.low_rank.U.cpu().detach().numpy()[model.slices[0], 0]
readout = model.low_rank.U.cpu().detach().numpy()[model.slices[0], 1]
# readout = model.low_rank.linear.weight.data[0].cpu().detach().numpy()
plot_overlap(rates, memory, readout, labels=['pair', 'unpair'])
```

``` ipython
idx = get_idx(model, 1)
plot_rates_selec(rates, idx)
```

``` ipython
print(rates.shape)
```

``` ipython
plot_m0_m1_phi(rates, idx)
```

``` ipython
print(rates.shape)
```

``` ipython
plt.plot(rates[:, :,1].T)
# plt.xlim([0, 10])
plt.show()
```

``` ipython
```

``` ipython
from matplotlib.patches import Circle
m0, m1, phi = decode_bump(rates[..., idx], axis=-1)

x = m1 / m0 * np.cos(phi)
y = m1 / m0 * np.sin(phi)

xA = x
yA = y

fig, ax = plt.subplots(1, 1, figsize=[height, height])

# ax.plot(xA.T[0], yA.T[0], 'x', alpha=.5, ms=10)
# ax.plot(xA.T, yA.T, '-', alpha=.5)
ax.plot(xA.T[-1], yA.T[-1], 'o', alpha=.5, ms=20)
# ax.set_xlim([-.9, .9])
# ax.set_ylim([-.9, .9])
circle = Circle((0., 0.), 1.7, fill=False, edgecolor='k')
ax.add_patch(circle)

# Set the aspect of the plot to equal to make the circle circular
ax.set_aspect('equal')
plt.savefig('fp_dual_train.svg', dpi=300)
plt.show()
```

``` ipython
print(float(B0))
```