In [29]:
from typing import List
from transformers import LlamaTokenizerFast

class TrieNode:
    def __init__(self):
        self.children: dict[int, TrieNode] = {}
        self.is_leaf = False

class TokenTrie:
    def __init__(self, tokenizer: LlamaTokenizerFast):
        self.root = TrieNode()
        self.tokenizer = tokenizer

    def init_from_words(self, words: List[str]):
        for word in words:
            self.insert_word(" " + word)

    def insert_word(self, word: str):
        tokens = self.tokenizer.encode(word, add_special_tokens=False)
        tokens.append(self.tokenizer.eos_token_id)
        self._insert_tokens(tokens)

    def walk(self, token_ids: List[int]):
        cur = self.root
        for id in token_ids:
            if id not in cur.children:
                raise ValueError(f"Token {id} not in trie children at node {cur}")
            cur = cur.children[id]
        return cur

    def _insert_tokens(self, token_ids: List[int]):
        cur = self.root
        for id in token_ids:
            if id not in cur.children:
                cur.children[id] = TrieNode()
            cur = cur.children[id]
        cur.is_leaf = True

In [30]:
from transformers import LogitsProcessor
import torch
class TrieLogitsProcessor(LogitsProcessor):
    def __init__(self, trie: TokenTrie, prompt_len: int):
        self.trie = trie
        self.prompt_len = prompt_len

    def __call__(self, input_ids, scores):
        masked_scores = torch.full_like(scores, -float("inf"))

        batch_size = input_ids.shape[0]
        for i in range(batch_size):
            generated_tokens = input_ids[i, self.prompt_len:].tolist()
            cur_node = self.trie.walk(generated_tokens)
            allowed_tokens = list(cur_node.children.keys())

            if not allowed_tokens:
                raise RuntimeError("No possible continuation")
            masked_scores[i, allowed_tokens] = scores[i, allowed_tokens]

        return masked_scores

In [32]:
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "eryk-mazus/polka-1.1b"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

In [51]:
allowed_words = ["Kraków", "Warszawa", "Londyn", "Stachu", "Opole", "Alutek", "Wawel", "Wieża Eiffela", "Monke"]
trie = TokenTrie(tokenizer=tokenizer)
trie.init_from_words(allowed_words)

prompt = "Nazywa słynnego zamku w Krakowie:"
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs.input_ids.to(model.device)
prompt_len = input_ids.shape[1]

processor = TrieLogitsProcessor(trie, prompt_len=prompt_len)


outputs = model.generate(
    input_ids,
    max_new_tokens=10,
    logits_processor=[processor],
    num_beams=3,
    top_k=50,
    pad_token_id=tokenizer.eos_token_id,
    eos_token_id=tokenizer.eos_token_id
)

decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
print(decoded)

RuntimeError: No possible continuation