In [1]:
import json
import os
import time
import random
from PIL import Image
from PIL.Image import BICUBIC
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import shutil

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.utils.data import Subset
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision import models
from torchsummary import summary
from torchvision.datasets.cifar import CIFAR10
import ignite
from ignite.contrib.handlers import ProgressBar

from model_pytorch import EfficientNet
from utils import Bar,Logger, AverageMeter, accuracy, mkdir_p, savefig
from warmup_scheduler import GradualWarmupScheduler

# Arguments

In [2]:
pretrained = ''
resume = 'd:/log2/checkpoint.pth.tar'

In [3]:
# Model
model_name = 'efficientnet-b0' # b0-b7 scale
data_dir = 'd:/dataset/cifar10/cifar-10-batches-py'

# Optimization
num_classes = 10
epochs = 400
start_epoch = 0
train_batch = 60
test_batch = 50
lr = 0.04
schedule = [150, 225]
momentum = 0.9
gamma = 0.1 # LR is multiplied by gamma on schedule

# CheckPoint
checkpoint = 'd:/log2/' # dir
if not os.path.isdir(checkpoint):
    os.mkdir(checkpoint)
num_workers = 4

# GPU Device
gpu_id = 0
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
use_cuda = torch.cuda.is_available()
print("GPU device %d:" %(gpu_id), use_cuda)

# Seed
manual_seed = 7
random.seed(manual_seed)
torch.cuda.manual_seed_all(manual_seed)

best_acc = 0
image_size = 224

GPU device 0: True


In [4]:
state = {}
state['num_classes'] = num_classes
state['epochs'] = epochs
state['start_epoch'] = start_epoch
state['train_batch'] = train_batch
state['test_batch'] = test_batch
state['lr'] = lr
state['schedule'] = schedule
state['momentum'] = momentum
state['gamma'] = gamma

# Dataset

In [5]:
train_aug = transforms.Compose([
    transforms.Resize(image_size, BICUBIC),
    transforms.RandomAffine(degrees=2, translate=(0.02, 0.02), scale=(0.98, 1.02), shear=2, fillcolor=(124,117,104)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=True),
])
val_aug = transforms.Compose([
    transforms.Resize(image_size, BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = CIFAR10(data_dir, train=True, transform=train_aug, download=False)
test_dataset = CIFAR10(data_dir, train=False, transform=val_aug, download=False)

train_eval_indices = [random.randint(0, len(train_dataset) - 1) for i in range(len(test_dataset))]
train_eval_dataset = Subset(train_dataset, train_eval_indices)


# pin_memory : cuda pin memeory use
train_loader = DataLoader(train_dataset, batch_size=train_batch, drop_last=True, 
                          shuffle=True, num_workers=num_workers, pin_memory=True)
val_loader = DataLoader(test_dataset, batch_size=test_batch, drop_last=False,
                        shuffle=False, num_workers=num_workers, pin_memory=True)
eval_train_loader = DataLoader(train_eval_dataset, batch_size=train_batch, drop_last=False,
                              shuffle=False, num_workers=num_workers, pin_memory=True)

# Model

In [6]:
model = EfficientNet.from_name(model_name, num_classes=num_classes)

# Pre-trained
if pretrained:
    print("=> using pre-trained model '{}'".format(pretrained))
    model.load_state_dict(torch.load(pretrained))

In [7]:
model.to('cuda')
cudnn.benchmark = True
print('    Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0))

    Total params: 4.02M


In [8]:
# summary(model, input_size=(3,64,64), device='cuda')

# Loss

In [9]:
criterion = nn.CrossEntropyLoss().cuda()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, nesterov=True)

In [10]:
scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=8, total_epoch=10, after_scheduler=scheduler_cosine)

In [11]:
# Resume
if resume:
    print('==> Resuming from checkpoint..')
    checkpoint = os.path.dirname(resume)
    checkpoint = torch.load(resume)
    best_acc = checkpoint['best_acc']
    start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    checkpoint = "d:/log2/"
    logger = Logger(os.path.join(checkpoint, 'log.txt'), resume=True)
else:
    logger = Logger(os.path.join(checkpoint, 'log.txt'))
    logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.'])

==> Resuming from checkpoint..


In [12]:
checkpoint

'd:/log2/'

# Train

In [13]:

def train(train_loader, model, criterion, optimizer, epoch, use_cuda):
    model.train()
    torch.set_grad_enabled(True)
    
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    end = time.time()
    
    bar = Bar('Processing', max=len(train_loader))
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        batch_size = inputs.size(0)
        if batch_size < train_batch:
            continue
        # measure data loading time
        data_time.update(time.time() - end)

        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

        # compute output
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # measure accuracy and record loss
        prec1 = accuracy(outputs.data, targets.data)
        losses.update(loss.data.tolist(), inputs.size(0))
        top1.update(prec1[0], inputs.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # plot progress
        bar.suffix  = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} '.format(
                    batch=batch_idx + 1,
                    size=len(train_loader),
                    data=data_time.val,
                    bt=batch_time.val,
                    total=bar.elapsed_td,
                    eta=bar.eta_td,
                    loss=losses.avg,
                    top1=top1.avg,
                    )
        bar.next()
        if batch_idx % 10 == 0:
            print('{batch}/{size} Data:{data:.3f} | Batch:{bt:.3f} | Total:{total:} | ETA:{eta:} | Loss:{loss:} | top1:{tp1:}'.format(
                 batch=batch_idx+1, size=len(train_loader), data=data_time.val, bt=batch_time.val, total=bar.elapsed_td, eta=bar.eta_td, loss=losses.avg, tp1=top1.avg))
    bar.finish()
    return (losses.avg, top1.avg)

In [14]:
def test(val_loader, model, criterion, epoch, use_cuda):
    global best_acc

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to evaluate mode
    model.eval()
    torch.set_grad_enabled(False)

    end = time.time()
    bar = Bar('Processing', max=len(val_loader))
    for batch_idx, (inputs, targets) in enumerate(val_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        inputs, targets = torch.autograd.Variable(inputs, volatile=True), torch.autograd.Variable(targets)

        # compute output
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # measure accuracy and record loss
        prec1 = accuracy(outputs.data, targets.data)
        losses.update(loss.data.tolist(), inputs.size(0))
        top1.update(prec1[0], inputs.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # plot progress
        bar.suffix  = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:} | top1: {top1:}'.format(
                    batch=batch_idx + 1,
                    size=len(val_loader),
                    data=data_time.avg,
                    bt=batch_time.avg,
                    total=bar.elapsed_td,
                    eta=bar.eta_td,
                    loss=losses.avg,
                    top1=top1.avg,)
        bar.next()
    print('{batch}/{size} Data:{data:.3f} | Batch:{bt:.3f} | Total:{total:} | ETA:{eta:} | Loss:{loss:} | top1:{tp1:}'.format(
         batch=batch_idx+1, size=len(val_loader), data=data_time.val, bt=batch_time.val, total=bar.elapsed_td, eta=bar.eta_td, loss=losses.avg, tp1=top1.avg))
    bar.finish()
    return (losses.avg, top1.avg)

In [15]:
def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar'):
    filepath = os.path.join(checkpoint, filename)
    torch.save(state, filepath)
    if is_best:
        shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar'))

def adjust_learning_rate(optimizer, epoch):
    global state
    if epoch in schedule:
        state['lr'] *= gamma
        for param_group in optimizer.param_groups:
            param_group['lr'] = state['lr']

In [None]:
for epoch in range(start_epoch, epochs):
#     adjust_learning_rate(optimizer, epoch)
    state['lr'] = optimizer.state_dict()['param_groups'][0]['lr']
    print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, epochs, state['lr']))
    
    train_loss, train_acc = train(train_loader, model, criterion, optimizer, epoch, use_cuda)
    test_loss, test_acc = test(val_loader, model, criterion, epoch, use_cuda)
    
    logger.append([state['lr'], train_loss, test_loss, train_acc, test_acc])
    scheduler_warmup.step()
    is_best = test_acc > best_acc
    best_acc = max(test_acc, best_acc)
    save_checkpoint({
        'epoch': epoch + 1,
        'state_dict' : model.state_dict(),
        'acc': test_acc,
        'best_acc': best_acc,
        'optimizer': optimizer.state_dict(),
    }, is_best, checkpoint=checkpoint)


Epoch: [54 | 400] LR: 0.311374
1/833 Data:5.588 | Batch:8.555 | Total:0:00:08 | ETA:1:58:32 | Loss:0.2796719968318939 | top1:91.66667175292969
11/833 Data:0.040 | Batch:0.321 | Total:0:00:11 | ETA:0:15:55 | Loss:0.24074966731396588 | top1:92.12120819091797
21/833 Data:0.041 | Batch:0.328 | Total:0:00:15 | ETA:0:04:21 | Loss:0.21875845889250436 | top1:92.53968048095703
31/833 Data:0.040 | Batch:0.321 | Total:0:00:18 | ETA:0:04:20 | Loss:0.2078775999046141 | top1:92.79570007324219
41/833 Data:0.041 | Batch:0.324 | Total:0:00:21 | ETA:0:04:25 | Loss:0.2023892607994196 | top1:92.60162353515625
51/833 Data:0.041 | Batch:0.320 | Total:0:00:25 | ETA:0:04:24 | Loss:0.20497550902997747 | top1:92.74510192871094
61/833 Data:0.040 | Batch:0.323 | Total:0:00:28 | ETA:0:04:13 | Loss:0.2121195061529269 | top1:92.70491790771484
71/833 Data:0.041 | Batch:0.322 | Total:0:00:31 | ETA:0:04:15 | Loss:0.2126453695162921 | top1:92.76995849609375
81/833 Data:0.041 | Batch:0.328 | Total:0:00:34 | ETA:0:04:04 

711/833 Data:0.040 | Batch:0.323 | Total:0:04:04 | ETA:0:00:40 | Loss:0.22507685774992287 | top1:92.25269317626953
721/833 Data:0.041 | Batch:0.328 | Total:0:04:07 | ETA:0:00:37 | Loss:0.22582892281999203 | top1:92.21451568603516
731/833 Data:0.040 | Batch:0.321 | Total:0:04:10 | ETA:0:00:33 | Loss:0.22649623543309447 | top1:92.17738342285156
741/833 Data:0.039 | Batch:0.323 | Total:0:04:13 | ETA:0:00:30 | Loss:0.2265272733331853 | top1:92.1659927368164
751/833 Data:0.041 | Batch:0.322 | Total:0:04:17 | ETA:0:00:27 | Loss:0.22675292580803924 | top1:92.15711975097656
761/833 Data:0.040 | Batch:0.320 | Total:0:04:20 | ETA:0:00:24 | Loss:0.2271375424582291 | top1:92.15287017822266
771/833 Data:0.041 | Batch:0.323 | Total:0:04:23 | ETA:0:00:20 | Loss:0.22734364389832074 | top1:92.14224243164062
781/833 Data:0.040 | Batch:0.320 | Total:0:04:26 | ETA:0:00:17 | Loss:0.22746159857980877 | top1:92.14468383789062
791/833 Data:0.040 | Batch:0.318 | Total:0:04:30 | ETA:0:00:14 | Loss:0.22751268768



200/200 Data:0.000 | Batch:0.077 | Total:0:00:18 | ETA:0:00:00 | Loss:0.36399985127151013 | top1:89.30999755859375

Epoch: [55 | 400] LR: 0.068000
1/833 Data:5.446 | Batch:5.911 | Total:0:00:05 | ETA:1:22:33 | Loss:0.09281327575445175 | top1:96.66667175292969
11/833 Data:0.041 | Batch:0.322 | Total:0:00:09 | ETA:0:12:07 | Loss:0.2302336726676334 | top1:91.81817626953125
21/833 Data:0.041 | Batch:0.321 | Total:0:00:12 | ETA:0:04:23 | Loss:0.19779833583604722 | top1:92.93650817871094
31/833 Data:0.039 | Batch:0.321 | Total:0:00:15 | ETA:0:04:19 | Loss:0.1895282869377444 | top1:93.27957153320312
41/833 Data:0.041 | Batch:0.321 | Total:0:00:18 | ETA:0:04:15 | Loss:0.1867853797427038 | top1:93.41463470458984
51/833 Data:0.040 | Batch:0.319 | Total:0:00:22 | ETA:0:04:11 | Loss:0.19041235014504077 | top1:93.39869689941406
61/833 Data:0.040 | Batch:0.328 | Total:0:00:25 | ETA:0:04:09 | Loss:0.19428534225606528 | top1:93.38798522949219
71/833 Data:0.040 | Batch:0.320 | Total:0:00:28 | ETA:0:04:

701/833 Data:0.040 | Batch:0.325 | Total:0:03:53 | ETA:0:00:43 | Loss:0.16627850210806952 | top1:94.22254180908203
711/833 Data:0.040 | Batch:0.325 | Total:0:03:57 | ETA:0:00:40 | Loss:0.16603816248235487 | top1:94.24519348144531
721/833 Data:0.041 | Batch:0.327 | Total:0:04:00 | ETA:0:00:37 | Loss:0.16550325598159543 | top1:94.26953125
731/833 Data:0.040 | Batch:0.327 | Total:0:04:03 | ETA:0:00:33 | Loss:0.16572212905874004 | top1:94.2726821899414
741/833 Data:0.041 | Batch:0.321 | Total:0:04:06 | ETA:0:00:30 | Loss:0.16520756135685885 | top1:94.29374694824219
751/833 Data:0.041 | Batch:0.326 | Total:0:04:10 | ETA:0:00:27 | Loss:0.16474140967712422 | top1:94.30536651611328
761/833 Data:0.040 | Batch:0.321 | Total:0:04:13 | ETA:0:00:24 | Loss:0.16513094626631436 | top1:94.27945709228516
771/833 Data:0.041 | Batch:0.327 | Total:0:04:16 | ETA:0:00:21 | Loss:0.16512055767050204 | top1:94.28231811523438
781/833 Data:0.040 | Batch:0.320 | Total:0:04:19 | ETA:0:00:17 | Loss:0.165763806494456

561/833 Data:0.066 | Batch:0.367 | Total:0:03:11 | ETA:0:01:38 | Loss:0.15369501823157655 | top1:94.6197280883789
571/833 Data:0.050 | Batch:0.359 | Total:0:03:15 | ETA:0:01:34 | Loss:0.15319571066354182 | top1:94.64682006835938
581/833 Data:0.043 | Batch:0.365 | Total:0:03:18 | ETA:0:01:31 | Loss:0.15242807854050305 | top1:94.67584991455078
591/833 Data:0.046 | Batch:0.360 | Total:0:03:22 | ETA:0:01:28 | Loss:0.15247230411723078 | top1:94.68415069580078
601/833 Data:0.049 | Batch:0.362 | Total:0:03:26 | ETA:0:01:24 | Loss:0.15225874605013706 | top1:94.69218444824219
611/833 Data:0.046 | Batch:0.355 | Total:0:03:29 | ETA:0:01:20 | Loss:0.15236811063354516 | top1:94.68085479736328
621/833 Data:0.045 | Batch:0.354 | Total:0:03:33 | ETA:0:01:16 | Loss:0.15242041743715892 | top1:94.67256927490234
631/833 Data:0.043 | Batch:0.353 | Total:0:03:36 | ETA:0:01:12 | Loss:0.15274076210113505 | top1:94.65927124023438
641/833 Data:0.046 | Batch:0.354 | Total:0:03:40 | ETA:0:01:09 | Loss:0.151928094

421/833 Data:0.047 | Batch:0.355 | Total:0:02:29 | ETA:0:02:29 | Loss:0.15029688679000403 | top1:94.79414367675781
431/833 Data:0.058 | Batch:0.370 | Total:0:02:32 | ETA:0:02:24 | Loss:0.15053988409308575 | top1:94.77958679199219
441/833 Data:0.052 | Batch:0.359 | Total:0:02:36 | ETA:0:02:21 | Loss:0.15061447736036346 | top1:94.77324676513672
451/833 Data:0.044 | Batch:0.356 | Total:0:02:40 | ETA:0:02:17 | Loss:0.15111589959167984 | top1:94.73392486572266
461/833 Data:0.046 | Batch:0.354 | Total:0:02:43 | ETA:0:02:14 | Loss:0.1514536384169877 | top1:94.72523498535156
471/833 Data:0.046 | Batch:0.357 | Total:0:02:47 | ETA:0:02:10 | Loss:0.15135946539635997 | top1:94.72045135498047
481/833 Data:0.045 | Batch:0.360 | Total:0:02:50 | ETA:0:02:07 | Loss:0.15090932717581673 | top1:94.74012756347656
491/833 Data:0.046 | Batch:0.360 | Total:0:02:54 | ETA:0:02:05 | Loss:0.1513722642474711 | top1:94.7216567993164
501/833 Data:0.056 | Batch:0.379 | Total:0:02:58 | ETA:0:02:01 | Loss:0.15084769525

281/833 Data:0.040 | Batch:0.331 | Total:0:01:39 | ETA:0:03:02 | Loss:0.15194078554727428 | top1:94.86951446533203
291/833 Data:0.041 | Batch:0.330 | Total:0:01:42 | ETA:0:02:58 | Loss:0.15216472047249885 | top1:94.83963775634766
301/833 Data:0.041 | Batch:0.327 | Total:0:01:45 | ETA:0:02:55 | Loss:0.151959513075823 | top1:94.82835388183594
311/833 Data:0.040 | Batch:0.323 | Total:0:01:49 | ETA:0:02:51 | Loss:0.15156113431815932 | top1:94.84994506835938
321/833 Data:0.040 | Batch:0.326 | Total:0:01:52 | ETA:0:02:48 | Loss:0.15158966806865184 | top1:94.82347106933594
331/833 Data:0.042 | Batch:0.325 | Total:0:01:55 | ETA:0:02:46 | Loss:0.15021437528090176 | top1:94.88922882080078
341/833 Data:0.040 | Batch:0.326 | Total:0:01:59 | ETA:0:02:41 | Loss:0.14993738052039202 | top1:94.92668914794922
351/833 Data:0.040 | Batch:0.323 | Total:0:02:02 | ETA:0:02:38 | Loss:0.14936703425484504 | top1:94.93827056884766
361/833 Data:0.042 | Batch:0.326 | Total:0:02:05 | ETA:0:02:35 | Loss:0.1498806850

In [None]:
logger.close()
print('Best acc:', best_acc)