In [1]:
import torch
import pandas as pd
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("gbyuvd/bionat-selfies-gen-tokenizer-wordlevel")
    
# Ensure [MASK] token exists
if not hasattr(tokenizer, 'mask_token_id') or tokenizer.mask_token_id is None:
    tokenizer.add_special_tokens({'mask_token': '[MASK]'})

print(f"Vocab size: {tokenizer.vocab_size}")
print(f"[MASK] token ID: {tokenizer.mask_token_id}")

Vocab size: 412
[MASK] token ID: 4


In [2]:
from model import ImplicitRefinementModel, ImplicitRefinementConfig

seq_len = 90  # must match training

config = ImplicitRefinementConfig(
    vocab_size=tokenizer.vocab_size,
    hidden_size=320,
    num_layers=6,
    num_heads=4,
    max_seq_len=seq_len,
    max_refinement_steps=10,
    dropout=0.1,
    use_self_cond=True,
    stop_threshold=0.02,
    min_refine_uncertainty=0.1,
    ema_decay=0.995,
    diversity_weight=0.05,
    sampling_temperature=1.0
)

model = ImplicitRefinementModel(config, tokenizer=tokenizer).to(device)
print(f"Model loaded with {sum(p.numel() for p in model.parameters()):,} parameters")

Model loaded with 7,794,742 parameters


In [3]:
checkpoint_path = "best_model_v.pth"  # or "refinement_model_z_final.pth"

torch.serialization.add_safe_globals([ImplicitRefinementConfig])
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()  # set to evaluation mode
print("✅ Model weights loaded successfully!")

✅ Model weights loaded successfully!


In [5]:
def generate_and_decode(model, tokenizer, num_samples=5, max_len=90, temperature=1.0):
    """Generate chemistry molecules using adaptive refinement."""
    model.eval()
    print(f"\n🧪 Generating {num_samples} SELFIES molecules...")
    print(f"   Temperature: {temperature}")
    print("="*70)
    
    with torch.no_grad():
        samples = model.sample(batch_size=2, max_len=90, device='cuda')
    
    # Compute actual lengths (stop at first pad or eos)
    actual_lengths = []
    for s in samples:
        # Find first pad or eos
        s_list = s if isinstance(s, list) else s.tolist()
        length = len(s_list)
        for i, tok in enumerate(s_list):
            if tok == tokenizer.pad_token_id or tok == tokenizer.eos_token_id:
                length = i
                break
        actual_lengths.append(torch.tensor(length))
    
    for i, (sample, length) in enumerate(zip(samples, actual_lengths)):
        decoded = tokenizer.decode(sample[:length.item()], skip_special_tokens=True)
        print(f"{i+1}. (len={length.item()}) {decoded}")
    
    print("="*70)


In [6]:
# Try different temperatures
for temp in [0.8, 1.0, 1.2]:
    print(f"\n--- Temperature: {temp} ---")
    generate_and_decode(
        model=model,
        tokenizer=tokenizer,
        num_samples=5,
        temperature=temp
    )


--- Temperature: 0.8 ---

🧪 Generating 5 SELFIES molecules...
   Temperature: 0.8
✅ Stopped at step 3 (change: 0.00%)
1. (len=57) [C] [O] [=C] [C] [Branch1] [C] [=O] [C] [S] [Branch2] [F] [C] [=C] [C] [C] [C] [N] [Ring2] [C] [C] [Ring1] [Branch1] [C] [C] [Branch1] [Branch1] [Branch1] [=O] [C] [Branch2] [Branch1] [P] [=N] [Branch1] [=N] [C] [=N] [Branch1] [C] [=C] [=O] [Ring1] [=C] [=C] [Branch1] [N] [S] [Ring2] [=O] [=O] [Branch1] [=N] [C] [Branch1] [C] [P] [C]
2. (len=58) [C] [O] [=C] [=C] [C] [C] [N] [C] [Ring1] [=N] [N] [O] [C] [=C] [C] [=O] [N] [C] [Ring1] [#Branch1] [=Branch1] [N] [Branch1] [P] [=O] [Branch1] [Branch1] [=O] [Branch1] [S] [Branch1] [=O] [C] [=C] [N] [=C] [Branch1] [Branch1] [Branch1] [Branch2] [C] [C] [Branch1] [=O] [=Branch1] [Branch1] [C] [Ring2] [Ring1] [Ring1] [C] [Ring1] [C] [N] [C] [Ring1] [#Branch2]

--- Temperature: 1.0 ---

🧪 Generating 5 SELFIES molecules...
   Temperature: 1.0
✅ Stopped at step 3 (change: 0.00%)
1. (len=59) [=C] [=C] [C] [N] [C] [=C] [R

In [27]:
import selfies as sf

tokens = "[#Branch2] [Branch1] [#Branch1] [=Branch2] [O] .[NH3+1] [N] [C] [C] [N] [C] [C] [C] [C] [Branch1] [N] [C] [C] [C] [C] [Ring1] [C] [C] [=O] [Ring1] [C] [C] [=C] [N] [C] [C] [C] [C] [C] [=Branch1] [C] [=O] [#Branch1]"
tokens = tokens.replace(" ", "")
print(sf.decoder(tokens))

O.[NH3+1]NCCNCCCC(CCC=CC=O)=CNCCCCC=O
