In [1]:
import torch
import torchvision
import os.path

import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

from sklearn import manifold
from math import exp, sqrt
from torch.autograd import Variable
from my_dataset import MNIST_M
from my_dataset import ST_Dataset

%matplotlib inline
%load_ext skip_kernel_extension

import torch.utils.model_zoo as model_zoo

from data_loader import get_train_test_loader, get_office31_dataloader

use_gpu = torch.cuda.is_available()
print("use_gpu = " + str(use_gpu))

def reset_seq(seq):
    for m in seq:
        if isinstance(m, nn.Conv2d):
            n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            m.weight.data.normal_(0, sqrt(2 / n))
            if m.bias is not None:
                m.bias.data.zero_()
        elif isinstance(m, nn.Linear):
            size = m.weight.size()
            fan_out = size[0] # number of rows
            fan_in = size[1] # number of columns
            m.weight.data.normal_(0, sqrt(2 / (fan_in + fan_out)))
            m.bias.data.zero_()
        elif hasattr(m, 'reset_parameters'):
            m.reset_parameters()
            
def evaluate_da_accuracy(model, dataloader, source):
    model.eval()
    correct_LC = 0
    correct_DC = 0
    total = 0
    
    with torch.no_grad():
        for i, data in enumerate(dataloader):
            inputs, labels = data
            if (use_gpu):
                inputs, labels = inputs.cuda(), labels.cuda()
            inputs, labels = Variable(inputs), Variable(labels)
            outputs_LC, outputs_DC = model(inputs)
            correct_LC += (torch.max(outputs_LC.data, 1)[1] == labels.data).sum().item()
            if source:
                correct_DC += labels.size()[0] - outputs_DC.data.sum().item()
            else:
                correct_DC += outputs_DC.data.sum().item()
            total += labels.size()[0]
        acc_LC = correct_LC / total
        acc_DC = correct_DC / total
    return acc_LC, acc_DC

use_gpu = True


In [2]:
trainloader_source = get_office31_dataloader("webcam", batch_size=79)

[INFO] Loading datasets: webcam


In [3]:
trainloader_target = get_office31_dataloader("dslr", batch_size=49)

[INFO] Loading datasets: dslr


In [4]:
class GRL_func(torch.autograd.Function):

    @staticmethod
    def forward(ctx, inputs, lamda):
        ctx.save_for_backward(lamda)
        return inputs

    @staticmethod
    def backward(ctx, grad_outputs):
        lamda, = ctx.saved_tensors
        return -lamda * grad_outputs, None

class GRL(nn.Module):
    
    def __init__(self, lamda_init):
        super(GRL, self).__init__()
        self.GRL_func = GRL_func.apply
        self.lamda = nn.Parameter(torch.Tensor(1), requires_grad=False)
        self.set_lamda(lamda_init)
        
    def forward(self, x):
        return self.GRL_func(x, self.lamda)
    
    def set_lamda(self, lamda_new):
        self.lamda[0] = lamda_new

In [5]:
class AlexNet_DA(nn.Module):

    def __init__(self, lamda_init):
        super(AlexNet_DA, self).__init__()
        # lambda
        self.lamda = lamda_init
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 256),
            nn.ReLU(inplace=True)
        )
        self.label_classifier = nn.Sequential(
            nn.Linear(256, 31)
        )
        self.GRL_layer = GRL(lamda_init)
        self.domain_classifier = nn.Sequential(
            nn.Linear(256, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), 256 * 6 * 6)
        x = self.classifier(x)
        x_l = self.label_classifier(x)
        x_d = self.GRL_layer(x)
        x_d = self.domain_classifier(x_d)
        return x_l, x_d
    
    def set_lamda(self, lamda_new):
        self.GRL_layer.set_lamda(lamda_new)
    
    def load_pretrained_part(self, state_dict):
        own_state = self.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                continue
            if isinstance(param, nn.Parameter):
                # backwards compatibility for serialized parameters
                param = param.data
            own_state[name].copy_(param)

In [6]:
cnn_da = AlexNet_DA(0)

model_url = "https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth"

cnn_da.load_pretrained_part(model_zoo.load_url(model_url))

if (use_gpu):
    cnn_da.cuda()
print(cnn_da)

AlexNet_DA(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Dropout(p=0.5)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
    (2): ReLU(inplace)
    (3): Dropout(p=0.5)
    (4): Linear(in_features=4096, out_feat

In [7]:
for param in cnn_da.features.parameters():
    param.requires_grad = False

lr_init = 0.01
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(filter(lambda p: p.requires_grad, cnn_da.parameters()), lr=lr_init, momentum=0.9)

def adjust_lr(optimizer, p):
    global lr_init
    lr_0 = lr_init
    alpha = 10
    beta = 0.75
    lr = lr_0 / (1 + alpha * p) ** beta
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr

In [8]:
prev_loss = np.float("inf")
total_epoch = 10
finetune_epoch_start = 20
#reset_seq(cnn_da.classifier)
reset_seq(cnn_da.label_classifier)

early_stop_cnt = 0

for epoch in range(total_epoch):
    epoch_loss = 0.0
    running_loss = 0.0
    p = epoch * 1.0 / total_epoch
    lr = adjust_lr(optimizer, p)
    if epoch == finetune_epoch_start:
        for param in cnn_da.features.parameters():
            param.requires_grad = True
        optimizer = optim.SGD(filter(lambda p: p.requires_grad, cnn_da.parameters()), lr=lr, momentum=0.9)
    for i, data in enumerate(trainloader_source):
        inputs, labels = data
        if (use_gpu):
            inputs, labels = inputs.cuda(), labels.cuda()
        inputs, labels = Variable(inputs), Variable(labels)
        optimizer.zero_grad()
        outputs_LC, _ = cnn_da(inputs)
        loss = criterion(outputs_LC, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        running_loss += loss.item()
        if i % 2 == 1:    # print every 50 mini-batches
            print('[%2d] batch loss: %.3f' %
                  (i + 1, running_loss / 2))
            running_loss = 0.0
    print("epoch %d loss: %.3f -> %.3f\n" % (epoch + 1, prev_loss, epoch_loss))
    if prev_loss - epoch_loss < 0.1:
        prev_loss = epoch_loss
        if epoch > finetune_epoch_start:
            early_stop_cnt += 1
            if early_stop_cnt > 2:
                break
    else:
        early_stop_cnt = 0
        prev_loss = epoch_loss

[ 2] batch loss: 3.722
[ 4] batch loss: 3.051
[ 6] batch loss: 2.204
[ 8] batch loss: 1.496
[10] batch loss: 0.929
epoch 1 loss: inf -> 23.945

[ 2] batch loss: 0.792
[ 4] batch loss: 0.470
[ 6] batch loss: 0.454
[ 8] batch loss: 0.489
[10] batch loss: 0.342
epoch 2 loss: 23.945 -> 5.105

[ 2] batch loss: 0.279
[ 4] batch loss: 0.258
[ 6] batch loss: 0.249
[ 8] batch loss: 0.175
[10] batch loss: 0.201
epoch 3 loss: 5.105 -> 2.646

[ 2] batch loss: 0.162
[ 4] batch loss: 0.175
[ 6] batch loss: 0.159
[ 8] batch loss: 0.287
[10] batch loss: 0.155
epoch 4 loss: 2.646 -> 1.878

[ 2] batch loss: 0.085
[ 4] batch loss: 0.076
[ 6] batch loss: 0.085
[ 8] batch loss: 0.065
[10] batch loss: 0.108
epoch 5 loss: 1.878 -> 1.257

[ 2] batch loss: 0.079
[ 4] batch loss: 0.149
[ 6] batch loss: 0.241
[ 8] batch loss: 0.171
[10] batch loss: 0.185
epoch 6 loss: 1.257 -> 2.000

[ 2] batch loss: 0.053
[ 4] batch loss: 0.138
[ 6] batch loss: 0.108
[ 8] batch loss: 0.113
[10] batch loss: 0.196
epoch 7 loss: 2

In [9]:
evaluate_da_accuracy(cnn_da, trainloader_source, source=True)

(1.0, 0.5158779534153969)

In [10]:
evaluate_da_accuracy(cnn_da, trainloader_target, source=False)

(0.9678714859437751, 0.48369674701767273)

In [30]:
from math import exp

lr_init = 0.01
criterion_LC = nn.CrossEntropyLoss()
criterion_DC = nn.BCELoss()
optimizer = optim.SGD(filter(lambda p: p.requires_grad, cnn_da.parameters()), lr=lr_init, momentum=0.9)

def adjust_lr(optimizer, p):
    global lr_init
    lr_0 = lr_init
    alpha = 10
    beta = 0.75
    lr = lr_0 / (1 + alpha * p) ** beta
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr

def adjust_lamda(model, p):
    gamma = 10
    lamda = 2 / (1 + exp(- gamma * p)) - 1
    model.set_lamda(lamda)
    return lamda

In [31]:
for param in cnn_da.features.parameters():
    param.requires_grad = True

In [32]:
prev_loss = np.float("inf")
prev_loss_LC = np.float("inf")
prev_loss_DC = np.float("inf")
total_epoch = 100
reset_seq(cnn_da.domain_classifier)

for epoch in range(total_epoch):
    epoch_loss = 0.0
    epoch_loss_LC = 0.0
    epoch_loss_DC = 0.0
    running_loss = 0.0
    p = epoch * 1.0 / total_epoch
    lr = adjust_lr(optimizer, p)
    dslr_iter = iter(trainloader_source)
    webcam_iter = iter(trainloader_target)
    i = 0
    while True:
        try:
            images_s, labels_s = dslr_iter.next()
            images_t, labels_t = webcam_iter.next()
        except:
            break
        inputs, labels = torch.cat((images_s, images_t)), torch.cat((labels_s, labels_t))
        source_size, target_size = labels_s.size(0), labels_t.size(0)
        domains = torch.cat((torch.zeros(source_size), torch.ones(target_size)))
        if (use_gpu):
            inputs, labels, domains = inputs.cuda(), labels.cuda(), domains.cuda()
        inputs, labels, domains = Variable(inputs), Variable(labels), Variable(domains)
        optimizer.zero_grad()
        outputs_LC, outputs_DC = cnn_da(inputs)
        loss_LC = criterion_LC(outputs_LC[:source_size], labels[:source_size])
        loss_DC = criterion_DC(outputs_DC.view(-1), domains)
        loss = loss_LC + loss_DC
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_loss_LC += loss_LC.item()
        epoch_loss_DC += loss_DC.item()
        running_loss += loss.item()
        if i % 2 == 1:    # print every 50 mini-batches
            print('[%2d] batch loss: %.3f' %
                  (i + 1, running_loss / 2))
            running_loss = 0.0
        i += 1
    print("epoch %d loss: %.3f -> %.3f" % (epoch + 1, prev_loss, epoch_loss))
    print("LC loss: %.3f -> %.3f" % (prev_loss_LC, epoch_loss_LC))
    print("DC loss: %.3f -> %.3f" % (prev_loss_DC, epoch_loss_DC))
    if epoch % 10 == 9:
        acc_l, acc_d = evaluate_da_accuracy(cnn_da, trainloader_target, source=False)
        print(acc_l, acc_d)
    print()
    prev_loss = epoch_loss
    prev_loss_LC = epoch_loss_LC
    prev_loss_DC = epoch_loss_DC

[ 2] batch loss: 1.038
[ 4] batch loss: 0.926
[ 6] batch loss: 1.008
[ 8] batch loss: 1.434
[10] batch loss: 2.048
epoch 1 loss: inf -> 13.695
LC loss: inf -> 6.024
DC loss: inf -> 7.671

[ 2] batch loss: 0.773
[ 4] batch loss: 1.270
[ 6] batch loss: 0.917
[ 8] batch loss: 1.120
[10] batch loss: 0.997
epoch 2 loss: 13.695 -> 10.838
LC loss: 6.024 -> 2.574
DC loss: 7.671 -> 8.264

[ 2] batch loss: 0.689
[ 4] batch loss: 0.762
[ 6] batch loss: 0.700
[ 8] batch loss: 0.771
[10] batch loss: 1.259
epoch 3 loss: 10.838 -> 9.065
LC loss: 2.574 -> 1.468
DC loss: 8.264 -> 7.596

[ 2] batch loss: 0.634
[ 4] batch loss: 0.742
[ 6] batch loss: 0.718
[ 8] batch loss: 1.108
[10] batch loss: 0.756
epoch 4 loss: 9.065 -> 8.946
LC loss: 1.468 -> 1.188
DC loss: 7.596 -> 7.758

[ 2] batch loss: 0.675
[ 4] batch loss: 0.704
[ 6] batch loss: 0.703
[ 8] batch loss: 0.677
[10] batch loss: 0.692
epoch 5 loss: 8.946 -> 7.549
LC loss: 1.188 -> 0.176
DC loss: 7.758 -> 7.373

[ 2] batch loss: 0.616
[ 4] batch los

[ 2] batch loss: 0.686
[ 4] batch loss: 0.553
[ 6] batch loss: 0.514
[ 8] batch loss: 0.570
[10] batch loss: 0.513
epoch 43 loss: 9.793 -> 6.028
LC loss: 0.012 -> 0.012
DC loss: 9.781 -> 6.016

[ 2] batch loss: 0.513
[ 4] batch loss: 0.499
[ 6] batch loss: 0.485
[ 8] batch loss: 0.497
[10] batch loss: 0.465
epoch 44 loss: 6.028 -> 5.159
LC loss: 0.012 -> 0.012
DC loss: 6.016 -> 5.147

[ 2] batch loss: 0.442
[ 4] batch loss: 0.480
[ 6] batch loss: 0.447
[ 8] batch loss: 0.461
[10] batch loss: 0.441
epoch 45 loss: 5.159 -> 4.694
LC loss: 0.012 -> 0.012
DC loss: 5.147 -> 4.682

[ 2] batch loss: 0.413
[ 4] batch loss: 0.441
[ 6] batch loss: 0.412
[ 8] batch loss: 0.420
[10] batch loss: 0.401
epoch 46 loss: 4.694 -> 4.271
LC loss: 0.012 -> 0.012
DC loss: 4.682 -> 4.260

[ 2] batch loss: 0.383
[ 4] batch loss: 0.408
[ 6] batch loss: 0.380
[ 8] batch loss: 0.382
[10] batch loss: 0.379
epoch 47 loss: 4.271 -> 3.941
LC loss: 0.012 -> 0.011
DC loss: 4.260 -> 3.930

[ 2] batch loss: 0.355
[ 4] ba

[10] batch loss: 0.615
epoch 84 loss: 6.198 -> 7.637
LC loss: 0.008 -> 0.008
DC loss: 6.190 -> 7.629

[ 2] batch loss: 0.317
[ 4] batch loss: 0.448
[ 6] batch loss: 0.478
[ 8] batch loss: 0.495
[10] batch loss: 0.310
epoch 85 loss: 7.637 -> 4.192
LC loss: 0.008 -> 0.008
DC loss: 7.629 -> 4.184

[ 2] batch loss: 0.263
[ 4] batch loss: 0.368
[ 6] batch loss: 0.286
[ 8] batch loss: 0.305
[10] batch loss: 0.272
epoch 86 loss: 4.192 -> 3.091
LC loss: 0.008 -> 0.008
DC loss: 4.184 -> 3.083

[ 2] batch loss: 0.249
[ 4] batch loss: 0.273
[ 6] batch loss: 0.233
[ 8] batch loss: 0.245
[10] batch loss: 0.255
epoch 87 loss: 3.091 -> 2.575
LC loss: 0.008 -> 0.008
DC loss: 3.083 -> 2.568

[ 2] batch loss: 0.228
[ 4] batch loss: 0.247
[ 6] batch loss: 0.208
[ 8] batch loss: 0.211
[10] batch loss: 0.230
epoch 88 loss: 2.575 -> 2.300
LC loss: 0.008 -> 0.008
DC loss: 2.568 -> 2.293

[ 2] batch loss: 0.197
[ 4] batch loss: 0.227
[ 6] batch loss: 0.185
[ 8] batch loss: 0.189
[10] batch loss: 0.215
epoch 8

In [33]:
#torch.save(cnn_da.state_dict(), "./parameters/cnn_webcam_to_dslr_0.989.pt")

In [39]:
#cnn_da.load_state_dict(torch.load("./parameters/cnn_dslr_to_webcam.pt"))