In [38]:
import numpy as np
from scipy.io import loadmat

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import torchvision.models as models
from torch.autograd import Function
from torchinfo import summary

from torch.utils.data import Dataset, DataLoader

In [2]:
resnet50 = models.resnet50(pretrained=True)

In [9]:
resnet50.fc = nn.Identity()
resnet50

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [4]:
summary(resnet50, (1, 3, 224, 224))

Layer (type:depth-idx)                   Output Shape              Param #
ResNet                                   --                        --
├─Conv2d: 1-1                            [1, 64, 112, 112]         9,408
├─BatchNorm2d: 1-2                       [1, 64, 112, 112]         128
├─ReLU: 1-3                              [1, 64, 112, 112]         --
├─MaxPool2d: 1-4                         [1, 64, 56, 56]           --
├─Sequential: 1-5                        [1, 256, 56, 56]          --
│    └─Bottleneck: 2-1                   [1, 256, 56, 56]          --
│    │    └─Conv2d: 3-1                  [1, 64, 56, 56]           4,096
│    │    └─BatchNorm2d: 3-2             [1, 64, 56, 56]           128
│    │    └─ReLU: 3-3                    [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-4                  [1, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-5             [1, 64, 56, 56]           128
│    │    └─ReLU: 3-6                    [1, 64, 56, 56]           --
│ 

In [10]:
class GaninModel(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.feature = resnet50
        
        self.projection = nn.Sequential(
            nn.Linear(2048, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, 256),
        )
        # self.feature = nn.Sequential(
        #     nn.Conv2d(in_channels=3, out_channels=32,
        #               kernel_size=(5, 5)),  # 3 28 28, 32 24 24
        #     nn.BatchNorm2d(32),
        #     nn.ReLU(),
        #     nn.MaxPool2d(kernel_size=(2, 2)),  # 32 12 12
        #     nn.Conv2d(in_channels=32, out_channels=48,
        #               kernel_size=(5, 5)),  # 48 8 8
        #     nn.BatchNorm2d(48),
        #     nn.ReLU(),
        #     nn.Dropout2d(),
        #     nn.MaxPool2d(kernel_size=(2, 2)),  # 48 4 4
        # )

        self.classifier = nn.Sequential(
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Linear(64, 10),
        )

        self.discriminator = nn.Sequential(
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128, 1),
        )

    def forward(self, input_data, alpha=0):

        feature = self.feature(input_data)
        feature = feature.view(-1, 2048)
        feature = self.projection(feature)
        reverse_feature = ReverseGradientLayer.apply(feature, alpha)
        class_output = self.classifier(feature)
        domain_output = self.discriminator(reverse_feature)

        return class_output, domain_output


class ReverseGradientLayer(Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha

        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):

        return -ctx.alpha * grad_output, None


class SimpleClassifier(nn.Module):
    """Simple classifier based on Ganin et al. architecture. It does not include
       discriminator.
    """
    def __init__(self):
        super().__init__()

        self.feature = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32,
                      kernel_size=(5, 5)),  # 3 28 28, 32 24 24
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=(2, 2)),  # 32 12 12
            nn.Conv2d(in_channels=32, out_channels=48,
                      kernel_size=(5, 5)),  # 48 8 8
            nn.BatchNorm2d(48),
            nn.Dropout2d(),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=(2, 2)),  # 48 4 4
        )

        self.classifier = nn.Sequential(
            nn.Linear(48 * 4 * 4, 100),
            nn.BatchNorm1d(100),
            nn.ReLU(inplace=True),
            nn.Linear(100, 100),
            nn.BatchNorm1d(100),
            nn.ReLU(inplace=True),
            nn.Linear(100, 10),
        )

    def forward(self, input_data):

        feature = self.feature(input_data)
        feature = feature.reshape(-1, 48 * 4 * 4)
        return self.classifier(feature)


In [18]:
model = GaninModel()
summary(model, (1, 3, 28, 28))

Layer (type:depth-idx)                        Output Shape              Param #
GaninModel                                    --                        --
├─ResNet: 1-1                                 [1, 2048]                 --
│    └─Conv2d: 2-1                            [1, 64, 14, 14]           9,408
│    └─BatchNorm2d: 2-2                       [1, 64, 14, 14]           128
│    └─ReLU: 2-3                              [1, 64, 14, 14]           --
│    └─MaxPool2d: 2-4                         [1, 64, 7, 7]             --
│    └─Sequential: 2-5                        [1, 256, 7, 7]            --
│    │    └─Bottleneck: 3-1                   [1, 256, 7, 7]            75,008
│    │    └─Bottleneck: 3-2                   [1, 256, 7, 7]            70,400
│    │    └─Bottleneck: 3-3                   [1, 256, 7, 7]            70,400
│    └─Sequential: 2-6                        [1, 512, 4, 4]            --
│    │    └─Bottleneck: 3-4                   [1, 512, 4, 4]            379,392

In [20]:
## Hyperparameters

batch_size = 128
epochs = 2  
lr = 1e-3 


In [21]:
class DigitDataset(Dataset):
    def __init__(self, data, transform=None):

        self.data = data[0]
        self.labels = data[1]
        self.transform = transform

    def __len__(self):
        return len(self.data)  # 60000 for training and 10000 for test

    def __getitem__(self, idx):

        X = self.data[idx]
        y = self.labels[idx]

        if self.transform:
            X = self.transform(X)

        return X, y


def fetch(data,
          batch_size=64,
          transform=None,
          shuffle=True,
          num_workers=1,
          pin_memory=True):

    # data = torch.load(data_dir)

    dataset = DigitDataset(data=data, transform=transform)

    return DataLoader(dataset,
                      batch_size=batch_size,
                      shuffle=shuffle,
                      num_workers=num_workers,
                      pin_memory=pin_memory)

In [39]:
transform = transforms.Compose(
    [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


In [40]:
loaders_args = dict(
    batch_size=batch_size,
    transform=transform,
    shuffle=True,
    num_workers=0,
    pin_memory=True,
)

train_src = torch.load("./data/mnist/train.pt")
trainloader_src = fetch(data=train_src, **loaders_args)

# fetching testloader_m for symmetry but it is not needed in the code
test_src = torch.load("./data/mnist/test.pt")
testloader_src = fetch(data=test_src, **loaders_args)

train_tgt = torch.load("./data/mnist_m/train.pt")
trainloader_tgt = fetch(data=train_tgt, **loaders_args)

test_tgt = torch.load("./data/mnist_m/test.pt")
testloader_tgt = fetch(data=test_tgt, **loaders_args)


In [55]:
criterion_class = nn.CrossEntropyLoss()
criterion_domain = nn.BCELoss()
criterion_contrast = None



In [None]:
## Training Loop

