In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.autograd import Function

In [2]:
# Hyperparameters
latent_dim = 100
lr = 0.0002
batch_size = 256
image_size = 28*28

In [22]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = DataLoader(testset, batch_size=256, shuffle=False, drop_last = True)

# Baseline

In [31]:
class Linear(nn.Module):
    def __init__(self):
        super(Linear, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear (28 * 28 * 1, 1000),
            nn.ReLU(inplace=True),
            nn.Linear (1000, 100),
            nn.ReLU(inplace=True),
            nn.Linear (100, 10),
        )
    def forward(self, x):
        x = x.view(-1, 28 * 28 * 1)
        x = self.classifier(x)

        return x

In [None]:
def test_classification(testloader, model):
    correct = 0
    total = 0
    with torch.no_grad():
        for a, b in testloader:
            a, b = a.to(device), b.to(device)
            outputs = model(a)
            _, predicted = torch.max(outputs.data, 1)
            total += b.size(0)
            correct += (predicted == b).sum().item()
    return correct/total

In [32]:
base = Linear().to(device)
base.train()

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(base.parameters(), lr=0.005)

data_iter = iter(train_loader)
images, labels = next(data_iter)
images, labels = images.to(device), labels.to(device)
for epoch in range(15):
    optimizer.zero_grad()
    outputs = base(images)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    with torch.no_grad():
        torch.save(base.state_dict(), f'Basemodel_batch_A_{epoch}.pth')

    batch_acc = test_classification(testloader, base)

    print(f"Epoch [{epoch+1}/15], Loss: {loss.item():.4f}, test_acc: {batch_acc}")

images, labels = next(data_iter)
images, labels = images.to(device), labels.to(device)
for e in range(15):
    with torch.no_grad():
        load_model = Linear()
        load_model.load_state_dict(torch.load(f'Basemodel_batch_A_{e}.pth'))
        load_model = load_model.to(device)
        outputs = load_model(images)
        loss = criterion(outputs, labels)

        batch_acc = test_classification(testloader, load_model)

        print(f"Epoch [{e+1}/15], Loss: {loss.item():.4f}, test_acc: {batch_acc}")

Epoch [1/30], Loss: 2.3055, test_acc: 0.2936698717948718
Epoch [2/30], Loss: 2.6849, test_acc: 0.09805689102564102
Epoch [3/30], Loss: 4.7147, test_acc: 0.10116185897435898
Epoch [4/30], Loss: 2.8637, test_acc: 0.21764823717948717
Epoch [5/30], Loss: 2.7097, test_acc: 0.16075721153846154
Epoch [6/30], Loss: 2.8161, test_acc: 0.37159455128205127
Epoch [7/30], Loss: 2.5747, test_acc: 0.21424278846153846
Epoch [8/30], Loss: 2.3344, test_acc: 0.2364783653846154
Epoch [9/30], Loss: 2.0577, test_acc: 0.296474358974359
Epoch [10/30], Loss: 1.8395, test_acc: 0.35466746794871795
Epoch [11/30], Loss: 1.7064, test_acc: 0.5071113782051282
Epoch [12/30], Loss: 1.6011, test_acc: 0.5448717948717948
Epoch [13/30], Loss: 1.4855, test_acc: 0.5503806089743589
Epoch [14/30], Loss: 1.3465, test_acc: 0.547676282051282
Epoch [15/30], Loss: 1.2077, test_acc: 0.5735176282051282


In [4]:
class ReverseLayerF(Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha

        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output.neg() * ctx.alpha

        return output, None

class DANN(nn.Module):
    def __init__(self):
        super(DANN, self).__init__()
        self.feature = nn.Sequential(
                    nn.Conv2d(1, 64, kernel_size=5),
                    nn.BatchNorm2d(64),
                    nn.MaxPool2d(2),
                    nn.ReLU(True),
                    nn.Conv2d(64, 50, kernel_size=5),
                    nn.BatchNorm2d(50),
                    nn.Dropout2d(),
                    nn.MaxPool2d(2),
                    nn.ReLU(True)
                )

        self.classifier = nn.Sequential(
                    nn.Linear(50 * 4 * 4, 100),
                    nn.BatchNorm1d(100),
                    nn.ReLU(True),
                    nn.Dropout(),
                    nn.Linear(100, 100),
                    nn.BatchNorm1d(100),
                    nn.ReLU(True),
                    nn.Linear(100, 10),
                )

        self.domain_classifier = nn.Sequential(
                    nn.Linear(50 * 4 * 4, 100),
                    nn.BatchNorm1d(100),
                    nn.ReLU(True),
                    nn.Linear(100, 2),
                )
    def forward(self, input_data, alpha):
        input_data = input_data.expand(input_data.data.shape[0], 1, 28, 28)
        feature = self.feature(input_data)
        feature = feature.view(-1, 50 * 4 * 4)
        reverse_feature = ReverseLayerF.apply(feature, alpha)
        class_output = self.classifier(feature)
        domain_output = self.domain_classifier(reverse_feature)

        return class_output, domain_output

In [5]:
def train_DANN(train_loader, model, criterion, optimizer, epoches):
    model.train()
    src_domain_label = torch.ones(batch_size).long().to(device)
    tgt_domain_label = torch.zeros(batch_size).long().to(device)

    alpha = 0.005

    for e in range(epoches):
        data_target_iter = iter(train_loader)
        correct_source_domain, correct_tgt_domain = 0, 0
        total = 0
        for i in range(40):
            # Src
            source, source_label = next(data_target_iter)
            total += source.size(0)

            source, source_label = source.to(device), source_label.to(device)

            class_output, domain_output = model(source, alpha)

            loss_s_label = criterion(class_output, source_label)
            loss_s_domain = criterion(domain_output, src_domain_label)

            _, predicted = torch.max(domain_output.data, 1)
            correct_source_domain += predicted.eq(src_domain_label.data).cpu().sum().item()

            # Tgt
            target, target_label  = next(data_target_iter)
            target, target_label = target.to(device), target_label.to(device)

            class_output, domain_output = model(target, alpha)

            loss_t_label = criterion(class_output, target_label)
            loss_t_domain = criterion(domain_output, tgt_domain_label)

            _, predicted = torch.max(domain_output.data, 1)
            correct_tgt_domain += predicted.eq(tgt_domain_label.data).cpu().sum().item()

            loss = loss_s_label + loss_s_domain + loss_t_domain + loss_t_label
            loss.backward()
            optimizer.step()
        if ((e + 1) % 5 == 0):
            print(f"{e}: source correct: {correct_source_domain/total}, target correct: {correct_tgt_domain/total}")

In [6]:
criterion = nn.CrossEntropyLoss()
dann = DANN().to(device)
optimizer = optim.Adam(dann.parameters(), lr=0.001)

train_DANN(train_loader, dann, criterion, optimizer, 50)

0: source correct: 0.41337890625, target correct: 0.5865234375
1: source correct: 0.80703125, target correct: 0.19306640625
2: source correct: 0.1201171875, target correct: 0.8833984375
3: source correct: 0.58515625, target correct: 0.415234375
4: source correct: 0.8009765625, target correct: 0.19931640625
5: source correct: 0.29814453125, target correct: 0.70107421875
6: source correct: 0.0919921875, target correct: 0.91083984375
7: source correct: 0.6775390625, target correct: 0.325390625
8: source correct: 0.84013671875, target correct: 0.1609375
9: source correct: 0.9439453125, target correct: 0.05595703125
10: source correct: 0.989453125, target correct: 0.01162109375
11: source correct: 0.77373046875, target correct: 0.2275390625
12: source correct: 0.0857421875, target correct: 0.91240234375
13: source correct: 0.01416015625, target correct: 0.98662109375
14: source correct: 0.00078125, target correct: 0.99912109375
15: source correct: 0.0, target correct: 1.0
16: source correct

In [35]:
class Linear(nn.Module):
    def __init__(self):
        super(Linear, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear (50 * 4 * 4, 1000),
            nn.ReLU(inplace=True),
            nn.Linear (1000, 100),
            nn.ReLU(inplace=True),
            nn.Linear (100, 10),
        )
    def forward(self, x):
        x = x.view(-1, 50 * 4 * 4)
        x = self.classifier(x)

        return x

In [41]:
leader = Linear().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(leader.parameters(), lr=0.005)

In [None]:
def test(testloader, dann, model):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for a, b in testloader:
            a, b = a.to(device), b.to(device)
            feature = dann.feature(a)
            outputs = leader(feature)
            _, predicted = torch.max(outputs.data, 1)
            total += b.size(0)
            correct += (predicted == b).sum().item()
    return correct/total

In [42]:
data_iter = iter(train_loader)
images, labels = next(data_iter)
images, labels = images.to(device), labels.to(device)

for epoch in range(15):
    leader.train()
    optimizer.zero_grad()
    with torch.no_grad():
        feature = dann.feature(images)

    outputs = leader(feature)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    with torch.no_grad():
        torch.save(leader.state_dict(), f'model_batch_A_{epoch}.pth')

    batch_acc = test(testloader, dann, leader)

    print(f"Epoch [{epoch+1}/15], Loss: {loss.item():.4f}, test_acc: {batch_acc}")

images, labels = next(data_iter)
images, labels = images.to(device), labels.to(device)
for e in range(15):
    with torch.no_grad():
        feature = dann.feature(images)
        new_model = Linear()
        new_model.load_state_dict(torch.load(f'model_batch_A_{e}.pth'))
        new_model = new_model.to(device)
        outputs = new_model(feature)
        loss = criterion(outputs, labels)

        batch_acc = test(testloader, dann, new_model)

        print(f"Epoch [{e+1}/15], Loss: {loss.item():.4f}, test_acc: {batch_acc}")

Epoch [1/30], Loss: 2.3496, test_acc: 0.3816105769230769
Epoch [2/30], Loss: 1.4788, test_acc: 0.6490384615384616
Epoch [3/30], Loss: 2.3385, test_acc: 0.5740184294871795
Epoch [4/30], Loss: 1.9101, test_acc: 0.7028245192307693
Epoch [5/30], Loss: 0.9868, test_acc: 0.7924679487179487
Epoch [6/30], Loss: 0.5209, test_acc: 0.9236778846153846
Epoch [7/30], Loss: 0.3131, test_acc: 0.909354967948718
Epoch [8/30], Loss: 0.2191, test_acc: 0.8883213141025641
Epoch [9/30], Loss: 0.2190, test_acc: 0.8960336538461539
Epoch [10/30], Loss: 0.2195, test_acc: 0.9142628205128205
Epoch [11/30], Loss: 0.1343, test_acc: 0.9380008012820513
Epoch [12/30], Loss: 0.0939, test_acc: 0.9512219551282052
Epoch [13/30], Loss: 0.0690, test_acc: 0.9573317307692307
Epoch [14/30], Loss: 0.0337, test_acc: 0.9555288461538461
Epoch [15/30], Loss: 0.0581, test_acc: 0.9519230769230769
