In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import re
import json
import random
import copy
import glob
import pandas as pd
import numpy as np
from itertools import combinations
from collections import Counter

import matplotlib.pyplot as plt
from tqdm import tqdm

from rdkit import Chem
from rdkit.Chem import BRICS
from rdkit import RDLogger
import libs.selfies as sf


RDLogger.DisableLog("rdApp.*")

# 1. Data preprocessing - PORMAKE

In [None]:
mc_to_idx = json.load(open("data/mc_to_idx.json"))
topo_to_idx = json.load(open("data/topo_to_idx.json"))
len(mc_to_idx), len(topo_to_idx)

In [None]:
ol_to_smiles = json.load(open("data/ol_to_smiles.json"))
ol_to_selfies = json.load(open("data/ol_to_selfies.json"))
print(len(ol_to_smiles), len(ol_to_selfies))

# 2. Dataset for Generator

## 1) BRICS

In [None]:
def make_ol_input(smile):
    pairs = []
    # find the number of connection points
    num_conn = len(re.findall("\*", smile))

    # BRICS composition
    mol = Chem.MolFromSmiles(smile)
    frags = BRICS.BRICSDecompose(mol)
    frags = np.array(sorted({re.sub(r"\[\d+\*\]", "*", f) for f in frags}))

    # replace * with H
    du, hy = Chem.MolFromSmiles("*"), Chem.MolFromSmiles("[H]")
    subs = np.array([Chem.MolFromSmiles(f) for f in frags])
    subs = np.array(
        [
            Chem.RemoveHs(Chem.ReplaceSubstructs(f, du, hy, replaceAll=True)[0])
            for f in subs
        ]
    )
    subs = np.array([m for m in subs if m.GetNumAtoms() > 1])

    # delete substruct of frags
    match = np.array([[m.HasSubstructMatch(f) for f in subs] for m in subs])
    frags = subs[match.sum(axis=0) == 1]

    # only use top four frags with descending order of number of atoms
    frags = sorted(frags, key=lambda x: -x.GetNumAtoms())[:4]  # [:voc.n_frags]
    frags = [Chem.MolToSmiles(Chem.RemoveHs(f)) for f in frags]

    # encoding selfies
    smile_sf = sf.encoder(smile)
    frags = [sf.encoder(f) for f in frags]

    max_comb = len(frags)
    for ix in range(1, max_comb + 1):
        combs = combinations(frags, ix)
        for comb in combs:

            comb_frags = ".".join(comb)

            if (
                len(comb_frags) > len(smile_sf)
                or len(list(sf.split_selfies(smile_sf))) > 98
            ):
                continue

            # if mol.HasSubstructMatch(Chem.MolFromSmarts(input)):
            pairs.append([num_conn, comb_frags, smile_sf])

    return pairs

In [None]:
cif_ids = json.load(open("assets/success_cif_ids.json"))  # rmsd < 0.25
len(cif_ids)

In [None]:
pairs_mof = []
for mof_name, _ in tqdm(cif_ids.items()):
    topo, mc, ol = mof_name.split("+")

    if ol not in ol_to_selfies:
        continue
    ol_sm = ol_to_smiles[ol]
    pairs_ol = make_ol_input(ol_sm)

    for pair_ol in pairs_ol:
        pairs_mof.append([topo, mc, ol] + pair_ol)

In [None]:
df = pd.DataFrame(
    pairs_mof, columns=["topo", "mc", "ol", "num_conn", "frags", "selfies"]
)
# df.to_csv("data/dataset_generator/raw.csv")

## 2) Vocab_to_idx

In [None]:
# df = pd.read_csv("data/dataset_generator/raw.csv")

In [None]:
counter = Counter()
for i in tqdm(range(len(df))):
    data = df.iloc[i]
    frags = data["frags"]
    selfies = data["selfies"]
    counter.update(sf.split_selfies(frags))
    counter.update(sf.split_selfies(selfies))

In [None]:
vocab = list(counter.keys())
len(vocab)

In [None]:
vocab_to_idx = {}
vocab_to_idx["[PAD]"] = 0
vocab_to_idx["[SOS]"] = 1
vocab_to_idx["[EOS]"] = 2
vocab_to_idx.update({s: i + 3 for i, s in enumerate(vocab)})

In [None]:
# json.dump(vocab_to_idx, open("data/vocab_to_idx.json", "w"))

In [None]:
vocab_to_idx = json.load(open("data/vocab_to_idx.json"))
len(vocab_to_idx)

## 3) Split train, valid, test

In [None]:
df = df.sample(frac=1)

In [None]:
k = int(len(df) * 0.8)
k_ = int(len(df) * 0.9)
# k = int(1e6)
# k_ = int(1e5)

train = df.iloc[:k]
val = df.iloc[k : k + k_]
test = df.iloc[-int(1e4) :]
print(len(train), len(val), len(test))
train.to_csv("data/dataset_generator/train.csv")
val.to_csv("data/dataset_generator/val.csv")
test.to_csv("data/dataset_generator/test.csv")

# 2. Dataset for Reinforcement Learning

In [None]:
vocab_to_idx = json.load(open("data/vocab_to_idx.json"))
train = pd.read_csv("data/dataset_generator/train.csv")
val = pd.read_csv("data/dataset_generator/val.csv")
test = pd.read_csv("data/dataset_generator/test.csv")
len(train), len(val), len(test)

In [None]:
frags = np.array(train["frags"])
num_conn = np.array(train["num_conn"])

In [None]:
counter = Counter()
for f in tqdm(frags):
    counter.update(f.split("."))
vocab_frag = list(counter.keys())

In [None]:
new_train = copy.deepcopy(train[:100000])
new_val = copy.deepcopy(val[:10000])
new_test = copy.deepcopy(test)
split = [new_train, new_val, new_test]

In [None]:
for s in split:
    num = len(s) / 2
    for i in tqdm(range(int(num))):

        num_frags = np.random.choice(range(1, 5), 1, p=[0.3, 0.3, 0.3, 0.1])[0]

        check = True
        while check:
            new_frags = ".".join(np.random.choice(vocab_frag, num_frags))
            if len(list(sf.split_selfies(new_frags))) < 100:
                check = False

        s["frags"].iloc[i] = new_frags
    s.sample(frac=1)

In [None]:
name = ["train", "val", "test"]
for s in split:
    s.to_csv(f"data/dataset_reinforce/{s}.csv")

In [None]:
new_train.to_csv("data/dataset_reinforce/train.csv")
new_val.to_csv("data/dataset_reinforce/val.csv")
new_test.to_csv("data/dataset_reinforce/test.csv")

# 3. Dataset for Predictor

In [None]:
mc_to_idx = json.load(open("data/mc_to_idx.json"))
topo_to_idx = json.load(open("data/topo_to_idx.json"))
ol_to_smiles = json.load(open("data/ol_to_smiles.json"))
ol_to_selfies = json.load(open("data/ol_to_selfies.json"))
len(mc_to_idx), len(topo_to_idx), len(ol_to_smiles), len(ol_to_selfies)

In [None]:
vocab_to_idx = json.load(open("data/vocab_to_idx.json"))

In [None]:
mof_to_rmsd = json.load(open("assets/mof_to_rmsd.json"))
vocab_to_idx = json.load(open("data/vocab_to_idx.json"))
len(mof_to_rmsd), len(vocab_to_idx)

In [None]:
dict_mof = {}
for mof_name, target in mof_to_rmsd.items():
    mof = {}
    topo, mc, ol = mof_name.split("+")

    if ol not in ol_to_selfies:
        continue

    mof["topo_name"] = topo
    mof["mc_name"] = mc
    mof["ol_name"] = ol

    # get smiles of oragnic linker (add start and end token)
    # mof["ol_smiles"] = "<" + ol_to_smiles[ol] + ">"
    mof["ol_selfies"] = "[SOS]" + ol_to_selfies[ol] + "[EOS]"
    # make sequence of MOF
    mof["topo"] = topo_to_idx[topo]
    mof["mc"] = mc_to_idx[mc]

    ol = sf.selfies_to_encoding(
        selfies=ol_to_selfies[ol], vocab_stoi=vocab_to_idx, enc_type="label"
    )
    mof["ol"] = [1] + ol + [2]
    # get target
    dict_mof[mof_name] = mof

In [None]:
dict_mof = json.load(open("data/dict_mof.json"))
len(dict_mof)

## 1) Q_kh

### charge check

In [None]:
filenames_cif = glob.glob("assets/co2/uff_eqeq_cifs/charge_cifs_batch*/*.cif")
len(filenames_cif)

In [None]:
target_cif_ids = []
fail_cif_ids = []
for filename in tqdm(filenames_cif):
    with open(filename, "r") as f:
        lines = f.read().splitlines()
        f.close()

    fail = False
    for line in lines[16:]:
        tokens = line.split()
        if len(tokens) > 5:
            c = float(tokens[-1])

            if not -3 < c < 3:
                fail = True
                break
    cif_id = filename.split("/")[-1][:-4].split("_charge")[0]
    if fail is False:
        target_cif_ids.append(cif_id)
    else:
        fail_cif_ids.append(cif_id)

In [None]:
len(target_cif_ids), len(fail_cif_ids)

In [None]:
target_cif_ids = []
fail_cif_ids = []
for filename in tqdm(filenames_cif):
    with open(filename, "r") as f:
        lines = f.read().splitlines()
        f.close()

    fail = False
    for line in lines[16:]:
        tokens = line.split()
        if len(tokens) > 5:
            c = float(tokens[-1])

            if not -3 < c < 3:
                fail = True
                break
    cif_id = filename.split("/")[-1][:-4].split("_charge")[0]
    if fail is False:
        target_cif_ids.append(cif_id)
    else:
        fail_cif_ids.append(cif_id)

In [None]:
mof_to_qkh = {}
for csv_file in tqdm(glob.glob("assets/co2/*.csv")):
    csv_ = pd.read_csv(csv_file)

    for i in range(len(csv_)):
        data = csv_.iloc[i]
        cif_id = data["structure"].split(".")[0].split("_charge")[0]

        if cif_id not in target_cif_ids:
            continue

        qkh = float(data["q_kh_co2"])
        if -100 < qkh < 0:
            mof_to_qkh[cif_id] = float(data["q_kh_co2"])

len(mof_to_qkh)

In [None]:
mof_names = list(set(mof_to_qkh.keys()) & set(dict_mof.keys()))
len(mof_names)

In [None]:
dict_qkh = {}
for mof_name in tqdm(mof_names):
    # get target

    d_ = dict_mof[mof_name]
    d_.update({"target": mof_to_qkh[mof_name]})
    dict_qkh[mof_name] = d_

In [None]:
cif_id = list(dict_qkh.keys())
random.shuffle(cif_id)
k = int(len(cif_id) * 0.8)
k_ = int(len(cif_id) * 0.9)

In [None]:
train = cif_id[:k]
test = cif_id[k:k_]
val = cif_id[k_:]
print(len(train), len(test), len(val))
split = ["train", "test", "val"]
data = [train, test, val]
for i, s in enumerate(split):
    print(s)
    d = {n: dict_qkh[n] for n in data[i]}
    print(len(d))
    json.dump(d, open(f"data/dataset_predictor/qkh/{s}.json", "w"))