# 1.) Imports

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [None]:
%matplotlib inline


In [None]:
# 🧱 Core PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F


# 🧠 Torchvision for pretrained 3D models
import torchvision
from torchvision import transforms

# 📊 Evaluation
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score

# 🧪 Scientific stack
import numpy as np
import pandas as pd
import random
import os
import sys

# 📈 Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# 🧰 Misc utilities
from tqdm import tqdm

import nibabel as nib  # For loading .nii or .nii.gz files

In [None]:
sys.path.append('/content/drive/MyDrive/BrainAgeRegression')
sys.path.append('/content/drive/MyDrive/BrainAgeRegression/models')

In [None]:
from utils.utils import (
    BrainAgeDataset, set_seed, count_parameters,
    split_dataset, normalize_targets, denormalize
)
from utils.train_utils import BrainAgeTrainer
from utils.eval_utils import BrainAgeEvaluator
from utils.save_utils import BrainAgeSave

# 🔁 Reload if editing utils
import importlib
import utils.train_utils
importlib.reload(utils.train_utils)

import utils.utils
importlib.reload(utils.utils)


# 🧠 Dataset & Metadata
csv_path = '/content/drive/MyDrive/BrainAgeRegression/matched_metadata.csv'
df = pd.read_csv(csv_path)
nifti_dir = '/content/drive/MyDrive/BrainAgeRegression/data/nifti'
full_dataset = BrainAgeDataset(df, nifti_dir)

In [None]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed()

In [None]:
model = torchvision.models.video.r3d_18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 1)

# 2.) Setup + Config

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Add project path
sys.path.append('/content/drive/MyDrive/BrainAgeRegression')

# Paths
model_save_path = "/content/drive/MyDrive/BrainAgeRegression/saved_models/r3d18_transfer"
csv_path = "/content/drive/MyDrive/BrainAgeRegression/matched_metadata.csv"
nifti_dir = "/content/drive/MyDrive/BrainAgeRegression/data/nifti"

# Load Dataset
df = pd.read_csv(csv_path)
full_dataset = BrainAgeDataset(df, nifti_dir, use_normalized_age=False)


# 4.) Normalize / Train/Test / Dataloaders / Criterion / Optimizer

In [None]:
df, mean_age, std_age = normalize_targets(df, target_col='Age')
full_dataset = BrainAgeDataset(df, nifti_dir, use_normalized_age=True)

In [None]:
set_seed(42)
train_dataset, val_dataset, test_dataset = split_dataset(full_dataset, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15)

In [None]:
batch_size = 4
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=batch_size)
test_loader  = DataLoader(test_dataset, batch_size=batch_size)

In [None]:
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

#5.) Model Load

In [None]:
from torchvision.models.video import r3d_18

model = r3d_18(pretrained=True)

# Modify first conv layer to accept 1 channel instead of 3
model.stem[0] = nn.Conv3d(1, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3), bias=False)

# Replace final FC layer for regression
model.fc = nn.Linear(model.fc.in_features, 1)

model = model.to(device)

# 6.) Training Loop

In [None]:
!nvidia-smi


In [None]:
for name, param in model.named_parameters():
    print(name, param.requires_grad)

In [None]:
trainer = BrainAgeTrainer(
    model, train_loader, val_loader,
    criterion, optimizer, device,
    augment=False
)

trainer.train(epochs=20, track_predictions=True)


In [None]:
history = trainer.get_history()
train_pred, train_true = trainer.get_predictions()['train']
val_pred, val_true = trainer.get_predictions()['val']

# 7.) Eval Setup

In [None]:
evaluator = BrainAgeEvaluator(model, device)
metrics, y_pred, y_true = evaluator.evaluate(test_loader, criterion)

In [None]:
train_metadata_df = df.iloc[train_dataset.indices].reset_index(drop=True)
val_metadata_df   = df.iloc[val_dataset.indices].reset_index(drop=True)
test_metadata_df  = df.iloc[test_dataset.indices].reset_index(drop=True)

train_pred, train_true = trainer.get_predictions()['train']
val_pred, val_true     = trainer.get_predictions()['val']

In [None]:
# Convert to NumPy arrays
train_pred = np.array(train_pred)
train_true = np.array(train_true)
val_pred = np.array(val_pred)
val_true = np.array(val_true)

# Denormalize
train_pred_real = train_pred * std_age + mean_age
train_true_real = train_true * std_age + mean_age

val_pred_real = val_pred * std_age + mean_age
val_true_real = val_true * std_age + mean_age

In [None]:
evaluator = BrainAgeEvaluator(model, device, metadata_df=train_metadata_df, mean=mean_age, std=std_age)

In [None]:
# 📊 Compute Metrics on Train Set
train_metrics = evaluator.compute_metrics(train_true_real, train_pred_real)
print("📘 Train Set Metrics:")
print(f"  MAE : {train_metrics['mae']:.2f}")
print(f"  RMSE: {train_metrics['rmse']:.2f}")
print(f"  R²  : {train_metrics['r2']:.3f}")

# 🔄 Switch to Validation Metadata
evaluator.metadata = val_metadata_df

# 📊 Compute Metrics on Validation Set
val_metrics = evaluator.compute_metrics(val_true_real, val_pred_real)
print("\n📗 Validation Set Metrics:")
print(f"  MAE : {val_metrics['mae']:.2f}")
print(f"  RMSE: {val_metrics['rmse']:.2f}")
print(f"  R²  : {val_metrics['r2']:.3f}")

# 📊 Stratified MAE by Demographics
print("\n📊 Stratified MAE by Sex (M/F):")
evaluator.stratify_mae(val_true_real, val_pred_real, by='M/F')

print("\n📊 Stratified MAE by Handedness:")
evaluator.stratify_mae(val_true_real, val_pred_real, by='Hand')

# 📊 Stratified MAE by Brain Volume Bins
print("\n📊 Stratified MAE by nWBV Bins:")
evaluator.stratify_by_volume_bins(val_true_real, val_pred_real, col='nWBV')

print("\n📊 Stratified MAE by eTIV Bins:")
evaluator.stratify_by_volume_bins(val_true_real, val_pred_real, col='eTIV')

In [None]:
# 📈 Plots
evaluator.plot_train_vs_test(
    train_true_real, train_pred_real,
    val_true_real, val_pred_real
)

evaluator.metadata = val_metadata_df
evaluator.plot_by_cdr_status(val_true_real, val_pred_real, dataset_label="Validation")

evaluator.compare_train_test_metrics(
    train_true_real, train_pred_real,
    val_true_real, val_pred_real
)

In [None]:
print(df.columns)
print(df[['SubjectID', 'Age']].head())


In [None]:
img_tensor, age = full_dataset[1]  # Try a different index
mid_slice = img_tensor[0, img_tensor.shape[1] // 2].numpy()

import matplotlib.pyplot as plt
plt.imshow(mid_slice, cmap='gray')
plt.title(f"Age (raw): {age.item():.2f}")
plt.colorbar()
plt.show()



In [None]:
# 🔍 Evaluate on Test Set
evaluator.metadata = test_metadata_df
test_metrics, test_pred, test_true = evaluator.evaluate(test_loader, criterion)

# Denormalize
test_pred_real = test_pred * std_age + mean_age
test_true_real = test_true * std_age + mean_age

# Print metrics
print("\n📕 Test Set Metrics:")
print(f"  MAE : {test_metrics['mae']:.2f}")
print(f"  RMSE: {test_metrics['rmse']:.2f}")
print(f"  R²  : {test_metrics['r2']:.3f}")
