In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path

from recording import Recording
import utils as utl

import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("whitegrid")

import numpy as np
import pandas as pd
from scipy.ndimage import uniform_filter1d

# global settings
params = {
    # bin size [s]
    'bin_size'        : 0.2,

    # filter units
    # thresholds for source (src) and target (trg) population
    # note: `None`: no upper/lower limit
    # note: for within-region analysis, only source thresholds apply
    'rate_src': (1, None),            # firing rate [Hz] interval
    'rate_trg': (1, None),
    'spike_width_src': (None, None), # spike width [ms] interval
    'spike_width_trg': (  .5, None),
    'perc_trials': 0.9,              # percentage of trials to be covered
                                     # note: Since the valid trial range for neurons may differ greatly, we choose a percentage of trials
                                     #       that we want to keep (0.9 = 90 %). Then, we drop neurons until the remaining neurons
                                     #       cover at least this part of the maximum trial range.

    # filter trials
    'first_lick' : (None, None), # lick time [s] interval, relative to cue
    'type_incl': [ 'l_n', ],   # trial types (list of strings)
                               # note: Define strings to match beginning of `unit(1).Behavior.stim_type_name` strings.
                               #       To include all trials, set list empty


    # score reported in the output
    # note: The definition of the score does not affect definition of the loss function
    #       and therefore does not affect the model parameters.
    # note: see https://scikit-learn.org/stable/modules/model_evaluation.html for available scorers
    'scoring': 'r2',

    # subtract baseline 
    # note: This subtracts the average firing rate during pre cue period per trial
    'subtract_baseline': True

}

# define trial epochs 
# epoch_name : (start, end, alignment)
# note:
#   (i): Epochs define the half-open interval [start, end)
#  (ii): Lick times may fall between bins, because bins are aligned to cue
# (iii): Default epoch ('all') defined as ['cue' - 2 s, 'lick' + 2 s)
#        Only complete bins are kept. For example, if bin_size = 0.2 and first_lick = 1.81,
#        then the last bin in this trial is defined as [1.98, 2.00)
epochs = {
    'all'       : (None, None, 'cue'), # None only works with 'cue' alignment
    'pre_cue'   : (-.6,  .0,   'cue'),
    'post_cue1' : ( .0,  .6,   'cue'),
    'post_cue2' : ( .6, 1.2,   'cue'),
    'pre_lick'  : (-.6,  .0,  'lick'),
    'post_lick' : ( .0,  .6,  'lick'),
}

# define sets of trials 
trial_groups = {
    'lick_0.6': (0.6, None, 'lick'), # e.g. all trials with lick times relative to cue > 0.6 s
    'lick_1.2': (1.2, None, 'lick'),
}

# example

## select units: single region

## choose recordings

In [None]:
# # single ALM probe
# rec1 = Recording('./data/zidan/ALM_ALM/MK22_20230301/MK22_20230301_2H2_g0_JRC_units_probe1.mat')
# rec2 = None

# # ALM-ALM
# rec1 = Recording('./data/zidan/ALM_ALM/MK22_20230301/MK22_20230301_2H2_g0_JRC_units_probe2.mat')
# rec2 = Recording('./data/zidan/ALM_ALM/MK22_20230301/MK22_20230301_2H2_g0_JRC_units_probe1.mat')

# ALM-Str (imec0: STR)
rec2 = Recording('./data/zidan/ALM_STR/ZY78_20211015/ZY78_20211015NP_g0_JRC_units.mat')
rec1 = Recording('./data/zidan/ALM_STR/ZY78_20211015/ZY78_20211015NP_g0_imec0_JRC_units.mat')

# # ALM-Thal (imec0: Thal)
# rec2 = Recording('./data/zidan/ALM_Thal/ZY113_20220617/ZY113_20220617_NPH2_g0_JRC_units.mat')
# rec1 = Recording('./data/zidan/ALM_Thal/ZY113_20220617/ZY113_20220617_NPH2_g0_imec0_JRC_units.mat')

## preprocess data

In [None]:
# select units and trials, and bin data
dfx_bin, dfy_bin = utl.select_data(rec1, rec2=rec2, params=params)

# subtract baseline
dfx_bin0 = utl.subtract_baseline(dfx_bin, rec1.df_spk)
dfy_bin0 = utl.subtract_baseline(dfy_bin, rec1.df_spk if rec2 is None else rec2.df_spk)

# optional: filter some epoch
dfx_bin0_epo = utl.select_epoch(dfx_bin0, epochs['pre_lick'], rec1.df_trl)
dfy_bin0_epo = utl.select_epoch(dfy_bin0, epochs['pre_lick'], rec1.df_trl if rec2 is None else rec2.df_trl)

## regression

In [None]:
# linear regression (= ridge with alpha=0)
lin_mods = utl.ridge_regression(dfx_bin0, dfy_bin0, scoring=params['scoring'], alphas=[0])
lin_mod = lin_mods.best_estimator_

# ridge
ridge_mods = utl.ridge_regression(dfx_bin0, dfy_bin0, scoring=params['scoring'], alphas=np.logspace(-13, 13, 27))
ridge_mod = ridge_mods.best_estimator_
utl.plot_gridsearch(ridge_mods, 'ridge', other_mods={'linear': lin_mods}, logscale=True)

In [None]:
# RRR
rr_mods = utl.reduced_rank_regression(dfx_bin0, dfy_bin0, scoring=params['scoring'])
rr_mod = rr_mods.best_estimator_
utl.plot_gridsearch(rr_mods, 'reduced rank', other_mods={'linear': lin_mods, 'ridge': ridge_mods}, logscale=False)

In [None]:
# plot
Y_pred, scores = utl.get_ypred(dfx_bin0, dfy_bin0, ridge_mod, scoring=params['scoring'])
utl.plot_mean_response(dfy_bin0, Y_pred, scores)

# remote batch processing

In [None]:
# batch process all recordings
p_root = Path(r'X:\Users\Zidan\DataForNico\data')

if not p_root.is_dir():
    print('Invalid rood dir. No access to network drive?')

else: 
    p_dirs = \
        [ *p_root.glob('ALM_ALM/*/') ] + \
        [ *p_root.glob('ALM_STR/*/') ] + \
        [ *p_root.glob('ALM_Thal/*/') ]

name2region = {
    'MK22_20230301_2H2_g0_JRC_units_probe1.mat' : 'ALM1',
    'MK22_20230301_2H2_g0_JRC_units_probe2.mat' : 'ALM2',
    'MK22_20230303_2H2_g0_JRC_units_probe1.mat' : 'ALM1',
    'MK22_20230303_2H2_g0_JRC_units_probe2.mat' : 'ALM2',
    'MK25_20230314_2H2_g0_JRC_units_probe1.mat' : 'ALM1',
    'MK25_20230314_2H2_g0_JRC_units_probe2.mat' : 'ALM2',
    'ZY78_20211015NP_g0_imec0_JRC_units.mat'    : 'STR',
    'ZY78_20211015NP_g0_JRC_units.mat'          : 'ALM',
    'ZY82_20211028NP_g0_imec0_JRC_units.mat'    : 'STR',
    'ZY82_20211028NP_g0_JRC_units.mat'          : 'ALM',
    'ZY83_20211108NP_g0_imec0_JRC_units.mat'    : 'STR',
    'ZY83_20211108NP_g0_JRC_units.mat'          : 'ALM',
    'ZY113_20220617_NPH2_g0_imec0_JRC_units.mat': 'THA',
    'ZY113_20220617_NPH2_g0_JRC_units.mat'      : 'ALM',
    'ZY113_20220618_NPH2_g0_imec0_JRC_units.mat': 'THA',
    'ZY113_20220618_NPH2_g0_JRC_units.mat'      : 'ALM',
    'ZY113_20220620_NPH2_g0_imec0_JRC_units.mat': 'THA',
    'ZY113_20220620_NPH2_g0_JRC_units.mat'      : 'ALM',
}

## precalculate bins

In [None]:
for bin_size in [.2, 1e-3]:
    print(bin_size)
    for p_mat in p_root.glob('**/*mat'):

        print(p_mat)
        rec = Recording(p_mat)

        rec.path_bin = rec._path_tmp / 'bin{}.hdf'.format(bin_size)
        df_bin = rec._assign_df(rec.path_bin, utl.bin_spikes, {'df_spk': rec.df_spk, 'df_trl': rec.df_trl, 'bin_size': bin_size})

In [None]:

bin_size = 1e-3
for p_mat in p_root.glob('**/*mat'):

    print(p_mat)
    rec = Recording(p_mat)

    rec.path_bin = rec._path_tmp / 'bin{}.hdf'.format(bin_size)
    df_bin = rec._assign_df(rec.path_bin, utl.bin_spikes, {'df_spk': rec.df_spk, 'df_trl': rec.df_trl, 'bin_size': bin_size})
    
    df_bin = df_bin.apply(uniform_filter1d, axis=0, size=50)
    # utl.plot_mean_response(df_bin, path=rec._path_tmp / 'psth1ms.png')

    break


## fit

In [None]:
def proc_wrapper(p_out, params, recX, recY):

    # create folder
    p_out.mkdir(exist_ok=True, parents=True)

    # path for params.json
    p_json = p_out / 'params.json'
    if p_json.exists():
        print(f'params.json found. Skipping {p_out}')
        return 
    
    # load data, select trials and units based on `params`
    dfx_bin, dfy_bin = utl.select_data(recX, rec2=recY, params=params)

    # subtract baseline, if required
    if params['subtract_baseline']:
        dfx_bin = utl.subtract_baseline(dfx_bin, recX.df_spk)
        dfy_bin = utl.subtract_baseline(dfy_bin, recX.df_spk if recY is None else recY.df_spk)
    if dfx_bin.empty:
        print(f'INFO no data left, skipping recX: {recX.session}, recY: {recY.session}')
        return

    # do fit for each epoch separately
    for name, epo in epochs.items():

        # output folder for epoch
        p_out_epo = p_out / name
        p_out_epo.mkdir(exist_ok=True)

        # select subset of data
        dfx_bin_epo = utl.select_epoch(dfx_bin, epo, recX.df_trl)
        dfy_bin_epo = utl.select_epoch(dfy_bin, epo, recX.df_trl if recY is None else recY.df_trl)

        if dfx_bin_epo.empty:
            print(f'INFO no data left in epoch {name}, skipping recX: {recX.session}, recY: {recY.session}')
            continue

        # linear regression (= ridge with alpha=0)
        lin_mods = utl.ridge_regression(dfx_bin_epo, dfy_bin_epo, scoring=params['scoring'], alphas=[0])
        utl.save_cv_results(lin_mods, path=p_out_epo / 'reg_linear.parquet')

        # ridge regression
        ridge_mods = utl.ridge_regression(dfx_bin_epo, dfy_bin_epo, scoring=params['scoring'], alphas=np.logspace(-13, 13, 27))
        ridge_mod = ridge_mods.best_estimator_
        utl.save_cv_results(ridge_mods, path=p_out_epo / 'reg_ridge.parquet')

        # RRR
        rr_mods = utl.reduced_rank_regression(dfx_bin_epo, dfy_bin_epo, scoring=params['scoring'])
        utl.save_cv_results(rr_mods, path=p_out_epo / 'reg_rrr.parquet')

        # plot regressions
        utl.plot_gridsearch(ridge_mods, 'ridge', other_mods={'linear': lin_mods}, logscale=True, path=p_out_epo / 'reg_ridge.png')
        utl.plot_gridsearch(rr_mods, 'reduced rank', other_mods={'linear': lin_mods, 'ridge': ridge_mods}, logscale=False, path=p_out_epo / 'reg_rrr.png')

        # prediction
        Y_pred, scores = utl.get_ypred(dfx_bin_epo, dfy_bin_epo, ridge_mod, scoring=params['scoring'])
        utl.plot_mean_response(dfy_bin_epo, Y_pred, scores, path=p_out_epo / 'pred_ridge.png')
        ds = pd.Series(scores, name=params['scoring'])
        ds.index.name = 'unit'
        ds.to_csv(p_out_epo / 'pred_ridge_scores.csv', index=True)

    # save params
    pd.Series(params).to_json(p_json)

In [None]:
d_params = {
    'all_depth_0.2': {
        'bin_size'          : 0.2, 
        'spike_width_src'   : (None, None),
        'spike_width_trg'   : (  .5, None),
        'first_lick'        : (None, None), 
        'type_incl'         : [ 'l_n', ],       
        'rate_src'          : (1, None),
        'rate_trg'          : (1, None),        
        'perc_trials'       : 0.9,  
        'scoring'           : 'r2', 
        'subtract_baseline' :  True
        },
    'all_depth_0.2_sw': {
        'bin_size'          : 0.2, 
        'spike_width_src'   : (None, None),
        'spike_width_trg'   : (None, .5), 
        'first_lick'        : (None, None), 
        'type_incl'         : [ 'l_n', ],       
        'rate_src'          : (1, None),
        'rate_trg'          : (1, None),        
        'perc_trials'       : 0.9,  
        'scoring'           : 'r2', 
        'subtract_baseline' :  True
        },
    'all_depth_0.2_lck0.6': {
        'bin_size'          : 0.2, 
        'spike_width_src'   : (None, None),
        'spike_width_trg'   : (.5, None), 
        'first_lick'        : (None, .6), 
        'type_incl'         : [ 'l_n', ],       
        'rate_src'          : (1, None),
        'rate_trg'          : (1, None),        
        'perc_trials'       : 0.9,  
        'scoring'           : 'r2', 
        'subtract_baseline' :  True
        },
    'all_depth_0.2_lck1.2': {
        'bin_size'          : 0.2, 
        'spike_width_src'   : (None, None),
        'spike_width_trg'   : (.5, None), 
        'first_lick'        : (None, 1.2), 
        'type_incl'         : [ 'l_n', ],       
        'rate_src'          : (1, None),
        'rate_trg'          : (1, None),        
        'perc_trials'       : 0.9,  
        'scoring'           : 'r2', 
        'subtract_baseline' :  True
        },
    'all_depth_0.2_lck1.8': {
        'bin_size'          : 0.2, 
        'spike_width_src'   : (None, None),
        'spike_width_trg'   : (.5, None), 
        'first_lick'        : (None, 1.8), 
        'type_incl'         : [ 'l_n', ],       
        'rate_src'          : (1, None),
        'rate_trg'          : (1, None),        
        'perc_trials'       : 0.9,  
        'scoring'           : 'r2', 
        'subtract_baseline' :  True
        },
    'all_depth_raw_0.2': {
        'bin_size'          : 0.2, 
        'spike_width_src'   : (None, None),
        'spike_width_trg'   : (  .5, None), 
        'first_lick'        : (None, None), 
        'type_incl'         : [ 'l_n', ],       
        'rate_src'          : (1, None),
        'rate_trg'          : (1, None),        
        'perc_trials'       : 0.9,  
        'scoring'           : 'r2', 
        'subtract_baseline' :  False
        },

}

In [None]:
for name, params in d_params.items():
    print(f'>>>> starting parameter set {name}....')

    for p_dir in p_dirs:

        print(p_dir)

        # load recordings
        p_matA, p_matB = [ *p_dir.glob('*.mat')]
        recA, recB = Recording(p_matA, tmp_dir='analysis/tmp'), Recording(p_matB, tmp_dir='analysis/tmp')
        regA, regB = name2region[p_matA.name], name2region[p_matB.name]

        # regA -> regB
        p_out = p_dir / f'analysis/{name}/{regA}_{regB}'
        proc_wrapper(p_out, params, recA, recB)

        # regB -> regA
        p_out = p_dir / f'analysis/{name}/{regB}_{regA}'
        proc_wrapper(p_out, params, recB, recA)

        # regA
        p_out = p_dir / f'analysis/{name}/{regA}'
        proc_wrapper(p_out, params, recA, None)

        # regB
        p_out = p_dir / f'analysis/{name}/{regB}'
        proc_wrapper(p_out, params, recB, None)
        

### plot

In [None]:
def load_scores(ps_csv):

    dfs = []
    for p_csv in ps_csv:

        parts = p_csv.parts
        epoch = parts[-2]
        inter = parts[-3]
        setti = parts[-4]
        recor = parts[-6]
        anima, date = recor.split('_')
        probe = parts[-7]

        df = pd.read_csv(p_csv)

        dfs.append(pd.DataFrame(data={
            'unit':         df.loc[:, 'unit'], # TODO change this for newer data
            'score':        df.iloc[:, 1],
            'epoch':        epoch,
            'interaction':  inter,
            'settings':     setti,
            'recording':    recor,
            'animal':       anima,
            'date' :        date,
            'probes':       probe
        })
        )
    df = pd.concat(dfs, ignore_index=True)
    df.loc[:, 'interaction_'] = df.loc[:, 'interaction'].map(lambda x: x.replace('ALM1', 'ALM').replace('ALM2', 'ALM'))
    df.loc[:, 'n_regions'] = df.loc[:, 'interaction'].apply(lambda x: len(x.split('_')))

    return df

In [None]:


dfs = []
for name in d_params.keys():
    
    print(f'INFO: Now processing {name}')
    p_plot = p_root / f'plots/{name}'
    p_plot.mkdir(exist_ok=True, parents=True)

    p_csvs = [ *p_root.glob(f'**/{name}/**/pred_ridge_scores.csv') ]
    if not p_csvs:
        print(f'No CSV files found: skipping {name}')
        continue
    
    df = load_scores(p_csvs)
    gr_epo = df.groupby('epoch')
    df_all = gr_epo.get_group('all')
    
    for epo, df_epo in gr_epo:

        
        # g = sns.catplot(df_epo, x='interaction', y='score', col='probes', hue='recording', sharex=False, facet_kws={'ylim': (-1, 1)}, dodge=True)
        # g.fig.savefig(p_plot / f'scores_{epo}.png')
        # plt.close(g.fig)

        # df_epo.loc[:, '_sort'] = df_epo.loc[:, 'interaction_'].str.len()
        # d = df_epo.sort_values(by=['_sort', 'interaction_'])
        # fig, ax = plt.subplots(figsize=(12, 6))
        # sns.stripplot(data=d, ax=ax, x='interaction_', y='score', hue='recording', dodge=True)
        # ax.set_ylim((-1, 1))
        # fig.savefig(p_plot / f'pooled_scores_{epo}.png')
        # plt.close(fig)

        df.loc[df_epo.index, 'dscore'] = df_epo.loc[:, 'score'].values - df_all.loc[:, 'score'].values
    dfs.append(df)
df = pd.concat(dfs, ignore_index=True)


In [None]:
sel = {
    'all_depth_0.2': '> .5ms',
}

m = df.loc[:, 'settings'].isin(sel.keys())

g = sns.catplot(df, 
            x='interaction', y='score', col='probes',  hue='epoch',
            kind='box', whis=0, fliersize=0, palette='pastel', hue_order=epochs.keys(),
            sharex=False, facet_kws={'ylim': (-.25, .85)}) #, dodge=True) #, aspect=1.2)
g.map(sns.stripplot, 'interaction', 'score', 'epoch',
 hue_order=epochs.keys(), dodge=True, palette='deep', 
 edgecolor='auto', linewidth=.5, size=1)

g.fig.savefig('epoch.png')

In [None]:
sel = {
    'all_depth_0.2': '> .5ms',
}

m = df.loc[:, 'settings'].isin(sel.keys())

g = sns.catplot(df, 
            x='interaction', y='dscore', col='probes',  hue='epoch',
            kind='box', whis=0, fliersize=0, palette='pastel', hue_order=epochs.keys(),
            sharex=False, facet_kws={'ylim': (-1, 1)}) #, dodge=True) #, aspect=1.2)
g.map(sns.stripplot, 'interaction', 'dscore', 'epoch',
 hue_order=epochs.keys(), dodge=True, palette='deep', 
 edgecolor='auto', linewidth=.5, size=1)

g.fig.savefig('d_epoch.png')

In [None]:
sel = {
    'all_depth_0.2': '> .5ms',
    'all_depth_0.2_sw': '< .5 ms' 
}
label = 'spike width'

m = df.loc[:, 'settings'].isin(sel.keys())
df.loc[:, label] = df.loc[:, 'settings'].map(sel)

g = sns.catplot(df.loc[m, :], 
            x='interaction', y='score', col='probes',  hue=label,
            kind='box', whis=0, fliersize=0, palette='pastel', hue_order=sel.values(),
            sharex=False, facet_kws={'ylim': (-.25, .85)}, dodge=True) #, aspect=1.2)
g.map(sns.stripplot, 'interaction', 'score', label,
  hue_order=sel.values(), dodge=True, palette='deep', 
 edgecolor='auto', linewidth=.5, size=2)

g.fig.savefig('sw.png')

In [None]:
sel = {
    'all_depth_0.2': 'True',
    'all_depth_raw_0.2': 'False' 
}
label = 'baseline subtracted'

m = df.loc[:, 'settings'].isin(sel.keys())
df.loc[:, label] = df.loc[:, 'settings'].map(sel)

g = sns.catplot(df.loc[m, :], 
            x='interaction', y='score', col='probes',  hue=label,
            kind='box', whis=0, fliersize=0, palette='pastel', hue_order=df.loc[m, label].unique(),
            sharex=False, facet_kws={'ylim': (-.25, .85)}, dodge=True) #, aspect=1.2)
g.map(sns.stripplot, 'interaction', 'score', label,
  hue_order=df.loc[m, label].unique(), dodge=True, palette='deep', 
 edgecolor='auto', linewidth=.5, size=2)

g.fig.savefig('baseline.png')

In [None]:
sel = {
    'all_depth_0.2_lck0.6': '< 0.6 s',
    'all_depth_0.2_lck1.2': '< 1.2 s',
    'all_depth_0.2_lck1.8': '< 1.8 s',
    'all_depth_0.2': 'all' 
}
label = 'lick time'

m = df.loc[:, 'settings'].isin(sel.keys())
df.loc[:, label] = df.loc[:, 'settings'].map(sel)

g = sns.catplot(df.loc[m, :], 
            x='interaction', y='score', col='probes',  hue=label,
            kind='box', whis=0, fliersize=0, palette='pastel', hue_order=sel.values(),
            sharex=False, facet_kws={'ylim': (-.25, .85)}, dodge=True) #, aspect=1.2)
g.map(sns.stripplot, 'interaction', 'score', label,
  hue_order=sel.values(), dodge=True, palette='deep', 
 edgecolor='auto', linewidth=.5, size=2)

g.fig.savefig('lick_time.png')

In [None]:
fig, axarr = plt.subplots(ncols=2, figsize=(12, 5))

for ax, (k, d) in zip(axarr, df.groupby('n_regions')):
    sns.histplot(ax=ax, data=d, x='score', hue='interaction_', stat='density', common_norm=False, bins=np.linspace(-1, 1, 40), kde=True)
    ax.set_xlim(-.25, 1)

fig.savefig('scores_hist.svg')

In [None]:
g = sns.catplot(data=df, 
                x='interaction', y='score', col='recording', col_wrap=3, hue='bin_size',
                sharex=False, facet_kws={'ylim': (-1, 1)},
                kind='box', whis=0, fliersize=0, palette='pastel',
                dodge=True, height=4, aspect=1.2)
g.map(sns.stripplot, 'interaction', 'score', 'bin_size', dodge=True, palette='deep', edgecolor='auto', linewidth=.5)
# g.fig.savefig(p_plot / f'scores_{name}.png')

In [None]:
g = sns.catplot(df, 
            x='interaction', y='dscore', row='probes', hue='epoch', hue_order=epochs.keys(),
            # kind='box', whis=0, fliersize=0, palette='pastel',
            sharex=False, facet_kws={'ylim': (-1, 1)}, dodge=True, palette='deep')
# g.map(sns.stripplot, 'interaction', 'dscore', 'epoch', hue_order=epochs.keys(), dodge=True, palette='deep', edgecolor='auto', linewidth=.5)


In [None]:
fig, axmat = plt.subplots(nrows=3, ncols=4, figsize=(20, 15))
legend = True
for axarr, (pro, df_pro) in zip(axmat, df.groupby('probes')):
    print(pro, len(axarr))
    for ax, (inter, df_inter) in zip(axarr, df_pro.groupby('interaction')):

        sns.kdeplot(ax=ax, data=df_inter, x='score', hue='epoch', legend=legend)
        legend = False

        ax.set_xlim(-1, 1)
        ax.set_title(f'{pro}: {inter}')
