In [None]:
# STEP 1: Build Token Vocabulary with Priority + Preserve Special Token IDs
import pandas as pd
from transformers import AutoTokenizer
import json

# Load tokenizer (for vocab and special token IDs)
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")

# Load DataFrame
df = pd.read_csv("./top_motifs_readable.csv")

# Convert space-separated token IDs to list of ints
df['ngram_decoded_ids'] = df['ngram_decoded'].apply(
    lambda x: [int(tok) for tok in x.strip().split()] if isinstance(x, str) else []
)

# Decode to token strings
df['ngram_tokens'] = df['ngram_decoded_ids'].apply(
    lambda ids: tokenizer.decode(ids, skip_special_tokens=False)
)

# --- Sort by token string length DESC (longest first)
df_sorted = df.sort_values(
    by='ngram_tokens',
    key=lambda x: x.str.len(),
    ascending=False
).reset_index(drop=True)

# Start building vocab list & token_id mapping
token_to_id = {}

# Step 1: Add special tokens first — preserve their original IDs
special_tokens = ["<s>", "</s>", "<pad>", "<unk>", "<mask>"]
for tok in special_tokens:
    if tok in tokenizer.get_vocab():
        token_to_id[tok] = tokenizer.convert_tokens_to_ids(tok)

# Step 2: Add ngram tokens (longest first) — assign new IDs if not already in vocab
next_id = max(token_to_id.values()) + 1 if token_to_id else 0

for token_str in df_sorted['ngram_tokens'].drop_duplicates():
    if token_str not in token_to_id:
        token_to_id[token_str] = next_id
        next_id += 1

# Step 3: Add remaining vocab tokens (individual tokens) — only if not already added
vocab_tokens = tokenizer.get_vocab().keys()
for tok in vocab_tokens:
    if tok not in token_to_id:
        token_to_id[tok] = next_id
        next_id += 1

#  Final token_to_id mapping ready
print(f"Total tokens in custom vocab: {len(token_to_id)}")

Total tokens in custom vocab: 1508


In [6]:
# STEP 2: Build Tokenizer Class Compatible with Common Pipeline
from FastChemTokenizer import FastChemTokenizer

# Save vocab
with open("./smitok-proto/vocab.json", "w") as f:
    json.dump(token_to_id, f, indent=2)

# Save as LongestMatchTokenizer object (optional, for Python use)
tokenizer = FastChemTokenizer.from_pretrained('./smitok-proto/')

In [8]:
# Full Test

# === Test 1: Basic encode/decode with trace ===
testsmi = "CC1=CC(CNC(=O)c2ccccc2)C(C(C)C)CC1Cc1nc2ccc(F)cc2[nH]1"
encoded = tokenizer.encode(testsmi)
print("✅ Encoded:", encoded)
decoded = tokenizer.decode(encoded)
print("✅ Decoded:", decoded)
tokenizer.decode_with_trace(encoded)

# === Test 2: Single sequence with special tokens ===
single = tokenizer.encode_plus(
    text=benzene,
    add_special_tokens=True,
    padding=True,
    max_length=20,
    return_tensors="pt"
)
print("\n✅ Single Sequence Output:")
print(single)

# === Test 3: Sequence pair ===
seqA = "CC1=CC(CNC(=O)c2ccccc2)C(C(C)C)CC1Cc1nc2ccc(F)cc2[nH]1"
seqB = "Cc1ccc(C2CC(O)C(O)C2NCC(C)C)cc1"
pair = tokenizer.encode_plus(
    text=seqA,
    text_pair=seqB,
    add_special_tokens=True,
    padding=True,
    max_length=20,
    return_tensors="pt"
)
print("\n✅ Sequence Pair Output:")
print(pair)

# === Test 4: Batch encode ===
batch = tokenizer.batch_encode_plus(
    [
        benzene,
        (seqA, seqB),
        "Cc1ccc(C2CC(O)C(O)C2NCC(C)C)cc1"
    ],
    add_special_tokens=True,
    padding=True,
    max_length=32,
    return_tensors="pt"
)
print("\n✅ Batch Output:")
for k, v in batch.items():
    print(f"{k}: shape {v.shape}")
    print(f"  sample: {v[0][:10]}...")

# === Test 5: Round-trip with special tokens ===
test_with_special = "<s>Cc1ccc(C2CC(O)C(O)C2NCC(C)C)cc1</s>"
ids = tokenizer.encode_plus(test_with_special, add_special_tokens=False)
recovered = tokenizer.decode(ids["input_ids"])
print(f"\n✅ Round-trip with manual special tokens: {recovered}")
assert recovered == test_with_special.strip(), "Round-trip failed!"
print("🎉 All tests passed!")

✅ Encoded: [529, 762, 485, 765, 583, 759, 710, 630, 691, 780, 473, 644, 749, 767, 779, 623, 780]
✅ Decoded: CC1=CC(CNC(=O)c2ccccc2)C(C(C)C)CC1Cc1nc2ccc(F)cc2[nH]1

🔍 Decoding 17 tokens:
  [000] ID=  529 → 'CC1=CC'
  [001] ID=  762 → '(C'
  [002] ID=  485 → 'NC(=O)'
  [003] ID=  765 → 'c2'
  [004] ID=  583 → 'ccccc'
  [005] ID=  759 → '2)'
  [006] ID=  710 → 'C(C'
  [007] ID=  630 → '(C)C'
  [008] ID=  691 → ')CC'
  [009] ID=  780 → '1'
  [010] ID=  473 → 'Cc1nc2'
  [011] ID=  644 → 'ccc('
  [012] ID=  749 → 'F)'
  [013] ID=  767 → 'cc'
  [014] ID=  779 → '2'
  [015] ID=  623 → '[nH]'
  [016] ID=  780 → '1'

✅ Single Sequence Output:
{'input_ids': tensor([[  0, 529, 762, 485, 765, 583, 759, 710, 630, 691, 780, 473, 644, 749,
         767, 779, 623, 780,   2,   1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])}

✅ Sequence Pair Output:
{'input_ids': tenso

  return [torch.tensor(item, dtype=torch.long) for item in lst]
