In [5]:
%load_ext autoreload
%autoreload 2

In [31]:
from genlm.control import Canonical, BoolFSA, AWRS, PromptedLLM, direct_token_sampler
import transformers
import numpy as np
from genlm.control.constant import EOS
import asyncio


**Simple tests, verifying it works**

In [28]:
tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2', use_fast=False)
canonical_potential = Canonical(tokenizer, model_name="gpt2")

adding override b'\n' <-> b'\n'
adding override b".'" <-> b's'


In [12]:
# Check that the potential has the correct vocabulary
assert len(canonical_potential.vocab) == len(canonical_potential.canonicality_filter._decode)

# Check that EOS is added correctly
assert len(canonical_potential.vocab_eos) == len(canonical_potential.vocab) + 1
log_weight = await canonical_potential.complete([])
assert log_weight == 0.0


In [13]:
tokens = [b"Token", b"ization"]
log_weight = await canonical_potential.complete(tokens)
assert log_weight == 0.0

percent of mask:  97.26207294506239


In [14]:
scrambled = [b'To', b'ken', b'ization']
print(scrambled)
log_weight = await canonical_potential.complete(scrambled)
print(log_weight)
assert log_weight == float('-inf')

[b'To', b'ken', b'ization']
percent of mask:  88.56079750084565
-inf


In [19]:
scrambled = [b'To', b'ken']
logw = await canonical_potential.logw_next(scrambled)
assert logw[b'ization'] == float('-inf')
assert logw[EOS] == 0.0

percent of mask:  88.56079750084565


In [20]:
assert canonical_potential._check_canonicality([])

# Single token is always canonical
assert canonical_potential._check_canonicality([b" the"])

# Valid token sequence should be canonical
assert canonical_potential._check_canonicality([b"Token", b"ization"])

hello = b"hello"
world = b" world"

assert canonical_potential._check_canonicality([hello, world])

# This should be non-canonical as " world" cannot start a sequence
assert not canonical_potential._check_canonicality([b"hel", b"lo", b" world"])

percent of mask:  97.26207294506239
percent of mask:  87.04658057584018
percent of mask:  95.87122191933462


just checking with a few examples

In [30]:
sentences = [
    "Natural language processing",
    "The quick brown fox jumps over the lazy dog",
    "Artificial intelligence and machine learning"
]

for sentence in sentences:
    print(sentence)
    tokens = tokenizer.encode(sentence, add_special_tokens=False)
    token_bytes = [tokenizer.decode([token]).encode('utf-8') for token in tokens]
    
    # This should be canonical
    log_weight = await canonical_potential.complete(token_bytes)
    assert log_weight == 0.0
    
    # Also test prefix for each subsequence
    for i in range(1, len(token_bytes) + 1):
        prefix = token_bytes[:i]
        log_weight = await canonical_potential.prefix(prefix)
        assert log_weight == 0.0
        
    # Test that each valid prefix allows appropriate next tokens
    for i in range(len(token_bytes)):
        prefix = token_bytes[:i]
        next_token = token_bytes[i] if i < len(token_bytes) else p.eos
        print(prefix)
        print(next_token)
        lazy_weights = await canonical_potential.logw_next(prefix)
        
        # The next token in the sequence should be allowed
        token_idx = lazy_weights.encode.get(next_token)
        if token_idx is not None:
            assert not np.isneginf(lazy_weights.weights[token_idx])
    print("done")

Natural language processing
percent of mask:  98.49175239270151
percent of mask:  95.17082197504826
percent of mask:  98.49175239270151
percent of mask:  98.49175239270151
percent of mask:  95.17082197504826
[]
b'Natural'
[b'Natural']
b' language'
[b'Natural', b' language']
b' processing'
percent of mask:  98.49175239270151
done
The quick brown fox jumps over the lazy dog
percent of mask:  98.59721033885828
percent of mask:  98.74843305410191
percent of mask:  99.12251029707305
percent of mask:  99.54633185426906
percent of mask:  97.84905585291601
percent of mask:  96.76661957538253
percent of mask:  98.05400242752253
percent of mask:  96.71090594345067
percent of mask:  98.59721033885828
percent of mask:  98.59721033885828
percent of mask:  98.74843305410191
percent of mask:  98.59721033885828
percent of mask:  98.74843305410191
percent of mask:  99.12251029707305
percent of mask:  98.59721033885828
percent of mask:  98.74843305410191
percent of mask:  99.12251029707305
percent of ma

Now the interesting stuff, run smc with normal PromptedLLM with and without canonical potential, notice the difference!

In [45]:
async def main():
    print("===== TESTING CANONICAL BPE POTENTIAL EFFECT ON GENERATION =====\n")
    
    tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2', use_fast=False)
    canonical_potential = Canonical(tokenizer, model_name="gpt2")
    prompt = "Tokeniz"  # v good prompt for non canonicals
    
    print("\n===== GENERATION WITHOUT CANONICAL ENFORCEMENT =====")
    llm = PromptedLLM.from_name("gpt2", temperature=0.7)  
    llm.set_prompt_from_str(prompt)
    
    
    sampler = direct_token_sampler(llm)
    sequences_without = await sampler.smc(n_particles=5, max_tokens=10, ess_threshold=0.5, verbosity=1)

    print("\nGeneration results WITHOUT canonical enforcement:")
    print(sequences_without.posterior)
    

    print("\nDetailed token-by-token view WITHOUT canonical enforcement:")
    for sequence, weight in sequences_without.posterior.items():
        token_str = ' | '.join([token.decode('utf-8', errors='replace') for token in sequence])
        print(f"{weight:.4f}: {token_str}")
    

    for sequence, weight in sequences_without.posterior.items():
        if weight > 0.1:  # Only show high-probability sequences, verbosity=1 already shows all
            text = tokenizer.decode([tokenizer.encode(token.decode('utf-8', errors='replace'))[0] for token in sequence])
            print(f"\nText: {prompt + text}")
    
    print("\n\n===== GENERATION WITH CANONICAL ENFORCEMENT =====")
    
    llm = PromptedLLM.from_name("gpt2", temperature=0.7)
    llm.set_prompt_from_str(prompt)
    
    product = llm * canonical_potential
    

    sampler = direct_token_sampler(product)
    sequences_with = await sampler.smc(n_particles=5, max_tokens=10, ess_threshold=0.5, verbosity=1)
    
    print("\nGeneration results WITH canonical enforcement:")
    print(sequences_with.posterior)
    
    print("\nDetailed token-by-token view WITH canonical enforcement:")
    for sequence, weight in sequences_with.posterior.items():
        token_str = ' | '.join([token.decode('utf-8', errors='replace') for token in sequence])
        print(f"{weight:.4f}: {token_str}")
    
    # Convert to actual text
    for sequence, weight in sequences_with.posterior.items():
        if weight > 0.1:
            text = tokenizer.decode([tokenizer.encode(token.decode('utf-8', errors='replace'))[0] for token in sequence])
            print(f"\nText: {prompt + text}")
    
await main()

===== TESTING CANONICAL BPE POTENTIAL EFFECT ON GENERATION =====

adding override b'\n' <-> b'\n'
adding override b".'" <-> b's'

===== GENERATION WITHOUT CANONICAL ENFORCEMENT =====




0.00:	[0;35m[[0mb'i'[0;35m][0m
0.00:	[0;35m[[0mb'y'[0;35m][0m
0.00:	[0;35m[[0mb'r'[0;35m][0m
0.00:	[0;35m[[0mb'r'[0;35m][0m
0.00:	[0;35m[[0mb'ah'[0;35m][0m
0.00:	[0;35m[[0mb'i'[0;35m|[0mb' will'[0;35m][0m
0.00:	[0;35m[[0mb'y'[0;35m|[0mb')'[0;35m][0m
0.00:	[0;35m[[0mb'r'[0;35m|[0mb'W'[0;35m][0m
0.00:	[0;35m[[0mb'r'[0;35m|[0mb','[0;35m][0m
0.00:	[0;35m[[0mb'ah'[0;35m|[0mb' ('[0;35m][0m
0.00:	[0;35m[[0mb'i'[0;35m|[0mb' will'[0;35m|[0mb' be'[0;35m][0m
0.00:	[0;35m[[0mb'y'[0;35m|[0mb')'[0;35m|[0mb' July'[0;35m][0m
0.00:	[0;35m[[0mb'r'[0;35m|[0mb'W'[0;35m|[0mb'5'[0;35m][0m
0.00:	[0;35m[[0mb'r'[0;35m|[0mb','[0;35m|[0mb' Kod'[0;35m][0m
0.00:	[0;35m[[0mb'ah'[0;35m|[0mb' ('[0;35m|[0mb'se'[0;35m][0m
0.00:	[0;35m[[0mb'i'[0;35m|[0mb' will'[0;35m|[0mb' be'[0;35m|[0mb' the'[0;35m][0m
0.00:	[0;35m[[0mb'y'[0;35m|[0mb')'[0;35m|[0mb' July'[0;35m|[0mb' 17'[0;35m][0m
0.00:	[0;35m[[0mb'r'[0;35m|

Same exact code but with AWRS sampler instead of direct token sampler.

In [46]:
async def main():
    print("===== TESTING CANONICAL BPE POTENTIAL EFFECT ON GENERATION =====\n")
    
    tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2', use_fast=False)
    canonical_potential = Canonical(tokenizer, model_name="gpt2")
    prompt = "Tokeniz"  # v good prompt for non canonicals
    
    print("\n===== GENERATION WITHOUT CANONICAL ENFORCEMENT =====")
    llm = PromptedLLM.from_name("gpt2", temperature=0.7)  
    llm.set_prompt_from_str(prompt)
    
    
    sampler = direct_token_sampler(llm)
    sequences_without = await sampler.smc(n_particles=5, max_tokens=10, ess_threshold=0.5, verbosity=1)

    print("\nGeneration results WITHOUT canonical enforcement:")
    print(sequences_without.posterior)
    

    print("\nDetailed token-by-token view WITHOUT canonical enforcement:")
    for sequence, weight in sequences_without.posterior.items():
        token_str = ' | '.join([token.decode('utf-8', errors='replace') for token in sequence])
        print(f"{weight:.4f}: {token_str}")
    

    for sequence, weight in sequences_without.posterior.items():
        if weight > 0.1:  # Only show high-probability sequences, verbosity=1 already shows all
            text = tokenizer.decode([tokenizer.encode(token.decode('utf-8', errors='replace'))[0] for token in sequence])
            print(f"\nText: {prompt + text}")
    
    print("\n\n===== GENERATION WITH CANONICAL ENFORCEMENT =====")
    
    llm = PromptedLLM.from_name("gpt2", temperature=0.7)
    llm.set_prompt_from_str(prompt)

    awrs_sampler = AWRS(llm, canonical_potential)
    sequences_with = await awrs_sampler.smc(n_particles=5, max_tokens=10, ess_threshold=0.5, verbosity=1)
    
    print("\nGeneration results WITH canonical enforcement:")
    print(sequences_with.posterior)
    
    print("\nDetailed token-by-token view WITH canonical enforcement:")
    for sequence, weight in sequences_with.posterior.items():
        token_str = ' | '.join([token.decode('utf-8', errors='replace') for token in sequence])
        print(f"{weight:.4f}: {token_str}")
    
    # Convert to actual text
    for sequence, weight in sequences_with.posterior.items():
        if weight > 0.1:
            text = tokenizer.decode([tokenizer.encode(token.decode('utf-8', errors='replace'))[0] for token in sequence])
            print(f"\nText: {prompt + text}")
    
await main()


===== TESTING CANONICAL BPE POTENTIAL EFFECT ON GENERATION =====

adding override b'\n' <-> b'\n'
adding override b".'" <-> b's'

===== GENERATION WITHOUT CANONICAL ENFORCEMENT =====
0.00:	[0;35m[[0mb'ah'[0;35m][0m
0.00:	[0;35m[[0mb"'"[0;35m][0m
0.00:	[0;35m[[0mb'er'[0;35m][0m
0.00:	[0;35m[[0mb'io'[0;35m][0m
0.00:	[0;35m[[0mb'Edge'[0;35m][0m
0.00:	[0;35m[[0mb'ah'[0;35m|[0mb'.'[0;35m][0m
0.00:	[0;35m[[0mb"'"[0;35m|[0mb' ['[0;35m][0m
0.00:	[0;35m[[0mb'er'[0;35m|[0mb')'[0;35m][0m
0.00:	[0;35m[[0mb'io'[0;35m|[0mb')'[0;35m][0m
0.00:	[0;35m[[0mb'Edge'[0;35m|[0mb':'[0;35m][0m
0.00:	[0;35m[[0mb'ah'[0;35m|[0mb'.'[0;35m|[0mb'\n'[0;35m][0m
0.00:	[0;35m[[0mb"'"[0;35m|[0mb' ['[0;35m|[0mb'2'[0;35m][0m
0.00:	[0;35m[[0mb'er'[0;35m|[0mb')'[0;35m|[0mb' :'[0;35m][0m
0.00:	[0;35m[[0mb'io'[0;35m|[0mb')'[0;35m|[0mb' Feb'[0;35m][0m
0.00:	[0;35m[[0mb'Edge'[0;35m|[0mb':'[0;35m|[0mb'0'[0;35m][0m
0.00:	[0;35m[[0mb'ah'[0;

combining json potential with canonical potential, just interesting, to check how it works

In [49]:
import asyncio
import json
from transformers import AutoTokenizer
from genlm.control import PromptedLLM, JsonSchema, Canonical, AWRS, direct_token_sampler
from huggingface_hub import login
login(token="meow")
import json
from genlm.control import PromptedLLM, JsonSchema, AWRS

async def main():
    person_schema = {
    "type": "object",
    "properties": {
        "name": {
            "type": "string",
            "enum": ["Alice", "Bob", "Charlie"],
            "description": "The name of the person"
        },
        "age": {
            "type": "integer",
            "minimum": 20,
            "maximum": 80,
            "description": "The age of the person"
        },
    },
    }

    book_schema = {
        "type": "object",
        "properties": {
            "title": {
                "type": "string",
                "minLength": 1,
                "description": "The title of the book"
            },
            "pages": {
                "type": "integer",
                "minimum": 1,
                "maximum": 2000,
                "description": "The number of pages in the book"
            },
            "genre": {
                "type": "string",
                "enum": ["fiction", "non-fiction", "mystery"],
                "description": "The genre of the book"
            }
        },
    }

    # Create a language model potential.
    # Since this task is harder, we use a larger model.
    # (You will need to login via the Hugging Face CLI and have access to the model.)
    llm = PromptedLLM.from_name(
        "meta-llama/Llama-3.2-1B-Instruct",
        eos_tokens=[b"<|eom_id|>", b"<|eot_id|>"],
        temperature=0.8,
        engine_opts={"dtype": "half"}
    )

    # Set the prompt for the language model.
    # Since we are using an instruction-tuned model, we use the chat template.
    # The prompt contains an example of a schema and a generated object,
    # followed by the schema we want to match.
    llm.prompt_ids = llm.model.tokenizer.apply_chat_template(
        conversation=[
            {"role": "system", "content": "You need to generate a JSON object that matches the schema below. Only generate the JSON object on a single line with no other text."},
            {"role": "user", "content": json.dumps(person_schema)},
            {"role": "assistant", "content": '{"name": "Alice", "age": 30}'},
            {"role": "user", "content": json.dumps(book_schema)},
        ],
        tokenize=True,
        add_generation_prompt=True
    )

    # Create a schema potential.
    schema_potential = JsonSchema(book_schema)

    # Coerce the schema potential so that it operates on the token type of the language model.
    coerced_schema = schema_potential.coerce(llm, f=b"".join)

    # Create a token sampler that combines the language model and the schema potential.
    token_sampler = AWRS(llm, coerced_schema)

    # Generate text using SMC.
    # Generation is asynchronous; use `await` if calling in an async context (like in an async
    # function or in a Jupyter notebook) and `asyncio.run(token_sampler.smc(...))` otherwise.
    sequences = await token_sampler.smc(
        n_particles=2, # Number of candidate sequences to maintain
        ess_threshold=0.5, # Threshold for resampling
        max_tokens=30, # Maximum sequence length
        verbosity=1 # Print particles at each step
    )

    # Show the inferred posterior distribution over complete UTF-8 decodable sequences.
    print("Original json potential sequences:")
    print(sequences.decoded_posterior)
    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
    canonical_potential = Canonical(tokenizer, "meta-llama/Llama-3.2-1B-Instruct")
    combined_constraint = coerced_schema * canonical_potential
    combined_sampler = AWRS(llm, combined_constraint)
    combined_sequences = await combined_sampler.smc(
        n_particles=3,
        max_tokens=50,
        ess_threshold=0.5,
        verbosity=1
    )
    print("Combined json-canonicality potentials sequences:")
    print(combined_sequences.decoded_posterior)


    # Example output:
    # {
#   '{"title": "The Lord of the Rings", "pages": 1200, "genre": "fiction"}': 0.5008318164809697,
#   '{"title": "The Great Gatsby", "pages": 178, "genre": "fiction"}': 0.49916818351903025,
# }

await main()



Original json potential sequences:
Chart {
  '{"title": "The Great Gatsby", "pages": 173, "genre": "fiction"}': 0.5000000091850528,
  '{"title": "Book", "pages": 500, "genre": "fiction"}': 0.49999999081494706,
}
Combined json-canonicality potentials sequences:
Chart {
  '{"title": "Harry Potter", "pages": 300, "genre": "fiction"}': 0.5000001368322415,
  '{"title": "The Great Gatsby", "pages": 195, "genre": "fiction"}': 0.4999998631677583,
}


Below, was playing around with temperature and setting eos_tokens, makes a difference!

In [37]:
myllm = PromptedLLM.from_name("gpt2", temperature=0.5, eos_tokens=[b'.'])

# Set the fixed prompt prefix for the language model
# All language model predictions will be conditioned on this prompt
myllm.set_prompt_from_str("Tokeniz")

# Load a sampler that proposes tokens by sampling directly from the LM's distribution
token_sampler = direct_token_sampler(myllm)

# Run SMC with 5 particles, a maximum of 25 tokens, and an ESS threshold of 0.5
sequences = await token_sampler.smc(n_particles=5, max_tokens=25, ess_threshold=0.5)

# Show the posterior over token sequences
sequences.posterior

# Show the posterior over complete UTF-8 decodable sequences
sequences.decoded_posterior



0,1
key,value
": You're right, but I'm not sure what the hell you're talking about",0.2500081447348389
r-2,0.24999831322512667
ar,0.24999693204360887
r,0.2499966099964255


another run, it's quite bad with this prompt

In [38]:
myllm = PromptedLLM.from_name("gpt2", temperature=0.5, eos_tokens=[b'.'])

# Set the fixed prompt prefix for the language model
# All language model predictions will be conditioned on this prompt
myllm.set_prompt_from_str("Tokeniz")

# Load a sampler that proposes tokens by sampling directly from the LM's distribution
token_sampler = direct_token_sampler(myllm)

# Run SMC with 5 particles, a maximum of 25 tokens, and an ESS threshold of 0.5
sequences = await token_sampler.smc(n_particles=10, max_tokens=25, ess_threshold=0.5)

# Show the posterior over token sequences
sequences.posterior

# Show the posterior over complete UTF-8 decodable sequences
sequences.decoded_posterior



0,1
key,value
r,0.42856975119565044
io,0.2857117677543588
r_1_0,0.14285979563810608
r = 1,0.14285868541188462


In [43]:
canonical_potential = Canonical(tokenizer, "gpt2")
product = myllm * canonical_potential
token_sampler = direct_token_sampler(product)
sequences = await token_sampler.smc(n_particles=5, max_tokens=25, ess_threshold=0.5, verbosity=1)
sequences.decoded_posterior


adding override b'\n' <-> b'\n'
adding override b".'" <-> b's'
0.00:	[0;35m[[0mb'_'[0;35m][0m
0.00:	[0;35m[[0mb','[0;35m][0m
0.00:	[0;35m[[0mEOS[0;35m][0m
0.00:	[0;35m[[0mb'io'[0;35m][0m
0.00:	[0;35m[[0mb'\n'[0;35m][0m
0.00:	[0;35m[[0mb'io'[0;35m|[0mEOS[0;35m][0m
-18.52:	[0;35m[[0mb'\n'[0;35m|[0mEOS[0;35m][0m
-0.00:	[0;35m[[0mb'_'[0;35m|[0mb'429'[0;35m][0m
0.00:	[0;35m[[0mb','[0;35m|[0mb'\n'[0;35m][0m
-0.00:	[0;35m[[0mb'_'[0;35m|[0mb'429'[0;35m|[0mb'48'[0;35m][0m
-19.00:	[0;35m[[0mb','[0;35m|[0mb'\n'[0;35m|[0mb'('[0;35m][0m
-0.00:	[0;35m[[0mb'_'[0;35m|[0mb'429'[0;35m|[0mb'48'[0;35m|[0mb'79'[0;35m][0m
-19.00:	[0;35m[[0mb','[0;35m|[0mb'\n'[0;35m|[0mb'('[0;35m|[0mb'uint'[0;35m][0m
-0.02:	[0;35m[[0mb'_'[0;35m|[0mb'429'[0;35m|[0mb'48'[0;35m|[0mb'79'[0;35m|[0mb'166'[0;35m][0m
-19.00:	[0;35m[[0mb','[0;35m|[0mb'\n'[0;35m|[0mb'('[0;35m|[0mb'uint'[0;35m|[0mb')'[0;35m][0m
-0.02:	[0;35m[[0mb'_

0,1
key,value
,0.5000009714158655
io,0.4999990240421109
,4.542023496616819e-09


this prompt is fine

In [44]:
myllm.set_prompt_from_str("Montreal is")
product = myllm * canonical_potential
token_sampler = direct_token_sampler(product)
sequences = await token_sampler.smc(n_particles=5, max_tokens=25, ess_threshold=0.5, verbosity=1)
sequences.decoded_posterior

0.00:	[0;35m[[0mb' in'[0;35m][0m
0.00:	[0;35m[[0mb' the'[0;35m][0m
0.00:	[0;35m[[0mb' the'[0;35m][0m
0.00:	[0;35m[[0mb' a'[0;35m][0m
0.00:	[0;35m[[0mb' a'[0;35m][0m
0.00:	[0;35m[[0mb' in'[0;35m|[0mb' a'[0;35m][0m
0.00:	[0;35m[[0mb' the'[0;35m|[0mb' only'[0;35m][0m
0.00:	[0;35m[[0mb' the'[0;35m|[0mb' only'[0;35m][0m
0.00:	[0;35m[[0mb' a'[0;35m|[0mb' city'[0;35m][0m
0.00:	[0;35m[[0mb' a'[0;35m|[0mb' city'[0;35m][0m
0.00:	[0;35m[[0mb' in'[0;35m|[0mb' a'[0;35m|[0mb' similar'[0;35m][0m
0.00:	[0;35m[[0mb' the'[0;35m|[0mb' only'[0;35m|[0mb' city'[0;35m][0m
0.00:	[0;35m[[0mb' the'[0;35m|[0mb' only'[0;35m|[0mb' city'[0;35m][0m
0.00:	[0;35m[[0mb' a'[0;35m|[0mb' city'[0;35m|[0mb' of'[0;35m][0m
0.00:	[0;35m[[0mb' a'[0;35m|[0mb' city'[0;35m|[0mb' with'[0;35m][0m
0.00:	[0;35m[[0mb' in'[0;35m|[0mb' a'[0;35m|[0mb' similar'[0;35m|[0mb' situation'[0;35m][0m
0.00:	[0;35m[[0mb' the'[0;35m|[0mb' only'[0

0,1
key,value
the only city in Canada where there are no minimum wage laws,0.3333356962071329
a city with a history of racism and racism against black people,0.33333545001935444
"in a similar situation to Montreal, where the city's budget is $5",0.3333288537735127


now, try with combining with other potentials, the FSA and json potentials