In [1]:
cfg = dict(
    seq_length  = 160,
    # ---------------- model ----------------
    d_model     = 256,
    latent_dim  = 64,   # latent dimension
    enc_layers  = 3,
    dec_layers  = 7,
    dropout     = 0.05,
    emb_dropout = 0.05,

    # special token indices (match your vocabulary)
    pad_idx     = 0,
    sos_idx     = 2,
    eos_idx     = 3,

    # -------- validation / decoding --------
    metrics_every  = 5,   # run beam metrics every N epochs
    beam_size   = 5
)

In [2]:
import torch, torch.nn as nn
import model_bs as mdl
import data_utils as du

# --- paths/config you already have ---
vocab_path   = "/home/md_halim_mondol/LSTM_VAE_Paper/vocab.json"
ckpt_path    = "/home/md_halim_mondol/LSTM_VAE_Paper/checkpoints/model_epoch_50.pth" # best_model.pth
test_csv     = "/home/md_halim_mondol/Data/Test.csv"


# --- load vocab ---
token_to_idx, idx_to_token = du.load_or_create_vocabulary(csv_paths=[], cache_path=vocab_path, test_smiles=None)
assert token_to_idx["<PAD>"] == cfg["pad_idx"]
assert token_to_idx["<SOS>"] == cfg["sos_idx"]
assert token_to_idx["<EOS>"] == cfg["eos_idx"]

# --- build the same architecture you trained ---
model = mdl.CNNCharVAE(
    vocab_size=len(token_to_idx),
    d_model=cfg["d_model"],
    latent_dim=cfg["latent_dim"],
    pad_idx=cfg["pad_idx"],
    sos_idx=cfg["sos_idx"],
    eos_idx=cfg["eos_idx"],
    enc_layers=cfg.get("enc_layers"),
    dec_layers=cfg.get("dec_layers"),
    dropout=cfg.get("dropout"),
    emb_dropout=cfg.get("emb_dropout"),
    max_len=cfg["seq_length"])

# --- load weights robustly (handles 'module.' prefixes if any) ---
state = torch.load(ckpt_path, map_location="cpu")
try:
    model.load_state_dict(state, strict=True)
except RuntimeError:
    # remove a leading 'module.' if the checkpoint came from DataParallel
    from collections import OrderedDict
    new_state = OrderedDict()
    for k, v in state.items():
        new_state[k.replace("module.", "", 1)] = v
    model.load_state_dict(new_state, strict=True)

# --- device & optional DataParallel for speed (not required) ---
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
print("Trainable params:", du.count_parameters(model))
print(f"Encoder parameters: {du.count_parameters(model.encoder)}")
model.eval()

[vocab] loaded cached vocabulary from /home/md_halim_mondol/LSTM_VAE_Paper/vocab.json (49 tokens)


  state = torch.load(ckpt_path, map_location="cpu")


Trainable params: 4729973
Encoder parameters: 1924420


CNNCharVAE(
  (encoder): EncoderCNN(
    (emb): Embedding(49, 256, padding_idx=0)
    (emb_ln): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (emb_do): Dropout(p=0.05, inplace=False)
    (conv): Sequential(
      (0): ConstantPad1d(padding=(4, 4), value=0.0)
      (1): Conv1d(256, 256, kernel_size=(9,), stride=(1,))
      (2): GELU(approximate='none')
      (3): Dropout(p=0.05, inplace=False)
      (4): ConstantPad1d(padding=(4, 4), value=0.0)
      (5): Conv1d(256, 256, kernel_size=(9,), stride=(1,))
      (6): GELU(approximate='none')
      (7): Dropout(p=0.05, inplace=False)
      (8): ConstantPad1d(padding=(4, 5), value=0.0)
      (9): Conv1d(256, 256, kernel_size=(10,), stride=(1,))
      (10): GELU(approximate='none')
      (11): Dropout(p=0.05, inplace=False)
    )
    (proj): Sequential(
      (0): Linear(in_features=256, out_features=196, bias=True)
      (1): ReLU()
    )
    (to_mu): Linear(in_features=196, out_features=64, bias=True)
    (to_logvar): Linear(in_f

In [3]:
Dye_smiles = [
'CC1=CC(=O)c2c(Br)cc(Br)c(S(=O)(=O)O)c2C1=O',
'Cc1c(Br)cc(Br)c(C(C)c2ccc(C(C)C)c(Br)c2O)c1S(=O)(=O)O',
'Cc1ccccc1N=Nc1ccc(C(N)=O)cc1',
'Cc1ccc(-c2cccc(O)c2C(C)c2ccc(S(=O)(=O)O)cc2)cc1',
'CN(C)c1ccc2c(c1)CN=C2c1ccccc1',
'O=C1c2ccccc2C(=O)c2c(O)cccc21',
'O=C1Nc2ccccc2S(=O)(=O)[N-]c2c1cccc2S(=O)(=O)[O-]',
'CN(C)C1=CS(=O)(=O)c2ccc(N(C)C)cc21',
'CN(C)c1ccc(C(C)(c2ccc(N(C)C)cc2)c2ccc(N(C)C)cc2)cc1',
'Cc1ccc(C(c2ccc(O)cc2)(c2ccc(O)cc2)c2ccc(C(C)C)cc2)cc1',
'CN(C)c1ccc2c(c1)CC1=CC(=[N+](C)C)C(=N2)C=C1',
'O=C1c2cc(=O)cccc2C(=O)c2c(Br)cc(Br)cc21',
'O=c1cc(O)ccc(-c2ccc3c(oc(=O)c4ccccc43)c2O)c1']

In [4]:
import pandas as pd
from inference import reconstruct_smiles_table, tensor_to_smiles
import metrics as met

# Use the *unwrapped* model object for beam_search
m = model  # (if you ever wrap with DataParallel, use: model.module)

df_rec = reconstruct_smiles_table( smiles_list=None, test_csv=test_csv, model=m, token_to_idx=token_to_idx,
                                  idx_to_token=idx_to_token, seq_length=cfg["seq_length"], pad_idx=cfg["pad_idx"],
                                  sos_idx=cfg["sos_idx"], eos_idx=cfg["eos_idx"], device=device, mode="beam",
                                  beam_size=cfg["beam_size"])

# show a preview
display(df_rec.head(10))


# ------------------------------------------------------------------
# 1.  Token-level accuracy (micro-average over SMILES tokens)
# ------------------------------------------------------------------
def token_accuracy_row(gold, pred):
    g = du.tokenize_smiles(gold)
    p = du.tokenize_smiles(pred)
    L = min(len(g), len(p))
    if L == 0:                      # degenerate empty case
        return 0, 0
    correct = sum(gi == pi for gi, pi in zip(g[:L], p[:L]))
    total   = L
    return correct, total

tot_corr = tot_tok = 0
for g, p in zip(df_rec["input"], df_rec["reconstructed"]):
    c, t = token_accuracy_row(g, p)
    tot_corr += c
    tot_tok  += t

beam_token_acc = tot_corr / tot_tok if tot_tok else 0.0
print(f"Token level test accuracy (beam): {beam_token_acc:.4f}")

# ------------------------------------------------------------------
# 2.  Sequence-level (exact-match) accuracy
# ------------------------------------------------------------------
exact_match_acc = (df_rec["input"] == df_rec["reconstructed"]).mean()
print(f"Exact SMILES match accuracy (beam): {exact_match_acc:.4f}")


# ---- summary metrics (no retraining) ----
valid_ratio = (df_rec["valid"] == "yes").mean() if len(df_rec) else float("nan")
avg_lev     = df_rec["lev"].mean() if len(df_rec) else float("nan")

print(f"[beam] validity ratio: {valid_ratio:.3f}")
print(f"[beam] average Levenshtein: {avg_lev:.3f}")

Unnamed: 0,input,reconstructed,valid,lev
0,CCOC(=O)C1(CC2CC2)CC[NH+]([C@@H](C)CCC2=c3cccc...,CC1=CCN(CC2CC2)C[C@@]2(CC[NH+](Cc3ccccc3)C2=O)CC1,yes,29
1,O=C(N[C@@H]1CCN(CC(F)(F)F)C1=O)c1c[nH]c2cccc(F...,O=C(N[C@@H]1CCN(CC(F)(F)F)C1=O)c1c[nH]c2cccc(F...,yes,0
2,Cc1ccc(C[NH+]2CCC[C@@H]2c2ccc(C(=O)Nc3nc(C)n(C...,Cc1ccc(C[NH+]2CCC[C@@H]2c2ccc(C(=O)Nc3nc(C)n(C...,yes,0
3,O=C(C[C@@H]1C(=O)N=C2[N-]c3ccccc3N21)Nc1cccc(O...,O=C(C[C@@H]1C(=O)N(C)c2ccccc2N1)N1Cc2cccc(OC(F...,yes,13
4,CCc1ccc(-c2nc(CSc3nnnn3C3CC3)cs2)cc1,CCc1ccc(-c2nc(CSc3nnnn3C3CC3)sc2)cc1,yes,2
5,C[C@H]1CN(C(=O)N[C@@H](C)c2cccc(Br)c2)CCO1,C[C@H]1CN(C(=O)N[C@@H](C)c2cccc(Br)c2)CCO1,yes,0
6,Cc1ccn(C)c1[C@](O)(C(=O)[O-])C(F)(F)F,Cc1ccn(C)c1[C@](O)(C(=O)[O-])C(F)(F)F,yes,0
7,CC(C)N(C[C@@H]1CCCCO1)C(=O)[C@H]1CCc2n[nH]c(C(...,CC(C)N(C[C@@H]1CCCCO1)C(=O)[C@H]1CCc2n[nH]c(C(...,yes,0
8,N[C@H]1C=C[C@H](c2nnc3c(Cl)cc(C(F)(F)F)cn23)C1,N1[C@H](C)Cc2n(nc3c(c2)nc(C(F)(F)F)cn3)C1,no,14
9,CC[C@@H](C)N(Cc1c(-c2ccccc2F)noc1N1CCOCC1)C(=O...,CC[C@@H](C)N(Cc1c(-c2ccccc2)n2c(c1)OCCC2)C(=O)...,yes,9


Token level test accuracy (beam): 0.8633
Exact SMILES match accuracy (beam): 0.5911
[beam] validity ratio: 0.848
[beam] average Levenshtein: 4.331
