In [1]:
"""
Dataloaders for lstm_only model
"""
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import  pad_sequence

import h5py
import numpy as np
import pandas as pd

In [26]:
class MultiModalDataset(Dataset):
    def __init__(self, data_path):
        
        self.data_path = data_path
        self.ts_h5_file = os.path.join(self.data_path, 'ts_each_patient.h5')
        self.risks_h5_file = os.path.join(self.data_path, 'risk_each_patient.h5')
        self.flat_h5_file = os.path.join(self.data_path, 'flat.h5')
        
        self.ts_h5f = h5py.File(self.ts_h5_file, 'r')
        self.risk_h5f = h5py.File(self.risks_h5_file, 'r')
        self.flat_data = pd.read_hdf(self.flat_h5_file)
        
        self.patient_ids = list(self.ts_h5f.keys())

    def __len__(self):
        return len(self.patient_ids)

    def __getitem__(self, idx):
        patient_id = self.patient_ids[idx]
        
        ts_data = self.ts_h5f[patient_id][:, 1:]  # exclude the first column which is the time
        risk_data = self.risk_h5f[patient_id][:] #
        flat_data = self.flat_data.loc[int(patient_id)].values

        category = int(risk_data[0][5])  # discharge_risk_category
        mortality_label = int(risk_data[0][4])  # unitdischargestatus
        
        ts_data = torch.tensor(ts_data, dtype=torch.float32)
        flat_data = torch.tensor(flat_data, dtype=torch.float32)
        risk_data = torch.tensor(risk_data[:, -1], dtype=torch.float32) # risk data is the last column

        
        return patient_id, flat_data,ts_data, risk_data, category,mortality_label

    def close(self):
        self.ts_h5f.close()
        self.risk_h5f.close()

    
def collate_fn(batch):
    patient_ids,  flat_list,ts_list, risk_list,category_list,mortality_labels = zip(*batch)
    lengths = [x.shape[0] for x in ts_list]
    lengths = torch.tensor(lengths, dtype=torch.long)

    # order by length
    lengths, sorted_idx = torch.sort(lengths, descending=True)
    ts_list = [ts_list[i] for i in sorted_idx]
    risk_list = [risk_list[i] for i in sorted_idx]
    flat_list = [flat_list[i] for i in sorted_idx]
    patient_ids = [patient_ids[i] for i in sorted_idx]
    category_list = [category_list[i] for i in sorted_idx]
    mortality_labels = [mortality_labels[i] for i in sorted_idx]

    # pad sequences
    padding_value = 0
    padded_ts = pad_sequence(ts_list, batch_first=True, padding_value=padding_value)
    padded_risk = pad_sequence(risk_list, batch_first=True, padding_value=padding_value)
    flat_data = torch.stack(flat_list)
    categories = torch.tensor(category_list, dtype=torch.long)
    mortality_labels = torch.tensor(mortality_labels, dtype=torch.long)
    
    return patient_ids,  flat_data, padded_ts, padded_risk, lengths,categories, mortality_labels


In [27]:
data_path="/home/mei/nas/docker/thesis/data/hdf/test"
dataset = MultiModalDataset(data_path)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

In [5]:
for packed_ts_data, lengths in dataloader:
    print("Packed Time Series Data Shape:", packed_ts_data.shape)
    print("Lengths:", lengths)
    break 

Packed Time Series Data Shape: torch.Size([2, 2767, 324])
Lengths: tensor([2767, 1577])


In [37]:
for patient_ids, flat_data,packed_ts_data, risks_data,lengths,categories, mortality_labels in dataloader:
    print("Patient IDs:", patient_ids)
    print("Packed Time Series Data Shape:", packed_ts_data.shape)
    print("Flat Data Shape:", flat_data.shape)
    print("Risks Data Shape:", risks_data.shape)
    print("Lengths:", lengths)
    print("Categories:", categories)
    print("Mortality Labels:", mortality_labels)
    break  

Patient IDs: ['3167984', '3066964']
Packed Time Series Data Shape: torch.Size([2, 1663, 154])
Flat Data Shape: torch.Size([2, 104])
Risks Data Shape: torch.Size([2, 1663])
Lengths: tensor([1663, 1198])
Categories: tensor([1, 1])
Mortality Labels: tensor([0, 0])


In [11]:
ts_h5_file = os.path.join(data_path, 'ts_each_patient.h5')
with h5py.File(ts_h5_file, 'r') as f:
    ts_data = {key: np.array(f[key]) for key in f.keys()}

risks_h5_file = os.path.join(data_path, 'risk_each_patient.h5')
with h5py.File(risks_h5_file, 'r') as f:
    risk_data = {key: np.array(f[key]) for key in f.keys()}


In [39]:

patient_id = '3066964' 
if patient_id in ts_data:
    ts_series = ts_data[patient_id]
    print(f"Time series length for patient {patient_id}: {len(ts_series)}")
    
if patient_id in risk_data:
    risk_series = risk_data[patient_id]
    categories = risk_series[:, 5]
    mortality_labels = risk_series[:, 4]
    print(f"Risk categories for patient {patient_id}: {categories}")
    print(f"Mortality labels for patient {patient_id}: {mortality_labels}")
    print(f"Risk series length for patient {patient_id}: {len(risk_series)}")

Time series length for patient 3066964: 1198
Risk categories for patient 3066964: [1. 1. 1. ... 1. 1. 1.]
Mortality labels for patient 3066964: [0. 0. 0. ... 0. 0. 0.]
Risk series length for patient 3066964: 1198


In [38]:

patient_id = '3167984' 
if patient_id in ts_data:
    ts_series = ts_data[patient_id]
    print(f"Time series length for patient {patient_id}: {len(ts_series)}")
    
if patient_id in risk_data:
    risk_series = risk_data[patient_id]
    categories = risk_series[:, 5]
    mortality_labels = risk_series[:, 4]
    print(f"Risk categories for patient {patient_id}: {categories}")
    print(f"Mortality labels for patient {patient_id}: {mortality_labels}")
    print(f"Risk series length for patient {patient_id}: {len(risk_series)}")

Time series length for patient 3167984: 1663
Risk categories for patient 3167984: [1. 1. 1. ... 1. 1. 1.]
Mortality labels for patient 3167984: [0. 0. 0. ... 0. 0. 0.]
Risk series length for patient 3167984: 1663
