In [6]:

from torch.autograd import Function
import load_mnist_data

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import numpy as np
import torch.backends.cudnn as cudnn



In [None]:
class ReverseLayerF(Function):

    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha

        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output.neg() * ctx.alpha

        return output, None

In [7]:
class DANN(nn.Module):
    def __init__(self):
        super(DANN, self).__init__()
        self.feature = nn.Sequential(
                    nn.Conv2d(3, 64, kernel_size=5),
                    nn.BatchNorm2d(64),
                    nn.MaxPool2d(2),
                    nn.ReLU(True),
                    nn.Conv2d(64, 50, kernel_size=5),
                    nn.BatchNorm2d(50),
                    nn.Dropout2d(),
                    nn.MaxPool2d(2),
                    nn.ReLU(True)
                )
                
        self.avgpool=nn.AdaptiveAvgPool2d((5,5))
        self.classifier = nn.Sequential(
                    nn.Linear(50 * 4 * 4, 100),
                    nn.BatchNorm1d(100),
                    nn.ReLU(True),
                    nn.Dropout(),
                    nn.Linear(100, 100),
                    nn.BatchNorm1d(100),
                    nn.ReLU(True),
                    nn.Linear(100, 10),
                )


        self.domain_classifier = nn.Sequential(
                    nn.Linear(50 * 4 * 4, 100),
                    nn.BatchNorm1d(100),
                    nn.ReLU(True),
                    nn.Linear(100, 2),
                )
    def forward(self, input_data, alpha):
        input_data = input_data.expand(input_data.data.shape[0], 3, 28, 28)
        feature = self.feature(input_data)
        feature = feature.view(-1, 50 * 4 * 4)
        reverse_feature = ReverseLayerF.apply(feature, alpha)
        class_output = self.classifier(feature)
        domain_output = self.domain_classifier(reverse_feature)

        return class_output, domain_output


In [9]:
def train(source, target, net, criterion, optimizer, epoch, use_cuda=True):
    net.train() # Sets the module in training mode.

    train_loss = 0
    correct_source_label = 0
    correct_source_domain = 0
    correct_target_label = 0
    correct_target_domain = 0
    total = 0
    batch_size = 128

    data_target_iter = iter(target)
    len_dataloader = min(len(source), len(target))

    for batch_idx, (inputs, source_label) in enumerate(source):

        p = float(batch_idx + epoch * len_dataloader) / (200 * len_dataloader)
        alpha = 2. / (1. + np.exp(-10 * p)) - 1

        batch_size = inputs.size(0)
        total += batch_size

        # Feed source image to the network
        source_label = source_label.type(torch.LongTensor)
        domain_label = torch.zeros(batch_size).long()

        if use_cuda:
            inputs, source_label, domain_label = inputs.cuda(), source_label.cuda(), domain_label.cuda()
            
        optimizer.zero_grad()
        inputs, source_label = Variable(inputs), Variable(source_label)
        
        class_output, domain_output = net(inputs, alpha)
        
        _, predicted = torch.max(class_output.data, 1)
        correct_source_label += predicted.eq(source_label.data).cpu().sum().item()
        _, predicted = torch.max(domain_output.data, 1)
        correct_source_domain += predicted.eq(domain_label.data).cpu().sum().item()

        loss_s_label = criterion(class_output, source_label)
        loss_s_domain = criterion(domain_output, domain_label)

        # Feed target image to the network
        target_inputs, target_label = data_target_iter.next()
        domain_label = torch.ones(batch_size).long()
        if use_cuda:
            target_inputs, target_label, domain_label = target_inputs.cuda(), target_label.cuda(), domain_label.cuda()
        
        class_output, domain_output = net(target_inputs, alpha)
        loss_t_domain = criterion(domain_output, domain_label)

        _, predicted = torch.max(class_output.data, 1)
        correct_target_label += predicted.eq(target_label.data).cpu().sum().item()
        _, predicted = torch.max(domain_output.data, 1)
        correct_target_domain += predicted.eq(domain_label.data).cpu().sum().item()

        loss = loss_s_label + loss_s_domain + loss_t_domain
        loss.backward()
        optimizer.step()
    
    return correct_source_label, correct_source_domain, correct_target_label, correct_target_domain, total


In [10]:
loader_source, loader_target = load_mnist_data.get_data_loader(1.0)

In [11]:
criterion = nn.CrossEntropyLoss()

net = DANN()
if (torch.cuda.is_available()):
    torch.cuda.manual_seed_all(42)
    cudnn.benchmark = True
    net.cuda()
    criterion = criterion.cuda()
    
optimizer = optim.Adam(net.parameters(), lr=0.001) 

for epoch in range(0, 201):
    sl, sd, tl, td, total = train(loader_source, loader_target, net, criterion, optimizer, epoch) 

    if (epoch % 10 == 0):
        print("e: %d, sl: %f, sd: %f, tl: %f, td: %f" % (epoch, sl/total, sd/total, tl/total, td/total))

e: 0, sl: 0.825350, sd: 0.819050, tl: 0.357367, td: 0.786783
e: 5, sl: 0.945483, sd: 0.866517, tl: 0.489200, td: 0.862967
e: 10, sl: 0.947933, sd: 0.792900, tl: 0.541817, td: 0.788483
e: 15, sl: 0.948433, sd: 0.766067, tl: 0.580033, td: 0.759033
e: 20, sl: 0.950317, sd: 0.740800, tl: 0.602883, td: 0.732117
e: 25, sl: 0.946933, sd: 0.725217, tl: 0.623417, td: 0.716683
e: 30, sl: 0.948483, sd: 0.703617, tl: 0.622317, td: 0.701867
e: 35, sl: 0.947033, sd: 0.694283, tl: 0.636033, td: 0.690783
e: 40, sl: 0.947600, sd: 0.688467, tl: 0.650367, td: 0.681950
e: 45, sl: 0.947450, sd: 0.680067, tl: 0.657450, td: 0.674250
e: 50, sl: 0.946117, sd: 0.673850, tl: 0.667700, td: 0.670467
e: 55, sl: 0.947133, sd: 0.666150, tl: 0.679850, td: 0.659467
e: 60, sl: 0.947967, sd: 0.663733, tl: 0.679967, td: 0.659450
e: 65, sl: 0.947383, sd: 0.657833, tl: 0.683667, td: 0.654183
e: 70, sl: 0.947800, sd: 0.650567, tl: 0.687700, td: 0.648967
e: 75, sl: 0.948533, sd: 0.651950, tl: 0.689917, td: 0.649033
e: 80, sl: