Skip to content

Commit

Permalink
update all
Browse files Browse the repository at this point in the history
  • Loading branch information
dragen1860 committed Jan 28, 2019
1 parent 4495d70 commit a1ced28
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 41 deletions.
51 changes: 23 additions & 28 deletions .idea/workspace.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 7 additions & 10 deletions test.py
Expand Up @@ -11,30 +11,26 @@

from model import NetworkCIFAR as Network

parser = argparse.ArgumentParser("cifar")
parser = argparse.ArgumentParser("cifar10")
parser.add_argument('--data', type=str, default='../data', help='location of the data corpus')
parser.add_argument('--batchsz', type=int, default=16, help='batch size')
parser.add_argument('--batchsz', type=int, default=36, help='batch size')
parser.add_argument('--report_freq', type=float, default=50, help='report frequency')
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
parser.add_argument('--init_ch', type=int, default=16, help='num of init channels')
parser.add_argument('--layers', type=int, default=8, help='total number of layers')
parser.add_argument('--exp_path', type=str, default='exp1/model.pt', help='path of pretrained model')
parser.add_argument('--init_ch', type=int, default=36, help='num of init channels')
parser.add_argument('--layers', type=int, default=20, help='total number of layers')
parser.add_argument('--exp_path', type=str, default='exp/model.pt', help='path of pretrained model')
parser.add_argument('--auxiliary', action='store_true', default=False, help='use auxiliary tower')
parser.add_argument('--cutout', action='store_true', default=False, help='use cutout')
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
parser.add_argument('--drop_path_prob', type=float, default=0.2, help='drop path probability')
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--arch', type=str, default='MyDARTS', help='which architecture to use')
parser.add_argument('--arch', type=str, default='DARTS', help='which architecture to use')
args = parser.parse_args()

log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
format=log_format, datefmt='%m/%d %I:%M:%S %p')

CIFAR_CLASSES = 10





def main():
Expand Down Expand Up @@ -68,6 +64,7 @@ def main():


def infer(test_queue, model, criterion):

objs = utils.AverageMeter()
top1 = utils.AverageMeter()
top5 = utils.AverageMeter()
Expand Down
8 changes: 5 additions & 3 deletions train.py
Expand Up @@ -17,7 +17,7 @@

parser = argparse.ArgumentParser("cifar10")
parser.add_argument('--data', type=str, default='../data', help='location of the data corpus')
parser.add_argument('--batchsz', type=int, default=48, help='batch size')
parser.add_argument('--batchsz', type=int, default=30, help='batch size')
parser.add_argument('--lr', type=float, default=0.025, help='init learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--wd', type=float, default=3e-4, help='weight decay')
Expand Down Expand Up @@ -91,11 +91,13 @@ def main():
logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
model.drop_path_prob = args.drop_path_prob * epoch / args.epochs

valid_acc, valid_obj = infer(valid_queue, model, criterion)
logging.info('valid_acc: %f', valid_acc)

train_acc, train_obj = train(train_queue, model, criterion, optimizer)
logging.info('train_acc: %f', train_acc)

valid_acc, valid_obj = infer(valid_queue, model, criterion)
logging.info('valid_acc: %f', valid_acc)


utils.save(model, os.path.join(args.save, 'trained.pt'))
print('saved to: trained.pt')
Expand Down

0 comments on commit a1ced28

Please sign in to comment.