# ü´Å Multimodal Survival Modeling in Pulmonary Tuberculosis
### Chest X-ray + Clinical Covariates

**Author:** Dr. Ikechukwu Ephraim Ugbo, MD  
**Project:** Innovative AI Healthcare Solutions  
**Framework:** TensorFlow / Keras  
**Study Type:** Prognostic modeling (time-to-event)

---

### Clinical Question
In adults with suspected or confirmed pulmonary tuberculosis, can baseline chest X-ray
features combined with clinical covariates predict **time to major complication or death**?

---

### Modeling Strategy
1. Baseline Cox proportional hazards model (clinical variables only)
2. Multimodal neural Cox model (images + clinical)
3. Comparison using Harrell‚Äôs C-index
4. Explainability via Grad-CAM


In [None]:
import os
import numpy as np
import pandas as pd
import tensorflow as tf

from src.config import *
from src.data_utils import generate_synthetic_tb_clinical_data,load_clinical_data, train_val_split
from src.model_utils import TBSurvivalNet
from src.survival_utils import harrell_c_index
from src.training_utils import compile_survival_model
from src.explainability_utils import generate_gradcam

In [None]:
tf.random.set_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

print("TensorFlow:", tf.__version__)
print("GPUs:", tf.config.list_physical_devices("GPU"))

## Data Overview

### Imaging
- Baseline chest X-rays
- Single frontal view per patient
- Pre-treatment images only

### Clinical Covariates
- Demographics (age, sex)
- Comorbidities (HIV, diabetes)
- Baseline labs (albumin, hemoglobin)
- Nutritional status (BMI)

### Outcome
**Primary outcome:**  
Composite of major complication or death

**Time scale:**  
Days from baseline CXR to event or censoring

In [None]:
clinical_features = [
    "age", "sex", "hiv_status", "diabetes",
    "bmi", "albumin", "hemoglobin"
]

X_tabular, time, event, df = load_clinical_data(
    csv_path="data/processed/clinical.csv",
    feature_cols=clinical_features,
    time_col="time_to_event",
    event_col="event"
)


In [None]:
X_images = np.load("data/processed/images.npy")

In [None]:
(
    X_img_train, X_img_val,
    X_tab_train, X_tab_val,
    time_train, time_val,
    event_train, event_val
) = train_val_split(
    X_images, X_tabular, time, event
)


## Baseline Model: Cox Proportional Hazards

Before introducing imaging data, we establish a baseline
using a classical Cox proportional hazards model with
clinical covariates only.

This provides:
- Interpretability (hazard ratios)
- A benchmark for model comparison


In [None]:
from lifelines import CoxPHFitter

df_train = pd.DataFrame(X_tab_train, columns=clinical_features)
df_train["time"] = time_train
df_train["event"] = event_train

cph = CoxPHFitter()
cph.fit(df_train, duration_col="time", event_col="event")

cph.summary


In [None]:
cox_risk = cph.predict_partial_hazard(df_train).values.ravel()

c_index_cox = harrell_c_index(
    time_train, event_train, cox_risk
)

print("Baseline Cox C-index:", c_index_cox)


## Multimodal Survival Architecture

The final model consists of:
- Image encoder ‚Üí latent imaging embedding
- Tabular encoder ‚Üí clinical feature representation
- Fusion layer
- Survival prediction head optimized via Cox partial likelihood

In [None]:
from tensorflow.keras.applications import DenseNet121

image_encoder = DenseNet121(
    include_top=False,
    weights="imagenet",
    input_shape=(224, 224, 3),
    pooling="avg"
)

image_encoder.trainable = False


In [None]:
model = TBSurvivalNet(
    image_encoder=image_encoder,
    tabular_dim=X_tabular.shape[1]
)

model.build([
    (None, 224, 224, 3),
    (None, X_tabular.shape[1])
])

model.summary()


In [None]:
model = TBSurvivalNet(
    image_encoder=image_encoder,
    tabular_dim=X_tabular.shape[1],
    hidden_dim=256
)

model.build([
    (None, 224, 224, 3),
    (None, X_tabular.shape[1])
])

model.summary()


## Multimodal Training Strategy

- Loss: Cox partial likelihood
- Optimizer: Adam
- Learning rate scheduling
- Early stopping
- Checkpointing
- TensorBoard logging

The neural network is trained to output a **log-risk score**
compatible with Cox partial likelihood optimization.


In [None]:
model = compile_survival_model(
    model,
    lr=LEARNING_RATE
)

In [None]:
risk_val = model.predict((X_img_val, X_tab_val)).ravel()

c_index_multimodal = harrell_c_index(
    time_val, event_val, risk_val
)

print("Multimodal C-index:", c_index_multimodal)


## Explainability

Grad-CAM is used to visualize image regions contributing
to higher predicted risk, supporting clinical interpretation.

In [None]:
generate_gradcam(
    model,
    image=X_img_val[0],
    tabular_dim=X_tabular.shape[1],
    layer_name="conv5_block16_concat"
)


## Interpretation and Next Steps

- Compare clinical vs multimodal performance
- Perform subgroup analyses (HIV, age)
- External validation
- Competing risks modeling
- Manuscript preparation


In [None]:
### Summary

This notebook provides a reproducible, interpretable framework
for survival modeling in pulmonary tuberculosis using
chest X-ray imaging and clinical data.

It serves as the foundation for further validation and publication
within the **Innovative AI Healthcare Solutions** initiative.
