diff --git a/.circleci/config.yml b/.circleci/config.yml index d59fd1c7..5bae2340 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -160,7 +160,7 @@ commands: echo "Using $(python -V) ($(which python))" echo "Using $(pip -V) ($(which pip))" pip install tensorboard - python examples/cifar10.py --lr 0.1 --sigma 1.5 -c 10 --sample-rate 0.04 --epochs 10 --data-root runs/cifar10/data --log-dir runs/cifar10/logs --device <> + python examples/cifar10.py --lr 0.1 --sigma 1.5 -c 10 --batch-size 2000 --epochs 10 --data-root runs/cifar10/data --log-dir runs/cifar10/logs --device <> python -c "import torch; model = torch.load('model_best.pth.tar'); exit(0) if (model['best_acc1']>0.4 and model['best_acc1']<0.49) else exit(1)" when: always - store_test_results: diff --git a/examples/cifar10.py b/examples/cifar10.py index 90902dff..9f2305dd 100644 --- a/examples/cifar10.py +++ b/examples/cifar10.py @@ -262,7 +262,7 @@ def main(): train_loader = torch.utils.data.DataLoader( train_dataset, - batch_size=int(args.sample_rate * len(train_dataset)), + batch_size=args.batch_size, generator=generator, num_workers=args.workers, pin_memory=True, @@ -421,11 +421,11 @@ def parse_args(): "using Data Parallel or Distributed Data Parallel", ) parser.add_argument( - "--sample-rate", - default=0.04, + "--batch-size", + default=2000, type=float, - metavar="SR", - help="sample rate used for batch construction (default: 0.005)", + metavar="N", + help="train bacth size", ) parser.add_argument( "--lr",