In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.nn import functional as F

In [3]:
import os

In [2]:
model_name = "eryk-mazus/polka-1.1b"
device = "cpu"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)


def log_probs_from_logits(logits, labels):
    logp = F.log_softmax(logits, dim=-1)
    logp_label = torch.gather(logp, 2, labels.unsqueeze(2)).squeeze(-1)
    return logp_label


def sentence_prob(sentence_txt):
    input_ids = tokenizer(sentence_txt, return_tensors="pt")["input_ids"].to(device)
    with torch.no_grad():
        output = model(input_ids=input_ids)
        log_probs = log_probs_from_logits(output.logits[:, :-1, :], input_ids[:, 1:])
        seq_log_probs = torch.sum(log_probs)
    return seq_log_probs.cpu().numpy()


In [12]:
def parse_variant_line(sentence: str):

    token_groups = [t for t in sentence.split() if t]

    parsed_sentence = []
    for group in token_groups:
        variants = group.split('|')
        parsed_sentence.append(variants)

    return parsed_sentence

def load_sentences(file_path: str):
    parsed_data = []

    with open(file_path, 'r', encoding='utf-8') as f:
        for i, line in enumerate(f):
            parsed_data.append(parse_variant_line(line))

    return parsed_data

dataset_path = os.path.join("..", "..", "datasets", "p2", "zdania_z_wariantami.txt")
parsed_sentences = load_sentences(dataset_path)

In [18]:
print(f"{len(parsed_sentences)} sentences total")
example = parsed_sentences[0]
for group in example:
    print(group)

300 sentences total
['parlament', 'parlamentem', 'parlamentów', 'parlamenty', 'parlamencie']
['zdecydował', 'zdecyduje', 'zdecydowałaby', 'zdecydujecie', 'zdecydowali', 'zdecydowała']
['jednak']
['inaczej']
['i']
['przyjął', 'przyjąć', 'przyjąłby', 'przyjmiecie', 'przyjęła', 'przyjęli']
['w']
['ustawie', 'ustawom', 'ustawami', 'ustawa', 'ustawą', 'ustawach']
['z']
['dnia', 'dzień', 'dniom', 'dniami', 'dni', 'dniach']
['28.06.1996']
['r.']
['jednoinstancyjne', 'jednoinstancyjnym', 'jednoinstancyjny', 'jednoinstancyjnego']
['postępowanie', 'postępowaniom', 'postępowania', 'postępowaniu', 'postępowań']
['orzeczniczo-lekarskie']
['.']


In [15]:
def beam_search_simple(variant_sentence, beam_width=3):
    beam = [("", 0.0)]
    for word_variants in variant_sentence:
        candidates = []
        for path_text, _ in beam:
            for variant in word_variants:
                if path_text == "":
                    new_sentence = variant
                else:
                    new_sentence = path_text + " " + variant

                score = sentence_prob(new_sentence)
                candidates.append((new_sentence, score))

        candidates.sort(key=lambda x: x[1], reverse=True)
        beam = candidates[:beam_width]

    return beam[0][0]

In [22]:
N = 5
debug = False
results = []
for i, parsed_sentence in enumerate(parsed_sentences[:N]):
    total_words = 0
    correct_words = 0

    solution = beam_search_simple(parsed_sentence, beam_width=3)

    predicted_words = solution.split()

    print(f"Sentence {i+1}")
    if debug:
        print(f"Sentence with variants:")
        for group in parsed_sentence:
            print(group)
        print(f"Prediction: {solution}")

    for j, pred_word in enumerate(predicted_words):
        if pred_word == parsed_sentence[j][0]:
            correct_words += 1
        total_words += 1

    print(f"Result {correct_words}/{total_words} = {correct_words}/{total_words}")
    print(30*"=")

    results.append((correct_words, total_words, correct_words/total_words))


Sentence 1
Result 16/16 = 16/16
Sentence 2
Result 37/43 = 37/43
Sentence 3
Result 10/10 = 10/10
Sentence 4
Result 18/18 = 18/18
Sentence 5
Result 18/18 = 18/18
