In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import numpy as np
import math
import torch.backends.cudnn as cudnn
import copy
from torch.autograd import Function
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, random_split
import os

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

batch_size = 256
image_size = 28*28
alpha = 0.005
DANN_EPOCHES = 50
DANN_TRAINING_BATCH = 40
dann_path = 'dann.pth'

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_set = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
test_set = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)

# Split training set for training and validation
train_size = int(0.8 * len(train_set))
val_size = len(train_set) - train_size
train_set, val_set = random_split(train_set, [train_size, val_size])

# DataLoader for validation set
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, drop_last = True)
train_loader =  DataLoader(train_set, batch_size=256, shuffle=False, drop_last = True)
test_loader =  DataLoader(test_set, batch_size=256, shuffle=False, drop_last = True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:01<00:00, 85835463.70it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


# DANN

In [4]:
class ReverseLayerF(Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha

        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output.neg() * ctx.alpha

        return output, None
    
class ConvAutoencoder(nn.Module):
    def __init__(self):
        super(ConvAutoencoder, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, 3, stride=2, padding=1), 
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 7)
        )
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 7),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 3, 3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


class VGG(nn.Module):
    def __init__(self, out_dim = 10):
        super(VGG, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
        )

        self.classifier = nn.Sequential(
            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, out_dim),
            nn.Softmax(dim=1)
        )

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                m.bias.data.zero_()

    def forward(self, x):
        x = self.features(x)
        x = x.view(-1, 512)
        x = self.classifier(x)
        return x

In [5]:
class DANN(nn.Module):
    def __init__(self):
        super(DANN, self).__init__()
        self.autoencoder = ConvAutoencoder()
        self.classifier = VGG(out_dim = 10)
        self.domain_classifier = VGG(out_dim = 2)

    def forward(self, input_data, alpha):
        feature = self.autoencoder(input_data)
        reverse_feature = ReverseLayerF.apply(feature, alpha)
        class_output = self.classifier(feature)
        domain_output = self.domain_classifier(reverse_feature)
        
        return class_output, domain_output

In [6]:
def train_DANN(train_loader, model, criterion, optimizer, epoches):
    model.train()
    src_domain_label = torch.ones(batch_size).long().to(device)
    tgt_domain_label = torch.zeros(batch_size).long().to(device)

    for e in range(epoches):
        data_target_iter = iter(train_loader)
        correct_source_domain, correct_tgt_domain = 0, 0
        total = 0
        for i in range(DANN_TRAINING_BATCH):
            # Src
            source, source_label = next(data_target_iter)
            total += source.size(0)

            source, source_label = source.to(device), source_label.to(device)

            class_output, domain_output = model(source, alpha)

            loss_s_label = criterion(class_output, source_label)
            loss_s_domain = criterion(domain_output, src_domain_label)

            _, predicted = torch.max(domain_output.data, 1)
            correct_source_domain += predicted.eq(src_domain_label.data).cpu().sum().item()

            # Tgt
            target, target_label  = next(data_target_iter)
            target, target_label = target.to(device), target_label.to(device)

            class_output, domain_output = model(target, alpha)

            loss_t_label = criterion(class_output, target_label)
            loss_t_domain = criterion(domain_output, tgt_domain_label)

            _, predicted = torch.max(domain_output.data, 1)
            correct_tgt_domain += predicted.eq(tgt_domain_label.data).cpu().sum().item()

            loss = loss_s_label + loss_s_domain + loss_t_domain + loss_t_label
            loss.backward()
            optimizer.step()
        if ((e + 1) % 5 == 0):
            print(f"{e}: source correct: {correct_source_domain/total}, \
                        target correct: {correct_tgt_domain/total}")

In [7]:
if not os.path.exists(dann_path):
    criterion = nn.CrossEntropyLoss()
    dann = DANN().to(device)
    optimizer = optim.Adam(dann.parameters(), lr=0.001)

    train_DANN(train_loader, dann, criterion, optimizer, DANN_EPOCHES)
    with torch.no_grad():
        torch.save(dann.state_dict(), dann_path)
else:
    dann = DANN().to(device)
    dann.load_state_dict(torch.load(dann_path))
    print("Loaded model from file.")

4: source correct: 0.94189453125,                         target correct: 0.0453125
9: source correct: 0.4751953125,                         target correct: 0.519921875
14: source correct: 0.33388671875,                         target correct: 0.69306640625
19: source correct: 0.89814453125,                         target correct: 0.104296875
24: source correct: 0.03935546875,                         target correct: 0.9677734375
29: source correct: 0.990625,                         target correct: 0.0091796875
34: source correct: 0.5033203125,                         target correct: 0.4974609375
39: source correct: 0.02919921875,                         target correct: 0.9734375
44: source correct: 0.983203125,                         target correct: 0.016015625
49: source correct: 0.01865234375,                         target correct: 0.98388671875


# Classifier

In [None]:
def test(model, testloader):
    correct = 0
    total = 0

    for images, label in testloader:
        images, label = images.to(device), label.to(device)
        features = dann.autoencoder(images)
        outputs = model(features)
        _, predicted = torch.max(outputs.data, 1)
        total += label.size(0)
        correct += (predicted == label).sum().item()

    return correct / total

In [None]:
def train(model, epoches, criterion, optimizer):
    best_model_wts = None
    best_loss = float('inf')
    batch_num = 0
    warm_up_batch = 3

    for inputs, labels in train_loader:
        if (best_model_wts):
            model.load_state_dict(best_model_wts)

        inputs, labels = inputs.to(device), labels.to(device)
        prev_loss = float('inf')
        for epoch in range(epoches):
            model.train()
            optimizer.zero_grad()
            features = dann.autoencoder(inputs)
            outputs = model(features)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            model.eval()
            val_loss = 0.0
            with torch.no_grad():
                for val_inputs, val_labels in val_loader:
                    val_inputs, val_labels = val_inputs.to(device), val_labels.to(device)
                    outputs = model(val_inputs)
                    batch_loss = criterion(outputs, val_labels)
                    val_loss += batch_loss.item()

                if val_loss < best_loss:
                    best_loss = val_loss
                    best_model_wts = copy.deepcopy(model.state_dict())

            print(f"Batch: {batch_num}, epoch: {epoch}, Train Loss: {loss.item()}, Val Loss: {val_loss}")
            if (prev_loss < val_loss and warm_up_batch < batch_num):
                break
            prev_loss = val_loss
        with torch.no_grad():
            test_acc = test(model, test_loader)
            
        print(f"epoch: {epoch}, Test Acc: {test_acc}")

        batch_num += 1

    return best_model_wts

In [None]:
model = VGG().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0005)

best_model_wts = train(model, 20, criterion, optimizer)