In [2]:
"""
Dataloaders for lstm_only model
"""
import os
import numpy as np
import torch
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence 


In [20]:
class LSTMDataset(Dataset):
    """
    PyTorch Dataset for loading time series, labels, and flat features from HDF5 files.
    """
    def __init__(self, data_dir):
        """
        Args:
        - data_dir (str): Path to the dataset directory (e.g., 'train', 'val', 'test')
        """
        self.data_dir = data_dir
        stays_path = os.path.join(data_dir, "stays.txt")
        self.patients = pd.read_csv(stays_path, header=None)[0].tolist()   
        

    def __len__(self):
        return len(self.patients)

    def __getitem__(self,idx):
        
        patient_id = self.patients[idx]
 
        # **load time series**
        with pd.HDFStore(os.path.join(self.data_dir, "timeseries.h5")) as store:
            timeseries = store.get("/table").loc[patient_id] 
            ts_len = len(timeseries) 
            timeseries = torch.tensor(timeseries.values, dtype=torch.float)

        # ** flat features**
        with pd.HDFStore(os.path.join(self.data_dir, "flat.h5")) as store:
            flat = store.get("/table").loc[patient_id].values 
            flat = torch.tensor(flat, dtype=torch.float)

        # ** labels**
        with pd.HDFStore(os.path.join(self.data_dir, "labels.h5")) as store:
            label = store.get("/table").loc[patient_id, "unitdischargestatus"] 
            label = torch.tensor(label, dtype=torch.long)

        return timeseries, flat, label, ts_len, patient_id


def collate_fn(batch):
    """Dynamic padding for batch processing."""
    seqs, flats, labels, ts_lens,ids = zip(*batch)

    seq_lengths = torch.tensor(ts_lens, dtype=torch.long)

    seqs_padded = pad_sequence(seqs, batch_first=True, padding_value=-9999)   # pad with -1

    flats = torch.stack(flats).float()
    labels = torch.tensor(labels).long()
    ids = torch.tensor(ids).long()

    return (seqs_padded, flats), labels, ids 


In [19]:
from torch.utils.data import DataLoader

data_dir = "/home/mei/nas/docker/thesis/data/hdf/train"
dataset = LSTMDataset(data_dir)

train_loader = DataLoader(dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

for batch in train_loader:
    (seqs_padded, seq_lengths, flats), labels, ids = batch
    print(seqs_padded.shape, seq_lengths, flats.shape, labels.shape, ids.shape)
    break

torch.Size([32, 3914, 163]) tensor([ 898, 1420, 1521,  490,  412,  436,  739, 2511, 2051,  890,  933, 3914,
         720,  287,  751, 1066, 1229, 1380, 1277,  501,  515,  924, 1081,  839,
        1235, 3448, 2960,  602, 1207, 2377, 3017, 2165]) torch.Size([32, 104]) torch.Size([32]) torch.Size([32])
