Domain Adaptation using WDGRL architecture

In [139]:
%set_env CUDA_DEVICE_ORDER=PCI_BUS_ID
%set_env CUDA_VISIBLE_DEVICES=1

env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=1


In [140]:
%matplotlib notebook
from config import domainData
from config import num_classes as NUM_CLASSES
from torchvision import datasets, models, transforms
import yaml
import numpy as np
import matplotlib.pyplot as plt
import pickle as pkl
import logit
from random import shuffle
import math
import utils

In [141]:
import torch
import torch.nn as nn
import torch.autograd as autograd
import torch.optim as optim
from torch.autograd import Variable

In [142]:
# src_pkl = domainData['amazon_resnet50']
# tar_pkl = domainData['webcam_resnet50']
src_pkl = "./amazon_alexnet_4096.pkl"
tar_pkl = "./webcam_alexnet_4096.pkl"
# src_pkl = './amazon_4096.yml'
# tar_pkl = './webcam_4096.yml'

In [143]:
with open(src_pkl, 'rb') as f:
    src_data = pkl.load(f)
with open(tar_pkl, 'rb') as f:
    tar_data = pkl.load(f)

# with open(src_pkl, 'r') as f:
#     source = yaml.load(f)
# with open(tar_pkl, 'r') as f:
#     target = yaml.load(f)

def batch_generator(data, batch_size, shuffle=True):
    len_data = len(data['features'])
    def shuffle_data():
        nonlocal len_data
        idx = list(range(len_data))
        idx = np.random.permutation(idx)
        t_features = [data['features'][i] for i in idx]
        t_labels = [data['labels'][i] for i in idx]
        data['features'] = t_features
        data['labels'] = t_labels
    if shuffle:
        shuffle_data()
    batch_count = 0
    while True:
        if batch_count * batch_size + batch_size >= len_data:
            batch_count = 0
            if shuffle:
                shuffle_data()
        start = (int) (batch_count * batch_size)
        end = (int) (start + batch_size)
        batch_count += 1
        yield torch.stack(data['features'][start:end]), np.vstack(src_data['labels'][start:end])

In [144]:
wd_param = 0.1
gp_param = 10
lr_wd_D = 1e-3
D_train_num = 10

l2_param = 1e-5
lr = 1e-4
batch_size = 64
num_steps = 3200
num_class = 31
n_input = 4096
n_hidden = [2048, 1024, 512, 256]
lr_sch = None

use_gpu = True

In [145]:
xs, ys = src_data['features'], src_data['labels']
xt, yt = tar_data['features'], tar_data['labels']
ys = ys.squeeze()
yt = yt.squeeze()
# xs = source['train']
# ys = source['train_labels']
# xt = target['train']
# yt = target['train_labels']

In [146]:
class _Generator(nn.Module):
    def __init__(self, act=nn.ReLU(inplace=True)):
        super(_Generator, self).__init__()
        self.h1 = nn.Sequential(nn.Linear(n_input, n_hidden[0]), act,
                                nn.Linear(n_hidden[0], n_hidden[1]), act,
                                nn.Linear(n_hidden[1], n_hidden[2]), act,
                                nn.Linear(n_hidden[2], n_hidden[3]), act
                                )
        utils.weight_init(self.h1)

    def forward(self, x):
        out = self.h1(x)
        return out

class _Classifier(nn.Module):
    """
    Using nn.CrossEntropyLoss() for tf.nn.softmax_cross_entropy_with_logits()
    There is difference in target labels. This loss requires target labels
    in range [0,C-1] instead of one-hot-encoding, as required in tensorflow.
    Also, CrossEntropyLoss() uses LogSoftmax instead of normal Softmax.
    """
    def __init__(self, act=nn.ReLU(inplace=True)):
        super(_Classifier, self).__init__()
        self.h = nn.Sequential(nn.Linear(n_hidden[-1], 128), act,
                               nn.Linear(128,128), act,
                               nn.Linear(128, 64), act,
                               nn.Linear(64, num_class))
        utils.weight_init(self.h)

    def forward(self, x):
        out = self.h(x)
        return out

class _Critic(nn.Module):
    def __init__(self, act=nn.ReLU(inplace=True)):
        super(_Critic, self).__init__()
        self.h1 = nn.Sequential(nn.Linear(n_hidden[-1], 128), act,
                               nn.Linear(128, 64), act,
                               nn.Linear(64, 1))
        utils.weight_init(self.h1)

    def forward(self, x):
        out = self.h1(x)
        return out

class Model(object):
    def __init__(self):
        self.cr = _Critic()
        self.gen = _Generator()
        self.cls = _Classifier()

In [147]:
model = Model()
model.cr = model.cr.cuda() if use_gpu else model.cr
model.gen = model.gen.cuda() if use_gpu else model.gen
model.cls = model.cls.cuda() if use_gpu else model.cls

In [148]:
# cls_criterion = nn.CrossEntropyLoss()
cls_criterion = utils.softmax_cross_entropy_with_logits(num_class)
l2_reg = utils.L2_Loss()
softmax_ = nn.Softmax(dim=1)

params_wd = [
{'params': model.cr.parameters(), 'lr': lr_wd_D}
]
# wd_d_op = optim.SGD(params_wd, momentum=0.9)
wd_d_op = optim.Adam(params_wd)

params = [
{'params': model.gen.parameters(), 'lr': lr},
{'params': model.cls.parameters(), 'lr': lr}
]
# train_op = optim.SGD(params, momentum=0.9)
train_op = optim.Adam(params)
lr_sch = optim.lr_scheduler.ExponentialLR(train_op, 0.1)

S_batches = batch_generator(src_data, batch_size)
xs_batch, ys_batch = next(S_batches)

In [149]:
def test(i, acc):
    XS = torch.from_numpy(xs).float()
    XS = XS.cuda() if use_gpu else XS
    y_true = torch.from_numpy(ys).long()
    y_true = y_true.cuda() if use_gpu else y_true
    
    XS = Variable(XS, volatile=True)
    y_true = Variable(y_true, requires_grad=False)
    
    h2 = model.gen(XS)
    logits = model.cls(h2)
    clf_loss = cls_criterion(logits, y_true)
    logits = softmax_(logits)
    _, pred_idx = torch.max(logits, 1)
    clf_acc = torch.eq(pred_idx, y_true).float().mean()
#     clf_acc = torch.sum(pred_idx == y_true) / y_true.size(0)
    print('step: ', i)
    print('source classifier loss: %f, source accuracy: %f' % (clf_loss, clf_acc))
    acc['src'].append(clf_acc.cpu().data.numpy())

    XS = torch.from_numpy(xt).float()
    XS = XS.cuda() if use_gpu else XS
    y_true = torch.from_numpy(yt).long()
    y_true = y_true.cuda() if use_gpu else y_true
    
    XS = Variable(XS, volatile=True)
    y_true = Variable(y_true, requires_grad=False)
    
    h2 = model.gen(XS)
    logits = model.cls(h2)
    clf_loss = cls_criterion(logits, y_true)
    logits = softmax_(logits)
    _, pred_idx = torch.max(logits, 1)
    clf_acc = torch.eq(pred_idx, y_true).float().mean()
#     clf_acc = torch.sum(pred_idx == y_true) / y_true.size(0)
    print('target classifier loss: %f, target accuracy: %f' % (clf_loss, clf_acc))
#     ax.plot(clf_acc)
    acc['tar'].append(clf_acc.cpu().data.numpy())

In [150]:
acc = { 'src': [], 'tar': [] }

def train():

    # put all network in training mode
    model.gen.train()
    model.cr.train()
    model.cls.train()

    S_batches = utils.batch_generator([xs, ys], batch_size / 2)
    T_batches = utils.batch_generator([xt, yt], batch_size / 2)

    for i in range(num_steps):
        xs_batch, ys_batch = next(S_batches)
        xt_batch, yt_batch = next(T_batches)
        xb = np.vstack([xs_batch, xt_batch])
        # yb = np.hstack([ys_batch, yt_batch])

        for _ in range(D_train_num):
            wd_d_op.zero_grad()
            XB = torch.from_numpy(xb).float()
            if use_gpu:
                XB = XB.cuda()
            XB = Variable(XB)
            h2 = model.gen(XB)
            h2_s = h2[:batch_size // 2]
            h2_t = h2[batch_size // 2:]

            alpha = torch.Tensor(h2_s.size()).uniform_(0.1,0.9)
            if use_gpu:
                alpha = alpha.cuda()
            difference = h2_s.data - h2_t.data
            interpolates = h2_s.data + (alpha * difference)
#             h2_whole = interpolates

            h2_whole = torch.cat([h2.data, interpolates])

            if use_gpu:
                h2_whole = h2_whole.cuda()
            h2_whole = Variable(h2_whole, requires_grad=True)
            critic_out1 = model.cr(h2)
            critic_out = model.cr(h2_whole)

            critic_s = critic_out1[:batch_size // 2]
#             critic_t = critic_out1[batch_size // 2:]
            critic_t = critic_out[batch_size // 2:batch_size]
            
            wd_loss = torch.sigmoid(critic_t) ** 2 + (1. - torch.sigmoid(critic_s)) ** 2
            wd_loss = wd_loss.mean()

#             wd_loss = critic_s.mean() - critic_t.mean()


            ones = torch.ones(critic_out.size())
            if use_gpu:
                ones = ones.cuda()

            grads = autograd.grad(critic_out, h2_whole, grad_outputs=ones,
                retain_graph=True, create_graph=True, only_inputs=True)[0]

#             slopes = torch.sqrt(torch.sum(grads ** 2, dim=1))
#             gradient_penalty = ((slopes - 1.) ** 2).mean()
            gradient_penalty = ((grads.norm(2, dim=1) - 1) ** 2).mean() 

#             grads = grads.view(grads.size(0), -1)
#             gradient_penalty = ((grads.norm(2, dim=1) - 1.) ** 2).mean()l

            wd_loss_t = wd_loss + gp_param * gradient_penalty
            wd_loss_t.backward()
            wd_d_op.step()

        train_op.zero_grad()        
        XB = torch.from_numpy(xb).float()
        ys_batch = torch.from_numpy(ys_batch).long()
        if use_gpu:
            XB = XB.cuda()
            ys_batch = ys_batch.cuda()
        XB = Variable(XB)
        ys_batch = Variable(ys_batch, requires_grad=False)
        h2 = model.gen(XB)
        h2_s = h2[:batch_size // 2]
        h2_t = h2[batch_size // 2:]

        alpha = torch.Tensor(h2_s.size()).uniform_(0.1,0.9)
        if use_gpu:
            alpha = alpha.cuda()
        difference = h2_s.data - h2_t.data
        interpolates = h2_s.data + (alpha * difference)

        h2_whole = torch.cat([h2.data, interpolates])
        h2_whole = Variable(h2_whole, requires_grad=True)
        
        critic_out = model.cr(h2_whole)
#         critic_out = model.cr(h2)
    
        critic_s = critic_out[:batch_size // 2]
        critic_t = critic_out[batch_size // 2:batch_size]
#         critic_t = critic_out[batch_size // 2:]

        wd_loss = critic_s - critic_t
        if i%200==0:
            wd_loss = -1 * wd_loss
        wd_loss = wd_loss.mean()

        pred_logit = model.cls(h2_s)
        clf_loss = cls_criterion(pred_logit, ys_batch)

        # https://discuss.pytorch.org/t/simple-l2-regularization/139/2
        # https://discuss.pytorch.org/t/add-custom-regularizer-to-loss/4831/2
        l2_loss = None
        for name, param in model.gen.named_parameters():
            if 'weight' in name:
                if l2_loss is None:
                    l2_loss = l2_reg(param) * l2_param
                else:
                    l2_loss += l2_reg(param) * l2_param
        for name, param in model.cr.named_parameters():
            if 'weight' in name:	
                l2_loss += l2_reg(param) * l2_param
                
        P = i / num_steps
#         utils.adjust_learning_rate(train_op, P)
        lmbd = 2. / (1. + math.exp(-10 * P)) - 1.
#         print("new wd_param:", lmbd)
        total_loss = clf_loss + l2_loss + (1. - lmbd) * wd_loss
        # total_loss = clf_loss + wd_param * wd_loss
        total_loss.backward()
        train_op.step()
        
        if i % 500 == 0 and lr_sch is not None:
            lr_sch.step()

        if i % 10 == 0:
            print(lmbd)
            test(i, acc)


In [151]:
%time train()

0.0
step:  0
source classifier loss: 4.049394, source accuracy: 0.035144
target classifier loss: 3.996413, target accuracy: 0.021384
0.015623728558408878
step:  10
source classifier loss: 3.428989, source accuracy: 0.083777
target classifier loss: 3.501337, target accuracy: 0.080503
0.031239831446031152
step:  20
source classifier loss: 3.192865, source accuracy: 0.199858
target classifier loss: 3.393479, target accuracy: 0.139623
0.04684069787264811
step:  30
source classifier loss: 2.797700, source accuracy: 0.249911
target classifier loss: 3.191042, target accuracy: 0.163522
0.06241874674751258
step:  40
source classifier loss: 2.554735, source accuracy: 0.299255
target classifier loss: 3.038263, target accuracy: 0.181132
0.07796644137536823
step:  50
source classifier loss: 2.521787, source accuracy: 0.306354
target classifier loss: 3.097163, target accuracy: 0.181132
0.09347630396922768
step:  60
source classifier loss: 2.397920, source accuracy: 0.304224
target classifier loss: 3

0.6959355167556514
step:  550
source classifier loss: 0.665514, source accuracy: 0.830671
target classifier loss: 2.017829, target accuracy: 0.459119
0.7039056039366212
step:  560
source classifier loss: 0.651254, source accuracy: 0.832091
target classifier loss: 2.057248, target accuracy: 0.461635
0.7117022939345188
step:  570
source classifier loss: 0.641930, source accuracy: 0.833511
target classifier loss: 2.089514, target accuracy: 0.435220
0.7193275010198334
step:  580
source classifier loss: 0.636196, source accuracy: 0.832446
target classifier loss: 2.073282, target accuracy: 0.445283
0.7267832199475612
step:  590
source classifier loss: 0.621305, source accuracy: 0.839191
target classifier loss: 2.027400, target accuracy: 0.455346
0.7340715196043412
step:  600
source classifier loss: 0.616509, source accuracy: 0.841321
target classifier loss: 2.031158, target accuracy: 0.460377
0.7411945368167221
step:  610
source classifier loss: 0.612051, source accuracy: 0.841676
target cla

0.9377123389304431
step:  1100
source classifier loss: 0.445053, source accuracy: 0.893504
target classifier loss: 2.109438, target accuracy: 0.475472
0.9395708258652122
step:  1110
source classifier loss: 0.444572, source accuracy: 0.893859
target classifier loss: 2.108239, target accuracy: 0.475472
0.9413755384972873
step:  1120
source classifier loss: 0.443710, source accuracy: 0.894924
target classifier loss: 2.101396, target accuracy: 0.475472
0.9431279339102947
step:  1130
source classifier loss: 0.443047, source accuracy: 0.893859
target classifier loss: 2.091522, target accuracy: 0.477987
0.9448294355464197
step:  1140
source classifier loss: 0.442524, source accuracy: 0.892794
target classifier loss: 2.086099, target accuracy: 0.476730
0.9464814336291136
step:  1150
source classifier loss: 0.442026, source accuracy: 0.894924
target classifier loss: 2.087669, target accuracy: 0.476730
0.9480852856044062
step:  1160
source classifier loss: 0.441839, source accuracy: 0.895634
tar

0.988539506970016
step:  1650
source classifier loss: 0.428069, source accuracy: 0.899539
target classifier loss: 2.105427, target accuracy: 0.477987
0.988890150592618
step:  1660
source classifier loss: 0.428040, source accuracy: 0.899539
target classifier loss: 2.105994, target accuracy: 0.477987
0.9892301240764443
step:  1670
source classifier loss: 0.427996, source accuracy: 0.899539
target classifier loss: 2.105875, target accuracy: 0.476730
0.9895597486128833
step:  1680
source classifier loss: 0.427935, source accuracy: 0.899894
target classifier loss: 2.105392, target accuracy: 0.476730
0.9898793359360678
step:  1690
source classifier loss: 0.427894, source accuracy: 0.899539
target classifier loss: 2.104938, target accuracy: 0.476730
0.9901891885885559
step:  1700
source classifier loss: 0.427854, source accuracy: 0.899894
target classifier loss: 2.105107, target accuracy: 0.476730
0.9904896001803338
step:  1710
source classifier loss: 0.427811, source accuracy: 0.900248
targe

0.9979355379264905
step:  2200
source classifier loss: 0.426578, source accuracy: 0.899894
target classifier loss: 2.103462, target accuracy: 0.479245
0.9979989911981799
step:  2210
source classifier loss: 0.426575, source accuracy: 0.899894
target classifier loss: 2.103516, target accuracy: 0.479245
0.9980604960644348
step:  2220
source classifier loss: 0.426571, source accuracy: 0.899894
target classifier loss: 2.103539, target accuracy: 0.479245
0.9981201122386745
step:  2230
source classifier loss: 0.426567, source accuracy: 0.899539
target classifier loss: 2.103590, target accuracy: 0.479245
0.9981778976111988
step:  2240
source classifier loss: 0.426563, source accuracy: 0.899539
target classifier loss: 2.103630, target accuracy: 0.479245
0.9982339083044309
step:  2250
source classifier loss: 0.426559, source accuracy: 0.899539
target classifier loss: 2.103655, target accuracy: 0.479245
0.9982881987265098
step:  2260
source classifier loss: 0.426556, source accuracy: 0.899894
tar

0.9996295485134694
step:  2750
source classifier loss: 0.426457, source accuracy: 0.899894
target classifier loss: 2.103560, target accuracy: 0.479245
0.9996409440613065
step:  2760
source classifier loss: 0.426457, source accuracy: 0.899894
target classifier loss: 2.103560, target accuracy: 0.479245
0.9996519891289457
step:  2770
source classifier loss: 0.426457, source accuracy: 0.899894
target classifier loss: 2.103558, target accuracy: 0.479245
0.9996626944920217
step:  2780
source classifier loss: 0.426457, source accuracy: 0.899894
target classifier loss: 2.103558, target accuracy: 0.479245
0.9996730705950918
step:  2790
source classifier loss: 0.426457, source accuracy: 0.899894
target classifier loss: 2.103558, target accuracy: 0.479245
0.999683127561795
step:  2800
source classifier loss: 0.426457, source accuracy: 0.899894
target classifier loss: 2.103560, target accuracy: 0.479245
0.9996928752047007
step:  2810
source classifier loss: 0.426457, source accuracy: 0.899894
targ

fig = plt.figure()
ax1 = fig.add_subplot(211)
fig.subplots_adjust(bottom=0.2)
line1 = ax1.plot(acc['src'],'b',label='Source Accuracy')
line2 = ax1.plot(acc['tar'],'g',label='Target Accuracy')
ax1.set_ylim(0,1)
lines = line1+line2
labels = [l.get_label() for l in lines]
ax1.legend(lines, labels, loc=(0,-0.4), ncol=2)
plt.show()