# **5. Cross-comparison from our and MolForge's generative models**

In this notebook, we evaluate the performance of our model on the MolForge dataset (10k molecules), and the performance of the MolForge model on the MetaNetX and eMolecules datasets (2 x 10k molecules).

Results are used to populate table 3 in the main text.

## Case 1 — Pass MolForge's test set through generative model

### 1.0 — Init & Settings

In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path

import numpy as np
import pandas as pd
from rdkit import RDLogger  # for disabling RDKit warnings

from paper.dataset.utils import (
    assign_stereo,
    mol_from_smiles,
    mol_to_smiles,
    mol_to_ecfp,
    ecfp_to_string,
    tanimoto,
)
import handy
from handy import mol_to_ecfp_molforge, ecfp_to_string_molforge
from paper.learning import predict
from paper.learning.configure import Config

# Logging --------------------------------------------------------------------

RDLogger.DisableLog("rdApp.error")
RDLogger.DisableLog('rdApp.warning')


# Utils ------------------------------------------------------------------------

def remove_spaces(s):
    return "".join(s.split()).strip()


def mol_to_ecfp_string(mol):
    return ecfp_to_string(mol_to_ecfp(mol))


def mol_from_smiles_full_stereo(smiles):
    """Convert SMILES to RDKit Mol object with full stereo information."""
    try:
        mol = mol_from_smiles(smiles, clear_stereo=False)
        return assign_stereo(mol)
    except:
        return None


def mol_from_smiles_no_stereo(smiles):
    """Convert SMILES to RDKit Mol object without stereo information."""
    try:
        return mol_from_smiles(smiles, clear_stereo=True)
    except:
        return None

def mol_from_smiles_with_exception(smiles):
    try:
        return mol_from_smiles(smiles)
    except:
        return None


# Settings --------------------------------------------------------------------

BASE_DIR = Path().resolve().parent
DATA_DIR = BASE_DIR / "data" / "molforge"
OUT_DIR = BASE_DIR / "notebooks" / "molforge" / "case-1"
FILENAME = "ECFP4.smiles.test"

# Prediction settings
CONFIG = Config(
    model_path= BASE_DIR / "data" / "models" / "finetuned.ckpt",
    model_source_tokenizer= BASE_DIR / "data" / "tokens" / "ECFP.model",
    model_target_tokenizer= BASE_DIR / "data" / "tokens" / "SMILES.model",
    pred_mode="beam",  # "greedy" or "beam"
    pred_batch_size=10,
    pred_beam_size=1,  # >1 for beam search, not used for greedy
    pred_max_rows=-1,  # -1 means no limit
)
CONFIG.device = "cpu"

  Referenced from: <EB3FF92A-5EB1-3EE8-AF8B-5923C1265422> /Users/tduigou/miniforge3/envs/signature-paper/lib/python3.11/site-packages/torchvision/image.so
  warn(


### 1.0 — Data preparation

In [2]:
# Data preparation -------------------------------------------------------------

for stereo_case in ["raw", "full_stereo", "no_stereo"]:

    # Create output directory if it doesn't exist
    out_case_dir = OUT_DIR / stereo_case
    out_case_dir.mkdir(parents=True, exist_ok=True)

    # Output filename
    out_case_filename = f"data.tsv"

    # Load data
    data = pd.read_csv(DATA_DIR / FILENAME, sep="\t", header=None)
    data.rename(columns={0: "Query SMILES", 1: "Query ECFP"}, inplace=True)

    # Remove spaces in SMILES
    data["Query SMILES"] = data["Query SMILES"].apply(remove_spaces)

    if stereo_case == "raw":
        # Use stereo state as provided in the dataset
        data["Query Mol"] = data["Query SMILES"].apply(mol_from_smiles_with_exception)
        data["Query SMILES"] = data["Query Mol"].apply(mol_to_smiles)
        data["Query ECFP"] = data["Query Mol"].apply(mol_to_ecfp_string)
    
    elif stereo_case == "full_stereo":
        # Use a fully specified stereo state
        data["Query Mol"] = data["Query SMILES"].apply(mol_from_smiles_full_stereo)
        data["Query SMILES"] = data["Query Mol"].apply(mol_to_smiles)
        data["Query ECFP"] = data["Query Mol"].apply(mol_to_ecfp_string)

    elif stereo_case == "no_stereo":
        # Use a flat stereo state
        data["Query Mol"] = data["Query SMILES"].apply(mol_from_smiles_no_stereo)
        data["Query SMILES"] = data["Query Mol"].apply(mol_to_smiles)
        data["Query ECFP"] = data["Query Mol"].apply(mol_to_ecfp_string)

    # Clean
    data.drop(columns=["Query Mol"], inplace=True)

    # Append a "Query ID" column containing the row number for easier reference
    data['Query ID'] = range(1, len(data) + 1)

    # Reorder columns to have "Query ID" first
    cols = data.columns.tolist()
    cols = cols[-1:] + cols[:-1]
    data = data[cols]

    # Save the processed data
    data.to_csv(out_case_dir / out_case_filename, sep="\t", index=False)

### 1.1 — Predict

In [2]:
stereo_case = "no_stereo"  # Change this to "raw" or "full_stereo" or "no_stereo" as needed
work_dir = OUT_DIR / stereo_case
data_filename = f"data.tsv"
out_filename = f"results.raw.tsv"


# Prediction -------------------------------------------------------------------

# Load the data for prediction
data = pd.read_csv(work_dir / data_filename, sep="\t")

# Truncate data according to the prediction limit
data = data.iloc[:CONFIG.pred_max_rows]

# Predict
with handy.Timer() as timer:
    results = predict.run(CONFIG, query_data=data["Query ECFP"].values)

# Append the average time per query to each row on the results
# This is a trick to get per query time without modifying the predict.run function
avg_time_per_query = timer.elapsed / len(results)
results["Time Elapsed"] = avg_time_per_query

# Post-processing --------------------------------------------------------------

# Merge results with the original data using the "Query ID" column
results = pd.merge(data, results, on="Query ID", how="left")
assert results["Query ECFP_x"].equals(results["Query ECFP_y"])
results.drop(columns=["Query ECFP_y"], inplace=True)
results = results.rename(columns={"Query ECFP_x": "Query ECFP"})

Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/tduigou/miniforge3/envs/signature-paper/lib/python3.11/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.


Predicting DataLoader 0: 100%|██████████| 1000/1000 [40:37<00:00,  0.41it/s]


In [3]:
# Save results -----------------------------------------------------------------
results.to_csv(work_dir / out_filename, sep='\t', index=False)

### 1.2 — Refine results

In [4]:
stereo_case = "no_stereo"  # Change this to "raw" or "full_stereo" or "no_stereo" as needed
work_dir = OUT_DIR / stereo_case
results_filename = "results.raw.tsv"
out_filename = "results.refined.tsv"


# Load data -------------------------------------------------------------------

data = pd.read_csv(work_dir / results_filename, sep="\t")


# Refine data -----------------------------------------------------------------

# Let's recompute required information on the Query side
data["Query Mol"] = data["Query SMILES"].apply(mol_from_smiles_with_exception)
data["Query Counted ECFP Object"] = data["Query Mol"].apply(mol_to_ecfp)
data["Query Counted ECFP"] = data["Query Counted ECFP Object"].apply(ecfp_to_string)
data["Query Binary ECFP Object"] = data["Query Mol"].apply(mol_to_ecfp_molforge)
data["Query Binary ECFP"] = data["Query Binary ECFP Object"].apply(ecfp_to_string_molforge)

# Now let's populate back the prediction side
data["Predicted Prob"] = data["Predicted Log Prob"].apply(np.exp)
data["Predicted Mol"] = data["Predicted SMILES"].apply(mol_from_smiles_with_exception)
data["Predicted Canonic SMILES"] = data["Predicted Mol"].apply(mol_to_smiles)
data["Predicted Counted ECFP Object"] = data["Predicted Mol"].apply(mol_to_ecfp)
data["Predicted Counted ECFP"] = data["Predicted Counted ECFP Object"].apply(ecfp_to_string)
data["Predicted Binary ECFP Object"] = data["Predicted Mol"].apply(mol_to_ecfp_molforge)
data["Predicted Binary ECFP"] = data["Predicted Binary ECFP Object"].apply(ecfp_to_string_molforge)

# Now let's check for Mol validity
data["SMILES Syntaxically Valid"] = data["Predicted Mol"].notnull()

# Get the No Stereo version
data["Query Mol No Stereo"] = data["Query SMILES"].apply(mol_from_smiles_no_stereo)
data["Query SMILES No Stereo"] = data["Query Mol No Stereo"].apply(mol_to_smiles)
data["Query Counted ECFP Object No Stereo"] = data["Query Mol No Stereo"].apply(mol_to_ecfp)
data["Query Counted ECFP No Stereo"] = data["Query Counted ECFP Object No Stereo"].apply(ecfp_to_string)

data["Predicted Mol No Stereo"] = data["Predicted Canonic SMILES"].apply(mol_from_smiles_no_stereo)
data["Predicted SMILES No Stereo"] = data["Predicted Mol No Stereo"].apply(mol_to_smiles)
data["Predicted Counted ECFP Object No Stereo"] = data["Predicted Mol No Stereo"].apply(mol_to_ecfp)
data["Predicted Counted ECFP No Stereo"] = data["Predicted Counted ECFP Object No Stereo"].apply(ecfp_to_string)

# Now let's check for SMILES equality (with and without stereo)
data["SMILES Exact Match"] = data["Query SMILES"] == data["Predicted Canonic SMILES"]
data["SMILES Exact Match No Stereo"] = data["Query SMILES No Stereo"] == data["Predicted SMILES No Stereo"]

# Now let's check for Tanimoto similarity (with and without stereo)
data["Tanimoto Counted ECFP"] = data.apply(lambda x: tanimoto(x["Query Counted ECFP Object"], x["Predicted Counted ECFP Object"]), axis=1)
data["Tanimoto Counted ECFP No Stereo"] = data.apply(lambda x: tanimoto(x["Query Counted ECFP Object No Stereo"], x["Predicted Counted ECFP Object No Stereo"]), axis=1)
data["Tanimoto Binary ECFP"] = data.apply(lambda x: tanimoto(x["Query Binary ECFP Object"], x["Predicted Binary ECFP Object"]), axis=1)
data["Tanimoto Counted ECFP Exact Match"] = data["Tanimoto Counted ECFP"] == 1.0
data["Tanimoto Counted ECFP Exact Match No Stereo"] = data["Tanimoto Counted ECFP No Stereo"] == 1.0
data["Tanimoto Binary ECFP Exact Match"] = data["Tanimoto Binary ECFP"] == 1.0

# Finally export the refined DataFrame
cols = [
    "Query ID",
    "Query SMILES",
    "Query SMILES No Stereo",
    "Query Counted ECFP",
    "Query Counted ECFP No Stereo",
    "Query Binary ECFP",
    "Predicted Tokens",
    "Predicted Log Prob",
    "Predicted Prob",
    "Predicted SMILES",
    "Predicted SMILES No Stereo",
    "Predicted Counted ECFP",
    "Predicted Counted ECFP No Stereo",
    "Predicted Binary ECFP",
    "Predicted Canonic SMILES",
    "Tanimoto Counted ECFP",
    "Tanimoto Counted ECFP No Stereo",
    "Tanimoto Binary ECFP",
    "SMILES Exact Match",
    "SMILES Exact Match No Stereo",
    "Tanimoto Counted ECFP Exact Match",
    "Tanimoto Counted ECFP Exact Match No Stereo",
    "Tanimoto Binary ECFP Exact Match",
    "SMILES Syntaxically Valid",
    "Time Elapsed",
]
data.to_csv(work_dir / out_filename, sep="\t", index=False, columns=cols)

### 1.3 — Summary statistics

In [5]:
stereo_case = "no_stereo"  # Change this to "raw" or "full_stereo" or "no_stereo" as needed
top_k = 1
work_dir = OUT_DIR / stereo_case
results_filename = f"results.refined.tsv"

# Load data -------------------------------------------------------------------

data = pd.read_csv(work_dir / results_filename, sep="\t")

# Summary ---------------------------------------------------------------------
summary = handy.get_summary(data, topk=top_k)
out_filename = results_filename.replace("refined.tsv", "summary.tsv")
summary.to_csv(work_dir / out_filename, sep="\t", index=False)

# Statistics -------------------------------------------------------------------
stats = handy.get_statistics(data, topk=top_k)
out_filename = results_filename.replace("refined.tsv", "stastitics.tsv")
stats.to_csv(work_dir / out_filename, sep="\t", index=False)

# Uniqueness -------------------------------------------------------------------
uniqueness = handy.get_uniqueness(data, topk=top_k)
out_filename = results_filename.replace("refined.tsv", "uniqueness.tsv")
uniqueness.to_csv(work_dir / out_filename, sep="\t", index=False)

print(f"{stereo_case} stats:")
print(stats)
print()
print(f"{stereo_case} uniqueness:")
print(uniqueness)

no_stereo stats:
                                       Stat     Value
0                           SMILES Accuracy  0.354835
1            Tanimoto Counted ECFP Accuracy  0.371737
2                    SMILES Syntax Validity  0.851885
3                 SMILES No Stereo Accuracy  0.662066
4  Tanimoto Counted ECFP No Stereo Accuracy  0.673267
5             Tanimoto Binary ECFP Accuracy  0.673267
6                      Average Time Elapsed  0.243924

no_stereo uniqueness:
  Distinct Molecules per Query  Count
0                            0   6282
1                            1   3717


___

## 2 — Pass MetaNetX test set through MolForge's model

### 2.0 — Init & Settings

In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
from types import SimpleNamespace

import pandas as pd
import torch
import sentencepiece as spm
from rdkit import RDLogger

from paper.dataset.utils import (
    mol_from_smiles,
    mol_to_smiles,
    tanimoto,
    mol_to_ecfp,
    ecfp_to_string,
)
from MolForge.utils import pad_or_truncate, pad_id, build_model
from MolForge.predict import greedy_search, beam_search, setup
from handy import mol_to_ecfp_molforge, ecfp_to_string_molforge

# Logging --------------------------------------------------------------------

RDLogger.DisableLog("rdApp.error")
RDLogger.DisableLog('rdApp.warning')


# Utils ------------------------------------------------------------------------

def remove_spaces(x):
    return x.replace(" ", "")


def inference(model, input_sentence, method, args, src_sp, trg_sp, return_attn=False):

    tokenized = src_sp.EncodeAsIds(input_sentence)
    src_data = torch.LongTensor(pad_or_truncate(tokenized, args.src_seq_len)).unsqueeze(0).to(args.device) # (1, L)
    e_mask = (src_data != pad_id).unsqueeze(1).to(args.device) # (1, 1, L)

    model.eval()
    src_data = model.src_embedding(src_data)
    src_data = model.src_positional_encoder(src_data)
    e_output = model.encoder(src_data, e_mask) # (1, L, d_model)

    if method == 'greedy':
        result, attn = greedy_search(model, e_output, e_mask, trg_sp, args.device, True)

    elif method == 'beam':
        result, attn = beam_search(model, e_output, e_mask, trg_sp, args.device, True)

    if return_attn:
        return result, attn
    
    else:
        return result


def mol_from_smiles_with_exception(mol):
    try:
        return mol_from_smiles(mol)
    except:
        return None


def mol_from_smiles_no_stereo(smiles):
    """Convert SMILES to RDKit Mol object without stereo information."""
    try:
        return mol_from_smiles(smiles, clear_stereo=True)
    except:
        return None


# Settings --------------------------------------------------------------------

BASE_DIR = Path().resolve().parent
TEST_FILE = BASE_DIR / 'data' / 'metanetx' / 'splitting' / 'test.tsv'
WORK_DIR = BASE_DIR / 'notebooks' / 'molforge' / 'case-2'

SOURCE_SP = BASE_DIR / "data" / "molforge" / "ECFP4_vocab_sp.model"
TARGET_SP = BASE_DIR / "data" / "molforge" / "smiles_vocab_sp.model"

MODEL_PATH = BASE_DIR / "data" / "molforge" / "ECFP4_smiles_checkpoint.pth"
METHOD = "greedy"
MAX_ROWS = 10000

ARGS = SimpleNamespace(
    fp = "ECFP4",
    model_type = "smiles",
    checkpoint = str(MODEL_PATH),
    decode = "greedy",
    src_vocab_size = 2052,
    trg_vocab_size = 109,
    src_seq_len = 104,
    trg_seq_len = 130,
    root_dir = BASE_DIR,
    src_sp_prefix = str(SOURCE_SP).replace(".model", ""),
    trg_sp_prefix = str(TARGET_SP).replace(".model", ""),
    rank = "cpu",
    device = "cpu",
)

# Create output directory if it doesn't exist
WORK_DIR.mkdir(parents=True, exist_ok=True)

# Write settings
with open(WORK_DIR / "settings.txt", "w") as f:
    f.write("BASE_DIR: " + str(BASE_DIR) + "\n")
    f.write("TEST_FILE: " + str(TEST_FILE) + "\n")
    f.write("WORK_DIR: " + str(WORK_DIR) + "\n")
    f.write("MAX_ROWS: " + str(MAX_ROWS) + "\n")
    for key, value in ARGS.__dict__.items():
        f.write(f"{key}: {value}\n")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### 2.0 — Data preparation

In [7]:
out_filename = "data.tsv"

# Data preparation -------------------------------------------------------------

# Load test data
data = pd.read_csv(TEST_FILE, sep='\t', nrows=MAX_ROWS)

# Drop unnecessary columns
data.drop(columns=['SMILES_0', 'SIGNATURE', 'SIGNATURE_MORGANS'], inplace=True)

# Rename few columns for more clarity
data.rename(columns={'SMILES': 'Query SMILES', 'ECFP': 'Query ECFP'}, inplace=True)

# Compute ECFP compatible with the MolForge 
data["Query SMILES"] = data["Query SMILES"].apply(lambda x: x.replace(" ", ""))
data["Query Mol"] = data["Query SMILES"].apply(mol_from_smiles)
data["Query ECFP Object"] = data["Query Mol"].apply(mol_to_ecfp_molforge)
data["Query ECFP"] = data["Query ECFP Object"].apply(ecfp_to_string_molforge)

# Clean up the DataFrame
data.drop(columns=["Query Mol", "Query ECFP Object"], inplace=True)

# Save processed data
data.to_csv(WORK_DIR / out_filename, sep="\t", index=False)

### 2.1 — Predict

In [8]:
data_filename = "data.tsv"
results_filename = "results.raw.tsv"

# Set up model -------------------------------------------------------------

# Load SentencePiece models
src_sp = spm.SentencePieceProcessor()
src_sp.Load(str(SOURCE_SP))
trg_sp = spm.SentencePieceProcessor()
trg_sp.Load(str(TARGET_SP))

# Load model
model = setup(build_model(ARGS).to(ARGS.device), ARGS.checkpoint, ARGS)

# Results DataFrame --------------------------------------------------------
results = pd.DataFrame(columns=["Query ID", "Query SMILES", "Query ECFP", "Predicted SMILES", "Time Elapsed"])

# Prediction -------------------------------------------------------------------

data = pd.read_csv(WORK_DIR / data_filename, sep="\t")

# Loop over data ECFP
for idx, row in data.iterrows():

    if idx >= MAX_ROWS:
        break
    
    query_id = idx + 1
    query_smiles = row["Query SMILES"]
    query_ecfp = row["Query ECFP"]

    # Log progress
    if idx % 100 == 0:
        print(f"Processing {query_id} / {len(data)}")

    with handy.Timer() as timer:
        # Perform inference
        predi_tokens = inference(model, query_ecfp, METHOD, ARGS, src_sp, trg_sp)

    predi_smiles = inference(model, query_ecfp, METHOD, ARGS, src_sp, trg_sp)
    query_smiles = query_smiles.replace(" ", "")
    predi_smiles = predi_smiles.replace(" ", "")
    results.loc[idx] = [query_id, query_smiles, query_ecfp, predi_smiles, timer.elapsed]


# Save results -----------------------------------------------------------------
results.to_csv(WORK_DIR / results_filename, sep='\t', index=False)

The size of src vocab is 2052 and that of trg vocab is 109.
Loading checkpoint... ECFP4 smiles
Processing 1 / 10000


  checkpoint = torch.load(checkpoint_path, map_location=torch.device(args.rank))


Processing 101 / 10000
Processing 201 / 10000
Processing 301 / 10000
Processing 401 / 10000
Processing 501 / 10000
Processing 601 / 10000
Processing 701 / 10000
Processing 801 / 10000
Processing 901 / 10000
Processing 1001 / 10000
Processing 1101 / 10000
Processing 1201 / 10000
Processing 1301 / 10000
Processing 1401 / 10000
Processing 1501 / 10000
Processing 1601 / 10000
Processing 1701 / 10000
Processing 1801 / 10000
Processing 1901 / 10000
Processing 2001 / 10000
Processing 2101 / 10000
Processing 2201 / 10000
Processing 2301 / 10000
Processing 2401 / 10000
Processing 2501 / 10000
Processing 2601 / 10000
Processing 2701 / 10000
Processing 2801 / 10000
Processing 2901 / 10000
Processing 3001 / 10000
Processing 3101 / 10000
Processing 3201 / 10000
Processing 3301 / 10000
Processing 3401 / 10000
Processing 3501 / 10000
Processing 3601 / 10000
Processing 3701 / 10000
Processing 3801 / 10000
Processing 3901 / 10000
Processing 4001 / 10000
Processing 4101 / 10000
Processing 4201 / 10000
P

### 2.2 — Refine results

In [9]:
results_filename = "results.raw.tsv"
out_filename = "results.refined.tsv"

# Load data -------------------------------------------------------------------

data = pd.read_csv(WORK_DIR / results_filename, sep="\t")


# Refine data -----------------------------------------------------------------

# Let's recompute required information on the Query side
data["Query SMILES"] = data["Query SMILES"].apply(remove_spaces)
data["Query Mol"] = data["Query SMILES"].apply(mol_from_smiles_with_exception)
data["Query Counted ECFP Object"] = data["Query Mol"].apply(mol_to_ecfp)
data["Query Counted ECFP"] = data["Query Counted ECFP Object"].apply(ecfp_to_string)
data["Query Binary ECFP Object"] = data["Query Mol"].apply(mol_to_ecfp_molforge)
data["Query Binary ECFP"] = data["Query Binary ECFP Object"].apply(ecfp_to_string_molforge)

# Now let's populate back the prediction side
data["Predicted Mol"] = data["Predicted SMILES"].apply(mol_from_smiles_with_exception)
data["Predicted Counted ECFP Object"] = data["Predicted Mol"].apply(mol_to_ecfp)
data["Predicted Counted ECFP"] = data["Predicted Counted ECFP Object"].apply(ecfp_to_string)
data["Predicted Binary ECFP Object"] = data["Predicted Mol"].apply(mol_to_ecfp_molforge)
data["Predicted Binary ECFP"] = data["Predicted Binary ECFP Object"].apply(ecfp_to_string_molforge)
data["Predicted Canonic SMILES"] = data["Predicted Mol"].apply(mol_to_smiles)

# Now let's check for Mol validity
data["SMILES Syntaxically Valid"] = data["Predicted Mol"].notnull()

# Now let's check for SMILES equality (with and without stereo)
data["SMILES Exact Match"] = data["Query SMILES"] == data["Predicted Canonic SMILES"]

# Now let's check for Tanimoto identity (with and without stereo)
data["Tanimoto Counted ECFP"] = data.apply(lambda x: tanimoto(x["Query Counted ECFP Object"], x["Predicted Counted ECFP Object"]), axis=1)
data["Tanimoto Counted ECFP Exact Match"] = data["Tanimoto Counted ECFP"] == 1.0
data["Tanimoto Binary ECFP"] = data.apply(lambda x: tanimoto(x["Query Binary ECFP Object"], x["Predicted Binary ECFP Object"]), axis=1)
data["Tanimoto Binary ECFP Exact Match"] = data["Tanimoto Binary ECFP"] == 1.0

# Finally export the refined DataFrame
cols = [
    "Query ID",
    "Query SMILES",
    "Query Counted ECFP",
    "Query Binary ECFP",
    "Predicted SMILES",
    "Predicted Counted ECFP",
    "Predicted Binary ECFP",
    "Predicted Canonic SMILES",
    "Tanimoto Counted ECFP",
    "Tanimoto Counted ECFP Exact Match",
    "Tanimoto Binary ECFP",
    "Tanimoto Binary ECFP Exact Match",
    "SMILES Exact Match",
    "SMILES Syntaxically Valid",
    "Time Elapsed",
]
data.to_csv(WORK_DIR / out_filename, sep="\t", index=False, columns=cols)

### 2.3 — Summary statistics

In [10]:
results_filename = "results.refined.tsv"
top_k = 1

from handy import get_summary, get_statistics, get_uniqueness


# Load data -------------------------------------------------------------------

data = pd.read_csv(WORK_DIR / results_filename, sep="\t")

# Add dummy "Predicted Log Prob" column if not present (for compatibility)
if "Predicted Log Prob" not in data.columns:
    data["Predicted Log Prob"] = 0.0


# Summary ---------------------------------------------------------------------

summary = get_summary(data, topk=top_k)
out_filename = "results.summary.tsv"
summary.to_csv(WORK_DIR / out_filename, sep="\t", index=False)


# Statistics -------------------------------------------------------------------

stats = get_statistics(data, topk=top_k)
out_filename = "results.statistics.tsv"
stats.to_csv(WORK_DIR / out_filename, sep="\t", index=False)


# Uniqueness -------------------------------------------------------------------

uniqueness = get_uniqueness(data, topk=top_k)
out_filename = "results.uniqueness.tsv"
uniqueness.to_csv(WORK_DIR / out_filename, sep="\t", index=False)


print(f"stats:")
print(stats)
print()
print(f"uniqueness:")
print(uniqueness)

stats:
                             Stat     Value
0                 SMILES Accuracy  0.193200
1  Tanimoto Counted ECFP Accuracy  0.204300
2          SMILES Syntax Validity  0.975300
3   Tanimoto Binary ECFP Accuracy  0.496400
4            Average Time Elapsed  1.571109

uniqueness:
  Distinct Molecules per Query  Count
0                            0   7957
1                            1   2043


___

## 3 — Pass eMolecules test set through MolForge's model

### 3.0 — Init & Settings

In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
from types import SimpleNamespace

import pandas as pd
import torch
import sentencepiece as spm
from rdkit import RDLogger

import handy
from paper.dataset.utils import (
    mol_from_smiles,
    mol_to_smiles,
    tanimoto,
    mol_to_ecfp,
    ecfp_to_string,
)
from handy import mol_to_ecfp_molforge, ecfp_to_string_molforge
from MolForge.utils import pad_or_truncate, pad_id, build_model
from MolForge.predict import greedy_search, beam_search, setup

# Logging --------------------------------------------------------------------

RDLogger.DisableLog("rdApp.error")
RDLogger.DisableLog('rdApp.warning')


# Utils ------------------------------------------------------------------------

def remove_spaces(x):
    return x.replace(" ", "")


def inference(model, input_sentence, method, args, src_sp, trg_sp, return_attn=False):

    tokenized = src_sp.EncodeAsIds(input_sentence)
    src_data = torch.LongTensor(pad_or_truncate(tokenized, args.src_seq_len)).unsqueeze(0).to(args.device) # (1, L)
    e_mask = (src_data != pad_id).unsqueeze(1).to(args.device) # (1, 1, L)

    model.eval()
    src_data = model.src_embedding(src_data)
    src_data = model.src_positional_encoder(src_data)
    e_output = model.encoder(src_data, e_mask) # (1, L, d_model)

    if method == 'greedy':
        result, attn = greedy_search(model, e_output, e_mask, trg_sp, args.device, True)

    elif method == 'beam':
        result, attn = beam_search(model, e_output, e_mask, trg_sp, args.device, True)

    if return_attn:
        return result, attn
    
    else:
        return result


def mol_from_smiles_with_exception(mol):
    try:
        return mol_from_smiles(mol)
    except:
        return None


def mol_from_smiles_no_stereo(smiles):
    """Convert SMILES to RDKit Mol object without stereo information."""
    try:
        return mol_from_smiles(smiles, clear_stereo=True)
    except:
        return None


# Settings --------------------------------------------------------------------

BASE_DIR = Path().resolve().parent
TEST_FILE = BASE_DIR / 'data' / 'emolecules' / 'splitting' / 'test.tsv'
WORK_DIR = BASE_DIR / 'notebooks' / 'molforge' / 'case-3'

SOURCE_SP = BASE_DIR / "data" / "molforge" / "ECFP4_vocab_sp.model"
TARGET_SP = BASE_DIR / "data" / "molforge" / "smiles_vocab_sp.model"

MODEL_PATH = BASE_DIR / "data" / "molforge" / "ECFP4_smiles_checkpoint.pth"
METHOD = "greedy"
MAX_ROWS = 10000

ARGS = SimpleNamespace(
    fp = "ECFP4",
    model_type = "smiles",
    checkpoint = str(MODEL_PATH),
    decode = "greedy",
    src_vocab_size = 2052,
    trg_vocab_size = 109,
    src_seq_len = 104,
    trg_seq_len = 130,
    root_dir = BASE_DIR,
    src_sp_prefix = str(SOURCE_SP).replace(".model", ""),
    trg_sp_prefix = str(TARGET_SP).replace(".model", ""),
    rank = "cpu",
    device = "cpu",
)

# Create output directory if it doesn't exist
WORK_DIR.mkdir(parents=True, exist_ok=True)

# Write settings
with open(WORK_DIR / "settings.txt", "w") as f:
    f.write("BASE_DIR: " + str(BASE_DIR) + "\n")
    f.write("TEST_FILE: " + str(TEST_FILE) + "\n")
    f.write("WORK_DIR: " + str(WORK_DIR) + "\n")
    f.write("MAX_ROWS: " + str(MAX_ROWS) + "\n")
    for key, value in ARGS.__dict__.items():
        f.write(f"{key}: {value}\n")

### 3.0 — Data preparation

In [2]:
out_filename = "data.tsv"

# Data preparation -------------------------------------------------------------

# Load test data
data = pd.read_csv(TEST_FILE, sep='\t', nrows=MAX_ROWS)

# Drop unnecessary columns
data.drop(columns=['SMILES_0', 'SIGNATURE', 'SIGNATURE_MORGANS'], inplace=True)

# Rename few columns for more clarity
data.rename(columns={'SMILES': 'Query SMILES', 'ECFP': 'Query ECFP'}, inplace=True)

# Compute ECFP compatible with the MolForge 
data["Query SMILES"] = data["Query SMILES"].apply(lambda x: x.replace(" ", ""))
data["Query Mol"] = data["Query SMILES"].apply(mol_from_smiles)
data["Query ECFP Object"] = data["Query Mol"].apply(mol_to_ecfp_molforge)
data["Query ECFP"] = data["Query ECFP Object"].apply(ecfp_to_string_molforge)

# Clean up the DataFrame
data.drop(columns=["Query Mol", "Query ECFP Object"], inplace=True)

# Save processed data
data.to_csv(WORK_DIR / out_filename, sep="\t", index=False)

### 3.1 — Predict

In [3]:
data_filename = "data.tsv"
results_filename = "results.raw.tsv"

# Set up model -------------------------------------------------------------

# Load SentencePiece models
src_sp = spm.SentencePieceProcessor()
src_sp.Load(str(SOURCE_SP))
trg_sp = spm.SentencePieceProcessor()
trg_sp.Load(str(TARGET_SP))

# Load model
model = setup(build_model(ARGS).to(ARGS.device), ARGS.checkpoint, ARGS)

# Results DataFrame --------------------------------------------------------
results = pd.DataFrame(columns=["Query ID", "Query SMILES", "Query ECFP", "Predicted SMILES", "Time Elapsed"])

# Prediction -------------------------------------------------------------------

data = pd.read_csv(WORK_DIR / data_filename, sep="\t")

# Loop over data ECFP
for idx, row in data.iterrows():

    if idx >= MAX_ROWS:
        break
    
    query_id = idx + 1
    query_smiles = row["Query SMILES"]
    query_ecfp = row["Query ECFP"]

    # Log progress
    if idx % 100 == 0:
        print(f"Processing {query_id} / {len(data)}")

    with handy.Timer() as timer:
        # Perform inference
        predi_tokens = inference(model, query_ecfp, METHOD, ARGS, src_sp, trg_sp)

    predi_smiles = inference(model, query_ecfp, METHOD, ARGS, src_sp, trg_sp)
    query_smiles = query_smiles.replace(" ", "")
    predi_smiles = predi_smiles.replace(" ", "")
    results.loc[idx] = [query_id, query_smiles, query_ecfp, predi_smiles, timer.elapsed]


# Save results -----------------------------------------------------------------
results.to_csv(WORK_DIR / results_filename, sep='\t', index=False)

The size of src vocab is 2052 and that of trg vocab is 109.
Loading checkpoint... ECFP4 smiles
Processing 1 / 10000


  checkpoint = torch.load(checkpoint_path, map_location=torch.device(args.rank))


Processing 101 / 10000
Processing 201 / 10000
Processing 301 / 10000
Processing 401 / 10000
Processing 501 / 10000
Processing 601 / 10000
Processing 701 / 10000
Processing 801 / 10000
Processing 901 / 10000
Processing 1001 / 10000
Processing 1101 / 10000
Processing 1201 / 10000
Processing 1301 / 10000
Processing 1401 / 10000
Processing 1501 / 10000
Processing 1601 / 10000
Processing 1701 / 10000
Processing 1801 / 10000
Processing 1901 / 10000
Processing 2001 / 10000
Processing 2101 / 10000
Processing 2201 / 10000
Processing 2301 / 10000
Processing 2401 / 10000
Processing 2501 / 10000
Processing 2601 / 10000
Processing 2701 / 10000
Processing 2801 / 10000
Processing 2901 / 10000
Processing 3001 / 10000
Processing 3101 / 10000
Processing 3201 / 10000
Processing 3301 / 10000
Processing 3401 / 10000
Processing 3501 / 10000
Processing 3601 / 10000
Processing 3701 / 10000
Processing 3801 / 10000
Processing 3901 / 10000
Processing 4001 / 10000
Processing 4101 / 10000
Processing 4201 / 10000
P

### 3.2 — Refine results

In [2]:
results_filename = "results.raw.tsv"
out_filename = "results.refined.tsv"

# Load data -------------------------------------------------------------------

data = pd.read_csv(WORK_DIR / results_filename, sep="\t")


# Refine data -----------------------------------------------------------------

# Let's recompute required information on the Query side
data["Query SMILES"] = data["Query SMILES"].apply(remove_spaces)
data["Query Mol"] = data["Query SMILES"].apply(mol_from_smiles_with_exception)
data["Query Counted ECFP Object"] = data["Query Mol"].apply(mol_to_ecfp)
data["Query Counted ECFP"] = data["Query Counted ECFP Object"].apply(ecfp_to_string)
data["Query Binary ECFP Object"] = data["Query Mol"].apply(mol_to_ecfp_molforge)
data["Query Binary ECFP"] = data["Query Binary ECFP Object"].apply(ecfp_to_string_molforge)

# Now let's populate back the prediction side
data["Predicted Mol"] = data["Predicted SMILES"].apply(mol_from_smiles_with_exception)
data["Predicted Counted ECFP Object"] = data["Predicted Mol"].apply(mol_to_ecfp)
data["Predicted Counted ECFP"] = data["Predicted Counted ECFP Object"].apply(ecfp_to_string)
data["Predicted Binary ECFP Object"] = data["Predicted Mol"].apply(mol_to_ecfp_molforge)
data["Predicted Binary ECFP"] = data["Predicted Binary ECFP Object"].apply(ecfp_to_string_molforge)
data["Predicted Canonic SMILES"] = data["Predicted Mol"].apply(mol_to_smiles)

# Now let's check for Mol validity
data["SMILES Syntaxically Valid"] = data["Predicted Mol"].notnull()

# Now let's check for SMILES equality (with and without stereo)
data["SMILES Exact Match"] = data["Query SMILES"] == data["Predicted Canonic SMILES"]

# Now let's check for Tanimoto identity (with and without stereo)
data["Tanimoto Counted ECFP"] = data.apply(lambda x: tanimoto(x["Query Counted ECFP Object"], x["Predicted Counted ECFP Object"]), axis=1)
data["Tanimoto Counted ECFP Exact Match"] = data["Tanimoto Counted ECFP"] == 1.0
data["Tanimoto Binary ECFP"] = data.apply(lambda x: tanimoto(x["Query Binary ECFP Object"], x["Predicted Binary ECFP Object"]), axis=1)
data["Tanimoto Binary ECFP Exact Match"] = data["Tanimoto Binary ECFP"] == 1.0

# Finally export the refined DataFrame
cols = [
    "Query ID",
    "Query SMILES",
    "Query Counted ECFP",
    "Query Binary ECFP",
    "Predicted SMILES",
    "Predicted Counted ECFP",
    "Predicted Binary ECFP",
    "Predicted Canonic SMILES",
    "Tanimoto Counted ECFP",
    "Tanimoto Counted ECFP Exact Match",
    "Tanimoto Binary ECFP",
    "Tanimoto Binary ECFP Exact Match",
    "SMILES Exact Match",
    "SMILES Syntaxically Valid",
    "Time Elapsed",
]
data.to_csv(WORK_DIR / out_filename, sep="\t", index=False, columns=cols)

### 3.3 — Summary statistics

In [3]:
results_filename = "results.refined.tsv"
top_k = 1

from handy import get_summary, get_statistics, get_uniqueness


# Load data -------------------------------------------------------------------

data = pd.read_csv(WORK_DIR / results_filename, sep="\t")

# Add dummy "Predicted Log Prob" column if not present (for compatibility)
if "Predicted Log Prob" not in data.columns:
    data["Predicted Log Prob"] = 0.0


# Summary ---------------------------------------------------------------------

summary = get_summary(data, topk=top_k)
out_filename = "results.summary.tsv"
summary.to_csv(WORK_DIR / out_filename, sep="\t", index=False)


# Statistics -------------------------------------------------------------------

stats = get_statistics(data, topk=top_k)
out_filename = "results.statistics.tsv"
stats.to_csv(WORK_DIR / out_filename, sep="\t", index=False)


# Uniqueness -------------------------------------------------------------------

uniqueness = get_uniqueness(data, topk=top_k)
out_filename = "results.uniqueness.tsv"
uniqueness.to_csv(WORK_DIR / out_filename, sep="\t", index=False)


print(f"stats:")
print(stats)
print()
print(f"uniqueness:")
print(uniqueness)

stats:
                             Stat     Value
0                 SMILES Accuracy  0.648700
1  Tanimoto Counted ECFP Accuracy  0.673700
2          SMILES Syntax Validity  0.998200
3   Tanimoto Binary ECFP Accuracy  0.913000
4            Average Time Elapsed  1.470123

uniqueness:
  Distinct Molecules per Query  Count
0                            0   3263
1                            1   6737


___

## 4 — Pass MolForge's test set through MolForge's model

### 4.0 — Init & Settings

In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
from types import SimpleNamespace

import pandas as pd
import torch
import sentencepiece as spm
from rdkit import RDLogger

import handy
from paper.dataset.utils import (
    mol_from_smiles,
    mol_to_smiles,
    tanimoto,
    mol_to_ecfp,
    ecfp_to_string,
)
from handy import mol_to_ecfp_molforge, ecfp_to_string_molforge
from MolForge.utils import pad_or_truncate, pad_id, build_model
from MolForge.predict import greedy_search, beam_search, setup

# Logging --------------------------------------------------------------------

RDLogger.DisableLog("rdApp.error")
RDLogger.DisableLog('rdApp.warning')


# Utils ------------------------------------------------------------------------

def remove_spaces(x):
    return x.replace(" ", "")


def inference(model, input_sentence, method, args, src_sp, trg_sp, return_attn=False):

    tokenized = src_sp.EncodeAsIds(input_sentence)
    src_data = torch.LongTensor(pad_or_truncate(tokenized, args.src_seq_len)).unsqueeze(0).to(args.device) # (1, L)
    e_mask = (src_data != pad_id).unsqueeze(1).to(args.device) # (1, 1, L)

    model.eval()
    src_data = model.src_embedding(src_data)
    src_data = model.src_positional_encoder(src_data)
    e_output = model.encoder(src_data, e_mask) # (1, L, d_model)

    if method == 'greedy':
        result, attn = greedy_search(model, e_output, e_mask, trg_sp, args.device, True)

    elif method == 'beam':
        result, attn = beam_search(model, e_output, e_mask, trg_sp, args.device, True)

    if return_attn:
        return result, attn
    
    else:
        return result


def mol_from_smiles_with_exception(mol):
    try:
        return mol_from_smiles(mol)
    except:
        return None


def mol_from_smiles_no_stereo(smiles):
    """Convert SMILES to RDKit Mol object without stereo information."""
    try:
        return mol_from_smiles(smiles, clear_stereo=True)
    except:
        return None


# Settings --------------------------------------------------------------------

BASE_DIR = Path().resolve().parent
TEST_FILE = BASE_DIR / 'data' / 'molforge' / 'ECFP4.smiles.test'
WORK_DIR = BASE_DIR / 'notebooks' / 'molforge' / 'case-4'

SOURCE_SP = BASE_DIR / "data" / "molforge" / "ECFP4_vocab_sp.model"
TARGET_SP = BASE_DIR / "data" / "molforge" / "smiles_vocab_sp.model"

MODEL_PATH = BASE_DIR / "data" / "molforge" / "ECFP4_smiles_checkpoint.pth"
METHOD = "greedy"
MAX_ROWS = 100

ARGS = SimpleNamespace(
    fp = "ECFP4",
    model_type = "smiles",
    checkpoint = str(MODEL_PATH),
    decode = "greedy",
    src_vocab_size = 2052,
    trg_vocab_size = 109,
    src_seq_len = 104,
    trg_seq_len = 130,
    root_dir = BASE_DIR,
    src_sp_prefix = str(SOURCE_SP).replace(".model", ""),
    trg_sp_prefix = str(TARGET_SP).replace(".model", ""),
    rank = "cpu",
    device = "cpu",
)

# Create output directory if it doesn't exist
WORK_DIR.mkdir(parents=True, exist_ok=True)

# Write settings
with open(WORK_DIR / "settings.txt", "w") as f:
    f.write("BASE_DIR: " + str(BASE_DIR) + "\n")
    f.write("TEST_FILE: " + str(TEST_FILE) + "\n")
    f.write("WORK_DIR: " + str(WORK_DIR) + "\n")
    f.write("MAX_ROWS: " + str(MAX_ROWS) + "\n")
    for key, value in ARGS.__dict__.items():
        f.write(f"{key}: {value}\n")

### 4.0 — Data preparation

In [7]:
out_filename = "data.tsv"

# Data preparation -------------------------------------------------------------

# Load data
data = pd.read_csv(TEST_FILE, sep="\t", header=None, nrows=MAX_ROWS)
data.rename(columns={0: "Query SMILES", 1: "Query ECFP"}, inplace=True)

# Remove spaces in SMILES
data["Query SMILES"] = data["Query SMILES"].apply(remove_spaces)

# Append a "Query ID" column containing the row number for easier reference
data['Query ID'] = range(1, len(data) + 1)

# Reorder columns to have "Query ID" first
cols = data.columns.tolist()
cols = cols[-1:] + cols[:-1]
data = data[cols]

# Save the processed data
data.to_csv(WORK_DIR / out_filename, sep="\t", index=False)

### 4.1 — Predict

In [None]:
data_filename = "data.tsv"
results_filename = "results.raw.tsv"

# Set up model -------------------------------------------------------------

# Load SentencePiece models
src_sp = spm.SentencePieceProcessor()
src_sp.Load(str(SOURCE_SP))
trg_sp = spm.SentencePieceProcessor()
trg_sp.Load(str(TARGET_SP))

# Load model
model = setup(build_model(ARGS).to(ARGS.device), ARGS.checkpoint, ARGS)

# Results DataFrame ----------------------------------------------------------

results = pd.DataFrame(columns=["Query ID", "Query SMILES", "Query ECFP", "Predicted SMILES", "Time Elapsed"])

# Load data ------------------------------------------------------------------

data = pd.read_csv(WORK_DIR / data_filename, sep="\t")

# Prediction -----------------------------------------------------------------

# Loop over data ECFP
for idx, row in data.iterrows():

    if idx >= MAX_ROWS:
        break
    
    query_id = idx + 1
    query_smiles = row["Query SMILES"]
    query_ecfp = row["Query ECFP"]

    # Log progress
    if idx % 100 == 0:
        print(f"Processing {query_id} / {len(data)}")

    with handy.Timer() as timer:
        # Perform inference
        predi_tokens = inference(model, query_ecfp, METHOD, ARGS, src_sp, trg_sp)

    predi_smiles = inference(model, query_ecfp, METHOD, ARGS, src_sp, trg_sp)
    # query_smiles = query_smiles.replace(" ", "")
    predi_smiles = predi_smiles.replace(" ", "")
    results.loc[idx] = [query_id, query_smiles, query_ecfp, predi_smiles, timer.elapsed]


# Save results -----------------------------------------------------------------
results.to_csv(WORK_DIR / results_filename, sep='\t', index=False)

The size of src vocab is 2052 and that of trg vocab is 109.
Loading checkpoint... ECFP4 smiles
Processing 1 / 10000


  checkpoint = torch.load(checkpoint_path, map_location=torch.device(args.rank))


Processing 101 / 10000
Processing 201 / 10000
Processing 301 / 10000
Processing 401 / 10000
Processing 501 / 10000
Processing 601 / 10000
Processing 701 / 10000
Processing 801 / 10000
Processing 901 / 10000
Processing 1001 / 10000
Processing 1101 / 10000
Processing 1201 / 10000
Processing 1301 / 10000
Processing 1401 / 10000
Processing 1501 / 10000
Processing 1601 / 10000
Processing 1701 / 10000
Processing 1801 / 10000
Processing 1901 / 10000
Processing 2001 / 10000
Processing 2101 / 10000
Processing 2201 / 10000
Processing 2301 / 10000
Processing 2401 / 10000
Processing 2501 / 10000
Processing 2601 / 10000
Processing 2701 / 10000
Processing 2801 / 10000
Processing 2901 / 10000
Processing 3001 / 10000
Processing 3101 / 10000
Processing 3201 / 10000
Processing 3301 / 10000
Processing 3401 / 10000
Processing 3501 / 10000
Processing 3601 / 10000
Processing 3701 / 10000
Processing 3801 / 10000
Processing 3901 / 10000
Processing 4001 / 10000
Processing 4101 / 10000
Processing 4201 / 10000
P

### 4.2 - Refine results

In [2]:
results_filename = "results.raw.tsv"
out_filename = "results.refined.tsv"

# Load data -------------------------------------------------------------------

data = pd.read_csv(WORK_DIR / results_filename, sep="\t")


# Refine data -----------------------------------------------------------------

# Let's recompute required information on the Query side
data["Query SMILES"] = data["Query SMILES"].apply(remove_spaces)
data["Query Mol"] = data["Query SMILES"].apply(mol_from_smiles_with_exception)
data["Query Counted ECFP Object"] = data["Query Mol"].apply(mol_to_ecfp)
data["Query Counted ECFP"] = data["Query Counted ECFP Object"].apply(ecfp_to_string)
data["Query Binary ECFP Object"] = data["Query Mol"].apply(mol_to_ecfp_molforge)
data["Query Binary ECFP"] = data["Query Binary ECFP Object"].apply(ecfp_to_string_molforge)

# Now let's populate back the prediction side
data["Predicted Mol"] = data["Predicted SMILES"].apply(mol_from_smiles_with_exception)
data["Predicted Counted ECFP Object"] = data["Predicted Mol"].apply(mol_to_ecfp)
data["Predicted Counted ECFP"] = data["Predicted Counted ECFP Object"].apply(ecfp_to_string)
data["Predicted Binary ECFP Object"] = data["Predicted Mol"].apply(mol_to_ecfp_molforge)
data["Predicted Binary ECFP"] = data["Predicted Binary ECFP Object"].apply(ecfp_to_string_molforge)
data["Predicted Canonic SMILES"] = data["Predicted Mol"].apply(mol_to_smiles)

# Now let's check for Mol validity
data["SMILES Syntaxically Valid"] = data["Predicted Mol"].notnull()

# Now let's check for SMILES equality
data["SMILES Exact Match"] = data["Query SMILES"] == data["Predicted SMILES"]

# Now let's check for Tanimoto identity
data["Tanimoto Counted ECFP"] = data.apply(lambda x: tanimoto(x["Query Counted ECFP Object"], x["Predicted Counted ECFP Object"]), axis=1)
data["Tanimoto Counted ECFP Exact Match"] = data["Tanimoto Counted ECFP"] == 1.0
data["Tanimoto Binary ECFP"] = data.apply(lambda x: tanimoto(x["Query Binary ECFP Object"], x["Predicted Binary ECFP Object"]), axis=1)
data["Tanimoto Binary ECFP Exact Match"] = data["Tanimoto Binary ECFP"] == 1.0

# Finally export the refined DataFrame
cols = [
    "Query ID",
    "Query SMILES",
    "Query Counted ECFP",
    "Query Binary ECFP",
    "Predicted SMILES",
    "Predicted Counted ECFP",
    "Predicted Binary ECFP",
    "Predicted Canonic SMILES",
    "Tanimoto Counted ECFP",
    "Tanimoto Counted ECFP Exact Match",
    "Tanimoto Binary ECFP",
    "Tanimoto Binary ECFP Exact Match",
    "SMILES Exact Match",
    "SMILES Syntaxically Valid",
    "Time Elapsed",
]
data.to_csv(WORK_DIR / out_filename, sep="\t", index=False, columns=cols)

### 4.3 — Summary statistics

In [3]:
results_filename = "results.refined.tsv"
top_k = 1

from handy import get_summary, get_statistics, get_uniqueness


# Load data -------------------------------------------------------------------

data = pd.read_csv(WORK_DIR / results_filename, sep="\t")

# Add dummy "Predicted Log Prob" column if not present (for compatibility)
if "Predicted Log Prob" not in data.columns:
    data["Predicted Log Prob"] = 0.0


# Summary ---------------------------------------------------------------------

summary = get_summary(data, topk=top_k)
out_filename = "results.summary.tsv"
summary.to_csv(WORK_DIR / out_filename, sep="\t", index=False)


# Statistics -------------------------------------------------------------------

stats = get_statistics(data, topk=top_k)
out_filename = "results.statistics.tsv"
stats.to_csv(WORK_DIR / out_filename, sep="\t", index=False)


# Uniqueness -------------------------------------------------------------------

uniqueness = get_uniqueness(data, topk=top_k)
out_filename = "results.uniqueness.tsv"
uniqueness.to_csv(WORK_DIR / out_filename, sep="\t", index=False)


print(f"stats:")
print(stats)
print()
print(f"uniqueness:")
print(uniqueness)

stats:
                             Stat     Value
0                 SMILES Accuracy  0.604800
1  Tanimoto Counted ECFP Accuracy  0.667300
2          SMILES Syntax Validity  0.997700
3   Tanimoto Binary ECFP Accuracy  0.871600
4            Average Time Elapsed  1.715669

uniqueness:
  Distinct Molecules per Query  Count
0                            0   3327
1                            1   6673
