# 1.) Imports

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# 📚 Core Libraries
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random


In [None]:
# 🔥 PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

In [None]:
# 🧠 Add project directory to path
sys.path.append('/content/drive/MyDrive/BrainAgeRegression')
sys.path.append('/content/drive/MyDrive/BrainAgeRegression/models')

In [None]:
# 🧰 Custom Utilities
from utils.utils import (
    BrainAgeDataset, set_seed, count_parameters,
    split_dataset, normalize_targets, denormalize
)
from utils.train_utils import BrainAgeTrainer
from utils.eval_utils import BrainAgeEvaluator
from utils.save_utils import BrainAgeSave

In [None]:
import importlib
import utils.train_utils
importlib.reload(utils.train_utils)

In [None]:
import importlib
import utils.train_utils
importlib.reload(utils.train_utils)

In [None]:
from utils.train_utils import BrainAgeTrainer

In [None]:
# 🧠 Model
from cnn3d import Medium3DCNN_v3

In [None]:
csv_path = '/content/drive/MyDrive/BrainAgeRegression/matched_metadata.csv'
df = pd.read_csv(csv_path)

nifti_dir = '/content/drive/MyDrive/BrainAgeRegression/data/nifti'
full_dataset = BrainAgeDataset(df, nifti_dir)


# 2.) Setup & Configuration

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Add project path
sys.path.append('/content/drive/MyDrive/BrainAgeRegression')

## Paths

In [None]:
# Paths
model_save_path = "/content/drive/MyDrive/BrainAgeRegression/saved_models/Medium3DCNN_v3"
csv_path = "/content/drive/MyDrive/BrainAgeRegression/matched_metadata.csv"
nifti_dir = "/content/drive/MyDrive/BrainAgeRegression/data/nifti"

## Load Dataset

In [None]:
df = pd.read_csv(csv_path)
full_dataset = BrainAgeDataset(df, nifti_dir)

## Load Model and Eval Data

In [None]:
# Initialize model and save manager
model = Medium3DCNN_v3(dropout_rate=0.3)
model = model.to(device)

#save_manager = BrainAgeSave(model_save_path)

# Load model weights

#model = save_manager.load_model_weights(model, device=device)

# Load evaluation data

#eval_data = save_manager.load_eval_data()

# Restore variables

#mean_age = eval_data['mean_age']
#std_age = eval_data['std_age']
#train_true_real = eval_data['train_true_real']
#train_pred_real = eval_data['train_pred_real']
#test_true_real = eval_data['test_true_real']
#test_pred_real = eval_data['test_pred_real']
#train_metadata_df = eval_data['train_metadata_df']
#test_metadata_df = eval_data['test_metadata_df']

# 2.) Normalize Data

In [None]:
df, mean_age, std_age = normalize_targets(df, target_col='Age')
full_dataset = BrainAgeDataset(df, nifti_dir, use_normalized_age=True)

# 2.) Train/Test/Val Split

In [None]:
# Set seed for reproducibility
set_seed(42)

# Split the dataset
train_dataset, val_dataset, test_dataset = split_dataset(full_dataset, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15)


# 4.) Create our DataLoaders

In [None]:
batch_size = 4  # Adjust based on memory

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=batch_size)
test_loader  = DataLoader(test_dataset, batch_size=batch_size)


# 4.) Initialize Model and Components

In [None]:
# Create model with dropout
model = Medium3DCNN_v3(dropout_rate=0.01)

# Loss and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Debug: Check shape after conv layers
sample = next(iter(train_loader))[0]  # [B, C, D, H, W]
x = sample.to(device)
x = model.features(x)
print("Shape after conv layers:", x.shape)
print("Flattened size:", x.view(x.size(0), -1).shape[1])


# 5.) Training Loop

In [None]:
trainer = BrainAgeTrainer(
    model, train_loader, val_loader,
    criterion, optimizer, device,
    augment=True  # 👈 enable augmentation
)

trainer.train(epochs=20, track_predictions=True)

history = trainer.get_history()
train_pred, train_true = trainer.get_predictions()['train']
val_pred, val_true = trainer.get_predictions()['val']


# 6.) Test Eval

In [None]:
from sklearn.metrics import mean_squared_error
print(mean_squared_error.__module__)
print(mean_squared_error)


In [None]:
evaluator = BrainAgeEvaluator(model, device)
print("Evaluator compute_metrics points to:", evaluator.compute_metrics.__module__)


In [None]:
import utils.eval_utils
print(utils.eval_utils.__file__)


In [None]:
# Evaluate on normalized targets
evaluator = BrainAgeEvaluator(model, device)
metrics, y_pred, y_true = evaluator.evaluate(test_loader, criterion)

# Denormalize MAE and RMSE
real_mae = metrics['mae'] * std_age
real_rmse = metrics['rmse'] * std_age

# Print both normalized and real-world metrics
print(f"📏 Normalized → MAE: {metrics['mae']:.2f} | RMSE: {metrics['rmse']:.2f} | R²: {metrics['r2']:.3f}")
print(f"📏 Real-World → MAE: {real_mae:.2f} years | RMSE: {real_rmse:.2f} years")


# 7.) Evaluate our Model

## Dataset Matching

In [None]:
train_metadata_df = df.iloc[train_dataset.indices].reset_index(drop=True)
val_metadata_df   = df.iloc[val_dataset.indices].reset_index(drop=True)
test_metadata_df  = df.iloc[test_dataset.indices].reset_index(drop=True)

In [None]:
train_pred, train_true = trainer.get_predictions()['train']
val_pred, val_true = trainer.get_predictions()['val']


In [None]:
import numpy as np

# Convert to NumPy arrays
train_pred = np.array(train_pred)
train_true = np.array(train_true)
val_pred = np.array(val_pred)
val_true = np.array(val_true)

# Denormalize
train_pred_real = train_pred * std_age + mean_age
train_true_real = train_true * std_age + mean_age

test_pred_real = val_pred * std_age + mean_age
test_true_real = val_true * std_age + mean_age


In [None]:
# 🧠 Initialize Evaluator with Train Metadata
evaluator = BrainAgeEvaluator(model, device, metadata_df=train_metadata_df, mean=mean_age, std=std_age)

# 📊 Compute Metrics on Train Set
train_metrics = evaluator.compute_metrics(train_true_real, train_pred_real)
print("📘 Train Set Metrics:")
print(f"  MAE : {train_metrics['mae']:.2f}")
print(f"  RMSE: {train_metrics['rmse']:.2f}")
print(f"  R²  : {train_metrics['r2']:.3f}")

# 🔄 Switch to Test Metadata
evaluator.metadata = val_metadata_df

# 📊 Compute Metrics on Test Set
test_metrics = evaluator.compute_metrics(test_true_real, test_pred_real)
print("\n📗 Test Set Metrics:")
print(f"  MAE : {test_metrics['mae']:.2f}")
print(f"  RMSE: {test_metrics['rmse']:.2f}")
print(f"  R²  : {test_metrics['r2']:.3f}")

# 📊 Stratified MAE by Demographics
print("\n📊 Stratified MAE by Sex (M/F):")
evaluator.stratify_mae(test_true_real, test_pred_real, by='M/F')

print("\n📊 Stratified MAE by Handedness:")
evaluator.stratify_mae(test_true_real, test_pred_real, by='Hand')

# 📊 Stratified MAE by Brain Volume Bins
print("\n📊 Stratified MAE by nWBV Bins:")
evaluator.stratify_by_volume_bins(test_true_real, test_pred_real, col='nWBV')

print("\n📊 Stratified MAE by eTIV Bins:")
evaluator.stratify_by_volume_bins(test_true_real, test_pred_real, col='eTIV')



## Plots

In [None]:
# Train vs. Test comparison plot
evaluator.plot_train_vs_test(
    train_true_real, train_pred_real,
    test_true_real, test_pred_real
)

# CDR > 0 plot and CDR = 0
evaluator.metadata = val_metadata_df
evaluator.plot_by_cdr_status(test_true_real, test_pred_real, dataset_label="Test")



In [None]:
evaluator.compare_train_test_metrics(
    train_true_real, train_pred_real,
    test_true_real, test_pred_real
)


In [None]:
# Save model checkpoint
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'mean_age': mean_age,
    'std_age': std_age,
    'train_history': trainer.get_history()
}, "brain_age_model_checkpoint.pt")


In [None]:
import numpy as np

# Save predictions and true values
np.savez("brain_age_predictions.npz",
         train_pred=train_pred,
         train_true=train_true,
         val_pred=val_pred,
         val_true=val_true)


In [None]:
# Save metadata DataFrames as CSV
train_metadata_df.to_csv("train_metadata.csv", index=False)
test_metadata_df.to_csv("test_metadata.csv", index=False)


In [None]:
# loading later

In [None]:
# Load model and optimizer
checkpoint = torch.load("brain_age_model_checkpoint.pt", map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

# Restore metadata
mean_age = checkpoint['mean_age']
std_age = checkpoint['std_age']
history = checkpoint['train_history']

# Load predictions
data = np.load("brain_age_predictions.npz")
train_pred = data['train_pred']
train_true = data['train_true']
val_pred = data['val_pred']
val_true = data['val_true']

# Load metadata
import pandas as pd
train_metadata_df = pd.read_csv("train_metadata.csv")
test_metadata_df = pd.read_csv("test_metadata.csv")
