In [None]:
import os
import sys
from rdkit import Chem
import time
import math
from tqdm import tqdm
import numpy as np
from sklearn.linear_model import LinearRegression 
from sklearn import metrics
import torch.optim as optim
import torch
from torch import nn
import numpy as np
import pandas as pd
# %matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
from torch_geometric.loader import DataLoader
from rdkit.Chem import AllChem
from rdkit.Chem import RDConfig
from rdkit.Chem import Descriptors
from rdkit.Chem import rdDepictor
sys.path.append(os.path.join(RDConfig.RDContribDir, "SA_Score"))
import sascorer
sns.set_theme(style="white", palette=None)

from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')
from rdkit.Chem.Draw import IPythonConsole
# IPythonConsole.molSize = (200, 200)   # Change image size
IPythonConsole.ipython_useSVG = True  # Change output to SVG

from catcvae.utils import smiles_to_mol, mol_to_smiles
from catcvae.dataset import getDatasetFromFile, getDatasetObject, getDatasetSplittingFinetune

# display_molecule

In [None]:
def display_molecule(molecules, title=None, texts=None):
    fig, axs = plt.subplots(math.ceil(len(molecules)/5), 5, figsize=(15, math.ceil(len(molecules)*0.75)), dpi=300)
    fig.subplots_adjust(hspace=.5, wspace=.001)
    axs = axs.ravel()
    for i in range(math.ceil(len(molecules)/5)*5):
        if i < len(molecules):
            mol = molecules[i]
            ax = axs[i]
            ax.imshow(Chem.Draw.MolToImage(mol))
            ax.axis('off')
            if title:
                ax.set_title(title[i])
            if texts:
                ax.text(100, 350, texts[i], fontsize=12)
        else:
            ax = axs[i]
            ax.axis('off')

# Read file

In [None]:
file = 'ps'
seed = 42
folder = 'output_0_42_20250428_203238_3809035'
df_name = 'generated_mol_lat_20250429_004122'
df_gen = pd.read_csv('dataset/'+file+'/'+folder+'/'+df_name+'.csv', header=None, names=['smiles', 'predicted'])
print("df: ", len(df_gen))

In [None]:
# get dict of SMILES with avg and std of value
def get_dict(df):
    dict_mol = {}
    for idx, row in df.iterrows():
        smiles = row['smiles']
        dict_mol.setdefault(smiles, []).append(row['predicted'])
    return dict_mol

dict_gen = get_dict(df_gen)

print("dict_gen: ", len(dict_gen))

In [None]:
result_gen = list()

for smiles, values in dict_gen.items():
    result_gen.append({'smiles': smiles, 'avg': np.mean(values), 'std': np.std(values)})

result_gen = pd.DataFrame(result_gen).sort_values(by='avg', ascending=True)
result_gen = result_gen.reset_index(drop=True)

result_gen

In [None]:
result_rand_sample = result_gen[:20]

molecules = [smiles_to_mol(smiles, with_atom_index=False) for smiles in result_rand_sample['smiles']]
# texts = [f'{row["avg"]:.2f} ± {row["std"]:.2f}' for idx, row in result_rand_sample.iterrows()]
texts = [f'{row["avg"]:.2f}' for idx, row in result_rand_sample.iterrows()]
title = range(1, len(molecules)+1)
display_molecule(molecules, title=None, texts=texts)

# Yield

In [None]:
color_1 = '#1C3077' # blue
color_2 = '#E97132' # orange
color_3 = '#196B24' # green
color_4 = '#0F9ED5' # sky blue
color_5 = '#A02B93' # purple
color_6 = '#CE1500' # red

In [None]:
# dataset
df_dataset = pd.read_csv('dataset/'+file+'/datasets_dobj_split_0.csv')
df_dataset_test = df_dataset[df_dataset['s']== 'test']
df_dataset = df_dataset[df_dataset['s']!= 'test']

print("dataset: ", len(df_dataset))
print("dataset_test: ", len(df_dataset_test))

In [None]:
# display test molecules
molecules = [smiles_to_mol(smiles, with_atom_index=False) for smiles in df_dataset_test['smiles_catalyst'][:10]]
texts = [f'{row["y"]:.2f}' for idx, row in df_dataset_test[:10].iterrows()]
display_molecule(molecules, title=None, texts=texts)


In [None]:
# plot distribution
plt.figure(figsize=(8, 5), dpi=300)
sns.histplot(list(df_dataset['y']), bins=range(0, 101, 5), color=color_1, stat='percent')
sns.histplot(list(result_gen['avg']), bins=range(0, 101, 5), color=color_2, stat='percent')
plt.xlabel('Dataset and predicted target')
plt.ylabel('Percentage of molecules (%)')
plt.title('Dataset and Predicted Target Distribution')
plt.legend(['Dataset catalyst', 'Generated catalyst'])
sns.despine()
plt.show()

# SAScore

In [None]:
sascore = list()
for i, row in result_gen.iterrows():
    try:
        mol = Chem.MolFromSmiles(row['smiles'])
        if mol is None:
            mol = Chem.MolFromSmiles(row['smiles'], sanitize=False)
            mol.UpdatePropertyCache(strict=False)
        sascore.append(sascorer.calculateScore(mol))
    except Exception as e:
        print(e)
        sascore.append(np.nan)

result_gen['sascore'] = sascore
result_gen

In [None]:
# sort by sascore
result_gen = result_gen.sort_values(by='sascore', ascending=True)
result_gen = result_gen.dropna()
result_gen = result_gen.reset_index(drop=True)
result_gen

In [None]:
# plot distribution
plt.figure(figsize=(8, 5), dpi=300)
sns.histplot(list(result_gen['sascore']), bins=range(0, 11, 1), color=color_1)
plt.xlabel('SAScore')
plt.ylabel('Frequency')
plt.title('SAScore Distribution')
sns.despine()
plt.show()

# Filtering

In [None]:
from rdkit import DataStructs
def get_fingerprint_dictionary(smiles_list):
    result = {}
    for smiles in tqdm(smiles_list):
        mol = Chem.MolFromSmiles(smiles)
        if mol is not None:
            fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=2048, useChirality=False)
            result[smiles] = fp
    return result

def similarity(a, b, radius=2, dictionary=None):
    if a is None or b is None: 
        return 0.0
    if dictionary and a in dictionary and b in dictionary:
        fp1 = dictionary[a]
        fp2 = dictionary[b]
    else:
        amol = Chem.MolFromSmiles(a)
        bmol = Chem.MolFromSmiles(b)
        if amol is None or bmol is None:
            # print(a, b)
            return 0.0
        fp1 = AllChem.GetMorganFingerprintAsBitVect(amol, radius=radius, nBits=2048, useChirality=False)
        fp2 = AllChem.GetMorganFingerprintAsBitVect(bmol, radius=radius, nBits=2048, useChirality=False)
    return DataStructs.TanimotoSimilarity(fp1, fp2) 

def similarity_to_nearest_neighbor(smiles_list, ref_list, radius=2, dictionary=None):
    similarity_list = []
    for i, a in enumerate(smiles_list):
        max_similarity = 0
        for b in ref_list:
            sim = similarity(a, b, radius=radius, dictionary=dictionary)
            if sim > max_similarity:
                max_similarity = sim
        similarity_list.append(max_similarity)
    return np.mean(similarity_list), np.std(similarity_list)

# get fingerprint dictionary from dataset
training_smiles = [row['smiles_catalyst'] for idx, row in df_dataset.iterrows()]
fingerprint_dict = get_fingerprint_dictionary(training_smiles)

In [None]:
filter_index = []
not_pass_sim = []
for i, row in result_gen.iterrows():
    mol_catalyst = Chem.MolFromSmiles(row['smiles'])
    flag = True

    if 'sm' in file:
        # check number of fragments
        rs = Chem.GetMolFrags(mol_catalyst, asMols=True)
        if len(rs) != 4:
            flag = False
            continue

    if 'l_sm' in file:
        # check number of fragments
        rs = Chem.GetMolFrags(mol_catalyst, asMols=True)
        if len(rs) != 1:
            flag = False
            continue
        # check atom P or atom N in molecule
        has_P = False
        has_N = False
        for atom in mol_catalyst.GetAtoms():
            if atom.GetSymbol() == 'P':
                has_P = True
            if atom.GetSymbol() == 'N':
                has_N = True
        if not has_P and not has_N:
            flag = False
            continue
        # check neightbor of atom P is exactly 3
        if has_P:
            for atom in mol_catalyst.GetAtoms():
                if atom.GetSymbol() == 'P':
                    if atom.GetDegree() != 3:
                        flag = False
                        break
        # check ring of 3 not contain atom P
        if has_P:
            for atom in mol_catalyst.GetAtoms():
                if atom.GetSymbol() == 'P':
                    if atom.IsInRingSize(3):
                        flag = False
                        break
        # check O have three neighbor
        for atom in mol_catalyst.GetAtoms():
            if atom.GetSymbol() == 'O':
                if atom.GetDegree() > 2:
                    flag = False
                    break
        # sim = similarity_to_nearest_neighbor([row['smiles']], training_smiles, dictionary=fingerprint_dict)
        # if sim[0] < 0.3:
        #     not_pass_sim.append(mol_catalyst)
        #     flag = False
        # check contain 3-member ring
        ssr = Chem.GetSymmSSSR(mol_catalyst)
        has_3member_ring = any(len(ring) < 5 for ring in ssr)
        if has_3member_ring:
            flag = False
            # objective = 0

    if 'ps' in file:
        # check number of fragments
        rs = Chem.GetMolFrags(mol_catalyst, asMols=True)
        if len(rs) != 1:
            flag = False
            continue

    if flag:
        filter_index.append(i)

result_gen_filtered = result_gen.iloc[filter_index]
result_gen_filtered = result_gen_filtered.reset_index(drop=True)
print("result_gen_filtered: ", len(result_gen_filtered))


In [None]:
# molecules = [mol for mol in not_pass_sim[:20]]
# display_molecule(molecules, title=None, texts=None)

In [None]:
result_gen_filtered

In [None]:
# plot yield compare
plt.figure(figsize=(8, 5), dpi=300)
sns.histplot(list(result_gen['avg']), bins=range(0, 101, 5), color=color_1)
sns.histplot(list(result_gen_filtered['avg']), bins=range(0, 101, 5), color=color_2)
plt.xlabel('Predicted')
plt.ylabel('Number of molecules')
plt.title('Predicted Target Distribution')
plt.legend(['Generated catalyst', 'Filtered generated catalyst'])
sns.despine()
plt.show()

In [None]:
# percent filtering
print("all", len(result_gen))
print("filtered", len(result_gen_filtered))
print("percent filtering: ", len(result_gen_filtered)/len(result_gen)*100)

In [None]:
result_gen_filtered = result_gen_filtered.sort_values(by='avg', ascending=False)
result_gen_filtered = result_gen_filtered.reset_index(drop=True)
result_rand_sample_filtered = result_gen_filtered[:20]

molecules = [smiles_to_mol(smiles, with_atom_index=False) for smiles in result_rand_sample_filtered['smiles']]
# texts = [f'{row["avg"]:.2f} ± {row["std"]:.2f}' for idx, row in result_rand_sample_filtered.iterrows()]
texts = [f'{row["avg"]:.2f}' for idx, row in result_rand_sample_filtered.iterrows()]
title = range(1, len(molecules)+1)
display_molecule(molecules, title=None, texts=texts)

In [None]:
# only not in training dataset
smiles_dataset = [Chem.CanonSmiles(smiles, useChiral=False) for smiles in training_smiles]

novel_molecules_index = []
for i, row in result_gen_filtered.iterrows():
    smiles_cat = Chem.CanonSmiles(row['smiles'], useChiral=False)
    if smiles_cat not in smiles_dataset:
        novel_molecules_index.append(i)

result_gen_filtered_novel = result_gen_filtered.iloc[novel_molecules_index]
result_gen_filtered_novel = result_gen_filtered_novel.reset_index(drop=True)
print("result_gen_filtered_novel: ", len(result_gen_filtered_novel))

In [None]:
result_gen_filtered_novel = result_gen_filtered_novel.sort_values(by='avg', ascending=False)
result_gen_filtered_novel = result_gen_filtered_novel.reset_index(drop=True)
result_rand_sample_filtered_novel = result_gen_filtered_novel[:20]

molecules = [smiles_to_mol(smiles, with_atom_index=False) for smiles in result_rand_sample_filtered_novel['smiles']]
# texts = [f'{row["avg"]:.2f} ± {row["std"]:.2f}' for idx, row in result_rand_sample_filtered.iterrows()]
texts = [f'{row["avg"]:.2f}' for idx, row in result_rand_sample_filtered_novel.iterrows()]
title = range(1, len(molecules)+1)
display_molecule(molecules, title=None, texts=texts)

In [None]:
molecules = [smiles_to_mol(smiles, with_atom_index=False) for smiles in result_gen_filtered_novel[result_gen_filtered_novel['sascore']>=5.5][:20]['smiles']]
display_molecule(molecules, title=None, texts=None)

In [None]:
# percent filtering
print("all", len(result_gen))
print("filtered_novel", len(result_gen_filtered_novel))
print("percent filtered_novel: ", len(result_gen_filtered_novel)/len(result_gen)*100)

In [None]:
result_gen_filtered_novel['MolImage'] = result_gen_filtered_novel['smiles'].apply(Chem.MolFromSmiles)
result_gen_filtered_novel_columns = result_gen_filtered_novel[['smiles', 'avg', 'MolImage']]
result_gen_filtered_novel_columns.rename(columns={'smiles': 'SMILES', 'avg': 'Predicted'}, inplace=True)
result_gen_filtered_novel_columns = result_gen_filtered_novel_columns.reset_index()

# # save to excel with molecule image 
from rdkit.Chem import PandasTools
PandasTools.SaveXlsxFromFrame(result_gen_filtered_novel_columns, 'dataset/'+file+'/'+folder+'/'+df_name+'_filtered.xlsx', molCol='MolImage')

In [None]:
result_gen_filtered_novel_columns

In [None]:
# only not in training and test dataset
smiles_dataset_test = [Chem.CanonSmiles(smiles, useChiral=False) for smiles in df_dataset_test['smiles_catalyst'].values]
df_dataset_test['smiles_canon'] = smiles_dataset_test

novel_molecules_index = []
for i, row in result_gen_filtered_novel.iterrows():
    smiles_cat = Chem.CanonSmiles(row['smiles'], useChiral=False)
    if smiles_cat not in smiles_dataset_test:
        novel_molecules_index.append(i)
    else:
        print("in dataset test: ", smiles_cat)
        testing_output = df_dataset_test[df_dataset_test['smiles_canon'] == smiles_cat]['y'].values[0]
        display_molecule([Chem.MolFromSmiles(smiles_cat), Chem.MolFromSmiles(smiles_cat)], title=['testing', 'generated'], texts=[testing_output, f"{row['avg']:.2f}"])

result_gen_filtered_novel_test = result_gen_filtered_novel.iloc[novel_molecules_index]
result_gen_filtered_novel_test = result_gen_filtered_novel_test.reset_index(drop=True)
print("result_con_lat_filtered_novel_test: ", len(result_gen_filtered_novel_test))