In [1]:
import json
import os
import time
import random
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import shutil

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision import models
from torchsummary import summary
from sklearn.model_selection import train_test_split

from model_pytorch import EfficientNet
from utils import Bar,Logger, AverageMeter, accuracy, mkdir_p, savefig
from warmup_scheduler import GradualWarmupScheduler

from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import collections

In [2]:
# GPU Device
gpu_id = 3
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
use_cuda = torch.cuda.is_available()
print("GPU device %d:" %(gpu_id), use_cuda)

GPU device 3: True


# Arguments

In [3]:
data_dir = '/media/data2/dataset/GAN_ImageData/PGGAN_128/'

In [4]:
pretrained = ''
resume = ''

In [5]:
# Model
model_name = 'efficientnet-b1' # b0-b7 scale

# Optimization
num_classes = 2
epochs = 300
start_epoch = 0
train_batch = 756
test_batch = 500
lr = 0.04
schedule = [75, 150, 225]
momentum = 0.9
gamma = 0.1 # LR is multiplied by gamma on schedule

# CheckPoint
checkpoint = './log/pggan/64/b1' # dir
if not os.path.isdir(checkpoint):
    os.mkdir(checkpoint)
num_workers = 4

# Seed
manual_seed = 7
random.seed(manual_seed)
torch.cuda.manual_seed_all(manual_seed)

# Image
size = (64, 64)

best_acc = 0

In [6]:
state = {}
state['num_classes'] = num_classes
state['epochs'] = epochs
state['start_epoch'] = start_epoch
state['train_batch'] = train_batch
state['test_batch'] = test_batch
state['lr'] = lr
state['schedule'] = schedule
state['momentum'] = momentum
state['gamma'] = gamma

# Dataset

In [7]:
train_dir = os.path.join(data_dir, 'train')
val_dir = os.path.join(data_dir, 'validation')    
train_aug = transforms.Compose([
#     transforms.RandomAffine(degrees=2, translate=(0.02, 0.02), scale=(0.98, 1.02), shear=2, fillcolor=(124,117,104)),
    transforms.Resize(size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
#     transforms.RandomErasing(p=0.3, scale=(0.02, 0.10), ratio=(0.3, 3.3), value=0, inplace=True),
])
val_aug = transforms.Compose([
    transforms.Resize(size),
    transforms.ToTensor(),
])

# pin_memory : cuda pin memeory use
train_loader = DataLoader(datasets.ImageFolder(train_dir, transform=train_aug),
                          batch_size=train_batch, shuffle=True, num_workers=num_workers, pin_memory=True)
val_loader = DataLoader(datasets.ImageFolder(val_dir, val_aug),
                       batch_size=test_batch, shuffle=True, num_workers=num_workers, pin_memory=True)

# Model

In [8]:
model = EfficientNet.from_name(model_name, num_classes=num_classes)

# Pre-trained
if pretrained:
    print("=> using pre-trained model '{}'".format(pretrained))
    model.load_state_dict(torch.load(pretrained)['state_dict'])

In [9]:
model.to('cuda')
cudnn.benchmark = True
print('    Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0))

    Total params: 6.52M


In [10]:
# summary(model, input_size=(3,64,64), device='cuda')

# Loss

In [11]:
criterion = nn.CrossEntropyLoss().cuda()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=1e-4, nesterov=True)
# optimizer = optim.Adam(model.parameters(), weight_decay=1e-4)
scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=8, total_epoch=10, after_scheduler=scheduler_cosine)

In [12]:
# Resume
if resume:
    print('==> Resuming from checkpoint..')
    checkpoint = os.path.dirname(resume)
#     checkpoint = torch.load(resume)
    resume = torch.load(resume)
    best_acc = resume['best_acc']
    start_epoch = resume['epoch']
    model.load_state_dict(resume['state_dict'])
    optimizer.load_state_dict(resume['optimizer'])
    logger = Logger(os.path.join(checkpoint, 'log.txt'), resume=True)
else:
    logger = Logger(os.path.join(checkpoint, 'log.txt'))
    logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.'])

# Train

In [13]:
def train(train_loader, model, criterion, optimizer, epoch, use_cuda):
    model.train()
    torch.set_grad_enabled(True)
    
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    end = time.time()
    
    bar = Bar('Processing', max=len(train_loader))
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        batch_size = inputs.size(0)
        if batch_size < train_batch:
            continue
        # measure data loading time
        data_time.update(time.time() - end)

        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

        # compute output
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # measure accuracy and record loss
        prec1 = accuracy(outputs.data, targets.data)
        losses.update(loss.data.tolist(), inputs.size(0))
        top1.update(prec1[0], inputs.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # plot progress
        bar.suffix  = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} '.format(
                    batch=batch_idx + 1,
                    size=len(train_loader),
                    data=data_time.val,
                    bt=batch_time.val,
                    total=bar.elapsed_td,
                    eta=bar.eta_td,
                    loss=losses.avg,
                    top1=top1.avg,
                    )
        bar.next()
        if batch_idx % 10 == 0:
            print('{batch}/{size} Data:{data:.3f} | Batch:{bt:.3f} | Total:{total:} | ETA:{eta:} | Loss:{loss:} | top1:{tp1:}'.format(
                 batch=batch_idx+1, size=len(train_loader), data=data_time.val, bt=batch_time.val, total=bar.elapsed_td, eta=bar.eta_td, loss=losses.avg, tp1=top1.avg))
    bar.finish()
    return (losses.avg, top1.avg)

In [14]:
def test(val_loader, model, criterion, epoch, use_cuda):
    global best_acc

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to evaluate mode
    model.eval()
    torch.set_grad_enabled(False)

    end = time.time()
    bar = Bar('Processing', max=len(val_loader))
    for batch_idx, (inputs, targets) in enumerate(val_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        inputs, targets = torch.autograd.Variable(inputs, volatile=True), torch.autograd.Variable(targets)

        # compute output
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # measure accuracy and record loss
        prec1 = accuracy(outputs.data, targets.data)
        losses.update(loss.data.tolist(), inputs.size(0))
        top1.update(prec1[0], inputs.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # plot progress
        bar.suffix  = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:} | top1: {top1:}'.format(
                    batch=batch_idx + 1,
                    size=len(val_loader),
                    data=data_time.avg,
                    bt=batch_time.avg,
                    total=bar.elapsed_td,
                    eta=bar.eta_td,
                    loss=losses.avg,
                    top1=top1.avg,)
        bar.next()
    print('{batch}/{size} Data:{data:.3f} | Batch:{bt:.3f} | Total:{total:} | ETA:{eta:} | Loss:{loss:} | top1:{tp1:}'.format(
         batch=batch_idx+1, size=len(val_loader), data=data_time.val, bt=batch_time.val, total=bar.elapsed_td, eta=bar.eta_td, loss=losses.avg, tp1=top1.avg))
    bar.finish()
    return (losses.avg, top1.avg)

In [15]:
def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar'):
    filepath = os.path.join(checkpoint, filename)
    torch.save(state, filepath)
    if is_best:
        shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar'))

def adjust_learning_rate(optimizer, epoch):
    global state
    lr_set = [1, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5]
    lr_list = schedule.copy()
    lr_list.append(epoch)
    lr_list.sort()
    idx = lr_list.index(epoch)
    state['lr'] *= lr_set[idx]
    for param_group in optimizer.param_groups:
        param_group['lr'] = state['lr']

In [None]:
for epoch in range(start_epoch, epochs):
    state['lr'] = optimizer.state_dict()['param_groups'][0]['lr']
    adjust_learning_rate(optimizer, epoch)
    print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, epochs, state['lr']))
    train_loss, train_acc = train(train_loader, model, criterion, optimizer, epoch, use_cuda)
    test_loss, test_acc = test(val_loader, model, criterion, epoch, use_cuda)
    
    logger.append([state['lr'], train_loss, test_loss, train_acc, test_acc])
    scheduler_warmup.step()

    is_best = test_acc > best_acc
    best_acc = max(test_acc, best_acc)
    save_checkpoint({
        'epoch': epoch + 1,
        'state_dict' : model.state_dict(),
        'acc': test_acc,
        'best_acc': best_acc,
        'optimizer': optimizer.state_dict(),
    }, is_best, checkpoint=checkpoint)


Epoch: [1 | 300] LR: 0.040000
1/170 Data:1.570 | Batch:5.456 | Total:0:00:05 | ETA:0:15:23 | Loss:0.7146677374839783 | top1:48.94179916381836
11/170 Data:0.001 | Batch:0.505 | Total:0:00:10 | ETA:0:02:42 | Loss:1.4604379317977212 | top1:49.49494934082031
21/170 Data:0.002 | Batch:0.517 | Total:0:00:15 | ETA:0:01:17 | Loss:1.3155791220210848 | top1:50.35273742675781
31/170 Data:0.001 | Batch:0.510 | Total:0:00:20 | ETA:0:01:12 | Loss:1.1421904198585018 | top1:50.166412353515625
41/170 Data:0.001 | Batch:0.506 | Total:0:00:26 | ETA:0:01:07 | Loss:1.053890917359329 | top1:49.90966796875
51/170 Data:0.001 | Batch:0.509 | Total:0:00:31 | ETA:0:01:01 | Loss:0.9852287488825181 | top1:49.95850372314453
61/170 Data:0.001 | Batch:0.512 | Total:0:00:36 | ETA:0:00:56 | Loss:0.9385558767396895 | top1:49.91326141357422
71/170 Data:0.001 | Batch:0.524 | Total:0:00:41 | ETA:0:00:53 | Loss:0.9050314837778118 | top1:49.919891357421875
81/170 Data:0.001 | Batch:0.526 | Total:0:00:46 | ETA:0:00:47 | Loss



65/65 Data:0.000 | Batch:0.343 | Total:0:00:12 | ETA:0:00:00 | Loss:0.6988236948708507 | top1:50.000003814697266

Epoch: [2 | 300] LR: 0.068000
1/170 Data:1.980 | Batch:2.543 | Total:0:00:02 | ETA:0:07:10 | Loss:0.6981512904167175 | top1:50.925926208496094
11/170 Data:0.001 | Batch:0.534 | Total:0:00:07 | ETA:0:01:57 | Loss:1.1397061944007874 | top1:49.4348258972168
21/170 Data:0.001 | Batch:0.523 | Total:0:00:13 | ETA:0:01:21 | Loss:1.055970311164856 | top1:49.88032531738281
31/170 Data:0.001 | Batch:0.512 | Total:0:00:18 | ETA:0:01:14 | Loss:1.011167566622457 | top1:49.735450744628906
41/170 Data:0.001 | Batch:0.534 | Total:0:00:23 | ETA:0:01:08 | Loss:0.9467596920525155 | top1:49.87095260620117
51/170 Data:0.001 | Batch:0.526 | Total:0:00:29 | ETA:0:01:05 | Loss:0.8988683504216811 | top1:49.815853118896484
61/170 Data:0.001 | Batch:0.517 | Total:0:00:34 | ETA:0:01:00 | Loss:0.8696197144320754 | top1:49.770145416259766
71/170 Data:0.001 | Batch:0.524 | Total:0:00:39 | ETA:0:00:51 | L

65/65 Data:0.000 | Batch:0.070 | Total:0:00:11 | ETA:0:00:00 | Loss:0.6939394388614786 | top1:50.000003814697266

Epoch: [6 | 300] LR: 0.180000
1/170 Data:2.112 | Batch:2.694 | Total:0:00:02 | ETA:0:07:36 | Loss:0.6946014761924744 | top1:49.338623046875
11/170 Data:0.021 | Batch:0.786 | Total:0:00:10 | ETA:0:02:34 | Loss:0.6963306231932207 | top1:49.98797607421875
21/170 Data:0.001 | Batch:0.516 | Total:0:00:18 | ETA:0:02:03 | Loss:0.6956309335572379 | top1:50.176368713378906
31/170 Data:0.001 | Batch:0.816 | Total:0:00:26 | ETA:0:01:44 | Loss:0.6951173620839273 | top1:50.17494583129883
41/170 Data:0.001 | Batch:0.886 | Total:0:00:33 | ETA:0:01:39 | Loss:0.6950761879362711 | top1:50.23551559448242
51/170 Data:0.002 | Batch:0.759 | Total:0:00:41 | ETA:0:01:28 | Loss:0.6950677750157375 | top1:50.176368713378906
61/170 Data:0.001 | Batch:0.723 | Total:0:00:49 | ETA:0:01:30 | Loss:0.6954300892157633 | top1:50.02602005004883
71/170 Data:0.006 | Batch:0.556 | Total:0:00:57 | ETA:0:01:22 | Lo

65/65 Data:0.000 | Batch:0.062 | Total:0:00:12 | ETA:0:00:00 | Loss:0.6666027292655636 | top1:55.90031433105469

Epoch: [10 | 300] LR: 0.292000
1/170 Data:1.452 | Batch:2.020 | Total:0:00:02 | ETA:0:05:42 | Loss:0.6838404536247253 | top1:53.96825408935547
11/170 Data:0.001 | Batch:0.519 | Total:0:00:07 | ETA:0:01:48 | Loss:0.6777945377609946 | top1:57.455509185791016
21/170 Data:0.002 | Batch:0.532 | Total:0:00:12 | ETA:0:01:18 | Loss:0.6706057957240513 | top1:58.711265563964844
31/170 Data:0.001 | Batch:0.510 | Total:0:00:17 | ETA:0:01:14 | Loss:0.670417224207232 | top1:59.66461944580078
41/170 Data:0.001 | Batch:0.561 | Total:0:00:23 | ETA:0:01:08 | Loss:0.6784354462856199 | top1:57.155765533447266
51/170 Data:0.001 | Batch:0.509 | Total:0:00:28 | ETA:0:01:04 | Loss:0.6815501790420682 | top1:55.84604263305664
61/170 Data:0.001 | Batch:0.514 | Total:0:00:33 | ETA:0:00:58 | Loss:0.6844917633494393 | top1:54.82478713989258
71/170 Data:0.001 | Batch:0.520 | Total:0:00:38 | ETA:0:00:52 | 