In [1]:
import torch
from torch_geometric.loader import DataLoader
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.nn import GCNConv  # Replace with GATConv, SAGEConv, etc.
from torch_geometric.data import Data
from sklearn.metrics import accuracy_score, mean_squared_error
from torch_geometric.nn import GATConv, global_mean_pool

  from .autonotebook import tqdm as notebook_tqdm


Define the model you want to use (GCN or MPNN or GraphSAGE, etc.)

In [3]:
class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads=8):
        super().__init__()
        self.conv1 = GATConv(in_channels, hidden_channels, heads=heads)
        self.conv2 = GATConv(hidden_channels * heads, hidden_channels, heads=heads)
        self.lin = Linear(hidden_channels * heads, out_channels)
        
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index).relu()
        x = global_mean_pool(x, batch)  # Aggregate node features to graph level
        return self.lin(x)


In [4]:
# 🏋️ Training loop
def train():
    model.train()
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        out = model(batch)
        loss = criterion(out.squeeze(), batch.y if task == "regression" else batch.y.long())
        loss.backward()
        optimizer.step()

# 📈 Evaluation
def evaluate(loader):
    model.eval()
    preds, labels = [], []
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            out = model(batch)
            preds.append(out.squeeze().cpu())
            labels.append(batch.y.cpu())
    preds = torch.cat(preds)
    labels = torch.cat(labels)

    if task == "classification":
        pred_classes = preds.argmax(dim=1)
        return accuracy_score(labels, pred_classes)
    else:
        return mean_squared_error(labels, preds)


In [7]:
#5 fold cross-validation
task = "classification"  # or "regression"

test_data = torch.load(f"../4_train_test_split/5fold_cv/{task}/{task}_test.pt")
average_score = 0

for fold_idx in range(5):
    #load data from fold_idx-th fold
    train_data = torch.load(f"../4_train_test_split/5fold_cv/{task}/{task}_train_fold{fold_idx}.pt")
    val_data = torch.load(f"../4_train_test_split/5fold_cv/{task}/{task}_val_fold{fold_idx}.pt")

    train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=32)
    test_loader = DataLoader(test_data, batch_size=32)


    #initialize models
    if task == "classification":
        num_classes = len(set([int(data.y.item()) for data in train_data]))
        model = GAT(in_channels=train_data[0].x.size(1), hidden_channels=64, out_channels=num_classes, heads=8)
        criterion = torch.nn.CrossEntropyLoss()
    else:
        model = GAT(in_channels=train_data[0].x.size(1), hidden_channels=64, out_channels=1, heads=8)
        criterion = torch.nn.MSELoss()

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # training & evaluation
    for epoch in range(1, 101):
        train()
        metric = evaluate(val_loader)
        print(f"Epoch {epoch:03d} - {'Accuracy' if task == 'classification' else 'MSE'}: {metric:.4f}")

    # ✅ Final test evaluation
    test_metric = evaluate(test_loader)
    average_score += test_metric
    print(f"\n🧪 Test {'Accuracy' if task == 'classification' else 'MSE'} for {fold_idx}-th fold: {test_metric:.4f}")

average_score /= 5

  test_data = torch.load(f"../4_train_test_split/5fold_cv/{task}/{task}_test.pt")
  train_data = torch.load(f"../4_train_test_split/5fold_cv/{task}/{task}_train_fold{fold_idx}.pt")
  val_data = torch.load(f"../4_train_test_split/5fold_cv/{task}/{task}_val_fold{fold_idx}.pt")


Epoch 001 - Accuracy: 0.5444
Epoch 002 - Accuracy: 0.5444
Epoch 003 - Accuracy: 0.5444
Epoch 004 - Accuracy: 0.5444
Epoch 005 - Accuracy: 0.6889
Epoch 006 - Accuracy: 0.7444
Epoch 007 - Accuracy: 0.6222
Epoch 008 - Accuracy: 0.6111
Epoch 009 - Accuracy: 0.6667
Epoch 010 - Accuracy: 0.7000
Epoch 011 - Accuracy: 0.6778
Epoch 012 - Accuracy: 0.6667
Epoch 013 - Accuracy: 0.7111
Epoch 014 - Accuracy: 0.6667
Epoch 015 - Accuracy: 0.7222
Epoch 016 - Accuracy: 0.7222
Epoch 017 - Accuracy: 0.6889
Epoch 018 - Accuracy: 0.6667
Epoch 019 - Accuracy: 0.6667
Epoch 020 - Accuracy: 0.7222
Epoch 021 - Accuracy: 0.7222
Epoch 022 - Accuracy: 0.6889
Epoch 023 - Accuracy: 0.7222
Epoch 024 - Accuracy: 0.7222
Epoch 025 - Accuracy: 0.7111
Epoch 026 - Accuracy: 0.7333
Epoch 027 - Accuracy: 0.7000
Epoch 028 - Accuracy: 0.7222
Epoch 029 - Accuracy: 0.7111
Epoch 030 - Accuracy: 0.7222
Epoch 031 - Accuracy: 0.7000
Epoch 032 - Accuracy: 0.7111
Epoch 033 - Accuracy: 0.7222
Epoch 034 - Accuracy: 0.7222
Epoch 035 - Ac

  train_data = torch.load(f"../4_train_test_split/5fold_cv/{task}/{task}_train_fold{fold_idx}.pt")
  val_data = torch.load(f"../4_train_test_split/5fold_cv/{task}/{task}_val_fold{fold_idx}.pt")


Epoch 001 - Accuracy: 0.5444
Epoch 002 - Accuracy: 0.5444
Epoch 003 - Accuracy: 0.5444
Epoch 004 - Accuracy: 0.5444
Epoch 005 - Accuracy: 0.5444
Epoch 006 - Accuracy: 0.6222
Epoch 007 - Accuracy: 0.5444
Epoch 008 - Accuracy: 0.6889
Epoch 009 - Accuracy: 0.6778
Epoch 010 - Accuracy: 0.7333
Epoch 011 - Accuracy: 0.7111
Epoch 012 - Accuracy: 0.7111
Epoch 013 - Accuracy: 0.6889
Epoch 014 - Accuracy: 0.7333
Epoch 015 - Accuracy: 0.7000
Epoch 016 - Accuracy: 0.7000
Epoch 017 - Accuracy: 0.7000
Epoch 018 - Accuracy: 0.7000
Epoch 019 - Accuracy: 0.6889
Epoch 020 - Accuracy: 0.7111
Epoch 021 - Accuracy: 0.7111
Epoch 022 - Accuracy: 0.7000
Epoch 023 - Accuracy: 0.6889
Epoch 024 - Accuracy: 0.7222
Epoch 025 - Accuracy: 0.7000
Epoch 026 - Accuracy: 0.7222
Epoch 027 - Accuracy: 0.7111
Epoch 028 - Accuracy: 0.7000
Epoch 029 - Accuracy: 0.6889
Epoch 030 - Accuracy: 0.7000
Epoch 031 - Accuracy: 0.7000
Epoch 032 - Accuracy: 0.7000
Epoch 033 - Accuracy: 0.6889
Epoch 034 - Accuracy: 0.6778
Epoch 035 - Ac

  train_data = torch.load(f"../4_train_test_split/5fold_cv/{task}/{task}_train_fold{fold_idx}.pt")
  val_data = torch.load(f"../4_train_test_split/5fold_cv/{task}/{task}_val_fold{fold_idx}.pt")


Epoch 001 - Accuracy: 0.5444
Epoch 002 - Accuracy: 0.5444
Epoch 003 - Accuracy: 0.5444
Epoch 004 - Accuracy: 0.5444
Epoch 005 - Accuracy: 0.5444
Epoch 006 - Accuracy: 0.5444
Epoch 007 - Accuracy: 0.5667
Epoch 008 - Accuracy: 0.5444
Epoch 009 - Accuracy: 0.5444
Epoch 010 - Accuracy: 0.7333
Epoch 011 - Accuracy: 0.6889
Epoch 012 - Accuracy: 0.7222
Epoch 013 - Accuracy: 0.7000
Epoch 014 - Accuracy: 0.7222
Epoch 015 - Accuracy: 0.7333
Epoch 016 - Accuracy: 0.7333
Epoch 017 - Accuracy: 0.6889
Epoch 018 - Accuracy: 0.7444
Epoch 019 - Accuracy: 0.7333
Epoch 020 - Accuracy: 0.7111
Epoch 021 - Accuracy: 0.7111
Epoch 022 - Accuracy: 0.6556
Epoch 023 - Accuracy: 0.7333
Epoch 024 - Accuracy: 0.7444
Epoch 025 - Accuracy: 0.7333
Epoch 026 - Accuracy: 0.7111
Epoch 027 - Accuracy: 0.7111
Epoch 028 - Accuracy: 0.7444
Epoch 029 - Accuracy: 0.7111
Epoch 030 - Accuracy: 0.7000
Epoch 031 - Accuracy: 0.7000
Epoch 032 - Accuracy: 0.6778
Epoch 033 - Accuracy: 0.7000
Epoch 034 - Accuracy: 0.6667
Epoch 035 - Ac

  train_data = torch.load(f"../4_train_test_split/5fold_cv/{task}/{task}_train_fold{fold_idx}.pt")
  val_data = torch.load(f"../4_train_test_split/5fold_cv/{task}/{task}_val_fold{fold_idx}.pt")


Epoch 001 - Accuracy: 0.4000
Epoch 002 - Accuracy: 0.5333
Epoch 003 - Accuracy: 0.6000
Epoch 004 - Accuracy: 0.5333
Epoch 005 - Accuracy: 0.5333
Epoch 006 - Accuracy: 0.5333
Epoch 007 - Accuracy: 0.5333
Epoch 008 - Accuracy: 0.7111
Epoch 009 - Accuracy: 0.7000
Epoch 010 - Accuracy: 0.7000
Epoch 011 - Accuracy: 0.7111
Epoch 012 - Accuracy: 0.7111
Epoch 013 - Accuracy: 0.7444
Epoch 014 - Accuracy: 0.7000
Epoch 015 - Accuracy: 0.7000
Epoch 016 - Accuracy: 0.7000
Epoch 017 - Accuracy: 0.7333
Epoch 018 - Accuracy: 0.7333
Epoch 019 - Accuracy: 0.7444
Epoch 020 - Accuracy: 0.7222
Epoch 021 - Accuracy: 0.7222
Epoch 022 - Accuracy: 0.7333
Epoch 023 - Accuracy: 0.6889
Epoch 024 - Accuracy: 0.6556
Epoch 025 - Accuracy: 0.7222
Epoch 026 - Accuracy: 0.7333
Epoch 027 - Accuracy: 0.7444
Epoch 028 - Accuracy: 0.7333
Epoch 029 - Accuracy: 0.7333
Epoch 030 - Accuracy: 0.7222
Epoch 031 - Accuracy: 0.7444
Epoch 032 - Accuracy: 0.7444
Epoch 033 - Accuracy: 0.7333
Epoch 034 - Accuracy: 0.7333
Epoch 035 - Ac

  train_data = torch.load(f"../4_train_test_split/5fold_cv/{task}/{task}_train_fold{fold_idx}.pt")
  val_data = torch.load(f"../4_train_test_split/5fold_cv/{task}/{task}_val_fold{fold_idx}.pt")


Epoch 001 - Accuracy: 0.3933
Epoch 002 - Accuracy: 0.5393
Epoch 003 - Accuracy: 0.5393
Epoch 004 - Accuracy: 0.5393
Epoch 005 - Accuracy: 0.5393
Epoch 006 - Accuracy: 0.5393
Epoch 007 - Accuracy: 0.5393
Epoch 008 - Accuracy: 0.5393
Epoch 009 - Accuracy: 0.5730
Epoch 010 - Accuracy: 0.5506
Epoch 011 - Accuracy: 0.5843
Epoch 012 - Accuracy: 0.5843
Epoch 013 - Accuracy: 0.5955
Epoch 014 - Accuracy: 0.5843
Epoch 015 - Accuracy: 0.5843
Epoch 016 - Accuracy: 0.5843
Epoch 017 - Accuracy: 0.5843
Epoch 018 - Accuracy: 0.5843
Epoch 019 - Accuracy: 0.6067
Epoch 020 - Accuracy: 0.6517
Epoch 021 - Accuracy: 0.5955
Epoch 022 - Accuracy: 0.5843
Epoch 023 - Accuracy: 0.6292
Epoch 024 - Accuracy: 0.5843
Epoch 025 - Accuracy: 0.5730
Epoch 026 - Accuracy: 0.5955
Epoch 027 - Accuracy: 0.6517
Epoch 028 - Accuracy: 0.5955
Epoch 029 - Accuracy: 0.6180
Epoch 030 - Accuracy: 0.6629
Epoch 031 - Accuracy: 0.5843
Epoch 032 - Accuracy: 0.6517
Epoch 033 - Accuracy: 0.5955
Epoch 034 - Accuracy: 0.6629
Epoch 035 - Ac

In [8]:
print(f"\n🧪 Average Test {'Accuracy' if task == 'classification' else 'MSE'}: {average_score:.4f}")


🧪 Average Test Accuracy: 0.6600
