<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Define-network" data-toc-modified-id="Define-network-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Define network</a></span></li><li><span><a href="#Define-train-/-test-functions" data-toc-modified-id="Define-train-/-test-functions-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Define train / test functions</a></span></li><li><span><a href="#Create-model" data-toc-modified-id="Create-model-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Create model</a></span></li><li><span><a href="#Create-data-loaders" data-toc-modified-id="Create-data-loaders-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>Create data loaders</a></span></li><li><span><a href="#Setup-libvis" data-toc-modified-id="Setup-libvis-5"><span class="toc-item-num">5&nbsp;&nbsp;</span>Setup libvis</a></span><ul class="toc-item"><li><span><a href="#Loss-function-history-graph" data-toc-modified-id="Loss-function-history-graph-5.1"><span class="toc-item-num">5.1&nbsp;&nbsp;</span>Loss function history graph</a></span></li><li><span><a href="#'Stop-training'-button" data-toc-modified-id="'Stop-training'-button-5.2"><span class="toc-item-num">5.2&nbsp;&nbsp;</span>'Stop training' button</a></span></li><li><span><a href="#Learning-rate-slider" data-toc-modified-id="Learning-rate-slider-5.3"><span class="toc-item-num">5.3&nbsp;&nbsp;</span>Learning rate slider</a></span></li><li><span><a href="#Callbacks" data-toc-modified-id="Callbacks-5.4"><span class="toc-item-num">5.4&nbsp;&nbsp;</span>Callbacks</a></span></li></ul></li><li><span><a href="#Train-model" data-toc-modified-id="Train-model-6"><span class="toc-item-num">6&nbsp;&nbsp;</span>Train model</a></span></li><li><span><a href="#Resulting-dashboard" data-toc-modified-id="Resulting-dashboard-7"><span class="toc-item-num">7&nbsp;&nbsp;</span>Resulting dashboard</a></span></li></ul></div>

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import sklearn.metrics

import matplotlib.pyplot as plt

from libvis import Vis

In [2]:
%load_ext autoreload
%autoreload 2

## Define network

Nothing fancy here, just a 2-layer convolutional network to use for MNIST classification

In [3]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

## Define train / test functions

Train and test functions take a callback function and call it after each train batch,
passing information about model performance.

You can also use a global variable and access it in the train function, but it 
tends to get messy for many variables.
Better to keep visualization part apart from training.

In [4]:
from collections import namedtuple
TrainInfo = namedtuple('TrainInfo', 'epoch optimizer model output loss target pred')

def output2pred(output):
    """ Desicion from output of the network. """
    return output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability

def train(model, device, train_loader, optimizer, epoch, callback):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        
        pred = output2pred(output)
        callback_info = TrainInfo(epoch, optimizer, model, output, loss, target, pred)
        callback(callback_info)
        print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            epoch, batch_idx * len(data), len(train_loader.dataset),
            100. * batch_idx / len(train_loader), loss.item()))
        optimizer.step()

def test(model, device, test_loader, callback):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output2pred(output)
            correct += pred.eq(target.view_as(pred)).sum().item()
            
            callback(data, target, pred, test_loss, output)

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


## Create model

In [5]:
use_cuda = False

device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

epochs=5
batch_size=2000
model = Net().to(device)

## Create data loaders

In [6]:
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True, **kwargs)
    
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, 
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True, **kwargs)

In [9]:
optimizer = torch.optim.Adadelta(model.parameters())

## Setup libvis

In [7]:
from libvis.modules import Image, uicontrols
from libvis import Vis
from loguru import logger as log
import bokeh
import bokeh.plotting

# lets mute logs for now
log.disable('libvis')
log.disable('legimens')

In [16]:
vis = Vis()

HTTPServer start on 7000 failed: [Errno 98] Address already in use
Webapp HTTP server failed to start at localhost:7000. To start manually: `Vis.start_http(port)`. Error was: [Errno 98] Address already in use
Exception in thread Thread-5:
Traceback (most recent call last):
  File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.8/threading.py", line 870, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/lib/python3.8/site-packages/trio/_core/_run.py", line 1804, in run
    raise runner.main_task_outcome.error
  File "/home/dali/side-projects-hobby/legimens/python/legimens/App.py", line 268, in _start
    await trio.sleep(CORO_SCHEDULER_DELAY)
  File "/usr/lib/python3.8/site-packages/trio/_core/_run.py", line 730, in __aexit__
    raise combined_error_from_nursery
  File "/home/dali/side-projects-hobby/legimens/python/legimens/websocket/server.py", line 54, in start_server
    await ws_serve(addr, port, handler_func

Exception: Failed to start Legimens. Reason may be printed above by other thread. Exception sharing is under development.

In [9]:
optimizer = torch.optim.Adadelta(model.parameters())

### Loss function history graph

In [10]:
losses = []
vis.watch(losses, 'loss')

'Legi_0x7f63ba2c6180'

### 'Stop training' button

In [11]:
train_enable = True
def disable_train():
    """ Set train_enable flag to false to sto training. """
    print('Stopping train after this batch')
    global train_enable
    train_enable = False
    
vis.vars.stop = uicontrols.Button(label='Stop training', on_press=disable_train)

### Learning rate slider

In [12]:
lr = optimizer.param_groups[0]['lr']
slider = uicontrols.Slider(value=lr, min=0, max=0.05)
vis.vars.lr = slider

def on_slider(lr_new):
    """ Change learinng rate of optimizer. """
    optimizer.param_groups[0]['lr'] = lr_new
    print('Changed lr to', lr_new)
    
vis.vars.lr.on_change = on_slider

### Callbacks

In [13]:
def train_callback(info):
    global train_enable
    loss = info.loss
    model = info.model
    
    v = vis.vars
    v.epoch = info.epoch
    
    losses.append(loss.item())
    params = [param.grad.flatten() for _, param in model.named_parameters()]
    vals = np.concatenate(params)
    y, x = np.histogram(vals, bins=200)
    v.hist = np.array([x[1:], np.log(y+1)])
    
    y, x = np.histogram(model.fc2.weight.grad.flatten(), bins=200)
    v.hist_fc2_grad = np.array([x[1:], np.log(y+1)])
    
    y, x = np.histogram(model.fc1.weight.data.flatten(), bins=200)
    v.hist_fc1 = np.array([x[1:], np.log(y+1)])
    
    
    confmat = sklearn.metrics.confusion_matrix(info.target, info.pred)
    fig = bokeh.plotting.figure(
        title='confusion matrix',
        sizing_mode='stretch_both'
    )
    fig.image(image=[confmat], dw=10, dh=10)
    v.confusion_matrix = fig 
    
    if not train_enable:
        train_enable=True
        raise StopIteration()
    
def test_callback(data, target, pred, loss, output):
    pass


## Train model

In [14]:
for epoch in range(1, epochs + 1):
    train(model, device, train_loader, optimizer, epoch, train_callback)
    test(model, device, test_loader, test_callback)


Changed lr to 0.0465
Changed lr to 0.045
Changed lr to 0.042
Changed lr to 0.0355
Changed lr to 0.029
Changed lr to 0.028
Changed lr to 0.0275
Changed lr to 0.025
Changed lr to 0.0245
Changed lr to 0.043
Changed lr to 0.006
Changed lr to 0.0
Changed lr to 0.042


KeyboardInterrupt: 

## Resulting dashboard


![](https://libvis.dev/pictures/torch_adv_demo.png)

In [105]:
vis.stop()

Stopping webapp http server: `Vis.stop_http()`... OK
Stopping websocket server: `Vis.app.stop()`... OK


In [16]:
param = p[0][1]
param.data


tensor([[[[-0.1778,  0.2160, -0.1562],
          [-0.2498,  0.3287,  0.3338],
          [ 0.0827,  0.1157, -0.2219]]],


        [[[-0.0362, -0.1317, -0.0864],
          [ 0.2682,  0.3209,  0.1445],
          [ 0.0390, -0.1257,  0.2463]]],


        [[[-0.1711, -0.1340, -0.1410],
          [ 0.0288, -0.2328,  0.2236],
          [ 0.0702,  0.2597,  0.0077]]],


        [[[ 0.2867,  0.4020, -0.2208],
          [ 0.2556,  0.3275,  0.1281],
          [ 0.0071, -0.3357, -0.2546]]],


        [[[ 0.1315,  0.0074, -0.1273],
          [ 0.1005,  0.2866, -0.3804],
          [ 0.0958, -0.1048,  0.0318]]],


        [[[-0.5412, -0.0485,  0.3391],
          [-0.3502,  0.2956,  0.1419],
          [-0.3551,  0.0093,  0.3081]]],


        [[[ 0.1048, -0.0904, -0.0148],
          [-0.0435,  0.0599, -0.0093],
          [ 0.2936, -0.1868,  0.1078]]],


        [[[ 0.1351,  0.1152,  0.1770],
          [ 0.0157,  0.1192, -0.2279],
          [-0.3656, -0.1112, -0.1889]]],


        [[[-0.0725, -0.3036,  0.