In [1]:
import torch
import torchvision

import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import pytorch_lightning as pl

from torch.utils.data import DataLoader, random_split

In [2]:
pl.seed_everything(42)

Global seed set to 42


42

In [3]:
IMAGE_CHANNEL_NUM = 1
IMAGE_SIZE = 32
CLASS_NUM = 10
EPOCH = 100
BATCH_NUM = 32
LR = 0.1

In [4]:
transform_mnist_train = transforms.Compose([
    transforms.Resize(size=(32, 32)),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
])

transform_svhn_train = transforms.Compose([
    transforms.Resize(size=(IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
])

train_set_src = torchvision.datasets.SVHN(root='./data', split="train", download=True, transform=transform_svhn_train)
train_set_src, val_set_src = torch.utils.data.random_split(train_set_src, [63257, 10000])
train_set_tgt = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform_mnist_train)
train_set_tgt, val_set_tgt = torch.utils.data.random_split(train_set_tgt, [50000, 10000])
test_set_tgt = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform_mnist_train)

Using downloaded and verified file: ./data/train_32x32.mat


# Model

In [5]:
class DebugLayer(nn.Module):
    def forward(self, x):
        print(x.shape)
        return x


class Flatten(nn.Module):
    def forward(self, x):
        batch_size = x.shape[0]
        return x.view(batch_size, -1)


class GRL(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha

        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output.neg() * ctx.alpha

        return output, None
    

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(in_channels=32, out_channels=48, kernel_size=5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            Flatten(),
        )
        self.classifier = nn.Sequential(
            nn.Linear(in_features=1200, out_features=100),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=100, out_features=100),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=100, out_features=10),
            nn.Sigmoid(),
        )
        
        self.discriminator = nn.Sequential(
            nn.Linear(in_features=1200, out_features=100),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=100, out_features=1),
            nn.Sigmoid(),
        )
        
class SVHNCNN(nn.Module):
    def __init__(self):
        super(SVHNCNN, self).__init__()
        
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, padding=2),
#             DebugLayer(),
            nn.ReLU(inplace=True),
            nn.MaxPool2d((3,3), (2,2)),
#             DebugLayer(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5, padding=2),
#             DebugLayer(),
            nn.ReLU(inplace=True),
            nn.MaxPool2d((3,3), (2,2)),
#             DebugLayer(),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, padding=2),
#             DebugLayer(),
            nn.ReLU(inplace=True),
            nn.MaxPool2d((3,3), (2,2)),
#             DebugLayer(),
            Flatten(),
        )

        self.classifier = nn.Sequential(
#             DebugLayer(),
            nn.Linear(in_features=128*3*3, out_features=3072),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=3072, out_features=2048),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=2048, out_features=10),
            nn.Sigmoid(),
        )
        
        self.discriminator = nn.Sequential(
            nn.Linear(in_features=128*3*3, out_features=1024),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=1024, out_features=1024),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=1024, out_features=1),
            nn.Sigmoid(),
        )

In [6]:
class DANN(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.feature_extractor = model.feature_extractor
        self.classifier = model.classifier
        self.discriminator = model.discriminator
        
        self.train_accuracy = pl.metrics.Accuracy()
        self.val_accuracy = pl.metrics.Accuracy()
        self.test_accuracy = pl.metrics.Accuracy()
    
    def training_step(self, batch, batch_idx):
        (inputs_src, targets_src), (inputs_tgt, _) = batch
        device = inputs_src.device
        iterations = self.global_step
        p = float(iterations / EPOCH * (50000 / BATCH_NUM))
        alpha = 2. / (1. + np.exp(-10 * p)) - 1
        
        features_src = self.feature_extractor(inputs_src)
        outputs_src = self.classifier(features_src)
        features_src = GRL.apply(features_src, alpha)
        outputs_domain_src = self.discriminator(features_src)
        
        features_tgt = self.feature_extractor(inputs_tgt)
        features_tgt = GRL.apply(features_tgt, alpha)
        outputs_domain_tgt = self.discriminator(features_tgt)
        
        outputs_domain = torch.cat([
            outputs_domain_src,
            outputs_domain_tgt
        ], axis=0)
        targets_domain = torch.cat([
            torch.ones(outputs_domain_src.shape[0]),
            torch.zeros(outputs_domain_tgt.shape[0]),
        ], axis=0).unsqueeze(1).to(device)
        loss_cls = F.cross_entropy(outputs_src, targets_src)
        loss_dsc = F.binary_cross_entropy(outputs_domain, targets_domain)
        loss = loss_cls + loss_dsc
        
        self.log("train_loss_cls", loss_cls, on_epoch=True, prog_bar=True, logger=True)
        self.log("train_acc", self.train_accuracy(outputs_src, targets_src), on_epoch=True, prog_bar=True, logger=True)
        self.log("train_loss_dsc", loss_dsc, on_epoch=True, prog_bar=True, logger=True)
        
        return loss_cls
        
    def training_epoch_end(self, outs):
        self.log("train_acc_epoch", self.train_accuracy.compute(), on_epoch=True, prog_bar=True, logger=True)
    
    def validation_step(self, batch, batch_idx):
        inputs, targets = batch
        features = self.feature_extractor(inputs)
        outputs = self.classifier(features)
        
        loss_cls = F.cross_entropy(outputs, targets)
        self.log("val_acc", self.val_accuracy(outputs, targets), on_epoch=True, prog_bar=True, logger=True)
        self.log("val_loss_cls", loss_cls, on_epoch=True, prog_bar=True, logger=True)
        
        return loss_cls

    def validation_epoch_end(self, outs):
        self.log("val_acc_epoch", self.val_accuracy.compute(), on_epoch=True, prog_bar=True, logger=True)

    def test_step(self, batch, batch_idx):
        inputs, targets = batch
        features = self.feature_extractor(inputs)
        outputs = self.classifier(features)
        
        loss_cls = F.cross_entropy(outputs, targets)
        self.log("test_acc", self.test_accuracy(outputs, targets), on_epoch=True, prog_bar=True, logger=True)
        self.log("test_loss_cls", loss_cls, on_epoch=True, prog_bar=True, logger=True)
        
        return loss_cls
    
    def test_epoch_end(self, outs):
        self.log("test_acc_epoch", self.test_accuracy.compute(), on_epoch=True, prog_bar=True, logger=True)

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(
            list(self.feature_extractor.parameters()) + list(self.classifier.parameters()),
            lr=LR,
        )

        return optimizer

    def train_dataloader(self):
        src_loader = DataLoader(train_set_src, batch_size=BATCH_NUM, shuffle=True, num_workers=8)
        tgt_loader = DataLoader(train_set_tgt, batch_size=BATCH_NUM, shuffle=True, num_workers=8)
        return list(zip(src_loader, tgt_loader))
    
    def val_dataloader(self):
        return DataLoader(val_set_tgt, batch_size=BATCH_NUM, num_workers=8)
    
    def test_dataloader(self):
        return DataLoader(test_set_tgt, batch_size=BATCH_NUM, num_workers=8)

In [7]:
model = DANN(SVHNCNN())
trainer = pl.Trainer(
    deterministic=True,
    check_val_every_n_epoch=1, 
    gpus=1,
    max_epochs=EPOCH,
)
trainer.fit(model)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type       | Params
-------------------------------------------------
0 | feature_extractor | Sequential | 312 K 
1 | classifier        | Sequential | 9.9 M 
2 | discriminator     | Sequential | 2.2 M 
3 | train_accuracy    | Accuracy   | 0     
4 | val_accuracy      | Accuracy   | 0     
5 | test_accuracy     | Accuracy   | 0     
-------------------------------------------------
12.4 M    Trainable params
0         Non-trainable params
12.4 M    Total params


Validation sanity check: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

1

In [8]:
trainer.test()

Testing: |          | 0/? [00:00<?, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': 0.7365000247955322,
 'test_acc_epoch': 0.7365000247955322,
 'test_loss_cls': 1.7305165529251099}
--------------------------------------------------------------------------------


[{'test_acc': 0.7365000247955322,
  'test_loss_cls': 1.7305165529251099,
  'test_acc_epoch': 0.7365000247955322}]