In [35]:
import os
import random
import pandas as pd
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from utils import *




In [None]:

base_dir = 'C:/Users/janis/Documents/GitHub/PM_EEG_CONTROL/dataverse_files'
base_dir = "C:/Users/janis/Desktop/dataverse_files"

# Loading all participants into a single dictionary
all_data = load_all_data(base_dir)
print("Loaded data for subjects:", list(all_data.keys()))


In [40]:
print(len(all_data))

21


In [None]:

class CLDriveDataset(Dataset):

    def __init__(self, base_dir: str):
        self.base_dir = base_dir
        eeg_dir = os.path.join(base_dir, 'EEG')
        self.subject_ids = []
        if os.path.isdir(eeg_dir):
            for folder in os.listdir(eeg_dir):
                folder_path = os.path.join(eeg_dir, folder)
                if os.path.isdir(folder_path):
                    self.subject_ids.append(folder)

    def __len__(self):
        return len(self.subject_ids)

    def __getitem__(self, idx):
        subject_id = self.subject_ids[idx]
        subject_data = load_subject_data(self.base_dir, subject_id)
        return subject_data


In [44]:


def custom_collate_fn(batch):

    collated = {}
    modalities = ['EEG', 'ECG', 'EDA', 'Gaze']
    labels_key = 'Labels'
    
    # Initialize dictionaries for each modality and level
    for mod in modalities:
        collated[mod] = {}
        for level in range(1, 10):
            collated[mod][level] = {'data': [], 'baseline': []}

    collated[labels_key] = []

    # Iterate over each sample in the batch
    for sample in batch:
        # Process each modality
        for mod in modalities:
            mod_data = sample.get(mod, None)
            if mod_data is not None:
                for level, level_data in mod_data.items():
                    df_data = level_data['data']
                    df_baseline = level_data['baseline']
                    # Ensure 'Timestamp' is present
                    if 'Timestamp' not in df_data.columns or 'Timestamp' not in df_baseline.columns:
                        continue
                    # Merge data and baseline on 'Timestamp'
                    merged_data = pd.merge(df_data, df_baseline, on='Timestamp', suffixes=('_data', '_baseline'))
                    # Drop rows with any NaNs
                    merged_data = merged_data.dropna()
                    if merged_data.empty:
                        continue
                    # Extract features (drop 'Timestamp')
                    data_features = merged_data.drop(columns=['Timestamp'])
                    baseline_features = merged_data.drop(columns=['Timestamp'])
                    # Convert to numpy and then to tensors
                    data_tensor = torch.tensor(data_features.to_numpy(dtype=np.float32))
                    baseline_tensor = torch.tensor(baseline_features.to_numpy(dtype=np.float32))
                    # Append to collated
                    collated[mod][level]['data'].append(data_tensor)
                    collated[mod][level]['baseline'].append(baseline_tensor)
        
        # Process Labels
        labels_df = sample.get(labels_key, None)
        if labels_df is not None and 'time' in labels_df.columns:
            # Drop rows with any NaNs
            labels_df_clean = labels_df.dropna()
            if not labels_df_clean.empty:
                # Drop 'time' column
                labels_features = labels_df_clean.drop(columns=['time'])
                labels_tensor = torch.tensor(labels_features.to_numpy(dtype=np.float32))
                collated[labels_key].append(labels_tensor)

    # Stack tensors for each modality and level
    for mod in modalities:
        for level in range(1, 10):
            if collated[mod][level]['data']:
                # Concatenate along the first dimension (batch dimension)
                collated[mod][level]['data'] = torch.cat(collated[mod][level]['data'], dim=0)
            else:
                collated[mod][level]['data'] = None
            if collated[mod][level]['baseline']:
                collated[mod][level]['baseline'] = torch.cat(collated[mod][level]['baseline'], dim=0)
            else:
                collated[mod][level]['baseline'] = None

    # Stack Labels
    if collated[labels_key]:
        collated[labels_key] = torch.stack(collated[labels_key], dim=0)
    else:
        collated[labels_key] = None

    return collated


In [None]:
# Initialize the dataset
dataset = CLDriveDataset(base_dir=base_dir )

# Initialize the DataLoader with the custom collate function
dataloader = DataLoader(
    dataset,
    batch_size=4,           # Adjust batch size as needed
    shuffle=True,
    collate_fn=custom_collate_fn
)

# Example: Iterating through the DataLoader
for batch_idx, batch in enumerate(dataloader):
    print(f"Batch {batch_idx + 1}:")
    
    # Accessing EEG data
    eeg_data = batch['EEG']
    if eeg_data:
        for level, level_data in eeg_data.items():
            if level_data['data'] is not None:
                print(f"  EEG Level {level}: Data shape {level_data['data'].shape}, Baseline shape {level_data['baseline'].shape}")
            else:
                print(f"  EEG Level {level}: No data")
    
    # Accessing Labels
    labels = batch['Labels']
    if labels is not None:
        print(f"  Labels shape: {labels.shape}")
    else:
        print("  No Labels found")
    break


Batch 1:
  EEG Level 1: Data shape torch.Size([1024, 8]), Baseline shape torch.Size([1024, 8])
  EEG Level 2: Data shape torch.Size([1024, 8]), Baseline shape torch.Size([1024, 8])
  EEG Level 3: Data shape torch.Size([1024, 8]), Baseline shape torch.Size([1024, 8])
  EEG Level 4: Data shape torch.Size([1023, 8]), Baseline shape torch.Size([1023, 8])
  EEG Level 5: Data shape torch.Size([768, 8]), Baseline shape torch.Size([768, 8])
  EEG Level 6: Data shape torch.Size([767, 8]), Baseline shape torch.Size([767, 8])
  EEG Level 7: Data shape torch.Size([1381, 8]), Baseline shape torch.Size([1381, 8])
  EEG Level 8: Data shape torch.Size([768, 8]), Baseline shape torch.Size([768, 8])
  EEG Level 9: Data shape torch.Size([768, 8]), Baseline shape torch.Size([768, 8])
  Labels shape: torch.Size([3, 18, 9])
