# Learning to Reweight Examples for Robust Deep Learning

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from model import *
from data_loader import *
import matplotlib.pyplot as plt
from tqdm import tqdm
import IPython
import gc

%matplotlib inline
%load_ext autoreload
%autoreload 2

  from ._conv import register_converters as _register_converters


In [3]:
hyperparameters = {
    'lr' : 1e-3,
    'momentum' : 0.9,
    'batch_size' : 100,
    'num_iterations' : 8000,
}

### Dataset
Following the class imbalance experiment in the paper, we used numbers 9 and 4 of the MNIST dataset to form a highly imbalanced dataset where 9 is the dominating class. The test set on the other hand is balanced.

In [4]:
data_loader = get_mnist_loader(hyperparameters['batch_size'], classes=[9, 4], proportion=0.995, mode="train")
test_loader = get_mnist_loader(hyperparameters['batch_size'], classes=[9, 4], proportion=0.5, mode="test")

In [5]:
def to_var(x, requires_grad=True):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x, requires_grad=requires_grad)

#### Since the validation data is small (only 10 examples) there is no need to wrap it in a dataloader

In [6]:
val_data = to_var(data_loader.dataset.data_val, requires_grad=False)
val_labels = to_var(data_loader.dataset.labels_val, requires_grad=False)

In [7]:
for i,(img, label) in enumerate(data_loader):
    print(img.size(),label)
    break

torch.Size([100, 1, 32, 32]) tensor([ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.])


In [8]:
for i,(img, label) in enumerate(test_loader):
    print(img.size(),label)
    break

torch.Size([100, 1, 32, 32]) tensor([ 0.,  0.,  0.,  0.,  0.,  1.,  0.,  1.,  1.,  0.,  1.,  0.,
         1.,  0.,  0.,  1.,  0.,  0.,  0.,  1.,  1.,  0.,  1.,  1.,
         1.,  0.,  1.,  0.,  0.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,
         0.,  1.,  0.,  1.,  0.,  0.,  1.,  1.,  0.,  1.,  1.,  0.,
         0.,  1.,  1.,  1.,  1.,  0.,  0.,  1.,  0.,  0.,  0.,  1.,
         0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  1.,
         0.,  1.,  0.,  1.,  1.,  1.,  0.,  1.,  0.,  1.,  0.,  1.,
         0.,  0.,  0.,  1.,  0.,  1.,  0.,  0.,  1.,  0.,  1.,  0.,
         0.,  1.,  0.,  1.])


In [9]:
net = LeNet(n_out=1)

if torch.cuda.is_available():
    net.cuda()
    torch.backends.cudnn.benchmark=True

NameError: name 'to_var' is not defined

In [None]:
opt = torch.optim.SGD(net.params(),lr=hyperparameters["lr"])

## Baseline Model
I trained a LeNet model for the MNIST data without weighting the loss as a baseline model for comparison.

In [None]:
net_losses = []
plot_step = 100

smoothing_alpha = 0.9
accuracy_log = []
for i in tqdm(range(hyperparameters['num_iterations'])):
    net.train()
    image, labels = next(iter(data_loader))

    image = to_var(image, requires_grad=False)
    labels = to_var(labels, requires_grad=False)
    
    # real network
    y = net(image)
    cost = F.binary_cross_entropy_with_logits(y, labels)
    
    opt.zero_grad()
    cost.backward()
    opt.step()
    
    net_l = smoothing_alpha *net_l + (1 - smoothing_alpha)* cost.item()
    net_losses.append(net_l/(1 - smoothing_alpha**(i+1)))
    
    if i % plot_step == 0:
        net.eval()
        
        acc = []
        for itr,(test_img, test_label) in enumerate(test_loader):
            test_img = to_var(test_img, requires_grad=False)
            test_label = to_var(test_label, requires_grad=False)
            
            output = net(test_img)
            predicted = (F.sigmoid(output) > 0.5).int()
            
            acc.append((predicted.int() == test_label.int()).float())

        accuracy = torch.cat(acc,dim=0).mean()
        accuracy_log.append(np.array([i,accuracy])[None])
        
        
        IPython.display.clear_output()
        fig, axes = plt.subplots(1, 2, figsize=(13,5))
        ax1, ax2 = axes.ravel()

        ax1.plot(net_losses, label='net_losses')
        ax1.set_ylabel("Losses")
        ax1.set_xlabel("Iteration")
        ax1.legend()
        
        acc_log = np.concatenate(accuracy_log, axis=0)
        ax2.plot(acc_log[:,0],acc_log[:,1])
        ax2.set_ylabel('Accuracy')
        ax2.set_xlabel('Iteration')
        plt.show()

As expected, due to the heavily imbalanced training data, the network could not learn how to differentiate between 9 and 4.

## Learning to Reweight Examples 
Below is a pseudocode of the method proposed in the paper. It is very straightforward.

![pseudocode](pseudocode.PNG)

In [None]:
meta_losses_clean = []
net_losses = []
plot_step = 100
torch.backends.cudnn.benchmark=True

smoothing_alpha = 0.9
meta_l = 0
net_l = 0
accuracy_log = []
for i in tqdm(range(hyperparameters['num_iterations'])):
    net.train()
    image, labels = next(iter(data_loader))
    # since validation data is small I just fixed them instead of building an iterator
    # initialize a dummy network for the meta learning of the weights
    meta_net = LeNet(n_out=1)
    meta_net.load_state_dict(net.state_dict())
    
    if torch.cuda.is_available():
        meta_net.cuda()
    
    image = to_var(image, requires_grad=False)
    labels = to_var(labels, requires_grad=False)
    
    y_f_hat  = meta_net(image)
    cost = F.binary_cross_entropy_with_logits(y_f_hat,labels, reduce=False)
    eps = to_var(torch.zeros(cost.size()))
    l_f_meta = torch.sum(cost * eps)

    meta_net.zero_grad()

    grads = torch.autograd.grad(l_f_meta, (meta_net.params()), create_graph=True)
    
    meta_net.update_params(hyperparameters['lr'], source_params=grads)
    
    y_g_hat = meta_net(val_data)

    l_g_meta = F.binary_cross_entropy_with_logits(y_g_hat,val_labels)

    grad_eps = torch.autograd.grad(l_g_meta, eps, only_inputs=True)[0]

    w_tilde = torch.clamp(-grad_eps,min=0)
    norm_c = torch.sum(w_tilde)

    if norm_c != 0:
        w = w_tilde / norm_c
    else:
        w = w_tilde

    # real network
    y_f_hat = net(image)
    cost = F.binary_cross_entropy_with_logits(y_f_hat, labels, reduce=False)
    l_f = torch.sum(cost * w)
    
    opt.zero_grad()
    l_f.backward()
    opt.step()
    
    meta_l = smoothing_alpha *meta_l + (1 - smoothing_alpha)* l_g_meta.item()
    meta_losses_clean.append(meta_l/(1 - smoothing_alpha**(i+1)))
    
    net_l = smoothing_alpha *net_l + (1 - smoothing_alpha)* l_f.item()
    net_losses.append(net_l/(1 - smoothing_alpha**(i+1)))
    
    if i % plot_step == 0:
        net.eval()
        
        acc = []
        for itr,(test_img, test_label) in enumerate(test_loader):
            test_img = to_var(test_img, requires_grad=False)
            test_label = to_var(test_label, requires_grad=False)
            
            output = net(test_img)
            predicted = (F.sigmoid(output) > 0.5).int()
            
            acc.append((predicted.int() == test_label.int()).float())

        accuracy = torch.cat(acc,dim=0).mean()
        accuracy_log.append(np.array([i,accuracy])[None])
        
        
        IPython.display.clear_output()
        fig, axes = plt.subplots(1, 2, figsize=(13,5))
        ax1, ax2 = axes.ravel()

        ax1.plot(meta_losses_clean, label='meta_losses_clean')
        ax1.plot(net_losses, label='net_losses')
        ax1.set_ylabel("Losses")
        ax1.set_xlabel("Iteration")
        ax1.legend()
        
        acc_log = np.concatenate(accuracy_log, axis=0)
        ax2.plot(acc_log[:,0],acc_log[:,1])
        ax2.set_ylabel('Accuracy')
        ax2.set_xlabel('Iteration')
        plt.show()
        
#         print(accuracy)
            