In [1]:
import json
import os
import time
import random
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import shutil
import collections

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets
from torchvision import transforms
from torchvision import models
from torchsummary import summary

from model_pytorch import resnext50_32x4d
from utils import Bar,Logger, AverageMeter, accuracy, mkdir_p, savefig
from warmup_scheduler import GradualWarmupScheduler
from sklearn.metrics import accuracy_score, roc_auc_score
from sklearn.model_selection import train_test_split

from scipy.ndimage.filters import gaussian_filter
import cv2

from PIL import ImageFile, ImageOps
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [2]:
# GPU Device
gpu_id = 2
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
use_cuda = torch.cuda.is_available()
print("GPU device %d:" %(gpu_id), use_cuda)

GPU device 2: True


# Arguments

In [3]:
data_dir = '/media/data2/dataset/GAN_ImageData/PGGAN_128'

In [4]:
pretrained = ''
resume = ''

In [5]:
# Model
model_name = 'resnext32x4d' # b0-b7 scale

# Optimization
num_classes = 2
epochs = 400
start_epoch = 0
train_batch = 200
test_batch = 200
lr = 0.04
schedule = [75, 175, 250]
momentum = 0.9
gamma = 0.1 # LR is multiplied by gamma on schedule

# CheckPoint
checkpoint = './log/pggan/128/32x4d/aug' # dir
if not os.path.isdir(checkpoint):
    os.makedirs(checkpoint)
num_workers = 8

# Seed
manual_seed = 7
random.seed(manual_seed)
torch.cuda.manual_seed_all(manual_seed)

# Image
size = (128, 128)

# cutmix
cm_prob = 0.5
cm_beta = 1.0

# augmentation
blur_prob = 0.2
blog_sig = 0.5
jpg_prob = 0.2

best_acc = 0

In [6]:
state = {}
state['num_classes'] = num_classes
state['epochs'] = epochs
state['start_epoch'] = start_epoch
state['train_batch'] = train_batch
state['test_batch'] = test_batch
state['lr'] = lr
state['schedule'] = schedule
state['momentum'] = momentum
state['gamma'] = gamma

In [7]:
def data_augment(img):
    img = np.array(img)

    if random.random() < blur_prob:
        sig = np.random.uniform(0.0, 3.0)
        gaussian_blur(img, sig)

    if random.random() < jpg_prob:
        qual = np.random.uniform(30.0, 100.0)
        img = cv2_jpg(img, qual)

    return Image.fromarray(img)


def gaussian_blur(img, sigma):
    gaussian_filter(img[:,:,0], output=img[:,:,0], sigma=sigma)
    gaussian_filter(img[:,:,1], output=img[:,:,1], sigma=sigma)
    gaussian_filter(img[:,:,2], output=img[:,:,2], sigma=sigma)


def cv2_jpg(img, compress_val):
    img_cv2 = img[:,:,::-1]
    encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), compress_val]
    result, encimg = cv2.imencode('.jpg', img_cv2, encode_param)
    decimg = cv2.imdecode(encimg, 1)
    return decimg[:,:,::-1]

In [8]:
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

# Dataset

In [9]:
train_dir = os.path.join(data_dir, 'train')
val_dir = os.path.join(data_dir, 'validation')    

train_aug = transforms.Compose([
    transforms.Lambda(lambda img: data_augment(img)),
    transforms.Resize(size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
#     transforms.RandomErasing(p=0.3, scale=(0.02, 0.10), ratio=(0.3, 3.3), value=0, inplace=True),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
val_aug = transforms.Compose([
    transforms.Resize(size),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# pin_memory : cuda pin memeory use
train_loader = DataLoader(datasets.ImageFolder(train_dir, transform=train_aug),
                          batch_size=train_batch, shuffle=True, num_workers=num_workers, pin_memory=True)
val_loader = DataLoader(datasets.ImageFolder(val_dir, val_aug),
                       batch_size=test_batch, shuffle=True, num_workers=num_workers, pin_memory=True)

# Model

In [10]:
model = resnext50_32x4d(pretrained=False, num_classes=2)

# Pre-trained
if pretrained:
    print("=> using pre-trained model '{}'".format(pretrained))
    model.load_state_dict(torch.load(pretrained)['state_dict'])

In [11]:
model.to('cuda')
cudnn.benchmark = True
print('    Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0))

    Total params: 22.98M


# Loss

In [12]:
criterion = nn.CrossEntropyLoss().cuda()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=1e-4)
scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=8, total_epoch=10, after_scheduler=scheduler_cosine)

In [13]:
# Resume
if resume:
    print('==> Resuming from checkpoint..')
    checkpoint = os.path.dirname(resume)
#     checkpoint = torch.load(resume)
    resume = torch.load(resume)
    best_acc = resume['best_acc']
    start_epoch = resume['epoch']
    model.load_state_dict(resume['state_dict'])
    optimizer.load_state_dict(resume['optimizer'])
    logger = Logger(os.path.join(checkpoint, 'log.txt'), resume=True)
else:
    logger = Logger(os.path.join(checkpoint, 'log.txt'))
    logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.', 'Train AUROC.', 'Valid AUROC.'])

# Train

In [14]:
def train(train_loader, model, criterion, optimizer, epoch, use_cuda):
    model.train()
    
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    arc = AverageMeter()
    end = time.time()
    
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        batch_size = inputs.size(0)
        if batch_size < train_batch:
            continue
        # measure data loading time
        data_time.update(time.time() - end)

        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
            
        r = np.random.rand(1)
        if cm_beta > 0 and r < cm_prob:
            
            rand_index = torch.randperm(inputs.size()[0]).cuda()
            tt= targets[rand_index]
            boolean = targets==tt
            rand_index = rand_index[boolean]
            lam = np.random.beta(cm_beta, cm_beta)
            bbx1, bby1, bbx2, bby2 = rand_bbox(inputs.size(), lam)
            inputs[boolean, :, bbx1:bbx2, bby1:bby2] = inputs[rand_index, :, bbx1:bbx2, bby1:bby2]

        # compute output
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # measure accuracy and record loss
        prec1 = accuracy(outputs.data, targets.data)
        auroc = roc_auc_score(targets.cpu().detach().numpy(), outputs.cpu().detach().numpy()[:,1])
        losses.update(loss.data.tolist(), inputs.size(0))
        top1.update(prec1[0], inputs.size(0))
        arc.update(auroc, inputs.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        
        if batch_idx % 100 == 0:
            print('{batch}/{size} | Loss:{loss:.4f} | top1:{tp1:.4f} | AUROC:{ac:.4f}'.format(
                     batch=batch_idx+1, size=len(train_loader), loss=losses.avg, tp1=top1.avg, ac=arc.avg))
    print('{batch}/{size} | Loss:{loss:.4f} | top1:{tp1:.4f} | AUROC:{ac:.4f}'.format(
                     batch=batch_idx+1, size=len(train_loader), loss=losses.avg, tp1=top1.avg, ac=arc.avg))
    return (losses.avg, top1.avg, arc.avg)

In [15]:
def test(val_loader, model, criterion, epoch, use_cuda):
    global best_acc

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    arc = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(val_loader):
            # measure data loading time
            data_time.update(time.time() - end)

            if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda()

            # compute output
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            # measure accuracy and record loss
            prec1 = accuracy(outputs.data, targets.data)
            auroc = roc_auc_score(targets.cpu().detach().numpy(), outputs.cpu().detach().numpy()[:,1])
            losses.update(loss.data.tolist(), inputs.size(0))
            top1.update(prec1[0], inputs.size(0))
            arc.update(auroc, inputs.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

    print('{batch}/{size} | Loss:{loss:.4f} | top1:{tp1:.4f} | AUROC:{ac:.4f}'.format(
         batch=batch_idx+1, size=len(val_loader), loss=losses.avg, tp1=top1.avg, ac=arc.avg))
    return (losses.avg, top1.avg, arc.avg)

In [16]:
def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar'):
    filepath = os.path.join(checkpoint, filename)
    torch.save(state, filepath)
    if is_best:
        shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar'))

def adjust_learning_rate(optimizer, epoch):
    global state
    lr_set = [1, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5]
    lr_list = schedule.copy()
    lr_list.append(epoch)
    lr_list.sort()
    idx = lr_list.index(epoch)
    state['lr'] *= lr_set[idx]
    for param_group in optimizer.param_groups:
        param_group['lr'] = state['lr']

In [17]:
for epoch in range(start_epoch, epochs):
    state['lr'] = optimizer.state_dict()['param_groups'][0]['lr']
    adjust_learning_rate(optimizer, epoch)
    print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, epochs, state['lr']))
    
    train_loss, train_acc, train_auroc = train(train_loader, model, criterion, optimizer, epoch, use_cuda)
    test_loss, test_acc, test_auroc = test(val_loader, model, criterion, epoch, use_cuda)
    
    logger.append([state['lr'], train_loss, test_loss, train_acc, test_acc, train_auroc, test_auroc])
    scheduler_warmup.step()

    is_best = test_acc > best_acc
    best_acc = max(test_acc, best_acc)
    save_checkpoint({
        'epoch': epoch + 1,
        'state_dict' : model.state_dict(),
        'acc': test_acc,
        'best_acc': best_acc,
        'optimizer': optimizer.state_dict(),
    }, is_best, checkpoint=checkpoint)


Epoch: [1 | 400] LR: 0.040000
1/643 | Loss:0.6926 | top1:51.0000 | AUROC:0.4973
101/643 | Loss:3.7560 | top1:60.7426 | AUROC:0.6804
201/643 | Loss:2.1371 | top1:68.4453 | AUROC:0.7447
301/643 | Loss:1.5797 | top1:71.8571 | AUROC:0.7706
401/643 | Loss:1.2955 | top1:73.8317 | AUROC:0.7842
501/643 | Loss:1.1244 | top1:74.9471 | AUROC:0.7939
601/643 | Loss:1.0055 | top1:76.0133 | AUROC:0.8012
643/643 | Loss:0.9692 | top1:76.2687 | AUROC:0.8030
161/161 | Loss:0.1993 | top1:98.9252 | AUROC:0.9999

Epoch: [2 | 400] LR: 0.068000
1/643 | Loss:0.4154 | top1:79.5000 | AUROC:0.8716
101/643 | Loss:0.4256 | top1:80.5545 | AUROC:0.8313
201/643 | Loss:0.4164 | top1:81.0771 | AUROC:0.8374
301/643 | Loss:0.4136 | top1:81.2824 | AUROC:0.8393
401/643 | Loss:0.4120 | top1:81.2294 | AUROC:0.8397
501/643 | Loss:0.4113 | top1:81.2126 | AUROC:0.8397
601/643 | Loss:0.4094 | top1:81.3012 | AUROC:0.8411
643/643 | Loss:0.4086 | top1:81.3388 | AUROC:0.8415
161/161 | Loss:0.1676 | top1:99.5203 | AUROC:1.0000

Epoch

401/643 | Loss:0.0887 | top1:96.5648 | AUROC:0.9961
501/643 | Loss:0.0902 | top1:96.5240 | AUROC:0.9961
601/643 | Loss:0.0924 | top1:96.4301 | AUROC:0.9959
643/643 | Loss:0.0915 | top1:96.4572 | AUROC:0.9960
161/161 | Loss:0.0052 | top1:99.8598 | AUROC:1.0000

Epoch: [18 | 400] LR: 0.319822
1/643 | Loss:0.1374 | top1:96.0000 | AUROC:0.9946
101/643 | Loss:0.0839 | top1:96.7624 | AUROC:0.9968
201/643 | Loss:0.0854 | top1:96.7164 | AUROC:0.9965
301/643 | Loss:0.0859 | top1:96.7243 | AUROC:0.9964
401/643 | Loss:0.0897 | top1:96.5598 | AUROC:0.9962
501/643 | Loss:0.0906 | top1:96.5269 | AUROC:0.9960
601/643 | Loss:0.0903 | top1:96.5532 | AUROC:0.9960
643/643 | Loss:0.0896 | top1:96.5818 | AUROC:0.9961
161/161 | Loss:0.0279 | top1:99.0343 | AUROC:1.0000

Epoch: [19 | 400] LR: 0.319758
1/643 | Loss:0.0630 | top1:97.5000 | AUROC:0.9978
101/643 | Loss:0.0865 | top1:96.7079 | AUROC:0.9965
201/643 | Loss:0.0797 | top1:96.9453 | AUROC:0.9968
301/643 | Loss:0.0806 | top1:96.9153 | AUROC:0.9966
401/


Epoch: [34 | 400] LR: 0.317617
1/643 | Loss:0.0336 | top1:98.5000 | AUROC:0.9993
101/643 | Loss:0.0685 | top1:97.4653 | AUROC:0.9980
201/643 | Loss:0.0650 | top1:97.5920 | AUROC:0.9980
301/643 | Loss:0.0650 | top1:97.5615 | AUROC:0.9980
401/643 | Loss:0.0662 | top1:97.4913 | AUROC:0.9980
501/643 | Loss:0.0686 | top1:97.3842 | AUROC:0.9978
601/643 | Loss:0.0691 | top1:97.3677 | AUROC:0.9977
643/643 | Loss:0.0686 | top1:97.3910 | AUROC:0.9978
161/161 | Loss:0.0218 | top1:99.1838 | AUROC:0.9999

Epoch: [35 | 400] LR: 0.317397
1/643 | Loss:0.0489 | top1:98.0000 | AUROC:0.9986
101/643 | Loss:0.0715 | top1:97.3218 | AUROC:0.9974
201/643 | Loss:0.0665 | top1:97.5572 | AUROC:0.9978
301/643 | Loss:0.0656 | top1:97.5930 | AUROC:0.9979
401/643 | Loss:0.0667 | top1:97.5536 | AUROC:0.9978
501/643 | Loss:0.0696 | top1:97.4152 | AUROC:0.9978
601/643 | Loss:0.0703 | top1:97.3810 | AUROC:0.9977
643/643 | Loss:0.0705 | top1:97.3715 | AUROC:0.9977
161/161 | Loss:0.0651 | top1:97.5140 | AUROC:0.9999

Epo

401/643 | Loss:0.0642 | top1:97.5162 | AUROC:0.9983
501/643 | Loss:0.0643 | top1:97.5200 | AUROC:0.9983
601/643 | Loss:0.0659 | top1:97.4609 | AUROC:0.9982
643/643 | Loss:0.0658 | top1:97.4642 | AUROC:0.9982
161/161 | Loss:0.0144 | top1:99.6137 | AUROC:1.0000

Epoch: [51 | 400] LR: 0.312553
1/643 | Loss:0.0777 | top1:97.5000 | AUROC:0.9966
101/643 | Loss:0.0592 | top1:97.7525 | AUROC:0.9983
201/643 | Loss:0.0595 | top1:97.7438 | AUROC:0.9984
301/643 | Loss:0.0575 | top1:97.8173 | AUROC:0.9984
401/643 | Loss:0.0587 | top1:97.7544 | AUROC:0.9985
501/643 | Loss:0.0606 | top1:97.6607 | AUROC:0.9984
601/643 | Loss:0.0600 | top1:97.6930 | AUROC:0.9984
643/643 | Loss:0.0602 | top1:97.6877 | AUROC:0.9983
161/161 | Loss:0.0135 | top1:99.5389 | AUROC:1.0000

Epoch: [52 | 400] LR: 0.312169
1/643 | Loss:0.1127 | top1:97.0000 | AUROC:0.9926
101/643 | Loss:0.0568 | top1:97.8416 | AUROC:0.9984
201/643 | Loss:0.0617 | top1:97.7289 | AUROC:0.9982
301/643 | Loss:0.0632 | top1:97.6478 | AUROC:0.9982
401/

KeyboardInterrupt: 