# Unsupervised Domain Adaptation By Backpropagation #
We are given a source dataset [$(x_{i}, y_{i}) \in D_{S} $] where $x_{i}$ is an instance and $y_{i}$ is the corresponding label. 
Our task is to classify the target dataset [$x_{i} \in D_{T}$] for which we have no labels. Thus we wish to perform unsupervised classification on $D_{T}$ by leveraging $D_{S}$ via Domain Adaptation. To do this, we will leverage insights from the paper "Unsupervised Domain Adaptation By Backpropagation" https://arxiv.org/pdf/1409.7495.pdf

The structure of this assignment is as follows.

**Part 1 : Upper Bounding Performance**

You will build a supervised classifier for $D_{S}$ and $D_{T}$, given the full datasets. This will provide an upper bound on the our performance on  $D_{T}$ in the case where we have training labels. 

**Part 2 : Lower Bounding Performance**

You will apply the classifier from $D_{S}$ to $D_{T}$. Here, we are assuming that training a supervised classifier on $D_{S}$ (for which we have lables) and applying it directly to $D_{T}$ is sufficient. 
The performance of this method should provide a lower bound. 

**Part 3 : Gradient Reversal Layer**

In this part, we apply the trick from [1] to peform unsupervised classification on $D_{T}$ by leveraging the available data in $D_{S}$. We will do this by training a model that has 2 heads. Head 1 will be used for performing classification - this will only be trained on $D_{S}$ since it has available labels. Head 2 will be used for distinguising the two domains $D_{S}$ and $D_{T}$. It turns out that by trying to produce representations that fool Head 2, you can build a reasonably accurate classifier for $D_{T}$  

[1] "Unsupervised Domain Adaptation By Backpropagation" https://arxiv.org/pdf/1409.7495.pdf

# PART 1 : Upper Bounding Performance

In [None]:
import torch 
from torchvision import datasets 
import matplotlib.pyplot as plt 
from torch.utils.data import DataLoader 
from data_loader import MNISTM
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torchvision import transforms
from tqdm import tqdm_notebook
import numpy as np
from sklearn.manifold import TSNE 
import math

In [None]:
pil_transform = transforms.ToPILImage()
to_tensor_tform = transforms.ToTensor() 

**Load MNIST**

The MNIST Dataset is a set of grayscale images of digits from 0-9. 

Run the Cell below to load the training and test splits of MNIST

In [None]:
mnist_train = datasets.MNIST(root='./mnist', train=True, transform=to_tensor_tform, download=True)
mnist_test = datasets.MNIST(root='./mnist', train=False, transform=to_tensor_tform, download=True)

In [None]:
mnist_img, mnist_label = mnist_train[0]
print('Mnist Images have this shape : ', mnist_img.shape)
plt.imshow(pil_transform(mnist_img))
_ = plt.title('Digit : {}'.format(mnist_label))

**Load MNIST-M**

MNISTM is a dataset generated by placing digits from 0-9 on random background patches.

Run the Cell below to load the training and test splits of MNIST-M

In [None]:
mnistm_train = MNISTM(root='./mnistm', mnist_root='./mnist', train=True, transform=to_tensor_tform, download=True)
mnistm_test = MNISTM(root='./mnistm', mnist_root='./mnist', train=False, transform=to_tensor_tform, download=True)

In [None]:
mnistm_img, mnistm_label = mnistm_train[0]
print('Mnist Images have this shape : ', mnistm_img.shape)
plt.imshow(pil_transform(mnistm_img))
_ = plt.title('Digit : {}'.format(mnistm_label))

## Domain Specific Classifier ##
Train a classifier for each of the individual domains 

In [None]:
# Declaring constants here 
NCLASSES = 10 # Because MNIST, MNIST-M have 10 classes (0 - 9)
NEPOCHS = 8   # Number of epochs to run the model for 
BATCH_SZ = 32 # The batch size of for training

**TODO [YOU]**

Build a classifier for the MNIST and MNIST-M datasets. Feel free to design your own architecture. However, if you want, you can use the following architecture : 

Conv (output_channels = 8, kernel = 3)  
|  
ReLU  
|  
Conv (output_channels = 16, kernel = 5)  
|  
ReLU  
|  
Maxpool (kernel = 2, stride = 2)  
|  
Conv (output_channels = 32, kernel = 3)  
|  
ReLU  
|  
Conv (output_channels = 32, kernel = 5)  
|  
Maxpool (kernel = 2, stride = 2)  
|  
Flatten  
|  
Linear (in = 128 , out = NCLASSES)

If not specified, all other parameters to these functions are the default. 

inchannels specifies the number of channels the input image has. Remember MNIST is grayscale (inchannels = 1) and MNIST-M is color (inchannels = 3)


In [None]:
from utils import Flatten 

def get_domain_specific_model(inchannels=3):
    ######################## YOUR CODE BEGINS ########################
    domain_model = None
    ######################## YOUR CODE ENDS ########################
    assert domain_model != None, 'Domain Model has not yet been implemented'
    domain_model.cuda()
    return domain_model

In [None]:
def visualize_results(desc, train_stats, valid_stats):
    f, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    ax1.plot(train_stats[1], label='Train Acc', c='green')
    ax1.plot(valid_stats[1], label='Valid Acc', c='red')
    ax1.set_title('{} : Acccuracy Metric'.format(desc))
    ax1.legend()
    
    ax2.plot(train_stats[0], label='Train Loss', c='green')
    ax2.plot(valid_stats[0], label='Valid Loss', c='red')
    ax2.set_title('{} : Loss Metric'.format(desc))
    ax2.legend()
    
    plt.show()
    
def domain_epoch(epoch, domain_model, domain_data, optim, criterion, is_train=True): 
    print('Starting {} Epoch {}'.format(epoch, 'Train' if is_train else 'Valid'))
    if is_train:
        domain_model.train()
    else:
        domain_model.eval()
    
    losses, total_correct, total = [], 0, 0
    for batch in tqdm_notebook(domain_data):
        imgs, labels = batch
        imgs, labels = torch.tensor(imgs).cuda(), torch.tensor(labels).cuda()
        
        logits = domain_model(imgs)
        loss = criterion(logits, labels)
        if is_train:
            domain_model.zero_grad()
            loss.backward()
            optim.step()
        
        losses.append(loss.item())
        total_correct += logits.argmax(dim=-1).eq(labels).sum().item()
        total += logits.shape[0]
    
    avg_loss, avg_acc = np.mean(losses), (total_correct * 1.0) / total
    return avg_loss, avg_acc


def model_pipeline(train_data, valid_data, inchannels=3, lr=3e-4):
    criterion = nn.CrossEntropyLoss() 
    domain_model = get_domain_specific_model(inchannels=inchannels)
    parameters = list(domain_model.parameters())
    optim = Adam(parameters, lr=lr)
    train_stats, valid_stats = [[], []], [[], []]
    for i in range(NEPOCHS):
        t_loss_avg, t_acc_avg = domain_epoch(i, domain_model, train_data, optim, criterion, is_train=True)
        train_stats[0].append(t_loss_avg)
        train_stats[1].append(t_acc_avg)
        v_loss_avg, v_acc_avg = domain_epoch(i, domain_model, valid_data, None, criterion, is_train=False)
        valid_stats[0].append(v_loss_avg)
        valid_stats[1].append(v_acc_avg)
        
    return domain_model, train_stats, valid_stats

**Perform Domain Specific Training**  
Run the code below. The outputs of the model pipeline will be  
1. A trained model
2. An array containing the training statistics
3. An array containing the validation statistics 

In [None]:
print('Doing the domain specific training for MNIST')
mnist_train_loader = DataLoader(mnist_train, batch_size=BATCH_SZ)
mnist_val_loader = DataLoader(mnist_test, batch_size=BATCH_SZ)
mnist_model, mnist_train_stats, mnist_test_stats = model_pipeline(mnist_train_loader, mnist_val_loader, inchannels=1)
visualize_results('MNIST', mnist_train_stats, mnist_test_stats)
print("MNIST - Best Test Loss {}, Best Test Accuracy {}".format(min(mnist_test_stats[0]), max(mnist_test_stats[1])))

Verify that your test loss is decreasing and your test accuracy is increasing from the graphs above. 
Your Test Accuracy should be > 90 %

In [None]:
print('Doing the domain specific training for MNIST-M')
mnistm_train_loader = DataLoader(mnistm_train, batch_size=32)
mnistm_val_loader = DataLoader(mnistm_test, batch_size=32)
mnistm_model, mnistm_train_stats, mnistm_test_stats = model_pipeline(mnistm_train_loader, mnistm_val_loader, inchannels=3)
visualize_results('MNIST-M', mnistm_train_stats, mnistm_test_stats)
print("MNISTM - Best Test Loss {}, Best Test Accuracy {}".format(min(mnistm_test_stats[0]), max(mnistm_test_stats[1])))

Verify that your test loss is decreasing and your test accuracy is increasing from the graphs above. 
Your Test Accuracy should be > 90 %

## Part 2 : Lower Bounding Performance ##

In this section, you will perform Cross Domain Evaluation. This will set a lower bound to your classifier performance. 

**TODO [YOU]**  
In part one, you trained 2 models, **mnist_model** and **mnistm_model**. You will evaluate **mnist_model** on the Mnist-M dataset. 
Remember that the Mnist-M data is color whilst the **mnist_model** was trained on grayscale images. You will have to apply the following transforms to Mnist-M images  
1. GrayScale
2. Convert from Image to Tensor

After this, run 1 epoch of the mnist_model on the cross-domain Mnist-M dataset. Use the **domain_epoch** function defined above 

In [None]:
######################## YOUR CODE BEGINS ########################
mnistm_transforms = None
######################## YOUR CODE ENDS ########################

cross_mnistm_test = MNISTM(root='./mnistm', mnist_root='./mnist', train=False, transform=mnistm_transforms, download=True)
cross_mnistm_val_loader = DataLoader(cross_mnistm_test, batch_size=32)

cross_mnistm_train = MNISTM(root='./mnistm', mnist_root='./mnist', train=True, transform=mnistm_transforms, download=True)
cross_mnistm_train_loader = DataLoader(cross_mnistm_train, batch_size=32)

In [None]:
######################## YOUR CODE BEGINS ########################
# Run an epoch of mnist_model on cross_mnistm_val_loader. Remember we only care about evaluation
criterion = nn.CrossEntropyLoss() 
loss_avg, acc_avg = None
######################## YOUR CODE ENDS ########################
print('Doing Cross Domain Evaluation Yields : Loss - {}, Accuracy - {}'.format(loss_avg, acc_avg))

As you can see, the cross-domain accuracy is very poor. ~ 50%. In the next section, we will explore an approach to improve this accuracy 

## Part 3 : Gradient Reversal Layer ##

In this section, we will implement an unsupervised approach for domain adaptation - "Unsupervised Domain Adaptation By Backpropagation" https://arxiv.org/pdf/1409.7495.pdf.  
![Unsupervised Domain Adaptation By Backpropagation](figs/paper_fig.png)

The main idea behind this paper is to train 1 model to classify both $D_{S}$ and $D_{T}$.  
Our model will have 1 head for classifying samples from both domains. This head, $G_y$ with parameters $\theta_y$, will classify samples into NCLASSES where NCLASSES = 10 for our running example of the digit classification task. 
In order for $G_y$ to be a good classifier for both domains, we need to project $D_{S}$ and $D_{T}$ into a single, domain-invariant space **F** (the hope is that this space captures relevant information for classifying digits and discards irrelevant information like background). What this means is that after embedding $D_{S}$ and $D_{T}$ into **F**, we should not be able to distinguish samples from the two domains.
Our model has a trunk $G_{f}$ with parameters $\theta_f$ that is responsible for embedding samples into the domain invariant space **F**. In order to ensure that the space defined by $G_{f}$ is domain-invariant, we augment our model with another head $G_d$ with parameters $\theta_d$ which is responsible identifying the domain of samples from **F**. Thus, whilst $G_d$ tries to achieve maximum domain classification accuracy, $G_f$ tries to achieve minimal accuracy so that **F** is truly domain invariant.  
The above setup thus defines a kind of min-max adversarial game between $G_f$ and $G_d$.

More formally, let $L_d$ be the domain classification loss and $L_y$ be the digit classification loss. Let $d_i = 0$ if instance $x_i$ belongs to the source domain and $d_i = 1$ if it belongs to the target domain. Our overall loss function will be :  

![Loss](figs/loss_fig.png)

Running gradient descent on this loss will not work. Remember that we want a min-max adversarial game between $G_f$ and $G_d$ on the domain loss $L_d$ but with the current objective, both minimize $L_d$. Since we want $G_f$ to maximize $L_d$ instead, we have to perform gradient ascent instead of gradient descent on $G_f$ w.r.t $L_d$. 
This is achieved by inserting a Gradient Reversal Layer (GRL) ($R_\lambda$ in the figure above) between $G_f$ and $G_d$. $R_\lambda$ reverses the gradient from the domain classification head before they reach the feature trunk - as in figure 1 above. We introduce the factor $\lambda$ which serves as a weighting between the classification loss and the domain loss for $G_f$

**TODO [YOU]**  
Implement a Gradient Reversal Layer. You can extend PyTorch with a new Autograd Layer by following the examples here :
https://pytorch.org/docs/stable/notes/extending.html

Your **forward** function should take in **input**, **lambda**. Your forward pass should not modify the input but just save the required information for the backward pass.  

![GRL](figs/grl_fig.png)

Your **backward** function should return the **gradient of the input, scaled by $-\lambda$** and **a dummy gradient of the lambda** which you can set to zero. 

In [None]:
from torch.autograd import Function
class GradRevLayer(Function):
    ######################## YOUR CODE BEGINS ########################
    @staticmethod
    def forward(ctx, input, lambda_):
        pass

    @staticmethod
    def backward(ctx, grad_output):
        pass
    ######################## YOUR CODE ENDS ##########################

**TODO [YOU]**  
Implement the Full Gradient Reversal Model. Follow the following architecture in your **init** method. 
![architecture](figs/arch.png)

You can instantiate a GRL via :  

grad_rev_layer = GradRevLayer.apply

Your **forward** method should take the input example **x** and the current value of  **lambda_** and return 3 things in the following order  
1. the logits for the digit classification task
2. the logits for the domain classification task 
3. the features from $G_f$ before they are passed through the gradient reversal layer

In [None]:
class GradReversalModel(nn.Module):
    def __init__(self):
        super(GradReversalModel, self).__init__()
        ######################## YOUR CODE BEGINS ########################
        pass
        ######################## YOUR CODE ENDS ##########################
    
    def forward(self, x, lambda_):
        ######################## YOUR CODE BEGINS ########################
        pass
        ######################## YOUR CODE ENDS ##########################

In [None]:
def get_features(model, dataset, max_examples=1000):
    feature_list, num_examples = [], 0
    for batch in dataset:
        imgs, labels = batch
        imgs, labels = torch.tensor(imgs).cuda(), torch.tensor(labels).cuda()
        dummy_lambda_ = torch.tensor(1).cuda()
        _, _, features = model(imgs, dummy_lambda_)
        feature_list.append(features)
        num_examples += imgs.shape[0]
        if num_examples >  max_examples:
            break
    return torch.cat(feature_list)

def tsne_visualize(final_source, final_target, init_source, init_target):
    tsne =  TSNE(n_components=2)

    init_embeds = tsne.fit_transform(torch.cat([init_source, init_target]).detach().cpu().numpy())
    final_embeds = tsne.fit_transform(torch.cat([final_source, final_target]).detach().cpu().numpy())
    
    f, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    src_len = init_source.shape[0]
    ax1.scatter(init_embeds[:src_len, 0], init_embeds[:src_len, 1], marker='.', color='green', label='source')
    ax1.scatter(init_embeds[src_len:, 0], init_embeds[src_len:, 1], marker='.', color='red', label='target')
    ax1.set_title("Initial Distribution Before Domain Adaptation")
    ax1.legend()
    
    src_len = final_source.shape[0]
    ax2.scatter(final_embeds[:src_len, 0], final_embeds[:src_len, 1], marker='.', color='green', label='source')
    ax2.scatter(final_embeds[src_len:, 0], final_embeds[src_len:, 1], marker='.', color='red', label='target')
    ax2.set_title("Final Distribution After Domain Adaptation")
    ax2.legend()
    
    plt.show()

In [None]:
def unsup_da_epoch(epoch, model, source_data, target_data, optim, class_criterion, domain_criterion, is_train=True): 
    print('Starting {} Epoch {}'.format(epoch, 'Train' if is_train else 'Valid'))
    
    # Set the model to train or eval mode
    if is_train:
        model.train()
    else:
        model.eval()
    
    def get_lambda(train_progress, gamma=10):
        lambda_ =  ( 2 / (1 + math.exp(-gamma * train_progress) ) ) - 1
        lambda_ = torch.tensor(lambda_, requires_grad=False).cuda()
        return lambda_
        
    def run_batch(batch, lambda_, is_source=True):
        imgs, labels = batch 
        imgs, labels = torch.tensor(imgs).cuda(), torch.tensor(labels).cuda()
        
        class_logits, domain_logits, _ = model(imgs, lambda_)
        class_loss = class_criterion(class_logits, labels)
        if is_source : 
            domain_labels = torch.ones_like(domain_logits)
            domain_loss = domain_criterion(domain_logits, domain_labels)
            total_loss = class_loss + domain_loss
        else:
            domain_labels = torch.zeros_like(domain_logits)
            domain_loss = domain_criterion(domain_logits, domain_labels)
            total_loss = domain_loss

        if is_train:
            total_loss.backward()

        class_loss = class_loss.item()
        class_accuracy = class_logits.argmax(dim=-1).eq(labels).sum().item()
        
        domain_loss = domain_loss.item()
        domain_accuracy = ((F.sigmoid(domain_logits) > 0.5).long()).eq(domain_labels.long()).sum().item()
        
        return class_loss, class_accuracy, domain_loss, domain_accuracy, labels.shape[0]
    
    
    
    source_stats = [[], [], 0, 0, 0]
    target_stats = [[], [], 0, 0, 0]

    target_data = iter(target_data)
    num_batches, cur_batch  = len(source_data), 1
    for batch in tqdm_notebook(source_data):
        lambda_ = get_lambda((num_batches * epoch + cur_batch ) / (num_batches * NEPOCHS))
        class_loss, class_corr, domain_loss, domain_corr, sz = run_batch(batch, lambda_)
        
        # Save the statistics of the batch
        source_stats[0].append(class_loss)
        source_stats[2] += class_corr
        source_stats[1].append(domain_loss)
        source_stats[3] += domain_corr
        source_stats[4] += sz
        
        class_loss, class_corr, domain_loss, domain_corr, sz = run_batch(target_data.next(), lambda_, is_source=False)
       
        # Save the statistics of the batch
        target_stats[0].append(class_loss)
        target_stats[2] += class_corr
        target_stats[1].append(domain_loss)
        target_stats[3] += domain_corr
        target_stats[4] += sz
        
        if is_train:
            optim.step()
            model.zero_grad()
        
        cur_batch += 1
        
    source_stats = np.mean(source_stats[0]),  source_stats[2] * 1.0 /  source_stats[4], np.mean(source_stats[1]), source_stats[3] * 1.0 / source_stats[4]
    target_stats = np.mean(target_stats[0]),  target_stats[2] * 1.0 /  target_stats[4], np.mean(target_stats[1]), target_stats[3] * 1.0 / target_stats[4]        
    print('Source : digit loss {} | digit accuracy {} | domain loss {} | domain accuracy {}'.format(*source_stats))
    print('Target : digit loss {} | digit accuracy {} | domain loss {} | domain accuracy {}'.format(*target_stats))
    return source_stats, target_stats

**TODO[YOU]**  
What type of loss should be used for digit classification   
What type of loss should be used for domain classification 

In [None]:
def unsup_da_model_pipeline(model, source_train, source_valid, target_train, target_valid, lr=5e-5):
    ######################## YOUR CODE BEGINS ########################
    class_criterion = None 
    domain_criterion = None
    ######################## YOUR CODE ENDS ##########################
    parameters = list(model.parameters())
    optim = Adam(parameters, lr=lr)
    source_train_stats, source_valid_stats = [], []
    target_train_stats, target_valid_stats = [], []
    for i in range(NEPOCHS):
        # Perform Traininig Procedure
        source_stats, target_stats = unsup_da_epoch(i, model, source_train, target_train, optim, class_criterion, domain_criterion, is_train=True)
        source_train_stats.append(source_stats)
        target_train_stats.append(target_stats)
        
        # Perform Evaluation Procedure
        source_stats, target_stats = unsup_da_epoch(i, model, source_valid, target_valid, None, class_criterion, domain_criterion, is_train=False)
        source_valid_stats.append(source_stats)
        target_valid_stats.append(target_stats)
        
    return model, source_train_stats, source_valid_stats, target_train_stats, target_valid_stats

In [None]:
tform = transforms.Compose([ 
                            transforms.ToTensor(),
                            transforms.Lambda(lambda x: torch.cat([x, x, x], 0)),
                          ])

mnist_train = datasets.MNIST(root='./mnist', train=True, transform=tform, download=True)
mnist_test = datasets.MNIST(root='./mnist', train=False, transform=tform, download=True)
mnist_train_loader = DataLoader(mnist_train, batch_size=BATCH_SZ)
mnist_val_loader = DataLoader(mnist_test, batch_size=BATCH_SZ)

tform = transforms.Compose([
                            transforms.ToTensor(),
                          ])

cross_mnistm_test = MNISTM(root='./mnistm', mnist_root='./mnist', train=False, transform=tform, download=True)
cross_mnistm_train = MNISTM(root='./mnistm', mnist_root='./mnist', train=True, transform=tform, download=True)
cross_mnistm_train_loader = DataLoader(cross_mnistm_train, batch_size=BATCH_SZ)
cross_mnistm_val_loader = DataLoader(cross_mnistm_test, batch_size=BATCH_SZ)

In [None]:
model =  GradReversalModel().cuda()

In [None]:
# Get the initial features from the model before training 
init_source = get_features(model, mnist_val_loader)
init_target = get_features(model, cross_mnistm_val_loader)

In [None]:
NEPOCHS = 20
BATCH_SZ = 64
# TRAIN THE GRL MODEL. 
model, source_train_stats, source_valid_stats, target_train_stats, target_valid_stats = unsup_da_model_pipeline(model, mnist_train_loader, mnist_val_loader, cross_mnistm_train_loader, cross_mnistm_val_loader)

Observe the target digit accuracy as the model trains. How does the best target model accuracy compare with the cross-domain evaluation done above ? 
You should be getting an accuracy ~70% with your GRL Model. This is about ~20% more than the cross-domain evaluation accuracy.   
Run the Cell below to visualize the accuracies and losses at durring training for both the source and target. Do the curves follow your expectations ?

In [None]:
train_c_loss, train_c_acc, train_d_loss, train_d_acc = list(zip(*source_train_stats))
valid_c_loss, valid_c_acc, valid_d_loss, valid_d_acc = list(zip(*source_valid_stats))

visualize_results('MNIST CLASS', [train_c_loss, train_c_acc,], [valid_c_loss, valid_c_acc])
visualize_results('MNIST DOMAIN', [train_d_loss, train_d_acc,], [valid_d_loss, valid_d_acc])

In [None]:
train_c_loss, train_c_acc, train_d_loss, train_d_acc = list(zip(*target_train_stats))
valid_c_loss, valid_c_acc, valid_d_loss, valid_d_acc = list(zip(*target_valid_stats))

visualize_results('MNISTM CLASS', [train_c_loss, train_c_acc,], [valid_c_loss, valid_c_acc])
visualize_results('MNISTM DOMAIN', [train_d_loss, train_d_acc,], [valid_d_loss, valid_d_acc])

As mentioned in the instructions above, we want $G_f$ to embed samples into a domain-invariant space **F**. 
Run the above cell to visualize the features from $G_f$ before and after training the GRL model. 
What do you notice. Is this what you expect ? 

In [None]:
final_source = get_features(model, mnist_val_loader)
final_target = get_features(model, cross_mnistm_val_loader)
tsne_visualize(final_source, final_target, init_source, init_target)