# Data Preprocessing

In [125]:
import sys
import gzip
from pathlib import Path
import pandas as pd
import numpy as np

src_dir = Path.cwd().parent

# sys.path strictly for importing modules
sys.path.append(str(src_dir))
from utils.data_utils import *

COHORT_PATH = src_dir / "data" / "processed" / "diabetic_patient_day_table.csv.gz"

In [126]:
cohort = load_data(COHORT_PATH)
print(cohort.shape)
cohort.head()

(1372192, 65)


Unnamed: 0,subject_id,chartdate,50803,50809,50822,50824,50837,50841,50842,50847,...,n_admissions,first_admission_date,last_admission_date,hypertension_flag,ckd_flag,obesity_flag,neuropathy_flag,retinopathy_flag,heart_disease_flag,insulin_flag
0,10000635,2136-04-08,,,,,,,,,...,2,2136-06-19 14:24:00,2143-12-23 14:55:00,True,,True,,,,False
1,10000635,2138-09-29,,,,,,,,,...,2,2136-06-19 14:24:00,2143-12-23 14:55:00,True,,True,,,,False
2,10000635,2141-08-15,,,,,,,,,...,2,2136-06-19 14:24:00,2143-12-23 14:55:00,True,,True,,,,False
3,10000635,2142-12-23,,,,,,,,,...,2,2136-06-19 14:24:00,2143-12-23 14:55:00,True,,True,,,,False
4,10000635,2143-06-06,,,,,,,,,...,2,2136-06-19 14:24:00,2143-12-23 14:55:00,True,,True,,,,False


In [127]:
cohort.columns

Index(['subject_id', 'chartdate', '50803', '50809', '50822', '50824', '50837',
       '50841', '50842', '50847', '50848', '50852', '50854', '50882', '50912',
       '50931', '50971', '50983', '51021', '51022', '51027', '51032', '51034',
       '51041', '51042', '51048', '51052', '51053', '51057', '51058', '51061',
       '51064', '51065', '51067', '51070', '51073', '51076', '51080', '51081',
       '51082', '51084', '51097', '51099', '51100', '51106', '51478', '51790',
       '51981', '52024', '52546', '52569', '52610', '52623', 'gender', 'age',
       'n_admissions', 'first_admission_date', 'last_admission_date',
       'hypertension_flag', 'ckd_flag', 'obesity_flag', 'neuropathy_flag',
       'retinopathy_flag', 'heart_disease_flag', 'insulin_flag'],
      dtype='object')

In [128]:
lab_cols = cohort.columns[2:53]
static_cols = [col for col in cohort.columns if col not in lab_cols]
static_cols = [c for c in static_cols if c != "chartdate"]
print(f"Lab columns: {lab_cols}")
print(f"Static columns: {static_cols}")

Lab columns: Index(['50803', '50809', '50822', '50824', '50837', '50841', '50842', '50847',
       '50848', '50852', '50854', '50882', '50912', '50931', '50971', '50983',
       '51021', '51022', '51027', '51032', '51034', '51041', '51042', '51048',
       '51052', '51053', '51057', '51058', '51061', '51064', '51065', '51067',
       '51070', '51073', '51076', '51080', '51081', '51082', '51084', '51097',
       '51099', '51100', '51106', '51478', '51790', '51981', '52024', '52546',
       '52569', '52610', '52623'],
      dtype='object')
Static columns: ['subject_id', 'gender', 'age', 'n_admissions', 'first_admission_date', 'last_admission_date', 'hypertension_flag', 'ckd_flag', 'obesity_flag', 'neuropathy_flag', 'retinopathy_flag', 'heart_disease_flag', 'insulin_flag']


## Impute Missing Lab Values

In [129]:
cohort = cohort.sort_values(["subject_id", "chartdate"]).copy()
cohort["chartdate"] = pd.to_datetime(cohort["chartdate"])
cohort = cohort.set_index("chartdate")

def compute_time_since_last(s):
    last = None
    deltas = []
    for idx, val in s.items():
        if pd.notna(val):
            last = idx
            deltas.append(0)
        else:
            deltas.append(np.nan if last is None else (idx - last).days)
    return pd.Series(deltas, index=s.index)

for col in lab_cols:
    print(f"Current column: {col}")
    cohort[col + '_days_since_last'] = (
        cohort
        .groupby('subject_id')[col]
        .transform(compute_time_since_last)
    )

cohort = cohort.reset_index()
cohort.head()

Current column: 50803
Current column: 50809
Current column: 50822
Current column: 50824
Current column: 50837
Current column: 50841
Current column: 50842
Current column: 50847
Current column: 50848
Current column: 50852
Current column: 50854
Current column: 50882
Current column: 50912
Current column: 50931
Current column: 50971
Current column: 50983
Current column: 51021
Current column: 51022
Current column: 51027
Current column: 51032
Current column: 51034
Current column: 51041
Current column: 51042
Current column: 51048
Current column: 51052
Current column: 51053
Current column: 51057
Current column: 51058
Current column: 51061
Current column: 51064
Current column: 51065
Current column: 51067
Current column: 51070
Current column: 51073
Current column: 51076
Current column: 51080
Current column: 51081
Current column: 51082
Current column: 51084
Current column: 51097
Current column: 51099
Current column: 51100
Current column: 51106
Current column: 51478
Current column: 51790
Current co

Unnamed: 0,chartdate,subject_id,50803,50809,50822,50824,50837,50841,50842,50847,...,51100_days_since_last,51106_days_since_last,51478_days_since_last,51790_days_since_last,51981_days_since_last,52024_days_since_last,52546_days_since_last,52569_days_since_last,52610_days_since_last,52623_days_since_last
0,2136-04-08,10000635,,,,,,,,,...,,,,,,,,,,
1,2138-09-29,10000635,,,,,,,,,...,,,,,,,,,,
2,2141-08-15,10000635,,,,,,,,,...,,,,,,,,,,
3,2142-12-23,10000635,,,,,,,,,...,,,,,,,,,,
4,2143-06-06,10000635,,,,,,,,,...,,,,,,,,,,


In [130]:
# Default forward-fill window (in days) per lab
# Shorter windows for frequently measured labs, longer for infrequent or stable labs

lab_window_dict = {
    '50803': 1,    # Glucose (daily measurement)
    '50809': 1,    # Glucose (fasting)
    '50822': 14,   # Creatinine
    '50824': 14,   # Creatinine
    '50837': 7,    # Sodium
    '50841': 7,    # Potassium
    '50842': 7,    # Chloride
    '50847': 30,   # Hemoglobin
    '50848': 30,   # Hematocrit
    '50852': 30,   # WBC
    '50854': 30,   # Platelet
    '50882': 90,   # HbA1c
    '50912': 90,   # HbA1c
    '50931': 30,   # ALT
    '50971': 30,   # AST
    '50983': 30,   # ALP
    '51021': 30,   # Total Bilirubin
    '51022': 30,   # Direct Bilirubin
    '51027': 30,   # Albumin
    '51032': 30,   # Total Protein
    '51034': 30,   # Calcium
    '51041': 30,   # Magnesium
    '51042': 30,   # Phosphate
    '51048': 30,   # Iron
    '51052': 30,   # TIBC
    '51053': 30,   # Ferritin
    '51057': 30,   # CRP
    '51058': 30,   # ESR
    '51061': 30,   # Bilirubin (conjugated)
    '51064': 30,   # Bilirubin (unconjugated)
    '51065': 30,   # LDH
    '51067': 30,   # CK
    '51070': 30,   # Troponin
    '51073': 30,   # BNP
    '51076': 30,   # NT-proBNP
    '51080': 30,   # GGT
    '51081': 30,   # Amylase
    '51082': 30,   # Lipase
    '51084': 30,   # CRP high-sensitivity
    '51097': 30,   # WBC differential
    '51099': 30,   # RBC
    '51100': 30,   # MCV
    '51106': 30,   # MCH
    '51478': 30,   # MCHC
    '51790': 30,   # Platelet mean volume
    '51981': 30,   # INR
    '52024': 30,   # PT
    '52546': 30,   # aPTT
    '52569': 30,   # Fibrinogen
    '52610': 30,   # D-dimer
    '52623': 30,   # Lactate
}

cohort_ff = cohort.groupby("subject_id")[lab_cols].ffill()

for col in lab_cols:
    window = lab_window_dict[col]  # define this dict
    too_old = cohort[col + "_days_since_last"] > window
    cohort_ff.loc[too_old, col] = np.nan

cohort[lab_cols] = cohort_ff

mask_df = cohort[lab_cols].notna().astype(int).add_suffix('_mask')
cohort = pd.concat([cohort, mask_df], axis=1)
cohort.head()

Unnamed: 0,chartdate,subject_id,50803,50809,50822,50824,50837,50841,50842,50847,...,51100_mask,51106_mask,51478_mask,51790_mask,51981_mask,52024_mask,52546_mask,52569_mask,52610_mask,52623_mask
0,2136-04-08,10000635,,,,,,,,,...,0,0,0,0,0,0,0,0,0,0
1,2138-09-29,10000635,,,,,,,,,...,0,0,0,0,0,0,0,0,0,0
2,2141-08-15,10000635,,,,,,,,,...,0,0,0,0,0,0,0,0,0,0
3,2142-12-23,10000635,,,,,,,,,...,0,0,0,0,0,0,0,0,0,0
4,2143-06-06,10000635,,,,,,,,,...,0,0,0,0,0,0,0,0,0,0


## Applying PyTorch Transformers

In [131]:
import torch
from torch.utils.data import Dataset, DataLoader

class LabTimeSeriesDataset(Dataset):
    def __init__(self, cohort, lab_cols, static_cols, max_seq_len=100):
        valid = cohort.groupby("subject_id").size()
        self.patients = valid[valid > 0].index.to_numpy()
        self.cohort = cohort
        self.lab_cols = lab_cols
        self.static_cols = static_cols
        self.max_seq_len = max_seq_len

    def __len__(self):
        return len(self.patients)

    def __getitem__(self, idx):
        pid = self.patients[idx]
        df = self.cohort[self.cohort['subject_id'] == pid].sort_values('chartdate')

        # Extract dynamic features
        values = df[self.lab_cols].to_numpy(dtype=np.float32)
        masks  = df[[f"{c}_mask" for c in self.lab_cols]].to_numpy(dtype=np.float32)
        deltas = df[[f"{c}_days_since_last" for c in self.lab_cols]].to_numpy(dtype=np.float32)

        dynamic_features = np.concatenate([values, masks, deltas], axis=1)

        # Pad/truncate to max_seq_len
        seq_len, feat_dim = dynamic_features.shape
        if seq_len < self.max_seq_len:
            pad = np.zeros((self.max_seq_len - seq_len, feat_dim))
            dynamic_features = np.vstack([dynamic_features, pad])
        else:
            dynamic_features = dynamic_features[-self.max_seq_len:]

        # Extract static features
        if len(df) == 0:
            static = np.zeros(len(self.static_cols), dtype=np.float32)
        else:
            static = df[self.static_cols].iloc[0].to_numpy(dtype=np.float32)

        static_values = torch.tensor(static, dtype=torch.float32)

        return {
            'dynamic': torch.tensor(dynamic_features, dtype=torch.float32),
            'static': static_values,
            'patient_id': pid
        }

In [132]:
import torch.nn as nn

class LabTransformer(nn.Module):
    def __init__(self, input_dim, d_model=128, nhead=8, num_layers=2, num_static=0, output_dim=1):
        super().__init__()
        
        # Optional: embed static features
        self.static_fc = nn.Linear(num_static, d_model) if num_static > 0 else None

        # Project dynamic features to model dimension
        self.input_fc = nn.Linear(input_dim, d_model)

        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Prediction head
        self.fc_out = nn.Linear(d_model, output_dim)

    def forward(self, dynamic, static=None):
        """
        dynamic: (batch_size, seq_len, input_dim)
        static: (batch_size, num_static)
        """
        x = self.input_fc(dynamic)  # -> (batch_size, seq_len, d_model)

        # Add static features if available
        if static is not None and self.static_fc is not None:
            static_emb = self.static_fc(static).unsqueeze(1)  # (batch, 1, d_model)
            x = x + static_emb  # broadcast across seq_len

        # Transformer expects (seq_len, batch, d_model)
        x = x.permute(1, 0, 2)

        x = self.transformer(x)  # (seq_len, batch, d_model)

        # Take last timestep for prediction
        out = x[-1]             # (batch, d_model)
        out = self.fc_out(out)  # (batch, output_dim)
        return out

In [133]:
static_cols

['subject_id',
 'gender',
 'age',
 'n_admissions',
 'first_admission_date',
 'last_admission_date',
 'hypertension_flag',
 'ckd_flag',
 'obesity_flag',
 'neuropathy_flag',
 'retinopathy_flag',
 'heart_disease_flag',
 'insulin_flag']

In [134]:
cohort["chartdate"] = pd.to_datetime(cohort["chartdate"])
cohort["first_admission_date"] = pd.to_datetime(cohort["first_admission_date"], errors="coerce")
cohort["last_admission_date"] = pd.to_datetime(cohort["last_admission_date"], errors="coerce")

cohort["gender"] = cohort["gender"].map({"M": 0, "F": 1}).astype("float32")
cohort["age"] = pd.to_numeric(cohort["age"], errors="coerce").astype("float32")
cohort["n_admissions"] = pd.to_numeric(cohort["n_admissions"], errors="coerce").astype("float32")

flag_cols = [
    'hypertension_flag', 'ckd_flag', 'obesity_flag', 'neuropathy_flag',
    'retinopathy_flag', 'heart_disease_flag', 'insulin_flag'
]

for col in flag_cols:
    cohort[col] = pd.to_numeric(cohort[col], errors="coerce").fillna(0).astype("float32")

cohort["days_since_first_admission"] = (
    (cohort["chartdate"] - cohort["first_admission_date"]).dt.days
)

cohort["days_since_last_admission"] = (
    (cohort["chartdate"] - cohort["last_admission_date"]).dt.days
)

cohort["days_since_first_admission"] = cohort["days_since_first_admission"].astype("float32")
cohort["days_since_last_admission"]  = cohort["days_since_last_admission"].astype("float32")

if ("first_admission_date" in static_cols) and ("last_admission_date" in static_cols):
    static_cols.remove("first_admission_date")
    static_cols.remove("last_admission_date")

if ("days_since_first_admission" not in static_cols) and ("days_since_last_admission" not in static_cols):
    static_cols.extend(["days_since_first_admission", "days_since_last_admission"])

In [135]:
cohort[static_cols].head()

Unnamed: 0,subject_id,gender,age,n_admissions,hypertension_flag,ckd_flag,obesity_flag,neuropathy_flag,retinopathy_flag,heart_disease_flag,insulin_flag,days_since_first_admission,days_since_last_admission
0,10000635,1.0,74.0,2.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-73.0,-2816.0
1,10000635,1.0,74.0,2.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,831.0,-1912.0
2,10000635,1.0,74.0,2.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1882.0,-861.0
3,10000635,1.0,74.0,2.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2377.0,-366.0
4,10000635,1.0,74.0,2.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2542.0,-201.0


In [136]:
for col in static_cols:
    if cohort[col].dtype == object:
        cohort[col] = pd.to_numeric(cohort[col], errors="coerce").astype("float32")

cohort[static_cols].dtypes

subject_id                    float32
gender                        float32
age                           float32
n_admissions                  float32
hypertension_flag             float32
ckd_flag                      float32
obesity_flag                  float32
neuropathy_flag               float32
retinopathy_flag              float32
heart_disease_flag            float32
insulin_flag                  float32
days_since_first_admission    float32
days_since_last_admission     float32
dtype: object

In [150]:
for col in lab_cols:
    cohort[col] = pd.to_numeric(cohort[col], errors="coerce").astype("float32")
    cohort[col] = cohort[col].fillna(0.0)
    cohort[col + "_mask"] = (cohort[col] != 0).astype(np.float32)

    cohort[col + "_days_since_last"] = cohort[col + "_days_since_last"].fillna(0.0)

# Create DataLoader
dataset = LabTimeSeriesDataset(cohort, lab_cols, static_cols, max_seq_len=100)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

# Transformer model
input_dim = len(lab_cols) * 3  # values + masks + deltas
num_static = len(static_cols)
model = LabTransformer(input_dim=input_dim, num_static=num_static, output_dim=1)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()  # no training, just embeddings

all_embeddings = []
all_patient_ids = []

with torch.no_grad():
    for batch in dataloader:
        dynamic = batch['dynamic'].to(device)  # (batch_size, seq_len, input_dim)
        static = batch['static'].to(device)    # (batch_size, num_static)
        patient_ids = batch['patient_id']
        
        embeddings = model(dynamic, static)  # (batch_size, 1)
        all_embeddings.append(embeddings.cpu())
        all_patient_ids.extend(patient_ids)

all_embeddings = torch.cat(all_embeddings, dim=0).numpy()  # shape: (num_patients, embedding_dim)
print("Embeddings shape:", all_embeddings.shape)
print("Patient IDs length:", len(all_patient_ids))



Embeddings shape: (45948, 1)
Patient IDs length: 45948


In [154]:
stacked_embeddings = np.vstack(all_embeddings)
np.save("../data/processed/patient_embeddings.npy", stacked_embeddings)
np.save("../data/processed/patient_ids.npy", np.array(all_patient_ids))