In [None]:
import os
import sys

import pandas as pd
import numpy as np
import pickle as pkl
import functools
from tqdm import tqdm
import jax.numpy as jnp
import jax
import seaborn as sns
sns.reset_defaults()

from rdkit import Chem
from rdkit.Chem import Crippen, QED, AllChem, DataStructs, Descriptors
from rdkit.Contrib.SA_Score import sascorer # type: ignore[ReportMissingImports]
from rdkit import Chem
from rdkit.rdBase import BlockLogs
import matplotlib.pyplot as plt
import matplotlib as mpl

In [None]:
opt_file_name = 'diffusion_es_opt.pkl'

with open(opt_file_name, 'rb') as f:
    opt_data = pkl.load(f)

In [None]:
colors = sns.color_palette("crest", n_colors = 30)
def plot_evo(ax: plt.Axes, scores: list, constraints: list):
    time_ = len(scores)
    range_ = range(0, time_, 1)
    selected_scores = [scores[i] for i in range_]

    for idx, data in enumerate(selected_scores):
        _filt_idx = np.prod(constraints[idx], -1).astype(bool)
        data = data[_filt_idx]
        x, y = data[:, 0], data[:, 1]
        ax.scatter(x, y, color=colors[idx], alpha=1, s=20, marker='o', zorder=2, linewidths=0.5, edgecolors='w')

In [None]:
## JNK3/GSK3b example
exp1, exp2 = 9.97, 10.08
xmin, xmax = 8, 12
ymin, ymax = 8.5, 12.5
xticks = np.arange(xmin, xmax + 0.1, 1)
yticks = np.arange(ymin, ymax + 0.1, 1)
xhead, yhead = r'JNK3', r'GSK3$\beta$'
init1, init2 = opt_data['scores'][0][0]

In [None]:
bwidth = 0.5
plt.figure(figsize=(70 / 25.4, 70 / 25.4), dpi=300)
ax_ = plt.gca()
plt.subplots_adjust(left=0.17, bottom=0.15, right=0.98, top=0.98)
ax_.spines[:].set_linewidth(bwidth)
ax_.tick_params(axis='both', which='major', width=bwidth, length=2)
plot_evo(ax_, opt_data['scores'][1:], opt_data['constraints'][1:])
# init score
ax_.scatter(init1, init2, s=75, color='grey', marker="*", linewidths=0.5, zorder=4, edgecolors='w', label='Start Compound')
# exp ref score
ax_.scatter(x=exp1, y=exp2, s=50, color='grey', marker="^", linewidths=0.5, zorder=4, edgecolors='w', label='Exp. Ref. Compound')
ax_.set_aspect('equal', adjustable='box')
ax_.grid(linewidth=bwidth, ls='-.', alpha=0.4, zorder=0)
ax_.set_xlim(xmin - 0.5, xmax + 0.5)
ax_.set_ylim(ymin - 0.5, ymax + 0.5)
ax_.set_xticks(xticks, labels=[f'-{_:.1f}' for _ in xticks])
ax_.set_yticks(yticks, labels=[f'-{_:.1f}' for _ in yticks])
ax_.set_xlabel(f"{xhead} Docking Score")
ax_.set_ylabel(f"{yhead} Docking Score")
ax_.legend(frameon=False)
plt.savefig('evo_jnk3-gsk3b.pdf')
plt.show()

In [None]:
init_smiles = opt_data['smiles'][0][0]
init_fp = AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(init_smiles), radius=2, nBits=2048)
def calc_constraint_property(smiles: list, constraints: list):
    time_list = []
    qed_list = []
    sa_list = []
    tanimoto_list = []
    for idx, smi in enumerate(smiles):
        _filt_idx = np.prod(constraints[idx], -1).astype(bool)
        smi = smi[_filt_idx]
        mol = [Chem.MolFromSmiles(smi_) for smi_ in smi]
        qeds = [QED.qed(m) for m in mol]
        sas = [sascorer.calculateScore(m) for m in mol]
        fps = [AllChem.GetMorganFingerprintAsBitVect(m, radius=2, nBits=2048) for m in mol]
        qed_list.extend(qeds)
        sa_list.extend(sas)
        tanimoto_list.extend([DataStructs.TanimotoSimilarity(fp, init_fp) for fp in fps])
        time_list.extend([idx] * len(qeds))
    return {'time': time_list, 'qed': qed_list, 'sa': sa_list, 'tanimoto': tanimoto_list}

In [None]:
df_prop = pd.DataFrame(calc_constraint_property(opt_data['smiles'], opt_data['constraints']))

In [None]:
bwidth = 0.5
plt.figure(figsize=(40 / 25.4, 30 / 25.4), dpi=300)
ax = plt.gca()
plt.subplots_adjust(left=0.2, bottom=0.2, right=0.9, top=0.9)
ax.spines[:].set_linewidth(bwidth)
ax.tick_params(axis='both', which='major', width=bwidth, length=2,)

main_color = '#234e80'
sns.lineplot(
    data=df_prop[df_prop['time'] > 0], x='time', y='qed', marker='none', ax=ax, linewidth=.5,
    errorbar=("pi", 50), color=main_color,
    )
for collection in ax.collections:
    collection.set_edgecolor('none')
    collection.set_alpha(0.2)
# plot time 0 as a horizontal line
ax.axhline(y=df_prop[df_prop['time'] == 0]['qed'].values[0], color='grey', linestyle='--', linewidth=bwidth, label='Initial QED')
ax.axhline(y=0.5, color='red', linestyle='--', linewidth=bwidth, label='Initial QED')
ax.set_xlabel('')
ax.set_ylabel('')
ax.set_xticks([1, 5, 10, 20, 30])
ax.set_xlim(1, 30)
ax.set_ylim(0.5, 0.8)
plt.savefig('evo_jnk3-gsk3b_qed.pdf')
plt.show()

In [None]:
bwidth = 0.5
plt.figure(figsize=(40 / 25.4, 30 / 25.4), dpi=300)
ax = plt.gca()
plt.subplots_adjust(left=0.2, bottom=0.2, right=0.9, top=0.9)
ax.spines[:].set_linewidth(bwidth)
ax.tick_params(axis='both', which='major', width=bwidth, length=2,)

main_color = '#30828c'
ax.axhline(y=df_prop[df_prop['time'] == 0]['sa'].values[0], color='grey', linestyle='--', linewidth=bwidth,)
sns.lineplot(
    data=df_prop[df_prop['time'] > 0], x='time', y='sa', marker='none', ax=ax, linewidth=.5,
    errorbar=("pi", 50), color=main_color,
    )
for collection in ax.collections:
    collection.set_edgecolor('none')
    collection.set_alpha(0.2)

ax.set_xlabel('')
ax.set_ylabel('')
ax.set_xticks([1, 5, 10, 20, 30])
ax.set_xlim(1, 30)
ax.set_ylim(2.0, 2.6)
plt.savefig('evo_jnk3-gsk3b_sa.pdf')
plt.show()

In [None]:
bwidth = 0.5
plt.figure(figsize=(40 / 25.4, 30 / 25.4), dpi=300)
ax = plt.gca()
plt.subplots_adjust(left=0.2, bottom=0.2, right=0.9, top=0.9)
ax.spines[:].set_linewidth(bwidth)
ax.tick_params(axis='both', which='major', width=bwidth, length=2,)

main_color = '#B48972'
sns.lineplot(
    data=df_prop[df_prop['time'] > 0], x='time', y='tanimoto', marker='none', ax=ax, linewidth=.5,
    errorbar=("pi", 50), color=main_color,
    )
for collection in ax.collections:
    collection.set_edgecolor('none')
    collection.set_alpha(0.2)

ax.set_xlabel('')
ax.set_ylabel('')
ax.set_xticks([1, 5, 10, 20, 30])
ax.set_xlim(1, 30)
ax.set_ylim(0.4, 0.6)
plt.savefig('evo_jnk3-gsk3b_tanimoto.pdf')
plt.show()