# 1.) Imports

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
import sys
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

import matplotlib.pyplot as plt

# Add project paths
sys.path.append('/content/drive/MyDrive/BrainAgeRegression')
sys.path.append('/content/drive/MyDrive/BrainAgeRegression/models')

# Custom utilities
from utils.utils import BrainAgeDataset, set_seed, count_parameters, split_dataset
from utils.train_utils import BrainAgeTrainer
from utils.eval_utils import BrainAgeEvaluator
from cnn3d import Medium3DCNN

# 2.) Setup & Configuration

In [None]:
csv_path = '/content/drive/MyDrive/BrainAgeRegression/matched_metadata.csv'
nifti_dir = '/content/drive/MyDrive/BrainAgeRegression/data/nifti'
model_save_path = '/content/drive/MyDrive/BrainAgeRegression/saved_models/Medium3DCNN_v2'

batch_size = 4
epochs = 20
learning_rate = 1e-3
weight_decay = 1e-5
set_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

## Load + Normalize Dataset

In [None]:
df = pd.read_csv(csv_path)
dataset = BrainAgeDataset(df, nifti_dir)

# 3.) Create our Train/Test/Val DataLoaders



In [None]:
train_dataset, val_dataset, test_dataset = split_dataset(dataset)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=batch_size)
test_loader  = DataLoader(test_dataset, batch_size=batch_size)

# 4.) Initialize our Model and Components

In [None]:
# 5. 🧠 Initialize Model
model = Medium3DCNN().to(device)
count_parameters(model)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)

In [None]:
# Grab a batch
images, targets = next(iter(train_loader))
images = images.to(device)
targets = targets.to(device)

# Run through model
model.eval()
with torch.no_grad():
    outputs = model(images)

# Print diagnostics
print("🧠 Input Image Stats:")
print(f"  Mean: {images.mean().item():.4f}, Std: {images.std().item():.4f}")
print(f"  Min: {images.min().item():.2f}, Max: {images.max().item():.2f}")

print("\n🎯 Target Age Stats:")
print(f"  Shape: {targets.shape}, Min: {targets.min().item()}, Max: {targets.max().item()}")

print("\n🔮 Model Output Stats:")
print(f"  Shape: {outputs.shape}, Min: {outputs.min().item()}, Max: {outputs.max().item()}")
print(f"  Sample predictions: {outputs[:5].cpu().numpy()}")


In [None]:
print("Final layer bias:", model.classifier[-1].bias.data)


# 5.) Training Loop

In [None]:
trainer = BrainAgeTrainer(model, train_loader, val_loader, criterion, optimizer, device, scheduler=scheduler)
trainer.train(epochs=epochs, track_predictions=True)

history = trainer.get_history()
train_pred, train_true = trainer.get_predictions()['train']
val_pred, val_true = trainer.get_predictions()['val']


# 6.) Test Eval

In [None]:
plt.plot(history['train_loss'], label='Train Loss')
plt.plot(history['val_loss'], label='Val Loss')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training & Validation Loss")
plt.legend()
plt.grid(True)
plt.show()

# 7.) Evaluate Model

In [None]:
evaluator = BrainAgeEvaluator(model, device, metadata_df=df)

metrics, test_pred, test_true = evaluator.evaluate(test_loader, criterion)
print(f"\n📊 Test Set Metrics:\n  MAE : {metrics['mae']:.2f}\n  RMSE: {metrics['rmse']:.2f}\n  R²  : {metrics['r2']:.3f}")

# 8. 📊 Compare Train vs. Test

In [None]:
evaluator.compare_train_test_metrics(train_true, train_pred, test_true, test_pred)

# 9. 📉 Plot Predictions

In [None]:
evaluator.plot_predictions(test_true, test_pred, title="Test Set: Predicted vs. True Age")
evaluator.plot_train_vs_test(train_true, train_pred, test_true, test_pred)

# 10. Save Results

In [None]:
import pickle

eval_data = {
    'train_pred': train_pred,
    'train_true': train_true,
    'val_pred': val_pred,
    'val_true': val_true,
    'test_pred': test_pred,
    'test_true': test_true,
    'train_metadata_df': df.iloc[train_dataset.indices].reset_index(drop=True),
    'val_metadata_df': df.iloc[val_dataset.indices].reset_index(drop=True),
    'test_metadata_df': df.iloc[test_dataset.indices].reset_index(drop=True)
}

os.makedirs(model_save_path, exist_ok=True)
torch.save(model.state_dict(), os.path.join(model_save_path, 'model_weights.pth'))

with open(os.path.join(model_save_path, 'eval_data.pkl'), 'wb') as f:
    pickle.dump(eval_data, f)

print("✅ Model and evaluation data saved.")

## 📈 Evaluation Summary

### 📏 Real-World Performance
- **MAE**: 6.51 years  
- **R²**: 0.869

---

### 📊 Stratified MAE by Sex (M/F)
- **M**: 5.76 years  
- **F**: 6.91 years

---

### ✋ Stratified MAE by Handedness
- **Right-handed (R)**: 6.51 years

---

### 🧠 Stratified MAE by Normalized Whole Brain Volume (nWBV)
- **High**: 4.03 years  
- **Medium**: 7.55 years  
- **Low**: 7.86 years

---

### 🧠 Stratified MAE by Estimated Total Intracranial Volume (eTIV)
- **High**: 3.34 years  
- **Medium**: 7.77 years  
- **Low**: 8.57 years


nWBV bin edges: [0.644 0.76  0.832 0.893]
eTIV bin edges: [1123. 1407. 1542. 1992.]
