# Notebook Settings

``` ipython
%load_ext autoreload
%autoreload 2
%reload_ext autoreload

%run ../notebooks/setup.py
%matplotlib inline
%config InlineBackend.figure_format = 'png'
```

# Imports

``` ipython
import sys
sys.path.insert(0, '../')

import torch
import pandas as pd
from time import perf_counter

from src.network import Network
from src.plot_utils import plot_con
from src.decode import decode_bump
```

# Helpers

## Connectivity

``` ipython
def plot_eigen(W):
    # Compute eigenvalues
    eigenvalues = torch.linalg.eigvals(W).cpu().numpy()

    # Extract real and imaginary parts
    real_parts = eigenvalues.real
    imag_parts = eigenvalues.imag

    # Plotting
    plt.scatter(real_parts, imag_parts)
    plt.xlabel('Real Part')
    plt.ylabel('Imaginary Part')
    plt.axhline(y=0, color='k', linestyle='--')
    plt.axvline(x=0, color='k', linestyle='--')

    # plt.grid(True, which='both')
    plt.show()
```

## Random

``` ipython
def convert_seconds(seconds):
    h = seconds // 3600
    m = (seconds % 3600) // 60
    s = seconds % 60
    return h, m, s
```

``` ipython
def get_theta(a, b, GM=0, IF_NORM=0):

    if GM:
        b = b - np.dot(b, a) / np.dot(a, a) * a

    if IF_NORM:
        u = a / np.linalg.norm(a)
        v = b / np.linalg.norm(b)
    else:
        u=a
        v=b

    return np.arctan2(v, u)
```

``` ipython
def normalize(v):
    return v / np.linalg.norm(v)

def project(x, u):
    return x * u
# return np.dot(x, u) * u

def sort_by_angle(x, u, v):
    u_hat = normalize(u)
    v_hat = normalize(v)

    x_proj_u = project(x, u_hat)
    x_proj_v = project(x, v_hat)
    # x_proj = x_proj_u + x_proj_v
    theta = np.arctan2(x_proj_v, x_proj_u) + np.pi

    # cos_theta = np.dot(x_proj, u_hat) / np.linalg.norm(x_proj) * u_hat
    # sin_theta = np.dot(x_proj, v_hat) / np.linalg.norm(x_proj) * v_hat
    # theta = np.arctan2(sin_theta, cos_theta)

    # Pair up each element of x with the corresponding angle
    # x_angle_pairs = list(zip(x, theta))

    # Sort based on the angle
    # x_angle_pairs.sort(key=lambda pair: pair[1])

    # Extract the sorted elements
    # sorted_x = [pair[0] for pair in x_angle_pairs]

    return theta
```

``` ipython
def get_idx(model):
    ksi = model.PHI0.cpu().detach().numpy()
    print(ksi.shape)

    idx = np.arange(0, len(ksi[0]))
    theta = get_theta(ksi[0], ksi[2], GM=0, IF_NORM=0)

    return theta.argsort()
```

``` ipython
def get_overlap(model, rates):
    ksi = model.PHI0.cpu().detach().numpy()
    return rates @ ksi.T / rates.shape[-1]

```

# Manifolds

## 1 population

``` ipython
REPO_ROOT = "/home/leon/models/NeuroTorch"
model = Network('config_1pop.yml', REPO_ROOT, VERBOSE=1, DEVICE='cuda:1', LIVE_FF_UPDATE=0, GAIN=2)
```

### Training

``` ipython
theta_list = torch.linspace(0, 2.0 * torch.pi, model.Na[0] + 1)[:-1]

Wfb = torch.stack((torch.cos(theta_list), torch.sin(theta_list)))
Wfb = Wfb.to('cuda:1')
print('Wfb:', Wfb.shape)

N_TRAIN = 40
A_psi = 1.2 # / torch.sqrt(model.Ka[0])
psi_list = torch.linspace(0, 2.0 * torch.pi, N_TRAIN + 1)[:-1]

z = torch.stack((torch.cos(psi_list), torch.sin(psi_list))).T
z = A_psi * z.to('cuda:1')

print('z:', z.shape)

ff_input = model.Ja0[0] + z @ Wfb
print('input:', ff_input.shape)

# need to make the input a sequence
ff_input = ff_input.unsqueeze(1).expand(ff_input.shape[0], model.N_STEPS, ff_input.shape[-1])
print('reshaped input:', ff_input.shape)
```

``` ipython
rates = model(ff_input, REC_LAST_ONLY=0)
```

``` ipython
plt.imshow(rates[-3].cpu().numpy().T, aspect='auto', origin='lower',vmin=0, vmax=2, cmap='jet')
plt.show()
```

``` ipython
plt.plot(rates[0,:,:3].cpu().numpy())
plt.show()
```

``` ipython
print(z.shape)
theta = get_theta(z.T[0].cpu().numpy(), z.T[1].cpu().numpy(), GM=0, IF_NORM=0)
idx = theta.argsort()
print(theta.shape)
rates_ord = rates[..., idx]
print(rates_ord.shape)
```

``` ipython
plt.imshow(rates_ord[0].cpu().numpy().T, aspect='auto', origin='lower', vmax=2, cmap='jet')
plt.show()
```

``` ipython
print('rates:', rates.shape)
```

``` ipython
# Wout = Phi @ Cinv @ zbar
# where Phi are the steady state rates phi(theta_i, psi_m) (N x M)
# Cinv is the correlations btw rates PhiT @ Phi (MxM)

Phi = rates[:,-1].T
Corr = torch.inverse(Phi.T @ Phi)
print('Phi', Phi.shape, 'Corr', Corr.shape, 'z', z.shape)

Wout = Phi @ Corr @ z
print('Wout', Wout.shape)

Wstruct = Wfb.T @ Wout.T
print('W', Wstruct.shape)
```

``` ipython
# Cij = model.Wab_T.cpu().detach().numpy()
Cij = Wstruct.cpu().numpy()
```

``` ipython
plt.figure(figsize=(12, 5))  # Set the figure size (width, height) in inches

ax1 = plt.subplot2grid((2, 3), (0, 0), rowspan=2)
im = ax1.imshow(Cij, cmap='jet', aspect=1)
ax1.set_xlabel("Presynaptic")
ax1.set_ylabel("Postsynaptic")

# Second column, first row
ax2 = plt.subplot2grid((2, 3), (0, 1))
Kj = np.sum(Cij, axis=0)  # sum over pres
ax2.plot(Kj)
# ax2.set_xticklabels([])
ax2.set_ylabel("$K_j$")

# # Second column, second row
ax3 = plt.subplot2grid((2, 3), (1, 1))
Ki = np.sum(Cij, axis=1)  # sum over pres
ax3.plot(Kj)
ax3.set_ylabel("$K_i$")

ax4 = plt.subplot2grid((2, 3), (0, 2), rowspan=2)
diags = []
for i in range(int(Cij.shape[0] / 2)):
    diags.append(np.trace(Cij, offset=i) / Cij.shape[0])
    diags = np.array(diags)
    ax4.plot(diags)
    ax4.set_xlabel("Neuron #")
    ax4.set_ylabel("$P_{ij}$")

plt.tight_layout()
plt.show()

```

### Testing

``` ipython
model.Wab_T[model.slices[0],model.slices[0]].add_(Wstruct.T);
```

``` ipython
model.TASK = 'None'
rates = model()
```

``` ipython
print(rates.shape)
```

``` ipython
plt.imshow(rates[0].cpu().numpy().T, aspect='auto', origin='lower', vmax=1, cmap='jet')
plt.show()
```

``` ipython
m0, m1, phi = decode_bump(rates.cpu().numpy(), axis=-1)
print(m0.shape)
```

``` ipython
plt.plot(phi.T * 180 / np.pi)
plt.show()
```

``` ipython

```

## 2 populations

### Training

``` ipython
theta_list = torch.linspace(0, 2.0 * torch.pi, model.Na[0] + 1)[:-1]

Wfb = torch.stack((torch.cos(theta_list), torch.sin(theta_list)))
Wfb = Wfb.to('cuda:1')
print('Wfb:', Wfb.shape)

N_TRAIN = 10
A_psi = 1.0 # / torch.sqrt(model.Ka[0])
psi_list = torch.linspace(0, 2.0 * torch.pi, N_TRAIN + 1)[:-1]

z = torch.stack((torch.cos(psi_list), torch.sin(psi_list))).T
z = A_psi * z.to('cuda:1')

print('z:', z.shape)

input_E = model.Ja0[0] * torch.sqrt(model.Ka[0]) * model.M0 + z @ Wfb
print('input:', input_E.shape)

# need to make the input a sequence
input_E = input_E.unsqueeze(1).expand(input_E.shape[0], model.N_STEPS, input_E.shape[-1])
print('reshaped input:', input_E.shape)

# need to add inhibitory inputs
N_I = model.N_NEURON - model.Na[0]
input_I = torch.ones(input_E.size(0), input_E.size(1), N_I, device='cuda:1') * model.Ja0[1] * torch.sqrt(model.Ka[0]) * model.M0
print(input_I.shape)

ff_input = torch.cat((input_E, input_I), dim=-1)
print('reshaped input:', ff_input.shape)
```

``` ipython
rates = model(ff_input, REC_LAST_ONLY=0)
```

``` ipython
plt.imshow(rates[-3].cpu().numpy().T, aspect='auto', origin='lower', vmax=10, cmap='jet')
plt.show()
```

``` ipython
print(z.shape)
theta = get_theta(z.T[0].cpu().numpy(), z.T[1].cpu().numpy(), GM=0, IF_NORM=0)
idx = theta.argsort()
print(theta.shape)
rates_ord = rates[..., idx]
print(rates_ord.shape)
```

``` ipython
plt.imshow(rates_ord[0].cpu().numpy().T, aspect='auto', origin='lower', vmax=10, cmap='jet')
plt.show()
```

``` ipython
print('rates:', rates.shape)
```

``` ipython
# Wout = Phi @ Cinv @ zbar
# where Phi are the steady state rates phi(theta_i, psi_m) (N x M)
# Cinv is the correlations btw rates PhiT @ Phi (MxM)

Phi = rates[:,-1].T
Corr = Phi.T @ Phi
print('Phi', Phi.shape, 'Corr', Corr.shape, 'z', z.shape)

Wout = Phi @ Corr @ z
print('Wout', Wout.shape)

Wstruct = Wfb.T @ Wout.T
print('W', Wstruct.shape)
```

``` ipython
Cij = Wstruct.cpu().detach().numpy()
```

``` ipython
plt.figure(figsize=(12, 5))  # Set the figure size (width, height) in inches

ax1 = plt.subplot2grid((2, 3), (0, 0), rowspan=2)
im = ax1.imshow(Cij, cmap='jet', aspect=1)
ax1.set_xlabel("Presynaptic")
ax1.set_ylabel("Postsynaptic")

# Second column, first row
ax2 = plt.subplot2grid((2, 3), (0, 1))
Kj = np.sum(Cij, axis=0)  # sum over pres
ax2.plot(Kj)
# ax2.set_xticklabels([])
ax2.set_ylabel("$K_j$")

# # Second column, second row
ax3 = plt.subplot2grid((2, 3), (1, 1))
Ki = np.sum(Cij, axis=1)  # sum over pres
ax3.plot(Kj)
ax3.set_ylabel("$K_i$")

ax4 = plt.subplot2grid((2, 3), (0, 2), rowspan=2)
diags = []
for i in range(int(Cij.shape[0] / 2)):
    diags.append(np.trace(Cij, offset=i) / Cij.shape[0])
diags = np.array(diags)
ax4.plot(diags)
ax4.set_xlabel("Neuron #")
ax4.set_ylabel("$P_{ij}$")

plt.tight_layout()
plt.show()

```

### Testing

``` ipython
model.Wab_T[model.slices[0],model.slices[0]].add_(Wstruct.T);
```

``` ipython
rates = model()
```

``` ipython
print(rates.shape)
```

``` ipython
plt.imshow(rates[0].cpu().numpy().T, aspect='auto', origin='lower', vmax=10, cmap='jet')
plt.show()
```