In [93]:

import os

import pandas as pd
import numpy as np
import mne
import h5py
from dataclasses import dataclass

from rembler.utils.dataset_utils import split_train_test
from rembler.utils.io_utils import aggregate_csvs, list_files_with_extension
from rembler.utils import sleep_utils as su

In [46]:
full.head()

Unnamed: 0,sleep,start,stop,activity,context,subject,session,day
0,A,0,5000,0.091589,AAA,MPSD4,Baseline,1
1,A,5000,10000,0.011654,AAAX,MPSD4,Baseline,1
2,A,10000,15000,1.001923,AAAXA,MPSD4,Baseline,1
3,X,15000,20000,4.275597,AAXAA,MPSD4,Baseline,1
4,A,20000,25000,2.070143,AXAAA,MPSD4,Baseline,1


In [47]:
stage_matched.head()

Unnamed: 0,sleep,start,stop,activity,context,subject,session,day
0,R,55000,60000,0.017417,SSRRR,MPSD1,Baseline,1
1,R,60000,65000,0.017363,SRRRR,MPSD1,Baseline,1
2,R,65000,70000,0.017319,RRRRR,MPSD1,Baseline,1
3,R,70000,75000,0.017282,RRRRR,MPSD1,Baseline,1
4,R,75000,80000,0.017252,RRRRR,MPSD1,Baseline,1


In [48]:
final

Unnamed: 0,sleep,start,stop,activity,context,subject,session,day,role
0,S,34970000,34975000,0.015893,SSSSS,MPSD1,Baseline,1,train
1,R,16815000,16820000,0.004257,RRRRR,MPSD1,Baseline,1,train
2,R,35295000,35300000,0.039750,RRRAA,MPSD1,Baseline,1,train
3,A,39020000,39025000,1.003210,AAAAA,MPSD1,Baseline,1,train
4,A,8015000,8020000,2.472006,AAAAX,MPSD1,Baseline,1,train
...,...,...,...,...,...,...,...,...,...
17632,A,27190000,27195000,0.005248,AAAAA,MPSD9,Recovery,End,test
17633,R,35075000,35080000,0.010631,SSRRR,MPSD9,Recovery,End,test
17634,A,23830000,23835000,3.184113,AAAAA,MPSD9,Recovery,End,test
17635,A,37550000,37555000,0.081009,AAAAA,MPSD9,Recovery,End,test


In [103]:
df = pd.read_csv("data/full_sleep_stage_matched_train_test_split.csv")

In [75]:
file_components =df.loc[df["role"] == "train"][["subject", "session", "day"]].drop_duplicates(subset=["subject", "session"])

In [94]:
@dataclass
class Config:
    bout_length: float
    bout_context: float
    sample_rate: float
    causal: bool

config = Config(
    bout_length=10.0,
    bout_context=5.0,
    sample_rate=500.0,
    causal=False,
)

In [118]:
signals_to_extract = ["EEG", "EMG"]
path = os.path.join("/Volumes/DataCave/rembler_data/training_datasets", "5bout_context_non_causal.h5")

with h5py.File(path, "a") as f:
    for idx, row in file_components.iterrows():
        if row["session"] == "Baseline":
            edf_filename = os.path.join("/Volumes/DataCave/rembler_data/raw_edf", f"{row['subject']} {row['session']}.edf")
        else:
            edf_filename = os.path.join("/Volumes/DataCave/rembler_data/raw_edf", f"{row['subject']} {row['session']} {row['day']}.edf")
        # subset to rows relevant to this file
        df_sub = df.query(f"subject == '{row['subject']}' & session == '{row['session']}' & day == '{row['day']}'")
        # read the edf file
        edf_data = mne.io.read_raw_edf(edf_filename, preload=True, verbose="WARNING")
        signals = edf_data.get_data(signals_to_extract)
        leading_buffer, trailing_buffer = su.determine_buffering(
            config.bout_length, config.bout_context, config.sample_rate, config.causal
        )
        # loop over each row in the subset dataframe
        for idx, row in df_sub.iterrows():
            # extract the relevant signal segments
            bout_signals = su.get_bout_signal(signals, row, leading_buffer, trailing_buffer)
            # save the extracted segments to the HDF5 file
            for i, signal_type in enumerate(signals_to_extract):
                f.create_dataset(f"{row['bout_id']}/{signal_type.lower()}", data=bout_signals[i])
                f.flush()
    

0
  eeg: (25000,)
  emg: (25000,)
1
  eeg: (25000,)
  emg: (25000,)
10
  eeg: (25000,)
  emg: (25000,)
100
  eeg: (25000,)
  emg: (25000,)
1000
  eeg: (25000,)
  emg: (25000,)
1001
  eeg: (25000,)
  emg: (25000,)
1002
  eeg: (25000,)
  emg: (25000,)
1003
  eeg: (25000,)
  emg: (25000,)
1004
  eeg: (25000,)
  emg: (25000,)
1005
  eeg: (25000,)
  emg: (25000,)
1006
  eeg: (25000,)
  emg: (25000,)
1007
  eeg: (25000,)
  emg: (25000,)
1008
  eeg: (25000,)
  emg: (25000,)
1009
  eeg: (25000,)
  emg: (25000,)
101
  eeg: (25000,)
  emg: (25000,)
1010
  eeg: (25000,)
  emg: (25000,)
1011
  eeg: (25000,)
  emg: (25000,)
1012
  eeg: (25000,)
  emg: (25000,)
1013
  eeg: (25000,)
  emg: (25000,)
1014
  eeg: (25000,)
  emg: (25000,)
1015
  eeg: (25000,)
  emg: (25000,)
1016
  eeg: (25000,)
  emg: (25000,)
1017
  eeg: (25000,)
  emg: (25000,)
1018
  eeg: (25000,)
  emg: (25000,)
1019
  eeg: (25000,)
  emg: (25000,)
102
  eeg: (25000,)
  emg: (25000,)
1020
  eeg: (25000,)
  emg: (25000,)
1021
  eeg: 