In [None]:
import torch
import pandas as pd
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("gbyuvd/bionat-selfies-gen-tokenizer-wordlevel")
    
# Ensure [MASK] token exists
if not hasattr(tokenizer, 'mask_token_id') or tokenizer.mask_token_id is None:
    tokenizer.add_special_tokens({'mask_token': '[MASK]'})

print(f"Vocab size: {tokenizer.vocab_size}")
print(f"[MASK] token ID: {tokenizer.mask_token_id}")

In [2]:
from model import ImplicitRefinementModel, ImplicitRefinementConfig

seq_len = 90  # must match training

config = ImplicitRefinementConfig(
    vocab_size=tokenizer.vocab_size,
    hidden_size=320,
    num_layers=6,
    num_heads=4,
    max_seq_len=seq_len,
    max_refinement_steps=10,
    dropout=0.1,
    use_self_cond=True,
    stop_threshold=0.02,
    min_refine_uncertainty=0.1,
    ema_decay=0.995,
    diversity_weight=0.05,
    sampling_temperature=1.0
)

model = ImplicitRefinementModel(config, tokenizer=tokenizer).to(device)
print(f"Model loaded with {sum(p.numel() for p in model.parameters()):,} parameters")

Model loaded with 7,794,742 parameters


In [3]:
checkpoint_path = "best_chemistrySELFIESmodel.pth"  # or "refinement_model_z_final.pth"

torch.serialization.add_safe_globals([ImplicitRefinementConfig])
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()  # set to evaluation mode
print("✅ Model weights loaded successfully!")

✅ Model weights loaded successfully!


In [4]:
def generate_and_decode(model, tokenizer, num_samples=5, max_len=90, temperature=1.0):
    """Generate chemistry molecules using adaptive refinement."""
    model.eval()
    print(f"\n🧪 Generating {num_samples} SELFIES molecules...")
    print(f"   Temperature: {temperature}")
    print("="*70)
    
    with torch.no_grad():
        samples = model.sample(batch_size=2, max_len=90, device='cuda')
    
    # Compute actual lengths (stop at first pad or eos)
    actual_lengths = []
    for s in samples:
        # Find first pad or eos
        s_list = s if isinstance(s, list) else s.tolist()
        length = len(s_list)
        for i, tok in enumerate(s_list):
            if tok == tokenizer.pad_token_id or tok == tokenizer.eos_token_id:
                length = i
                break
        actual_lengths.append(torch.tensor(length))
    
    for i, (sample, length) in enumerate(zip(samples, actual_lengths)):
        decoded = tokenizer.decode(sample[:length.item()], skip_special_tokens=True)
        print(f"{i+1}. (len={length.item()}) {decoded}")
    
    print("="*70)


In [5]:
# Try different temperatures
for temp in [0.8, 1.0, 1.2]:
    print(f"\n--- Temperature: {temp} ---")
    generate_and_decode(
        model=model,
        tokenizer=tokenizer,
        num_samples=5,
        temperature=temp
    )


--- Temperature: 0.8 ---

🧪 Generating 5 SELFIES molecules...
   Temperature: 0.8
✅ Stopped at step 4 (change: 0.00%)
1. (len=60) [=C] [=C] [Branch2] [C] [C] [C] [=O] [=C] [C] [=C] [C] [Ring1] [C] [C] [#Branch1] [C] [=O] [O] [C] [=N] [C] [=C] [S] [Branch1] [Branch1] [=C] [Branch1] [Branch1] [S] [Branch1] [Branch1] [=Branch2] [Ring1] [Branch1] [S] [Branch1] [Ring1] [NH1] [N] [S] [Branch2] [Branch1] [S] [S] [C] [=O] [S] [=Branch1] [Ring1] [N] [=Branch1] [C] [=O] [=N] [=Branch2] [P] [N] [#Branch2] [Ring1]
2. (len=60) [C] [C] [O] [Ring2] [C] [Ring1] [Ring2] [C] [C] [C] [Branch2] [=Branch2] [Branch1] [C] [N] [C] [O] [O] [Branch1] [C] [C] [C] [C] [Branch1] [Branch1] [S] [Branch1] [F] [C] [=C] [=Branch2] [Branch1] [Branch2] [Branch1] [C] [Branch2] [Branch1] [=Branch1] [S] [Branch1] [Ring1] [Ring1] [=O] [Branch1] [=O] [S] [O] [C] [=Branch1] [Ring2] [C] [P] [Ring2] [#Branch2] [C] [S] [Branch2] [Ring1] [S]

--- Temperature: 1.0 ---

🧪 Generating 5 SELFIES molecules...
   Temperature: 1.0
✅ Stop

In [6]:
import selfies as sf

tokens = "[C] [O] [C] [N] [C] [Ring2] [C] [Ring2] [=C] [C] [C] [=Branch2] [N] [C] [C] [=C]"
tokens = tokens.replace(" ", "")
print(sf.decoder(tokens))

COCNC=CCCC=C


In [7]:
model.eval()
analysis = model.analyze_refinement_trajectory(
    max_len=5,
    device='cuda',
    seed=42
)
model.print_refinement_trajectory(analysis, tokenizer=tokenizer)

🔍 Refinement Trajectory (max_steps=10)

t=0: [[12]] [ [5]] [ [5]] [ [5]] [[10]]
        ↑ ↑ ↑ ↑ ↑  ← High entropy everywhere (max uncertainty)
t=1: [ [7]] [ [5]] [ [5]] [ [5]] [ EOS]
         ↑  ↑  ↑  ↑  ↑                     ← High uncertainty at pos 0, 1, 2, 3, 4
t=2: [ [7]] [ [5]] [[18]] [ [9]] [ EOS]
           ↑  ↑  ↑                         ← High uncertainty at pos 1, 2, 3
              ← change_ratio = 0.0% < 2% → ✅ Early stopping triggered
        (no critic — just self-consistency)

Final output: '[Ring1] [C] [S] [O] </s>'
