Create and train different GNNs using Spektral.

## 1. Imports and Setup

In [235]:
import os, random, pickle
import spektral
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras import layers
from spektral.data import Dataset
from spektral.data.loaders import DisjointLoader
from spektral.layers import GCNConv, GATConv, ECCConv, GraphSageConv, GINConv, GlobalSumPool, GlobalAvgPool, GlobalMaxPool, GeneralConv
from tensorflow.keras import Model, Input
from tensorflow.keras.layers import Dense, Dropout
from sklearn.model_selection import KFold, train_test_split
from sklearn.metrics import accuracy_score, mean_squared_error, r2_score

## 2. Reproducibility Setup 

In [178]:
SEED = 55
os.environ["PYTHONHASHSEED"] = str(SEED)
os.environ["TF_DETERMINISTIC_OPS"] = "1"   # enforce deterministic ops (GPU too)
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

# Load Custom Dataset to be able to load pickle file
class CustomDataset(Dataset):
    def __init__(self, graph_list, **kwargs):
        self.graph_list = graph_list
        super().__init__(**kwargs)

    def read(self):
        return self.graph_list

## 3. Load Dataset and Train-Test Split

In [None]:
with open("../4_gnn_data_pipeline/dataset_deepchem.pkl", "rb") as f:
    dataset = pickle.load(f)   # already a Spektral Dataset object

graph = dataset[0]

# Looking at shape of features in the dataset
x = graph.x         # node features
a = graph.a         # adjacency / edge indices
e = graph.e         # edge features
y = graph.y         # target(s)
g = graph.globals   # global features

print(x.shape, a.shape, e.shape if e is not None else None, y.shape, g.shape)

(6, 30) (2, 12) (12, 11) (512,) (1, 1)


## 4. Define GNN Models

In [298]:
class GCN_EdgeModel(Model):
    def __init__(self, n_node_features, n_edge_features, n_globals=0, n_targets=1, hidden=64, dropout=0.0, pool="avg"):
        super().__init__()
        self.n_node_features = n_node_features
        self.n_edge_features = n_edge_features
        self.n_globals = n_globals
        self.n_targets = n_targets

        # Edge-aware convolutions
        self.conv1 = ECCConv(hidden, activation="relu")
        self.conv2 = ECCConv(hidden, activation="relu")

        # Dropout
        self.dropout = layers.Dropout(dropout)

        # Pooling
        self.pool = {"sum": GlobalSumPool(),
                     "avg": GlobalAvgPool(),
                     "max": GlobalMaxPool()}[pool]

        # Dense head
        self.fc = layers.Dense(hidden, activation="relu")
        self.out = layers.Dense(n_targets, activation="linear")  # regression output

    def call(self, inputs, training=False):
        """
        inputs: tuple from DisjointLoader:
            x: node features (num_nodes, n_node_features)
            a: adjacency matrix, SparseTensor (num_edges, 2)
            e: edge features (num_edges, n_edge_features)
            i: batch index vector (num_nodes,) mapping node -> graph
            g: global features (num_graphs_in_batch, n_globals) or None
        """
        x, a, i, e, g = inputs

        # Ensure edge features are 2D
        if e is not None and len(e.shape) == 1:
            e = tf.expand_dims(e, axis=-1)

        # ECCConv expects SparseTensor adjacency and edge features
        x = self.conv1([x, a, e])
        x = self.conv2([x, a, e])

        # Dropout
        x = self.dropout(x, training=training)

        # Graph pooling
        x = self.pool([x, i])

        # Concatenate global features if present
        if g is not None and self.n_globals > 0:
            x = tf.concat([x, g], axis=-1)

        # Dense head
        x = self.fc(x)
        return self.out(x)


In [129]:
# Doesn't take in edges; have to add a custom layer for that later
class GATModel(Model):
    def __init__(self, n_node_features, n_edge_features, n_globals, n_targets=1, hidden=64, dropout=0.0, heads=4, pool="avg"):
        super().__init__()
        self.n_node_features = n_node_features
        self.n_edge_features = n_edge_features
        self.n_globals = n_globals

        self.conv1 = GATConv(hidden, attn_heads=heads, concat_heads=True, activation="elu")
        self.conv2 = GATConv(hidden, attn_heads=1, concat_heads=False, activation="elu")
        self.dropout = Dropout(dropout)
        self.pool = {"sum": GlobalSumPool(),
                     "avg": GlobalAvgPool(),
                     "max": GlobalMaxPool()}[pool]
        self.fc = Dense(hidden, activation="relu")
        self.out = Dense(n_targets, activation="linear")

    def call(self, inputs, training=False):
        x, a, i, e, g = inputs
        x = self.conv1([x, a])
        x = self.conv2([x, a])
        x = self.dropout(x, training=training)
        x = self.pool([x, i])
        x = tf.concat([x, g], axis=-1)
        x = self.fc(x)
        return self.out(x)

In [130]:
# Also doesn't use edges
class GraphSAGEModel(Model):
    def __init__(self, n_node_features, n_edge_features, n_globals, n_targets=1, hidden=64, dropout=0.0, pool="avg"):
        super().__init__()
        self.n_node_features = n_node_features
        self.n_edge_features = n_edge_features
        self.n_globals = n_globals

        self.conv1 = GraphSageConv(hidden, activation="relu")
        self.conv2 = GraphSageConv(hidden, activation="relu")
        self.dropout = Dropout(dropout)
        self.pool = {"sum": GlobalSumPool(),
                     "avg": GlobalAvgPool(),
                     "max": GlobalMaxPool()}[pool]
        self.fc = Dense(hidden, activation="relu")
        self.out = Dense(n_targets, activation="linear")

    def call(self, inputs, training=False):
        x, a, i, e, g = inputs
        x = self.conv1([x, a])
        x = self.conv2([x, a])
        x = self.dropout(x, training=training)
        x = self.pool([x, i])
        x = tf.concat([x, g], axis=-1)
        x = self.fc(x)
        return self.out(x)

In [131]:
# Also doesn't use edges
class GINModel(Model):
    def __init__(self, n_node_features, n_edge_features, n_globals, n_targets=1, hidden=64, dropout=0.0, pool="avg"):
        super().__init__()
        self.n_node_features = n_node_features
        self.n_edge_features = n_edge_features
        self.n_globals = n_globals
        # Each GINConv requires an MLP
        mlp1 = tf.keras.Sequential([Dense(hidden, activation="relu"), Dense(hidden, activation="relu")])
        mlp2 = tf.keras.Sequential([Dense(hidden, activation="relu"), Dense(hidden, activation="relu")])
        self.conv1 = GINConv(mlp1)
        self.conv2 = GINConv(mlp2)
        self.dropout = Dropout(dropout)
        self.pool = {"sum": GlobalSumPool(),
                     "avg": GlobalAvgPool(),
                     "max": GlobalMaxPool()}[pool]
        self.fc = Dense(hidden, activation="relu")
        self.out = Dense(n_targets, activation="linear")

    def call(self, inputs, training=False):
        x, a, i, e, g = inputs
        x = self.conv1([x, a])
        x = self.conv2([x, a])
        x = self.dropout(x, training=training)
        x = self.pool([x, i])
        x = tf.concat([x, g], axis=-1)
        x = self.fc(x)
        return self.out(x)

In [132]:
class MPNNModel(Model):
    def __init__(self, n_node_features, n_edge_features, n_globals, n_targets=1, hidden=64, dropout=0.0, pool="avg"):
        super().__init__()
        self.n_node_features = n_node_features
        self.n_edge_features = n_edge_features
        self.n_globals = n_globals

        self.conv1 = MPNNConv(hidden, activation="relu")
        self.conv2 = MPNNConv(hidden, activation="relu")
        self.dropout = Dropout(dropout)
        self.pool = {"sum": GlobalSumPool(),
                     "avg": GlobalAvgPool(),
                     "max": GlobalMaxPool()}[pool]
        self.fc = Dense(hidden, activation="relu")
        self.out = Dense(n_targets, activation="linear")

    def call(self, inputs, training=False):
        x, a, i, e, g = inputs
        x = self.conv1([x, a, e])
        x = self.conv2([x, a, e])
        x = self.dropout(x, training=training)
        x = self.pool([x, i])
        x = tf.concat([x, g], axis=-1)
        x = self.fc(x)
        return self.out(x)

## 4. Evalution Helper Function

In [133]:
def compute_metrics(y_true, y_pred):
    mae = mean_absolute_error(y_true, y_pred)
    mse = mean_squared_error(y_true, y_pred)
    rmse = float(np.sqrt(mse))
    r2 = r2_score(y_true, y_pred)
    return {"MAE": mae, "MSE": mse, "RMSE": rmse, "R2": r2}

## 5. Cross-Validation and hyperparameter sweep

In [260]:
# Build model factory to make a fresh model for each kfold sweep to prevent biased weights

def model_factory(model_name,
                  n_node_features,
                  n_edge_features,
                  n_globals,
                  n_targets=1,
                  hidden=64,
                  dropout=0.0,
                  heads=4,
                  pool="avg"):

    model_name = model_name.lower()

    if model_name == "gcn":
        return GCNModel(n_node_features=n_node_features,
                        n_edge_features=n_edge_features,
                        n_globals=n_globals,
                        n_targets=n_targets,
                        hidden=hidden,
                        dropout=dropout,
                        pool=pool)
    elif model_name == "gat":
        return GATModel(n_node_features=n_node_features,
                        n_edge_features=n_edge_features,
                        n_globals=n_globals,
                        n_targets=n_targets,
                        hidden=hidden,
                        dropout=dropout,
                        pool=pool)
    elif model_name == "graphsage":
        return GraphSAGEModel(n_node_features=n_node_features,
                              n_edge_features=n_edge_features,
                              n_globals=n_globals,
                              n_targets=n_targets,
                              hidden=hidden,
                              dropout=dropout,
                              pool=pool)
    elif model_name == "gin":
        return GINModel(n_node_features=n_node_features,
                        n_edge_features=n_edge_features,
                        n_globals=n_globals,
                        n_targets=n_targets,
                        hidden=hidden,
                        dropout=dropout,
                        pool=pool)
    elif model_name == "mpnn":
        return MPNNModel(n_node_features=n_node_features,
                         n_edge_features=n_edge_features,
                         n_globals=n_globals,
                         n_targets=n_targets,
                         hidden=hidden,
                         dropout=dropout,
                         pool=pool)
    else:
        raise ValueError(f"Unknown model_name: {model_name}")


In [280]:
# Wrap DisjointLoader to include g so we don't have issues with tensor dimensions

class DisjointLoaderWithGlobals(DisjointLoader):
    def collate(self, batch):
        # Call the usual disjoint collate
        inputs, target = super().collate(batch)

        # Extract globals separately
        g = [graph.globals for graph in batch]
        g = tf.convert_to_tensor(g, dtype=tf.float32)

        # Return inputs + g
        x, a, e, i = inputs
        return (x, a, e, i, g), target

In [301]:
# -----------------------------
# Hyperparameters
# -----------------------------
hidden_options = [64, 128, 256]
dropout_options = [0.0, 0.2, 0.4]
lr_options = [1e-3, 5e-4, 1e-4]
pool_options = ["sum", "avg", "max"]

spectrum_length = 100 #Want to predict the absorption spectrum with this number of points

n_hidden = 64
n_outputs = spectrum_length

model_names = ["GCN", "GAT", "GraphSAGE", "GIN", "MPNN"]

# -----------------------------
# 10-fold CV
# -----------------------------
kf = KFold(n_splits=10, shuffle=True, random_state=SEED)

results = []
# ------------------------------
# Input feature sizes 
# ------------------------------
n_node_features = graph.x.shape[-1] # or dataset.node_features.shape[-1] 
n_edge_indices = graph.a.shape[-1]
n_edge_features = graph.e.shape # or 0 if no edge features 
n_globals = graph.globals.shape # or 0 if no global features 
n_targets = 1 # regression target

print(n_edge_features)
# -----------------------------
# Training loop
# -----------------------------
for model_name in model_names:
    print(f"\nStarting hyperparameter sweep for {model_name}")
    
    for hidden in hidden_options:
        for dropout in dropout_options:
            for lr in lr_options:
                for pool in pool_options:
                    fold_mae = []
                    
                    for fold, (train_idx, val_idx) in enumerate(kf.split(dataset), 1):
                        # Split dataset
                        train_dataset = dataset[train_idx.tolist()]
                        val_dataset = dataset[val_idx.tolist()]
                        
                        train_loader = DisjointLoaderWithGlobals(train_dataset, batch_size=32, shuffle=True, node_level=False)
                        val_loader = DisjointLoaderWithGlobals(val_dataset, batch_size=32, shuffle=False, node_level=False)

                        # Build model
                        model = model_factory(
                            model_name,
                            n_node_features, 
                            n_edge_features, 
                            n_globals, 
                            n_targets,
                            hidden,
                            dropout,
                            pool
                        )
                        optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
                        loss_fn = tf.keras.losses.MeanSquaredError()
                        
                        # Training per epoch
                        epochs = 50
                        patience = 10
                        best_val_mae = np.inf
                        wait = 0
                        best_weights = None
                        
                        for epoch in range(epochs):
                            # ---- Training ----
                            for batch in train_loader:  
                                # Looking at shape of features in the dataset
                                inputs, y = batch
                                x, a, i, e, g = inputs
                                print("Node features:", x.shape)
                                print("Edge indices (sparse):", a)
                                print("Edge features:", e.shape if e is not None else None)
                                print("Batch index:", i.shape)
                                print("Global features:", g.shape if g is not None else None)
                                print("Targets:", y.shape)

                                # Forward pass
                                with tf.GradientTape() as tape:
                                    y_pred = model([x, a, i, e, g], training=True)
                                    loss = loss_fn(y, y_pred)
                                # Backwards pass
                                grads = tape.gradient(loss, model.trainable_variables)
                                optimizer.apply_gradients(zip(grads, model.trainable_variables))
                            
                            # ---- Validation ----
                            val_maes = []
                            for batch in val_loader:
                                inputs, y = batch
                                x, a, i, e, g = inputs
                                y_pred = model([x, a, i, e, g], training=False)
                                val_maes.append(tf.reduce_mean(tf.keras.losses.mean_absolute_error(y, y_pred)).numpy())
                            
                            val_mae = np.mean(val_maes)
                            
                            if val_mae < best_val_mae:
                                best_val_mae = val_mae
                                best_weights = model.get_weights()
                                wait = 0
                            else:
                                wait += 1
                            
                            if wait >= patience:
                                break
                        
                        # Restore best weights
                        model.set_weights(best_weights)
                        fold_mae.append(best_val_mae)
                    
                    mean_mae = np.mean(fold_mae)
                    results.append({
                        "model": model_name,
                        "hidden": hidden,
                        "dropout": dropout,
                        "lr": lr,
                        "pool": pool,
                        "mean_mae": mean_mae
                    })
                    print(f"{model_name} | hidden={hidden} dropout={dropout} lr={lr} pool={pool} -> MAE={mean_mae:.4f}")

# Save hyperparameter sweep results
results_df = pd.DataFrame(results)
results_df.to_csv("hyperparameter_sweep_results.csv", index=False)


(12, 11)

Starting hyperparameter sweep for GCN
Node features: (490, 30)
Edge indices (sparse): SparseTensor(indices=tf.Tensor(
[[   0    0]
 [   0    3]
 [   0    4]
 ...
 [  63 1085]
 [  63 1086]
 [  63 1087]], shape=(2080, 2), dtype=int64), values=tf.Tensor([ 1.  7.  7. ... 18. 14. 18.], shape=(2080,), dtype=float32), dense_shape=tf.Tensor([  64 1088], shape=(2,), dtype=int64))
Edge features: (490,)
Batch index: (1088, 11)
Global features: (32, 1, 1)
Targets: (32, 512)


  np.random.shuffle(a)
1. The `call()` method of your layer may be crashing. Try to `__call__()` the layer eagerly on some test input first to see if it works. E.g. `x = np.random.random((3, 4)); y = layer(x)`
2. If the `call()` method is correct, then you may need to implement the `def build(self, input_shape)` method on your layer. It should create all variables used by the layer (e.g. by calling `layer.build()` on all its children layers).
Exception encountered: ''Exception encountered when calling ECCConv.call().

[1mDimensions must be equal, but are 2080 and 1088 for '{{node ecc_conv_198_1/einsum/Einsum}} = Einsum[N=2, T=DT_FLOAT, equation="...ab,...abc->...ac"](ecc_conv_198_1/GatherV2, ecc_conv_198_1/Reshape)' with input shapes: [2080,30], [1088,30,64].[0m

Arguments received by ECCConv.call():
  • inputs=['tf.Tensor(shape=(490, 30), dtype=float32)', 'tf.Tensor(shape=(64, 1088), dtype=float32)', 'tf.Tensor(shape=(1088, 11), dtype=float32)']
  • mask=['None', 'None', 'None']''


InvalidArgumentError: Exception encountered when calling ECCConv.call().

[1m{{function_node __wrapped____MklEinsum_N_2_device_/job:localhost/replica:0/task:0/device:CPU:0}} Expected dimension 2080 at axis 0 of the input shaped [1088,30,64] but got dimension 1088 [Op:Einsum] name: [0m

Arguments received by ECCConv.call():
  • inputs=['tf.Tensor(shape=(490, 30), dtype=float32)', 'tf.Tensor(shape=(64, 1088), dtype=float32)', 'tf.Tensor(shape=(1088, 11), dtype=float32)']
  • mask=['None', 'None', 'None']

## 6. Retrain Folds with best hyperparameters

In [None]:
from tensorflow.keras.callbacks import EarlyStopping

# Store retrained fold models
fold_models = {name: [] for name in model_names}

# Prepare a list to collect summary rows for CSV
summary_rows = []

for fold, (train_idx, val_idx) in enumerate(kf.split(dataset)):
    train_dataset = dataset[train_idx.tolist()]
    val_dataset   = dataset[val_idx.tolist()]
    
    train_loader = DisjointLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader   = DisjointLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
    
    # Collect true targets for validation
    val_true = []
    for g in val_dataset:
        val_true.append(g.y)
    val_true = np.vstack(val_true)
    
    for model_name in model_names:
        # Get best hyperparameters
        params = best_params_per_model[model_name]
        model = model_factory(model_name, hidden=params["hidden"], dropout=params["dropout"])
        model.compile(optimizer=tf.keras.optimizers.Adam(params["lr"]), loss="mse")

        # Early stopping
        es = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True, verbose=0)

        # Train
        model.fit(train_loader.load(), steps_per_epoch=train_loader.steps_per_epoch,
                  validation_data=val_loader.load(), validation_steps=val_loader.steps_per_epoch,
                  epochs=EPOCHS, verbose=0, callbacks=[es])
        
        # Save retrained model
        fold_models[model_name].append(model)
        
        # Predict on validation set
        val_preds = []
        for batch in val_loader:
            x, a, i, e, g = batch
            y_pred = model([x, a, i, e, g], training=False)
            val_preds.append(y_pred.numpy())
        val_preds = np.vstack(val_preds)
        
        # Compute metrics
        mae = mean_absolute_error(val_true, val_preds)
        mse = mean_squared_error(val_true, val_preds)
        rmse = np.sqrt(mse)
        r2 = r2_score(val_true, val_preds)
        
        # Append metrics to summary list
        summary_rows.append({
            "Fold": fold+1,
            "Model": model_name,
            "MAE": mae,
            "MSE": mse,
            "RMSE": rmse,
            "R2": r2
        })
        
        print(f"Fold {fold+1}, {model_name}: MAE={mae:.4f}, RMSE={rmse:.4f}, R2={r2:.4f}")

# Save summary to CSV
summary_df = pd.DataFrame(summary_rows)
summary_df.to_csv("crossval_summary_reg.csv", index=False)
print("\nCross-validation summary saved to crossval_summary_reg.csv")


## 7. Visualize CV Metrics

In [None]:
import matplotlib.pyplot as plt

# Load summary CSV
summary_df = pd.read_csv("crossval_summary_reg.csv")

# Metrics to analyze
metrics = ["MAE", "MSE", "RMSE", "R2"]

# Plot bar charts for each metric across folds
for metric in metrics:
    plt.figure(figsize=(8, 5))
    for model in summary_df["Model"].unique():
        model_data = summary_df[summary_df["Model"] == model]
        plt.bar(model_data["Fold"] + (0.15 * list(summary_df["Model"].unique()).tolist().index(model)),
                model_data[metric], width=0.15, label=model)
    plt.xlabel("Fold")
    plt.ylabel(metric)
    plt.title(f"{metric} across folds")
    plt.legend()
    plt.tight_layout()
    plt.show()

# Print mean ± std for each metric per model
for model in summary_df["Model"].unique():
    print(f"\n=== {model} ===")
    model_data = summary_df[summary_df["Model"] == model]
    for metric in metrics:
        mean_val = model_data[metric].mean()
        std_val = model_data[metric].std()
        print(f"{metric}: {mean_val:.4f} ± {std_val:.4f}")


## 8. Do Ensemble Averaging on the test set and visualize results

In [None]:
# ✅ Import your custom model classes
from models import GCNModel, GATModel, GraphSAGEModel, GINModel, MPNNModel

# Test set loader (no shuffle, one pass)
test_loader = DisjointLoader(test_dataset, batch_size=32, epochs=1, shuffle=False)

# Custom objects for loading
custom_objects = {
    "GCNModel": GCNModel,
    "GATModel": GATModel,
    "GraphSAGEModel": GraphSAGEModel,
    "GINModel": GINModel,
    "MPNNModel": MPNNModel,
}

# Use your model naming convention
model_names = ["GCN", "GAT", "GraphSAGE", "GIN", "MPNN"]

results = {}

for model_name in model_names:
    print(f"\n=== Ensemble Averaging for {model_name} ===")

    all_preds = []

    for fold in range(10):
        # Match your checkpoint filenames
        model = tf.keras.models.load_model(
            f"checkpoints/{model_name}_fold{fold}.h5",
            custom_objects=custom_objects
        )

        fold_preds = model.predict(test_loader.load(), verbose=0)
        all_preds.append(fold_preds)

    # Stack and average predictions across folds
    all_preds = np.stack(all_preds, axis=0)
    ensemble_preds = np.mean(all_preds, axis=0).squeeze()

    # Collect true values
    y_true = np.concatenate([y for _, y in test_loader], axis=0).squeeze()

    # ✅ Metrics
    mae = mean_absolute_error(y_true, ensemble_preds)
    mse = mean_squared_error(y_true, ensemble_preds)
    rmse = np.sqrt(mse)
    r2 = r2_score(y_true, ensemble_preds)

    results[model_name] = {"MAE": mae, "MSE": mse, "RMSE": rmse, "R2": r2}

    # 📈 Scatter plot
    plt.figure(figsize=(6, 6))
    plt.scatter(y_true, ensemble_preds, alpha=0.6)
    plt.plot([y_true.min(), y_true.max()],
             [y_true.min(), y_true.max()],
             color="red", linestyle="--")
    plt.xlabel("True Values")
    plt.ylabel("Ensemble Predictions")
    plt.title(f"{model_name}: True vs Ensemble Predictions")
    plt.tight_layout()
    plt.show()

# Save all results
df_results = pd.DataFrame(results).T
df_results.to_csv("ensemble_results_testset.csv", index=True)

print("\n=== Ensemble Results Across Models ===")
print(df_results)


## 9. Final model training and test evalution

In [None]:
# -----------------------------
# Parameters / directories
# -----------------------------
save_dir = "final_models"
os.makedirs(save_dir, exist_ok=True)

# List of model names and corresponding factory functions that build them
# Replace these factories with your custom model constructors
model_names = ["GCN", "GAT", "GraphSAGE", "GIN", "MPNN"]
model_factories = {
    "GCN": lambda: GCNModel(n_node_features, n_edge_features, n_globals, n_targets=1),
    "GAT": lambda: GATModel(n_node_features, n_edge_features, n_globals, n_targets=1),
    "GraphSAGE": lambda: GraphSAGEModel(n_node_features, n_edge_features, n_globals, n_targets=1),
    "GIN": lambda: GINModel(n_node_features, n_edge_features, n_globals, n_targets=1),
    "MPNN": lambda: MPNNModel(n_node_features, n_edge_features, n_globals, n_targets=1),
}

# -----------------------------
# Merge all folds into train+val
# -----------------------------
# Assume you already have your folds stored
all_graphs = []
for fold_dataset in folds:  # folds is a list of Spektral datasets
    all_graphs.extend(fold_dataset.graphs)

# Split off 10% for validation
np.random.seed(42)
idx = np.random.permutation(len(all_graphs))
split = int(0.9 * len(all_graphs))
train_graphs = [all_graphs[i] for i in idx[:split]]
val_graphs   = [all_graphs[i] for i in idx[split:]]

train_dataset = Dataset(train_graphs)
val_dataset   = Dataset(val_graphs)

train_loader = DisjointLoader(train_dataset, batch_size=32, epochs=1, shuffle=True)
val_loader   = DisjointLoader(val_dataset, batch_size=32, epochs=1, shuffle=False)
test_loader  = DisjointLoader(test_dataset, batch_size=32, epochs=1, shuffle=False)  # already held-out

# -----------------------------
# Train final models
# -----------------------------
results = []

for model_name in model_names:
    print(f"\n=== Training final {model_name} model ===")

    # Build model
    model = model_factories[model_name]()

    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
        loss="mse",
        metrics=["mae"]
    )

    # Callbacks
    callbacks = [
        tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=20, restore_best_weights=True),
        tf.keras.callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=10)
    ]

    # Train
    model.fit(
        train_loader.load(),
        steps_per_epoch=train_loader.steps_per_epoch,
        validation_data=val_loader.load(),
        validation_steps=val_loader.steps_per_epoch,
        epochs=300,
        verbose=1,
        callbacks=callbacks
    )

    # Save final model (one file per model)
    save_path = os.path.join(save_dir, f"{model_name.lower()}_final_model.keras")
    model.save(save_path)
    print(f"✅ Saved {model_name} model at {save_path}")

    # -----------------------------
    # Evaluate on test set
    # -----------------------------
    y_true_list, y_pred_list = [], []

    for batch in test_loader:
        x, y = batch
        preds = model(x, training=False)
        y_true_list.append(y.numpy())
        y_pred_list.append(preds.numpy())

    y_true = np.vstack(y_true_list).squeeze()
    y_pred = np.vstack(y_pred_list).squeeze()

    mae = mean_absolute_error(y_true, y_pred)
    mse = mean_squared_error(y_true, y_pred)
    rmse = np.sqrt(mse)
    r2 = r2_score(y_true, y_pred)

    results.append({"Model": model_name, "MAE": mae, "MSE": mse, "RMSE": rmse, "R2": r2})

    # -----------------------------
    # Scatter + regression line
    # -----------------------------
    plt.figure(figsize=(6,6))
    plt.scatter(y_true, y_pred, alpha=0.6, label="Predictions")

    # Perfect prediction line
    min_val, max_val = y_true.min(), y_true.max()
    plt.plot([min_val, max_val], [min_val, max_val], "r--", label="Perfect prediction")

    # Fit regression line
    slope, intercept = np.polyfit(y_true, y_pred, 1)
    plt.plot(y_true, slope*y_true + intercept, "b-", label=f"Fit: y={slope:.2f}x+{intercept:.2f}")

    plt.xlabel("True Values")
    plt.ylabel("Predicted Values")
    plt.title(f"{model_name} Final Model: True vs Predicted")
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f"{model_name.lower()}_final_scatter.png"))
    plt.close()

# -----------------------------
# Save metrics CSV
# -----------------------------
df_metrics = pd.DataFrame(results)
df_metrics.to_csv(os.path.join(save_dir, "final_metrics_reg.csv"), index=False)
print("\n✅ Saved final metrics to final_metrics_reg.csv")
print(df_metrics)
