In [None]:
from genlm.control import SMC
from genlm.control.sampler import DirectTokenSampler, MultiTokenUnitSampler
from genlm.control.sampler.unit import BoundaryPredicate, TokenSetBoundary
from genlm.control.potential import Potential
from genlm.control.potential.built_in import PromptedLLM
from genlm.control.constant import EOS

class GoodWordsCritic(Potential):
    def __init__(self, vocab, good_words):
        super().__init__(vocabulary=vocab, eos=EOS)
        self.good_words = set(good_words)
    
    async def prefix(self, context):
        if any(word in context for word in self.good_words):
            return -0.3
        else:
            return -0.5
    
    async def complete(self, context):
        return 0


async def example_gpt2_units(llm, max_subunits=10):
    llm.set_prompt_from_str("Once upon a time")
    print("Each 'unit' is a complete sentence ending with . ! or ?\n")
    
    subunit_sampler = DirectTokenSampler(llm)
    boundary = TokenSetBoundary(set([v for v in llm.vocab if v.endswith(b"t")or v == EOS]))
    
    unit_sampler = MultiTokenUnitSampler(
        subunit_sampler=subunit_sampler,
        boundary_predicate=boundary,
        max_subunits_per_unit=max_subunits,
    )

    critic = GoodWordsCritic(llm.vocab, ["hello", "world"])
    sequences = await SMC(unit_sampler, critic=critic)(
        n_particles=3,
        ess_threshold=0.8,
        max_tokens=3,
        verbosity=2
    )

In [14]:
llm = PromptedLLM.from_name("gpt2")

await example_gpt2_units(llm)

Task was destroyed but it is pending!
task: <Task cancelling name='Task-2334' coro=<AsyncTokenByteTrie._background_loop() running at /opt/miniconda3/envs/gen/lib/python3.12/site-packages/genlm/bytes/trie.py:485> wait_for=<Future cancelled>>
Task was destroyed but it is pending!
task: <Task cancelling name='Task-2335' coro=<AsyncTokenByteTrie._background_loop() running at /opt/miniconda3/envs/gen/lib/python3.12/site-packages/genlm/bytes/trie.py:485> wait_for=<Future cancelled>>


Example 1: GPT-2 with Multi-Token Unit Sampling
--- Example 1a: Sentence-Level Units ---

Each 'unit' is a complete sentence ending with . ! or ?

Particles: [0.00:	[0;35m[[0m[0;35m][0m, 0.00:	[0;35m[[0m[0;35m][0m, 0.00:	[0;35m[[0m[0;35m][0m]
-0.50:	[0;35m[[0mb'␣they',␣b'␣reached',␣b'␣a',␣b'␣campaign',␣b'␣position',␣b'␣in',␣b'␣2018',␣b'␣by',␣b'␣Peter',␣b'␣to'[0;35m][0m
-0.50:	[0;35m[[0mb',',␣b'␣seven',␣b'␣times',␣b'␣a',␣b'␣day',␣b',',␣b'␣blind',␣b'fold',␣b'ed',␣b'␣men'[0;35m][0m
-0.50:	[0;35m[[0mb'␣life',␣b'␣is',␣b'␣as',␣b'␣mine',␣b'␣in',␣b'␣Iz',␣b'la',␣b':',␣b'␣clothing',␣b'␣is'[0;35m][0m
Particles: [0.00:	[0;35m[[0mb'␣they',␣b'␣reached',␣b'␣a',␣b'␣campaign',␣b'␣position',␣b'␣in',␣b'␣2018',␣b'␣by',␣b'␣Peter',␣b'␣to'[0;35m][0m, 0.00:	[0;35m[[0mb',',␣b'␣seven',␣b'␣times',␣b'␣a',␣b'␣day',␣b',',␣b'␣blind',␣b'fold',␣b'ed',␣b'␣men'[0;35m][0m, 0.00:	[0;35m[[0mb'␣life',␣b'␣is',␣b'␣as',␣b'␣mine',␣b'␣in',␣b'␣Iz',␣b'la',␣b':',␣b'␣clothing',␣b'␣is'[0;35m][0m]
-

# Test Byte Potential

In [15]:
from genlm.control.potential.built_in.llm import load_model_by_name
from genlm.bytes import BeamParams
from genlm.control.potential.built_in import ByteLLM

def build_bytelm(llm_name):
    llm = load_model_by_name("gpt2", backend="hf")
    beam_params = BeamParams(
        K=5,
        prune_threshold=0.0,
    )
    return ByteLLM(llm, beam_params)

byte_llm = build_bytelm("gpt2")
await example_gpt2_units(byte_llm, max_subunits=20)


Example 1: GPT-2 with Multi-Token Unit Sampling
--- Example 1a: Sentence-Level Units ---

Each 'unit' is a complete sentence ending with . ! or ?

Particles: [0.00:	[0;35m[[0m[0;35m][0m, 0.00:	[0;35m[[0m[0;35m][0m, 0.00:	[0;35m[[0m[0;35m][0m]


Task was destroyed but it is pending!
task: <Task cancelling name='Task-2336' coro=<AsyncTokenByteTrie._background_loop() running at /opt/miniconda3/envs/gen/lib/python3.12/site-packages/genlm/bytes/trie.py:485> wait_for=<Future cancelled>>


-0.50:	[0;35m[[0mb'␣',␣b't'[0;35m][0m
-0.50:	[0;35m[[0mb'␣',␣b'i',␣b'n',␣b'␣',␣b'N',␣b'o',␣b'r',␣b't'[0;35m][0m
-0.50:	[0;35m[[0mb',',␣b'␣',␣b'r',␣b'u',␣b'n',␣b'n',␣b'i',␣b'n',␣b'g',␣b'␣',␣b'S',␣b'l',␣b'a',␣b'c',␣b'k',␣b'␣',␣b'w',␣b'a',␣b's',␣b'␣'[0;35m][0m
Particles: [0.00:	[0;35m[[0mb'␣',␣b'i',␣b'n',␣b'␣',␣b'N',␣b'o',␣b'r',␣b't'[0;35m][0m, 0.00:	[0;35m[[0mb'␣',␣b't'[0;35m][0m, -0.00:	[0;35m[[0mb',',␣b'␣',␣b'r',␣b'u',␣b'n',␣b'n',␣b'i',␣b'n',␣b'g',␣b'␣',␣b'S',␣b'l',␣b'a',␣b'c',␣b'k',␣b'␣',␣b'w',␣b'a',␣b's',␣b'␣'[0;35m][0m]
-0.50:	[0;35m[[0mb'␣',␣b't'[0;35m|[0mb'h',␣b'e',␣b'r',␣b'e',␣b'␣',␣b'w',␣b'e',␣b'r',␣b'e',␣b'␣',␣b'n',␣b'o',␣b'␣',␣b's',␣b't'[0;35m][0m
-0.50:	[0;35m[[0mb'␣',␣b'i',␣b'n',␣b'␣',␣b'N',␣b'o',␣b'r',␣b't'[0;35m|[0mb'h',␣b'␣',␣b'A',␣b'm',␣b'e',␣b'r',␣b'i',␣b'c',␣b'a',␣b'n',␣b'␣',␣b'b',␣b'a',␣b's',␣b'e',␣b'b',␣b'a',␣b'l',␣b'l',␣b'␣'[0;35m][0m
-0.50:	[0;35m[[0mb',',␣b'␣',␣b'r',␣b'u',␣b'n',␣b'n',␣b'i',␣b'n',␣b'g',␣b'␣',␣b'S',␣b'l',␣b'a',␣