Getting started with loading CHUV data
--------------------------------------

In this notebook, you can see examples of how to load some of the CHUV data, such as *dt5*, 
or *gdp* generated files, or even whole *absd* directory.

Requirements:  
* having the *__UP2/* data 
* having set the env variable ``DATA_DIR`` to the above mentioned directory

*Author: Etienne de Montalivet*

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
import torchaudio
import torchvision
import torchvision.transforms as T

import lighthouse.data_loader.files_folders as ff
import lighthouse.metadata as metadata
from lighthouse.data_loader.load_chuv import load_absd, load_dt5, load_gdp, load_smr
from lighthouse.data_loader.torch_dataset import TimeseriesDataset
from lighthouse.preprocessing.transform import MNEFilter

In [None]:
training_sessions = metadata.get_training_sessions()

### GDP files - stimulation

Loading GDP files for all sessions is still ongoing work which is why we load a specific session here.

In [None]:
data_dir = Path(os.environ["DATA_DIR"]) / "__UP2" / "0_RAW_DATA" / "UP2_001"
(
    (stim_data, stim_times, stim_ch_names),
    (pred_data, pred_times, pred_ch_names),
    (enable_stim_data, enable_stim_times),
    lm,
    stim_metadata,
) = load_gdp(
    data_dir
    / "UP2001_2023_11_02_BSITraining_day11"
    / "GDP"
    / "Patients"
    / "Patient_UP2001Rostral"
    / "Sessions"
    / "Session_20231102141829"
    / "GeneralLogs"
)

In [None]:
stim_data.shape, pred_data.shape, enable_stim_data.shape

In [None]:
stim_ch_names, pred_ch_names

In [None]:
stim_metadata.keys()

In [None]:
stim_metadata["newElbowExtension"]

In [None]:
enable_stim_data, enable_stim_times

### smr files - pure hardware data (ecog, trigger, temp, acc)

In [None]:
smr_files = list(Path(training_sessions[0]).glob("**/*.smr"))
smr_file = smr_files[0]
display(smr_files)

In [None]:
signals, times, ch_names = load_smr(smr_file)

In [None]:
signals.shape, times.shape, ch_names

### dt5 files - ecog + pred + features

In [None]:
dt5_files = list(Path(training_sessions[0]).glob("**/*.dt5"))
dt5_file = dt5_files[0]
display(dt5_files)

In [None]:
signals, ch_names = load_dt5(dt5_file, return_ch_names=True, return_all=True)

In [None]:
signals.shape, ch_names

In [None]:
plt.plot(signals[ch_names.index("is_updating")])
plt.show()

### absd data (whole folder)

In [None]:
absd_dirs = [
    p.resolve() for p in Path(training_sessions[0]).glob("**") if p.is_dir() and "ABSD" in p.name or "absd" in p.name
]
absd_dir = absd_dirs[0]
display(absd_dirs)

In [None]:
signals, ch_names = load_absd(absd_dir)

In [None]:
signals.shape, ch_names

In [None]:
plt.plot(signals[ch_names.index("is_updating")])
plt.title(f"is_updating state")
plt.show()

### torch dataset with preprocessing

In [20]:
# use first absd dir of first training session
absd_dir = ff.find_absd_dirs(training_sessions[0])[0]

In [25]:
SFREQ = 585
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [22]:
# loading functions to be used in the dataset
def load_absd_ecog(absd_dir):
    signals, ch_names = load_absd(absd_dir)
    ecog_ch_names = [ch for ch in ch_names if "ecog" in ch]
    ecog_signals = signals[[ch_names.index(ch) for ch in ecog_ch_names]]
    return ecog_signals


def load_absd_states(absd_dir, return_ch_names=False):
    signals, ch_names = load_absd(absd_dir, return_states=True)
    state_ch_names = [ch for ch in ch_names if "state__" in ch]
    state_signals = signals[[ch_names.index(ch) for ch in state_ch_names]]
    if return_ch_names:
        return state_signals, state_ch_names
    return state_signals

In [None]:
_, state_ch_names = load_absd_states(absd_dir, return_ch_names=True)

In [None]:
win_size = 585
win_step = 59
n_fft = win_size
hop_length = 10
dataset = TimeseriesDataset(
    load_x_func=load_absd_ecog,
    load_x_args={"absd_dir": absd_dir},
    load_y_func=load_absd_states,
    load_y_args={"absd_dir": absd_dir},
    n_samples_step=win_step,
    n_samples_win=win_size,
    x_preprocess=T.Compose(
        [
            MNEFilter(sfreq=SFREQ, l_freq=1, h_freq=200, notch_freqs=np.arange(50, 201, 50), apply_car=True),
            torchvision.transforms.ToTensor(),
        ],
    ),
    y_preprocess=T.Compose(
        [],
    ),
    x_transform=T.Compose(
        [
            # we first push to GPU, then apply transforms
            lambda x: x.to(DEVICE),
            # needs to be adjusted to the desired output size (FREQ_BINS, TIME_BINS)
            torchaudio.transforms.Spectrogram(
                n_fft=n_fft,
                win_length=n_fft,
                hop_length=hop_length,
                center=True,
                window_fn=lambda x: torch.hann_window(x).to(DEVICE),
            ),
            lambda x: x.squeeze(0).float(),
        ]
    ),
    y_transform=T.Compose(
        [
            lambda x: torch.tensor(x).to(DEVICE),
            # for the sake of this example, we take the last state value
            lambda x: x[..., -1].flatten().float(),
        ]
    ),
    precompute=True,
)

In [16]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=False)

In [17]:
X, y = next(iter(dataloader))

In [None]:
X.shape, y.shape

In [None]:
y[0]