Notebook Settings
=================

``` ipython
%load_ext autoreload
%autoreload 2
%reload_ext autoreload

%run ../../../notebooks/setup.py
%matplotlib inline
%config InlineBackend.figure_format = 'png'

REPO_ROOT = "/home/leon/models/NeuroFlame"
pal = sns.color_palette("tab10")
```

Imports
=======

``` ipython
import torch
import torch.nn as nn
import torch.optim as optim
# import torchmetrics
import torch.nn.functional as F
from torch.utils.data import Dataset, TensorDataset, DataLoader
```

``` ipython
import sys
sys.path.insert(0, '../../../')

import pandas as pd
import torch.nn as nn
from time import perf_counter
from scipy.stats import circmean

from src.network import Network
from src.plot_utils import plot_con
from src.decode import decode_bump, circcvl, decode_bump_torch
from src.lr_utils import masked_normalize, clamp_tensor, normalize_tensor
```

Helpers
=======

plots
-----

``` ipython
def add_vlines(model, ax=None):

    if ax is None:
        for i in range(len(model.T_STIM_ON)):
            plt.axvspan(model.T_STIM_ON[i], model.T_STIM_OFF[i], alpha=0.25)
    else:
        for i in range(len(model.T_STIM_ON)):
            ax.axvspan(model.T_STIM_ON[i], model.T_STIM_OFF[i], alpha=0.25)

```

``` ipython
def plot_rates_selec(rates, idx=0, thresh=0.5, figname='fig.svg'):
        fig, ax = plt.subplots(1, 2, figsize=[2*width, height])
        r_max = thresh * np.max(rates[idx])

        idx = np.random.randint(0, 96)
        vmin, vmax = np.percentile(rates[idx].reshape(-1), [5, 95])

        ax[0].imshow(rates[idx].T, aspect='auto', cmap='jet', vmin=vmin, vmax=vmax)
        ax[0].set_ylabel('Neuron #')
        ax[0].set_xlabel('Step')

        idx = np.random.randint(0, 96)
        vmin, vmax = np.percentile(rates[idx].reshape(-1), [5, 95])
        ax[1].imshow(rates[idx].T, aspect='auto', cmap='jet', vmin=vmin, vmax=vmax)
        ax[1].set_ylabel('Neuron #')
        ax[1].set_xlabel('Step')
        # ax[1].set_ylim([745, 755])
        # plt.savefig(figname, dpi=300)
        plt.show()
```

``` ipython
def plot_m0_m1_phi(model, rates, idx, figname='fig.svg'):

    m0, m1, phi = decode_bump_torch(rates, axis=-1, RET_TENSOR=0)
    print(m0.shape, m1.shape, phi.shape)

    fig, ax = plt.subplots(1, 3, figsize=[2*width, height])

    xtime = np.linspace(0, model.DURATION, m0.shape[-1])
    idx = np.random.randint(0, 96, 16)

    ax[0].plot(xtime, m0[idx].T)
    #ax[0].set_ylim([0, 360])
    #ax[0].set_yticks([0, 90, 180, 270, 360])
    ax[0].set_ylabel('$\mathcal{F}_0$ (Hz)')
    ax[0].set_xlabel('Time (s)')
    add_vlines(model, ax[0])

    ax[1].plot(xtime, m1[idx].T)
    # ax[1].set_ylim([0, 360])
    # ax[1].set_yticks([0, 90, 180, 270, 360])
    ax[1].set_ylabel('$\mathcal{F}_1$ (Hz)')
    ax[1].set_xlabel('Time (s)')
    add_vlines(model, ax[1])

    ax[2].plot(xtime, phi[idx].T * 180 / np.pi, alpha=.5)
    ax[2].set_ylim([0, 360])
    ax[2].set_yticks([0, 90, 180, 270, 360])
    ax[2].set_ylabel('Phase (°)')
    ax[2].set_xlabel('Time (s)')
    add_vlines(model, ax[2])

    plt.savefig(figname, dpi=300)
    plt.show()
```

Data Split
----------

``` ipython
from sklearn.model_selection import train_test_split, StratifiedShuffleSplit

def split_data(X, Y, train_perc=0.8, batch_size=32, shuffle=True):

    # if shuffle:
    #     X_train, X_test, Y_train, Y_test = train_test_split(X, Y,
    #                                                         train_size=train_perc,
    #                                                         stratify=Y[:, 0].cpu().numpy(),
    #                                                         shuffle=True)
    # else:
    X_train, X_test, Y_train, Y_test = train_test_split(X, Y,
                                                        train_size=train_perc,
                                                        stratify=None,
                                                        shuffle=False)

    plt.hist(Y_train[Y_train!=-999].cpu() * 180 / np.pi, bins=15, label='train')
    plt.hist(Y_test[Y_test!=-999].cpu() * 180 / np.pi, bins=15, label='test')
    plt.xlabel('Target Loc. (°)')
    plt.ylabel('Count')
    plt.show()

    print(X_train.shape, X_test.shape)
    print(Y_train.shape, Y_test.shape)

    train_dataset = TensorDataset(X_train, Y_train)
    val_dataset = TensorDataset(X_test, Y_test)

    # Create data loaders
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=shuffle)
    val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader
```

Optimization
------------

``` ipython
def prune_and_grow_vectorized(W, mask, K):
    new_mask = mask.clone()

    # --- PRUNE ---
    big_num = 1e9
    # Set inactive weights to inf (so they're not selected for pruning)
    to_prune = W.abs().clone()
    to_prune[new_mask == 0] = big_num

    # Find indices of K smallest per-column
    prune_vals, prune_idx = torch.topk(-to_prune, K, dim=0)  # negative for smallest values

    # Vectorized column/row indices for scatter
    cols = torch.arange(N).view(1, -1).expand(K, -1)   # shape (K, N)
    new_mask[prune_idx, cols] = 0.0

    # --- GROW ---
    # Regrow: randomly select K locations per column where mask == 0
    # Make a mask for regrow-eligible entries
    grow_candidates = (new_mask == 0).float()
    # Random scores for each eligible position, -inf for non-candidates
    grow_scores = torch.rand_like(W) * grow_candidates + (1 - grow_candidates) * (-big_num)
    grow_vals, grow_idx = torch.topk(grow_scores, K, dim=0)
    new_mask[grow_idx, cols] = 1.0

    return new_mask
```

``` ipython
def training_step(dataloader, model, loss_fn, optimizer, SET=0):
    device = torch.device(DEVICE if torch.cuda.is_available() else "cpu")

    model.train()
    total_loss = 0.0
    total_batches = len(dataloader)

    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        optimizer.zero_grad()

        rates = model(X)
        loss = loss_fn(rates, y)

        loss.backward()
        optimizer.step()

        if SET:
            mask = prune_and_grow_vectorized(W, mask, model.Ka[0])

        total_loss += loss.item()

    avg_loss = total_loss / total_batches
    return avg_loss
```

``` ipython
def validation_step(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    device = torch.device(DEVICE if torch.cuda.is_available() else "cpu")

    model.eval()
    val_loss = 0.0

    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)

            rates = model(X)
            batch_loss = loss_fn(rates, y)
            val_loss += batch_loss.item() * X.size(0)

    val_loss /= size
    return val_loss
```

``` ipython
def optimization(model, train_loader, val_loader, loss_fn, optimizer, num_epochs=100, thresh=0.005, gamma=0.9):

    # Choose one scheduler
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)
    # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.1, verbose=True)
    # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

    device = torch.device(DEVICE if torch.cuda.is_available() else 'cpu')
    model.to(device)

    loss_list = []
    val_loss_list = []

    for epoch in range(num_epochs):
        loss = training_step(train_loader, model, loss_fn, optimizer)
        val_loss = validation_step(val_loader, model, loss_fn)

        if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
            scheduler.step(val_loss)
        else:
            scheduler.step()

        loss_list.append(loss)
        val_loss_list.append(val_loss)

        print(f'Epoch {epoch+1}/{num_epochs}, Training Loss: {loss:.4f}, Validation Loss: {val_loss:.4f}')

        if val_loss < thresh and loss < thresh:
            print(f'Stopping training as loss has fallen below the threshold: {loss}, {val_loss}')
            break

        if val_loss > 300:
            print(f'Stopping training as loss is too high: {val_loss}')
            break

        if torch.isnan(torch.tensor(loss)):
            print(f'Stopping training as loss is NaN.')
            break

    return loss_list, val_loss_list
```

Loss
----

``` ipython
import torch
import torch.nn as nn

class CircularAngleLoss(nn.Module):
    def __init__(self, mode='angular', reduction='mean'):
        super().__init__()
        self.mode = mode
        self.reduction = reduction
        self.mse = nn.MSELoss(reduction=reduction)

    def forward(self, pred_angle, target_angle):
        if self.mode == 'polar':
            pred_sin, pred_cos = torch.sin(pred_angle), torch.cos(pred_angle)
            target_sin, target_cos = torch.sin(target_angle), torch.cos(target_angle)
            loss_sin = self.mse(pred_sin, target_sin)
            loss_cos = self.mse(pred_cos, target_cos)
            return (loss_sin + loss_cos) / 2

        elif self.mode == 'angular':
            error = 1 - torch.cos(pred_angle - target_angle)
            if self.reduction == 'mean':
                return error.mean()
            elif self.reduction == 'sum':
                return error.sum()
            else:
                return error
        else:
            raise ValueError(f"Unknown loss mode: {self.mode}")
```

``` ipython
import torch
import torch.nn as nn
import torch.distributions

class VonMisesNLLLoss(nn.Module):
    def __init__(self, kappa=4.0, reduction='none'):
        super().__init__()
        self.kappa = kappa
        self.reduction = reduction

    def forward(self, pred_angle, target_angle):
        # pred_angle and target_angle in radians, same shape
        vm = torch.distributions.VonMises(pred_angle, self.kappa)
        nll = -vm.log_prob(target_angle)
        if self.reduction == 'mean':
            return nll.mean()
        elif self.reduction == 'sum':
            return nll.sum()
        else:
            return nll  # (no reduction)
```

``` ipython
import torch
import torch.nn as nn
import torch.nn.functional as F

class AngularErrorLoss(nn.Module):
    def __init__(self, thresh=1.0, reg_tuning=0.1):
        super(AngularErrorLoss, self).__init__()

        self.loss = nn.MSELoss(reduction='none')
        # self.loss = nn.SmoothL1Loss(reduction='none')

        # self.polar_loss = VonMisesNLLLoss(reduction='none')
        self.polar_loss = CircularAngleLoss(reduction='none')

        self.thresh = thresh
        self.reg_tuning = reg_tuning

    def forward(self, readout, theta_batch):
        m0, m1, y_pred = decode_bump_torch(readout, axis=-1, device=readout.device)

        valid_mask = theta_batch != -999
        invalid_mask = ~valid_mask
        total_loss = 0

        # angular loss (Dcos, Dsin)
        loss_polar = self.polar_loss(theta_batch, y_pred) * valid_mask
        loss_angular = loss_polar.sum()
        total_loss += loss_angular

        # imposing tuning strength
        regularization = F.relu((self.thresh * m0 - m1)) * valid_mask
        # regularization = F.relu((1.0 - m1 / (self.thresh * m0 + 1e-6))) * valid_mask
        total_loss += self.reg_tuning * regularization.sum()

        regularization += F.relu((0.5 - m0)) * valid_mask
        total_loss += regularization.sum()

        regularization += F.relu((0.5 - m1)) * valid_mask
        total_loss += regularization.sum()

        # normalize over batch and time points
        total_loss /= valid_mask.sum()

        # imposing zero tuning in invalid mask
        loss_zero = self.loss(m1, 0.0 * m1) * invalid_mask
        total_loss += loss_zero.sum() / invalid_mask.sum()

        return total_loss
```

Other
-----

``` ipython
import pickle as pkl

def pkl_save(obj, name, path="."):
      pkl.dump(obj, open(path + "/" + name + ".pkl", "wb"))


def pkl_load(name, path="."):
     return pkl.load(open(path + "/" + name + '.pkl', "rb"))

```

``` ipython
def convert_seconds(seconds):
    h = seconds // 3600
    m = (seconds % 3600) // 60
    s = seconds % 60
    return h, m, s
```

Model
=====

``` ipython
kwargs = {
    'TRAINING': 1,
    'LR_INI': 0.001,

    'GAIN': 1.0,
    'DURATION': 13.0,
    'T_STEADY': 10,

    'T_STIM_ON': [4.0, 8.0],
    'T_STIM_OFF': [5.0, 9.0],

    'Jab': [1.0, -1.5, 1.0, -1],
    'Ja0': [2.0, 1.0],

    'STIM_EI': 0,
    'I0': [1.0, -10.0],
    'PHI0': [180.0, 180],
    'SIGMA0': [1.0, 0.0],
    'M0': 1.0,

    'RANDOM_DELAY': 1,
    'MIN_DELAY': 1,
    'MAX_DELAY': 6,

    'IF_STP': 1,
    'IS_STP': [1, 0, 0, 0],
    'J_STP': 1.0,
    'USE': [0.03, 0.03, 0.03, 0.1],
    'TAU_FAC': [2.0, 2.0, 2.0, 0.0],
    'TAU_REC': [0.2, 0.2, 0.2, 0.1],
    'W_STP': [1.0, 3.0, 4.0, 1.0],

    'DT': 0.02,
    'RATE_DYN': 1,
    'TAU': [0.2, 0.1],

    'SYN_DYN': 0,
    'TAU_SYN': [0.2, 0.1],

    'IF_NMDA': 1,
    'R_NMDA': 1.0,
    'TAU_NMDA': [0.5, 0.5],

    'IF_ADAPT': 0,
    'A_ADAPT': 1.0,
    'TAU_ADAPT': 100.0,
}
```

``` ipython
REPO_ROOT = "/home/leon/models/NeuroFlame"
conf_name = "train_odr_EI.yml"
# conf_name = "train_vanilla.yml"
DEVICE = 'cuda:1'

total_batches = 128 * 5
batch_size = 128

ratio = total_batches // batch_size

N_BATCH = int(batch_size * ratio)
print('N_BATCH', N_BATCH, 'batch_size', batch_size)

seed = np.random.randint(0, 1e6)
seed = 0
print('seed', seed)
```

``` ipython
model = Network(conf_name, REPO_ROOT, VERBOSE=0, DEVICE=DEVICE, SEED=seed, N_BATCH=N_BATCH, **kwargs)
```

Training
========

``` ipython
model.J_STP.requires_grad = True
```

### Parameters

``` ipython
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, param.shape)
```

``` ipython
model.N_BATCH = N_BATCH
rwd_mask = torch.zeros((model.N_BATCH, int((model.N_STEPS-model.N_STEADY) / model.N_WINDOW)), device=DEVICE, dtype=torch.bool)
print('rwd_mask', rwd_mask.shape)

for i in range(model.N_BATCH):
    # from first stim onset to second stim onset
    mask = torch.arange((model.start_indices[0, i] - model.N_STEADY)/ model.N_WINDOW,
                        (model.start_indices[1, i] - model.N_STEADY) / model.N_WINDOW).to(torch.int)
    # print(mask)
    rwd_mask[i, mask] = True

idx = np.random.randint(N_BATCH)
print(torch.where(rwd_mask[idx]==1)[0])
```

### Inputs and Labels

``` ipython
total_batches = N_BATCH // batch_size

print('total_batches', N_BATCH // batch_size)

labels = []
for _ in range(total_batches):
    batch_labels = torch.randint(0, 360, (batch_size, 1)).to(DEVICE)
    labels.append(batch_labels)

labels = torch.cat(labels, dim=0)
print(labels.shape)
```

``` ipython
model.PHI0 = torch.ones((N_BATCH, 2, 1), device=DEVICE, dtype=torch.float)
model.PHI0[:, 0] = labels * np.pi / 180.0

window_size = int((model.N_STEPS-model.N_STEADY) / model.N_WINDOW)
labels = labels.repeat(1, window_size) * np.pi / 180.0
labels[~rwd_mask] = -999

ff_input = model.init_ff_input()
print(model.PHI0.shape, ff_input.shape, labels.shape)
```

``` ipython
plt.hist(labels[labels!=-999].cpu() * 180 / np.pi, bins=15)
plt.xlabel('Target Loc. (°)')
plt.show()
```

### Run

``` ipython
train_loader, val_loader = split_data(ff_input, labels, train_perc=0.8, batch_size=batch_size, shuffle=False)
```

``` ipython
criterion = AngularErrorLoss(thresh=1.5)
learning_rate = 0.1
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
```

``` ipython
num_epochs = 20
start = perf_counter()
loss = optimization(model, train_loader, val_loader, criterion, optimizer, num_epochs, thresh=.005)
end = perf_counter()
print("Elapsed (with compilation) = %dh %dm %ds" % convert_seconds(end - start))
```

``` ipython
torch.save(model.state_dict(), '../models/odr/odr_%d.pth' % seed)
```

``` ipython
from src.utils import clear_cache
clear_cache()
```

Testing
=======

``` ipython
model_state_dict = torch.load('../models/odr/odr_%d.pth' % seed);
print(model_state_dict.keys())
print(torch.allclose(model.Wab_train, model_state_dict['Wab_train']))
model.load_state_dict(model_state_dict);
model.eval();
print(model.J_STP)
```

``` ipython
with torch.no_grad():
    model.N_BATCH = N_BATCH

    labels = torch.randint(0, 360, (N_BATCH, 1)).to(DEVICE) * torch.pi / 180.0
    model.PHI0 = torch.ones((N_BATCH, 2, 1), device=DEVICE, dtype=torch.float)
    model.PHI0[:, 0] = labels


    ff_input = model.init_ff_input()
    print(model.PHI0.shape, ff_input.shape, labels.shape)
```

``` ipython
target_loc = labels[:, 0].cpu() * 180 / np.pi

plt.hist(target_loc, bins='auto', density=True)
plt.xlabel('Target Loc. (°)')
plt.ylabel('Density')
plt.xticks(np.linspace(0, 360, 5))
# plt.savefig('./figs/memhist/targets.svg', dpi=300)
plt.show()
```

``` ipython
with torch.no_grad():
    rates = model.forward(ff_input=ff_input).cpu().detach().numpy()
print('rates', rates.shape)
```

``` ipython
plot_rates_selec(rates=rates, idx=20, thresh=20)
```

``` ipython
plot_m0_m1_phi(model, rates, 4)
```

``` ipython
# targets = (target_loc + np.pi) % (2 * np.pi) - np.pi

# fig, ax = plt.subplots(1, 2, figsize=[2*width, height])
# # ax[0].hist(targets[:, 0] * 180 / np.pi , bins=32 , histtype='step')
# ax[0].hist(errors2, bins=32, histtype='step')
# ax[0].set_xlabel('Encoding Errors (°)')

# ax[1].hist(errors, bins=32)
# ax[1].set_xlabel('Memory Errors (°)')
# # ax[1].set_xlim([-45, 45])
# plt.show()
```

Connectivity
============

``` ipython
from src.lr_utils import clamp_tensor
# Cij = model.GAIN * model.J_STP * (model.W_stp_T[0]  / torch.sqrt(model.Ka[0])
#                                 + model.Wab_train[model.slices[0], model.slices[0]]
#                                 * torch.sqrt(model.Ka[0]) / model.Na[0])

Cij = model.GAIN * model.J_STP * model.W_stp_T[0] / torch.sqrt(model.Ka[0]) * (1.0 + model.Wab_train[model.slices[0], model.slices[0]])


# Cij = model.GAIN * ( model.W_stp_T[0]  + model.Wab_train[model.slices[0], model.slices[0]])
Cij = clamp_tensor(Cij, 0, model.slices).cpu().detach().numpy()
# Cij = Cij>0
```

``` ipython
from scipy.ndimage import gaussian_filter1d, uniform_filter1d

plt.figure(figsize=(2.5*width, 1.5*height))  # Set the figure size (width, height) in inches

ax1 = plt.subplot2grid((2, 3), (0, 0), rowspan=2)
im = ax1.imshow(Cij, cmap='jet', aspect=1, vmin=0)
ax1.set_xlabel("Presynaptic")
ax1.set_ylabel("Postsynaptic")

# Second column, first row
ax2 = plt.subplot2grid((2, 3), (0, 1))
Kj = np.sum(Cij, axis=0)  # sum over pres
ax2.plot(uniform_filter1d(Kj, size=75))
# ax2.set_xticklabels([])
ax2.set_ylabel("$K_j$")

# # Second column, second row
ax3 = plt.subplot2grid((2, 3), (1, 1))
Ki = np.sum(Cij, axis=1)  # sum over pres
ax3.plot(uniform_filter1d(Ki, size=75))
ax3.set_ylabel("$K_i$")

ax4 = plt.subplot2grid((2, 3), (0, 2), rowspan=2)
diags = []
for i in range(int(Cij.shape[0] / 2)):
   diags.append(np.trace(Cij, offset=i) / Cij.shape[0])
diags = np.array(diags)
ax4.plot(diags)
ax4.set_xlabel("Neuron #")
ax4.set_ylabel("$P_{ij}$")

plt.tight_layout()
plt.show()
```

``` ipython
Dij = Cij.flatten()
np.random.shuffle(Dij)
Dij = Dij.reshape(Cij.shape)

from scipy.ndimage import gaussian_filter1d, uniform_filter1d

plt.figure(figsize=(2.5*width, 1.5*height))  # Set the figure size (width, height) in inches

ax1 = plt.subplot2grid((2, 3), (0, 0), rowspan=2)
im = ax1.imshow(Dij, cmap='jet', aspect=1, vmin=0)
ax1.set_xlabel("Presynaptic")
ax1.set_ylabel("Postsynaptic")

# Second column, first row
ax2 = plt.subplot2grid((2, 3), (0, 1))
Kj = np.sum(Dij, axis=0)  # sum over pres
ax2.plot(uniform_filter1d(Kj, size=75))
# ax2.set_xticklabels([])
ax2.set_ylabel("$K_j$")

# # Second column, second row
ax3 = plt.subplot2grid((2, 3), (1, 1))
Ki = np.sum(Dij, axis=1)  # sum over pres
ax3.plot(uniform_filter1d(Ki, size=75))
ax3.set_ylabel("$K_i$")

ax4 = plt.subplot2grid((2, 3), (0, 2), rowspan=2)
diags = []
for i in range(int(Dij.shape[0] / 2)):
   diags.append(np.trace(Dij, offset=i) / Dij.shape[0])
diags = np.array(diags)
ax4.plot(diags)
ax4.set_xlabel("Neuron #")
ax4.set_ylabel("$P_{ij}$")

plt.tight_layout()
plt.show()
```

``` ipython
```