Domain Adaptation using WDGRL architecture

In [206]:
%set_env CUDA_DEVICE_ORDER=PCI_BUS_ID
%set_env CUDA_VISIBLE_DEVICES=1

env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=1


In [207]:
%matplotlib notebook
from config import domainData
from config import num_classes as NUM_CLASSES
from torchvision import datasets, models, transforms
import yaml
import numpy as np
import matplotlib.pyplot as plt
import pickle as pkl
import logit
from random import shuffle
import math
import utils

In [208]:
import torch
import torch.nn as nn
import torch.autograd as autograd
import torch.optim as optim
from torch.autograd import Variable

In [209]:
# src_pkl = domainData['amazon_resnet50']
# tar_pkl = domainData['webcam_resnet50']
src_pkl = "./amazon_alexnet_4096.pkl"
tar_pkl = "./webcam_alexnet_4096.pkl"
# src_pkl = './amazon_4096.yml'
# tar_pkl = './webcam_4096.yml'

In [210]:
with open(src_pkl, 'rb') as f:
    src_data = pkl.load(f)
with open(tar_pkl, 'rb') as f:
    tar_data = pkl.load(f)

# with open(src_pkl, 'r') as f:
#     source = yaml.load(f)
# with open(tar_pkl, 'r') as f:
#     target = yaml.load(f)

def batch_generator(data, batch_size, shuffle=True):
    len_data = len(data['features'])
    def shuffle_data():
        nonlocal len_data
        idx = list(range(len_data))
        idx = np.random.permutation(idx)
        t_features = [data['features'][i] for i in idx]
        t_labels = [data['labels'][i] for i in idx]
        data['features'] = t_features
        data['labels'] = t_labels
    if shuffle:
        shuffle_data()
    batch_count = 0
    while True:
        if batch_count * batch_size + batch_size >= len_data:
            batch_count = 0
            if shuffle:
                shuffle_data()
        start = (int) (batch_count * batch_size)
        end = (int) (start + batch_size)
        batch_count += 1
        yield torch.stack(data['features'][start:end]), np.vstack(src_data['labels'][start:end])

In [211]:
wd_param = 0.1
gp_param = 10
lr_wd_D = 1e-3
D_train_num = 50

l2_param = 1e-5
lr = 1e-3
batch_size = 64
num_steps = 1800
num_class = 31
n_input = 4096
n_hidden = [2048, 1024, 512, 256]
lr_sch = None

use_gpu = True

In [212]:
xs, ys = src_data['features'], src_data['labels']
xt, yt = tar_data['features'], tar_data['labels']
ys = ys.squeeze()
yt = yt.squeeze()
# xs = source['train']
# ys = source['train_labels']
# xt = target['train']
# yt = target['train_labels']

In [213]:
class _Generator(nn.Module):
    def __init__(self, act=nn.ReLU(inplace=True)):
        super(_Generator, self).__init__()
        self.h1 = nn.Sequential(nn.Linear(n_input, n_hidden[0]), act,
                                nn.Linear(n_hidden[0], n_hidden[1]), act,
                                nn.Linear(n_hidden[1], n_hidden[2]), act,
                                nn.Linear(n_hidden[2], n_hidden[3]), act
                                )
        utils.weight_init(self.h1)

    def forward(self, x):
        out = self.h1(x)
        return out

class _Classifier(nn.Module):
    """
    Using nn.CrossEntropyLoss() for tf.nn.softmax_cross_entropy_with_logits()
    There is difference in target labels. This loss requires target labels
    in range [0,C-1] instead of one-hot-encoding, as required in tensorflow.
    Also, CrossEntropyLoss() uses LogSoftmax instead of normal Softmax.
    """
    def __init__(self, act=nn.ReLU(inplace=True)):
        super(_Classifier, self).__init__()
        self.h = nn.Sequential(nn.Linear(n_hidden[-1], 128), act,
                               nn.Linear(128, 64), act,
                               nn.Linear(64, num_class))
        utils.weight_init(self.h)

    def forward(self, x):
        out = self.h(x)
        return out

class _Critic(nn.Module):
    def __init__(self, act=nn.ReLU(inplace=True)):
        super(_Critic, self).__init__()
        self.h1 = nn.Sequential(nn.Linear(n_hidden[-1], 128), act,
                               nn.Linear(128, 64), act,
                               nn.Linear(64, 1))
        utils.weight_init(self.h1)

    def forward(self, x):
        out = self.h1(x)
        return out

class Model(object):
    def __init__(self):
        self.cr = _Critic()
        self.gen = _Generator()
        self.cls = _Classifier()

In [214]:
model = Model()
model.cr = model.cr.cuda() if use_gpu else model.cr
model.gen = model.gen.cuda() if use_gpu else model.gen
model.cls = model.cls.cuda() if use_gpu else model.cls

In [215]:
# cls_criterion = nn.CrossEntropyLoss()
cls_criterion = utils.softmax_cross_entropy_with_logits(num_class)
l2_reg = utils.L2_Loss()
softmax_ = nn.Softmax(dim=1)

params_wd = [
{'params': model.cr.parameters(), 'lr': lr_wd_D}
]
# wd_d_op = optim.SGD(params_wd, momentum=0.9)
wd_d_op = optim.Adam(params_wd)

params = [
{'params': model.gen.parameters(), 'lr': lr},    
{'params': model.cls.parameters(), 'lr': lr}
]
train_op = optim.SGD(params, momentum=0.9)
# train_op = optim.Adam(params)
lr_sch = optim.lr_scheduler.ExponentialLR(train_op, 0.1)
# lr_sch = optim.lr_scheduler.StepLR(train_op, 200, 0.1)

S_batches = batch_generator(src_data, batch_size)
xs_batch, ys_batch = next(S_batches)

In [216]:
def test(i, acc):
    XS = torch.from_numpy(xs).float()
    XS = XS.cuda() if use_gpu else XS
    y_true = torch.from_numpy(ys).long()
    y_true = y_true.cuda() if use_gpu else y_true
    
    XS = Variable(XS, volatile=True)
    y_true = Variable(y_true, requires_grad=False)
    
    h2 = model.gen(XS)
    logits = model.cls(h2)
    clf_loss = cls_criterion(logits, y_true)
    logits = softmax_(logits)
    _, pred_idx = torch.max(logits, 1)
    clf_acc = torch.eq(pred_idx, y_true).float().mean()
#     clf_acc = torch.sum(pred_idx == y_true) / y_true.size(0)
    print('step: ', i)
    print('source classifier loss: %f, source accuracy: %f' % (clf_loss, clf_acc))
    acc['src'].append(clf_acc.cpu().data.numpy())

    XS = torch.from_numpy(xt).float()
    XS = XS.cuda() if use_gpu else XS
    y_true = torch.from_numpy(yt).long()
    y_true = y_true.cuda() if use_gpu else y_true
    
    XS = Variable(XS, volatile=True)
    y_true = Variable(y_true, requires_grad=False)
    
    h2 = model.gen(XS)
    logits = model.cls(h2)
    clf_loss = cls_criterion(logits, y_true)
    logits = softmax_(logits)
    _, pred_idx = torch.max(logits, 1)
    clf_acc = torch.eq(pred_idx, y_true).float().mean()
#     clf_acc = torch.sum(pred_idx == y_true) / y_true.size(0)
    print('target classifier loss: %f, target accuracy: %f' % (clf_loss, clf_acc))
#     ax.plot(clf_acc)
    acc['tar'].append(clf_acc.cpu().data.numpy())

In [217]:
acc = { 'src': [], 'tar': [] }

def train():

    # put all network in training mode
    model.gen.train()
    model.cr.train()
    model.cls.train()

    S_batches = utils.batch_generator([xs, ys], batch_size / 2)

    for i in range(num_steps):
        xs_batch, ys_batch = next(S_batches)
        xb = xs_batch
        
        train_op.zero_grad()        
        XB = torch.from_numpy(xb).float()
        ys_batch = torch.from_numpy(ys_batch).long()
        if use_gpu:
            XB = XB.cuda()
            ys_batch = ys_batch.cuda()
        XB = Variable(XB)
        ys_batch = Variable(ys_batch, requires_grad=False)
        
        XB = model.gen(XB)
        pred_logit = model.cls(XB)
        clf_loss = cls_criterion(pred_logit, ys_batch)
        
        l2_loss = None
        for name, param in model.gen.named_parameters():
            if 'weight' in name:
                if l2_loss is None:
                    l2_loss = l2_reg(param) * l2_param
                else:
                    l2_loss += l2_reg(param) * l2_param
        for name, param in model.cr.named_parameters():
            if 'weight' in name:	
                l2_loss += l2_reg(param) * l2_param

        total_loss = clf_loss + l2_loss
        # total_loss = clf_loss + wd_param * wd_loss
        total_loss.backward()
        train_op.step()
        
        if i % 500 == 0 and lr_sch is not None:
            lr_sch.step()
        
        if i % 10 == 0:
            test(i, acc)


In [218]:
train()

step:  0
source classifier loss: 5.084923, source accuracy: 0.045438
target classifier loss: 4.674077, target accuracy: 0.036478
step:  10
source classifier loss: 3.736034, source accuracy: 0.055023
target classifier loss: 3.724208, target accuracy: 0.033962
step:  20
source classifier loss: 3.212682, source accuracy: 0.138445
target classifier loss: 3.343492, target accuracy: 0.085535
step:  30
source classifier loss: 3.019206, source accuracy: 0.208378
target classifier loss: 3.331100, target accuracy: 0.104403
step:  40
source classifier loss: 3.012053, source accuracy: 0.166134
target classifier loss: 3.217976, target accuracy: 0.118239
step:  50
source classifier loss: 2.786849, source accuracy: 0.272985
target classifier loss: 3.118044, target accuracy: 0.134591
step:  60
source classifier loss: 2.523650, source accuracy: 0.340078
target classifier loss: 2.946956, target accuracy: 0.176101
step:  70
source classifier loss: 2.482185, source accuracy: 0.315229
target classifier los

step:  630
source classifier loss: 0.517450, source accuracy: 0.861910
target classifier loss: 2.109501, target accuracy: 0.476730
step:  640
source classifier loss: 0.511075, source accuracy: 0.862975
target classifier loss: 2.101278, target accuracy: 0.471698
step:  650
source classifier loss: 0.505335, source accuracy: 0.866880
target classifier loss: 2.125921, target accuracy: 0.474214
step:  660
source classifier loss: 0.498204, source accuracy: 0.870785
target classifier loss: 2.132320, target accuracy: 0.471698
step:  670
source classifier loss: 0.495064, source accuracy: 0.865460
target classifier loss: 2.135585, target accuracy: 0.471698
step:  680
source classifier loss: 0.488346, source accuracy: 0.867590
target classifier loss: 2.148796, target accuracy: 0.479245
step:  690
source classifier loss: 0.488774, source accuracy: 0.867235
target classifier loss: 2.151665, target accuracy: 0.470440
step:  700
source classifier loss: 0.481859, source accuracy: 0.870785
target class

step:  1260
source classifier loss: 0.373926, source accuracy: 0.910898
target classifier loss: 2.164113, target accuracy: 0.457862
step:  1270
source classifier loss: 0.373751, source accuracy: 0.911963
target classifier loss: 2.165433, target accuracy: 0.456604
step:  1280
source classifier loss: 0.373581, source accuracy: 0.912673
target classifier loss: 2.165725, target accuracy: 0.455346
step:  1290
source classifier loss: 0.373167, source accuracy: 0.911608
target classifier loss: 2.166076, target accuracy: 0.457862
step:  1300
source classifier loss: 0.372819, source accuracy: 0.910898
target classifier loss: 2.165650, target accuracy: 0.456604
step:  1310
source classifier loss: 0.372504, source accuracy: 0.910898
target classifier loss: 2.169392, target accuracy: 0.455346
step:  1320
source classifier loss: 0.372125, source accuracy: 0.911253
target classifier loss: 2.170741, target accuracy: 0.454088
step:  1330
source classifier loss: 0.371831, source accuracy: 0.911608
targ

fig = plt.figure()
ax1 = fig.add_subplot(211)
fig.subplots_adjust(bottom=0.2)
line1 = ax1.plot(acc['src'],'b',label='Source Accuracy')
line2 = ax1.plot(acc['tar'],'g',label='Target Accuracy')
ax1.set_ylim(0,1)
lines = line1+line2
labels = [l.get_label() for l in lines]
ax1.legend(lines, labels, loc=(0,-0.4), ncol=2)
plt.show()