# **5. Recovery from the generative model predictions**

The notebook has been used to populate Tables 1 (top-100) and 2 of the manuscript. The notebook content shows
the operations for the evaluation of the generative model from eMolecules datasets. Similar operations can be
performed for the MetaNetX dataset.

___

## 5.1. Generation

In [1]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
import warnings

import pandas as pd
from rdkit import RDLogger  # for disabling RDKit warnings

from paper.learning import predict
from paper.learning.configure import Config
from paper.dataset.utils import mol_from_smiles, mol_to_smiles, mol_to_ecfp, ecfp_to_string

RDLogger.DisableLog("rdApp.error")
RDLogger.DisableLog('rdApp.warning')

warnings.filterwarnings("ignore", ".*The 'predict_dataloader' does not have many workers which may be a bottleneck.*")


# Utils ------------------------------------------------------------------------

def mol_to_ecfp_string(mol):
    return ecfp_to_string(mol_to_ecfp(mol))


def mol_from_smiles_with_exception(mol):
    try:
        return mol_from_smiles(mol)
    except:
        return None


# Settings --------------------------------------------------------------------

BASE_DIR = Path().resolve().parent
TEST_FILE = BASE_DIR / 'data' / 'emolecules' / 'splitting' / 'test.tsv'
OUT_FILE = BASE_DIR / 'notebooks' / 'table-2-recovery' / 'emolecules-top100-raw.tsv'

# Prediction settings
CONFIG = Config(
    model_path= BASE_DIR / "data" / "models" / "finetuned.ckpt",
    model_source_tokenizer= BASE_DIR / "data" / "tokens" / "ECFP.model",
    model_target_tokenizer= BASE_DIR / "data" / "tokens" / "SMILES.model",
    pred_mode="beam",
    pred_batch_size=1,
    pred_beam_size=100,
    pred_max_rows=1000,
)
CONFIG.device = "cpu"


# Data preparation -------------------------------------------------------------

# Load test data
df = pd.read_csv(
    TEST_FILE,
    sep='\t',
    nrows=CONFIG.pred_max_rows if CONFIG.pred_max_rows > 0 else None
)

# Drop unnecessary columns
df.drop(columns=['SMILES_0', 'SIGNATURE', 'SIGNATURE_MORGANS'], inplace=True)

# Rename few columns for more clarity
df = df.rename(columns={
    'SMILES': 'Query SMILES',
    'ECFP': 'Query ECFP',
    'ID': 'DB ID',
})

# Append a "Query ID" column containing the row number for easier reference
df['Query ID'] = range(1, len(df) + 1)

# Push this column to the front
cols = df.columns.tolist()
cols = cols[-1:] + cols[:-1]
df = df[cols]

In [2]:
# Prediction -------------------------------------------------------------------
# Tracking time for 1k molecules (with CPU computing device):
# - top   1 using   beam search (batch= 10): 0 days 00:02:14
# - top   1 using   beam search (batch=100): 0 days 00:01:43
# - top  10 using   beam search (batch=  1): 0 days 00:15:55
# - top  10 using   beam search (batch= 10): 0 days 00:17:34
# - top 100 using   beam search (batch=  1): 0 days 02:26:01
# - top 100 using   beam search (batch= 10): 0 days 04:10:33

# We track time taken for the prediction to complete for the entire dataset
time_before = pd.Timestamp.now()

# Predict SMILES for the ECFP queries
results = predict.run(CONFIG, query_data=df["Query ECFP"].values)

# Track time taken for the prediction to complete
time_after = pd.Timestamp.now()
time_diff = time_after - time_before

# Print the time taken for the prediction to complete
print(f"Time taken for prediction: {time_diff}")

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/tduigou/miniforge3/envs/retrosig/lib/python3.11/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.


Predicting DataLoader 0: 100%|██████████| 1000/1000 [2:25:52<00:00,  0.11it/s] 
Time taken for prediction: 0 days 02:26:01.845196


In [3]:
# Refine results ---------------------------------------------------------------

# Merge results with the original data using the "Query ID" column
results = pd.merge(df, results, on="Query ID", how="left")
assert results["Query ECFP_x"].equals(results["Query ECFP_y"])
results.drop(columns=["Query ECFP_y"], inplace=True)
results = results.rename(columns={"Query ECFP_x": "Query ECFP"})


# Save results -----------------------------------------------------------------

results.to_csv(OUT_FILE, sep='\t', index=False)

___

## 5.2. Refine results

In [15]:
from pathlib import Path

import numpy as np
import pandas as pd
from rdkit import RDLogger

from paper.learning.utils import (
    mol_from_smiles,
    mol_to_ecfp,
    ecfp_to_string,
    mol_to_smiles,
    tanimoto,
)


# Logging ---------------------------------------------------------------------

RDLogger.DisableLog("rdApp.error")
RDLogger.DisableLog("rdApp.warning")


# Utils ------------------------------------------------------------------------

def mol_from_smiles_with_exception(mol):
    try:
        return mol_from_smiles(mol)
    except:
        return None


# Load data -------------------------------------------------------------------

BASE_DIR = Path.cwd()
DATA_DIR = BASE_DIR / "table-2-recovery"
FILENAME = "emolecules-top100-raw.tsv"
OUTNAME = FILENAME.replace("-raw.tsv", "-refined.tsv")

df = pd.read_csv(DATA_DIR / FILENAME, sep="\t")


# Refine data -----------------------------------------------------------------

# Let's rename the specific DB column ID to a more generic name
df = df.rename(columns={"EMOLECULES ID": "DB ID"})

# Let's recompute required information on the Query side
df["Query Mol"] = df["Query SMILES"].apply(mol_from_smiles_with_exception)
df["Query ECFP Object"] = df["Query Mol"].apply(mol_to_ecfp)

# Quick check to see if the ECFP are the same
assert df["Query ECFP"].equals(df["Query ECFP Object"].apply(ecfp_to_string))

# Now let's populate back the prediction side
df["Prediction Prob"] = df["Prediction Log Prob"].apply(np.exp)
df["Prediction Mol"] = df["Prediction SMILES"].apply(mol_from_smiles_with_exception)
df["Prediction ECFP Object"] = df["Prediction Mol"].apply(mol_to_ecfp)
df["Prediction ECFP"] = df["Prediction ECFP Object"].apply(ecfp_to_string)

# Now let's check for Mol validity
df["SMILES Syntaxically Valid"] = df["Prediction Mol"].notnull()

# Now let's check for SMILES equality
df["Prediction Canonic SMILES"] = df["Prediction Mol"].apply(mol_to_smiles)
df["SMILES Exact Match"] = df["Query SMILES"] == df["Prediction Canonic SMILES"]

# Now let's check for Tanimoto similarity
df["Tanimoto"] = df.apply(lambda x: tanimoto(x["Query ECFP Object"], x["Prediction ECFP Object"]), axis=1)
df["Tanimoto Exact Match"] = df["Tanimoto"] == 1.0

# Finally export the refined DataFrame
cols = [
    "DB ID",
    "Query ID",
    "Query SMILES",
    "Query ECFP",
    "Prediction Tokens",
    "Prediction Log Prob",
    "Prediction Prob",
    "Prediction SMILES",
    "Prediction ECFP",
    "Prediction Canonic SMILES",
    "Tanimoto",
    "SMILES Exact Match",
    "Tanimoto Exact Match",
    "SMILES Syntaxically Valid",
]
df.to_csv(DATA_DIR / OUTNAME, sep="\t", index=False, columns=cols)

___

## 5.3. Summary statistics

In [64]:
from pathlib import Path

import pandas as pd


# Utils -----------------------------------------------------------------------

def get_summary(results: pd.DataFrame, topk=1) -> pd.DataFrame:
    """Get summary from the results DataFrame.

    Parameters
    ----------
    results : pandas.DataFrame
        The results DataFrame.
    topk : int, optional
        The top-k to consider, by default 1.

    Returns
    -------
    pandas.DataFrame
        The summary DataFrame.
    """

    # First we need to get the unique sequence IDs
    query_ids = results["Query ID"].unique()

    summary = pd.DataFrame(
        columns=[
            "Query ID",
            "Query SMILES",
            "Query ECFP",
            "SMILES Exact Match",
            "Tanimoto Exact Match",
            "SMILES Syntaxically Valid",
            "Tanimoto Exact Match Unique Count",
            "Tanimoto Exact Match Unique List",
        ],
        index=query_ids,
    )

    # Now for can collect results for query ID
    for query_id in query_ids:

        # Get mask corresponding to the query ID
        query_mask = results["Query ID"] == query_id

        # Get subset corresponding to the top-k
        top_query_subset = results[query_mask].nlargest(topk, "Prediction Log Prob")

        # Get the subset from the top-k corresponding to Tanimoto exact match
        top_query_exact_match = top_query_subset[top_query_subset["Tanimoto Exact Match"]]

        # Fill in the stats
        summary.loc[query_id, "Query ID"] = query_id
        summary.loc[query_id, "Query SMILES"] = top_query_subset.iloc[0]["Query SMILES"]
        summary.loc[query_id, "Query ECFP"] = top_query_subset.iloc[0]["Query ECFP"]
        summary.loc[query_id, "SMILES Exact Match"] = any(top_query_subset["SMILES Exact Match"])
        summary.loc[query_id, "Tanimoto Exact Match"] = any(top_query_subset["Tanimoto Exact Match"])
        summary.loc[query_id, "SMILES Syntaxically Valid"] = any(top_query_subset["SMILES Syntaxically Valid"])
        summary.loc[query_id, "Tanimoto Exact Match Unique Count"] = top_query_exact_match["Prediction Canonic SMILES"].nunique()
        summary.loc[query_id, "Tanimoto Exact Match Unique List"] = str(list(top_query_exact_match["Prediction Canonic SMILES"].unique()))

    return summary


def get_statistics(df: pd.DataFrame, topk=1) -> pd.DataFrame:
    """Get statistics from the results DataFrame.

    Parameters
    ----------
    df : pandas.DataFrame
        The results DataFrame.
    topk : int, optional
        The top-k to consider, by default 1.

    Returns
    -------
    pandas.DataFrame
        The statistics DataFrame.
    """

    # First we get summary information
    summary = get_summary(df, topk=topk)

    # Now we can compute basic statistics
    stats = summary.aggregate(
        {
            "SMILES Exact Match": ["mean"],
            "Tanimoto Exact Match": ["mean"],
            "SMILES Syntaxically Valid": ["mean"],
        }
    )

    # Rename columns
    stats.columns = [
        "SMILES Accuracy",
        "Tanimoto Accuracy",
        "SMILES Syntax Validity",
    ]

    # Transpose and set index as a "Stat" column
    stats = stats.T
    stats["Stat"] = stats.index
    stats.reset_index(drop=True, inplace=True)
    
    # Rename and reorder columns
    stats.columns = ["Value", "Stat"]
    stats = stats[["Stat", "Value"]]

    return stats


def get_uniqueness(df: pd.DataFrame, topk=1) -> pd.DataFrame:
    """Get the number of unique molecules per query.

    Parameters
    ----------
    df : pandas.DataFrame
        The results DataFrame.
    topk : int, optional
        The top-k to consider, by default 1.

    Returns
    -------
    pandas.DataFrame
        The unique count per query DataFrame.
    """
    # First we get summary information
    summary = get_summary(df, topk=topk)

    # Get count on the number of unique SMILES per query
    uniqueness = pd.DataFrame(summary["Tanimoto Exact Match Unique Count"].value_counts().sort_index())
    uniqueness.rename(columns={"count": "Count"}, inplace=True)
    uniqueness["Distinct Molecules per Query"] = uniqueness.index
    uniqueness.reset_index(drop=True, inplace=True)
    uniqueness = uniqueness.iloc[:, [1, 0]]  # reverse the order of the columns
    
    return uniqueness


# Load data -------------------------------------------------------------------
BASE_DIR = Path.cwd()
DATA_DIR = BASE_DIR / "table-2-recovery"
FILENAME = "emolecules-top100-refined.tsv"
TOPK = 100

df = pd.read_csv(DATA_DIR / FILENAME, sep="\t")

# Summary ---------------------------------------------------------------------
summary = get_summary(df, topk=TOPK)
OUTFILE = FILENAME.replace("-refined.tsv", "-summary.tsv")
summary.to_csv(DATA_DIR / OUTFILE, sep="\t", index=False)

# Statistics -------------------------------------------------------------------
stats = get_statistics(df, topk=TOPK)
OUTFILE = FILENAME.replace("-refined.tsv", "-statistics.tsv")
stats.to_csv(DATA_DIR / OUTFILE, sep="\t", index=False)

# Uniqueness -------------------------------------------------------------------
uniqueness = get_uniqueness(df, topk=TOPK)
OUTFILE = FILENAME.replace("-refined.tsv", "-uniqueness.tsv")
uniqueness.to_csv(DATA_DIR / OUTFILE, sep="\t", index=False)

print(f"{FILENAME} stats:")
print(stats)
print()
print(f"{FILENAME} uniqueness:")
print(uniqueness)

emolecules-top100-refined.tsv stats:
                     Stat  Value
0         SMILES Accuracy  0.998
1       Tanimoto Accuracy  0.998
2  SMILES Syntax Validity  1.000
3     Molecule Uniqueness  1.063

emolecules-top100-refined.tsv uniqueness:
  Distinct Molecules per Query  Count
0                            0      2
1                            1    947
2                            2     40
3                            3      9
4                            4      1
5                            5      1
