Anti disconnect
```javascript
function ConnectButton(){
    console.log("Connect pushed"); 
    document.querySelector("#top-toolbar > colab-connect-button").shadowRoot.querySelector("#connect").click() 
}
setInterval(ConnectButton, 60000);
```

# Introduction
## Algorithms
Considered algorithms:
* SGD + momentum
* Adam
* [AdamW](https://arxiv.org/abs/1711.05101)
* [AdaBound](https://arxiv.org/abs/1902.09843)
* [Radam](https://arxiv.org/abs/1908.03265)
* [COCOB](https://arxiv.org/abs/1705.07795)
* [Storm](https://arxiv.org/abs/1905.10018) (**to be implemented**)

If time allows:
* [Hypergradient Descent](https://arxiv.org/abs/1703.04782)
* [Lookahead](https://arxiv.org/abs/1907.08610)

Considered problems/dataset:
* Image recognition:
    * CIFAR-10
    * CIFAR-100
    * Street View House Number
* Language Modelling:
  * PenTreebank
* Machine translation (maybe later):
  * ?




## Package installs

In [0]:
!pip install -q adabound
!pip install -q git+https://github.com/gbaydin/hypergradient-descent.git
!pip install -q lookahead
!pip install -q wandb

  Building wheel for hypergrad (setup.py) ... [?25l[?25hdone
  Building wheel for lookahead (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 1.4MB 2.5MB/s 
[K     |████████████████████████████████| 102kB 8.9MB/s 
[K     |████████████████████████████████| 460kB 16.1MB/s 
[K     |████████████████████████████████| 102kB 9.3MB/s 
[K     |████████████████████████████████| 112kB 19.0MB/s 
[K     |████████████████████████████████| 71kB 7.9MB/s 
[K     |████████████████████████████████| 71kB 7.2MB/s 
[?25h  Building wheel for watchdog (setup.py) ... [?25l[?25hdone
  Building wheel for subprocess32 (setup.py) ... [?25l[?25hdone
  Building wheel for gql (setup.py) ... [?25l[?25hdone
  Building wheel for pathtools (setup.py) ... [?25l[?25hdone
  Building wheel for graphql-core (setup.py) ... [?25l[?25hdone


## Imports

In [0]:
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, models, datasets
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.optim.lr_scheduler import StepLR
import time
import os
import copy
import adabound

import wandb

import matplotlib.pyplot as plt
import matplotlib.style as style 
import seaborn as sns
style.use('seaborn-poster') 
style.use('dark_background')

from sklearn.metrics import accuracy_score
from sklearn.model_selection import ParameterGrid

import numpy as np

## Optimizers

### Cocob

In [0]:
import torch.optim as optim
import torch

###########################################################################
# Training Deep Networks without Learning Rates Through Coin Betting
# Paper: https://arxiv.org/abs/1705.07795
#
# NOTE: This optimizer is hardcoded to run on GPU, needs to be parametrized
###########################################################################

class COCOBBackprop(optim.Optimizer):

    def __init__(self, params, alpha=100, epsilon=1e-8, weight_decay=0):

        self._alpha = alpha
        self.epsilon = epsilon
        defaults = dict(alpha=alpha, epsilon=epsilon, weight_decay=weight_decay)
        super(COCOBBackprop, self).__init__(params, defaults)

    def step(self, closure=None):

        loss = None

        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad.data

                if group['weight_decay'] != 0:
                    grad = grad.add(p.data, alpha=group['weight_decay'])

                state = self.state[p]

                if len(state) == 0:
                    state['gradients_sum'] = torch.zeros_like(p.data).cuda().float()
                    state['grad_norm_sum'] = torch.zeros_like(p.data).cuda().float()
                    state['L'] = self.epsilon * torch.ones_like(p.data).cuda().float()
                    state['tilde_w'] = torch.zeros_like(p.data).cuda().float()
                    state['reward'] = torch.zeros_like(p.data).cuda().float()

                gradients_sum = state['gradients_sum']
                grad_norm_sum = state['grad_norm_sum']
                tilde_w = state['tilde_w']
                L = state['L']
                reward = state['reward']

                zero = torch.cuda.FloatTensor([0.])

                L_update = torch.max(L, torch.abs(grad))
                gradients_sum_update = gradients_sum + grad
                grad_norm_sum_update = grad_norm_sum + torch.abs(grad)
                reward_update = torch.max(reward - grad * tilde_w, zero)
                new_w = -gradients_sum_update/(L_update * (torch.max(grad_norm_sum_update + L_update, self._alpha * L_update)))*(reward_update + L_update)
                p.data = p.data - tilde_w + new_w
                tilde_w_update = new_w

                state['gradients_sum'] = gradients_sum_update
                state['grad_norm_sum'] = grad_norm_sum_update
                state['L'] = L_update
                state['tilde_w'] = tilde_w_update
                state['reward'] = reward_update

        return loss

### Radam

In [0]:
# Source: https://github.com/LiyuanLucasLiu/RAdam/tree/master/radam
import math
import torch
from torch.optim.optimizer import Optimizer, required

class RAdam(Optimizer):

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        
        self.degenerated_to_sgd = degenerated_to_sgd
        if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
            for param in params:
                if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
                    param['buffer'] = [[None, None, None] for _ in range(10)]
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)])
        super(RAdam, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(RAdam, self).__setstate__(state)

    def step(self, closure=None):

        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:

            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data.float()
                if grad.is_sparse:
                    raise RuntimeError('RAdam does not support sparse gradients')

                p_data_fp32 = p.data.float()

                state = self.state[p]

                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p_data_fp32)
                    state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
                else:
                    state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
                    state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                exp_avg.mul_(beta1).add_(1 - beta1, grad)

                state['step'] += 1
                buffered = group['buffer'][int(state['step'] % 10)]
                if state['step'] == buffered[0]:
                    N_sma, step_size = buffered[1], buffered[2]
                else:
                    buffered[0] = state['step']
                    beta2_t = beta2 ** state['step']
                    N_sma_max = 2 / (1 - beta2) - 1
                    N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
                    buffered[1] = N_sma

                    # more conservative since it's an approximated value
                    if N_sma >= 5:
                        step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
                    elif self.degenerated_to_sgd:
                        step_size = 1.0 / (1 - beta1 ** state['step'])
                    else:
                        step_size = -1
                    buffered[2] = step_size

                # more conservative since it's an approximated value
                if N_sma >= 5:
                    if group['weight_decay'] != 0:
                        p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
                    denom = exp_avg_sq.sqrt().add_(group['eps'])
                    p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
                    p.data.copy_(p_data_fp32)
                elif step_size > 0:
                    if group['weight_decay'] != 0:
                        p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
                    p_data_fp32.add_(-step_size * group['lr'], exp_avg)
                    p.data.copy_(p_data_fp32)

        return loss

class PlainRAdam(Optimizer):

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
                    
        self.degenerated_to_sgd = degenerated_to_sgd
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)

        super(PlainRAdam, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(PlainRAdam, self).__setstate__(state)

    def step(self, closure=None):

        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:

            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data.float()
                if grad.is_sparse:
                    raise RuntimeError('RAdam does not support sparse gradients')

                p_data_fp32 = p.data.float()

                state = self.state[p]

                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p_data_fp32)
                    state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
                else:
                    state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
                    state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                exp_avg.mul_(beta1).add_(1 - beta1, grad)

                state['step'] += 1
                beta2_t = beta2 ** state['step']
                N_sma_max = 2 / (1 - beta2) - 1
                N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)


                # more conservative since it's an approximated value
                if N_sma >= 5:
                    if group['weight_decay'] != 0:
                        p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
                    step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
                    denom = exp_avg_sq.sqrt().add_(group['eps'])
                    p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
                    p.data.copy_(p_data_fp32)
                elif self.degenerated_to_sgd:
                    if group['weight_decay'] != 0:
                        p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
                    step_size = group['lr'] / (1 - beta1 ** state['step'])
                    p_data_fp32.add_(-step_size, exp_avg)
                    p.data.copy_(p_data_fp32)

        return loss


class AdamW(Optimizer):

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup = 0):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        
        defaults = dict(lr=lr, betas=betas, eps=eps,
                        weight_decay=weight_decay, warmup = warmup)
        super(AdamW, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(AdamW, self).__setstate__(state)

    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:

            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data.float()
                if grad.is_sparse:
                    raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')

                p_data_fp32 = p.data.float()

                state = self.state[p]

                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p_data_fp32)
                    state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
                else:
                    state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
                    state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                state['step'] += 1

                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                exp_avg.mul_(beta1).add_(1 - beta1, grad)

                denom = exp_avg_sq.sqrt().add_(group['eps'])
                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']
                
                if group['warmup'] > state['step']:
                    scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup']
                else:
                    scheduled_lr = group['lr']

                step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1
                
                if group['weight_decay'] != 0:
                    p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32)

                p_data_fp32.addcdiv_(-step_size, exp_avg, denom)

                p.data.copy_(p_data_fp32)

        return loss

## Architectures

### ResNet18/34/50/101/152 

In [0]:
# Source: https://github.com/uoguelph-mlrg/Cutout/blob/master/model/resnet.py 

'''ResNet18/34/50/101/152 in Pytorch.'''
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.autograd import Variable


def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(in_planes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = conv3x3(3,64)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def ResNet18(num_classes=10):
    return ResNet(BasicBlock, [2,2,2,2], num_classes)

def ResNet34(num_classes=10):
    return ResNet(BasicBlock, [3,4,6,3], num_classes)

def ResNet50(num_classes=10):
    return ResNet(Bottleneck, [3,4,6,3], num_classes)

def ResNet101(num_classes=10):
    return ResNet(Bottleneck, [3,4,23,3], num_classes)

def ResNet152(num_classes=10):
    return ResNet(Bottleneck, [3,8,36,3], num_classes)

## Helpers

Progress bar

In [0]:
from IPython.display import HTML, display

def progress(value, max=100):
    return HTML("""
        <progress
            value='{value}'
            max='{max}',
            style='width: 100%'
        >
            {value}
        </progress>
    """.format(value=value, max=max))

Seeder

In [0]:
def random_state():
  seed = np.random.randint(100000)
  torch.manual_seed(seed)
  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = False
  np.random.seed(seed)
  return seed

### DataLoader

In [0]:
import os
import torchvision.datasets
import torchvision.transforms 

class DataLoader:

  DATA_ROOT = './data'

  def __init__(self, set_name, train_transform, test_transform, 
               batch_size_train=256, batch_size_val=1024,
               shuffle=True, data_root='./data'):
    
    self.data_root = data_root
    self.batch_size_train = batch_size_train
    self.batch_size_val = batch_size_val
    self.shuffle = shuffle
    self.num_workers = os.cpu_count()
    self.set_name = set_name
    self._loader_params_train = self._get_loader_params(True)
    self._loader_params_val = self._get_loader_params(False)
    self._mean = None
    self._std = None
    self._train_transform = train_transform
    self._test_transform = test_transform

  def get_loaders(self):
    trainloader = self.get_loader(True)
    testloader = self.get_loader(False)
    return trainloader, testloader

  def get_loader(self, train):
    params = self._loader_params_train if train else self._loader_params_val
    transform = self._train_transform if train else self._test_transform
    dataset_f = getattr(torchvision.datasets, self.set_name)
    dataset = dataset_f(self.data_root, train=train, 
                        transform=transform, 
                        download=True)
    return torch.utils.data.DataLoader(dataset, 
                                       **params)
  
  def _get_loader_params(self, train):
    return {
        'batch_size': self.batch_size_train if train else self.batch_size_val,
        'shuffle': self.shuffle,
        'num_workers': self.num_workers
    }

### Utils

In [0]:
import matplotlib.pyplot as plt
import numpy as np
import torchvision

class Utils:

  @staticmethod
  def imshow(img):
      img = img / 2 + 0.5 # unnormalize
      npimg = img.numpy()
      plt.imshow(np.transpose(npimg, (1, 2, 0)))
      plt.show()
  
  @staticmethod
  def show_random_images(loader):
    dataiter = iter(loader)
    images, labels = dataiter.next()
    Utils.imshow(torchvision.utils.make_grid(images))
    # print labels
    print(' '.join('%5s' % classes[labels[j]] for j in range(loader.batch_size)))


### Training

In [0]:
class Trainer:

  HTML = False

  def __init__(self, model, trainloader, testloader, 
               criterion, optimizer, num_epochs, 
               score=accuracy_score,
               scheduler=None, wandb=None):
    
    if not(torch.cuda.is_available()) and not(CPU):
      raise EnvironmentError('Turn on GPU acceleration')

    self._device = torch.device('cuda:0')
    self._model = model.to(self._device)
    self._trainloader = trainloader
    self._testloader = testloader
    self._criterion = criterion
    self._optimizer = optimizer
    self._num_epochs = num_epochs
    self._scheduler = scheduler
    self._score = score

    self._BATCH_SIZE = trainloader.batch_size
    self._TRAIN_SIZE = trainloader.batch_size * len(trainloader)
    self._TEST_SIZE = testloader.batch_size * len(testloader)

    self._best_model_wts = copy.deepcopy(self._model.state_dict())
    self._best_score = 0
    self._epoch = 0
    self._train_hist = self._get_hist_dict()
    self._val_hist = self._get_hist_dict()

    self._bar = None
    self._wandb = wandb
    self._wandb_entry = {}

    self._best_scores = {'train': 0,
                         'val': 0}
    

  def _get_hist_dict(self):
    return {
        'scores': np.zeros(self._num_epochs),
        'losses': np.zeros(self._num_epochs)
    }

  def train(self):
    for self._epoch in range(self._num_epochs):
      self._log_epoch()

      self._model.train()
      self._pass()

      self._model.eval()
      self._pass()
      self._scheduler_step()

    self._save_model()
      
  def _save_model(self):
    name = 'model_' + str(self._epoch) + '.pt'
    torch.save({
            'model_state_dict': self._model.state_dict(),
            'optimizer_state_dict': self._optimizer.state_dict(),
            }, os.path.join(wandb.run.dir, name))
    wandb.save(name)

  def _pass(self):
    loader = self._get_loader()
    losses = np.zeros(len(loader))
    scores = np.zeros(len(loader))
    self._show_bar()
    for i, (X, y) in enumerate(loader):
      X = X.to(self._device)
      y = y.to(self._device)
      with torch.set_grad_enabled(self._model.training):
        self._optimizer.zero_grad()
        y_pred = self._model(X)
        _, preds = torch.max(y_pred, 1)
        loss = self._criterion(y_pred, y)
        losses[i] = loss.item()
        scores[i] = self._score(y.tolist(), preds.tolist())
        if self._model.training:
          loss.backward()
          self._optimizer.step()
        self._update_bar(i+1, loss.item(), scores[i].item())
    self._update_hist(np.mean(losses), np.mean(scores))
  
  def _update_hist(self, loss, score):
    training = self._model.training
    hist = self._train_hist if training else self._val_hist
    hist['losses'][self._epoch] = loss
    hist['scores'][self._epoch] = score

    best_type = 'train' if training else 'val'
    best = self._best_scores[best_type]
    if best < score:
      self._best_scores[best_type] = score
      if self._wandb:
        wandb.run.summary["best_" + best_type + "_acc"] = score
        wandb.run.summary.update()
      
    if self._wandb:
      entry = self._wandb_entry
      key = 'train' if training else 'val'
      entry[key + '_loss'] = loss
      entry[key + '_acc'] = score
      if key=='val':
        self._wandb.log(entry)
        self._wandb_entry = {}
        wandb.run.summary["finished_epochs"] = self._epoch
        wandb.run.summary.update()

  def _scheduler_step(self):
    if self._scheduler:
      self._scheduler.step()
  
  def _get_loader(self):
    if not(self._model.training):
      return self._testloader
    return self._trainloader

  def _log_epoch(self):
      print('Epoch {}/{}'.format(self._epoch, self._num_epochs - 1))
      print('-' * 10)

  def _show_bar(self):
    html = self._get_html_bar(0, 0)
    self._bar = display(html, display_id=True)

  def _get_html_bar(self, value, loss=0, score=0):
    max = len(self._trainloader) if \
     self._model.training else len(self._testloader)
    prefix = 'train_' if self._model.training else 'val_'

    return HTML("""
                <progress
                    value='{value}'
                    max='{max}',
                    style='width: 60%',
                >
                    {value}
                </progress> 
                <span style='margin-left: 10px'>
                  {prefix}loss: {loss:.3f} {prefix}acc: {score:.3f}
                <span>
              """.format(value=value, max=max, 
                         loss=loss, score=score,
                         prefix=prefix))
  
  def _update_bar(self, value, loss, score):
    if not(self._bar):
      raise RuntimeError('Bar not being shown')
    self._bar.update(self._get_html_bar(value, loss, score))

## Experiments

In [0]:
class OptimizerProvider:
  OPTIMIZERS = ['SGD', 'Adam', 'AdamW', 'Cocob', 'Radam', 'Adabound']

  @staticmethod
  def get(opt_name, model_parameters, params):
    if opt_name not in OptimizerProvider.OPTIMIZERS:
      raise RuntimeError('Unknown optimizer')

    lr = params['lr']
    wd = params['weight_decay']

    if opt_name == 'Adam':
      return optim.Adam(model_parameters,
                        lr=lr,
                        weight_decay=wd)

    elif opt_name == 'AdamW':
      return optim.AdamW(model_parameters,
                         lr=lr,
                         weight_decay=wd)

    elif opt_name == 'Radam':
      return RAdam(model_parameters,
                   lr=lr,
                   weight_decay=wd)

    elif opt_name == 'SGD': 
      return optim.SGD(model_parameters,
                       lr=lr,
                       weight_decay=wd)

    elif opt_name == 'Cocob': 
      return COCOBBackprop(model_parameters,
                           weight_decay=wd)

    elif opt_name == 'Adabound': 
      return adabound.AdaBound(model_parameters,
                               lr=lr,
                               weight_decay=wd)

### CIFAR-10 / CIFAR-100

Normalization and augmentation the same for **Cifar-10** and **Cifar-100** and taken from  
[https://github.com/uoguelph-mlrg/Cutout](https://github.com/uoguelph-mlrg/Cutout)

In [0]:
class CifarExperiment:

  BATCH_SIZE = 128
  EPOCHS = 200
  CIFAR_10 = 'CIFAR10'
  CIFAR_100 = 'CIFAR100'
  CIFAR_10_PROJECT = 'ms-cifar10-resNet18-v2'
  CIFAR_100_PROJECT = 'ms-cifar100-resNet18'

  @staticmethod
  def train_resNet18_cifar10(opt_name, **params):
    nclasses = 10
    CifarExperiment._train_resNet18_cifar(
                                         CifarExperiment.CIFAR_10,
                                         nclasses, 
                                         CifarExperiment.CIFAR_10_PROJECT,
                                         opt_name,
                                         **params)

  @staticmethod
  def train_resNet18_cifar100(opt_name, **params):
    nclasses = 100
    CifarExperiment._train_resNet18_cifar(
                                         CifarExperiment.CIFAR_100, 
                                         nclasses,
                                         CifarExperiment.CIFAR_100_PROJECT,
                                         opt_name,
                                         **params)

  @staticmethod
  def _get_transforms():
    normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]],
                                  std=[x/255.0 for x in [63.0, 62.1, 66.7]])

    train_transform = transforms.Compose([])
    train_transform.transforms.append(transforms.RandomCrop(32, padding=4))
    train_transform.transforms.append(transforms.RandomHorizontalFlip())
    train_transform.transforms.append(transforms.ToTensor())
    train_transform.transforms.append(normalize)

    test_transform = transforms.Compose([])
    test_transform.transforms.append(transforms.ToTensor())
    test_transform.transforms.append(normalize)

    return train_transform, test_transform
  
  @staticmethod
  def _train_resNet18_cifar(dataset, nclasses, project_name, opt_name, lrs, weight_decays, tags=[]):

    # LR step decay
    # src: "Lookahead Optimizer: k steps forward, 1 step back" - Hintot et al.
    STEP_SIZE = 50
    GAMMA = 1/5
    
    train_transform, test_transform = CifarExperiment._get_transforms()

    parameters = {
        'lr': lrs,
        'batch_size': [CifarExperiment.BATCH_SIZE],
        'epochs': [CifarExperiment.EPOCHS],
        'weight_decay': weight_decays,
        'lr_decay': ['StepLR'],
        'init': ['default'],
        'optimizer': [opt_name]
    }

    grid = ParameterGrid(parameters)

    trainloader, testloader = DataLoader(dataset, 
                                        train_transform, test_transform,
                                        batch_size_train=CifarExperiment.BATCH_SIZE).get_loaders() 
    # Grid search 
    for config in grid:
      seed = random_state()
      config['seed'] = seed
      
      wandb.init(reinit=True,
                project=project_name,
                tags=tags,
                config=config)
      
      model = ResNet18(nclasses)
      optimizer = OptimizerProvider.get(opt_name, model.parameters(), config) 
      criterion = nn.CrossEntropyLoss()

      # lr decay
      if opt_name in ['Cocob', 'Adabound']:
          scheduler = None
      else:
          scheduler = StepLR(optimizer, step_size=STEP_SIZE, gamma=GAMMA)

      # Training
      trainer = Trainer(model, trainloader, testloader, 
                        criterion, optimizer, CifarExperiment.EPOCHS, 
                        wandb=wandb, scheduler=scheduler)
      wandb.watch(model)
      trainer.train()
      wandb.join()

_wandb_ token:  
`XYZ`

In [0]:
!wandb login

In [0]:
lrs = [0.1]
weight_decays = [0.01]
CifarExperiment.train_resNet18_cifar10('SGD', lrs=lrs, weight_decays=weight_decays)