In [1]:
import json
import os
import time
import random
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import shutil

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision import models
from torchsummary import summary
from sklearn.model_selection import train_test_split

from model_pytorch import EfficientNet
from utils import Bar,Logger, AverageMeter, accuracy, mkdir_p, savefig
from warmup_scheduler import GradualWarmupScheduler

from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import collections

In [2]:
# GPU Device
gpu_id = 0
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
use_cuda = torch.cuda.is_available()
print("GPU device %d:" %(gpu_id), use_cuda)

GPU device 0: True


# Arguments

In [3]:
data_dir = '/media/data2/dataset/GAN_ImageData/PGGAN_128'

In [4]:
pretrained = ''
resume = './log/pggan/128/b1_2/checkpoint.pth.tar'

In [5]:
# Model
model_name = 'efficientnet-b1' # b0-b7 scale

# Optimization
num_classes = 2
epochs = 300
start_epoch = 0
train_batch = 190
test_batch = 190
lr = 0.1
schedule = [20, 75, 125, 175]
momentum = 0.9
gamma = 0.1 # LR is multiplied by gamma on schedule

# CheckPoint
checkpoint = './log/pggan/128/b1_2' # dir
if not os.path.isdir(checkpoint):
    os.mkdir(checkpoint)
num_workers = 8

# Seed
manual_seed = 7
random.seed(manual_seed)
torch.cuda.manual_seed_all(manual_seed)

# Image
size = (128, 128)

best_acc = 0

In [6]:
state = {}
state['num_classes'] = num_classes
state['epochs'] = epochs
state['start_epoch'] = start_epoch
state['train_batch'] = train_batch
state['test_batch'] = test_batch
state['lr'] = lr
state['schedule'] = schedule
state['momentum'] = momentum
state['gamma'] = gamma

# Dataset

In [7]:
train_dir = os.path.join(data_dir, 'train')
val_dir = os.path.join(data_dir, 'validation')    
train_aug = transforms.Compose([
    transforms.RandomAffine(degrees=2, translate=(0.02, 0.02), scale=(0.98, 1.02), shear=2, fillcolor=(124,117,104)),
    transforms.Resize(size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.RandomErasing(p=0.3, scale=(0.02, 0.10), ratio=(0.3, 3.3), value=0, inplace=True),
])
val_aug = transforms.Compose([
    transforms.Resize(size),
    transforms.ToTensor(),
])

# pin_memory : cuda pin memeory use
train_loader = DataLoader(datasets.ImageFolder(train_dir, transform=train_aug),
                          batch_size=train_batch, shuffle=True, num_workers=num_workers, pin_memory=True)
val_loader = DataLoader(datasets.ImageFolder(val_dir, val_aug),
                       batch_size=test_batch, shuffle=True, num_workers=num_workers, pin_memory=True)

# Model

In [8]:
model = EfficientNet.from_name(model_name, num_classes=num_classes,
                              override_params={'dropout_rate':0.3, 'drop_connect_rate':0.3})

# Pre-trained
if pretrained:
    print("=> using pre-trained model '{}'".format(pretrained))
    model.load_state_dict(torch.load(pretrained)['state_dict'])

In [9]:
model.to('cuda')
cudnn.benchmark = True
print('    Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0))

    Total params: 6.52M


# Loss

In [10]:
criterion = nn.CrossEntropyLoss().cuda()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=1e-4)
scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=8, total_epoch=10, after_scheduler=scheduler_cosine)

In [11]:
# Resume
if resume:
    print('==> Resuming from checkpoint..')
    checkpoint = os.path.dirname(resume)
#     checkpoint = torch.load(resume)
    resume = torch.load(resume)
    best_acc = resume['best_acc']
    start_epoch = resume['epoch']
    model.load_state_dict(resume['state_dict'])
    optimizer.load_state_dict(resume['optimizer'])
    logger = Logger(os.path.join(checkpoint, 'log.txt'), resume=True)
else:
    logger = Logger(os.path.join(checkpoint, 'log.txt'))
    logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.'])

==> Resuming from checkpoint..


# Train

In [12]:
def train(train_loader, model, criterion, optimizer, epoch, use_cuda):
    model.train()
    torch.set_grad_enabled(True)
    
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    end = time.time()
    
    bar = Bar('Processing', max=len(train_loader))
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        batch_size = inputs.size(0)
        if batch_size < train_batch:
            continue
        # measure data loading time
        data_time.update(time.time() - end)

        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

        # compute output
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # measure accuracy and record loss
        prec1 = accuracy(outputs.data, targets.data)
        losses.update(loss.data.tolist(), inputs.size(0))
        top1.update(prec1[0], inputs.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # plot progress
        bar.suffix  = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} '.format(
                    batch=batch_idx + 1,
                    size=len(train_loader),
                    data=data_time.val,
                    bt=batch_time.val,
                    total=bar.elapsed_td,
                    eta=bar.eta_td,
                    loss=losses.avg,
                    top1=top1.avg,
                    )
        bar.next()
        if batch_idx % 10 == 0:
            print('{batch}/{size} Data:{data:.3f} | Batch:{bt:.3f} | Total:{total:} | ETA:{eta:} | Loss:{loss:} | top1:{tp1:}'.format(
                 batch=batch_idx+1, size=len(train_loader), data=data_time.val, bt=batch_time.val, total=bar.elapsed_td, eta=bar.eta_td, loss=losses.avg, tp1=top1.avg))
    bar.finish()
    return (losses.avg, top1.avg)

In [13]:
def test(val_loader, model, criterion, epoch, use_cuda):
    global best_acc

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to evaluate mode
    model.eval()
    torch.set_grad_enabled(False)

    end = time.time()
    bar = Bar('Processing', max=len(val_loader))
    for batch_idx, (inputs, targets) in enumerate(val_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        inputs, targets = torch.autograd.Variable(inputs, volatile=True), torch.autograd.Variable(targets)

        # compute output
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # measure accuracy and record loss
        prec1 = accuracy(outputs.data, targets.data)
        losses.update(loss.data.tolist(), inputs.size(0))
        top1.update(prec1[0], inputs.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # plot progress
        bar.suffix  = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:} | top1: {top1:}'.format(
                    batch=batch_idx + 1,
                    size=len(val_loader),
                    data=data_time.avg,
                    bt=batch_time.avg,
                    total=bar.elapsed_td,
                    eta=bar.eta_td,
                    loss=losses.avg,
                    top1=top1.avg,)
        bar.next()
    print('{batch}/{size} Data:{data:.3f} | Batch:{bt:.3f} | Total:{total:} | ETA:{eta:} | Loss:{loss:} | top1:{tp1:}'.format(
         batch=batch_idx+1, size=len(val_loader), data=data_time.val, bt=batch_time.val, total=bar.elapsed_td, eta=bar.eta_td, loss=losses.avg, tp1=top1.avg))
    bar.finish()
    return (losses.avg, top1.avg)

In [14]:
def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar'):
    filepath = os.path.join(checkpoint, filename)
    torch.save(state, filepath)
    if is_best:
        shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar'))

def adjust_learning_rate(optimizer, epoch):
    global state
    lr_set = [1, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5]
    lr_list = schedule.copy()
    lr_list.append(epoch)
    lr_list.sort()
    idx = lr_list.index(epoch)
    state['lr'] *= lr_set[idx]
    for param_group in optimizer.param_groups:
        param_group['lr'] = state['lr']

In [None]:
for epoch in range(start_epoch, epochs):
    state['lr'] = optimizer.state_dict()['param_groups'][0]['lr']
    adjust_learning_rate(optimizer, epoch)
    print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, epochs, state['lr']))
    
    train_loss, train_acc = train(train_loader, model, criterion, optimizer, epoch, use_cuda)
    test_loss, test_acc = test(val_loader, model, criterion, epoch, use_cuda)
    
    logger.append([state['lr'], train_loss, test_loss, train_acc, test_acc])
    scheduler_warmup.step()

    is_best = test_acc > best_acc
    best_acc = max(test_acc, best_acc)
    save_checkpoint({
        'epoch': epoch + 1,
        'state_dict' : model.state_dict(),
        'acc': test_acc,
        'best_acc': best_acc,
        'optimizer': optimizer.state_dict(),
    }, is_best, checkpoint=checkpoint)


Epoch: [190 | 300] LR: 0.000073
1/676 Data:0.768 | Batch:5.641 | Total:0:00:05 | ETA:1:03:29 | Loss:0.14894811809062958 | top1:93.68421173095703
11/676 Data:0.001 | Batch:0.665 | Total:0:00:12 | ETA:0:13:12 | Loss:0.11345696246082132 | top1:95.2153091430664
21/676 Data:0.001 | Batch:0.666 | Total:0:00:19 | ETA:0:07:15 | Loss:0.11436259249846141 | top1:95.43859100341797
31/676 Data:0.002 | Batch:0.657 | Total:0:00:25 | ETA:0:07:08 | Loss:0.11889133506244229 | top1:95.07640075683594
41/676 Data:0.003 | Batch:0.660 | Total:0:00:32 | ETA:0:06:59 | Loss:0.12211005934854834 | top1:94.83953857421875
51/676 Data:0.002 | Batch:0.667 | Total:0:00:39 | ETA:0:06:52 | Loss:0.12064347766778048 | top1:94.89164733886719
61/676 Data:0.001 | Batch:0.664 | Total:0:00:45 | ETA:0:06:50 | Loss:0.12412355205074685 | top1:94.7195816040039
71/676 Data:0.002 | Batch:0.663 | Total:0:00:52 | ETA:0:06:42 | Loss:0.12249374484092417 | top1:94.76649475097656
81/676 Data:0.001 | Batch:0.662 | Total:0:00:58 | ETA:0:06



169/169 Data:0.002 | Batch:1.415 | Total:0:00:32 | ETA:0:00:00 | Loss:0.048176510560955864 | top1:98.17134094238281

Epoch: [191 | 300] LR: 0.000017
1/676 Data:0.775 | Batch:1.686 | Total:0:00:01 | ETA:0:19:00 | Loss:0.1032135933637619 | top1:96.84210968017578
11/676 Data:0.002 | Batch:0.664 | Total:0:00:08 | ETA:0:08:50 | Loss:0.11978998441587795 | top1:95.55023956298828
21/676 Data:0.003 | Batch:0.660 | Total:0:00:15 | ETA:0:07:13 | Loss:0.11676709353923798 | top1:95.41352844238281
31/676 Data:0.002 | Batch:0.664 | Total:0:00:21 | ETA:0:07:07 | Loss:0.11599493723723196 | top1:95.39898681640625
41/676 Data:0.002 | Batch:0.661 | Total:0:00:28 | ETA:0:07:02 | Loss:0.11721923129587639 | top1:95.37869262695312
51/676 Data:0.002 | Batch:0.661 | Total:0:00:35 | ETA:0:06:56 | Loss:0.11902291678330477 | top1:95.22187805175781
61/676 Data:0.002 | Batch:0.663 | Total:0:00:41 | ETA:0:06:48 | Loss:0.11885816237477005 | top1:95.25452423095703
71/676 Data:0.003 | Batch:0.670 | Total:0:00:48 | ETA:0

21/676 Data:0.002 | Batch:0.663 | Total:0:00:14 | ETA:0:07:14 | Loss:0.12610845799957002 | top1:94.46115112304688
31/676 Data:0.002 | Batch:0.663 | Total:0:00:21 | ETA:0:07:08 | Loss:0.11917218049207041 | top1:94.8896484375
41/676 Data:0.002 | Batch:0.661 | Total:0:00:27 | ETA:0:07:00 | Loss:0.12301543654828537 | top1:94.801025390625
51/676 Data:0.002 | Batch:0.661 | Total:0:00:34 | ETA:0:06:53 | Loss:0.1227696588372483 | top1:94.81940460205078
61/676 Data:0.002 | Batch:0.658 | Total:0:00:41 | ETA:0:06:46 | Loss:0.12124164695622491 | top1:94.93528747558594
71/676 Data:0.002 | Batch:0.660 | Total:0:00:47 | ETA:0:06:40 | Loss:0.12257842572642044 | top1:94.877685546875
81/676 Data:0.002 | Batch:0.660 | Total:0:00:54 | ETA:0:06:35 | Loss:0.12366642701772997 | top1:94.80831146240234
91/676 Data:0.002 | Batch:0.663 | Total:0:01:01 | ETA:0:06:28 | Loss:0.1245055031645429 | top1:94.76576232910156
101/676 Data:0.002 | Batch:0.659 | Total:0:01:07 | ETA:0:06:21 | Loss:0.12346474700930095 | top1:9

51/676 Data:0.002 | Batch:0.660 | Total:0:00:34 | ETA:0:06:54 | Loss:0.12979373014440723 | top1:94.55108642578125
61/676 Data:0.001 | Batch:0.662 | Total:0:00:41 | ETA:0:06:48 | Loss:0.12833000719547272 | top1:94.65918731689453
71/676 Data:0.002 | Batch:0.659 | Total:0:00:47 | ETA:0:06:41 | Loss:0.12812491369919038 | top1:94.64047241210938
81/676 Data:0.003 | Batch:0.658 | Total:0:00:54 | ETA:0:06:34 | Loss:0.12730689650332486 | top1:94.61988067626953
91/676 Data:0.002 | Batch:0.662 | Total:0:01:01 | ETA:0:06:28 | Loss:0.1286635419333374 | top1:94.52284240722656
101/676 Data:0.002 | Batch:0.660 | Total:0:01:07 | ETA:0:06:21 | Loss:0.12682508754700716 | top1:94.61177825927734
111/676 Data:0.002 | Batch:0.656 | Total:0:01:14 | ETA:0:06:12 | Loss:0.12642028655957532 | top1:94.65149688720703
121/676 Data:0.002 | Batch:0.671 | Total:0:01:20 | ETA:0:06:07 | Loss:0.12687122375388776 | top1:94.67594909667969
131/676 Data:0.002 | Batch:0.660 | Total:0:01:27 | ETA:0:06:01 | Loss:0.12675838209404