Imports
=======

``` ipython
import torch
import torch.nn as nn
import torch.optim as optim
import torchmetrics
import gc
from torch.utils.data import Dataset, TensorDataset, DataLoader

REPO_ROOT = "/home/leon/models/NeuroFlame"

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

pal = sns.color_palette("tab10")
DEVICE = 'cuda:1'
```

``` ipython
import sys
sys.path.insert(0, '../../../')

from notebooks.setup import *

import pandas as pd
import torch.nn as nn
from time import perf_counter
from scipy.stats import circmean

from src.network import Network
from src.plot_utils import plot_con
from src.decode import decode_bump, circcvl
from src.lr_utils import masked_normalize, clamp_tensor, normalize_tensor
```

``` ipython
import pickle as pkl

def pkl_save(obj, name, path="."):
      pkl.dump(obj, open(path + "/" + name + ".pkl", "wb"))


def pkl_load(name, path="."):
      return pkl.load(open(path + "/" + name + '.pkl', "rb"))

```

``` ipython
def add_vlines(model, ax=None):

    if ax is None:
        for i in range(len(model.T_STIM_ON)):
            plt.axvspan(model.T_STIM_ON[i], model.T_STIM_OFF[i], alpha=0.25)
    else:
        for i in range(len(model.T_STIM_ON)):
            ax.axvspan(model.T_STIM_ON[i], model.T_STIM_OFF[i], alpha=0.25)

```

Notebook Settings
=================

``` ipython
%load_ext autoreload
%autoreload 2
%reload_ext autoreload
%run ../../../notebooks/setup.py
%matplotlib inline
%config InlineBackend.figure_format = 'png'
```

``` example
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Python exe
/home/leon/mambaforge/envs/torch/bin/python
```

Utils
=====

``` ipython
def init_model(task, seed, **kwargs):
    model = Network(conf_name, REPO_ROOT, VERBOSE=0, DEVICE=DEVICE, SEED=seed, N_BATCH=1, **kwargs)
    path = model.SAVE_PATH
    model_state_dict = torch.load('%s/%s_%d.pth' % (path, task, seed))
    model.load_state_dict(model_state_dict)
    # print('task', task, 'seed', seed)

    return model
```

``` ipython
def del_tensor(tensor):
    DEVICE = tensor.device
    del tensor
    gc.collect()

    torch.cuda.empty_cache()
    torch.cuda.device(DEVICE)
    torch.cuda.synchronize()
    torch.cuda.reset_accumulated_memory_stats(DEVICE)
```

``` ipython
def run_grid(GRID_RANGE, seed, task, **kwargs):

    GRID_LIST = [[-GRID_RANGE, 0], [0, GRID_RANGE]]

    rates_grid = []
    with torch.no_grad():
        for GRID_X_RANGE in GRID_LIST:
            for GRID_Y_RANGE in GRID_LIST:
                model = init_model(task, seed, **kwargs)

                model.GRID_X_RANGE = GRID_X_RANGE
                model.GRID_Y_RANGE = GRID_Y_RANGE

                model.N_BATCH = int(model.GRID_SIZE * model.GRID_SIZE)

                with torch.no_grad():
                    ff_input = model.init_ff_input()
                    # print(ff_input.shape, model.N_BATCH)

                    rates = model(ff_input, RET_REC=0).cpu().detach().numpy()
                    # print('rates', rates.shape)
                rates_grid.append(rates)

                del_tensor(ff_input)
                del_tensor(model)

    return np.vstack(rates_grid)
```

``` ipython
def get_low_rank(rates, model, IF_REC=0):
    if IF_REC==0:
        vec1 = model.low_rank.V.T[0]
        vec2 = model.low_rank.V.T[1]

        vec2 = vec2 - (vec2 @ vec1) * vec1 / (vec1 @ vec1)

        # vec1 = vec1 / torch.linalg.norm(vec1)
        # vec2 = vec2 / torch.linalg.norm(vec2)

        vec = torch.stack((vec1, vec2))
        overlaps = rates @ vec.T / model.Na[0]
    else:
        vec1 = model.low_rank.U.T[0]
        vec2 = model.low_rank.U.T[1]
        # vec2 = vec2 - (vec2 @ vec1) * vec1 / (vec1 @ vec1)
        vec1 = vec1 / torch.linalg.norm(vec1)**2
        vec2 = vec2 / torch.linalg.norm(vec2)**2

        vec = torch.stack((vec1, vec2))
        overlaps = model.rec_input[0, :, :] @ vec.T

    return overlaps.cpu().detach().numpy(), vec.cpu().detach().numpy()
```

``` ipython
import numpy as np

def get_bissec(point1, point2, length=100):
    # Calculate the directional vector of the original line
    direction = point2 - point1
    print(direction.shape)
    # Midpoint of the line segment
    midpoint = (point1 + point2) / 2

    # Direction of the orthogonal line (perpendicular vector)
    orthogonal_direction = np.array([-direction[1], direction[0]])

    # Normalize the orthogonal direction
    orthogonal_direction = orthogonal_direction / np.linalg.norm(orthogonal_direction)

    # Calculate the endpoints of the orthogonal line segment
    endpoint1 = midpoint - (length / 2) * orthogonal_direction
    endpoint2 = midpoint + (length / 2) * orthogonal_direction

    return np.array([endpoint1, endpoint2])
```

``` ipython
from scipy.interpolate import griddata

def create_mesh(x, y, size=100):
    x_min, x_max = np.min((x, y)) - 1, np.max((x, y)) + 1
    y_min, y_max = np.min((x, y)) - 1, np.max((x, y)) + 1

    dx = np.gradient(x, axis=1)
    dy = np.gradient(y, axis=1)

    # Create a dense grid
    xi, yi = np.meshgrid(np.linspace(x_min, x_max, size),
                         np.linspace(y_min, y_max, size))

    # Flatten your dx and dy along with x and y for interpolation
    points = np.vstack((x.flatten(), y.flatten())).T
    dx_flat = dx.flatten()
    dy_flat = dy.flatten()


    # Interpolating on the grid
    ui = griddata(points, dx_flat, (xi, yi), method='linear', fill_value=np.nan)
    vi = griddata(points, dy_flat, (xi, yi), method='linear', fill_value=np.nan)

    return xi, yi, ui, vi
```

``` ipython
import numpy as np
from scipy.interpolate import griddata
from scipy.ndimage import gaussian_filter
from scipy.spatial import cKDTree

def create_mesh(x, y, size=100, sigma=0, interp_method='nearest', mask_radius=10):
    """
    x, y: arrays of shape (n_traj, n_points)
    size: grid size along each axis
    sigma: Gaussian smoothing for velocities (0=none)
    interp_method: 'linear', 'cubic', or 'nearest'
    mask_radius: mask out grid points farther than this multiple of median point spacing

    Returns: xi, yi, ui, vi (masked arrays)
    """
    x = np.asarray(x)
    y = np.asarray(y)

    # Flatten for easier handling
    x_flat = x.flatten()
    y_flat = y.flatten()

    # Compute dense grid
    x_min, x_max = np.min(x_flat)-1, np.max(x_flat)+1
    y_min, y_max = np.min(y_flat)-1, np.max(y_flat)+1

    xi, yi = np.meshgrid(np.linspace(x_min, x_max, size),
                         np.linspace(y_min, y_max, size))

    # Compute velocities (finite differences along time axis)
    dx = np.gradient(x, axis=1)
    dy = np.gradient(y, axis=1)

    # Optional smoothing of velocities
    if sigma > 0:
        dx = gaussian_filter(dx, sigma=sigma)
        dy = gaussian_filter(dy, sigma=sigma)

    dx_flat = dx.flatten()
    dy_flat = dy.flatten()

    # Prepare for griddata interpolation
    points = np.vstack((x_flat, y_flat)).T

    # Interpolate velocity components onto grid
    ui = griddata(points, dx_flat, (xi, yi), method=interp_method, fill_value=np.nan)
    vi = griddata(points, dy_flat, (xi, yi), method=interp_method, fill_value=np.nan)

    # Find where it failed
    mask = np.isnan(ui)

    # Interpolate only those points with 'nearest'
    if np.any(mask):
        ui_nearest = griddata(points, dx_flat, (xi, yi), method='nearest')
        vi_nearest = griddata(points, dy_flat, (xi, yi), method='nearest')
        ui[mask] = ui_nearest[mask]
        vi[mask] = vi_nearest[mask]

    # # Mask far-from-data regions (optional)
    # tree = cKDTree(points)
    # dists, _ = tree.query(np.column_stack([xi.flatten(), yi.flatten()]), k=1)
    # dists = dists.reshape(xi.shape)
    # median_spacing = np.median(np.sqrt(np.diff(x_flat)**2 + np.diff(y_flat)**2))
    # mask = dists > (mask_radius * median_spacing)
    # ui = np.ma.masked_where(mask, ui)
    # vi = np.ma.masked_where(mask, vi)

    return xi, yi, ui, vi
```

``` ipython
import matplotlib as mpl

def plot_field(overlaps, ax, window, IF_FP=0, task=0, GRID_TEST=0, IF_CBAR=0):
    x = overlaps[:, window:, 0]
    y = overlaps[:, window:, 1]

    xi, yi, ui, vi = create_mesh(x, y, size=300)
    speed = np.sqrt(ui**2+vi**2)
    speed = (speed - np.mean(speed)) / (np.std(speed) + 1e-6)

    center, center_ = get_fp(overlaps, window, task, GRID_TEST=GRID_TEST)

    vmin, vmax = np.nanpercentile(speed, [5, 95])
    norm = mpl.colors.Normalize(vmin, vmax)

    heatmap = ax.streamplot(xi, yi, ui, vi, density=0.5, arrowsize=1.25, norm=norm, color=('w', 0.5))
    heatmap = ax.pcolormesh(xi, yi, speed, cmap='coolwarm', shading='gouraud', norm=norm)
    # heatmap = ax.imshow(speed, extent=(yi.min(), yi.max(), yi.min(), yi.max()), cmap='jet', norm=norm, origin='lower', aspect='auto')

    ax.plot(center.T[0], center.T[1], 'o', color='w', ms=18)
    # if GRID_TEST is not None:
    #     ax.plot(center_.T[0], center_.T[1], 'o', color='w', ms=18)

    # ax.set_aspect('equal')
    # ax.set_xlim([yi.min(), yi.max()])
    # ax.set_ylim([yi.min(), yi.max()])
    ax.set_yticks([-10, 0, 10])
    heatmap.set_clim(-1.5, 1.5)

    if IF_CBAR:
        cbar = plt.colorbar(heatmap, ax=ax)
        cbar.set_label('Norm. Speed')
        # cbar.set_clim(-1.5, 1.5)

    ax.set_xlabel('A/B Overlap')
    ax.set_ylabel('Choice Overlap')
```

``` ipython
def save_fig(figname, GRID_TEST, format='png'):

    if GRID_TEST==4:
        plt.savefig('../figures/flow/%s_test_C_%d.%s' % (figname, seed, format), dpi=300)
    elif GRID_TEST==9:
        plt.savefig('../figures/flow/%s_test_D_%d.%s' % (figname, seed, format), dpi=300)
    elif GRID_TEST==1:
        plt.savefig('../figures/flow/%s_go_%d.%s' % (figname, seed, format), dpi=300)
    elif GRID_TEST==6:
        plt.savefig('../figures/flow/%s_nogo_%d.%s' % (figname, seed, format), dpi=300)
    elif GRID_TEST==0:
        plt.savefig('../figures/flow/%s_sample_A_%d.%s' % (figname, seed, format), dpi=300)
    elif GRID_TEST==5:
        plt.savefig('../figures/flow/%s_sample_B_%d.%s' % (figname, seed, format), dpi=300)
    else:
        plt.savefig('../figures/flow/%s_%d.%s' % (figname, seed, format), dpi=300)

```

``` ipython
from scipy.ndimage import map_coordinates

def integrate_to_attractor(xi, yi, ui, vi, attractors, n_steps=500, dt=0.05, tol=1e-2):
    """
    For each mesh point, integrate its trajectory and assign the attractor (index) it converges to.
    Returns: basin_map (shape of xi), index to attractor for each gridpoint.
    """
    shape = xi.shape
    positions = np.stack([xi.flatten(), yi.flatten()], axis=1)
    basin_idx = np.full(positions.shape[0], -1, dtype=int)

    # Make interpolators for u,v
    def interp_field(pos, field):
        # input pos: Nx2, field: mesh
        coords = [
            (pos[:,1] - yi[0,0]) / (yi[0,-1] - yi[0,0]) * (yi.shape[1]-1),
            (pos[:,0] - xi[0,0]) / (xi[-1,0] - xi[0,0]) * (xi.shape[0]-1)
        ]
        # Reversed axes, order is (y, x)
        return map_coordinates(field.T, coords, order=1, mode='nearest')

    # For each gridpoint, integrate until close to attractor or steps end
    curr = positions.copy()
    for step in range(n_steps):
        if np.all(basin_idx >= 0):
            break
        not_assigned = (basin_idx < 0)
        u_ = interp_field(curr[not_assigned], ui)
        v_ = interp_field(curr[not_assigned], vi)
        curr[not_assigned,0] += dt * u_
        curr[not_assigned,1] += dt * v_

        # Check for proximity to attractors
        for i, fp in enumerate(attractors):
            dists = np.linalg.norm(curr[not_assigned] - fp, axis=1)
            close = dists < tol
            basin_idx[not_assigned.nonzero()[0][close]] = i

    basin_map = basin_idx.reshape(shape)
    return basin_map
```

Model
=====

``` ipython
REPO_ROOT = "/home/leon/models/NeuroFlame"
conf_name = "train_dual.yml"
DEVICE = 'cuda:1'
```

``` ipython
kwargs = {
    'DURATION': 40.0,
    'TASK': 'dual_flow',
    'T_STIM_ON': [1.0, 2.0],
    'T_STIM_OFF': [2.0, 300.0],
    'I0': [1.0, 1.0],
    'GRID_SIZE': 10,
    'GRID_TEST': 0, # here
    'GRID_INPUT': 0,
    'IF_OPTO': 0
}
```

``` ipython
tasks = ['dpa']
tasks = ['dpa', 'dual_naive', 'dual_train']
seed = np.random.randint(100)
seed = 3
print(seed)
GRID_RANGE = 0.4
```

Flow
====

``` ipython
rates = []
for task in tasks:
        rates.append(run_grid(GRID_RANGE, seed, task, **kwargs))
rates = np.array(rates)
```

``` ipython
rates_tensor = torch.tensor(rates).to(DEVICE)
print(rates_tensor.shape)
```

``` ipython
model = init_model(task, seed, **kwargs)
overlaps, vec = get_low_rank(rates_tensor, model, IF_REC=0)
print(overlaps.shape)

window = int((model.N_STIM_OFF[0] - model.N_STEADY) / model.N_WINDOW) + 1

# ff_overlaps = ff_input[..., model.N_STEADY: , model.slices[0]] @ vec.T
# ff_overlaps = ff_overlaps[:, ::10]
# print(overlaps.shape, ff_overlaps.shape)
```

Field
=====

``` ipython
from sklearn.cluster import KMeans
def get_fp(overlaps, window, task, GRID_TEST=None, x=None, y=None):
    kmeans = KMeans(n_clusters=5, random_state=None)

    if x is None:
        x = overlaps[:, window:, 0]
        y = overlaps[:, window:, 1]

    x_fp = x[:, -1]
    y_fp = y[:, -1]
    fp = np.stack((x_fp, y_fp)).T

    # print(fp.shape)
    kmeans.fit(fp)
    center = np.array(kmeans.cluster_centers_)

    if task==2:
        center = center[:3]

    center_ = []

    if GRID_TEST is None:
        pkl_save(center, 'center_%s' % task, path="/home/leon/")
    else:
        center_ = pkl_load('center_%s' % task, path="/home/leon/")

    return center, center_
```

``` ipython
# fig, ax = plt.subplots(1, 1, figsize=[width, width])
# aplot_field(overlaps[0], ax, window, IF_FP=1, task=0, GRID_TEST=model.GRID_TEST, IF_CBAR=1);
# save_fig('flow_field_cbar_seed_%d' % (seed), GRID_TEST=model.GRID_TEST)
```

``` ipython
for i in range(len(tasks)):
    fig, ax = plt.subplots(1, 1, figsize=[width, width])
    if i==3:
        plot_field(overlaps[0], ax, window, IF_FP=1, task=0, GRID_TEST=model.GRID_TEST, IF_CBAR=1)
        save_fig('flow_field_cbar_seed_%d' % (seed), GRID_TEST=model.GRID_TEST)
    else:
        plot_field(overlaps[i], ax, window, IF_FP=1, task=i, GRID_TEST=model.GRID_TEST, IF_CBAR=0);
        save_fig('flow_field_task_%s_seed_%d' % (tasks[i], seed), GRID_TEST=model.GRID_TEST)
```

``` ipython
fig, ax = plt.subplots(1, len(tasks), figsize=[len(tasks) * width, width])

for i in range(len(tasks)):
    plot_field(overlaps[i], ax[i], window, IF_FP=1, task=i, GRID_TEST=model.GRID_TEST)

save_fig('flow_field_seed_%d' % seed, GRID_TEST=model.GRID_TEST)
plt.show()
```

``` ipython
```