# 1. Imports

In [None]:
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np

# 2. Utilities

## 2.1 Metric Monitor
Keeps track of average over values added to its instance. Useful to track accuracy over batches and epochs

In [None]:
from collections import defaultdict

class MetricMonitor:
    def __init__(self, float_precision=4):
        self.float_precision = float_precision
        self.reset()

    def reset(self):
        self.metrics = defaultdict(lambda: {"val": 0, "count": 0, "avg": 0})

    def update(self, metric_name, val):
        metric = self.metrics[metric_name]

        metric["val"] += val
        metric["count"] += 1
        metric["avg"] = metric["val"] / metric["count"]

    def __str__(self):
        return " | ".join(
            [
                "{metric_name}: {avg:.{float_precision}f}".format(
                    metric_name=metric_name, avg=metric["avg"], float_precision=self.float_precision
                )
                for (metric_name, metric) in self.metrics.items()
            ]
        )

## 2.2 Early Stopping
Early stopping is a form of regularization used to avoid overfitting on the training dataset. Early stopping keeps track of the validation loss, if the loss stops decreasing for several epochs in a row the training stops. The ```EarlyStopping``` class is used to create an object to keep track of the validation loss. It will save a checkpoint of the model each time the validation loss decrease.  We set the ```patience``` argument to how many epochs we want to wait after the last time the validation loss improved before breaking the training loop.

In [None]:
class EarlyStopping:
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        """
        Source:
            https://github.com/Bjarten/early-stopping-pytorch/blob/master/pytorchtools.py
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'> early stopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(f'validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

## 2.3 Attention
* Attention
* Gated Attention
* MIL pool (mean/max)

In [None]:
class Attention(nn.Module):
    def __init__(self, model):
        super(Attention, self).__init__()
        self.name = 'Attention'
        self.L = 500
        self.D = 128
        self.K = 1

        self.feature_extractor = model
        self._to_linear = model._to_linear
        
        self.fc = nn.Sequential(
            nn.Linear(self._to_linear, self.L),
            nn.ReLU())
        
        self.attention = nn.Sequential(
            nn.Linear(self.L, self.D),
            nn.Tanh(),
            nn.Linear(self.D, self.K))

        self.classifier = nn.Sequential(
            nn.Linear(self.L*self.K, 1),
            nn.Sigmoid())

    def forward(self, x):
        x = x.squeeze(0)

        H = self.feature_extractor(x)
        H = H.view(-1, self._to_linear)
        H = self.fc(H)  # [b x L]

        A = self.attention(H)  # [b x K]
        A = torch.transpose(A, 1, 0)  # [K x b]
        A = F.softmax(A, dim=1)  # softmax over b
            
        M = torch.mm(A, H)  # [K x L]

        Y_prob = self.classifier(M)
        Y_hat = torch.ge(Y_prob, 0.5).float()

        return Y_prob, Y_hat
        
    # AUXILIARY METHODS
    def calculate_classification_error(self, X, Y):
        Y = Y.float()
        _, Y_hat = self.forward(X)
        error = 1. - Y_hat.eq(Y).cpu().float().mean().data

        return error, Y_hat

    def calculate_objective(self, X, Y):
        Y = Y.float()
        Y_prob, Y_hat = self.forward(X)
        Y_prob = torch.clamp(Y_prob, min=1e-5, max=1. - 1e-5)
        loss = -1. * (Y * torch.log(Y_prob) + (1. - Y) * torch.log(1. - Y_prob))  # negative log bernoulli
        return loss
    
    
class GatedAttention(nn.Module):
    def __init__(self, model):
        super(GatedAttention, self).__init__()
        self.name = 'Gated Attention'
        self.L = 500
        self.D = 128
        self.K = 1

        self.feature_extractor = model
        self._to_linear = model._to_linear
        
        self.fc = nn.Sequential(
            nn.Linear(self._to_linear, self.L),
            nn.ReLU())

        self.attention_V = nn.Sequential(
            nn.Linear(self.L, self.D),
            nn.Tanh())

        self.attention_U = nn.Sequential(
            nn.Linear(self.L, self.D),
            nn.Sigmoid())

        self.attention_weights = nn.Linear(self.D, self.K)

        self.classifier = nn.Sequential(
            nn.Linear(self.L*self.K, 1),
            nn.Sigmoid())

    def forward(self, x):
        x = x.squeeze(0)

        H = self.feature_extractor(x)
        H = H.view(-1, self._to_linear)
        H = self.fc(H)  # [b x L]

        A_V = self.attention_V(H)  # [b x D]
        A_U = self.attention_U(H)  # [b x D]
        A = self.attention_weights(A_V * A_U) # element wise multiplication -> [b x K]
        A = torch.transpose(A, 1, 0)  # [K x b]
        A = F.softmax(A, dim=1)  # softmax over b

        M = torch.mm(A, H)  # [K x L]

        Y_prob = self.classifier(M)
        Y_hat = torch.ge(Y_prob, 0.5).float()

        return Y_prob, Y_hat
    
    # AUXILIARY METHODS
    def calculate_classification_error(self, X, Y):
        Y = Y.float()
        _, Y_hat = self.forward(X)
        error = 1. - Y_hat.eq(Y).cpu().float().mean().data

        return error, Y_hat

    def calculate_objective(self, X, Y):
        Y = Y.float()
        Y_prob, _ = self.forward(X)
        Y_prob = torch.clamp(Y_prob, min=1e-5, max=1. - 1e-5)
        loss = -1. * (Y * torch.log(Y_prob) + (1. - Y) * torch.log(1. - Y_prob))  # negative log bernoulli

        return loss
    
    
class MIL_pool(nn.Module):
    def __init__(self, model, operator='mean'):
        super(MIL_pool, self).__init__()
        self.L = 500 
        if operator == 'mean':
            self.operator = 'mean'
        elif operator == 'max':
            self.operator = 'max'    
        else:
            raise NotImplementedError('Operator not supported: {}'.format(operator))

        self.name = 'MIL pool ' + self.operator
        self.feature_extractor = model
        self._to_linear = model._to_linear

        self.fc = nn.Sequential(
            nn.Linear(self._to_linear, self.L),
            nn.ReLU())
        
        self.classifier = nn.Sequential(
            nn.Linear(self.L, 1),
            nn.Sigmoid())

    def forward(self, x):
        x = x.squeeze(0)
        
        # prepNN
        H = self.feature_extractor(x)
        H = H.view(-1, self._to_linear)
        H = self.fc(H)  # [b x L]
        
        # aggregate function
        if self.operator == 'mean':
            M = torch.mean(H, 0)
        else:
            M = torch.amax(H, 0)
          
        # afterNN
        Y_prob = self.classifier(M)
        Y_hat = torch.ge(Y_prob, 0.5).float()

        return Y_prob, Y_hat
        
    # AUXILIARY METHODS
    def calculate_classification_error(self, X, Y):
        Y = Y.float()
        _, Y_hat = self.forward(X)
        error = 1. - Y_hat.eq(Y).cpu().float().mean().data

        return error, Y_hat

    def calculate_objective(self, X, Y):
        Y = Y.float()
        Y_prob, Y_hat = self.forward(X)
        Y_prob = torch.clamp(Y_prob, min=1e-5, max=1. - 1e-5)
        loss = -1. * (Y * torch.log(Y_prob) + (1. - Y) * torch.log(1. - Y_prob))  # negative log bernoulli
        return loss

## 2.4 Bad Dataset

In [None]:
class BagDataset(torch.utils.data.Dataset):
    def __init__(self, loader, dataset_length, target_number=9, mean_bag_length=10, var_bag_length=2, num_bag=250, seed=1):
        self.target_number = target_number
        self.mean_bag_length = mean_bag_length
        self.var_bag_length = var_bag_length
        self.num_bag = num_bag
        self.loader = loader
        self.r = np.random.RandomState(seed)
        self.dataset_length = dataset_length # 60.000 for train MNIST
        self.loader = loader
        
        self.bag_list, self.labels_list = self._create_bags()

    def _create_bags(self):
        for (batch_data, batch_labels) in self.loader:
            all_imgs = batch_data
            all_labels = batch_labels

        bags_list = []
        labels_list = []

        for i in range(self.num_bag):
            bag_length = np.int(self.r.normal(self.mean_bag_length, self.var_bag_length, 1))
            if bag_length < 1:
                bag_length = 1

            indices = torch.LongTensor(self.r.randint(0, self.dataset_length, bag_length))

            labels_in_bag = all_labels[indices]
            labels_in_bag = labels_in_bag == self.target_number

            bags_list.append(all_imgs[indices])
            labels_list.append(labels_in_bag)

        return bags_list, labels_list

    def __len__(self):
        return len(self.labels_list)

    def __getitem__(self, index):
        bag = self.bag_list[index]
        label = [max(self.labels_list[index]), self.labels_list[index]]
        return bag, label

## 2.5 Calculate Accuracy
Accuracy is one metric for evaluating classification models. Informally, accuracy is the fraction of predictions our model got right.

In [None]:
def calculate_accuracy(output, target):
    "Calculates accuracy"
    output = output.data.max(dim=1,keepdim=True)[1]
    output = output == 1.0
    output = torch.flatten(output)
    target = target == 1.0
    target = torch.flatten(target)
    return torch.true_divide((target == output).sum(dim=0), output.size(0)).item() 

# 3. Method

## 3.1 Model
A simple Convolutional Neural Network (CNN) for image classification. Input consists of grayscale images 28x28. The feature extractor is only present since attention mechanism and used for classification. The specific CNN aims to demonstrate the use of every possible layer present in any CNN (convolutional layer, pooling layer, fully-connected layer), regularization techniques (batch normalization, dropout), and activation functions (here: ReLU).

In [None]:
class Model(torch.nn.Module):
    "Convolutional Neural Network"
    def __init__(self):
        super(Model, self).__init__()
        # L1 (?, 28, 28, 1) -> (?, 28, 28, 32) -> (?, 14, 14, 32)
        self.layer1 = torch.nn.Sequential(
            torch.nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            torch.nn.BatchNorm2d(32),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2),
            torch.nn.Dropout(p=0.2)
            )
        # L2 (?, 14, 14, 32) -> (?, 14, 14, 64) -> (?, 7, 7, 64)
        self.layer2 = torch.nn.Sequential(
            torch.nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            torch.nn.BatchNorm2d(64),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2),
            torch.nn.Dropout(p=0.2)
            )
        # L3 (?, 7, 7, 64) -> (?, 7, 7, 128) -> (?, 4, 4, 128)
        self.layer3 = torch.nn.Sequential(
            torch.nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            torch.nn.BatchNorm2d(128),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
            torch.nn.Dropout(p=0.2)
            )
        self._to_linear = 4 * 4 * 128

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = x.view(x.size(0), -1) # Flatten them for FC
        return x

## 3.2 Training
Main training loop of **model** for a number of batches over an epoch, as it is defined in [PyTorch](https://pytorch.org/). If CUDA is availiable, training will take place in GPU. ```EarlyStopping``` class is used to keep track of loss and accuracy. Returns the loss and accuracy of the epoch.

In [None]:
def training(epoch, model, train_loader, optimizer, criterion):
    "Training over an epoch"
    metric_monitor = MetricMonitor()
    model.train()
    for batch_idx, (data, labels) in enumerate(train_loader):
        bag_label = labels[0]
        if torch.cuda.is_available():
            data, bag_label = data.cuda(), bag_label.cuda()
        data , bag_label = torch.autograd.Variable(data,False), torch.autograd.Variable(bag_label)
        loss = model.calculate_objective(data.float(), bag_label)
        error, _ = model.calculate_classification_error(data.float(), bag_label)
        metric_monitor.update("Loss", loss.item())
        metric_monitor.update("Accuracy", 1-error)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print("[Epoch: {epoch:03d}] Train      | {metric_monitor}".format(epoch=epoch, metric_monitor=metric_monitor))
    return metric_monitor.metrics['Loss']['avg'], metric_monitor.metrics['Accuracy']['avg']

## 3.3 Validation
Main validation loop of **model** for a number of batches over an epoch, as it is defined in [PyTorch](https://pytorch.org/). If CUDA is availiable, validation will take place in GPU. ```EarlyStopping``` class is used to keep track of loss and accuracy. Returns the loss and accuracy of the epoch.

In [None]:
def validation(epoch, model, valid_loader, criterion):
    "Validation over an epoch"
    metric_monitor = MetricMonitor()
    model.eval()
    for batch_idx, (data, labels) in enumerate(valid_loader):
        bag_label = labels[0]
        if torch.cuda.is_available():
            data, bag_label = data.cuda(), bag_label.cuda()
        data, bag_label = torch.autograd.Variable(data,False), torch.autograd.Variable(bag_label)
        loss = model.calculate_objective(data.float(), bag_label)
        error, predicted_label = model.calculate_classification_error(data.float(), bag_label)
        metric_monitor.update("Loss", loss.item())
        metric_monitor.update("Accuracy", 1-error)
    print("[Epoch: {epoch:03d}] Validation | {metric_monitor}".format(epoch=epoch, metric_monitor=metric_monitor))
    return metric_monitor.metrics['Loss']['avg'], metric_monitor.metrics['Accuracy']['avg']

# 4. Main

In [None]:
def main():
    
    num_epochs = 100
    use_early_stopping = True
    use_scheduler = True
    attention_type = 'mil_pool_max' # choose among attention, gated_attention, mil_pool_mean, mil_pool_max
    
    if attention_type == 'attention': 
        model = Attention(Model()).cuda() if torch.cuda.is_available() else Attention(Model())
    elif attention_type == 'gated_attention':
        model = GatedAttention(Model()).cuda() if torch.cuda.is_available() else GatedAttention(Model())
    elif attention_type == 'mil_pool_mean':
        model = MIL_pool(Model(), 'mean').cuda() if torch.cuda.is_available() else MIL_pool(Model(), 'mean')
    elif attention_type == 'mil_pool_max':
        model = MIL_pool(Model(), 'max').cuda() if torch.cuda.is_available() else MIL_pool(Model(), 'max')
    else:
        raise NotImplementedError('Attention mechanism is not implemented or does not exist')
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-3)
    criterion = torch.nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.9)

    transform = transforms.Compose([
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5,), (0.5,)),
                                   ])
    
    train_set = datasets.MNIST('./data', download=True, train=True, transform=transform)
    valid_set = datasets.MNIST('./data', download=True, train=False, transform=transform)
    
    num_train = len(train_set)
    num_valid = len(valid_set)
    
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=num_train, shuffle=False)
    valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=num_valid, shuffle=False)
        
    train_loader_bags = torch.utils.data.DataLoader(BagDataset(
                                                        loader=train_loader,
                                                        dataset_length=num_train,
                                                        target_number=9,
                                                        mean_bag_length=10,
                                                        var_bag_length=2,
                                                        num_bag=100,
                                                        seed=1,
                                                        ),
                                                    batch_size=1,
                                                    shuffle=True)
    valid_loader_bags = torch.utils.data.DataLoader(BagDataset(
                                                        loader=valid_loader,
                                                        dataset_length=num_valid,
                                                        target_number=9,
                                                        mean_bag_length=10,
                                                        var_bag_length=2,
                                                        num_bag=250,
                                                        seed=1,
                                                        ),
                                                    batch_size=1,
                                                    shuffle=False)
    
    train_losses , train_accuracies = [],[]
    valid_losses , valid_accuracies = [],[]
    
    if use_early_stopping:
        early_stopping = EarlyStopping(patience=30, verbose=False, delta=1e-4)
 
    for epoch in range(1, num_epochs+1):
        
        train_loss, train_accuracy = training(epoch,model,train_loader_bags,optimizer,criterion)
        valid_loss, valid_accuracy = validation(epoch,model,valid_loader_bags,criterion)
        
        if use_scheduler:
            scheduler.step()
            
        train_losses.append(train_loss)
        train_accuracies.append(train_accuracy)
        valid_losses.append(valid_loss)
        valid_accuracies.append(valid_accuracy)
             
        if use_early_stopping: 
            early_stopping(valid_loss, model)
            
            if early_stopping.early_stop:
                print('Early stopping at epoch', epoch)
                #model.load_state_dict(torch.load('checkpoint.pt'))
                break
     
    plt.plot(range(1,len(train_losses)+1), train_losses, color='b', label = 'training loss')
    plt.plot(range(1,len(valid_losses)+1), valid_losses, color='r', linestyle='dashed', label = 'validation loss')
    plt.legend(), plt.ylabel('loss'), plt.xlabel('epochs'), plt.title('Loss'), plt.show()
     
    plt.plot(range(1,len(train_accuracies)+1),train_accuracies, color='b', label = 'training accuracy')
    plt.plot(range(1,len(valid_accuracies)+1),valid_accuracies, color='r', linestyle='dashed', label = 'validation accuracy')
    plt.legend(), plt.ylabel('loss'), plt.xlabel('epochs'), plt.title('Accuracy'), plt.show()

In [None]:
main()