<a href="https://colab.research.google.com/github/kchuri01/Bayesian-Neural-Network-CHD-Risk/blob/main/bnn_framingham.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import math
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

from sklearn.model_selection import train_test_split
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score, accuracy_score

In [2]:
# Load dataset
DATA_PATH = "/content/framingham.csv.xls"

In [3]:
df = pd.read_csv(DATA_PATH, na_values=["NA"])

print(df.head())
print(df.info())

   male  age  education  currentSmoker  cigsPerDay  BPMeds  prevalentStroke  \
0     1   39        4.0              0         0.0     0.0                0   
1     0   46        2.0              0         0.0     0.0                0   
2     1   48        1.0              1        20.0     0.0                0   
3     0   61        3.0              1        30.0     0.0                0   
4     0   46        3.0              1        23.0     0.0                0   

   prevalentHyp  diabetes  totChol  sysBP  diaBP    BMI  heartRate  glucose  \
0             0         0    195.0  106.0   70.0  26.97       80.0     77.0   
1             0         0    250.0  121.0   81.0  28.73       95.0     76.0   
2             0         0    245.0  127.5   80.0  25.34       75.0     70.0   
3             1         0    225.0  150.0   95.0  28.58       65.0    103.0   
4             0         0    285.0  130.0   84.0  23.10       85.0     85.0   

   TenYearCHD  
0           0  
1           0  
2 

In [4]:
TARGET = "TenYearCHD"
assert TARGET in df.columns, f"Target '{TARGET}' not found. Columns: {df.columns.tolist()}"

In [5]:
# Separate features/target
y = df[TARGET].astype(float).values
X = df.drop(columns=[TARGET])

In [6]:
for col in X.columns:
    if X[col].dtype == "object":
        X[col] = X[col].astype(str).str.lower().map({"male": 1.0, "m": 1.0, "female": 0.0, "f": 0.0})
        # if mapping fails, fall back to factorize
        if X[col].isna().any():
            X[col] = pd.factorize(df[col].astype(str))[0].astype(float)

X = X.astype(float).values

In [7]:
# Train/test split + preprocessing
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

imputer = SimpleImputer(strategy="median")
scaler = StandardScaler()

X_train = imputer.fit_transform(X_train)
X_test = imputer.transform(X_test)

X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

In [8]:
# Convert to torch tensors
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
X_train_t = torch.tensor(X_train, dtype=torch.float32).to(device)
y_train_t = torch.tensor(y_train, dtype=torch.float32).to(device)

X_test_t  = torch.tensor(X_test, dtype=torch.float32).to(device)
y_test_t  = torch.tensor(y_test, dtype=torch.float32).to(device)

train_ds = TensorDataset(X_train_t, y_train_t)
train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)

In [9]:
# Class imbalance handling
pos_rate = float(y_train.mean())
# pos_weight > 1 increases penalty on positive misclassification
pos_weight = torch.tensor([(1 - pos_rate) / max(pos_rate, 1e-6)], device=device)

In [10]:
# Bayesian Layers (Variational)
class BayesLinear(nn.Module):

    def __init__(self, in_features, out_features, prior_std=1.0):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        # Variational parameters
        self.weight_mu = nn.Parameter(torch.zeros(out_features, in_features))
        self.weight_rho = nn.Parameter(torch.full((out_features, in_features), -5.0))
        self.bias_mu = nn.Parameter(torch.zeros(out_features))
        self.bias_rho = nn.Parameter(torch.full((out_features,), -5.0))

        # Prior parameters
        self.prior_std = prior_std

    def _sigma(self, rho):
        return torch.log1p(torch.exp(rho))

    def kl_divergence(self):

        w_sigma = self._sigma(self.weight_rho)
        b_sigma = self._sigma(self.bias_rho)

        prior_var = self.prior_std ** 2
        w_var = w_sigma ** 2
        b_var = b_sigma ** 2

        # KL for each parameter then sum
        kl_w = 0.5 * torch.sum((w_var + self.weight_mu**2) / prior_var - 1.0 - torch.log(w_var / prior_var))
        kl_b = 0.5 * torch.sum((b_var + self.bias_mu**2) / prior_var - 1.0 - torch.log(b_var / prior_var))
        return kl_w + kl_b

    def forward(self, x, sample=True):
        if sample:
            w_sigma = self._sigma(self.weight_rho)
            b_sigma = self._sigma(self.bias_rho)

            # Reparameterisation
            eps_w = torch.randn_like(self.weight_mu)
            eps_b = torch.randn_like(self.bias_mu)
            weight = self.weight_mu + w_sigma * eps_w
            bias = self.bias_mu + b_sigma * eps_b
        else:
            # Use posterior mean for deterministic pass
            weight = self.weight_mu
            bias = self.bias_mu

        return F.linear(x, weight, bias)

class BayesianMLP(nn.Module):

    def __init__(self, in_dim, hidden_dim=16, prior_std=1.0):
        super().__init__()
        self.b1 = BayesLinear(in_dim, hidden_dim, prior_std=prior_std)
        self.b2 = BayesLinear(hidden_dim, 1, prior_std=prior_std)

    def kl(self):
        return self.b1.kl_divergence() + self.b2.kl_divergence()

    def forward(self, x, sample=True):
        x = self.b1(x, sample=sample)
        x = torch.relu(x)
        logits = self.b2(x, sample=sample).squeeze(-1)
        return logits

In [11]:
# Train with ELBO (NLL + KL)
model = BayesianMLP(in_dim=X_train.shape[1], hidden_dim=16, prior_std=1.0).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Bernoulli likelihood to BCEWithLogitsLoss (numerically stable)
bce = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

def train_epoch(model, loader, beta=1.0):
    model.train()
    total_loss = 0.0
    total_nll = 0.0
    total_kl = 0.0
    n = 0

    for xb, yb in loader:
        optimizer.zero_grad()

        logits = model(xb, sample=True)
        nll = bce(logits, yb)

        # Scale KL by dataset size
        kl = model.kl() / X_train_t.shape[0]

        loss = nll + beta * kl
        loss.backward()
        optimizer.step()

        bs = xb.size(0)
        total_loss += float(loss.item()) * bs
        total_nll += float(nll.item()) * bs
        total_kl += float(kl.item()) * bs
        n += bs

    return total_loss / n, total_nll / n, total_kl / n

@torch.no_grad()
def mc_predict_proba(model, X, mc_samples=200):

    model.eval()
    probs = []
    for _ in range(mc_samples):
        logits = model(X, sample=True)
        probs.append(torch.sigmoid(logits).unsqueeze(0))  # [1, N]
    probs = torch.cat(probs, dim=0)                       # [S, N]
    return probs       # return all samples to compute uncertainty

# Train
EPOCHS = 80
for epoch in range(1, EPOCHS + 1):
    loss, nll, kl = train_epoch(model, train_loader, beta=1.0)
    if epoch % 10 == 0 or epoch == 1:
        print(f"Epoch {epoch:03d} | Loss {loss:.4f} | NLL {nll:.4f} | KL {kl:.6f}")

Epoch 001 | Loss 1.5365 | NLL 1.1748 | KL 0.361644
Epoch 010 | Loss 1.3690 | NLL 1.0260 | KL 0.343074
Epoch 020 | Loss 1.3340 | NLL 1.0111 | KL 0.322900
Epoch 030 | Loss 1.3122 | NLL 1.0087 | KL 0.303557
Epoch 040 | Loss 1.2910 | NLL 1.0059 | KL 0.285126
Epoch 050 | Loss 1.2723 | NLL 1.0047 | KL 0.267609
Epoch 060 | Loss 1.2524 | NLL 1.0010 | KL 0.251422
Epoch 070 | Loss 1.2370 | NLL 1.0004 | KL 0.236669
Epoch 080 | Loss 1.2222 | NLL 0.9987 | KL 0.223514


In [12]:
# Evaluate (AUC/Accuracy) + Uncertainty
probs_samps = mc_predict_proba(model, X_test_t, mc_samples=300)  # [S, N]
probs_mean = probs_samps.mean(dim=0).cpu().numpy()               # [N]
probs_std  = probs_samps.std(dim=0).cpu().numpy()                # [N]

y_true = y_test_t.cpu().numpy()

auc = roc_auc_score(y_true, probs_mean)
pred = (probs_mean >= 0.5).astype(int)
acc = accuracy_score(y_true, pred)

# 95% credible interval (approx) from MC samples
lower = np.quantile(probs_samps.cpu().numpy(), 0.025, axis=0)
upper = np.quantile(probs_samps.cpu().numpy(), 0.975, axis=0)

print("\n--- Test Performance ---")
print(f"ROC-AUC:  {auc:.4f}")
print(f"Accuracy: {acc:.4f}")

# A few predictions with uncertainty
out = pd.DataFrame({"y_true": y_true.astype(int), "p_mean": probs_mean,
                    "p_std": probs_std, "p_2.5%": lower, "p_97.5%": upper})

print("\n--- Sample Predictions (with uncertainty) ---")
print(out.sort_values("p_std", ascending=False).head(10).to_string(index=False))


--- Test Performance ---
ROC-AUC:  0.6948
Accuracy: 0.6474

--- Sample Predictions (with uncertainty) ---
 y_true   p_mean    p_std   p_2.5%  p_97.5%
      0 0.702786 0.068225 0.579354 0.822898
      0 0.799691 0.053340 0.682819 0.891626
      1 0.638588 0.049528 0.541555 0.750512
      1 0.626004 0.046145 0.531368 0.710149
      0 0.425514 0.045079 0.339676 0.512058
      0 0.281426 0.041884 0.212521 0.377711
      0 0.757855 0.040245 0.674695 0.827601
      1 0.844172 0.040206 0.755036 0.905858
      0 0.568484 0.040086 0.491572 0.649084
      0 0.654470 0.040011 0.584784 0.732826


In [13]:
# uncertain cases
# flag cases where CI crosses 0.5 (model not confident about class)
out["ci_crosses_0.5"] = (out["p_2.5%"] < 0.5) & (out["p_97.5%"] > 0.5)
print("\nUncertain cases (CI crosses 0.5):", int(out["ci_crosses_0.5"].sum()), "out of", len(out))


Uncertain cases (CI crosses 0.5): 97 out of 848
