Notebook Settings
=================

``` ipython
%load_ext autoreload
%autoreload 2
%reload_ext autoreload

golden_ratio = (5**.5 - 1) / 2
width = 6
height = width * golden_ratio

%matplotlib inline
%config InlineBackend.figure_format = 'png'

```

Imports
=======

``` ipython
import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

from tqdm import tqdm
from scipy.stats import binned_statistic
```

Helpers
=======

``` ipython
import pickle as pkl
import os

def pkl_save(obj, name, path="."):
    os.makedirs(path, exist_ok=True)
    destination = path + "/" + name + ".pkl"
    print("saving to", destination)
    pkl.dump(obj, open(destination, "wb"))


def pkl_load(name, path="."):
    source = path + "/" + name + '.pkl'
    print('loading from', source)
    return pkl.load(open( source, "rb"))

```

``` ipython
T_STIM_ON = [1.0, 5.0]
T_STIM_OFF = [2.0, 6.0]

def add_vlines(ax=None):

    if ax is None:
        for i in range(len(T_STIM_ON)):
            plt.axvspan(T_STIM_ON[i], T_STIM_OFF[i], alpha=0.25)
    else:
        for i in range(len(T_STIM_ON)):
            ax.axvspan(T_STIM_ON[i], T_STIM_OFF[i], alpha=0.25)
```

``` ipython
def decode_bump_torch(signal, axis=-1, device=None, RET_TENSOR=1):

    if not torch.is_tensor(signal):
        signal = torch.as_tensor(signal, dtype=torch.float32, device=device or 'cpu')
    else:
        signal = signal.to(dtype=torch.float32, device=device or 'cpu')

    if axis != -1 and signal.ndim != 1:
        signal = signal.movedim(axis, -1)

    m0 = torch.nanmean(signal, dim=-1)
    N = signal.shape[-1]
    k = 1
    n = torch.arange(N, device=signal.device)
    twiddle = torch.exp(2j * torch.pi * k * n / N) # my convention is + here

    # | Input Signal           | DFT Exponent sign | Decoded phase is     |
    # |-----------------------|-------------------|----------------------|
    # | cos(θ + φ₀)           | e^{-2πikn/N}      |  +φ₀                 |
    # | cos(θ - φ₀)           | e^{-2πikn/N}      |  –φ₀ |
    # | cos(θ - φ₀)           | e^{+2πikn/N}      |  +φ₀                 |

    dft1 = (signal * twiddle).sum(dim=-1) / N
    m1 = 2 * torch.abs(dft1)
    phi = torch.angle(dft1) % (2 * torch.pi)

    if RET_TENSOR:
        return m0, m1, phi

    return m0.cpu().detach().numpy(), m1.cpu().detach().numpy(), phi.cpu().detach().numpy()
```

``` ipython
def get_error_curr_prev(phi, curr, prev, reference):
    target_loc = curr  * 180.0 / np.pi

    rel_loc = prev - curr
    rel_loc = (rel_loc + np.pi) % (2 * np.pi) - np.pi
    rel_loc *= 180 / np.pi

    ref_loc = reference[np.newaxis] - curr
    ref_loc = (ref_loc + np.pi) % (2 * np.pi) - np.pi
    ref_loc *= 180 / np.pi

    error_curr = phi - curr
    error_curr = (error_curr + np.pi) % (2 * np.pi) - np.pi
    error_curr *= 180 / np.pi

    return np.vstack(target_loc), np.vstack(rel_loc), np.vstack(ref_loc), np.array(error_curr)
```

``` ipython
def get_end_point(errors, stim_start_idx):

    end_point = []
    for k in range(errors.shape[1]):
            idx = stim_start_idx[1][k]-1
            end_point.append(errors[:, k, idx])

    return np.array(end_point).T.reshape(-1, 1)
```

``` ipython
def get_correct_error(n_bins, df, error_type='rel_loc', thresh=25):
    import numpy as np
    import pandas as pd

    # 1. Threshold errors
    if thresh is not None:
        data = df[(df['errors'] >= -thresh) & (df['errors'] <= thresh)].copy()
    else:
        data = df.copy()

    # 2. Bin target locations
    bin_edges = np.linspace(0, 360, n_bins + 1)
    data['bin_target'] = pd.cut(data['target_loc'], bins=bin_edges, include_lowest=True)
    mean_errors_per_bin = data.groupby('bin_target', observed=True)['errors'].mean()

    # 3. Remove mean error per target location (for rel_loc)
    if error_type == 'rel_loc':
        data['adjusted_errors'] = data['errors'] - data['bin_target'].map(mean_errors_per_bin).astype(float)
    else:
        data['adjusted_errors'] = data['errors']

    # 4. Bin by error_type for both full versions
    data['bin_error'] = pd.cut(data[error_type], bins=n_bins)
    bin_error = data.groupby('bin_error', observed=True)['adjusted_errors'].agg(['mean', 'sem']).reset_index()
    edges = bin_error['bin_error'].cat.categories
    centers = (edges.left + edges.right) / 2

    # 5. Flipped error absolute analysis
    if error_type == 'rel_loc':
        # Bin abs(rel_loc) from 0 to 180
        data['error_abs'] = np.abs(data[error_type])
        data['bin_error_abs'] = pd.cut(data['error_abs'], bins=n_bins, include_lowest=True)
        # Flip so all directions use same sign
        data['adjusted_errors_abs'] = data['adjusted_errors'] * np.sign(data[error_type])
    else:
        # Bin abs(ref_loc) from 0 to 90
        data['error_abs'] = np.abs(data[error_type])
        data = data[data['error_abs'] <= 90.0]  # Only 0-90
        data['bin_error_abs'] = pd.cut(data['error_abs'], bins=n_bins, include_lowest=True)
        # Flip so all directions use same sign for ref_loc
        data['adjusted_errors_abs'] = data['adjusted_errors'] * np.sign(data[error_type])

    bin_error_abs = data.groupby('bin_error_abs', observed=True)['adjusted_errors_abs'].agg(['mean', 'sem']).reset_index()
    edges_abs = bin_error_abs['bin_error_abs'].cat.categories
    centers_abs = (edges_abs.left + edges_abs.right) / 2

    return centers, bin_error, centers_abs, bin_error_abs
```

The Data
========

100 consecutive trials, 768 sessions, 750 neurons Each trial last 6s,
tuned stimulus is on from 1 to 2, inhibitory stimulus (to delete the
bump) from 5 to 6.

``` ipython
ref_list = pkl_load('ref_list') # references (n_trials, 1) in
prev_list = pkl_load('prev_list') # previous stimulus location (n_trials, n_session, 1) in radians
curr_list = pkl_load('curr_list') # current stimulus location (n_trials, n_session, 1) in radians
rates_list = pkl_load('rates_list') # rates of the simulations (n_trials, n_session, n_time, n_neurons) in Hz
```

``` ipython
print('\n', ref_list.shape, prev_list.shape, curr_list.shape, rates_list.shape)
```

``` example
 torch.Size([128, 1]) (100, 128, 1) (100, 128, 1) (100, 128, 61, 750)
```

``` ipython
stim_start_idx = pkl_load('stim_start_idx')
print(stim_start_idx.shape) # size (N_STIMULI, N_SESSION) and here 2 stimulus per trial but timing could be changed in each trial/session
```

``` ipython
DURATION = rates_list.shape[2] / 10
N_NEURONS = rates_list.shape[-1]
N_TRIALS = rates_list.shape[0]
N_SESSION = rates_list.shape[1]
```

``` ipython
fig, ax = plt.subplots(1, 2, figsize=[2*width, height])

ax[0].hist(curr_list[:, 0, 0] * 180 / np.pi, bins=20, histtype='step')
ax[0].set_xlabel('Stimulus Loc. (°)')
ax[0].set_ylabel('Count')

ax[1].hist(ref_list[:, 0] * 180.0 / np.pi, bins=20, histtype='step')
ax[1].set_xlabel('Reference Loc. (°)')
ax[1].set_ylabel('Count')

plt.show()
```

``` ipython
fig, ax = plt.subplots(1, 1, figsize=[2.5*width, height])

n_trials = 10
idx = np.random.randint(0, N_SESSION)
rates = np.vstack(rates_list[:n_trials, idx]).T

plt.imshow(rates, aspect='auto', cmap='jet', vmin=0, vmax=5, origin='lower', extent=[0, n_trials * DURATION, 0, N_NEURONS])
plt.ylabel('Pref. Location (°)')
plt.yticks(np.linspace(0, N_NEURONS, 5), np.linspace(0, 360, 5).astype(int))
plt.xlabel('Time (s)')

plt.show()
```

``` ipython
```

Errors
======

``` ipython
n_half = N_TRIALS // 2
```

``` ipython
curr_ini =  curr_list[:n_half]
curr_last = curr_list[-n_half:]

prev_ini =  prev_list[:n_half]
prev_last = prev_list[-n_half:]
print(curr_ini.shape, prev_ini.shape)
```

``` ipython
_, _, phi_ini = decode_bump_torch(rates_list[:n_half, ...], axis=-1)
print(phi_ini.shape)
```

``` ipython
_, _, phi_last = decode_bump_torch(rates_list[-n_half:, ...], axis=-1)
print(phi_last.shape)
```

``` ipython
targ_ini, rel_ini, ref_ini, errors_ini = get_error_curr_prev(phi_ini, curr_ini, prev_ini, ref_list)
targ_last, rel_last, ref_last, errors_last = get_error_curr_prev(phi_last, curr_last, prev_last, ref_list)
print(targ_ini.shape, rel_ini.shape, ref_ini.shape, errors_ini.shape)
```

``` ipython
print(ref_ini.shape, rel_ini.shape, errors_ini.shape)
```

``` ipython
end_point_ini = get_end_point(errors_ini, stim_start_idx)
end_point_last = get_end_point(errors_last, stim_start_idx)
print(end_point_ini.shape, end_point_last.shape)
```

``` ipython
fig, ax = plt.subplots(1, 2, figsize=[2*width, height])

ax[0].hist(end_point_ini[:, 0], bins=30)
ax[0].set_xlabel('Errors First Half(°)')

ax[1].hist(end_point_last[:, 0], bins=30)
ax[1].set_xlabel('Errors Second Half(°)')
plt.show()
```

``` ipython
time_points = np.linspace(0, DURATION, errors_ini.shape[-1])
idx = np.random.randint(errors_ini.shape[1], size=100)

fig, ax = plt.subplots(1, 2, figsize=[2*width, height])
ax[0].plot(time_points, errors_ini[0][idx].T, alpha=.4)
add_vlines(ax[0])

ax[0].set_xlabel('t')
ax[0].set_ylabel('Error first Half(°)')

ax[1].plot(time_points, errors_last[0][idx].T, alpha=.4)
add_vlines(ax[1])

ax[1].set_xlabel('t')
ax[1].set_ylabel('Error 2nd Half (°)')
plt.show()
```

``` ipython
```

Serial Bias Curves First/Second Half
====================================

``` ipython
print(targ_ini.shape, rel_ini.shape, ref_ini.shape, end_point_ini.shape)
```

``` ipython
n_bins = 16
data_ini = pd.DataFrame({'target_loc': targ_ini[:, -1], 'rel_loc': rel_ini[:, -1], 'ref_loc': ref_ini[:, -1], 'errors': end_point_ini[:, 0]})
data_last = pd.DataFrame({'target_loc': targ_last[:, -1], 'rel_loc': rel_last[:, -1], 'ref_loc': ref_last[:, -1], 'errors': end_point_last[:, 0]})
```

``` ipython
centers_ini, bin_rel_ini, centers_abs_ini, bin_rel_abs_ini = get_correct_error(n_bins, data_ini)
centers_last, bin_rel_last, centers_abs_last, bin_rel_abs_last = get_correct_error(n_bins, data_last)
```

``` ipython
fig, ax = plt.subplots(1, 2, figsize=[2*width, height])

ax[0].plot(centers_ini, bin_rel_ini['mean'], 'r', label='First half')
ax[0].fill_between(centers_ini, bin_rel_ini['mean'] - bin_rel_ini['sem'], bin_rel_ini['mean'] + bin_rel_ini['sem'], color='r', alpha=0.2)

ax[0].plot(centers_last, bin_rel_last['mean'], 'b', label='Last half')
ax[0].fill_between(centers_last, bin_rel_last['mean'] - bin_rel_last['sem'], bin_rel_last['mean'] + bin_rel_last['sem'], color='b', alpha=0.2)

ax[0].axhline(0, color='k', linestyle=":")
ax[0].set_xlabel('Rel. Loc. (°)')
ax[0].set_ylabel('Error (°)')
ax[0].set_xticks(np.linspace(-180, 180, 5))

ax[1].plot(centers_abs_ini, bin_rel_abs_ini['mean'], 'r', label='First half')
ax[1].fill_between(centers_abs_ini, bin_rel_abs_ini['mean'] - bin_rel_abs_ini['sem'], bin_rel_abs_ini['mean'] + bin_rel_abs_ini['sem'], color='r', alpha=0.2)

ax[1].plot(centers_abs_last, bin_rel_abs_last['mean'], 'b', label='Last half')
ax[1].fill_between(centers_abs_last, bin_rel_abs_last['mean'] - bin_rel_abs_last['sem'], bin_rel_abs_last['mean'] + bin_rel_abs_last['sem'], color='b', alpha=0.2)

ax[1].axhline(0, color='k', linestyle=":")
ax[1].set_xlabel('Rel. Loc. (°)')
ax[1].set_ylabel('Flip. Error (°)')
ax[1].legend(fontsize=12)
ax[1].set_xticks(np.linspace(0, 180, 3))

plt.tight_layout()
plt.show()
```

``` ipython
centers_ref_ini, bin_ref_ini, centers_ref_abs_ini, bin_ref_abs_ini = get_correct_error(n_bins, data_ini, error_type='ref_loc')
centers_ref_last, bin_ref_last, centers_ref_abs_last, bin_ref_abs_last = get_correct_error(n_bins, data_last, error_type='ref_loc')
```

``` ipython
fig, ax = plt.subplots(1, 2, figsize=[2*width, height])

ax[0].plot(centers_ref_ini, bin_ref_ini['mean'], 'r', label='First half')
ax[0].fill_between(centers_ref_ini, bin_ref_ini['mean'] - bin_ref_ini['sem'], bin_ref_ini['mean'] + bin_ref_ini['sem'], color='r', alpha=0.2)

ax[0].plot(centers_ref_last, bin_ref_last['mean'], 'b', label='Last half')
ax[0].fill_between(centers_ref_last, bin_ref_last['mean'] - bin_ref_last['sem'], bin_ref_last['mean'] + bin_ref_last['sem'], color='b', alpha=0.2)

ax[0].axhline(0, color='k', linestyle=":")
ax[0].set_xlabel('Ref. Loc. (°)')
ax[0].set_ylabel('Error (°)')
ax[0].set_xticks(np.linspace(-180, 180, 5))

ax[1].plot(centers_ref_abs_ini, bin_ref_abs_ini['mean'], 'r', label='First half')
ax[1].fill_between(centers_ref_abs_ini, bin_ref_abs_ini['mean'] - bin_ref_abs_ini['sem'], bin_ref_abs_ini['mean'] + bin_ref_abs_ini['sem'], color='r', alpha=0.2)

ax[1].plot(centers_ref_abs_last, bin_ref_abs_last['mean'], 'b', label='Last half')
ax[1].fill_between(centers_ref_abs_last, bin_ref_abs_last['mean'] - bin_ref_abs_last['sem'], bin_ref_abs_last['mean'] + bin_ref_abs_last['sem'], color='b', alpha=0.2)

ax[1].axhline(0, color='k', linestyle=":")
ax[1].set_xlabel('Ref. Loc. (°)')
ax[1].set_ylabel('Flip. Error (°)')
ax[1].legend(fontsize=12)
ax[1].set_xticks(np.linspace(0, 90, 3))

plt.tight_layout()
plt.show()
```

``` ipython
```

Bias Evolution along a session
==============================

``` ipython
_, _, phi_list = decode_bump_torch(rates_list, axis=-1)
print(phi_list.shape)
```

``` ipython
cmap = plt.get_cmap('Blues')
colors = [cmap( (i+1) / phi_list.shape[0] ) for i in range(phi_list.shape[0])]

n_bins = 8

serial_list = []
ref_bias_list = []

fig, ax = plt.subplots(1, 2, figsize=[2*width, height])

for i in range(phi_list.shape[0]): # trial by trial
    targ_trial, rel_trial, ref_trial, errors_trial = get_error_curr_prev(phi_list[i, np.newaxis], curr_list[i, np.newaxis], prev_list[i, np.newaxis], ref_list)

    end_point_trial = get_end_point(errors_trial, stim_start_idx)

    # print(targ_trial.shape, rel_trial.shape, ref_trial.shape, errors_trial.shape, end_point_trial.shape)

    data = pd.DataFrame({'target_loc': targ_trial[:, -1], 'rel_loc': rel_trial[:, -1], 'ref_loc': ref_trial[:, -1], 'errors': end_point_trial[:, 0]})

    centers, bin_rel, centers_abs, bin_rel_abs = get_correct_error(n_bins, data)
    centers_ref, bin_ref, centers_abs, bin_ref_abs = get_correct_error(n_bins, data, error_type='ref_loc')

    ax[0].plot(centers, bin_rel['mean'], color=colors[i], alpha=1)
    ax[0].axhline(0, ls='--', color='k')
    ax[0].set_xlabel('Rel. Loc. (°)')
    ax[0].set_ylabel('Error (°)')
    ax[0].set_xticks(np.linspace(-180, 180, 5))

    ax[1].plot(centers_ref, bin_ref['mean'], color=colors[i], alpha=1)
    ax[1].axhline(0, ls='--', color='k')
    ax[1].set_xlabel('Ref. Loc. (°)')
    ax[1].set_ylabel('Error (°)')
    ax[1].set_xticks(np.linspace(-180, 180, 5))

    idx_max = np.argmax(abs(bin_rel_abs['mean']))
    serial_max = bin_rel_abs['mean'][idx_max]
    serial_std = bin_rel_abs['sem'][idx_max]

    serial_list.append([serial_max, serial_std])

    idx_max = np.argmax(abs(bin_ref_abs['mean']))
    ref_max = bin_ref_abs['mean'][idx_max]
    ref_std = bin_ref_abs['sem'][idx_max]

    ref_bias_list.append([ref_max, ref_std])

serial_list = np.array(serial_list).T
ref_bias_list = np.array(ref_bias_list).T
print(serial_list.shape)
plt.show()
```

``` ipython
xtrial = np.linspace(0, N_TRIALS, serial_list.shape[1])

fig, ax = plt.subplots(1, 2, figsize=[2*width, height])

from scipy.ndimage import gaussian_filter1d

s0 = 5
ax[0].plot(xtrial, gaussian_filter1d(serial_list[0], s0), '-')
ax[0].fill_between(xtrial, gaussian_filter1d(serial_list[0] - serial_list[1], s0), gaussian_filter1d(serial_list[0] + serial_list[1], s0), color='b', alpha=0.2)
ax[0].axhline(0, ls='--', color='k')

ax[0].set_xlabel('Trial #')
ax[0].set_ylabel('Serial Bias (°)')

ax[1].plot(xtrial, gaussian_filter1d(ref_bias_list[0], s0), '-')
ax[1].fill_between(xtrial, gaussian_filter1d(ref_bias_list[0] - ref_bias_list[1], s0), gaussian_filter1d(ref_bias_list[0] + ref_bias_list[1], s0), color='b', alpha=0.2)
ax[1].axhline(0, ls='--', color='k')

ax[1].set_xlabel('Trial #')
ax[1].set_ylabel('Ref. Bias (°)')

plt.show()
```

``` ipython
```