In [1]:
%load_ext autoreload
%autoreload 2
import os
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import imodelsx.process_results
from collections import defaultdict
import numpy as np
import viz
from sklearn.tree import DecisionTreeClassifier
from scipy import stats
from os.path import join
from tqdm import tqdm
import joblib
from scipy.sparse import issparse
import sys
sys.path.append('../experiments/')
results_dir = '/home/chansingh/mntv1/tree-prompt/dummy_runs/'

r = imodelsx.process_results.get_results_df(results_dir, results_fname='params.json')
experiment_filename = '../experiments/01_fit.py'
r = imodelsx.process_results.fill_missing_args_with_default(r, experiment_filename)
# imodelsx.process_results.delete_runs_in_dataframe(
    # r[(~r.dataset_name.str.startswith('knnp')) & (r.prompt_source == 'data_demonstrations')], actually_delete=True)

100%|██████████| 105/105 [00:08<00:00, 12.61it/s]


In [2]:
with pd.option_context('display.max_rows', None, 'display.max_columns', None, 'display.max_colwidth', None):
    display(r.groupby(['dataset_name', 'prompt_source']).size())
    # display(r.groupby(['dataset_name', 'prompt_source'])
    # ['checkpoint'].unique())

dataset_name          prompt_source      
emotion               manual                 7
financial_phrasebank  manual                 7
imdb                  manual                 7
knnp__agnews          data_demonstrations    7
knnp__cb              data_demonstrations    7
knnp__cr              data_demonstrations    7
knnp__dbpedia         data_demonstrations    7
knnp__mpqa            data_demonstrations    7
knnp__mr              data_demonstrations    7
knnp__rte             data_demonstrations    7
knnp__sst2            data_demonstrations    7
knnp__subj            data_demonstrations    7
knnp__trec            data_demonstrations    7
rotten_tomatoes       manual                 7
sst2                  manual                 7
dtype: int64

In [3]:
DATA_PROCESSED_DIR = os.path.expanduser('~/cost-optimal-tree/data_processed/')
r = r.sort_values(by=['dataset_name']).reset_index().drop(columns=['index'])
for k in ['n_train', 'n_test', 'n_prompts', 'n_outputs']:
    r[k] = np.nan
for i in tqdm(range(r.shape[0])):
    args = r.iloc[i]
    print(args.dataset_name, args.checkpoint)
    run = joblib.load(join(args.save_dir_unique, 'cached_dset.pkl'))

    out_dir = join(DATA_PROCESSED_DIR, args.dataset_name)
    os.makedirs(out_dir, exist_ok=True)

    X_train = run['X_train']
    X_test = run['X_test']
    y_train = run['y_train']
    y_test = run['y_test']
    feature_names = run['feature_names']

    # run data checks
    # check that X only contains ones and zeros
    if issparse(X_train):
        X_train = X_train.toarray()
        X_test = X_test.toarray()
    assert np.all(np.isin(X_train, [0, 1]))
    assert np.all(np.isin(X_test, [0, 1]))

    # remove any cols that are all the same value
    idxs_constant = np.where(np.all(X_train == X_train[0, :], axis=0))[0]
    # print(f'removing {idxs_constant.size} constant cols')
    # print('unique', np.unique(X_train))
    X_train = np.delete(X_train, idxs_constant, axis=1)
    X_test = np.delete(X_test, idxs_constant, axis=1)
    feature_names = np.delete(feature_names, idxs_constant)

    # fit simple sklearn classifier
    clf = DecisionTreeClassifier(random_state=42, max_depth=3)
    clf.fit(X_train, y_train)
    acc = clf.score(X_test, y_test)
    acc_baseline = np.unique(y_test, return_counts=True)[1][0] / y_test.size
    acc_improvement = acc - acc_baseline
    print('acc improvement', acc_improvement.round(3))
    # assert acc > acc_baseline

    # export data
    X_train = X_train.astype(bool)
    X_test = X_test.astype(bool)

    joblib.dump(
        {
            'X_train': X_train,
            'X_test': X_test,
            # 'y_train': y_train,
            # 'y_test': y_test,
            'feature_names': feature_names,
            'verbalizer': args['verbalizer'],
            'acc_improvement': acc_improvement,
            'acc': acc,
        },
        join(out_dir, imodelsx.viz.CHECKPOINTS_RENAME_DICT.get(
            args.checkpoint) + '_features.pkl')
    )
    joblib.dump({
        'y_train': y_train,
        'y_test': y_test,
    }, join(out_dir, 'labels.pkl'))

    r.loc[i, 'n_train'] = X_train.shape[0]
    r.loc[i, 'n_test'] = X_test.shape[0]
    r.loc[i, 'n_prompts'] = X_train.shape[1]
    r.loc[i, 'n_outputs'] = np.unique(y_train).size

  0%|          | 0/105 [00:00<?, ?it/s]

emotion llama_7b


  1%|          | 1/105 [00:00<01:19,  1.31it/s]

acc improvement 0.354
emotion gpt2


  2%|▏         | 2/105 [00:01<01:45,  1.02s/it]

acc improvement 0.376
emotion gpt2-xl


  3%|▎         | 3/105 [00:03<01:50,  1.09s/it]

acc improvement 0.394
emotion meta-llama/Llama-2-7b-hf


  4%|▍         | 4/105 [00:04<01:54,  1.13s/it]

acc improvement 0.354
emotion gpt2-medium


  5%|▍         | 5/105 [00:05<01:52,  1.12s/it]

acc improvement 0.393
emotion EleutherAI/gpt-j-6B


  6%|▌         | 6/105 [00:06<01:50,  1.12s/it]

acc improvement 0.333
emotion gpt2-large


  8%|▊         | 8/105 [00:07<01:02,  1.56it/s]

acc improvement 0.375
financial_phrasebank gpt2-medium
acc improvement 0.197
financial_phrasebank EleutherAI/gpt-j-6B


 10%|▉         | 10/105 [00:07<00:37,  2.56it/s]

acc improvement 0.232
financial_phrasebank gpt2-large
acc improvement 0.225
financial_phrasebank llama_7b


 11%|█▏        | 12/105 [00:07<00:26,  3.54it/s]

acc improvement 0.225
financial_phrasebank gpt2
acc improvement 0.101
financial_phrasebank meta-llama/Llama-2-7b-hf


 13%|█▎        | 14/105 [00:08<00:20,  4.51it/s]

acc improvement 0.239
financial_phrasebank gpt2-xl
acc improvement 0.22
imdb llama_7b


 14%|█▍        | 15/105 [00:10<01:27,  1.03it/s]

acc improvement 0.382
imdb gpt2-medium


 15%|█▌        | 16/105 [00:13<02:01,  1.37s/it]

acc improvement 0.366
imdb EleutherAI/gpt-j-6B


 16%|█▌        | 17/105 [00:15<02:32,  1.73s/it]

acc improvement 0.335
imdb gpt2-large


 17%|█▋        | 18/105 [00:17<02:26,  1.68s/it]

acc improvement 0.399
imdb meta-llama/Llama-2-7b-hf


 18%|█▊        | 19/105 [00:18<02:19,  1.62s/it]

acc improvement 0.359
imdb gpt2


 19%|█▉        | 20/105 [00:20<02:16,  1.60s/it]

acc improvement 0.349
imdb gpt2-xl


 20%|██        | 21/105 [00:21<02:13,  1.59s/it]

acc improvement 0.407
knnp__agnews gpt2


 21%|██        | 22/105 [00:22<01:41,  1.22s/it]

acc improvement 0.367
knnp__agnews meta-llama/Llama-2-7b-hf


 22%|██▏       | 23/105 [00:22<01:19,  1.04it/s]

acc improvement 0.566
knnp__agnews llama_7b


 23%|██▎       | 24/105 [00:22<01:02,  1.29it/s]

acc improvement 0.453
knnp__agnews gpt2-large


 24%|██▍       | 25/105 [00:23<00:52,  1.53it/s]

acc improvement 0.504
knnp__agnews gpt2-medium


 25%|██▍       | 26/105 [00:23<00:46,  1.71it/s]

acc improvement 0.43
knnp__agnews gpt2-xl


 26%|██▌       | 27/105 [00:23<00:39,  1.96it/s]

acc improvement 0.477
knnp__agnews EleutherAI/gpt-j-6B


 27%|██▋       | 28/105 [00:24<00:35,  2.16it/s]

acc improvement 0.52
knnp__cb gpt2-large


 28%|██▊       | 29/105 [00:24<00:32,  2.32it/s]

acc improvement 0.196
knnp__cb gpt2-medium


 29%|██▊       | 30/105 [00:25<00:30,  2.48it/s]

acc improvement 0.196
knnp__cb llama_7b


 30%|██▉       | 31/105 [00:25<00:28,  2.56it/s]

acc improvement 0.0
knnp__cb gpt2-xl


 30%|███       | 32/105 [00:25<00:27,  2.63it/s]

acc improvement 0.036
knnp__cb gpt2


 31%|███▏      | 33/105 [00:26<00:27,  2.60it/s]

acc improvement 0.036
knnp__cb EleutherAI/gpt-j-6B


 32%|███▏      | 34/105 [00:26<00:26,  2.66it/s]

acc improvement 0.268
knnp__cb meta-llama/Llama-2-7b-hf


 34%|███▍      | 36/105 [00:27<00:21,  3.22it/s]

acc improvement 0.196
knnp__cr EleutherAI/gpt-j-6B
acc improvement 0.168
knnp__cr llama_7b


 36%|███▌      | 38/105 [00:27<00:16,  4.13it/s]

acc improvement 0.301
knnp__cr gpt2
acc improvement 0.246
knnp__cr gpt2-medium


 38%|███▊      | 40/105 [00:27<00:13,  4.89it/s]

acc improvement 0.168
knnp__cr gpt2-xl
acc improvement 0.281
knnp__cr gpt2-large


 40%|████      | 42/105 [00:28<00:11,  5.39it/s]

acc improvement 0.352
knnp__cr meta-llama/Llama-2-7b-hf
acc improvement 0.355
knnp__dbpedia gpt2


 41%|████      | 43/105 [00:30<00:44,  1.38it/s]

acc improvement 0.098
knnp__dbpedia gpt2-large


 42%|████▏     | 44/105 [00:31<00:50,  1.21it/s]

acc improvement 0.172
knnp__dbpedia gpt2-xl


 43%|████▎     | 45/105 [00:32<00:54,  1.11it/s]

acc improvement 0.152
knnp__dbpedia llama_7b


 44%|████▍     | 46/105 [00:43<04:04,  4.15s/it]

acc improvement 0.098
knnp__dbpedia EleutherAI/gpt-j-6B


 45%|████▍     | 47/105 [00:45<03:14,  3.35s/it]

acc improvement 0.141
knnp__dbpedia meta-llama/Llama-2-7b-hf


 46%|████▌     | 48/105 [00:46<02:33,  2.69s/it]

acc improvement 0.148
knnp__dbpedia gpt2-medium


 47%|████▋     | 49/105 [00:47<02:02,  2.18s/it]

acc improvement 0.133
knnp__mpqa gpt2-large


 49%|████▊     | 51/105 [00:47<01:03,  1.17s/it]

acc improvement 0.133
knnp__mpqa llama_7b
acc improvement -0.004
knnp__mpqa EleutherAI/gpt-j-6B


 50%|█████     | 53/105 [00:48<00:34,  1.51it/s]

acc improvement 0.203
knnp__mpqa gpt2
acc improvement 0.215
knnp__mpqa gpt2-medium


 52%|█████▏    | 55/105 [00:49<00:25,  1.99it/s]

acc improvement 0.117
knnp__mpqa gpt2-xl
acc improvement 0.203
knnp__mpqa meta-llama/Llama-2-7b-hf


 53%|█████▎    | 56/105 [00:49<00:21,  2.32it/s]

acc improvement 0.312
knnp__mr gpt2-xl
acc improvement 0.355


 55%|█████▌    | 58/105 [00:49<00:14,  3.26it/s]

knnp__mr EleutherAI/gpt-j-6B
acc improvement 0.375
knnp__mr gpt2-medium


 56%|█████▌    | 59/105 [00:49<00:12,  3.73it/s]

acc improvement 0.18
knnp__mr llama_7b


 58%|█████▊    | 61/105 [00:50<00:12,  3.59it/s]

acc improvement 0.277
knnp__mr gpt2-large
acc improvement 0.336
knnp__mr gpt2


 60%|██████    | 63/105 [00:51<00:10,  3.83it/s]

acc improvement 0.055
knnp__mr meta-llama/Llama-2-7b-hf
acc improvement 0.422
knnp__rte llama_7b
acc improvement -0.02


 61%|██████    | 64/105 [00:51<00:10,  4.08it/s]

knnp__rte EleutherAI/gpt-j-6B
acc improvement 0.086


 62%|██████▏   | 65/105 [00:51<00:09,  4.29it/s]

knnp__rte gpt2-medium


 63%|██████▎   | 66/105 [00:51<00:11,  3.49it/s]

acc improvement 0.023
knnp__rte gpt2-large


 64%|██████▍   | 67/105 [00:52<00:09,  3.80it/s]

acc improvement 0.012
knnp__rte meta-llama/Llama-2-7b-hf


 65%|██████▍   | 68/105 [00:52<00:14,  2.63it/s]

acc improvement 0.27
knnp__rte gpt2-xl


 66%|██████▌   | 69/105 [00:52<00:11,  3.03it/s]

acc improvement -0.027
knnp__rte gpt2


 68%|██████▊   | 71/105 [00:53<00:08,  3.79it/s]

acc improvement 0.047
knnp__sst2 gpt2-large
acc improvement 0.312
knnp__sst2 EleutherAI/gpt-j-6B


 70%|██████▉   | 73/105 [00:53<00:06,  4.67it/s]

acc improvement 0.414
knnp__sst2 gpt2-xl
acc improvement 0.363
knnp__sst2 gpt2-medium


 71%|███████▏  | 75/105 [00:54<00:06,  4.73it/s]

acc improvement 0.328
knnp__sst2 llama_7b
acc improvement 0.281
knnp__sst2 gpt2


 73%|███████▎  | 77/105 [00:54<00:05,  5.03it/s]

acc improvement 0.156
knnp__sst2 meta-llama/Llama-2-7b-hf
acc improvement 0.426
knnp__subj meta-llama/Llama-2-7b-hf


 75%|███████▌  | 79/105 [00:54<00:04,  5.21it/s]

acc improvement 0.387
knnp__subj gpt2-xl
acc improvement 0.336
knnp__subj llama_7b


 77%|███████▋  | 81/105 [00:55<00:05,  4.75it/s]

acc improvement 0.039
knnp__subj gpt2
acc improvement 0.039
knnp__subj EleutherAI/gpt-j-6B


 79%|███████▉  | 83/105 [00:55<00:04,  5.10it/s]

acc improvement 0.434
knnp__subj gpt2-large
acc improvement 0.379
knnp__subj gpt2-medium


 81%|████████  | 85/105 [00:56<00:03,  5.39it/s]

acc improvement 0.324
knnp__trec EleutherAI/gpt-j-6B
acc improvement 0.34
knnp__trec llama_7b


 83%|████████▎ | 87/105 [00:56<00:04,  3.81it/s]

acc improvement 0.312
knnp__trec gpt2-medium
acc improvement 0.211
knnp__trec gpt2-large


 84%|████████▍ | 88/105 [00:56<00:04,  4.17it/s]

acc improvement 0.312
knnp__trec gpt2


 85%|████████▍ | 89/105 [00:57<00:03,  4.07it/s]

acc improvement 0.223
knnp__trec meta-llama/Llama-2-7b-hf


 87%|████████▋ | 91/105 [00:57<00:03,  4.19it/s]

acc improvement 0.434
knnp__trec gpt2-xl
acc improvement 0.297
rotten_tomatoes gpt2-large


 88%|████████▊ | 92/105 [00:59<00:09,  1.41it/s]

acc improvement 0.336
rotten_tomatoes gpt2-xl


 89%|████████▊ | 93/105 [00:59<00:07,  1.60it/s]

acc improvement 0.337
rotten_tomatoes gpt2


 90%|████████▉ | 94/105 [01:00<00:06,  1.76it/s]

acc improvement 0.265
rotten_tomatoes gpt2-medium


 90%|█████████ | 95/105 [01:00<00:05,  1.91it/s]

acc improvement 0.29
rotten_tomatoes EleutherAI/gpt-j-6B


 91%|█████████▏| 96/105 [01:01<00:04,  1.95it/s]

acc improvement 0.287
rotten_tomatoes llama_7b


 92%|█████████▏| 97/105 [01:01<00:03,  2.01it/s]

acc improvement 0.28
rotten_tomatoes meta-llama/Llama-2-7b-hf


 93%|█████████▎| 98/105 [01:03<00:06,  1.12it/s]

acc improvement 0.356
sst2 meta-llama/Llama-2-7b-hf


 94%|█████████▍| 99/105 [01:05<00:06,  1.10s/it]

acc improvement 0.352
sst2 gpt2-xl


 95%|█████████▌| 100/105 [01:06<00:06,  1.25s/it]

acc improvement 0.384
sst2 llama_7b


 96%|█████████▌| 101/105 [01:11<00:09,  2.31s/it]

acc improvement 0.334
sst2 gpt2-large


 97%|█████████▋| 102/105 [01:13<00:06,  2.09s/it]

acc improvement 0.385
sst2 EleutherAI/gpt-j-6B


 98%|█████████▊| 103/105 [01:17<00:05,  2.89s/it]

acc improvement 0.303
sst2 gpt2-medium


 99%|█████████▉| 104/105 [01:19<00:02,  2.49s/it]

acc improvement 0.267
sst2 gpt2


100%|██████████| 105/105 [01:21<00:00,  1.30it/s]

acc improvement 0.227





In [5]:
with pd.option_context('display.max_rows', None, 'display.max_columns', None, 'display.max_colwidth', None):
    metadata = r.groupby(['dataset_name', 'prompt_source', 'n_train',
                          'n_test', 'n_outputs', 'checkpoint'])[['n_prompts']].sum()
    display(metadata)
metadata.to_pickle(join(DATA_PROCESSED_DIR, 'metadata.pkl'))

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,n_prompts
dataset_name,prompt_source,n_train,n_test,n_outputs,checkpoint,Unnamed: 6_level_1
emotion,manual,10028.0,1254.0,2.0,EleutherAI/gpt-j-6B,32.0
emotion,manual,10028.0,1254.0,2.0,gpt2,32.0
emotion,manual,10028.0,1254.0,2.0,gpt2-large,32.0
emotion,manual,10028.0,1254.0,2.0,gpt2-medium,32.0
emotion,manual,10028.0,1254.0,2.0,gpt2-xl,32.0
emotion,manual,10028.0,1254.0,2.0,llama_7b,32.0
emotion,manual,10028.0,1254.0,2.0,meta-llama/Llama-2-7b-hf,32.0
financial_phrasebank,manual,871.0,436.0,2.0,EleutherAI/gpt-j-6B,40.0
financial_phrasebank,manual,871.0,436.0,2.0,gpt2,35.0
financial_phrasebank,manual,871.0,436.0,2.0,gpt2-large,38.0
