In [None]:
%load_ext dotenv
%dotenv

import os

%cd {os.getenv("PROJECT_PATH") or "."}

%load_ext autoreload
%autoreload 1

from IPython.display import display

In [None]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import sys
from pathlib import Path
from absl import logging
from tqdm.notebook import tqdm, trange
from timeit import default_timer as timer
import pickle
from collections import defaultdict

logging.set_verbosity(logging.INFO)

In [None]:
from pandarallel import pandarallel

pandarallel.initialize(
    nb_workers=os.cpu_count(),
    progress_bar=True,
    verbose=0
)

In [None]:
def show_df(df: pd.DataFrame):
    display(df.head())
    print(df.shape)

In [None]:
from rdkit import Chem

from src.utils.scores import *
from src.vae import load_vae
from src.pinn.pde import load_wavepde
from src.pinn import VAEGenerator

In [None]:
prop = 'qed'
reverse = prop in MINIMIZE_PROPS

files = [
    (f'data/interim/optimization/{prop}_random_0.1_absolute.csv', 'Random'),
    (f'data/interim/optimization/{prop}_random_1d_0.1_absolute.csv', 'Random 1D'),
    (f'data/interim/optimization/{prop}_limo_0.1_relative.csv', 'Gradient Flow'),
    (f'data/interim/optimization/{prop}_chemspace_0.1_absolute.csv', 'ChemSpace'),
    (f'data/interim/optimization/{prop}_wave_sup_0.1_relative.csv',
     'Wave eqn. (spv)'),
    (f'data/interim/optimization/{prop}_wave_unsup_0.1_relative.csv',
     'Wave eqn. (unsup)'),
    (f'data/interim/optimization/{prop}_hj_sup_0.1_relative.csv', 'HJ eqn. (spv)'),
    (f'data/interim/optimization/{prop}_hj_unsup_0.1_relative.csv',
     'HJ eqn. (unsup)'),
    (f'data/interim/optimization/{prop}_fp_0.1_relative.csv', 'Langevin Dynamics'),
]

In [None]:
results = []
for file, name in files:
    df_raw = pd.read_csv(file, index_col=0)
    df_init = df_raw.query('t == 0')

    n = df_init.shape[0]
    steps = df_raw.t.max() + 1


    def func(x: pd.Series):
        mol = Chem.MolFromSmiles(x['smiles'])

        if mol is None:
            # x['valid'] = False
            return x
        # x['valid'] = True
        if x['t'] == 0:
            x['sim'] = 1
            x['delta'] = 0
        else:
            try:
                x['sim'] = ssim(x['smiles'], df_init.loc[x['idx'], 'smiles'])
                x['delta'] = x[prop] - df_init.loc[x['idx'], prop]
            except Exception as e:
                # x['valid'] = False
                return x
        return x


    df_imp = df_raw.parallel_apply(func, axis=1).dropna()

    file_path = Path(file)
    df_imp.to_csv(file_path.parent / 'sims' / file_path.name)
#     for sim in [0, 0.2, 0.4, 0.6]:
#         df_imp = df_imp.query(f'sim >= {sim}')
#         deltas = torch.zeros((n, steps))
#         for _, row in df_imp.iterrows():
#             deltas[row['idx'], row['t']] = row['delta']
# 
#         # improvements = torch.cummax(deltas, dim=1).values
#         # improvements = improvements[:,-1]
#         improvements = torch.max(deltas, dim=1).values
# 
#         succ = (improvements > 0).sum().item() / n
#         improvements = improvements[improvements > 0]
#         r = f'{improvements.mean().item():.2f} ± {improvements.std().item():.2f} ({succ * 100:.1f})'
# 
#         print(f'{name:<20} {sim:.1f}: {r}')
#         
#         results.append({
#             'name': name,
#             'sim': sim,
#             'improvement': r
#         })
# 
# df_results = pd.DataFrame(results)

# show_df(df_results)

In [None]:
sim_files = []
for file, name in files:
    file_path = Path(file)
    sim_files.append((file_path.parent / 'sims' / file_path.name, name))

print(sim_files)

In [None]:
n = 800
steps = 1000
deltas = torch.zeros((n, steps), device='cuda')
results = []

for file, name in sim_files:
    df = pd.read_csv(file, index_col=0)

    for sim in [0, 0.2, 0.4, 0.6]:
        df = df.query(f'sim >= {sim}')
        deltas.zero_()
        for _, row in df.iterrows():
            deltas[row['idx'], row['t']] = row['delta']
        # improvements = torch.cummax(deltas, dim=1).values
        # improvements = improvements[:,-1]
        improvements = torch.max(deltas, dim=1).values

        succ = (improvements > 0).sum().item() / n
        improvements = improvements[improvements > 0]
        r = f'{improvements.mean().item():.2f} ± {improvements.std().item():.2f} ({succ * 100:.1f})'

        print(f'{name:<20} {sim:.1f}: {r}')

        results.append({
            'name': name,
            'sim': sim,
            'improvement': r
        })

df_results = pd.DataFrame(results)

show_df(df_results)

In [None]:
results = []
for sim in [0, 0.2, 0.4, 0.6]:
    row = []
    for file, name in sim_files:
        row.append(
            df_results.query(f'name == "{name}" and sim == {sim}').improvement.values[
                0])
    results.append(row)

df_table = pd.DataFrame(results, columns=[name for _, name in sim_files],
                        index=[f'{sim:.1f}' for sim in [0, 0.2, 0.4, 0.6]])

show_df(df_table)
df_table.to_csv(f'data/interim/optimization/{prop}_improvement.csv')

In [None]:
# make a df
r = []

for file, name in sim_files:
    # with open(file, 'rb') as f:
    #     deltas = pickle.load(f)
    # for i in range(len(deltas)):
    #     for t in range(len(deltas[i][prop])):
    #         r.append({
    #             'name': name,
    #             'idx': i,
    #             't': t,
    #             'smiles': deltas[i]['smiles'][t],
    #             prop: deltas[i][prop][t],
    #             'similarity': deltas[i]['similarity'][t]
    #         })
    df = pd.read_csv(file, index_col=0)
    df['name'] = name
    r.append(df)

df_all = pd.concat(r)

show_df(df_all)

In [None]:
df_all.to_csv(f'data/interim/optimization/{prop}_all.csv')

In [None]:
df_all = pd.read_csv(f'data/interim/optimization/{prop}_all.csv', index_col=0)

show_df(df_all)

In [None]:
sns.set_theme(
    context='paper',
    style='ticks',
    palette='tab10',
    font='serif',
)

# fig, ax = plt.subplots()

sup_line = df_all.query('t == 999 and name == "Langevin Dynamics"')[prop].mean()

_df = df_all.query(
    '(t % 100 == 0 or t == 999) and name in ["Random", "ChemSpace", "Gradient Flow", "Wave eqn. (spv)", "Langevin Dynamics"]')
# _df = df_all.query('t % 100 == 0 or t == 999')

# set x-axis limits -10 to 10
g = sns.displot(_df, x=prop, hue='t', kind='kde', fill=True, col='name', height=2.5,
                col_wrap=5, facet_kws={'sharey': False})
g.set_titles('{col_name}')
g.set_xlabels('plogp')
for ax in g.axes.flat:
    ax.set_xlim(-12, 5)
    #plot vertical line at x=0
    ax.axvline(-2.5, color='black', linestyle='--', lw=0.5)
g.savefig(f'figures/optimization/{prop}_spv_kde.pdf')
g.savefig(f'figures/optimization/{prop}_spv_kde.png')

In [None]:
sns.set_theme(
    context='paper',
    style='ticks',
    palette='tab10',
    font='serif',
)
# sns.set_theme()

_df = df_all.query('t % 100 == 0 or t == 999')
# _df = df_all.query('t % 100 == 0 or t == 999')

g = plt = sns.displot(_df, x=prop, hue='t', kind='kde', fill=True, col='name',
                      col_wrap=3, height=3, facet_kws={'sharey': False})
g.set_titles('{col_name}')
g.set_xlabels('QED')
# for ax in g.axes.flat:
#     ax.set_xlim(-12, 5)
#     #plot vertical line at x=0
#     ax.axvline(-2.5, color='black', linestyle='--', lw=0.5)
g.savefig(f'figures/optimization/{prop}_kde.pdf')
g.savefig(f'figures/optimization/{prop}_kde.png')

In [None]:
df_all.query('name == "ChemSpace" and t == 999').smiles.value_counts()

In [None]:
df_all.query('name == "Langevin Dynamics" and t == 999').smiles.value_counts()

In [None]:
df_all.query('name == "LIMO" and t == 999').smiles.value_counts()

In [None]:
df_all.query('name == "Random" and t == 999').smiles.value_counts()

In [None]:
df_all.query('name == "Wave eqn. (unsup)" and t == 999').smiles.value_counts()

In [None]:
sns.color_palette("rocket")

In [None]:
deltas = torch.zeros((n, steps))
results = []

for file, name in tqdm(sim_files):
    df = pd.read_csv(file, index_col=0)

    for sim in [0, 0.2, 0.4, 0.6]:
        df = df.query(f'sim >= {sim}')
        deltas.zero_()
        for _, row in df.iterrows():
            deltas[row['idx'], row['t']] = row['delta']
        improvements = torch.cummax(deltas, dim=1).values.mean(dim=0)  # (steps,)

        _df = pd.DataFrame({
            't': range(steps),
            'improvement': improvements.cpu().numpy(),
            'name': name,
            'sim': sim
        })
        results.append(_df)

df_results = pd.concat(results)

show_df(df_results)

In [None]:
sns.set_theme(
    context='paper',
    style='darkgrid',
    palette='tab10',
    font='serif',
)

_df = df_results
# _df = df_conv.query('(t % 100 == 0 or t == 999) and name in ["Random", "Langevin Dynamics"]')

# sns.set_theme(context='paper', style='white')
# sns.set_theme()

g = sns.relplot(
    data=_df, x='t', y='improvement', hue='name', col='sim',
    kind='line', height=3, errorbar=None, col_wrap=2,
    facet_kws={'sharey': False}, aspect=1.5,
)
# remove legend title
g.legend.set_title('')
g.set_titles('$\delta$ = {col_name}')
g.set_xlabels('Steps')
g.set_ylabels('Improvement in plogp')
g.savefig(f'figures/optimization/{prop}_conv.pdf')
g.savefig(f'figures/optimization/{prop}_conv.png')