In [1]:
from load_data import load_pyg_obj
import torch
from torch_geometric.loader import DataLoader
from torch.utils.data import random_split
import random
import numpy as np
import os

from utils import fix_target_shapes,remove_unused_onehot_columns,set_seed
%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


# GNN for Metal Salt Prediction


This code implements a **Graph Neural Network (GNN)** to predict metal salts, using:
- **Node Features**  
- **Edge Features**  
- **Lattice**  
- **Modified scherrer**
- **Oms**

---

## Code Structure

1. **Load the Data**  
   Import and prepare the dataset for use in the model.

2. **Define the GNN Model**  
   Define the neural network architecture (layers, activation functions, etc.).

3. **Train the Model**  
   - Train the model on the dataset.  
   - Save the trained GNN weights into the `tmp/` folder.  

   ⚠️ **Note**:  
   If you only want to **test the model** without re-training, you can **skip this section** and avoid running the training step.

4. **Load and Evaluate the Model**  
   - Load the trained model weights.  
   - Evaluate the model performance.  


### 1) Load the Data

In [2]:
# Load the dataset. Note the dataset is not included in the git repo, you have to download it!
data_in = load_pyg_obj(path_to_mdb="../../data/mof_syncondition_data/")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



#  Train / Validation / Test Split (80/10/10)
dataset = fix_target_shapes(data_in,"metal_salts")
dataset = remove_unused_onehot_columns(dataset,"metal_salts")
Y_size = max([torch.argmax(d["metal_salts"]).item() for d in dataset])
set_seed(seed=42) # 42
num_classes = Y_size+1
input_dim = dataset[0].x.shape[1]


# Add and reshape extra features
for data in dataset:
    data.modified_scherrer = data.modified_scherrer.view(1, 1).float()
    data.oms = data.oms.view(1, 1).float()



train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    dataset, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(42)
)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
test_loader = DataLoader(test_dataset, batch_size=32)

### 2) Define the GNN 

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GINEConv, global_mean_pool
from collections import defaultdict
from sklearn.metrics import f1_score

# ==================== Modello ====================
class MetalSaltGNN(nn.Module):
    def __init__(
        self,
        node_in_dim,
        edge_in_dim,
        lattice_in_dim,
        extra_feat_dim,
        hidden_dim,
        num_classes,
        num_gnn_layers=4,
        num_lattice_layers=2,
        num_mlp_layers=2,
        dropout=0.2,
        use_batchnorm=True
    ):
        super().__init__()

        self.use_batchnorm = use_batchnorm
        self.dropout = dropout

        self.edge_encoder = nn.Sequential(
            nn.Linear(edge_in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        self.gnn_layers = nn.ModuleList()
        self.gnn_bns = nn.ModuleList() if use_batchnorm else None

        for i in range(num_gnn_layers):
            in_dim = node_in_dim if i == 0 else hidden_dim
            mlp = nn.Sequential(
                nn.Linear(in_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
            )
            self.gnn_layers.append(GINEConv(mlp, edge_dim=hidden_dim))
            if use_batchnorm:
                self.gnn_bns.append(nn.BatchNorm1d(hidden_dim))

        # Lattice encoder
        lattice_layers = []
        in_dim = lattice_in_dim
        for _ in range(num_lattice_layers - 1):
            lattice_layers.append(nn.Linear(in_dim, hidden_dim))
            lattice_layers.append(nn.ReLU())
            if use_batchnorm:
                lattice_layers.append(nn.BatchNorm1d(hidden_dim))
            in_dim = hidden_dim
        lattice_layers.append(nn.Linear(in_dim, hidden_dim))
        self.lattice_encoder = nn.Sequential(*lattice_layers)

        # Extra feature encoder
        self.extra_feat_encoder = nn.Sequential(
            nn.Linear(extra_feat_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        # Final MLP
        mlp_layers = []
        in_dim = hidden_dim * 3  # x_pool + lattice_feat + extra_feat
        for _ in range(num_mlp_layers - 1):
            mlp_layers.append(nn.Linear(in_dim, hidden_dim))
            mlp_layers.append(nn.ReLU())
            if use_batchnorm:
                mlp_layers.append(nn.BatchNorm1d(hidden_dim))
            mlp_layers.append(nn.Dropout(p=dropout))
            in_dim = hidden_dim
        mlp_layers.append(nn.Linear(in_dim, num_classes))
        self.final_mlp = nn.Sequential(*mlp_layers)

    def forward(self, data):
        x, edge_index, edge_attr, batch, lattice = (
            data.x, data.edge_index, data.edge_attr, data.batch, data.lattice
        )

        edge_feat = self.edge_encoder(edge_attr)

        for i, conv in enumerate(self.gnn_layers):
            x = conv(x, edge_index, edge_feat)
            x = F.relu(x)
            if self.use_batchnorm:
                x = self.gnn_bns[i](x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        x_pool = global_mean_pool(x, batch)

        lattice_flat = lattice.reshape(-1, 9)
        lattice_feat = self.lattice_encoder(lattice_flat)

        extra_feat = torch.cat([
            data.modified_scherrer,
            data.oms
        ], dim=1)
        extra_feat = self.extra_feat_encoder(extra_feat)

        out = torch.cat([x_pool, lattice_feat, extra_feat], dim=1)
        out = self.final_mlp(out)

        return out

def train(model, loader, criterion, optimizer, device, target_name):
    model.train()
    total_loss = 0
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)

        # ✅ Converti one-hot target in indice classe
        target = torch.argmax(data[target_name], dim=1).long()

        loss = criterion(out, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)


def evaluate(model, loader, device, target_name):
    model.eval()
    correct_top1 = 0
    correct_top3 = 0
    correct_top5 = 0
    correct_top10 = 0
    total = 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            out = model(data)
            
            labels = torch.argmax(data[target_name], dim=1).long()
            #labels = data[target_name]().long()
            total += labels.size(0)
            _, pred = out.topk(10, 1, True, True)

            correct_top1 += (pred[:, :1] == labels.view(-1, 1)).sum().item()
            correct_top3 += (pred[:, :3] == labels.view(-1, 1)).sum().item()
            correct_top5 += (pred[:, :5] == labels.view(-1, 1)).sum().item()
            correct_top10 += (pred[:, :10] == labels.view(-1, 1)).sum().item()

    return {
        "top1_acc": correct_top1 / total,
        "top3_acc": correct_top3 / total,
        "top5_acc": correct_top5 / total,
        "top10_acc": correct_top10 / total,
        "macro_f1": 0.0  # Calcolo F1 macro opzionale
    }

In [4]:
node_in_dim = data_in[0].x.shape[1]
edge_in_dim = data_in[0].edge_attr.shape[1]
lattice_in_dim = 9
extra_feat_dim = 2
hidden_dim = 32
dropout = 0.25

number_of_runs = [0,1,2,3,4]  # due seed come richiesto

### 3) Train the GNN

In [5]:
results = []


for seed in number_of_runs:
    config_name = f"HID{hidden_dim}_DO{dropout}_SEED{seed}_X_edgeAttr_lattice_modScherrer_Oms"
    print(f"\n===== Training config: {config_name} =====")

    set_seed(seed)

    model = MetalSaltGNN(
        node_in_dim, edge_in_dim, lattice_in_dim, extra_feat_dim,
        hidden_dim, num_classes,
        num_gnn_layers=4,
        num_lattice_layers=2,
        num_mlp_layers=3,
        dropout=dropout,
        use_batchnorm=True
    ).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    checkpoint_name = f"tmp/Metal_salts_{config_name}.pt"
    patience = 50
    eval_every = 5
    best_metric = 0.0
    epochs_no_improve = 0

    for epoch in range(1, 1001):
        loss = train(model, train_loader, criterion, optimizer, device, "metal_salts")
        if epoch % eval_every == 0:
            res = evaluate(model, val_loader, device, "metal_salts")
            macro_top_k = res["top5_acc"]

            if macro_top_k > best_metric:
                best_metric = macro_top_k
                epochs_no_improve = 0
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'best_metric': best_metric
                }, checkpoint_name)
            else:
                epochs_no_improve += eval_every

            if epochs_no_improve >= patience:
                print(f"Early stopping at epoch {epoch} (no improvement for {patience} evals).")
                break

    checkpoint = torch.load(checkpoint_name, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])

    res_test = evaluate(model, test_loader, device, "metal_salts")

    results.append({
        'config': config_name,
        'top1_acc': res_test['top1_acc'],
        'top10_acc': res_test['top10_acc'],
        'top5_acc': res_test['top5_acc'],
        'top3_acc': res_test['top3_acc'],
        'macro_f1': res_test['macro_f1']
    })

    print(f"{config_name} TEST: top10_acc={res_test['top10_acc']:.4f}, top5_acc={res_test['top5_acc']:.4f}, top3_acc={res_test['top3_acc']:.4f}, macro_f1={res_test['macro_f1']:.4f}")




===== Training config: HID32_DO0.25_SEED0_X_edgeAttr_lattice_modScherrer_Oms =====
Early stopping at epoch 15 (no improvement for 5 evals).
HID32_DO0.25_SEED0_X_edgeAttr_lattice_modScherrer_Oms TEST: top10_acc=0.4218, top5_acc=0.3251, top3_acc=0.2953, macro_f1=0.0000

===== Training config: HID32_DO0.25_SEED1_X_edgeAttr_lattice_modScherrer_Oms =====


KeyboardInterrupt: 

### 4) Load and Evaluate the Model

In [None]:
results = []

for seed in number_of_runs:
    config_name = f"HID{hidden_dim}_DO{dropout}_SEED{seed}_X_edgeAttr_lattice_modScherrer_Oms"
    print(f"\n===== Evaluating config: {config_name} =====")
    

    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    
    model = MetalSaltGNN(
        node_in_dim, edge_in_dim, lattice_in_dim, extra_feat_dim,
        hidden_dim, num_classes,
        num_gnn_layers=4,
        num_lattice_layers=2,
        num_mlp_layers=3,
        dropout=dropout,
        use_batchnorm=True
    ).to(device)
    
    checkpoint_name = f"tmp/Metal_salts_{config_name}.pt"
    # Carica best model e valuta su test
    checkpoint = torch.load(checkpoint_name, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])

    res_test = evaluate(model, test_loader, device, "metal_salts")

    results.append({
        'config': config_name,
        'top1_acc': res_test['top1_acc'],
        'top10_acc': res_test['top10_acc'],
        'top5_acc': res_test['top5_acc'],
        'top3_acc': res_test['top3_acc'],
        'macro_f1': res_test['macro_f1']
    })
    print(f"{config_name} TEST: top10_acc={res_test['top10_acc']:.4f}, top5_acc={res_test['top5_acc']:.4f}, top3_acc={res_test['top3_acc']:.4f}, macro_f1={res_test['macro_f1']:.4f}")


===== Evaluating config: HID32_DO0.25_SEED0_X_edgeAttr_lattice_modScherrer_Oms =====
HID32_DO0.25_SEED0_X_edgeAttr_lattice_modScherrer_Oms TEST: top10_acc=0.6625, top5_acc=0.5409, top3_acc=0.4665, macro_f1=0.0000

===== Evaluating config: HID32_DO0.25_SEED1_X_edgeAttr_lattice_modScherrer_Oms =====
HID32_DO0.25_SEED1_X_edgeAttr_lattice_modScherrer_Oms TEST: top10_acc=0.6650, top5_acc=0.5385, top3_acc=0.4615, macro_f1=0.0000

===== Evaluating config: HID32_DO0.25_SEED2_X_edgeAttr_lattice_modScherrer_Oms =====
HID32_DO0.25_SEED2_X_edgeAttr_lattice_modScherrer_Oms TEST: top10_acc=0.6898, top5_acc=0.5732, top3_acc=0.4764, macro_f1=0.0000

===== Evaluating config: HID32_DO0.25_SEED3_X_edgeAttr_lattice_modScherrer_Oms =====
HID32_DO0.25_SEED3_X_edgeAttr_lattice_modScherrer_Oms TEST: top10_acc=0.6675, top5_acc=0.5186, top3_acc=0.4541, macro_f1=0.0000

===== Evaluating config: HID32_DO0.25_SEED4_X_edgeAttr_lattice_modScherrer_Oms =====
HID32_DO0.25_SEED4_X_edgeAttr_lattice_modScherrer_Oms TEST

In [None]:
import pandas as pd
df = pd.DataFrame(results)


df['base_config'] = df['config'].str.replace(r'_SEED\d+', '', regex=True)

metrics = ['top1_acc','top10_acc','top5_acc','top3_acc','macro_f1']


grouped = df.groupby('base_config')[metrics].agg(['mean','std'])

for cfg, row in grouped.iterrows():
    print(f"\nConfig: {cfg}")
    for m in metrics:
        mean = row[(m,'mean')]
        std  = row[(m,'std')]
        print(f"{m} \t= {mean:.2f} ± {std:.2f}")


Config: HID32_DO0.25_X_edgeAttr_lattice_modScherrer_Oms
top1_acc 	= 0.31 ± 0.01
top10_acc 	= 0.67 ± 0.01
top5_acc 	= 0.54 ± 0.02
top3_acc 	= 0.46 ± 0.01
macro_f1 	= 0.00 ± 0.00
