## Extracting data for VaLPACa

In [1]:
from pathlib import Path
import warnings

import h5py
from matplotlib import pyplot as plt
import numpy as np

## Hyperparameters

**Note:** This example was obtained from [dataset 39 of the Dandi archive](https://dandiarchive.org/dandiset/000039).

In [2]:
# FILENAME = "sub-699590236_ses-727153911_behavior+ophys.nwb" # VisP L2/3
FILENAME = "sub-746926904_ses-764182166_behavior+ophys.nwb" # RL L2/3
FILEPATH = Path("nwb_files", FILENAME)
SEED = 10 # for generating the random train/valid split
NUM_FR = 60 # number of frames per trial
PROP_TRAIN = 0.8 # proportion of the dataset to use for training
TRIAL_PARAMS = ["direction", "contrast"] # trial parameters to record 

OUTPUT_NAME = "{}.h5".format("_".join(FILEPATH.name.split("_")[:2])) # set an output file name

## Extract data
This data file is in the NWB format, and therefore is read using the `pynwb` package.

In [3]:
import pynwb
with pynwb.NWBHDF5IO(FILEPATH, "r") as io:
    read_nwbfile = io.read()
    
    # extract the stimulus dataframe
    stim_df = read_nwbfile.intervals["epochs"].to_dataframe()
    stim_df = stim_df.rename(columns={"start_time": "start_frame", "stop_time": "stop_frame"})
    
    # extract full fluorescence data
    dff_traces = read_nwbfile.processing["brain_observatory_pipeline"]["Fluorescence"]["DfOverF"].data[()]

    # extract the full timestamps
    timestamps = read_nwbfile.processing["brain_observatory_pipeline"]["Fluorescence"]["DfOverF"].timestamps[()]
    
    # ignore the length warning for the MotionCorrection time series - it is not relevant.

<hr style="border:1.5px solid black">

## Format data

### If `stim_df`, `dff` and `timestamps` are properly specified, the rest should work on its own (unless additional exclusion criteria are needed: see below):
- `stim_df`: Stimulus dataframe, where each row carries information for a single trial: 
  - `start_frame`,
  - `stop_frame`, 
  - the info specified by `TRIAL_PARAMS`. 
- `dff_traces`: dF/F per frame, into which `stim_df['start_frame']` and `stim_df['stop_frame']` index.   
- `timestamps`: Timestamp for each dF/F frame (in sec).  

### Exclusion criteria:
- NaN under any of the necessary `stim_df` columns.
- Trial length shorter than 80% of `NUM_FR`.

In [4]:
def get_duration(num_sec):
    num_min = int(num_sec // 60)
    num_sec = num_sec - num_min * 60
    return num_min, num_sec

### Check for potential problems in the data

In [5]:
if len(dff_traces)!= len(timestamps):
    raise ValueError("'dff_traces' and 'timestamps' should have the same length.")

check_cols = ["start_frame", "stop_frame"] + TRIAL_PARAMS
for col in check_cols:
    if col not in stim_df.columns:
        raise KeyError(f"'stim_df' missing '{col}' column.")

if stim_df["start_frame"].min() < 0:
    raise ValueError("Lowest start frame cannot be below 0.")

if stim_df["stop_frame"].max() > len(dff_traces):
    raise ValueError("Highest stop frame cannot be greater than the length of 'dff_traces'.")

if (stim_df["stop_frame"] - stim_df["start_frame"]).min() < 0:
    raise ValueError("No stop frame should be smaller than its corresponding start frame.")

### Compute relevant info

In [6]:
# compute dt
dt = np.diff(timestamps).mean()

# exclude trials, if applicable, and warn
excl_nan = np.isnan(sum([stim_df[param] for param in ["start_frame", "stop_frame"] + TRIAL_PARAMS]))
excl_leng = (stim_df["stop_frame"] - stim_df["start_frame"]) < 0.8 * NUM_FR
excl = excl_nan + excl_leng
num_excl = sum(excl.astype(bool))
if num_excl:
    prop_excl = np.around(100 * num_excl / len(excl), 2)
    warnings.warn(f"Excluding {num_excl} trials ({prop_excl}%).")
stim_df_keep = stim_df.loc[~excl]

# aggregate trial parameter data
trial_param_data = dict()
for trial_param in TRIAL_PARAMS:
    trial_param_data[trial_param] = stim_df_keep[trial_param].to_numpy()

# get start and stop frames
start_fr = stim_df_keep["start_frame"].to_numpy().astype(int)
stop_fr = stim_df_keep["stop_frame"].to_numpy().astype(int)
num_fr = stop_fr - start_fr
num_trials = len(start_fr)

# check number of frames per trial
if NUM_FR < num_fr.min():
    warnings.warn("The number of frames per trial is lower than the minimum number of frames per trial.")
if NUM_FR > num_fr.max():
    warnings.warn("The number of frames per trial is higher than the maximum number of frames per trial.")
if NUM_FR > num_fr.min() * 1.1:
    warnings.warn("The number of frames per trial is higher by at least 10% than the minimum number of frames per trial.")

# get the number of train/valid trials
num_train = int(num_trials * PROP_TRAIN)
num_valid = num_trials - num_train

# calculate the duration of each trial, on average
durations = timestamps[stop_fr] - timestamps[start_fr]
sec_per_trial = durations.mean()
duration_min, duration_sec = get_duration(durations.sum())

# extract dF/F data
index = (start_fr + np.arange(NUM_FR).reshape(-1, 1)).T
dff = dff_traces[index]



### Report a few characteristics

In [7]:
print(f"Total : {num_trials} trials ({duration_min}m {duration_sec:05.2f}s)")
print(f"Per   : {num_fr.min()} to {num_fr.max()} frames ({sec_per_trial:.4f}s)")
print(f"dt    : {dt:.4f} sec / frame")
print(f"Split : {num_train} train / {num_valid} valid")
print("\nOverall: {} trials x {} frames x {} ROIs".format(*dff.shape))

Total : 1152 trials (38m 07.05s)
Per   : 59 to 61 frames (1.9853s)
dt    : 0.0332 sec / frame
Split : 921 train / 231 valid

Overall: 1152 trials x 60 frames x 72 ROIs


### List of keys to include in the output h5 file
`dt` (1 value)  
`train_fluor` (trial x ROI)  
`train_idx` (trial numbers)  
`train_{param}` (value per trial) (param of interest)  
`valid_fluor` (trial x ROI)  
`valid_idx` (trial numbers)  
`valid_{param}` (value per trial) (param of interest)  

### Keys that will be added during preprocessing with OASIS
*`obs_bias_init` (a 0 per ROI)  
*`obs_gain_init` (a value per ROI)   
*`obs_tau_init` (a value per ROI)   
*`obs_var_init` (a value per ROI)  
*`train_ocalcium` (trial x ROI)  
*`train_ospikes` (trial x ROI)  
*`valid_ocalcium` (trial x ROI)  
*`valid_ospikes` (trial x ROI)  

### Add dt (sec / frame)

In [8]:
data_dict = {
    "dt": dt,    
}

### Select and add training/validation indices

In [9]:
randst = np.random.RandomState(SEED)

train_idxs = randst.choice(num_trials, num_train, replace=False)
valid_mask = np.zeros(num_trials).astype(bool)
valid_mask[train_idxs] = True
valid_idxs = np.where(~valid_mask)[0]
randst.shuffle(valid_idxs)

data_dict[f"train_idx"] = train_idxs
data_dict[f"valid_idx"] = valid_idxs

### Add trial information

In [10]:
for prefix in ["train", "valid"]:
    idxs = data_dict[f"{prefix}_idx"]
    for key, vals in trial_param_data.items():
        data_dict[f"{prefix}_{key}"] = vals[idxs]

### Add fluorescence data

In [11]:
for prefix in ["train", "valid"]:
    idxs = data_dict[f"{prefix}_idx"]
    data_dict[f"{prefix}_fluor"] = dff[idxs]

### Save the file

In [12]:
with h5py.File(OUTPUT_NAME, "w") as f:
    for key, value in data_dict.items():
        f.create_dataset(key, data=value)

### Create file with added OASIS-inferred spiking data, and initialization values
Took 20 min locally, for this example.

In [13]:
import os
import sys

run_file = Path("..", "utils", "preprocessing_oasis.py")
os.environ["RUN_FILE"] = str(run_file)
os.environ["DATA_FILE"] = str(OUTPUT_NAME)

In [14]:
!echo "python ${RUN_FILE} --data_path ${DATA_FILE} --normalize --undo_train_test_split"
!python ${RUN_FILE} --data_path ${DATA_FILE} --normalize --undo_train_test_split

python ../utils/preprocessing_oasis.py --data_path sub-746926904_ses-764182166.h5 --normalize --undo_train_test_split
  c /= stddev[:, None]
  c /= stddev[None, :]
