In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys
import json
import shutil
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from ase.io import read
from ase.visualize.plot import plot_atoms
from tqdm import tqdm
from pathlib import Path
from collections import Counter
from rdkit.Chem import AllChem
from rdkit import Chem
from rdkit.Chem import RDConfig
sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score'))
import sascorer

In [None]:
plt.rcParams['font.sans-serif'] = "Arial"
plt.rcParams['font.family'] = "sans-serif"
plt.rcParams["font.size"] = 15
plt.rcParams["xtick.major.size"] = 0
plt.rcParams["ytick.major.size"] = 0

# Construction of the generated MOFs by PORMAKE

After executing a `test` of reinforcement learning, the results will be saved in JSON format at the `log_dir` directory. The JSON file contains information on the MOFs generated from a test set of 10,000 data, including:
```
$ python run_reinforce.py with v0_qkh_round3 log_dir=test test_only=True load_path=model/reinforce/best_v0_qkh_round3.ckpt
```

`rewards` : the rewards of generated MOFs

`pres` : the estimated target values of generated MOFs by the predictor

`gen_sms` : the smiles of organic linkers of generated MOFs

`gen_mcs` : the metal clusters of generated MOFs

`gen_topos` : the topologies of generated MOFs

In [None]:
path_json = "test/results_v0_qkh_round3_seed0_from_best_v0_qkh_round3.json"

In [None]:
results_optimized = json.load(open(path_json))
ret = np.array( list(zip(*results_optimized["preds"])) +
    [
        results_optimized["gen_topos"],
        results_optimized["gen_mcs"],
        results_optimized["gen_sms"],
    ]
).T
len(ret)

## 0. Analysis of generated MOFs for top 500

In [None]:
top_n = 500
ret = np.unique(ret, axis=0)
sorted_ret = ret[np.argsort(ret[:, 0].astype(float))][:top_n]

### top common organic linkers

In [None]:
counter = Counter()
counter.update(sorted_ret[:,-1].tolist())
top_ol_sm, c = zip(*counter.most_common())
top_ol = [Chem.MolFromSmiles(sm) for sm in top_ol_sm]
len(top_ol)

In [None]:
img = Chem.Draw.MolsToGridImage(top_ol[:16], molsPerRow=4,  subImgSize=[200, 200], useSVG=True)
img

### top common metal clusters

In [None]:
import pormake as pm
from ase.visualize.plot import plot_atoms
database = pm.Database()

In [None]:
# topology
counter = Counter()
counter.update(sorted_ret[:, 1].tolist())
top_topo, count = zip(*counter.most_common())
len(top_topo)

In [None]:
top_topo = top_topo[:16]
fig, axes = plt.subplots(len(top_topo) // 4, 4, figsize=(8, 8), constrained_layout=True)
for i, t in enumerate(top_topo):
    ax = axes[i//4, i%4]
    topo_ = database.get_topo(t)
    plot_atoms(topo_.atoms, ax=ax, radii=0.2)
    ax.set_title(f"{t}\n({round(count[i]/sum(count) * 100, 1)} %)", fontsize=20)
    ax.set_axis_off()

### top common metal clusters

In [None]:
# metal cluster
counter = Counter()
counter.update(sorted_ret[:, 2].tolist())
top_mc, count = zip(*counter.most_common())
len(top_mc)

In [None]:
top_mc = top_mc[:16]
fig, axes = plt.subplots(len(top_mc) // 4, 4, figsize=(8, 8), constrained_layout=True)
for i, mc in enumerate(top_mc):
    ax = axes[i//4, i%4]
    mc_ = database.get_bb(mc)
    plot_atoms(mc_.atoms, ax=ax)
    ax.set_title(f"{mc}\n({round(count[i]/sum(count) * 100, 1)} %)", fontsize=20)
    ax.set_axis_off()

## 1. contruction of the generated MOFs by PORMAKE

In [None]:
from rdkit import RDLogger  
RDLogger.DisableLog('rdApp.*')                                                                                                                                                           
pm.log.disable_print()
pm.log.disable_file_print()

In [None]:
save_dir_bb = "results/qkh/bb_dir"
save_dir_gen_mofs = "results/qkh/gen_mofs"

In [None]:
# move node bbs from origin bb_dir (PORMAKE) to new bb_dir
if os.path.exists(save_dir_bb):
    shutil.rmtree(save_dir_bb)
shutil.copytree(f"{pm.__path__[0]}/database/bbs", save_dir_bb)
# bb_dir
database = pm.Database(bb_dir=Path(save_dir_bb))
# save_dir for generated MOFs
os.makedirs(save_dir_gen_mofs, exist_ok=True)

In [None]:
def smiles_to_xyz(smiles, save_dir, bb_name="tmp"):
    # smiles to mol
    mol = Chem.MolFromSmiles(smiles)
    # mol to 3D mol
    m = Chem.AddHs(mol)
    AllChem.EmbedMolecule(m)
    AllChem.MMFFOptimizeMolecule(m)
    # mol to molblock
    mol_block = Chem.MolToMolBlock(m)
    lines = mol_block.splitlines()
    # write xyz file
    line = lines[3]
    num_atoms = int(line[:3])
    num_bonds = int(line[3:6])
    
    save_path = os.path.join(save_dir, f"{bb_name}.xyz")
    with open(save_path, "w") as f:
        f.write(f"{num_atoms}\n")
        f.write(f"mol to xyz file\n")
        # coords
        for line in lines[4:4+num_atoms]:
            tokens = line.split()
            # change dummy atoms R to X
            if tokens[3] == "R":
                tokens[3] = "X"
            f.write(f"{tokens[3]:<10}    {tokens[0]:<10}    {tokens[1]:<10}    {tokens[2]:<10}\n")
        # bonds
        for line in lines[4+num_atoms:4+num_atoms+num_bonds]:
            tokens = [int(line[:3]), int(line[3:6]), int(line[6:9])]
            # bond type
            if tokens[2] == 1:
                bond_type = "S"
            elif tokens[2] == 2:
                bond_type = "D"
            elif tokens[2] == 3:
                bond_type = "T"
            elif tokens[2] == 4:
                bond_type = "A"
            else:
                raise Exception("bond type error")
            # find index of atom
            idx_1 = int(tokens[0]) - 1
            idx_2 = int(tokens[1]) - 1
            f.write(f"{idx_1:<10}{idx_2:<6}{bond_type:<6}\n")
        f.close()

In [None]:
def construct_mofs(final_ret, save_dir_bb, save_dir_gen_mofs):
    e0 = 0 # build error
    e1 = 0
    e2 = 0
    e3 = 0
    e4 = 0

    idx = 0
    vocab_sm = {}
    for p, topo_, mc_, sm_ in tqdm(final_ret):
        print(p, topo_, mc_, sm_)
        # save smiles to xyz file
        try:
            if sm_ not in vocab_sm.keys():
                smiles_to_xyz(sm_, save_dir=save_dir_bb, bb_name=f"{len(vocab_sm)}")
                vocab_sm[sm_] = f"{len(vocab_sm)}"
        except Exception as e:
            e0 += 1
            print("The smile of organice linker can't be converted to xyz files")
            continue


        # get topo, mc, ol
        topo = database.get_topo(topo_)
        mc = database.get_bb(mc_)
        ol = database.get_bb(vocab_sm[sm_])

        # check connection point matching
        topo_cn = list(topo.unique_cn)
        if len(topo_cn) == 1:
            topo_cn.append(2)
        mc_cn = mc.n_connection_points
        ol_cn = sm_.count("*")

        if set(topo_cn) != set([mc_cn, ol_cn]):
            print(f"fail : {topo_cn, mc_cn, ol_cn}")
            continue

        # assingn node and edge
        if mc_cn == topo_cn[0] and topo_cn[1] == 2:
            node_bbs = {
                0 : mc,
            }
            edge_bbs = {
                tuple(topo.unique_edge_types[0]) : ol,
            }
        elif mc_cn == topo_cn[0] and topo_cn[1] != 2:
            node_bbs = {
                0 : mc,
                1 : ol,
            }
            edge_bbs = {}
        else:
            node_bbs = {
                0 : ol,
                1 : mc,
            }
            edge_bbs = {}
        # build MOF
        builder = pm.Builder()
        try:
            gen_mof = builder.build_by_type(topology=topo, node_bbs=node_bbs, edge_bbs=edge_bbs)
        except Exception as e:
            e0 += 1
            continue

        # check criterion
        # (1) SAscore < 6
        m = Chem.MolFromSmiles(sm_)
        score = sascorer.calculateScore(m)
        if score > 6:
            e1 += 1
            continue
        # (2) rmsd
        if gen_mof.info["max_rmsd"] > 0.3:
            e2 += 1
            continue
        # (3) # of atoms <= 3000
        if len(gen_mof.atoms) > 3000:
            e3 += 1
            continue
        # (4) length of cells < 60 A
        if gen_mof.atoms.get_cell_lengths_and_angles()[:3].max() > 60:
            e4 += 1
            continue

        # write cif
        try:
            filename = f"{str(idx).zfill(3)}_{topo_}+{mc_}+{vocab_sm[sm_]}.cif"
            print(f"write_cif {filename}")
            gen_mof.write_cif(f"{save_dir_gen_mofs}/{filename}")
            idx += 1
        except:
            e0 += 1
    # write vocab for smiles of organice linker
    json.dump(vocab_sm, open(f"{save_dir_gen_mofs}/vocab_sm.json", "w"))

In [None]:
# construct top 100 MOFs
final_ret = sorted_ret[:100]
construct_mofs(final_ret, save_dir_bb, save_dir_gen_mofs)

## 2. visualize the constructed MOFs

In [None]:
filenames = sorted(list(Path(save_dir_gen_mofs).glob("*.cif")))

In [None]:
p = filenames[1]
atoms = read(p)
plot_atoms(atoms)
print(p)

In [None]:
fig, axes = plt.subplots(3, 3, figsize=(16, 16))
for i in range(9):
    ax = axes[i//3, i%3]
    atoms = read(filenames[i])
    plot_atoms(atoms, ax=ax)
    ax.set_title(filenames[i].name.split("/")[-1].split(".")[0], fontsize=20)
    ax.set_axis_off()
plt.show()