In [1]:
import json
import os
from tqdm import tqdm
from collections import defaultdict
import numpy as np
from typing import *
import time
import copy
import warnings

from data.openai import *
from data.generation import *
from data.finetune import *
from data.inference import *
from data.io import *
from data.evaluation import *
from data.split import *

from utils.paths import *
from utils.metadata import *

In [2]:
import openai
openai.api_key = "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"

In [3]:
BASE_MODELS = [
    "ada", "babbage", "curie",
]

In [4]:
DATASETS = [
    "single_eq", "addsub", "multiarith", "gsm8k", "aqua", "svamp",
    "date_understanding", "coin_flip",
    "tracking_shuffled_objects", "last_letter_concatenation",
    "commonsense_qa", "strategy_qa",
]
datasets = DATASETS

In [122]:
DATASET_NAMES = {
    "single_eq": "SingleEq",
    "addsub": "AddSub",
    "multiarith": "MultiArith",
    "gsm8k": "GSM8K",
    "aqua": "AQUA",
    "svamp": "SVAMP",
    "commonsense_qa": "Common",  # SenseQA
    "strategy_qa": "Strategy",  # QA
    "date_understanding": "Date",  # Understanding
    "tracking_shuffled_objects": "Shuffled",  # Objects
    "last_letter_concatenation": "Last Letter",  # (4 words)
    "coin_flip": "Coin Flip",  # (4 times)
}

In [123]:
FULL_DATASET_NAMES = { "single_eq": "SingleEq", "addsub": "AddSub", "multiarith": "MultiArith", "gsm8k": "GSM8K", "aqua": "AQUA", "svamp": "SVAMP", "commonsense_qa": "CommonSenseQA", "strategy_qa": "StrategyQA", "date_understanding": "Date Understanding", "tracking_shuffled_objects": "Shuffled Objects", "last_letter_concatenation": "Last Letter (4 words)", "coin_flip": "Coin Flip (4 words)"}


In [124]:
METHOD_NAMES = {
    "zs": "Zero-shot",
    "ft": "Fine-tune",
    "zs_cot": "Zero-shot-CoT",
    "ft_cot": "Fine-tune-CoT",
    "ft_cot_008shot": "8-shot Ft-CoT",
    "ft_cot_032shot": "32-shot Ft-CoT",
    "ft_cot_128shot": "128-shot Ft-CoT",
}

# Student Models

In [5]:
def remove_whitespaces(string):
    lines = string.split("\n")
    lines = [line.strip() for line in lines]
    lines = filter(lambda a:a, lines)
    return " ".join(lines)

In [6]:
def write_samples(completion_data, dataset_key, f):
    for i, samples in completion_data.items():
        if samples:
            s = samples[0]
            q = s["question"]
            a = s["answer"]
            c = s["completion"]
            if c.find("END") != -1:
                c = c[:c.find("END")]

            clean_answer = cleanse_answer(a, dataset_key)
            clean_prediction, prediction_candidates = cleanse_prediction(
                c, dataset_key, answer_prefix="-->", stop_sequence="END")
            
            correct = (clean_answer == clean_prediction)
            correctness = "CORRECT  " if correct else "INCORRECT"
            complete = s["finish_reason"] == "stop" or s["completion"] != -1
            finishedness = "COMPLETE  " if complete else "INCOMPLETE"
            
            f.write("### {} #{:d} (`{} - {}`)\n".format(dataset_key.upper(), i, finishedness, correctness))
            f.write("- **Question**: {}\n".format(q))
            f.write("- **Answer**: {} (`{}`)\n".format(a, clean_answer))
            f.write("- **Completion**: {}\n".format(remove_whitespaces(c)))
            f.write("- **Prediction**: `{}`\n".format(clean_prediction))
            if len(prediction_candidates) >= 2:
                f.write("- *Candidates: ")
                for i, candidate in enumerate(prediction_candidates):
                    f.write("`{}`, ".format(candidate))
                f.write("\n")
        f.write("\n")

In [7]:
for base_model_key in BASE_MODELS:
    for dataset_key in DATASETS:
        template = "special"
        completion_key  = "finetune_cot"
        file_key = "zs_cot_{}_{}_train".format(template, dataset_key)
        model_key = "{}_{}".format(base_model_key, file_key)

        completion_data = load_completion_data(completion_key, dataset_key, model_key)
        evaluation = evaluate_completions(completion_data, dataset_key)
        metrics = get_evaluation_metrics(evaluation)
        
        basename = "finetune_cot_{}_{}.md".format(base_model_key, dataset_key)
        path = "manual_inspection"
        path = os.path.join(path, basename)
        os.makedirs(os.path.dirname(path), exist_ok=True)
        with open(path, "w") as f:
            f.write("## Metrics\n")
            f.write("```\n")
            for key, value in metrics.items():
                f.write("{:40s}: {:6.2f}%\n".format(key, value * 100))
            f.write("```\n")
            f.write("\n")
            f.write("## Samples\n")
            f.write("\n")
            write_samples(completion_data, dataset_key, f)
        print("Samples saved:")
        print(path)

Samples saved:
manual_inspection/finetune_cot_ada_single_eq.md
Samples saved:
manual_inspection/finetune_cot_ada_addsub.md
Samples saved:
manual_inspection/finetune_cot_ada_multiarith.md
Samples saved:
manual_inspection/finetune_cot_ada_gsm8k.md
Samples saved:
manual_inspection/finetune_cot_ada_aqua.md
Samples saved:
manual_inspection/finetune_cot_ada_svamp.md
Samples saved:
manual_inspection/finetune_cot_ada_date_understanding.md
Samples saved:
manual_inspection/finetune_cot_ada_coin_flip.md
Samples saved:
manual_inspection/finetune_cot_ada_tracking_shuffled_objects.md
Samples saved:
manual_inspection/finetune_cot_ada_last_letter_concatenation.md
Samples saved:
manual_inspection/finetune_cot_ada_commonsense_qa.md
Samples saved:
manual_inspection/finetune_cot_ada_strategy_qa.md
Samples saved:
manual_inspection/finetune_cot_babbage_single_eq.md
Samples saved:
manual_inspection/finetune_cot_babbage_addsub.md
Samples saved:
manual_inspection/finetune_cot_babbage_multiarith.md
Samples save

  "accuracy_without_prefix": (~contains_prefix & correct).sum() / (~contains_prefix).sum(),
  "accuracy_when_incomplete": (~complete & correct).sum() / (~complete).sum(),
  "accuracy_without_prefix": (~contains_prefix & correct).sum() / (~contains_prefix).sum(),
  "accuracy_when_incomplete": (~complete & correct).sum() / (~complete).sum(),
  "accuracy_without_prefix": (~contains_prefix & correct).sum() / (~contains_prefix).sum(),
  "accuracy_when_incomplete": (~complete & correct).sum() / (~complete).sum(),
  "accuracy_without_prefix": (~contains_prefix & correct).sum() / (~contains_prefix).sum(),
  "accuracy_when_incomplete": (~complete & correct).sum() / (~complete).sum(),
  "accuracy_without_prefix": (~contains_prefix & correct).sum() / (~contains_prefix).sum(),
  "accuracy_when_incomplete": (~complete & correct).sum() / (~complete).sum(),
  "accuracy_without_prefix": (~contains_prefix & correct).sum() / (~contains_prefix).sum(),
  "accuracy_when_incomplete": (~complete & correct).s

Samples saved:
manual_inspection/finetune_cot_babbage_last_letter_concatenation.md
Samples saved:
manual_inspection/finetune_cot_babbage_commonsense_qa.md
Samples saved:
manual_inspection/finetune_cot_babbage_strategy_qa.md
Samples saved:
manual_inspection/finetune_cot_curie_single_eq.md
Samples saved:
manual_inspection/finetune_cot_curie_addsub.md
Samples saved:
manual_inspection/finetune_cot_curie_multiarith.md
Samples saved:
manual_inspection/finetune_cot_curie_gsm8k.md
Samples saved:
manual_inspection/finetune_cot_curie_aqua.md
Samples saved:
manual_inspection/finetune_cot_curie_svamp.md
Samples saved:
manual_inspection/finetune_cot_curie_date_understanding.md
Samples saved:
manual_inspection/finetune_cot_curie_coin_flip.md
Samples saved:
manual_inspection/finetune_cot_curie_tracking_shuffled_objects.md
Samples saved:
manual_inspection/finetune_cot_curie_last_letter_concatenation.md
Samples saved:
manual_inspection/finetune_cot_curie_commonsense_qa.md
Samples saved:
manual_inspecti

  "accuracy_without_prefix": (~contains_prefix & correct).sum() / (~contains_prefix).sum(),
  "accuracy_when_incomplete": (~complete & correct).sum() / (~complete).sum(),
  "accuracy_without_prefix": (~contains_prefix & correct).sum() / (~contains_prefix).sum(),
  "accuracy_when_incomplete": (~complete & correct).sum() / (~complete).sum(),
  "accuracy_without_prefix": (~contains_prefix & correct).sum() / (~contains_prefix).sum(),
  "accuracy_when_incomplete": (~complete & correct).sum() / (~complete).sum(),


# Teacher Models

In [35]:
for dataset_key in DATASETS:
    break
ho = 1
 
completion_key  = "zs_cot"
model_key = "text-davinci-002"

print(dataset_key)

completion_data = load_completion_data(completion_key, dataset_key, model_key)
evaluation = evaluate_completions(completion_data, dataset_key)
metrics = get_evaluation_metrics(evaluation)

single_eq


  "accuracy_with_prefix": (contains_prefix & correct).sum() / contains_prefix.sum(),
  "accuracy_when_complete": (complete & correct).sum() / complete.sum(),


In [8]:
def write_teacher_samples(completion_data, dataset_key, f):
    for i, samples in completion_data.items():
        if samples:
            s = samples[0]
            q = s["question"]
            reasoning = remove_whitespaces(s["reasoning_completion"])
            a = s["answer"]
            c = s["completion"]
            if c.find("END") != -1:
                c = c[:c.find("END")]

            clean_answer = cleanse_answer(a, dataset_key)
            clean_prediction, prediction_candidates = cleanse_prediction(
                c, dataset_key, answer_prefix="-->", stop_sequence="END")

            correct = (clean_answer == clean_prediction)
            correctness = "CORRECT  " if correct else "INCORRECT"
            complete = s["reasoning_finish_reason"] == "stop" or c.find("END") != -1
            finishedness = "COMPLETE  " if complete else "INCOMPLETE"


            f.write("### {} #{:d} (`{} - {}`)\n".format(dataset_key.upper(), i, finishedness, correctness))
            f.write("- **Question**: {} *Let's think step by step*\n".format(q))
            f.write("- **Reasoning**: {} *Therefore, the answer is*\n".format(reasoning))
            f.write("- **Answer**: {} (`{}`)\n".format(a, clean_answer))
            f.write("- **Completion**: {}\n".format(remove_whitespaces(c)))
            f.write("- **Prediction**: `{}`\n".format(clean_prediction))
            if len(prediction_candidates) >= 2:
                f.write("- *Candidates: ")
                for i, candidate in enumerate(prediction_candidates):
                    f.write("`{}`, ".format(candidate))
                f.write("\n")
        f.write("\n")

In [11]:
for dataset_key in DATASETS:
    completion_key  = "zs_cot"
    model_key = "text-davinci-002"
    train_indices, test_indices = get_train_test_indices(dataset_key)

    completion_data = load_completion_data(completion_key, dataset_key, model_key)
    evaluation = evaluate_completions(completion_data, dataset_key, template=None, indices=train_indices)
    metrics = get_evaluation_metrics(evaluation)

    basename = "zs_cot_teacher_{}_train.md".format(dataset_key)
    path = "manual_inspection"
    path = os.path.join(path, basename)
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w") as f:
        f.write("## Metrics\n")
        f.write("```\n")
        for key, value in metrics.items():
            f.write("{:40s}: {:6.2f}%\n".format(key, value * 100))
        f.write("```\n")
        f.write("\n")
        f.write("## Samples\n")
        f.write("\n")
        write_teacher_samples(completion_data, dataset_key, f)
    print("Samples saved:")
    print(path)

Samples saved:
manual_inspection/zs_cot_teacher_single_eq_train.md
Samples saved:
manual_inspection/zs_cot_teacher_addsub_train.md
Samples saved:
manual_inspection/zs_cot_teacher_multiarith_train.md
Samples saved:
manual_inspection/zs_cot_teacher_gsm8k_train.md
Samples saved:
manual_inspection/zs_cot_teacher_aqua_train.md
Samples saved:
manual_inspection/zs_cot_teacher_svamp_train.md
Samples saved:
manual_inspection/zs_cot_teacher_date_understanding_train.md
Samples saved:
manual_inspection/zs_cot_teacher_coin_flip_train.md
Samples saved:
manual_inspection/zs_cot_teacher_tracking_shuffled_objects_train.md
Samples saved:
manual_inspection/zs_cot_teacher_last_letter_concatenation_train.md
Samples saved:
manual_inspection/zs_cot_teacher_commonsense_qa_train.md
Samples saved:
manual_inspection/zs_cot_teacher_strategy_qa_train.md
