In [1]:
import glob
import numpy as np
import pandas as pd
import numpy as np
from sglm.models import sglm
from sglm.features import gen_signal_df as gsd
from sglm.features import build_features as bf

In [2]:
# Load Signal Data
signal_files = glob.glob(f'../data/raw/GLM_SIGNALS_WT61_*')
table_files = [_.replace('GLM_SIGNALS', 'GLM_TABLE') for _ in signal_files]

files_list = signal_files
channel_definitions = {
        ('WT61',): {'Ch1': 'gACH', 'Ch2': 'rDA'},
        ('WT64',): {'Ch1': 'gACH', 'Ch2': 'empty'},
        ('WT63',): {'Ch1': 'gDA', 'Ch2': 'empty'},
    }

channel_assignments = bf.get_rename_columns_by_file(files_list, channel_definitions)

# Load Table Data
table_file = table_files[0]
signal_file = signal_files[0]

signal_df = pd.read_csv(signal_file)
table_df = pd.read_csv(table_file)

('WT61',)
> GLM_SIGNALS_WT61_10152021.txt
> GLM_SIGNALS_WT61_10042021.txt
> GLM_SIGNALS_WT61_10062021.txt
> GLM_SIGNALS_WT61_10132021.txt
> GLM_SIGNALS_WT61_10082021.txt
> GLM_SIGNALS_WT61_10182021.txt
> GLM_SIGNALS_WT61_10112021.txt
('WT64',)
('WT63',)


In [4]:
# Break down Preprocess Lynne into component parts


In [5]:
# Generate AB Labels
from sglm.features import table_file as tbf

def set_first_prv_trial_letter(prv_wasRewarded_series, label_series, loc=0):
    label = label_series.copy()
    label.loc[prv_wasRewarded_series] = label.loc[prv_wasRewarded_series].str.slice_replace(loc, loc+1, 'A')
    label.loc[~prv_wasRewarded_series] = label.loc[~prv_wasRewarded_series].str.slice_replace(loc, loc+1, 'a')
    return label

def set_current_trial_letter_switch(sameSide_series, wasRewarded_series, label_series, loc=1):
    label = label_series.copy()

    label.loc[(sameSide_series&wasRewarded_series)] = label.loc[(sameSide_series&wasRewarded_series)].str.slice_replace(loc, loc+1, 'A')
    label.loc[(~sameSide_series&wasRewarded_series)] = label.loc[(~sameSide_series&wasRewarded_series)].str.slice_replace(loc, loc+1, 'B')
    label.loc[(sameSide_series&~wasRewarded_series)] = label.loc[(sameSide_series&~wasRewarded_series)].str.slice_replace(loc, loc+1, 'a')
    label.loc[(~sameSide_series&~wasRewarded_series)] = label.loc[(~sameSide_series&~wasRewarded_series)].str.slice_replace(loc, loc+1, 'b')
    
    return label

def set_current_trial_letter_side(choseRight, wasRewarded_series, label_series, loc=2):
    label = label_series.copy()

    label.loc[(choseRight&wasRewarded_series)] = label.loc[(choseRight&wasRewarded_series)].str.slice_replace(loc, loc+1, 'R')
    label.loc[(~choseRight&wasRewarded_series)] = label.loc[(~choseRight&wasRewarded_series)].str.slice_replace(loc, loc+1, 'L')
    label.loc[(choseRight&~wasRewarded_series)] = label.loc[(choseRight&~wasRewarded_series)].str.slice_replace(loc, loc+1, 'r')
    label.loc[(~choseRight&~wasRewarded_series)] = label.loc[(~choseRight&~wasRewarded_series)].str.slice_replace(loc, loc+1, 'l')
    
    return label

def check_Ab_labels(df_t):
    df_t['pwR'] = df_t['prv_wasRewarded'].astype(int)
    df_t['wR'] = df_t['wasRewarded'].astype(int)
    df_t['sS'] = df_t['sameSide'].astype(int)
    check = df_t[['pwR', 'wR', 'sS', 'label']].copy()
    check['val_a'] = df_t['label'].str.slice(0, 1).replace('a', 0).replace('A', 1) + df_t['label'].str.slice(1, 2).replace('b', 0).replace('B', 2).replace('a', 4).replace('A', 6)
    check['val_b'] = check['pwR'] + check['wR']*2 + check['sS']*4
    assert (check['val_a'] == check['val_b']).all()
    return

def generate_Ab_labels(df_t):
    df_t = df_t.copy()
    
    df_t['wasRewarded'] = df_t['wasRewarded'].astype(bool)
    df_t['prv_wasRewarded'] = df_t['wasRewarded'].shift(1).astype(bool)
    df_t['prv_choseLeft'] = df_t['choseLeft'].shift(1).astype(bool)
    df_t['prv_choseRight'] = df_t['choseRight'].shift(1).astype(bool)
    
    df_t['sameSide'] = ((df_t['choseLeft'] == df_t['prv_choseLeft'])&(df_t['choseRight'] == df_t['prv_choseRight'])).astype(bool).fillna(False)
    df_t['label'] = '  '

    df_t['label'] = set_first_prv_trial_letter(df_t['prv_wasRewarded'], df_t['label'], loc=0)
    df_t['label'] = set_current_trial_letter_switch(df_t['sameSide'], df_t['wasRewarded'], df_t['label'], loc=1)

    df_t['label_side'] = '  '
    df_t['label_side'] = set_current_trial_letter_side(df_t['prv_choseRight'], df_t['prv_wasRewarded'], df_t['label_side'], loc=0)
    df_t['label_side'] = set_current_trial_letter_side(df_t['choseRight'], df_t['wasRewarded'], df_t['label_side'], loc=1)

    df_t['wasRewarded'] = df_t['wasRewarded'].fillna(False)
    df_t['prv_wasRewarded'] = df_t['prv_wasRewarded'].fillna(False)

    df_t.loc[df_t['prv_wasRewarded'].isna(), 'label'] = np.nan
    df_t.loc[df_t['prv_wasRewarded'].isna(), 'label'] = np.nan
    df_t = df_t.dropna()

    check_Ab_labels(df_t)

    return df_t

basis_Aa_cols = ['AA', 'Aa', 'aA', 'aa', 'AB', 'Ab', 'aB', 'ab']
df_t = generate_Ab_labels(table_df)
ab_dummies = pd.get_dummies(df_t['label'])
for basis_col in basis_Aa_cols:
    if basis_col not in ab_dummies.columns:
        df_t[basis_col] = 0
df_t[ab_dummies.columns] = ab_dummies


In [6]:
def replace_missed_center_out_indexes(df_t, max_num_duplications=None, verbose=0):
    df_t = df_t.copy()

    # num_inx_vals = df_t.groupby('photometryCenterOutIndex')['hasAllPhotometryData'].count()
    # if len(num_inx_vals) == 0:
    #     return
    # num_inx_vals = df_t.groupby('photometryCenterOutIndex')['hasAllPhotometryData'].transform(np.size)
    # reps = df_t[num_inx_vals > 1].copy()
    # reps['cin_out_delta'] = reps['photometryCenterOutIndex'] - reps['photometryCenterInIndex']
    # overwrite_inx = reps[reps['cin_out_delta'] != reps.groupby('photometryCenterOutIndex')['cin_out_delta'].transform(lambda x: np.min(np.abs(x)))].index
    # df_t.loc[overwrite_inx, 'photometryCenterOutIndex'] = df_t.loc[overwrite_inx, 'photometryCenterInIndex']

    i = 0

    while True:
        
        num_inx_vals = df_t[df_t['photometryCenterOutIndex'] > 0].groupby('photometryCenterOutIndex')['hasAllPhotometryData'].count()
        # print(_, num_inx_vals.max())
        if num_inx_vals.max() == 1:
            break
        duplicated_CO_inx = df_t['photometryCenterOutIndex'] == df_t['photometryCenterOutIndex'].shift(-1)
        df_t.loc[duplicated_CO_inx, 'photometryCenterOutIndex'] = df_t.loc[duplicated_CO_inx, 'photometryCenterInIndex']

        if max_num_duplications and i > max_num_duplications:
            break
        i += 1

    if verbose > 0:
        print('# of iterations', i,'— Final max amount of duplicated Center Out Indices:', num_inx_vals.max())
    
    return df_t


In [7]:
def get_is_relevant_trial(hasAllData_srs, index_event_srs):
    return (hasAllData_srs > 0)&(index_event_srs >= 0)

In [8]:
def matlab_indexing_to_python(index_event_srs):
    return index_event_srs - 1

In [9]:
table_index_columns = ['photometryCenterInIndex',
                       'photometryCenterOutIndex',
                       'photometrySideInIndex',
                       'photometrySideOutIndex',
                       'photometryFirstLickIndex']

df_t = replace_missed_center_out_indexes(df_t, verbose=1)
df_t[table_index_columns] = matlab_indexing_to_python(df_t[table_index_columns])

# of iterations 3 — Final max amount of duplicated Center Out Indices: 1


In [10]:
# review = df_t[((df_t['photometryCenterOutIndex'].shift(1) == df_t['photometryCenterOutIndex'])|(df_t['photometryCenterOutIndex'].shift(-1) == df_t['photometryCenterOutIndex']))&(df_t['hasAllPhotometryData'] > 0)]
# review
# review_2 = df_t[((df_t['photometryCenterOutIndex'].shift(1) == df_t['photometryCenterOutIndex'])|(df_t['photometryCenterOutIndex'].shift(-1) == df_t['photometryCenterOutIndex']))&(df_t['hasAllPhotometryData'] > 0)]
# review
# (df_t.groupby('photometryCenterOutIndex')['hasAllPhotometryData'].count() >= 2).sum()
# # tmp_a = df_t[((df_t['photometryCenterOutIndex'].shift(1) == df_t['photometryCenterOutIndex'])|(df_t['photometryCenterOutIndex'].shift(-1) == df_t['photometryCenterOutIndex']))&(df_t['hasAllPhotometryData'] > 0)]
# tmp_inx = df_t[df_t['photometryCenterOutIndex'] > -1].copy()
# num_inx_vals = tmp_inx.groupby('photometryCenterOutIndex')['hasAllPhotometryData'].transform(np.size)
# resp = tmp_inx[num_inx_vals > 1].copy()
# tmp_inx['cin_out_delta'] = tmp_inx['photometryCenterOutIndex'] - tmp_inx['photometryCenterInIndex']
# overwrite_inx = tmp_inx[tmp_inx['cin_out_delta'] != tmp_inx.groupby('photometryCenterOutIndex')['cin_out_delta'].transform(np.min)].index
# df_t.loc[overwrite_inx, 'photometryCenterOutIndex'] = df_t.loc[overwrite_inx, 'photometryCenterInIndex']
# df_t[((df_t['photometryCenterOutIndex'].shift(1) == df_t['photometryCenterOutIndex'])|(df_t['photometryCenterOutIndex'].shift(-1) == df_t['photometryCenterOutIndex']))&(df_t['hasAllPhotometryData'] > 0)]
# # tmp = df_t[(df_t['photometryCenterOutIndex'].shift(1) == df_t['photometryCenterOutIndex'])|(df_t['photometryCenterOutIndex'].shift(-1) == df_t['photometryCenterOutIndex'])]
# tmp = df_t[df_t['photometryCenterOutIndex'] > -1].copy()

# display(tmp)

# num_inx_vals = tmp.groupby('photometryCenterOutIndex')['hasAllPhotometryData'].count()
# print((num_inx_vals > 1).sum())
# num_inx_vals = tmp.groupby('photometryCenterOutIndex')['hasAllPhotometryData'].transform(np.size)
# reps = tmp[num_inx_vals > 1].copy()
# reps['cin_out_delta'] = reps['photometryCenterOutIndex'] - reps['photometryCenterInIndex']
# overwrite_inx = reps[reps['cin_out_delta'] != reps.groupby('photometryCenterOutIndex')['cin_out_delta'].transform(np.min)].index
# tmp.loc[overwrite_inx, 'photometryCenterOutIndex'] = tmp.loc[overwrite_inx, 'photometryCenterInIndex']

# num_inx_vals = tmp.groupby('photometryCenterOutIndex')['hasAllPhotometryData'].count()
# print((num_inx_vals > 1).sum())

# # display(tmp)

# tmp[(tmp['photometryCenterOutIndex'].shift(1) == tmp['photometryCenterOutIndex'])|(tmp['photometryCenterOutIndex'].shift(-1) == tmp['photometryCenterOutIndex'])]
# reps[reps['cin_out_delta'] != reps.groupby('photometryCenterOutIndex')['cin_out_delta'].transform(np.min)]
# # tmp = df_t[(df_t['photometryCenterOutIndex'].shift(1) == df_t['photometryCenterOutIndex'])|(df_t['photometryCenterOutIndex'].shift(-1) == df_t['photometryCenterOutIndex'])]
# tmp = tmp[tmp['photometryCenterOutIndex'] > -1]
# tmp
# # def check_srs_val_is_nan(srs):
    # return srs.isna().any()

# signal_df[col] = (df_t_tmp[df_t_tmp[col].isin(single_inx_vals)].set_index(col)['wasRewarded'] == df_t_tmp[df_t_tmp[col].isin(single_inx_vals)].set_index(col)['wasRewarded'])*1

# for col in table_index_columns:
#     print(((df_t.set_index(col)['wasRewarded'] == df_t.set_index(col)['wasRewarded'])*1).sum())
# # df_t_tmp.columns
# # df_t_tmp.set_index(col)['wasRewarded'] == df_t_tmp.set_index(col)['wasRewarded']
# len(np.unique(signal_df.columns)), len(signal_df.columns)
# (df_t_tmp.set_index(col)['wasRewarded'] == df_t_tmp.set_index(col)['wasRewarded'])*1
# # len(np.unique(df_t_tmp.set_index('photometryCenterOutIndex')['wasRewarded'].index)), len((df_t_tmp.set_index('photometryCenterOutIndex')['wasRewarded'].index))
# # df_t_tmp[(df_t_tmp['photometryCenterOutIndex'] == df_t_tmp['photometryCenterOutIndex'].shift(1))|(df_t_tmp['photometryCenterOutIndex'] == df_t_tmp['photometryCenterOutIndex'].shift(-1))]

In [11]:
# df_t[df_t['']]

In [12]:
def get_is_not_iti(df):
    '''
    Returns a boolean array of whether the trial is not ITI
    Args:
        df: dataframe with entry, exit, lick, reward, and dFF columns
    Returns:
        boolean array of whether the trial is not ITI
    '''
    return df['nTrial'] != df['nEndTrial']

def get_trial_start(center_in_srs):
    return (((~center_in_srs.isna())&(center_in_srs==1))*1)
    
def get_trial_end(center_out_srs):
    return (((~center_out_srs.isna())&(center_out_srs==1))*1)

In [13]:
for col in table_index_columns:
    df_t_tmp = df_t[get_is_relevant_trial(df_t['hasAllPhotometryData'], df_t[col])].copy()
    relevant_inx = df_t_tmp[col]
    
    # single_inx_vals = num_inx_vals[num_inx_vals == 1].index
    # signal_df[col] = (df_t_tmp[df_t_tmp[col].isin(single_inx_vals)].set_index(col)['wasRewarded'] == df_t_tmp[df_t_tmp[col].isin(single_inx_vals)].set_index(col)['wasRewarded'])*1
    
    signal_df[col] = (df_t_tmp.set_index(col)['wasRewarded'] == df_t_tmp.set_index(col)['wasRewarded'])*1
    signal_df[f'{col}r'] = df_t_tmp.set_index(col)['wasRewarded']
    signal_df[f'{col}nr'] = (1 - signal_df[f'{col}r'])

    if col in ['photometrySideInIndex', 'photometrySideOutIndex']: #, 'photometryCenterInIndex']:
        for basis in ['AA', 'Aa', 'aA', 'aa', 'AB', 'Ab', 'aB', 'ab']:
            signal_df[col+basis] = df_t_tmp.set_index(col)[basis].fillna(0)

In [14]:
signal_df['nTrial'] = get_trial_start(signal_df['photometryCenterInIndex']).cumsum().shift(-5)
signal_df['nEndTrial'] = get_trial_end(signal_df['photometrySideOutIndex']).cumsum().shift(5)
signal_df['wi_trial_keep'] = get_is_not_iti(signal_df)
# signal_df

In [15]:
signal_df = signal_df[signal_df['nTrial'] > 0].fillna(0)
signal_df

Unnamed: 0,Ch1,Ch2,Ch5,Ch6,GP_1,GP_2,GP_5,GP_6,SGP_1,SGP_2,...,photometrySideOutIndexAB,photometrySideOutIndexAb,photometrySideOutIndexaB,photometrySideOutIndexab,photometryFirstLickIndex,photometryFirstLickIndexr,photometryFirstLickIndexnr,nTrial,nEndTrial,wi_trial_keep
1177,-0.931840,-0.137341,-0.462593,-0.870066,2.136375,-0.580275,2.021663,0.708039,2.136375,-0.580275,...,0.0,0.0,0.0,0.0,0.0,0,0,1.0,0.0,True
1178,-1.298430,1.125617,-1.120742,-1.043454,2.126542,1.545170,2.351971,0.042744,2.126542,1.545170,...,0.0,0.0,0.0,0.0,0.0,0,0,1.0,0.0,True
1179,-1.591052,-0.477382,-1.246694,-0.359070,1.753276,-0.619391,1.844064,1.785319,1.753276,-0.619391,...,0.0,0.0,0.0,0.0,0.0,0,0,1.0,0.0,True
1180,-1.568876,-0.078864,-1.860099,-0.562284,1.423486,0.563760,1.045771,1.468696,1.423486,0.563760,...,0.0,0.0,0.0,0.0,0.0,0,0,1.0,0.0,True
1181,-1.691941,0.006053,-0.465977,-1.762359,1.320113,-0.050453,-0.793027,0.944078,1.320113,-0.050453,...,0.0,0.0,0.0,0.0,0.0,0,0,1.0,0.0,True
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
42618,1.967829,-0.122234,3.051226,-0.705700,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.0,0.0,0.0,0.0,0.0,0,0,384.0,384.0,False
42619,1.217799,-0.023535,1.543815,-1.778901,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.0,0.0,0.0,0.0,0.0,0,0,384.0,384.0,False
42620,0.802294,0.660677,0.836298,0.488886,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.0,0.0,0.0,0.0,0.0,0,0,384.0,384.0,False
42621,0.001986,-0.505008,-0.273042,0.263355,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.0,0.0,0.0,0.0,0.0,0,0,384.0,384.0,False


In [16]:
# # signal_df['spnr'] = ((df['spnr'] == 1)&(df['photometrySideInIndex'] != 1)).astype(int)
# signal_df['spnnr'] = ((signal_df['spnnr'] == 1)&(signal_df['photometrySideInIndex'] != 1)).astype(int)
# # signal_df['spxr'] = ((df['spxr'] == 1)&(df['photometrySideOutIndex'] != 1)).astype(int)
# signal_df['spxnr'] = ((signal_df['spxnr'] == 1)&(signal_df['photometrySideOutIndex'] != 1)).astype(int)

In [17]:
# Rename Columns
## * File-Specific Columns
## * General Columns


In [18]:
# # Add Missing Y-Columns to File
# for y_col in y_col_lst_all:
#             if y_col not in df.columns:
#                 df[y_col] = np.nan
#                 continue
#             if 'SGP_' == y_col[:len('SGP_')]:
#                 df[y_col] = df[y_col].replace(0, np.nan)
#             if df[y_col].std() >= 90:
#                 df[y_col] /= 100

In [19]:
# Combine Signal & Table Files
## * Define trial starts / ends
## * Get is not ITI
## * Define rewarded/unrewarded trials (across entire trial)
## * Define side-agnostic events
## * Get first-time events
## * Generate AB Labels


In [20]:
# Detrend Data


In [21]:
# Timeshift value predictors


In [22]:
# Split Data


In [23]:
# Fit / cross-validate


In [24]:
# Plot Results

In [25]:
# Evaluate Results


In [26]:
# Save Results


In [6]:
import glob
from sglm.features import gen_signal_df as gsd
from sglm.features import build_features as bf

# Load Signal Data
signal_files = glob.glob(f'../data/raw/GLM_SIGNALS_WT61_*')
table_files = [_.replace('GLM_SIGNALS', 'GLM_TABLE') for _ in signal_files]

files_list = signal_files
channel_definitions = {
        ('WT61',): {'Ch1': 'gACH', 'Ch2': 'rDA'},
        ('WT64',): {'Ch1': 'gACH', 'Ch2': 'empty'},
        ('WT63',): {'Ch1': 'gDA', 'Ch2': 'empty'},
    }

channel_assignments = bf.get_rename_columns_by_file(files_list, channel_definitions)

# Load Table Data
signal_fn = signal_files[0]
table_fn = table_files[0]

signal_df = pd.read_csv(signal_fn)
table_df = pd.read_csv(table_fn)
signal_filename_out = signal_fn.split('/')[-1].replace('GLM_SIGNALS', 'GLM_SIGNALS_INTERIM').replace('txt', 'csv')
table_filename_out = signal_fn.split('/')[-1].replace('GLM_TABLE', 'GLM_TABLE_INTERIM').replace('txt', 'csv')

%load_ext autoreload
%autoreload 2
signal_df, table_df = gsd.generate_signal_df(signal_fn,
                                        table_fn,
                                        signal_filename_out=f'../data/interim/{signal_filename_out}',
                                        table_filename_out=f'../data/interim/{table_filename_out}'
                                        )

('WT61',)
> GLM_SIGNALS_WT61_10152021.txt
> GLM_SIGNALS_WT61_10042021.txt
> GLM_SIGNALS_WT61_10062021.txt
> GLM_SIGNALS_WT61_10132021.txt
> GLM_SIGNALS_WT61_10082021.txt
> GLM_SIGNALS_WT61_10182021.txt
> GLM_SIGNALS_WT61_10112021.txt
('WT64',)
('WT63',)
# of iterations 3 — Final max amount of duplicated Center Out Indices: 1


In [7]:
### Rename Columns

# signal_df, table_df = gsd.gen_signal_df(signal_fn,
#                                         table_fn,
#                                         signal_filename_out=f'../data/interim/{signal_filename_out}',
#                                         table_filename_out=f'../data/interim/{table_filename_out}'
#                                         )

In [16]:

# Break down Preprocess Lynne into component parts

# Rename Columns
df = bf.rename_consistent_columns(signal_df)
if signal_fn in channel_assignments:
    df = df.rename(channel_assignments[signal_file], axis=1)

## Set Reward Flags
df['r_trial'] = (df.groupby('nTrial')['photometrySideInIndexr'].transform(np.sum) > 0) * 1.0
df['nr_trial'] = (df.groupby('nTrial')['photometrySideInIndexnr'].transform(np.sum) > 0) * 1.0

## Define Side Rewarded / Unrewarded Flags
df = bf.set_port_entry_exit_rewarded_unrewarded_indicators(df)

## Define Side Agnostic Events
df = bf.define_side_agnostic_events(df)

# print('Percent of Data in ITI:', (df['nTrial'] == df['nEndTrial']).mean())


In [17]:
df.columns

Index(['Ch1', 'Ch2', 'Ch5', 'Ch6', 'GP_1', 'GP_2', 'GP_5', 'GP_6', 'SGP_1',
       'SGP_2', 'SGP_5', 'SGP_6', 'cpo', 'cpn', 'cpx', 'rpo', 'rpn', 'rpx',
       'rl', 'lpo', 'lpn', 'lpx', 'll', 'r', 'nr', 'photometryCenterInIndex',
       'photometryCenterInIndexr', 'photometryCenterInIndexnr',
       'photometryCenterOutIndex', 'photometryCenterOutIndexr',
       'photometryCenterOutIndexnr', 'photometrySideInIndex',
       'photometrySideInIndexr', 'photometrySideInIndexnr',
       'photometrySideInIndexAA', 'photometrySideInIndexAa',
       'photometrySideInIndexaA', 'photometrySideInIndexaa',
       'photometrySideInIndexAB', 'photometrySideInIndexAb',
       'photometrySideInIndexaB', 'photometrySideInIndexab',
       'photometrySideOutIndex', 'photometrySideOutIndexr',
       'photometrySideOutIndexnr', 'photometrySideOutIndexAA',
       'photometrySideOutIndexAa', 'photometrySideOutIndexaA',
       'photometrySideOutIndexaa', 'photometrySideOutIndexAB',
       'photometrySideOutIn