# MCD
> Maximum Classifier Discrepancy for Unsupervised Domain Adaptation

In [None]:
#| default_exp ml.mcd

In [None]:
#| hide
from fastcore.test import *
from nbdev.showdoc import *

In [None]:
#| export

import torch
import torch.nn as nn
import torch.nn.functional as F
from bellek.ml.layer import GradReverse

In [None]:
#| export

def discrepancy(a, b):
    return torch.mean(torch.abs(F.softmax(a, dim=-1) - F.softmax(b, dim=-1)))

In [None]:
#| hide

a = torch.tensor([1.0, 2.0, 3.0])

test_eq(discrepancy(a, a).item(), 0.0)
assert discrepancy(a, -a).item() > 0

In [None]:
#| export

class Feature(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=5, stride=1, padding=2)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=5, stride=1, padding=2)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=2)
        self.bn3 = nn.BatchNorm2d(128)
        self.fc1 = nn.Linear(8192, 3072)
        self.bn1_fc = nn.BatchNorm1d(3072)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.bn1(self.conv1(x))), stride=2, kernel_size=3, padding=1)
        x = F.max_pool2d(F.relu(self.bn2(self.conv2(x))), stride=2, kernel_size=3, padding=1)
        x = F.relu(self.bn3(self.conv3(x)))
        x = x.view(x.size(0), 8192)
        x = F.relu(self.bn1_fc(self.fc1(x)))
        x = F.dropout(x, training=self.training)
        return x


class Predictor(nn.Module):
    def __init__(self, prob=0.5, lambd=1.0):
        super().__init__()
        self.prob = prob
        self.lambd = lambd
        self.fc1 = nn.Linear(8192, 3072)
        self.bn1_fc = nn.BatchNorm1d(3072)
        self.fc2 = nn.Linear(3072, 2048)
        self.bn2_fc = nn.BatchNorm1d(2048)
        self.fc3 = nn.Linear(2048, 10)
        self.bn_fc3 = nn.BatchNorm1d(10)
        self.gr = GradReverse(lambd)

    def forward(self, x, reverse=False):
        if reverse:
            x = self.gr(x)
        x = F.relu(self.bn2_fc(self.fc2(x)))
        x = self.fc3(x)
        return x


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()