# A gradient reversal layer

In [None]:
from torch.autograd import Function

class RevGrad(Function):
    
    @staticmethod
    def forward(ctx, input_):
        ######

    @staticmethod
    def backward(ctx, grad_output):  
        #####


In [None]:
from torch.nn import Module

class RevGradLayer(Module):
    
    def __init__(self, *args, **kwargs):
        """
        A gradient reversal layer.
        This layer has no parameters, and simply reverses the gradient
        in the backward pass.
        """

        super().__init__(*args, **kwargs)

    def forward(self, input_):
        #######

In [None]:
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

In [None]:
import torch

network_green = torch.nn.Sequential(
    torch.nn.Conv2d(1, 32, kernel_size=5),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(kernel_size=2, stride=2),
    torch.nn.Conv2d(32, 48, kernel_size=5),
    torch.nn.ReLU(),
    torch.nn.AdaptiveMaxPool2d(1),
    )

network_purple = torch.nn.Sequential(
    torch.nn.Linear(48, 100),
    torch.nn.ReLU(),
    torch.nn.Linear(100, 100),
    torch.nn.ReLU(),
    torch.nn.Linear(100, 10), 
    torch.nn.Softmax(),
)

network_pink = torch.nn.Sequential(
    RevGradLayer(),
    torch.nn.Linear(48, 100),
    torch.nn.ReLU(),
    torch.nn.Linear(100, 1),
    torch.nn.Sigmoid(),
    )

In [None]:
class MnistNetwork(torch.nn.Module):
    
    def __init__(self, network_green, network_purple, network_pink):
        super().__init__()
        self.network_green = network_green
        self.network_purple = network_purple
        self.network_pink = network_pink

    def forward(self, x):
        intermediate = self.network_green(x)
        intermediate = intermediate.reshape(-1, 48)
        class_label_prediction = self.network_purple(intermediate)
        domain_prediction = self.network_pink(intermediate)
        
        return class_label_prediction, domain_prediction

In [None]:
# use it to train:

model = MnistNetwork(network_green, network_purple, network_pink)

class_loss = torch.nn.CrossEntropyLoss()
domain_loss = torch.nn.BCELoss()
optimiser = torch.optim.Adam(model.parameters(), lr=0.0005)

data = torch.utils.data.DataLoader(
    MNIST(
        "mnist", download=True, transform=ToTensor()
    ),
    batch_size=64,
)

for epoch in range(5):
    for i, (batch_x, batch_y) in enumerate(data):
        # we only have one domain in this example,
        # so for now we just choose randomly 1:
        domain = (torch.rand((len(batch_y), )) > 0.5).float()
        class_predictions, domain_predictions = model(batch_x)
        loss_class = class_loss(class_predictions, batch_y)
        loss_domain = domain_loss(domain_predictions, domain)
        loss = loss_class + 0.1 * loss_domain

        # do the backward pass:
        loss.backward()
        # take a gradient descent step:
        optimiser.step()
        # reset the gradients
        optimiser.zero_grad()

        if (i % 200) == 0:
            # print the loss regularly to see what's going on:
            print(f"Classification loss: {loss_class}, domain loss: {loss_domain}")

---
# _Solutions_
---

In [None]:
return input_

return grad_output * -1



In [None]:
return RevGrad.apply(input_)