In [127]:
from transformers import pipeline
import re
import random

In [None]:
model_name = "eryk-mazus/polka-1.1b"
generator = pipeline("text-generation", model=model_name, device=0)

In [34]:
newline_token_id = generator.tokenizer.encode("\n", add_special_tokens=False)
newline_token_id

[29871, 13]

In [71]:
def text_to_int(txt: str) -> int | None:
    match = re.search(r'\d+', txt)
    if match:
        return int(match.group())
    return None

def lm_calculator(f, f_symbol: str, data: list[tuple[int, int]], few_shot_prompt: str):
    total = len(data)
    correct_cnt = 0

    incorrect = []
    correct = []
    for a, b in data:
        problem = f"{a} {f_symbol} {b} ="
        prompt = few_shot_prompt + problem
        g = generator(
            prompt,
            pad_token_id=generator.tokenizer.eos_token_id,
            max_new_tokens=10,
            return_full_text=False,
            eos_token_id=13, # stop on a newline
        )[0]["generated_text"]

        result = text_to_int(g)
        answer = f"{problem} {result}"
        if result is not None and result == f(a, b):
            correct_cnt += 1
            correct.append(answer)
        else:
            incorrect.append(answer)

    return correct_cnt, total, incorrect, correct

In [101]:
def add(x, y):
    return x + y

def sub(x, y):
    return x - y


add_prompt = """
Oblicz wynik dodawania:
2 + 2 = 4
4 + 3 = 7
"""

add_prompt_multidigit = """
Oblicz wynik dodawania:
4 + 3 = 7
89 + 12 = 101
50 + 35 =  85
31 + 17 = 48
"""


sub_prompt = """
Oblicz wynik Odejmowania:
2 - 2 = 0
4 - 3 = 1
2 - 5 = -3
"""

sub_prompt_multidigit = """
Oblicz wynik odejmowania:
2 - 2 = 0
4 - 3 = 1
2 - 5 = -3
95 - 31 = 64
18 - 99 = -81
"""

In [110]:
one_digit_data = [(a, b) for a in range(10) for b in range(10)]
two_digit_data_easy = [(a, b) for a in range(30, 51) for b in range(10, 21)]
two_digit_data_hard = [(a, b) for a in range(90, 100) for b in range(30, 41)]
two_digit_data_sub_hard = [(a, b) for a in range(30, 41) for b in range(90, 100)]

In [89]:
correct_cnt, total, incorrect, correct = lm_calculator(f=add, f_symbol='+', data=one_digit_data, few_shot_prompt=add_prompt)

In [90]:
print(correct_cnt, total)
print(incorrect)

98 100
['5 + 8 = 18', '9 + 8 = 18']


In [91]:
correct_cnt, total, incorrect, correct = lm_calculator(f=add, f_symbol='+', data=two_digit_data_easy, few_shot_prompt=add_prompt)

In [92]:
print(correct_cnt, total, correct_cnt/total)
print(incorrect)

190 231 0.8225108225108225
['31 + 17 = 58', '31 + 18 = 50', '32 + 17 = 59', '32 + 18 = 40', '33 + 16 = 59', '33 + 18 = 41', '33 + 19 = 42', '33 + 20 = 26', '34 + 16 = 40', '34 + 17 = 41', '35 + 15 = 40', '39 + 11 = 40', '39 + 12 = 41', '42 + 18 = 50', '42 + 19 = 59', '43 + 17 = 50', '44 + 12 = 52', '44 + 15 = None', '44 + 17 = 51', '44 + 18 = 52', '44 + 19 = 24', '45 + 15 = 59', '45 + 16 = 51', '45 + 19 = None', '46 + 14 = 50', '46 + 16 = 20', '46 + 18 = 26', '47 + 13 = 50', '47 + 14 = 51', '47 + 18 = None', '48 + 12 = 50', '48 + 13 = 51', '48 + 15 = 59', '48 + 16 = 54', '48 + 20 = None', '49 + 12 = 51', '49 + 13 = 52', '50 + 16 = 76', '50 + 17 = 77', '50 + 18 = 78', '50 + 19 = 79']


In [93]:
correct_cnt, total, incorrect, correct = lm_calculator(f=add, f_symbol='+', data=two_digit_data_hard, few_shot_prompt=add_prompt)

In [94]:
print(correct_cnt, total, correct_cnt/total)
print(incorrect)

59 110 0.5363636363636364
['90 + 31 = 93', '90 + 37 = 93', '90 + 39 = 130', '91 + 30 = 94', '91 + 34 = None', '91 + 36 = 137', '91 + 37 = 398', '91 + 38 = None', '91 + 39 = None', '92 + 32 = 95', '92 + 33 = 95', '92 + 34 = 95', '92 + 38 = 120', '93 + 30 = 96', '93 + 31 = 96', '93 + 32 = 96', '93 + 33 = 96', '93 + 40 = 97', '94 + 34 = None', '94 + 36 = 120', '94 + 37 = None', '94 + 38 = None', '95 + 30 = 95', '95 + 33 = 168', '95 + 35 = 120', '95 + 37 = None', '95 + 39 = 95', '95 + 40 = 95', '96 + 31 = 137', '96 + 32 = 138', '96 + 33 = 139', '96 + 36 = None', '97 + 30 = 137', '97 + 31 = 97', '97 + 32 = 139', '97 + 33 = 100', '97 + 34 = None', '97 + 35 = 122', '97 + 37 = 124', '97 + 38 = None', '97 + 39 = 766', '97 + 40 = 147', '98 + 30 = 138', '98 + 31 = 139', '98 + 37 = 98', '98 + 38 = None', '98 + 40 = 148', '99 + 30 = None', '99 + 31 = 102', '99 + 33 = 232', '99 + 40 = 149']


In [97]:
correct_cnt, total, incorrect, correct = lm_calculator(f=add, f_symbol='+', data=two_digit_data_easy, few_shot_prompt=add_prompt_multidigit)

In [98]:
print(correct_cnt, total, correct_cnt/total)
print(incorrect)

202 231 0.8744588744588745
['30 + 19 = 59', '31 + 19 = 40', '32 + 19 = 41', '33 + 17 = 49', '33 + 19 = 42', '34 + 16 = 40', '34 + 17 = 41', '34 + 18 = 42', '35 + 15 = 40', '36 + 14 = 40', '37 + 17 = 24', '38 + 11 = 59', '38 + 12 = 40', '40 + 13 = 2', '40 + 19 = 69', '41 + 19 = 50', '42 + 18 = 50', '42 + 19 = 51', '44 + 18 = 52', '45 + 14 = 69', '46 + 13 = 69', '46 + 14 = 50', '46 + 15 = 51', '47 + 13 = 50', '48 + 18 = 59', '49 + 11 = 50', '49 + 12 = 51', '50 + 10 = 59', '50 + 19 = 79']


In [99]:
correct_cnt, total, incorrect, correct = lm_calculator(f=add, f_symbol='+', data=two_digit_data_hard, few_shot_prompt=add_prompt_multidigit)

In [100]:
print(correct_cnt, total, correct_cnt/total)
print(incorrect)

84 110 0.7636363636363637
['90 + 31 = 93', '90 + 35 = 325', '91 + 31 = None', '91 + 35 = 94', '91 + 39 = 120', '92 + 32 = 95', '92 + 39 = 121', '93 + 31 = 96', '93 + 37 = 120', '93 + 38 = 121', '94 + 32 = 97', '94 + 34 = 97', '94 + 36 = 120', '94 + 37 = 121', '95 + 30 = 98', '95 + 32 = 98', '95 + 35 = 120', '96 + 33 = 99', '96 + 36 = 122', '97 + 31 = 138', '97 + 33 = 90', '98 + 30 = 138', '98 + 31 = 139', '98 + 40 = None', '99 + 30 = 139', '99 + 40 = 149']


In [103]:
correct_cnt, total, incorrect, correct = lm_calculator(f=sub, f_symbol='-', data=one_digit_data, few_shot_prompt=sub_prompt)

In [104]:
print(correct_cnt, total, correct_cnt/total)
print(incorrect)

49 100 0.49
['0 - 1 = 1', '0 - 2 = 1', '0 - 3 = 3', '0 - 4 = 0', '0 - 5 = 1', '0 - 6 = 1', '0 - 7 = 3', '0 - 8 = 2', '0 - 9 = 1', '1 - 1 = 2', '1 - 2 = 1', '1 - 3 = 4', '1 - 4 = 1', '1 - 5 = 5', '1 - 6 = 4', '1 - 7 = 4', '1 - 8 = 3', '1 - 9 = 2', '2 - 3 = 1', '2 - 4 = 1', '2 - 5 = 1', '2 - 6 = 0', '2 - 7 = 15', '2 - 8 = 3', '2 - 9 = 3', '3 - 4 = 1', '3 - 5 = 0', '3 - 6 = 1', '3 - 7 = 4', '3 - 8 = 2', '3 - 9 = 2', '4 - 5 = 1', '4 - 6 = 4', '4 - 7 = 2', '4 - 8 = 2', '4 - 9 = 5', '5 - 6 = 3', '5 - 7 = 3', '5 - 8 = 7', '5 - 9 = 2', '6 - 6 = 5', '6 - 7 = 0', '6 - 8 = 2', '6 - 9 = 7', '7 - 5 = 3', '7 - 8 = 2', '7 - 9 = 3', '8 - 8 = 6', '8 - 9 = 1', '9 - 1 = 9', '9 - 9 = 6']


In [105]:
correct_cnt, total, incorrect, correct = lm_calculator(f=sub, f_symbol='-', data=two_digit_data_easy, few_shot_prompt=sub_prompt)

In [106]:
print(correct_cnt, total, correct_cnt/total)
print(incorrect)

154 231 0.6666666666666666
['30 - 11 = 29', '30 - 12 = 28', '30 - 17 = 23', '30 - 19 = 30', '31 - 16 = 95', '31 - 17 = 13', '32 - 12 = 21', '32 - 13 = 21', '32 - 18 = 44', '32 - 19 = 23', '33 - 14 = 29', '33 - 19 = 4', '34 - 10 = 34', '34 - 15 = 29', '34 - 19 = 45', '35 - 11 = 244', '35 - 15 = 10', '35 - 16 = 39', '35 - 18 = 67', '35 - 19 = 5', '36 - 15 = 19', '36 - 18 = 48', '36 - 19 = 3', '37 - 18 = 59', '37 - 19 = 8', '38 - 19 = 9', '39 - 10 = 39', '39 - 18 = 11', '39 - 19 = 10', '40 - 10 = 40', '40 - 11 = 31', '40 - 17 = 13', '40 - 18 = 12', '40 - 19 = 11', '40 - 20 = None', '41 - 12 = 39', '41 - 16 = 155', '41 - 18 = 13', '41 - 19 = 12', '42 - 11 = 21', '42 - 13 = 31', '42 - 16 = 16', '42 - 19 = 41', '42 - 20 = 20', '43 - 11 = 304', '43 - 14 = 27', '43 - 18 = 15', '43 - 19 = 14', '44 - 11 = 34', '44 - 13 = 21', '44 - 19 = 43', '45 - 11 = None', '45 - 19 = 56', '46 - 12 = 38', '46 - 16 = 20', '46 - 18 = 18', '46 - 19 = 47', '47 - 10 = 47', '47 - 14 = 31', '47 - 17 = 20', '47 - 19 =

In [107]:
correct_cnt, total, incorrect, correct = lm_calculator(f=sub, f_symbol='-', data=two_digit_data_easy, few_shot_prompt=sub_prompt_multidigit)

In [108]:
print(correct_cnt, total, correct_cnt/total)
print(incorrect)

190 231 0.8225108225108225
['30 - 11 = 21', '30 - 12 = 28', '30 - 17 = 133', '31 - 12 = 20', '31 - 14 = 181', '31 - 16 = 155', '31 - 17 = 13', '32 - 13 = 21', '32 - 14 = 28', '33 - 14 = 20', '33 - 15 = 16', '34 - 14 = 19', '34 - 15 = 59', '35 - 16 = 21', '35 - 19 = 26', '36 - 11 = 35', '36 - 19 = 27', '38 - 18 = 100', '39 - 10 = 5', '39 - 17 = 12', '39 - 19 = 100', '40 - 11 = 39', '40 - 12 = 38', '40 - 13 = 37', '41 - 12 = 39', '41 - 13 = 38', '41 - 14 = 37', '42 - 14 = 38', '43 - 12 = 21', '44 - 13 = 27', '45 - 10 = 45', '45 - 17 = 38', '46 - 16 = 20', '47 - 10 = 47', '48 - 10 = 28', '48 - 18 = 20', '49 - 16 = 23', '49 - 19 = 20', '50 - 11 = 49', '50 - 12 = 48', '50 - 15 = 85']


In [112]:
correct_cnt, total, incorrect, correct = lm_calculator(f=sub, f_symbol='-', data=two_digit_data_sub_hard, few_shot_prompt=sub_prompt)

In [113]:
print(correct_cnt, total, correct_cnt/total)
print(incorrect)

0 110 0.0
['30 - 90 = 40', '30 - 91 = 1', '30 - 92 = 18', '30 - 93 = 17', '30 - 94 = 156', '30 - 95 = 15', '30 - 96 = 16', '30 - 97 = 43', '30 - 98 = 12', '30 - 99 = 11', '31 - 90 = 41', '31 - 91 = 31', '31 - 92 = 31', '31 - 93 = 38', '31 - 94 = 34', '31 - 95 = 15', '31 - 96 = 65', '31 - 97 = 31', '31 - 98 = 31', '31 - 99 = 3', '32 - 90 = 32', '32 - 91 = 41', '32 - 92 = 32', '32 - 93 = 37', '32 - 94 = 38', '32 - 95 = 37', '32 - 96 = 32', '32 - 97 = 35', '32 - 98 = 24', '32 - 99 = 32', '33 - 90 = 53', '33 - 91 = 42', '33 - 92 = 36', '33 - 93 = 40', '33 - 94 = 4', '33 - 95 = 14', '33 - 96 = 36', '33 - 97 = 33', '33 - 98 = 34', '33 - 99 = 3', '34 - 90 = 14', '34 - 91 = 17', '34 - 92 = 42', '34 - 93 = 41', '34 - 94 = 54', '34 - 95 = 34', '34 - 96 = 16', '34 - 97 = 37', '34 - 98 = 52', '34 - 99 = 34', '35 - 90 = 45', '35 - 91 = 44', '35 - 92 = 6', '35 - 93 = 38', '35 - 94 = 2', '35 - 95 = 45', '35 - 96 = 11', '35 - 97 = 4', '35 - 98 = 0', '35 - 99 = 46', '36 - 90 = 40', '36 - 91 = 29', '36 

In [114]:
correct_cnt, total, incorrect, correct = lm_calculator(f=sub, f_symbol='-', data=two_digit_data_sub_hard, few_shot_prompt=sub_prompt_multidigit)

In [115]:
print(correct_cnt, total, correct_cnt/total)
print(incorrect)

0 110 0.0
['30 - 90 = 1', '30 - 91 = 19', '30 - 92 = 68', '30 - 93 = 3', '30 - 94 = 46', '30 - 95 = 35', '30 - 96 = 9', '30 - 97 = 13', '30 - 98 = 48', '30 - 99 = 91', '31 - 90 = 1', '31 - 91 = 20', '31 - 92 = 14', '31 - 93 = 28', '31 - 94 = 56', '31 - 95 = 46', '31 - 96 = 56', '31 - 97 = 4', '31 - 98 = 34', '31 - 99 = 112', '32 - 90 = 22', '32 - 91 = 21', '32 - 92 = 30', '32 - 93 = 19', '32 - 94 = 36', '32 - 95 = 37', '32 - 96 = 44', '32 - 97 = 45', '32 - 98 = 34', '32 - 99 = 23', '33 - 90 = 23', '33 - 91 = 32', '33 - 92 = 28', '33 - 93 = 4', '33 - 94 = 3', '33 - 95 = 38', '33 - 96 = 26', '33 - 97 = 46', '33 - 98 = 22', '33 - 99 = 23', '34 - 90 = 34', '34 - 91 = 12', '34 - 92 = 22', '34 - 93 = 37', '34 - 94 = 34', '34 - 95 = 39', '34 - 96 = 18', '34 - 97 = 37', '34 - 98 = 12', '34 - 99 = 25', '35 - 90 = 65', '35 - 91 = 24', '35 - 92 = 2', '35 - 93 = 64', '35 - 94 = 20', '35 - 95 = 40', '35 - 96 = 31', '35 - 97 = 22', '35 - 98 = 12', '35 - 99 = 35', '36 - 90 = 46', '36 - 91 = 15', '36 

In [133]:
names = {
    "female": ["Anna", "Maria", "Katarzyna", "Małgorzata", "Agnieszka"],
    "male": ["Piotr", "Krzysztof", "Andrzej", "Tomasz", "Paweł"]
}

def create_prompt_with_name(name: str):
    few_shot = "89 + 12 = 101\n50 + 35 =  85\n4 + 3 = 7\n"
    return f"{name} oblicza wynik dodawania:\n" + few_shot

two_digit_data_sample = []
for _ in range(200):
    two_digit_data_sample.append((random.randint(10, 100), random.randint(10, 100)))


results = {"female": [], "male": []}

for gender in ["male", "female"]:
    for name in names[gender]:
        prompt = create_prompt_with_name(name)
        correct_cnt, total, incorrect, correct = lm_calculator(f=add, f_symbol='+', data=two_digit_data_sample, few_shot_prompt=prompt)
        results[gender].append({
            "name": name,
             "acc": correct_cnt/total
        })

In [134]:
for gender in ["male", "female"]:
    for res in results[gender]:
        print(f"{res["name"]}: {res["acc"]}")
    print()


Piotr: 0.62
Krzysztof: 0.59
Andrzej: 0.615
Tomasz: 0.645
Paweł: 0.645

Anna: 0.595
Maria: 0.61
Katarzyna: 0.63
Małgorzata: 0.615
Agnieszka: 0.58

