In [None]:
%load_ext autoreload
%autoreload 2

from scipy.io import loadmat
from pathlib import Path

from recording import Recording
import utils as utl

import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("whitegrid")
import numpy as np

import pandas as pd

# global settings
params = {
    'bin_size'      : 0.1,   # bin size in s
    'thresh_rate'   : 1,     # min firing rate in Hz
    'thresh_sw'     : 0.5,   # min spike width in ms
    'thresh_trials' : 0.9,   # number of trials to keep in %
    'scoring'       : 'r2',  # see https://scikit-learn.org/stable/modules/model_evaluation.html
    'signal'        : 'fr',
}

# define trial epochs 
# epoch_name : (start, end, alignment)
#              note that this defines the half-open interval [start, end)
#              note that lick times may fall between bins, because bins are aligned to cue
epochs = {
    'all'       : (None, None, 'cue'), # None only works with 'cue' alignment
    'pre_cue'   : (-.6,  .0,   'cue'),
    'post_cue1' : ( .0,  .6,   'cue'),
    'post_cue2' : ( .6, 1.2,   'cue'),
    'pre_lick'  : (-.6,  .0,  'lick'),
    'post_lick' : ( .0,  .6,  'lick'),
}

# example

## select units: single region

## choose recordings

In [None]:
# # single ALM probe
# rec1 = Recording('./data/zidan/ALM_ALM/MK22_20230301/MK22_20230301_2H2_g0_JRC_units_probe1.mat')
# rec2 = None

# # ALM-ALM
# rec1 = Recording('./data/zidan/ALM_ALM/MK22_20230301/MK22_20230301_2H2_g0_JRC_units_probe2.mat')
# rec2 = Recording('./data/zidan/ALM_ALM/MK22_20230301/MK22_20230301_2H2_g0_JRC_units_probe1.mat')

# ALM-Str (imec0: STR)
rec2 = Recording('./data/zidan/ALM_STR/ZY78_20211015/ZY78_20211015NP_g0_JRC_units.mat')
rec1 = Recording('./data/zidan/ALM_STR/ZY78_20211015/ZY78_20211015NP_g0_imec0_JRC_units.mat')

# # ALM-Thal (imec0: Thal)
# rec2 = Recording('./data/zidan/ALM_Thal/ZY113_20220617/ZY113_20220617_NPH2_g0_JRC_units.mat')
# rec1 = Recording('./data/zidan/ALM_Thal/ZY113_20220617/ZY113_20220617_NPH2_g0_imec0_JRC_units.mat')

## preprocess data

In [None]:
# select units and trials, and bin data
dfx_bin, dfy_bin = utl.select_data(rec1, rec2=rec2, params=params)

# subtract baseline
dfx_bin0 = utl.subtract_baseline(dfx_bin, rec1.df_spk)
dfy_bin0 = utl.subtract_baseline(dfy_bin, rec2.df_spk)

# # optional: filter some epoch
# dfx_bin0 = utl.select_epoch(dfx_bin0, epochs['all'], rec1.df_trl)
# dfy_bin0 = utl.select_epoch(dfy_bin0, epochs['all'], rec2.df_trl)

## regression

In [None]:
# linear regression (= ridge with alpha=0)
lin_mods = utl.ridge_regression(dfx_bin0, dfy_bin0, scoring=params['scoring'], alphas=[0])
lin_mod = lin_mods.best_estimator_

# ridge
ridge_mods = utl.ridge_regression(dfx_bin0, dfy_bin0, scoring=params['scoring'], alphas=np.logspace(-13, 13, 27))
ridge_mod = ridge_mods.best_estimator_
utl.plot_gridsearch(ridge_mods, 'ridge', other_mods={'linear': lin_mods}, logscale=True)

In [None]:
# RRR
rr_mods = utl.reduced_rank_regression(dfx_bin0, dfy_bin0, scoring=params['scoring'])
rr_mod = rr_mods.best_estimator_
utl.plot_gridsearch(rr_mods, 'reduced rank', other_mods={'linear': lin_mods, 'ridge': ridge_mods}, logscale=False)

In [None]:
# plot
Y_pred, scores = utl.get_ypred(dfx_bin0, dfy_bin0, ridge_mod, scoring=params['scoring'])
utl.plot_mean_response(dfy_bin0, Y_pred, scores)

# remote batch processing

In [None]:
# batch process all recordings
p_root = Path(r'X:\Users\Zidan\DataForNico')

if not p_root.is_dir():
    print('Invalid rood dir. No access to network drive?')

else: 
    p_dirs = [ p for p in p_root.glob('*/*/') if p.is_dir() ]

name2region = {
    'MK22_20230301_2H2_g0_JRC_units_probe1.mat' : 'ALM1',
    'MK22_20230301_2H2_g0_JRC_units_probe2.mat' : 'ALM2',
    'MK22_20230303_2H2_g0_JRC_units_probe1.mat' : 'ALM1',
    'MK22_20230303_2H2_g0_JRC_units_probe2.mat' : 'ALM2',
    'MK25_20230314_2H2_g0_JRC_units_probe1.mat' : 'ALM1',
    'MK25_20230314_2H2_g0_JRC_units_probe2.mat' : 'ALM2',
    'ZY78_20211015NP_g0_imec0_JRC_units.mat'    : 'STR',
    'ZY78_20211015NP_g0_JRC_units.mat'          : 'ALM',
    'ZY82_20211028NP_g0_imec0_JRC_units.mat'    : 'STR',
    'ZY82_20211028NP_g0_JRC_units.mat'          : 'ALM',
    'ZY83_20211108NP_g0_imec0_JRC_units.mat'    : 'STR',
    'ZY83_20211108NP_g0_JRC_units.mat'          : 'ALM',
    'ZY113_20220617_NPH2_g0_imec0_JRC_units.mat': 'THA',
    'ZY113_20220617_NPH2_g0_JRC_units.mat'      : 'ALM',
    'ZY113_20220618_NPH2_g0_imec0_JRC_units.mat': 'THA',
    'ZY113_20220618_NPH2_g0_JRC_units.mat'      : 'ALM',
    'ZY113_20220620_NPH2_g0_imec0_JRC_units.mat': 'THA',
    'ZY113_20220620_NPH2_g0_JRC_units.mat'      : 'ALM',
}

## plot response, binned at 1 ms

In [None]:
for p_mat in p_root.glob('**/*mat'):
    
    # load recording and create missing dataframes
    print(p_mat)

    rec = Recording(p_mat, calc_psth=True)
    rec.plot_psth(path=rec.path_psth.with_suffix('.png'))

## fits

In [None]:
def proc_wrapper(p_out, params, recA, recB):

    # create folder
    p_out.mkdir(exist_ok=True, parents=True)

    # path for params.json
    p_json = p_out / 'params.json'
    if p_json.exists():
        print(f'params.json found. Skipping {p_out}')
        return 
    
    # load data
    dfx_bin, dfy_bin = utl.select_data(recA, rec2=recB, params=params)

    if params['subtract_baseline']:
        dfx_bin = utl.subtract_baseline(dfx_bin)
        dfy_bin = utl.subtract_baseline(dfy_bin)

    # linear regression (= ridge with alpha=0)
    lin_mods = utl.ridge_regression(dfx_bin, dfy_bin, scoring=params['scoring'], alphas=[0])
    lin_mod = lin_mods.best_estimator_
    utl.save_cv_results(lin_mods, path=p_out / 'reg_linear.parquet')

    # ridge regression
    ridge_mods = utl.ridge_regression(dfx_bin, dfy_bin, scoring=params['scoring'], alphas=np.logspace(-13, 13, 27))
    ridge_mod = ridge_mods.best_estimator_
    utl.save_cv_results(ridge_mods, path=p_out / 'reg_ridge.parquet')

    # RRR
    rr_mods = utl.reduced_rank_regression(dfx_bin, dfy_bin, scoring=params['scoring'])
    rr_mod = rr_mods.best_estimator_
    utl.save_cv_results(rr_mods, path=p_out / 'reg_rrr.parquet')

    # plot regressions
    utl.plot_gridsearch(ridge_mods, 'ridge', other_mods={'linear': lin_mods}, logscale=True, path=p_out / 'reg_ridge.png')
    utl.plot_gridsearch(rr_mods, 'reduced rank', other_mods={'linear': lin_mods, 'ridge': ridge_mods}, logscale=False, path=p_out / 'reg_rrr.png')

    # prediction
    Y_pred, scores = utl.get_ypred(dfx_bin, dfy_bin, ridge_mod, scoring=params['scoring'])
    utl.plot_mean_response(dfy_bin, Y_pred, scores, path=p_out / 'pred_ridge.png')
    pd.Series(scores, name=params['scoring']).to_csv(p_out / 'pred_ridge_scores.csv', index=False)

    # save params
    pd.Series(params).to_json(p_json)

In [None]:
d_params = {
    # bin 0.05
    'all_depth_0.05': {
        'bin_size'          : 0.05, 
        'thresh_sw'         : 0.5,  
        'thresh_rate'       : 1,    
        'thresh_trials'     : 0.9,  
        'scoring'           : 'r2', 
        'subtract_baseline' :  True
        },
    'all_depth_raw_0.05': {
        'bin_size'          : 0.05,  
        'thresh_sw'         : 0.5,  
        'thresh_rate'       : 1,    
        'thresh_trials'     : 0.9,  
        'scoring'           : 'r2', 
        'subtract_baseline' :  False
        },

    # bin 0.1
    'all_depth_0.1': {
        'bin_size'          : 0.1, 
        'thresh_sw'         : 0.5,  
        'thresh_rate'       : 1,    
        'thresh_trials'     : 0.9,  
        'scoring'           : 'r2', 
        'subtract_baseline' :  True
        },
    'all_depth_raw_0.1': {
        'bin_size'          : 0.1,  
        'thresh_sw'         : 0.5,  
        'thresh_rate'       : 1,    
        'thresh_trials'     : 0.9,  
        'scoring'           : 'r2', 
        'subtract_baseline' :  False
        },

        ## all spike width
        'all_depth_allSW_0.1': {
            'bin_size'          : 0.1, 
            'thresh_sw'         : 0.0,  
            'thresh_rate'       : 1,    
            'thresh_trials'     : 0.9,  
            'scoring'           : 'r2', 
            'subtract_baseline' :  True
            },
        'all_depth_raw_allSW_0.1': {
            'bin_size'          : 0.1,  
            'thresh_sw'         : 0.0,  
            'thresh_rate'       : 1,    
            'thresh_trials'     : 0.9,  
            'scoring'           : 'r2', 
            'subtract_baseline' :  False
            },

    # bin 0.2
    'all_depth_0.2': {
        'bin_size'          : 0.2, 
        'thresh_sw'         : 0.5,  
        'thresh_rate'       : 1,    
        'thresh_trials'     : 0.9,  
        'scoring'           : 'r2', 
        'subtract_baseline' :  True
        },
    'all_depth_raw_0.2': {
        'bin_size'          : 0.2,  
        'thresh_sw'         : 0.5,  
        'thresh_rate'       : 1,    
        'thresh_trials'     : 0.9,  
        'scoring'           : 'r2', 
        'subtract_baseline' :  False
        },

    # bin 0.5
    'all_depth_0.5': {
        'bin_size'          : 0.5, 
        'thresh_sw'         : 0.5,  
        'thresh_rate'       : 1,    
        'thresh_trials'     : 0.9,  
        'scoring'           : 'r2', 
        'subtract_baseline' :  True
        },
    'all_depth_raw_0.5': {
        'bin_size'          : 0.5,  
        'thresh_sw'         : 0.5,  
        'thresh_rate'       : 1,    
        'thresh_trials'     : 0.9,  
        'scoring'           : 'r2', 
        'subtract_baseline' :  False
        },
}

In [None]:
for name, params in d_params.items():
    print(f'>>>> starting parameter set {name}....')

    for p_dir in p_dirs:

        print(p_dir)

        # load recordings
        p_matA, p_matB = [ *p_dir.glob('*.mat')]
        recA, recB = Recording(p_matA), Recording(p_matB)
        regA, regB = name2region[p_matA.name], name2region[p_matB.name]

        # regA -> regB
        p_out = p_dir / f'{name}/{regA}_{regB}'
        proc_wrapper(p_out, params, recA, recB)

        # regB -> regA
        p_out = p_dir / f'{name}/{regB}_{regA}'
        proc_wrapper(p_out, params, recB, recA)

        # regA
        p_out = p_dir / f'{name}/{regA}'
        proc_wrapper(p_out, params, recA, None)

        # regB
        p_out = p_dir / f'{name}/{regB}'
        proc_wrapper(p_out, params, recB, None)
        
        

## plot

In [None]:
def load_scores(ps_csv):

    dfs = []
    for p_csv in ps_csv:

        folder = p_csv.parent.name
        rec = p_csv.parent.parent.parent.name
        pro = p_csv.parent.parent.parent.parent.name
        ani = rec.split('_')[0]
        df = pd.read_csv(p_csv)


        dfs.append(pd.DataFrame(data={
            'unit': df.index,
            'score': df.iloc[:, 0],
            'region': folder.replace(folder.split('_')[0] + '_', ''),
            'interaction': folder,
            'recording': rec,
            'animal': ani,
            'probes': pro,
        })
        )
    df = pd.concat(dfs, ignore_index=True)
    df.loc[:, 'interaction_'] = df.loc[:, 'interaction'].map(lambda x: x.replace('ALM1', 'ALM').replace('ALM2', 'ALM'))

    return df

In [None]:
for name in d_params.keys():
    print(name)
    
    p_plot = p_root / 'plots/'
    p_plot.mkdir(exist_ok=True)

    p_csvs = [ *p_root.glob(f'**/{name}/**/pred_ridge_scores.csv') ]
    if not p_csvs:
        print(f'No CSV files found: skipping {name}')
        continue
    
    df = load_scores(p_csvs)

    g = sns.catplot(df, x='interaction', y='score', col='probes', hue='recording', sharex=False, facet_kws={'ylim': (-1, 1)}, dodge=True)
    g.fig.savefig(p_plot / f'scores_{name}.png')
    plt.close(g.fig)

    df.loc[:, '_sort'] = df.loc[:, 'interaction_'].str.len()
    d = df.sort_values(by=['_sort', 'interaction_'])
    fig, ax = plt.subplots(figsize=(12, 6))
    sns.stripplot(data=d, ax=ax, x='interaction_', y='score', hue='recording', dodge=True)
    ax.set_ylim((-1, 1))
    fig.savefig(p_plot / f'scores_pooled_{name}.png')
    plt.close(fig)

### selective comparisons

In [None]:
# baseline subtraction

name = 'all_depth_0.1_baseline_subtraction'
p_plot = p_root / 'plots/compare_settings'
p_plot.mkdir(exist_ok=True)

p_csvs = [ *p_root.glob('**/all_depth_0.1/**/pred_ridge_scores.csv')  ]
df1 = load_scores(p_csvs)
df1.loc[:, 'sub_baseline'] = True

p_csvs = [ *p_root.glob('**/all_depth_raw_0.1/**/pred_ridge_scores.csv')  ]
df2 = load_scores(p_csvs)
df2.loc[:, 'sub_baseline'] = False

df = pd.concat([df1, df2], ignore_index=True)

g = sns.catplot(data=df, 
                x='interaction', y='score', col='recording', col_wrap=3, hue='sub_baseline',
                sharex=False, facet_kws={'ylim': (-1, 1)},
                kind='box', whis=0, fliersize=0, palette='pastel',
                dodge=True, height=4, aspect=.8)
g.map(sns.stripplot, 'interaction', 'score', 'sub_baseline', dodge=True, palette='deep', edgecolor='auto', linewidth=.5)
g.fig.savefig(p_plot / f'scores_{name}.png')
plt.close(g.fig)


In [None]:
# bin size
name = 'all_depth_bin_size'
p_plot = p_root / 'plots/compare_settings'
p_plot.mkdir(exist_ok=True)

p_csvs = [ *p_root.glob('**/all_depth_0.05/**/pred_ridge_scores.csv')  ]
df1 = load_scores(p_csvs)
df1.loc[:, 'bin_size'] = 0.05

p_csvs = [ *p_root.glob('**/all_depth_0.1/**/pred_ridge_scores.csv')  ]
df2 = load_scores(p_csvs)
df2.loc[:, 'bin_size'] = 0.1

p_csvs = [ *p_root.glob('**/all_depth_0.2/**/pred_ridge_scores.csv')  ]
df3 = load_scores(p_csvs)
df3.loc[:, 'bin_size'] = 0.2

p_csvs = [ *p_root.glob('**/all_depth_0.5/**/pred_ridge_scores.csv')  ]
df4 = load_scores(p_csvs)
df4.loc[:, 'bin_size'] = 0.5


df = pd.concat([df1, df2, df3, df4], ignore_index=True)

g = sns.catplot(data=df, 
                x='interaction', y='score', col='recording', col_wrap=3, hue='bin_size',
                sharex=False, facet_kws={'ylim': (-1, 1)},
                kind='box', whis=0, fliersize=0, palette='pastel',
                dodge=True, height=4, aspect=1.2)
g.map(sns.stripplot, 'interaction', 'score', 'bin_size', dodge=True, palette='deep', edgecolor='auto', linewidth=.5)
g.fig.savefig(p_plot / f'scores_{name}.png')
plt.close(g.fig)


In [None]:
# spike width

name = 'all_depth_0.1_spike_width'
p_plot = p_root / 'plots/compare_settings'
p_plot.mkdir(exist_ok=True)

p_csvs = [ *p_root.glob('**/all_depth_0.1/**/pred_ridge_scores.csv')  ]
df1 = load_scores(p_csvs)
df1.loc[:, 'min_spike_width'] = 0.5

p_csvs = [ *p_root.glob('**/all_depth_allSW_0.1/**/pred_ridge_scores.csv')  ]
df2 = load_scores(p_csvs)
df2.loc[:, 'min_spike_width'] = 0.0

df = pd.concat([df1, df2], ignore_index=True)

g = sns.catplot(data=df, 
                x='interaction', y='score', col='recording', col_wrap=3, hue='min_spike_width',
                sharex=False, facet_kws={'ylim': (-1, 1)},
                kind='box', whis=0, fliersize=0, palette='pastel',
                dodge=True, height=4, aspect=1.2)
g.map(sns.stripplot, 'interaction', 'score', 'min_spike_width', dodge=True, palette='deep', edgecolor='auto', linewidth=.5)
g.fig.savefig(p_plot / f'scores_{name}.png')
plt.close(g.fig)


# other

In [None]:
## all vs one
# get scores for best alpha
ridge_scores = {}
alpha = ridge_mod.get_params()['mod__alpha']

for y, u in zip(Y.T, dfy_mat.columns):
    mods = utl.ridge_regression(X, y, alphas=[alpha])
    mod = mods.best_estimator_
    ridge_scores[u] = mod.score(X, y)

# prediction
Y_pred = ridge_mod.predict(X)
dfy_pred = utl.matrix2df(Y_pred, dfy)

# plot true and predicted with score
utl.plot_psth(dfy, bin_size, df2=dfy_pred, scores=ridge_scores)

In [None]:

def matrix2df(X, dfx):

    t = dfx.loc[:, 'T'].values
    bins = dfx.loc[:, 'bins'].values
    trl = dfx.loc[:, 'trial'].values
    t2bins = pd.Series(index=t, data=bins).to_dict()   
    t2trl = pd.Series(index=t, data=trl).to_dict()   
    
    df_piv = pd.pivot_table(dfx, values='dfr', index='T', columns='unit').fillna(0)
    df_piv.loc[:, :] = X
    df_stack = df_piv.stack()
    dfr = df_stack.values
    t, unt  = [ *df_stack.index.to_frame().values.T ]
    df = pd.DataFrame(data={
        'unit': unt.astype(int),
        'trial': [ t2trl[i] for i in t ],
        'dfr': dfr,
        'bins': [ t2bins[i] for i in t ],
        'T': t,
    })

    return df
    
# all vs one for best ridge model
def all_vs_one(dfx, dfy, ridge_mod):

    alpha = ridge_mod.get_params()['mod__alpha']

    X, Y = utl.get_matrix(dfx), utl.get_matrix(dfy)
    units = dfy.loc[:, 'unit'].unique()

    ridge_scores = {}

    for y, u in zip(Y.T, units):
        mods = utl.ridge_regression(X, y, alphas=[alpha])
        mod = mods.best_estimator_
        ridge_scores[u] = mod.score(X, y)

    # prediction
    Y_pred = ridge_mod.predict(X)
    dfy_pred = utl.matrix2df(Y_pred, dfy)

    return dfy_pred, ridge_scores

## raw vs trial

In [None]:
p = './data/zidan/ALM_ALM/MK22_20230301/MK22_20230301_2H2_g0_JRC_units_probe2.mat'
rec = Recording(p, force_overwrite=True)

from scipy.io import loadmat
m = loadmat(p, squeeze_me=True, struct_as_record=False)

import matplotlib.pylab as plt
import seaborn as sns

In [None]:
df = pd.DataFrame()

for i, u in enumerate(m['unit']):
    t = vars(u)['RawSpikeTimes']

    d = pd.DataFrame(data={
        'T': t,
        'unit': i + 1,
    })


    df = pd.concat([df, d], ignore_index=True)

raw = df.loc[ df.loc[:, 'unit'] == 1 ].loc[:, 'T'].values

In [None]:
spk = vars(m['unit'][0])['SpikeTimes']
idx = vars(m['unit'][0])['Trial_idx_of_spike']

In [None]:
t0s = np.array([vars(i)['onset'] for i in m['trial_info']]) / 2.5e4
behavior = vars(vars(m['unit'][0])['Behavior'])
t_lck = behavior['First_lick']
t_cue = behavior['Sample_start']

In [None]:
sns.histplot(t_cue - t_lck)

In [None]:
fig, ax = plt.subplots(figsize=(10, 4))

for i in range(2, 30):
    mask = idx == i
    s = spk[mask]


    # t0 = df_trl.loc[ df_trl.loc[:, 'trial'] == i ].loc[:, 'T_0'].item()
    l = t_lck[i + 1]
    c = t_cue[i + 1]
    t0 = t0s[i-1]
    mask = (s > (c-2)) & (s < ( l + 2) )
    s = s[mask]

    ax.axvline(t0, c='C0', ls=':', lw=1)
    ax.axvline(t0 + c, c='C1', ls=':', lw=1)
    ax.axvline(l + t0, c='C2', ls=':', lw=1)

    ax.eventplot(s + t0, lineoffsets=i, color=f'C{i}')

mask = raw < 300
ax.eventplot(raw[mask], lineoffsets=i + 1, ls='-')

ax.set_xlabel('times [s]')
ax.set_ylabel('trial index')
ax.set_xlim((0, 60))

## new spike processing

In [None]:
p = './data/zidan/ALM_ALM/MK22_20230301/MK22_20230301_2H2_g0_JRC_units_probe2.mat'
rec = Recording(p, force_overwrite=False)

In [None]:
fig, ax = plt.subplots()

df = rec.df_spk.groupby('unit').get_group(14)
for trl, d in df.groupby('trial'):
    x = d.loc[:, 't'].values
    ax.eventplot(x, lineoffsets=trl)
ax.set_xlim((-1, 4))
ax.set_ylim((0, 300))

In [None]:
fig, ax = plt.subplots()

df = rec.df_spk.groupby('unit').get_group(14)
for trl, d in df.groupby('trial'):
    x = d.loc[:, 't'].values
    ax.eventplot(x, lineoffsets=trl)
ax.set_xlim((-1, 4))
# ax.set_ylim((0, 300))

In [None]:
# rec.df_spk = rec._load_spike_times()
# rec.df_psth = rec._calculate_psth()
rec.plot_psth(unts=[14], xlims=(-1, 6))

In [None]:
rec.df_spk = rec._load_spike_times()
print('loaded')
rec.df_psth = rec._calculate_psth()
print('loaded')
rec.plot_psth(unts=[14], xlims=(-1, 1))

## only rewarded trials

In [None]:

t = rec1.df_trl.loc[ rec1.df_trl.loc[:,'response'] == 'Reward', 'trial' ]
m = rec1.df_prec.loc[:, 'trial'].isin(t)
rec1.df_prec = rec1.df_prec.loc[ m ]

t = rec2.df_trl.loc[ rec2.df_trl.loc[:,'response'] == 'Reward', 'trial' ]
m = rec2.df_prec.loc[:, 'trial'].isin(t)
rec2.df_prec = rec2.df_prec.loc[ m ]