diff --git a/.idea/workspace.xml b/.idea/workspace.xml index 1487a39..025947b 100644 --- a/.idea/workspace.xml +++ b/.idea/workspace.xml @@ -2,13 +2,8 @@ - - - - - @@ -223,16 +218,6 @@ - - - - - - - - - - @@ -277,20 +262,30 @@ - - + + - - + + + + + + + + + + + + \ No newline at end of file diff --git a/test.py b/test.py index e9169ae..dc54a3e 100644 --- a/test.py +++ b/test.py @@ -11,30 +11,26 @@ from model import NetworkCIFAR as Network -parser = argparse.ArgumentParser("cifar") +parser = argparse.ArgumentParser("cifar10") parser.add_argument('--data', type=str, default='../data', help='location of the data corpus') -parser.add_argument('--batchsz', type=int, default=16, help='batch size') +parser.add_argument('--batchsz', type=int, default=36, help='batch size') parser.add_argument('--report_freq', type=float, default=50, help='report frequency') parser.add_argument('--gpu', type=int, default=0, help='gpu device id') -parser.add_argument('--init_ch', type=int, default=16, help='num of init channels') -parser.add_argument('--layers', type=int, default=8, help='total number of layers') -parser.add_argument('--exp_path', type=str, default='exp1/model.pt', help='path of pretrained model') +parser.add_argument('--init_ch', type=int, default=36, help='num of init channels') +parser.add_argument('--layers', type=int, default=20, help='total number of layers') +parser.add_argument('--exp_path', type=str, default='exp/model.pt', help='path of pretrained model') parser.add_argument('--auxiliary', action='store_true', default=False, help='use auxiliary tower') parser.add_argument('--cutout', action='store_true', default=False, help='use cutout') parser.add_argument('--cutout_length', type=int, default=16, help='cutout length') parser.add_argument('--drop_path_prob', type=float, default=0.2, help='drop path probability') parser.add_argument('--seed', type=int, default=0, help='random seed') -parser.add_argument('--arch', type=str, default='MyDARTS', help='which architecture to use') +parser.add_argument('--arch', type=str, default='DARTS', help='which architecture to use') args = parser.parse_args() log_format = '%(asctime)s %(message)s' logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') -CIFAR_CLASSES = 10 - - - def main(): @@ -68,6 +64,7 @@ def main(): def infer(test_queue, model, criterion): + objs = utils.AverageMeter() top1 = utils.AverageMeter() top5 = utils.AverageMeter() diff --git a/train.py b/train.py index 942dc6c..08459b4 100644 --- a/train.py +++ b/train.py @@ -17,7 +17,7 @@ parser = argparse.ArgumentParser("cifar10") parser.add_argument('--data', type=str, default='../data', help='location of the data corpus') -parser.add_argument('--batchsz', type=int, default=48, help='batch size') +parser.add_argument('--batchsz', type=int, default=30, help='batch size') parser.add_argument('--lr', type=float, default=0.025, help='init learning rate') parser.add_argument('--momentum', type=float, default=0.9, help='momentum') parser.add_argument('--wd', type=float, default=3e-4, help='weight decay') @@ -91,11 +91,13 @@ def main(): logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0]) model.drop_path_prob = args.drop_path_prob * epoch / args.epochs + valid_acc, valid_obj = infer(valid_queue, model, criterion) + logging.info('valid_acc: %f', valid_acc) + train_acc, train_obj = train(train_queue, model, criterion, optimizer) logging.info('train_acc: %f', train_acc) - valid_acc, valid_obj = infer(valid_queue, model, criterion) - logging.info('valid_acc: %f', valid_acc) + utils.save(model, os.path.join(args.save, 'trained.pt')) print('saved to: trained.pt')