This notebook was inspired by neural network & machine learning labs led by [GMUM](https://gmum.net/).

See also [How does Batch Normalization Help Optimization?](https://gradientscience.org/batchnorm/) and [Chapter 7](https://www.deeplearningbook.org/contents/regularization.html) of the Deep Learning book.

Some utils for today's class (run and hide the cell):

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torchvision
from torchvision.transforms import Compose, Lambda, ToTensor


def load_fashionmnist(train=True, shrinkage=None):
    dataset = torchvision.datasets.FashionMNIST(
        root='.',
        download=True,
        train=train,
        transform=Compose([ToTensor(), Lambda(torch.flatten)])
    )
    if shrinkage:
        dataset_size = len(dataset)
        perm = torch.randperm(dataset_size)
        idx = perm[:int(dataset_size * shrinkage)]
        return torch.utils.data.Subset(dataset, idx)
    return dataset


class ModelTrainer:
    def __init__(self, train_dataset, test_dataset, batch_size=128):
        self.batch_size = batch_size
        self.train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        self.test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    def train(self, model, optimizer, loss_fn=torch.nn.functional.cross_entropy, n_epochs=100):
        self.logs = {'train_loss': [], 'test_loss': [], 'train_accuracy': [], 'test_accuracy': []}
        model = model.to(self.device)
        correct, numel = 0, 0
        for e in range(1, n_epochs + 1):
            model.train()
            for x, y in self.train_loader:
                x = x.to(self.device)
                y = y.to(self.device)
                optimizer.zero_grad()
                output = model(x)
                y_pred = torch.argmax(output, dim=1)
                correct += torch.sum(y_pred == y).item()
                numel += self.batch_size
                loss = loss_fn(output, y)
                loss.backward()
                optimizer.step()

            self.logs['train_loss'].append(loss.item())
            self.logs['train_accuracy'].append(correct / numel)
            correct, numel = 0, 0

            model.eval()
            with torch.no_grad():
                for x_test, y_test in self.test_loader:
                    x_test = x_test.to(self.device)
                    y_test = y_test.to(self.device)
                    output = model(x_test)
                    y_pred = torch.argmax(output, dim=1)
                    correct += torch.sum(y_pred == y_test).item()
                    numel += self.batch_size
                loss = loss_fn(output, y_test)

            self.logs['test_loss'].append(loss.item())
            self.logs['test_accuracy'].append(correct / numel)
            correct, numel = 0, 0

        return self.logs


def show_results(orientation='horizontal', accuracy_bottom=None, loss_top=None, **histories):
    if orientation == 'horizontal':
        f, ax = plt.subplots(1, 2, figsize=(16, 5))
    else:
        f, ax = plt.subplots(2, 1, figsize=(16, 16))
    for i, (name, h) in enumerate(histories.items()):
        if len(histories) == 1:
            ax[0].set_title("Best test accuracy: {:.2f}% (train: {:.2f}%)".format(
                max(h['test_accuracy']) * 100,
                max(h['train_accuracy']) * 100
            ))
        else:
            ax[0].set_title("Accuracy")
        ax[0].plot(h['train_accuracy'], color='C%s' % i, linestyle='--', label='%s train' % name)
        ax[0].plot(h['test_accuracy'], color='C%s' % i, label='%s test' % name)
        ax[0].set_xlabel('epochs')
        ax[0].set_ylabel('accuracy')
        if accuracy_bottom:
            ax[0].set_ylim(bottom=accuracy_bottom)
        ax[0].legend()

        if len(histories) == 1:
            ax[1].set_title("Minimal train loss: {:.4f} (test: {:.4f})".format(
                min(h['train_loss']),
                min(h['test_loss'])
            ))
        else:
            ax[1].set_title("Loss")
        ax[1].plot(h['train_loss'], color='C%s' % i, linestyle='--', label='%s train' % name)
        ax[1].plot(h['test_loss'], color='C%s' % i, label='%s test' % name)
        ax[1].set_xlabel('epochs')
        ax[1].set_ylabel('loss')
        if loss_top:
            ax[1].set_ylim(top=loss_top)
        ax[1].legend()

    plt.show()

    
def test_dropout(dropout_cls):

    drop = dropout_cls(0.5)
    drop.train()
    x = torch.randn(10, 30)
    out = drop(x)

    for row in out:
        zeros_in_row = len(torch.where(row == 0.)[0]) 
        assert zeros_in_row > 0 and zeros_in_row < len(row)

    drop_eval = dropout_cls(0.5)
    drop_eval.eval()
    x = torch.randn(10, 30)
    out_eval = drop_eval(x)

    for row in out_eval:
        zeros_in_row = len(torch.where(row == 0.)[0]) 
        assert zeros_in_row == 0

        
def test_bn(bn_cls):

    torch.manual_seed(42)
    bn = bn_cls(num_features=100)

    opt = torch.optim.SGD(bn.parameters(), lr=0.1)

    bn.train()
    x = torch.rand(20, 100)
    out = bn(x)

    assert out.mean().abs().item() < 1e-4
    assert abs(out.var().item() - 1) < 1e-1

    assert (bn.sigma != 1).all()
    assert (bn.mu != 1).all()

    loss = 1 - out.mean()
    loss.backward()
    opt.step()

    assert (bn.beta != 0).all()
    
    n_steps = 10

    for i in range(n_steps):
        x = torch.rand(20, 100)
        out = bn(x)
        loss = 1 - out.mean()
        loss.backward()
        opt.step()

    torch.manual_seed(43)
    test_x = torch.randn(20, 100)
    bn.eval()
    test_out = bn(test_x)

    assert abs(test_out.mean() + 0.5) < 1e-1

# Regularization 
One of the main problems in machine learning is what happens when we run our model on new inputs. There are a lot of techniques and strategies designed to reduce test error, even at the expense of training error. Today we'll be talking about some of them. We'll be working (once again) with the FashionMNIST dataset. The cell below loads in the datasets. If the networks train too slowly, you can play with the `shrinkage` parameter, which determines how much of the dataset is used.

In [None]:
torch.manual_seed(44)

train_dataset = load_fashionmnist(train=True, shrinkage=0.01)
test_dataset = load_fashionmnist(train=False, shrinkage=0.1)

The cell below sets some hyperparameters for all of the models trained in the notebook. They should be set such that for all the models the learning curve flattens. The chosen hyperparameters should work, but the training might be a bit slow.

In [None]:
n_epochs = 300
learning_rate = 0.05
batch_size = 128

trainer = ModelTrainer(train_dataset, test_dataset, batch_size=batch_size)

## Task 1 (0.25p)
Use [`torch.nn.Sequential`](https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html) to create a simple neural network. It should have two hidden linear layers of size $256$ and ReLU activation functions and an output layer with the linear activation function (i.e. none). This network will serve as a baseline for today.

In [None]:
model = torch.nn.Sequential(
    ???
)

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
history = trainer.train(model, optimizer, n_epochs=n_epochs)
show_results(model=history)

If the above was defined correctly, you can see that after ~$50$ epochs (if you used the same hyperparameters as suggested above) the test loss starts rising and the test accuracy flattens out. This means that our model has started *overfitting* to the training set. If we want to get the best model out of this architecture, we would need to load in the parameters from the moment when the test accuracy was highest (so-called *early stoppping*, perhaps the most popular regularization technique in deep learning). Your task today will be to improve upon these results.

## Parameter Norm Penalties
A simple way of doing non-architecture-dependent regularization is by adding a penalty $\Omega(\theta)$ to the loss function $J(\theta)$:

$$\tilde{J}(\theta) = J(\theta) + \alpha \Omega(\theta),$$

where $\alpha \in [0, \infty)$ is a hyperparameter that weights the contribution of the penalty term $\Omega$.

When using such regularization for neural networks, we typically penalize only the weights and not the biases -- we do not introduce too much variance by leaving them unregularized and doing so can introduce significant underfitting.

$L_1$ and $L_2$ regularization, which you might know from regular machine learning, fall into this family of methods. We will not be talking about them today, but move to methods specific to neural networks.

## Task 2 (0.75p): Dropout
Dropout is regularization method where during training some number of outputs are randomly ignored (or *dropped out*). This prevents complex co-adaptations from arising (e.g. one neuron learning to fix the mistakes of another), which makes the model more robust.

![dropout](figures/dropout.png)
<center>Source: <a href="https://jmlr.org/papers/v15/srivastava14a.html">Dropout: A Simple Way to Prevent Neural Networks from Overfitting</a>.</center>

Dropout is not used after testing, but we need to correct for the fact that the network is used to smaller outputs, hence we need to scale the output by the chosen dropout rate.

In [None]:
class Dropout(torch.nn.Module):
    
    def __init__(self, p=0.5):
        super(Dropout, self).__init__()
        self.p = p
        
    def forward(self, x):
        if self.training:
            # hint: use torch.bernoulli
            ???            
        else:
            ???

In [None]:
test_dropout(Dropout)

Add dropout with probability $0.5$ to the baseline model after each hidden layer.

In [None]:
model = torch.nn.Sequential(
    ???
)

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
dropout_history = trainer.train(model, optimizer, n_epochs=n_epochs)
show_results(model=dropout_history)

### Questions:
1. How does the testing accuracy curve after applying dropout look in comparison to the baseline model? What does that suggest? 
2. How can we think of dropout as ensemble learning (combining many models at once)?

[your answers here]

## Task 3 (1p): Batch Normalization

Batch Normalization was introduced as a technique to reduce *internal covariate shift*. To understand this perspective, recall that training a neural network can be viewed as solving a collection of separate optimization problems -- one for each layer:

![layer based](figures/layerbased.jpg)
<center>Source: <a href="https://gradientscience.org/batchnorm/">How does Batch Normalization Help Optimization?</a></center>

During training, each step involves updating the parameters for all layers simultaneously. This implies that updates to the earlier layers change the input distribution of later layers, hence the optimization problems change at each step (this is the aforementioned internal covariate shift). To fix that, we whiten the input, so the input distribution is the same:

$$\hat{\mathbf{x}_i} = \frac{\mathbf{x}_i - \mathbf{\mu_B}}{\sqrt{\mathbf{\sigma^2_B} + \epsilon}},$$

where $\mathbf{\mu_B}=\frac{1}{m}\sum_{i=1}^m \mathbf{x}_i$ and $\mathbf{\sigma^2_B}  = \frac{1}{m} \sum_{i=1}^m (\mathbf{x}_i - \mathbf{\mu_B)^2}$ are the batch mean and batch variance respectively (ideally we'd want to do this over the whole training set, but in the context of stochastic gradient methods that would be impractical) and $\epsilon$ is added for numerical stability. To restore the representational power of the network, we modify $\hat{\mathbf{x}_i}$ with learned parameters $\gamma$ and $\beta$, so any mean and variance can be learned. In the end, the batch norm layer looks like this:

$$\mathtt{BN}(\mathbf{x}_i) = \gamma \hat{\mathbf{x}_i} + \beta.$$

During testing, we replace the batch statistics with population statistics computed during training via running means:

$$\mathbf{\overline{\mu}_{new}} =  (1 - \lambda) \mathbf{\overline{\mu}_{old}} + \lambda \mathbf{\mu_B},$$

$$\mathbf{\overline{\sigma}_{new}} = (1 - \lambda) \mathbf{\overline{\sigma}_{old}} + \lambda \mathbf{\sigma_B},$$

where $\lambda$ is the momentum term for the running means.

While the effectiveness of batch normalization is hard to dispute, the proposed mechanism is contested. See [How does Batch Normalization Help Optimization?](https://gradientscience.org/batchnorm/) for more on this topic.

When defining model parameters it can be useful to utilize [`torch.Parameter`](https://pytorch.org/docs/stable/generated/torch.nn.parameter.Parameter.html). When defining non-trainable parameters (e.g. running means) use [`torch.nn.Module.register_buffer`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html) instead.

In [None]:
class BatchNorm(torch.nn.Module):
    
    def __init__(self, num_features, eps=1e-05, momentum=0.1):
        super(BatchNorm, self).__init__()
        
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        
        ???

    def forward(self, x):
        if self.training:
            ???
            
        else:
            ???
        

In [None]:
test_bn(BatchNorm)

Add batch normalization to the baseline model after each hidden layer.

In [None]:
model = torch.nn.Sequential(
    ???
)

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
bn_history = trainer.train(model, optimizer, n_epochs=n_epochs)
show_results(model=bn_history)

### Questions:
1. Should batch normalization be used before or after the activation function? Why? (Hint: not sure there is a good answer.)
2. Can we think of batch normalization as regularization? Why? 

[your answers here]

In [None]:
show_results(vanilla=history, dropout=dropout_history, bn=bn_history, 
             orientation='vertical', accuracy_bottom=0.5, loss_top=1.75)