# Inference using the SynFerm models

Predict reaction outcome for Synthetic Fermentation products.

The input to this are SMILES of the desired product(s).
Inputs can be supplied directly as a csv file with one column named "smiles" and arbitrary additional columns.

The output is:
- the reactionSMILES that leads to this product using the SLAP platform
- a classification of whether the reaction is expected to work

The output is written to a new csv file containing all columns from the input file, and six new columns: `rxn_smiles`, `rxn_prediction`, `rxn_confidence`.

Predictions are given as `0` (meaning no reaction expected) or `1` (meaning successful reaction expected). 
If the reaction was in the acquired data set, the known outcome is returned instead.

Confidence is given as an integer in the range `0-4`, with `0` indicating the highest confidence.
Confidence is determined based on the complexity of the prediction problem using the following mapping:
- `0`: known reaction
- `1`: all three reactants known in other reactions
- `2`: exactly one reactant known in other reactions
- `3`: exactly two reactants known in other reactions
- `4`: none of the reactants known in other reactions


In [32]:
import pathlib
import statistics
import sys
sys.path.append(str(pathlib.Path("__file__").absolute().parents[1]))

import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader

from src.util.definitions import TRAINED_MODEL_DIR, LOG_DIR, DATA_ROOT
from src.model.classifier import load_trained_model
from src.data.dataloader import SynFermDataset, collate_fn
from reaction_generator import SFReactionGenerator

In [33]:
def import_smiles(
    raw_dir: pathlib.Path, filename: str, valid_idx_file: pathlib.Path = None
):
    """Import smiles from a csv file and filter by indices supplied in a second csv file"""
    smiles_df = pd.read_csv(raw_dir / filename)
    if valid_idx_file is None:
        return smiles_df
    else:
        indices_arr = pd.read_csv(valid_idx_file)["index"].to_numpy()
        return smiles_df.loc[indices_arr]

In [43]:
# paths to the best models
model_0D_name = "2023-11-20-175433_236136_fold0"
model_0D = LOG_DIR / "checkpoints" / model_0D_name / "last-epoch38-val_loss0.20.ckpt"  # FFN
#model_1D = TRAINED_MODEL_DIR / "2023-03-06-112027_188465" / "best.ckpt"  # D-MPNN
#model_2D = TRAINED_MODEL_DIR / "2023-03-06-112721_778803" / "best.ckpt"  # D-MPNN
#model_3D
# path to the OneHotEncoder state for model_0D
ohe_state_dict = LOG_DIR / "OHE_state_dict_KcEovvzIEafcIYUJ.json"

To use the notebook on your products, change `raw_dir` to the directory that your CSV file containing SMILES is in. Then change `filename_base` to the filename of your csv file without the `.csv` suffix. If you do not want to use all the SMILES in your file (e.g. because some are not valid SLAP products), suppy a `valid_idx_file`. You can set the value to `None` if you want to use all SMILES.

In [26]:
# Import product SMILES and generate reactionSMILES. This will take some time.
raw_dir = DATA_ROOT  # <-- change me
filename = "synferm_dataset_2023-09-05_40018records.csv"  # <-- change me
# remove the .csv extension AND any other extensions behind it (e.g. remove .csv.bz2 or csv.gz)
filename_base = filename.split(".csv")[0]
valid_idx_file = "../data/splits/synferm_dataset_2023-09-05_0D_split_final-retrain/fold0_val.csv"  # <-- change me or set me to None
df = import_smiles(raw_dir, filename, valid_idx_file=valid_idx_file)
#data = SLAPProductDataset(smiles=df["smiles"].values.tolist())
df

Unnamed: 0,I_long,M_long,T_long,product_A_smiles,I_smiles,M_smiles,T_smiles,reaction_smiles,reaction_smiles_atom_mapped,experiment_id,...,binary_H,scaled_A,scaled_B,scaled_C,scaled_D,scaled_E,scaled_F,scaled_G,scaled_H,major_A-C
22949,BiPh011,Fused006,TerTH020,CN(C)c1cccc(-c2nnc([C@@H]3CCC[C@@H]3NC(=O)c3cc...,[K+].[N-]=[N+]=NCCCOc1ccc(C(=O)[B-](F)(F)F)cc1,Cl.O=C1OC2(CCCCC2)O[C@]12ON[C@H]1CCC[C@H]12,CN(C)c1cccc(C(=S)NN)c1.Cl,[N-]=[N+]=NCCCOc1ccc(C(=O)[B-](F)(F)F)cc1.O=C1...,F[B-](F)(F)[C:2]([c:1]1[cH:13][cH:15][c:17]([O...,69958,...,1,0.449211,0.079516,0.000000,1.960700,1.283291,0.041321,0.017657,0.108452,A
22690,BiPh010,Mon097,TerTH005,COc1ccc(-c2nnc(C[C@H](NC(=O)c3cccc(OCCCCl)c3)c...,O=C(c1cccc(OCCCCl)c1)[B-](F)(F)F.[K+],O=C1OC2(CCCCC2)O[C@@]12C[C@@H](c1cnco1)NO2,COc1ccc(C(=S)NN)cn1.Cl,O=C(c1cccc(OCCCCl)c1)[B-](F)(F)F.O=C1OC2(CCCCC...,F[B-](F)(F)[C:2]([c:1]1[cH:13][cH:15][cH:17][c...,55855,...,1,0.423834,0.493011,0.000000,0.005785,0.173459,0.251768,0.136938,0.327969,B
2299,2-Pyr008,Spiro004,TerABT016,COc1cc2nc(C3(CNC(=O)c4ccc(F)cn4)CCN(C(=O)OCCc4...,O=C(c1ccc(F)cn1)[B-](F)(F)F.[K+],Cl.O=C(OCCc1ccc(F)cc1)N1CCC2(CC1)CNO[C@]21OC2(...,COc1cc(N)c(S)cc1OC,O=C(c1ccc(F)cn1)[B-](F)(F)F.O=C(OCCc1ccc(F)cc1...,F[B-](F)(F)[C:2]([c:1]1[cH:16][cH:18][c:20]([F...,45667,...,0,0.000000,0.000000,0.000000,0.000000,0.000000,0.052037,0.017801,,no_product
8671,4-Pyr002,Mon094,TerTH001,O=C(N[C@H](Cc1nnc(-c2ccccc2)s1)c1ccc2ccccc2c1)...,O=C(c1ccnc(Cl)c1)[B-](F)(F)F.[K+],O=C1OC2(CCCCC2)OC12C[C@H](c1ccc3ccccc3c1)[NH2+...,[Cl-].[NH3+]NC(=S)c1ccccc1,O=C(c1ccnc(Cl)c1)[B-](F)(F)F.O=C1OC2(CCCCC2)OC...,F[B-](F)(F)[C:2]([c:1]1[cH:13][cH:15][n:17][c:...,58689,...,1,0.040748,0.023838,0.261811,0.006519,1.432327,0.000000,0.053422,0.484402,C
29288,Ph013,Mon015,TerABT016,COc1cc2nc(C[C@@H](NC(=O)c3ccc(NC(=O)OC(C)(C)C)...,CC(C)(C)OC(=O)Nc1ccc(C(=O)[B-](F)(F)F)cc1.[K+],CC(C)(C)OC(=O)N1CCC([C@H]2C[C@]3(ON2)OC2(CCCCC...,COc1cc(N)c(S)cc1OC,CC(C)(C)OC(=O)Nc1ccc(C(=O)[B-](F)(F)F)cc1.CC(C...,F[B-](F)(F)[C:2]([c:1]1[cH:16][cH:18][c:20]([N...,45757,...,1,0.637875,0.000000,0.000000,0.001889,0.000000,0.890057,0.812710,0.307443,A
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
20186,BiPh002,Mon104,TerTH005,C#Cc1ccc(C(=O)NC[C@H](c2ccncc2)c2nnc(-c3ccc(OC...,C#Cc1ccc(C(=O)[B-](F)(F)F)cc1.[K+],O=C1OC2(CCCCC2)O[C@]12ONC[C@@H]2c1ccncc1,COc1ccc(C(=S)NN)cn1.Cl,C#Cc1ccc(C(=O)[B-](F)(F)F)cc1.O=C1OC2(CCCCC2)O...,F[B-](F)(F)[C:2]([c:1]1[cH:13][cH:15][c:17]([C...,66975,...,1,0.004200,0.000000,0.000000,0.720065,0.325878,0.007424,0.000000,0.574668,A
24942,BiPyr004,Mon007,TerTH022,C=Cc1ccnc(C(=O)N[C@@H](Cc2ccccc2)Cc2nnc(-c3c(C...,C=Cc1ccnc(C(=O)[B-](F)(F)F)c1.[K+],Cl.O=C1OC2(CCCCC2)O[C@@]12C[C@H](Cc1ccccc1)NO2,Cc1noc(C)c1C(=S)N[NH3+].[Cl-],C=Cc1ccnc(C(=O)[B-](F)(F)F)c1.O=C1OC2(CCCCC2)O...,F[B-](F)(F)[C:2]([c:1]1[n:13][cH:15][cH:17][c:...,64590,...,0,1.596361,1.429999,0.101866,0.015639,0.158343,3.693744,0.097508,0.000000,A
20060,BiPh002,Mon092,TerTH006,C#Cc1ccc(C(=O)N[C@@H](Cc2nnc(-c3ccc(N4CCOCC4)c...,C#Cc1ccc(C(=O)[B-](F)(F)F)cc1.[K+],O=C1OC2(CCCCC2)OC12C[C@@H](c1ccc3ccccc3c1)[NH2...,[Cl-].[NH3+]NC(=S)c1ccc(N2CCOCC2)cc1,C#Cc1ccc(C(=O)[B-](F)(F)F)cc1.O=C1OC2(CCCCC2)O...,F[B-](F)(F)[C:2]([c:1]1[cH:13][cH:15][c:17]([C...,37861,...,1,1.579454,0.000000,0.000000,0.002721,0.587177,0.003223,0.248133,1.522079,A
29858,Ph013,Mon104,TerABT005,Cc1ccc2nc([C@H](CNC(=O)c3ccc(NC(=O)OC(C)(C)C)c...,CC(C)(C)OC(=O)Nc1ccc(C(=O)[B-](F)(F)F)cc1.[K+],O=C1OC2(CCCCC2)O[C@]12ONC[C@@H]2c1ccncc1,Cc1ccc(N)c(S)c1,CC(C)(C)OC(=O)Nc1ccc(C(=O)[B-](F)(F)F)cc1.O=C1...,F[B-](F)(F)[C:2]([c:1]1[cH:16][cH:18][c:20]([N...,82241,...,0,0.000000,0.000000,0.000000,4.591078,0.000000,0.000000,0.000000,0.000000,no_product


In [30]:
# save the data
df.to_csv(DATA_ROOT / "inference_test.csv", index=True)

In [66]:
# for the moment we assume everything is 0D data
data = SynFermDataset(
    name="inference_test.csv",
    raw_dir=DATA_ROOT,
    reaction=True,
    global_features=["OHE", ],
    global_featurizer_state_dict_path=ohe_state_dict,
    graph_type="bond_edges",
    featurizers="custom",
    smiles_columns=["reaction_smiles_atom_mapped"],
    label_columns=None,
    task="multilabel"
)


Done saving data into cached files.


In [67]:
# run the predictions

# load the trained model 
model_0D = load_trained_model("FFN", model_0D)
model_0D.eval()
trainer = pl.Trainer(accelerator="gpu", logger=False, max_epochs=-1)
# prepare data
dl = DataLoader(data, collate_fn=collate_fn, num_workers=0)
# predict
probabilities_0D = torch.sigmoid(torch.concat(trainer.predict(model_0D, dl)))
    

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(


Predicting: 0it [00:00, ?it/s]

In [68]:
probabilities_0D

tensor([[0.9989, 0.7657, 0.0236],
        [1.0000, 1.0000, 0.4029],
        [0.7920, 0.0038, 0.0105],
        ...,
        [1.0000, 0.0069, 0.0772],
        [0.1281, 0.0474, 0.0163],
        [1.0000, 1.0000, 0.2797]])

In [69]:
# load decision thresholds
with open(LOG_DIR / "thresholds" / f"{model_0D_name}.txt", "r") as f:
    thresholds = [float(i) for i in f.readlines()]
print(thresholds)

[0.56, 0.54, 0.49]


In [70]:
# apply the thresholds
preds = torch.stack([torch.where(probabilities_0D[:, i] > thresholds[i], 1, 0) for i in range(3)], dim=1)

In [71]:
# combine with data
df[["prob_A", "prob_B", "prob_C"]] = probabilities_0D
df[["pred_A", "pred_B", "pred_C"]] = preds 

In [72]:
# check accuracy
df

Unnamed: 0,I_long,M_long,T_long,product_A_smiles,I_smiles,M_smiles,T_smiles,reaction_smiles,reaction_smiles_atom_mapped,experiment_id,...,scaled_F,scaled_G,scaled_H,major_A-C,pred_A,pred_B,pred_C,prob_A,prob_B,prob_C
22949,BiPh011,Fused006,TerTH020,CN(C)c1cccc(-c2nnc([C@@H]3CCC[C@@H]3NC(=O)c3cc...,[K+].[N-]=[N+]=NCCCOc1ccc(C(=O)[B-](F)(F)F)cc1,Cl.O=C1OC2(CCCCC2)O[C@]12ON[C@H]1CCC[C@H]12,CN(C)c1cccc(C(=S)NN)c1.Cl,[N-]=[N+]=NCCCOc1ccc(C(=O)[B-](F)(F)F)cc1.O=C1...,F[B-](F)(F)[C:2]([c:1]1[cH:13][cH:15][c:17]([O...,69958,...,0.041321,0.017657,0.108452,A,1,1,0,0.998934,0.765715,0.023582
22690,BiPh010,Mon097,TerTH005,COc1ccc(-c2nnc(C[C@H](NC(=O)c3cccc(OCCCCl)c3)c...,O=C(c1cccc(OCCCCl)c1)[B-](F)(F)F.[K+],O=C1OC2(CCCCC2)O[C@@]12C[C@@H](c1cnco1)NO2,COc1ccc(C(=S)NN)cn1.Cl,O=C(c1cccc(OCCCCl)c1)[B-](F)(F)F.O=C1OC2(CCCCC...,F[B-](F)(F)[C:2]([c:1]1[cH:13][cH:15][cH:17][c...,55855,...,0.251768,0.136938,0.327969,B,1,1,0,1.000000,0.999999,0.402944
2299,2-Pyr008,Spiro004,TerABT016,COc1cc2nc(C3(CNC(=O)c4ccc(F)cn4)CCN(C(=O)OCCc4...,O=C(c1ccc(F)cn1)[B-](F)(F)F.[K+],Cl.O=C(OCCc1ccc(F)cc1)N1CCC2(CC1)CNO[C@]21OC2(...,COc1cc(N)c(S)cc1OC,O=C(c1ccc(F)cn1)[B-](F)(F)F.O=C(OCCc1ccc(F)cc1...,F[B-](F)(F)[C:2]([c:1]1[cH:16][cH:18][c:20]([F...,45667,...,0.052037,0.017801,,no_product,1,0,0,0.791972,0.003836,0.010543
8671,4-Pyr002,Mon094,TerTH001,O=C(N[C@H](Cc1nnc(-c2ccccc2)s1)c1ccc2ccccc2c1)...,O=C(c1ccnc(Cl)c1)[B-](F)(F)F.[K+],O=C1OC2(CCCCC2)OC12C[C@H](c1ccc3ccccc3c1)[NH2+...,[Cl-].[NH3+]NC(=S)c1ccccc1,O=C(c1ccnc(Cl)c1)[B-](F)(F)F.O=C1OC2(CCCCC2)OC...,F[B-](F)(F)[C:2]([c:1]1[cH:13][cH:15][n:17][c:...,58689,...,0.000000,0.053422,0.484402,C,1,1,0,0.999621,0.999215,0.012042
29288,Ph013,Mon015,TerABT016,COc1cc2nc(C[C@@H](NC(=O)c3ccc(NC(=O)OC(C)(C)C)...,CC(C)(C)OC(=O)Nc1ccc(C(=O)[B-](F)(F)F)cc1.[K+],CC(C)(C)OC(=O)N1CCC([C@H]2C[C@]3(ON2)OC2(CCCCC...,COc1cc(N)c(S)cc1OC,CC(C)(C)OC(=O)Nc1ccc(C(=O)[B-](F)(F)F)cc1.CC(C...,F[B-](F)(F)[C:2]([c:1]1[cH:16][cH:18][c:20]([N...,45757,...,0.890057,0.812710,0.307443,A,1,0,0,0.999859,0.009234,0.032983
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
20186,BiPh002,Mon104,TerTH005,C#Cc1ccc(C(=O)NC[C@H](c2ccncc2)c2nnc(-c3ccc(OC...,C#Cc1ccc(C(=O)[B-](F)(F)F)cc1.[K+],O=C1OC2(CCCCC2)O[C@]12ONC[C@@H]2c1ccncc1,COc1ccc(C(=S)NN)cn1.Cl,C#Cc1ccc(C(=O)[B-](F)(F)F)cc1.O=C1OC2(CCCCC2)O...,F[B-](F)(F)[C:2]([c:1]1[cH:13][cH:15][c:17]([C...,66975,...,0.007424,0.000000,0.574668,A,0,0,0,0.108916,0.013784,0.028272
24942,BiPyr004,Mon007,TerTH022,C=Cc1ccnc(C(=O)N[C@@H](Cc2ccccc2)Cc2nnc(-c3c(C...,C=Cc1ccnc(C(=O)[B-](F)(F)F)c1.[K+],Cl.O=C1OC2(CCCCC2)O[C@@]12C[C@H](Cc1ccccc1)NO2,Cc1noc(C)c1C(=S)N[NH3+].[Cl-],C=Cc1ccnc(C(=O)[B-](F)(F)F)c1.O=C1OC2(CCCCC2)O...,F[B-](F)(F)[C:2]([c:1]1[n:13][cH:15][cH:17][c:...,64590,...,3.693744,0.097508,0.000000,A,1,1,1,0.999998,0.999993,0.903402
20060,BiPh002,Mon092,TerTH006,C#Cc1ccc(C(=O)N[C@@H](Cc2nnc(-c3ccc(N4CCOCC4)c...,C#Cc1ccc(C(=O)[B-](F)(F)F)cc1.[K+],O=C1OC2(CCCCC2)OC12C[C@@H](c1ccc3ccccc3c1)[NH2...,[Cl-].[NH3+]NC(=S)c1ccc(N2CCOCC2)cc1,C#Cc1ccc(C(=O)[B-](F)(F)F)cc1.O=C1OC2(CCCCC2)O...,F[B-](F)(F)[C:2]([c:1]1[cH:13][cH:15][c:17]([C...,37861,...,0.003223,0.248133,1.522079,A,1,0,0,0.999996,0.006941,0.077242
29858,Ph013,Mon104,TerABT005,Cc1ccc2nc([C@H](CNC(=O)c3ccc(NC(=O)OC(C)(C)C)c...,CC(C)(C)OC(=O)Nc1ccc(C(=O)[B-](F)(F)F)cc1.[K+],O=C1OC2(CCCCC2)O[C@]12ONC[C@@H]2c1ccncc1,Cc1ccc(N)c(S)c1,CC(C)(C)OC(=O)Nc1ccc(C(=O)[B-](F)(F)F)cc1.O=C1...,F[B-](F)(F)[C:2]([c:1]1[cH:16][cH:18][c:20]([N...,82241,...,0.000000,0.000000,0.000000,no_product,0,0,0,0.128062,0.047368,0.016284


In [73]:
from sklearn.metrics import accuracy_score, balanced_accuracy_score, recall_score, precision_score, fbeta_score

In [74]:
# CONTROL: check that we still obtain the same metrics
y_true = df["binary_A"]
y_pred = df["pred_A"]
acc = accuracy_score(y_true, y_pred)
bal_acc = balanced_accuracy_score(y_true, y_pred)
recall = recall_score(y_true, y_pred)
precision = precision_score(y_true, y_pred)
fbeta = fbeta_score(y_true, y_pred, beta=0.5)

print("accuracy:", f"{acc:.2f}")
print("balanced accuracy:", f"{bal_acc:.2f}")
print("recall:", f"{recall:.2f}")
print("precision:", f"{precision:.2f}")
print("f_0.5 score:", f"{fbeta:.2f}")

accuracy: 0.94
balanced accuracy: 0.91
recall: 0.96
precision: 0.97
f_0.5 score: 0.97


In [17]:
# assemble outputs
predictions = np.full(len(data.reactions), np.nan, dtype=float)

predictions[data.idx_known] = [statistics.mean(data.known_outcomes[i]) for i in data.idx_known]  # for known reaction we add the average reaction outcome
try:
    predictions[data.idx_0D] = predictions_0D
except NameError:
    pass
try:
    predictions[data.idx_1D_slap] = predictions_1D_slap
except NameError:
    pass
try:
    predictions[data.idx_1D_aldehyde] = predictions_1D_aldehyde
except NameError:
    pass
try:
    predictions[data.idx_2D] = predictions_2D
except NameError:
    pass


In [18]:
# check if we have not predicted for anything
# this should be only the reactions in data.invalid_idxs
rxn_idxs_no_pred = np.argwhere(np.isnan(predictions)).flatten()

rxn_idxs_invalid = [data.product_idxs.index(i) for i in data.invalid_idxs]

assert set(rxn_idxs_no_pred) == set(rxn_idxs_invalid)

In [23]:
# obtain individual new columns for output df
df["rxn1_smiles"] = [data.reactions[i] for i in arr[:,0]]

df["rxn1_predictions"] = [predictions[i] for i in arr[:,0]]

df["rxn1_confidence"] = [rxn_problem_types[i] for i in arr[:,0]]

df["rxn2_smiles"] = [reactions_augmented[i] for i in arr[:,1]]

df["rxn2_predictions"] = [predictions_augmented[i] for i in arr[:,1]]

df["rxn2_confidence"] = [rxn_problem_types_augmented[i] for i in arr[:,1]]

In [24]:
# write dataset statistics for control to log file (+ optionally print)
verbose = True
log_output = f"""\
{len(data.reactions)} reactions generated from {len(data.smiles)} input SMILES
Known reactions: {(sum(x is not None for x in data.known_outcomes))}
0D reactions: 0, thereof 0 predicted positive
1D reactions with unknown aldehyde: {len(data.dataset_1D_aldehyde)}, thereof {np.count_nonzero(predictions_1D_aldehyde)} predicted positive
1D reactions with unknown SLAP reagent: {len(data.dataset_1D_slap)}, thereof {np.count_nonzero(predictions_1D_slap)} predicted positive
2D reactions: {len(data.dataset_2D)}, thereof {np.count_nonzero(predictions_2D)} predicted positive
"""

with open(raw_dir / f"{filename_base}_reaction_prediction.log", "w") as file:
    file.write(log_output)
if verbose:
    print(log_output)

15621 reactions generated from 10000 input SMILES
Known reactions: 0
0D reactions: 0, thereof 0 predicted positive
1D reactions with unknown aldehyde: 55, thereof 31 predicted positive
1D reactions with unknown SLAP reagent: 219, thereof 62 predicted positive
2D reactions: 15341, thereof 1838 predicted positive



In [14]:
# write df to output file
df.to_csv(raw_dir / f"{filename_base}_reaction_prediction.csv", index=False)