In [24]:
"""
Dataloaders for lstm_only model
"""
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence

import h5py
import numpy as np
import pandas as pd

In [23]:
class MultiModalDataset(Dataset):
    def __init__(self, data_path):
        
        self.data_path = data_path
        self.ts_h5_file = os.path.join(self.data_path, 'ts_each_patient_np.h5')
        self.risks_h5_file = os.path.join(self.data_path, 'risk_scores_each_patient_np.h5')
        self.flat_h5_file = os.path.join(self.data_path, 'flat.h5')
        
        self.ts_h5f = h5py.File(self.ts_h5_file, 'r')
        self.risk_h5f = h5py.File(self.risks_h5_file, 'r')
        self.flat_data = pd.read_hdf(self.flat_h5_file)
        
        self.patient_ids = list(self.ts_h5f.keys())

    def __len__(self):
        return len(self.patient_ids)

    def __getitem__(self, idx):
        patient_id = self.patient_ids[idx]
        
        ts_data = self.ts_h5f[patient_id][:, 1:] 
        risk_data = self.risk_h5f[patient_id][:]
        flat_data = self.flat_data.loc[int(patient_id)].values
        
        ts_data = torch.tensor(ts_data, dtype=torch.float32)
        flat_data = torch.tensor(flat_data, dtype=torch.float32)
        risk_data = torch.tensor(risk_data, dtype=torch.float32)
        
        return patient_id, ts_data, flat_data, risk_data

    def close(self):
        self.ts_h5f.close()
        self.risk_h5f.close()

    
def collate_fn(batch):
    patient_ids, ts_list, flat_list, risk_list = zip(*batch)
    lengths = [x.shape[0] for x in ts_list]
    lengths = torch.tensor(lengths, dtype=torch.long)

    # order by length
    lengths, sorted_idx = torch.sort(lengths, descending=True)
    ts_list = [ts_list[i] for i in sorted_idx]
    risk_list = [risk_list[i] for i in sorted_idx]
    flat_list = [flat_list[i] for i in sorted_idx]
    patient_ids = [patient_ids[i] for i in sorted_idx]

    # pad sequences
    padding_value = -99
    padded_ts = pad_sequence(ts_list, batch_first=True, padding_value=padding_value)
    padded_risk = pad_sequence(risk_list, batch_first=True, padding_value=padding_value)
    flat_data = torch.stack(flat_list)

    return patient_ids, padded_ts, flat_data, padded_risk, lengths



In [29]:
data_path="/home/mei/nas/docker/thesis/data/hdf/train"
dataset = MultiModalDataset(data_path)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

In [30]:
for patient_ids, packed_ts_data, flat_data, risks_data,lengths in dataloader:
    print("Patient IDs:", patient_ids)
    print("Packed Time Series Data Shape:", packed_ts_data.shape)
    print("Flat Data Shape:", flat_data.shape)
    print("Risks Data Shape:", risks_data.shape)
    print("Lengths:", lengths)
    break  

Patient IDs: ['3132351', '1788546']
Packed Time Series Data Shape: torch.Size([2, 1759, 162])
Flat Data Shape: torch.Size([2, 104])
Risks Data Shape: torch.Size([2, 1759])
Lengths: tensor([1759,  422])


In [31]:
ts_h5_file = os.path.join(data_path, 'ts_each_patient_np.h5')
with h5py.File(ts_h5_file, 'r') as f:
    ts_data = {key: np.array(f[key]) for key in f.keys()}

risks_h5_file = os.path.join(data_path, 'risk_scores_each_patient_np.h5')
with h5py.File(risks_h5_file, 'r') as f:
    risk_data = {key: np.array(f[key]) for key in f.keys()}


In [32]:

patient_id = '1788546' 
if patient_id in ts_data:
    ts_series = ts_data[patient_id]
    print(f"Time series length for patient {patient_id}: {len(ts_series)}")
    
if patient_id in risk_data:
    risk_series = risk_data[patient_id]
    print(f"Risk series length for patient {patient_id}: {len(risk_series)}")

Time series length for patient 1788546: 422
Risk series length for patient 1788546: 422
