Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
216 lines (203 sloc) 5.81 KB
"""Run experiments with CLI"""
import argparse
import os
import shutil
from .experiments import registry
def init_flags():
"""Init command line flags used for experiment configuration."""
parser = argparse.ArgumentParser(
description="Runs experiments to find robust, sparse networks.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--trials",
metavar="trials",
type=int,
nargs=1,
default=[10],
help="number trials per experiment",
)
parser.add_argument(
"--train_iters",
metavar="train_iters",
type=int,
nargs=1,
default=[50000],
help="number of training iterations",
)
parser.add_argument(
"--prune_iters",
metavar="prune_iters",
type=int,
nargs=1,
default=[20],
help="number of pruning iterations",
)
parser.add_argument(
"--eval_every",
metavar="eval_every",
type=int,
nargs=1,
default=[500],
help="number of iterations to eval on validation set",
)
parser.add_argument(
"--batch_size",
metavar="batch_size",
type=int,
nargs=1,
default=[60],
help="batch size",
)
parser.add_argument(
"--valid_size",
metavar="valid_size",
type=int,
nargs=1,
default=[10000],
help="validation set size",
)
parser.add_argument(
"--dataset",
metavar="dataset",
type=str,
nargs=1,
default=["digits"],
choices=["digits", "fashion"],
help="source dataset",
)
parser.add_argument(
"--model",
metavar="model",
type=str,
nargs=1,
default=["dense-300-100"],
choices=["dense-300-100"],
help="model type",
)
parser.add_argument(
"--experiment",
metavar="experiment",
type=str,
nargs=1,
default=["reinit_orig"],
choices=["no_pruning", "reinit_rand", "reinit_orig", "reinit_none"],
help="the experiment to run",
)
base_dir_default = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "output"
)
parser.add_argument(
"--base_dir",
metavar="base_dir",
type=str,
nargs=1,
default=[base_dir_default],
help="base output directory for results and checkpoints",
)
parser.add_argument(
"--attack",
metavar="attack",
type=str,
nargs=1,
default=["fgsm"],
choices=["fgsm", "pgd"],
help="adversarial attack used for training and evaluation",
)
parser.add_argument(
"--adv_train",
action="store_true",
default=False,
help="use adversarial training for the given attack method",
)
parser.add_argument(
"-lr",
"--learning_rate",
metavar="learning_rate",
type=float,
nargs=1,
default=[0.0012],
help="model's learning rate",
)
parser.add_argument(
"-l1",
"--l1_reg",
metavar="l1_reg",
type=float,
nargs=1,
default=[0.0],
help="l1 regularization penalty",
)
parser.add_argument(
"--devices",
metavar="devices",
type=str,
nargs=1,
default=["0,1,2,3"],
help="gpu devices",
)
parser.add_argument(
"--force",
action="store_true",
default=False,
help="force train, deleting old experiment dirs if existing.",
)
return parser.parse_args()
def parse_args(args):
"""Parse provided args for runtime configuration."""
hparams = {
"trials": args.trials[0],
"train_iters": args.train_iters[0],
"prune_iters": args.prune_iters[0],
"eval_every": args.eval_every[0],
"batch_size": args.batch_size[0],
"valid_size": args.valid_size[0],
"dataset": args.dataset[0],
"model": args.model[0],
"experiment": args.experiment[0],
"attack": args.attack[0],
"adv_train": args.adv_train,
"base_dir": os.path.join(
args.base_dir[0],
args.dataset[0],
args.model[0],
args.experiment[0],
args.attack[0],
),
"learning_rate": args.learning_rate[0],
"l1_reg": args.l1_reg[0],
"devices": args.devices[0],
"force": args.force,
}
exp_dir = "lr-{}_l1-{}_advtrain-{}".format(
hparams["learning_rate"], hparams["l1_reg"], str(hparams["adv_train"]).lower()
)
hparams["base_dir"] = os.path.join(hparams["base_dir"], exp_dir)
hparams["percents"] = {"hidden_1": 0.2, "hidden_2": 0.2, "output": 0.1}
print("-" * 40, "hparams", "-" * 40)
print("Beginning experiments using the following configuration:\n")
for param, value in hparams.items():
print("\t{:>13}: {}".format(param, value))
print()
print("-" * 89)
return hparams
def main():
"""Parses command line arguments and runs the specified experiment."""
# Init hparams
hparams = parse_args(init_flags())
os.environ["CUDA_VISIBLE_DEVICES"] = hparams["devices"]
# Check if base_dir already exists, fail or create as specified
if os.path.exists(hparams["base_dir"]) and not hparams["force"]:
raise Exception(
"directory '{} already exists. ".format(hparams["base_dir"])
+ "Run with --force to overwrite."
)
if os.path.exists(hparams["base_dir"]):
shutil.rmtree(hparams["base_dir"])
os.makedirs(hparams["base_dir"])
# Fetch experiment function
run_fn = registry.get_experiment_fn(hparams["experiment"])
# Run experiments
run_fn(hparams["dataset"], hparams["model"], hparams)
if __name__ == "__main__":
main()
You can’t perform that action at this time.