In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.autograd import Function

In [None]:
# Hyperparameters
latent_dim = 100
lr = 0.0002
batch_size = 128
image_size = 28*28

In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

In [None]:
class ReverseLayerF(Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha

        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output.neg() * ctx.alpha

        return output, None

class DANN(nn.Module):
    def __init__(self):
        super(DANN, self).__init__()
        self.feature = nn.Sequential(
                    nn.Conv2d(1, 64, kernel_size=5),
                    nn.BatchNorm2d(64),
                    nn.MaxPool2d(2),
                    nn.ReLU(True),
                    nn.Conv2d(64, 50, kernel_size=5),
                    nn.BatchNorm2d(50),
                    nn.Dropout2d(),
                    nn.MaxPool2d(2),
                    nn.ReLU(True)
                )
                
        self.classifier = nn.Sequential(
                    nn.Linear(50 * 4 * 4, 100),
                    nn.BatchNorm1d(100),
                    nn.ReLU(True),
                    nn.Dropout(),
                    nn.Linear(100, 100),
                    nn.BatchNorm1d(100),
                    nn.ReLU(True),
                    nn.Linear(100, 10),
                )

        self.domain_classifier = nn.Sequential(
                    nn.Linear(50 * 4 * 4, 100),
                    nn.BatchNorm1d(100),
                    nn.ReLU(True),
                    nn.Linear(100, 2),
                )
    def forward(self, input_data, alpha):
        input_data = input_data.expand(input_data.data.shape[0], 1, 28, 28)
        feature = self.feature(input_data)
        feature = feature.view(-1, 50 * 4 * 4)
        reverse_feature = ReverseLayerF.apply(feature, alpha)
        class_output = self.classifier(feature)
        domain_output = self.domain_classifier(reverse_feature)

        return class_output, domain_output

In [None]:
def train_DANN(train_loader, model, criterion, optimizer, epoches):
    model.train()
    data_target_iter = iter(train_loader)
    src_domain_label = torch.ones(batch_size).long()
    tgt_domain_label = torch.zeros(batch_size).long()

    alpha = 0.05

    for e in range(epoches):
        data_target_iter = iter(train_loader)
        correct_source_domain, correct_tgt_domain = 0, 0
        total = 0
        for i in range(20):
            # Src
            source, source_label = data_target_iter.next()
            total += source.size(0)

            source, source_label = source.to(device), source_label.to(device)

            class_output, domain_output = model(source, alpha)

            loss_s_label = criterion(class_output, source_label)
            loss_s_domain = criterion(domain_output, src_domain_label)

            _, predicted = torch.max(domain_output.data, 1)
            correct_source_domain += predicted.eq(src_domain_label.data).cpu().sum().item()

            # Tgt
            target, target_label  = data_target_iter.next()
            target, target_label = target.to(device), target_label.to(device)

            class_output, domain_output = model(target, alpha)
            
            loss_t_label = criterion(class_output, target_label)
            loss_t_domain = criterion(domain_output, tgt_domain_label)

            _, predicted = torch.max(domain_output.data, 1)
            correct_tgt_domain += predicted.eq(tgt_domain_label.data).cpu().sum().item()

            loss = loss_s_label + loss_s_domain + loss_t_domain + loss_t_label
            loss.backward()
            optimizer.step()

        print(f"{e}: source correct: {correct_source_domain/total}, target correct: {correct_tgt_domain/total}")

In [None]:
criterion = nn.CrossEntropyLoss()
dann = DANN()
dann.to(device)
optimizer = optim.Adam(dann.parameters(), lr=0.001) 

train_DANN(train_loader, dann, criterion, optimizer, 50)

In [None]:
class Linear(nn.Module):
    def __init__(self):
        super(Linear, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear (50 * 4 * 4, 1000), 
            nn.ReLU(inplace=True),
            nn.Linear (1000, 100), 
            nn.ReLU(inplace=True),
            nn.Linear (100, 10),
        )
    def forward(self, x):
        x = x.view(256, 32 * 32 * 3)
        x = self.classifier(x)

        return x

In [None]:
leader = Linear()
leader.train()

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(leader.parameters(), lr=0.005)

In [None]:
data_iter = iter(train_loader)
images, labels = next(data_iter)

for epoch in range(15):
    optimizer.zero_grad()
    with torch.no_grad():
        feature = dann.feature(images)
        
    outputs = leader(feature)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

    print(f"Epoch [{epoch+1}/30], Loss: {loss.item():.4f}")

In [None]:
torch.save(leader.state_dict(), 'model_batch_A.pth')

In [None]:
images, labels = next(data_iter)

with torch.no_grad():
    with torch.no_grad():
        feature = dann.feature(images)
    outputs = leader(feature)
    loss = criterion(outputs, labels)
    print(f"B all on A, Loss: {loss.item():.4f}")