# Evaluate conditionally generated molecules

In [None]:
from pathlib import Path
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
from plotly import express as px
from tqdm import tqdm

## Load data

In [None]:
pred_dir = Path("/homes/buttensc/Projects/semla-flow/evaluation/generated/conditional")
files = list(pred_dir.glob("*_conditional.csv"))
dfs = []
for file in tqdm(files):
    df = pd.read_csv(file)
    df["method"] = " ".join(file.stem.split("_")[1:3])
    df = df.dropna(subset=["Reference molecule"])
    dfs.append(df)
df = pd.concat(dfs).sort_values(["method", "Reference molecule"])


## Load references

In [None]:
file = "/homes/buttensc/Projects/semla-flow/evaluation/truth/test_first_1000.csv"
truth = pd.read_csv(file)
truth = truth[truth.fail != 1.0]
print(len(truth))

In [None]:
metrics = {
    "sucos": "SuCOS",
    "tanimoto": "ECFP4 Bit Tanimoto",
    "ensemble_avg_energy": "Ensemble Average Energy",
    "mol_pred_energy": "Molecular Prediction Energy",
    "energy_ratio": "Energy Ratio",
    "sa": "Synthetic Accessability Score",
    "sa_normalized": "Synthetic Accessability Score (normalized)",
    "spacial": "Spacial Score",
    "qed": "QED",
    "logp": "LogP",
    "lipinski": "Lipinski Rule of 5",
    "num_heavy": "Number of Heavy Atoms",
    "weight": "Molecular Weight",
    "num_rings": "Number of Rings",
}

## Show

In [None]:
metric = "sucos"
# metric = "tanimoto"
name = metrics[metric]
sns.histplot(
    df[["method", metric]].reset_index(drop=True),
    x=metric,
    hue="method",
    bins=100,
    # cumulative=True,
    common_norm=False,
    stat="density",
    element="step",
    # kde=True,
    fill=False,
    # legend=True, palette="tab10", linewidth=1.5
)

In [None]:
# metric = "sucos"
metric = "tanimoto"
name = metrics[metric]
sns.histplot(
    df[["method", metric]].reset_index(drop=True),
    x=metric,
    hue="method",
    bins=100,
    cumulative=True,
    common_norm=False,
    stat="density",
    element="step",
    # kde=True,
    fill=False,
    # legend=True, palette="tab10", linewidth=1.5
)

# Novelty and uniqueness

In [None]:
# df_train = pd.read_csv("evaluation/truth/train.csv")
# df_train["method"] = "GEOM Drugs Training"

df_test = pd.read_csv("evaluation/truth/test.csv")
df_test["method"] = "GEOM Drugs Testing"

testing_smiles = set(df_test["smiles"].dropna()) - {None, "", pd.NA, np.nan}

In [None]:
def compute_uniquenss(smiles: list[str]) -> float:
    """Compute the uniqueness of a list of SMILES strings."""
    valid_smiles = [s for s in smiles if s not in {None, "", pd.NA, np.nan}]  # list
    return len(set(valid_smiles)) / len(valid_smiles)


def compute_novelty(
    smiles: list[str], reference_smiles: set[str] = testing_smiles
) -> float:
    """How many are not in the test set?"""
    # valid_smiles = set(s for s in smiles if s not in {None, "", pd.NA, np.nan})  # set
    valid_smiles = list(s for s in smiles if s not in {None, "", pd.NA, np.nan})  # list
    return len(
        [smiles for smiles in valid_smiles if smiles not in reference_smiles]
    ) / len(valid_smiles)


def compute_unique_novelty(
    smiles: list[str], reference_smiles: set[str] = testing_smiles
) -> float:
    """How many unique new molecules have we generated?"""
    # valid_smiles = set(s for s in smiles if s not in {None, "", pd.NA, np.nan})  # set
    valid_smiles = list(s for s in smiles if s not in {None, "", pd.NA, np.nan})  # list
    return len(set(valid_smiles) - reference_smiles) / len(valid_smiles)


In [None]:
# How much repetition is there? How unique are the generated molecules?
df.groupby("method")["smiles_pred"].agg(compute_uniqueness)

In [None]:
# How many of the valid generated molecules are not in the test set?
df.groupby("method")["smiles_pred"].agg(compute_novelty)

In [None]:
# How many of the valid generated molecules are in the test set?
(df.groupby("method")["smiles_pred"].agg(compute_novelty) - 1).abs()

In [None]:
# How many valid, unique and new molecules have we generated?
df.groupby("method")["smiles_pred"].agg(compute_unique_novelty)