In [7]:
import pandas as pd
from rdkit import Chem
import torch
import yaml
import sys
from functools import partial
from rdkit.Chem import AllChem
import torch.nn as nn
sys.path.append('/projects/mai/se_mai/users/kvvq085_Mary')
from MolAI.main.lib.dataset.chem import standardize_smiles, remove_isotopes
from MolAI.main.lib.model.model import LitMolformer

sys.path.append('/projects/mai/se_mai/users/kvvq085_Mary/MolAI/main')
def load_model(config_path, checkpoint_path, vocabulary_path, device="cuda"):
    hparams = yaml.load(open(config_path), Loader=yaml.FullLoader)
    hparams["vocabulary"] = vocabulary_path
    model = LitMolformer(**hparams)
    state_dict = torch.load(checkpoint_path, map_location=device)["state_dict"]
    model.load_state_dict(state_dict)
    # Move the model to the device after loading the state dict
    model = model.to(device)
    model = model.eval()
    if "with_counts" in config_path:
        model.mol_to_fingerprints = partial(AllChem.GetMorganFingerprint, radius=2)
    else:
        model.mol_to_fingerprints = partial(
            AllChem.GetMorganFingerprintAsBitVect, radius=2, nBits=1024
        )
    return model

def is_good_smiles(smi, tokenizer, vocabulary):
    try:
        mol = Chem.MolFromSmiles(smi)
        smi_no_iso = remove_isotopes(mol)
        std_smi = standardize_smiles(smi_no_iso)
        tokens = tokenizer.tokenize(std_smi)
        return all([token in vocabulary for token in tokens])
    except:
        return False

def process_pairs(pairs, tokenizer, vocabulary):
    good_pairs = []
    good_smiles = set()
    broken_smiles = set()
    for pair in pairs:
        s1, s2, rule, rule_number = pair
        if (s1 in broken_smiles) or (s2 in broken_smiles):
            continue

        skip = False
        if (not s1 in good_smiles):
            if is_good_smiles(s1, tokenizer, vocabulary):
                good_smiles.add(s1)
            else:
                broken_smiles.add(s1)
                skip = True

        if (not s2 in good_smiles):
            if is_good_smiles(s2, tokenizer, vocabulary):
                good_smiles.add(s2)
            else:
                broken_smiles.add(s2)
                skip = True

        if not skip:
            good_pairs.append((s1, s2, rule, rule_number))
        
    return good_pairs

config_path = "/projects/mai/se_mai/users/kvvq085_Mary/config.yml"
checkpoint_path = "/projects/mai/se_mai/users/kvvq085_Mary/weights.ckpt"
vocabulary_path = "/projects/mai/se_mai/users/kvvq085_Mary/vocabulary.pkl"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_molfeatures = load_model(config_path, checkpoint_path, vocabulary_path, device=device)

# Load the pos_val_pairs dataset
pos_val_pairs = pd.read_csv("../pos_pairs_val_new_rnuumber.csv")

# Extract the pairs from the DataFrame
pos_pairs = [(row['mol1'], row['mol2'], row['rule'], row['rule_number']) for _, row in pos_val_pairs.iterrows()]

# Process the pairs and get the good pairs
good_pairs = process_pairs(pos_pairs, model_molfeatures.tokenizer, model_molfeatures.vocabulary)

# Create a new DataFrame with the good pairs
good_pairs_df = pd.DataFrame(good_pairs, columns=['mol1', 'mol2', 'rule', 'rule_number'])

# Save the good pairs DataFrame to a CSV file (optional)
good_pairs_df.to_csv("pos_goodpairs_val.csv", index=False)

  state_dict = torch.load(checkpoint_path, map_location=device)["state_dict"]
[11:20:52] Explicit valence for atom # 18 B, 5, is greater than permitted
[11:21:27] Explicit valence for atom # 0 B, 5, is greater than permitted
[11:23:20] Explicit valence for atom # 0 B, 5, is greater than permitted
[11:24:00] Can't kekulize mol.  Unkekulized atoms: 3 5 27 28 29 31
[11:24:00] Can't kekulize mol.  Unkekulized atoms: 3 5 27 28 29 31
[11:24:00] Can't kekulize mol.  Unkekulized atoms: 3 5 27 28 29 31
[11:24:00] Can't kekulize mol.  Unkekulized atoms: 3 5 27 28 29 31
[11:24:00] Can't kekulize mol.  Unkekulized atoms: 3 5 27 28 29 31
[11:24:00] Can't kekulize mol.  Unkekulized atoms: 3 5 27 28 29 31
[11:24:56] Explicit valence for atom # 16 B, 5, is greater than permitted
[11:25:16] Explicit valence for atom # 14 B, 5, is greater than permitted
[11:25:27] Can't kekulize mol.  Unkekulized atoms: 3 5
[11:25:27] Can't kekulize mol.  Unkekulized atoms: 3 5
[11:25:27] Can't kekulize mol.  Unkekulize

In [8]:
print(good_pairs_df.shape, pos_val_pairs.shape)

(99959, 4) (99989, 5)


In [11]:
pos_tr_pairs = pd.read_csv("../pos_pairs_train_new_rnuumber.csv")

# Extract the pairs from the DataFrame
pos_pairs = [(row['mol1'], row['mol2'], row['rule'], row['rule_number']) for _, row in pos_tr_pairs.iterrows()]

# Process the pairs and get the good pairs
good_pairs = process_pairs(pos_pairs, model_molfeatures.tokenizer, model_molfeatures.vocabulary)

# Create a new DataFrame with the good pairs
good_pairs_df = pd.DataFrame(good_pairs, columns=['mol1', 'mol2', 'rule', 'rule_number'])

# Save the good pairs DataFrame to a CSV file (optional)
good_pairs_df.to_csv("pos_goodpairs_tr.csv", index=False)

[11:46:57] Can't kekulize mol.  Unkekulized atoms: 3 5 9 10 12 15
[11:46:57] Can't kekulize mol.  Unkekulized atoms: 3 5 9 10 12 15
[11:46:57] Can't kekulize mol.  Unkekulized atoms: 3 5 9 10 12 15
[11:46:57] Can't kekulize mol.  Unkekulized atoms: 3 5 9 10 12 15
[11:46:57] Can't kekulize mol.  Unkekulized atoms: 3 5 9 10 12 15
[11:46:57] Can't kekulize mol.  Unkekulized atoms: 3 5 9 10 12 15
[11:48:05] Explicit valence for atom # 0 B, 5, is greater than permitted
[11:48:45] Explicit valence for atom # 0 B, 5, is greater than permitted
[11:50:09] Explicit valence for atom # 15 B, 5, is greater than permitted
[11:52:11] Explicit valence for atom # 22 B, 5, is greater than permitted
[12:02:40] Explicit valence for atom # 4 B, 5, is greater than permitted
[12:03:45] Explicit valence for atom # 16 B, 5, is greater than permitted
[12:05:12] Explicit valence for atom # 13 B, 5, is greater than permitted
[12:08:51] Can't kekulize mol.  Unkekulized atoms: 3 5 9 10 12 15
[12:08:51] Can't kekuli

In [12]:
print(good_pairs_df.shape, pos_tr_pairs.shape)

(999700, 4) (999985, 5)


In [None]:
neg_tr_pairs = pd.read_csv("../neg_pairs_train_new_rnuumber.csv")

# Extract the pairs from the DataFrame
neg_pairs = [(row['mol1'], row['mol2'], row['rule'], row['rule_number']) for _, row in neg_tr_pairs.iterrows()]

# Process the pairs and get the good pairs
good_pairs = process_pairs(neg_pairs, model_molfeatures.tokenizer, model_molfeatures.vocabulary)

# Create a new DataFrame with the good pairs
good_pairs_df = pd.DataFrame(good_pairs, columns=['mol1', 'mol2', 'rule', 'rule_number'])

# Save the good pairs DataFrame to a CSV file (optional)
good_pairs_df.to_csv("neg_goodpairs_tr.csv", index=False)

[13:58:19] Can't kekulize mol.  Unkekulized atoms: 3 5
[13:58:19] Can't kekulize mol.  Unkekulized atoms: 3 5
[13:58:19] Can't kekulize mol.  Unkekulized atoms: 3 5
[13:58:19] Can't kekulize mol.  Unkekulized atoms: 3 5
[13:58:19] Can't kekulize mol.  Unkekulized atoms: 3 5
[13:58:19] Can't kekulize mol.  Unkekulized atoms: 3 5
[13:58:21] Can't kekulize mol.  Unkekulized atoms: 3 5
[13:58:21] Can't kekulize mol.  Unkekulized atoms: 3 5
[13:58:21] Can't kekulize mol.  Unkekulized atoms: 3 5
[13:58:21] Can't kekulize mol.  Unkekulized atoms: 3 5
[13:58:21] Can't kekulize mol.  Unkekulized atoms: 3 5
[13:58:21] Can't kekulize mol.  Unkekulized atoms: 3 5
[13:59:03] Explicit valence for atom # 1 B, 5, is greater than permitted
[14:00:19] Explicit valence for atom # 1 B, 5, is greater than permitted
[14:05:38] Can't kekulize mol.  Unkekulized atoms: 3 5
[14:05:38] Can't kekulize mol.  Unkekulized atoms: 3 5
[14:05:38] Can't kekulize mol.  Unkekulized atoms: 3 5
[14:05:38] Can't kekulize mol

In [1]:
import pandas as pd
good_pairs_df = pd.read_csv('neg_goodpairs_tr.csv')
neg_tr_pairs = pd.read_csv("../neg_pairs_train_new_rnuumber.csv")
print(good_pairs_df.shape, neg_tr_pairs.shape)

(999708, 4) (999974, 5)


In [None]:
neg_val_pairs = pd.read_csv("../neg_pairs_val_new_rnuumber.csv")

# Extract the pairs from the DataFrame
neg_pairs = [(row['mol1'], row['mol2'], row['rule'], row['rule_number']) for _, row in neg_val_pairs.iterrows()]

# Process the pairs and get the good pairs
good_pairs = process_pairs(neg_pairs, model_molfeatures.tokenizer, model_molfeatures.vocabulary)

# Create a new DataFrame with the good pairs
good_pairs_df = pd.DataFrame(good_pairs, columns=['mol1', 'mol2', 'rule', 'rule_number'])

# Save the good pairs DataFrame to a CSV file (optional)
good_pairs_df.to_csv("neg_goodpairs_val.csv", index=False)

In [2]:
good_pairs_df = pd.read_csv('neg_goodpairs_val.csv')
neg_val_pairs = pd.read_csv("../neg_pairs_val_new_rnuumber.csv")
print(good_pairs_df.shape, neg_val_pairs.shape)

(99939, 4) (99971, 5)
