In [None]:
# execute this cell first
# used to make the colors consistent across all plots
my_palette = ['#3C91E6', '#C03221', '#1B9D78', '#C1839F', '#020402']

The following cells have been used to generate the UN persistence plot

In [None]:
import random
from amcg_utils.eval_utils import get_un_evaluation_df
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
#create UN summary dataframes
df_1 = get_un_evaluation_df("data/orig_smiles/qm9_clean.txt", "data/sampled_smiles/qm9_vaelike.txt", 2000)
df_2 = get_un_evaluation_df("data/orig_smiles/qm9_clean.txt", "data/sampled_smiles/qm9_full.txt", 2000)
df_3 = get_un_evaluation_df("data/orig_smiles/qm9_clean.txt", "data/sampled_smiles/qm9_diag.txt", 2000)
df_4 = get_un_evaluation_df("data/orig_smiles/qm9_clean.txt", "data/sampled_smiles/qm9_diag_2.txt", 2000)

In [None]:
#show plots for different metrics

LEGENDS = ['VAE-like', 'GMM-F', 'GMM-D1','GMM-D2']
RESULTS_FOLDER="eval_results"
dizz = {'Uniqueness': 'uniqueness', 'Novelty':'novelty', 'Uniqueness * Novelty': 'un', 'Non Unique Novelty': 'novelty_nu'}

df = pd.DataFrame()

for column in df_1.columns:
    fig = plt.figure(figsize=(10,5))
    for i in range(4):
        col = eval("df_" + str(i+1) + "['" + column + "']")
        df[LEGENDS[i]] = col
    ax = sns.lineplot(df, dashes=False, palette=my_palette)
    ax.set_xlabel('# Samples / 2000', fontsize=20)
    ax.set_ylabel(column, fontsize=20)
    plt.legend(prop={'family': 'monospace', 'size':14})
    plt.show()
    

Here we calculate and visualize molecular properties

In [None]:
from amcg_utils.gen_utils import read_lines_list
from amcg_utils.eval_utils import get_un_smiles
from amcg_utils.mol_prop_utils import get_diversity, get_props_df_from_list, process_dataframe

In [None]:
#extract random samples from the qm9 dataset
qm9 = random.sample(read_lines_list("orig_smiles/qm9_clean_smiles.txt"), 10000)

#read the samples from the different priors
full_samples_1 = read_lines_list("sampled_smiles/qm9_vaelike.txt")
full_samples_2 = read_lines_list("sampled_smiles/qm9_full.txt")
full_samples_3 = read_lines_list("sampled_smiles/qm9_diag_max.txt")
full_samples_4 = read_lines_list("sampled_smiles/qm9_diag_max_5.txt")

#extract unique and novel samples (for the different priors)
un_1 = get_un_smiles(full_samples_1, qm9)
un_2 = get_un_smiles(full_samples_2, qm9)
un_3 = get_un_smiles(full_samples_3, qm9)
un_4 = get_un_smiles(full_samples_4, qm9)

In [2]:
# get tanimoto similarity for the different priors
diversities = {}
diversities['qm9'] = get_diversity(qm9)
diversities['un_conf_1'] = get_diversity(un_1)
diversities['un_conf_2'] = get_diversity(un_2)
diversities['un_conf_3'] = get_diversity(un_3)
diversities['un_conf_4'] = get_diversity(un_4)

In [None]:
#extract properties for the different priors
prop_df_0 = get_props_df_from_list(qm9)
prop_df_1 = get_props_df_from_list(un_1)
prop_df_2 = get_props_df_from_list(un_2)
prop_df_3 = get_props_df_from_list(un_3)
prop_df_4 = get_props_df_from_list(un_4)

In [None]:
palette=[my_palette[-1]] + my_palette
LEGENDS = ['QM9', 'VAE-like', 'GMM-F', 'GMM-D1','GMM-D2']

COLUMNS = ['logp','qed','sas','heavymolwt']
label_dizz = {'logp': 'logP', 'qed':'QED', 'sas': 'SA Score', 'nps': 'NP Score', 'num_heavy_atoms': '# heavy atoms', 'heavymolwt': "Heavy mol. weight", 'plogp': "Penalized logP"}

BW_ADJUST = 1.5 #under 1 less bins

df = pd.DataFrame()
fig, axs = plt.subplots(2, 2, figsize=(12,10))
for j, column in enumerate(COLUMNS):
    for i, label in enumerate(LEGENDS):
        linewidth = 1
        linestyle = '-'
        if i == 0:
            linestyle = 'dashed'
            linewidth = linewidth + 1
        col = eval("prop_df_" + str(i) + "['" + column + "']")
        sns.kdeplot(col, fill=False, color=palette[i], ax=axs[j//2][j%2], bw_adjust=BW_ADJUST, label=label, linewidth=linewidth, linestyle=linestyle)
    axs[j//2][j%2].set_xlabel(label_dizz[column], fontsize=20)
    axs[j//2][j%2].set_ylabel('Density', fontsize=20)
    axs[j//2][j%2].legend(prop={'family': 'monospace', 'size':12})
 
plt.show()

In [None]:
df_list = [prop_df_0, prop_df_1, prop_df_2, prop_df_3, prop_df_4]
LEGENDS = ['QM9', 'VAE-like', 'GMM-F', 'GMM-D1','GMM-D2']
rows = []
summary = pd.DataFrame()
for df, legend in zip(df_list, LEGENDS):
    sub_df = df.iloc[:,:4].describe()
    rows.append(process_dataframe(sub_df))

summary = pd.DataFrame(rows, index=LEGENDS)