In [1]:
import json
import os
import time
import random
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import shutil

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision import models
from torchsummary import summary
from sklearn.model_selection import train_test_split

from model_pytorch import EfficientNet
from utils import Bar,Logger, AverageMeter, accuracy, mkdir_p, savefig
from warmup_scheduler import GradualWarmupScheduler

from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import collections

In [2]:
# GPU Device
gpu_id = 3
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
use_cuda = torch.cuda.is_available()
print("GPU device %d:" %(gpu_id), use_cuda)

GPU device 3: True


# Arguments

In [3]:
data_dir = '/media/data2/dataset/GAN_ImageData/StyleGAN_256/'

In [4]:
pretrained = ''
resume = ''

In [5]:
# Model
model_name = 'efficientnet-b0' # b0-b7 scale

# Optimization
num_classes = 2
epochs = 400
start_epoch = 0
train_batch = 300
test_batch = 300
lr = 0.04
schedule = [75, 150, 225]
momentum = 0.9
gamma = 0.1 # LR is multiplied by gamma on schedule

# CheckPoint
checkpoint = './log/style1/128/b0' # dir
if not os.path.isdir(checkpoint):
    os.mkdir(checkpoint)
num_workers = 8

# Seed
manual_seed = 7
random.seed(manual_seed)
torch.cuda.manual_seed_all(manual_seed)

# Image
size = (128, 128)

best_acc = 0

In [6]:
state = {}
state['num_classes'] = num_classes
state['epochs'] = epochs
state['start_epoch'] = start_epoch
state['train_batch'] = train_batch
state['test_batch'] = test_batch
state['lr'] = lr
state['schedule'] = schedule
state['momentum'] = momentum
state['gamma'] = gamma

# Dataset

In [7]:
train_dir = os.path.join(data_dir, 'train')
val_dir = os.path.join(data_dir, 'validation')    
train_aug = transforms.Compose([
#     transforms.RandomAffine(degrees=2, translate=(0.02, 0.02), scale=(0.98, 1.02), shear=2, fillcolor=(124,117,104)),
    transforms.Resize(size),
#     transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
#     transforms.RandomErasing(p=0.3, scale=(0.02, 0.10), ratio=(0.3, 3.3), value=0, inplace=True),
])
val_aug = transforms.Compose([
    transforms.Resize(size),
    transforms.ToTensor(),
])

# pin_memory : cuda pin memeory use
train_loader = DataLoader(datasets.ImageFolder(train_dir, transform=train_aug),
                          batch_size=train_batch, shuffle=True, num_workers=num_workers, pin_memory=True)
val_loader = DataLoader(datasets.ImageFolder(val_dir, val_aug),
                       batch_size=test_batch, shuffle=True, num_workers=num_workers, pin_memory=True)

# Model

In [8]:
model = EfficientNet.from_name(model_name, num_classes=num_classes,
                              override_params={'dropout_rate':0.2, 'drop_connect_rate':0.2})

# Pre-trained
if pretrained:
    print("=> using pre-trained model '{}'".format(pretrained))
    model.load_state_dict(torch.load(pretrained)['state_dict'])

In [9]:
model.named_parameters

<bound method Module.named_parameters of EfficientNet(
  (_conv_stem): Conv2dStaticSamePadding(
    3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False
    (static_padding): ZeroPad2d(padding=(0, 1, 0, 1), value=0.0)
  )
  (_gn0): GroupNorm(8, 32, eps=1e-05, affine=True)
  (_blocks): ModuleList(
    (0): MBConvBlock(
      (_depthwise_conv): Conv2dStaticSamePadding(
        32, 32, kernel_size=(3, 3), stride=[1, 1], groups=32, bias=False
        (static_padding): ZeroPad2d(padding=(1, 1, 1, 1), value=0.0)
      )
      (_gn1): GroupNorm(8, 32, eps=1e-05, affine=True)
      (_se_reduce): Conv2dStaticSamePadding(
        32, 8, kernel_size=(1, 1), stride=(1, 1)
        (static_padding): Identity()
      )
      (_se_expand): Conv2dStaticSamePadding(
        8, 32, kernel_size=(1, 1), stride=(1, 1)
        (static_padding): Identity()
      )
      (_project_conv): Conv2dStaticSamePadding(
        32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False
        (static_padding): Identity()


In [10]:
model.to('cuda')
cudnn.benchmark = True
print('    Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0))

    Total params: 4.01M


# Loss

In [11]:
criterion = nn.CrossEntropyLoss().cuda()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=1e-4)
# optimizer = optim.Adam(model.parameters(), weight_decay=1e-4)
scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=8, total_epoch=10, after_scheduler=scheduler_cosine)

In [12]:
# Resume
if resume:
    print('==> Resuming from checkpoint..')
    checkpoint = os.path.dirname(resume)
#     checkpoint = torch.load(resume)
    resume = torch.load(resume)
    best_acc = resume['best_acc']
    start_epoch = resume['epoch']
    model.load_state_dict(resume['state_dict'])
    optimizer.load_state_dict(resume['optimizer'])
    logger = Logger(os.path.join(checkpoint, 'log.txt'), resume=True)
else:
    logger = Logger(os.path.join(checkpoint, 'log.txt'))
    logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.'])

# Train

In [13]:
def train(train_loader, model, criterion, optimizer, epoch, use_cuda):
    model.train()
    torch.set_grad_enabled(True)
    
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    end = time.time()
    
    bar = Bar('Processing', max=len(train_loader))
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        batch_size = inputs.size(0)
        if batch_size < train_batch:
            continue
        # measure data loading time
        data_time.update(time.time() - end)

        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

        # compute output
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # measure accuracy and record loss
        prec1 = accuracy(outputs.data, targets.data)
        losses.update(loss.data.tolist(), inputs.size(0))
        top1.update(prec1[0], inputs.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # plot progress
        bar.suffix  = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} '.format(
                    batch=batch_idx + 1,
                    size=len(train_loader),
                    data=data_time.val,
                    bt=batch_time.val,
                    total=bar.elapsed_td,
                    eta=bar.eta_td,
                    loss=losses.avg,
                    top1=top1.avg,
                    )
        bar.next()
        if batch_idx % 10 == 0:
            print('{batch}/{size} Data:{data:.3f} | Batch:{bt:.3f} | Total:{total:} | ETA:{eta:} | Loss:{loss:} | top1:{tp1:}'.format(
                 batch=batch_idx+1, size=len(train_loader), data=data_time.val, bt=batch_time.val, total=bar.elapsed_td, eta=bar.eta_td, loss=losses.avg, tp1=top1.avg))
    bar.finish()
    return (losses.avg, top1.avg)

In [14]:
def test(val_loader, model, criterion, epoch, use_cuda):
    global best_acc

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to evaluate mode
    model.eval()
#     torch.set_grad_enabled(False)

    end = time.time()
    bar = Bar('Processing', max=len(val_loader))
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(val_loader):
            # measure data loading time
            data_time.update(time.time() - end)

            if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda()
            inputs, targets = torch.autograd.Variable(inputs, volatile=True), torch.autograd.Variable(targets)

            # compute output
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            # measure accuracy and record loss
            prec1 = accuracy(outputs.data, targets.data)
            losses.update(loss.data.tolist(), inputs.size(0))
            top1.update(prec1[0], inputs.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            # plot progress
            bar.suffix  = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:} | top1: {top1:}'.format(
                        batch=batch_idx + 1,
                        size=len(val_loader),
                        data=data_time.avg,
                        bt=batch_time.avg,
                        total=bar.elapsed_td,
                        eta=bar.eta_td,
                        loss=losses.avg,
                        top1=top1.avg,)
            bar.next()
    print('{batch}/{size} Data:{data:.3f} | Batch:{bt:.3f} | Total:{total:} | ETA:{eta:} | Loss:{loss:} | top1:{tp1:}'.format(
         batch=batch_idx+1, size=len(val_loader), data=data_time.val, bt=batch_time.val, total=bar.elapsed_td, eta=bar.eta_td, loss=losses.avg, tp1=top1.avg))
    bar.finish()
    return (losses.avg, top1.avg)

In [15]:
def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar'):
    filepath = os.path.join(checkpoint, filename)
    torch.save(state, filepath)
    if is_best:
        shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar'))

def adjust_learning_rate(optimizer, epoch):
    global state
    lr_set = [1, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5]
    lr_list = schedule.copy()
    lr_list.append(epoch)
    lr_list.sort()
    idx = lr_list.index(epoch)
    state['lr'] *= lr_set[idx]
    for param_group in optimizer.param_groups:
        param_group['lr'] = state['lr']

In [None]:
for epoch in range(start_epoch, epochs):
    state['lr'] = optimizer.state_dict()['param_groups'][0]['lr']
    adjust_learning_rate(optimizer, epoch)
    print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, epochs, state['lr']))
    train_loss, train_acc = train(train_loader, model, criterion, optimizer, epoch, use_cuda)
    test_loss, test_acc = test(val_loader, model, criterion, epoch, use_cuda)
    
    logger.append([state['lr'], train_loss, test_loss, train_acc, test_acc])
    scheduler_warmup.step()

    is_best = test_acc > best_acc
    best_acc = max(test_acc, best_acc)
    save_checkpoint({
        'epoch': epoch + 1,
        'state_dict' : model.state_dict(),
        'acc': test_acc,
        'best_acc': best_acc,
        'optimizer': optimizer.state_dict(),
    }, is_best, checkpoint=checkpoint)


Epoch: [1 | 400] LR: 0.040000
1/230 Data:1.815 | Batch:6.516 | Total:0:00:06 | ETA:0:24:53 | Loss:0.7019357085227966 | top1:50.66666793823242
11/230 Data:0.055 | Batch:0.695 | Total:0:00:13 | ETA:0:04:41 | Loss:0.7241052009842612 | top1:49.66666793823242
21/230 Data:0.047 | Batch:0.691 | Total:0:00:20 | ETA:0:02:27 | Loss:0.7396161613010225 | top1:50.126983642578125
31/230 Data:0.059 | Batch:0.709 | Total:0:00:27 | ETA:0:02:19 | Loss:0.7264233558408676 | top1:50.66666793823242
41/230 Data:0.054 | Batch:0.695 | Total:0:00:34 | ETA:0:02:12 | Loss:0.7196704745292664 | top1:51.178863525390625
51/230 Data:0.059 | Batch:0.704 | Total:0:00:41 | ETA:0:02:05 | Loss:0.7167434528762219 | top1:51.32025909423828
61/230 Data:0.045 | Batch:0.690 | Total:0:00:48 | ETA:0:01:58 | Loss:0.7201931330024219 | top1:51.14754104614258
71/230 Data:0.059 | Batch:0.712 | Total:0:00:55 | ETA:0:01:51 | Loss:0.7208243608474731 | top1:51.09859085083008
81/230 Data:0.047 | Batch:0.691 | Total:0:01:02 | ETA:0:01:44 | 



26/26 Data:0.002 | Batch:0.200 | Total:0:00:06 | ETA:0:00:00 | Loss:0.48304608234992397 | top1:77.20512390136719

Epoch: [2 | 400] LR: 0.068000
1/230 Data:1.796 | Batch:2.437 | Total:0:00:02 | ETA:0:09:32 | Loss:0.49888381361961365 | top1:73.0
11/230 Data:0.059 | Batch:0.702 | Total:0:00:09 | ETA:0:03:13 | Loss:0.4729139344258742 | top1:77.93939208984375
21/230 Data:0.056 | Batch:0.703 | Total:0:00:16 | ETA:0:02:26 | Loss:0.4542351123832521 | top1:79.4285659790039
31/230 Data:0.059 | Batch:0.690 | Total:0:00:23 | ETA:0:02:20 | Loss:0.44507852100556894 | top1:79.93548583984375
41/230 Data:0.033 | Batch:0.607 | Total:0:00:30 | ETA:0:02:10 | Loss:0.4374533135716508 | top1:80.60975646972656
51/230 Data:0.051 | Batch:0.612 | Total:0:00:36 | ETA:0:01:58 | Loss:0.43113452897352333 | top1:80.98039245605469
61/230 Data:0.059 | Batch:0.532 | Total:0:00:43 | ETA:0:01:50 | Loss:0.42387275226780624 | top1:81.4699478149414
71/230 Data:0.045 | Batch:0.647 | Total:0:00:49 | ETA:0:01:38 | Loss:0.419560

26/26 Data:0.141 | Batch:0.340 | Total:0:00:12 | ETA:0:00:00 | Loss:0.027783757612968866 | top1:99.0

Epoch: [5 | 400] LR: 0.152000
1/230 Data:2.008 | Batch:2.649 | Total:0:00:02 | ETA:0:10:20 | Loss:0.00872308760881424 | top1:100.0
11/230 Data:0.059 | Batch:0.703 | Total:0:00:09 | ETA:0:03:17 | Loss:0.016366861591284924 | top1:99.36363983154297
21/230 Data:0.059 | Batch:0.701 | Total:0:00:16 | ETA:0:02:26 | Loss:0.015175748001118856 | top1:99.4285659790039
31/230 Data:0.061 | Batch:0.682 | Total:0:00:23 | ETA:0:02:15 | Loss:0.015929729568832103 | top1:99.40859985351562
41/230 Data:0.048 | Batch:0.712 | Total:0:00:30 | ETA:0:02:09 | Loss:0.019435636866724164 | top1:99.35772705078125
51/230 Data:0.048 | Batch:0.647 | Total:0:00:36 | ETA:0:01:59 | Loss:0.019640898192757925 | top1:99.35294342041016
61/230 Data:0.032 | Batch:0.596 | Total:0:00:43 | ETA:0:01:57 | Loss:0.0199364512361067 | top1:99.32787322998047
71/230 Data:0.067 | Batch:0.710 | Total:0:00:50 | ETA:0:01:41 | Loss:0.021167983

221/230 Data:0.059 | Batch:0.699 | Total:0:02:25 | ETA:0:00:07 | Loss:0.021394318657241364 | top1:99.26395416259766
26/26 Data:0.002 | Batch:0.185 | Total:0:00:06 | ETA:0:00:00 | Loss:0.030985599312071618 | top1:99.025634765625

Epoch: [8 | 400] LR: 0.236000
1/230 Data:1.835 | Batch:2.487 | Total:0:00:02 | ETA:0:09:43 | Loss:0.011803006753325462 | top1:99.66667938232422
11/230 Data:0.059 | Batch:0.706 | Total:0:00:09 | ETA:0:03:15 | Loss:0.032245799475772816 | top1:99.1212158203125
21/230 Data:0.059 | Batch:0.705 | Total:0:00:16 | ETA:0:02:27 | Loss:0.02603572021637644 | top1:99.20634460449219
31/230 Data:0.047 | Batch:0.696 | Total:0:00:23 | ETA:0:02:20 | Loss:0.022578796724818887 | top1:99.30107116699219
41/230 Data:0.059 | Batch:0.704 | Total:0:00:30 | ETA:0:02:12 | Loss:0.02183411644612689 | top1:99.32520294189453
51/230 Data:0.059 | Batch:0.688 | Total:0:00:37 | ETA:0:02:06 | Loss:0.02161105865996112 | top1:99.35294342041016
61/230 Data:0.059 | Batch:0.701 | Total:0:00:44 | ETA:0: