In [1]:
import torch
import torchvision
import os.path

import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

from sklearn import manifold
from math import exp, sqrt
from torch.autograd import Variable
from my_dataset import MNIST_M
from my_dataset import ST_Dataset

%matplotlib inline
%load_ext skip_kernel_extension

import torch.utils.model_zoo as model_zoo

from data_loader import get_train_test_loader, get_office31_dataloader

use_gpu = torch.cuda.is_available()
print("use_gpu = " + str(use_gpu))

def reset_seq(seq):
    for m in seq:
        if isinstance(m, nn.Conv2d):
            n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            m.weight.data.normal_(0, sqrt(2 / n))
            if m.bias is not None:
                m.bias.data.zero_()
        elif isinstance(m, nn.Linear):
            size = m.weight.size()
            fan_out = size[0] # number of rows
            fan_in = size[1] # number of columns
            m.weight.data.normal_(0, sqrt(2 / (fan_in + fan_out)))
            m.bias.data.zero_()
        elif hasattr(m, 'reset_parameters'):
            m.reset_parameters()
            
def evaluate_da_accuracy(model, dataloader, source):
    model.eval()
    correct_LC = 0
    correct_DC = 0
    total = 0
    
    with torch.no_grad():
        for i, data in enumerate(dataloader):
            inputs, labels = data
            if (use_gpu):
                inputs, labels = inputs.cuda(), labels.cuda()
            inputs, labels = Variable(inputs), Variable(labels)
            outputs_LC, outputs_DC = model(inputs)
            correct_LC += (torch.max(outputs_LC.data, 1)[1] == labels.data).sum().item()
            if source:
                correct_DC += labels.size()[0] - outputs_DC.data.sum().item()
            else:
                correct_DC += outputs_DC.data.sum().item()
            total += labels.size()[0]
        acc_LC = correct_LC / total
        acc_DC = correct_DC / total
    return acc_LC, acc_DC

use_gpu = True


In [2]:
trainloader_dslr = get_office31_dataloader("dslr", batch_size=49)

[INFO] Loading datasets: dslr


In [3]:
trainloader_webcam = get_office31_dataloader("webcam", batch_size=79)

[INFO] Loading datasets: webcam


In [4]:
class GRL_func(torch.autograd.Function):

    @staticmethod
    def forward(ctx, inputs, lamda):
        ctx.save_for_backward(lamda)
        return inputs

    @staticmethod
    def backward(ctx, grad_outputs):
        lamda, = ctx.saved_tensors
        return -lamda * grad_outputs, None

class GRL(nn.Module):
    
    def __init__(self, lamda_init):
        super(GRL, self).__init__()
        self.GRL_func = GRL_func.apply
        self.lamda = nn.Parameter(torch.Tensor(1), requires_grad=False)
        self.set_lamda(lamda_init)
        
    def forward(self, x):
        return self.GRL_func(x, self.lamda)
    
    def set_lamda(self, lamda_new):
        self.lamda[0] = lamda_new

In [5]:
class AlexNet_DA(nn.Module):

    def __init__(self, lamda_init):
        super(AlexNet_DA, self).__init__()
        # lambda
        self.lamda = lamda_init
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 256),
            nn.ReLU(inplace=True)
        )
        self.label_classifier = nn.Sequential(
            nn.Linear(256, 31)
        )
        self.GRL_layer = GRL(lamda_init)
        self.domain_classifier = nn.Sequential(
            nn.Linear(256, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), 256 * 6 * 6)
        x = self.classifier(x)
        x_l = self.label_classifier(x)
        x_d = self.GRL_layer(x)
        x_d = self.domain_classifier(x_d)
        return x_l, x_d
    
    def set_lamda(self, lamda_new):
        self.GRL_layer.set_lamda(lamda_new)
    
    def load_pretrained_part(self, state_dict):
        own_state = self.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                continue
            if isinstance(param, nn.Parameter):
                # backwards compatibility for serialized parameters
                param = param.data
            own_state[name].copy_(param)

In [15]:
cnn_da = AlexNet_DA(0)

model_url = "https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth"

cnn_da.load_pretrained_part(model_zoo.load_url(model_url))

if (use_gpu):
    cnn_da.cuda()
print(cnn_da)

AlexNet_DA(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Dropout(p=0.5)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
    (2): ReLU(inplace)
    (3): Dropout(p=0.5)
    (4): Linear(in_features=4096, out_feat

In [16]:
para_file_dslr = "./parameters/cnn_dslr_to_webcam_source_only.pt"
load_model_dslr = os.path.isfile(para_file_dslr)
print(load_model_dslr)

True


In [17]:
for param in cnn_da.features.parameters():
    param.requires_grad = False

lr_init = 0.01
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(filter(lambda p: p.requires_grad, cnn_da.parameters()), lr=lr_init, momentum=0.9)

def adjust_lr(optimizer, p):
    global lr_init
    lr_0 = lr_init
    alpha = 10
    beta = 0.75
    lr = lr_0 / (1 + alpha * p) ** beta
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr

In [18]:
prev_loss = np.float("inf")
total_epoch = 10
finetune_epoch_start = 20
#reset_seq(cnn_da.classifier)
reset_seq(cnn_da.label_classifier)

early_stop_cnt = 0

for epoch in range(total_epoch):
    epoch_loss = 0.0
    running_loss = 0.0
    p = epoch * 1.0 / total_epoch
    lr = adjust_lr(optimizer, p)
    if epoch == finetune_epoch_start:
        for param in cnn_da.features.parameters():
            param.requires_grad = True
        optimizer = optim.SGD(filter(lambda p: p.requires_grad, cnn_da.parameters()), lr=lr, momentum=0.9)
    for i, data in enumerate(trainloader_dslr):
        inputs, labels = data
        if (use_gpu):
            inputs, labels = inputs.cuda(), labels.cuda()
        inputs, labels = Variable(inputs), Variable(labels)
        optimizer.zero_grad()
        outputs_LC, _ = cnn_da(inputs)
        loss = criterion(outputs_LC, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        running_loss += loss.item()
        if i % 2 == 1:    # print every 50 mini-batches
            print('[%2d] batch loss: %.3f' %
                  (i + 1, running_loss / 2))
            running_loss = 0.0
    print("epoch %d loss: %.3f -> %.3f\n" % (epoch + 1, prev_loss, epoch_loss))
    if prev_loss - epoch_loss < 0.1:
        prev_loss = epoch_loss
        if epoch > finetune_epoch_start:
            early_stop_cnt += 1
            if early_stop_cnt > 2:
                break
    else:
        early_stop_cnt = 0
        prev_loss = epoch_loss

[ 2] batch loss: 3.625
[ 4] batch loss: 6.135
[ 6] batch loss: 4.907
[ 8] batch loss: 4.585
[10] batch loss: 4.508
epoch 1 loss: inf -> 51.622

[ 2] batch loss: 1.235
[ 4] batch loss: 2.368
[ 6] batch loss: 2.375
[ 8] batch loss: 2.524
[10] batch loss: 2.663
epoch 2 loss: 51.622 -> 23.105

[ 2] batch loss: 0.645
[ 4] batch loss: 1.399
[ 6] batch loss: 1.200
[ 8] batch loss: 1.656
[10] batch loss: 1.785
epoch 3 loss: 23.105 -> 13.486

[ 2] batch loss: 0.413
[ 4] batch loss: 0.664
[ 6] batch loss: 0.465
[ 8] batch loss: 0.869
[10] batch loss: 1.081
epoch 4 loss: 13.486 -> 6.997

[ 2] batch loss: 0.297
[ 4] batch loss: 0.418
[ 6] batch loss: 0.274
[ 8] batch loss: 0.479
[10] batch loss: 0.562
epoch 5 loss: 6.997 -> 4.069

[ 2] batch loss: 0.130
[ 4] batch loss: 0.269
[ 6] batch loss: 0.288
[ 8] batch loss: 0.231
[10] batch loss: 0.426
epoch 6 loss: 4.069 -> 2.692

[ 2] batch loss: 0.047
[ 4] batch loss: 0.160
[ 6] batch loss: 0.114
[ 8] batch loss: 0.265
[10] batch loss: 0.301
epoch 7 los

In [19]:
evaluate_da_accuracy(cnn_da, trainloader_dslr, source=True)

(1.0, 0.532371352475331)

In [20]:
evaluate_da_accuracy(cnn_da, trainloader_webcam, source=False)

(0.9144654088050315, 0.47331236773316965)

In [16]:
#torch.save(cnn_da.state_dict(), para_file_dslr)

In [8]:
#cnn_da.load_pretrained_part(torch.load(para_file_dslr))

In [21]:
from math import exp

lr_init = 0.01
criterion_LC = nn.CrossEntropyLoss()
criterion_DC = nn.BCELoss()
optimizer = optim.SGD(filter(lambda p: p.requires_grad, cnn_da.parameters()), lr=lr_init, momentum=0.9)

def adjust_lr(optimizer, p):
    global lr_init
    lr_0 = lr_init
    alpha = 10
    beta = 0.75
    lr = lr_0 / (1 + alpha * p) ** beta
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr

def adjust_lamda(model, p):
    gamma = 10
    lamda = 2 / (1 + exp(- gamma * p)) - 1
    model.set_lamda(lamda)
    return lamda

In [22]:
for param in cnn_da.features.parameters():
    param.requires_grad = True

In [23]:
prev_loss = np.float("inf")
prev_loss_LC = np.float("inf")
prev_loss_DC = np.float("inf")
total_epoch = 100
reset_seq(cnn_da.domain_classifier)

for epoch in range(total_epoch):
    epoch_loss = 0.0
    epoch_loss_LC = 0.0
    epoch_loss_DC = 0.0
    running_loss = 0.0
    p = epoch * 1.0 / total_epoch
    lr = adjust_lr(optimizer, p)
    dslr_iter = iter(trainloader_dslr)
    webcam_iter = iter(trainloader_webcam)
    i = 0
    while True:
        try:
            images_s, labels_s = dslr_iter.next()
            images_t, labels_t = webcam_iter.next()
        except:
            break
        inputs, labels = torch.cat((images_s, images_t)), torch.cat((labels_s, labels_t))
        source_size, target_size = labels_s.size(0), labels_t.size(0)
        domains = torch.cat((torch.zeros(source_size), torch.ones(target_size)))
        if (use_gpu):
            inputs, labels, domains = inputs.cuda(), labels.cuda(), domains.cuda()
        inputs, labels, domains = Variable(inputs), Variable(labels), Variable(domains)
        optimizer.zero_grad()
        outputs_LC, outputs_DC = cnn_da(inputs)
        loss_LC = criterion_LC(outputs_LC[:source_size], labels[:source_size])
        loss_DC = criterion_DC(outputs_DC.view(-1), domains)
        loss = loss_LC + loss_DC
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_loss_LC += loss_LC.item()
        epoch_loss_DC += loss_DC.item()
        running_loss += loss.item()
        if i % 2 == 1:    # print every 50 mini-batches
            print('[%2d] batch loss: %.3f' %
                  (i + 1, running_loss / 2))
            running_loss = 0.0
        i += 1
    print("epoch %d loss: %.3f -> %.3f" % (epoch + 1, prev_loss, epoch_loss))
    print("LC loss: %.3f -> %.3f" % (prev_loss_LC, epoch_loss_LC))
    print("DC loss: %.3f -> %.3f" % (prev_loss_DC, epoch_loss_DC))
    if epoch % 10 == 9:
        acc_l, acc_d = evaluate_da_accuracy(cnn_da, trainloader_webcam, source=False)
        print(acc_l, acc_d)
    print()
    prev_loss = epoch_loss
    prev_loss_LC = epoch_loss_LC
    prev_loss_DC = epoch_loss_DC

[ 2] batch loss: 0.677
[ 4] batch loss: 0.714
[ 6] batch loss: 0.690
[ 8] batch loss: 0.745
[10] batch loss: 0.773
epoch 1 loss: inf -> 7.918
LC loss: inf -> 0.344
DC loss: inf -> 7.574

[ 2] batch loss: 0.836
[ 4] batch loss: 0.736
[ 6] batch loss: 0.721
[ 8] batch loss: 0.733
[10] batch loss: 0.724
epoch 2 loss: 7.918 -> 8.273
LC loss: 0.344 -> 0.195
DC loss: 7.574 -> 8.078

[ 2] batch loss: 0.650
[ 4] batch loss: 0.669
[ 6] batch loss: 0.631
[ 8] batch loss: 0.658
[10] batch loss: 0.696
epoch 3 loss: 8.273 -> 7.642
LC loss: 0.195 -> 0.207
DC loss: 8.078 -> 7.435

[ 2] batch loss: 0.619
[ 4] batch loss: 0.688
[ 6] batch loss: 0.683
[ 8] batch loss: 0.667
[10] batch loss: 0.905
epoch 4 loss: 7.642 -> 7.612
LC loss: 0.207 -> 0.709
DC loss: 7.435 -> 6.903

[ 2] batch loss: 0.583
[ 4] batch loss: 0.703
[ 6] batch loss: 0.655
[ 8] batch loss: 0.782
[10] batch loss: 2.174
epoch 5 loss: 7.612 -> 10.313
LC loss: 0.709 -> 3.552
DC loss: 6.903 -> 6.761

[ 2] batch loss: 0.621
[ 4] batch loss: 

[ 2] batch loss: 1.004
[ 4] batch loss: 1.774
[ 6] batch loss: 1.018
[ 8] batch loss: 0.858
[10] batch loss: 0.477
epoch 43 loss: 6.443 -> 10.400
LC loss: 0.012 -> 0.012
DC loss: 6.431 -> 10.388

[ 2] batch loss: 0.546
[ 4] batch loss: 0.580
[ 6] batch loss: 0.493
[ 8] batch loss: 0.510
[10] batch loss: 0.497
epoch 44 loss: 10.400 -> 5.439
LC loss: 0.012 -> 0.012
DC loss: 10.388 -> 5.427

[ 2] batch loss: 0.565
[ 4] batch loss: 0.539
[ 6] batch loss: 0.483
[ 8] batch loss: 0.502
[10] batch loss: 0.452
epoch 45 loss: 5.439 -> 5.227
LC loss: 0.012 -> 0.011
DC loss: 5.427 -> 5.216

[ 2] batch loss: 0.379
[ 4] batch loss: 0.525
[ 6] batch loss: 0.462
[ 8] batch loss: 0.473
[10] batch loss: 0.421
epoch 46 loss: 5.227 -> 4.638
LC loss: 0.011 -> 0.011
DC loss: 5.216 -> 4.626

[ 2] batch loss: 0.339
[ 4] batch loss: 0.501
[ 6] batch loss: 0.427
[ 8] batch loss: 0.446
[10] batch loss: 0.399
epoch 47 loss: 4.638 -> 4.328
LC loss: 0.011 -> 0.011
DC loss: 4.626 -> 4.317

[ 2] batch loss: 0.311
[ 4

[10] batch loss: 0.173
epoch 84 loss: 1.956 -> 1.850
LC loss: 0.007 -> 0.007
DC loss: 1.949 -> 1.842

[ 2] batch loss: 0.120
[ 4] batch loss: 0.241
[ 6] batch loss: 0.163
[ 8] batch loss: 0.185
[10] batch loss: 0.166
epoch 85 loss: 1.850 -> 1.755
LC loss: 0.007 -> 0.007
DC loss: 1.842 -> 1.747

[ 2] batch loss: 0.113
[ 4] batch loss: 0.231
[ 6] batch loss: 0.154
[ 8] batch loss: 0.175
[10] batch loss: 0.159
epoch 86 loss: 1.755 -> 1.670
LC loss: 0.007 -> 0.007
DC loss: 1.747 -> 1.663

[ 2] batch loss: 0.108
[ 4] batch loss: 0.222
[ 6] batch loss: 0.146
[ 8] batch loss: 0.166
[10] batch loss: 0.153
epoch 87 loss: 1.670 -> 1.595
LC loss: 0.007 -> 0.007
DC loss: 1.663 -> 1.588

[ 2] batch loss: 0.102
[ 4] batch loss: 0.214
[ 6] batch loss: 0.139
[ 8] batch loss: 0.158
[10] batch loss: 0.148
epoch 88 loss: 1.595 -> 1.527
LC loss: 0.007 -> 0.007
DC loss: 1.588 -> 1.520

[ 2] batch loss: 0.098
[ 4] batch loss: 0.207
[ 6] batch loss: 0.133
[ 8] batch loss: 0.151
[10] batch loss: 0.143
epoch 8

In [16]:
#torch.save(cnn_da.state_dict(), "./parameters/cnn_dslr_to_webcam_0.933.pt")

In [17]:
cnn_da.load_state_dict(torch.load("./parameters/cnn_dslr_to_webcam_0.934.pt"))

In [18]:
evaluate_da_accuracy(cnn_da, trainloader_webcam, source=False)

(0.9345911949685535, 0.6322475808221589)