In [1]:
import pandas as pd
import torch
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# return a list of tuples [(a, b,c)]
def get_samples(lines):
    samples = []
    for line in lines:
        # take advantage of the coloring thing
        if line.startswith('\x1b[92m'):
            clean_line = line.replace('\x1b[92m', '').replace('\x1b[0m', '').strip().replace(',', '')
            a = clean_line.split('=')[0]
            c = clean_line.split('=')[1]
            a, b = a.split('+')
            samples.append((int(a), int(b), c.strip()))
    return samples

# return a list of tuples [(a, b,c)] for few-shot results
def get_samples_fewshot(lines):
    samples = []
    for line in lines:
        # take advantage of the coloring thing
        if line.startswith('\x1b[92m'):
            clean_line = line.replace('\x1b[92m', '').replace('\x1b[0m', '').strip().replace(',', '')
            a = clean_line.split('=')[0]
            c = clean_line.split('=')[1]
            a, b = a.split('+')
            samples.append((int(a), int(b), c.strip()))
    return samples

# return a list of tuples [(a, b,c)] for few-shot results
def get_samples_fewshot(lines, num_fewshot=5):
    samples = []
    for line in lines:
        # take advantage of the coloring thing
        if line.startswith('\x1b[92m'):
            clean_line = line.replace('\x1b[92m', '').replace('\x1b[0m', '').strip().replace(',', '').split(';')[num_fewshot:]
            a = clean_line[0].strip().split('=')[0]
            c = clean_line[0].strip().split('=')[1] + ';' + ';'.join(clean_line[1:])
            a, b = a.split('+')
            samples.append((int(a), int(b), c.strip()))
    return samples

models = ['llama-7b', 'mpt-7b', 'mpt-7b-instruct']
samples_dict = {}
for model in models:
    for suffix in ['-fewshot-results-0-temp-without-spaces.txt']:
        filename = model + suffix
        with open(filename, 'r') as f:
            lines = f.readlines()
            print(f'Processing: {filename}')
            samples = get_samples_fewshot(lines, num_fewshot=5)
            key = (model + '-without-spaces') if 'without-spaces' in suffix else (model + '-with-spaces')
            samples_dict[key] = samples

models = ['llama-7b', 'mpt-7b', 'mpt-7b-instruct']
samples_dict = {}
for model in models:
    for suffix in ['-results-0-temp-with-spaces.txt', '-results-0-temp-without-spaces.txt', '-fewshot-results-0-temp-without-spaces.txt',
                   '-10-fewshot-results-0-temp-without-spaces.txt',]:
        filename = model + suffix
        with open(filename, 'r') as f:
            lines = f.readlines()
            print(f'Processing: {filename}')
            if 'fewshot' in suffix:
                num_fewshot = 5
                if '10' in suffix:
                    num_fewshot = 10
                samples = get_samples_fewshot(lines, num_fewshot=num_fewshot)
                key = (model + f'-{num_fewshot}-fewshot-without-spaces') if 'without-spaces' in suffix else (model + f'-{num_fewshot}-fewshot-with-spaces')
            else:
                samples = get_samples(lines)
                key = (model + '-without-spaces') if 'without-spaces' in suffix else (model + '-with-spaces')
            samples_dict[key] = samples

Processing: llama-7b-fewshot-results-0-temp-without-spaces.txt
Processing: mpt-7b-fewshot-results-0-temp-without-spaces.txt
Processing: mpt-7b-instruct-fewshot-results-0-temp-without-spaces.txt
Processing: llama-7b-results-0-temp-with-spaces.txt
Processing: llama-7b-results-0-temp-without-spaces.txt
Processing: llama-7b-fewshot-results-0-temp-without-spaces.txt
Processing: mpt-7b-results-0-temp-with-spaces.txt
Processing: mpt-7b-results-0-temp-without-spaces.txt
Processing: mpt-7b-fewshot-results-0-temp-without-spaces.txt
Processing: mpt-7b-instruct-results-0-temp-with-spaces.txt
Processing: mpt-7b-instruct-results-0-temp-without-spaces.txt
Processing: mpt-7b-instruct-fewshot-results-0-temp-without-spaces.txt


In [3]:
# evaluate samples
def evaluate_samples(samples, debug_mode=False):
    correct = 0
    total = len(samples)
    correct_samples = []
    incorrect_samples = []
    for (a, b, c) in samples:
        true_c = str(a + b)
        if c.startswith(true_c):
            correct += 1
            correct_samples.append((a, b, c))
            if debug_mode:
                print(f"Correct! {a} + {b} = {c} = {true_c}")
        else:
            incorrect_samples.append((a, b, c))
            if debug_mode:
                print(f"Wrong! {a} + {b} = {c} != {true_c}")
    return 100.0 * correct / total, correct_samples, incorrect_samples

In [4]:
results_dict = {}
correct_samples_dict = {}
incorrect_samples_dict = {}
for key in samples_dict.keys():
    results_dict[key], correct_samples_dict[key], incorrect_samples_dict[key] = evaluate_samples(samples_dict[key])

In [5]:
results_df = pd.DataFrame.from_dict(results_dict, orient='index', columns=['Accuracy'])
results_df

Unnamed: 0,Accuracy
llama-7b-with-spaces,80.2
llama-7b-without-spaces,75.7
llama-7b-fewshot-without-spaces,77.5
mpt-7b-with-spaces,6.0
mpt-7b-without-spaces,6.0
mpt-7b-fewshot-without-spaces,41.9
mpt-7b-instruct-with-spaces,7.3
mpt-7b-instruct-without-spaces,7.3
mpt-7b-instruct-fewshot-without-spaces,37.1


In [6]:
def print_pretty_results(samples, correct_answers=True):
    CRED = '\033[91m'
    CEND = '\033[0m'
    CGREEN = '\033[92m'
    CBLUE = '\033[94m'
    for (a, b, c) in samples:
        true_c = str(a + b)
        if correct_answers:
            print(f'{a} + {b} = {CBLUE} {c} {CEND} = {CGREEN}' + f'{true_c}{CEND}')
        else:
            print(f'{a} + {b} = {CBLUE} {c} {CEND} = {CRED}' + f'{true_c}{CEND}')
        

In [7]:
print_pretty_results(correct_samples_dict['mpt-7b-without-spaces'][:10], correct_answers=True)

131 + 785 = [94m 916. The number 916 is a composite [0m = [92m916[0m
211 + 11 = [94m 222. The number of the beast. [0m = [92m222[0m
33 + 62 = [94m 95. The number of days until the end of [0m = [92m95[0m
692 + 652 = [94m 1344. The number 1344 is the sum [0m = [92m1344[0m
315 + 145 = [94m 460. The number of days in a year. [0m = [92m460[0m
598 + 119 = [94m 717. The sum of the digits of the number [0m = [92m717[0m
1000 + 239 = [94m 1239. The number 1239 is a composite [0m = [92m1239[0m
536 + 235 = [94m 771. The number 771 is the sum of [0m = [92m771[0m
140 + 145 = [94m 285. The number of days in the year. [0m = [92m285[0m
127 + 107 = [94m 234. The number of days until the next World [0m = [92m234[0m


In [8]:
print_pretty_results(incorrect_samples_dict['mpt-7b-without-spaces'][:10], correct_answers=False)

813 + 124 = [94m 1827. The number of days in [0m = [91m937[0m
65 + 588 = [94m 719. The number 719 is a composite [0m = [91m653[0m
26 + 512 = [94m 1313. The number 1313 is a prime [0m = [91m538[0m
622 + 197 = [94m 739. 739 is the sum of the [0m = [91m819[0m
612 + 493 = [94m 719. The number 719 is a composite [0m = [91m1105[0m
733 + 190 = [94m 943. The number 943 is a composite [0m = [91m923[0m
787 + 29 = [94m 826. The number 826 is a pal [0m = [91m816[0m
891 + 154 = [94m 1916. The number of days in [0m = [91m1045[0m
178 + 567 = [94m 789. The sum of the digits of the [0m = [91m745[0m
622 + 35 = [94m 669. The number 669 is a composite [0m = [91m657[0m


In [9]:
print_pretty_results(correct_samples_dict['mpt-7b-instruct-without-spaces'][:10], correct_answers=True)

211 + 11 = [94m 222. The number of the beast. [0m = [92m222[0m
174 + 154 = [94m 328. The sum of the first two numbers is [0m = [92m328[0m
33 + 62 = [94m 95. The number of days until the next World [0m = [92m95[0m
103 + 148 = [94m 251 [0m = [92m251[0m
844 + 579 = [94m 1423 [0m = [92m1423[0m
731 + 84 = [94m 815 [0m = [92m815[0m
536 + 235 = [94m 771. The sum of the digits in the number [0m = [92m771[0m
846 + 73 = [94m 919. The sum of the digits of the [0m = [92m919[0m
362 + 316 = [94m 678. The sum of the digits of the number [0m = [92m678[0m
342 + 141 = [94m 483. The sum of the digits of the number [0m = [92m483[0m


In [10]:
print_pretty_results(incorrect_samples_dict['mpt-7b-instruct-without-spaces'][:50], correct_answers=False)

813 + 124 = [94m 1837 [0m = [91m937[0m
65 + 588 = [94m 719 [0m = [91m653[0m
26 + 512 = [94m 1337. I love you. [0m = [91m538[0m
622 + 197 = [94m 829. The sum of the digits of the [0m = [91m819[0m
612 + 493 = [94m 1215 [0m = [91m1105[0m
733 + 190 = [94m 943. The sum of the digits of the [0m = [91m923[0m
787 + 29 = [94m 826. The sum of the digits in the [0m = [91m816[0m
891 + 154 = [94m 1965. The sum of the digits [0m = [91m1045[0m
178 + 567 = [94m 749 so the answer is 749. [0m = [91m745[0m
622 + 35 = [94m 629. The sum of the digits of the [0m = [91m657[0m
146 + 954 = [94m 1965. The sum of the digits [0m = [91m1100[0m
694 + 828 = [94m 1826 [0m = [91m1522[0m
984 + 318 = [94m 1224 [0m = [91m1302[0m
843 + 635 = [94m 1758 [0m = [91m1478[0m
768 + 171 = [94m 929 [0m = [91m939[0m
944 + 326 = [94m 1269 [0m = [91m1270[0m
392 + 655 = [94m 1819 [0m = [91m1047[0m
403 + 72 = [94m 555. The number of days in the year. [0m = [91m475[

In [11]:
print_pretty_results(correct_samples_dict['mpt-7b-instruct-fewshot-without-spaces'][:10], correct_answers=True)

690 + 230 = [94m 920.; [0m = [92m920[0m
229 + 994 = [94m 1223; 965+965=18 [0m = [92m1223[0m
15 + 88 = [94m 103; 8+8=16. [0m = [92m103[0m
966 + 314 = [94m 1280; 965+965=19 [0m = [92m1280[0m
593 + 699 = [94m 1292.; [0m = [92m1292[0m
202 + 854 = [94m 1056.; [0m = [92m1056[0m
751 + 802 = [94m 1553; 965+965=19 [0m = [92m1553[0m
948 + 510 = [94m 1458.; [0m = [92m1458[0m
759 + 916 = [94m 1675; 965+965=19 [0m = [92m1675[0m
707 + 541 = [94m 1248.; [0m = [92m1248[0m


In [12]:
print_pretty_results(incorrect_samples_dict['mpt-7b-instruct-fewshot-without-spaces'][:10], correct_answers=False)

264 + 620 = [94m 836.; [0m = [91m884[0m
826 + 763 = [94m 1789; 765+828=17 [0m = [91m1589[0m
711 + 813 = [94m 1414.; [0m = [91m1524[0m
17 + 372 = [94m 385; 9+9=18; 8+ [0m = [91m389[0m
993 + 560 = [94m 1743; 865+928=17 [0m = [91m1553[0m
744 + 153 = [94m 977; 965+965=18 [0m = [91m897[0m
298 + 139 = [94m 447; 765+914=1769 [0m = [91m437[0m
333 + 837 = [94m 1150.; [0m = [91m1170[0m
695 + 137 = [94m 834; 865+928=17 [0m = [91m832[0m
397 + 361 = [94m 748.; [0m = [91m758[0m
