In [None]:
#default_exp information_regularization

In [None]:
#export
from functools import partial

import pickle
import numpy as np
import torch

import matplotlib.pyplot as plt

from torch import nn
import torch.nn.functional as F

import torch.optim as optim
import torchvision
from torchvision import transforms

import ignite
from ignite.engine import create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Loss, Accuracy, Fbeta
from ignite.contrib.handlers import FastaiLRFinder, ProgressBar
import tqdm

In [None]:
#export


def weights_init(m):
    if isinstance(m, nn.Conv2d):
        torch.nn.init.kaiming_normal_(m.weight)
        torch.nn.init.zeros_(m.bias)

# Information Theory for Neural Network Regularization

We are going to use CIFAR dataset.

## CIFAR dataset

In [None]:
!ls data

In [None]:
#export
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
stats = (np.array([ 0.4914 ,  0.48216,  0.44653]), np.array([ 0.24703,  0.24349,  0.26159]))

In [None]:
#export
img_size = 32 
batch_size = 100 
normalization_values = torch.tensor(((0.4914, 0.4822, 0.4465), (1, 1, 1)))

## Whitening

In [None]:
#export

cifar_whitening_matrix = np.load(open('data/cifar_Z.npy', 'rb')).astype('float32')
cifar_mean = np.load(open('data/cifar_mean.npy', 'rb')).reshape(-1).astype('float32')


In [None]:
class BatchLinearTransformation:
    def __init__(self, transformation_matrix, transformation_mean):
        if transformation_matrix.size(0) != transformation_matrix.size(1):
            raise ValueError("transformation_matrix should be square. Got " +
                             "[{} x {}] rectangular matrix.".format(*transformation_matrix.size()))
        self.transformation_matrix = transformation_matrix
        self.transformation_mean = transformation_mean

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (N, C, H, W) to be whitened.
        Returns:
            Tensor: Transformed image.
        """
        if tensor.size(1) * tensor.size(2) * tensor.size(3) != self.transformation_matrix.size(0):
            raise ValueError("tensor and transformation matrix have incompatible shape." +
                             "[{} x {} x {}] != ".format(*tensor[0].size()) +
                             "{}".format(self.transformation_matrix.size(0)))
        batch = tensor.size(0)

        flat_tensor = tensor.view(batch, -1)
        transformed_tensor = torch.mm(flat_tensor - self.transformation_mean, self.transformation_matrix)

        tensor = transformed_tensor.view(tensor.size())
        return tensor

In [None]:
whitening_transform = BatchLinearTransformation(
    torch.tensor(cifar_whitening_matrix).cuda(),
    torch.tensor(cifar_mean).cuda()
)

In [None]:
#export


def get_cifar_dataloaders(batch_size, num_workers=10):
    data_path = 'data/cifar10/'
    subsets = ['train', 'test']
    datasets = {
        subset: torchvision.datasets.CIFAR10(
            root='data',
            train=subset == 'train',
            transform=transforms.Compose([
                    transforms.ToTensor()
                ]
            ),
            download=True
        ) for subset in subsets 
    }
    train_dl = torch.utils.data.DataLoader(
        datasets['train'],
        batch_size,
        num_workers=num_workers,
        shuffle=True
    )
    test_dl = torch.utils.data.DataLoader(
        datasets['test'],
        batch_size,
        num_workers=num_workers,
        shuffle=False
    )
    return train_dl, test_dl 
    

In [None]:
!ls data/cifar10

In [None]:
#export

cifar_dl_train, cifar_dl_test = get_cifar_dataloaders(batch_size)

In [None]:
#export


def to_model_dtype(model, x):
    is_cuda = next(model.parameters()).is_cuda
    dtype =  next(model.parameters()).dtype
    if is_cuda:
        x = x.cuda()
    if dtype is torch.float16:
        x = x.half()
    return x
  

In [None]:
sample_mini_batch_size = 2
x = torch.tensor(np.ones((sample_mini_batch_size, 3, 32, 32), dtype='float32'))

In [None]:
#export


class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)


class ConvDropoutBlock(nn.Module):
    
    def __init__(
            self,
            n_features_in,
            n_features_out,
            kernel_size,
            activation=nn.ReLU(),
            padding=0,
            stride=1,
            dropout_rate=None,
            use_information_dropout=False):    
        super(ConvDropoutBlock, self).__init__()
        self.n_features_in = n_features_in
        self.n_features_out = n_features_out
        self.kernel_size = kernel_size
        self.padding = padding
        self.stride = stride
        if use_information_dropout:
            self.use_dropout = False 
            self.use_information_dropout = True 
            conv = nn.Conv2d(n_features_in, n_features_out, kernel_size, padding=padding, stride=stride)
            self.dropout_layer = InfoDropout(conv, activation)
        elif dropout_rate is not None:
            self.use_dropout = True
            self.use_information_dropout = False
            self.dropout_layer = nn.Dropout2d(dropout_rate)
            self.conv = nn.Conv2d(n_features_in, n_features_out, kernel_size, padding=padding, stride=stride)
        else:
            self.use_dropout = False 
            self.use_information_dropout = False 
    
    def forward(self, X):
        if self.use_information_dropout:
            X, information_dropout_loss = self.dropout_layer(X)
        elif self.use_dropout:
            X = self.conv(X)
            X = self.dropout_layer(X)
            information_dropout_loss = to_model_dtype(self, torch.tensor(0))
        else:
            information_dropout_loss = to_model_dtype(self, torch.tensor(0))
        return X, information_dropout_loss


class AllConvNet(nn.Module):
    
    def __init__(
        self,
        n_filters=[96, 192],
        n_classes=10,
        init_dropout_rate=0.2,
        dropout_rate=0.5,
        use_information_dropout=False,
        kernel_size=3,
        activation=nn.ReLU(),
        **kwargs
    ):
        super(AllConvNet, self).__init__()
        self.init_dropout_rate = init_dropout_rate
        self.conv1 = torch.nn.Sequential(
            nn.Conv2d(kernel_size, n_filters[0], 3, padding=1),
            nn.BatchNorm2d(n_filters[0]),
            activation,
            nn.Conv2d(n_filters[0], n_filters[0], kernel_size, padding=1),
            nn.BatchNorm2d(n_filters[0]),
            activation,
            nn.Conv2d(n_filters[0], n_filters[0], kernel_size, padding=1),
            nn.BatchNorm2d(n_filters[0]),
            activation,
        )
        self.conv1_dropout = ConvDropoutBlock(n_filters[0], n_filters[0], kernel_size, dropout_rate=dropout_rate, use_information_dropout=use_information_dropout, stride=2)
        self.conv2 = torch.nn.Sequential(
            nn.Conv2d(n_filters[0], n_filters[1], kernel_size, padding=1),
            nn.BatchNorm2d(n_filters[1]),
            activation,
            nn.Conv2d(n_filters[1], n_filters[1], kernel_size, padding=1),
            nn.BatchNorm2d(n_filters[1]),
            activation,
        )
        self.conv2_dropout = ConvDropoutBlock(n_filters[1], n_filters[1], kernel_size, dropout_rate=dropout_rate, use_information_dropout=use_information_dropout, stride=2, padding=1)
        self.conv3 = torch.nn.Sequential(
            nn.Conv2d(n_filters[1], n_filters[1], kernel_size, padding=1),
            nn.BatchNorm2d(n_filters[1]),
            activation,
            nn.Conv2d(n_filters[1], n_filters[1], 1),
            nn.BatchNorm2d(n_filters[1]),
            activation,
            nn.Conv2d(n_filters[1], 10, 1),
            nn.BatchNorm2d(10),
            activation,
        )
        self.reshape = torch.nn.Sequential(
            nn.AvgPool2d(8),
            Flatten(),
            nn.Linear(10, n_classes)
        )
        
    def forward(self, X):
        X = nn.Dropout2d(self.init_dropout_rate)(X)
        X = self.conv1(X)
        X, info_loss_1 = self.conv1_dropout(X)
        X = self.conv2(X)
        X, info_loss_2 = self.conv2_dropout(X)
        X = self.conv3(X)
        return self.reshape(X), info_loss_1 + info_loss_2

In [None]:
#export


class LossSumWrapper(nn.Module):
    
    def __init__(self, loss, beta=0.0):
        super(LossSumWrapper, self).__init__()
        self.loss = loss
        self.beta = beta
        
    def forward(self, inputs, target, **kwargs):
        input, aux_loss = inputs
        loss_value = self.loss(input, target)
        return loss_value + self.beta * loss_value

In [None]:
net = AllConvNet(use_information_dropout=False)

In [None]:
net.conv1_dropout.use_information_dropout

In [None]:
next(net.conv1_dropout.parameters()).is_cuda

In [None]:
with torch.autograd.no_grad():
    preds = net(to_model_dtype(net, x))

In [None]:
preds[1].device

In [None]:
assert preds[0].numpy().shape == (2, 10)
assert preds[1].numpy().shape == ()

In [None]:
wrapped_loss = LossSumWrapper(torch.nn.CrossEntropyLoss(), beta=0.01)

In [None]:
wrapped_loss_value = wrapped_loss(preds, torch.tensor(np.ones(sample_mini_batch_size, dtype='int')))

assert wrapped_loss_value.numpy().shape == ()

# [Regularizing Neural Networks by Penalizing Confident Output Distributions](https://openreview.net/pdf?id=HyhbYrGYe)


In [None]:
#export


def log_sum_exp(x):
    x_max = x.max(axis=1).values
    return x_max + torch.log(torch.exp(x.T - x_max).sum())


def entropy_from_logits(logits, eps=1e-4):
    logits_lse = log_sum_exp(logits)
    p = F.softmax(logits.T, dim=1)
    return - ((logits.T - log_sum_exp(logits)) * p).sum(axis=0)


class EntropyPenalizedLogLoss(nn.Module):
    
    def __init__(self, beta):
        super(EntropyPenalizedLogLoss, self).__init__()
        self.beta = beta
        
    def forward(self, input, target, **kwargs):
        cross_entropy = F.cross_entropy(input, target)
        return cross_entropy + self.beta * entropy_from_logits(input).mean()

In [None]:
x_t = torch.tensor(np.random.randn(50, 2))

In [None]:
assert entropy_from_logits(x_t).numpy().shape == (50,)

In [None]:
epll = EntropyPenalizedLogLoss(beta=0.01)
y =torch.tensor(np.ones([50], dtype=int)) 

In [None]:
assert epll(x_t, y).numpy().shape == ()

In [None]:
in_t = torch.Tensor(
    np.random.rand(1, 3, 32, 32),
)

# [Information Dropout](https://arxiv.org/pdf/1611.01353.pdf)

In [None]:
#export


class InfoDropout(nn.Module):
    
    def __init__(
            self,
            wrapped_layer,
            activation,
            max_alpha=0.7,
            min_alpha=0.001,
        ):
        input_dim = wrapped_layer.in_channels
        output_dim = wrapped_layer.out_channels
        super(InfoDropout, self).__init__()
        self.get_alpha = nn.Sequential(
            nn.Conv2d(
                input_dim,
                output_dim, 
                kernel_size=wrapped_layer.kernel_size,
                padding=wrapped_layer.padding,
                stride=wrapped_layer.stride),
            nn.Sigmoid()
        )
        self.layer = nn.Sequential(
            wrapped_layer,
            nn.BatchNorm2d(output_dim),
            activation
        ) 

        self.kl_loss = self.make_kl_loss(activation)
        self.max_alpha = max_alpha
        self.min_alpha = min_alpha
        
    def forward(self, X):
        X_out = self.layer(X)
        alpha = self.min_alpha + self.max_alpha * self.get_alpha(X)
        eps = self.sample_lognormal(alpha)
        X_out_trunc = torch.where(X_out > 0, X_out, self.min_alpha * torch.ones_like(X_out))

        kl_loss = self.kl_loss(torch.log(X_out_trunc), alpha)
        if self.training:
            X_out = eps * X_out
        return X_out, kl_loss.mean()
        
    def sample_lognormal(self, sigma):
        batch_size = sigma.size()[0]
        shape = sigma.size()[1:]
        zeros = to_model_dtype(self, torch.zeros(shape))
        ones = to_model_dtype(self, torch.ones(shape))
        gaussian = torch.distributions.Normal(zeros, ones)
        random_normal_sample = gaussian.sample([batch_size])
        return torch.exp(sigma * random_normal_sample)
    
    def make_kl_loss(self, activation):
        if isinstance(activation, nn.Softplus):
            def _get_kl_loss(mu, sigma):
                self.mu1 = torch.nn.Parameter(torch.zeros([]))
                self.sigma1 = torch.nn.Parameter(torch.ones([])) 
                sigma1 = self.sigma1
                mu1 = self.mu1
                kl = 0.5 * ((sigma / sigma1) ** 2 + (mu - mu1)** 2/ sigma1 ** 2 - 1 + 2 * (torch.log(sigma1) - torch.log(sigma)))
                return kl.view(kl.size(0), -1).mean(dim=1)
        elif isinstance(activation, nn.ReLU):
            def _get_kl_loss(mu=None, alpha=None):
                kl = - torch.log(alpha / (self.max_alpha + self.min_alpha))
                return kl.view(kl.size(0), -1).mean(dim=1)
        return _get_kl_loss
 

In [None]:
dropout = InfoDropout(nn.Conv2d(3, 3, 3, padding=1), activation=nn.ReLU())
assert dropout(x)[0].shape == x.shape

In [None]:
def convert_to_fp16(model):
    model.half().cuda()
    for layer in model.modules():
        if isinstance(layer, nn.BatchNorm2d):
            layer.float()
    return model

In [None]:
all_conv_net = AllConvNet(use_information_dropout=True)
model = all_conv_net.cuda()
loss = torch.nn.CrossEntropyLoss()#, beta=0.01)

In [None]:
model(x.cuda())[0]

In [None]:
y = torch.ones([2], dtype=torch.long).cuda()

In [None]:
y.shape

In [None]:
y_pred = model(x.cuda())[0]#.shape#, y)

## Ignite training operators

In [None]:
#export


def process_function(engine, batch, model, loss, optimizer, beta):
    model.train()
    optimizer.zero_grad()
    x, y = batch
    x = whitening_transform(x.cuda())
    y_pred, kl_loss_value = model(x)
    y = y.cuda()
    loss_value = F.cross_entropy(y_pred, y)
    total_loss_value = loss(y_pred, y) + beta * kl_loss_value
    total_loss_value.backward()
    optimizer.step()
    return y_pred, y, {'loss': total_loss_value.item(), 'log_loss': loss_value.item(), 'kl_loss': kl_loss_value.item()}


def evaluate_function(engine, batch, model, loss, beta):
    model.eval()
    with torch.no_grad():
        x, y = batch
        x = whitening_transform(x.cuda())
        y = y.cuda()
        y_pred, kl_loss_value = model(x)
        kl_loss_value = kl_loss_value.cpu()
        log_loss_value = F.cross_entropy(y_pred, y).cpu()
        total_loss_value = beta * kl_loss_value.numpy() + loss(y_pred, y).cpu().numpy()
        y = y.cpu()
        y_pred = y_pred.cpu().float()
        kwargs = {
            'loss': total_loss_value,
            'kl_loss': kl_loss_value,
            'log_loss': log_loss_value
        }
        return y_pred, y, kwargs

In [None]:
#export
import tensorboardX


def print_logs(engine, evaluator, dataloader, mode, history_dict, tb_writer):
    evaluator.run(dataloader, max_epochs=1)
    metrics = evaluator.state.metrics
    loss = metrics['loss']
    log_loss = metrics['log_loss']
    accuracy = metrics['accuracy']
    kl_loss = metrics['kl_loss']
    print(mode + " Results - Epoch {}".format(engine.state.epoch))
    if mode == 'Validation':
        print('Accuracy: {}'.format(accuracy))
    print(
        "loss: {:.3f} log loss: {:.3f} kl_loss: {:.3f}"
        .format(loss, log_loss, kl_loss))
    if mode == 'Validation':
        print()
    
    for key in metrics.keys():
        history_dict[key].append(metrics[key])
    tb_writer.add_scalars(
        mode,
        {
            "loss": loss,
            "kl_loss": kl_loss,
            "log_loss": log_loss,
        }, 
        engine.state.epoch)
    tb_writer.add_scalar(mode + "/accuracy", accuracy, engine.state.epoch)


In [None]:
#export
def run_training_loop(model, loss, epochs, beta=3.0):
    writer = tensorboardX.SummaryWriter('./logs')
    trainer = ignite.engine.Engine(partial(process_function, model=model, loss=loss, optimizer=optimizer, beta=beta))
    evaluator = ignite.engine.Engine(partial(evaluate_function, model=model, loss=loss, beta=beta))
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[80, 120, 160])

    Loss(nn.CrossEntropyLoss(), output_transform=lambda x: [x[0], x[1]]).attach(evaluator, 'log_loss')
    accuracy = Accuracy(output_transform=lambda x: [x[0], x[1]])
    accuracy.attach(trainer, 'accuracy')
    accuracy.attach(evaluator, 'accuracy')

    Loss(lambda *input, **kwargs: kwargs['kl_loss']).attach(evaluator, 'kl_loss')
    Loss(lambda *input, **kwargs: kwargs['loss']).attach(evaluator, 'loss')
    
    training_history = {'log_loss': [], 'loss': [], 'kl_loss': [], 'accuracy': []}
    validation_history = {'log_loss': [], 'loss': [], 'kl_loss': [], 'accuracy': []}
    
    trainer.add_event_handler(ignite.engine.Events.EPOCH_COMPLETED, print_logs, evaluator, cifar_dl_train, 'Training', training_history, writer)
    trainer.add_event_handler(ignite.engine.Events.EPOCH_COMPLETED, print_logs, evaluator, cifar_dl_test, 'Validation', validation_history, writer)
    trainer.add_event_handler(ignite.engine.Events.EPOCH_COMPLETED, lambda engine: scheduler.step())
    
    trainer.run(cifar_dl_train, max_epochs=epochs)
    return model, evaluator, training_history, validation_history

In [None]:
#export
lr = 0.01
momentum = 0.9
weight_decay = 0.001
beta = 3.0
use_information_dropout = True
use_entropy_penalization = False
entropy_penalization_beta = 0.01
n_epochs = 200


all_conv_net = AllConvNet(use_information_dropout=use_information_dropout)
if use_entropy_penalization:
    loss = EntropyPenalizedLogLoss(entropy_penalization_beta)
else:
    loss = torch.nn.CrossEntropyLoss()
    
model = all_conv_net.cuda()
model.apply(weights_init)

optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)

In [None]:
#export

run_training_loop(model, loss, n_epochs)

In [None]:
#export
import dill
torch.save(model, open('info_dropout_low.pkl', 'wb'), dill)