In [3]:
import pandas as pd
import torch
from torch.utils.data import Subset, ConcatDataset
from pathlib import Path
import numpy as np
from lib import BrennanDataset

base_dir = Path("/ocean/projects/cis240129p/shared/data/eeg_alice")
subjects_used = ["S04", "S13", "S19"]  # exclude 'S05' - less channels

# ds = BrennanDataset(
#     root_dir=base_dir,
#     phoneme_dir=base_dir / "phonemes",
#     idx="S01",
#     phoneme_dict_path=base_dir / "phoneme_dict.txt",
# )

In [None]:
def create_datasets(subjects, base_dir):
    train_datasets = []
    test_datasets = []
    for subject in subjects:
        dataset = BrennanDataset(
            root_dir=base_dir,
            phoneme_dir=base_dir / "phonemes",
            idx=subject,
            phoneme_dict_path=base_dir / "phoneme_dict.txt",
        )
        num_data_points = len(dataset)

        # Split indices into train and test sets
        split_index = int(num_data_points * 0.8)
        train_indices = list(range(split_index))
        test_indices = list(range(split_index, num_data_points))

        # Create Subset datasets using indices
        train_dataset = Subset(dataset, train_indices)
        test_dataset = Subset(dataset, test_indices)

        train_datasets.append(train_dataset)
        test_datasets.append(test_dataset)
    return train_datasets, test_datasets


train_ds, test_ds = create_datasets(subjects_used, base_dir)
train_dataset = ConcatDataset(train_ds)
test_dataset = ConcatDataset(test_ds)
print(
    f"Train dataset length: {len(train_dataset)}, Test dataset length: {len(test_dataset)}"
)

Extracting parameters from /ocean/projects/cis240129p/shared/data/eeg_alice/S04.vhdr...
Setting channel info structure...
Reading 0 ... 368449  =      0.000 ...   736.898 secs...
Extracting parameters from /ocean/projects/cis240129p/shared/data/eeg_alice/S13.vhdr...
Setting channel info structure...
Reading 0 ... 368274  =      0.000 ...   736.548 secs...
Extracting parameters from /ocean/projects/cis240129p/shared/data/eeg_alice/S19.vhdr...
Setting channel info structure...
Reading 0 ... 373374  =      0.000 ...   746.748 secs...
Train dataset length: 5109, Test dataset length: 1278


In [10]:
def collate_fn(batch):
    """
    A custom collate function that handles different types of data in a batch.
    It dynamically creates batches by converting arrays or lists to tensors and
    applies padding to variable-length sequences.
    """
    batch_dict = {}
    for key in batch[0].keys():
        batch_items = [item[key] for item in batch]
        if isinstance(batch_items[0], np.ndarray) or isinstance(
            batch_items[0], torch.Tensor
        ):
            if isinstance(batch_items[0], np.ndarray):
                batch_items = [torch.tensor(b) for b in batch_items]
            if len(batch_items[0].shape) > 0:
                batch_dict[key] = torch.nn.utils.rnn.pad_sequence(
                    batch_items, batch_first=True  # pad with zeros
                )
            else:
                batch_dict[key] = torch.stack(batch_items)
        else:
            batch_dict[key] = batch_items

    return batch_dict


train_dataloder = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=2,
    num_workers=1,
    shuffle=True,
    collate_fn=collate_fn,
)

test_dataloder = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=2,
    num_workers=1,
    shuffle=False,
    collate_fn=collate_fn,
)

In [8]:
item = train_dataset[0]
for k, v in item.items():
    try:
        print(k, v.shape, type(v))
    except:
        print(k, type(v))

label <class 'str'>
audio_feats (104, 128) <class 'numpy.ndarray'>
audio_raw (16735,) <class 'numpy.ndarray'>
eeg_raw (520, 62) <class 'numpy.ndarray'>
eeg_feats (159, 310) <class 'numpy.ndarray'>
phonemes (104,) <class 'numpy.ndarray'>


In [11]:
# test dataloader
i = 0
for batch in train_dataloder:
    print(i)
    for k, v in batch.items():
        try:
            print(k, v.shape, type(v))
        except:
            print(k, type(v))
    i += 1
    if i > 4:
        break

0
label <class 'list'>
audio_feats torch.Size([2, 130, 128]) <class 'torch.Tensor'>
audio_raw torch.Size([2, 20800]) <class 'torch.Tensor'>
eeg_raw torch.Size([2, 520, 62]) <class 'torch.Tensor'>
eeg_feats torch.Size([2, 159, 310]) <class 'torch.Tensor'>
phonemes torch.Size([2, 130]) <class 'torch.Tensor'>
1
label <class 'list'>
audio_feats torch.Size([2, 130, 128]) <class 'torch.Tensor'>
audio_raw torch.Size([2, 20800]) <class 'torch.Tensor'>
eeg_raw torch.Size([2, 520, 62]) <class 'torch.Tensor'>
eeg_feats torch.Size([2, 159, 310]) <class 'torch.Tensor'>
phonemes torch.Size([2, 130]) <class 'torch.Tensor'>
2
label <class 'list'>
audio_feats torch.Size([2, 130, 128]) <class 'torch.Tensor'>
audio_raw torch.Size([2, 20800]) <class 'torch.Tensor'>
eeg_raw torch.Size([2, 520, 62]) <class 'torch.Tensor'>
eeg_feats torch.Size([2, 159, 310]) <class 'torch.Tensor'>
phonemes torch.Size([2, 130]) <class 'torch.Tensor'>
3
label <class 'list'>
audio_feats torch.Size([2, 130, 128]) <class 'torch.T