## Prelude

In [40]:
%pip install datasets
%pip install matplotlib

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [41]:
import numpy as np
import pandas as pd
from tqdm import tqdm
tqdm.pandas()
import matplotlib.pyplot as plt
import os
import json

import multiprocess as mp

from datasets import load_dataset

In [42]:
SEED = 42

## CalibrationDataset class def

In [43]:
class CalibrationDataset:
    def __init__(
        self,
        dataset_name,
        df,
        is_equiv,
        get_prompt_content,
        seed,
    ):
        self.dataset_name = dataset_name
        self.df = df
        self.is_equiv = is_equiv
        self.get_prompt_content = get_prompt_content
        self.seed = seed

        self.df['q_id'] = self.df.index

    def get_examples_text(self, num_shots):
        return "\n\n".join(
            self.df.sample(num_shots, random_state=self.seed)
            .apply(lambda row: f"Question: {row['question']}\n\nAnswer: {row['answer']}", axis=1)
        )
    
    def write_requests_file(
        self,
        requests_file_path,
        num_questions,
        num_attempts_per_question,
        num_shots,
        model_name,
        max_response_tokens,
    ):
        df = self.df.sample(num_questions, random_state=self.seed).copy()

        att = pd.concat([df]*num_attempts_per_question, ignore_index=True)
        with open(requests_file_path, "w") as f:
            for index, row in att.iterrows():
                dict_for_one_request = {
                    "custom_id": f"request_{index}_qid_{row.q_id}",
                    "method": "POST", "url": "/v1/chat/completions",
                    "body": {
                    "model": model_name,
                        "messages": [
                            {"role": "system", "content": "You are a helpful assistant."},
                            {
                                "role": "user",
                                "content": self.get_prompt_content(
                                    question=row.question,
                                    examples_text=self.get_examples_text(num_shots),
                                )
                            }
                        ],
                        "max_tokens": max_response_tokens
                    }
                }
                print(json.dumps(dict_for_one_request), file=f)

In [60]:
class Results:
    def __init__(self, dataset, results_file_path):
        def parse_json_string(line):
            try:
                return json.loads(line)
            except ValueError:
                return None
        
        with open(results_file_path, "r") as f:
            lines = [parse_json_string(line.rstrip()) for line in f]
        
        line_count = len(lines)
        lines = [line for line in lines if line is not None]
        lines_dropped = line_count - len(lines)
        print(f"Lines dropped: {lines_dropped}")
        df = pd.DataFrame.from_dict(lines)
        df["q_id"] = df["custom_id"].str.split("_").str[3].astype(int)
        df["attempt"] = df["response"].apply(
            lambda r: r['choices'][0]['message']['content']
        )

        df = df[['q_id', 'attempt']].merge(dataset.df[['q_id', 'question', 'answer']], how='left', on='q_id')
        df['correct'] = df.progress_apply(lambda row: dataset.is_equiv(row['attempt'], row['answer']), axis=1)

        self.df = df

    def top1_confs(self):
        num_attempts_per_question = self.df.groupby('q_id').size().max()
        confs = (
            self.df[['q_id', 'correct']]
            .assign(confidence=1/num_attempts_per_question)
            .groupby(['q_id', 'correct'])
            .sum()
            .reset_index()
        )
        return confs
    
    def top1_acc(self):
        return self.top1_confs().query("correct").confidence.sum()

## GSM8K (1000 problems)

In [44]:
def gsm8k_str_to_num_parser(s : str) -> float:
    if isinstance(s, float) or isinstance(s, int):
        return s
    try:
        ending = s.split("####")[-1]
        strip_non_numbers = "".join((c for c in ending if (c in "1234567890.-")))
        return float(strip_non_numbers)
    except Exception as e:
        return float("nan")
    
def gsm8k_is_equiv(s1, s2):
    return gsm8k_str_to_num_parser(s1) == gsm8k_str_to_num_parser(s2)
    
def gsm8k_get_prompt_content(question, examples_text):
    return (
        "Please answer the following question.\n\n"
        + f"Question: {question}\n\n"
        + "Please give your reasoning, then output your final answer as a single number immediately preceded by #### with nothing after.\n\n"
        + f"Examples:\n\n{examples_text}"
    )

In [49]:
gsm8k_all = CalibrationDataset(
    dataset_name="gsm8k_all",
    df = pd.DataFrame(load_dataset("gsm8k", "main")["test"]),
    is_equiv = gsm8k_is_equiv,
    get_prompt_content=gsm8k_get_prompt_content,
    seed=SEED,
)

gsm8k_all.write_requests_file(
    requests_file_path="requests/gsm8k_requests_all_llama3_8b.jsonl",
    num_questions=1319,
    num_attempts_per_question=20,
    num_shots=5,
    model_name="meta-llama/Meta-Llama-3-8B-Instruct",
    max_response_tokens=1000,
)

In [61]:
gsm8k_all_results_llama3_8b = Results(
    dataset=gsm8k_all,
    results_file_path="results/results_gsm8k_llama3_8b_500q_20a_broken.jsonl",
)

Lines dropped: 18


100%|██████████| 9980/9980 [00:00<00:00, 172748.28it/s]


In [62]:
gsm8k_all_results_llama3_8b.get_top1_confs()

Unnamed: 0,q_id,correct,confidence
0,0,True,1.00
1,1,True,1.00
2,2,False,0.85
3,2,True,0.15
4,3,False,0.05
...,...,...,...
843,497,False,0.40
844,497,True,0.55
845,498,True,0.95
846,499,False,0.35
