This notebook aims to assess the generalization accuracy of a generated suffix, assuming a data-split was used during training.

In [5]:
%load_ext autoreload
%autoreload 2
import numpy as np
import torch
from torch import nn
import matplotlib.pyplot as plt
from copy import deepcopy
import pandas as pd
from tqdm import tqdm
from collections import defaultdict
from transformers import AutoTokenizer
import pandas as pd
import seaborn as sns
from types import SimpleNamespace
from datasets import Dataset
from os.path import join as oj
import pickle as pkl
import os
import dvu
dvu.set_style()
import analyze_utils
import sys
sys.path.append('..')
import data
from model_utils import prompt_classification

class fake_args:
    template_num_task_phrasing = 0
    max_dset_size = 1000
    max_digit = 10
    seed = 1
    train_split_frac = 0.75

    # these will be varied
    n_shots = 1
    task_name = 'add_two'
args = fake_args()
np.random.seed(args.seed)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:
task_names = ['add_two', 'multiply_two', 'divide_two', 'subtract_two',
              'max_two', 'first_two',
              'square_one', 'exp_one', 'double_one', 'fibonacci_one'] + \
    ['task1146_country_capital', 'task1509_evalution_antonyms', 'task1147_country_currency',
     'task1149_item_check_edible', 'task183_rhyme_generation', 'task1191_food_veg_nonveg',
     'task092_check_prime_classification', 'task088_identify_typo_verification',
     'task1336_peixian_equity_evaluation_corpus_gender_classifier', 'task107_splash_question_to_sql'
     ]

for checkpoint in ['gpt2-medium', 'EleutherAI/gpt-j-6B', 'gpt2-xl', 'EleutherAI/gpt-neox-20b']:
    d = defaultdict(list)
    print('loading', checkpoint)
    model = prompt_classification.create_model(checkpoint)
    for prompt in ['', 'manual']:
        for task_name in tqdm(task_names):
            for n_shots in [1, 5]: 
                    args.task_name = task_name
                    args.n_shots = n_shots
                    (dset, dset_test), check_answer_func, descr = data.get_data(
                        args, args.task_name, n_shots=args.n_shots, train_split_frac=args.train_split_frac)
                    d['checkpoint'].append(checkpoint)
                    d['prompt'].append(prompt)
                    d['task_name'].append(task_name)
                    d['n_shots'].append(n_shots)
                    if prompt == 'manual':
                        prompt_actual = descr
                    else:
                        prompt_actual = prompt
                    d['prompt_actual'].append(prompt_actual)
                    batch_size = 16
                    if checkpoint == 'EleutherAI/gpt-neox-20b':
                        batch_size = 1
                    loss, acc = prompt_classification.test_model_on_task_with_prefix(
                        dset=dset, model=model, prefix=prompt_actual, multi_token=True, verbose=False,
                    )
                    d['acc'].append(acc)
        pkl.dump(d, open(f'baseline_accs_{checkpoint.replace("/", "___")}.pkl', 'wb'))

100%|██████████| 1/1 [00:00<00:00,  1.15it/s]


Percent correct: 8.00


100%|██████████| 1/1 [00:01<00:00,  1.04s/it]

Percent correct: 26.67



100%|██████████| 1/1 [00:05<00:00,  5.36s/it]


Percent correct: 20.00


100%|██████████| 1/1 [00:06<00:00,  6.85s/it]

Percent correct: 70.67



100%|██████████| 1/1 [00:02<00:00,  2.23s/it]


Percent correct: 18.67


100%|██████████| 1/1 [00:02<00:00,  2.76s/it]

Percent correct: 21.33





In [10]:
ds = []
for checkpoint in ['gpt2-medium', 'EleutherAI/gpt-j-6B', 'gpt2-xl']:
    d = pd.DataFrame.from_dict(pkl.load(open(f'baseline_accs_{checkpoint.replace("/", "___")}.pkl', 'rb')))
    ds.append(deepcopy(d))
df = pd.concat(ds)

In [11]:
df

Unnamed: 0,checkpoint,prompt,task_name,n_shots,prompt_actual,acc
0,gpt2-medium,,add_two,1,,10.666667
1,gpt2-medium,manual,add_two,1,Return the sum of the inputs.,26.666667
0,EleutherAI/gpt-j-6B,,add_two,1,,20.0
1,EleutherAI/gpt-j-6B,manual,add_two,1,Return the sum of the inputs.,70.666667
0,gpt2-xl,,add_two,1,,18.666667
1,gpt2-xl,manual,add_two,1,Return the sum of the inputs.,21.333333


# Manually inspect prompts

In [6]:

for task_name in task_names:
    (dset, dset_test), check_answer_func, descr = data.get_data(
        args, task_name, n_shots=args.n_shots, train_split_frac=args.train_split_frac)
    print(task_name, descr, dset[0], end='\n\n')

add_two Return the sum of the inputs. {'text': 'Given the input numbers 8 and 0, the answer is 8.\n\n', 'input': 'Given the input numbers 8 and 0, the answer is', 'output': ' 8.\n\n', '__index_level_0__': 80}

multiply_two Return the product of the inputs. {'text': 'Given the input numbers 4 and 4, the answer is 16.\n\n', 'input': 'Given the input numbers 4 and 4, the answer is', 'output': ' 16.\n\n', '__index_level_0__': 44}

divide_two Return the quotient of the inputs. {'text': 'Given the input numbers 9 and 5, the answer is 9/5.\n\n', 'input': 'Given the input numbers 9 and 5, the answer is', 'output': ' 9/5.\n\n', '__index_level_0__': 95}

subtract_two Return the difference of the inputs. {'text': 'Given the input numbers 4 and 0, the answer is 4.\n\n', 'input': 'Given the input numbers 4 and 0, the answer is', 'output': ' 4.\n\n', '__index_level_0__': 40}

max_two Return the maximum of the inputs. {'text': 'Given the input numbers 8 and 9, the answer is 9.\n\n', 'input': 'Given t