In [1]:
from load_data import load_pyg_obj
import torch
from torch_geometric.loader import DataLoader
from torch.utils.data import random_split
import random
import numpy as np
from sklearn.metrics import f1_score
from utils import fix_target_shapes,remove_unused_onehot_columns,set_seed,filter_metals
from mofstructure import mofdeconstructor

from fairmofsyncondition.read_write.coords_library import pytorch_geometric_to_ase

from pymatgen.io.ase import AseAtomsAdaptor
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer

convert_struct = {'cubic':0, 'hexagonal':1, 'monoclinic':2, 'orthorhombic':3, 'tetragonal':4,'triclinic':5, 'trigonal':6}

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


# GNN for Metal Salt Prediction

This code implements a **Graph Neural Network (GNN)** to predict metal salts, using:
- **Node Features**  
- **Edge Features**  
- **Lattice**   
- **Oms**
- **Atomic number**
---

## Code Structure

1. **Load the Data**  
   Import and prepare the dataset for use in the model.

2. **Define the GNN Model**  
   Define the neural network architecture (layers, activation functions, etc.).

3. **Train the Model**  
   - Train the model on the dataset.  
   - Save the trained GNN weights into the `tmp/` folder.  

   ⚠️ **Note**:  
   If you only want to **test the model** without re-training, you can **skip this section** and avoid running the training step.

4. **Load and Evaluate the Model**  
   - Load the trained model weights.  
   - Evaluate the model performance.  


### 1) Load the Data


### to get oms
### stru.get_oms()["has_oms"]


In [2]:
# from mofstructure.structure import MOFstructure
# convert_metals = {j:i for i,j in enumerate(mofdeconstructor.transition_metals()[1:])}


# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# set_seed(seed=42)

# data_in = load_pyg_obj(path_to_mdb="../../data/mof_syncondition_data/")
# dataset = fix_target_shapes(data_in, "metal_salts")
# dataset = remove_unused_onehot_columns(dataset, "metal_salts")


# bad = []
# good = []

# c = 0
# for d in dataset:
#     if c % 100 == 0:
#         print(c)
#     c = c + 1

#     try:
#         # =======================
#         # Parte 1: atomic one-hot
#         # =======================
#         node_features = d.x.numpy()
#         atom_num = node_features[:, 0].astype(int)
#         a, b = np.unique(atom_num, return_counts=True)
#         emb = torch.zeros(120)
#         for aa, bb in zip(a, b):
#             emb[aa] = bb
#         d.atomic_one_hot = emb

#         # =======================
#         # Parte 2: struttura ASE
#         # =======================
#         ase_atoms = pytorch_geometric_to_ase(d)
#         stru = MOFstructure(ase_atoms)
#         pymat = AseAtomsAdaptor.get_structure(ase_atoms)

#         # =======================
#         # Parte 3: OMS
#         # =======================
#         emb = torch.zeros(96)
#         tmp_dict = dict()
#         for i in stru.get_oms()["metal_info"]:
#             cord = i["coordination_number"]
#             metal = i["metal"]

#             if metal in tmp_dict:
#                 if cord > tmp_dict[metal]:
#                     tmp_dict[metal] = cord
#             else:
#                 tmp_dict[metal] = cord

#         for i, j in tmp_dict.items():
#             emb[convert_metals[i]] = j
#         d.cordinates = emb

#         # =======================
#         # Parte 4: spazio e sistema cristallino
#         # =======================
#         sga = SpacegroupAnalyzer(pymat)
#         space_group_number = sga.get_space_group_number()
#         emb = torch.zeros(231)
#         emb[space_group_number] = 1
#         d.space_group_number = emb

#         get_crystal_system = sga.get_crystal_system()
#         emb = torch.zeros(7)
#         emb[convert_struct[get_crystal_system]] = 1
#         d.crystal_system = emb

#         # =======================
#         # Parte 5: altri attributi
#         # =======================
#         d.oms = d.oms.view(1, 1).float()

#         por = stru.get_porosity()
#         por = list(por.values())
#         d.porosity = torch.tensor(por)

#         d.modified_scherrer = None
#         d.microstrain = None

#         # Se arrivo qui senza eccezioni → struttura buona
#         good.append(d)

#     except Exception:
#         bad.append(d)
#         continue


# import torch

# torch.save(good, "dataset_cleen_all_info.pt")   # salva lista di Data
# # poi
# dataset = torch.load("dataset_cleen_all_info.pt")

In [11]:
import os

os.listdir("../../")

['new_data',
 'dinga',
 'fairmofsyncondition',
 'officialGIT',
 'Interactive_tool.png',
 'new_data_v2',
 'dataset_cleen_all_info.pt',
 'test_data.lmdb',
 'GIT',
 'disegno-1.svg']

In [13]:
from mofstructure.structure import MOFstructure
convert_metals = {j:i for i,j in enumerate(mofdeconstructor.transition_metals()[1:])}


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
set_seed(seed=42)
dataset = torch.load("../../dataset_cleen_all_info.pt")



Y = []
for d in dataset:
    Y.append(d.metal_salts.argmax(dim=1).item())
    
a,b = np.unique(Y,return_counts=True)
conv_y = {i:j for i,j in zip(a,b)}

good = []
for d in dataset:
    if conv_y[d.metal_salts.argmax(dim=1).item()] > 5:
        good.append(d)
len(good)

Y2 = []
for d in good:
    Y2.append(d.metal_salts.argmax(dim=1).item())
    
    
print("There are classess",len(np.unique(Y2)), len(Y2))


dataset = good

# Train/val/test split
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    dataset, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(42)
)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128)
test_loader = DataLoader(test_dataset, batch_size=128)

There are classess 122 3054


### 2) Define the GNN 

In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GINEConv, global_mean_pool,GINConv


# ==================== Model ====================

class MetalSaltGNN_v2(nn.Module):
    def __init__(
        self,
        node_in_dim=4,     # da x.shape[1]
        edge_in_dim=1,     # da edge_attr.shape[1]
        lattice_in_dim=9,  # 3x3 flatten
        hidden_dim=128,
        num_classes=10,
        num_gnn_layers=4,
        num_lattice_layers=2,
        num_mlp_layers=2,
        dropout=0.2,
        use_batchnorm=True
    ):
        super().__init__()
        self.use_batchnorm = use_batchnorm
        self.dropout = dropout

        # --- Edge encoder (per GINE)
        self.edge_encoder = nn.Sequential(
            nn.Linear(edge_in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        # --- GINE layers
        self.gnn_layers = nn.ModuleList()
        self.gnn_bns = nn.ModuleList() if use_batchnorm else None
        for i in range(num_gnn_layers):
            in_dim = node_in_dim if i == 0 else hidden_dim
            mlp = nn.Sequential(
                nn.Linear(in_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim)
            )
            self.gnn_layers.append(GINEConv(mlp, edge_dim=hidden_dim))
            if use_batchnorm:
                self.gnn_bns.append(nn.BatchNorm1d(hidden_dim))

        # --- Lattice encoder
        lattice_layers = []
        in_dim = lattice_in_dim
        for _ in range(max(1, num_lattice_layers - 1)):
            lattice_layers += [nn.Linear(in_dim, hidden_dim), nn.ReLU()]
            if use_batchnorm:
                lattice_layers.append(nn.BatchNorm1d(hidden_dim))
            in_dim = hidden_dim
        lattice_layers.append(nn.Linear(in_dim, hidden_dim))
        self.lattice_encoder = nn.Sequential(*lattice_layers)

        # --- Final MLP head
        # Extra graph-level features
        extra_dim = 120 + 231 + 7 + 1 + 8 + 96
        final_in = hidden_dim * 2 + extra_dim

        mlp_layers = []
        in_dim = final_in
        for _ in range(max(1, num_mlp_layers - 1)):
            mlp_layers += [nn.Linear(in_dim, hidden_dim), nn.ReLU()]
            if use_batchnorm:
                mlp_layers.append(nn.BatchNorm1d(hidden_dim))
            mlp_layers.append(nn.Dropout(p=dropout))
            in_dim = hidden_dim
        mlp_layers.append(nn.Linear(in_dim, num_classes))
        self.final_mlp = nn.Sequential(*mlp_layers)

    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        lattice = data.lattice

        # Per-graph extras
        atomic_oh = data.atomic_one_hot.view(-1, 120)
        sg_oh = data.space_group_number.view(-1, 231)
        cs_oh = data.crystal_system.view(-1, 7)
        oms = data.oms.view(-1, 1)
        porosity = data.porosity.view(-1, 8)
        coords = data.cordinates.view(-1, 96)

        # Encode edges
        e = self.edge_encoder(edge_attr)

        # GNN layers
        for i, conv in enumerate(self.gnn_layers):
            x = conv(x, edge_index, e)
            x = F.relu(x)
            if self.use_batchnorm:
                x = self.gnn_bns[i](x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        # Global pooling
        x_pool = global_mean_pool(x, batch)

        # Lattice encoding
        lattice = lattice.view(-1, 9)
        lattice_feat = self.lattice_encoder(lattice)

        # Concatenate everything
        extras = torch.cat([atomic_oh, sg_oh, cs_oh, oms, porosity, coords], dim=1)
        final_in = torch.cat([x_pool, lattice_feat, extras], dim=1)

        out = self.final_mlp(final_in)
        return out


# ==================== Train / Eval helpers ====================
def train(model, loader, criterion, optimizer, device, target_name):
    model.train()
    total_loss = 0.0
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        # target expected one-hot per-graph -> class indices
        target = torch.argmax(data[target_name], dim=1).long()
        loss = criterion(out, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / max(1, len(loader))

@torch.no_grad()
def evaluate(model, loader, device, target_name):
    model.eval()
    correct = {1: 0, 3: 0, 5: 0, 10: 0}
    total = 0
    all_preds, all_labels = [], []

    for data in loader:
        data = data.to(device)
        logits = model(data)
        labels = torch.argmax(data[target_name], dim=1).long()
        total += labels.size(0)

        # Top-k accuracy
        _, pred = logits.topk(10, dim=1)
        for k in correct.keys():
            correct[k] += (pred[:, :k] == labels.view(-1, 1)).any(dim=1).sum().item()

        # Per F1: usiamo solo la predizione top1
        top1_preds = torch.argmax(logits, dim=1)
        all_preds.append(top1_preds.cpu())
        all_labels.append(labels.cpu())

    # Concatena
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    macro_f1 = f1_score(all_labels.numpy(), all_preds.numpy(), average="macro")

    results = {f"top{k}_acc": correct[k] / max(1, total) for k in correct}
    results["macro_f1"] = macro_f1
    return results



In [16]:
node_in_dim = dataset[0].x.shape[1]
edge_in_dim = dataset[0].edge_attr.shape[1]
lattice_in_dim = 9

number_of_runs = [1,2,3]  # 

hidden_dim = 64
dropout = 0.35


Y_size = max([torch.argmax(d["metal_salts"]).item() for d in dataset])
num_classes = Y_size + 1
num_classes

666

### 3) Train the GNN

In [10]:
results = []
for seed in number_of_runs:
    config_name = f"HID{hidden_dim}_DO{dropout}_SEED{seed}__X_edgeAttr_lattice_oms_AtomicNumber_structsym"
    print(f"\n===== Training config: {config_name} =====")
    set_seed(seed)

    # ==================== Model ====================
    model = MetalSaltGNN_v2(
        node_in_dim=node_in_dim,
        edge_in_dim=edge_in_dim,
        lattice_in_dim=lattice_in_dim,
        hidden_dim=hidden_dim,
        num_classes=num_classes, 
        num_gnn_layers=4,
        num_lattice_layers=2,
        num_mlp_layers=2,
        dropout=dropout,
        use_batchnorm=True
    ).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3,weight_decay=0.0001)
    checkpoint_name = f"tmp2/Metal_salts_{config_name}.pt"
    best_metric = 0.0
    epochs_no_improve = 0
    patience = 50
    eval_every = 5

    for epoch in range(1, 1001):
        train(model, train_loader, criterion, optimizer, device, "metal_salts")
        if epoch % eval_every == 0:
            res = evaluate(model, val_loader, device, "metal_salts")
            print(f"VAL: top1_acc={res['top1_acc']:.4f}, top5_acc={res['top5_acc']:.4f}, top3_acc={res['top3_acc']:.4f} macro_f1={res['macro_f1']:.4f}")
            
            if res["top5_acc"] > best_metric:
                best_metric = res["top5_acc"]
                epochs_no_improve = 0
                torch.save(model.state_dict(), checkpoint_name)
            else:
                epochs_no_improve += eval_every
            if epochs_no_improve >= patience:
                print(f"Early stopping at epoch {epoch}")
                break

    model.load_state_dict(torch.load(checkpoint_name))
    res_test = evaluate(model, test_loader, device, "metal_salts")
    results.append({**res_test, 'config': config_name})
    print(f"{config_name} TEST: top1_acc={res_test['top1_acc']:.4f}, top5_acc={res_test['top5_acc']:.4f}, top3_acc={res_test['top3_acc']:.4f} macro_f1={res_test['macro_f1']:.4f}")


===== Training config: HID64_DO0.35_SEED1__X_edgeAttr_lattice_oms_AtomicNumber_structsym =====
VAL: top1_acc=0.2984, top5_acc=0.5082, top3_acc=0.4689 macro_f1=0.0446
VAL: top1_acc=0.4033, top5_acc=0.6623, top3_acc=0.5836 macro_f1=0.1178
VAL: top1_acc=0.4754, top5_acc=0.6951, top3_acc=0.6262 macro_f1=0.1754
VAL: top1_acc=0.4656, top5_acc=0.7213, top3_acc=0.6328 macro_f1=0.1563
VAL: top1_acc=0.4820, top5_acc=0.7410, top3_acc=0.6623 macro_f1=0.1809
VAL: top1_acc=0.4852, top5_acc=0.7443, top3_acc=0.6656 macro_f1=0.2034
VAL: top1_acc=0.4820, top5_acc=0.7574, top3_acc=0.6787 macro_f1=0.2044
VAL: top1_acc=0.4787, top5_acc=0.7508, top3_acc=0.6557 macro_f1=0.1861
VAL: top1_acc=0.4852, top5_acc=0.7508, top3_acc=0.6459 macro_f1=0.1841
VAL: top1_acc=0.4852, top5_acc=0.7738, top3_acc=0.6787 macro_f1=0.1919
VAL: top1_acc=0.5082, top5_acc=0.7902, top3_acc=0.6656 macro_f1=0.2570
VAL: top1_acc=0.4951, top5_acc=0.7902, top3_acc=0.7049 macro_f1=0.2425
VAL: top1_acc=0.5082, top5_acc=0.7738, top3_acc=0.69

In [43]:
dataset[0]

Data(x=[92, 4], edge_index=[2, 124], edge_attr=[124, 1], lattice=[3, 3], metal_salts=[1, 667], ligands=[2, 1743], solvents=[1, 139], oms=[1, 1], atomic_one_hot=[120], cordinates=[96], space_group_number=[231], crystal_system=[7], porosity=[8])

### 4) Load and Evaluate the Model

In [38]:
results = []

for seed in number_of_runs:
    config_name = f"HID{hidden_dim}_DO{dropout}_SEED{seed}__X_edgeAttr_lattice_oms_AtomicNumber_structsym"
    print(f"\n===== Evaluating config: {config_name} =====")
    

    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    
    model = MetalSaltGNN_v2(
        node_in_dim=node_in_dim,
        edge_in_dim=edge_in_dim,
        lattice_in_dim=lattice_in_dim,
        hidden_dim=hidden_dim,
        num_classes=num_classes, 
        num_gnn_layers=4,
        num_lattice_layers=2,
        num_mlp_layers=2,
        dropout=dropout,
        use_batchnorm=True
    ).to(device)
    
    checkpoint_name = f"tmp2/Metal_salts_{config_name}.pt"
    # Carica best model e valuta su test
    checkpoint = torch.load(checkpoint_name, map_location=device)
    model.load_state_dict(checkpoint)

    res_test = evaluate(model, test_loader, device, "metal_salts")

    results.append({
        'config': config_name,
        'top1_acc': res_test['top1_acc'],
        'top10_acc': res_test['top10_acc'],
        'top5_acc': res_test['top5_acc'],
        'top3_acc': res_test['top3_acc'],
        'macro_f1': res_test['macro_f1']
    })
    print(f"{config_name} TEST: top10_acc={res_test['top10_acc']:.4f}, top5_acc={res_test['top5_acc']:.4f}, top3_acc={res_test['top3_acc']:.4f}, macro_f1={res_test['macro_f1']:.4f}")


===== Evaluating config: HID64_DO0.35_SEED1__X_edgeAttr_lattice_oms_AtomicNumber_structsym =====
HID64_DO0.35_SEED1__X_edgeAttr_lattice_oms_AtomicNumber_structsym TEST: top10_acc=0.8954, top5_acc=0.7614, top3_acc=0.6601, macro_f1=0.2764

===== Evaluating config: HID64_DO0.35_SEED2__X_edgeAttr_lattice_oms_AtomicNumber_structsym =====
HID64_DO0.35_SEED2__X_edgeAttr_lattice_oms_AtomicNumber_structsym TEST: top10_acc=0.8824, top5_acc=0.7484, top3_acc=0.6503, macro_f1=0.2492

===== Evaluating config: HID64_DO0.35_SEED3__X_edgeAttr_lattice_oms_AtomicNumber_structsym =====
HID64_DO0.35_SEED3__X_edgeAttr_lattice_oms_AtomicNumber_structsym TEST: top10_acc=0.8954, top5_acc=0.7647, top3_acc=0.6797, macro_f1=0.2827


In [None]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
set_seed(seed=42)

data_in = load_pyg_obj(path_to_mdb="../../data/mof_syncondition_data/")
dataset = fix_target_shapes(data_in, "metal_salts")
dataset = remove_unused_onehot_columns(dataset, "metal_salts")

In [None]:
one_hot = torch.nn.functional.one_hot(torch.tensor(idx), num_classes=len(unique_atomic_numbers)).float()

In [None]:
atom_num

In [None]:
np.unique(atom_num,return_counts=True)

In [None]:

[0,8,0,0,0,..,32,0,4,0,0,0,0,0,0] # 1

In [None]:
from ase.data import atomic_numbers

atomic_numbers

In [None]:


a, _ = np.unique(filter_metals(d.x.numpy()[:,0].astype(int)), return_counts=True)

In [None]:
filter_metals(d.x.numpy()[:,0].astype(int))

In [None]:
from mofstructure.mofdeconstructor import transition_metals


from ase.data import atomic_numbers

len(transition_metals())


In [None]:
from ase.data import atomic_numbers

atomic_numbers


[atomic_numbers[i] for i in transition_metals()]

In [None]:


unique_atomic_numbers = np.unique(atomic_numb)
atomic_to_idx = {num: idx for idx, num in enumerate(unique_atomic_numbers)}

for data in dataset:
    node_features = data.x.numpy()
    atomic_number = filter_metals(node_features[:, 0].astype(int))
    atomic_number = np.unique(atomic_number)[0]
    idx = atomic_to_idx[atomic_number]
    one_hot = torch.nn.functional.one_hot(torch.tensor(idx), num_classes=len(unique_atomic_numbers)).float()
    data.atomic_one_hot = one_hot

In [None]:
from ase.data import chemical_symbols

In [None]:
chemical_symbols[25]

# real word experiment 

In [17]:
import os
from mofstructure import structure, mofdeconstructor
from mofstructure.filetyper import load_iupac_names
from fairmofsyncondition.read_write import cheminfo2iupac, coords_library, filetyper
from ase.data import atomic_numbers
from ase.io import read
import torch


inchi_corrector = {
    "FDTQOZYHNDMJCM-UHFFFAOYSA-N":"benzene-1,4-dicarboxylic acid"
}
def get_ligand_iupacname(ligand_inchi):
    print(ligand_inchi)
    name = load_iupac_names().get(ligand_inchi, None)
    if name is None:
        pubchem = cheminfo2iupac.pubchem_to_inchikey(ligand_inchi, name='inchikey')
        if pubchem is None:
            name = inchi_corrector.get(ligand_inchi, ligand_inchi)
        else:
            pubchem.get('iupac_name', ligand_inchi)
    return name


def load_system(filename):
    """
    A function to extract
    """
    data = {}
    os.makedirs('LigandsXYZ', exist_ok=True)
    ase_data = read(filename)
    structure_data = structure.MOFstructure(ase_atoms=ase_data)
    _, ligands = structure_data.get_ligands()

    inchikeys = [ligand.info.get('inchikey') for ligand in ligands]
    for inchi, ligand in zip(inchikeys, ligands):
        ligand.write(f'LigandsXYZ/{inchi}.xyz')
    ligands_names = [get_ligand_iupacname(i) for i in inchikeys]
    general = structure_data.get_oms()
    oms = general.get('has_oms')
    metal_symbols = general.get('metals')
    metals_atomic_number = [atomic_numbers[i] for i in metal_symbols]
    torch_data = coords_library.ase_to_pytorch_geometric(ase_data)
    oms = torch.tensor([1 if  oms else 0], dtype=torch.int16)
    torch_data.oms = oms
    data['general'] = general
    return torch_data, metals_atomic_number, inchikeys, ligands_names

In [18]:
torch_data, metal_atomic_number, inch, lig = load_system("EDUSIF.cif")

  #1 :Accepted unusual valence(s): Zn(1); Metal was disconnected
  #1 :Accepted unusual valence(s): Zn(1); Metal was disconnected
  #1 :Accepted unusual valence(s): O(1)
  #1 :Accepted unusual valence(s): O(1)


FDTQOZYHNDMJCM-UHFFFAOYSA-N


  with zopen(filename, mode=mode) as file:


In [26]:
torch_data, metal_atomic_number, inch, lig = load_system("Zn2C43N6H29O8.cif")

  Failed to kekulize aromatic bonds in OBMol::PerceiveBondOrders

  #1 :Accepted unusual valence(s): O(1); C(3)
  #1 :Accepted unusual valence(s): O(1); C(3)


MGFJDEHFNMWYBD-OWOJBTEDSA-N
VIORWCNXRPKALR-UHFFFAOYSA-N


2025-08-22 16:21:27,251 - INFO - 'PUGREST.NotFound: No CID found that matches the given InChI key'
  with zopen(filename, mode=mode) as file:


In [30]:
torch_data, metal_atomic_number, inch, lig = load_system("ZnC5HO5.cif")

  #1 :Accepted unusual valence(s): O(1)
  #1 :Accepted unusual valence(s): O(1)


RESFABUZNSNZRT-UHFFFAOYSA-N


2025-08-22 16:21:54,445 - INFO - 'PUGREST.NotFound: No CID found that matches the given InChI key'


In [31]:
d = torch_data

In [None]:
from mofstructure.structure import MOFstructure
convert_metals = {j:i for i,j in enumerate(mofdeconstructor.transition_metals()[1:])}


# Parte 1: atomic one-hot
# =======================
node_features = d.x.numpy()
atom_num = node_features[:, 0].astype(int)
a, b = np.unique(atom_num, return_counts=True)
emb = torch.zeros(120)
for aa, bb in zip(a, b):
    emb[aa] = bb
d.atomic_one_hot = emb

# =======================
# Parte 2: struttura ASE
# =======================
ase_atoms = pytorch_geometric_to_ase(d)
stru = MOFstructure(ase_atoms)
pymat = AseAtomsAdaptor.get_structure(ase_atoms)

# =======================
# Parte 3: OMS
# =======================
emb = torch.zeros(96)
tmp_dict = dict()
for i in stru.get_oms()["metal_info"]:
    cord = i["coordination_number"]
    metal = i["metal"]

    if metal in tmp_dict:
        if cord > tmp_dict[metal]:
            tmp_dict[metal] = cord
    else:
        tmp_dict[metal] = cord

for i, j in tmp_dict.items():
    emb[convert_metals[i]] = j
d.cordinates = emb

# =======================
# Parte 4: spazio e sistema cristallino
# =======================
sga = SpacegroupAnalyzer(pymat)
space_group_number = sga.get_space_group_number()
emb = torch.zeros(231)
emb[space_group_number] = 1
d.space_group_number = emb

get_crystal_system = sga.get_crystal_system()
emb = torch.zeros(7)
emb[convert_struct[get_crystal_system]] = 1
d.crystal_system = emb

# =======================
# Parte 5: altri attributi
# =======================
###
to get oms
stru.get_oms()["has_oms"]

d.oms = d.oms.view(1, 1).float()

por = stru.get_porosity()
por = list(por.values())
d.porosity = torch.tensor(por)

d.modified_scherrer = None
d.microstrain = None


True None
Reading input file: tmp.cssr
Radii analysis: the smallest atom r = 1.09 while the largest atoms r = 1.7.
Box dimensions:
  va=(6.862800 0 0)
  vb=(0.000000 25.902900 0)
  vc=(0.000000 -12.951450 22.432569)

Total particles = 1602

Internal grid size = (3 10 9)

Using voro++ with radii for particles.
Performing Voronoi decomposition.
Volume check:
  Total domain volume  = 3987.757607
  Total Voronoi volume = 3987.757607
Voronoi decomposition finished. Rerouting Voronoi network information.
Finished rerouting information.
Voronoi network with 9598 nodes. 891 of them are accessible. 

Finding channels and pockets in Dijkstra network of 9598 node(s). 891 are expected to compose pores.
Analyzed and assigned 9598 nodes.
Identified 3 channels and 0 pockets.
891 nodes assigned to pores. 
Radii analysis: the smallest atom r = 1.09 while the largest atoms r = 1.7.
Box dimensions:
  va=(6.862800 0 0)
  vb=(0.000000 25.902900 0)
  vc=(0.000000 -12.951450 22.432569)

Total particles = 160

In [39]:
pred = model(d.to(device))
prx = "Zn"

c = 0
for i in pred.argsort()[0]:
    a = filetyper.category_names()["metal_salts"][i.item()]
    
    if a[:2] == prx:
        print(a)
        c = c +1 
        
    if c == 5:
        break

ZnC13H11O8NP
ZnCl2·4H2O
Zn(CH3COO)2·2H2O
ZnSO4.7H2O
ZnSiF6


True

In [None]:
inch

In [None]:
pred.argmax()