In [None]:
%load_ext autoreload
%autoreload 2

from scipy.io import loadmat
from pathlib import Path

from recording import Recording
import utils as utl

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

import pandas as pd
# from helpers_notebook import show_full_df

# within ALM

In [None]:
# load example recording
p_mat = './data/zidan/ALM_ALM/MK22_20230301/MK22_20230301_2H2_g0_JRC_units_probe1.mat'
rec = Recording(p_mat)

## select units based on trial ranges and rate

In [None]:
# select bin size and filter width 
bin_size, filter_sigma = 0.1, 1
p_prq = rec._path_name(f'bin{bin_size}.parquet')
df = rec._assign_df(p_prq, rec._calculate_psth, {'bin_size': bin_size})
df = utl.smooth_psth(df, filter_sigma, bin_size)

In [None]:
# data for all units
utl.plot_missing(df, bin_size, vmax=5)

In [None]:
# drop unit one at a time
unts_range, trls_range = utl.filter_trial_ranges(rec, plot=True)

In [None]:
# plot binned spike rates
m = df.loc[:, 'unit'].isin(unts_range) & df.loc[:, 'trial'].isin(trls_range)
utl.plot_missing(df.loc[ m ], bin_size, vmax=5)

In [None]:
unts_rate = utl.filter_rates(df.loc[ m ], 2, plot=True)

In [None]:
u = unts_range & unts_rate
t = trls_range
m = (df.loc[:, 'unit'].isin(u) & df.loc[:, 'trial'].isin(t))
print(f'{len(u)} units and {len(t)} trials survived')

rec.units = u
rec.trials = t

utl.plot_missing(df.loc[ m ], bin_size, vmax=5)

## visualize bin size and filter width

In [None]:
unit = 1

# whole session
utl.plot_unit(df, rec, bin_size, unit=unit)

# zoomed in
xlims = (1000, 1100)
utl.plot_unit(df, rec, bin_size, unit=unit, xlims=xlims)

In [None]:
# visualize different bin sizes
bin_size, unit = .1, 1
filter_sigmas = [2, 1, .5, .1]
fig, ax = plt.subplots(figsize=(20, 5))

# # trial start times
# for t0 in rec.df_trl.loc[:, 'T_0']:
#     ax.axvline(t0, c='gray', lw=.5)

# plot different filter sigma
for filter_sigma in filter_sigmas:
    p_prq = rec._path_name(f'bin{bin_size}.parquet')
    df = rec._assign_df(p_prq, rec._calculate_psth, {'bin_size': bin_size})
    df = utl.smooth_psth(df, filter_sigma, bin_size)

    d = df.groupby('unit').get_group(unit)

    x = d.loc[:, 'bins_'].values * bin_size
    y = d.loc[:, 'fr'].values

    ax.plot(x, y, label=f'{filter_sigma}')

ax.margins(x=0)
ax.set_xlabel('time [s]')
ax.set_ylabel('firing rate [Hz]')
ax.legend(title='filter sigma [s]')
ax.set_xlim(xlims)
ax.set_title(f'unit {unit} | bin size {bin_size}')

fig.tight_layout()

In [None]:
# visualize different bin sizes
filter_sigma, unit = .00001, 1
bin_sizes = [.5, .1, .05]
fig, ax = plt.subplots(figsize=(20, 5))

# # trial start times
# for t0 in rec.df_trl.loc[:, 'T_0']:
#     ax.axvline(t0, c='gray', lw=.5)

# plot different bin sizes
for bin_size in bin_sizes:
    p_prq = rec._path_name(f'bin{bin_size}.parquet')
    df = rec._assign_df(p_prq, rec._calculate_psth, {'bin_size': bin_size})
    df = utl.smooth_psth(df, filter_sigma, bin_size)

    d = df.groupby('unit').get_group(unit)

    x = d.loc[:, 'bins_'].values * bin_size
    y = d.loc[:, 'fr'].values

    ax.plot(x, y, label=f'{bin_size}')

ax.margins(x=0)
ax.set_xlabel('time [s]')
ax.set_ylabel('firing rate [Hz]')
ax.legend(title='bin size [s]')
ax.set_xlim(xlims)
ax.set_title(f'unit {unit} | filter sigma {filter_sigma} s')

fig.tight_layout()

In [None]:
# visualize different bin sizes
filter_sigma, unit = .1, 1
bin_sizes = [.5, .1, .05]
fig, ax = plt.subplots(figsize=(20, 5))

# # trial start times
# for t0 in rec.df_trl.loc[:, 'T_0']:
#     ax.axvline(t0, c='gray', lw=.5)

# plot different bin sizes
for bin_size in bin_sizes:
    p_prq = rec._path_name(f'bin{bin_size}.parquet')
    df = rec._assign_df(p_prq, rec._calculate_psth, {'bin_size': bin_size})
    df = utl.smooth_psth(df, filter_sigma, bin_size)

    d = df.groupby('unit').get_group(unit)

    x = d.loc[:, 'bins_'].values * bin_size
    y = d.loc[:, 'fr'].values

    ax.plot(x, y, label=f'{bin_size}')

ax.margins(x=0)
ax.set_xlabel('time [s]')
ax.set_ylabel('firing rate [Hz]')
ax.legend(title='bin size [s]')
ax.set_xlim(xlims)
ax.set_title(f'unit {unit} | filter sigma {filter_sigma} s')

fig.tight_layout()

## get X and Y population

In [None]:
# rate distribution of train and test set
X, Y = utl.get_xy(rec, bin_size=.1, filter_sigma=.0001)
utl.plot_rate_dist(X, Y)

In [None]:
# ridge
# lin_scores = utl.linear_regression(X, Y)
ridge_mods = utl.ridge_regression(X, Y, alphas=np.logspace(0, 2, 100))
ridge_mod = ridge_mods.best_estimator_
utl.plot_gridsearch(ridge_mods, 'param_mod__alpha')

In [None]:
# RRR
rr_mods = utl.reduced_rank_regression(X, Y, Y.shape[1])
rr_mod = rr_mods.best_estimator_
utl.plot_gridsearch(rr_mods, 'param_mod__r', logscale=False)

# ALM-ALM

In [None]:
rec1 = Recording('./data/zidan/ALM_ALM/MK22_20230301/MK22_20230301_2H2_g0_JRC_units_probe2.mat')
rec2 = Recording('./data/zidan/ALM_ALM/MK22_20230301/MK22_20230301_2H2_g0_JRC_units_probe1.mat')

# select bin size and filter width 
bin_size, filter_sigma = 0.1, .0001

for rec in [rec1, rec2]:
    
    p_prq = rec._path_name(f'bin{bin_size}.parquet')
    df = rec._assign_df(p_prq, rec._calculate_psth, {'bin_size': bin_size})
    df = utl.smooth_psth(df, filter_sigma, bin_size)

    # utl.plot_missing(df, bin_size, vmax=5)

    # filter based on trials ranges
    unts_range, trls_range = utl.filter_trial_ranges(rec, thresh=.9, plot=True)

    # plot binned spike rates
    m = df.loc[:, 'unit'].isin(unts_range) & df.loc[:, 'trial'].isin(trls_range)
    # utl.plot_missing(df.loc[ m ], bin_size, vmax=5)

    unts_rate = utl.filter_rates(df.loc[ m ], 2, plot=True)

    rec.units = unts_rate
    rec.trials = trls_range

In [None]:
X, Y = utl.get_xy(rec1, rec2=rec2, bin_size=bin_size, filter_sigma=filter_sigma)
utl.plot_rate_dist(X, Y)

In [None]:
# ridge
# lin_scores = utl.linear_regression(X, Y)
ridge_mods = utl.ridge_regression(X, Y, alphas=np.logspace(-4, 4, 100))
ridge_mod = ridge_mods.best_estimator_
utl.plot_gridsearch(ridge_mods, 'param_mod__alpha')

In [None]:
# RRR
rr_mods = utl.reduced_rank_regression(X, Y, Y.shape[1])
rr_mod = rr_mods.best_estimator_
utl.plot_gridsearch(rr_mods, 'param_mod__r', logscale=False)

# ALM-STR

In [None]:
# imec0: STR
rec2 = Recording('./data/zidan/ALM_STR/ZY78_20211015/ZY78_20211015NP_g0_JRC_units.mat')
rec1 = Recording('./data/zidan/ALM_STR/ZY78_20211015/ZY78_20211015NP_g0_imec0_JRC_units.mat')

In [None]:
# select bin size and filter width 
bin_size, filter_sigma = 0.1, 1

for rec in [rec1, rec2]:
    
    # calculate psth with bin, filter, avg trial times
    p_prq = rec._path_name(f'bin{bin_size}.parquet')
    df = rec._assign_df(p_prq, rec._calculate_psth, {'bin_size': bin_size})
    df = utl.smooth_psth(df, filter_sigma, bin_size)

    # utl.plot_missing(df, bin_size, vmax=5)

    # filter based on trials ranges
    unts_range, trls_range = utl.filter_trial_ranges(rec, thresh=.9, plot=True)

    # plot binned spike rates
    m = df.loc[:, 'unit'].isin(unts_range) & df.loc[:, 'trial'].isin(trls_range)
    # utl.plot_missing(df.loc[ m ], bin_size, vmax=5)

    unts_rate = utl.filter_rates(df.loc[ m ], 2, plot=True)

    rec.units = unts_rate
    rec.trials = trls_range


In [None]:
X, Y = utl.get_xy(rec1, rec2=rec2, bin_size=bin_size, filter_sigma=filter_sigma)
utl.plot_rate_dist(X, Y)

In [None]:
# ridge
# lin_scores = utl.linear_regression(X, Y)
ridge_mods = utl.ridge_regression(X, Y, alphas=np.logspace(-3, 3, 100))
ridge_mod = ridge_mods.best_estimator_
utl.plot_gridsearch(ridge_mods, 'param_mod__alpha')

In [None]:
# RRR
rr_mods = utl.reduced_rank_regression(X, Y, Y.shape[1])
rr_mod = rr_mods.best_estimator_
utl.plot_gridsearch(rr_mods, 'param_mod__r', logscale=False)

# ALM-Thal

In [None]:
# imec0: Thal
rec2 = Recording('./data/zidan/ALM_Thal/ZY113_20220617/ZY113_20220617_NPH2_g0_JRC_units.mat')
rec1 = Recording('./data/zidan/ALM_Thal/ZY113_20220617/ZY113_20220617_NPH2_g0_imec0_JRC_units.mat')

# select bin size and filter width 
bin_size, filter_sigma = 0.1, 1

for rec in [rec1, rec2]:
    
    p_prq = rec._path_name(f'bin{bin_size}.parquet')
    df = rec._assign_df(p_prq, rec._calculate_psth, {'bin_size': bin_size})
    df = utl.smooth_psth(df, filter_sigma, bin_size)

    # utl.plot_missing(df, bin_size, vmax=5)

    # filter based on trials ranges
    unts_range, trls_range = utl.filter_trial_ranges(rec, thresh=.9, plot=True)

    # plot binned spike rates
    m = df.loc[:, 'unit'].isin(unts_range) & df.loc[:, 'trial'].isin(trls_range)
    # utl.plot_missing(df.loc[ m ], bin_size, vmax=5)

    unts_rate = utl.filter_rates(df.loc[ m ], 2, plot=True)

    rec.units = unts_rate
    rec.trials = trls_range

In [None]:
X, Y = utl.get_xy(rec1, rec2=rec2, bin_size=bin_size, filter_sigma=filter_sigma)
utl.plot_rate_dist(X, Y)

In [None]:
# ridge
# lin_scores = utl.linear_regression(X, Y)
ridge_mods = utl.ridge_regression(X, Y, alphas=np.logspace(-4, 4, 100))
ridge_mod = ridge_mods.best_estimator_
utl.plot_gridsearch(ridge_mods, 'param_mod__alpha')

In [None]:
# RRR
rr_mods = utl.reduced_rank_regression(X, Y, Y.shape[1])
rr_mod = rr_mods.best_estimator_
utl.plot_gridsearch(rr_mods, 'param_mod__r', logscale=False)

# local batch processing

In [None]:
p_mats = Path('./data/zidan/').glob('**/*.mat')

bin_size = 0.1

for p_mat in p_mats:
    print(p_mat)
    rec = Recording(p_mat)

    p_out = rec._path_name(f'bin{bin_size}.parquet')
    df = rec._assign_df(p_out, rec._calculate_psth, {'align': None, 'bin_size': bin_size})

    df = utl.smooth_psth(df, 1, bin_size)
    p_png = rec._path_name(f'bin{bin_size}.png')
    utl.plot_missing(df, bin_size, path=p_png, vmax=5)
    


# remote batch processing

In [None]:
# batch process all recordings

p = Path(r'X:\Users\Zidan\DataForNico')

if p.is_file():
    for p_mat in p.glob('**/*mat'):
        
        # load recording and create missing dataframes
        print(p_mat)
        rec = Recording(p_mat)
        
        # plot PSTH
        rec.plot_psth(xlims=(-1.5, 2), filter_size=50, path=rec.path_psth.with_suffix('.png'))

# other

## raw vs trial

In [None]:
p = './data/zidan/ALM_ALM/MK22_20230301/MK22_20230301_2H2_g0_JRC_units_probe2.mat'
rec = Recording(p, force_overwrite=True)

from scipy.io import loadmat
m = loadmat(p, squeeze_me=True, struct_as_record=False)

import matplotlib.pylab as plt
import seaborn as sns

In [None]:
df = pd.DataFrame()

for i, u in enumerate(m['unit']):
    t = vars(u)['RawSpikeTimes']

    d = pd.DataFrame(data={
        'T': t,
        'unit': i + 1,
    })


    df = pd.concat([df, d], ignore_index=True)

raw = df.loc[ df.loc[:, 'unit'] == 1 ].loc[:, 'T'].values

In [None]:
spk = vars(m['unit'][0])['SpikeTimes']
idx = vars(m['unit'][0])['Trial_idx_of_spike']

In [None]:
t0s = np.array([vars(i)['onset'] for i in m['trial_info']]) / 2.5e4
behavior = vars(vars(m['unit'][0])['Behavior'])
t_lck = behavior['First_lick']
t_cue = behavior['Sample_start']

In [None]:
sns.histplot(t_cue - t_lck)

In [None]:
fig, ax = plt.subplots(figsize=(10, 4))

for i in range(2, 30):
    mask = idx == i
    s = spk[mask]


    # t0 = df_trl.loc[ df_trl.loc[:, 'trial'] == i ].loc[:, 'T_0'].item()
    l = t_lck[i + 1]
    c = t_cue[i + 1]
    t0 = t0s[i-1]
    mask = (s > (c-2)) & (s < ( l + 2) )
    s = s[mask]

    ax.axvline(t0, c='C0', ls=':', lw=1)
    ax.axvline(t0 + c, c='C1', ls=':', lw=1)
    ax.axvline(l + t0, c='C2', ls=':', lw=1)

    ax.eventplot(s + t0, lineoffsets=i, color=f'C{i}')

mask = raw < 300
ax.eventplot(raw[mask], lineoffsets=i + 1, ls='-')

ax.set_xlabel('times [s]')
ax.set_ylabel('trial index')
ax.set_xlim((0, 60))

## new spike processing

In [None]:
p = './data/zidan/ALM_ALM/MK22_20230301/MK22_20230301_2H2_g0_JRC_units_probe2.mat'
rec = Recording(p, force_overwrite=False)

In [None]:
fig, ax = plt.subplots()

df = rec.df_spk.groupby('unit').get_group(14)
for trl, d in df.groupby('trial'):
    x = d.loc[:, 't'].values
    ax.eventplot(x, lineoffsets=trl)
ax.set_xlim((-1, 4))
ax.set_ylim((0, 300))

In [None]:
fig, ax = plt.subplots()

df = rec.df_spk.groupby('unit').get_group(14)
for trl, d in df.groupby('trial'):
    x = d.loc[:, 't'].values
    ax.eventplot(x, lineoffsets=trl)
ax.set_xlim((-1, 4))
# ax.set_ylim((0, 300))

In [None]:
# rec.df_spk = rec._load_spike_times()
# rec.df_psth = rec._calculate_psth()
rec.plot_psth(unts=[14], xlims=(-1, 6))

In [None]:
rec.df_spk = rec._load_spike_times()
print('loaded')
rec.df_psth = rec._calculate_psth()
print('loaded')
rec.plot_psth(unts=[14], xlims=(-1, 1))