In [1]:
from load_data import load_pyg_obj
import torch
from torch_geometric.loader import DataLoader
from torch.utils.data import random_split
import random
import numpy as np
import os

from utils import fix_target_shapes,remove_unused_onehot_columns,set_seed
%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


# load the data

In [2]:
# Load the dataset. Note the dataset is not included in the git repo, you have to download it!
data_in = load_pyg_obj(path_to_mdb="../../data/mof_syncondition_data/")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



#  Train / Validation / Test Split (80/10/10)
dataset = fix_target_shapes(data_in,"metal_salts")
dataset = remove_unused_onehot_columns(dataset,"metal_salts")
Y_size = max([torch.argmax(d["metal_salts"]).item() for d in dataset])
set_seed(seed=42) # 42
num_classes = Y_size+1
input_dim = dataset[0].x.shape[1]


train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    dataset, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(42)
)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
test_loader = DataLoader(test_dataset, batch_size=32)

# Define the GNN 

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GINEConv, global_mean_pool
from collections import defaultdict
from sklearn.metrics import f1_score



### GNN architecture 

class MetalSaltGNN(nn.Module):
    def __init__(
        self,
        node_in_dim,
        edge_in_dim,
        lattice_in_dim,
        hidden_dim,
        num_classes,
        num_gnn_layers=4,
        num_lattice_layers=2,
        num_mlp_layers=2,
        dropout=0.2,
        use_batchnorm=True
    ):
        super().__init__()

        self.use_batchnorm = use_batchnorm
        self.dropout = dropout

        # Edge encoder
        self.edge_encoder = nn.Sequential(
            nn.Linear(edge_in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        # GINE layers: ora usiamo GINEConv
        self.gnn_layers = nn.ModuleList()
        self.gnn_bns = nn.ModuleList() if use_batchnorm else None

        for i in range(num_gnn_layers):
            in_dim = node_in_dim if i == 0 else hidden_dim
            mlp = nn.Sequential(
                nn.Linear(in_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
            )
            self.gnn_layers.append(GINEConv(mlp, edge_dim=hidden_dim))  # <-- aggiunto edge_dim!
            if use_batchnorm:
                self.gnn_bns.append(nn.BatchNorm1d(hidden_dim))

        # Lattice encoder come MLP profondo parametrico
        lattice_layers = []
        in_dim = lattice_in_dim
        for _ in range(num_lattice_layers - 1):
            lattice_layers.append(nn.Linear(in_dim, hidden_dim))
            lattice_layers.append(nn.ReLU())
            if use_batchnorm:
                lattice_layers.append(nn.BatchNorm1d(hidden_dim))
            in_dim = hidden_dim
        lattice_layers.append(nn.Linear(in_dim, hidden_dim))
        self.lattice_encoder = nn.Sequential(*lattice_layers)

        # Final MLP layers
        mlp_layers = []
        in_dim = hidden_dim * 2  # x_pool + lattice_feat
        for _ in range(num_mlp_layers - 1):
            mlp_layers.append(nn.Linear(in_dim, hidden_dim))
            mlp_layers.append(nn.ReLU())
            if use_batchnorm:
                mlp_layers.append(nn.BatchNorm1d(hidden_dim))
            mlp_layers.append(nn.Dropout(p=dropout))
            in_dim = hidden_dim
        mlp_layers.append(nn.Linear(in_dim, num_classes))
        self.final_mlp = nn.Sequential(*mlp_layers)

    def forward(self, data):
        x, edge_index, edge_attr, batch, lattice = (
            data.x, data.edge_index, data.edge_attr, data.batch, data.lattice
        )

        # Encode edge_attr
        edge_feat = self.edge_encoder(edge_attr)

        # GINE layers
        for i, conv in enumerate(self.gnn_layers):
            x = conv(x, edge_index, edge_feat)  # ora usiamo edge_feat nella propagazione!
            x = F.relu(x)
            if self.use_batchnorm:
                x = self.gnn_bns[i](x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        # Pooling globale
        x_pool = global_mean_pool(x, batch)

        # Lattice processing
        lattice_flat = lattice.reshape(-1, 3 * 3)  # batch_size x 9
        lattice_feat = self.lattice_encoder(lattice_flat)

        # Combine and classify
        out = torch.cat([x_pool, lattice_feat], dim=1)
        out = self.final_mlp(out)

        return out


# Train and eval functions 


########################################## training and eval
# TRAINING FUNCTION
def train(model, loader, criterion, optimizer, device,target):
    model.train()
    total_loss = 0
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        if target == "metal_salts":
            loss = criterion(out, data.metal_salts.squeeze())
        elif target == "ligands":
            loss = criterion(out, data.ligands.squeeze())
        else:
            print("specify target")
            assert False
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
    return total_loss / len(loader.dataset)


def evaluate(model, loader, device, target_name):
    model.eval()
    correct_top1 = 0
    correct_top10 = 0
    correct_top3 = 0
    correct_top5 = 0
    total = 0

    all_preds = []
    all_targets = []

    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            out = model(data)  # [batch_size, num_classes]

            # Predizioni
            pred_top1 = out.argmax(dim=1)  # [batch_size]
            pred_top10 = out.topk(10, dim=1).indices  # [batch_size, 10]
            pred_top5 = out.topk(5, dim=1).indices  # [batch_size, 10]
            pred_top3 = out.topk(3, dim=1).indices  # [batch_size, 10]

            # Target
            if target_name == "metal_salts":
                target = data.metal_salts.argmax(dim=1)  # [batch_size]
            else:
                print("specify target")
                assert False

            # Salva per F1
            all_preds.append(pred_top1.cpu())
            all_targets.append(target.cpu())

            # Top-1
            correct_top1 += (pred_top1 == target).sum().item()
            # Top-10
            correct_top10 += (pred_top10 == target.unsqueeze(1)).any(dim=1).sum().item()
            correct_top3 += (pred_top3 == target.unsqueeze(1)).any(dim=1).sum().item()
            correct_top5 += (pred_top5 == target.unsqueeze(1)).any(dim=1).sum().item()

            total += data.num_graphs

    # Concatenate tutte le predizioni e target in un unico tensore/vettore
    all_preds = torch.cat(all_preds).numpy()
    all_targets = torch.cat(all_targets).numpy()

    # Calcola macro-F1
    macro_f1 = f1_score(all_targets, all_preds, average='macro', zero_division=0)

    return {
        'top1_acc': correct_top1 / total,
        'top10_acc': correct_top10 / total,
        'top3_acc': correct_top3 / total,
        'top5_acc': correct_top5 / total,
        'macro_f1': macro_f1
    }



In [None]:
# X edge_indes, edge_attr, lattice

# Example values: adjust to match your actual data dimensions
node_in_dim = data_in[0].x.shape[1]
edge_in_dim = data_in[0].edge_attr.shape[1]
lattice_in_dim = 9  # 3x3 lattice flattened
hidden_dim = 32
dropout = 0.25

number_of_runs = [0,1,2,3,4]  # due seed come richiesto

In [None]:


results = []

for seed in number_of_runs:
    config_name = f"HID{hidden_dim}_DO{dropout}_SEED{seed}_X_edgeAttr_lattice"
    print(f"\n===== Training config: {config_name} =====")
    
    # Seed per riproducibilità
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    
    # Modello
    model = MetalSaltGNN(
        node_in_dim, edge_in_dim, lattice_in_dim, hidden_dim, num_classes,
        num_gnn_layers=4,
        num_lattice_layers=2,
        num_mlp_layers=3,
        dropout=dropout,
        use_batchnorm=True
    ).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    os.makedirs("tmp", exist_ok=True)
    checkpoint_name = f"tmp/Metal_salts_{config_name}.pt"
    patience = 50
    eval_every = 5
    best_metric = 0.0
    epochs_no_improve = 0

    for epoch in range(1, 1001):
        loss = train(model, train_loader, criterion, optimizer, device, "metal_salts")
        if epoch % eval_every == 0:
            res = evaluate(model, val_loader, device, "metal_salts")
            macro_top_k = res["top5_acc"]

            if macro_top_k > best_metric:
                best_metric = macro_top_k
                epochs_no_improve = 0
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'best_metric': best_metric
                }, checkpoint_name)
            else:
                epochs_no_improve += eval_every

            if epochs_no_improve >= patience:
                print(f"Early stopping at epoch {epoch} (no improvement for {patience} evals).")
                break

    # Carica best model e valuta su test
    checkpoint = torch.load(checkpoint_name, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])

    res_test = evaluate(model, test_loader, device, "metal_salts")

    # Logga i risultati
    results.append({
        'config': config_name,
        'top1_acc': res_test['top1_acc'],
        'top10_acc': res_test['top10_acc'],
        'top5_acc': res_test['top5_acc'],
        'top3_acc': res_test['top3_acc'],
        'macro_f1': res_test['macro_f1']
    })
    print(f"{config_name} TEST: top10_acc={res_test['top10_acc']:.4f}, top5_acc={res_test['top5_acc']:.4f}, top3_acc={res_test['top3_acc']:.4f}, macro_f1={res_test['macro_f1']:.4f}")


===== Training config: HID32_DO0.25_SEED0_X_edgeAttr_lattice_tmp =====
Early stopping at epoch 30 (no improvement for 5 evals).
HID32_DO0.25_SEED0_X_edgeAttr_lattice_tmp TEST: top10_acc=0.5261, top5_acc=0.4268, top3_acc=0.3573, macro_f1=0.0243


# Load trained models and evaluate

In [12]:
results = []

for seed in number_of_runs:
    config_name = f"HID{hidden_dim}_DO{dropout}_SEED{seed}_X_edgeAttr_lattice"
    print(f"\n===== Evaluating config: {config_name} =====")
    
    # Seed per riproducibilità
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    
    # Modello
    model = MetalSaltGNN(
        node_in_dim, edge_in_dim, lattice_in_dim, hidden_dim, num_classes,
        num_gnn_layers=4,
        num_lattice_layers=2,
        num_mlp_layers=3,
        dropout=dropout,
        use_batchnorm=True
    ).to(device)

    os.makedirs("tmp", exist_ok=True)
    checkpoint_name = f"tmp/Metal_salts_{config_name}.pt"
        # Carica best model e valuta su test
    checkpoint = torch.load(checkpoint_name, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])

    res_test = evaluate(model, test_loader, device, "metal_salts")

    # Logga i risultati
    results.append({
        'config': config_name,
        'top1_acc': res_test['top1_acc'],
        'top10_acc': res_test['top10_acc'],
        'top5_acc': res_test['top5_acc'],
        'top3_acc': res_test['top3_acc'],
        'macro_f1': res_test['macro_f1']
    })
    print(f"{config_name} TEST: top10_acc={res_test['top10_acc']:.4f}, top5_acc={res_test['top5_acc']:.4f}, top3_acc={res_test['top3_acc']:.4f}, macro_f1={res_test['macro_f1']:.4f}")


===== Evaluating config: HID32_DO0.25_SEED0_X_edgeAttr_lattice =====
HID32_DO0.25_SEED0_X_edgeAttr_lattice TEST: top10_acc=0.6650, top5_acc=0.5757, top3_acc=0.4839, macro_f1=0.0813

===== Evaluating config: HID32_DO0.25_SEED1_X_edgeAttr_lattice =====
HID32_DO0.25_SEED1_X_edgeAttr_lattice TEST: top10_acc=0.6675, top5_acc=0.5459, top3_acc=0.4665, macro_f1=0.0622

===== Evaluating config: HID32_DO0.25_SEED2_X_edgeAttr_lattice =====
HID32_DO0.25_SEED2_X_edgeAttr_lattice TEST: top10_acc=0.6650, top5_acc=0.5533, top3_acc=0.4839, macro_f1=0.0677

===== Evaluating config: HID32_DO0.25_SEED3_X_edgeAttr_lattice =====
HID32_DO0.25_SEED3_X_edgeAttr_lattice TEST: top10_acc=0.6526, top5_acc=0.5310, top3_acc=0.4665, macro_f1=0.0660

===== Evaluating config: HID32_DO0.25_SEED4_X_edgeAttr_lattice =====
HID32_DO0.25_SEED4_X_edgeAttr_lattice TEST: top10_acc=0.6749, top5_acc=0.5484, top3_acc=0.4864, macro_f1=0.0635


In [19]:
import pandas as pd
df = pd.DataFrame(results)


# normalizza config togliendo il seed
df['base_config'] = df['config'].str.replace(r'_SEED\d+', '', regex=True)

metrics = ['top1_acc','top10_acc','top5_acc','top3_acc','macro_f1']

# calcolo mean e std per ogni base_config
grouped = df.groupby('base_config')[metrics].agg(['mean','std'])

# stampa formattata
for cfg, row in grouped.iterrows():
    print(f"\nConfig: {cfg}")
    for m in metrics:
        mean = row[(m,'mean')]
        std  = row[(m,'std')]
        print(f"{m} \t= {mean:.2f} ± {std:.2f}")


Config: HID32_DO0.25_X_edgeAttr_lattice
top1_acc 	= 0.31 ± 0.02
top10_acc 	= 0.67 ± 0.01
top5_acc 	= 0.55 ± 0.02
top3_acc 	= 0.48 ± 0.01
macro_f1 	= 0.07 ± 0.01
