In [1]:
import json
import os
import time
import random
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import shutil

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision import models
from torchsummary import summary
from sklearn.model_selection import train_test_split

from model_pytorch import EfficientNet
from utils import Bar,Logger, AverageMeter, accuracy, mkdir_p, savefig
from warmup_scheduler import GradualWarmupScheduler

from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import collections

In [2]:
# GPU Device
gpu_id = 0
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
use_cuda = torch.cuda.is_available()
print("GPU device %d:" %(gpu_id), use_cuda)

GPU device 0: True


# Arguments

In [3]:
data_dir = '/media/data2/dataset/GAN_ImageData/StarGAN_128/'

In [4]:
pretrained = ''
resume = ''

In [5]:
# Model
model_name = 'efficientnet-b1' # b0-b7 scale

# Optimization
num_classes = 2
epochs = 300
start_epoch = 0
train_batch = 200
test_batch = 200
lr = 0.04
schedule = [75, 150, 225]
momentum = 0.9
gamma = 0.1 # LR is multiplied by gamma on schedule

# CheckPoint
checkpoint = './log/star/128/b1' # dir
if not os.path.isdir(checkpoint):
    os.makedirs(checkpoint)
num_workers = 4

# Seed
manual_seed = 7
random.seed(manual_seed)
torch.cuda.manual_seed_all(manual_seed)

# Image
size = (128, 128)

best_acc = 0

In [6]:
state = {}
state['num_classes'] = num_classes
state['epochs'] = epochs
state['start_epoch'] = start_epoch
state['train_batch'] = train_batch
state['test_batch'] = test_batch
state['lr'] = lr
state['schedule'] = schedule
state['momentum'] = momentum
state['gamma'] = gamma

# Dataset

In [7]:
train_dir = os.path.join(data_dir, 'train')
val_dir = os.path.join(data_dir, 'validation')    
train_aug = transforms.Compose([
    transforms.RandomAffine(degrees=2, translate=(0.02, 0.02), scale=(0.98, 1.02), shear=2, fillcolor=(124,117,104)),
    transforms.Resize(size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.RandomErasing(p=0.3, scale=(0.02, 0.10), ratio=(0.3, 3.3), value=0, inplace=True),
])
val_aug = transforms.Compose([
    transforms.Resize(size),
    transforms.ToTensor(),
])

# pin_memory : cuda pin memeory use
train_loader = DataLoader(datasets.ImageFolder(train_dir, transform=train_aug),
                          batch_size=train_batch, shuffle=True, num_workers=num_workers, pin_memory=True)
val_loader = DataLoader(datasets.ImageFolder(val_dir, val_aug),
                       batch_size=test_batch, shuffle=True, num_workers=num_workers, pin_memory=True)

# Model

In [8]:
model = EfficientNet.from_name(model_name, num_classes=num_classes)

# Pre-trained
if pretrained:
    print("=> using pre-trained model '{}'".format(pretrained))
    model.load_state_dict(torch.load(pretrained)['state_dict'])

In [9]:
model.to('cuda')
cudnn.benchmark = True
print('    Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0))

    Total params: 6.52M


In [10]:
# summary(model, input_size=(3,64,64), device='cuda')

# Loss

In [11]:
criterion = nn.CrossEntropyLoss().cuda()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=1e-4, nesterov=True)
# optimizer = optim.Adam(model.parameters(), weight_decay=1e-4)
scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=8, total_epoch=10, after_scheduler=scheduler_cosine)

In [12]:
# Resume
if resume:
    print('==> Resuming from checkpoint..')
    checkpoint = os.path.dirname(resume)
#     checkpoint = torch.load(resume)
    resume = torch.load(resume)
    best_acc = resume['best_acc']
    start_epoch = resume['epoch']
    model.load_state_dict(resume['state_dict'])
    optimizer.load_state_dict(resume['optimizer'])
    logger = Logger(os.path.join(checkpoint, 'log.txt'), resume=True)
else:
    logger = Logger(os.path.join(checkpoint, 'log.txt'))
    logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.'])

# Train

In [13]:
def train(train_loader, model, criterion, optimizer, epoch, use_cuda):
    model.train()
    torch.set_grad_enabled(True)
    
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    end = time.time()
    
    bar = Bar('Processing', max=len(train_loader))
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        batch_size = inputs.size(0)
        if batch_size < train_batch:
            continue
        # measure data loading time
        data_time.update(time.time() - end)

        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

        # compute output
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # measure accuracy and record loss
        prec1 = accuracy(outputs.data, targets.data)
        losses.update(loss.data.tolist(), inputs.size(0))
        top1.update(prec1[0], inputs.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # plot progress
        bar.suffix  = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} '.format(
                    batch=batch_idx + 1,
                    size=len(train_loader),
                    data=data_time.val,
                    bt=batch_time.val,
                    total=bar.elapsed_td,
                    eta=bar.eta_td,
                    loss=losses.avg,
                    top1=top1.avg,
                    )
        bar.next()
        if batch_idx % 10 == 0:
            print('{batch}/{size} Data:{data:.3f} | Batch:{bt:.3f} | Total:{total:} | ETA:{eta:} | Loss:{loss:} | top1:{tp1:}'.format(
                 batch=batch_idx+1, size=len(train_loader), data=data_time.val, bt=batch_time.val, total=bar.elapsed_td, eta=bar.eta_td, loss=losses.avg, tp1=top1.avg))
    bar.finish()
    return (losses.avg, top1.avg)

In [14]:
def test(val_loader, model, criterion, epoch, use_cuda):
    global best_acc

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to evaluate mode
    model.eval()
    torch.set_grad_enabled(False)

    end = time.time()
    bar = Bar('Processing', max=len(val_loader))
    for batch_idx, (inputs, targets) in enumerate(val_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        inputs, targets = torch.autograd.Variable(inputs, volatile=True), torch.autograd.Variable(targets)

        # compute output
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # measure accuracy and record loss
        prec1 = accuracy(outputs.data, targets.data)
        losses.update(loss.data.tolist(), inputs.size(0))
        top1.update(prec1[0], inputs.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # plot progress
        bar.suffix  = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:} | top1: {top1:}'.format(
                    batch=batch_idx + 1,
                    size=len(val_loader),
                    data=data_time.avg,
                    bt=batch_time.avg,
                    total=bar.elapsed_td,
                    eta=bar.eta_td,
                    loss=losses.avg,
                    top1=top1.avg,)
        bar.next()
    print('{batch}/{size} Data:{data:.3f} | Batch:{bt:.3f} | Total:{total:} | ETA:{eta:} | Loss:{loss:} | top1:{tp1:}'.format(
         batch=batch_idx+1, size=len(val_loader), data=data_time.val, bt=batch_time.val, total=bar.elapsed_td, eta=bar.eta_td, loss=losses.avg, tp1=top1.avg))
    bar.finish()
    return (losses.avg, top1.avg)

In [15]:
def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar'):
    filepath = os.path.join(checkpoint, filename)
    torch.save(state, filepath)
    if is_best:
        shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar'))

def adjust_learning_rate(optimizer, epoch):
    global state
    lr_set = [1, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5]
    lr_list = schedule.copy()
    lr_list.append(epoch)
    lr_list.sort()
    idx = lr_list.index(epoch)
    state['lr'] *= lr_set[idx]
    for param_group in optimizer.param_groups:
        param_group['lr'] = state['lr']

In [None]:
for epoch in range(start_epoch, epochs):
    state['lr'] = optimizer.state_dict()['param_groups'][0]['lr']
    adjust_learning_rate(optimizer, epoch)
    print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, epochs, state['lr']))
    train_loss, train_acc = train(train_loader, model, criterion, optimizer, epoch, use_cuda)
    test_loss, test_acc = test(val_loader, model, criterion, epoch, use_cuda)
    
    logger.append([state['lr'], train_loss, test_loss, train_acc, test_acc])
    scheduler_warmup.step()

    is_best = test_acc > best_acc
    best_acc = max(test_acc, best_acc)
    save_checkpoint({
        'epoch': epoch + 1,
        'state_dict' : model.state_dict(),
        'acc': test_acc,
        'best_acc': best_acc,
        'optimizer': optimizer.state_dict(),
    }, is_best, checkpoint=checkpoint)


Epoch: [1 | 300] LR: 0.040000
1/1374 Data:4.778 | Batch:8.775 | Total:0:00:08 | ETA:3:20:49 | Loss:0.6947049498558044 | top1:49.5
11/1374 Data:0.010 | Batch:0.719 | Total:0:00:15 | ETA:0:32:34 | Loss:0.7476149689067494 | top1:50.09090805053711
21/1374 Data:0.006 | Batch:0.640 | Total:0:00:22 | ETA:0:17:37 | Loss:0.7443042596181234 | top1:50.83333206176758
31/1374 Data:0.001 | Batch:0.730 | Total:0:00:33 | ETA:0:24:25 | Loss:0.7497971057891846 | top1:50.51613235473633
41/1374 Data:0.003 | Batch:0.704 | Total:0:00:44 | ETA:0:24:05 | Loss:0.7440140378184434 | top1:50.46341323852539
51/1374 Data:0.003 | Batch:0.778 | Total:0:00:54 | ETA:0:21:37 | Loss:0.7421530157912011 | top1:50.33333206176758
61/1374 Data:0.001 | Batch:0.550 | Total:0:01:03 | ETA:0:19:56 | Loss:0.7452344698984115 | top1:50.16393280029297
71/1374 Data:0.007 | Batch:0.664 | Total:0:01:12 | ETA:0:19:57 | Loss:0.7419868522966412 | top1:50.04225540161133
81/1374 Data:0.006 | Batch:0.609 | Total:0:01:19 | ETA:0:14:12 | Loss:0

721/1374 Data:0.001 | Batch:0.490 | Total:0:07:14 | ETA:0:05:24 | Loss:0.7066438475858818 | top1:50.16435623168945
731/1374 Data:0.003 | Batch:0.485 | Total:0:07:19 | ETA:0:04:42 | Loss:0.7065048694773672 | top1:50.167579650878906
741/1374 Data:0.040 | Batch:0.484 | Total:0:07:24 | ETA:0:05:35 | Loss:0.7063694976762882 | top1:50.17409133911133
751/1374 Data:0.001 | Batch:0.425 | Total:0:07:29 | ETA:0:05:04 | Loss:0.7062083797829446 | top1:50.20372772216797
761/1374 Data:0.001 | Batch:0.516 | Total:0:07:33 | ETA:0:04:47 | Loss:0.7060715964996549 | top1:50.21156311035156
771/1374 Data:0.001 | Batch:0.470 | Total:0:07:38 | ETA:0:04:39 | Loss:0.7059846882226403 | top1:50.209468841552734
781/1374 Data:0.001 | Batch:0.510 | Total:0:07:43 | ETA:0:04:47 | Loss:0.7058205488549305 | top1:50.20870590209961
791/1374 Data:0.001 | Batch:0.420 | Total:0:07:48 | ETA:0:04:39 | Loss:0.7056997121510704 | top1:50.208595275878906
801/1374 Data:0.001 | Batch:0.482 | Total:0:07:52 | ETA:0:04:16 | Loss:0.7056



153/153 Data:0.000 | Batch:0.786 | Total:0:01:02 | ETA:0:00:00 | Loss:0.6940237873503609 | top1:50.429229736328125

Epoch: [2 | 300] LR: 0.068000
1/1374 Data:1.162 | Batch:1.852 | Total:0:00:01 | ETA:0:42:24 | Loss:0.6997387409210205 | top1:47.5
11/1374 Data:0.000 | Batch:0.394 | Total:0:00:06 | ETA:0:13:49 | Loss:0.6987679275599393 | top1:50.681819915771484
21/1374 Data:0.001 | Batch:0.482 | Total:0:00:10 | ETA:0:09:54 | Loss:0.6989927973066058 | top1:51.35714340209961
31/1374 Data:0.001 | Batch:0.371 | Total:0:00:15 | ETA:0:10:04 | Loss:0.6977941778398329 | top1:51.000003814697266
41/1374 Data:0.001 | Batch:0.459 | Total:0:00:19 | ETA:0:09:44 | Loss:0.6977050537016334 | top1:50.70731735229492
51/1374 Data:0.001 | Batch:0.476 | Total:0:00:24 | ETA:0:10:10 | Loss:0.6974992857259863 | top1:50.97058868408203
61/1374 Data:0.001 | Batch:0.565 | Total:0:00:28 | ETA:0:09:32 | Loss:0.6974737888476887 | top1:51.02458953857422
71/1374 Data:0.001 | Batch:0.462 | Total:0:00:33 | ETA:0:10:44 | Los

711/1374 Data:0.001 | Batch:0.607 | Total:0:05:52 | ETA:0:06:33 | Loss:0.6956410822989065 | top1:51.14134979248047
721/1374 Data:0.002 | Batch:0.636 | Total:0:05:58 | ETA:0:07:12 | Loss:0.695562124004311 | top1:51.17753219604492
731/1374 Data:0.003 | Batch:0.637 | Total:0:06:05 | ETA:0:07:12 | Loss:0.6955384610452665 | top1:51.193572998046875
741/1374 Data:0.001 | Batch:0.640 | Total:0:06:12 | ETA:0:06:58 | Loss:0.6955070173048619 | top1:51.194332122802734
751/1374 Data:0.001 | Batch:0.650 | Total:0:06:18 | ETA:0:06:47 | Loss:0.6954355785913379 | top1:51.228363037109375
761/1374 Data:0.003 | Batch:0.700 | Total:0:06:25 | ETA:0:06:54 | Loss:0.6953845577077076 | top1:51.25295639038086
771/1374 Data:0.002 | Batch:0.671 | Total:0:06:31 | ETA:0:06:34 | Loss:0.6953488003109835 | top1:51.27756118774414
781/1374 Data:0.002 | Batch:0.516 | Total:0:06:38 | ETA:0:06:41 | Loss:0.6952881658886215 | top1:51.30665969848633
791/1374 Data:0.001 | Batch:0.562 | Total:0:06:43 | ETA:0:04:40 | Loss:0.69529