In [19]:
from load_data import load_pyg_obj
import torch
from torch_geometric.loader import DataLoader
from torch.utils.data import random_split


import random
import numpy as np
from sklearn.metrics import f1_score
from utils import fix_target_shapes,remove_unused_onehot_columns,set_seed,filter_metals
from mofstructure import mofdeconstructor

from fairmofsyncondition.read_write.coords_library import pytorch_geometric_to_ase

from pymatgen.io.ase import AseAtomsAdaptor
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer

convert_struct = {'cubic':0, 'hexagonal':1, 'monoclinic':2, 'orthorhombic':3, 'tetragonal':4,'triclinic':5, 'trigonal':6}

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# This code gives you tha name of not used metal salts

In [20]:
from mofstructure.structure import MOFstructure
from tqdm import tqdm
import os
convert_metals = {j:i for i,j in enumerate(mofdeconstructor.transition_metals()[1:])}


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
set_seed(seed=42)

data_in = load_pyg_obj(path_to_mdb="../../data/mof_syncondition_data/")
dataset = fix_target_shapes(data_in, "metal_salts")
dataset = remove_unused_onehot_columns(dataset, "metal_salts")


bad = []
good = []

if os.path.exists("../../dataset_cleen_all_info.pt"):
    print("Loading precomputed dataset...")
    dataset = torch.load("../../dataset_cleen_all_info.pt")
else:
    print("Computing dataset with all info...")
    for d in tqdm(dataset):
        try:
            # =======================
            # Parte 1: atomic one-hot
            # =======================
            node_features = d.x.numpy()
            atom_num = node_features[:, 0].astype(int)
            a, b = np.unique(atom_num, return_counts=True)
            emb = torch.zeros(120)
            for aa, bb in zip(a, b):
                emb[aa] = bb
            d.atomic_one_hot = emb
            
            # =======================
            # Parte 2: struttura ASE
            # =======================
            ase_atoms = pytorch_geometric_to_ase(d)
            stru = MOFstructure(ase_atoms)
            pymat = AseAtomsAdaptor.get_structure(ase_atoms)
            
            # =======================
            # Parte 3: OMS
            # =======================
            emb = torch.zeros(96)
            tmp_dict = dict()
            for i in stru.get_oms()["metal_info"]:
                cord = i["coordination_number"]
                metal = i["metal"]

                if metal in tmp_dict:
                    if cord > tmp_dict[metal]:
                        tmp_dict[metal] = cord
                else:
                    tmp_dict[metal] = cord

            for i, j in tmp_dict.items():
                emb[convert_metals[i]] = j
            d.cordinates = emb
            
            # =======================
            # Parte 4: spazio e sistema cristallino
            # =======================
            sga = SpacegroupAnalyzer(pymat)
            space_group_number = sga.get_space_group_number()
            emb = torch.zeros(231)
            emb[space_group_number] = 1
            d.space_group_number = emb

            get_crystal_system = sga.get_crystal_system()
            emb = torch.zeros(7)
            emb[convert_struct[get_crystal_system]] = 1
            d.crystal_system = emb
            # =======================
            # Parte 5: altri attributi
            # =======================
            d.oms = d.oms.view(1, 1).float()

            ###################### no porosity is too long to compute
            #por = stru.get_porosity()
            #por = list(por.values())
            #d.porosity = torch.tensor(por)

            d.modified_scherrer = None
            d.microstrain = None

            # Se arrivo qui senza eccezioni → struttura buona
            good.append(d)
        except Exception:
            bad.append(d)
            continue
    torch.save(good, "../../dataset_cleen_all_info.pt")   # salva lista di Data

Loading precomputed dataset...


In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch_geometric.nn import GINEConv, global_mean_pool
from sklearn.metrics import f1_score

# ==================== Utils & setup ====================

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def set_seed(seed=42):
    import random, os
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Filtra classi rare (come nel tuo codice)
Y = [d.metal_salts.argmax(dim=1).item() for d in dataset]
a,b = np.unique(Y, return_counts=True)
conv_y = {i:j for i,j in zip(a,b)}

In [22]:
bad = [d for d in dataset if conv_y[d.metal_salts.argmax(dim=1).item()] <= 5]
#good = [d for d in dataset if conv_y[d.metal_salts.argmax(dim=1).item()] > 5]


from fairmofsyncondition.read_write import cheminfo2iupac, coords_library, filetyper

In [23]:
names = []
for b in bad:
    names.append(filetyper.category_names()["metal_salts"][b.metal_salts.argmax(dim=1).item()])
names,b = np.unique(names, return_counts=True)

In [26]:
with open("not_handeled_metal_salts.txt", "w") as f:
    for name in names:
        f.write(name + "\n")


5