<a href="https://colab.research.google.com/github/miracle078/ELEC70121_TAIMI_2025/blob/main/bn2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# AKI Prediction 48h Onset - Full Implementation
Bayesian Neural Network with Clinical Interpretability

## 1. Environment Setup & Configuration

In [16]:
# Install core dependencies
%pip install torch pyro-ppl shap pandas scikit-learn matplotlib ipywidgets
%pip install --upgrade numpy
%pip install --upgrade scipy shap
import numpy as np
print("NumPy version:", np.__version__)

Collecting numpy
  Using cached numpy-2.2.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (62 kB)
Using cached numpy-2.2.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (16.4 MB)
Installing collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 2.0.2
    Uninstalling numpy-2.0.2:
      Successfully uninstalled numpy-2.0.2
[31mERROR: Operation cancelled by user[0m[31m
Collecting numpy<2.5,>=1.23.5 (from scipy)
  Using cached numpy-2.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (60 kB)
Using cached numpy-2.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (19.5 MB)
[31mERROR: Operation cancelled by user[0m[31m
[0mNumPy version: 2.0.2


In [25]:
import torch
import pyro
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import shap
from pyro.infer import SVI, Trace_ELBO, Predictive
from pyro.optim import Adam
from pyro.distributions import Normal, Bernoulli
from pyro.nn import PyroModule, PyroSample
from sklearn.model_selection import GroupShuffleSplit
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (roc_auc_score, average_precision_score,
                            confusion_matrix, classification_report)
from sklearn.calibration import calibration_curve
from tqdm.auto import trange
import ipywidgets as widgets
from IPython.display import display
import os

## 2. Hardware Optimization Setup

In [26]:
# Configure device with automatic fallback
def configure_hardware():
    device = None

    # Check CUDA first
    if torch.cuda.is_available():
        device = torch.device('cuda')
        torch.backends.cudnn.benchmark = True
        torch.set_float32_matmul_precision('high')
        print(f"Using NVIDIA GPU: {torch.cuda.get_device_name(0)}")
        print(f"Total VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB")

    # Check Apple Silicon
    elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        device = torch.device('mps')
        print("Using Apple Silicon GPU")

    # Check Intel GPUs
    elif torch.xpu.is_available():
        device = torch.device('xpu')
        import intel_extension_for_pytorch as ipex
        print(f"Using Intel GPU: {torch.xpu.get_device_name(0)}")

    # Fallback to CPU
    else:
        device = torch.device('cpu')
        print("Using CPU")

    # Memory optimization
    if device.type in ['cuda', 'xpu']:
        torch.set_float32_matmul_precision('high')
        torch.backends.cuda.matmul.allow_tf32 = True

    return device

DEVICE = configure_hardware()

Using CPU


## 3. Data Loading & Preprocessing Pipeline

In [27]:
class AKIDataProcessor:
    def __init__(self, data_path):
        self.data_path = data_path
        self.labs = ['creatinine', 'bicarbonate', 'chloride', 'glucose', 'magnesium',
                     'potassium', 'sodium', 'urea_nitrogen', 'hemoglobin', 'platelet_count',
                     'wbc_count', 'lactate', 'paco2', 'ph', 'pao2', 'albumin', 'anion_gap']

        self.vitals = ['heart_rate', 'resp_rate', 'temperature', 'spo2',
                      'nbp_sys', 'nbp_dias', 'nbp_mean', 'gcs_total']

        self.coags = ['inr', 'pt', 'aptt']
        self.urine = ['urine_output_ml', 'urine_or', 'urine_pacu']
        self.demogs = ['age', 'gender', 'race', 'weight', 'height']

        # 'admission_type'

    def load_and_preprocess(self):
        # Load raw data
        df = pd.read_csv("/content/drive/MyDrive/AKI-sample-clean.csv")

        # Convert categorical demographics to numeric codes (if they are object types)
        for col in ['gender', 'race', 'admission_type']:
            if df[col].dtype == object:
                df[col] = pd.factorize(df[col])[0]

        # Temporal features
        df['hour_from_icu'] = df['hour_from_icu'].astype(int)
        df['early_icu_period'] = (df['hour_from_icu'] <= 24).astype(int)
        df['critical_window'] = ((df['hour_from_icu'] >= 24) & (df['hour_from_icu'] <= 48)).astype(int)

        # Creatinine dynamics
        df = df.sort_values(['stay_id', 'hour_from_icu'])
        df['creatinine_24h_change'] = df.groupby('stay_id')['creatinine'].transform(
            lambda x: x.diff().rolling(24, min_periods=1).mean()
        )

        # KDIGO staging
        df['aki_stage'] = 0
        df['aki_stage'] = np.where(df['creatinine'] >= 1.5 * df['creatinine'].shift(24), 1, df['aki_stage'])
        df['aki_stage'] = np.where(df['creatinine'] >= 2.0 * df['creatinine'].shift(24), 2, df['aki_stage'])
        df['aki_stage'] = np.where(df['creatinine'] >= 3.0 * df['creatinine'].shift(24), 3, df['aki_stage'])

        # Target definition
        df['aki_48h'] = df.groupby('stay_id')['aki_stage'].transform(
            lambda x: (x.shift(-48) >= 1).any().astype(int)
        )

        # Missing indicators for lab, vital, coags, and urine variables
        for col in self.labs + self.vitals + self.coags + self.urine:
            df[f'is_{col}_missing'] = df[col].isnull().astype(int)

        return df

## 4. Bayesian Neural Network Architecture

In [28]:
class BayesianAKINetwork(PyroModule):
    def __init__(self, input_dim):
        super().__init__()
        self.input_dim = input_dim

        # Informed priors scaled by input dimension
        self.fc1 = PyroModule[nn.Linear](input_dim, 64)
        self.fc1.weight = PyroSample(
            lambda prior: Normal(0, 1/np.sqrt(input_dim)).expand([64, input_dim]).to_event(2))
        self.fc1.bias = PyroSample(Normal(0, 0.1).expand([64]).to_event(1))

        self.fc2 = PyroModule[nn.Linear](64, 32)
        self.fc2.weight = PyroSample(
            lambda prior: Normal(0, 1/np.sqrt(64)).expand([32, 64]).to_event(2))
        self.fc2.bias = PyroSample(Normal(0, 0.1).expand([32]).to_event(1))

        self.out = PyroModule[nn.Linear](32, 1)
        self.out.weight = PyroSample(
            lambda prior: Normal(0, 1/np.sqrt(32)).expand([1, 32]).to_event(2))
        self.out.bias = PyroSample(Normal(0, 0.1).expand([1]).to_event(1))

        # Dropout for regularization
        self.dropout = nn.Dropout(0.3)

    def forward(self, x, y=None):
        x = x.to(DEVICE)
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = torch.relu(self.fc2(x))
        x = self.dropout(x)
        logits = self.out(x).squeeze()
        probs = torch.sigmoid(logits)

        with pyro.plate("data", x.size(0)):
            obs = pyro.sample("obs", Bernoulli(probs), obs=y)

        return probs

## 5. Training Pipeline with GPU Optimization

In [29]:
class AKITrainer:
    def __init__(self, data_path):
        self.processor = AKIDataProcessor(data_path)
        self.df = self.processor.load_and_preprocess()
        self.features = None
        self.scaler = StandardScaler()
        self.imputer = IterativeImputer(max_iter=50, random_state=42)

    def preprocess_data(self):
        # Feature selection
        features = (self.processor.labs + self.processor.vitals +
                   self.processor.demogs + self.processor.coags +
                   self.processor.urine + ['creatinine_24h_change',
                   'early_icu_period', 'critical_window'] +
                   [f'is_{col}_missing' for col in
                    self.processor.labs + self.processor.vitals +
                    self.processor.coags + self.processor.urine])

        # Impute missing values
        X = self.imputer.fit_transform(self.df[features])
        X = self.scaler.fit_transform(X)

        # Temporal split
        splitter = GroupShuffleSplit(n_splits=1, test_size=0.3, random_state=42)
        train_idx, test_idx = next(splitter.split(X, groups=self.df['stay_id']))
        val_idx = next(splitter.split(X[train_idx], groups=self.df.iloc[train_idx]['stay_id']))[1]

        # Convert to tensors
        X_train = torch.tensor(X[train_idx], dtype=torch.float32, device=DEVICE)
        y_train = torch.tensor(self.df.iloc[train_idx]['aki_48h'].values,
                              dtype=torch.float32, device=DEVICE)

        X_val = torch.tensor(X[val_idx], dtype=torch.float32, device=DEVICE)
        y_val = torch.tensor(self.df.iloc[val_idx]['aki_48h'].values,
                            dtype=torch.float32, device=DEVICE)

        X_test = torch.tensor(X[test_idx], dtype=torch.float32, device=DEVICE)
        y_test = torch.tensor(self.df.iloc[test_idx]['aki_48h'].values,
                             dtype=torch.float32, device=DEVICE)

        return (X_train, y_train), (X_val, y_val), (X_test, y_test)

    def train(self, num_epochs=1000, patience=20):
        (X_train, y_train), (X_val, y_val), _ = self.preprocess_data()

        # Initialize model
        input_dim = X_train.shape[1]
        model = BayesianAKINetwork(input_dim).to(DEVICE)
        guide = pyro.infer.autoguide.AutoDiagonalNormal(model)

        # Class weighting
        class_counts = torch.bincount(y_train.long())
        class_weights = 1. / class_counts
        class_weights = class_weights.to(DEVICE)

        # Configure SVI
        optimizer = Adam({"lr": 0.001, "betas": (0.95, 0.999)})
        svi = SVI(model, guide, optimizer, loss=self._weighted_loss(class_weights))

        # Training state
        best_loss = float('inf')
        patience_counter = 0
        training_loss = []
        validation_loss = []

        # Mixed precision
        scaler = torch.cuda.amp.GradScaler() if DEVICE.type == 'cuda' else None

        # Training loop
        progress = trange(num_epochs, desc="Training")
        for epoch in progress:
            # Train step
            epoch_loss = 0.0
            model.train()
            with torch.autocast(device_type=DEVICE.type, enabled=DEVICE.type != 'cpu'):
                loss = svi.step(X_train, y_train)

            if scaler:
                scaler.scale(loss).backward()
                scaler.step(svi.optim)
                scaler.update()
            else:
                loss.backward()
                svi.optim.step()

            # Validation
            model.eval()
            with torch.no_grad(), torch.autocast(device_type=DEVICE.type):
                val_loss = -guide.log_prob(X_val, y_val).item()

            # Early stopping
            if val_loss < best_loss:
                best_loss = val_loss
                patience_counter = 0
                torch.save(model.state_dict(), "best_model.pth")
            else:
                patience_counter += 1

            if patience_counter >= patience:
                progress.close()
                print(f"Early stopping at epoch {epoch}")
                break

            training_loss.append(loss)
            validation_loss.append(val_loss)
            progress.set_postfix({"Train Loss": loss, "Val Loss": val_loss})

        return model, guide, (training_loss, validation_loss)

    def _weighted_loss(self, class_weights):
        def loss_fn(model, guide, x, y):
            probs = model(x)
            return -torch.sum(Bernoulli(probs).log_prob(y) * class_weights[y.long()])
        return loss_fn

## 6. Clinical Evaluation & Interpretability

In [30]:
class AKIEvaluator:
    def __init__(self, model, guide, test_data):
        self.model = model
        self.guide = guide
        self.X_test, self.y_test = test_data

    def predict(self, uncertainty_threshold=0.15):
        with torch.no_grad(), torch.autocast(device_type=DEVICE.type):
            predictive = Predictive(self.model, guide=self.guide, num_samples=200)
            samples = predictive(self.X_test)

        probs = samples["obs"].float().mean(dim=0).cpu().numpy()
        uncertainty = samples["obs"].float().std(dim=0).cpu().numpy()

        # Adaptive deferral
        defer_mask = uncertainty > uncertainty_threshold
        predictions = np.where(defer_mask, -1, (probs > 0.5).astype(int))

        return predictions, probs, uncertainty

    def evaluate(self, predictions, probs):
        y_true = self.y_test.cpu().numpy()
        valid_mask = predictions != -1
        y_valid = y_true[valid_mask]
        pred_valid = predictions[valid_mask]

        metrics = {
            "defer_rate": 1 - valid_mask.mean(),
            "auc_roc": roc_auc_score(y_true, probs),
            "auc_prc": average_precision_score(y_true, probs),
            "sensitivity": confusion_matrix(y_valid, pred_valid)[1,1] / y_valid.sum(),
            "specificity": confusion_matrix(y_valid, pred_valid)[0,0] / (len(y_valid)-y_valid.sum()),
            "calibration": calibration_curve(y_true, probs, n_bins=10)
        }

        return metrics

    def shap_analysis(self, background_samples=100):
        background = self.X_test[:background_samples].to(DEVICE)
        test_samples = self.X_test[100:105].to(DEVICE)

        def model_wrapper(x):
            with torch.no_grad(), torch.autocast(device_type=DEVICE.type):
                return self.model(x).cpu().numpy()

        explainer = shap.DeepExplainer(model_wrapper, background)
        shap_values = explainer.shap_values(test_samples)

        return shap_values

## 7. Execution Pipeline

In [31]:
def main(data_path):
    # Initialize components
    trainer = AKITrainer(data_path)
    model, guide, losses = trainer.train()

    # Load test data
    _, _, test_data = trainer.preprocess_data()

    # Evaluate
    evaluator = AKIEvaluator(model, guide, test_data)
    preds, probs, uncertainty = evaluator.predict()
    metrics = evaluator.evaluate(preds, probs)

    # Interpretability
    shap_values = evaluator.shap_analysis()

    return metrics, shap_values

## 8. Run the Full Pipeline

In [32]:
if __name__ == "__main__":
    # Configure data path
    data_path = "AKI-data.csv"  # Update this path

    # Run pipeline
    metrics, shap_values = main("/content/drive/MyDrive/AKI-sample-clean.csv")

    # Display results
    print("Clinical Performance Metrics:")
    print(f"AUC-ROC: {metrics['auc_roc']:.3f}")
    print(f"AUC-PRC: {metrics['auc_prc']:.3f}")
    print(f"Deferral Rate: {metrics['defer_rate']:.1%}")
    print(f"Sensitivity: {metrics['sensitivity']:.1%}")
    print(f"Specificity: {metrics['specificity']:.1%}")

    # Plot calibration
    plt.figure(figsize=(8,8))
    plt.plot(metrics['calibration'][1], metrics['calibration'][0], 's-')
    plt.plot([0,1], [0,1], 'k--')
    plt.title("Model Calibration")
    plt.xlabel("Mean Predicted Probability")
    plt.ylabel("Observed Frequency")
    plt.show()

    # Plot SHAP values
    shap.summary_plot(shap_values, feature_names=trainer.features)

KeyError: 'admission_type'

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
if __name__ == "__main__":
    # Configure data path
    data_path = "/content/drive/MyDrive/AKI-sample-clean.csv"

    # Initialize progress tracking
    progress_bar = widgets.IntProgress(
        value=0,
        min=0,
        max=6,
        description='Pipeline Progress:',
        style={'description_width': 'initial'},
        bar_style='info'
    )
    display(progress_bar)

    # 1. Data Loading & Preprocessing
    print("🔄 [1/6] Loading and preprocessing data...")
    trainer = AKITrainer(data_path)
    (X_train, y_train), (X_val, y_val), (X_test, y_test) = trainer.preprocess_data()
    progress_bar.value += 1

    # 2. Model Training with Live Updates
    print("\n🔥 [2/6] Training model...")
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    plt.ion()  # Enable interactive mode

    def training_callback(epoch, train_loss, val_loss):
        ax1.clear()
        ax1.plot(train_loss, label='Training Loss')
        ax1.plot(val_loss, label='Validation Loss')
        ax1.set_title("Loss Trajectory")
        ax1.legend()

        ax2.clear()
        ax2.plot(np.log10(train_loss), label='Log Training Loss')
        ax2.set_title("Log-Scale Loss")
        ax2.legend()

        plt.pause(0.01)

    model, guide, (train_loss, val_loss) = trainer.train(callback=training_callback)
    progress_bar.value += 1
    plt.ioff()

    # 3. Model Evaluation
    print("\n📊 [3/6] Evaluating model performance...")
    evaluator = AKIEvaluator(model, guide, (X_test, y_test))
    preds, probs, uncertainty = evaluator.predict()
    progress_bar.value += 1

    # 4. Generate Metrics with Enhanced Display
    print("\n📈 [4/6] Calculating clinical metrics...")
    metrics = evaluator.evaluate(preds, probs)

    # Display metrics table
    from IPython.display import HTML
    metrics_html = f"""
    <style>
    .metrics-table {{background: #f8f9fa; padding: 15px; border-radius: 10px;}}
    .metric-value {{color: #2c3e50; font-weight: bold;}}
    </style>
    <div class='metrics-table'>
        <h3>Clinical Performance Report</h3>
        <p>✅ AUC-ROC: <span class='metric-value'>{metrics['auc_roc']:.3f}</span></p>
        <p>📉 AUC-PRC: <span class='metric-value'>{metrics['auc_prc']:.3f}</span></p>
        <p>⏸️ Deferral Rate: <span class='metric-value'>{metrics['defer_rate']:.1%}</span></p>
        <p>🎯 Sensitivity: <span class='metric-value'>{metrics['sensitivity']:.1%}</span></p>
        <p>🛡️ Specificity: <span class='metric-value'>{metrics['specificity']:.1%}</span></p>
    </div>
    """
    display(HTML(metrics_html))
    progress_bar.value += 1

    # 5. Visualizations
    print("\n🎨 [5/6] Generating visualizations...")

    # Calibration Plot
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.plot(metrics['calibration'][1], metrics['calibration'][0], 's-', color='#3498db')
    plt.plot([0,1], [0,1], 'k--', linewidth=1)
    plt.title("Model Calibration", fontsize=12)
    plt.xlabel("Predicted Probability", fontsize=10)
    plt.ylabel("Actual Frequency", fontsize=10)
    plt.grid(alpha=0.2)

    # SHAP Summary Plot
    plt.subplot(1, 2, 2)
    shap.summary_plot(shap_values, feature_names=trainer.features, plot_type='bar',
                     color='#e74c3c', show=False)
    plt.title("Feature Importance (SHAP Values)", fontsize=12)
    plt.tight_layout()
    progress_bar.value += 1

    # 6. Final Output
    print("\n✅ [6/6] Pipeline complete!")
    progress_bar.bar_style = 'success'

    # Show all plots
    plt.show()

    # Add interactive controls
    print("\n🔧 Additional Controls:")
    uncertainty_slider = widgets.FloatSlider(
        value=0.15,
        min=0.05,
        max=0.5,
        step=0.05,
        description='Uncertainty Threshold:',
        style={'description_width': 'initial'}
    )

    def update_threshold(change):
        new_preds = evaluator.predict(change.new)
        new_metrics = evaluator.evaluate(new_preds, probs)
        print(f"\nUpdated Metrics (Threshold={change.new:.2f}):")
        print(f"Sensitivity: {new_metrics['sensitivity']:.1%}")
        print(f"Specificity: {new_metrics['specificity']:.1%}")

    uncertainty_slider.observe(update_threshold, names='value')
    display(uncertainty_slider)

## 9. Save Notebook State

In [None]:
# Save complete notebook
from IPython.display import Javascript
from google.colab import files

def save_notebook():
    display(Javascript('''
    require(["base/js/namespace"], function(IPython) {
        IPython.notebook.save_checkpoint();
    });
    '''))
    !jupyter nbconvert --to notebook /content/bn2.ipynb
    files.download("/content/aki_prediction_full.ipynb")

# Uncomment to save
# save_notebook()