In [None]:
%load_ext autoreload
%autoreload 2

In [1]:
import pandas as pd
import numpy as np
from pathlib import Path
import itertools
import matplotlib.pyplot as plt
import fit_glm_helpers as fgh

# Import Data
Inputted data should conform to the following conventions:

### Session ID
* Must be a unique string file identifier that will represent the combined version of the signal and behavior files

### Signal File
* Each file should represent a single recording with a constant sampling rate
* Each row in the file should represent a single timestamp within a given recording (must be in chronological order)

Example:
| id_timestamp | identifier_1 | predictor_1 | predictor_2 | response_1 | response_2 |
| --- | --- | --- | --- | --- | --- |
| 0 | 0 | 0 | 0 | 1 | 0.3 |
| 1 | 0 | 0 | 0 | 0 | 1.4 |
| 2 | 0 | 0 | 0 | 1 | 2.3 |
| 3 | 0 | 1 | 0 | 1 | 0.3 |
| 4 | 0 | 0 | 0 | 0 | 1.4 |
| 5 | 0 | 0 | 0 | 1 | 2.3 |
| 6 | 1 | 0 | 0 | 0 | 1.4 |
| 7 | 0 | 0 | 0 | 1 | 2.3 |
| 8 | 0 | 0 | 0 | 0 | 1.4 |
| 9 | 1 | 0 | 0 | 1 | 2.3 |
| 10 | 0 | 0 | 0 | 0 | 1.4 |
| 11 | 0 | 0 | 0 | 1 | 2.3 |
| 12 | 0 | 0 | 0 | 0 | 1.4 |
| 13 | 0 | 0 | 0 | 1 | 2.3 |
| 14 | 0 | 0 | 1 | 0 | 1.4 |
| 15 | 0 | 0 | 0 | 0 | 2.3 |
| 16 | 0 | 0 | 0 | 0 | 1.4 |
| 17 | 0 | 0 | 0 | 0 | 2.3 |
| 18 | 0 | 1 | 0 | 0 | 1.4 |
| 19 | 0 | 0 | 0 | 1 | 2.3 |

### Trial File (Optional)
* Each file should represent a single recording with a constant sampling rate
* Each row must represent a unique trial in chronological order, but does not need to start from zero

Columns (Alignment Columns + Information Columns):
* Alignment Columns: Contains the indices (in the signal table) at which the event in question occurs
* Information Columns: Contains information associated with the given trial

Example:
| id_trial | centerInIndex (Alignment) | centerOutIndex (Alignment) | sideInIndex (Alignment) | sideOutIndex (Alignment) | hasAllData (Information) |
| --- | --- | --- | --- | --- | --- |
| trial_0 | 12 | 13 | 17 | 18 | 0 |
| trial_1 | 20 | 22 | 23 | 25 | 1 |
| trial_0 | 27 | 28 | 31 | 35 | 1 |
| trial_2 | 50 | 54 | 57 | 60 | 1 |

In [None]:
dir_data = Path('/Users/josh/Documents/Harvard/GLM/sabatinilab-glm/data/old-data-version/raw-new/Figure_1_2')
dir_output = Path('/Users/josh/Desktop/example_output_folder')

lst_dict_inputdata = [
    {'session_id': 'WT63_11082021',
    'filepath_signal': dir_data / Path('GLM_SIGNALS_WT63_11082021.txt'),
    'filepath_trial': dir_data / Path('GLM_TABLE_WT63_11082021.txt'),
    'bool_trialTable_matlab_indexed': True,
    'columnName_trialTable_trialId': None,
    'columnRenames_signal': {'Ch1': 'gDA', 'Ch5': 'gACH'},
    'columnRenames_trial': None},
]

dir_output.mkdir(parents=True, exist_ok=True)

In [None]:
columnName_alignment_trial_start = 'photometryCenterInIndex'
columnName_alignment_trial_end = 'photometrySideOutIndex'

# Note: Alignment values of 0 for Matlab-indexed trial tables will be treated as "no-data" values
# and and -1 for Python-indexed trial tables. Matlab-indexed trial tables should only have values
# >= 0 in and >= -1 in Python.
lst_strColumns_alignment = [
    'photometryCenterInIndex',
    'photometryCenterOutIndex',
    'photometrySideInIndex',
    'photometrySideOutIndex',
]

lst_strColumns_information = [
    'nTrial_raw', 'hasAllPhotometryData',
    'wasRewarded', 'word',
]

In [None]:
bool_drop_zeroAlignments = True

trialSignalAligned_agg = fgh.TrialSignalAlignerAggregator()

for dict_inputdata in lst_dict_inputdata:
    # Load data
    trial = fgh.TrialPreprocessor(pd.read_csv(dict_inputdata['filepath_trial']))
    signal = fgh.SignalPreprocessor(pd.read_csv(dict_inputdata['filepath_signal']))

    # Preprocess trial table
    trial.preprocess();
    signal.preprocess();

    # Trial / signal alignment
    trialSignalAligned = fgh.TrialSignalAligner(trial, signal)
    trialSignalAligned.align();
    trialSignalAligned.trialstamp();
    trialSignalAligned.timestamp();

    # Aggregate
    trialSignalAligned_agg.add(trialSignalAligned);

trialSignalAligned_agg.combine();

In [None]:
# Generate prediction dataframe X, prediction dataframe y
predictors = ['predictor_1', 'predictor_2']
response = 'y'
trialSignalAligned_agg.generate_Xy(predictors, response);

# Unroll specified X columns into onehot representations
trialSignalAligned_agg.unroll_X_columns(['predictor_1', 'predictor_2']);

# Timeshift X columns
trialSignalAligned_agg.timeshift_X_columns(['predictor_1', 'predictor_2'], shift_amt=1);

# Split train/validation/test sets
trialSignalAligned_agg.split_train_validation_test();

# Fit GLM
glm = fgh.GLM(trialSignalAligned_agg);
glm.fit_GLM();
glm.generate_GLM_summary();
glm.plot_GLM_summary();

# Generate predictions for train/validation/test sets. Evaluate predictions on train/validation/test sets.
glm.generate_predictions();
glm.evaluate_predictions();
glm.generate_prediction_plots();

# Save preprocessing parameters
trial.save_preprocessing_info(dir_output / Path('trial_preprocessing_info.json'));
signal.save_preprocessing_info(dir_output / Path('signal_preprocessing_info.json'));

# Save alignment parameters
trialSignalAligned.save_alignment_info(dir_output / Path('alignment_info.json'));

# Save aggregation parameters
trialSignalAligned_agg.save_aggregation_info(dir_output / Path('aggregation_info.json'));

# Save GLM parameters
glm.save_GLM_info(dir_output / Path('glm_info.json'));