In [29]:
from load_data import load_pyg_obj
import torch
from torch_geometric.loader import DataLoader
from torch.utils.data import random_split


import random
import numpy as np
from sklearn.metrics import f1_score
from utils import fix_target_shapes,remove_unused_onehot_columns,set_seed,filter_metals
from mofstructure import mofdeconstructor

from fairmofsyncondition.read_write.coords_library import pytorch_geometric_to_ase

from pymatgen.io.ase import AseAtomsAdaptor
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer

convert_struct = {'cubic':0, 'hexagonal':1, 'monoclinic':2, 'orthorhombic':3, 'tetragonal':4,'triclinic':5, 'trigonal':6}

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# GNN for Metal Salt Prediction

This code implements a **Graph Neural Network (GNN)** to predict metal salts, using:
- **Node Features**  
- **Edge Features**  
- **Lattice**   
- **Oms**
- **Atomic number**
---

## Code Structure

1. **Load the Data**  
   Import and prepare the dataset for use in the model.

2. **Define the GNN Model**  
   Define the neural network architecture (layers, activation functions, etc.).

3. **Train the Model**  
   - Train the model on the dataset.  
   - Save the trained GNN weights into the `tmp/` folder.  

   ⚠️ **Note**:  
   If you only want to **test the model** without re-training, you can **skip this section** and avoid running the training step.

4. **Load and Evaluate the Model**  
   - Load the trained model weights.  
   - Evaluate the model performance.  


### 1) Load the Data


### to get oms
### stru.get_oms()["has_oms"]


In [30]:
import os

os.listdir("../../data/")

['mof_syncondition_data']

In [None]:
from mofstructure.structure import MOFstructure
from tqdm import tqdm
import os
convert_metals = {j:i for i,j in enumerate(mofdeconstructor.transition_metals()[1:])}


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
set_seed(seed=42)

data_in = load_pyg_obj(path_to_mdb="../../data/mof_syncondition_data/")
dataset = fix_target_shapes(data_in, "metal_salts")
dataset = remove_unused_onehot_columns(dataset, "metal_salts")


bad = []
good = []

if os.path.exists("../../dataset_cleen_all_info.pt"):
    print("Loading precomputed dataset...")
    dataset = torch.load("../../dataset_cleen_all_info.pt")
else:
    print("Computing dataset with all info...")
    for d in tqdm(dataset):
        try:
            # =======================
            # Parte 1: atomic one-hot
            # =======================
            node_features = d.x.numpy()
            atom_num = node_features[:, 0].astype(int)
            a, b = np.unique(atom_num, return_counts=True)
            emb = torch.zeros(120)
            for aa, bb in zip(a, b):
                emb[aa] = bb
            d.atomic_one_hot = emb
            
            # =======================
            # Parte 2: struttura ASE
            # =======================
            ase_atoms = pytorch_geometric_to_ase(d)
            stru = MOFstructure(ase_atoms)
            pymat = AseAtomsAdaptor.get_structure(ase_atoms)
            
            # =======================
            # Parte 3: OMS
            # =======================
            emb = torch.zeros(96)
            tmp_dict = dict()
            for i in stru.get_oms()["metal_info"]:
                cord = i["coordination_number"]
                metal = i["metal"]

                if metal in tmp_dict:
                    if cord > tmp_dict[metal]:
                        tmp_dict[metal] = cord
                else:
                    tmp_dict[metal] = cord

            for i, j in tmp_dict.items():
                emb[convert_metals[i]] = j
            d.cordinates = emb
            
            # =======================
            # Parte 4: spazio e sistema cristallino
            # =======================
            sga = SpacegroupAnalyzer(pymat)
            space_group_number = sga.get_space_group_number()
            emb = torch.zeros(231)
            emb[space_group_number] = 1
            d.space_group_number = emb

            get_crystal_system = sga.get_crystal_system()
            emb = torch.zeros(7)
            emb[convert_struct[get_crystal_system]] = 1
            d.crystal_system = emb
            # =======================
            # Parte 5: altri attributi
            # =======================
            d.oms = d.oms.view(1, 1).float()

            ###################### no porosity is too long to compute
            #por = stru.get_porosity()
            #por = list(por.values())
            #d.porosity = torch.tensor(por)

            d.modified_scherrer = None
            d.microstrain = None

            # Se arrivo qui senza eccezioni → struttura buona
            good.append(d)
        except Exception:
            bad.append(d)
            continue
    torch.save(good, "../../dataset_cleen_all_info.pt")   # salva lista di Data

Loading precomputed dataset...


In [23]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch_geometric.nn import GINEConv, global_mean_pool
from sklearn.metrics import f1_score

# ==================== Utils & setup ====================

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def set_seed(seed=42):
    import random, os
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Filtra classi rare (come nel tuo codice)
Y = [d.metal_salts.argmax(dim=1).item() for d in dataset]
a,b = np.unique(Y, return_counts=True)
conv_y = {i:j for i,j in zip(a,b)}
good = [d for d in dataset if conv_y[d.metal_salts.argmax(dim=1).item()] > 5]
dataset = good

print("There are classes", len(np.unique([d.metal_salts.argmax(dim=1).item() for d in dataset])), len(dataset))

# Split & loaders
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    dataset, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(42)
)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader   = DataLoader(val_dataset,   batch_size=128)
test_loader  = DataLoader(test_dataset,  batch_size=128)

# ==================== NEW: ablation helpers ====================

# 1) Getter per ciascuna feature extra (sempre reshaped a [B, D])
#    Aggiungi qui dentro eventuali nuove feature future.

def _reshape_feat(tensor, d):
    # Se è Batch (ha num_graphs)
    if hasattr(d, "num_graphs"):
        return tensor.view(d.num_graphs, -1)
    else:  # è un singolo Data
        return tensor.view(1, -1)

EXTRA_GETTERS = {
    "atomic_one_hot":      lambda d: _reshape_feat(d.atomic_one_hot, d),
    "space_group_number":  lambda d: _reshape_feat(d.space_group_number, d),
    "crystal_system":      lambda d: _reshape_feat(d.crystal_system, d),
    "oms":                 lambda d: _reshape_feat(d.oms, d),
    "cordinates":          lambda d: _reshape_feat(d.cordinates, d),
}
def compute_extras_dim(sample_data, selected_extras):
    dim = 0
    for name in selected_extras:
        if name not in EXTRA_GETTERS:
            raise ValueError(f"Feature extra sconosciuta: {name}")
        dim += EXTRA_GETTERS[name](sample_data).shape[1]
    return dim

def build_extras_tensor(data, selected_extras):
    if not selected_extras:
        return None
    parts = [EXTRA_GETTERS[name](data) for name in selected_extras]
    return torch.cat(parts, dim=1)

def extras_suffix(selected_extras):
    if not selected_extras:
        return "no_extras"
    return "_".join(selected_extras)

# ==================== Model ====================

class MetalSaltGNN_Ablation(nn.Module):
    def __init__(
        self,
        node_in_dim,
        edge_in_dim,
        lattice_in_dim=9,
        hidden_dim=128,
        num_classes=10,
        num_gnn_layers=4,
        num_lattice_layers=2,
        num_mlp_layers=2,
        dropout=0.2,
        use_batchnorm=True,
        selected_extras=None,      # NEW: lista di nomi feature extra
        extras_dim=0               # NEW: dimensione totale delle extra
    ):
        super().__init__()
        self.use_batchnorm = use_batchnorm
        self.dropout = dropout
        self.selected_extras = selected_extras or []
        self.extras_dim = extras_dim

        # --- Edge encoder (per GINE)
        self.edge_encoder = nn.Sequential(
            nn.Linear(edge_in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        # --- GINE layers
        self.gnn_layers = nn.ModuleList()
        self.gnn_bns = nn.ModuleList() if use_batchnorm else None
        for i in range(num_gnn_layers):
            in_dim = node_in_dim if i == 0 else hidden_dim
            mlp = nn.Sequential(
                nn.Linear(in_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim)
            )
            self.gnn_layers.append(GINEConv(mlp, edge_dim=hidden_dim))
            if use_batchnorm:
                self.gnn_bns.append(nn.BatchNorm1d(hidden_dim))

        # --- Lattice encoder
        lattice_layers = []
        in_dim = lattice_in_dim
        for _ in range(max(1, num_lattice_layers - 1)):
            lattice_layers += [nn.Linear(in_dim, hidden_dim), nn.ReLU()]
            if use_batchnorm:
                lattice_layers.append(nn.BatchNorm1d(hidden_dim))
            in_dim = hidden_dim
        lattice_layers.append(nn.Linear(in_dim, hidden_dim))
        self.lattice_encoder = nn.Sequential(*lattice_layers)

        # --- Final MLP head
        final_in = hidden_dim * 2 + self.extras_dim  # graph pooled + lattice + extras
        mlp_layers = []
        in_dim = final_in
        for _ in range(max(1, num_mlp_layers - 1)):
            mlp_layers += [nn.Linear(in_dim, hidden_dim), nn.ReLU()]
            if use_batchnorm:
                mlp_layers.append(nn.BatchNorm1d(hidden_dim))
            mlp_layers.append(nn.Dropout(p=dropout))
            in_dim = hidden_dim
        mlp_layers.append(nn.Linear(in_dim, num_classes))
        self.final_mlp = nn.Sequential(*mlp_layers)

    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch

        # Encode edges
        e = self.edge_encoder(edge_attr)

        # GNN layers
        for i, conv in enumerate(self.gnn_layers):
            x = conv(x, edge_index, e)
            x = F.relu(x)
            if self.use_batchnorm:
                x = self.gnn_bns[i](x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        # Global pooling
        x_pool = global_mean_pool(x, batch)

        # Lattice encoding (sempre usato)
        lattice = data.lattice.view(-1, 9)
        lattice_feat = self.lattice_encoder(lattice)

        # Extras (abilitate in base alla lista)
        extras = build_extras_tensor(data, self.selected_extras)
        if extras is not None:
            final_in = torch.cat([x_pool, lattice_feat, extras], dim=1)
        else:
            final_in = torch.cat([x_pool, lattice_feat], dim=1)

        out = self.final_mlp(final_in)
        return out

# ==================== Train / Eval helpers ====================

def train_one_epoch(model, loader, criterion, optimizer, device, target_name):
    model.train()
    total_loss = 0.0
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        target = torch.argmax(data[target_name], dim=1).long()
        loss = criterion(out, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / max(1, len(loader))

@torch.no_grad()
def evaluate(model, loader, device, target_name):
    model.eval()
    correct = {1: 0, 3: 0, 5: 0, 10: 0}
    total = 0
    all_preds, all_labels = [], []

    for data in loader:
        data = data.to(device)
        logits = model(data)
        labels = torch.argmax(data[target_name], dim=1).long()
        total += labels.size(0)

        _, pred = logits.topk(10, dim=1)
        for k in correct.keys():
            correct[k] += (pred[:, :k] == labels.view(-1, 1)).any(dim=1).sum().item()

        top1_preds = torch.argmax(logits, dim=1)
        all_preds.append(top1_preds.cpu())
        all_labels.append(labels.cpu())

    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)
    macro_f1 = f1_score(all_labels.numpy(), all_preds.numpy(), average="macro")

    results = {f"top{k}_acc": correct[k] / max(1, total) for k in correct}
    results["macro_f1"] = macro_f1
    return results

There are classes 122 3054


### 2) Define the GNN 

In [24]:

# ==================== Run ablation ====================

node_in_dim = dataset[0].x.shape[1]
edge_in_dim = dataset[0].edge_attr.shape[1]
lattice_in_dim = 9

# Scegli qui le extra da usare nell’ablation:
# Esempio richiesto: ["oms", "atomic_one_hot"]


EXTRA_GETTERS = {
    "atomic_one_hot":      lambda d: _reshape_feat(d.atomic_one_hot, d),
    "space_group_number":  lambda d: _reshape_feat(d.space_group_number, d),
    "crystal_system":      lambda d: _reshape_feat(d.crystal_system, d),
    "oms":                 lambda d: _reshape_feat(d.oms, d),
    "cordinates":          lambda d: _reshape_feat(d.cordinates, d),
}

selected_extras = []

selected_extras = ["atomic_one_hot"]
selected_extras = ["cordinates"]
selected_extras = ["crystal_system"]
selected_extras = ["oms"]
selected_extras = ["space_group_number"]


selected_extras = ["atomic_one_hot","cordinates"]
selected_extras = ["atomic_one_hot", "crystal_system"]
selected_extras = ["atomic_one_hot", "oms"]
selected_extras = ["atomic_one_hot", "space_group_number"]



selected_extras = ["atomic_one_hot", "oms", "cordinates"]
selected_extras = ["atomic_one_hot", "oms", "crystal_system"]
selected_extras = ["atomic_one_hot", "oms", "space_group_number"]



selected_extras = ["atomic_one_hot", "oms", "cordinates", "crystal_system"]
selected_extras = ["atomic_one_hot", "oms", "cordinates", "space_group_number"]
selected_extras = ["atomic_one_hot", "cordinates","crystal_system", "oms","space_group_number"]
# imoprtant: for saving
selected_extras = np.sort(selected_extras).tolist()
selected_extras

['atomic_one_hot', 'cordinates', 'crystal_system', 'oms', 'space_group_number']

In [25]:
# Calcolo dinamico della dimensione delle extra
extras_dim = compute_extras_dim(dataset[0], selected_extras)

# Classi
Y_size = max([torch.argmax(d["metal_salts"]).item() for d in dataset])
num_classes = Y_size + 1

number_of_runs = [1,2,3,4,5]
hidden_dim = 64
dropout = 0.35

results = []
suffix = extras_suffix(selected_extras)

In [27]:
num_classes,extras_dim

(666, 455)

In [None]:

for seed in number_of_runs:
    config_name = f"HID{hidden_dim}_DO{dropout}_SEED{seed}__{suffix}"
    print(f"\n===== Training config: {config_name} =====")
    set_seed(seed)

    model = MetalSaltGNN_Ablation(
        node_in_dim=node_in_dim,
        edge_in_dim=edge_in_dim,
        lattice_in_dim=lattice_in_dim,
        hidden_dim=hidden_dim,
        num_classes=num_classes,
        num_gnn_layers=4,
        num_lattice_layers=2,
        num_mlp_layers=2,
        dropout=dropout,
        use_batchnorm=True,
        selected_extras=selected_extras,
        extras_dim=extras_dim
    ).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
    checkpoint_name = f"trained_models/Metal_salts_{config_name}.pt"

    best_metric = 0.0
    epochs_no_improve = 0
    patience = 50
    eval_every = 5

    for epoch in range(1, 1001):
        _ = train_one_epoch(model, train_loader, criterion, optimizer, device, "metal_salts")
        if epoch % eval_every == 0:
            res = evaluate(model, val_loader, device, "metal_salts")
            print(f"VAL: top1_acc={res['top1_acc']:.4f}, top5_acc={res['top5_acc']:.4f}, top3_acc={res['top3_acc']:.4f} macro_f1={res['macro_f1']:.4f}")

            # Early stopping su top5, come nel tuo codice
            if res["top5_acc"] > best_metric:
                best_metric = res["top5_acc"]
                epochs_no_improve = 0
                torch.save(model.state_dict(), checkpoint_name)
            else:
                epochs_no_improve += eval_every
            if epochs_no_improve >= patience:
                print(f"Early stopping at epoch {epoch}")
                break

    # Valutazione test con il best checkpoint
    model.load_state_dict(torch.load(checkpoint_name, map_location=device))
    res_test = evaluate(model, test_loader, device, "metal_salts")
    results.append({**res_test, 'config': config_name})
    print(f"{config_name} TEST: top1_acc={res_test['top1_acc']:.4f}, top5_acc={res_test['top5_acc']:.4f}, top3_acc={res_test['top3_acc']:.4f} macro_f1={res_test['macro_f1']:.4f}")



===== Training config: HID64_DO0.35_SEED1__atomic_one_hot_cordinates_oms_space_group_number =====
VAL: top1_acc=0.4525, top5_acc=0.6984, top3_acc=0.6328 macro_f1=0.1590
VAL: top1_acc=0.4984, top5_acc=0.7738, top3_acc=0.6918 macro_f1=0.1921
VAL: top1_acc=0.4984, top5_acc=0.7902, top3_acc=0.7016 macro_f1=0.1935
VAL: top1_acc=0.5049, top5_acc=0.8164, top3_acc=0.7082 macro_f1=0.2104
VAL: top1_acc=0.5148, top5_acc=0.8066, top3_acc=0.6984 macro_f1=0.2369
VAL: top1_acc=0.5180, top5_acc=0.7967, top3_acc=0.6984 macro_f1=0.2564
VAL: top1_acc=0.5180, top5_acc=0.8131, top3_acc=0.7148 macro_f1=0.2476
VAL: top1_acc=0.5279, top5_acc=0.8098, top3_acc=0.7246 macro_f1=0.2424
VAL: top1_acc=0.5213, top5_acc=0.8131, top3_acc=0.7213 macro_f1=0.2518
VAL: top1_acc=0.5213, top5_acc=0.8230, top3_acc=0.7246 macro_f1=0.2398
VAL: top1_acc=0.5377, top5_acc=0.8262, top3_acc=0.7311 macro_f1=0.2573
VAL: top1_acc=0.5148, top5_acc=0.8066, top3_acc=0.7082 macro_f1=0.2572
VAL: top1_acc=0.5311, top5_acc=0.8197, top3_acc=0

# evalm

In [7]:
results = []
for seed in number_of_runs:
    config_name = f"HID{hidden_dim}_DO{dropout}_SEED{seed}__{suffix}"
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    model = MetalSaltGNN_Ablation(
        node_in_dim=node_in_dim,
        edge_in_dim=edge_in_dim,
        lattice_in_dim=lattice_in_dim,
        hidden_dim=hidden_dim,
        num_classes=num_classes,
        num_gnn_layers=4,
        num_lattice_layers=2,
        num_mlp_layers=2,
        dropout=dropout,
        use_batchnorm=True,
        selected_extras=selected_extras,
        extras_dim=extras_dim
    ).to(device)

    checkpoint_name = f"trained_models/Metal_salts_{config_name}.pt"
    checkpoint = torch.load(checkpoint_name, map_location=device)
    model.load_state_dict(checkpoint)

    res_test = evaluate(model, test_loader, device, "metal_salts")

    results.append({
        'config': config_name,
        'top1_acc':  res_test['top1_acc'],
        'top3_acc':  res_test['top3_acc'],
        'top5_acc':  res_test['top5_acc'],
        'top10_acc': res_test['top10_acc'],
        'macro_f1':  res_test['macro_f1'],
    })
    
import numpy as np

# results è la tua lista di dict
metrics = ["top1_acc","top3_acc","top5_acc","top10_acc","macro_f1"]

print("==== Media ± Std sulle run ====\n")
for metric in metrics:
    values = [r[metric] for r in results]
    mean = np.mean(values)
    std  = np.std(values)
    print(f"{metric}: {mean:.4f} ± {std:.4f}")

==== Media ± Std sulle run ====

top1_acc: 0.4654 ± 0.0094
top3_acc: 0.6850 ± 0.0177
top5_acc: 0.7797 ± 0.0094
top10_acc: 0.8993 ± 0.0084
macro_f1: 0.2902 ± 0.0066


# real word experiment 

In [8]:
import os
from mofstructure import structure, mofdeconstructor
from mofstructure.filetyper import load_iupac_names
from fairmofsyncondition.read_write import cheminfo2iupac, coords_library, filetyper
from ase.data import atomic_numbers
from ase.io import read
import torch


inchi_corrector = {
    "FDTQOZYHNDMJCM-UHFFFAOYSA-N":"benzene-1,4-dicarboxylic acid"
}
def get_ligand_iupacname(ligand_inchi):
    print(ligand_inchi)
    name = load_iupac_names().get(ligand_inchi, None)
    if name is None:
        pubchem = cheminfo2iupac.pubchem_to_inchikey(ligand_inchi, name='inchikey')
        if pubchem is None:
            name = inchi_corrector.get(ligand_inchi, ligand_inchi)
        else:
            pubchem.get('iupac_name', ligand_inchi)
    return name


def load_system(filename):
    """
    A function to extract
    """
    data = {}
    os.makedirs('LigandsXYZ', exist_ok=True)
    ase_data = read(filename)
    structure_data = structure.MOFstructure(ase_atoms=ase_data)
    _, ligands = structure_data.get_ligands()

    inchikeys = [ligand.info.get('inchikey') for ligand in ligands]
    for inchi, ligand in zip(inchikeys, ligands):
        ligand.write(f'LigandsXYZ/{inchi}.xyz')
    ligands_names = [get_ligand_iupacname(i) for i in inchikeys]
    general = structure_data.get_oms()
    oms = general.get('has_oms')
    metal_symbols = general.get('metals')
    metals_atomic_number = [atomic_numbers[i] for i in metal_symbols]
    torch_data = coords_library.ase_to_pytorch_geometric(ase_data)
    oms = torch.tensor([1 if  oms else 0], dtype=torch.int16)
    torch_data.oms = oms
    data['general'] = general
    return torch_data, metals_atomic_number, inchikeys, ligands_names

# load cif structure

In [9]:
torch_data, metal_atomic_number, inch, lig = load_system("EDUSIF.cif")

  #1 :Accepted unusual valence(s): Zn(1); Metal was disconnected
  #1 :Accepted unusual valence(s): Zn(1); Metal was disconnected
  #1 :Accepted unusual valence(s): O(1)
  #1 :Accepted unusual valence(s): O(1)


FDTQOZYHNDMJCM-UHFFFAOYSA-N


  with zopen(filename, mode=mode) as file:


In [10]:
torch_data

Data(x=[424, 4], edge_index=[2, 512], edge_attr=[512, 1], lattice=[3, 3], oms=[1])

In [None]:
torch_data, metal_atomic_number, inch, lig = load_system("Zn2C43N6H29O8.cif")

In [50]:
torch_data, metal_atomic_number, inch, lig = load_system("ZnC5HO5.cif")

  #1 :Accepted unusual valence(s): O(1)
  #1 :Accepted unusual valence(s): O(1)


RESFABUZNSNZRT-UHFFFAOYSA-N


2025-09-07 15:25:37,259 - INFO - 'PUGREST.NotFound: No CID found that matches the given InChI key'
  with zopen(filename, mode=mode) as file:


In [53]:
lig

['RESFABUZNSNZRT-UHFFFAOYSA-N']

# add additional info
## atomic one-hot oms cordinates crystal_system space_group_number 

In [None]:
d = torch_data

# =======================
# Parte 1: atomic one-hot
# =======================
node_features = d.x.numpy()
atom_num = node_features[:, 0].astype(int)
a, b = np.unique(atom_num, return_counts=True)
emb = torch.zeros(120)
for aa, bb in zip(a, b):
    emb[aa] = bb
d.atomic_one_hot = emb



# =======================
# Parte 2: struttura ASE
# =======================
ase_atoms = pytorch_geometric_to_ase(d)
stru = MOFstructure(ase_atoms)
pymat = AseAtomsAdaptor.get_structure(ase_atoms)

# =======================
# Parte 3: OMS
# =======================
emb = torch.zeros(96)
tmp_dict = dict()
for i in stru.get_oms()["metal_info"]:
    cord = i["coordination_number"]
    metal = i["metal"]

    if metal in tmp_dict:
        if cord > tmp_dict[metal]:
            tmp_dict[metal] = cord
    else:
        tmp_dict[metal] = cord

for i, j in tmp_dict.items():
    emb[convert_metals[i]] = j
d.cordinates = emb

# =======================
# Parte 4: spazio e sistema cristallino
# =======================
sga = SpacegroupAnalyzer(pymat)
space_group_number = sga.get_space_group_number()
emb = torch.zeros(231)
emb[space_group_number] = 1
d.space_group_number = emb

get_crystal_system = sga.get_crystal_system()
emb = torch.zeros(7)
emb[convert_struct[get_crystal_system]] = 1
d.crystal_system = emb
# =======================
# Parte 5: altri attributi
# =======================
d.oms = d.oms.view(1, 1).float()


  return int(self._space_group_data["number"])


In [28]:
d.oms

tensor([[0.]], device='cuda:0')

In [16]:
kept = []
pred = model(d.to(device))
for i in pred.argsort()[0]:
    a = filetyper.category_names()["metal_salts"][i.item()]
    kept.append(a)

In [None]:
# prepare a file showing the names of metal_salts we removed.

# create a latex table in the ESA

666

In [None]:
pred = model(d.to(device))
for i in pred.argsort()[0][0:10]:
    a = filetyper.category_names()["metal_salts"][i.item()]
    print(a)

CdBr2
CdCl2·2.5H2O
[Mn3(μ3-O)(O2CPh)6(py)2(H2O)]
Ni–H4TCPP
CuSCN
InC20H28NO13
K2O
CrCl3.6H2O
CuCl2·4H2O
YbCl3·6H2O


In [19]:
extras_dim

448

In [56]:
pred = model(d.to(device))
prx = "Zn"

c = 0
for i in pred.argsort()[0]:
    a = filetyper.category_names()["metal_salts"][i.item()]
    if a[:2] == prx:
        print(a)
        c = c +1  
    if c == 5:
        break

ZnBr2
ZnC34H28N6O5
Zn(OAc)2·3H2O
ZnEt2
Zn(OAc)2·6H2O


In [None]:
inch

In [None]:
pred.argmax()