In [1]:
from genlm_control import InferenceEngine
from genlm_control.potential import PromptedLLM, BoolFSA, Potential
from genlm_control.sampler import direct_token_sampler, eager_token_sampler

# Sampling from a language model

In [2]:
# Load gpt2 (or any other HuggingFace model) using the HuggingFace backend.
# (Setting backend='vllm' will be much faster, but requires a GPU).
mtl_llm = PromptedLLM.from_name("gpt2", backend="hf", temperature=0.5)

INFO 02-26 08:49:15 __init__.py:183] Automatically detected platform cuda.




In [3]:
# Set the fixed prompt prefix for the language model.
# All language model predictions will be conditioned on the
# token ids which this string encodes to (via the LM's tokenizer).
mtl_llm.set_prompt_from_str("Montreal is")

In [4]:
# Load a sampler that proposes tokens by sampling directly
# from the language model's distribution.
sampler = direct_token_sampler(mtl_llm)

In [5]:
# Create an inference engine.
engine = InferenceEngine(sampler)

In [6]:
# Run SMC with 10 particles, a max sequence length of 25 tokens
# and an ESS threshold of 0.5.
sequences = await engine(n_particles=10, max_tokens=10, ess_threshold=0.5)

In [7]:
# Get the inferred posterior distribution over sequences.
sequences.posterior

0,1
key,value
"(b' a', b' city', b' of', b' a', b' thousand', b' people', b',', b' and', b' it', b""'s"")",0.10000062587195321
"(b' set', b' to', b' launch', b' a', b' new', b' tech', b' lab', b' in', b' its', b' own')",0.10000034590320922
"(b' a', b' city', b' with', b' a', b' high', b' number', b' of', b' businesses', b',', b' but')",0.10000022453141078
"(b' about', b' to', b' get', b' a', b' new', b' name', b' for', b' itself', b'.', b' The')",0.1000001402314081
"(b' a', b' city', b' of', b' about', b' 5', b' million', b' people', b'.', b' It', b' is')",0.10000006956968781
"(b' the', b' best', b' city', b' in', b' the', b' world', b' for', b' young', b' people', b'.')",0.09999994869131731
"(b' one', b' of', b' the', b' few', b' cities', b' in', b' Canada', b' that', b' offers', b' a')",0.09999975226188743
"(b' the', b' first', b' city', b' in', b' Canada', b' to', b' ban', b' smoking', b' in', b' public')",0.099999743833424
"(b' one', b' of', b' the', b' most', b' popular', b' destinations', b' in', b' Quebec', b',', b' and')",0.09999965745613204


# Prompt intersection

In [8]:
# Spawn a new language model. This is shallow copy, so both models
# share the same underlying language model.
bos_llm = mtl_llm.spawn()
# Set a different prompt for the new language model.
bos_llm.set_prompt_from_str("Boston is")

In [9]:
# Take the product of the two language models.
product = mtl_llm * bos_llm

In [10]:
# Load a token sampler that samples next tokens directly from the
# product of the two language models.
sampler = direct_token_sampler(product)

In [11]:
# Create an inference engine.
engine = InferenceEngine(sampler)

In [12]:
# Run the inference engine for 10 particles with a max sequence length of 25 tokens
# and an ESS threshold of 0.5.
sequences = await engine(n_particles=10, max_tokens=10, ess_threshold=0.5)

In [13]:
sequences.posterior

0,1
key,value
"(b' the', b' only', b' city', b' in', b' the', b' country', b' that', b' doesn', b""'t"", b' have')",0.5
"(b' a', b' great', b' place', b' to', b' live', b'.', b' It', b""'s"", b' a', b' great')",0.39999999999999997
"(b' a', b' city', b' of', b' about', b' 1', b'.', b'3', b' million', b' people', b',')",0.09999999999999999


# Adding a regex constraint

In [14]:
best_fsa = BoolFSA.from_regex(r"is\sthe\s(best|worst).*")

In [15]:
# The following is valid but will be slow!
# slow_sampler = direct_token_sampler(
#    product * best_fsa.coerce(product, f=b''.join)
# )

# This sampler is much faster.
sampler = eager_token_sampler(product, best_fsa)

  ).to_sparse_csr()


In [16]:
engine = InferenceEngine(sampler)

In [19]:
sequences = await engine(n_particles=10, max_tokens=10, ess_threshold=0.5)

In [20]:
sequences.posterior

0,1
key,value
"(b'is', b' the', b' best', b' team', b' in', b' the', b' league', b'.', b' They', b' have')",0.4888266356714295
"(b'is', b' the', b' best', b' team', b' in', b' the', b' league', b'.', b' They', b""'ve"")",0.24441333869533186
"(b'is', b' the', b' best', b' team', b' in', b' the', b' league', b'.', b' The', b' team')",0.1479736079456642
"(b'is', b' the', b' best', b' team', b' in', b' the', b' league', b',', b' but', b' they')",0.11356924044156694
"(b'is', b' the', b' best', b'-', b'known', b',', b' most', b' successful', b',', b' and')",0.004418511235382777
"(b'is', b' the', b' best', b'-', b'known', b' Canadian', b' team', b' to', b' have', b' won')",0.000798666010624802


## Criticizing with a custom `Potential`

In [21]:
# A custom potential that does sentiment analysis.

import torch
from transformers import (
    DistilBertTokenizer,
    DistilBertForSequenceClassification,
)


class SentimentAnalysis(Potential):
    def __init__(self, model, tokenizer, sentiment="POSITIVE"):
        self.model = model
        self.tokenizer = tokenizer

        self.sentiment_idx = model.config.label2id.get(sentiment, None)
        if self.sentiment_idx is None:
            raise ValueError(f"Sentiment {sentiment} not found in model labels")

        super().__init__(vocabulary=list(range(256)))  # Defined over bytes.

    def _forward(self, contexts):
        strings = [
            bytes(context).decode("utf-8", errors="ignore") for context in contexts
        ]  # Convert bytes to strings.
        inputs = self.tokenizer(strings, return_tensors="pt", padding=True)  # Tokenize.
        with torch.no_grad():
            logits = self.model(**inputs).logits
        return logits.log_softmax(dim=-1)[:, self.sentiment_idx].cpu().numpy()

    async def prefix(self, context):
        return self._forward([context])[0].item()

    async def complete(self, context):
        return self._forward([context])[0].item()

    async def batch_complete(self, contexts):
        return self._forward(contexts)

    async def batch_prefix(self, contexts):
        return self._forward(contexts)


model_name = "distilbert-base-uncased-finetuned-sst-2-english"

sentiment_analysis = SentimentAnalysis(
    model=DistilBertForSequenceClassification.from_pretrained(model_name),
    tokenizer=DistilBertTokenizer.from_pretrained(model_name),
    sentiment="POSITIVE",
)

In [22]:
await sentiment_analysis.prefix(b"so good"), await sentiment_analysis.prefix(b"so bad")

(-0.00015841660206206143, -8.44865894317627)

In [26]:
# Check that our custom potential satisfies the potential contract.
await sentiment_analysis.assert_logw_next_consistency(b"the best", top=5)
await sentiment_analysis.assert_autoreg_fact(b"the best")

In [27]:
# The following is valid but will be slow!
# slow_sampler = eager_token_sampler(
#    iter_potential=product, item_potential=best_fsa * sentiment_analysis
# )

# This setup will be much faster.
sampler = eager_token_sampler(product, best_fsa)
critic = sentiment_analysis.coerce(sampler.target, f=b"".join)
engine = InferenceEngine(sampler, critic=critic)

In [29]:
sequences = await engine(n_particles=10, max_tokens=10, ess_threshold=0.5)

In [30]:
sequences.posterior

0,1
key,value
"(b'is', b' the', b' best', b' team', b' in', b' the', b' league', b'.', b' They', b' have')",0.5140855398885724
"(b'is', b' the', b' best', b' team', b' in', b' the', b' league', b'.', b' They', b' are')",0.1285215094145529
"(b'is', b' the', b' best', b' team', b' in', b' the', b' league', b'.', b' They', b""'ve"")",0.1285213259135228
"(b'is', b' the', b' best', b' team', b' in', b' the', b' league', b',', b' but', b' they')",0.11940118987790568
"(b'is', b' the', b' best', b' team', b' in', b' the', b' league', b',', b' and', b' they')",0.10269826874282272
"(b'is', b' the', b' best', b'-', b'known', b' of', b' the', b' three', b'.', b' He')",0.004838426304815362
"(b'is', b' the', b' best', b'-', b'known', b' and', b' most', b' popular', b' city', b' in')",0.001933739857808242


## Optimizing with autobatching

In [31]:
# This creates a new potential that automatically batches concurrent
# requests to the instance methods (`prefix`, `complete`, `logw_next`)
# and processes them using the batch methods (`batch_complete`, `batch_prefix`, `batch_logw_next`).
critic = critic.to_autobatched()
engine = InferenceEngine(sampler, critic=critic)

In [32]:
sequences = await engine(n_particles=10, max_tokens=10, ess_threshold=0.5)

In [33]:
sequences.posterior

0,1
key,value
"(b'is', b' the', b' best', b' team', b' in', b' the', b' league', b'.', b' They', b""'re"")",0.25002511729300825
"(b'is', b' the', b' best', b' team', b' in', b' the', b' league', b'.', b' The', b' team')",0.15136989929033556
"(b'is', b' the', b' best', b' team', b' in', b' the', b' league', b'.', b' They', b' are')",0.12501213401630581
"(b'is', b' the', b' best', b' team', b' in', b' the', b' league', b'.', b' They', b' have')",0.12501200827576064
"(b'is', b' the', b' best', b' team', b' in', b' the', b' league', b',', b' but', b' the')",0.11615715275068081
"(b'is', b' the', b' best', b' team', b' in', b' the', b' league', b',', b' and', b' they')",0.09989401368109278
"(b'is', b' the', b' best', b' team', b' in', b' the', b' league', b',', b' and', b' the')",0.09989362075757438
"(b'is', b' the', b' best', b'-', b'known', b' and', b' most', b' famous', b' of', b' the')",0.017487806312315345
"(b'is', b' the', b' best', b' place', b' to', b' get', b' a', b' good', b' view', b' of')",0.015148247622926587
