In [1]:
import json
import os
import time
import random
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import shutil

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision import models
from torchsummary import summary
from sklearn.model_selection import train_test_split

from model_pytorch import EfficientNet
from utils import Bar,Logger, AverageMeter, accuracy, mkdir_p, savefig
from warmup_scheduler import GradualWarmupScheduler

from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import collections

In [2]:
# GPU Device
gpu_id = 0
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
use_cuda = torch.cuda.is_available()
print("GPU device %d:" %(gpu_id), use_cuda)

GPU device 0: True


# Arguments

In [3]:
data_dir = '/media/data2/dataset/GAN_ImageData/PGGAN_128'

In [4]:
pretrained = ''
resume = ''

In [5]:
# Model
model_name = 'efficientnet-b1' # b0-b7 scale

# Optimization
num_classes = 2
epochs = 300
start_epoch = 0
train_batch = 256
test_batch = 200
lr = 0.1
schedule = [20, 125, 225, 275]
momentum = 0.9
gamma = 0.1 # LR is multiplied by gamma on schedule

# CheckPoint
checkpoint = './log/pggan/128x128/b1' # dir
if not os.path.isdir(checkpoint):
    os.mkdir(checkpoint)
num_workers = 4

# Seed
manual_seed = 7
random.seed(manual_seed)
torch.cuda.manual_seed_all(manual_seed)

# Image
size = (128, 128)

best_acc = 0

In [6]:
state = {}
state['num_classes'] = num_classes
state['epochs'] = epochs
state['start_epoch'] = start_epoch
state['train_batch'] = train_batch
state['test_batch'] = test_batch
state['lr'] = lr
state['schedule'] = schedule
state['momentum'] = momentum
state['gamma'] = gamma

# Dataset

In [7]:
train_dir = os.path.join(data_dir, 'train')
val_dir = os.path.join(data_dir, 'validation')    
train_aug = transforms.Compose([
    transforms.RandomAffine(degrees=2, translate=(0.02, 0.02), scale=(0.98, 1.02), shear=2, fillcolor=(124,117,104)),
    transforms.Resize(size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.RandomErasing(p=0.3, scale=(0.02, 0.10), ratio=(0.3, 3.3), value=0, inplace=True),
])
val_aug = transforms.Compose([
    transforms.Resize(size),
    transforms.ToTensor(),
])

# pin_memory : cuda pin memeory use
train_loader = DataLoader(datasets.ImageFolder(train_dir, transform=train_aug),
                          batch_size=train_batch, shuffle=True, num_workers=num_workers, pin_memory=True)
val_loader = DataLoader(datasets.ImageFolder(val_dir, val_aug),
                       batch_size=test_batch, shuffle=True, num_workers=num_workers, pin_memory=True)

# Model

In [8]:
model = EfficientNet.from_name(model_name, num_classes=num_classes)

# Pre-trained
if pretrained:
    print("=> using pre-trained model '{}'".format(pretrained))
    model.load_state_dict(torch.load(pretrained)['state_dict'])

In [9]:
model.to('cuda')
cudnn.benchmark = True
print('    Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0))

    Total params: 6.52M


In [10]:
# summary(model, input_size=(3,64,64), device='cuda')

# Loss

In [11]:
criterion = nn.CrossEntropyLoss().cuda()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=1e-4, nesterov=True)
# optimizer = optim.Adam(model.parameters(), weight_decay=0)
scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=8, total_epoch=10, after_scheduler=scheduler_cosine)

In [12]:
# Resume
if resume:
    print('==> Resuming from checkpoint..')
    checkpoint = os.path.dirname(resume)
#     checkpoint = torch.load(resume)
    resume = torch.load(resume)
    best_acc = resume['best_acc']
    start_epoch = resume['epoch']
    model.load_state_dict(resume['state_dict'])
    optimizer.load_state_dict(resume['optimizer'])
    logger = Logger(os.path.join(checkpoint, 'log.txt'), resume=True)
else:
    logger = Logger(os.path.join(checkpoint, 'log.txt'))
    logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.'])

# Train

In [13]:
def train(train_loader, model, criterion, optimizer, epoch, use_cuda):
    model.train()
    torch.set_grad_enabled(True)
    
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    end = time.time()
    
    bar = Bar('Processing', max=len(train_loader))
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        batch_size = inputs.size(0)
        if batch_size < train_batch:
            continue
        # measure data loading time
        data_time.update(time.time() - end)

        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

        # compute output
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # measure accuracy and record loss
        prec1 = accuracy(outputs.data, targets.data)
        losses.update(loss.data.tolist(), inputs.size(0))
        top1.update(prec1[0], inputs.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # plot progress
        bar.suffix  = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} '.format(
                    batch=batch_idx + 1,
                    size=len(train_loader),
                    data=data_time.val,
                    bt=batch_time.val,
                    total=bar.elapsed_td,
                    eta=bar.eta_td,
                    loss=losses.avg,
                    top1=top1.avg,
                    )
        bar.next()
        if batch_idx % 10 == 0:
            print('{batch}/{size} Data:{data:.3f} | Batch:{bt:.3f} | Total:{total:} | ETA:{eta:} | Loss:{loss:} | top1:{tp1:}'.format(
                 batch=batch_idx+1, size=len(train_loader), data=data_time.val, bt=batch_time.val, total=bar.elapsed_td, eta=bar.eta_td, loss=losses.avg, tp1=top1.avg))
    bar.finish()
    return (losses.avg, top1.avg)

In [14]:
def test(val_loader, model, criterion, epoch, use_cuda):
    global best_acc

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to evaluate mode
    model.eval()
    torch.set_grad_enabled(False)

    end = time.time()
    bar = Bar('Processing', max=len(val_loader))
    for batch_idx, (inputs, targets) in enumerate(val_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        inputs, targets = torch.autograd.Variable(inputs, volatile=True), torch.autograd.Variable(targets)

        # compute output
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # measure accuracy and record loss
        prec1 = accuracy(outputs.data, targets.data)
        losses.update(loss.data.tolist(), inputs.size(0))
        top1.update(prec1[0], inputs.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # plot progress
        bar.suffix  = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:} | top1: {top1:}'.format(
                    batch=batch_idx + 1,
                    size=len(val_loader),
                    data=data_time.avg,
                    bt=batch_time.avg,
                    total=bar.elapsed_td,
                    eta=bar.eta_td,
                    loss=losses.avg,
                    top1=top1.avg,)
        bar.next()
    print('{batch}/{size} Data:{data:.3f} | Batch:{bt:.3f} | Total:{total:} | ETA:{eta:} | Loss:{loss:} | top1:{tp1:}'.format(
         batch=batch_idx+1, size=len(val_loader), data=data_time.val, bt=batch_time.val, total=bar.elapsed_td, eta=bar.eta_td, loss=losses.avg, tp1=top1.avg))
    bar.finish()
    return (losses.avg, top1.avg)

In [15]:
def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar'):
    filepath = os.path.join(checkpoint, filename)
    torch.save(state, filepath)
    if is_best:
        shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar'))

def adjust_learning_rate(optimizer, epoch):
    global state
    lr_set = [1, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5]
    lr_list = schedule.copy()
    lr_list.append(epoch)
    lr_list.sort()
    idx = lr_list.index(epoch)
    state['lr'] *= lr_set[idx]
    for param_group in optimizer.param_groups:
        param_group['lr'] = state['lr']

In [None]:
for epoch in range(start_epoch, epochs):
    state['lr'] = optimizer.state_dict()['param_groups'][0]['lr']
    adjust_learning_rate(optimizer, epoch)
    print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, epochs, state['lr']))
    
    train_loss, train_acc = train(train_loader, model, criterion, optimizer, epoch, use_cuda)
    test_loss, test_acc = test(val_loader, model, criterion, epoch, use_cuda)
    
    logger.append([state['lr'], train_loss, test_loss, train_acc, test_acc])
    scheduler_warmup.step()

    is_best = test_acc > best_acc
    best_acc = max(test_acc, best_acc)
    save_checkpoint({
        'epoch': epoch + 1,
        'state_dict' : model.state_dict(),
        'acc': test_acc,
        'best_acc': best_acc,
        'optimizer': optimizer.state_dict(),
    }, is_best, checkpoint=checkpoint)


Epoch: [1 | 300] LR: 0.100000
1/502 Data:1.177 | Batch:6.873 | Total:0:00:06 | ETA:0:57:24 | Loss:0.7125661969184875 | top1:45.3125
11/502 Data:0.001 | Batch:0.956 | Total:0:00:16 | ETA:0:12:36 | Loss:4.01255205002698 | top1:49.431819915771484
21/502 Data:0.007 | Batch:0.960 | Total:0:00:25 | ETA:0:07:34 | Loss:2.787099994364239 | top1:49.795387268066406
31/502 Data:0.001 | Batch:0.860 | Total:0:00:35 | ETA:0:07:22 | Loss:2.1186306957275636 | top1:50.151206970214844
41/502 Data:0.003 | Batch:0.950 | Total:0:00:44 | ETA:0:07:04 | Loss:1.773259055323717 | top1:50.14291000366211
51/502 Data:0.003 | Batch:0.981 | Total:0:00:53 | ETA:0:07:08 | Loss:1.5625986024445178 | top1:50.38296890258789
61/502 Data:0.003 | Batch:0.912 | Total:0:01:03 | ETA:0:07:00 | Loss:1.4228033267083715 | top1:50.08964920043945
71/502 Data:0.002 | Batch:0.931 | Total:0:01:12 | ETA:0:06:49 | Loss:1.3222534807635025 | top1:49.98349380493164
81/502 Data:0.003 | Batch:0.920 | Total:0:01:22 | ETA:0:06:37 | Loss:1.245906



161/161 Data:0.002 | Batch:0.684 | Total:0:00:40 | ETA:0:00:00 | Loss:0.6936783963274733 | top1:50.000003814697266

Epoch: [2 | 300] LR: 0.170000
1/502 Data:0.956 | Batch:2.011 | Total:0:00:02 | ETA:0:16:49 | Loss:0.6906084418296814 | top1:55.078125
11/502 Data:0.002 | Batch:0.951 | Total:0:00:11 | ETA:0:08:38 | Loss:0.7067780928178267 | top1:48.650569915771484
21/502 Data:0.001 | Batch:0.925 | Total:0:00:20 | ETA:0:07:22 | Loss:0.7044571865172613 | top1:49.572174072265625
31/502 Data:0.003 | Batch:0.876 | Total:0:00:30 | ETA:0:07:26 | Loss:0.7033193246010812 | top1:49.71017837524414
41/502 Data:0.002 | Batch:0.701 | Total:0:00:37 | ETA:0:06:07 | Loss:0.7033702222312369 | top1:49.952362060546875
51/502 Data:0.001 | Batch:0.776 | Total:0:00:45 | ETA:0:05:24 | Loss:0.7028513913061104 | top1:50.091915130615234
61/502 Data:0.001 | Batch:0.467 | Total:0:00:51 | ETA:0:04:59 | Loss:0.7028621757616762 | top1:50.0576286315918
71/502 Data:0.002 | Batch:0.928 | Total:0:01:01 | ETA:0:06:44 | Loss:

191/502 Data:0.002 | Batch:0.936 | Total:0:02:52 | ETA:0:04:47 | Loss:0.6960025582638086 | top1:50.020450592041016
201/502 Data:0.008 | Batch:0.956 | Total:0:03:01 | ETA:0:04:39 | Loss:0.6959513633405391 | top1:50.03303527832031
211/502 Data:0.001 | Batch:0.944 | Total:0:03:11 | ETA:0:04:37 | Loss:0.6959525716813255 | top1:50.02592086791992
221/502 Data:0.001 | Batch:0.891 | Total:0:03:20 | ETA:0:04:19 | Loss:0.6958533557831432 | top1:50.086612701416016
231/502 Data:0.001 | Batch:0.914 | Total:0:03:29 | ETA:0:04:10 | Loss:0.6957242352105838 | top1:50.180938720703125
241/502 Data:0.002 | Batch:0.925 | Total:0:03:38 | ETA:0:04:06 | Loss:0.695767320785285 | top1:50.13777542114258
251/502 Data:0.006 | Batch:0.951 | Total:0:03:48 | ETA:0:03:58 | Loss:0.6954904144028743 | top1:50.31592559814453
261/502 Data:0.001 | Batch:0.945 | Total:0:03:57 | ETA:0:03:43 | Loss:0.696003350718268 | top1:50.45797348022461
271/502 Data:0.002 | Batch:0.934 | Total:0:04:06 | ETA:0:03:35 | Loss:0.695757985115051

391/502 Data:0.001 | Batch:0.936 | Total:0:05:55 | ETA:0:01:41 | Loss:0.43292614276451835 | top1:79.7744140625
401/502 Data:0.000 | Batch:0.893 | Total:0:06:04 | ETA:0:01:33 | Loss:0.4330442741773372 | top1:79.76543426513672
411/502 Data:0.001 | Batch:0.941 | Total:0:06:14 | ETA:0:01:25 | Loss:0.4328284505800029 | top1:79.77873992919922
421/502 Data:0.001 | Batch:0.924 | Total:0:06:23 | ETA:0:01:16 | Loss:0.4322073504930437 | top1:79.83502960205078
431/502 Data:0.005 | Batch:0.909 | Total:0:06:32 | ETA:0:01:06 | Loss:0.43162299212059674 | top1:79.87419891357422
441/502 Data:0.014 | Batch:0.970 | Total:0:06:42 | ETA:0:00:58 | Loss:0.43162431704754733 | top1:79.86376953125
451/502 Data:0.001 | Batch:0.914 | Total:0:06:51 | ETA:0:00:48 | Loss:0.4314710033309962 | top1:79.85552978515625
461/502 Data:0.002 | Batch:0.926 | Total:0:07:00 | ETA:0:00:38 | Loss:0.43140099797228154 | top1:79.86459350585938
471/502 Data:0.002 | Batch:0.937 | Total:0:07:09 | ETA:0:00:29 | Loss:0.431456470628706 | t

71/502 Data:0.002 | Batch:0.928 | Total:0:01:07 | ETA:0:06:38 | Loss:0.4256589887007861 | top1:79.83604431152344
81/502 Data:0.001 | Batch:0.948 | Total:0:01:16 | ETA:0:06:30 | Loss:0.423126733597414 | top1:80.078125
91/502 Data:0.001 | Batch:0.945 | Total:0:01:25 | ETA:0:06:25 | Loss:0.42434693430806253 | top1:79.99656677246094
101/502 Data:0.002 | Batch:0.893 | Total:0:01:35 | ETA:0:06:20 | Loss:0.4250613636899703 | top1:79.95436096191406
111/502 Data:0.003 | Batch:0.893 | Total:0:01:44 | ETA:0:06:06 | Loss:0.42447531035354547 | top1:80.01478576660156
121/502 Data:0.002 | Batch:0.937 | Total:0:01:53 | ETA:0:05:51 | Loss:0.4227267200789176 | top1:80.165283203125
131/502 Data:0.003 | Batch:0.926 | Total:0:02:03 | ETA:0:05:40 | Loss:0.42142670413919986 | top1:80.24510955810547
141/502 Data:0.001 | Batch:0.944 | Total:0:02:12 | ETA:0:05:42 | Loss:0.4214046147275478 | top1:80.26927947998047
151/502 Data:0.003 | Batch:0.955 | Total:0:02:21 | ETA:0:05:29 | Loss:0.420866208558051 | top1:80.3

271/502 Data:0.003 | Batch:0.922 | Total:0:04:11 | ETA:0:03:35 | Loss:0.4129192807137746 | top1:80.784423828125
281/502 Data:0.002 | Batch:0.957 | Total:0:04:20 | ETA:0:03:27 | Loss:0.4129495388460329 | top1:80.78430938720703
291/502 Data:0.003 | Batch:0.979 | Total:0:04:29 | ETA:0:03:17 | Loss:0.4129787370101693 | top1:80.7828598022461
301/502 Data:0.002 | Batch:0.897 | Total:0:04:39 | ETA:0:03:07 | Loss:0.41320405687604633 | top1:80.76982879638672
311/502 Data:0.003 | Batch:0.496 | Total:0:04:46 | ETA:0:02:20 | Loss:0.4134148653488834 | top1:80.755126953125
321/502 Data:0.001 | Batch:0.483 | Total:0:04:53 | ETA:0:02:08 | Loss:0.4129210085698006 | top1:80.79366302490234
331/502 Data:0.001 | Batch:0.541 | Total:0:04:59 | ETA:0:01:53 | Loss:0.4126497974388549 | top1:80.82160949707031
341/502 Data:0.007 | Batch:0.902 | Total:0:05:08 | ETA:0:02:16 | Loss:0.4131396905767603 | top1:80.784912109375
351/502 Data:0.002 | Batch:0.952 | Total:0:05:17 | ETA:0:02:22 | Loss:0.41295065278680915 | to