In [None]:
%load_ext autoreload
%autoreload 2

from scipy.io import loadmat
from pathlib import Path

from recording import Recording
import utils as utl

import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("whitegrid")
import numpy as np

import pandas as pd

# global settings
params = {
    'bin_size'      : 0.1,   # bin size in s
    'thresh_rate'   : 1,     # min firing rate in Hz
    'thresh_sw'     : 0.5,   # max spike width in ms
    'thresh_trials' : 0.9,   # number of trials to keep in %
    'scoring'       : 'r2',  # see https://scikit-learn.org/stable/modules/model_evaluation.html
    'signal'        : 'fr',
}

# example

## select units: single region

In [None]:
# single ALM probe
rec = Recording('./data/zidan/ALM_ALM/MK22_20230301/MK22_20230301_2H2_g0_JRC_units_probe1.mat')

# bin and filter
utl.preproc_rec(rec, params)

# features and targets 
dfx, dfy = utl.select_data(rec)
X, Y, basis_time = utl.get_matrices(dfx, dfy, params['signal'])

## select units: two regions

In [None]:
# # ALM-ALM
# rec1 = Recording('./data/zidan/ALM_ALM/MK22_20230301/MK22_20230301_2H2_g0_JRC_units_probe2.mat')
# rec2 = Recording('./data/zidan/ALM_ALM/MK22_20230301/MK22_20230301_2H2_g0_JRC_units_probe1.mat')

# ALM-Str (imec0: STR)
rec2 = Recording('./data/zidan/ALM_STR/ZY78_20211015/ZY78_20211015NP_g0_JRC_units.mat')
rec1 = Recording('./data/zidan/ALM_STR/ZY78_20211015/ZY78_20211015NP_g0_imec0_JRC_units.mat', force_overwrite=True)

# # ALM-Thal (imec0: Thal)
# rec2 = Recording('./data/zidan/ALM_Thal/ZY113_20220617/ZY113_20220617_NPH2_g0_JRC_units.mat')
# rec1 = Recording('./data/zidan/ALM_Thal/ZY113_20220617/ZY113_20220617_NPH2_g0_imec0_JRC_units.mat')

# preprocess
utl.preproc_rec(rec1, params)
utl.preproc_rec(rec2, params)

# features and targets 
dfx, dfy = utl.select_data(rec1, rec2=rec2)
X, Y, basis_time = utl.get_matrices(dfx, dfy, params['signal'])

In [None]:
df = pd.pivot_table(data=dfx.groupby('unit').get_group(10), values='dfr', columns='bins', index='trial_')
sns.heatmap(df)

In [None]:
y = dfx.groupby('unit').get_group(10).loc[:, 'dfr']
plt.sub

In [None]:
d = {j: i for i, j in enumerate(dfx.loc[:, 'trial'].unique())}
dfx.loc[:, 'trial_'] = dfx.loc[:, 'trial'].map(d)
dfx

In [None]:
df_piv1 = pd.pivot_table(dfx, values='dfr', index='T', columns='unit')
df_piv2 = pd.pivot_table(dfy, values='dfr', index='T', columns='unit')

In [None]:
df_piv1

In [None]:

fig, ax = plt.subplots(figsize=(25, 4))
y = X[:, 1]
x = np.arange(len(y))
ax.plot(x, y)

In [None]:
fig, ax = plt.subplots(figsize=(15, 7))
sns.heatmap(ax=ax, data=X.T)

In [None]:
x = rec1.df_spk.groupby('unit').get_group((10))
t = x.loc[:, 't'].values
trl = x.loc[:, 'trial'].values

fig, ax = plt.subplots(figsize=(15, 7))
for i, j in enumerate(np.unique(trl)):
    ax.eventplot(t[trl == j], lineoffsets=i)

In [None]:
x = rec2.df_spk.groupby('unit').get_group((31))
t = x.loc[:, 't'].values
trl = x.loc[:, 'trial'].values

fig, ax = plt.subplots(figsize=(15, 7))
for i, j in enumerate(np.unique(trl)):
    ax.eventplot(t[trl == j], lineoffsets=i)

## regression

In [None]:
# linear regression (= ridge with alpha=0)
lin_mods = utl.ridge_regression(X, Y, scoring=params['scoring'], alphas=[0])
lin_mod = lin_mods.best_estimator_

# ridge
ridge_mods = utl.ridge_regression(X, Y, scoring=params['scoring'], alphas=np.logspace(-13, 13, 27))
ridge_mod = ridge_mods.best_estimator_
utl.plot_gridsearch(ridge_mods, 'ridge', other_mods={'linear': lin_mods}, logscale=True)

In [None]:
# RRR
rr_mods = utl.reduced_rank_regression(X, Y, Y.shape[1], scoring=params['scoring'])
rr_mod = rr_mods.best_estimator_
utl.plot_gridsearch(rr_mods, 'reduced rank', other_mods={'linear': lin_mods, 'ridge': ridge_mods}, logscale=False)

In [None]:
# plot
dfy_pred, scores = utl.get_ypred(X, Y, dfx, dfy, basis_time, ridge_mod, scoring=params['scoring'])
utl.plot_mean_response(dfy, params['bin_size'], params['signal'], dfy_pred, scores)

# remote batch processing

In [None]:
# batch process all recordings
p_root = Path(r'X:\Users\Zidan\DataForNico')

if not p_root.is_dir():
    print('Invalid rood dir. No access to network drive?')

else: 
    p_dirs = [ p for p in p_root.glob('*/*/') if p.is_dir() ]

name2region = {
    'MK22_20230301_2H2_g0_JRC_units_probe1.mat' : 'ALM1',
    'MK22_20230301_2H2_g0_JRC_units_probe2.mat' : 'ALM2',
    'MK22_20230303_2H2_g0_JRC_units_probe1.mat' : 'ALM1',
    'MK22_20230303_2H2_g0_JRC_units_probe2.mat' : 'ALM2',
    'MK25_20230314_2H2_g0_JRC_units_probe1.mat' : 'ALM1',
    'MK25_20230314_2H2_g0_JRC_units_probe2.mat' : 'ALM2',
    'ZY78_20211015NP_g0_imec0_JRC_units.mat'    : 'STR',
    'ZY78_20211015NP_g0_JRC_units.mat'          : 'ALM',
    'ZY82_20211028NP_g0_imec0_JRC_units.mat'    : 'STR',
    'ZY82_20211028NP_g0_JRC_units.mat'          : 'ALM',
    'ZY83_20211108NP_g0_imec0_JRC_units.mat'    : 'STR',
    'ZY83_20211108NP_g0_JRC_units.mat'          : 'ALM',
    'ZY113_20220617_NPH2_g0_imec0_JRC_units.mat': 'THA',
    'ZY113_20220617_NPH2_g0_JRC_units.mat'      : 'ALM',
    'ZY113_20220618_NPH2_g0_imec0_JRC_units.mat': 'THA',
    'ZY113_20220618_NPH2_g0_JRC_units.mat'      : 'ALM',
    'ZY113_20220620_NPH2_g0_imec0_JRC_units.mat': 'THA',
    'ZY113_20220620_NPH2_g0_JRC_units.mat'      : 'ALM',
}

## plot response, binned at 1 ms

In [None]:
for p_mat in p_root.glob('**/*mat'):
    
    # load recording and create missing dataframes
    print(p_mat)

    rec = Recording(p_mat, calc_psth=True)
    rec.plot_psth(path=rec.path_psth.with_suffix('.png'))

## fits

In [None]:
def proc_wrapper(p_out, params, dfx, dfy):

    # create folder
    p_out.mkdir(exist_ok=True, parents=True)
    
    # save thresholds
    pd.Series(params).to_json(p_out / 'params.json')

    # features and targets
    X, Y, basis_time = utl.get_matrices(dfx, dfy, params['signal'])

    # linear regression (= ridge with alpha=0)
    lin_mods = utl.ridge_regression(X, Y, scoring=params['scoring'], alphas=[0])
    lin_mod = lin_mods.best_estimator_
    utl.save_cv_results(lin_mods, path=p_out / 'reg_linear.parquet')

    # ridge regression
    ridge_mods = utl.ridge_regression(X, Y, scoring=params['scoring'], alphas=np.logspace(-13, 13, 27))
    ridge_mod = ridge_mods.best_estimator_
    utl.save_cv_results(ridge_mods, path=p_out / 'reg_ridge.parquet')

    # RRR
    rr_mods = utl.reduced_rank_regression(X, Y, Y.shape[1], scoring=params['scoring'])
    rr_mod = rr_mods.best_estimator_
    utl.save_cv_results(rr_mods, path=p_out / 'reg_rrr.parquet')

    # plot regressions
    utl.plot_gridsearch(ridge_mods, 'ridge', other_mods={'linear': lin_mods}, logscale=True, path=p_out / 'reg_ridge.png')
    utl.plot_gridsearch(rr_mods, 'reduced rank', other_mods={'linear': lin_mods, 'ridge': ridge_mods}, logscale=False, path=p_out / 'reg_rrr.png')

    # prediction
    dfy_pred, scores = utl.get_ypred(X, Y, dfx, dfy, basis_time, ridge_mod, scoring=params['scoring'])
    utl.plot_mean_response(dfy, params['bin_size'], params['signal'], dfy_pred, scores, path=p_out / 'pred_ridge.png')
    pd.Series(scores, name=params['scoring']).to_csv(p_out / 'pred_ridge_scores.csv', index=False)


In [None]:
# global settings
params = {
    'bin_size'      : 0.2,   # bin size in s
    'thresh_rate'   : 1,     # min firing rate in Hz
    'thresh_sw'     : 0.5,   # max spike width in ms
    'thresh_trials' : 0.9,   # number of trials to keep in %
    'scoring'       : 'r2',  # see https://scikit-learn.org/stable/modules/model_evaluation.html
    'signal'        : 'fr',
}

name = 'all_depths_raw_0.2'

for p_dir in p_dirs:

    print(p_dir)

    # load recordings
    p_mat1, p_mat2 = [ *p_dir.glob('*.mat')]
    rec1, rec2 = Recording(p_mat1), Recording(p_mat2)
    reg1, reg2 = name2region[p_mat1.name], name2region[p_mat2.name]

    # preoprocessing
    utl.preproc_rec(rec1, params)
    utl.preproc_rec(rec2, params)

    # reg1 -> reg2
    dfx, dfy = utl.select_data(rec1, rec2=rec2)
    p_out = p_dir / f'{name}/{reg1}_{reg2}'
    proc_wrapper(p_out, params, dfx, dfy)

    # reg2 -> reg1
    dfx, dfy = utl.select_data(rec2, rec2=rec1)
    p_out = p_dir / f'{name}/{reg2}_{reg1}'
    proc_wrapper(p_out, params, dfx, dfy)

    # reg1
    dfx, dfy = utl.select_data(rec1)
    p_out = p_dir / f'{name}/{reg1}'
    proc_wrapper(p_out, params, dfx, dfy)

    # reg2
    dfx, dfy = utl.select_data(rec2)
    p_out = p_dir / f'{name}/{reg2}'
    proc_wrapper(p_out, params, dfx, dfy)
    

## plot

In [None]:
def load_scores(ps_csv):

    dfs = []
    for p_csv in ps_csv:

        folder = p_csv.parent.name
        rec = p_csv.parent.parent.parent.name
        pro = p_csv.parent.parent.parent.parent.name
        ani = rec.split('_')[0]
        df = pd.read_csv(p_csv)


        dfs.append(pd.DataFrame(data={
            'unit': df.index,
            'score': df.iloc[:, 0],
            'region': folder.replace(folder.split('_')[0] + '_', ''),
            'interaction': folder,
            'recording': rec,
            'animal': ani,
            'probes': pro,
        })
        )
    df = pd.concat(dfs, ignore_index=True)
    df.loc[:, 'interaction_'] = df.loc[:, 'interaction'].map(lambda x: x.replace('ALM1', 'ALM').replace('ALM2', 'ALM'))

    return df

In [None]:
for i in Path(r'X:\Users\Zidan\DataForNico\ALM_ALM\MK22_20230301').glob('all_depths*/'):
    print(i.name)

In [None]:
for name in [ 'all_depths', 'all_depths_0.2', 'all_depths_0.5', 'all_depths_raw', 'all_depths_raw_0.2', 'all_depths_raw_0.5' ]:
    p_plot = p_root / 'plots/'
    p_plot.mkdir(exist_ok=True)

    df = load_scores(p_root.glob(f'**/{name}/**/pred_ridge_scores.csv'))

    g = sns.catplot(df, x='interaction', y='score', col='probes', hue='recording', sharex=False, facet_kws={'ylim': (-1, 1)})
    g.fig.savefig(p_plot / f'scores_{name}.png')
    plt.close(g.fig)

    df.loc[:, '_sort'] = df.loc[:, 'interaction_'].str.len()
    d = df.sort_values(by=['_sort', 'interaction_'])
    fig, ax = plt.subplots(figsize=(12, 6))
    sns.stripplot(data=d, ax=ax, x='interaction_', y='score', hue='recording')
    ax.set_ylim((-1, 1))
    fig.savefig(p_plot / f'scores_pooled_{name}.png')
    plt.close(fig)

# other

In [None]:
## all vs one
# get scores for best alpha
ridge_scores = {}
alpha = ridge_mod.get_params()['mod__alpha']

for y, u in zip(Y.T, dfy_mat.columns):
    mods = utl.ridge_regression(X, y, alphas=[alpha])
    mod = mods.best_estimator_
    ridge_scores[u] = mod.score(X, y)

# prediction
Y_pred = ridge_mod.predict(X)
dfy_pred = utl.matrix2df(Y_pred, dfy)

# plot true and predicted with score
utl.plot_psth(dfy, bin_size, df2=dfy_pred, scores=ridge_scores)

In [None]:

def matrix2df(X, dfx):

    t = dfx.loc[:, 'T'].values
    bins = dfx.loc[:, 'bins'].values
    trl = dfx.loc[:, 'trial'].values
    t2bins = pd.Series(index=t, data=bins).to_dict()   
    t2trl = pd.Series(index=t, data=trl).to_dict()   
    
    df_piv = pd.pivot_table(dfx, values='dfr', index='T', columns='unit').fillna(0)
    df_piv.loc[:, :] = X
    df_stack = df_piv.stack()
    dfr = df_stack.values
    t, unt  = [ *df_stack.index.to_frame().values.T ]
    df = pd.DataFrame(data={
        'unit': unt.astype(int),
        'trial': [ t2trl[i] for i in t ],
        'dfr': dfr,
        'bins': [ t2bins[i] for i in t ],
        'T': t,
    })

    return df
    
# all vs one for best ridge model
def all_vs_one(dfx, dfy, ridge_mod):

    alpha = ridge_mod.get_params()['mod__alpha']

    X, Y = utl.get_matrix(dfx), utl.get_matrix(dfy)
    units = dfy.loc[:, 'unit'].unique()

    ridge_scores = {}

    for y, u in zip(Y.T, units):
        mods = utl.ridge_regression(X, y, alphas=[alpha])
        mod = mods.best_estimator_
        ridge_scores[u] = mod.score(X, y)

    # prediction
    Y_pred = ridge_mod.predict(X)
    dfy_pred = utl.matrix2df(Y_pred, dfy)

    return dfy_pred, ridge_scores

## raw vs trial

In [None]:
p = './data/zidan/ALM_ALM/MK22_20230301/MK22_20230301_2H2_g0_JRC_units_probe2.mat'
rec = Recording(p, force_overwrite=True)

from scipy.io import loadmat
m = loadmat(p, squeeze_me=True, struct_as_record=False)

import matplotlib.pylab as plt
import seaborn as sns

In [None]:
df = pd.DataFrame()

for i, u in enumerate(m['unit']):
    t = vars(u)['RawSpikeTimes']

    d = pd.DataFrame(data={
        'T': t,
        'unit': i + 1,
    })


    df = pd.concat([df, d], ignore_index=True)

raw = df.loc[ df.loc[:, 'unit'] == 1 ].loc[:, 'T'].values

In [None]:
spk = vars(m['unit'][0])['SpikeTimes']
idx = vars(m['unit'][0])['Trial_idx_of_spike']

In [None]:
t0s = np.array([vars(i)['onset'] for i in m['trial_info']]) / 2.5e4
behavior = vars(vars(m['unit'][0])['Behavior'])
t_lck = behavior['First_lick']
t_cue = behavior['Sample_start']

In [None]:
sns.histplot(t_cue - t_lck)

In [None]:
fig, ax = plt.subplots(figsize=(10, 4))

for i in range(2, 30):
    mask = idx == i
    s = spk[mask]


    # t0 = df_trl.loc[ df_trl.loc[:, 'trial'] == i ].loc[:, 'T_0'].item()
    l = t_lck[i + 1]
    c = t_cue[i + 1]
    t0 = t0s[i-1]
    mask = (s > (c-2)) & (s < ( l + 2) )
    s = s[mask]

    ax.axvline(t0, c='C0', ls=':', lw=1)
    ax.axvline(t0 + c, c='C1', ls=':', lw=1)
    ax.axvline(l + t0, c='C2', ls=':', lw=1)

    ax.eventplot(s + t0, lineoffsets=i, color=f'C{i}')

mask = raw < 300
ax.eventplot(raw[mask], lineoffsets=i + 1, ls='-')

ax.set_xlabel('times [s]')
ax.set_ylabel('trial index')
ax.set_xlim((0, 60))

## new spike processing

In [None]:
p = './data/zidan/ALM_ALM/MK22_20230301/MK22_20230301_2H2_g0_JRC_units_probe2.mat'
rec = Recording(p, force_overwrite=False)

In [None]:
fig, ax = plt.subplots()

df = rec.df_spk.groupby('unit').get_group(14)
for trl, d in df.groupby('trial'):
    x = d.loc[:, 't'].values
    ax.eventplot(x, lineoffsets=trl)
ax.set_xlim((-1, 4))
ax.set_ylim((0, 300))

In [None]:
fig, ax = plt.subplots()

df = rec.df_spk.groupby('unit').get_group(14)
for trl, d in df.groupby('trial'):
    x = d.loc[:, 't'].values
    ax.eventplot(x, lineoffsets=trl)
ax.set_xlim((-1, 4))
# ax.set_ylim((0, 300))

In [None]:
# rec.df_spk = rec._load_spike_times()
# rec.df_psth = rec._calculate_psth()
rec.plot_psth(unts=[14], xlims=(-1, 6))

In [None]:
rec.df_spk = rec._load_spike_times()
print('loaded')
rec.df_psth = rec._calculate_psth()
print('loaded')
rec.plot_psth(unts=[14], xlims=(-1, 1))