In [1]:
from Models.TensorGen import BeamSearchOptimized as BeamSearch
from Models.Configure import VanillaTransformerConfig, prepare_model
from Data.Dataset import StepSM_Dataset_v2
from Utils.PostProcess import (
    BeamResultType, find_valid_paths, process_paths,
    find_matching_paths, find_top_n_accuracy, load_pharma_compounds
)
import torch
import yaml
from pathlib import Path
from rdkit import RDLogger, Chem
from tqdm import tqdm

RDLogger.DisableLog("rdApp.*")

data_path = Path.cwd() / "Data"
processed_path = data_path / "Processed"
ckpt_path = data_path / "Checkpoints"
fig_path = data_path / "Figures"

# Load Model and Dataset Class

In [2]:
with open(processed_path / "character_dictionary.yaml", "rb") as file:
    data = yaml.safe_load(file)
    idx_to_token = data["invdict"]

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ckpt_name = "van_6x3_6x3_final.ckpt"

van_enc_conf = VanillaTransformerConfig(
    input_dim=53,
    output_dim=53,
    input_max_length=145 + 135,
    output_max_length=1074 + 1,  # 1074 is max length
    pad_index=52,
    n_layers=6,
    ff_mult=3,
    attn_bias=False,
    ff_activation="gelu",
    hid_dim=256,
)
van_dec_conf = VanillaTransformerConfig(
    input_dim=53,
    output_dim=53,
    input_max_length=145 + 135,
    output_max_length=1074 + 1,  # 1074 is max length
    pad_index=52,
    n_layers=6,
    ff_mult=3,
    attn_bias=False,
    ff_activation="gelu",
    hid_dim=256,
)
model = prepare_model(enc_config=van_enc_conf, dec_config=van_dec_conf)

ckpt_torch = torch.load(ckpt_path / ckpt_name, map_location=device)
model.load_state_dict(ckpt_torch)
model.to(device)
model.eval();

The model has 9,857,333 trainable parameters


# Load Pharma Compounds

In [4]:
BSObject = BeamSearch(
    model=model,
    beam_size=50,
    start_idx=0,
    pad_idx=52,
    end_idx=22,
    max_length=1074,
    idx_to_token=idx_to_token,
    device=device,
)

_products, _sms, _path_strings, _steps_list, nameToIdx = load_pharma_compounds("pharma_compounds.json")

pharma_ds = StepSM_Dataset_v2(
        products=_products,
        starting_materials=_sms,
        path_strings=_path_strings,
        n_steps_list=_steps_list,
        metadata_path=processed_path / "character_dictionary.yaml",
    )

pharma_dl = torch.utils.data.DataLoader(
            dataset=pharma_ds, batch_size=1, shuffle=False, num_workers=0
        )

# Generate Routes

In [5]:
all_beam_results_NS2: BeamResultType = []
for prod_sm, _, steps in pharma_dl:
    beam_result_BS2 = BSObject.decode(
        src_BC=prod_sm.to(device), steps_B1=steps.to(device)
    )
    for beam_result_S2 in beam_result_BS2:
        all_beam_results_NS2.append(beam_result_S2)

 20%|█▉        | 211/1072 [00:02<00:09, 87.46it/s] 
 19%|█▉        | 209/1072 [00:02<00:09, 95.08it/s] 
 19%|█▉        | 201/1072 [00:02<00:08, 97.45it/s] 
 31%|███       | 332/1072 [00:04<00:10, 68.92it/s] 
 29%|██▉       | 310/1072 [00:04<00:10, 72.72it/s] 
 29%|██▉       | 311/1072 [00:04<00:10, 72.25it/s] 
 24%|██▍       | 256/1072 [00:03<00:09, 83.04it/s] 
 24%|██▎       | 254/1072 [00:03<00:09, 83.61it/s] 
 23%|██▎       | 243/1072 [00:02<00:09, 85.75it/s] 
 41%|████▏     | 443/1072 [00:08<00:11, 54.51it/s] 
 42%|████▏     | 453/1072 [00:08<00:11, 53.27it/s] 
 41%|████      | 440/1072 [00:08<00:11, 54.46it/s] 
 41%|████▏     | 443/1072 [00:08<00:11, 54.12it/s] 
 30%|███       | 322/1072 [00:04<00:10, 69.40it/s] 
 34%|███▍      | 368/1072 [00:05<00:11, 62.65it/s] 
 31%|███       | 332/1072 [00:04<00:10, 67.84it/s] 
 39%|███▊      | 414/1072 [00:07<00:11, 57.40it/s] 
 38%|███▊      | 410/1072 [00:07<00:11, 57.69it/s] 
 37%|███▋      | 396/1072 [00:06<00:11, 59.36it/s] 
 42%|████▏  

In [6]:
top_n_vals = [1, 3, 5, 10, 20, 50]

valid_paths_NS2n = find_valid_paths(all_beam_results_NS2)
correct_paths_NS2n = process_paths(
    paths_NS2n=valid_paths_NS2n, true_products=_products, true_reacs=_sms, commercial_stock=None, verbose=True
        )
match_accuracy_N, perm_match_accuracy_N = find_matching_paths(
    correct_paths_NS2n, _path_strings
)

freqs_noperm = find_top_n_accuracy(match_accuracy_N, top_n_vals)
freqs_wperm = find_top_n_accuracy(perm_match_accuracy_N, top_n_vals)
print("---- Top N Accuracy ---")
print("based on raw output:")
print(f"W/o perms: {freqs_noperm}")
print(f"W/  perms: {freqs_wperm}")

Starting to canonicalize paths:


100%|██████████| 21/21 [00:01<00:00, 19.08it/s]


Failed to canonicalize counter=0 path strings
Starting to remove repetitions within beam results:


100%|██████████| 21/21 [00:00<00:00, 43.15it/s]


Starting to find paths with correct product and reactants:


21it [00:00, 63.84it/s]

---- Top N Accuracy ---
based on raw output:
W/o perms: {'Top 1': '0.0', 'Top 3': '4.8', 'Top 5': '9.5', 'Top 10': '14.3', 'Top 20': '14.3', 'Top 50': '14.3'}
W/  perms: {'Top 1': '19.0', 'Top 3': '38.1', 'Top 5': '42.9', 'Top 10': '47.6', 'Top 20': '47.6', 'Top 50': '47.6'}





In [7]:
print("Ranks:", perm_match_accuracy_N)
print("Compound names and their corresponding indices (depends on # of SMs):")
nameToIdx

Ranks: [1, 3, 6, None, None, None, 1, None, None, 2, 3, 3, 1, 4, None, 1, None, None, None, None, None]
Compound names and their corresponding indices (depends on # of SMs):


{'Vonoprazan-1': [0, 1, 2],
 'Vonoprazan-2': [3, 4, 5],
 'Vonoprazan-2 partial': [6, 7, 8],
 'Mitapivat-1': [9, 10, 11, 12],
 'Mitavivat-2': [13, 14, 15],
 'Daridoxerant': [16, 17, 18, 19, 20]}