In [1]:
import json
import os
from tqdm import tqdm
from collections import defaultdict
import numpy as np
from typing import *
import time
import copy
import warnings

from data.openai import *
from data.generation import *
from data.finetune import *
from data.inference import *
from data.io import *
from data.evaluation import *
from data.split import *

from utils import cleanse_answer
from utils.paths import *
from utils.metadata import *

In [2]:
import openai
openai.api_key = "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"

In [3]:
ALL_DATASETS = [
    "single_eq", "addsub", "multiarith", "gsm8k", "aqua", "svamp",
    "date_understanding", "coin_flip",
    "tracking_shuffled_objects", "last_letter_concatenation",
    "commonsense_qa", "strategy_qa",
]
datasets = ALL_DATASETS

# Compile Summary

In [142]:
summary = []

with warnings.catch_warnings():
    warnings.simplefilter(action='ignore', category=RuntimeWarning)
    
    # Compile teacher model performance
    template = None
    completion_key = "zs_cot"
    file_key = "zs_cot"
    method = "zs_cot"
    base_model_key = "text-davinci-002"
    
    model_key = "{}".format(base_model_key)
    for dataset_key in datasets:
        completion_data = load_completion_data(completion_key, dataset_key, model_key)
        item = dict()
        item["dataset"] = dataset_key
        item["method"] = method
        item["base_model"] = "idavinci"
        item["shot"] = ""
        item["aug"] = ""
        evaluation = evaluate_completions(completion_data, dataset_key, template=template, print_metrics=False)
        item.update(get_evaluation_metrics(evaluation))
        summary.append(item)
    
    # Compile student model performance
    template = None
    for base_model_key in ["ada", "babbage", "curie"]:
        for dataset_key in datasets:
            completion_keys = [
                "ft",
                "zs_cot",
                "finetune_cot",
                "finetune_cot",
                "finetune_cot",
                "finetune_cot",
            ]
            file_keys = [
                "{}_train".format(dataset_key),
                "",
                "zs_cot_special_{}_train".format(dataset_key),
                "zs_cot_special_{}_8shot".format(dataset_key),
                "zs_cot_special_{}_32shot".format(dataset_key),
                "zs_cot_special_{}_128shot".format(dataset_key),
            ]
            methods = [
                "ft",
                "zs_cot",
                "ft_cot",
                "ft_cot",
                "ft_cot",
                "ft_cot",
            ]
            shots = [
                "", "", "", 8, 32, 128
            ]
            augs = [
                "", "", 1, 1, 1, 1
            ]
            for shot in ["", 8, 32, 128]:
                if shot == "":
                    shot_str = ""
                else:
                    shot_str = "{}shot_".format(shot)
                for aug in [2, 4, 8, 16, 32]:
                    completion_keys.append("finetune_cot")
                    file_keys.append("zs_cot_special_{}_{}{}aug".format(dataset_key, shot_str, aug))
                    methods.append("ft_cot")
                    shots.append(shot)
                    augs.append(aug)
                    
            for completion_key, file_key, method, shot, aug in zip(
                completion_keys, file_keys, methods, shots, augs):
                if method == "zs_cot":
                    model_key = base_model_key
                else:
                    model_key = "{}_{}".format(base_model_key, file_key)
                completion_data = load_completion_data(completion_key, dataset_key, model_key)
                if not completion_data:
                    continue
                item = dict()
                item["dataset"] = dataset_key
                item["method"] = method
                item["base_model"] = base_model_key
                item["shot"] = shot
                item["aug"] = aug
                if completion_data:
                    if "special" in file_key:
                        template = "special"
                    else:
                        template = None
                    evaluation = evaluate_completions(completion_data, dataset_key, template=template, print_metrics=False)
                item.update(get_evaluation_metrics(evaluation))
                summary.append(item)

In [143]:
summary = pd.DataFrame(summary)
summary.sort_values(["dataset", "method", "base_model"])[::4]

Unnamed: 0,dataset,method,base_model,shot,aug,accuracy,contains_prediction,contains_answer,reason_complete,accuracy_when_reason_complete,accuracy_when_reason_incomplete,complete,accuracy_when_complete,accuracy_when_incomplete,contains_prefix,accuracy_with_prefix,accuracy_without_prefix
18,addsub,ft,ada,,,0.084034,1.0,0.084034,,,,0.991597,0.084746,0.0,,,
21,addsub,ft_cot,ada,8.0,1.0,0.008403,0.857143,0.243697,,,,0.0,,0.008403,0.0,,0.008403
112,addsub,ft_cot,babbage,8.0,1.0,0.033613,0.915966,0.033613,,,,0.966387,0.034783,0.0,0.882353,0.028571,0.071429
204,addsub,ft_cot,curie,8.0,1.0,0.12605,0.798319,0.134454,,,,0.747899,0.146067,0.066667,0.89916,0.140187,0.0
110,addsub,zs_cot,babbage,,,0.0,0.89916,0.067227,0.0,,0.0,,,,,,
147,aqua,ft,babbage,,,0.212598,0.889764,0.0,,,,0.889764,0.238938,0.0,,,
60,aqua,ft_cot,ada,128.0,1.0,0.177165,0.637795,0.0,,,,0.531496,0.259259,0.084034,0.523622,0.24812,0.099174
152,aqua,ft_cot,babbage,128.0,1.0,0.149606,0.759843,0.0,,,,0.712598,0.198895,0.027397,0.728346,0.194595,0.028986
244,aqua,ft_cot,curie,128.0,1.0,0.125984,0.712598,0.0,,,,0.61811,0.159236,0.072165,0.625984,0.157233,0.073684
4,aqua,zs_cot,idavinci,,,0.2963,0.879,0.0,0.4404,0.394187,0.219264,,,,,,


In [144]:
pd.set_option('display.max_rows', 100)
summary = pd.DataFrame(summary)
summary = summary.sort_values(["dataset", "method", "base_model", "dataset", "shot", "aug"])
summary[summary.dataset == "multiarith"]

Unnamed: 0,dataset,method,base_model,shot,aug,accuracy,contains_prediction,contains_answer,reason_complete,accuracy_when_reason_complete,accuracy_when_reason_incomplete,complete,accuracy_when_complete,accuracy_when_incomplete,contains_prefix,accuracy_with_prefix,accuracy_without_prefix
24,multiarith,ft,ada,,,0.088889,1.0,0.0,,,,1.0,0.088889,,,,
115,multiarith,ft,babbage,,,0.177778,1.0,0.0,,,,1.0,0.177778,,,,
207,multiarith,ft,curie,,,0.15,0.988889,0.0,,,,0.988889,0.151685,0.0,,,
27,multiarith,ft_cot,ada,8.0,1.0,0.044444,0.983333,0.0,,,,0.0,,0.044444,0.0,,0.044444
35,multiarith,ft_cot,ada,8.0,2.0,0.033333,0.883333,0.0,,,,0.627778,0.044248,0.014925,0.6,0.027778,0.041667
36,multiarith,ft_cot,ada,8.0,4.0,0.016667,0.994444,0.0,,,,0.972222,0.017143,0.0,0.95,0.017544,0.0
37,multiarith,ft_cot,ada,8.0,8.0,0.033333,1.0,0.0,,,,0.944444,0.035294,0.0,0.933333,0.035714,0.0
38,multiarith,ft_cot,ada,8.0,16.0,0.022222,1.0,0.0,,,,0.977778,0.022727,0.0,0.977778,0.022727,0.0
39,multiarith,ft_cot,ada,8.0,32.0,0.027778,0.994444,0.0,,,,0.972222,0.028571,0.0,0.977778,0.028409,0.0
28,multiarith,ft_cot,ada,32.0,1.0,0.038889,1.0,0.0,,,,0.777778,0.028571,0.075,0.777778,0.028571,0.075


In [145]:
summary.to_csv("evaluation_summary.csv")

In [146]:
selection = fs.dataset == "multiarith"
selection = selection & (fs.method == "ft_cot")
selection = selection & (fs.shot == 128)
selection = selection & (fs.aug == 16)
selection = selection & (fs.base_model== "curie")

In [147]:
accuracy_summary = defaultdict(list)

fs = summary
s = summary.loc[:,["dataset", "method", "shot", "aug"]].drop_duplicates()
for d, m, s, a in zip(s.dataset, s.method, s.shot, s.aug):
    accuracy_summary["dataset"].append(d)
    accuracy_summary["method"].append(m)
    accuracy_summary["shot"].append(s)
    accuracy_summary["aug"].append(a)
    for model in ["ada", "babbage", "curie", "idavinci"]:
        sums = []
        selection = fs.dataset == d 
        selection = selection & (fs.method == m)
        selection = selection & (fs.shot == s)
        selection = selection & (fs.aug == a)
        selection = selection & (fs.base_model == model)
    
        try:
            accuracy = fs[selection].accuracy.item()
        except ValueError:
            accuracy = None
        accuracy_summary[model].append(accuracy)
        
accuracy_summary = pd.DataFrame(accuracy_summary)
accuracy_summary

Unnamed: 0,dataset,method,shot,aug,ada,babbage,curie,idavinci
0,addsub,ft,,,0.084034,0.176471,0.252101,
1,addsub,ft_cot,8.0,1.0,0.008403,0.033613,0.12605,
2,addsub,ft_cot,32.0,1.0,0.05042,0.042017,0.184874,
3,addsub,ft_cot,128.0,1.0,0.02521,0.092437,0.226891,
4,addsub,ft_cot,,1.0,0.067227,0.117647,0.201681,
5,addsub,zs_cot,,,0.0,0.0,0.02521,0.757246
6,aqua,ft,,,0.244094,0.212598,0.153543,
7,aqua,ft_cot,32.0,1.0,0.141732,0.208661,0.090551,
8,aqua,ft_cot,128.0,1.0,0.177165,0.149606,0.125984,
9,aqua,ft_cot,,1.0,0.165354,0.153543,0.161417,


In [148]:
accuracy_summary.to_csv("evaluation_summary_accuracy.csv")

In [149]:
pd.DataFrame(accuracy_summary).to_csv("evaluation_summary_accuracy.csv")

### Common Verbal Names

In [122]:
DATASET_NAMES = {
    "single_eq": "SingleEq",
    "addsub": "AddSub",
    "multiarith": "MultiArith",
    "gsm8k": "GSM8K",
    "aqua": "AQUA",
    "svamp": "SVAMP",
    "commonsense_qa": "Common",  # SenseQA
    "strategy_qa": "Strategy",  # QA
    "date_understanding": "Date",  # Understanding
    "tracking_shuffled_objects": "Shuffled",  # Objects
    "last_letter_concatenation": "Last Letter",  # (4 words)
    "coin_flip": "Coin Flip",  # (4 times)

}

In [123]:
FULL_DATASET_NAMES = {
    "single_eq": "SingleEq",
    "addsub": "AddSub",
    "multiarith": "MultiArith",
    "gsm8k": "GSM8K",
    "aqua": "AQUA",
    "svamp": "SVAMP",
    "commonsense_qa": "CommonSenseQA",  # SenseQA
    "strategy_qa": "StrategyQA",  # QA
    "date_understanding": "Date Understanding",  # Understanding
    "tracking_shuffled_objects": "Shuffled Objects",  # Objects
    "last_letter_concatenation": "Last Letter (4 words)",  # (4 words)
    "coin_flip": "Coin Flip (4 words)",  # (4 times)

}

In [124]:
METHOD_NAMES = {
    "zs": "Zero-shot",
    "ft": "Fine-tune",
    "zs_cot": "Zero-shot-CoT",
    "ft_cot": "Fine-tune-CoT",
    "ft_cot_008shot": "8-shot Ft-CoT",
    "ft_cot_032shot": "32-shot Ft-CoT",
    "ft_cot_128shot": "128-shot Ft-CoT",
}

In [125]:
DATASETS = ["single_eq", "addsub", "multiarith", "gsm8k", "aqua", "svamp", "commonsense_qa",
            "strategy_qa", "date_understanding", "tracking_shuffled_objects",
            "last_letter_concatenation", "coin_flip"]

### Table 1.

In [126]:
lines = [
    ["single_eq", "addsub", "multiarith", "gsm8k", "aqua", "svamp"],
    ["commonsense_qa", "strategy_qa", "date_understanding", "tracking_shuffled_objects", "last_letter_concatenation", "coin_flip"],
]

In [127]:
table1_dfs = []

base_model = "curie"
for line in lines:
    data = defaultdict(dict)
    for dataset in line:
        for method in ["ft", "zs_cot", "ft_cot"]:
            try:
                s = summary
                s = s[s.dataset == dataset]
                s = s[s.method == method]
                s = s[s.base_model == base_model]
                s = s[s.shot == ""]
                s = s[s.aug.isin([1, ""])]
                accuracy = s.accuracy.item() * 100
                accuracy = "{:.2f}".format(accuracy)
            except (KeyError, ValueError):
                accuracy = ""
            
            dataset_name = DATASET_NAMES[dataset]
            method_name = METHOD_NAMES[method]
            data[dataset_name][method_name] = accuracy
    table1_dfs.append(pd.DataFrame(data))

In [128]:
table1_dfs[0]

Unnamed: 0,SingleEq,AddSub,MultiArith,GSM8K,AQUA,SVAMP
Fine-tune,24.34,25.21,15.0,6.14,15.35,20.67
Zero-shot-CoT,1.32,2.52,5.0,2.35,21.26,1.33
Fine-tune-CoT,21.05,20.17,34.44,7.2,16.14,12.33


In [129]:
table1_dfs[1]

Unnamed: 0,Common,Strategy,Date,Shuffled,Last Letter,Coin Flip
Fine-tune,76.17,65.21,14.41,33.78,32.67,72.0
Zero-shot-CoT,19.98,51.09,15.32,31.11,0.0,46.67
Fine-tune-CoT,51.02,47.16,60.36,64.0,52.0,98.0


In [130]:
print(table1_dfs[0].to_latex())

\begin{tabular}{lllllll}
\toprule
{} & SingleEq & AddSub & MultiArith & GSM8K &   AQUA &  SVAMP \\
\midrule
Fine-tune     &    24.34 &  25.21 &      15.00 &  6.14 &  15.35 &  20.67 \\
Zero-shot-CoT &     1.32 &   2.52 &       5.00 &  2.35 &  21.26 &   1.33 \\
Fine-tune-CoT &    21.05 &  20.17 &      34.44 &  7.20 &  16.14 &  12.33 \\
\bottomrule
\end{tabular}



  print(table1_dfs[0].to_latex())


In [131]:
print(table1_dfs[1].to_latex())

\begin{tabular}{lllllll}
\toprule
{} & Common & Strategy &   Date & Shuffled & Last Letter & Coin Flip \\
\midrule
Fine-tune     &  76.17 &    65.21 &  14.41 &    33.78 &       32.67 &     72.00 \\
Zero-shot-CoT &  19.98 &    51.09 &  15.32 &    31.11 &        0.00 &     46.67 \\
Fine-tune-CoT &  51.02 &    47.16 &  60.36 &    64.00 &       52.00 &     98.00 \\
\bottomrule
\end{tabular}



  print(table1_dfs[1].to_latex())


### Table 2.

In [132]:
MODEL_NAMES = {
    "davinci": "175B",
    "curie": "6.7B",
    "babbage": "1.3B",
    "ada": "0.3B",
    "idavinci": "i175B",
    "icurie": "i6.7B",
    "ibabbage": "i1.3B",
    "iada": "i0.3B",
}

In [133]:
datasets = [
#     "multiarith", "gsm8k", "date_understanding", "tracking_shuffled_objects", "last_letter_concatenation", "coin_flip",
    "single_eq", "addsub", "multiarith", "gsm8k", "aqua", "svamp", "commonsense_qa", "strategy_qa", "date_understanding", "tracking_shuffled_objects", "last_letter_concatenation", "coin_flip",
]
#     "single_eq", "addsub", "multiarith", "gsm8k", "aqua", "svamp", "commonsense_qa", "strategy_qa", "date_understanding", "tracking_shuffled_objects", "last_letter_concatenation", "coin_flip",

In [134]:
data = defaultdict(dict)

for dataset in datasets:
    for method in ["zs_cot", "ft", "ft_cot"]:
        if method == "zs_cot":
            base_models = ["idavinci", "ada", "babbage", "curie"]
        else:
            base_models = ["ada", "babbage", "curie"]

        for base_model in base_models:
            try:
                s = summary
                s = s[s.dataset == dataset]
                s = s[s.method == method]
                s = s[s.base_model == base_model]
                s = s[s.shot == ""]
                s = s[s.aug.isin([1, ""])]
                accuracy = s.accuracy.item() * 100
                accuracy = "{:.2f}".format(accuracy)
            except (KeyError, ValueError):
                accuracy = ""

            dataset_name = DATASET_NAMES[dataset]
            method_name = METHOD_NAMES[method]
            model_name = MODEL_NAMES[base_model]
            data[dataset_name][(method_name, model_name)] = accuracy
            
table2_df = pd.DataFrame(data)

In [135]:
table2_df

Unnamed: 0,Unnamed: 1,SingleEq,AddSub,MultiArith,GSM8K,AQUA,SVAMP,Common,Strategy,Date,Shuffled,Last Letter,Coin Flip
Zero-shot-CoT,i175B,81.18,75.72,77.33,42.07,29.63,64.0,59.86,53.4,68.29,52.93,57.0,88.6
Zero-shot-CoT,0.3B,0.66,0.0,3.89,1.52,25.2,3.0,19.57,17.76,10.81,34.67,0.0,4.67
Zero-shot-CoT,1.3B,0.0,0.0,3.33,1.67,21.65,1.0,20.23,37.99,14.41,38.22,0.0,47.33
Zero-shot-CoT,6.7B,1.32,2.52,5.0,2.35,21.26,1.33,19.98,51.09,15.32,31.11,0.0,46.67
Fine-tune,0.3B,9.87,8.4,8.89,5.08,24.41,7.67,51.68,60.41,23.42,32.44,28.67,100.0
Fine-tune,1.3B,11.84,17.65,17.78,5.38,21.26,14.33,70.93,60.7,31.53,30.22,30.0,100.0
Fine-tune,6.7B,24.34,25.21,15.0,6.14,15.35,20.67,76.17,65.21,14.41,33.78,32.67,72.0
Fine-tune-CoT,0.3B,6.58,6.72,5.56,2.88,16.54,4.33,30.3,46.72,17.12,48.89,50.67,99.33
Fine-tune-CoT,1.3B,11.18,11.76,13.89,3.87,15.35,7.33,40.62,46.58,38.74,53.78,50.67,100.0
Fine-tune-CoT,6.7B,21.05,20.17,34.44,7.2,16.14,12.33,51.02,47.16,60.36,64.0,52.0,98.0


In [90]:
print(table2_df.to_latex())

\begin{tabular}{llllllllllllll}
\toprule
              &      & SingleEq & AddSub & MultiArith &  GSM8K &   AQUA &  SVAMP & Common & Strategy &   Date & Shuffled & Last Letter & Coin Flip \\
\midrule
Zero-shot-CoT & i175B &    81.18 &  75.72 &      77.33 &  42.07 &  29.63 &  64.00 &  59.86 &    53.40 &  68.29 &    52.93 &       57.00 &     88.60 \\
              & 0.3B &     0.66 &   0.00 &       3.89 &   1.52 &  25.20 &   3.00 &  19.57 &    17.76 &  10.81 &    34.67 &        0.00 &      4.67 \\
              & 1.3B &     0.00 &   0.00 &       3.33 &   1.67 &  21.65 &   1.00 &  20.23 &    37.99 &  14.41 &    38.22 &        0.00 &     47.33 \\
              & 6.7B &     1.32 &   2.52 &       5.00 &   2.35 &  21.26 &   1.33 &  19.98 &    51.09 &  15.32 &    31.11 &        0.00 &     46.67 \\
Fine-tune & 0.3B &     9.87 &   8.40 &       8.89 &   5.08 &  24.41 &   7.67 &  51.68 &    60.41 &  23.42 &    32.44 &       28.67 &    100.00 \\
              & 1.3B &    11.84 &  17.65 &      17.78

  print(table2_df.to_latex())


### Table 3.

In [136]:
lines = [
    ["single_eq", "addsub", "multiarith", "gsm8k", "aqua", "svamp"],
    ["commonsense_qa", "strategy_qa", "date_understanding", "tracking_shuffled_objects", "last_letter_concatenation", "coin_flip"],
]

In [138]:
data = defaultdict(dict)

datasets = ALL_DATASETS
base_model = "curie"
for dataset in datasets:
    dataset_name = FULL_DATASET_NAMES[dataset]
    shots = ["", 8, 32, 128, ""]
    methods = ["zs_cot"] + ["ft_cot"] * 4

    for method, shot in zip(methods, shots):
        method_name = METHOD_NAMES[method]
        if shot:
            shot_name = "{}-shot".format(shot)
        else:
            shot_name = ""
        try:
            s = summary
            s = s[s.dataset == dataset]
            s = s[s.method == method]
            s = s[s.base_model == base_model]
            s = s[s.aug.isin(["", 1])]
            s = s[s.shot == shot]
            accuracy = s.accuracy.item() * 100
            accuracy = "{:.2f}".format(accuracy)
        except (KeyError, ValueError):
            accuracy = ""
        data[dataset_name][(method_name, shot_name)] = accuracy

table3 = pd.DataFrame(data).T
table3

Unnamed: 0_level_0,Zero-shot-CoT,Fine-tune-CoT,Fine-tune-CoT,Fine-tune-CoT,Fine-tune-CoT
Unnamed: 0_level_1,Unnamed: 1_level_1,8-shot,32-shot,128-shot,Unnamed: 5_level_1
SingleEq,1.32,5.26,8.55,13.82,21.05
AddSub,2.52,12.61,18.49,22.69,20.17
MultiArith,5.0,0.56,12.22,17.78,34.44
GSM8K,2.35,0.91,2.43,2.65,7.2
AQUA,21.26,11.81,9.06,12.6,16.14
SVAMP,1.33,2.0,8.0,7.33,12.33
Date Understanding,15.32,0.9,18.02,26.13,60.36
Coin Flip (4 words),46.67,45.33,100.0,98.67,98.0
Shuffled Objects,31.11,28.44,34.22,47.56,64.0
Last Letter (4 words),0.0,2.0,21.33,42.67,52.0


In [140]:
print(table3.to_latex())

\begin{tabular}{llllll}
\toprule
{} & Zero-shot-CoT & \multicolumn{4}{l}{Fine-tune-CoT} \\
{} &        8-shot & 32-shot & \multicolumn{2}{l}{128-shot} \\
\midrule
SingleEq              &          1.32 &          5.26 &    8.55 &    13.82 &  21.05 \\
AddSub                &          2.52 &         12.61 &   18.49 &    22.69 &  20.17 \\
MultiArith            &          5.00 &          0.56 &   12.22 &    17.78 &  34.44 \\
GSM8K                 &          2.35 &          0.91 &    2.43 &     2.65 &   7.20 \\
AQUA                  &         21.26 &         11.81 &    9.06 &    12.60 &  16.14 \\
SVAMP                 &          1.33 &          2.00 &    8.00 &     7.33 &  12.33 \\
Date Understanding    &         15.32 &          0.90 &   18.02 &    26.13 &  60.36 \\
Coin Flip (4 words)   &         46.67 &         45.33 &  100.00 &    98.67 &  98.00 \\
Shuffled Objects      &         31.11 &         28.44 &   34.22 &    47.56 &  64.00 \\
Last Letter (4 words) &          0.00 &          2.00 

  print(table3.to_latex())


### Table 4.

In [150]:
data = defaultdict(list)

dataset = "multiarith"
dataset_name = FULL_DATASET_NAMES[dataset]
ft_augs = [1, 2, 4, 8, 16, 32]
methods = ["zs_cot"] + ["ft_cot"] * len(ft_augs)
augs = [""] + ft_augs
for method, aug in zip(methods, augs):
    data[("Method", "")].append(METHOD_NAMES[method])
    data[("Reasoning", "Samples")].append(aug)
    for base_model in ["ada", "babbage", "curie"]:
        try:
            s = summary
            s = s[s.dataset == dataset]
            s = s[s.method == method]
            s = s[s.base_model == base_model]
            s = s[s.aug == aug]
            s = s[s.shot == ""]
            accuracy = s.accuracy.item() * 100
            accuracy = "{:.2f}".format(accuracy)
        except (KeyError, ValueError):
            accuracy = ""

        base_model_name = MODEL_NAMES[base_model]
        
        data[("Model", base_model_name)].append(accuracy)
        
table4 = pd.DataFrame(data)
table4

Unnamed: 0_level_0,Method,Reasoning,Model,Model,Model
Unnamed: 0_level_1,Unnamed: 1_level_1,Samples,0.3B,1.3B,6.7B
0,Zero-shot-CoT,,3.89,3.33,5.0
1,Fine-tune-CoT,1.0,5.56,13.89,34.44
2,Fine-tune-CoT,2.0,7.22,15.56,27.22
3,Fine-tune-CoT,4.0,7.78,13.33,33.89
4,Fine-tune-CoT,8.0,7.78,19.44,47.22
5,Fine-tune-CoT,16.0,16.67,21.11,41.67
6,Fine-tune-CoT,32.0,21.11,30.0,55.56


In [151]:
print(table4.style.hide(axis="index").to_latex(multicol_align="c", hrules=True))

\begin{tabular}{lllll}
\toprule
Method & Reasoning & \multicolumn{3}{c}{Model} \\
 & Samples & 0.3B & 1.3B & 6.7B \\
\midrule
Zero-shot-CoT &  & 3.89 & 3.33 & 5.00 \\
Fine-tune-CoT & 1 & 5.56 & 13.89 & 34.44 \\
Fine-tune-CoT & 2 & 7.22 & 15.56 & 27.22 \\
Fine-tune-CoT & 4 & 7.78 & 13.33 & 33.89 \\
Fine-tune-CoT & 8 & 7.78 & 19.44 & 47.22 \\
Fine-tune-CoT & 16 & 16.67 & 21.11 & 41.67 \\
Fine-tune-CoT & 32 & 21.11 & 30.00 & 55.56 \\
\bottomrule
\end{tabular}



### Table 5.

In [379]:
data = defaultdict(list)

dataset = "multiarith"
dataset_name = FULL_DATASET_NAMES[dataset]
ft_shots = [8, 32, 128]
ft_augs = [1, 2, 4, 8, 16, 32]
n_shots, n_augs = len(ft_shots), len(ft_augs)
methods = ["zs_cot"] + ["ft_cot"] + ["ft_cot"] * n_shots * n_augs
shots = ["", ""] + np.repeat(ft_shots, n_augs).tolist()
augs = ["", 1] + ft_augs * n_shots

for method, shot, aug in zip(methods, shots, augs):
    data[("Method", "")].append(METHOD_NAMES[method])
    data[("Shots", "")].append(shot)
    data[("Reasoning", "Samples")].append(aug)
    for base_model in ["ada", "babbage", "curie"]:
        try:
            s = summary
            s = s[s.dataset == dataset]
            s = s[s.method == method]
            s = s[s.base_model == base_model]
            s = s[s.aug == aug]
            s = s[s.shot == shot]
            accuracy = s.accuracy.item() * 100
            accuracy = "{:.2f}".format(accuracy)
        except (KeyError, ValueError):
            accuracy = ""

        base_model_name = MODEL_NAMES[base_model]
        
        data[("Model", base_model_name)].append(accuracy)
        
table5 = pd.DataFrame(data)
table5

Unnamed: 0_level_0,Method,Shots,Reasoning,Model,Model,Model
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Samples,0.3B,1.3B,6.7B
0,Zero-shot-CoT,,,3.89,3.33,5.0
1,Fine-tune-CoT,,1.0,5.56,13.89,34.44
2,Fine-tune-CoT,8.0,1.0,4.44,3.33,0.56
3,Fine-tune-CoT,8.0,2.0,3.33,5.0,2.22
4,Fine-tune-CoT,8.0,4.0,1.67,0.56,2.78
5,Fine-tune-CoT,8.0,8.0,3.33,2.22,7.22
6,Fine-tune-CoT,8.0,16.0,2.22,3.89,6.11
7,Fine-tune-CoT,8.0,32.0,2.78,1.67,6.11
8,Fine-tune-CoT,32.0,1.0,3.89,5.0,12.22
9,Fine-tune-CoT,32.0,2.0,3.89,2.78,10.0


In [380]:
print(table5.style.hide(axis="index").to_latex(multicol_align="c", hrules=True))

\begin{tabular}{llllll}
\toprule
Method & Shots & Reasoning & \multicolumn{3}{c}{Model} \\
 &  & Samples & 0.3B & 1.3B & 6.7B \\
\midrule
Zero-shot-CoT &  &  & 3.89 & 3.33 & 5.00 \\
Fine-tune-CoT &  & 1 & 5.56 & 13.89 & 34.44 \\
Fine-tune-CoT & 8 & 1 & 4.44 & 3.33 & 0.56 \\
Fine-tune-CoT & 8 & 2 & 3.33 & 5.00 & 2.22 \\
Fine-tune-CoT & 8 & 4 & 1.67 & 0.56 & 2.78 \\
Fine-tune-CoT & 8 & 8 & 3.33 & 2.22 & 7.22 \\
Fine-tune-CoT & 8 & 16 & 2.22 & 3.89 & 6.11 \\
Fine-tune-CoT & 8 & 32 & 2.78 & 1.67 & 6.11 \\
Fine-tune-CoT & 32 & 1 & 3.89 & 5.00 & 12.22 \\
Fine-tune-CoT & 32 & 2 & 3.89 & 2.78 & 10.00 \\
Fine-tune-CoT & 32 & 4 & 2.78 & 5.56 & 8.89 \\
Fine-tune-CoT & 32 & 8 & 2.78 & 4.44 & 12.78 \\
Fine-tune-CoT & 32 & 16 & 3.33 & 8.89 & 17.22 \\
Fine-tune-CoT & 32 & 32 & 1.11 & 7.78 & 15.56 \\
Fine-tune-CoT & 128 & 1 & 6.11 & 9.44 & 17.78 \\
Fine-tune-CoT & 128 & 2 & 3.33 & 7.78 & 13.33 \\
Fine-tune-CoT & 128 & 4 & 7.22 & 8.33 & 23.33 \\
Fine-tune-CoT & 128 & 8 & 6.67 & 9.44 & 27.78 \\
Fine-tun

In [119]:
data = defaultdict(dict)

base_model = "babbage"
for dataset in sum(lines, []):
    for method in ["zs_cot", "ft_cot_008shot", "ft_cot_032shot", "ft_cot_128shot", "ft_cot"]:
        try:
            i = summary.dataset == dataset
            i = i & (summary.method == method)
            i = i & (summary.base_model == base_model)
            accuracy = summary.accuracy[i].item() * 100
            accuracy = "{:.2f}".format(accuracy)
        except (KeyError, ValueError):
            accuracy = ""

        dataset_name = FULL_DATASET_NAMES[dataset]
        if "ft_cot" in method:
            method_name = METHOD_NAMES["ft_cot"]
        else:
            method_name = METHOD_NAMES[method]
        
        size = ""
        if method == "ft_cot":
            size = "Full"
        if "008" in method:
            size = "8-shot"
        elif "032" in method: 
            size = "32-shot"
        elif "128" in method: 
            size = "128-shot"
        data[(method_name, size)][dataset_name] = accuracy
table3_babbage = pd.DataFrame(data)

In [120]:
table3_babbage

Unnamed: 0_level_0,Zero-shot-CoT,Fine-tune-CoT,Fine-tune-CoT,Fine-tune-CoT,Fine-tune-CoT
Unnamed: 0_level_1,Unnamed: 1_level_1,8-shot,32-shot,128-shot,Full
SingleEq,,,,,
AddSub,,,,,
MultiArith,3.33,3.33,5.0,9.44,13.89
GSM8K,1.67,1.67,1.82,1.67,3.87
AQUA,,,,,
SVAMP,,,,,
CommonSenseQA,,,,,
StrategyQA,,,,,
Date Understanding,14.41,16.22,18.92,23.42,38.74
Shuffled Objects,38.22,27.11,32.0,36.44,53.78


In [88]:
print(table1_dfs[0].to_latex())

\begin{tabular}{lllllll}
\toprule
{} & SingleEq & AddSub & MultiArith & GSM8K & AQUA & SVAMP \\
\midrule
Fine-tune     &          &        &      15.00 &  6.14 &      &       \\
Zero-shot-CoT &          &        &       5.00 &  2.35 &      &       \\
Fine-tune-CoT &          &        &      34.44 &  7.20 &      &       \\
\bottomrule
\end{tabular}



  print(table1_dfs[0].to_latex())


In [89]:
print(table1_dfs[1].to_latex())

\begin{tabular}{lllllll}
\toprule
{} & Common & Strategy &   Date & Shuffled & Last Letter & Coin Flip \\
\midrule
Fine-tune     &        &          &  14.41 &    33.78 &       32.67 &     72.00 \\
Zero-shot-CoT &        &          &  15.32 &    31.11 &        0.00 &     46.67 \\
Fine-tune-CoT &        &          &  60.36 &    64.00 &       52.00 &     98.00 \\
\bottomrule
\end{tabular}



  print(table1_dfs[1].to_latex())
