In [27]:
import json
import os
import time
import random
from PIL import Image

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision import models
from torchsummary import summary

from model_pytorch import EfficientNet
from utils import Bar,Logger, AverageMeter, accuracy, mkdir_p, savefig

# Arguments

In [None]:
pretrained = ''
resume = ''

In [36]:
# Model
model_name = 'efficient-b7' # b0-b7 scale
data_dir = '/media/data2/dataset/fake_detection/'

# Optimization
num_classes = 2
epochs = 400
start_epoch = 0
train_batch = 256
test_batch = 200
lr = 0.1
schedule = [150, 225]
momentum = 0.9
gamma = 0.1 # LR is multiplied by gamma on schedule

# CheckPoint
checkpoint = './log/' # dir
if not os.path.isdir(checkpoint):
    os.mkdir(checkpoint)
num_workers = 4

# GPU Device
gpu_id = 0
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
use_cuda = torch.cuda.is_available()
print("GPU device %d:" %(gpu_id), use_cuda)

# Seed
manual_seed = 7
random.seed(manual_seed)
torch.cuda.manual_seed_all(manual_seed)

best_acc = 0

GPU device 0: True


# Dataset

In [None]:
traindir = os.path.join(data_dir, 'train')
valdir = os.path.join(data_dir, 'val')

train_aug = transforms.Compose([
    transforms.ToTensor(), 
    transforms.RandomHorizontalFlip(),])
val_aug = transforms.Compose([
    transforms.ToTensor(),])

# pin_memory : cuda pin memeory use
train_loader = DataLoader(datasets.ImageFolder(train_dir, transform=train_aug),
                          batch_size=train_batch, shuffle=True, num_workers=num_workers, pin_memory=True)
val_loader = DataLoader(datasets.ImageFolder(val_dir, val_aug),
                       batch_size=test_batch, shuffle=False, num_workers=num_workers, pin_memory=True)

# Model

In [None]:
model = EfficientNet.from_name(model_name, num_classes=num_classes)

# Pre-trained
if pretrained:
    print("=> using pre-trained model '{}'".format(pretrained))
    model.load_state_dict(torch.load(pretrained))

In [22]:
model.to('cuda')
cudnn.benchmark = True
print('    Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0))

    Total params: 63.79M


In [23]:
summary(model, input_size=(3,64,64), device='cuda')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         ZeroPad2d-1            [-1, 3, 65, 65]               0
Conv2dStaticSamePadding-2           [-1, 64, 32, 32]           1,728
         GroupNorm-3           [-1, 64, 32, 32]             128
MemoryEfficientSwish-4           [-1, 64, 32, 32]               0
         ZeroPad2d-5           [-1, 64, 34, 34]               0
Conv2dStaticSamePadding-6           [-1, 64, 32, 32]             576
         GroupNorm-7           [-1, 64, 32, 32]             128
MemoryEfficientSwish-8           [-1, 64, 32, 32]               0
          Identity-9             [-1, 64, 1, 1]               0
Conv2dStaticSamePadding-10             [-1, 16, 1, 1]           1,040
MemoryEfficientSwish-11             [-1, 16, 1, 1]               0
         Identity-12             [-1, 16, 1, 1]               0
Conv2dStaticSamePadding-13             [-1, 64, 1, 1]           1,088
         I

        Identity-121          [-1, 288, 16, 16]               0
Conv2dStaticSamePadding-122           [-1, 48, 16, 16]          13,824
       GroupNorm-123           [-1, 48, 16, 16]              96
     MBConvBlock-124           [-1, 48, 16, 16]               0
        Identity-125           [-1, 48, 16, 16]               0
Conv2dStaticSamePadding-126          [-1, 288, 16, 16]          13,824
       GroupNorm-127          [-1, 288, 16, 16]             576
MemoryEfficientSwish-128          [-1, 288, 16, 16]               0
       ZeroPad2d-129          [-1, 288, 18, 18]               0
Conv2dStaticSamePadding-130          [-1, 288, 16, 16]           2,592
       GroupNorm-131          [-1, 288, 16, 16]             576
MemoryEfficientSwish-132          [-1, 288, 16, 16]               0
        Identity-133            [-1, 288, 1, 1]               0
Conv2dStaticSamePadding-134             [-1, 12, 1, 1]           3,468
MemoryEfficientSwish-135             [-1, 12, 1, 1]               0


     MBConvBlock-243             [-1, 80, 8, 8]               0
        Identity-244             [-1, 80, 8, 8]               0
Conv2dStaticSamePadding-245            [-1, 480, 8, 8]          38,400
       GroupNorm-246            [-1, 480, 8, 8]             960
MemoryEfficientSwish-247            [-1, 480, 8, 8]               0
       ZeroPad2d-248          [-1, 480, 12, 12]               0
Conv2dStaticSamePadding-249            [-1, 480, 8, 8]          12,000
       GroupNorm-250            [-1, 480, 8, 8]             960
MemoryEfficientSwish-251            [-1, 480, 8, 8]               0
        Identity-252            [-1, 480, 1, 1]               0
Conv2dStaticSamePadding-253             [-1, 20, 1, 1]           9,620
MemoryEfficientSwish-254             [-1, 20, 1, 1]               0
        Identity-255             [-1, 20, 1, 1]               0
Conv2dStaticSamePadding-256            [-1, 480, 1, 1]          10,080
        Identity-257            [-1, 480, 8, 8]               0


       GroupNorm-365            [-1, 960, 4, 4]           1,920
MemoryEfficientSwish-366            [-1, 960, 4, 4]               0
       ZeroPad2d-367            [-1, 960, 6, 6]               0
Conv2dStaticSamePadding-368            [-1, 960, 4, 4]           8,640
       GroupNorm-369            [-1, 960, 4, 4]           1,920
MemoryEfficientSwish-370            [-1, 960, 4, 4]               0
        Identity-371            [-1, 960, 1, 1]               0
Conv2dStaticSamePadding-372             [-1, 40, 1, 1]          38,440
MemoryEfficientSwish-373             [-1, 40, 1, 1]               0
        Identity-374             [-1, 40, 1, 1]               0
Conv2dStaticSamePadding-375            [-1, 960, 1, 1]          39,360
        Identity-376            [-1, 960, 4, 4]               0
Conv2dStaticSamePadding-377            [-1, 160, 4, 4]         153,600
       GroupNorm-378            [-1, 160, 4, 4]             320
     MBConvBlock-379            [-1, 160, 4, 4]               0


Conv2dStaticSamePadding-487           [-1, 1344, 4, 4]          33,600
       GroupNorm-488           [-1, 1344, 4, 4]           2,688
MemoryEfficientSwish-489           [-1, 1344, 4, 4]               0
        Identity-490           [-1, 1344, 1, 1]               0
Conv2dStaticSamePadding-491             [-1, 56, 1, 1]          75,320
MemoryEfficientSwish-492             [-1, 56, 1, 1]               0
        Identity-493             [-1, 56, 1, 1]               0
Conv2dStaticSamePadding-494           [-1, 1344, 1, 1]          76,608
        Identity-495           [-1, 1344, 4, 4]               0
Conv2dStaticSamePadding-496            [-1, 224, 4, 4]         301,056
       GroupNorm-497            [-1, 224, 4, 4]             448
     MBConvBlock-498            [-1, 224, 4, 4]               0
        Identity-499            [-1, 224, 4, 4]               0
Conv2dStaticSamePadding-500           [-1, 1344, 4, 4]         301,056
       GroupNorm-501           [-1, 1344, 4, 4]           2,6

        Identity-609           [-1, 1344, 1, 1]               0
Conv2dStaticSamePadding-610             [-1, 56, 1, 1]          75,320
MemoryEfficientSwish-611             [-1, 56, 1, 1]               0
        Identity-612             [-1, 56, 1, 1]               0
Conv2dStaticSamePadding-613           [-1, 1344, 1, 1]          76,608
        Identity-614           [-1, 1344, 4, 4]               0
Conv2dStaticSamePadding-615            [-1, 224, 4, 4]         301,056
       GroupNorm-616            [-1, 224, 4, 4]             448
     MBConvBlock-617            [-1, 224, 4, 4]               0
        Identity-618            [-1, 224, 4, 4]               0
Conv2dStaticSamePadding-619           [-1, 1344, 4, 4]         301,056
       GroupNorm-620           [-1, 1344, 4, 4]           2,688
MemoryEfficientSwish-621           [-1, 1344, 4, 4]               0
       ZeroPad2d-622           [-1, 1344, 8, 8]               0
Conv2dStaticSamePadding-623           [-1, 1344, 4, 4]          33,6

        Identity-731             [-1, 96, 1, 1]               0
Conv2dStaticSamePadding-732           [-1, 2304, 1, 1]         223,488
        Identity-733           [-1, 2304, 2, 2]               0
Conv2dStaticSamePadding-734            [-1, 384, 2, 2]         884,736
       GroupNorm-735            [-1, 384, 2, 2]             768
     MBConvBlock-736            [-1, 384, 2, 2]               0
        Identity-737            [-1, 384, 2, 2]               0
Conv2dStaticSamePadding-738           [-1, 2304, 2, 2]         884,736
       GroupNorm-739           [-1, 2304, 2, 2]           4,608
MemoryEfficientSwish-740           [-1, 2304, 2, 2]               0
       ZeroPad2d-741           [-1, 2304, 6, 6]               0
Conv2dStaticSamePadding-742           [-1, 2304, 2, 2]          57,600
       GroupNorm-743           [-1, 2304, 2, 2]           4,608
MemoryEfficientSwish-744           [-1, 2304, 2, 2]               0
        Identity-745           [-1, 2304, 1, 1]               0
Conv

Conv2dStaticSamePadding-853            [-1, 384, 2, 2]         884,736
       GroupNorm-854            [-1, 384, 2, 2]             768
     MBConvBlock-855            [-1, 384, 2, 2]               0
        Identity-856            [-1, 384, 2, 2]               0
Conv2dStaticSamePadding-857           [-1, 2304, 2, 2]         884,736
       GroupNorm-858           [-1, 2304, 2, 2]           4,608
MemoryEfficientSwish-859           [-1, 2304, 2, 2]               0
       ZeroPad2d-860           [-1, 2304, 4, 4]               0
Conv2dStaticSamePadding-861           [-1, 2304, 2, 2]          20,736
       GroupNorm-862           [-1, 2304, 2, 2]           4,608
MemoryEfficientSwish-863           [-1, 2304, 2, 2]               0
        Identity-864           [-1, 2304, 1, 1]               0
Conv2dStaticSamePadding-865             [-1, 96, 1, 1]         221,280
MemoryEfficientSwish-866             [-1, 96, 1, 1]               0
        Identity-867             [-1, 96, 1, 1]               0


# Loss

In [25]:
criterion = nn.BCELoss().cuda()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

In [33]:
# Resume
if resume:
    print('==> Resuming from checkpoint..')
    checkpoint = os.path.dirname(resume)
    checkpoint = torch.load(resume)
    best_acc = checkpoint['best_acc']
    start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    logger = Logger(os.path.join(checkpoint, 'log.txt'), resume=True)
else:
    logger = Logger(os.path.join(checkpoint, 'log.txt'))
    logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.'])

# Train

In [38]:
def train(train_loader, model, criterion, optimizer, epoch, use_cuda):
    model.train()
    torch.set_grad_enabled(True)
    
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    end = time.time()
    
    bar = Bar('Processing', max=len(train_loader))
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        batch_size = inputs.size(0)
        if batch_size < args.train_batch:
            continue
        # measure data loading time
        data_time.update(time.time() - end)

        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

        # compute output
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # measure accuracy and record loss
        prec1 = accuracy(outputs.data, targets.data)
        losses.update(loss.data[0], inputs.size(0))
        top1.update(prec1[0], inputs.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # plot progress
        bar.suffix  = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format(
                    batch=batch_idx + 1,
                    size=len(train_loader),
                    data=data_time.val,
                    bt=batch_time.val,
                    total=bar.elapsed_td,
                    eta=bar.eta_td,
                    loss=losses.avg,
                    top1=top1.avg,
                    )
        bar.next()
    bar.finish()
    return (losses.avg, top1.avg)

In [39]:
def test(val_loader, model, criterion, epoch, use_cuda):
    global best_acc

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to evaluate mode
    model.eval()
    torch.set_grad_enabled(False)

    end = time.time()
    bar = Bar('Processing', max=len(val_loader))
    for batch_idx, (inputs, targets) in enumerate(val_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        inputs, targets = torch.autograd.Variable(inputs, volatile=True), torch.autograd.Variable(targets)

        # compute output
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # measure accuracy and record loss
        prec1 = accuracy(outputs.data, targets.data, topk=(1))
        losses.update(loss.data[0], inputs.size(0))
        top1.update(prec1[0], inputs.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # plot progress
        bar.suffix  = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format(
                    batch=batch_idx + 1,
                    size=len(val_loader),
                    data=data_time.avg,
                    bt=batch_time.avg,
                    total=bar.elapsed_td,
                    eta=bar.eta_td,
                    loss=losses.avg,
                    top1=top1.avg,)
        bar.next()
    bar.finish()
    return (losses.avg, top1.avg)

In [40]:
def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar'):
    filepath = os.path.join(checkpoint, filename)
    torch.save(state, filepath)
    if is_best:
        shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar'))

def adjust_learning_rate(optimizer, epoch):
    global state
    if epoch in schedule:
        state['lr'] *= gamma
        for param_group in optimizer.param_groups:
            param_group['lr'] = state['lr']

In [None]:
for epoch in range(start_epoch, epochs):
    adjust_learning_rate(optimizer, epoch)
    print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, epochs, state['lr']))
    
    train_loss, train_acc = train(train_loader, model, criterion, optimizer, epoch, use_cuda)
    test_loss, test_acc = test(val_loader, model, criterion, epoch, use_cuda)
    
    logger.append([state['lr'], train_loss, test_loss, train_acc, test_acc])
    
    is_best = test_acc > best_acc
    best_acc = max(test_acc, best_acc)
    save_checkpoint({
        'epoch': epoch + 1,
        'state_dict' : model.state_dict(),
        'acc': test_acc,
        'best_acc': best_acc,
        'optimizer': optimizer.state_dict(),
    }, is_best, checkpoint=checkpoint)

In [None]:
logger.close()
print('Best acc:', best_acc)