In [1]:
import json
import os
import time
import random
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import shutil

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets
from torchvision import transforms
from torchvision import models
from torchsummary import summary
from sklearn.model_selection import train_test_split

from model_pytorch import EfficientNet
from utils import Bar,Logger, AverageMeter, accuracy, mkdir_p, savefig
from warmup_scheduler import GradualWarmupScheduler
from sklearn.metrics import accuracy_score

from PIL import ImageFile, ImageOps
ImageFile.LOAD_TRUNCATED_IMAGES = True
import collections

In [2]:
# GPU Device
gpu_id = 3
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
use_cuda = torch.cuda.is_available()
print("GPU device %d:" %(gpu_id), use_cuda)

GPU device 3: True


# Arguments

In [3]:
data_dir = '/media/data2/dataset/GAN_ImageData/PGGAN_128/'

In [4]:
pretrained = ''
resume = './log/pggan/128/b0/siamese/checkpoint.pth.tar'

In [5]:
# Model
model_name = 'efficientnet-b0' # b0-b7 scale

# Optimization
num_classes = 128
epochs = 300
start_epoch = 0
train_batch = 160
test_batch = 160
lr = 0.04
schedule = [75,150,225]
momentum = 0.9
gamma = 0.1 # LR is multiplied by gamma on schedule

# CheckPoint
checkpoint = './log/pggan/128/b0/siamese' # dir
if not os.path.isdir(checkpoint):
    os.makedirs(checkpoint)
num_workers = 8

# Seed
manual_seed = 7
random.seed(manual_seed)
torch.cuda.manual_seed_all(manual_seed)

# Image
size = (128, 128)

# sp
sp_alpha = 0.1
sp_beta = 0.1
fc_name = '_fc.'

# iterative training
feedback = 0
# iter_time = [1000, 2000, 2500]

# cutmix
cm_prob = 0.5
cm_prob_init = 0.99
cm_prob_low = 0.01
cm_beta = 1.0

# constrastive
thresh = 0.5

best_acc = 0

In [6]:
state = {}
state['num_classes'] = num_classes
state['epochs'] = epochs
state['start_epoch'] = start_epoch
state['train_batch'] = train_batch
state['test_batch'] = test_batch
state['lr'] = lr
state['schedule'] = schedule
state['momentum'] = momentum
state['gamma'] = gamma

In [7]:
class SiameseNetworkDataset(Dataset):
    
    def __init__(self,imageFolderDataset,transform=None,should_invert=True):
        self.imageFolderDataset = imageFolderDataset    
        self.transform = transform
        self.should_invert = should_invert
        
    def __getitem__(self,index):
        img0_tuple = random.choice(self.imageFolderDataset.imgs)
        #we need to make sure approx 50% of images are in the same class
        should_get_same_class = random.randint(0,1) 
        if should_get_same_class:
            while True:
                #keep looping till the same class image is found
                img1_tuple = random.choice(self.imageFolderDataset.imgs) 
                if img0_tuple[1]==img1_tuple[1]:
                    break
        else:
            while True:
                #keep looping till a different class image is found
                
                img1_tuple = random.choice(self.imageFolderDataset.imgs) 
                if img0_tuple[1] !=img1_tuple[1]:
                    break

        img0 = Image.open(img0_tuple[0])
        img1 = Image.open(img1_tuple[0])
#         img0 = img0.convert("L")
#         img1 = img1.convert("L")
        
        if self.should_invert:
            img0 = PIL.ImageOps.invert(img0)
            img1 = PIL.ImageOps.invert(img1)

        if self.transform is not None:
            img0 = self.transform(img0)
            img1 = self.transform(img1)
        
        return img0, img1 , torch.from_numpy(np.array([int(img1_tuple[1]!=img0_tuple[1])],dtype=np.float32))
    
    def __len__(self):
        return len(self.imageFolderDataset.imgs)

In [8]:
class ContrastiveLoss(torch.nn.Module):
    """
    Contrastive loss function.
    Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
    """

    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2, keepdim = True)
        loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
                                      (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))


        return loss_contrastive

# Dataset

In [9]:
train_dir = os.path.join(data_dir, 'train')
val_dir = os.path.join(data_dir, 'validation')

train_aug = transforms.Compose([
#     transforms.RandomAffine(degrees=2, translate=(0.02, 0.02), scale=(0.98, 1.02), shear=2, fillcolor=(124,117,104)),
    transforms.Resize(size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
#     transforms.RandomErasing(p=0.3, scale=(0.02, 0.10), ratio=(0.3, 3.3), value=0, inplace=True),
])
val_aug = transforms.Compose([
    transforms.Resize(size),
    transforms.ToTensor(),
])

# pin_memory : cuda pin memeory use
train_ = SiameseNetworkDataset(datasets.ImageFolder(train_dir), transform=train_aug, should_invert=False)
train_loader = DataLoader(train_, shuffle=True, num_workers=num_workers, batch_size=train_batch)
val_ = SiameseNetworkDataset(datasets.ImageFolder(val_dir), transform=val_aug, should_invert=False)
val_loader = DataLoader(val_, shuffle=True, num_workers=num_workers, batch_size=test_batch)

# Model

In [10]:
model = EfficientNet.from_name(model_name, num_classes=num_classes,
                              override_params={'dropout_rate':0.0, 'drop_connect_rate':0.2})

# Pre-trained
if pretrained:
    print("=> using pre-trained model '{}'".format(pretrained))
    pt = torch.load(pretrained)['state_dict']
    model.load_state_dict(pt)

In [11]:
model.to('cuda')
cudnn.benchmark = True
print('    Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0))

    Total params: 4.17M


# Loss

In [12]:
criterion = ContrastiveLoss(margin=1.0).cuda()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
# optimizer = optim.Adam(model.parameters(), weight_decay=0)
scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=4, total_epoch=50, after_scheduler=scheduler_cosine)

In [13]:
# Resume
if resume:
    print('==> Resuming from checkpoint..')
    checkpoint = os.path.dirname(resume)
#     checkpoint = torch.load(resume)
    resume = torch.load(resume)
    best_acc = resume['best_acc']
    start_epoch = resume['epoch']
    model.load_state_dict(resume['state_dict'])
    optimizer.load_state_dict(resume['optimizer'])
    logger = Logger(os.path.join(checkpoint, 'log.txt'), resume=True)
else:
    logger = Logger(os.path.join(checkpoint, 'log.txt'))
    logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.'])

==> Resuming from checkpoint..


In [14]:
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

# Train

In [15]:
def train(train_loader, model, criterion, optimizer, epoch, use_cuda):
    model.train()
    
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    end = time.time()
    
    bar = Bar('Processing', max=len(train_loader))
    for batch_idx, (inputs0, inputs1, targets) in enumerate(train_loader):
        batch_size = inputs0.size(0)
        if batch_size < train_batch:
            continue
        # measure data loading time
        data_time.update(time.time() - end)

        if use_cuda:
            inputs0, inputs1, targets = inputs0.cuda(), inputs1.cuda(), targets.cuda()
            
        r = np.random.rand(1)
        if cm_beta > 0 and r < cm_prob:
            
            target_index = targets[targets==0]
            target_index = target_index.long().cuda()
            lam = np.random.beta(cm_beta, cm_beta)
            bbx1, bby1, bbx2, bby2 = rand_bbox(inputs0.size(), lam)
            inputs0[target_index, :, bbx1:bbx2, bby1:bby2] = inputs1[target_index, :, bbx1:bbx2, bby1:bby2]
        

        
        outputs0 = model(inputs0)
        outputs1 = model(inputs1)
        
        loss = criterion(outputs0, outputs1, targets)
            
        # compute output
        outputs = F.pairwise_distance(outputs0, outputs1, keepdim=True)

        # measure accuracy and record loss
        pred = outputs.data
        pred[pred < thresh] = 0.
        pred[pred >= thresh] = 1.
        prec1 = [accuracy_score(targets.data.cpu().numpy(), pred.cpu().numpy())]
        
        losses.update(loss.data.tolist(), inputs0.size(0))
        top1.update(prec1[0], inputs0.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # plot progress
        bar.suffix  = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} '.format(
                    batch=batch_idx + 1,
                    size=len(train_loader),
                    data=data_time.val,
                    bt=batch_time.val,
                    total=bar.elapsed_td,
                    eta=bar.eta_td,
                    loss=losses.avg,
                    top1=top1.avg,
                    )
        bar.next()
#         if batch_idx % 10 == 0:
    print('{batch}/{size} Data:{data:.3f} | Batch:{bt:.3f} | Total:{total:} | ETA:{eta:} | Loss:{loss:} | top1:{tp1:}'.format(
                 batch=batch_idx+1, size=len(train_loader), data=data_time.val, bt=batch_time.val, total=bar.elapsed_td, eta=bar.eta_td, loss=losses.avg, tp1=top1.avg))
    bar.finish()
    return (losses.avg, top1.avg)

In [16]:
def test(val_loader, model, criterion, epoch, use_cuda):
    global best_acc

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    bar = Bar('Processing', max=len(val_loader))
    with torch.no_grad():
        for batch_idx, (inputs0, inputs1, targets) in enumerate(val_loader):
            # measure data loading time
            data_time.update(time.time() - end)

            if use_cuda:
                inputs0, inputs1, targets = inputs0.cuda(), inputs1.cuda(), targets.cuda()

            # compute output
            outputs0 = model(inputs0)
            outputs1 = model(inputs1)
            loss = criterion(outputs0, outputs1, targets)
            outputs = F.pairwise_distance(outputs0, outputs1, keepdim=True)

            # measure accuracy and record loss
            pred = outputs.data
            pred[pred < thresh] = 0.
            pred[pred >= thresh] = 1.
            prec1 = [accuracy_score(targets.data.cpu().numpy(), pred.cpu().numpy())]
            losses.update(loss.data.tolist(), inputs0.size(0))
            top1.update(prec1[0], inputs0.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            # plot progress
            bar.suffix  = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:} | top1: {top1:}'.format(
                        batch=batch_idx + 1,
                        size=len(val_loader),
                        data=data_time.avg,
                        bt=batch_time.avg,
                        total=bar.elapsed_td,
                        eta=bar.eta_td,
                        loss=losses.avg,
                        top1=top1.avg,)
            bar.next()
        print('{batch}/{size} Data:{data:.3f} | Batch:{bt:.3f} | Total:{total:} | ETA:{eta:} | Loss:{loss:} | top1:{tp1:}'.format(
             batch=batch_idx+1, size=len(val_loader), data=data_time.val, bt=batch_time.val, total=bar.elapsed_td, eta=bar.eta_td, loss=losses.avg, tp1=top1.avg))
        bar.finish()
    return (losses.avg, top1.avg)

In [17]:
def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar'):
    filepath = os.path.join(checkpoint, filename)
    torch.save(state, filepath)
    if is_best:
        shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar'))

def adjust_learning_rate(optimizer, epoch):
    global state
    lr_set = [1, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5]
    lr_list = schedule.copy()
    lr_list.append(epoch)
    lr_list.sort()
    idx = lr_list.index(epoch)
    state['lr'] *= lr_set[idx]
    for param_group in optimizer.param_groups:
        param_group['lr'] = state['lr']

In [None]:
for epoch in range(start_epoch, epochs):    
    state['lr'] = optimizer.state_dict()['param_groups'][0]['lr']
    adjust_learning_rate(optimizer, epoch)
    print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, epochs, state['lr']))
    
    train_loss, train_acc = train(train_loader, model, criterion, optimizer, epoch, use_cuda)
    test_loss, test_acc = test(val_loader, model, criterion, epoch, use_cuda)

    logger.append([state['lr'], train_loss, test_loss, train_acc, test_acc])
    is_best = test_acc > best_acc
    best_acc = max(test_acc, best_acc)
    save_checkpoint({
        'epoch': epoch + 1,
        'state_dict' : model.state_dict(),
        'acc': test_acc,
        'best_acc': best_acc,
        'optimizer': optimizer.state_dict(),
    }, is_best, checkpoint=checkpoint)
    scheduler_warmup.step()


Epoch: [170 | 300] LR: 0.000004
803/803 Data:0.019 | Batch:0.396 | Total:0:10:13 | ETA:0:00:01 | Loss:0.019608123752113172 | top1:0.9775015586034913
201/201 Data:0.012 | Batch:0.641 | Total:0:00:29 | ETA:0:00:00 | Loss:0.07217540102677182 | top1:0.9173520249221184

Epoch: [171 | 300] LR: 0.000424
803/803 Data:0.023 | Batch:0.802 | Total:0:09:28 | ETA:0:00:01 | Loss:0.019560468421454664 | top1:0.9776028678304239
201/201 Data:0.048 | Batch:0.245 | Total:0:00:53 | ETA:0:00:00 | Loss:0.07033096096504514 | top1:0.9192834890965732

Epoch: [172 | 300] LR: 0.000448
803/803 Data:0.021 | Batch:0.758 | Total:0:10:29 | ETA:0:00:01 | Loss:0.01935934853292084 | top1:0.9778912094763093
201/201 Data:0.019 | Batch:0.159 | Total:0:00:49 | ETA:0:00:00 | Loss:0.07188618736939267 | top1:0.9170404984423676

Epoch: [173 | 300] LR: 0.000472
803/803 Data:0.050 | Batch:0.761 | Total:0:11:27 | ETA:0:00:01 | Loss:0.018587625312167184 | top1:0.978732855361596
201/201 Data:0.032 | Batch:0.309 | Total:0:00:56 | ETA

803/803 Data:0.010 | Batch:0.793 | Total:0:10:26 | ETA:0:00:01 | Loss:0.01915567363522357 | top1:0.978428927680798
201/201 Data:0.012 | Batch:0.152 | Total:0:00:47 | ETA:0:00:00 | Loss:0.0736843740944736 | top1:0.9146417445482866

Epoch: [202 | 300] LR: 0.001168
803/803 Data:0.010 | Batch:0.769 | Total:0:10:26 | ETA:0:00:01 | Loss:0.01897501092259842 | top1:0.9786081670822943
201/201 Data:0.011 | Batch:0.149 | Total:0:00:47 | ETA:0:00:00 | Loss:0.07135026071023347 | top1:0.9186292834890966

Epoch: [203 | 300] LR: 0.001192
803/803 Data:0.011 | Batch:0.780 | Total:0:10:24 | ETA:0:00:01 | Loss:0.018945925622717467 | top1:0.9783977556109725
201/201 Data:0.012 | Batch:0.150 | Total:0:00:47 | ETA:0:00:00 | Loss:0.07192592424952723 | top1:0.9172585669781932

Epoch: [204 | 300] LR: 0.001216
803/803 Data:0.010 | Batch:0.774 | Total:0:10:26 | ETA:0:00:01 | Loss:0.01915510846910118 | top1:0.9782029301745636
201/201 Data:0.010 | Batch:0.165 | Total:0:00:47 | ETA:0:00:00 | Loss:0.07413307763055849 

803/803 Data:0.026 | Batch:0.735 | Total:0:10:38 | ETA:0:00:01 | Loss:0.018526932411519998 | top1:0.9790211970074812
201/201 Data:0.012 | Batch:0.171 | Total:0:00:45 | ETA:0:00:00 | Loss:0.07034490730996444 | top1:0.9189096573208723

Epoch: [233 | 300] LR: 0.000159
803/803 Data:0.012 | Batch:0.794 | Total:0:10:37 | ETA:0:00:01 | Loss:0.018313874365797293 | top1:0.9790134039900249
201/201 Data:0.011 | Batch:0.159 | Total:0:00:46 | ETA:0:00:00 | Loss:0.07374223741322664 | top1:0.9163862928348909

Epoch: [234 | 300] LR: 0.000159
803/803 Data:0.012 | Batch:0.791 | Total:0:10:34 | ETA:0:00:01 | Loss:0.018539928412601988 | top1:0.9793640897755611
201/201 Data:0.012 | Batch:0.156 | Total:0:00:49 | ETA:0:00:00 | Loss:0.07241459677338229 | top1:0.9168535825545171

Epoch: [235 | 300] LR: 0.000159
803/803 Data:0.012 | Batch:0.809 | Total:0:10:30 | ETA:0:00:01 | Loss:0.018406592737976228 | top1:0.9792705735660848
201/201 Data:0.012 | Batch:0.171 | Total:0:00:49 | ETA:0:00:00 | Loss:0.0742249030701

803/803 Data:0.011 | Batch:0.773 | Total:0:10:20 | ETA:0:00:01 | Loss:0.018112806680317887 | top1:0.9797927057356608
201/201 Data:0.016 | Batch:0.170 | Total:0:00:47 | ETA:0:00:00 | Loss:0.06950906964013139 | top1:0.920404984423676

Epoch: [264 | 300] LR: 0.000152
803/803 Data:0.011 | Batch:0.813 | Total:0:10:27 | ETA:0:00:01 | Loss:0.01900968282147336 | top1:0.9783743765586035
201/201 Data:0.010 | Batch:0.150 | Total:0:00:47 | ETA:0:00:00 | Loss:0.06781050686106503 | top1:0.9212149532710281

Epoch: [265 | 300] LR: 0.000152
803/803 Data:0.021 | Batch:0.721 | Total:0:10:20 | ETA:0:00:01 | Loss:0.018461933472072426 | top1:0.9793017456359102
201/201 Data:0.017 | Batch:0.164 | Total:0:00:49 | ETA:0:00:00 | Loss:0.07342046785670278 | top1:0.9164174454828661

Epoch: [266 | 300] LR: 0.000151
803/803 Data:0.010 | Batch:0.774 | Total:0:10:22 | ETA:0:00:01 | Loss:0.018312809543081437 | top1:0.9795121571072319
201/201 Data:0.010 | Batch:0.165 | Total:0:00:47 | ETA:0:00:00 | Loss:0.069720189101599