In [1]:
import numpy as np
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
from copy import deepcopy
from typing import Any, Tuple, Optional, Sequence


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import torchvision.transforms.functional as TF
from torch.autograd import Function

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
data_root = "/home/hhchung/data"

In [4]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        
        return x
    
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)
        
    def forward(self, x):
        x1 = self.fc1(x)
        x1 = F.relu(x1)
        x1 = self.dropout2(x1)
        output = self.fc2(x1)
        # if self.training:
        features = x1 # torch.stack([x, x1])
        return output, features
        # else:
        #     return output

In [5]:
class MyRotationTransform:
    """Rotate by one of the given angles."""

    def __init__(self, angle):
        self.angle = angle

    def __call__(self, x):
        return TF.rotate(x, self.angle)


In [6]:
class GradientReverseFunction(Function):

    @staticmethod
    def forward(ctx: Any, input: torch.Tensor, coeff: Optional[float] = 1.) -> torch.Tensor:
        ctx.coeff = coeff
        output = input * 1.0
        return output

    @staticmethod
    def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, Any]:
        return grad_output.neg() * ctx.coeff, None


class GradientReverseLayer(nn.Module):
    def __init__(self):
        super(GradientReverseLayer, self).__init__()

    def forward(self, *input):
        return GradientReverseFunction.apply(*input)

In [7]:
class GaussianKernel(nn.Module):
    r"""Gaussian Kernel Matrix
    Gaussian Kernel k is defined by
    .. math::
        k(x_1, x_2) = \exp \left( - \dfrac{\| x_1 - x_2 \|^2}{2\sigma^2} \right)
    where :math:`x_1, x_2 \in R^d` are 1-d tensors.
    Gaussian Kernel Matrix K is defined on input group :math:`X=(x_1, x_2, ..., x_m),`
    .. math::
        K(X)_{i,j} = k(x_i, x_j)
    Also by default, during training this layer keeps running estimates of the
    mean of L2 distances, which are then used to set hyperparameter  :math:`\sigma`.
    Mathematically, the estimation is :math:`\sigma^2 = \dfrac{\alpha}{n^2}\sum_{i,j} \| x_i - x_j \|^2`.
    If :attr:`track_running_stats` is set to ``False``, this layer then does not
    keep running estimates, and use a fixed :math:`\sigma` instead.
    Args:
        sigma (float, optional): bandwidth :math:`\sigma`. Default: None
        track_running_stats (bool, optional): If ``True``, this module tracks the running mean of :math:`\sigma^2`.
          Otherwise, it won't track such statistics and always uses fix :math:`\sigma^2`. Default: ``True``
        alpha (float, optional): :math:`\alpha` which decides the magnitude of :math:`\sigma^2` when track_running_stats is set to ``True``
    Inputs:
        - X (tensor): input group :math:`X`
    Shape:
        - Inputs: :math:`(minibatch, F)` where F means the dimension of input features.
        - Outputs: :math:`(minibatch, minibatch)`
    """

    def __init__(self, sigma: Optional[float] = None, track_running_stats: Optional[bool] = True,
                 alpha: Optional[float] = 1.):
        super(GaussianKernel, self).__init__()
        assert track_running_stats or sigma is not None
        self.sigma_square = torch.tensor(sigma * sigma) if sigma is not None else None
        self.track_running_stats = track_running_stats
        self.alpha = alpha

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        l2_distance_square = ((X.unsqueeze(0) - X.unsqueeze(1)) ** 2).sum(2)

        if self.track_running_stats:
            self.sigma_square = self.alpha * torch.mean(l2_distance_square.detach())

        return torch.exp(-l2_distance_square / (2 * self.sigma_square))
    
class JointMultipleKernelMaximumMeanDiscrepancy(nn.Module):
    r"""
    Args:
        kernels (tuple(tuple(torch.nn.Module))): kernel functions, where `kernels[r]` corresponds to kernel :math:`k^{\mathcal{L}[r]}`.
        linear (bool): whether use the linear version of JAN. Default: False
        thetas (list(Theta): use adversarial version JAN if not None. Default: None
    Inputs:
        - z_s (tuple(tensor)): multiple layers' activations from the source domain, :math:`z^s`
        - z_t (tuple(tensor)): multiple layers' activations from the target domain, :math:`z^t`
    Shape:
        - :math:`z^{sl}` and :math:`z^{tl}`: :math:`(minibatch, *)`  where * means any dimension
        - Outputs: scalar
    .. note::
        Activations :math:`z^{sl}` and :math:`z^{tl}` must have the same shape.
    .. note::
        The kernel values will add up when there are multiple kernels for a certain layer.
    Examples::
        >>> feature_dim = 1024
        >>> batch_size = 10
        >>> layer1_kernels = (GaussianKernel(alpha=0.5), GaussianKernel(1.), GaussianKernel(2.))
        >>> layer2_kernels = (GaussianKernel(1.), )
        >>> loss = JointMultipleKernelMaximumMeanDiscrepancy((layer1_kernels, layer2_kernels))
        >>> # layer1 features from source domain and target domain
        >>> z1_s, z1_t = torch.randn(batch_size, feature_dim), torch.randn(batch_size, feature_dim)
        >>> # layer2 features from source domain and target domain
        >>> z2_s, z2_t = torch.randn(batch_size, feature_dim), torch.randn(batch_size, feature_dim)
        >>> output = loss((z1_s, z2_s), (z1_t, z2_t))
    """

    def __init__(self, kernels: Sequence[Sequence[nn.Module]], linear: Optional[bool] = True, thetas: Sequence[nn.Module] = None):
        super(JointMultipleKernelMaximumMeanDiscrepancy, self).__init__()
        self.kernels = kernels
        self.index_matrix = None
        self.linear = linear
        if thetas:
            self.thetas = thetas
        else:
            self.thetas = [nn.Identity() for _ in kernels]

    def forward(self, z_s: torch.Tensor, z_t: torch.Tensor) -> torch.Tensor:
        batch_size = int(z_s[0].size(0))
        self.index_matrix = _update_index_matrix(batch_size, self.index_matrix, self.linear).to(z_s[0].device)

        kernel_matrix = torch.ones_like(self.index_matrix)
        for layer_z_s, layer_z_t, layer_kernels, theta in zip(z_s, z_t, self.kernels, self.thetas):
            layer_features = torch.cat([layer_z_s, layer_z_t], dim=0)
            layer_features = theta(layer_features)
            kernel_matrix *= sum(
                [kernel(layer_features) for kernel in layer_kernels])  # Add up the matrix of each kernel

        # Add 2 / (n-1) to make up for the value on the diagonal
        # to ensure loss is positive in the non-linear version
        loss = (kernel_matrix * self.index_matrix).sum() + 2. / float(batch_size - 1)
        return loss

def _update_index_matrix(batch_size: int, index_matrix: Optional[torch.Tensor] = None,
                         linear: Optional[bool] = True) -> torch.Tensor:
    r"""
    Update the `index_matrix` which convert `kernel_matrix` to loss.
    If `index_matrix` is a tensor with shape (2 x batch_size, 2 x batch_size), then return `index_matrix`.
    Else return a new tensor with shape (2 x batch_size, 2 x batch_size).
    """
    if index_matrix is None or index_matrix.size(0) != batch_size * 2:
        index_matrix = torch.zeros(2 * batch_size, 2 * batch_size)
        if linear:
            for i in range(batch_size):
                s1, s2 = i, (i + 1) % batch_size
                t1, t2 = s1 + batch_size, s2 + batch_size
                index_matrix[s1, s2] = 1. / float(batch_size)
                index_matrix[t1, t2] = 1. / float(batch_size)
                index_matrix[s1, t2] = -1. / float(batch_size)
                index_matrix[s2, t1] = -1. / float(batch_size)
        else:
            for i in range(batch_size):
                for j in range(batch_size):
                    if i != j:
                        index_matrix[i][j] = 1. / float(batch_size * (batch_size - 1))
                        index_matrix[i + batch_size][j + batch_size] = 1. / float(batch_size * (batch_size - 1))
            for i in range(batch_size):
                for j in range(batch_size):
                    index_matrix[i][j + batch_size] = -1. / float(batch_size * batch_size)
                    index_matrix[i + batch_size][j] = -1. / float(batch_size * batch_size)
    return index_matrix
    
class Theta(nn.Module):
    """
    maximize loss respect to :math:`\theta`
    minimize loss respect to features
    """
    def __init__(self, dim: int):
        super(Theta, self).__init__()
        self.grl1 = GradientReverseLayer()
        self.grl2 = GradientReverseLayer()
        self.layer1 = nn.Linear(dim, dim)
        nn.init.eye_(self.layer1.weight)
        nn.init.zeros_(self.layer1.bias)

    def forward(self, features: torch.Tensor) -> torch.Tensor:
        features = self.grl1(features)
        return self.grl2(self.layer1(features))


In [8]:
def train(encoder, classifier, device, train_loader, optimizer):
    encoder.train()
    classifier.train()
    
    total_train_loss = 0
    total_size = 0
    for data, target in tqdm(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output, _ = classifier(encoder(data))
        loss = F.nll_loss(F.log_softmax(output, dim=1), target)
        loss.backward()
        optimizer.step()
        
        batch_size = data.shape[0]
        total_train_loss += loss.item() * batch_size
        total_size += batch_size
    
    total_train_loss /= total_size
    return total_train_loss


  
@torch.no_grad()
def test(encoder, classifier, device, test_loader):
    encoder.eval()
    classifier.eval()
    
    total_test_loss = 0  
    total_correct = 0
    total_size = 0
    
    
    for data, target in tqdm(test_loader):
        
        data, target = data.to(device), target.to(device)
        output = classifier(encoder(data))
        loss = F.nll_loss(F.log_softmax(output, dim=1), target, reduction='sum')
        total_test_loss += loss.item()
        pred = output.argmax(dim=1, keepdim=True)
        total_correct += pred.eq(target.view_as(pred)).sum().item()
        total_size += data.shape[0]
    
    total_test_loss /= total_size
    total_correct /= total_size
    
    return total_test_loss, total_correct


In [9]:
def adapt(encoder, classifier, jmmd_loss, device, src_loader, tgt_loader, optimizer, e, epochs, lambda_coeff):
    encoder.train()
    classifier.train()
    jmmd_loss.train()
    len_dataloader = min(len(src_loader), len(tgt_loader))
    src_iter = iter(src_loader)
    tgt_iter = iter(tgt_loader)
    
    total_loss = 0
    total_cls_loss = 0
    total_transfer_loss = 0
    total_src_data_size = 0
    
    for i in tqdm(range(len_dataloader)):
        src_data, src_label = src_iter.next()
        src_data, src_label = src_data.to(device), src_label.to(device)
        
        tgt_data, _ = tgt_iter.next()
        tgt_data = tgt_data.to(device)
        
        src_y, src_f = classifier(encoder(src_data))
        tgt_y, tgt_f = classifier(encoder(tgt_data))
        cls_loss = F.nll_loss(F.log_softmax(src_y, dim=1), src_label)
        transfer_loss = jmmd_loss((src_f, F.softmax(src_y, dim=1)), (tgt_f, F.softmax(tgt_y, dim=1)))
        loss = cls_loss + transfer_loss * lambda_coeff
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item() * src_data.size(0)
        total_cls_loss += cls_loss.item() * src_data.size(0)
        total_transfer_loss += transfer_loss.item() * src_data.size(0)
        total_src_data_size += src_data.size(0)
    
    total_loss /= total_src_data_size
    total_cls_loss /= total_src_data_size
    total_transfer_loss /= total_src_data_size
    return total_loss, total_cls_loss, total_transfer_loss
    
@torch.no_grad()
def adapt_test(encoder, classifier, jmmd_loss, device, src_loader, tgt_loader, e, epochs, lambda_coeff):
    encoder.eval()
    classifier.eval()
    jmmd_loss.eval()
    len_dataloader = min(len(src_loader), len(tgt_loader))
    src_iter = iter(src_loader)
    tgt_iter = iter(tgt_loader)
    
    total_loss = 0
    total_cls_loss = 0
    total_transfer_loss = 0
    total_src_data_size = 0
    
    for i in tqdm(range(len_dataloader)):
        src_data, src_label = src_iter.next()
        src_data, src_label = src_data.to(device), src_label.to(device)
        
        tgt_data, _ = tgt_iter.next()
        tgt_data = tgt_data.to(device)
        
        src_y, src_f = classifier(encoder(src_data))
        tgt_y, tgt_f = classifier(encoder(tgt_data))
        cls_loss = F.nll_loss(F.log_softmax(src_y, dim=1), src_label)
        transfer_loss = jmmd_loss((src_f, F.softmax(src_y, dim=1)), (tgt_f, F.softmax(tgt_y, dim=1)))
        loss = cls_loss + transfer_loss * lambda_coeff
        
        
        total_loss += loss.item() * src_data.size(0)
        total_cls_loss += cls_loss.item() * src_data.size(0)
        total_transfer_loss += transfer_loss.item() * src_data.size(0)
        total_src_data_size += src_data.size(0)
    
    total_loss /= total_src_data_size
    total_cls_loss /= total_src_data_size
    total_transfer_loss /= total_src_data_size
    return total_loss, total_cls_loss, total_transfer_loss
    


In [10]:
gpuID = 1
device = torch.device('cuda:' + str(gpuID) if torch.cuda.is_available() else 'cpu')
encoder = Encoder().to(device)
classifier = Classifier().to(device)

In [11]:
transform=transforms.Compose([
          transforms.ToTensor(),
          transforms.Normalize((0.1307,), (0.3081,))
          ])
train_dataset = datasets.MNIST(data_root, train=True, download=True,
                          transform=transform)

train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [50000, 10000])

test_dataset = datasets.MNIST(data_root, train=False,
                       transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False)

epochs = 5
optimizer = torch.optim.Adam(list(encoder.parameters()) + list(classifier.parameters()), lr=0.001)

In [12]:
for e in range(1, epochs + 1):
    train_loss = train(encoder, classifier, device, train_loader, optimizer)
    val_loss, correct = test(encoder, classifier, device, val_loader)
    print(f'Epoch:{e}/{epochs} Train Loss: {round(train_loss, 3)}, Val Loss: {round(val_loss, 3)}, Val Accuracy: {round(correct, 3)}')

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 391/391 [00:07<00:00, 55.16it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 80.21it/s]


Epoch:1/5 Train Loss: 0.24, Val Loss: 0.066, Val Accuracy: 0.982


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 391/391 [00:05<00:00, 71.20it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 80.27it/s]


Epoch:2/5 Train Loss: 0.089, Val Loss: 0.049, Val Accuracy: 0.986


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 391/391 [00:05<00:00, 71.86it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 80.30it/s]


Epoch:3/5 Train Loss: 0.067, Val Loss: 0.043, Val Accuracy: 0.988


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 391/391 [00:05<00:00, 71.90it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 80.28it/s]


Epoch:4/5 Train Loss: 0.054, Val Loss: 0.052, Val Accuracy: 0.986


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 391/391 [00:05<00:00, 71.57it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 80.23it/s]

Epoch:5/5 Train Loss: 0.048, Val Loss: 0.042, Val Accuracy: 0.989





In [13]:
test_loss, test_acc = test(encoder, classifier, device, test_loader)
print(test_loss, test_acc)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 80.00it/s]

0.029825609946274197 0.9902





## Rotate 30 Degrees ##

In [14]:
transform=transforms.Compose([
          transforms.ToTensor(),
          transforms.Normalize((0.1307,), (0.3081,)),
          MyRotationTransform(30)
          ])

train_dataset_30 = datasets.MNIST(data_root, train=True, download=True,
                          transform=transform)
train_dataset_30, val_dataset_30 = torch.utils.data.random_split(train_dataset_30, [50000, 10000])
test_dataset_30 = datasets.MNIST(data_root, train=False,
                       transform=transform)


train_loader_30 = torch.utils.data.DataLoader(train_dataset_30, batch_size=128, shuffle=True)
val_loader_30 = torch.utils.data.DataLoader(val_dataset_30, batch_size=128, shuffle=True)
test_loader_30 = torch.utils.data.DataLoader(test_dataset_30, batch_size=128, shuffle=False)

epochs = 100
encoder_30 = deepcopy(encoder)
classifier_30 = deepcopy(classifier)
optimizer_30 = torch.optim.Adam(list(encoder_30.parameters()) + list(classifier_30.parameters()), lr=0.001)

In [15]:
thetas = None # none adversarial
jmmd_loss = JointMultipleKernelMaximumMeanDiscrepancy(
    kernels=(
        [GaussianKernel(alpha=2 ** k) for k in range(-3, 2)],
        (GaussianKernel(sigma=0.92, track_running_stats=False),)
    ),
    linear=False, thetas=thetas
).to(device)

In [16]:
# lambda_coef_list = [10, 5, 3, 1]

# for lambda_coef in lambda_coef_list:
lambda_coeff = 1.0
best_val_loss = 0
best_encoder_30, best_classifier_30 = None, None
for e in range(1, epochs + 1):
    total_train_loss, total_train_cls_loss, total_train_transfer_loss = adapt(encoder, classifier, jmmd_loss, device, train_loader, train_loader_30, optimizer_30, e, epochs, lambda_coeff)
    total_val_loss, total_val_cls_loss, total_val_transfer_loss = adapt_test(encoder, classifier, jmmd_loss, device, val_loader, val_loader_30, e, epochs, lambda_coeff)

    print(f'Train Total Loss: {round(total_train_loss,3)} Train Src Cls Loss: {round(total_train_cls_loss,3)} Train Tgt Transfer Loss: {round(total_train_transfer_loss,3)} \n Val Total Loss: {round(total_val_loss,3)} Val Src Cls Loss: {round(total_val_cls_loss,3)} Val Tgt Transfer Loss: {round(total_val_transfer_loss,3)}')
    test_loss, test_acc = test(encoder_30, classifier_30, device, test_loader_30)
    print(f"Test Loss: {test_loss} Test Acc: {test_acc}")


encoder_30 = deepcopy(best_encoder_30)
classifier_30 = deepcopy(best_classifier_30)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 391/391 [00:18<00:00, 20.68it/s]
  0%|                                                                                                                                         | 0/391 [00:00<?, ?it/s]


ValueError: too many values to unpack (expected 2)

In [None]:
test(encoder_30, classifier_30, device, test_loader_30)