In [1]:
from load_data import load_pyg_obj
import torch
from torch_geometric.loader import DataLoader
from torch.utils.data import random_split


import random
import numpy as np
from sklearn.metrics import f1_score
from utils import fix_target_shapes,remove_unused_onehot_columns,set_seed,filter_metals
from mofstructure import mofdeconstructor

from fairmofsyncondition.read_write.coords_library import pytorch_geometric_to_ase

from pymatgen.io.ase import AseAtomsAdaptor
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer

convert_struct = {'cubic':0, 'hexagonal':1, 'monoclinic':2, 'orthorhombic':3, 'tetragonal':4,'triclinic':5, 'trigonal':6}

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


# GNN Energy


### 1) Load the Data


### to get oms
### stru.get_oms()["has_oms"]


In [2]:
from mofstructure.structure import MOFstructure
from tqdm import tqdm
import os
convert_metals = {j:i for i,j in enumerate(mofdeconstructor.transition_metals()[1:])}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
set_seed(seed=42)

In [3]:
from load_data import LMDBDataset

def load_pyg_obj_energy(path_to_mdb = "mof_syncondition_data"):    
    data = []
    for d in LMDBDataset(lmdb_path=path_to_mdb):
        d.x = d.x.float()
        data.append(d)
    return data


In [4]:
dataset = load_pyg_obj_energy(path_to_mdb="../../data/stability_energy_data/")

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch_geometric.nn import GINEConv, global_mean_pool
from sklearn.metrics import f1_score
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import RobustScaler


# ==================== Utils & setup ====================


In [6]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def set_seed(seed=42):
    import random, os
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Filtra classi rare (come nel tuo codice)
Y = [d.energy for d in dataset]
a,b = np.unique(Y, return_counts=True)
conv_y = {i:j for i,j in zip(a,b)}



# Split & loaders
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    dataset, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(42)
)

y_train = np.array([d.energy for d in train_dataset])
y_val = np.array([d.energy for d in val_dataset])
y_test = np.array([d.energy for d in test_dataset])


q_low, q_high = np.percentile(y_train, [1, 99])  # o [0.5, 99.5]
def clip_like_train(arr):
    return np.clip(arr, q_low, q_high)

y_train_c = clip_like_train(y_train)
y_val_c  = clip_like_train(y_val)
y_test_c  = clip_like_train(y_test)

# 3) Fit scaler SOLO sul train
# scaler = StandardScaler()
scaler = RobustScaler(with_centering=True, with_scaling=True)  # più robusto agli outlier
scaler.fit(y_train_c.reshape(-1,1))

import pickle

# salva
with open("target_scaler_energy.pkl", "wb") as f:
    pickle.dump(scaler, f)
    

# 4) Trasforma y
y_train_scaled = scaler.transform(y_train_c.reshape(-1,1)).ravel()
y_val_scaled  = scaler.transform(y_val_c.reshape(-1,1)).ravel()
y_test_scaled  = scaler.transform(y_test_c.reshape(-1,1)).ravel()


def attach_scaled_energy(subset, scaled_values):
    """
    subset: torch.utils.data.Subset (o lista di Data)
    scaled_values: array 1D di float già scalati, stessa lunghezza del subset
    """
    for data, y in zip(subset, scaled_values):
        data.energy_scaled = torch.tensor(y, dtype=torch.float32)

# Aggiungi il nuovo campo a ciascun split
attach_scaled_energy(train_dataset, y_train_scaled)
attach_scaled_energy(val_dataset,   y_val_scaled)
attach_scaled_energy(test_dataset,  y_test_scaled)


train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader   = DataLoader(val_dataset,   batch_size=128)
test_loader  = DataLoader(test_dataset,  batch_size=128)

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GINEConv, global_mean_pool
from sklearn.metrics import r2_score
import numpy as np

# ==================== Helpers extras (immutati) ====================

def _reshape_feat(tensor, d):
    if hasattr(d, "num_graphs"):  # Batch
        return tensor.view(d.num_graphs, -1)
    else:  # singolo Data
        return tensor.view(1, -1)

EXTRA_GETTERS = {
    "atomic_one_hot":      lambda d: _reshape_feat(d.atomic_one_hot, d),
    "space_group_number":  lambda d: _reshape_feat(d.space_group_number, d),
    "crystal_system":      lambda d: _reshape_feat(d.crystal_system, d),
    "oms":                 lambda d: _reshape_feat(d.oms, d),
    "cordinates":          lambda d: _reshape_feat(d.cordinates, d),  # (tenuto il nome come nel tuo dataset)
}

def compute_extras_dim(sample_data, selected_extras):
    dim = 0
    for name in selected_extras:
        if name not in EXTRA_GETTERS:
            raise ValueError(f"Feature extra sconosciuta: {name}")
        dim += EXTRA_GETTERS[name](sample_data).shape[1]
    return dim

def build_extras_tensor(data, selected_extras):
    if not selected_extras:
        return None
    parts = [EXTRA_GETTERS[name](data) for name in selected_extras]
    return torch.cat(parts, dim=1)

def extras_suffix(selected_extras):
    if not selected_extras:
        return "no_extras"
    return "_".join(selected_extras)

# ==================== Model (head -> regressione) ====================

class MetalSaltGNN_Ablation(nn.Module):
    def __init__(
        self,
        node_in_dim,
        edge_in_dim,
        lattice_in_dim=9,
        hidden_dim=128,
        num_gnn_layers=4,
        num_lattice_layers=2,
        num_mlp_layers=2,
        dropout=0.2,
        use_batchnorm=True,
        selected_extras=None,
        extras_dim=0
    ):
        super().__init__()
        self.use_batchnorm = use_batchnorm
        self.dropout = dropout
        self.selected_extras = selected_extras or []
        self.extras_dim = extras_dim

        # Edge encoder
        self.edge_encoder = nn.Sequential(
            nn.Linear(edge_in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        # GINE stack
        self.gnn_layers = nn.ModuleList()
        self.gnn_bns = nn.ModuleList() if use_batchnorm else None
        for i in range(num_gnn_layers):
            in_dim = node_in_dim if i == 0 else hidden_dim
            mlp = nn.Sequential(
                nn.Linear(in_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim)
            )
            self.gnn_layers.append(GINEConv(mlp, edge_dim=hidden_dim))
            if use_batchnorm:
                self.gnn_bns.append(nn.BatchNorm1d(hidden_dim))

        # Lattice encoder
        lattice_layers = []
        in_dim = lattice_in_dim
        for _ in range(max(1, num_lattice_layers - 1)):
            lattice_layers += [nn.Linear(in_dim, hidden_dim), nn.ReLU()]
            if use_batchnorm:
                lattice_layers.append(nn.BatchNorm1d(hidden_dim))
            in_dim = hidden_dim
        lattice_layers.append(nn.Linear(in_dim, hidden_dim))
        self.lattice_encoder = nn.Sequential(*lattice_layers)

        # Final MLP head -> output 1 per regressione
        final_in = hidden_dim * 2 + self.extras_dim
        mlp_layers = []
        in_dim = final_in
        for _ in range(max(1, num_mlp_layers - 1)):
            mlp_layers += [nn.Linear(in_dim, hidden_dim), nn.ReLU()]
            if use_batchnorm:
                mlp_layers.append(nn.BatchNorm1d(hidden_dim))
            mlp_layers.append(nn.Dropout(p=dropout))
            in_dim = hidden_dim
        mlp_layers.append(nn.Linear(in_dim, 1))  # <--- regressione
        self.final_mlp = nn.Sequential(*mlp_layers)

    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        e = self.edge_encoder(edge_attr)

        for i, conv in enumerate(self.gnn_layers):
            x = conv(x, edge_index, e)
            x = F.relu(x)
            if self.use_batchnorm:
                x = self.gnn_bns[i](x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        x_pool = global_mean_pool(x, batch)

        lattice = data.lattice.view(-1, 9)
        lattice_feat = self.lattice_encoder(lattice)

        extras = build_extras_tensor(data, self.selected_extras)
        if extras is not None:
            final_in = torch.cat([x_pool, lattice_feat, extras], dim=1)
        else:
            final_in = torch.cat([x_pool, lattice_feat], dim=1)

        out = self.final_mlp(final_in)        # shape [B, 1]
        return out.squeeze(-1)                # shape [B]

# ==================== Train / Eval per regressione ====================

def train_one_epoch(model, loader, criterion, optimizer, device, target_name="energy_scaled"):
    """
    Atteso: data[target_name] contiene il target SCALATO come float (shape [B] o [B,1])
    """
    model.train()
    total_loss = 0.0
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()

        preds = model(data)  # [B]
        target = data[target_name].view(-1).float()  # [B], scaled
        loss = criterion(preds, target)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    return total_loss / max(1, len(loader))

@torch.no_grad()
def evaluate(model, loader, device, target_name, scaler=None):
    """
    Ritorna metriche sia nello spazio SCALATO (loss) sia, se fornito uno 'scaler',
    nello spazio FISICO (MAE/RMSE/R2).
    - scaler: lo stesso usato per scalare il target (fit sul TRAIN). Deve supportare inverse_transform.
    """
    model.eval()
    losses = []
    y_true_scaled, y_pred_scaled = [], []
    for data in loader:
        data = data.to(device)
        preds = model(data)  # [B]
        target = data[target_name].view(-1).float()
        # se vuoi valutare la loss in scala del training (consigliato):
        loss_batch = F.l1_loss(preds, target, reduction='mean')  # stessa loss del training se usi MAE
        losses.append(loss_batch.item())

        y_true_scaled.append(target.cpu())
        y_pred_scaled.append(preds.cpu())

    y_true_scaled = torch.cat(y_true_scaled).numpy()
    y_pred_scaled = torch.cat(y_pred_scaled).numpy()

    results = {
        "val_loss_scaled": float(np.mean(losses)),  # ad es. MAE in scala z/robust
    }

    # Se hai lo scaler, riportiamo alle unità reali
    if scaler is not None:
        y_true = scaler.inverse_transform(y_true_scaled.reshape(-1,1)).ravel()
        y_pred = scaler.inverse_transform(y_pred_scaled.reshape(-1,1)).ravel()

        mae = np.mean(np.abs(y_pred - y_true))
        rmse = float(np.sqrt(np.mean((y_pred - y_true)**2)))
        r2 = float(r2_score(y_true, y_pred))

        results.update({
            "MAE": float(mae),
            "RMSE": rmse,
            "R2": r2,
        })
    return results

# ==================== Scelta loss/optimizer (esempio) ====================

# Per outlier, spesso conviene Huber (smussata) o MAE:
# criterion = nn.SmoothL1Loss(beta=1.0)  # Huber
# oppure
# criterion = nn.L1Loss()  # MAE, robusta agli outlier


In [8]:

# ==================== Run ablation ====================

node_in_dim = dataset[0].x.shape[1]
edge_in_dim = dataset[0].edge_attr.shape[1]
lattice_in_dim = 9

# Scegli qui le extra da usare nell’ablation:
# Esempio richiesto: ["oms", "atomic_one_hot"]

EXTRA_GETTERS = {
    "atomic_one_hot":      lambda d: _reshape_feat(d.atomic_one_hot, d),
    "space_group_number":  lambda d: _reshape_feat(d.space_group_number, d),
    "crystal_system":      lambda d: _reshape_feat(d.crystal_system, d),
    "oms":                 lambda d: _reshape_feat(d.oms, d),
    "cordinates":          lambda d: _reshape_feat(d.cordinates[0], d),
}


selected_extras = ["atomic_one_hot", "cordinates", "oms","space_group_number"]
# imoprtant: for saving
selected_extras = np.sort(selected_extras).tolist()
selected_extras

# Calcolo dinamico della dimensione delle extra
extras_dim = compute_extras_dim(dataset[0], selected_extras)

# Classi
number_of_runs = [1,2,3,4,5]
hidden_dim = 64
dropout = 0.35


In [11]:
results = []
suffix = extras_suffix(selected_extras)

for seed in number_of_runs:
    config_name = f"HID{hidden_dim}_DO{dropout}_SEED{seed}__{suffix}_energy"
    print(f"\n===== Training config: {config_name} =====")
    set_seed(seed)

    model = MetalSaltGNN_Ablation(
        node_in_dim=node_in_dim,
        edge_in_dim=edge_in_dim,
        lattice_in_dim=lattice_in_dim,
        hidden_dim=hidden_dim,
        num_gnn_layers=4,
        num_lattice_layers=2,
        num_mlp_layers=2,
        dropout=dropout,
        use_batchnorm=True,
        selected_extras=selected_extras,
        extras_dim=extras_dim).to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
    checkpoint_name = f"trained_models/Metal_salts_{config_name}_tmp_test.pt"

    best_metric = np.inf
    epochs_no_improve = 0
    patience = 5
    eval_every = 1

    for epoch in range(1, 1001):
        train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
        if epoch % eval_every == 0:
            res = evaluate(model, val_loader, device, "energy_scaled",scaler=scaler)
            print(f"train loss= {train_loss:.4f} \t val loss={res['val_loss_scaled']:.4f}")

            # Early stopping su top5, come nel tuo codice
            if res["val_loss_scaled"] < best_metric:
                best_metric = res["val_loss_scaled"]
                epochs_no_improve = 0
                torch.save(model.state_dict(), checkpoint_name)
            else:
                epochs_no_improve += eval_every
            if epochs_no_improve >= patience:
                print(f"Early stopping at epoch {epoch}")
                break

    # Valutazione test con il best checkpoint
    model.load_state_dict(torch.load(checkpoint_name, map_location=device))
    res_test = evaluate(model, test_loader, device, "energy_scaled")
    results.append({**res_test, 'config': config_name})
    print(f"{config_name} TEST: loss={res_test['val_loss_scaled']:.4f}")



===== Training config: HID64_DO0.35_SEED1__atomic_one_hot_cordinates_oms_space_group_number_energy =====
train loss= -1028.3870 	 val loss=6.5364
train loss= -3183.6568 	 val loss=16.8352
train loss= -6157.3815 	 val loss=28.0092
train loss= -10381.3831 	 val loss=54.3120
train loss= -15868.2311 	 val loss=78.9346
train loss= -21176.3532 	 val loss=95.5799
Early stopping at epoch 6
HID64_DO0.35_SEED1__atomic_one_hot_cordinates_oms_space_group_number_energy TEST: loss=6.4494

===== Training config: HID64_DO0.35_SEED2__atomic_one_hot_cordinates_oms_space_group_number_energy =====
train loss= -950.8655 	 val loss=7.4508
train loss= -3031.6677 	 val loss=17.0024
train loss= -6310.9382 	 val loss=30.9897


KeyboardInterrupt: 

In [10]:
res_test

{'val_loss_scaled': 7.151827010241422}

In [None]:
# carica
with open("target_scaler_energy.pkl", "rb") as f:
    scaler = pickle.load(f)

def inverse_to_physical(y_pred_scaled):
    return scaler.inverse_transform(y_pred_scaled.reshape(-1,1)).ravel()

# Esempio PyTorch: loss su target scalato
criterion = torch.nn.L1Loss()  # MAE spesso più robusta di MSE per outlier

# y_batch_scaled: tensore target già scalato
# pred_scaled: output della rete
# loss = criterion(pred_scaled, y_batch_scaled)
