In [1]:
import json
import re
from transformers import AutoTokenizer
from tqdm import tqdm
import numpy as np
from functools import lru_cache

inf = -99999999

In [2]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")

In [3]:
# 查看tokenize格式
print(tokenizer.convert_ids_to_tokens(tokenizer("Janet has 16 eggs per day.\nShe eats 3 for breakfast and bakes 4 for her friends.\nSo 16-3-4=<<16-3-4=9>>9 eggs are left.\nShe sells them for $2 each.\nSo she makes $2*9=<<2*9=18>>18 at the farmers' market.\n#### 18").input_ids))
print(len(tokenizer.convert_ids_to_tokens(tokenizer("Janet has 16 eggs per day.\nShe eats 3 for breakfast and bakes 4 for her friends.\nSo 16-3-4=<<16-3-4=9>>9 eggs are left.\nShe sells them for $2 each.\nSo she makes $2*9=<<2*9=18>>18 at the farmers' market.\n#### 18").input_ids)))

['Jan', 'et', 'Ġhas', 'Ġ16', 'Ġeggs', 'Ġper', 'Ġday', '.', 'Ċ', 'She', 'Ġeats', 'Ġ3', 'Ġfor', 'Ġbreakfast', 'Ġand', 'Ġb', 'akes', 'Ġ4', 'Ġfor', 'Ġher', 'Ġfriends', '.', 'Ċ', 'So', 'Ġ16', '-', '3', '-', '4', '=', '<<', '16', '-', '3', '-', '4', '=', '9', '>>', '9', 'Ġeggs', 'Ġare', 'Ġleft', '.', 'Ċ', 'She', 'Ġsells', 'Ġthem', 'Ġfor', 'Ġ$', '2', 'Ġeach', '.', 'Ċ', 'So', 'Ġshe', 'Ġmakes', 'Ġ$', '2', '*', '9', '=', '<<', '2', '*', '9', '=', '18', '>>', '18', 'Ġat', 'Ġthe', 'Ġfarmers', "'", 'Ġmarket', '.', 'Ċ', '####', 'Ġ18']
79


In [4]:
@lru_cache(maxsize=40960)
def my_tokenize(s):
    return tokenizer(s).input_ids


def get_tokens_from_text(text):
    return [post_process_text(x) for x in tokenizer.convert_ids_to_tokens(my_tokenize(text))]


def post_process_text(text):
    return text.replace("Ġ", " ").replace("Ċ", "\n").replace("ÃĹ", "×")


def get_all_indices(data_list, target):
    return [i for i, _ in enumerate(data_list) if _ == target]


def chunk_tokens_and_logprobs(tokens, logprobs):
    idx = len(tokens) - 1
    for i in range(len(tokens)-1):
        if tokens[i] == '\n' and tokens[i+1] == '\n':
            idx = i
            break
    return tokens[:idx], logprobs[:idx]


@lru_cache(maxsize=4096)
def get_answer(s):
    ans = ""
    if "####" in s:
        ans = s.split("####")[-1].replace("%%", "").replace(" ", "").strip()
    else:
        expression = re.findall("<<.+>>[0-9\.]+", s)
        if len(expression) == 0:
            ans = inf
        else:
            ans = expression[-1].split(">>")[-1].strip()
    return clean_ans(ans)


def clean_ans(s):
    s = str(s)
    if s and len(s) > 0 and s[-1] == '.':
        s = s[:-1]
    return s.lower()  # for CLUTRR and strategyQA use

### 加载带token logits的数据

In [5]:
# 从文件中加载数据
data = []

path = "/home/lyf/projects/aml-babel-components/datasets/gsm8k/prompts_samples.0(3).jsonl"
# path = ""

with open(path) as f:
    lines = f.readlines()
    for line in tqdm(lines):
        data.append(json.loads(line))

100%|██████████| 6595/6595 [00:34<00:00, 190.88it/s]


In [6]:
# 按一定格式，将数据装载到data_dict中
data_dict = {}
total_len = 0

for i, d in tqdm(enumerate(data)):
    question = d["metadata"]["question"]
    if question not in data_dict:
        data_dict[question] = {
            "metadata": d["metadata"].copy(),
            "samples": [d["choices"][x]["text"] for x in range(len(d["choices"]))],
            "tokens": [d["choices"][x]["logprobs"]["tokens"] for x in range(len(d["choices"]))],
            "token_logprobs": [d["choices"][x]["logprobs"]["token_logprobs"] for x in range(len(d["choices"]))],
        }
    else:
        data_dict[question]["samples"] += [d["choices"][x]["text"] for x in range(len(d["choices"]))]
        data_dict[question]["tokens"] += [d["choices"][x]["logprobs"]["tokens"] for x in range(len(d["choices"]))]
        data_dict[question]["token_logprobs"] += [d["choices"][x]["logprobs"]["token_logprobs"] for x in range(len(d["choices"]))]
        
    total_len += len(d["choices"])

print(total_len)
print(len(data_dict))
cnt = 0
for i, ques in tqdm(enumerate(data_dict)):
    for j in range(len(data_dict[ques]["samples"])):
        if "Question:" in data_dict[ques]["samples"][j]:
            data_dict[ques]["samples"][j] = data_dict[ques]["samples"][j].split("Question:")[0].strip()
        if ":" == data_dict[ques]["tokens"][j][0]:
            data_dict[ques]["tokens"][j] = data_dict[ques]["tokens"][j][1:]
            data_dict[ques]["token_logprobs"][j] = data_dict[ques]["token_logprobs"][j][1:]
        data_dict[ques]["tokens"][j], data_dict[ques]["token_logprobs"][j] = chunk_tokens_and_logprobs(data_dict[ques]["tokens"][j], data_dict[ques]["token_logprobs"][j])

6595it [00:00, 25146.99it/s]


131900
1319


1319it [00:04, 296.44it/s]


In [7]:
q1 = list(data_dict.keys())[0]
t = data_dict[q1]["samples"][3]
print(t)
print(get_answer(t))
token_list = [post_process_text(x) for x in tokenizer.convert_ids_to_tokens(my_tokenize(t))]
print(token_list)
print(data_dict[q1]["tokens"][3])
print(get_all_indices(token_list, '\n'))

Each duck lays <<16/3=5.33>>5.33 eggs per day
Janet eats 3 eggs for breakfast and 4 for baking so 3+4=<<3+4=7>>7 eggs per day
So the remainder is 5.33-7=<<5.33-7=-1.67>>-1.67 eggs per day
So Janet makes <<-1.67*2=-3.34>>-3.34 dollars per day
#### -3.34
-3.34
['Each', ' duck', ' lays', ' <<', '16', '/', '3', '=', '5', '.', '33', '>>', '5', '.', '33', ' eggs', ' per', ' day', '\n', 'Jan', 'et', ' eats', ' 3', ' eggs', ' for', ' breakfast', ' and', ' 4', ' for', ' baking', ' so', ' 3', '+', '4', '=', '<<', '3', '+', '4', '=', '7', '>>', '7', ' eggs', ' per', ' day', '\n', 'So', ' the', ' remainder', ' is', ' 5', '.', '33', '-', '7', '=', '<<', '5', '.', '33', '-', '7', '=-', '1', '.', '67', '>>', '-', '1', '.', '67', ' eggs', ' per', ' day', '\n', 'So', ' Janet', ' makes', ' <<', '-', '1', '.', '67', '*', '2', '=-', '3', '.', '34', '>>', '-', '3', '.', '34', ' dollars', ' per', ' day', '\n', '####', ' -', '3', '.', '34']
['Each', ' duck', ' lays', ' <<', '16', '/', '3', '=', '5', '.', '33

In [8]:
class Case():
    def __init__(self):
        self.question = ""
        self.ground_truth = ""
        self.samples = []
    
    def get_ground_truth_answer(self):
        ans = self.ground_truth.split("####")[-1].strip().replace("%%", "").replace(" ", "").replace("\n", "")
        return ans


class Sample():
    def __init__(self):
        self.text = ""
        self.steps = []
        self.step_tokens = []
        self.step_token_logprobs = []
        self.verifier_score = 0.0
    
    def get_final_answer(self):
        ans = ""
        if "####" in self.text:
            ans = self.text.split("####")[-1].strip().replace("%%", "").replace(" ", "").replace(", ", "").replace(",", "")
        else:
            ans = inf
        return clean_ans(ans)

cases = []
for i, ques in tqdm(enumerate(data_dict)):
    case = Case()
    case.question = ques
    case.ground_truth = data_dict[ques]["metadata"]["ground_truth"]
    for j, sample in enumerate(data_dict[ques]["samples"]):
        sample_obj = Sample()
        sample_obj.text = sample
        sample_obj.steps = [x for x in sample.split("\n")]
        sample_token_list = [post_process_text(x) for x in tokenizer.convert_ids_to_tokens(my_tokenize(sample))]
        step_symbol_indices = get_all_indices(sample_token_list, '\n')
        prev_index = 0
        for idx in range(len(step_symbol_indices)):
            sample_obj.step_tokens.append(data_dict[ques]["tokens"][j][prev_index : step_symbol_indices[idx]+1])
            sample_obj.step_token_logprobs.append(data_dict[ques]["token_logprobs"][j][prev_index : step_symbol_indices[idx]+1])
            prev_index = step_symbol_indices[idx] + 1
        sample_obj.step_tokens.append(data_dict[ques]["tokens"][j][prev_index : ])
        sample_obj.step_token_logprobs.append(data_dict[ques]["token_logprobs"][j][prev_index : ])
        case.samples.append(sample_obj)
    cases.append(case)

1319it [01:07, 19.57it/s]


In [9]:
print(len(cases))
print(cases[2].ground_truth)
print(cases[2].samples[2].step_tokens)
print(cases[2].samples[2].step_token_logprobs)

1319
The cost of the house and repairs came out to 80,000+50,000=$<<80000+50000=130000>>130,000
He increased the value of the house by 80,000*1.5=<<80000*1.5=120000>>120,000
So the new value of the house is 120,000+80,000=$<<120000+80000=200000>>200,000
So he made a profit of 200,000-130,000=$<<200000-130000=70000>>70,000
#### 70000
[['The', ' value', ' of', ' the', ' house', ' before', ' the', ' repairs', ' was', ' 80', ',', '000', '.', '\n'], ['The', ' value', ' of', ' the', ' house', ' after', ' the', ' repairs', ' was', ' 80', ',', '000', '*', '150', '%', ' =', ' <<', '80', ',', '000', '*', '150', '%', '=', '120', ',', '000', '>>', '120', ',', '000', '.', '\n'], ['The', ' profit', ' was', ' 120', ',', '000', '-', '80', ',', '000', ' =', ' <<', '120', ',', '000', '-', '80', ',', '000', '=', '40', ',', '000', '>>', '40', ',', '000', '.', '\n'], ['####', ' 40', ',', '000']]
[[-0.87423295, -1.4335212, -0.067064025, -0.05757702, -0.05662273, -1.9179745, -0.76921195, -0.29460934, -0.4415

In [10]:
import numpy as np


total = []
for case in cases:
    total_step_len = 0
    for sample in case.samples:
        step_len = len(sample.step_token_logprobs)
        total_step_len += step_len
    total_step_len /= len(case.samples)
    total.append(total_step_len)
print("所有case中【平均step数目】的最大值:", np.max(total))
print("所有case中【平均step数目】的最小值:", np.min(total))
print("所有case中【平均step数目】的平均值:", np.mean(total))


total_cnt = []
for case in cases:
    max_step_len = max([len(x.step_token_logprobs) for x in case.samples])
    cnt = 0
    for sample in case.samples:
        if len(sample.step_token_logprobs) == max_step_len:
            cnt += 1
    total_cnt.append(cnt)
print("所有case中【step最长的sample的个数】的最大值:", np.max(total_cnt))
print("所有case中【step最长的sample的个数】的最小值:", np.min(total_cnt))
print("所有case中【step最长的sample的个数】的平均值:", np.mean(total_cnt))

所有case中【平均step数目】的最大值: 11.81
所有case中【平均step数目】的最小值: 2.26
所有case中【平均step数目】的平均值: 4.669598180439728
所有case中【step最长的sample的个数】的最大值: 99
所有case中【step最长的sample的个数】的最小值: 1
所有case中【step最长的sample的个数】的平均值: 2.4981046247156935


In [11]:
for i, case in tqdm(enumerate(cases)):
    case.samples = sorted(case.samples, key=lambda x : np.mean(x.step_token_logprobs[0]) if len(x.step_token_logprobs) > 0 and len(x.step_token_logprobs[0]) > 0 else -99999, reverse=True)

1319it [00:02, 461.66it/s]


In [12]:
import random


def top_1_hit(gt_ans, preds):
    pred0_ans = preds[0].get_final_answer()
    return 1 if pred0_ans == gt_ans else 0


def random_1_hit(gt_ans, preds):
    idx = random.randint(0, len(preds)-1)
    # random 1 acc
    pred0_ans = preds[idx].get_final_answer()
    return 1 if pred0_ans == gt_ans else 0


def recall_hit(gt_ans, preds):
    for pred in preds:
        if pred.get_final_answer() == gt_ans:
            return 1
    return 0


def voting_hit(gt_ans, preds):
    # voting acc
    answers = {}
    for pred in preds:
        if pred.get_final_answer() not in answers:
            answers[pred.get_final_answer()] = 0
        answers[pred.get_final_answer()] += 1
    answers = sorted(answers.items(), key=lambda x : x[1], reverse=True)
    for i in range(len(answers)):
        ans, ans_cnt = answers[i][0], answers[i][1]
        if ans != inf:
            return 1 if ans == gt_ans else 0
    return 0

def weighted_voting_hit(gt_ans, preds):
    # voting acc
    answers = {}
    for pred in preds:
        if pred.get_final_answer() not in answers:
            answers[pred.get_final_answer()] = 0
        answers[pred.get_final_answer()] += pred.verifier_score
    answers = sorted(answers.items(), key=lambda x : x[1], reverse=True)
    for i in range(len(answers)):
        ans, ans_cnt = answers[i][0], answers[i][1]
        if ans != inf:
            return 1 if ans == gt_ans else 0
    return 0


def verification_hit(gt_ans, preds):
    preds = sorted(preds, key=lambda x : x.verifier_score, reverse=True)
    for pred in preds:
        ans = pred.get_final_answer()
        if ans != inf:
            return 1 if ans == gt_ans else 0
    return 0


def most_steps_voting_hit(gt_ans, preds):
    # voting acc
    answers = {}
    max_step_len = max([len(x.step_token_logprobs) for x in preds])
    for pred in preds:
        if max_step_len != len(pred.step_token_logprobs):
            continue
        if pred.get_final_answer() not in answers:
            answers[pred.get_final_answer()] = 0
        answers[pred.get_final_answer()] += 1
    answers = sorted(answers.items(), key=lambda x : x[1], reverse=True)
    # print(len(answers))
    for i in range(len(answers)):
        ans, ans_cnt = answers[i][0], answers[i][1]
        if ans != inf:
            return 1 if ans == gt_ans else 0
    return 0


def sum_probs_hit(gt_ans, preds):
    arr = []
    for pred in preds:
        summ = 0.0
        for step_probs in pred.step_token_logprobs:
            for prob in step_probs:
                summ += prob
        arr.append((pred, summ))
    arr = sorted(arr, key=lambda x : x[1], reverse=True)
    ans = arr[0][0].get_final_answer()
    if gt_ans == ans:
        return 1
    return 0


def sum_avg_probs_hit(gt_ans, preds):
    arr = []
    for pred in preds:
        vals = []
        for step_probs in pred.step_token_logprobs:
            if len(step_probs) == 0:
                continue
            val = np.mean(step_probs)  # 取每一个step概率的最小值
            vals.append(val)
        if len(vals) > 0:
            arr.append((pred, np.mean(vals)))
        else:
            arr.append((pred, -999999))
    arr = sorted(arr, key=lambda x : x[1], reverse=True)
    ans = arr[0][0].get_final_answer()
    if gt_ans == ans:
        return 1
    return 0


def sum_avg_probs_among_max_steps_hit(gt_ans, preds):
    arr = []
    max_steps = max([len(x.step_token_logprobs) for x in preds])
    for pred in preds:
        if len(pred.step_token_logprobs) == max_steps:
            summ = 0.0
            for step_probs in pred.step_token_logprobs:
                for prob in step_probs:
                    summ += prob
            summ /= len(pred.step_token_logprobs)
            arr.append((pred, summ))
    arr = sorted(arr, key=lambda x : x[1], reverse=True)
    ans = arr[0][0].get_final_answer()
    if gt_ans == ans:
        return 1
    return 0


def most_steps_hit(gt_ans, preds):
    arr = []
    for pred in preds:
        arr.append((pred, len(pred.step_token_logprobs)))
    arr = sorted(arr, key=lambda x : x[1], reverse=True)
    ans = arr[0][0].get_final_answer()
    if gt_ans == ans:
        return 1
    return 0


def least_steps_hit(gt_ans, preds):
    arr = []
    for pred in preds:
        arr.append((pred, len(pred.step_token_logprobs)))
    arr = sorted(arr, key=lambda x : x[1], reverse=False)
    ans = arr[0][0].get_final_answer()
    if gt_ans == ans:
        return 1
    return 0


def compute_top1_and_recall(data, rand_k=100):
    total_random_hit_cnt = 0
    total_sum_avg_probs_among_max_steps_cnt = 0
    total_least_steps_cnt = 0
    total_most_steps_cnt = 0
    total_most_steps_voting_cnt = 0
    total_sum_probs_cnt = 0
    total_sum_avg_probs_cnt = 0
    total_vote_cnt = 0
    total_recall_cnt = 0
    total_top1_cnt = 0
    total_verification_cnt = 0
    total_weighted_voting_cnt = 0
    for i, x in enumerate(data):
        gt_ans = x.get_ground_truth_answer()
        # slice = x.samples if rand_k >= len(x.samples) else random.sample(x.samples, rand_k)
        slice = x.samples if rand_k >= len(x.samples) else x.samples[:rand_k]
        
        total_random_hit_cnt += random_1_hit(gt_ans, slice)
        # total_sum_avg_probs_among_max_steps_cnt += sum_avg_probs_among_max_steps_hit(gt_ans, slice)
        # total_least_steps_cnt += least_steps_hit(gt_ans, slice)
        # total_most_steps_cnt += most_steps_hit(gt_ans, slice)
        # total_most_steps_voting_cnt += most_steps_voting_hit(gt_ans, slice)
        # total_sum_probs_cnt += sum_probs_hit(gt_ans, slice)
        # total_sum_avg_probs_cnt += sum_avg_probs_hit(gt_ans, slice)
        total_vote_cnt += voting_hit(gt_ans, slice)
        total_recall_cnt += recall_hit(gt_ans, slice)
        total_top1_cnt += top_1_hit(gt_ans, slice)
        total_verification_cnt += verification_hit(gt_ans, slice)
        total_weighted_voting_cnt += weighted_voting_hit(gt_ans, slice)
    result = {
        "random_top1": total_random_hit_cnt / 1319,
        # "sum_avg_probs_among_max_steps_top1": total_sum_avg_probs_among_max_steps_cnt / len(data),
        # "least_steps_hit": total_least_steps_cnt / len(data),
        # "most_steps_hit": total_most_steps_cnt / len(data),
        # "most_steps_voting_top1": total_most_steps_voting_cnt / len(data),
        # "sum_probs_top1": total_sum_probs_cnt / len(data),
        # "sum_avg_probs_top1": total_sum_avg_probs_cnt / len(data),

        "voting_top1_accuracy": total_vote_cnt / 1319,
        "recall": total_recall_cnt / 1319,
        "top1": total_top1_cnt / 1319,
        "verification": total_verification_cnt / 1319,
        "weighted_voting": total_weighted_voting_cnt / 1319,
    }
    return result


def compute_results(data, rand_k=100):
    total_random_hit_cnt = 0
    total_recall_cnt = 0
    total_vote_cnt = 0
    total_weighted_vote_cnt = 0
    total_verification_cnt = 0
    for i, x in enumerate(data):
        gt_ans = x.ground_truth.get_final_answer()
        slice = x.preds if rand_k == len(x.preds) else random.sample(x.preds, rand_k)
        
        total_random_hit_cnt += random_1_hit(gt_ans, slice)
        total_vote_cnt += voting_hit(gt_ans, slice)
        total_recall_cnt += recall_hit(gt_ans, slice)
    result = {
        "random_top1": total_random_hit_cnt / len(data), 
        f"recall@{rand_k}": total_recall_cnt / len(data),
        f"voting_top1_accuracy@{rand_k}": total_vote_cnt / len(data),
    }
    return result

In [13]:
compute_top1_and_recall(cases, rand_k=100)

{'random_top1': 0.5458680818802123,
 'voting_top1_accuracy': 0.7937831690674754,
 'recall': 0.9727065959059894,
 'top1': 0.5655799848369977,
 'verification': 0.5655799848369977,
 'weighted_voting': 0.5655799848369977}

### 加载round8版本的数据

In [15]:
# 从文件中加载数据
round8_file = "/home/lyf/projects/aml-babel-components/datasets/gsm8k/end_round8_diverse.jsonl"
# round8_file = ""

generator_outputs = [json.loads(line) for line in open(round8_file)]

beam_data_dict = {}

for data in generator_outputs:
    metadata = data["metadata"]
    question = data["metadata"]["question"]
    if question not in beam_data_dict:
        beam_data_dict[question] = {
            "samples": [],
        }
    solution = data["context"].split("Answer:")[-1]
    score = data["logprob"]
    beam_data_dict[question]["samples"].append((solution, score))
    beam_data_dict[question]["metadata"] = metadata

In [16]:
beam_cases = []
for i, ques in tqdm(enumerate(beam_data_dict)):
    case = Case()
    case.question = ques
    case.ground_truth = beam_data_dict[ques]["metadata"]["ground_truth"]
    for sample, score in beam_data_dict[ques]["samples"]:
        sample_obj = Sample()
        sample_obj.text = sample
        sample_obj.steps = [x for x in sample.split("\n") if x != ' ']
        sample_obj.verifier_score = score
        case.samples.append(sample_obj)
    beam_cases.append(case)

1318it [00:00, 67692.31it/s]


In [17]:
cnt = [0, 0, 0, 0, 0, 0]
for case in beam_cases:
    cnt[len(case.samples)] += 1
print(cnt)

[0, 3, 11, 45, 166, 1093]


### 从老数据中copy出round8版本数据里所有len(samples)>0的question对应的samples，便于对比

In [18]:
data_dict_copy = {}

for question in tqdm(beam_data_dict):
    if question in data_dict:
        data_dict_copy[question] = data_dict[question].copy()

100%|██████████| 1318/1318 [00:00<00:00, 404041.27it/s]


In [19]:
cases_copy = []
for i, ques in tqdm(enumerate(data_dict_copy)):
    case = Case()
    case.question = ques
    case.ground_truth = data_dict_copy[ques]["metadata"]["ground_truth"]
    for j, sample in enumerate(data_dict_copy[ques]["samples"]):
        sample_obj = Sample()
        sample_obj.text = sample
        sample_obj.steps = [x for x in sample.split("\n")]
        sample_token_list = [post_process_text(x) for x in tokenizer.convert_ids_to_tokens(my_tokenize(sample))]
        step_symbol_indices = get_all_indices(sample_token_list, '\n')
        prev_index = 0
        for idx in range(len(step_symbol_indices)):
            sample_obj.step_tokens.append(data_dict_copy[ques]["tokens"][j][prev_index : step_symbol_indices[idx]+1])
            sample_obj.step_token_logprobs.append(data_dict_copy[ques]["token_logprobs"][j][prev_index : step_symbol_indices[idx]+1])
            prev_index = step_symbol_indices[idx] + 1
        sample_obj.step_tokens.append(data_dict_copy[ques]["tokens"][j][prev_index : ])
        sample_obj.step_token_logprobs.append(data_dict_copy[ques]["token_logprobs"][j][prev_index : ])
        # if i == 44 and j == 21:
        #     print(sample_obj.text)
        #     print(sample_obj.steps)
        #     print(sample_obj.step_tokens)
        #     print(sample_obj.step_token_logprobs)
        case.samples.append(sample_obj)
    cases_copy.append(case)


1318it [00:55, 23.70it/s]


### 老版本数据结果

In [20]:
compute_top1_and_recall(cases_copy, rand_k=100)

{'random_top1': 0.5655799848369977,
 'voting_top1_accuracy': 0.7915087187263078,
 'recall': 0.9719484457922669,
 'top1': 0.5549658832448825,
 'verification': 0.5549658832448825,
 'weighted_voting': 0.5549658832448825}

### round8数据结果

In [21]:
compute_top1_and_recall(beam_cases, rand_k=100)

{'random_top1': 0.5633055344958302,
 'voting_top1_accuracy': 0.6557998483699773,
 'recall': 0.8127369219105383,
 'top1': 0.5246398786959818,
 'verification': 0.5648218347232752,
 'weighted_voting': 0.40636846095526913}

In [22]:
cnt = 0
for i, case in enumerate(beam_cases):
    gt_ans = case.get_ground_truth_answer()
    pred_ans = case.samples[0].get_final_answer()
    if gt_ans == pred_ans:
        cnt += 1
    else:
        if i < 500:
            print("i:", i)
            print("gt:", case.ground_truth)
            print("sample0:", case.samples[0].text)
print(cnt)

cnt = 0
for i, case in enumerate(beam_cases):
    gt_ans = case.get_ground_truth_answer()
    pred_ans = case.samples[0].get_final_answer()
    if gt_ans == pred_ans:
        cnt += 1
print(cnt)

i: 2
gt: The candle burns for 5 - 1 = <<5-1=4>>4 hours.
Thus, the candle will be 2 * 4 = <<2*4=8>>8 centimeters shorter.
#### 8
sample0: The candle will burn for 5 hours, so it will melt by 5 x 2 = <<5*2=10>>10 centimeters.
#### 10

i: 4
gt: He traveled 20 miles + 15 miles = <<20+15=35>>35 miles not counting the distance between stops.
Henry traveled 60 miles - 35 miles = <<60-35=25>>25 miles between his first and second stop.
#### 25
sample0: The distance between the first and second stops is 60 - 15 = <<60-15=45>>45 miles.
#### 45

i: 5
gt: He spends 10*.5=<<10*.5=5>>5 hours per day
That means he spends 5*7=<<5*7=35>>35 hours per week
#### 35
sample0: He spends 7 hours a week because 5 times 10 equals <<5*10=50>>50
#### 50

i: 6
gt: Let x be the number of silver coins Gretchen has
Gretchen has x+30 gold coins.
x+x+30=110
2*x=80
x=<<40=40>>40
Gretchen has 40+30=<<40+30=70>>70 gold coins
#### 70
sample0: Gretchen has 30+110 = <<30+110=140>>140 gold coins.
#### 140

i: 7
gt: Half of Ray