Skip to content

Commit

Permalink
update expt params
Browse files Browse the repository at this point in the history
  • Loading branch information
csinva committed Feb 13, 2024
1 parent 0eb7794 commit 94269aa
Show file tree
Hide file tree
Showing 3 changed files with 488 additions and 149 deletions.
9 changes: 6 additions & 3 deletions augdistill/experiments/01_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import joblib
import imodels
import inspect
import torch
import os.path
import imodelsx.cache_save_utils
from imodelsx import AugLinearClassifier
Expand Down Expand Up @@ -49,7 +50,9 @@ def add_main_args(parser):
parser.add_argument(
"--embedding_string_prompt", type=str, default="synonym", choices=set(list(EMBEDDING_STRING_SETTINGS.keys()) + ['None']), help="key for embedding string"
)

parser.add_argument(
'--zeroshot_strategy', type=str, default='pos_class', choices=['pos_class', 'difference'], help='strategy for zeroshot'
)
# training misc args
parser.add_argument("--seed", type=int, default=1, help="random seed")
parser.add_argument(
Expand Down Expand Up @@ -107,7 +110,7 @@ def add_computational_args(parser):
# set seed
np.random.seed(args.seed)
random.seed(args.seed)
# torch.manual_seed(args.seed)
torch.manual_seed(args.seed)

# load text data
dset_val = datasets.load_dataset(args.dataset_name)['validation']
Expand Down Expand Up @@ -157,5 +160,5 @@ def add_computational_args(parser):
r, join(save_dir_unique, "results.pkl")
) # caching requires that this is called results.pkl
# joblib.dump(model, join(save_dir_unique, "model.pkl"))
print(r)
# print(r)
logging.info("Succesfully completed :)\n\n")
Loading

0 comments on commit 94269aa

Please sign in to comment.