Domain Adaptation using WDGRL architecture

In [17]:
%set_env CUDA_DEVICE_ORDER=PCI_BUS_ID
%set_env CUDA_VISIBLE_DEVICES=3

env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=3


In [18]:
from config import domainData
from config import num_classes as NUM_CLASSES
from torchvision import datasets, models, transforms
import yaml
import numpy as np
import pickle as pkl
import logit
from random import shuffle
import utils

In [19]:
import torch
import torch.nn as nn
import torch.autograd as autograd
import torch.optim as optim
from torch.autograd import Variable

In [30]:
# src_pkl = domainData['amazon_resnet50']
# tar_pkl = domainData['webcam_resnet50']
src_pkl = "./amazon_alexnet_4096.pkl"
tar_pkl = "./webcam_alexnet_4096.pkl"
# src_pkl = './amazon_4096.yml'
# tar_pkl = './webcam_4096.yml'

In [31]:
with open(src_pkl, 'rb') as f:
    src_data = pkl.load(f)
with open(tar_pkl, 'rb') as f:
    tar_data = pkl.load(f)

# with open(src_pkl, 'r') as f:
#     source = yaml.load(f)
# with open(tar_pkl, 'r') as f:
#     target = yaml.load(f)

def batch_generator(data, batch_size, shuffle=True):
    len_data = len(data['features'])
    def shuffle_data():
        nonlocal len_data
        idx = list(range(len_data))
        idx = np.random.permutation(idx)
        t_features = [data['features'][i] for i in idx]
        t_labels = [data['labels'][i] for i in idx]
        data['features'] = t_features
        data['labels'] = t_labels
    if shuffle:
        shuffle_data()
    batch_count = 0
    while True:
        if batch_count * batch_size + batch_size >= len_data:
            batch_count = 0
            if shuffle:
                shuffle_data()
        start = (int) (batch_count * batch_size)
        end = (int) (start + batch_size)
        batch_count += 1
        yield torch.stack(data['features'][start:end]), np.vstack(src_data['labels'][start:end])

In [40]:
wd_param = 0.1
gp_param = 10
lr_wd_D = 1e-3
D_train_num = 10

l2_param = 1e-5
lr = 1e-4
batch_size = 64
num_steps = 1200
num_class = 31
n_input = 4096
n_hidden = [500, 100]

use_gpu = False

In [41]:
xs, ys = src_data['features'], src_data['labels']
xt, yt = tar_data['features'], tar_data['labels']
ys = ys.squeeze()
yt = yt.squeeze()
# xs = source['train']
# ys = source['train_labels']
# xt = target['train']
# yt = target['train_labels']

In [42]:
class _Generator(nn.Module):
    def __init__(self, act=nn.ReLU(inplace=True)):
        super(_Generator, self).__init__()
        self.h1 = nn.Sequential(nn.Linear(n_input, n_hidden[0]), act)
        self.h2 = nn.Sequential(nn.Linear(n_hidden[0], n_hidden[1]), act)
        utils.weight_init(self.h1)
        utils.weight_init(self.h2)
# 		for x in self.h1.modules():
# 			utils.weight_init(x)
# 		for x in self.h2.modules():
# 			utils.weight_init(x)

    def forward(self, x):
        x = self.h1(x)
        out = self.h2(x)
        return out

class _Classifier(nn.Module):
    """
    Using nn.CrossEntropyLoss() for tf.nn.softmax_cross_entropy_with_logits()
    There is difference in target labels. This loss requires target labels
    in range [0,C-1] instead of one-hot-encoding, as required in tensorflow.
    Also, CrossEntropyLoss() uses LogSoftmax instead of normal Softmax.
    """
    def __init__(self):
        super(_Classifier, self).__init__()
        self.h = nn.Linear(n_hidden[-1], num_class)
        utils.weight_init(self.h)
#         for x in self.h.modules():
#             utils.weight_init(x)

    def forward(self, x):
        out = self.h(x)
        return out

class _Critic(nn.Module):
    def __init__(self, act=nn.ReLU(inplace=True)):
        super(_Critic, self).__init__()
        self.h1 = nn.Sequential(nn.Linear(n_hidden[-1], 100), act)
        self.out = nn.Linear(100, 1)
        utils.weight_init(self.h1)
        utils.weight_init(self.out)
# 		for x in self.h1.modules():
# 			utils.weight_init(x)
# 		for x in self.out.modules():
# 			utils.weight_init(x)

    def forward(self, x):
        x = self.h1(x)
        out = self.out(x)
        return out

class Model(object):
    def __init__(self):
        self.cr = _Critic()
        self.gen = _Generator()
        self.cls = _Classifier()

In [43]:
model = Model()

In [44]:
# cls_criterion = nn.CrossEntropyLoss()
cls_criterion = utils.softmax_cross_entropy_with_logits(num_class)
l2_reg = utils.L2_Loss()
softmax_ = nn.Softmax(dim=1)

params_wd = [
{'params': model.cr.parameters(), 'lr': lr_wd_D}
]
wd_d_op = optim.Adam(params_wd)

params = [
{'params': model.gen.parameters()},
{'params': model.cls.parameters()}
]
train_op = optim.Adam(params, lr=lr)

S_batches = batch_generator(src_data, batch_size)
xs_batch, ys_batch = next(S_batches)

In [45]:
def test(i=0):
    XS = Variable(torch.from_numpy(xs).float(), volatile=True)
    y_true = Variable(torch.from_numpy(ys).long(), requires_grad=False)
    h2 = model.gen(XS)
    logits = model.cls(h2)
    clf_loss = cls_criterion(logits, y_true)
    logits = softmax_(logits)
    _, pred_idx = torch.max(logits, 1)
    clf_acc = torch.eq(pred_idx, y_true).float().mean()
#     clf_acc = torch.sum(pred_idx == y_true) / y_true.size(0)
    print('step: ', i)
    print('source classifier loss: %f, source accuracy: %f' % (clf_loss, clf_acc))

    XS = Variable(torch.from_numpy(xt).float(), volatile=True)
    y_true = Variable(torch.from_numpy(yt).long(), requires_grad=False)
    h2 = model.gen(XS)
    logits = model.cls(h2)
    clf_loss = cls_criterion(logits, y_true)
    logits = softmax_(logits)
    _, pred_idx = torch.max(logits, 1)
    clf_acc = torch.eq(pred_idx, y_true).float().mean()
#     clf_acc = torch.sum(pred_idx == y_true) / y_true.size(0)
    print('target classifier loss: %f, target accuracy: %f' % (clf_loss, clf_acc))

In [46]:
def train():

    # put all network in training mode
    model.gen.train()
    model.cr.train()
    model.cls.train()

    S_batches = utils.batch_generator([xs, ys], batch_size / 2)
    T_batches = utils.batch_generator([xt, yt], batch_size / 2)

    for i in range(num_steps):
        xs_batch, ys_batch = next(S_batches)
        xt_batch, yt_batch = next(T_batches)
        xb = np.vstack([xs_batch, xt_batch])
        # yb = np.hstack([ys_batch, yt_batch])

        for _ in range(D_train_num):
            wd_d_op.zero_grad()
            XB = torch.from_numpy(xb).float()
            if use_gpu:
                XB = XB.cuda()
            XB = Variable(XB)
            h2 = model.gen(XB)
            h2_s = h2[:batch_size // 2]
            h2_t = h2[batch_size // 2:]

            alpha = torch.Tensor(h2_s.size()).uniform_(0,1)
            if use_gpu:
                alpha = alpha.cuda()
            difference = h2_s.data - h2_t.data
            interpolates = h2_s.data + (alpha * difference)

            h2_whole = torch.cat([h2.data, interpolates])

            if use_gpu:
                h2_whole = h2_whole.cuda()
            h2_whole = Variable(h2_whole, requires_grad=True)
            critic_out = model.cr(h2_whole)

            critic_s = critic_out[:batch_size // 2]
            critic_t = critic_out[batch_size // 2:batch_size]
            
            wd_loss = torch.sigmoid(critic_t) ** 2 + (1. - torch.sigmoid(critic_s)) ** 2
            wd_loss = wd_loss.mean()

#             wd_loss = critic_s.mean() - critic_t.mean()


            ones = torch.ones(critic_out.size())
            if use_gpu:
                ones = ones.cuda()

            grads = autograd.grad(critic_out, h2_whole, grad_outputs=ones,
                retain_graph=True, create_graph=True)[0]

            slopes = torch.sqrt(torch.sum(grads ** 2, dim=1))
            gradient_penalty = ((slopes - 1.) ** 2).mean()

#             grads = grads.view(grads.size(0), -1)
#             gradient_penalty = ((grads.norm(2, dim=1) - 1.) ** 2).mean()

            wd_loss_t = wd_loss + gp_param * gradient_penalty
            wd_loss_t.backward()
            wd_d_op.step()

        train_op.zero_grad()
        utils.adjust
        XB = torch.from_numpy(xb).float()
        ys_batch = torch.from_numpy(ys_batch).long()
        if use_gpu:
            XB = XB.cuda()
            ys_batch = ys_batch.cuda()
        XB = Variable(XB)
        ys_batch = Variable(ys_batch, requires_grad=False)
        h2 = model.gen(XB)
        h2_s = h2[:batch_size // 2]
        h2_t = h2[batch_size // 2:]

        alpha = torch.Tensor(h2_s.size()).uniform_(0,1)
        if use_gpu:
            alpha = alpha.cuda()
        difference = h2_s.data - h2_t.data
        interpolates = h2_s.data + (alpha * difference)

        h2_whole = torch.cat([h2.data, interpolates])
        h2_whole = Variable(h2_whole, requires_grad=True)
        
        critic_out = model.cr(h2_whole)

        critic_s = critic_out[:batch_size // 2]
        critic_t = critic_out[batch_size // 2:batch_size]

        wd_loss = critic_s - critic_t
        wd_loss = wd_loss.mean()

        pred_logit = model.cls(h2_s)
        clf_loss = cls_criterion(pred_logit, ys_batch)

        # https://discuss.pytorch.org/t/simple-l2-regularization/139/2
        # https://discuss.pytorch.org/t/add-custom-regularizer-to-loss/4831/2
        l2_loss = None
        # l2_loss = torch.Tensor(1).zero_()
        # if use_gpu:
        # 	l2_loss = l2_loss.cuda()
        # l2_loss = Variable(l2_loss, requires_grad=True)
        for name, param in model.gen.named_parameters():
#             print(name, param, type(param))
            if 'weight' in name:
                if l2_loss is None:
                    l2_loss = l2_reg(param) * l2_param
                else:
                    l2_loss += l2_reg(param) * l2_param
        for name, param in model.cr.named_parameters():
            if 'weight' in name:	
                l2_loss += l2_reg(param) * l2_param

#         print("wd_loss: ", wd_loss)
        total_loss = clf_loss + l2_loss + wd_param * wd_loss
        # total_loss = clf_loss + wd_param * wd_loss
        total_loss.backward()
        train_op.step()

        if i % 10 == 0:
            test(i)

In [47]:
train()

step:  0
source classifier loss: 13.927851, source accuracy: 0.056088
target classifier loss: 11.250607, target accuracy: 0.037736
step:  10
source classifier loss: 6.940927, source accuracy: 0.091232
target classifier loss: 6.737860, target accuracy: 0.067925
step:  20
source classifier loss: 3.505122, source accuracy: 0.197728
target classifier loss: 3.917571, target accuracy: 0.096855
step:  30
source classifier loss: 2.958606, source accuracy: 0.217962
target classifier loss: 3.555614, target accuracy: 0.130818
step:  40
source classifier loss: 2.700550, source accuracy: 0.298900
target classifier loss: 3.533226, target accuracy: 0.133333
step:  50
source classifier loss: 2.426571, source accuracy: 0.349308
target classifier loss: 3.116075, target accuracy: 0.181132
step:  60
source classifier loss: 2.294974, source accuracy: 0.401846
target classifier loss: 3.204958, target accuracy: 0.194969
step:  70
source classifier loss: 2.160150, source accuracy: 0.413561
target classifier l

step:  630
source classifier loss: 0.535122, source accuracy: 0.848775
target classifier loss: 2.566674, target accuracy: 0.406289
step:  640
source classifier loss: 0.510184, source accuracy: 0.853390
target classifier loss: 2.621575, target accuracy: 0.361006
step:  650
source classifier loss: 0.544891, source accuracy: 0.848065
target classifier loss: 2.762738, target accuracy: 0.358491
step:  660
source classifier loss: 0.568942, source accuracy: 0.840256
target classifier loss: 2.825892, target accuracy: 0.348428
step:  670
source classifier loss: 0.573215, source accuracy: 0.828186
target classifier loss: 2.687329, target accuracy: 0.369811
step:  680
source classifier loss: 0.530244, source accuracy: 0.855165
target classifier loss: 2.622560, target accuracy: 0.382390
step:  690
source classifier loss: 0.518750, source accuracy: 0.850195
target classifier loss: 2.758844, target accuracy: 0.367296
step:  700
source classifier loss: 0.540509, source accuracy: 0.852325
target class