# **4. Recovery from the generative model predictions**

In this notebook, we evaluate the generative model predictions by checking wether the original molecules
can be recovered from the generated SMILES strings within the top-K predictions. Top-K are defined with K = 1, 10, and 100. Test datasets size is 10k molecules for both MetaNetX and eMolecules.

Parameters to be set before running the notebook:
- `MAX_ROWS`: maximum number of rows to be considered for the evaluation (e.g. 10 000)
- `DB`: dataset to be used for the evaluation (e.g. 'MetaNetX' or 'eMolecules')
- `TOP_K`: top-K predictions to be considered (e.g. 1, 10, 100)

In [21]:
MAX_ROWS = 10000
DB = "metanetx"
TOP_K = 100

___

## 5.0 — Init & Settings

In [22]:
%load_ext autoreload
%autoreload 2

from pathlib import Path

import numpy as np
import pandas as pd
from rdkit import RDLogger  # for disabling RDKit warnings

import handy
from paper.dataset.utils import (
    mol_from_smiles,
    mol_to_smiles,
    mol_to_ecfp,
    ecfp_to_string,
    tanimoto,
)
from handy import mol_to_ecfp_molforge, ecfp_to_string_molforge
from paper.learning import predict
from paper.learning.configure import Config

# Logging --------------------------------------------------------------------

RDLogger.DisableLog("rdApp.error")
RDLogger.DisableLog('rdApp.warning')


# Utils ------------------------------------------------------------------------

def mol_to_ecfp_string(mol):
    return ecfp_to_string(mol_to_ecfp(mol))


def mol_from_smiles_with_exception(mol):
    try:
        return mol_from_smiles(mol)
    except:
        return None


# Settings --------------------------------------------------------------------

BASE_DIR = Path().resolve().parent
TEST_FILE = BASE_DIR / 'data' / DB / 'splitting' / 'test.tsv'
WORK_DIR = BASE_DIR / 'notebooks' / 'tables-1-2'


# Prediction settings
CONFIG = Config(
    model_path= BASE_DIR / "data" / "models" / "finetuned.ckpt",
    model_source_tokenizer= BASE_DIR / "data" / "tokens" / "ECFP.model",
    model_target_tokenizer= BASE_DIR / "data" / "tokens" / "SMILES.model",
    pred_mode="beam",
    pred_batch_size=1,
    pred_beam_size=TOP_K,
    pred_max_rows=MAX_ROWS,
)
CONFIG.device = "cpu"

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## 5.1 — Data Preparation

In [3]:
# Data preparation -------------------------------------------------------------

# Create output directory if it doesn't exist
WORK_DIR.mkdir(parents=True, exist_ok=True)
Path(WORK_DIR / DB / str(TOP_K)).mkdir(parents=True, exist_ok=True)

# Load test data
data = pd.read_csv(TEST_FILE, sep='\t', nrows=MAX_ROWS)

# Drop unnecessary columns
data.drop(columns=['SMILES_0', 'SIGNATURE', 'SIGNATURE_MORGANS'], inplace=True)

# Rename few columns for more clarity
data.rename(columns={'ID': 'Query ID', 'SMILES': 'Query SMILES', 'ECFP': 'Query ECFP'}, inplace=True)

# Saved processed data
data.to_csv(WORK_DIR / DB / 'data.tsv', sep='\t', index=False)

## 5.1. Generation

In [4]:
data_filename = 'data.tsv'
out_filename = 'results.raw.tsv'


# Prediction -------------------------------------------------------------------

# Load the data for prediction
data = pd.read_csv(WORK_DIR / DB / data_filename, sep="\t")

# Truncate data according to the prediction limit
data = data.iloc[:CONFIG.pred_max_rows]

# Predict
with handy.Timer() as timer:
    results = predict.run(CONFIG, query_data=data[["Query ID", "Query ECFP"]])

# Log the prediction time
print(f"Prediction completed in {timer.elapsed:.2f} seconds.")

# Post-processing --------------------------------------------------------------

# Merge results with the original data using the "Query ID" column
results = pd.merge(data, results, on="Query ID", how="left")
assert results["Query ECFP_x"].equals(results["Query ECFP_y"])
results.drop(columns=["Query ECFP_y"], inplace=True)
results = results.rename(columns={"Query ECFP_x": "Query ECFP"})

Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/tduigou/miniforge3/envs/signature-paper/lib/python3.11/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.


Predicting DataLoader 0: 100%|██████████| 10000/10000 [21:50:36<00:00,  0.13it/s]  
Prediction completed in 78725.38 seconds.


In [5]:
# Save results -----------------------------------------------------------------

results.to_csv(WORK_DIR / DB / str(TOP_K) / out_filename, sep='\t', index=False)

___

## 5.2. Refine results

In [18]:
results_file = WORK_DIR / DB / str(TOP_K) / "results.raw.tsv"
out_file = WORK_DIR / DB / str(TOP_K) / "results.refined.tsv"


# Load data --------------------------------------------------------------

data = pd.read_csv(results_file, sep="\t")


# Refine data ------------------------------------------------------------

# Let's recompute required information on the Query side
data["Query Mol"] = data["Query SMILES"].apply(mol_from_smiles_with_exception)
data["Query Counted ECFP Object"] = data["Query Mol"].apply(mol_to_ecfp)
data["Query Counted ECFP"] = data["Query Counted ECFP Object"].apply(ecfp_to_string)
data["Query Binary ECFP Object"] = data["Query Mol"].apply(mol_to_ecfp_molforge)
data["Query Binary ECFP"] = data["Query Binary ECFP Object"].apply(ecfp_to_string_molforge)

assert data["Query Counted ECFP"].equals(data["Query ECFP"])
data.drop(columns=["Query ECFP"], inplace=True)

# Now let's populate back the prediction side
data["Predicted Prob"] = data["Predicted Log Prob"].apply(np.exp)
data["Predicted Mol"] = data["Predicted SMILES"].apply(mol_from_smiles_with_exception)
data["Predicted Canonic SMILES"] = data["Predicted Mol"].apply(mol_to_smiles)
data["Predicted Counted ECFP Object"] = data["Predicted Mol"].apply(mol_to_ecfp)
data["Predicted Counted ECFP"] = data["Predicted Counted ECFP Object"].apply(ecfp_to_string)
data["Predicted Binary ECFP Object"] = data["Predicted Mol"].apply(mol_to_ecfp_molforge)
data["Predicted Binary ECFP"] = data["Predicted Binary ECFP Object"].apply(ecfp_to_string_molforge)

# Now let's check for Mol validity
data["SMILES Syntaxically Valid"] = data["Predicted Mol"].notnull()

# Now let's check for SMILES equality (with and without stereo)
data["SMILES Exact Match"] = data["Query SMILES"] == data["Predicted Canonic SMILES"]

# Now let's check for Tanimoto similarity (with and without stereo)
data["Tanimoto Counted ECFP"] = data.apply(lambda x: tanimoto(x["Query Counted ECFP Object"], x["Predicted Counted ECFP Object"]), axis=1)
data["Tanimoto Counted ECFP Exact Match"] = data["Tanimoto Counted ECFP"] == 1.0
data["Tanimoto Binary ECFP"] = data.apply(lambda x: tanimoto(x["Query Binary ECFP Object"], x["Predicted Binary ECFP Object"]), axis=1)
data["Tanimoto Binary ECFP Exact Match"] = data["Tanimoto Binary ECFP"] == 1.0

# Finally export the refined DataFrame
cols = [
    "Query ID",
    "Query SMILES",
    "Query Counted ECFP",
    "Predicted Tokens",
    "Predicted Log Prob",
    "Predicted Prob",
    "Predicted SMILES",
    "Predicted Counted ECFP",
    "Predicted Binary ECFP",
    "Predicted Canonic SMILES",
    "Tanimoto Counted ECFP",
    "Tanimoto Binary ECFP",
    "SMILES Exact Match",
    "Tanimoto Counted ECFP Exact Match",
    "Tanimoto Binary ECFP Exact Match",
    "SMILES Syntaxically Valid",
    "Time Elapsed",
]
data.to_csv(out_file, sep="\t", index=False, columns=cols)


___

## 5.3. Summary statistics

In [19]:
results_file = WORK_DIR / DB / str(TOP_K) / "results.refined.tsv"

# Load refined data -------------------------------------------------------------

data = pd.read_csv(results_file, sep="\t")

# Summary -----------------------------------------------------------------------
summary = handy.get_summary(data, topk=TOP_K)
out_file = WORK_DIR / DB / str(TOP_K) / "summary.tsv"
summary.to_csv(out_file, sep="\t", index=False)

# Statistics  --------------------------------------------------------------------
stats = handy.get_statistics(data, topk=TOP_K)
out_file = WORK_DIR / DB / str(TOP_K) / "statistics.tsv"
stats.to_csv(out_file, sep="\t", index=False)

# Uniqueness ----------------------------------------------------------------------
uniqueness = handy.get_uniqueness(data, topk=TOP_K)
out_file = WORK_DIR / DB / str(TOP_K) / "uniqueness.tsv"
uniqueness.to_csv(out_file, sep="\t", index=False)

print("stats: ")
print(stats)
print()
print("uniqueness: ")
print(uniqueness)

stats: 
                             Stat     Value
0                 SMILES Accuracy  0.943100
1  Tanimoto Counted ECFP Accuracy  0.948800
2          SMILES Syntax Validity  0.998000
3   Tanimoto Binary ECFP Accuracy  0.960400
4            Average Time Elapsed  6.001689

uniqueness: 
   Distinct Molecules per Query  Count
0                             0    512
1                             1   8174
2                             2    657
3                             3    352
4                             4    115
5                             5     39
6                             6     27
7                             7     23
8                             8     16
9                             9     11
10                           10     10
11                           11      5
12                           12      9
13                           13      2
14                           14      7
15                           15      3
16                           17      4
17          