From 94269aa23a797797726bc252c5d9a34db34737c4 Mon Sep 17 00:00:00 2001 From: Chandan Singh Date: Mon, 12 Feb 2024 20:30:28 -0800 Subject: [PATCH] update expt params --- augdistill/experiments/01_eval.py | 9 +- augdistill/notebooks/01_model_results.ipynb | 591 ++++++++++++++++---- augdistill/scripts/01_eval_basic.py | 37 +- 3 files changed, 488 insertions(+), 149 deletions(-) diff --git a/augdistill/experiments/01_eval.py b/augdistill/experiments/01_eval.py index 3165af9..b673452 100644 --- a/augdistill/experiments/01_eval.py +++ b/augdistill/experiments/01_eval.py @@ -11,6 +11,7 @@ import joblib import imodels import inspect +import torch import os.path import imodelsx.cache_save_utils from imodelsx import AugLinearClassifier @@ -49,7 +50,9 @@ def add_main_args(parser): parser.add_argument( "--embedding_string_prompt", type=str, default="synonym", choices=set(list(EMBEDDING_STRING_SETTINGS.keys()) + ['None']), help="key for embedding string" ) - + parser.add_argument( + '--zeroshot_strategy', type=str, default='pos_class', choices=['pos_class', 'difference'], help='strategy for zeroshot' + ) # training misc args parser.add_argument("--seed", type=int, default=1, help="random seed") parser.add_argument( @@ -107,7 +110,7 @@ def add_computational_args(parser): # set seed np.random.seed(args.seed) random.seed(args.seed) - # torch.manual_seed(args.seed) + torch.manual_seed(args.seed) # load text data dset_val = datasets.load_dataset(args.dataset_name)['validation'] @@ -157,5 +160,5 @@ def add_computational_args(parser): r, join(save_dir_unique, "results.pkl") ) # caching requires that this is called results.pkl # joblib.dump(model, join(save_dir_unique, "model.pkl")) - print(r) + # print(r) logging.info("Succesfully completed :)\n\n") diff --git a/augdistill/notebooks/01_model_results.ipynb b/augdistill/notebooks/01_model_results.ipynb index 236aa54..c336e2c 100644 --- a/augdistill/notebooks/01_model_results.ipynb +++ b/augdistill/notebooks/01_model_results.ipynb @@ -2,15 +2,24 @@ "cells": [ { "cell_type": "code", - "execution_count": 3, + "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "The autoreload extension is already loaded. To reload it, use:\n", - " %reload_ext autoreload\n" + "[2024-02-12 20:29:10,546] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/chansingh/imodelsx/.venv/lib/python3.11/site-packages/thinc/compat.py:36: UserWarning: 'has_mps' is deprecated, please use 'torch.backends.mps.is_built()'\n", + " hasattr(torch, \"has_mps\")\n", + "/home/chansingh/imodelsx/.venv/lib/python3.11/site-packages/thinc/compat.py:37: UserWarning: 'has_mps' is deprecated, please use 'torch.backends.mps.is_built()'\n", + " and torch.has_mps # type: ignore[attr-defined]\n" ] } ], @@ -40,14 +49,14 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 3/3 [00:00<00:00, 2117.27it/s]\n" + "100%|██████████| 19/19 [00:00<00:00, 3859.91it/s]\n" ] } ], @@ -55,150 +64,486 @@ "r = imodelsx.process_results.get_results_df(results_dir)\n", "experiment_filename = '../experiments/01_eval.py'\n", "r = imodelsx.process_results.fill_missing_args_with_default(\n", - " r, experiment_filename)" + " r, experiment_filename)\n", + "\n", + "r = r[['acc_val'] + [c for c in r.columns if not c == 'acc_val']]\n", + "r = r.sort_values('acc_val', ascending=False)" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", - "\n", - "\n", + "
\n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - "
dataset_namecheckpointngramsuse_all_ngramsuse_next_token_distr_embeddingembedding_string_promptseedsave_dirmodel_nameuse_cachebatch_sizesave_dir_uniqueroc_valacc_valacc_baselinemean_pred
 acc_valdataset_namecheckpointngramsuse_all_ngramsuse_next_token_distr_embeddingembedding_string_promptzeroshot_strategyseedsave_diruse_cachebatch_sizesave_dir_uniqueroc_valacc_baselinemean_pred
0rotten_tomatoestextattack/distilbert-base-uncased-rotten-toma...210None1/home/chansingh/augmented-interpretable-models...decision_tree18/home/chansingh/augmented-interpretable-models...0.620.5933330.4866670.56
1rotten_tomatoesbert-base-uncased210None1/home/chansingh/augmented-interpretable-models...decision_tree18/home/chansingh/augmented-interpretable-models...0.490.5066670.4866670.70
2rotten_tomatoeshkunlp/instructor-xl210instructor_sentiment1/home/chansingh/augmented-interpretable-models...decision_tree18/home/chansingh/augmented-interpretable-models...0.620.6266670.4866670.50150.63rotten_tomatoeshkunlp/instructor-xl210instructor_sentimentpos_class1/home/chansingh/augmented-interpretable-models/augdistill/results18/home/chansingh/augmented-interpretable-models/augdistill/results/bd5ac3875c805a0e67a5b648fcc59e0a0b5e9abff5d469fc5f2b244c3496c6db0.620.490.50
40.59rotten_tomatoestextattack/distilbert-base-uncased-rotten-tomatoes210Nonepos_class1/home/chansingh/augmented-interpretable-models/augdistill/results18/home/chansingh/augmented-interpretable-models/augdistill/results/1a41b8dab1b5bfe4b08cfa3ba9d1ed0736b56e39780cda131d231d1af44afc640.620.490.56
20.55rotten_tomatoesmeta-llama/Llama-2-7b-hf211synonympos_class1/home/chansingh/augmented-interpretable-models/augdistill/results18/home/chansingh/augmented-interpretable-models/augdistill/results/0f62167b4ec16d0a661a7ca02177350ac82d1743cfe92c39ef9fdd82bda804dd0.530.490.39
170.55rotten_tomatoesmeta-llama/Llama-2-7b-hf211synonymdifference1/home/chansingh/augmented-interpretable-models/augdistill/results18/home/chansingh/augmented-interpretable-models/augdistill/results/d7868bedeea734468190e50a82a9e25e6c3c55ac2bb1fdc5abc2a51dfd069e530.530.490.39
30.51rotten_tomatoesmistralai/Mistral-7B-v0.1211movie_sentimentdifference1/home/chansingh/augmented-interpretable-models/augdistill/results18/home/chansingh/augmented-interpretable-models/augdistill/results/0fbf0437785d09339cfc969a7035719bee6815edd3e4ab7bafbe0a2b637719e20.510.490.64
70.51rotten_tomatoesmistralai/Mistral-7B-v0.1211movie_sentimentpos_class1/home/chansingh/augmented-interpretable-models/augdistill/results18/home/chansingh/augmented-interpretable-models/augdistill/results/3ec7890e8e37b6ee89ecbe9c723f8bb53bc11b459641c469188b2c33418951730.510.490.64
120.51rotten_tomatoesbert-base-uncased210Nonepos_class1/home/chansingh/augmented-interpretable-models/augdistill/results18/home/chansingh/augmented-interpretable-models/augdistill/results/89d804ff462734dded1c8d121761efe6a8b882396c1a38f463e5e4096050a63e0.490.490.70
140.49rotten_tomatoesgpt2-xl211movie_sentimentpos_class1/home/chansingh/augmented-interpretable-models/augdistill/results18/home/chansingh/augmented-interpretable-models/augdistill/results/9f080515fba8ef69a7f7e0322c514985ef701d3181e37f07de9f1e72108584d90.470.490.58
110.49rotten_tomatoesgpt2-xl211movie_sentimentdifference1/home/chansingh/augmented-interpretable-models/augdistill/results18/home/chansingh/augmented-interpretable-models/augdistill/results/81a91f52e82394957676e66cd65dfd9dcb377592104e7c52e512649529191e830.470.490.58
160.48rotten_tomatoesmistralai/Mistral-7B-v0.1211synonympos_class1/home/chansingh/augmented-interpretable-models/augdistill/results18/home/chansingh/augmented-interpretable-models/augdistill/results/d4737ff416b5fe2dde57987451785b2fa2107252a8888ee6eef29f517f8a09840.530.490.35
10.48rotten_tomatoesmistralai/Mistral-7B-v0.1211synonymdifference1/home/chansingh/augmented-interpretable-models/augdistill/results18/home/chansingh/augmented-interpretable-models/augdistill/results/0c89e0508fbf4d9b35eedc15206d2283478f76239e5a457458b1c5368d9ba3df0.530.490.35
90.47rotten_tomatoesgpt2-xl211synonympos_class1/home/chansingh/augmented-interpretable-models/augdistill/results18/home/chansingh/augmented-interpretable-models/augdistill/results/645fb3ab6709f9bf6e2b9e9bacbf115bb07c39c6ad7f3638f73f1a81696349450.440.490.53
80.47rotten_tomatoesgpt2-xl211synonymdifference1/home/chansingh/augmented-interpretable-models/augdistill/results18/home/chansingh/augmented-interpretable-models/augdistill/results/4cfb28799569489386f1dda7062f3313e80d735195557d5c42495ed4d65d169f0.440.490.53
00.47rotten_tomatoesgpt2211synonympos_class1/home/chansingh/augmented-interpretable-models/augdistill/results18/home/chansingh/augmented-interpretable-models/augdistill/results/05a0ea0b6832a65d1f5e01a496f4aaaeae8fa031edc6bd42f4fc40a988dcad870.470.490.45
130.47rotten_tomatoesgpt2211synonymdifference1/home/chansingh/augmented-interpretable-models/augdistill/results18/home/chansingh/augmented-interpretable-models/augdistill/results/9129bc24fc4b7bbc3ac5d00e5a95bd5dbd4dfd63990d896ebd68476047a79e6b0.470.490.45
60.45rotten_tomatoesgpt2211movie_sentimentpos_class1/home/chansingh/augmented-interpretable-models/augdistill/results18/home/chansingh/augmented-interpretable-models/augdistill/results/30d5fbec069b556534642cdbded906dbc0b74494da45e716d153bc3737f2771c0.460.490.50
50.45rotten_tomatoesgpt2211movie_sentimentdifference1/home/chansingh/augmented-interpretable-models/augdistill/results18/home/chansingh/augmented-interpretable-models/augdistill/results/27fd34241cdd2fdb243a8ab5bceb88b877cbce78e6ef9dc57ba6c02c26f54c220.460.490.50
100.41rotten_tomatoesmeta-llama/Llama-2-7b-hf211movie_sentimentpos_class1/home/chansingh/augmented-interpretable-models/augdistill/results18/home/chansingh/augmented-interpretable-models/augdistill/results/8100aad2235cf37c339436d647581b621aa473f1b135061e49670c52b5eef7c60.350.490.45
180.41rotten_tomatoesmeta-llama/Llama-2-7b-hf211movie_sentimentdifference1/home/chansingh/augmented-interpretable-models/augdistill/results18/home/chansingh/augmented-interpretable-models/augdistill/results/e6b16955710de162534e262a330de205516f38df150f57ce2adf8a01cbdd302c0.350.490.45
\n", - "
" + "\n" ], "text/plain": [ - " dataset_name checkpoint ngrams \\\n", - "0 rotten_tomatoes textattack/distilbert-base-uncased-rotten-toma... 2 \n", - "1 rotten_tomatoes bert-base-uncased 2 \n", - "2 rotten_tomatoes hkunlp/instructor-xl 2 \n", - "\n", - " use_all_ngrams use_next_token_distr_embedding embedding_string_prompt \\\n", - "0 1 0 None \n", - "1 1 0 None \n", - "2 1 0 instructor_sentiment \n", - "\n", - " seed save_dir model_name \\\n", - "0 1 /home/chansingh/augmented-interpretable-models... decision_tree \n", - "1 1 /home/chansingh/augmented-interpretable-models... decision_tree \n", - "2 1 /home/chansingh/augmented-interpretable-models... decision_tree \n", - "\n", - " use_cache batch_size save_dir_unique \\\n", - "0 1 8 /home/chansingh/augmented-interpretable-models... \n", - "1 1 8 /home/chansingh/augmented-interpretable-models... \n", - "2 1 8 /home/chansingh/augmented-interpretable-models... \n", - "\n", - " roc_val acc_val acc_baseline mean_pred \n", - "0 0.62 0.593333 0.486667 0.56 \n", - "1 0.49 0.506667 0.486667 0.70 \n", - "2 0.62 0.626667 0.486667 0.50 " + "" ] }, - "execution_count": 5, "metadata": {}, - "output_type": "execute_result" + "output_type": "display_data" } ], "source": [ - "r" + "# color the acc_val column\n", + "display(\n", + " r\n", + " .style\n", + " .background_gradient(\n", + " cmap='viridis', subset=['acc_val']\n", + " )\n", + " .format(precision=2)\n", + ")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/augdistill/scripts/01_eval_basic.py b/augdistill/scripts/01_eval_basic.py index 7b1dff0..fb5e49e 100644 --- a/augdistill/scripts/01_eval_basic.py +++ b/augdistill/scripts/01_eval_basic.py @@ -3,39 +3,28 @@ import os.path repo_dir = dirname(dirname(os.path.abspath(__file__))) -# Showcasing different ways to sweep over arguments -# Can pass any empty dict for any of these to avoid sweeping - -# List of values to sweep over (sweeps over all combinations of these) params_shared_dict = { 'seed': [1], 'save_dir': [join(repo_dir, 'results')], - # pass binary values with 0/1 instead of the ambiguous strings True/False 'use_cache': [1], } -# List of tuples to sweep over (these values are coupled, and swept over together) -# Note: this is a dictionary so you shouldn't have repeated keys params_coupled_dict = { - ('checkpoint', 'embedding_string_prompt', 'use_next_token_distr_embedding'): [ - ('bert-base-uncased', None, 0), - ('textattack/distilbert-base-uncased-rotten-tomatoes', None, 0), - ('hkunlp/instructor-xl', 'instructor_sentiment', 0), - - ('gpt2', 'synonym', 1), - ('gpt2-xl', 'synonym', 1), - ('meta-llama/Llama-2-7b-hf', 'synonym', 1), - ('mistralai/Mistral-7B-v0.1', 'synonym', 1), + # ('checkpoint', 'embedding_string_prompt', 'use_next_token_distr_embedding'): [ + # ('bert-base-uncased', None, 0), + # ('textattack/distilbert-base-uncased-rotten-tomatoes', None, 0), + # ('hkunlp/instructor-xl', 'instructor_sentiment', 0), + # ], + ('checkpoint', 'embedding_string_prompt', 'use_next_token_distr_embedding', 'zeroshot_strategy'): [ - ('gpt2', 'movie_sentiment', 1), - ('gpt2-xl', 'movie_sentiment', 1), - ('meta-llama/Llama-2-7b-hf', 'movie_sentiment', 1), - ('mistralai/Mistral-7B-v0.1', 'movie_sentiment', 1), + (checkpoint, string_prompt, 1, zeroshot_strategy) + # for checkpoint in ['gpt2', 'gpt2-xl', 'meta-llama/Llama-2-7b-hf', 'mistralai/Mistral-7B-v0.1'] + for checkpoint in ['mistralai/Mixtral-8x7B-v0.1', 'meta-llama/Llama-2-13b-hf'] + for string_prompt in ['synonym', 'movie_sentiment'] + for zeroshot_strategy in ['pos_class', 'difference'] ], } -# Args list is a list of dictionaries -# If you want to do something special to remove some of these runs, can remove them before calling run_args_list args_list = submit_utils.get_args_list( params_shared_dict=params_shared_dict, params_coupled_dict=params_coupled_dict, @@ -44,5 +33,7 @@ args_list, script_name=join(repo_dir, 'experiments', '01_eval.py'), actually_run=True, - gpu_ids=[0, 1, 2, 3], + # gpu_ids=[0, 1, 2, 3], + gpu_ids=[[0, 1], [2, 3]], + # debug_mode=True, )