In [None]:
import os
import sys
import time
import glob
import numpy as np
import torch
import utils
import logging
import argparse
import torch.nn as nn
import genotypes
import torch.utils
import torchvision.datasets as dset
import torch.backends.cudnn as cudnn

from torch.autograd import Variable
from model import NetworkCIFAR as Network


parser = argparse.ArgumentParser("cifar")
parser.add_argument('--data', type=str, default='../data', help='location of the data corpus')
parser.add_argument('--batch_size', type=int, default=96, help='batch size')
parser.add_argument('--learning_rate', type=float, default=0.025, help='init learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay')
parser.add_argument('--report_freq', type=float, default=50, help='report frequency')
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
parser.add_argument('--epochs', type=int, default=600, help='num of training epochs')
parser.add_argument('--init_channels', type=int, default=36, help='num of init channels')
parser.add_argument('--layers', type=int, default=20, help='total number of layers')
parser.add_argument('--model_path', type=str, default='saved_models', help='path to save the model')
parser.add_argument('--auxiliary', action='store_true', default=True, help='use auxiliary tower')
parser.add_argument('--auxiliary_weight', type=float, default=0.4, help='weight for auxiliary loss')
parser.add_argument('--cutout', action='store_true', default=True, help='use cutout')
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
parser.add_argument('--drop_path_prob', type=float, default=0.2, help='drop path probability')
parser.add_argument('--save', type=str, default='EXP', help='experiment name')
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--arch', type=str, default='Random_NSAS', help='which architecture to use')
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
args = parser.parse_args([])

args.save = 'eval-{}-{}'.format(args.save, time.strftime("%Y%m%d-%H%M%S"))
utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))

log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
    format=log_format, datefmt='%m/%d %I:%M:%S %p')
fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)

CIFAR_CLASSES = 10


def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled=True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)

    genotype = eval("genotypes.%s" % args.arch)
    model = Network(args.init_channels, CIFAR_CLASSES, args.layers, args.auxiliary, genotype)
    model = model.cuda()

    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    optimizer = torch.optim.SGD(
      model.parameters(),
      args.learning_rate,
      momentum=args.momentum,
      weight_decay=args.weight_decay
      )

    train_transform, valid_transform = utils._data_transforms_cifar10(args)
    train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
    valid_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform)

    train_queue = torch.utils.data.DataLoader(
      train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=2)

    valid_queue = torch.utils.data.DataLoader(
      valid_data, batch_size=64, shuffle=False, pin_memory=True, num_workers=2)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs))

    for epoch in range(args.epochs):
        scheduler.step()
        logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
        model.drop_path_prob = args.drop_path_prob * epoch / args.epochs

        train_acc, train_obj = train(train_queue, model, criterion, optimizer)
        logging.info('train_acc %f', train_acc)

        valid_acc, valid_obj = infer(valid_queue, model, criterion)
        logging.info('valid_acc %f', valid_acc)

        utils.save(model, os.path.join(args.save, 'weights.pt'))


def train(train_queue, model, criterion, optimizer):
    objs = utils.AvgrageMeter()
    top1 = utils.AvgrageMeter()
    top5 = utils.AvgrageMeter()
    model.train()

    for step, (input, target) in enumerate(train_queue):
        input = Variable(input).cuda()
        target = Variable(target).cuda(async=True)

        optimizer.zero_grad()

        logits, logits_aux = model(input)
        loss = criterion(logits, target)
        if args.auxiliary:
            loss_aux = criterion(logits_aux, target)
            loss += args.auxiliary_weight*loss_aux
        loss.backward()
        nn.utils.clip_grad_norm(model.parameters(), args.grad_clip)
        optimizer.step()

        prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
        n = input.size(0)
        objs.update(loss.data, n)
        top1.update(prec1, n)
        top5.update(prec5, n)

        if step % args.report_freq == 0:
            logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)

    return top1.avg, objs.avg


def infer(valid_queue, model, criterion):
    objs = utils.AvgrageMeter()
    top1 = utils.AvgrageMeter()
    top5 = utils.AvgrageMeter()
    model.eval()

    for step, (input, target) in enumerate(valid_queue):
        input = Variable(input, volatile=True).cuda()
        target = Variable(target, volatile=True).cuda(async=True)

        logits, _ = model(input)
        loss = criterion(logits, target)

        prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
        n = input.size(0)
        objs.update(loss.data, n)
        top1.update(prec1, n)
        top5.update(prec5, n)

        if step % args.report_freq == 0:
              logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)

    return top1.avg, objs.avg


if __name__ == '__main__':
    main() 



Experiment dir : eval-EXP-20191106-143600
11/06 02:36:00 PM gpu device = 0
11/06 02:36:00 PM args = Namespace(arch='ENNAS_S4', auxiliary=True, auxiliary_weight=0.4, batch_size=96, cutout=True, cutout_length=16, data='../data', drop_path_prob=0.2, epochs=600, gpu=0, grad_clip=5, init_channels=36, layers=20, learning_rate=0.025, model_path='saved_models', momentum=0.9, report_freq=50, save='eval-EXP-20191106-143600', seed=0, weight_decay=0.0003)
108 108 36
108 144 36
144 144 36
144 144 36
144 144 36
144 144 36
144 144 72
144 288 72
288 288 72
288 288 72
288 288 72
288 288 72
288 288 72
288 288 144
288 576 144
576 576 144
576 576 144
576 576 144
576 576 144
576 576 144
11/06 02:36:03 PM param size = 3.083014MB
Files already downloaded and verified
Files already downloaded and verified
11/06 02:36:04 PM epoch 0 lr 2.500000e-02
11/06 02:36:06 PM train 000 3.245519e+00 13.541667 47.916668




11/06 02:36:34 PM train 050 3.164259e+00 15.584150 61.887257
11/06 02:37:01 PM train 100 3.022205e+00 18.440594 69.636963
11/06 02:37:29 PM train 150 2.921169e+00 20.854029 73.772072
11/06 02:37:56 PM train 200 2.844364e+00 22.761194 76.300789
11/06 02:38:24 PM train 250 2.787577e+00 24.302790 78.012947
11/06 02:38:51 PM train 300 2.744956e+00 25.415283 79.107834
11/06 02:39:19 PM train 350 2.697217e+00 26.759851 80.211296
11/06 02:39:46 PM train 400 2.653530e+00 28.067852 81.182465
11/06 02:40:13 PM train 450 2.615200e+00 29.076590 81.940598
11/06 02:40:41 PM train 500 2.586664e+00 29.906853 82.636810
11/06 02:40:52 PM train_acc 30.184000




11/06 02:40:53 PM valid 000 1.598739e+00 37.500000 90.625000
11/06 02:40:58 PM valid 050 1.705861e+00 38.694855 89.062500
11/06 02:41:02 PM valid 100 1.698782e+00 38.845917 88.892326
11/06 02:41:07 PM valid 150 1.696674e+00 39.093544 89.114235
11/06 02:41:07 PM valid_acc 38.939999
11/06 02:41:08 PM epoch 1 lr 2.499983e-02
11/06 02:41:09 PM train 000 2.271325e+00 38.541664 90.625000
11/06 02:41:38 PM train 050 2.220511e+00 40.277779 90.931374
11/06 02:42:07 PM train 100 2.203785e+00 40.965347 90.583748
11/06 02:42:37 PM train 150 2.194829e+00 41.280354 90.673286
11/06 02:43:07 PM train 200 2.169698e+00 42.076077 90.697556
11/06 02:43:36 PM train 250 2.151849e+00 42.563084 90.911354
11/06 02:44:06 PM train 300 2.138197e+00 43.002491 91.047203
11/06 02:44:35 PM train 350 2.121051e+00 43.604580 91.245247
11/06 02:45:04 PM train 400 2.105440e+00 44.038342 91.445869
11/06 02:45:33 PM train 450 2.090255e+00 44.521435 91.571968
11/06 02:46:03 PM train 500 2.071321e+00 45.120174 91.760223
11/06

11/06 03:23:39 PM train_acc 74.377998
11/06 03:23:39 PM valid 000 9.190739e-01 71.875000 96.875000
11/06 03:23:44 PM valid 050 8.388569e-01 72.579659 98.590691
11/06 03:23:48 PM valid 100 8.398958e-01 72.385521 98.530319
11/06 03:23:53 PM valid 150 8.350059e-01 72.247513 98.520279
11/06 03:23:53 PM valid_acc 72.220001
11/06 03:23:54 PM epoch 9 lr 2.498612e-02
11/06 03:23:54 PM train 000 1.130695e+00 72.916672 97.916672
11/06 03:24:24 PM train 050 9.659551e-01 76.000816 98.672386
11/06 03:24:53 PM train 100 9.522829e-01 76.278877 98.617989
11/06 03:25:22 PM train 150 9.672667e-01 75.903694 98.585815
11/06 03:25:52 PM train 200 9.654325e-01 75.860283 98.626656
11/06 03:26:21 PM train 250 9.687029e-01 75.805115 98.626328
11/06 03:26:50 PM train 300 9.642602e-01 75.792496 98.619186
11/06 03:27:20 PM train 350 9.707095e-01 75.664764 98.566589
11/06 03:27:49 PM train 400 9.714072e-01 75.657211 98.545303
11/06 03:28:18 PM train 450 9.726701e-01 75.639778 98.542587
11/06 03:28:48 PM train 500 

11/06 04:06:19 PM train 500 7.532424e-01 81.362274 99.081001
11/06 04:06:30 PM train_acc 81.369995
11/06 04:06:31 PM valid 000 2.587756e-01 95.312500 100.000000
11/06 04:06:35 PM valid 050 4.647587e-01 84.803925 98.988976
11/06 04:06:40 PM valid 100 4.613133e-01 84.591583 99.071777
11/06 04:06:44 PM valid 150 4.558442e-01 84.623344 99.172188
11/06 04:06:45 PM valid_acc 84.629997
11/06 04:06:45 PM epoch 17 lr 2.495051e-02
11/06 04:06:46 PM train 000 5.425017e-01 88.541672 100.000000
11/06 04:07:15 PM train 050 6.989810e-01 82.516342 98.999184
11/06 04:07:44 PM train 100 7.086259e-01 82.549507 99.092415
11/06 04:08:13 PM train 150 7.019632e-01 82.829742 99.116997
11/06 04:08:43 PM train 200 7.168804e-01 82.441956 99.077530
11/06 04:09:12 PM train 250 7.161223e-01 82.519920 99.086990
11/06 04:09:42 PM train 300 7.196797e-01 82.478546 99.058693
11/06 04:10:11 PM train 350 7.102574e-01 82.624046 99.088913
11/06 04:10:41 PM train 400 7.143474e-01 82.559227 99.083023
11/06 04:11:10 PM train 4

11/06 04:48:38 PM train 450 6.270215e-01 84.451218 99.406410
11/06 04:49:08 PM train 500 6.290877e-01 84.377075 99.405350
11/06 04:49:19 PM train_acc 84.374001
11/06 04:49:19 PM valid 000 5.125998e-01 85.937500 100.000000
11/06 04:49:24 PM valid 050 4.597166e-01 85.569855 98.988976
11/06 04:49:29 PM valid 100 4.450559e-01 85.457916 99.211014
11/06 04:49:33 PM valid 150 4.451975e-01 85.492546 99.296356
11/06 04:49:34 PM valid_acc 85.489998
11/06 04:49:34 PM epoch 25 lr 2.489306e-02
11/06 04:49:34 PM train 000 4.245305e-01 89.583336 100.000000
11/06 04:50:04 PM train 050 5.827578e-01 85.968140 99.468956
11/06 04:50:33 PM train 100 6.083041e-01 85.055695 99.432762
11/06 04:51:02 PM train 150 6.064464e-01 85.175224 99.455025
11/06 04:51:32 PM train 200 6.156759e-01 84.898422 99.419571
11/06 04:52:02 PM train 250 6.188223e-01 84.777557 99.365044
11/06 04:52:32 PM train 300 6.198094e-01 84.793747 99.380539
11/06 04:53:01 PM train 350 6.181993e-01 84.799377 99.394585
11/06 04:53:31 PM train 4

11/06 05:30:56 PM train 400 5.624021e-01 86.250519 99.542816
11/06 05:31:25 PM train 450 5.659025e-01 86.141907 99.524208
11/06 05:31:55 PM train 500 5.643611e-01 86.148537 99.517632
11/06 05:32:07 PM train_acc 86.159996
11/06 05:32:07 PM valid 000 4.488550e-01 84.375000 100.000000
11/06 05:32:11 PM valid 050 4.084399e-01 86.917892 99.571083
11/06 05:32:16 PM valid 100 3.799438e-01 87.670174 99.489479
11/06 05:32:20 PM valid 150 3.791148e-01 87.458611 99.575745
11/06 05:32:21 PM valid_acc 87.409996
11/06 05:32:21 PM epoch 33 lr 2.481387e-02
11/06 05:32:22 PM train 000 5.770319e-01 82.291672 98.958336
11/06 05:32:52 PM train 050 5.290713e-01 87.132355 99.611931
11/06 05:33:21 PM train 100 5.433530e-01 86.953384 99.504951
11/06 05:33:50 PM train 150 5.345740e-01 87.092995 99.544701
11/06 05:34:20 PM train 200 5.396619e-01 86.893654 99.554314
11/06 05:34:49 PM train 250 5.425290e-01 86.848442 99.539345
11/06 05:35:18 PM train 300 5.354113e-01 86.973976 99.581261
11/06 05:35:48 PM train 35

11/06 06:13:10 PM train 350 5.093474e-01 87.672127 99.596390
11/06 06:13:39 PM train 400 5.082907e-01 87.661057 99.605156
11/06 06:14:09 PM train 450 5.106772e-01 87.587769 99.595802
11/06 06:14:38 PM train 500 5.119432e-01 87.570686 99.590401
11/06 06:14:50 PM train_acc 87.589996
11/06 06:14:50 PM valid 000 5.575035e-01 82.812500 100.000000
11/06 06:14:54 PM valid 050 4.391784e-01 86.182602 99.387260
11/06 06:14:59 PM valid 100 4.115392e-01 86.881187 99.474007
11/06 06:15:03 PM valid 150 4.082235e-01 86.827400 99.524010
11/06 06:15:04 PM valid_acc 86.809998
11/06 06:15:04 PM epoch 41 lr 2.471307e-02
11/06 06:15:05 PM train 000 5.262319e-01 84.375000 100.000000
11/06 06:15:34 PM train 050 5.104333e-01 87.132355 99.591507
11/06 06:16:03 PM train 100 5.028355e-01 87.654709 99.587463
11/06 06:16:33 PM train 150 5.189091e-01 86.975716 99.634384
11/06 06:17:02 PM train 200 5.121628e-01 87.209785 99.616501
11/06 06:17:31 PM train 250 5.110152e-01 87.226097 99.605743
11/06 06:18:01 PM train 3

11/06 06:55:24 PM train 300 4.828324e-01 88.091782 99.681618
11/06 06:55:54 PM train 350 4.840752e-01 88.054962 99.682449
11/06 06:56:23 PM train 400 4.809275e-01 88.068893 99.667503
11/06 06:56:53 PM train 450 4.824628e-01 88.047394 99.660477
11/06 06:57:22 PM train 500 4.846637e-01 88.046822 99.642380
11/06 06:57:34 PM train_acc 88.059998
11/06 06:57:35 PM valid 000 2.411368e-01 92.187500 100.000000
11/06 06:57:39 PM valid 050 3.438650e-01 88.878677 99.479172
11/06 06:57:44 PM valid 100 3.338697e-01 88.923264 99.504951
11/06 06:57:48 PM valid 150 3.316445e-01 89.062500 99.565399
11/06 06:57:49 PM valid_acc 89.119995
11/06 06:57:49 PM epoch 49 lr 2.459085e-02
11/06 06:57:50 PM train 000 5.473651e-01 86.458336 98.958336
11/06 06:58:19 PM train 050 4.646658e-01 88.582520 99.591507
11/06 06:58:49 PM train 100 4.641036e-01 88.510727 99.618401
11/06 06:59:18 PM train 150 4.722604e-01 88.431290 99.620583
11/06 06:59:47 PM train 200 4.788038e-01 88.266998 99.637230
11/06 07:00:17 PM train 25

11/06 07:37:41 PM train 250 4.577881e-01 88.753326 99.684601
11/06 07:38:10 PM train 300 4.578388e-01 88.718163 99.685081
11/06 07:38:39 PM train 350 4.595106e-01 88.722694 99.661682
11/06 07:39:09 PM train 400 4.616805e-01 88.687141 99.636330
11/06 07:39:38 PM train 450 4.659998e-01 88.569382 99.632759
11/06 07:40:07 PM train 500 4.690641e-01 88.481369 99.629906
11/06 07:40:19 PM train_acc 88.517998
11/06 07:40:19 PM valid 000 3.615554e-01 87.500000 100.000000
11/06 07:40:24 PM valid 050 3.223931e-01 89.613976 99.571083
11/06 07:40:28 PM valid 100 3.194835e-01 89.325493 99.597771
11/06 07:40:33 PM valid 150 3.122733e-01 89.548843 99.710266
11/06 07:40:33 PM valid_acc 89.549995
11/06 07:40:34 PM epoch 57 lr 2.444741e-02
11/06 07:40:34 PM train 000 4.739388e-01 89.583336 96.875000
11/06 07:41:04 PM train 050 4.180723e-01 89.726311 99.652779
11/06 07:41:33 PM train 100 4.290462e-01 89.583336 99.649345
11/06 07:42:02 PM train 150 4.424633e-01 89.321190 99.641281
11/06 07:42:32 PM train 20

11/06 08:19:58 PM train 200 4.260488e-01 89.629974 99.761612
11/06 08:20:27 PM train 250 4.315705e-01 89.504486 99.759300
11/06 08:20:57 PM train 300 4.381827e-01 89.334167 99.750832
11/06 08:21:26 PM train 350 4.333342e-01 89.512108 99.726967
11/06 08:21:56 PM train 400 4.342879e-01 89.484627 99.722054
11/06 08:22:25 PM train 450 4.350314e-01 89.460922 99.722839
11/06 08:22:55 PM train 500 4.359650e-01 89.427391 99.717232
11/06 08:23:06 PM train_acc 89.383995
11/06 08:23:07 PM valid 000 4.891018e-01 87.500000 98.437500
11/06 08:23:11 PM valid 050 3.767351e-01 88.572304 99.080887
11/06 08:23:16 PM valid 100 3.569663e-01 89.016090 99.319305
11/06 08:23:20 PM valid 150 3.471820e-01 88.979721 99.461922
11/06 08:23:21 PM valid_acc 88.989998
11/06 08:23:21 PM epoch 65 lr 2.428302e-02
11/06 08:23:21 PM train 000 4.390177e-01 90.625000 100.000000
11/06 08:23:51 PM train 050 4.394233e-01 88.868462 99.734482
11/06 08:24:21 PM train 100 4.312381e-01 89.418320 99.721535
11/06 08:24:50 PM train 15

11/06 09:02:16 PM train 150 4.226015e-01 89.528145 99.730957
11/06 09:02:45 PM train 200 4.173083e-01 89.697350 99.730515
11/06 09:03:14 PM train 250 4.155557e-01 89.902893 99.717796
11/06 09:03:44 PM train 300 4.125207e-01 89.981316 99.702385
11/06 09:04:13 PM train 350 4.183334e-01 89.823715 99.712128
11/06 09:04:42 PM train 400 4.174574e-01 89.830116 99.706467
11/06 09:05:12 PM train 450 4.210854e-01 89.745010 99.699738
11/06 09:05:41 PM train 500 4.218197e-01 89.693527 99.700592
11/06 09:05:53 PM train_acc 89.695999
11/06 09:05:53 PM valid 000 4.285156e-01 85.937500 100.000000
11/06 09:05:57 PM valid 050 3.047664e-01 89.950981 99.632355
11/06 09:06:02 PM valid 100 2.927603e-01 90.222771 99.659653
11/06 09:06:06 PM valid 150 2.927761e-01 90.107613 99.710266
11/06 09:06:07 PM valid_acc 90.059998
11/06 09:06:07 PM epoch 73 lr 2.409795e-02
11/06 09:06:08 PM train 000 4.431114e-01 91.666672 100.000000
11/06 09:06:37 PM train 050 4.133730e-01 89.889709 99.714050
11/06 09:07:06 PM train 1