In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
from collections import TypedDict
from pathlib import Path

import torch
from einops import rearrange
from tqdm import tqdm

from nsrr_data.datamodule.stage_datamodule import SleepStageDataset
from nsrr_data.datamodule.transforms.stft_transform import STFTTransform

In [47]:
# Define collate fn to collect elements in a batchx
def collate_fn(batch) -> TypedDict:

    subject_id_map = [x["record"].split("_")[0] for x in batch]
    waveforms = torch.stack([torch.as_tensor(x["signal"]) for x in batch]).to(torch.float32)
    targets = torch.stack([torch.as_tensor(x["stages"]) for x in batch])
    global_information = subject_id_map

    N, L, C, F, T = waveforms.shape
    if L == 1:
        waveforms = waveforms.squeeze(1)
    else:
        waveforms = rearrange(waveforms, "N L C T -> N C (L T)")

    return dict(waveform=waveforms, global_information=global_information, targets=targets)

In [48]:
# Instantiate sleep stage object
records = sorted(list(Path('/home/aneol/waveform-conversion/data/processed/shhs').rglob('*.h5')))[:10]
ds = SleepStageDataset(
    records=records,
    sequence_length=1,
    cache_data=True,
    fs=128,
    n_jobs=-1,
    picks=['c4'],
    scaling='standard',
    transform=STFTTransform(
        fs=128,
        segment_size=128,
        step_size=16,
        nfft=128
    )
)

Using cache for data prep: /home/aneol/waveform-conversion/notebooks/data/.cache


100%|██████████| 10/10 [00:00<00:00, 8950.71it/s]


In [52]:
# Iterate over the dataset and collect outputs
batch = []
for idx, el in enumerate(tqdm(ds)):
    batch.append(el)

100%|██████████| 8836/8836 [00:21<00:00, 411.15it/s]


In [53]:
# Finally grab everything in one batch
batch = collate_fn(batch=batch)