# # Full Cross-Validation and Final Training Evaluation Notebook
# 
# This notebook performs graph neural network (GNN)-based classification or regression using a custom Message Passing Neural Network (MPNN) architecture. It supports training with 5-fold cross-validation and then retrains a final model on combined data using early stopping. Performance is evaluated on a held-out test set.
# 
# ## üß≠ Workflow Overview
# 
# ### Step 1. Imports and Setup
# - Load libraries including PyTorch, PyTorch Geometric, and scikit-learn.
# - Configure plotting and file management tools.
# 
# ### Step 2. Task and Reproducibility Setup
# - Set task: `task = "classification"` or `"regression"`.
# - Configure reproducibility with a fixed random seed for reproducibility
# - Define class names if applicable and set up result directories.
# 
# ### Step 3. Model Definition
# - Define an MPNN layer and model using PyTorch Geometric.
# - The model takes node and edge features and outputs class scores or scalar predictions.
# 
# ### Step 4. Evaluation Helper Function
# - Implements an `evaluate(model, loader)` function to return model outputs and ground truths.
# 
# ### Step 5. Input Dimensions and Device Setup
# - Determine input feature dimensions from dataset.
# - Detect and assign appropriate device (CPU or GPU).
# 
# ### Step 6. Cross-Validation Training and Evaluation
# - Train the model across 5 folds using pre-split datasets from ../4_train_test_split/5fold_cv/{task}/
.
# - Log metrics per fold:
#   - **Classification**: accuracy, precision, recall, F1, AUC-ROC
#   - **Regression**: MAE, MSE, RMSE, R¬≤
# - Save all results to CSV.
# 
# ### Step 7. Final Model Training and Test Evaluation
# - Merge all training and validation folds.
# - Reserve a small stratified/random validation split.
# - Train model with:
#   - **Early stopping** (based on validation loss)
#   - **Learning rate scheduler**
#   - **Model checkpointing** (saves best model)
# - Evaluate on the held-out test set and save predictions + plots.
# 
# ### Step 8. Cross-Validation Results Visualization
# - Plot fold-wise bar charts for key metrics.
# - Display mean and standard deviation summary.
# 
# ### Step 9. Interpreting Final Model AUC-ROC and Confusion Matrix
# - Provide interpretation of model behavior:
#   - AUC-ROC curve ‚Üí how well the model separates classes
#   - Confusion matrix ‚Üí where the model gets confused
# - Identify any class-specific issues 
#
# ---
#
# ## üõ† Parameters That Can Be optimized/tuned
#
# - **Hidden dimension** of MPNN: `hidden_dim`
# - **Learning rate**: `lr` in `Adam` optimizer
# - **Batch size**: affects convergence speed and generalization
# - **Early stopping patience**: number of epochs without improvement
# - **Scheduler factor/patience**: reduce LR when plateauing
# - **Model depth**: number of MPNN layers
# - **Number of epochs**: max training rounds
# - **Loss function**: CrossEntropy vs. MSE depending on task
# - **Stratified vs. random validation split**

# ---
#
# ‚úÖ This notebook can be extended with additional model types (e.g., GAT, GIN), more complex featurization, or hyperparameter search.


## 1. Imports and Setup

In [39]:
import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MessagePassing, global_mean_pool
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support, roc_auc_score,
    confusion_matrix, ConfusionMatrixDisplay, roc_curve, auc,
    mean_absolute_error, mean_squared_error, r2_score
)
from sklearn.preprocessing import label_binarize
from sklearn.model_selection import train_test_split, StratifiedShuffleSplit
import random

## 2. Task and Reproducibility Setup

In [40]:
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

task = "classification"  # or "regression"
#task = "regression"  # or "classification"
num_classes = 3
class_names = {0: "Low", 1: "Medium", 2: "High"}

base_path = f"../4_train_test_split/10fold_cv/{task}/"
results_dir = f"MPNN_results/{task}/"
os.makedirs(results_dir, exist_ok=True)

## 3b. Define MPNN Layer and MPNN Model with Dropout Support
# Enhanced MPNN model with dropout layers after each message-passing block to reduce overfitting.

In [41]:
class MPNNLayer(MessagePassing):
    def __init__(self, in_channels, out_channels, edge_dim):
        super().__init__(aggr='add')
        self.lin = Linear(in_channels, out_channels)
        self.edge_proj = Linear(edge_dim, out_channels)
        self.msg_lin = Linear(out_channels, out_channels)

    def forward(self, x, edge_index, edge_attr):
        x = self.lin(x)
        edge_attr = self.edge_proj(edge_attr)
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_j, edge_attr):
        return self.msg_lin(x_j + edge_attr)

    def update(self, aggr_out):
        return F.relu(aggr_out)

class MPNN(torch.nn.Module):
    def __init__(self, input_dim, edge_dim, hidden_dim, output_dim, dropout=0.0):
        super().__init__()
        self.dropout = dropout
        self.mp1 = MPNNLayer(input_dim, hidden_dim, edge_dim)
        self.mp2 = MPNNLayer(hidden_dim, hidden_dim, edge_dim)
        self.out = Linear(hidden_dim, output_dim)

    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        x = self.mp1(x, edge_index, edge_attr)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.mp2(x, edge_index, edge_attr)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = global_mean_pool(x, data.batch)
        return self.out(x)


## 4. Evaluation Helper Function

In [42]:
def evaluate(model, loader):
    model.eval()
    preds, labels = [], []
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            out = model(batch)
            preds.append(out.cpu())
            labels.append(batch.y.cpu())
    return torch.cat(preds), torch.cat(labels)

## 5. Input Dimensions and Device Setup

In [44]:
sample_data = torch.load(os.path.join(base_path, f"{task}_train_fold0.pt"))[0]
input_dim = sample_data.x.size(1)
edge_dim = sample_data.edge_attr.size(1)
output_dim = num_classes if task == "classification" else 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# # Steps 3‚Äì7: Hyperparameter Tuning, Retraining, and Evaluation
# This notebook covers the core stages in training and evaluating a GNN model using 5-fold CV and ensemble averaging.


## Step 3: Cross-Validation Training with Hyperparameter Sweep


# ## Step 4: Select Best Hyperparameters
# Use this section to manually define the best hyperparameters based on the sweep above.


In [None]:
hidden_dims = [64, 128, 256]
dropouts    = [0.0, 0.2, 0.4]
lrs         = [1e-3, 5e-4, 1e-4]
results = []

for hd in hidden_dims:
    for dp in dropouts:
        for lr in lrs:
            print(f"\nüîß Config: hidden_dim={hd}, dropout={dp}, lr={lr}")
            fold_scores = []
            for fold in range(10):
                # Load fold data
                train_data = torch.load(os.path.join(base_path, f"{task}_train_fold{fold}.pt"))
                val_data   = torch.load(os.path.join(base_path, f"{task}_val_fold{fold}.pt"))
                
                model = MPNN(train_data[0].x.size(1),
                             train_data[0].edge_attr.size(1),
                             hidden_dim=hd,
                             output_dim=num_classes,
                             dropout=dp).to(device)
                opt = torch.optim.Adam(model.parameters(), lr=lr)
                tr = DataLoader(train_data, batch_size=32, shuffle=True,
                                worker_init_fn=seed_worker, generator=generator)
                vl = DataLoader(val_data,   batch_size=32)

                # Train
                for epoch in range(1, 51):
                    model.train()
                    for batch in tr:
                        batch = batch.to(device)
                        opt.zero_grad()
                        out = model(batch)
                        loss = F.cross_entropy(out, batch.y.long())
                        loss.backward()
                        opt.step()
                # Eval
                preds, labels = evaluate(model, vl)
                y_true = labels.numpy().astype(int)
                y_probs = F.softmax(preds, dim=1).numpy()
                auc = roc_auc_score(label_binarize(y_true, classes=np.arange(num_classes)),
                                     y_probs, multi_class='ovr')
                fold_scores.append(auc)
            results.append((hd, dp, lr, np.mean(fold_scores), np.std(fold_scores)))
            print(f"üìä AUC: {np.mean(fold_scores):.4f} ¬± {np.std(fold_scores):.4f}")

sweep_df = pd.DataFrame(results,
                        columns=["hidden_dim","dropout","lr","mean_auc","std_auc"])
display(sweep_df.sort_values("mean_auc", ascending=False))


In [45]:

# %%
best_hidden_dim = 128
best_dropout = 0.0
best_lr = 0.0005

# ## Step 5a: Retrain All Folds with Best Hyperparameters

In [None]:
# %%
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score
import pandas as pd
fold_metrics = []

for fold in range(10):
    print(f"\nüîÅ Retraining Fold {fold+1}/10")
    train_data = torch.load(os.path.join(base_path, f"{task}_train_fold{fold}.pt"))
    val_data   = torch.load(os.path.join(base_path, f"{task}_val_fold{fold}.pt"))

    model = MPNN(train_data[0].x.size(1),
                 train_data[0].edge_attr.size(1),
                 hidden_dim=best_hidden_dim,
                 output_dim=num_classes,
                 dropout=best_dropout).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=best_lr)
    tr = DataLoader(train_data, batch_size=32, shuffle=True,
                    worker_init_fn=seed_worker, generator=generator)
    vl = DataLoader(val_data,   batch_size=32)

    best_val_loss = float('inf'); patience=0
    for epoch in range(1, 101):
        model.train(); total=0
        for batch in tr:
            batch=batch.to(device); opt.zero_grad(); out=model(batch)
            loss=F.cross_entropy(out,batch.y.long()); loss.backward(); opt.step(); total+=loss.item()
        preds, labels = evaluate(model, vl)
        y_true = labels.numpy().astype(int)
        y_probs = F.softmax(preds, dim=1).numpy()
        val_loss = F.cross_entropy(preds, labels.long()).item()
        if val_loss<best_val_loss:
            best_val_loss=val_loss; patience=0
            torch.save(model.state_dict(), os.path.join(results_dir, f"fold{fold+1}_model.pt"))
        else:
            patience+=1
            if patience>=10: break

    # Metrics
    y_pred = preds.argmax(dim=1).numpy()
    acc = accuracy_score(y_true, y_pred)
    precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='weighted')
    auc = roc_auc_score(label_binarize(y_true, classes=np.arange(num_classes)), y_probs, multi_class='ovr')
    fold_metrics.append({"fold":fold+1,"accuracy":acc,"precision":precision,"recall":recall,"f1_score":f1,"auc_roc":auc})

# Save CV summary
cv_df = pd.DataFrame(fold_metrics)
cv_df.to_csv(os.path.join(results_dir, "crossval_summary.csv"), index=False)
print("‚úÖ Saved CV summary")


# ## Step 6: Visualize Cross-Validation Results


In [None]:

# %%
cv_df = pd.read_csv(os.path.join(results_dir, "crossval_summary.csv"))
metrics = ['accuracy','precision','recall','f1_score','auc_roc']
fig, axs = plt.subplots(1,len(metrics), figsize=(20,4))
for i,m in enumerate(metrics):
    axs[i].bar(cv_df['fold'], cv_df[m]); axs[i].set_title(m.upper()); axs[i].set_xlabel('Fold'); axs[i].set_ylabel(m)
plt.tight_layout(); plt.show()

# ## Step 7: Ensemble Averaging from 10 CV Models


In [None]:

# %%
ess_preds=[]
test_data=torch.load(os.path.join(base_path, f"{task}_test.pt"))
tl=DataLoader(test_data, batch_size=32)
for fold in range(10):
    model = MPNN(test_data[0].x.size(1), test_data[0].edge_attr.size(1), hidden_dim=best_hidden_dim, output_dim=num_classes, dropout=best_dropout).to(device)
    model.load_state_dict(torch.load(os.path.join(results_dir, f"fold{fold+1}_model.pt")))
    model.eval(); outs=[]
    with torch.no_grad():
        for b in tl: outs.append(model(b.to(device)).cpu())
    ess_preds.append(torch.cat(outs,0))
avg=torch.stack(ess_preds).mean(0)
f_pred=avg.argmax(1).numpy(); t_true=torch.cat([d.y for d in test_data]).numpy().astype(int)
pd.DataFrame({'True':t_true,'Pred':f_pred}).to_csv(os.path.join(results_dir,'ensemble_preds.csv'),index=False)
print('‚úÖ Ensemble preds saved')

# ## Step 7b: Ensemble Model Evaluation ‚Äì Confusion Matrix & AUC‚ÄëROC


In [None]:

# %%
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, roc_curve, auc
from sklearn.preprocessing import label_binarize
import matplotlib.pyplot as plt

# Load ensemble predictions
ens_df = pd.read_csv(os.path.join(results_dir, 'ensemble_preds.csv'))
y_true_ens = ens_df['True'].values
y_pred_ens = ens_df['Pred'].values

# 1) Confusion matrix
cm = confusion_matrix(y_true_ens, y_pred_ens)
disp = ConfusionMatrixDisplay(confusion_matrix=cm,
                              display_labels=[class_names[i] for i in range(num_classes)])
disp.plot(cmap='Blues')
plt.title("Ensemble Model ‚Äì Confusion Matrix")
plt.tight_layout()
plt.show()

# 2) AUC‚ÄëROC per class
# Need probability estimates ‚Äì reload avg_output if available or recompute probabilities
# If you only have hard preds, rerun ensemble loop with model outputs saved as probs:
#     avg_probs = torch.stack(ensemble_prob_lists).mean(0).numpy()
# For now, assuming you have `avg_probs`:
y_probs = avg_output.softmax(dim=1).numpy()  # or your stored avg_probs
y_true_bin = label_binarize(y_true_ens, classes=np.arange(num_classes))

plt.figure(figsize=(7,5))
for i in range(num_classes):
    fpr, tpr, _ = roc_curve(y_true_bin[:, i], y_probs[:, i])
    plt.plot(fpr, tpr, label=f"{class_names[i]} (AUC = {auc(fpr, tpr):.2f})")
plt.plot([0,1], [0,1], 'k--')
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Ensemble Model ‚Äì AUC‚ÄëROC by Class")
plt.legend()
plt.tight_layout()
plt.show()


# ## Step 7: Final Model Training on Combined Data & Test Evaluation


In [None]:
# %% [markdown]
# # GNN Classification Pipeline: 10-Fold CV, Ensemble & Final Model
#
# This notebook trains a Message Passing Neural Network (MPNN) for a 3-class classification task
# (Low/Medium/High) using:
# 1. 10-Fold Cross-Validation with hyperparameter tuning
# 2. Retraining folds with best hyperparameters + early stopping
# 3. Visualizing CV metrics
# 4. Ensemble averaging of fold models
# 5. Final model training on combined data + hold-out test evaluation

# %%
# --- Reproducibility & Imports ---
import os, random
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MessagePassing, global_mean_pool
from sklearn.metrics import (accuracy_score, precision_recall_fscore_support,
                             roc_auc_score, confusion_matrix, roc_curve)
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.preprocessing import label_binarize
import matplotlib.pyplot as plt

# Set seeds
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark     = False

# DataLoader seeding
from torch.utils.data import DataLoader as _DL
from torch.utils.data import get_worker_info

def seed_worker(worker_id):
    worker_seed = seed + worker_id
    np.random.seed(worker_seed)
    random.seed(worker_seed)

generator = torch.Generator()
generator.manual_seed(seed)

# Paths & Task
task = "classification"
num_classes = 3
class_names = {0: "Low", 1: "Medium", 2: "High"}
base_path  = f"../4_train_test_split/10fold_cv/{task}/"
results_dir = f"MPNN_results/{task}/"
os.makedirs(results_dir, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# %% [markdown]
# ## Step 1: Model Definition

# %%
class MPNNLayer(MessagePassing):
    def __init__(self, in_channels, out_channels, edge_dim, dropout=0.0):
        super().__init__(aggr='add')
        self.lin = Linear(in_channels, out_channels)
        self.edge_proj = Linear(edge_dim, out_channels)
        self.msg_lin = Linear(out_channels, out_channels)
        self.dropout = dropout

    def forward(self, x, edge_index, edge_attr):
        x = self.lin(x)
        edge_attr = self.edge_proj(edge_attr)
        out = self.propagate(edge_index, x=x, edge_attr=edge_attr)
        return F.relu(F.dropout(out, p=self.dropout, training=self.training))

    def message(self, x_j, edge_attr):
        return self.msg_lin(x_j + edge_attr)

class MPNN(torch.nn.Module):
    def __init__(self, input_dim, edge_dim, hidden_dim, output_dim, dropout=0.0):
        super().__init__()
        self.mp1 = MPNNLayer(input_dim, hidden_dim, edge_dim, dropout)
        self.mp2 = MPNNLayer(hidden_dim, hidden_dim, edge_dim, dropout)
        self.out = Linear(hidden_dim, output_dim)

    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        x = self.mp1(x, edge_index, edge_attr)
        x = self.mp2(x, edge_index, edge_attr)
        x = global_mean_pool(x, data.batch)
        return self.out(x)

# %% [markdown]
# ## Step 2: Evaluation Function

# %%
def evaluate(model, loader):
    model.eval()
    preds, labels = [], []
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            out = model(batch)
            preds.append(out.cpu())
            labels.append(batch.y.cpu())
    return torch.cat(preds), torch.cat(labels)

# %% [markdown]
# ## Step 3: Hyperparameter Sweep (10-Fold CV)

# %%
hidden_dims = [64, 128, 256]
dropouts    = [0.0, 0.2, 0.4]
lrs         = [1e-3, 5e-4]
results = []

for hd in hidden_dims:
    for dp in dropouts:
        for lr in lrs:
            print(f"\nüîß Config: hidden_dim={hd}, dropout={dp}, lr={lr}")
            fold_scores = []
            for fold in range(10):
                # Load fold data
                train_data = torch.load(os.path.join(base_path, f"{task}_train_fold{fold}.pt"))
                val_data   = torch.load(os.path.join(base_path, f"{task}_val_fold{fold}.pt"))
                
                model = MPNN(train_data[0].x.size(1),
                             train_data[0].edge_attr.size(1),
                             hidden_dim=hd,
                             output_dim=num_classes,
                             dropout=dp).to(device)
                opt = torch.optim.Adam(model.parameters(), lr=lr)
                tr = DataLoader(train_data, batch_size=32, shuffle=True,
                                worker_init_fn=seed_worker, generator=generator)
                vl = DataLoader(val_data,   batch_size=32)

                # Train
                for epoch in range(1, 51):
                    model.train()
                    for batch in tr:
                        batch = batch.to(device)
                        opt.zero_grad()
                        out = model(batch)
                        loss = F.cross_entropy(out, batch.y.long())
                        loss.backward()
                        opt.step()
                # Eval
                preds, labels = evaluate(model, vl)
                y_true = labels.numpy().astype(int)
                y_probs = F.softmax(preds, dim=1).numpy()
                auc = roc_auc_score(label_binarize(y_true, classes=np.arange(num_classes)),
                                     y_probs, multi_class='ovr')
                fold_scores.append(auc)
            results.append((hd, dp, lr, np.mean(fold_scores), np.std(fold_scores)))
            print(f"üìä AUC: {np.mean(fold_scores):.4f} ¬± {np.std(fold_scores):.4f}")

sweep_df = pd.DataFrame(results,
                        columns=["hidden_dim","dropout","lr","mean_auc","std_auc"])
display(sweep_df.sort_values("mean_auc", ascending=False))

# %% [markdown]
# ## Step 4: Select Best Hyperparameters
# Manually set these based on the sweep above.

# %%
best_hidden_dim = 128
best_dropout    = 0.2
best_lr         = 0.001

# %% [markdown]
# ## Step 5: Retrain & Evaluate 10-Fold CV with Best Hyperparameters

# %%
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score
import pandas as pd
fold_metrics = []

for fold in range(10):
    print(f"\nüîÅ Retraining Fold {fold+1}/10")
    train_data = torch.load(os.path.join(base_path, f"{task}_train_fold{fold}.pt"))
    val_data   = torch.load(os.path.join(base_path, f"{task}_val_fold{fold}.pt"))

    model = MPNN(train_data[0].x.size(1),
                 train_data[0].edge_attr.size(1),
                 hidden_dim=best_hidden_dim,
                 output_dim=num_classes,
                 dropout=best_dropout).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=best_lr)
    tr = DataLoader(train_data, batch_size=32, shuffle=True,
                    worker_init_fn=seed_worker, generator=generator)
    vl = DataLoader(val_data,   batch_size=32)

    best_val_loss = float('inf'); patience=0
    for epoch in range(1, 101):
        model.train(); total=0
        for batch in tr:
            batch=batch.to(device); opt.zero_grad(); out=model(batch)
            loss=F.cross_entropy(out,batch.y.long()); loss.backward(); opt.step(); total+=loss.item()
        preds, labels = evaluate(model, vl)
        y_true = labels.numpy().astype(int)
        y_probs = F.softmax(preds, dim=1).numpy()
        val_loss = F.cross_entropy(preds, labels.long()).item()
        if val_loss<best_val_loss:
            best_val_loss=val_loss; patience=0
            torch.save(model.state_dict(), os.path.join(results_dir, f"fold{fold+1}_model.pt"))
        else:
            patience+=1
            if patience>=10: break

    # Metrics
    y_pred = preds.argmax(dim=1).numpy()
    acc = accuracy_score(y_true, y_pred)
    precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='weighted')
    auc = roc_auc_score(label_binarize(y_true, classes=np.arange(num_classes)), y_probs, multi_class='ovr')
    fold_metrics.append({"fold":fold+1,"accuracy":acc,"precision":precision,"recall":recall,"f1_score":f1,"auc_roc":auc})

# Save CV summary
cv_df = pd.DataFrame(fold_metrics)
cv_df.to_csv(os.path.join(results_dir, "crossval_summary.csv"), index=False)
print("‚úÖ Saved CV summary")

# %% [markdown]
# ## Step 6: Visualize Cross-Validation Results

# %%
cv_df = pd.read_csv(os.path.join(results_dir, "crossval_summary.csv"))
metrics = ['accuracy','precision','recall','f1_score','auc_roc']
fig, axs = plt.subplots(1,len(metrics), figsize=(20,4))
for i,m in enumerate(metrics):
    axs[i].bar(cv_df['fold'], cv_df[m]); axs[i].set_title(m.upper()); axs[i].set_xlabel('Fold'); axs[i].set_ylabel(m)
plt.tight_layout(); plt.show()

# %% [markdown]
# ## Step 7: Ensemble Averaging from 10 CV Models

# %%
ess_preds=[]
test_data=torch.load(os.path.join(base_path, f"{task}_test.pt"))
tl=DataLoader(test_data, batch_size=32)
for fold in range(10):
    model = MPNN(test_data[0].x.size(1), test_data[0].edge_attr.size(1), hidden_dim=best_hidden_dim, output_dim=num_classes, dropout=best_dropout).to(device)
    model.load_state_dict(torch.load(os.path.join(results_dir, f"fold{fold+1}_model.pt")))
    model.eval(); outs=[]
    with torch.no_grad():
        for b in tl: outs.append(model(b.to(device)).cpu())
    ess_preds.append(torch.cat(outs,0))
avg=torch.stack(ess_preds).mean(0)
f_pred=avg.argmax(1).numpy(); t_true=torch.cat([d.y for d in test_data]).numpy().astype(int)
pd.DataFrame({'True':t_true,'Pred':f_pred}).to_csv(os.path.join(results_dir,'ensemble_preds.csv'),index=False)
print('‚úÖ Ensemble preds saved')

# %% [markdown]
# ## Step 8: Final Model Training on Combined Data & Test Evaluation

# %%
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
# Merge train+val
all_data=[]
for fold in range(10):
    all_data+=torch.load(os.path.join(base_path,f"{task}_train_fold{fold}.pt"))
    all_data+=torch.load(os.path.join(base_path,f"{task}_val_fold{fold}.pt"))
# small val
sss=StratifiedShuffleSplit(n_splits=1,test_size=0.1,random_state=seed)
labels=[int(d.y.item()) for d in all_data]
train_idx,val_idx=next(sss.split(all_data,labels))
train_split=[all_data[i] for i in train_idx]; val_split=[all_data[i] for i in val_idx]
tr=DataLoader(train_split,batch_size=32,shuffle=True,worker_init_fn=seed_worker,generator=generator)
vl=DataLoader(val_split,batch_size=32)
model=MPNN(all_data[0].x.size(1),all_data[0].edge_attr.size(1),hidden_dim=best_hidden_dim,output_dim=num_classes,dropout=best_dropout).to(device)
opt=torch.optim.Adam(model.parameters(),lr=best_lr)
sched=torch.optim.lr_scheduler.ReduceLROnPlateau(opt,mode='min',patience=5,factor=0.5,verbose=True)
best_v=1e9;pat=0
for epoch in range(1,101):
    model.train();tot=0
    for b in tr: b=b.to(device);opt.zero_grad();o=model(b);l=F.cross_entropy(o,b.y.long());l.backward();opt.step();tot+=l.item()
    preds,labels=evaluate(model,vl);vloss=F.cross_entropy(preds,labels.long()).item();sched.step(vloss)
    if vloss<best_v:best_v=vloss;pat=0;torch.save(model.state_dict(),os.path.join(results_dir,'final_model.pt'))
    else: pat+=1
    if pat>=10:break
# test eval
model.load_state_dict(torch.load(os.path.join(results_dir,'final_model.pt')))
td=DataLoader(torch.load(os.path.join(base_path,f"{task}_test.pt")),batch_size=32)
preds,labels=evaluate(model,td);y_pred=preds.argmax(1).numpy();y_true=labels.numpy().astype(int)
# metrics
acc_f=accuracy_score(y_true,y_pred);prec,rec,f1,_=precision_recall_fscore_support(y_true,y_pred,average='weighted')
# confusion
cm=confusion_matrix(y_true,y_pred);disp=ConfusionMatrixDisplay(cm,display_labels=list(class_names.values()));disp.plot()
plt.title('Final Model Confusion Matrix');plt.show()
# final auc
probs=F.softmax(preds,dim=1).numpy();fpr,tpr,_=roc_curve(label_binarize(y_true,classes=np.arange(num_classes)),probs[:,1]);
plt.plot(fpr,tpr);plt.plot([0,1],[0,1],'k--');plt.xlabel('FPR');plt.ylabel('TPR');plt.title('Final AUC-ROC');plt.show()

# Comparison table
ensemble_df=pd.read_csv(os.path.join(results_dir,'ensemble_preds.csv'))
ensemble_acc=accuracy_score(ensemble_df['True'],ensemble_df['Pred'])
final_metrics = {'ensemble': ensemble_acc, 'final': acc_f}
comp = pd.DataFrame.from_dict(final_metrics, orient='index', columns=['Accuracy'])
display(comp)
