# MCD
> Maximum Classifier Discrepancy for Unsupervised Domain Adaptation

In [None]:
#| default_exp ml.mcd

In [None]:
#| hide
from fastcore.test import *
from nbdev.showdoc import *

In [None]:
#| export

import torch
import torch.nn as nn
import torch.nn.functional as F
from bellek.ml.layer import GradReverse
from fastcore.basics import store_attr
from fastai.callback.core import Callback
from fastai.learner import CancelBatchException
from fastai.losses import BaseLoss
from fastai.torch_core import default_device

In [None]:
#| export

def discrepancy(a, b):
    return torch.mean(torch.abs(F.softmax(a, dim=-1) - F.softmax(b, dim=-1)))

In [None]:
#| hide

a = torch.tensor([1.0, 2.0, 3.0])

test_eq(discrepancy(a, a).item(), 0.0)
assert discrepancy(a, -a).item() > 0

In [None]:
#| export

class DiscrepancyLoss:
    def __call__(self, outs, *targets, **kwargs):
        assert len(outs) == 2
        return -discrepancy(*outs)

def discrepancy_metric(pred, *targets):
    a, b = pred[-2], pred[-1]
    return discrepancy(a, b)

In [None]:
#| export

class Feature(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=5, stride=1, padding=2)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=5, stride=1, padding=2)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=2)
        self.bn3 = nn.BatchNorm2d(128)
        self.fc1 = nn.Linear(8192, 3072)
        self.bn1_fc = nn.BatchNorm1d(3072)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.bn1(self.conv1(x))), stride=2, kernel_size=3, padding=1)
        x = F.max_pool2d(F.relu(self.bn2(self.conv2(x))), stride=2, kernel_size=3, padding=1)
        x = F.relu(self.bn3(self.conv3(x)))
        x = x.view(x.size(0), 8192)
        x = F.relu(self.bn1_fc(self.fc1(x)))
        x = F.dropout(x, training=self.training)
        return x


class Predictor(nn.Module):
    def __init__(self, prob=0.5, lambd=1.0):
        super().__init__()
        self.prob = prob
        self.lambd = lambd
        self.fc1 = nn.Linear(8192, 3072)
        self.bn1_fc = nn.BatchNorm1d(3072)
        self.fc2 = nn.Linear(3072, 2048)
        self.bn2_fc = nn.BatchNorm1d(2048)
        self.fc3 = nn.Linear(2048, 10)
        self.bn_fc3 = nn.BatchNorm1d(10)
        self.gr = GradReverse(lambd)

    def forward(self, x, reverse=False):
        if reverse:
            x = self.gr(x)
        x = F.relu(self.bn2_fc(self.fc2(x)))
        x = self.fc3(x)
        return x


In [None]:
#| export

class McdDataset:
    def __init__(self, source_ds, target_ds):
        store_attr()
    
    def __getitem__(self, idx):
        xs, ys = self.source_ds[idx]
        xt, yt = self.target_ds[idx]
        return xs, xt, ys, yt
    
    def __len__(self):
        return min(len(self.source_ds), len(self.target_ds))

In [None]:
#| export

class McdModel(nn.Module):
    def __init__(self, feature_extractor, classifier1, classifier2):
        super().__init__()
        store_attr()
    
    def forward(self, img, grad_reverse=False):
        feat = self.feature_extractor(img)
        output1 = self.classifier1(feat, grad_reverse)
        output2 = self.classifier2(feat, grad_reverse)
        return output1, output2

class McdCallback(Callback):
    def __init__(self, classification_loss_func, discrepancy_loss_func):
        super().__init__()
        store_attr()
    
    def before_batch(self, *args, **kwargs):
        self._do_one_batch()
        raise CancelBatchException

    def before_fit(self):
        "Set device for loss funcs"
        device = getattr(self.dls, 'device', default_device())
        if isinstance(self.classification_loss_func, (nn.Module, BaseLoss)): 
            self.classification_loss_func.to(device)
        if isinstance(self.discrepancy_loss_func, (nn.Module, BaseLoss)): 
            self.discrepancy_loss_func.to(device)
    
    def _do_one_batch(self):
        assert len(self.xb) == 2
        assert len(self.yb) == 2
        source_pred, source_loss = self._predict_source()
        target_pred, target_loss = self._predict_target()
        self.learn.pred = tuple([*source_pred, *target_pred])
        self.learn('after_pred')
        if source_loss is not None and target_loss is not None:
            self.learn.loss = torch.Tensor([source_loss.clone(), target_loss.clone()])
        self.learn('after_loss')
        if not self.training or not len(self.yb): 
            return
        self._do_grad_opt()
    
    def _predict_source(self):
        img = self.xb[0]
        pred = self.model(img, grad_reverse=False)
        loss = None
        if len(self.yb):
            loss = self.classification_loss_func(pred, *self.yb)
            if self.training:
                loss.backward(retain_graph=True)
        return pred, loss
    
    def _predict_target(self):
        img = self.xb[1]
        pred = self.model(img, grad_reverse=True)
        loss = None
        if len(self.yb):
            loss = self.discrepancy_loss_func(pred, *self.yb)
            if self.training:
                loss.backward()
        return pred, loss

    def _do_grad_opt(self):
        self.learn.opt.step()
        self.learn.opt.zero_grad()


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()