# Data Preprocessing

In [20]:
import sys
import gzip
from pathlib import Path
import pandas as pd
import numpy as np

src_dir = Path.cwd().parent

# sys.path strictly for importing modules
sys.path.append(str(src_dir))
from utils.data_utils import *

COHORT_PATH = src_dir / "data" / "processed" / "diabetic_patient_day_table.csv.gz"

In [21]:
cohort = load_data(COHORT_PATH)
print(cohort.shape)
cohort.head()

KeyboardInterrupt: 

In [3]:
cohort.columns

Index(['subject_id', 'chartdate', '50803', '50809', '50822', '50824', '50837',
       '50841', '50842', '50847', '50848', '50852', '50854', '50882', '50912',
       '50931', '50971', '50983', '51021', '51022', '51027', '51032', '51034',
       '51041', '51042', '51048', '51052', '51053', '51057', '51058', '51061',
       '51064', '51065', '51067', '51070', '51073', '51076', '51080', '51081',
       '51082', '51084', '51097', '51099', '51100', '51106', '51478', '51790',
       '51981', '52024', '52546', '52569', '52610', '52623', 'gender', 'age',
       'n_admissions', 'first_admission_date', 'last_admission_date',
       'hypertension_flag', 'ckd_flag', 'obesity_flag', 'neuropathy_flag',
       'retinopathy_flag', 'heart_disease_flag', 'insulin_flag'],
      dtype='object')

In [4]:
lab_cols = cohort.columns[2:53]
static_cols = [col for col in cohort.columns if col not in lab_cols]
print(f"Lab columns: {lab_cols}")
print(f"Static columns: {static_cols}")

Lab columns: Index(['50803', '50809', '50822', '50824', '50837', '50841', '50842', '50847',
       '50848', '50852', '50854', '50882', '50912', '50931', '50971', '50983',
       '51021', '51022', '51027', '51032', '51034', '51041', '51042', '51048',
       '51052', '51053', '51057', '51058', '51061', '51064', '51065', '51067',
       '51070', '51073', '51076', '51080', '51081', '51082', '51084', '51097',
       '51099', '51100', '51106', '51478', '51790', '51981', '52024', '52546',
       '52569', '52610', '52623'],
      dtype='object')
Static columns: ['subject_id', 'chartdate', 'gender', 'age', 'n_admissions', 'first_admission_date', 'last_admission_date', 'hypertension_flag', 'ckd_flag', 'obesity_flag', 'neuropathy_flag', 'retinopathy_flag', 'heart_disease_flag', 'insulin_flag']


## Impute Missing Lab Values

In [None]:
cohort = cohort.sort_values(["subject_id", "chartdate"]).copy()
cohort["chartdate"] = pd.to_datetime(cohort["chartdate"])
cohort = cohort.set_index("chartdate")

def compute_time_since_last(s):
    last = None
    deltas = []
    for idx, val in s.items():
        if pd.notna(val):
            last = idx
            deltas.append(0)
        else:
            deltas.append(np.nan if last is None else (idx - last).days)
    return pd.Series(deltas, index=s.index)

for col in lab_cols:
    print(f"Current column: {col}")
    cohort[col + '_days_since_last'] = (
        cohort
        .groupby('subject_id')[col]
        .transform(compute_time_since_last)
    )

cohort.head()

Current column: 50803
Current column: 50809
Current column: 50822
Current column: 50824
Current column: 50837
Current column: 50841
Current column: 50842
Current column: 50847
Current column: 50848
Current column: 50852
Current column: 50854
Current column: 50882
Current column: 50912
Current column: 50931
Current column: 50971
Current column: 50983
Current column: 51021
Current column: 51022
Current column: 51027
Current column: 51032
Current column: 51034
Current column: 51041
Current column: 51042
Current column: 51048
Current column: 51052
Current column: 51053
Current column: 51057
Current column: 51058
Current column: 51061
Current column: 51064
Current column: 51065
Current column: 51067
Current column: 51070
Current column: 51073
Current column: 51076
Current column: 51080
Current column: 51081
Current column: 51082
Current column: 51084
Current column: 51097
Current column: 51099
Current column: 51100
Current column: 51106
Current column: 51478
Current column: 51790
Current co

Unnamed: 0_level_0,subject_id,50803,50809,50822,50824,50837,50841,50842,50847,50848,...,51100_days_since_last,51106_days_since_last,51478_days_since_last,51790_days_since_last,51981_days_since_last,52024_days_since_last,52546_days_since_last,52569_days_since_last,52610_days_since_last,52623_days_since_last
chartdate,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2136-04-08,10000635,,,,,,,,,,...,,,,,,,,,,
2138-09-29,10000635,,,,,,,,,,...,,,,,,,,,,
2141-08-15,10000635,,,,,,,,,,...,,,,,,,,,,
2142-12-23,10000635,,,,,,,,,,...,,,,,,,,,,
2143-06-06,10000635,,,,,,,,,,...,,,,,,,,,,


In [None]:
# Default forward-fill window (in days) per lab
# Shorter windows for frequently measured labs, longer for infrequent or stable labs

lab_window_dict = {
    '50803': 1,    # Glucose (daily measurement)
    '50809': 1,    # Glucose (fasting)
    '50822': 14,   # Creatinine
    '50824': 14,   # Creatinine
    '50837': 7,    # Sodium
    '50841': 7,    # Potassium
    '50842': 7,    # Chloride
    '50847': 30,   # Hemoglobin
    '50848': 30,   # Hematocrit
    '50852': 30,   # WBC
    '50854': 30,   # Platelet
    '50882': 90,   # HbA1c
    '50912': 90,   # HbA1c
    '50931': 30,   # ALT
    '50971': 30,   # AST
    '50983': 30,   # ALP
    '51021': 30,   # Total Bilirubin
    '51022': 30,   # Direct Bilirubin
    '51027': 30,   # Albumin
    '51032': 30,   # Total Protein
    '51034': 30,   # Calcium
    '51041': 30,   # Magnesium
    '51042': 30,   # Phosphate
    '51048': 30,   # Iron
    '51052': 30,   # TIBC
    '51053': 30,   # Ferritin
    '51057': 30,   # CRP
    '51058': 30,   # ESR
    '51061': 30,   # Bilirubin (conjugated)
    '51064': 30,   # Bilirubin (unconjugated)
    '51065': 30,   # LDH
    '51067': 30,   # CK
    '51070': 30,   # Troponin
    '51073': 30,   # BNP
    '51076': 30,   # NT-proBNP
    '51080': 30,   # GGT
    '51081': 30,   # Amylase
    '51082': 30,   # Lipase
    '51084': 30,   # CRP high-sensitivity
    '51097': 30,   # WBC differential
    '51099': 30,   # RBC
    '51100': 30,   # MCV
    '51106': 30,   # MCH
    '51478': 30,   # MCHC
    '51790': 30,   # Platelet mean volume
    '51981': 30,   # INR
    '52024': 30,   # PT
    '52546': 30,   # aPTT
    '52569': 30,   # Fibrinogen
    '52610': 30,   # D-dimer
    '52623': 30,   # Lactate
}

cohort_ff = cohort.groupby("subject_id")[lab_cols].ffill()

for col in lab_cols:
    window = lab_window_dict[col]  # define this dict
    too_old = cohort[col + "_days_since_last"] > window
    cohort_ff.loc[too_old, col] = np.nan

cohort[lab_cols] = cohort_ff

mask_df = cohort[lab_cols].notna().astype(int).add_suffix('_mask')
cohort = pd.concat([cohort, mask_df], axis=1)
cohort.head()

Unnamed: 0_level_0,subject_id,50803,50809,50822,50824,50837,50841,50842,50847,50848,...,51100_mask,51106_mask,51478_mask,51790_mask,51981_mask,52024_mask,52546_mask,52569_mask,52610_mask,52623_mask
chartdate,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2136-04-08,10000635,,,,,,,,,,...,0,0,0,0,0,0,0,0,0,0
2138-09-29,10000635,,,,,,,,,,...,0,0,0,0,0,0,0,0,0,0
2141-08-15,10000635,,,,,,,,,,...,0,0,0,0,0,0,0,0,0,0
2142-12-23,10000635,,,,,,,,,,...,0,0,0,0,0,0,0,0,0,0
2143-06-06,10000635,,,,,,,,,,...,0,0,0,0,0,0,0,0,0,0


## Applying PyTorch Transformers

In [26]:
import torch
from torch.utils.data import Dataset, DataLoader

class LabTimeSeriesDataset(Dataset):
    def __init__(self, cohort, lab_cols, static_cols, max_seq_len=100):
        self.patients = cohort['subject_id'].unique()
        self.cohort = cohort
        self.lab_cols = lab_cols
        self.static_cols = static_cols
        self.max_seq_len = max_seq_len

    def __len__(self):
        return len(self.patients)

    def __getitem__(self, idx):
        pid = self.patients[idx]
        df = self.cohort[self.cohort['subject_id'] == pid].sort_values('chartdate')

        # Extract dynamic features
        values = df[self.lab_cols].to_numpy()
        masks  = df[[col+'_mask' for col in self.lab_cols]].to_numpy()
        deltas = df[[col+'_days_since_last' for col in self.lab_cols]].to_numpy()

        dynamic_features = np.concatenate([values, masks, deltas], axis=1)

        # Pad/truncate to max_seq_len
        seq_len, feat_dim = dynamic_features.shape
        if seq_len < self.max_seq_len:
            pad = np.zeros((self.max_seq_len - seq_len, feat_dim))
            dynamic_features = np.vstack([dynamic_features, pad])
        else:
            dynamic_features = dynamic_features[-self.max_seq_len:]

        # Extract static features
        static_values = df[self.static_cols].iloc[0].to_numpy()  # static = first row
        static_values = torch.tensor(static_values, dtype=torch.float32)

        return {
            'dynamic': torch.tensor(dynamic_features, dtype=torch.float32),
            'static': static_values,
            'patient_id': pid
        }

dataset = LabTimeSeriesDataset(cohort, lab_cols, static_cols, max_seq_len=100)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

In [27]:
import torch.nn as nn

class LabTransformer(nn.Module):
    def __init__(self, input_dim, d_model=128, nhead=8, num_layers=2, num_static=0, output_dim=1):
        super().__init__()
        
        # Optional: embed static features
        self.static_fc = nn.Linear(num_static, d_model) if num_static > 0 else None

        # Project dynamic features to model dimension
        self.input_fc = nn.Linear(input_dim, d_model)

        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Prediction head
        self.fc_out = nn.Linear(d_model, output_dim)

    def forward(self, dynamic, static=None):
        """
        dynamic: (batch_size, seq_len, input_dim)
        static: (batch_size, num_static)
        """
        x = self.input_fc(dynamic)  # -> (batch_size, seq_len, d_model)

        # Add static features if available
        if static is not None and self.static_fc is not None:
            static_emb = self.static_fc(static).unsqueeze(1)  # (batch, 1, d_model)
            x = x + static_emb  # broadcast across seq_len

        # Transformer expects (seq_len, batch, d_model)
        x = x.permute(1, 0, 2)

        x = self.transformer(x)  # (seq_len, batch, d_model)

        # Take last timestep for prediction
        out = x[-1]             # (batch, d_model)
        out = self.fc_out(out)  # (batch, output_dim)
        return out

In [28]:
input_dim = len(lab_cols) * 3  # values + masks + deltas
num_static = len(static_cols)
model = LabTransformer(input_dim=input_dim, num_static=num_static, output_dim=1)

for batch in dataloader:
    dynamic = batch['dynamic']  # (batch_size, seq_len, input_dim)
    static = batch['static']    # (batch_size, num_static)
    
    preds = model(dynamic, static)  # (batch_size, 1)



KeyError: "['chartdate'] not in index"