In [3]:
"""
Dataloaders for lstm_only model
"""
import os
import numpy as np
import torch
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence 

In [30]:
class LSTMTSDataset(Dataset):
    """
    PyTorch Dataset for loading time series, labels, and flat features from HDF5 files.
    """
    def __init__(self, data_dir):
        """
        Args:
        - data_dir (str): Path to the dataset directory (e.g., 'train', 'val', 'test')
        """
        self.data_dir = data_dir
        stays_path = os.path.join(data_dir, "stays.txt")
        self.patients = pd.read_csv(stays_path, header=None)[0].tolist()   
        

    def __len__(self):
        return len(self.patients)

    def __getitem__(self,idx):
        
        patient_id = self.patients[idx]
 
        # **load time series**
        with pd.HDFStore(os.path.join(self.data_dir, "timeseries.h5")) as store:
            timeseries = store.get("/table").loc[patient_id] 
            ts_len = len(timeseries) 
            timeseries = torch.tensor(timeseries.values, dtype=torch.float)

        # ** flat features**
        with pd.HDFStore(os.path.join(self.data_dir, "flat.h5")) as store:
            flat = store.get("/table").loc[patient_id].values 
            flat = torch.tensor(flat, dtype=torch.float)

        # ** labels**
        with pd.HDFStore(os.path.join(self.data_dir, "labels.h5")) as store:
            label = store.get("/table").loc[patient_id, "unitdischargestatus"] 
            label = torch.tensor(label, dtype=torch.long)

        return patient_id,timeseries, ts_len,flat, label


def collate_fn(batch):
    """Dynamic padding for batch processing."""
    ids,seqs,ts_lens, flats, labels = zip(*batch)

    seq_lengths = torch.tensor(ts_lens, dtype=torch.long)   # lengths of each sequence in the batch

    seqs_padded = pad_sequence(seqs, batch_first=True, padding_value=-9999)   # pad with -1

    flats = torch.stack(flats).float()
    labels = torch.tensor(labels).long()
    ids = torch.tensor(ids).long()

    return (seqs_padded, seq_lengths, flats), labels, ids


In [19]:

data_dir = "/home/mei/nas/docker/thesis/data/hdf/train"
dataset = LSTMTSDataset(data_dir)

train_loader = DataLoader(dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

for batch in train_loader:
    (seqs_padded, seq_lengths, flats), labels, ids = batch
    print(seqs_padded.shape, seq_lengths, flats.shape, labels.shape, ids.shape)
    break

torch.Size([32, 3914, 163]) tensor([ 898, 1420, 1521,  490,  412,  436,  739, 2511, 2051,  890,  933, 3914,
         720,  287,  751, 1066, 1229, 1380, 1277,  501,  515,  924, 1081,  839,
        1235, 3448, 2960,  602, 1207, 2377, 3017, 2165]) torch.Size([32, 104]) torch.Size([32]) torch.Size([32])


In [16]:
timeseries = pd.read_hdf("/home/mei/nas/docker/thesis/data/hdf/val/timeseries.h5", key="table")

In [17]:
timeseries

Unnamed: 0_level_0,time,-bands,-basos,-eos,-lymphs,-monos,-polys,24 h urine protein,24 h urine urea nitrogen,ALT (SGPT),...,sao2,heartrate,respiration,cvp,systemicsystolic,systemicdiastolic,systemicmean,st1,st2,st3
patient,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2048518,1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,100.0,92.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2048518,2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,100.0,81.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2048518,3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,100.0,84.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2048518,4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,100.0,86.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2048518,5,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,100.0,87.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2892666,1179,0.0,0.0,0.0,4.0,8.0,88.0,0.0,0.0,0.0,...,95.0,70.0,15.0,9.0,0.0,0.0,0.0,0.0,0.0,0.0
2892666,1180,0.0,0.0,0.0,4.0,8.0,88.0,0.0,0.0,0.0,...,95.0,70.0,15.0,9.0,0.0,0.0,0.0,0.0,0.0,0.0
2892666,1181,0.0,0.0,0.0,4.0,8.0,88.0,0.0,0.0,0.0,...,95.0,70.0,15.0,9.0,0.0,0.0,0.0,0.0,0.0,0.0
2892666,1182,0.0,0.0,0.0,4.0,8.0,88.0,0.0,0.0,0.0,...,95.0,70.0,15.0,9.0,0.0,0.0,0.0,0.0,0.0,0.0
