Notebook Settings
=================

``` ipython
%load_ext autoreload
%autoreload 2
%reload_ext autoreload

%run ../../../notebooks/setup.py
%matplotlib inline
%config InlineBackend.figure_format = 'png'

REPO_ROOT = "/home/leon/models/NeuroFlame"
pal = sns.color_palette("tab10")
```

Imports
=======

``` ipython
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, TensorDataset, DataLoader
from scipy.stats import binned_statistic
```

``` ipython
import sys
sys.path.insert(0, '../../../')

import pandas as pd
import torch.nn as nn
from time import perf_counter
from scipy.stats import circmean

from src.network import Network
from src.plot_utils import plot_con
from src.decode import decode_bump, circcvl, decode_bump_torch
from src.lr_utils import masked_normalize, clamp_tensor, normalize_tensor
```

``` ipython
import pickle as pkl
import os

def pkl_save(obj, name, path="."):
    os.makedirs(path, exist_ok=True)
    destination = path + "/" + name + ".pkl"
    print("saving to", destination)
    pkl.dump(obj, open(destination, "wb"))


def pkl_load(name, path="."):
    source = path + "/" + name + '.pkl'
    print('loading from', source)
    return pkl.load(open( source, "rb"))

```

Helpers
=======

``` ipython
def map2center(angles):
    """Map angles from [0, 2π] to [-π, π] using PyTorch tensors."""
    return np.where(angles > np.pi, angles - 2 * np.pi, angles)

def map2pos(angles):
    """Map angles from [-π, π] to [0, 2π] using PyTorch tensors."""
    return np.where(angles < 0, angles + 2 * np.pi, angles)
```

``` ipython
def maptens2center(angles):
    """Map angles from [0, 2π] to [-π, π] using PyTorch tensors."""
    return torch.where(angles > torch.pi, angles - 2 * torch.pi, angles)

def maptens2pos(angles):
    """Map angles from [-π, π] to [0, 2π] using PyTorch tensors."""
    return torch.where(angles < 0, angles + 2 * torch.pi, angles)
```

``` ipython
def add_vlines(model, ax=None):

    if ax is None:
        for i in range(len(model.T_STIM_ON)):
            plt.axvspan(model.T_STIM_ON[i], model.T_STIM_OFF[i], alpha=0.25)
    else:
        for i in range(len(model.T_STIM_ON)):
            ax.axvspan(model.T_STIM_ON[i], model.T_STIM_OFF[i], alpha=0.25)

```

``` ipython
import torch
import numpy as np

def generate_weighted_phase_samples(N_BATCH, angles, preferred_angle, sigma):
    # Convert angles list to a tensor
    angles_tensor = torch.tensor(angles)

    # Calculate Gaussian probability distribution centered at preferred_angle
    probs = np.exp(-0.5 * ((angles - preferred_angle) / sigma) ** 2)
    probs /= probs.sum()  # Normalize to get probabilities

    # Create a categorical distribution from the computed probabilities
    distribution = torch.distributions.Categorical(torch.tensor(probs))

    # Sample from the distribution
    indices = distribution.sample((N_BATCH,))

    # Map indices to angles and reshape to (N_BATCH, 1)
    phase_samples = angles_tensor[indices].reshape(N_BATCH, 1)

    return phase_samples
```

``` ipython
import torch
import numpy as np
import matplotlib.pyplot as plt

def continuous_biased_phases(N_BATCH, preferred_angle, sigma):
    # Generate samples from a normal distribution using PyTorch
    phase_samples = torch.normal(mean=preferred_angle, std=sigma, size=(N_BATCH, 1))

    # Normalize angles to the range [0, 360)
    phase_samples = phase_samples % 360

    return phase_samples
```

``` ipython
import torch
import numpy as np
import matplotlib.pyplot as plt

def continuous_bimodal_phases(N_BATCH, preferred_angle, sigma):
    # Sample half from preferred_angle and half from preferred_angle + 180
    half_batch = N_BATCH // 2

    # Sample from preferred_angle
    samples_1 = torch.normal(mean=preferred_angle, std=sigma, size=(half_batch, 1))

    # Sample from preferred_angle + 180
    samples_2 = torch.normal(mean=(preferred_angle + 180) % 360, std=sigma, size=(N_BATCH - half_batch, 1))

    # Combine samples and wrap around 360
    phase_samples = torch.cat((samples_1, samples_2), dim=0) % 360

    return phase_samples

# Example usage
# N_BATCH = 500
# preferred_angle = 45
# sigma = 45

# samples = continuous_bimodal_phases(N_BATCH, preferred_angle, sigma)

# plt.hist(samples.numpy(), bins='auto', density=True)
# plt.xlabel('Phase (degrees)')
# plt.ylabel('Probability Density')
# plt.title('Bimodal Distribution of Phases')
# plt.show()
```

Model
=====

``` ipython
kwargs = {
    'DURATION': 15.0,
    'T_STIM_ON': [1.0, 5.0, 7.0, 11.0],
    'T_STIM_OFF': [2.0, 6.0, 8.0, 12.0],
    'I0': [1.0, -10.0, 1.0, -10.0],
    'PHI0': [180.0, 180, 180, 180],
    'SIGMA0': [1.0, 0.0, 1.0, 0.0],
}
```

``` ipython
REPO_ROOT = "/home/leon/models/NeuroFlame"
conf_name = "train_odr_EI.yml"
DEVICE = 'cuda:1'
seed = np.random.randint(0, 1e6)

seed = 1
print('seed', seed)
```

``` ipython
N_BATCH = 128*4
model = Network(conf_name, REPO_ROOT, VERBOSE=0, DEVICE=DEVICE, SEED=seed, N_BATCH=N_BATCH, **kwargs)
```

``` ipython
model_state_dict = torch.load('../models/odr/odr_%d.pth' % seed)
model.load_state_dict(model_state_dict);
model.eval();
```

Batching Inputs
===============

``` ipython
print(N_BATCH)
model.N_BATCH = N_BATCH
model.PHI0 = torch.randint(low=0, high=360, size=(N_BATCH, len(model.I0), 1), device=DEVICE, dtype=torch.float)

ff_input = model.init_ff_input()
m0, m1, phase = decode_bump_torch(ff_input[..., model.slices[0]], axis=-1)
```

``` ipython
print(model.start_indices.shape)
```

``` ipython
rates_tensor = model.forward(ff_input=ff_input)# [..., ::3]
rates = rates_tensor.cpu().detach().numpy()
print('rates', rates.shape)
```

``` ipython
m0, m1, phi = decode_bump(rates, axis=-1)
```

Results
=======

Rates
-----

``` ipython
fig, ax = plt.subplots(1, 3, figsize=[2.5*width, height])

idx = np.random.randint(0, model.N_BATCH)
ax[0].imshow(rates[idx].T, aspect='auto', cmap='jet', vmin=0, vmax=2, origin='lower', extent=[0, model.DURATION, 0, model.Na[0].cpu()])
ax[0].set_ylabel('Pref. Location (°)')
ax[0].set_yticks(np.linspace(0, model.Na[0].cpu(), 5), np.linspace(0, 360, 5).astype(int))
ax[0].set_xlabel('Time (s)')

xtime = np.linspace(0, model.DURATION, phi.shape[-1])
idx = np.random.randint(0, model.N_BATCH, 8)
ax[1].plot(xtime, m1[idx].T)
ax[1].set_ylabel('m1 (Hz)')
ax[1].set_xlabel('Time (s)')
add_vlines(model, ax[1])

ax[2].plot(xtime, phi[idx].T * 180 / np.pi, alpha=0.5)
ax[2].set_yticks(np.linspace(0, 360, 5).astype(int), np.linspace(0, 360, 5).astype(int))
ax[2].set_ylabel('Bump Center (°)')
ax[2].set_xlabel('Time (s)')
add_vlines(model, ax[2])
plt.show()
```

``` ipython
PHI0 = model.PHI0.cpu().detach().numpy() * 180.0 / np.pi
print(PHI0.shape)

idx = np.random.randint(0, 32)
print(PHI0[idx, 0, 0])
window_size = int((model.N_STIM_ON[1]-model.N_STEADY) / model.N_WINDOW)
print(phi[idx, window_size] * 180 / np.pi)
```

errors
------

``` ipython
target_loc = PHI0[:, 2]

rel_loc = (PHI0[:, 0] - PHI0[:, 2]) * np.pi / 180.0
rel_loc = (rel_loc + np.pi) % (2 * np.pi) - np.pi
rel_loc *= 180 / np.pi

error_curr = (phi - PHI0[:, 2] * np.pi / 180.0)
error_curr = (error_curr + np.pi) % (2 * np.pi) - np.pi
error_curr *= 180 / np.pi

error_prev = ((phi - PHI0[:, 0] * np.pi / 180.0))
error_prev = (error_prev + np.pi) % (2 * np.pi) - np.pi
error_prev *= 180 / np.pi

errors = np.stack((error_prev, error_curr))
print(errors.shape, target_loc.shape, rel_loc.shape)
```

``` ipython
time_points = np.linspace(0, model.DURATION, errors.shape[-1])
idx = np.random.randint(errors.shape[1], size=100)

fig, ax = plt.subplots(1, 2, figsize=[2*width, height])
ax[0].plot(time_points, errors[0][idx].T, alpha=.4)
add_vlines(model, ax[0])

ax[0].set_xlabel('t')
ax[0].set_ylabel('prev. error (°)')

ax[1].plot(time_points, errors[1][idx].T, alpha=.4)
add_vlines(model, ax[1])

ax[1].set_xlabel('t')
ax[1].set_ylabel('curr. error (°)')
plt.show()
```

``` ipython
print(phi.shape, PHI0.shape, model.start_indices.shape)
stim_start = (model.DT * (model.start_indices - model.N_STEADY)).cpu().numpy()
stim_start_idx = ((model.start_indices - model.N_STEADY) / model.N_WINDOW - 1).to(int).cpu().numpy()
print(stim_start[1][:5], model.T_STIM_ON)
print(stim_start_idx[1][:5])
```

``` ipython
time_points = np.linspace(0, model.DURATION, errors.shape[-1])
idx = np.random.randint(errors.shape[0])

fig, ax = plt.subplots(1, 2, figsize=[2*width, height])
ax[0].plot(time_points, errors[0][idx].T)
ax[0].set_xlabel('t')
ax[0].set_ylabel('prev. error (°)')

ax[0].axvline(stim_start[0][idx], ls='--', c='k')
ax[0].axvline(stim_start[1][idx], ls='--', c='k')
ax[0].axvline(stim_start[2][idx], ls='--', c='k')
ax[0].axvline(stim_start[3][idx], ls='--', c='k')

ax[1].plot(time_points, errors[1][idx].T)

ax[1].axvline(stim_start[0][idx], ls='--', c='k')
ax[1].axvline(stim_start[1][idx], ls='--', c='k')
ax[1].axvline(stim_start[2][idx], ls='--', c='k')
ax[1].axvline(stim_start[3][idx], ls='--', c='k')

ax[1].set_xlabel('t')
ax[1].set_ylabel('curr. error (°)')
plt.show()
```

``` ipython
end_point = []
for j in [1, 3]:
    end_ = []
    for i in range(errors.shape[1]):
        idx = stim_start_idx[j][i]
        end_.append(errors[1][i][idx])

    end_point.append(end_)

end_point = np.array(end_point)
print(end_point.shape)
```

``` ipython
fig, ax = plt.subplots(1, 3, figsize=[3*width, height])
ax[0].hist(target_loc, bins='auto')
ax[0].set_xlabel('Targets (°)')

ax[1].hist(end_point[0], bins='auto')
ax[1].set_xlabel('Prev. Errors (°)')

ax[2].hist(end_point[1], bins='auto')
ax[2].set_xlabel('Curr. Errors (°)')
plt.show()
```

``` ipython
```

biases
------

``` ipython
data = pd.DataFrame({'target_loc': target_loc[:, -1], 'rel_loc': rel_loc[:, -1], 'errors': end_point[1]})
```

``` ipython
fig, ax = plt.subplots(1, 2, figsize=[2*width, height])

n_bins=16
ax[0].plot(data['target_loc'], data['errors'], 'o', alpha=.1)
ax[0].set_xlabel('Target Loc. (°)')
ax[0].set_ylabel('Error (°)')

stt = binned_statistic(data['target_loc'], data['errors'], statistic='mean', bins=n_bins, range=[0, 360])
dstt = np.mean(np.diff(stt.bin_edges))
ax[0].plot(stt.bin_edges[:-1]+dstt/2,stt.statistic,'r')

ax[0].axhline(color='k', linestyle=":")

ax[1].plot(data['rel_loc'], data['errors'], 'o', alpha=.1)
ax[1].set_xlabel('Rel. Loc. (°)')
ax[1].set_ylabel('Error (°)')

stt = binned_statistic(data['rel_loc'], data['errors'], statistic='mean', bins=n_bins, range=[-180, 180])
dstt = np.mean(np.diff(stt.bin_edges))
ax[1].plot(stt.bin_edges[:-1]+dstt/2, stt.statistic, 'b')

# plt.savefig('../figures/figs/christos/uncorr_biases.svg', dpi=300)
plt.show()
```

``` ipython
n_bins = 16
angle_min = 0
angle_max = 360

bin_edges = np.linspace(angle_min, angle_max, n_bins + 1)
data['bin_target'] = pd.cut(data['target_loc'], bins=bin_edges, include_lowest=True)

mean_errors_per_bin = data.groupby('bin_target')['errors'].mean()
data['adjusted_errors'] = data.apply(
    lambda row: row['errors'] - mean_errors_per_bin.loc[row['bin_target']],
    axis=1
)

bin_target = data.groupby('bin_target')['adjusted_errors'].agg(['mean', 'sem']).reset_index()
edges = bin_target['bin_target'].cat.categories
target_centers = (edges.left + edges.right) / 2

data['bin_rel'] = pd.cut(data['rel_loc'], bins=n_bins)
bin_rel = data.groupby('bin_rel')['adjusted_errors'].agg(['mean', 'sem']).reset_index()

edges = bin_rel['bin_rel'].cat.categories
centers = (edges.left + edges.right) / 2
```

``` ipython
fig, ax = plt.subplots(1, 2, figsize=[2*width, height])
ax[0].plot(centers, bin_target['mean'], 'b')
ax[0].fill_between(centers,
                   bin_target['mean'] - bin_target['sem'],
                   bin_target['mean'] + bin_target['sem'],
                   color='b', alpha=0.2)

ax[0].axhline(color='k', linestyle=":")
ax[0].set_xlabel('Target Loc. (°)')
ax[0].set_ylabel('Corrected Error (°)')

ax[1].plot(centers, bin_rel['mean'], 'b')
ax[1].fill_between(centers,
                bin_rel['mean'] - bin_rel['sem'],
                bin_rel['mean'] + bin_rel['sem'],
                color='b', alpha=0.2)

ax[1].axhline(color='k', linestyle=":")
ax[1].set_xlabel('Rel. Loc. (°)')
ax[1].set_ylabel('Corrected Error (°)')

plt.show()
```

``` ipython
#pkl_save(data, 'df_naive_%d' %seed, path="./figures/odr")
```

``` ipython
```