Domain Adaptation using WDGRL architecture

In [24]:
%set_env CUDA_DEVICE_ORDER=PCI_BUS_ID
%set_env CUDA_VISIBLE_DEVICES=1

env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=1


In [25]:
%matplotlib notebook
from config import domainData
from config import num_classes as NUM_CLASSES
from torchvision import datasets, models, transforms
import yaml
import numpy as np
import matplotlib.pyplot as plt
import pickle as pkl
import logit
from random import shuffle
import math
import utils

In [26]:
import torch
import torch.nn as nn
import torch.autograd as autograd
import torch.optim as optim
from torch.autograd import Variable

In [27]:
# src_pkl = domainData['amazon_resnet50']
# tar_pkl = domainData['webcam_resnet50']
src_pkl = "./amazon_alexnet_4096.pkl"
tar_pkl = "./webcam_alexnet_4096.pkl"
# src_pkl = './amazon_4096.yml'
# tar_pkl = './webcam_4096.yml'

In [28]:
with open(src_pkl, 'rb') as f:
    src_data = pkl.load(f)
with open(tar_pkl, 'rb') as f:
    tar_data = pkl.load(f)

# with open(src_pkl, 'r') as f:
#     source = yaml.load(f)
# with open(tar_pkl, 'r') as f:
#     target = yaml.load(f)

def batch_generator(data, batch_size, shuffle=True):
    len_data = len(data['features'])
    def shuffle_data():
        nonlocal len_data
        idx = list(range(len_data))
        idx = np.random.permutation(idx)
        t_features = [data['features'][i] for i in idx]
        t_labels = [data['labels'][i] for i in idx]
        data['features'] = t_features
        data['labels'] = t_labels
    if shuffle:
        shuffle_data()
    batch_count = 0
    while True:
        if batch_count * batch_size + batch_size >= len_data:
            batch_count = 0
            if shuffle:
                shuffle_data()
        start = (int) (batch_count * batch_size)
        end = (int) (start + batch_size)
        batch_count += 1
        yield torch.stack(data['features'][start:end]), np.vstack(src_data['labels'][start:end])

In [29]:
wd_param = 0.1
gp_param = 10
lr_wd_D = 1e-4
D_train_num = 10

l2_param = 1e-5
lr = 1e-5
batch_size = 32
num_steps = 1800
num_class = 31
n_input = 4096
n_hidden = [2048, 1024, 512, 256]
lr_sch = None

use_gpu = True

In [30]:
xs, ys = src_data['features'], src_data['labels']
xt, yt = tar_data['features'], tar_data['labels']
ys = ys.squeeze()
yt = yt.squeeze()
# xs = source['train']
# ys = source['train_labels']
# xt = target['train']
# yt = target['train_labels']

In [31]:
class _Generator(nn.Module):
    def __init__(self, act=nn.ReLU(inplace=True)):
        super(_Generator, self).__init__()
        self.h1 = nn.Sequential(nn.Linear(n_input, n_hidden[0]), act,
                                nn.Linear(n_hidden[0], n_hidden[1]), act,
                                nn.Linear(n_hidden[1], n_hidden[2]), act,
                                nn.Linear(n_hidden[2], n_hidden[3]), act
                                )
        utils.weight_init(self.h1)

    def forward(self, x):
        out = self.h1(x)
        return out

class _Classifier(nn.Module):
    """
    Using nn.CrossEntropyLoss() for tf.nn.softmax_cross_entropy_with_logits()
    There is difference in target labels. This loss requires target labels
    in range [0,C-1] instead of one-hot-encoding, as required in tensorflow.
    Also, CrossEntropyLoss() uses LogSoftmax instead of normal Softmax.
    """
    def __init__(self, act=nn.ReLU(inplace=True)):
        super(_Classifier, self).__init__()
        self.h = nn.Sequential(nn.Linear(n_hidden[-1], 128), act,
                               nn.Linear(128, 64), act,
                               nn.Linear(64, num_class))
        utils.weight_init(self.h)

    def forward(self, x):
        out = self.h(x)
        return out

class _Critic(nn.Module):
    def __init__(self, act=nn.ReLU(inplace=True)):
        super(_Critic, self).__init__()
        self.h1 = nn.Sequential(nn.Linear(n_hidden[-1], 128), act,
                               nn.Linear(128, 64), act,
                               nn.Linear(64, 1))
        utils.weight_init(self.h1)

    def forward(self, x):
        out = self.h1(x)
        return out

class Model(object):
    def __init__(self):
        self.cr = _Critic()
        self.gen = _Generator()
        self.cls = _Classifier()

In [32]:
model = Model()
model.cr = model.cr.cuda() if use_gpu else model.cr
model.gen = model.gen.cuda() if use_gpu else model.gen
model.cls = model.cls.cuda() if use_gpu else model.cls

In [33]:
# cls_criterion = nn.CrossEntropyLoss()
cls_criterion = utils.softmax_cross_entropy_with_logits(num_class)
l2_reg = utils.L2_Loss()
softmax_ = nn.Softmax(dim=1)

params_wd = [
{'params': model.cr.parameters(), 'lr': lr_wd_D}
]
# wd_d_op = optim.SGD(params_wd, momentum=0.9)
wd_d_op = optim.Adam(params_wd)

params = [
{'params': model.gen.parameters(), 'lr': lr}
]
train_gen_op = optim.Adam(params)

params = [
{'params': model.gen.parameters(), 'lr': lr},
{'params': model.cls.parameters(), 'lr': lr}
]
# train_op = optim.SGD(params, momentum=0.9)
train_cls_op = optim.Adam(params)
# lr_sch = optim.lr_scheduler.ExponentialLR(train_op, 0.1)

S_batches = batch_generator(src_data, batch_size)
xs_batch, ys_batch = next(S_batches)

In [34]:
def test(i, acc):
    XS = torch.from_numpy(xs).float()
    XS = XS.cuda() if use_gpu else XS
    y_true = torch.from_numpy(ys).long()
    y_true = y_true.cuda() if use_gpu else y_true
    
    XS = Variable(XS, volatile=True)
    y_true = Variable(y_true, requires_grad=False)
    
    h2 = model.gen(XS)
    logits = model.cls(h2)
    clf_loss = cls_criterion(logits, y_true)
    logits = softmax_(logits)
    _, pred_idx = torch.max(logits, 1)
    clf_acc = torch.eq(pred_idx, y_true).float().mean()
#     clf_acc = torch.sum(pred_idx == y_true) / y_true.size(0)
    print('step: ', i)
    print('source classifier loss: %f, source accuracy: %f' % (clf_loss, clf_acc))
    acc['src'].append(clf_acc.cpu().data.numpy())

    XS = torch.from_numpy(xt).float()
    XS = XS.cuda() if use_gpu else XS
    y_true = torch.from_numpy(yt).long()
    y_true = y_true.cuda() if use_gpu else y_true
    
    XS = Variable(XS, volatile=True)
    y_true = Variable(y_true, requires_grad=False)
    
    h2 = model.gen(XS)
    logits = model.cls(h2)
    clf_loss = cls_criterion(logits, y_true)
    logits = softmax_(logits)
    _, pred_idx = torch.max(logits, 1)
    clf_acc = torch.eq(pred_idx, y_true).float().mean()
#     clf_acc = torch.sum(pred_idx == y_true) / y_true.size(0)
    print('target classifier loss: %f, target accuracy: %f' % (clf_loss, clf_acc))
#     ax.plot(clf_acc)
    acc['tar'].append(clf_acc.cpu().data.numpy())

In [35]:
acc = { 'src': [], 'tar': [] }
one = torch.FloatTensor([1]).cuda() if use_gpu else torch.FloatTensor([1])
mone = -1 * one

def train():

    # put all network in training mode
    model.gen.train()
    model.cr.train()
    model.cls.train()

    S_batches = utils.batch_generator([xs, ys], batch_size / 2)
    T_batches = utils.batch_generator([xt, yt], batch_size / 2)

    for i in range(num_steps):
        xs_batch, ys_batch = next(S_batches)
        xt_batch, yt_batch = next(T_batches)
        xb = np.vstack([xs_batch, xt_batch])
        # yb = np.hstack([ys_batch, yt_batch])
        
        for p in model.cr.parameters():
            p.requires_grad=True
        for p in model.gen.parameters():
            p.requires_grad=False

        for _ in range(D_train_num):
            X_src = torch.Tensor(xs_batch)
            X_src = X_src.cuda() if use_gpu else X_src
            X_src = Variable(X_src)
            
            model.cr.zero_grad()
            
            # train with src
            X_src = model.gen(X_src)
            D_src = model.cr(X_src)
            D_src = D_src.mean()
            D_src.backward(mone)
            
            # train with target
            X_tar = torch.Tensor(xt_batch)
            X_tar = X_tar.cuda() if use_gpu else X_tar
            X_tar = Variable(X_tar)
            
            X_tar = model.gen(X_tar)
            D_tar = model.cr(X_tar)
            D_tar = D_tar.mean()
            D_tar.backward(one)
            
            # train with gradient penalty
            alpha = torch.rand(batch_size//2, 1)
            alpha = alpha.expand_as(X_src.data)
            alpha = alpha.cuda() if use_gpu else alpha
            
            interpolates = alpha * X_src.data + ((1. - alpha) * X_tar.data)
            interpolates = interpolates.cuda() if use_gpu else interpolates
            interpolates = Variable(interpolates, requires_grad=True)
            
            cr_interpolates = model.cr(interpolates)
            
            ones = torch.ones(cr_interpolates.size())
            ones = ones.cuda() if use_gpu else ones
            grads = autograd.grad(outputs=cr_interpolates, inputs=interpolates,
                                 grad_outputs=ones, create_graph=True,
                                 retain_graph=True, only_inputs=True)[0]
            gradient_penalty = ((grads.norm(2, dim=1) - 1) ** 2).mean()
            gradient_penalty.backward()
            wd_d_op.step()
            
        for p in model.cr.parameters():
            p.requires_grad=False
        for p in model.gen.parameters():
            p.requires_grad=True
        
        model.gen.zero_grad()
        model.cls.zero_grad()
        
        X_tar = torch.Tensor(xt_batch)
        X_tar = X_tar.cuda() if use_gpu else X_tar
        X_tar = Variable(X_tar) 
        
        X_tar = model.gen(X_tar)
        D_tar = model.cr(X_tar)
        D_tar = D_tar.mean()
        D_tar.backward(mone)
        train_gen_op.step()
        
        model.gen.zero_grad()
        model.cls.zero_grad()
        
        X_src = torch.Tensor(xs_batch)
        X_src = X_src.cuda() if use_gpu else X_src
        X_src = Variable(X_src)
        
        ys_batch = torch.LongTensor(ys_batch)
        ys_batch = ys_batch.cuda() if use_gpu else ys_batch
        ys_batch = Variable(ys_batch, requires_grad=False)
        
        pred_logit = model.cls(model.gen(X_src))
        clf_loss = cls_criterion(pred_logit, ys_batch)
        clf_loss.backward()
        train_cls_op.step()

        # https://discuss.pytorch.org/t/simple-l2-regularization/139/2
        # https://discuss.pytorch.org/t/add-custom-regularizer-to-loss/4831/2
#         l2_loss = None
#         for name, param in model.gen.named_parameters():
#             if 'weight' in name:
#                 if l2_loss is None:
#                     l2_loss = l2_reg(param) * l2_param
#                 else:
#                     l2_loss += l2_reg(param) * l2_param
#         for name, param in model.cr.named_parameters():
#             if 'weight' in name:	
#                 l2_loss += l2_reg(param) * l2_param
                
#         P = i / num_steps
# #         utils.adjust_learning_rate(train_op, P)
#         lmbd = 2. / (1. + math.exp(-10 * P)) - 1.
# #         print("new wd_param:", lmbd)
#         total_loss = clf_loss * lmbd + l2_loss + (1. - lmbd) * wd_loss
#         # total_loss = clf_loss + wd_param * wd_loss
#         total_loss.backward()
#         train_op.step()
        
#         if i % 500 == 0 and lr_sch is not None:
#             lr_sch.step()

        if i % 10 == 0:
#             print(lmbd)
            test(i, acc)


In [36]:
%time train()

step:  0
source classifier loss: 8.767797, source accuracy: 0.031594
target classifier loss: 6.544956, target accuracy: 0.035220
step:  10
source classifier loss: 12.397827, source accuracy: 0.055023
target classifier loss: 9.171897, target accuracy: 0.067925
step:  20
source classifier loss: nan, source accuracy: 0.035144
target classifier loss: nan, target accuracy: 0.054088
step:  30
source classifier loss: nan, source accuracy: 0.035144
target classifier loss: nan, target accuracy: 0.054088
step:  40
source classifier loss: nan, source accuracy: 0.035144
target classifier loss: nan, target accuracy: 0.054088
step:  50
source classifier loss: nan, source accuracy: 0.035144
target classifier loss: nan, target accuracy: 0.054088
step:  60
source classifier loss: nan, source accuracy: 0.035144
target classifier loss: nan, target accuracy: 0.054088
step:  70
source classifier loss: nan, source accuracy: 0.035144
target classifier loss: nan, target accuracy: 0.054088
step:  80
source cla

step:  680
source classifier loss: nan, source accuracy: 0.035144
target classifier loss: nan, target accuracy: 0.054088
step:  690
source classifier loss: nan, source accuracy: 0.035144
target classifier loss: nan, target accuracy: 0.054088
step:  700
source classifier loss: nan, source accuracy: 0.035144
target classifier loss: nan, target accuracy: 0.054088
step:  710
source classifier loss: nan, source accuracy: 0.035144
target classifier loss: nan, target accuracy: 0.054088
step:  720
source classifier loss: nan, source accuracy: 0.035144
target classifier loss: nan, target accuracy: 0.054088
step:  730
source classifier loss: nan, source accuracy: 0.035144
target classifier loss: nan, target accuracy: 0.054088
step:  740
source classifier loss: nan, source accuracy: 0.035144
target classifier loss: nan, target accuracy: 0.054088
step:  750
source classifier loss: nan, source accuracy: 0.035144
target classifier loss: nan, target accuracy: 0.054088
step:  760
source classifier los

step:  1360
source classifier loss: nan, source accuracy: 0.035144
target classifier loss: nan, target accuracy: 0.054088
step:  1370
source classifier loss: nan, source accuracy: 0.035144
target classifier loss: nan, target accuracy: 0.054088
step:  1380
source classifier loss: nan, source accuracy: 0.035144
target classifier loss: nan, target accuracy: 0.054088
step:  1390
source classifier loss: nan, source accuracy: 0.035144
target classifier loss: nan, target accuracy: 0.054088
step:  1400
source classifier loss: nan, source accuracy: 0.035144
target classifier loss: nan, target accuracy: 0.054088
step:  1410
source classifier loss: nan, source accuracy: 0.035144
target classifier loss: nan, target accuracy: 0.054088
step:  1420
source classifier loss: nan, source accuracy: 0.035144
target classifier loss: nan, target accuracy: 0.054088
step:  1430
source classifier loss: nan, source accuracy: 0.035144
target classifier loss: nan, target accuracy: 0.054088
step:  1440
source class

fig = plt.figure()
ax1 = fig.add_subplot(211)
fig.subplots_adjust(bottom=0.2)
line1 = ax1.plot(acc['src'],'b',label='Source Accuracy')
line2 = ax1.plot(acc['tar'],'g',label='Target Accuracy')
ax1.set_ylim(0,1)
lines = line1+line2
labels = [l.get_label() for l in lines]
ax1.legend(lines, labels, loc=(0,-0.4), ncol=2)
plt.show()