## Imports

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# 📦 Core Imports
import os
import sys
import json
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingWarmRestarts
from sklearn.metrics import mean_absolute_error, r2_score

# 🧠 Project Root
project_root = "/content/drive/MyDrive/BrainAgeRegression"

# 🛣️ Add Project Modules to Path
sys.path.append(project_root)
sys.path.append(os.path.join(project_root, "models"))

# 🛠️ Custom Utilities
from utils.utils import (
    BrainAgeDataset, set_seed, count_parameters,
    split_dataframe, brain_mri_augment, stratified_split
)
from utils.train_utils import (
    BrainAgeTrainer, compute_age_weights, compute_balanced_age_weights
)
from utils.eval_utils import BrainAgeEvaluator
from utils.resnet import ResNet3D
from utils.brain_age_analysis import BrainAgeAnalysis

# ⚙️ Device Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# 🔁 Reproducibility
set_seed(42)

# 💾 Directory Setup
save_dir = os.path.join(project_root, "saved_models", "healthy_model")
results_dir = os.path.join(project_root, "results")
metrics_dir = os.path.join(results_dir, "metrics")
plots_dir = os.path.join(results_dir, "plots")

# 📁 Ensure All Output Directories Exist
os.makedirs(save_dir, exist_ok=True)
os.makedirs(metrics_dir, exist_ok=True)
os.makedirs(plots_dir, exist_ok=True)



# Data / Model Loading

In [None]:
# 🧠 1. Load Metadata
df = pd.read_csv('/content/drive/MyDrive/BrainAgeRegression/data/matched_metadata.csv')
df['CDR'] = pd.to_numeric(df['CDR'], errors='coerce')  # Convert blanks to NaN

# ✅ 2. Define Healthy and Unhealthy Groups
# Healthy = CDR == 0.0 OR (CDR is missing AND Age < 65)
healthy_df = df[(df['CDR'] == 0.0) | (df['CDR'].isna() & (df['Age'] < 65))].copy()
unhealthy_df = df[df['CDR'] > 0].copy()

# 🧾 Summary
num_unknown_total = df['CDR'].isna().sum()
num_unknown_used = healthy_df['CDR'].isna().sum()
num_unknown_excluded = num_unknown_total - num_unknown_used
percent_unknown = (num_unknown_total / len(df)) * 100

print(f"🧠 Healthy individuals: {len(healthy_df)}")
print(f"⚠️ Unhealthy individuals: {len(unhealthy_df)}")
print(f"❓ Unknown CDR (excluded): {num_unknown_excluded}")
print(f"📊 Total unknown CDR: {num_unknown_total} ({percent_unknown:.2f}%)")

# 📂 3. Stratified Split of Healthy Data
train_df, val_df, test_df = stratified_split(healthy_df, bins=8)

# 🧪 4. Compute Age Weights with Custom Boost (on training set only)
custom_boost = [1.0, 1.0, 1.6, 2.0, 2.2, 1.6, 1.5, 1.3]
age_weights = compute_balanced_age_weights(train_df, bins=8, custom_boost=custom_boost)

# 🧠 5. Initialize Model
model = ResNet3D(layers=[1, 2, 2]).to(device)
print(f"🔢 Total trainable parameters: {count_parameters(model):,}")

# 🧾 6. Dataset & DataLoader Setup
nifti_dir = '/content/drive/MyDrive/BrainAgeRegression/data/nifti'

train_dataset = BrainAgeDataset(train_df, nifti_dir=nifti_dir, transform=None, age_weights=age_weights)
val_dataset = BrainAgeDataset(val_df, nifti_dir=nifti_dir, transform=None)
test_dataset = BrainAgeDataset(test_df, nifti_dir=nifti_dir, transform=None)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=4, pin_memory=True)

# 📊 7. Age Bin Distribution (for reference)
train_df['age_bin'] = pd.qcut(train_df['Age'], q=8, duplicates='drop')
print(train_df['age_bin'].value_counts().sort_index())


# Loss, Optimizer, Scheduler, Age Weights

In [None]:
# ⚙️ 1. Loss, Optimizer, Scheduler Setup
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=5e-4)
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)

In [None]:
import importlib
import utils.train_utils
importlib.reload(utils.train_utils)

In [None]:
from utils.train_utils import (
    BrainAgeTrainer, compute_age_weights, compute_balanced_age_weights
)

# Train The Model

In [None]:
# 🧠 2. Initialize Trainer
trainer = BrainAgeTrainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    device=device,
    scheduler=scheduler,
    augment=True,               # Apply data augmentation
    use_weighted_loss=True,     # Use age-weighted loss
    early_stopping_patience=12  # Stop if no val improvement
)

# 🚀 3. Train the Model
trainer.train(epochs=40, track_predictions=True)


# Evaluate Performance

In [None]:
# 4. Evaluate Performance
train_pred, train_true = trainer.get_predictions()['train']
val_pred, val_true = trainer.get_predictions()['val']

# Preview a few predictions
for i in range(10):
    print(f"[Train] True: {train_true[i]:.1f}, Predicted: {train_pred[i]:.1f}")
for i in range(10):
    print(f"[Val]   True: {val_true[i]:.1f}, Predicted: {val_pred[i]:.1f}")

# 5. Compute Metrics
def evaluate(y_true, y_pred, label="Set"):
    mae = mean_absolute_error(y_true, y_pred)
    rmse = np.sqrt(np.mean((np.array(y_true) - np.array(y_pred))**2))
    r2 = r2_score(y_true, y_pred)
    print(f"{label} | MAE: {mae:.2f} | RMSE: {rmse:.2f} | R²: {r2:.3f}")
    return {"mae": float(mae), "rmse": float(rmse), "r2": float(r2)}


evaluate(train_true, train_pred, "Train")
evaluate(val_true, val_pred, "Validation")

# Plot Loss + Predictions

In [None]:
# Create results folders
os.makedirs("results/plots", exist_ok=True)
os.makedirs("results/metrics", exist_ok=True)

In [None]:
# Retrieve training history
history = trainer.get_history()

In [None]:
# Loss Curves
fig, ax = plt.subplots()
ax.plot(history['train_loss'], label='Train Loss')
ax.plot(history['val_loss'], label='Val Loss')
ax.set_xlabel("Epoch")
ax.set_ylabel("MSE Loss")
ax.set_title("Loss Curve")
ax.legend()
ax.grid(True)
fig.tight_layout()
fig.savefig(os.path.join(plots_dir, "loss_curve.png"))
plt.close(fig)

# Predicted vs True (Validation)
fig, ax = plt.subplots()
ax.scatter(val_true, val_pred, alpha=0.6)
ax.plot([min(val_true), max(val_true)], [min(val_true), max(val_true)], 'r--')
ax.set_xlabel("True Age")
ax.set_ylabel("Predicted Age")
ax.set_title("Predicted vs True Age")
ax.grid(True)
fig.tight_layout()
fig.savefig(os.path.join(plots_dir, "predicted_vs_true_val.png"))
plt.close(fig)


# Save Model + Results

In [None]:
# Save model weights
torch.save(model.state_dict(), os.path.join(save_dir, "resnet3d_brain_age.pth"))

# Save predictions
np.savez(
    os.path.join(save_dir, "resnet3d_predictions.npz"),
    train_pred=train_pred,
    train_true=train_true,
    val_pred=val_pred,
    val_true=val_true
)


history = trainer.get_history()
np.save(os.path.join(save_dir, "training_history.npy"), history)


In [None]:
# Save metrics
train_metrics = evaluate(train_true, train_pred, "Train")
val_metrics = evaluate(val_true, val_pred, "Validation")

# Convert NumPy floats to native Python floats
train_metrics = {k: float(v) for k, v in train_metrics.items()}
val_metrics = {k: float(v) for k, v in val_metrics.items()}

with open("results/metrics/healthy_model_train_metrics.json", "w") as f:
    json.dump(train_metrics, f, indent=2)
with open("results/metrics/healthy_model_val_metrics.json", "w") as f:
    json.dump(val_metrics, f, indent=2)

In [None]:
import importlib
import utils.eval_utils  # Make sure it's already imported
importlib.reload(utils.eval_utils)
from utils.eval_utils import BrainAgeEvaluator


In [None]:
# Evaluate on Test Set
test_dataset = BrainAgeDataset(test_df, nifti_dir=nifti_dir, transform=None)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=4, pin_memory=True)

model = ResNet3D(layers=[1, 2, 2]).to(device)
model.load_state_dict(torch.load(os.path.join(save_dir, "resnet3d_brain_age.pth")))
model.eval()

evaluator = BrainAgeEvaluator(model, device)

test_metrics, test_pred, test_true = evaluator.evaluate(test_loader)

# Diagnostic Plots (saved)
fig = evaluator.plot_predictions(test_true, test_pred, title="Predicted vs. True Age (Test Set)")
fig.savefig(os.path.join(plots_dir, "predicted_vs_true_test.png"))
plt.close(fig)

# Prediction Distribution
fig = evaluator.plot_prediction_distribution(test_pred)
fig.savefig(os.path.join(plots_dir, "prediction_distribution.png"))
plt.close(fig)

# Residuals
fig = evaluator.plot_residuals(test_true, test_pred)
fig.savefig(os.path.join(plots_dir, "residuals.png"))
plt.close(fig)

# Prediction Bias
fig = evaluator.plot_prediction_bias(test_true, test_pred, bins=10, method='qcut')
fig.savefig(os.path.join(plots_dir, "prediction_bias.png"))
plt.close(fig)

# Stratified MAE
fig = evaluator.stratified_mae(test_true, test_pred, bins=10, method='qcut')
fig.savefig(os.path.join(plots_dir, "stratified_mae.png"))
plt.close(fig)

# Post-Hoc Bias Correction
corrected_pred, corrected_metrics = evaluator.apply_posthoc_bias_correction(test_true, test_pred)

print("📉 Post-hoc Bias-Corrected Metrics")
for k, v in corrected_metrics.items():
    print(f"{k.upper()}: {v:.2f}")

# Save metrics
with open("results/metrics/healthy_model_train_metrics.json", "w") as f:
    json.dump(test_metrics, f, indent=2)
with open("results/metrics/healthy_model_val_metrics.json", "w") as f:
    json.dump(corrected_metrics, f, indent=2)

# Save Predictions
evaluator.save_predictions(
    os.path.join(save_dir, "resnet3d_test_predictions.npz"),
    test_pred=test_pred,
    test_true=test_true,
    corrected_pred=corrected_pred
)


In [None]:
print("📊 Raw Test Set Metrics")
for k, v in test_metrics.items():
    print(f"{k.upper()}: {v:.2f}")


In [None]:
import matplotlib.pyplot as plt
import numpy as np

def plot_posthoc_correction(test_true, test_pred_raw, test_pred_corrected, save_path=None):
    plt.figure(figsize=(8, 6))

    # Plot raw predictions
    plt.scatter(test_true, test_pred_raw, alpha=0.4, label="Raw Prediction", color="gray")

    # Plot corrected predictions
    plt.scatter(test_true, test_pred_corrected, alpha=0.6, label="Bias-Corrected", color="royalblue")

    # Identity line
    plt.plot([test_true.min(), test_true.max()], [test_true.min(), test_true.max()],
             linestyle='--', color='red', label="Ideal Fit")

    plt.xlabel("True Age")
    plt.ylabel("Predicted Age")
    plt.title("Predicted Age vs. True Age\nBefore and After Post-Hoc Correction")
    plt.legend()
    plt.grid(True)

    if save_path:
        plt.savefig(save_path, dpi=300)
        print(f"✅ Plot saved to: {save_path}")
    else:
        plt.show()




In [None]:
plot_posthoc_correction(
    test_true=np.array(test_true),
    test_pred_raw=np.array(test_pred),
    test_pred_corrected=np.array(corrected_pred)
)


In [None]:
def plot_residual_comparison(test_true, test_pred_raw, test_pred_corrected, save_path=None):
    import matplotlib.pyplot as plt
    import numpy as np

    residual_raw = test_pred_raw - test_true
    residual_corrected = test_pred_corrected - test_true

    plt.figure(figsize=(8, 6))
    plt.plot(test_true, residual_raw, 'o', alpha=0.4, label="Raw Residuals", color="gray")
    plt.plot(test_true, residual_corrected, 'o', alpha=0.6, label="Corrected Residuals", color="royalblue")
    plt.axhline(0, linestyle='--', color='red', label="Zero Error")
    plt.xlabel("True Age")
    plt.ylabel("Residual (Predicted - True)")
    plt.title("Prediction Error Before and After Post-Hoc Correction")
    plt.legend()
    plt.grid(True)

    if save_path:
        plt.savefig(save_path, dpi=300)
        print(f"✅ Plot saved to: {save_path}")
    else:
        plt.show()


In [None]:
plot_residual_comparison(
    test_true=np.array(test_true),
    test_pred_raw=np.array(test_pred),
    test_pred_corrected=np.array(corrected_pred)
)


In [None]:
# Grab a batch from the DataLoader
batch = next(iter(train_loader))  # Replace with your actual DataLoader name if different

# Try unpacking with 3 elements first (in case weights are included)
try:
    images, ages, weights = batch
    print("✅ Using weighted loss")
except ValueError:
    images, ages = batch
    weights = None
    print("✅ Using unweighted loss")

# Inspect the batch
print(f"🧠 Image tensor shape: {images.shape}")        # e.g., [batch_size, 1, D, H, W]
print(f"📏 Voxel dimensions: {images.shape[2:]}")      # e.g., (128, 128, 128)
print(f"🎯 Age tensor shape: {ages.shape}")
print(f"🔎 Voxel intensity range: min={images.min():.2f}, max={images.max():.2f}")