#
# Preprocessing and training of the LFADS model (Session S6)
#


This notebook demonstrates how the dataset was preprocessed for LFADS model training, using Session S6 as an example.

Although the saved .pkl.gz file already includes the final LFADS and SpikeCount data, we reconstruct the intermediate steps here for reproducibility and understanding.

**Note:**  
This notebook takes as input a pre-filtered dataset saved in a gzip-compressed pickle file. The dataset originates from the raw recordings available at [Zenodo](https://zenodo.org/records/13207505), but includes only completed **decision trials** (`trialType==20`), excluding other tasks or incomplete trials.

After generating the LFADS training dataset, we load the results of running AutoLFADS on NeuroCAAS and augment the DataFrame with two new columns:

- `'SpikeCount'`: Trial-aligned spike count matrix used for training (firing rates in Hz, 10ms bins).
- `'LFADS'`: Smoothed firing rates inferred by AutoLFADS for the same neurons and time windows.

Both columns contain lists of arrays, one per trial, each of shape `(NCells √ó TimeSteps)`, with time cropped from dots onset to saccade completion ‚Äî so trial durations vary (not padded with zeros).

##
## üß© Step 1: Load preprocessed dataset

In [None]:

from src.io_utils import load_dataframe_with_metadata

session = "S6"
df = load_dataframe_with_metadata(session)

print(f"Session: {df.attrs['Session']}, Monkey: {df.attrs['Monkey']}, Date: {df.attrs['Date'].date()}")
print(f"Number of neurons: {df.attrs['NCells']}")
print(f"Columns: {df.columns.tolist()}")
df.head()


##
## üõ†Ô∏è Step 2: Construct training and validation datasets for AutoLFADS

We split the dataset into validation and training, mantaining a choice and coherence balanced split.

In [None]:

import numpy as np
import h5py

coherences = np.sort(df['coh'].unique())
train_idx, valid_idx = [], []

for choice in [0, 1]:
    for coh in coherences:
        idx = df[(df['choice'] == choice) & (df['coh'] == coh)].index.tolist()
        np.random.shuffle(idx)
        split = int(2 / 3 * len(idx))
        train_idx.extend(idx[:split])
        valid_idx.extend(idx[split:])

train_idx, valid_idx = np.array(train_idx), np.array(valid_idx)


### Binning spikes into 10ms bins and creating inputs

Inputs are just unit step functions masking the duration of the trial. There's one for each choice.

In [None]:

bin_size_ms = 10
bin_size_s = bin_size_ms / 1000
max_duration = np.max(df['saccadeComplete'] - df['dotsOn'])
bins = np.arange(0, max_duration + bin_size_s, bin_size_s)
times = bins[:-1] + bin_size_s / 2

def process_trial(trial, df, times, bins, N):
    t0, t1 = df['dotsOn'].loc[trial], df['saccadeComplete'].loc[trial]
    duration = t1 - t0
    spikes = [np.histogram(np.array(sp)-t0, bins=bins)[0] for sp in df['spCellPop'].loc[trial][:N]]
    spike_array = np.stack(spikes, axis=-1)
    valid = times < duration
    left = valid & (df['choice'].loc[trial] == 0) & (times > df['dotsOn'].loc[trial]-t0) & (times < df['dotsOff'].loc[trial]-t0)
    right = valid & (df['choice'].loc[trial] == 1) & (times > df['dotsOn'].loc[trial]-t0) & (times < df['dotsOff'].loc[trial]-t0)
    inputs = np.stack([left, right], axis=-1).astype(float)
    return spike_array, inputs

NCells = df.attrs['NCells']
DataT, InputsT = zip(*[process_trial(trial, df, times, bins, NCells) for trial in train_idx])
DataV, InputsV = zip(*[process_trial(trial, df, times, bins, NCells) for trial in valid_idx])

DataT, InputsT = np.stack(DataT), np.stack(InputsT)
DataV, InputsV = np.stack(DataV), np.stack(InputsV)


##
## üíæ Step 3: Save HDF5 file for AutoLFADS (NeuroCAAS compatible)

In [None]:

filename = f"DataLIP_10ms_{session}.h5"
with h5py.File(filename, 'w') as hf:
    hf.create_dataset('train_encod_data', data=DataT)
    hf.create_dataset('valid_encod_data', data=DataV)
    hf.create_dataset('train_recon_data', data=DataT)
    hf.create_dataset('valid_recon_data', data=DataV)
    hf.create_dataset('train_ext_input', data=InputsT)
    hf.create_dataset('valid_ext_input', data=InputsV)
    hf.create_dataset('IndT', data=train_idx)
    hf.create_dataset('IndV', data=valid_idx)

print(f"Saved LFADS dataset to {filename}")


##
## üß™ Step 4: Running AutoLFADS on NeuroCAAS

We use the preprocessed dataset `DataLIP_10ms_S6.h5` together with the config file `S6_AutoLFADS.yaml`.
This YAML file specifies all the training parameters, architecture details, and hyperparameter search settings.

NeuroCAAS is an online platform for reproducible cloud-based neuroscience analyses:
Upload the HDF5 data and `S6_AutoLFADS.yaml` configuration to [NeuroCAAS](https://neurocaas.org/analysis/20) for training.
Uploading both files to NeuroCAAS enables reproducible, scalable training of LFADS models across sessions.

üìñ How AutoLFADS works:
AutoLFADS (Keshtkaran et al., 2022) is a deep learning framework based on LFADS (Pandarinath et al., 2018) that uses population-based training (PBT) to automatically tune model hyperparameters.
It infers latent dynamical structure from neural spike trains using recurrent networks and variational inference.

**References:**  
- AutoLFADS paper: [Keshtkaran et al., 2022](https://doi.org/10.1038/s41592-022-01675-0)  
- Original LFADS paper: [Pandarinath et al., 2018](https://doi.org/10.1038/s41592-018-0109-9)


##
## üì• Step 5: Load LFADS results and populate DataFrame

In [None]:

lfads_path = f'./autolfads{session}/lfads_output_{session}.h5'

with h5py.File(lfads_path, 'r') as data:
    IndT, IndV = data['IndT'][:], data['IndV'][:]
    lfads_train, lfads_valid = data['train_output_params'][:], data['valid_output_params'][:]
    data_train, data_valid = data['train_encod_data'][:], data['valid_encod_data'][:]

amp = 100
df['LFADS'], df['SpikeCount'] = None, None

for i, trial in enumerate(IndT):
    t_end = int(100 * (df['saccadeDetected'].loc[trial] - df['dotsOn'].loc[trial]))
    df.at[trial, 'LFADS'] = (amp * lfads_train[i][:t_end]).T
    df.at[trial, 'SpikeCount'] = (amp * data_train[i][:t_end]).T

for i, trial in enumerate(IndV):
    t_end = int(100 * (df['saccadeDetected'].loc[trial] - df['dotsOn'].loc[trial]))
    df.at[trial, 'LFADS'] = (amp * lfads_valid[i][:t_end]).T
    df.at[trial, 'SpikeCount'] = (amp * data_valid[i][:t_end]).T
