# Inference using the Synthetic Fermentation models

In principle, this notebook works the same as `inference.ipynb`, but we keep it separate for reproducibility of the virtual library predictions and also because here we will do some specific steps like I/O directly to the database.


In [None]:
import pathlib
import statistics
import sys
import sqlite3
sys.path.append(str(pathlib.Path("__file__").absolute().parents[1]))

import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader

from src.util.definitions import TRAINED_MODEL_DIR, LOG_DIR, DATA_ROOT
from src.model.classifier import load_trained_model
from src.data.dataloader import SynFermDataset, collate_fn
from reaction_generator import SFReactionGenerator

In [None]:
# paths to the best models
model_0D_name = "2024-01-04-085409_305115_fold0"
model_0D = TRAINED_MODEL_DIR / model_0D_name / "last-epoch72-val_loss0.19.ckpt"  # FFN
# path to the OneHotEncoder state for model_0D
ohe_state_dict = LOG_DIR / "OHE_state_dict_ohlvinnXkSzSXBJi.json"
assert model_0D.is_file()
assert ohe_state_dict.is_file()

To use the notebook on your products, change `raw_dir` to the directory that your CSV file containing SMILES is in. Then change `filename_base` to the filename of your csv file without the `.csv` suffix. If you do not want to use all the SMILES in your file (e.g. because some are not valid SLAP products), suppy a `valid_idx_file`. You can set the value to `None` if you want to use all SMILES.

In [None]:
# Import product SMILES
raw_dir = DATA_ROOT
dbname = "50k_project.db"
con = sqlite3.connect(DATA_ROOT / dbname)
# n.b. we just select everything now, later filter for things that the 0D model cannot handle.
# Here we don't care whether the reaction has been seen before, we can merge with the reaction data later
res = con.execute("SELECT id, long_name, SMILES FROM virtuallibrary WHERE type = 'A';").fetchall()
df = pd.DataFrame(res, columns=["vl_id", "long_name", "product_A_smiles"])
df.head()

In [None]:
# check if all these were actually in the training data for the model we use
dfs = [pd.read_csv(TRAINED_MODEL_DIR / model_0D_name / f"train_{bb}.csv") for bb in ["initiators", "monomers", "terminators"]]
used_building_blocks = pd.concat(dfs)
used_building_blocks    


In [None]:
# unused building blocks will not be recognized by the one-hot encoder, so we filter for that
df = df.loc[df["long_name"].str.split("+").apply(lambda x: all([i.strip() in used_building_blocks["long"].values for i in x]))]
len(df)

In [None]:
gen = SFReactionGenerator()

In [None]:
%%time
df[0:100]["product_A_smiles"].apply(lambda x: gen.get_reaction_smiles(x))

In [None]:
# first we need to generate the reactionSMILES. This will take a moment 
# (to be precise I expect it to take a bit over 2h. Of course one could optimize or parallelize but I'm going home now so who cares)
df["reaction_smiles_atom_mapped"] = df["product_A_smiles"].apply(lambda x: gen.get_reaction_smiles(x))
df.head()

In [None]:
# save this
df.to_csv(DATA_ROOT / "virtual-library_reactionSMILES.csv")

In [None]:
# instantiate data set. This will also take a moment
data = SynFermDataset(
    name="virtual-library_reactionSMILES.csv",
    raw_dir=DATA_ROOT,
    reaction=True,
    global_features=["OHE", ],
    global_featurizer_state_dict_path=ohe_state_dict,
    graph_type="bond_edges",
    featurizers="custom",
    smiles_columns=["reaction_smiles_atom_mapped"],
    label_columns=None,
    task="multilabel"
)

In [None]:
# run the predictions

# load the trained model 
model_0D = load_trained_model("FFN", model_0D)
model_0D.eval()
trainer = pl.Trainer(accelerator="gpu", logger=False, max_epochs=-1)
# prepare data
dl = DataLoader(data, collate_fn=collate_fn, num_workers=0)
# predict
probabilities_0D = torch.sigmoid(torch.concat(trainer.predict(model_0D, dl)))
    

In [None]:
probabilities_0D

In [None]:
# load decision thresholds
with open(LOG_DIR / "thresholds" / f"{model_0D_name}.txt", "r") as f:
    thresholds = [float(i) for i in f.readlines()]
print(thresholds)

In [None]:
# apply the thresholds
preds = torch.stack([torch.where(probabilities_0D[:, i] > thresholds[i], 1, 0) for i in range(3)], dim=1)

In [None]:
# combine with data
df[["prob_A", "prob_B", "prob_C"]] = probabilities_0D
df[["pred_A", "pred_B", "pred_C"]] = preds 

In [None]:
df.head()

In [None]:
# summarize our predictions
df["pred_A"].value_counts()

In [None]:
df["prob_A"].plot.hist(bins=100)

In [None]:
# write df to output file
df.to_csv(DATA_ROOT / f"virtual-library_predictions_2023-12-20.csv", index=False)

In [None]:
from collections import defaultdict

In [None]:
vl_ids = defaultdict(dict)
longnames = {}

In [None]:
# get the mapping between vl_ids and longnames from db
res = con.execute("SELECT id, long_name, type FROM virtuallibrary").fetchall()
for row in res:
    vl_ids[row[1]][row[2]] = row[0]
    if row[2] == "A":
        longnames[row[0]] = row[1]

In [None]:
# write to db
cur = con.cursor()
for i, row in df[["vl_id", "pred_A", "pred_B", "pred_C"]].iterrows():
    # get vl_id for B and C
    other_ids = vl_ids[longnames[int(row["vl_id"])]]
    cur.execute("INSERT INTO virtuallibrary_predictions (vl_id, binary_outcome, binary_model) VALUES (?, ?, ?);",
               (row["vl_id"].item(), row["pred_A"].item(), model_0D_name))
    cur.execute("INSERT INTO virtuallibrary_predictions (vl_id, binary_outcome, binary_model) VALUES (?, ?, ?);",
               (other_ids["B"], row["pred_B"].item(), model_0D_name))
    cur.execute("INSERT INTO virtuallibrary_predictions (vl_id, binary_outcome, binary_model) VALUES (?, ?, ?);",
               (other_ids["C"], row["pred_C"].item(), model_0D_name))
con.commit()