In [1]:
import json
import os
import time
import random
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import shutil

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision import models
from torchsummary import summary
from sklearn.model_selection import train_test_split

from model_pytorch import EfficientNet
from utils import Bar,Logger, AverageMeter, accuracy, mkdir_p, savefig
from warmup_scheduler import GradualWarmupScheduler

from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import collections

In [2]:
# GPU Device
gpu_id = 3
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
use_cuda = torch.cuda.is_available()
print("GPU device %d:" %(gpu_id), use_cuda)

GPU device 3: True


# Arguments

In [3]:
data_dir = '/media/data2/dataset/GAN_ImageData/StyleGAN_256/'

In [4]:
pretrained = ''
resume = ''

In [5]:
# Model
model_name = 'efficientnet-b1' # b0-b7 scale

# Optimization
num_classes = 2
epochs = 300
start_epoch = 0
train_batch = 190
test_batch = 190
lr = 0.1
schedule = [20, 75, 125, 175]
momentum = 0.9
gamma = 0.1 # LR is multiplied by gamma on schedule

# CheckPoint
checkpoint = './log/style1/128/b1_2' # dir
if not os.path.isdir(checkpoint):
    os.mkdir(checkpoint)
num_workers = 4

# Seed
manual_seed = 7
random.seed(manual_seed)
torch.cuda.manual_seed_all(manual_seed)

# Image
size = (128, 128)

best_acc = 0

In [6]:
state = {}
state['num_classes'] = num_classes
state['epochs'] = epochs
state['start_epoch'] = start_epoch
state['train_batch'] = train_batch
state['test_batch'] = test_batch
state['lr'] = lr
state['schedule'] = schedule
state['momentum'] = momentum
state['gamma'] = gamma

# Dataset

In [7]:
train_dir = os.path.join(data_dir, 'train')
val_dir = os.path.join(data_dir, 'validation')    
train_aug = transforms.Compose([
    transforms.RandomAffine(degrees=2, translate=(0.02, 0.02), scale=(0.98, 1.02), shear=2, fillcolor=(124,117,104)),
    transforms.Resize(size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.RandomErasing(p=0.3, scale=(0.02, 0.10), ratio=(0.3, 3.3), value=0, inplace=True),
])
val_aug = transforms.Compose([
    transforms.Resize(size),
    transforms.ToTensor(),
])

# pin_memory : cuda pin memeory use
train_loader = DataLoader(datasets.ImageFolder(train_dir, transform=train_aug),
                          batch_size=train_batch, shuffle=True, num_workers=num_workers, pin_memory=True)
val_loader = DataLoader(datasets.ImageFolder(val_dir, val_aug),
                       batch_size=test_batch, shuffle=True, num_workers=num_workers, pin_memory=True)

# Model

In [8]:
model = EfficientNet.from_name(model_name, num_classes=num_classes,
                              override_params={'dropout_rate':0.3, 'drop_connect_rate':0.3})

# Pre-trained
if pretrained:
    print("=> using pre-trained model '{}'".format(pretrained))
    model.load_state_dict(torch.load(pretrained)['state_dict'])

In [9]:
model.named_parameters

<bound method Module.named_parameters of EfficientNet(
  (_conv_stem): Conv2dStaticSamePadding(
    3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False
    (static_padding): ZeroPad2d(padding=(0, 1, 0, 1), value=0.0)
  )
  (_gn0): GroupNorm(8, 32, eps=1e-05, affine=True)
  (_blocks): ModuleList(
    (0): MBConvBlock(
      (_depthwise_conv): Conv2dStaticSamePadding(
        32, 32, kernel_size=(3, 3), stride=[1, 1], groups=32, bias=False
        (static_padding): ZeroPad2d(padding=(1, 1, 1, 1), value=0.0)
      )
      (_gn1): GroupNorm(8, 32, eps=1e-05, affine=True)
      (_se_reduce): Conv2dStaticSamePadding(
        32, 8, kernel_size=(1, 1), stride=(1, 1)
        (static_padding): Identity()
      )
      (_se_expand): Conv2dStaticSamePadding(
        8, 32, kernel_size=(1, 1), stride=(1, 1)
        (static_padding): Identity()
      )
      (_project_conv): Conv2dStaticSamePadding(
        32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False
        (static_padding): Identity()


In [10]:
model.to('cuda')
cudnn.benchmark = True
print('    Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0))

    Total params: 6.52M


In [11]:
# summary(model, input_size=(3,64,64), device='cuda')

# Loss

In [12]:
criterion = nn.CrossEntropyLoss().cuda()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=1e-5, nesterov=True)
# optimizer = optim.Adam(model.parameters(), weight_decay=1e-4)
scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=8, total_epoch=10, after_scheduler=scheduler_cosine)

In [13]:
# Resume
if resume:
    print('==> Resuming from checkpoint..')
    checkpoint = os.path.dirname(resume)
#     checkpoint = torch.load(resume)
    resume = torch.load(resume)
    best_acc = resume['best_acc']
    start_epoch = resume['epoch']
    model.load_state_dict(resume['state_dict'])
    optimizer.load_state_dict(resume['optimizer'])
    logger = Logger(os.path.join(checkpoint, 'log.txt'), resume=True)
else:
    logger = Logger(os.path.join(checkpoint, 'log.txt'))
    logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.'])

# Train

In [14]:
def train(train_loader, model, criterion, optimizer, epoch, use_cuda):
    model.train()
    torch.set_grad_enabled(True)
    
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    end = time.time()
    
    bar = Bar('Processing', max=len(train_loader))
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        batch_size = inputs.size(0)
        if batch_size < train_batch:
            continue
        # measure data loading time
        data_time.update(time.time() - end)

        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

        # compute output
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # measure accuracy and record loss
        prec1 = accuracy(outputs.data, targets.data)
        losses.update(loss.data.tolist(), inputs.size(0))
        top1.update(prec1[0], inputs.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # plot progress
        bar.suffix  = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} '.format(
                    batch=batch_idx + 1,
                    size=len(train_loader),
                    data=data_time.val,
                    bt=batch_time.val,
                    total=bar.elapsed_td,
                    eta=bar.eta_td,
                    loss=losses.avg,
                    top1=top1.avg,
                    )
        bar.next()
        if batch_idx % 10 == 0:
            print('{batch}/{size} Data:{data:.3f} | Batch:{bt:.3f} | Total:{total:} | ETA:{eta:} | Loss:{loss:} | top1:{tp1:}'.format(
                 batch=batch_idx+1, size=len(train_loader), data=data_time.val, bt=batch_time.val, total=bar.elapsed_td, eta=bar.eta_td, loss=losses.avg, tp1=top1.avg))
    bar.finish()
    return (losses.avg, top1.avg)

In [15]:
def test(val_loader, model, criterion, epoch, use_cuda):
    global best_acc

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to evaluate mode
    model.eval()
    torch.set_grad_enabled(False)

    end = time.time()
    bar = Bar('Processing', max=len(val_loader))
    for batch_idx, (inputs, targets) in enumerate(val_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        inputs, targets = torch.autograd.Variable(inputs, volatile=True), torch.autograd.Variable(targets)

        # compute output
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # measure accuracy and record loss
        prec1 = accuracy(outputs.data, targets.data)
        losses.update(loss.data.tolist(), inputs.size(0))
        top1.update(prec1[0], inputs.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # plot progress
        bar.suffix  = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:} | top1: {top1:}'.format(
                    batch=batch_idx + 1,
                    size=len(val_loader),
                    data=data_time.avg,
                    bt=batch_time.avg,
                    total=bar.elapsed_td,
                    eta=bar.eta_td,
                    loss=losses.avg,
                    top1=top1.avg,)
        bar.next()
    print('{batch}/{size} Data:{data:.3f} | Batch:{bt:.3f} | Total:{total:} | ETA:{eta:} | Loss:{loss:} | top1:{tp1:}'.format(
         batch=batch_idx+1, size=len(val_loader), data=data_time.val, bt=batch_time.val, total=bar.elapsed_td, eta=bar.eta_td, loss=losses.avg, tp1=top1.avg))
    bar.finish()
    return (losses.avg, top1.avg)

In [16]:
def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar'):
    filepath = os.path.join(checkpoint, filename)
    torch.save(state, filepath)
    if is_best:
        shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar'))

def adjust_learning_rate(optimizer, epoch):
    global state
    lr_set = [1, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5]
    lr_list = schedule.copy()
    lr_list.append(epoch)
    lr_list.sort()
    idx = lr_list.index(epoch)
    state['lr'] *= lr_set[idx]
    for param_group in optimizer.param_groups:
        param_group['lr'] = state['lr']

In [None]:
for epoch in range(start_epoch, epochs):
    state['lr'] = optimizer.state_dict()['param_groups'][0]['lr']
    adjust_learning_rate(optimizer, epoch)
    print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, epochs, state['lr']))
    train_loss, train_acc = train(train_loader, model, criterion, optimizer, epoch, use_cuda)
    test_loss, test_acc = test(val_loader, model, criterion, epoch, use_cuda)
    
    logger.append([state['lr'], train_loss, test_loss, train_acc, test_acc])
    scheduler_warmup.step()

    is_best = test_acc > best_acc
    best_acc = max(test_acc, best_acc)
    save_checkpoint({
        'epoch': epoch + 1,
        'state_dict' : model.state_dict(),
        'acc': test_acc,
        'best_acc': best_acc,
        'optimizer': optimizer.state_dict(),
    }, is_best, checkpoint=checkpoint)


Epoch: [1 | 300] LR: 0.100000
1/363 Data:1.199 | Batch:3.562 | Total:0:00:03 | ETA:0:21:30 | Loss:0.7181949615478516 | top1:47.894737243652344
11/363 Data:0.000 | Batch:0.323 | Total:0:00:06 | ETA:0:03:48 | Loss:5.107670469717546 | top1:50.19138717651367
21/363 Data:0.000 | Batch:0.323 | Total:0:00:09 | ETA:0:01:49 | Loss:3.4429485968181064 | top1:50.20050048828125
31/363 Data:0.001 | Batch:0.326 | Total:0:00:13 | ETA:0:01:47 | Loss:2.6384854201347596 | top1:50.135826110839844
41/363 Data:0.000 | Batch:0.316 | Total:0:00:16 | ETA:0:01:45 | Loss:2.178675684987045 | top1:50.1668815612793
51/363 Data:0.000 | Batch:0.325 | Total:0:00:19 | ETA:0:01:40 | Loss:1.893785727959053 | top1:50.19607925415039
61/363 Data:0.000 | Batch:0.324 | Total:0:00:22 | ETA:0:01:37 | Loss:1.6994556784629822 | top1:50.120792388916016
71/363 Data:0.000 | Batch:0.315 | Total:0:00:26 | ETA:0:01:35 | Loss:1.5593682611492319 | top1:50.35581970214844
81/363 Data:0.000 | Batch:0.319 | Total:0:00:29 | ETA:0:01:31 | Los



42/42 Data:0.000 | Batch:0.391 | Total:0:00:08 | ETA:0:00:00 | Loss:0.6954568112507845 | top1:50.0

Epoch: [2 | 300] LR: 0.170000
1/363 Data:1.104 | Batch:1.484 | Total:0:00:01 | ETA:0:08:58 | Loss:0.7008553147315979 | top1:50.52631759643555
11/363 Data:0.001 | Batch:0.320 | Total:0:00:05 | ETA:0:02:46 | Loss:0.7246841138059442 | top1:49.330142974853516
21/363 Data:0.000 | Batch:0.322 | Total:0:00:08 | ETA:0:01:58 | Loss:0.7133762240409851 | top1:49.77443313598633
31/363 Data:0.000 | Batch:0.322 | Total:0:00:11 | ETA:0:01:47 | Loss:0.7131567924253402 | top1:50.27164840698242
41/363 Data:0.000 | Batch:0.322 | Total:0:00:14 | ETA:0:01:44 | Loss:0.7109273017906561 | top1:50.397945404052734
51/363 Data:0.000 | Batch:0.327 | Total:0:00:17 | ETA:0:01:41 | Loss:0.7107974618088966 | top1:50.08256149291992
61/363 Data:0.000 | Batch:0.315 | Total:0:00:21 | ETA:0:01:38 | Loss:0.709744499355066 | top1:50.19844436645508
71/363 Data:0.000 | Batch:0.313 | Total:0:00:24 | ETA:0:01:34 | Loss:0.70876859

331/363 Data:0.000 | Batch:0.351 | Total:0:01:47 | ETA:0:00:12 | Loss:0.7002588539325219 | top1:50.17013931274414
341/363 Data:0.000 | Batch:0.350 | Total:0:01:51 | ETA:0:00:08 | Loss:0.700230688468452 | top1:50.200645446777344
351/363 Data:0.000 | Batch:0.315 | Total:0:01:54 | ETA:0:00:05 | Loss:0.7001496021903818 | top1:50.190433502197266
361/363 Data:0.000 | Batch:0.322 | Total:0:01:57 | ETA:0:00:01 | Loss:0.7000208416143613 | top1:50.2041130065918
42/42 Data:0.000 | Batch:0.035 | Total:0:00:08 | ETA:0:00:00 | Loss:0.6936797223029992 | top1:50.0

Epoch: [4 | 300] LR: 0.310000
1/363 Data:1.082 | Batch:1.430 | Total:0:00:01 | ETA:0:08:38 | Loss:0.699485719203949 | top1:47.36842346191406
11/363 Data:0.000 | Batch:0.326 | Total:0:00:04 | ETA:0:02:32 | Loss:0.6982612393119119 | top1:50.14353942871094
21/363 Data:0.001 | Batch:0.469 | Total:0:00:08 | ETA:0:02:02 | Loss:0.6975717516172499 | top1:50.85212707519531
31/363 Data:0.000 | Batch:0.323 | Total:0:00:11 | ETA:0:02:01 | Loss:0.697993

291/363 Data:0.000 | Batch:0.314 | Total:0:01:39 | ETA:0:00:24 | Loss:0.6957915052515534 | top1:50.32917404174805
301/363 Data:0.000 | Batch:0.315 | Total:0:01:42 | ETA:0:00:21 | Loss:0.6957692046498143 | top1:50.346214294433594
311/363 Data:0.000 | Batch:0.323 | Total:0:01:45 | ETA:0:00:18 | Loss:0.6957360162995636 | top1:50.30461883544922
321/363 Data:0.000 | Batch:0.345 | Total:0:01:49 | ETA:0:00:14 | Loss:0.6956656549207146 | top1:50.34103775024414
331/363 Data:0.001 | Batch:0.336 | Total:0:01:52 | ETA:0:00:11 | Loss:0.6957335801643187 | top1:50.286216735839844
341/363 Data:0.000 | Batch:0.314 | Total:0:01:55 | ETA:0:00:08 | Loss:0.6957700591632697 | top1:50.27318572998047
351/363 Data:0.001 | Batch:0.374 | Total:0:01:59 | ETA:0:00:05 | Loss:0.6957709052284219 | top1:50.301395416259766
361/363 Data:0.000 | Batch:0.317 | Total:0:02:02 | ETA:0:00:01 | Loss:0.6957540976043554 | top1:50.33095169067383
42/42 Data:0.000 | Batch:0.035 | Total:0:00:08 | ETA:0:00:00 | Loss:0.694989178730891

251/363 Data:0.000 | Batch:0.323 | Total:0:01:22 | ETA:0:00:37 | Loss:0.6951448941135786 | top1:50.61228942871094
261/363 Data:0.000 | Batch:0.323 | Total:0:01:25 | ETA:0:00:33 | Loss:0.6951685123059942 | top1:50.641258239746094
271/363 Data:0.000 | Batch:0.317 | Total:0:01:28 | ETA:0:00:30 | Loss:0.6952348424059879 | top1:50.57875061035156
281/363 Data:0.000 | Batch:0.323 | Total:0:01:31 | ETA:0:00:27 | Loss:0.695175001629731 | top1:50.65742492675781
291/363 Data:0.000 | Batch:0.322 | Total:0:01:35 | ETA:0:00:24 | Loss:0.6951316513146731 | top1:50.65473175048828
301/363 Data:0.000 | Batch:0.322 | Total:0:01:38 | ETA:0:00:21 | Loss:0.6951636042309758 | top1:50.62773132324219
311/363 Data:0.000 | Batch:0.314 | Total:0:01:41 | ETA:0:00:17 | Loss:0.6951389446902505 | top1:50.629547119140625
321/363 Data:0.000 | Batch:0.314 | Total:0:01:44 | ETA:0:00:14 | Loss:0.6951227418358824 | top1:50.61485290527344
331/363 Data:0.000 | Batch:0.314 | Total:0:01:47 | ETA:0:00:11 | Loss:0.695165754985233