<center><img src='https://drive.google.com/uc?id=1_utx_ZGclmCwNttSe40kYA6VHzNocdET' height="60"></center>

AI TECH - Akademia Innowacyjnych Zastosowań Technologii Cyfrowych. Program Operacyjny Polska Cyfrowa na lata 2014-2020
<hr>

<center><img src='https://drive.google.com/uc?id=1BXZ0u3562N_MqCLcekI-Ens77Kk4LpPm'></center>

<center>
Projekt współfinansowany ze środków Unii Europejskiej w ramach Europejskiego Funduszu Rozwoju Regionalnego
Program Operacyjny Polska Cyfrowa na lata 2014-2020,
Oś Priorytetowa nr 3 "Cyfrowe kompetencje społeczeństwa" Działanie  nr 3.2 "Innowacyjne rozwiązania na rzecz aktywizacji cyfrowej"
Tytuł projektu:  „Akademia Innowacyjnych Zastosowań Technologii Cyfrowych (AI Tech)”
    </center>

Code based on https://github.com/pytorch/examples/blob/master/mnist/main.py

In this exercise we are using high level abstractions from torch.nn like nn.Linear.
Note: during the next lab session we will go one level deeper and implement more things
with bare hands.

Tasks:

    1. Read the code.

    2. Check that the given implementation reaches 95% test accuracy for architecture input-128-128-10 after few epochs.

    3. Add the option to use SGD with momentum instead of ADAM.

    4. Experiment with different learning rates. Use provided TrainingVisualizer
    to plot the learning curves and gradient to weight ratios. Compare visualizations
    for different learning rates for both ADAM and SGD with momentum.

    5. Parameterize the constructor by a list of sizes of hidden layers of the MLP.
    Note that this requires creating a list of layers as an atribute of the Net class,
    and one can't use a standard python list containing nn.Modules (why?).
    Check torch.nn.ModuleList.


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import plotly.graph_objects as go
import sys
if 'google.colab' in sys.modules:
    from google.colab import output
    output.enable_custom_widget_manager()

In [2]:
# @title Visualize gradients

class TrainingVisualizer:
    def __init__(self, log_interval=10):
        self.log_interval = log_interval
        self.train_loss_fig = self.init_line_plot(title='Training loss', xaxis_title='Step')
        self.grad_to_weight_fig = self.init_line_plot(
            title='Gradient standard deviation to weight standard deviation ratio at 1st layer',
            xaxis_title='Step',
            yaxis_title='Gradient to weight ratio (log scale)',
            yaxis_type='log'
        )
        self.test_acc_fig = self.init_line_plot(
            title='Test accuracy', x=[], xaxis_title='Epoch', mode='lines+markers'
        )

        # Parameters related to current tracked model and its training
        self.first_linear_layer = None
        self.lr = None
        self.trace_idx = -1

    def init_line_plot(
        self,
        title,
        x=None, xaxis_title=None,
        yaxis_title=None, yaxis_type='linear',
        mode='lines'
    ):
        fig = go.Figure()
        fig.update_layout(
            title=title, title_x=0.5,
            xaxis_title=xaxis_title, yaxis_title=yaxis_title,
            height=400, width=1500, margin=dict(b=10, t=60)
        )
        fig.update_yaxes(type=yaxis_type)
        # We cannot add new traces dynamically because Colab has problem with widgets
        # from plotly (traces added dynamically are rendered twice).
        # As an ugly workaround we create a lot of empty traces and update them later
        # with actual data. Empty traces are not plotted.
        for _ in range(25):
            fig.add_trace(go.Scatter(x=x, y=[], showlegend=True, mode=mode))

        fig_widget = go.FigureWidget(fig)
        display(fig_widget)
        return fig_widget

    def track_model(self, model, optimizer, lr):
        """
        Start tracking training metrics for a new model.
        """

        for field in model.__dict__['_modules'].values():
            if isinstance(field, nn.Linear):
                self.first_linear_layer = field
                break
            elif isinstance(field, nn.ModuleList):
                self.first_linear_layer = field[0]
                break

        self.lr = lr
        self.trace_idx += 1

        optim_name = type(optimizer).__name__
        self.train_loss_fig.data[self.trace_idx].name = f'{optim_name}, {lr}'
        self.grad_to_weight_fig.data[self.trace_idx].name = f'{optim_name}, {lr}'
        self.test_acc_fig.data[self.trace_idx].name = f'{optim_name}, {lr}'

    def plot_gradients_and_loss(self, batch_idx, loss):
        if batch_idx % self.log_interval == 0:
            self.train_loss_fig.data[self.trace_idx].y += (loss, )

            layer = self.first_linear_layer
            grad_to_weight_ratio = (self.lr * layer.weight.grad.std() / layer.weight.std()).item()

            self.grad_to_weight_fig.data[self.trace_idx].y += (grad_to_weight_ratio, )

    def plot_accuracy(self, epoch, accuracy):
        self.test_acc_fig.data[self.trace_idx].x += (epoch, )
        self.test_acc_fig.data[self.trace_idx].y += (accuracy, )

In [13]:
class Net(nn.Module):
    def __init__(self, modules=[748,128,128,10]):
        super(Net, self).__init__()
        # After flattening an image of size 28x28 we have 784 inputs
        self.modules = nn.ModuleList([nn.Linear(x,y)] for x,y in zip(modules[:-1], modules[1:]))

    def forward(self, x):
        x = torch.flatten(x, 1)

        for module in self.modules[:-1]:
            x=module(x)
            x=F.relu(x)

        x=self.modules[-1](x)
        output=F.log_softmax(x, dim=1)

        return output


def train(model, device, train_loader, optimizer, epoch, log_interval, visualizer, verbose=False):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        visualizer.plot_gradients_and_loss(batch_idx, loss.item())
        optimizer.step()
        if batch_idx % log_interval == 0:
            if verbose:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss.item()))


def test(model, device, test_loader, visualizer, verbose=False):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    if verbose:
        print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss, correct, len(test_loader.dataset),
            100. * correct / len(test_loader.dataset)))
    visualizer.plot_accuracy(epoch, 100. * correct / len(test_loader.dataset))



In [4]:
batch_size = 256
test_batch_size = 1000
epochs = 5
lr = 1e-2
seed = 1
log_interval = 10
use_cuda = torch.cuda.is_available()

In [5]:
torch.manual_seed(seed)
device = torch.device("cuda" if use_cuda else "cpu")

train_kwargs = {'batch_size': batch_size}
test_kwargs = {'batch_size': test_batch_size}
if use_cuda:
    cuda_kwargs = {'num_workers': 1,
                    'pin_memory': True,
                    'shuffle': True}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)

In [6]:
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
dataset1 = datasets.MNIST('../data', train=True, download=True,
                    transform=transform)
dataset2 = datasets.MNIST('../data', train=False,
                    transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:02<00:00, 4695802.08it/s]


Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 150723.46it/s]


Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:01<00:00, 1424513.97it/s]


Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 2337775.04it/s]

Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw






In [11]:
visualizer = TrainingVisualizer(log_interval=log_interval)

FigureWidget({
    'data': [{'mode': 'lines',
              'showlegend': True,
              'type': 'scatter',
              'uid': 'de2a041d-e5cb-4c27-b644-28d09ef38774',
              'y': []},
             {'mode': 'lines',
              'showlegend': True,
              'type': 'scatter',
              'uid': '55f71c7a-a791-496d-be18-b5e275f13345',
              'y': []},
             {'mode': 'lines',
              'showlegend': True,
              'type': 'scatter',
              'uid': 'df203860-42b1-4f89-9e1d-290b20721e1f',
              'y': []},
             {'mode': 'lines',
              'showlegend': True,
              'type': 'scatter',
              'uid': 'e8244747-37d8-4e5f-b242-42b0e4400e7f',
              'y': []},
             {'mode': 'lines',
              'showlegend': True,
              'type': 'scatter',
              'uid': 'b3aec545-bb1c-4710-9aae-d3523d54fa4f',
              'y': []},
             {'mode': 'lines',
              'showlegend': True,
     

FigureWidget({
    'data': [{'mode': 'lines',
              'showlegend': True,
              'type': 'scatter',
              'uid': '68e3f4c7-d0de-4aae-91f3-7797ff9f6e8b',
              'y': []},
             {'mode': 'lines',
              'showlegend': True,
              'type': 'scatter',
              'uid': 'a9423647-fd9f-4f95-8ad9-62933c141ddf',
              'y': []},
             {'mode': 'lines',
              'showlegend': True,
              'type': 'scatter',
              'uid': '1e3dd537-55fd-4f9f-8109-a48092a11ea7',
              'y': []},
             {'mode': 'lines',
              'showlegend': True,
              'type': 'scatter',
              'uid': 'd2fc164f-22d7-4d4f-ad98-171e89095034',
              'y': []},
             {'mode': 'lines',
              'showlegend': True,
              'type': 'scatter',
              'uid': '01454614-3863-41d3-ab42-40da97194b53',
              'y': []},
             {'mode': 'lines',
              'showlegend': True,
     

FigureWidget({
    'data': [{'mode': 'lines+markers',
              'showlegend': True,
              'type': 'scatter',
              'uid': 'b544ce81-7364-413c-b1d5-abd8854070b6',
              'x': [],
              'y': []},
             {'mode': 'lines+markers',
              'showlegend': True,
              'type': 'scatter',
              'uid': 'c61b5d22-39eb-41bc-9379-24007039c4b2',
              'x': [],
              'y': []},
             {'mode': 'lines+markers',
              'showlegend': True,
              'type': 'scatter',
              'uid': '9933f4c2-9dbe-4c25-b238-712cc95a9658',
              'x': [],
              'y': []},
             {'mode': 'lines+markers',
              'showlegend': True,
              'type': 'scatter',
              'uid': 'ac38e1b1-7d6c-4f81-82a8-5c0c2bc7b534',
              'x': [],
              'y': []},
             {'mode': 'lines+markers',
              'showlegend': True,
              'type': 'scatter',
              'uid': 'f

In [12]:
model = Net().to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
visualizer.track_model(model, optimizer, lr)

params = [
    ("adam", 1e-2),
    ("adam", 1e-3),
    ("adam", 1e-4),
    ("momentum", 1e-2),
    ("momentum", 1e-3),
    ("momentum", 1e-4),
]
for p in params:
    model = Net().to(device)
    if p[0] == "adam":
        optimizer = optim.Adam(model.parameters(), lr=p[1])
    elif p[0] == "momentum":
        optimizer = optim.SGD(model.parameters(), lr=p[1])
    visualizer.track_model(model, optimizer, p[1])
    for epoch in range(1, epochs + 1):
        train(model, device, train_loader, optimizer, epoch, log_interval, visualizer)
        test(model, device, test_loader, visualizer)