In [1]:
import torch
import pandas as pd
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("gbyuvd/bionat-selfies-gen-tokenizer-wordlevel")
    
# Ensure [MASK] token exists
if not hasattr(tokenizer, 'mask_token_id') or tokenizer.mask_token_id is None:
    tokenizer.add_special_tokens({'mask_token': '[MASK]'})

print(f"Vocab size: {tokenizer.vocab_size}")
print(f"[MASK] token ID: {tokenizer.mask_token_id}")

Vocab size: 412
[MASK] token ID: 4


In [None]:
from model import ImplicitRefinementModel, ImplicitRefinementConfig 
seq_len = 90
config = ImplicitRefinementConfig(
        vocab_size=tokenizer.vocab_size,
        hidden_size=320,
        num_layers=4,
        num_heads=4,
        max_seq_len=seq_len,          
        max_refinement_steps=10,
        dropout=0.1,
        use_self_cond=True,
        stop_threshold=0.02,
        min_refine_uncertainty=0.1,
        ema_decay=0.995,
        diversity_weight=0.05,
        sampling_temperature=1.0,
        use_refine_gate=True,  # Enable internal refinement gate
        use_gradient_checkpointing=False  # Enable for larger models
    )

model = ImplicitRefinementModel(config, tokenizer=tokenizer).to(device)
model.init_teacher()

In [10]:
checkpoint_path = "chemistrySELFIESmodelfinal.pth"  # or "refinement_model_z_final.pth"

torch.serialization.add_safe_globals([ImplicitRefinementConfig])
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()  # set to evaluation mode
print("✅ Model weights loaded successfully!")

✅ Model weights loaded successfully!


In [11]:
def generate_and_decode(model, tokenizer, num_samples=5, max_len=90, temperature=1.0):
    """Generate chemistry molecules using adaptive refinement."""
    model.eval()
    print(f"\n🧪 Generating {num_samples} SELFIES molecules...")
    print(f"   Temperature: {temperature}")
    print("="*70)
    
    with torch.no_grad():
        samples = model.sample(batch_size=2, max_len=90, device='cuda')
    
    # Compute actual lengths (stop at first pad or eos)
    actual_lengths = []
    for s in samples:
        # Find first pad or eos
        s_list = s if isinstance(s, list) else s.tolist()
        length = len(s_list)
        for i, tok in enumerate(s_list):
            if tok == tokenizer.pad_token_id or tok == tokenizer.eos_token_id:
                length = i
                break
        actual_lengths.append(torch.tensor(length))
    
    for i, (sample, length) in enumerate(zip(samples, actual_lengths)):
        decoded = tokenizer.decode(sample[:length.item()], skip_special_tokens=True)
        print(f"{i+1}. (len={length.item()}) {decoded}")
    
    print("="*70)


In [12]:
# Try different temperatures
for temp in [0.8, 1.0, 1.2]:
    print(f"\n--- Temperature: {temp} ---")
    generate_and_decode(
        model=model,
        tokenizer=tokenizer,
        num_samples=5,
        temperature=temp
    )


--- Temperature: 0.8 ---

🧪 Generating 5 SELFIES molecules...
   Temperature: 0.8
✅ Stopped at step 3 (change: 0.00%)
1. (len=65) [O] [=O] [#Branch2] [=Branch1] [=C] [Branch2] [=C] [C] [C] [Ring1] [C] [N] [C] [=O] [=O] [=Branch1] [C] [=C] [C] [=Branch1] [C] [=C] [C] [C] [Branch2] [C] [Branch1] [Ring1] [=C] [=C] [N] [Ring1] [Ring1] [Branch1] [NH1] [C] [C] [O] [C] [C] [C] [C] [C] [Branch1] [C] [=C] [=Branch1] [#C] [C] [=Branch2] [Ring1] [C] [Ring2] [Ring1] [C] [Branch1] [#Branch1] [C] [Ring1] [C] [C] [C] [C] [O]
2. (len=62) [Cl] [O] [=C] [=C] [C] [=C] [Branch1] [C] [=C] [C] [N] [=Branch1] [Ring2] [=O] [C] [Branch1] [C] [Branch1] [C] [=C] [C] [Branch1] [=Branch1] [C] [=O] [#C] [=C] [=Branch1] [C] [N] [Branch1] [Ring1] [=C] [C] [C] [=Branch1] [Ring1] [O] [O] [C] [C] [Ring2] [C] [Ring1] [Ring1] [=C] [=Branch1] [C] [C] [=C] [O] [Cl] [C] [=N] [C] [C] [Ring2] [Ring2] [O] [P] [O]

--- Temperature: 1.0 ---

🧪 Generating 5 SELFIES molecules...
   Temperature: 1.0
✅ Stopped at step 3 (change: 0.0

In [13]:
import selfies as sf

tokens = "[Cl] [O] [=C] [=C] [C] [=C] [Branch1] [C] [=C] [C] [N] [=Branch1] [Ring2] [=O] [C] [Branch1] [C] [Branch1] [C] [=C] [C] [Branch1] [=Branch1] [C] [=O] [#C] [=C] [=Branch1] [C] [N] [Branch1] [Ring1] [=C] [C] [C] [=Branch1] [Ring1] [O] [O] [C] [C] [Ring2] [C] [Ring1] [Ring1] [=C] [=Branch1] [C] [C] [=C] [O] [Cl] [C] [=N] [C] [C] [Ring2] [Ring2] [O] [P] [O]"
tokens = tokens.replace(" ", "")
print(sf.decoder(tokens))

ClOC=CC=C(C)CN(OC)C1=CC(C=O)CN(CC)C(OO2)CC21CC=COCl


In [14]:
model.eval()
analysis = model.analyze_refinement_trajectory(
    max_len=16,
    device='cuda',
    seed=22
)
model.print_refinement_trajectory(analysis, tokenizer=tokenizer)

🔍 Refinement Trajectory (INTERNAL GATE)

t=0: [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK]
        ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑
       Gate: 1.00 1.00 1.00 1.00 1.00 1.00 1.00 1.00 1.00 1.00 1.00 1.00 1.00 1.00 1.00 1.00
t=1: [MASK] [MASK] [ [6]] [[10]] [ [5]] [ [5]] [ [5]] [[14]] [ [8]] [ [5]] [ [6]] [ [5]] [ [5]] [[22]] [[37]] [ [5]]
            ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑
       Gate: 0.06 0.21 0.83 0.93 0.96 0.97 0.97 0.97 0.96 0.96 0.96 0.97 0.97 0.97 0.97 0.97
t=2: [MASK] [MASK] [ [9]] [[10]] [ [5]] [ [5]] [ [5]] [[13]] [ [9]] [ [5]] [[10]] [ [5]] [ [5]] [ [5]] [ [5]] [[13]]
            ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑
       Gate: 0.05 0.23 0.88 0.96 0.98 0.98 0.99 0.99 0.98 0.98 0.98 0.98 0.98 0.98 0.98 0.98
t=3: [MASK] [MASK] [ [5]] [[18]] [ [6]] [ [5]] [ [5]] [ [7]] [ [8]] [ [5]] [[33]] [[12]] [ [6]] [ [7]] [[10]] [ [8]]
            ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑
       Gate: 0.05 0.21 0.86 0.96 0.98 0.98 0.98 0.98