In [62]:
from typing import List
from transformers import LlamaTokenizerFast

class TrieNode:
    def __init__(self):
        self.children: dict[int, TrieNode] = {}
        self.is_leaf = False

class TokenTrie:
    def __init__(self, tokenizer: LlamaTokenizerFast):
        self.root = TrieNode()
        self.tokenizer = tokenizer

    def init_from_words(self, words: List[str]):
        for word in words:
            self.insert_word(" " + word)

    def insert_word(self, word: str):
        tokens = self.tokenizer.encode(word, add_special_tokens=False)
        tokens.append(self.tokenizer.eos_token_id)
        self._insert_tokens(tokens)

    def walk(self, token_ids: List[int]):
        cur = self.root
        for id in token_ids:
            if id not in cur.children:
                return None
            cur = cur.children[id]
        return cur

    def _insert_tokens(self, token_ids: List[int]):
        cur = self.root
        for id in token_ids:
            if id not in cur.children:
                cur.children[id] = TrieNode()
            cur = cur.children[id]
        cur.is_leaf = True

In [12]:
from transformers import LogitsProcessor
import torch
class TrieLogitsProcessor(LogitsProcessor):
    def __init__(self, trie: TokenTrie, prompt_len: int):
        self.trie = trie
        self.prompt_len = prompt_len

    def __call__(self, input_ids, scores):
        masked_scores = torch.full_like(scores, -float("inf"))

        batch_size = input_ids.shape[0]
        for i in range(batch_size):
            generated_tokens = input_ids[i, self.prompt_len:].tolist()
            cur_node = self.trie.walk(generated_tokens)
            if not cur_node or not cur_node.children:
                allowed_tokens = [self.trie.tokenizer.eos_token_id]
            else:
                allowed_tokens = list(cur_node.children.keys())
            masked_scores[i, allowed_tokens] = scores[i, allowed_tokens]

        return masked_scores

In [3]:
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "eryk-mazus/polka-1.1b"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

In [51]:
def generate_answers(question, trie, K):
    inputs = tokenizer(question, return_tensors="pt")
    input_ids = inputs.input_ids.to(model.device)
    prompt_len = input_ids.shape[1]
    processor = TrieLogitsProcessor(trie, prompt_len=prompt_len)
    outputs = model.generate(
        input_ids,
        max_new_tokens=15,
        logits_processor=[processor],
        num_beams=K,
        num_return_sequences=K,
        top_k=50,
        pad_token_id=tokenizer.eos_token_id,
        eos_token_id=tokenizer.eos_token_id,
        early_stopping=True
    )
    decoded_answers = []
    seen = set()
    for output in outputs:
        generated_ids = output[prompt_len:]
        text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
        if text not in seen:
            decoded_answers.append(text)
            seen.add(text)

    return decoded_answers


In [52]:
allowed_words = ["Krak√≥w", "Warszawa", "Londyn", "Stachu", "Opole", "Alutek", "Wawel", "Wie≈ºa Eiffela", "Monke"]
trie = TokenTrie(tokenizer=tokenizer)
trie.init_from_words(allowed_words)

test_question = "Nazwa s≈Çynnego zamku w Krakowie:"

ans = generate_answers(test_question, trie, K=7)
print(ans)

['Wawel', 'Krak√≥w', 'Stachu', 'Opole', 'Londyn']


## Riddles dataset and evaluation

In [54]:
from tqdm import tqdm
path_to_data = '../../datasets/riddles'

bases = {}
allowed_words_set = set()
allowed_words = []
answers = []
queries = []

def get_word_base(word):
    global bases
    word = word.lower()
    ret = bases.get(word)
    if ret:
        return ret
    return word

for x in open(f'{path_to_data}/superbazy_clean.txt'):
    word,base = x.lower().split()
    bases[word] = base

print("Loading allowed vocabulary...")
with open(f'{path_to_data}/plwiktionary_definitions_clean.txt', 'r', encoding='utf-8') as f:
    for line in f:
        parts = line.split('###')
        word = parts[0].strip()
        if ' ' not in word:
            allowed_words_set.add(word)

allowed_words = list(allowed_words_set)
print(f"Loaded {len(allowed_words)} unique allowed words.")

with open(f'{path_to_data}/zagadki_do_testow_clean.txt') as file:
    for line in file:
        line = line.replace(';;', '').split()
        answers.append(line[0])
        queries.append(' '.join(line[1:]))

Loading allowed vocabulary...
Loaded 8085 unique allowed words.


In [55]:
def mean_reciprocal_rank(real_answers, computed_answers, K=20):
    positions = []

    for real_answer, computed_answer in zip(real_answers, computed_answers):
        if real_answer in computed_answer[:K]:
            pos = computed_answer.index(real_answer) + 1
            positions.append(1 / pos)

    mrr = sum(positions) / len(real_answers)
    print('Mean Reciprocal Rank =', mrr)

    return mrr

def evaluate_algorithm(score_function, queries, answers, K):
    computed_answers = []
    for query in tqdm(queries, desc="queries answered"):
        computed_answers.append(score_function(query, K=K))
    score = mean_reciprocal_rank(answers, computed_answers, K=K)

    return score

In [69]:
print(f"Creating trie from {len(allowed_words)} words")
trie = TokenTrie(tokenizer=tokenizer)
trie.init_from_words(allowed_words)
print(f"Done")

Creating trie from 8085 words
Done


In [70]:
def answer_riddle(riddle, K):
    prompt = f"Zagadka: {riddle}\nOdpowied≈∫:"
    return generate_answers(prompt, trie, K)

In [74]:
import random

i = random.randint(0, len(queries) - 1)

sample_query = queries[i]
sample_answer = answers[i]

print(f"üîç TESTING RIDDLE INDEX {i}")
print(f"‚ùì Prompt: '{sample_query}'")
print(f"üéØ Real Answer: '{sample_answer}'")

# Run the model
# Make sure your answer_riddle function handles formatting the prompt!
# If not, do: preds = answer_riddle(f"{sample_query} Odpowied≈∫:", K=20)
preds = answer_riddle(sample_query, K=20)

print(f"\nü§ñ Model Predictions (Top {len(preds)}):")
print(preds)

# Check correctness
if sample_answer in preds:
    rank = preds.index(sample_answer) + 1
    print(f"\n‚úÖ SUCCESS! Found at rank #{rank}")
    print(f"   Score contribution: {1/rank:.4f}")
else:
    print(f"\n‚ùå FAILURE. Correct answer not found.")

üîç TESTING RIDDLE INDEX 229
‚ùì Prompt: 'narzƒôdzie lub przedmiot noszony na twarzy lub g≈Çowie, s≈Çu≈ºƒÖcy do zas≈Çaniania lub ochrony, czƒôsto stosowany w celach higienicznych, medycznych, dekoracyjnych lub w celach ochrony przed szkodliwymi substancjami.'
üéØ Real Answer: 'maska'

ü§ñ Model Predictions (Top 20):
['bi≈ºuteria', 'higiena', 'narzƒôdzie', 'zabawka', 'zagadka', 'przyprawa', 'artykulacja', 'maseczka', 'zab√≥jca', 'przypadek', 'zab√≥jstwo', 'gƒÖbka', 'medycyna', 'obuwie', 'skarbnik', 'zabieganie', 'plastik', 'gumka', 'zamek', 'zabieg']

‚ùå FAILURE. Correct answer not found.


In [73]:
PART_OF_DATA = 100
K = 20
valid_queries = queries[:PART_OF_DATA]
valid_answers = answers[:PART_OF_DATA]
score = evaluate_algorithm(answer_riddle, valid_queries, valid_answers, K=K)
print(f"Score: {score}")

queries answered:   0%|          | 0/100 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
queries answered: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 100/100 [04:24<00:00,  2.64s/it]

Mean Reciprocal Rank = 0.0645901686033265
Score: 0.0645901686033265



