<center><img src='https://drive.google.com/uc?id=1_utx_ZGclmCwNttSe40kYA6VHzNocdET' height="60"></center>

AI TECH - Akademia Innowacyjnych Zastosowań Technologii Cyfrowych. Program Operacyjny Polska Cyfrowa na lata 2014-2020
<hr>

<center><img src='https://drive.google.com/uc?id=1BXZ0u3562N_MqCLcekI-Ens77Kk4LpPm'></center>

<center>
Projekt współfinansowany ze środków Unii Europejskiej w ramach Europejskiego Funduszu Rozwoju Regionalnego
Program Operacyjny Polska Cyfrowa na lata 2014-2020,
Oś Priorytetowa nr 3 "Cyfrowe kompetencje społeczeństwa" Działanie  nr 3.2 "Innowacyjne rozwiązania na rzecz aktywizacji cyfrowej"
Tytuł projektu:  „Akademia Innowacyjnych Zastosowań Technologii Cyfrowych (AI Tech)”
    </center>

This exercise covers two aspects:
* In tasks 1-6 you will implement mechanisms that allow training deeper models. After doing each of the tasks you can look at the plots and check how your changes impact gradients of network layers.
* In task 7 you will implement a convnet using [conv2d](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html) and try to reach 99% accuracy on MNIST.

For this notebook, running on CPU is sufficient: tasks 1–5. run under a minute, 6. under two minutes, and 7. up to 5 minutes.

Tasks:
1. Check that the given implementation reaches 95% test accuracy for
   the architecture `input-64-64-10` in 8 epochs.
2. Modify `Linear.reset_parameters()` to improve initialization and
   check that the network learns much faster and reaches over 97% test accuracy.
   A good basic initialization scheme is so-called *Glorot initialization*.
   For a linear layer with $n_{in}$ and  $n_{out}$ input and output features,
   it samples each weight from a normal distribution with $0$ mean and a standard deviation of $\sqrt{\frac{2}{n_{in}+n_{out}}}$.  
   Check how that changes the distribution of gradients at the first epoch.
3. Check, that with proper initialization we can train the architecture
   `input-64-64-64-64-64-64-64-10` (seven hidden layers of `64`), for e.g. 3 epochs,
   while with the original, bad initialization, it does not even get off the ground.
4. Implement Dropout as a PyTorch module (without using torch.nn.Dropout).
   Use `p=0.25` as the zeroing probability.
   Use Dropout after activations of each layer, except the last.
5. Check that with 10 hidden layers (64 units each), even with proper initialization,
   but with dropout added, the network has a hard time to start learning.
6. Implement batch normalization (without using torch.nn.BatchNorm).
    * compute batch mean and variance
    * add new variables beta and gamma
    * check that the network learns even for 10 hidden layers.
    * check how gradients change in comparison to network without batch norm.
   
   In pseudo-code:
    ```
		if training:
		    var, mean = torch.var_mean(x, dim=/batch/)
		    update running_mean, running_var
		else:
			mean, var = running_mean, running_var
		x = (x - mean) / sqrt(var + eps)
		return x * gamma + beta
    ```

7. Design and implement a simple convolutional network and achieve 99% test accuracy (in up to 10 epochs, 5 should be enough). The architecture is up to you, but even a few convolutional layers should be enough.
You may use `torch.nn.Conv2d, BatchNorm2d, MaxPool2d, ReLU, Dropout2d, Linear, Flatten, BatchNorm1d`.
Don't make it too generic, just try a few fixed layers.

In [1]:
import math
import sys

import numpy as np
import PIL.Image
import plotly.express as px
import plotly.graph_objects as go
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from plotly.subplots import make_subplots
from torch import Tensor
from torch.nn import init

if 'google.colab' in sys.modules:
    from google.colab import output
    output.enable_custom_widget_manager()

In [2]:
# @title Visualize gradients

class GradientVisualizer:
    def __init__(self, net: nn.Module, num_epochs: int) -> None:
        self.num_epochs = num_epochs
        self.linear_layers = self.get_linear_layers(net)

        self.grad_to_weight_fig = None
        self.grads_in_layers_fig = None
        self.grads_at_epochs_fig = None
        self.init_figures()

    @staticmethod
    def get_linear_layers(net):
        linear_layers = []
        for field in net.modules():
            if isinstance(field, nn.Linear | Linear):
                linear_layers.append(field)
        assert linear_layers, "No linear layers found. Forgot to use Sequential or ModuleList?"
        return linear_layers

    def get_epochs_for_one_layer(self):
        """
        We want to show gradient distributions from up to 7 selected epochs
        for one linear layer.
        """
        if self.num_epochs < 7:
            return list(range(self.num_epochs))
        else:
            return torch.linspace(0, self.num_epochs - 1, 7).int().tolist()

    def get_three_epochs(self):
        """
        We want to show gradients distributions from all layers at each of
        three epochs: first, middle and last.
        """
        return [0, self.num_epochs // 2, self.num_epochs - 1]

    def rgb_to_rgba(self, rgb_color, epoch):
        """
        Value of epoch parameter determines how transparent color should be
        in comparison to others.
        Colors for earlier epochs should be more transparent/less visible.
        """
        return f'rgba{rgb_color[3:-1]},{0.6 * (epoch + 1) / self.num_epochs + 0.15})'

    def init_figures(self):
        # Initialize figure with gradient to weight ratio plot
        fig = go.Figure()
        fig.update_layout(
            title='Gradient stddev to weight stddev ratio', title_x=0.5,
            xaxis_title='Epoch',
            yaxis_title='Gradient to weight ratio (log scale)',
            height=400, width=1500, margin=dict(b=10, t=60)
        )
        fig.update_yaxes(type='log')
        for i in range(len(self.linear_layers)):
            fig.add_trace(go.Scatter(
                x=[], y=[],
                mode='lines+markers', name=f'Linear layer {i}'
            ))

        self.grad_to_weight_fig = go.FigureWidget(fig)
        display(self.grad_to_weight_fig)

        # Initialize figure visualizing gradient distributions in layers
        chosen_layers = [0, len(self.linear_layers) // 2, len(self.linear_layers) - 1]
        num_rows = (len(chosen_layers) - 1) // 3 + 1
        fig = make_subplots(
            rows=num_rows, cols=3,
            subplot_titles=[f'Linear layer {i}' for i in chosen_layers],
            vertical_spacing=0.2 / num_rows
        )
        fig.update_layout(
            title='Gradient distributions by epoch (for three chosen layers)', title_x=0.5,
            height=num_rows * 400, width=1500, margin=dict(b=10, t=60)
        )

        colors, _ = px.colors.convert_colors_to_same_type(2 * px.colors.qualitative.Plotly)
        for i, layer_num in enumerate(chosen_layers):
            row = i // 3 + 1
            col = i % 3 + 1
            fig.update_xaxes(
                title_text='Gradient value', range=(-0.1, 0.1), row=row, col=col
            )
            fig.update_yaxes(
                title_text='Density (log scale)', type='log', row=row, col=col
            )

            # Create empty traces and update them later with actual gradient distributions.
            # Unfortunately, we cannot add new traces dynamically because Colab has problem
            # with widgets from plotly (traces added dynamically are rendered twice).
            for epoch in self.get_epochs_for_one_layer():
                fig.add_trace(
                    go.Scatter(
                        mode='lines', name=f'Epoch {epoch + 1}',
                        line=dict(color=self.rgb_to_rgba(colors[layer_num], epoch)),
                        legendgroup=layer_num
                    ),
                    row=row, col=col
                )

        self.grads_in_layers_fig = go.FigureWidget(fig)
        display(self.grads_in_layers_fig)

        # Initialize figure comparing gradient distributions between layers at the
        # first, middle and last epoch
        selected_epochs_indices = self.get_three_epochs()
        fig = make_subplots(
            rows=1, cols=3,
            subplot_titles=[f'Epoch {epoch + 1}' for epoch in selected_epochs_indices]
        )
        fig.update_layout(
            title='Gradient distributions by layers (for three chosen epochs)', title_x=0.5,
            height=400, width=1500, margin=dict(b=10, t=60)
        )

        for col, epoch in enumerate(selected_epochs_indices, 1):
            fig.update_yaxes(title_text='Density (log scale)', type='log', row=1, col=col)
            fig.update_xaxes(
                title_text='Gradient value',
                range=(-0.05, 0.05) if epoch != 0 else (-1, 1),
                row=1, col=col
            )

            # Create empty traces and update them later with actual gradient distributions.
            for layer_num in range(len(self.linear_layers)):
                fig.append_trace(
                    go.Scatter(
                        mode='lines', name=f'Linear layer {layer_num}',
                        line=dict(color=colors[layer_num]), showlegend=(col == 1)
                    ),
                    row=1, col=col
                )

        self.grads_at_epochs_fig = go.FigureWidget(fig)
        display(self.grads_at_epochs_fig)

    def visualize_gradients(self, lr, epoch, batch_idx):
        # It is enough to use gradients calculated for the first batch.
        if batch_idx != 0:
            return

        epoch_grads = []
        epoch_grad_to_weight_ratios = []
        for layer in self.linear_layers:
            epoch_grads.append(layer.weight.grad.flatten().detach())
            epoch_grad_to_weight_ratios.append(
                (lr * layer.weight.grad.std() / layer.weight.std()).item()
            )

        # Update figure with gradient to weight ratio plot
        assert self.grad_to_weight_fig is not None
        for i, grad_to_weight_ratio in enumerate(epoch_grad_to_weight_ratios):
            x = self.grad_to_weight_fig.data[i].x
            next_x_val = x[-1] + 1 if x else 1
            self.grad_to_weight_fig.data[i].x += (next_x_val, )
            self.grad_to_weight_fig.data[i].y += (grad_to_weight_ratio, )

        # Update figure visualizing gradient distributions in layers
        assert self.grads_in_layers_fig is not None
        chosen_layers = [0, len(self.linear_layers) // 2, len(self.linear_layers) - 1]
        selected_epochs = self.get_epochs_for_one_layer()
        if epoch in selected_epochs:
            epoch_idx = selected_epochs.index(epoch)
            for layer_num, layer_grad in enumerate(epoch_grads):
                try:
                    layer_id = chosen_layers.index(layer_num)
                except ValueError:
                    continue
                trace_idx = layer_id * len(selected_epochs) + epoch_idx
                try:
                    hy, hx = torch.histogram(layer_grad, bins=50, density=True)
                except RuntimeError:
                    continue
                hy = hy / max(hy) + 0.001
                self.grads_in_layers_fig.data[trace_idx].x = hx[:-1].tolist()
                self.grads_in_layers_fig.data[trace_idx].y = hy.tolist()


        # Update figure visualizing gradient distributions at epochs
        assert self.grads_at_epochs_fig is not None
        selected_epochs = self.get_three_epochs()
        if epoch in selected_epochs:
            epoch_idx = selected_epochs.index(epoch)
            for layer_num, layer_grad in enumerate(epoch_grads):
                trace_idx = epoch_idx * len(self.linear_layers) + layer_num
                try:
                    hy, hx = torch.histogram(layer_grad, bins=50, density=True)
                except RuntimeError:
                    continue
                hy = hy / max(hy) + 0.001
                self.grads_at_epochs_fig.data[trace_idx].x = hx[:-1].tolist()
                self.grads_at_epochs_fig.data[trace_idx].y = hy.tolist()

In [3]:
class Linear(nn.Module):
    def __init__(self, in_features: int, out_features: int, glorot_init: bool) -> None:
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(Tensor(out_features, in_features))
        self.bias = nn.Parameter(Tensor(out_features))
        self.glorot_init = glorot_init
        self.reset_parameters()

    def reset_parameters(self) -> None:
        if self.glorot_init:
            # TODO Task 2.
        else:
            self.weight.data.normal_(mean=0, std=0.25)
        init.zeros_(self.bias)

    def forward(self, x: Tensor) -> Tensor:
        """
        Input: shape (*B, in_features)
        Output: shape (*B, out_features)
        """
        r = x.matmul(self.weight.T)
        r += self.bias
        return r


class Net(nn.Module):
    def __init__(self, glorot_init: bool = False) -> None:
        super().__init__()
        self.fc1 = Linear(784, 64, glorot_init=glorot_init)
        self.fc2 = Linear(64, 64, glorot_init=glorot_init)
        self.fc3 = Linear(64, 10, glorot_init=glorot_init)

    def forward(self, x: Tensor) -> Tensor:
        """
        Input: shape (B, H=28, W=28)
        Output: shape (B, num_classes=10)
        """
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [4]:
class MnistTrainer:
    def __init__(self, batch_size: int = 128) -> None:
        transform = transforms.Compose([transforms.ToTensor()])
        self.trainset = torchvision.datasets.MNIST(
            root='./data', download=True, train=True, transform=transform
        )
        self.trainloader = torch.utils.data.DataLoader(
            self.trainset, batch_size=batch_size, shuffle=True, num_workers=2
        )
        self.testset = torchvision.datasets.MNIST(
            root='./data', train=False, download=True, transform=transform
        )
        self.testloader = torch.utils.data.DataLoader(
            self.testset, batch_size=1000, shuffle=False, num_workers=2
        )

    def train(
        self,
        net: nn.Module,
        gradient_visualizer: GradientVisualizer | None = None,
        epochs: int = 20,
        lr: float = 0.05,
        momentum: float = 0.9
    ) -> None:
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(net.parameters(), lr=lr, momentum=momentum)
        log_freq = 100

        for epoch in range(epochs):
            net.train()
            running_loss = 0.0
            for i, data in enumerate(self.trainloader):
                inputs, labels = data
                optimizer.zero_grad()

                outputs = net(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                if gradient_visualizer is not None:
                    gradient_visualizer.visualize_gradients(lr, epoch, i)
                optimizer.step()

                running_loss += loss.item()
                if i % log_freq == log_freq - 1:
                    running_loss /= log_freq
                    print(f"epoch {epoch + 1:>5}, batch {i + 1:>5}, loss {running_loss:.3f}")
                    running_loss = 0.0
            self.test(net)

    def test(self, net: nn.Module) -> list[float]:
        net.eval()
        correct = 0
        total = 0
        all_gt_logits = []
        with torch.no_grad():
            for data in self.testloader:
                images, labels = data
                outputs = net(images)
                gt_logits = torch.gather(outputs, dim=1, index=labels.unsqueeze(-1))
                all_gt_logits.extend(float(logit.item()) for logit in gt_logits)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        print(f"Accuracy of the network on the {total} test images: {correct / total:.2%}")
        return all_gt_logits

    def show_worst_misclassified(self, net: nn.Module, k: int = 10) -> None:
        all_gt_logits = self.test(net)
        _, indices = torch.topk(-torch.tensor(all_gt_logits), k)
        imgs = []
        labels = []
        for i in indices:
            img, label = self.testset[int(i.item())]
            imgs.append(img)
            labels.append(label)
        image = PIL.Image.fromarray(
            torch.cat(imgs, dim=-1).mul(255).to(torch.uint8).permute(1, 2, 0).numpy().squeeze()
        )
        display(image.resize((5 * image.width, 5 * image.height), PIL.Image.Resampling.NEAREST))
        print(labels)

In [5]:
epochs = 20

net = Net()
gradient_visualizer = GradientVisualizer(net, epochs)

FigureWidget({
    'data': [{'mode': 'lines+markers',
              'name': 'Linear layer 0',
              'type': 'scatter',
              'uid': '911ccbed-1e66-4e5b-9377-cc494ea5519e',
              'x': [],
              'y': []},
             {'mode': 'lines+markers',
              'name': 'Linear layer 1',
              'type': 'scatter',
              'uid': 'c8b6c3d4-206a-4a2d-8b7f-ba6aabc00d59',
              'x': [],
              'y': []},
             {'mode': 'lines+markers',
              'name': 'Linear layer 2',
              'type': 'scatter',
              'uid': 'b920a3fd-a560-4558-9078-ebc973642d9d',
              'x': [],
              'y': []}],
    'layout': {'height': 400,
               'margin': {'b': 10, 't': 60},
               'template': '...',
               'title': {'text': 'Gradient standard deviation to weight standard deviation ratio', 'x': 0.5},
               'width': 1500,
               'xaxis': {'title': {'text': 'Epoch'}},
               'yaxis

FigureWidget({
    'data': [{'legendgroup': '0',
              'line': {'color': 'rgba(99, 110, 250,0.18)'},
              'mode': 'lines',
              'name': 'Epoch 1',
              'type': 'scatter',
              'uid': '3c247f48-e53a-49ab-abf3-d0cce00fbf2b',
              'xaxis': 'x',
              'yaxis': 'y'},
             {'legendgroup': '0',
              'line': {'color': 'rgba(99, 110, 250,0.27)'},
              'mode': 'lines',
              'name': 'Epoch 4',
              'type': 'scatter',
              'uid': 'ec0b95b2-f059-46db-971b-9d9b9e2d8175',
              'xaxis': 'x',
              'yaxis': 'y'},
             {'legendgroup': '0',
              'line': {'color': 'rgba(99, 110, 250,0.36)'},
              'mode': 'lines',
              'name': 'Epoch 7',
              'type': 'scatter',
              'uid': 'b1b1797a-35ee-4085-a7be-140d3ede3e00',
              'xaxis': 'x',
              'yaxis': 'y'},
             {'legendgroup': '0',
              'line': {'

FigureWidget({
    'data': [{'line': {'color': 'rgb(99, 110, 250)'},
              'mode': 'lines',
              'name': 'Linear layer 0',
              'showlegend': True,
              'type': 'scatter',
              'uid': 'a596b9b7-361b-460c-beff-4d68469d6706',
              'xaxis': 'x',
              'yaxis': 'y'},
             {'line': {'color': 'rgb(239, 85, 59)'},
              'mode': 'lines',
              'name': 'Linear layer 1',
              'showlegend': True,
              'type': 'scatter',
              'uid': '0b91c57a-021e-43a1-9aa2-e355f6e1e25a',
              'xaxis': 'x',
              'yaxis': 'y'},
             {'line': {'color': 'rgb(0, 204, 150)'},
              'mode': 'lines',
              'name': 'Linear layer 2',
              'showlegend': True,
              'type': 'scatter',
              'uid': 'f347469e-9dc2-4370-9e47-3253b9ce4eab',
              'xaxis': 'x',
              'yaxis': 'y'},
             {'line': {'color': 'rgb(99, 110, 250)'},
   

In [6]:
trainer = MnistTrainer(batch_size=128)
trainer.train(net, gradient_visualizer, epochs=epochs)

epoch     1, batch   100, loss 1.146
epoch     1, batch   200, loss 0.366
epoch     1, batch   300, loss 0.323
epoch     1, batch   400, loss 0.258
Accuracy of the network on the 10000 test images: 93.23%
epoch     2, batch   100, loss 0.212
epoch     2, batch   200, loss 0.208
epoch     2, batch   300, loss 0.194
epoch     2, batch   400, loss 0.194
Accuracy of the network on the 10000 test images: 94.53%
epoch     3, batch   100, loss 0.161
epoch     3, batch   200, loss 0.179
epoch     3, batch   300, loss 0.151
epoch     3, batch   400, loss 0.162
Accuracy of the network on the 10000 test images: 95.21%
epoch     4, batch   100, loss 0.144
epoch     4, batch   200, loss 0.129
epoch     4, batch   300, loss 0.130
epoch     4, batch   400, loss 0.141
Accuracy of the network on the 10000 test images: 95.38%
epoch     5, batch   100, loss 0.119
epoch     5, batch   200, loss 0.109
epoch     5, batch   300, loss 0.117
epoch     5, batch   400, loss 0.120
Accuracy of the network on the 1