In [33]:
# Code taken from here : http://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html

import os

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [34]:
from config import domainData
from config import num_classes as NUM_CLASSES
from wdStackDomain_alexnet import WDDomain
from logger import Logger
import torch
import torch.nn as nn
import torch.autograd as autograd
import torch.nn.init as init
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.autograd import Function
import torch.nn.functional as F
import numpy as np
import logit
import torchvision
from torchvision import datasets, models, transforms
import time
import random
import copy
import datetime
import itertools
from tqdm import *

In [35]:
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
#         transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

In [36]:
use_gpu = True and torch.cuda.is_available()
train_dir = domainData['amazon'] # 'amazon', 'dslr', 'webcam'
val_dir = domainData['webcam']
num_classes = NUM_CLASSES['office']
wd_param = 0.1
gp_lambda = 10
num_iters = 75 # total number of training iterations
num_cls_train = 2 # epochs to train classifier
num_gen_train = 1 # epochs to train generator
base_lr = 1e-5
l2_param = 1e-5
weight_decay = 1e-6
g_loss_param = 0.2
EPOCHS=50
batch_size = 64 # batch_size for each of source and target samples
load_cls = False
log = False
text_log = True
exp_name = 'wd_tr_2step_Rsnt2blk_r3'

In [37]:
print("use gpu: ", use_gpu)

torch.manual_seed(7)
if use_gpu:
    torch.cuda.manual_seed(7)

use gpu:  True


In [38]:
image_datasets = {'train' : datasets.ImageFolder(train_dir,
                                          data_transforms['train']),
                  'val' : datasets.ImageFolder(val_dir,
                                          data_transforms['val'])
                 }
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size,
                                             shuffle=True, num_workers=4, drop_last=True)
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

model_ft = WDDomain(num_classes)

if load_cls:
    model_ft.load_cls()

if use_gpu:
    model_ft = model_ft.cuda()

In [39]:
def get_inifinite_dataloader(dataloader):
    data_iter = iter(dataloader)
    while(1):
        try:
            data = next(data_iter)
            yield data
        except StopIteration:
            data_iter = iter(dataloader)

In [40]:
# clscriterion = nn.CrossEntropyLoss()
clscriterion = logit.softmax_cross_entropy_with_logits(num_classes)

param_group1 = [
{'params' : model_ft._discriminator.parameters(), 'lr': 10*base_lr, 'weight_decay' : weight_decay } # , 'lr' : 1e-4, 'betas' : (0.5, 0.9)
]
# disc_opt = optim.Adam(param_group1)
# disc_opt = optim.ASGD(param_group1)
disc_opt = optim.SGD(param_group1, momentum=0.9)

param_group2 = [
{'params' : model_ft.features.net.features.parameters(), 'lr' : 0.01 * base_lr,
 'weight_decay' : weight_decay},
{'params' : model_ft.features.net.classifier.parameters(), 'lr' : base_lr, 'weight_decay' : weight_decay}, # , 'betas' : (0.5, 0.9
{'params' : model_ft.features.extra.parameters(), 'lr' : 10*base_lr, 'weight_decay' : weight_decay}
]
# gen_opt = optim.Adam(param_group2)
# gen_opt = optim.ASGD(param_group2)
gen_opt = optim.SGD(param_group2, momentum=0.9)

param_group3 = [
{'params' : model_ft.features.net.features.parameters(), 'lr' : 0.01 * base_lr,
 'weight_decay' : weight_decay},
{'params' : model_ft.features.net.classifier.parameters(), 'lr' : base_lr, 'weight_decay' : weight_decay}, # , 'betas' : (0.5, 0.9
{'params' : model_ft.features.extra.parameters(), 'lr' : 10*base_lr, 'weight_decay' : weight_decay},
{'params' : model_ft.classifier.parameters(), 'lr': 10*base_lr, 'weight_decay' : weight_decay} # , 'lr' : 1e-4
]
# cls_opt = optim.Adam(param_group3)
# cls_opt = optim.ASGD(param_group3)
cls_opt = optim.SGD(param_group3, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
# tar_lr_scheduler = lr_scheduler.StepLR(taroptimizer, step_size=7, gamma=0.1)

# cls_scheduler = None
cls_scheduler = lr_scheduler.StepLR(cls_opt, 20, 0.1)
# cls_scheduler = lr_scheduler.ExponentialLR(cls_opt, 0.1)
# gen_scheduler = None
gen_scheduler = lr_scheduler.StepLR(gen_opt, 20, 0.1)
# gen_scheduler = lr_scheduler.ExponentialLR(gen_opt, 0.1)
# disc_scheduler = None
disc_scheduler = lr_scheduler.StepLR(disc_opt, 20, 0.1)
# disc_scheduler = lr_scheduler.ExponentialLR(disc_opt, 0.1)

lr_schedulers = [cls_scheduler, gen_scheduler, disc_scheduler]
# lr_schedulers = None

softmax = nn.Softmax(dim=1)
l2_reg = logit.L2_Loss()

In [41]:
def test_model(model_ft, criterion, save_model=False, save_name=None):
    data_iter = iter(dataloaders['val'])

    model_ft.features.eval()
    model_ft.classifier.eval()
    model_ft._discriminator.eval()
    
    acc_val = 0
    steps = 0.
    for data in data_iter:
        img, lbl = data
        if use_gpu:
            img = img.cuda()
            lbl = lbl.cuda()
        img = Variable(img, volatile=True)
        lbl = Variable(lbl, requires_grad=False)

        feat_out = model_ft.features(img)
        out = model_ft.classifier(feat_out)

        loss = criterion(out, lbl)
        
        out1 = softmax(out)
        _, preds = torch.max(out1.data, 1)
#         print("preds size: ", preds.size(), " lbl data size: ", lbl.data.size())
        acc_val += torch.eq(preds, lbl.data).float().mean()
        steps = steps + 1
    # acc = acc_val / dataset_sizes['val']
    acc = acc_val / steps
    print("validation accuracy: {:.4f}".format(acc))
    if text_log:
        text_file.write("validation accuracy: {:.4f}\n".format(acc))
    if save_model:
        torch.save(model_ft.state_dict(), save_name)
    return

In [42]:
def train_model(model, clscriterion, disc_opt, gen_opt, cls_opt, lr_schedulers=None, num_epochs=25):
    since = time.time()

    for i in range(num_iters):
        
        running_clsacc, running_clsloss, steps = 0., 0., 0
        
        # train classifier on source
        for _ in range(num_cls_train):
            cls_src = iter(dataloaders['train'])
            for data in cls_src:
                srcimgs, srclbls = data
                
                srcimgs = srcimgs.cuda() if use_gpu else srcimgs
                srclbls = srclbls.cuda() if use_gpu else srclbls
                srcimgs, srclbls = Variable(srcimgs), Variable(srclbls)
                
                outimgs = model.features(srcimgs)
                cls_out = model.classifier(outimgs)
                cls_loss = clscriterion(cls_out, srclbls)
                
#                 l2_loss = None
#                 for name, param in model.classifier.named_parameters():
#                     if 'weight' in name:
#                         if l2_loss is None:
#                             l2_loss = l2_reg(param) * l2_param
#                         else:
#                             l2_loss += l2_reg(param) * l2_param
#                 for name, param in model.features.named_parameters():
#                     if 'weight' in name:
#                         l2_loss += l2_reg(param) * l2_param
                        
                _, preds = torch.max(softmax(cls_out).data, 1)
                running_clsacc += torch.eq(preds, srclbls.data).float().mean()
                running_clsloss += cls_loss.data[0]
                steps = steps + 1
#                 cls_loss_t = cls_loss + l2_loss
                cls_loss_t = cls_loss
                cls_opt.zero_grad()
                cls_loss_t.backward()
                cls_opt.step()
            print("classification loss: {:.4f}, classification acc: {:.4f}".
                  format(running_clsloss/steps, running_clsacc/steps))
                           
        for _ in range(num_gen_train):
            src_data = iter(dataloaders['train'])
            tar_data = get_inifinite_dataloader(dataloaders['val'])
            running_critic_loss, running_gen_loss, running_gp_loss, steps = 0., 0., 0., 0
            
            net = copy.deepcopy(model.features) # copy the network
            for params in net.parameters():
                params.requires_grad=False
            
            for srcinps, srclbls in src_data:
                tarinps, _ = next(tar_data)

                srcinps = srcinps.cuda() if use_gpu else srcinps
                srclbls = srclbls.cuda() if use_gpu else srclbls
                tarinps = tarinps.cuda() if use_gpu else tarinps
                srcinps, srclbls = Variable(srcinps), Variable(srclbls)
                tarinps = Variable(tarinps)

                real_out = net(srcinps)
                gen_out = model.features(tarinps)
                critic_flag = random.randint(0,1)
                critic_in = torch.cat([real_out, gen_out]) if critic_flag else torch.cat([gen_out, real_out])
                critic_out = model._discriminator(critic_in)

                D = critic_out[:batch_size] if critic_flag else critic_out[batch_size:]
                D_ = critic_out[batch_size:] if critic_flag else critic_out[:batch_size]

                critic_loss = (torch.sigmoid(D_) ** 2).mean() + ((1. - torch.sigmoid(D)) ** 2).mean()
                gen_loss = D.mean() - D_.mean()
                
#                 print("critic_loss: ", critic_loss.data)
#                 print("gen_loss: ", gen_loss.data)

                running_critic_loss += critic_loss.data[0]
                running_gen_loss += gen_loss.data[0]

                alpha = torch.Tensor(real_out.size()).uniform_(0,1)
                alpha = alpha.cuda() if use_gpu else alpha
                diff = real_out.data - gen_out.data
                interpolates = real_out.data + (alpha * diff)
                interpolates = Variable(interpolates, requires_grad=True)

                inter_out = model._discriminator(interpolates)
                ones = torch.ones(inter_out.size())
                ones = ones.cuda() if use_gpu else ones

                grads = autograd.grad(inter_out, interpolates, grad_outputs=ones,
                    retain_graph=True, create_graph=True, only_inputs=False)[0]
                gp = ((grads.norm(2, dim=1) - 1) ** 2).mean()

                running_gp_loss += gp.data[0]
#                 print("gp_loss: ", gp.data)

                for params in model._discriminator.parameters():
                    params.requires_grad=True

#                 l2_loss = None
#                 for name, param in model._discriminator.named_parameters():
#                     if 'weight' in name:
#                         if l2_loss is None:
#                             l2_loss = l2_reg(param) * l2_param
#                         else:
#                             l2_loss += l2_reg(param) * l2_param

#                 critic_loss_t = critic_loss + l2_loss + gp * gp_lambda
                critic_loss_t = critic_loss * wd_param + gp * gp_lambda
                disc_opt.zero_grad()
                critic_loss_t.backward(retain_graph=True)
                disc_opt.step()

                for params in model._discriminator.parameters():
                    params.requires_grad=False

#                 l2_loss = None
#                 for name, param in model.features.named_parameters():
#                     if 'weight' in name:
#                         if l2_loss is None:
#                             l2_loss = l2_reg(param) * l2_param
#                         else:
#                             l2_loss += l2_reg(param) * l2_param

#                 gen_loss_t = gen_loss + l2_loss
                gen_loss_t = gen_loss * g_loss_param

                gen_opt.zero_grad()
                gen_loss_t.backward()
                gen_opt.step()
                steps = steps + 1
            
            del net
            
        print("critic loss: {:.4f}, gen loss: {:.4f}, gp loss: {:.4f}".
              format(running_critic_loss/steps, running_gen_loss/steps, running_gp_loss/steps))
            
        if lr_schedulers:
            for x in lr_schedulers:
                if x is not None:
                    x.step()

#             l2_loss = None
#             for name, param in model.classifier.named_parameters():
#                 if 'weight' in name:
#                     if l2_loss is None:
#                         l2_loss = l2_reg(param) * l2_param
#                     else:
#                         l2_loss += l2_reg(param) * l2_param
#             for name, param in model.features.basenet.fc.named_parameters():
#                 if 'weight' in name:
#                     l2_loss += l2_reg(param) * l2_param

#             total_loss = clsloss + l2_loss + g_l * g_loss_param
#             total_loss.backward()
#             cls_opt.step()

        test_model(model, clscriterion, False, None)

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))    
    return

In [43]:
train_model(model_ft, clscriterion, disc_opt, gen_opt, cls_opt, lr_schedulers, num_epochs=EPOCHS)

save_name = "grl_model_with_transform.pth"
test_model(model_ft, clscriterion, False, save_name)
test_model(model_ft, clscriterion, False, save_name)
test_model(model_ft, clscriterion, False, save_name)
test_model(model_ft, clscriterion, False, save_name)
test_model(model_ft, clscriterion, False, save_name)

classification loss: 3.3006, classification acc: 0.1147
classification loss: 3.0096, classification acc: 0.2397
critic loss: 0.5056, gen loss: -0.0190, gp loss: 0.8125
validation accuracy: 0.3008
classification loss: 2.2655, classification acc: 0.5366
classification loss: 2.1152, classification acc: 0.5730
critic loss: 0.5091, gen loss: -0.0341, gp loss: 0.8120
validation accuracy: 0.4049
classification loss: 1.7462, classification acc: 0.6484
classification loss: 1.6660, classification acc: 0.6664
critic loss: 0.5119, gen loss: -0.0456, gp loss: 0.8120
validation accuracy: 0.4609
classification loss: 1.4605, classification acc: 0.7060
classification loss: 1.4109, classification acc: 0.7154
critic loss: 0.5148, gen loss: -0.0568, gp loss: 0.8116
validation accuracy: 0.4818
classification loss: 1.2819, classification acc: 0.7390
classification loss: 1.2495, classification acc: 0.7404
critic loss: 0.5178, gen loss: -0.0682, gp loss: 0.8111
validation accuracy: 0.4922
classification loss:

validation accuracy: 0.5156
classification loss: 0.5926, classification acc: 0.8544
classification loss: 0.5917, classification acc: 0.8532
critic loss: 0.6454, gen loss: -0.2664, gp loss: 0.8002
validation accuracy: 0.5169
classification loss: 0.5886, classification acc: 0.8516
classification loss: 0.5878, classification acc: 0.8540
critic loss: 0.6447, gen loss: -0.2639, gp loss: 0.7999
validation accuracy: 0.5169
classification loss: 0.5876, classification acc: 0.8537
classification loss: 0.5893, classification acc: 0.8535
critic loss: 0.6452, gen loss: -0.2654, gp loss: 0.7999
validation accuracy: 0.5156
classification loss: 0.5896, classification acc: 0.8544
classification loss: 0.5894, classification acc: 0.8551
critic loss: 0.6451, gen loss: -0.2657, gp loss: 0.7999
validation accuracy: 0.5156
classification loss: 0.5906, classification acc: 0.8537
classification loss: 0.5892, classification acc: 0.8528
critic loss: 0.6454, gen loss: -0.2663, gp loss: 0.8001
validation accuracy: