In [None]:
%load_ext autoreload
%autoreload 2

from scipy.io import loadmat
from pathlib import Path

from recording import Recording
import utils as utl

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

import pandas as pd
# from helpers_notebook import show_full_df

# global settings
bin_size = 0.1
thresh_rate = 1 # firing rate in Hz
thresh_sw = 0.5 # spike width in ms
thresh_trials = 0.9 # number of trials in %

# example

## load data

In [None]:
# load recording as Recording object
p_mat = './data/zidan/ALM_ALM/MK22_20230301/MK22_20230301_2H2_g0_JRC_units_probe1.mat'
rec = Recording(p_mat)

In [None]:
# get binned spikes defined `bin_size`, if possible load from disk
p_prq = rec._path_name(f'bin{bin_size}.parquet')
df = rec._assign_df(p_prq, rec._calculate_psth, {'bin_size': bin_size})
df = utl.rate_and_time(df, bin_size)

In [None]:
unit = 1 # plot example unit

utl.plot_unit(df, rec, bin_size, unit=unit) # whole session

xlims = (1000, 1100) # zoomed in
utl.plot_unit(df, rec, bin_size, unit=unit, xlims=xlims)

## filter units and trials

In [None]:
# get units with rate > `thresh_rate` and spike width > `thresh_sw`
unts_rate = utl.filter_rates(df, thresh_rate)
unts_sw = utl.filter_spike_width(rec.df_unt, thresh_sw)

# drop unit one at a time
m = rec.df_unt.loc[:, 'unit'].isin(unts_rate & unts_sw)
unts_range, trls_range = utl.filter_trials(rec.df_unt.loc[m], thresh=thresh_trials, plot=False)

# only keep units/trials meeting all criteria
u = unts_rate & unts_sw & unts_range
t = trls_range
print(f'{len(u)} units and {len(t)} trials survived')
m = df.loc[:, 'unit'].isin(u) & df.loc[:, 'trial'].isin(t)
# utl.plot_missing(df.loc[m], bin_size, vmax=5)

# save for later
rec.units = u
rec.trials = t

## get X and Y population

In [None]:
# generate features and targets (mean pre-cue firing rate gets subtracted)
dfx, dfy = utl.select_data(rec, bin_size=.1)
dfx_mat, dfy_mat = utl.get_matrix(dfx), utl.get_matrix(dfy)
X, Y = dfx_mat.values, dfy_mat.values

In [None]:
# ridge
# lin_scores = utl.linear_regression(X, Y)
ridge_mods = utl.ridge_regression(X, Y, alphas=np.logspace(4, 6, 100))
ridge_mod = ridge_mods.best_estimator_
utl.plot_gridsearch(ridge_mods, 'param_mod__alpha')

In [None]:
# RRR
rr_mods = utl.reduced_rank_regression(X, Y, Y.shape[1])
rr_mod = rr_mods.best_estimator_
utl.plot_gridsearch(rr_mods, 'param_mod__r', logscale=False)

### PSTH and all vs one

In [None]:
## all vs one

# get scores for best alpha
ridge_scores = {}
alpha = ridge_mod.get_params()['mod__alpha']

for y, u in zip(Y.T, dfy_mat.columns):
    mods = utl.ridge_regression(X, y, alphas=[alpha])
    mod = mods.best_estimator_
    ridge_scores[u] = mod.score(X, y)

# prediction
Y_pred = ridge_mod.predict(X)
dfy_pred = utl.matrix2df(Y_pred, dfy)

# plot true and predicted with score
utl.plot_psth(dfy, bin_size, df2=dfy_pred, scores=ridge_scores)

# ALM-ALM

In [None]:
# load
rec1 = Recording('./data/zidan/ALM_ALM/MK22_20230301/MK22_20230301_2H2_g0_JRC_units_probe2.mat')
rec2 = Recording('./data/zidan/ALM_ALM/MK22_20230301/MK22_20230301_2H2_g0_JRC_units_probe1.mat')

# process
for rec in [rec1, rec2]:

    # bin spikes
    p_prq = rec._path_name(f'bin{bin_size}.parquet')
    df = rec._assign_df(p_prq, rec._calculate_psth, {'bin_size': bin_size})
    df = utl.rate_and_time(df, bin_size)

    # filter units/trials
    unts_rate = utl.filter_rates(df, thresh_rate)
    unts_sw = utl.filter_spike_width(rec.df_unt, thresh_sw)

    m = rec.df_unt.loc[:, 'unit'].isin(unts_rate & unts_sw)
    unts_range, trls_range = utl.filter_trials(rec.df_unt.loc[m], thresh=thresh_trials, plot=False)

    rec.units = unts_rate & unts_sw & unts_range
    rec.trials = trls_range

# features and targets 
X, Y = utl.get_xy(rec1, rec2=rec2, bin_size=bin_size)

In [None]:
# ridge
# lin_scores = utl.linear_regression(X, Y)
ridge_mods = utl.ridge_regression(X, Y, alphas=np.logspace(-4, 4, 100))
ridge_mod = ridge_mods.best_estimator_
utl.plot_gridsearch(ridge_mods, 'param_mod__alpha')

In [None]:
# RRR
rr_mods = utl.reduced_rank_regression(X, Y, Y.shape[1])
rr_mod = rr_mods.best_estimator_
utl.plot_gridsearch(rr_mods, 'param_mod__r', logscale=False)

# ALM-STR

In [None]:
# imec0: STR
rec2 = Recording('./data/zidan/ALM_STR/ZY78_20211015/ZY78_20211015NP_g0_JRC_units.mat')
rec1 = Recording('./data/zidan/ALM_STR/ZY78_20211015/ZY78_20211015NP_g0_imec0_JRC_units.mat')

# process
for rec in [rec1, rec2]:

    # bin spikes
    p_prq = rec._path_name(f'bin{bin_size}.parquet')
    df = rec._assign_df(p_prq, rec._calculate_psth, {'bin_size': bin_size})
    df = utl.rate_and_time(df, bin_size)

    # filter units/trials
    unts_rate = utl.filter_rates(df, thresh_rate)
    unts_sw = utl.filter_spike_width(rec.df_unt, thresh_sw)

    m = rec.df_unt.loc[:, 'unit'].isin(unts_rate & unts_sw)
    unts_range, trls_range = utl.filter_trials(rec.df_unt.loc[m], thresh=thresh_trials, plot=False)

    rec.units = unts_rate & unts_sw & unts_range
    rec.trials = trls_range

# features and targets 
X, Y = utl.get_xy(rec1, rec2=rec2, bin_size=bin_size)

In [None]:
# ridge
# lin_scores = utl.linear_regression(X, Y)
ridge_mods = utl.ridge_regression(X, Y, alphas=np.logspace(-3, 3, 100))
ridge_mod = ridge_mods.best_estimator_
utl.plot_gridsearch(ridge_mods, 'param_mod__alpha')

In [None]:
# RRR
rr_mods = utl.reduced_rank_regression(X, Y, Y.shape[1])
rr_mod = rr_mods.best_estimator_
utl.plot_gridsearch(rr_mods, 'param_mod__r', logscale=False)

# ALM-Thal

In [None]:
# imec0: Thal
rec2 = Recording('./data/zidan/ALM_Thal/ZY113_20220617/ZY113_20220617_NPH2_g0_JRC_units.mat')
rec1 = Recording('./data/zidan/ALM_Thal/ZY113_20220617/ZY113_20220617_NPH2_g0_imec0_JRC_units.mat')

# process
for rec in [rec1, rec2]:

    # bin spikes
    p_prq = rec._path_name(f'bin{bin_size}.parquet')
    df = rec._assign_df(p_prq, rec._calculate_psth, {'bin_size': bin_size})
    df = utl.rate_and_time(df, bin_size)

    # filter units/trials
    unts_rate = utl.filter_rates(df, thresh_rate)
    unts_sw = utl.filter_spike_width(rec.df_unt, thresh_sw)

    m = rec.df_unt.loc[:, 'unit'].isin(unts_rate & unts_sw)
    unts_range, trls_range = utl.filter_trials(rec.df_unt.loc[m], thresh=thresh_trials, plot=False)

    rec.units = unts_rate & unts_sw & unts_range
    rec.trials = trls_range

# features and targets 
X, Y = utl.get_xy(rec1, rec2=rec2, bin_size=bin_size)

In [None]:
# ridge
# lin_scores = utl.linear_regression(X, Y)
ridge_mods = utl.ridge_regression(X, Y, alphas=np.logspace(-4, 4, 100))
ridge_mod = ridge_mods.best_estimator_
utl.plot_gridsearch(ridge_mods, 'param_mod__alpha')

In [None]:
# RRR
rr_mods = utl.reduced_rank_regression(X, Y, Y.shape[1])
rr_mod = rr_mods.best_estimator_
utl.plot_gridsearch(rr_mods, 'param_mod__r', logscale=False)

# local batch processing

In [None]:
# TODO
p_mats = Path('./data/zidan/').glob('**/*.mat')

bin_size = 0.1

for p_mat in p_mats:
    print(p_mat)
    rec = Recording(p_mat)

    p_out = rec._path_name(f'bin{bin_size}.parquet')
    df = rec._assign_df(p_out, rec._calculate_psth, {'align': None, 'bin_size': bin_size})

    df = utl.smooth_psth(df, 1, bin_size)
    p_png = rec._path_name(f'bin{bin_size}.png')
    utl.plot_missing(df, bin_size, path=p_png, vmax=5)
    


# remote batch processing

In [None]:
# batch process all recordings

p = Path(r'X:\Users\Zidan\DataForNico')

if p.is_file():
    for p_mat in p.glob('**/*mat'):
        
        # load recording and create missing dataframes
        print(p_mat)
        rec = Recording(p_mat)
        
        # plot PSTH
        rec.plot_psth(xlims=(-1.5, 2), filter_size=50, path=rec.path_psth.with_suffix('.png'))

# other

## raw vs trial

In [None]:
p = './data/zidan/ALM_ALM/MK22_20230301/MK22_20230301_2H2_g0_JRC_units_probe2.mat'
rec = Recording(p, force_overwrite=True)

from scipy.io import loadmat
m = loadmat(p, squeeze_me=True, struct_as_record=False)

import matplotlib.pylab as plt
import seaborn as sns

In [None]:
df = pd.DataFrame()

for i, u in enumerate(m['unit']):
    t = vars(u)['RawSpikeTimes']

    d = pd.DataFrame(data={
        'T': t,
        'unit': i + 1,
    })


    df = pd.concat([df, d], ignore_index=True)

raw = df.loc[ df.loc[:, 'unit'] == 1 ].loc[:, 'T'].values

In [None]:
spk = vars(m['unit'][0])['SpikeTimes']
idx = vars(m['unit'][0])['Trial_idx_of_spike']

In [None]:
t0s = np.array([vars(i)['onset'] for i in m['trial_info']]) / 2.5e4
behavior = vars(vars(m['unit'][0])['Behavior'])
t_lck = behavior['First_lick']
t_cue = behavior['Sample_start']

In [None]:
sns.histplot(t_cue - t_lck)

In [None]:
fig, ax = plt.subplots(figsize=(10, 4))

for i in range(2, 30):
    mask = idx == i
    s = spk[mask]


    # t0 = df_trl.loc[ df_trl.loc[:, 'trial'] == i ].loc[:, 'T_0'].item()
    l = t_lck[i + 1]
    c = t_cue[i + 1]
    t0 = t0s[i-1]
    mask = (s > (c-2)) & (s < ( l + 2) )
    s = s[mask]

    ax.axvline(t0, c='C0', ls=':', lw=1)
    ax.axvline(t0 + c, c='C1', ls=':', lw=1)
    ax.axvline(l + t0, c='C2', ls=':', lw=1)

    ax.eventplot(s + t0, lineoffsets=i, color=f'C{i}')

mask = raw < 300
ax.eventplot(raw[mask], lineoffsets=i + 1, ls='-')

ax.set_xlabel('times [s]')
ax.set_ylabel('trial index')
ax.set_xlim((0, 60))

## new spike processing

In [None]:
p = './data/zidan/ALM_ALM/MK22_20230301/MK22_20230301_2H2_g0_JRC_units_probe2.mat'
rec = Recording(p, force_overwrite=False)

In [None]:
fig, ax = plt.subplots()

df = rec.df_spk.groupby('unit').get_group(14)
for trl, d in df.groupby('trial'):
    x = d.loc[:, 't'].values
    ax.eventplot(x, lineoffsets=trl)
ax.set_xlim((-1, 4))
ax.set_ylim((0, 300))

In [None]:
fig, ax = plt.subplots()

df = rec.df_spk.groupby('unit').get_group(14)
for trl, d in df.groupby('trial'):
    x = d.loc[:, 't'].values
    ax.eventplot(x, lineoffsets=trl)
ax.set_xlim((-1, 4))
# ax.set_ylim((0, 300))

In [None]:
# rec.df_spk = rec._load_spike_times()
# rec.df_psth = rec._calculate_psth()
rec.plot_psth(unts=[14], xlims=(-1, 6))

In [None]:
rec.df_spk = rec._load_spike_times()
print('loaded')
rec.df_psth = rec._calculate_psth()
print('loaded')
rec.plot_psth(unts=[14], xlims=(-1, 1))