In [None]:
import pandas as pd
from collections import defaultdict
from FastChemTokenizer import FastChemTokenizer
from tqdm import tqdm 

# --- 1. Load tokenizer ---
tokenizer = FastChemTokenizer.from_pretrained("../smitok-proto")

# --- 2. Load dataset ---
df = pd.read_csv("../comb_smi.csv", header=None, names=["smiles"])
smiles_list = df["smiles"].dropna().astype(str).tolist()

print(f"📊 Loaded {len(smiles_list)} SMILES strings for analysis.")

# --- 3. Count token usage WITH PROGRESS BAR ---
token_usage = defaultdict(int)

for smiles in tqdm(smiles_list, desc="📈 Counting token usage", unit="smiles"):
    token_ids = tokenizer.encode(smiles)
    for tid in token_ids:
        token_usage[tid] += 1

print(f"✅ Found usage for {len(token_usage)} unique tokens.")

# --- 4. Identify tokens to KEEP ---
# Always keep special tokens, even if unused (required for model)
special_tokens = {
    tokenizer.bos_token_id,
    tokenizer.eos_token_id,
    tokenizer.pad_token_id,
    tokenizer.unk_token_id,
    tokenizer.mask_token_id,
}

# Tokens that appeared at least once OR are special
tokens_to_keep = set(token_usage.keys()) | special_tokens

print(f"🧃 Keeping {len(tokens_to_keep)} tokens (including special tokens).")

# --- 5. Build new token_to_id mapping ---
# We assign special tokens fixed low IDs for consistency
new_token_to_id = {
    tokenizer.bos_token: 0,
    tokenizer.eos_token: 1,
    tokenizer.pad_token: 2,
    tokenizer.unk_token: 3,
    tokenizer.mask_token: 4
}

# Then assign remaining tokens new contiguous IDs
next_id = 5
for old_id in sorted(tokens_to_keep - special_tokens):  # sort for determinism
    token = tokenizer.id_to_token[old_id]
    if token not in new_token_to_id:  # avoid duplicates (shouldn't happen)
        new_token_to_id[token] = next_id
        next_id += 1

print(f"🆕 New vocab size: {len(new_token_to_id)}")

# --- 6. Create new pruned & renumbered tokenizer ---
pruned_tokenizer = FastChemTokenizer(
    token_to_id=new_token_to_id,
    model_max_length=tokenizer.model_max_length
)

# Optional: Verify decoding still works
test_smiles = "CCO"
test_ids = pruned_tokenizer.encode(test_smiles)
decoded = pruned_tokenizer.decode(test_ids)
print(f"🧪 Test encode/decode: '{test_smiles}' → {test_ids} → '{decoded}'")

# --- 7. Save to ./chemtok/ ---
save_dir = "./chemtok"
pruned_tokenizer.save_pretrained(save_dir)

print(f"✅ Pruned tokenizer saved to {save_dir}/vocab.json")

📊 Loaded 2684370 SMILES strings for analysis.


📈 Counting token usage: 100%|██████████| 2684370/2684370 [02:29<00:00, 18008.39smiles/s]

✅ Found usage for 1233 unique tokens.
🧃 Keeping 1237 tokens (including special tokens).
🆕 New vocab size: 1238
🧪 Test encode/decode: 'CCO' → [671] → 'CCO'
✅ Tokenizer vocab saved to: ./chemtok\vocab.json
✅ Pruned tokenizer saved to ./chemtok/vocab.json





In [4]:
tokenizer.decode_with_trace(tokenizer.encode("COc1ccc2[nH]c3c(c2c1)CCNC3=O"))


🔍 Decoding 10 tokens:
  [000] ID=   59 → 'COc1ccc2[nH]c'
  [001] ID=  782 → '3'
  [002] ID=  754 → 'c('
  [003] ID=  765 → 'c2'
  [004] ID=  755 → 'c1'
  [005] ID=  691 → ')CC'
  [006] ID=  745 → 'NC'
  [007] ID=  782 → '3'
  [008] ID=  778 → '='
  [009] ID= 1211 → 'O'


In [5]:
original_vocab_size = len(tokenizer.token_to_id)
pruned_vocab_size = len(new_token_to_id)
pruned_count = original_vocab_size - pruned_vocab_size

print(f"✂️  Pruned {pruned_count} unused tokens.")

# List some pruned tokens (if any)
pruned_tokens = set(tokenizer.token_to_id.keys()) - set(new_token_to_id.keys())
if pruned_tokens:
    print("🗑️  Sample of pruned tokens (max 10):")
    for t in list(pruned_tokens)[:10]:
        print(f"   '{t}'")

✂️  Pruned 270 unused tokens.
🗑️  Sample of pruned tokens (max 10):
   '²'
   'È'
   'C@@H](O)['
   '¦'
   'ñ'
   'Ĭ'
   '(C(=O'
   'Æ'
   'Ö'
   '§'


In [6]:
print("🔍 Validating tokenizer on dataset...")
for i, smiles in enumerate(tqdm(smiles_list[:1000], desc="🧪 Validating", unit="smiles")):
    try:
        ids = pruned_tokenizer.encode(smiles)
        decoded = pruned_tokenizer.decode(ids)
        if decoded != smiles:
            print(f"⚠️  Mismatch at line {i}: '{smiles}' → decoded as '{decoded}'")
    except Exception as e:
        print(f"❌ Error encoding line {i}: {e}")

print("✅ Validation complete.")

🔍 Validating tokenizer on dataset...


🧪 Validating: 100%|██████████| 1000/1000 [00:00<00:00, 10554.26smiles/s]

✅ Validation complete.



