In [None]:
!pip install wandb

In [None]:
# @title necessary imports
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import wandb
from tqdm import tqdm
from functools import partial

import pickle
from copy import copy, deepcopy

from typing import Optional, Dict, Union, Tuple

from time import time
import math

from pandas import DataFrame
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import torch.optim as optim
from torch.optim import Adam, NAdam, Adadelta
import torch.distributions as distr
from torch.utils.data import DataLoader

In [None]:
# @title import datasets
from torchvision.datasets import MNIST, CIFAR10
from torchvision import transforms
from torch.utils.data import TensorDataset, DataLoader

## Global variables and settings

In [None]:
BATCH_SIZE = 128
USE_CUDA = torch.cuda.is_available()

In [None]:
#if USE_CUDA:
#  torch.cuda.set_device(5)

device = torch.device('cuda' if USE_CUDA else 'cpu')
print("Torch uses", str(device).upper(),"now")

# Optimizer implementations

## AdaHessian

(source: https://github.com/jettify/pytorch-optimizer/blob/master/torch_optimizer/adahessian.py)

In [None]:
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union

from torch import Tensor

Params = Union[Iterable[Tensor], Iterable[Dict[str, Any]]]

LossClosure = Callable[[], float]
OptLossClosure = Optional[LossClosure]
Betas2 = Tuple[float, float]
State = Dict[str, Any]
OptFloat = Optional[float]
Nus2 = Tuple[float, float]

In [None]:
import math
from typing import List, Optional

import torch
from torch.optim.optimizer import Optimizer

Grads = Params

__all__ = ("Adahessian",)


class Adahessian(Optimizer):
    r"""Implements Adahessian Algorithm.
    It has been proposed in `ADAHESSIAN: An Adaptive Second Order Optimizer
    for Machine Learning`.

    Arguments:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float, optional): learning rate (default: 0.15)
        betas (Tuple[float, float], optional): coefficients used for computing
            running averages of gradient and its square (default: (0.9, 0.999))
        eps (float, optional): term added to the denominator to improve
            numerical stability (default: 1e-4)
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
        hessian_power (float, optional): Hessian power (default: 0.5)
        seed (int, optional): Random number generator seed (default: None)

        __ https://arxiv.org/abs/2006.00719

        Note:
            Reference code: https://github.com/amirgholami/adahessian
    """

    def __init__(
        self,
        params: Params,
        lr: float = 0.15,
        betas: Betas2 = (0.9, 0.999),
        eps: float = 1e-4,
        weight_decay: float = 0,
        hessian_power: float = 1,
        seed: Optional[int] = None,
    ) -> None:
        if lr <= 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if eps <= 0.0:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError(
                "Invalid beta parameter at index 0: {}".format(betas[0])
            )
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError(
                "Invalid beta parameter at index 1: {}".format(betas[1])
            )
        if not 0.0 <= hessian_power <= 1.0:
            raise ValueError(
                "Invalid Hessian power value: {}".format(hessian_power)
            )
        if seed is not None:
            torch.manual_seed(seed)
        defaults = dict(
            lr=lr,
            betas=betas,
            eps=eps,
            weight_decay=weight_decay,
            hessian_power=hessian_power,
        )
        super(Adahessian, self).__init__(params, defaults)

    def get_trace(self, params: Params, grads: Grads) -> List[torch.Tensor]:
        """Get an estimate of Hessian Trace.
        This is done by computing the Hessian vector product with a random
        vector v at the current gradient point, to estimate Hessian trace by
        computing the gradient of <gradsH,v>.
        :param gradsH: a list of torch variables
        :return: a list of torch tensors
        """

        # Check backward was called with create_graph set to True
        for i, grad in enumerate(grads):
            if grad.grad_fn is None:
                msg = (
                    "Gradient tensor {:} does not have grad_fn. When "
                    "calling loss.backward(), make sure the option "
                    "create_graph is set to True."
                )
                raise RuntimeError(msg.format(i))

        v = [
            2
            * torch.randint_like(
                p, high=2, memory_format=torch.preserve_format
            )
            - 1
            for p in params
        ]

        # this is for distributed setting with single node and multi-gpus,
        # for multi nodes setting, we have not support it yet.
        hvs = torch.autograd.grad(
            grads, params, grad_outputs=v, only_inputs=True, retain_graph=True
        )

        hutchinson_trace = []
        for hv in hvs:
            param_size = hv.size()
            if len(param_size) <= 2:  # for 0/1/2D tensor
                # Hessian diagonal block size is 1 here.
                # We use that torch.abs(hv * vi) = hv.abs()
                tmp_output = hv.abs()

            elif len(param_size) == 4:  # Conv kernel
                # Hessian diagonal block size is 9 here: torch.sum() reduces
                # the dim 2/3.
                # We use that torch.abs(hv * vi) = hv.abs()
                tmp_output = torch.mean(hv.abs(), dim=[2, 3], keepdim=True)
            hutchinson_trace.append(tmp_output)

        return hutchinson_trace

    def step(self, closure: OptLossClosure = None) -> OptFloat:
        """Perform a single optimization step.

        Arguments:
            closure: A closure that reevaluates the model and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        params = []
        groups = []
        grads = []

        # Flatten groups into lists, so that
        #  hut_traces can be called with lists of parameters
        #  and grads
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is not None:
                    params.append(p)
                    groups.append(group)
                    grads.append(p.grad)

        # get the Hessian diagonal

        hut_traces = self.get_trace(params, grads)

        for p, group, grad, hut_trace in zip(
            params, groups, grads, hut_traces
        ):
            state = self.state[p]

            # State initialization
            if len(state) == 0:
                state["step"] = 0
                # Exponential moving average of gradient values
                state["exp_avg"] = torch.zeros_like(p.data)
                # Exponential moving average of Hessian diagonal square values
                state["exp_hessian_diag_sq"] = torch.zeros_like(p.data)

            exp_avg, exp_hessian_diag_sq = (
                state["exp_avg"],
                state["exp_hessian_diag_sq"],
            )

            beta1, beta2 = group["betas"]

            state["step"] += 1

            # Decay the first and second moment running average coefficient
            exp_avg.mul_(beta1).add_(grad.detach_(), alpha=1 - beta1)
            exp_hessian_diag_sq.mul_(beta2).addcmul_(
                hut_trace, hut_trace, value=1 - beta2
            )

            bias_correction1 = 1 - beta1 ** state["step"]
            bias_correction2 = 1 - beta2 ** state["step"]

            # make the square root, and the Hessian power
            k = group["hessian_power"]
            denom = (
                (exp_hessian_diag_sq.sqrt() ** k)
                / math.sqrt(bias_correction2) ** k
            ).add_(group["eps"])

            # make update
            p.data = p.data - group["lr"] * (
                exp_avg / bias_correction1 / denom
                + group["weight_decay"] * p.data
            )

        return loss

In [None]:
class Adahessian(torch.optim.Optimizer):
    """
    Implements the AdaHessian algorithm from "ADAHESSIAN: An Adaptive Second OrderOptimizer for Machine Learning"

    Arguments:
        params (iterable) -- iterable of parameters to optimize or dicts defining parameter groups
        lr (float, optional) -- learning rate (default: 0.1)
        betas ((float, float), optional) -- coefficients used for computing running averages of gradient and the squared hessian trace (default: (0.9, 0.999))
        eps (float, optional) -- term added to the denominator to improve numerical stability (default: 1e-8)
        weight_decay (float, optional) -- weight decay (L2 penalty) (default: 0.0)
        hessian_power (float, optional) -- exponent of the hessian trace (default: 1.0)
        update_each (int, optional) -- compute the hessian trace approximation only after *this* number of steps (to save time) (default: 1)
        n_samples (int, optional) -- how many times to sample `z` for the approximation of the hessian trace (default: 1)
    """

    def __init__(self, params, lr=0.1, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0,
                 hessian_power=1.0, update_each=1, n_samples=1, average_conv_kernel=False):
        if not 0.0 <= lr:
            raise ValueError(f"Invalid learning rate: {lr}")
        if not 0.0 <= eps:
            raise ValueError(f"Invalid epsilon value: {eps}")
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
        if not 0.0 <= hessian_power <= 1.0:
            raise ValueError(f"Invalid Hessian power value: {hessian_power}")

        self.n_samples = n_samples
        self.update_each = update_each
        self.average_conv_kernel = average_conv_kernel

        # use a separate generator that deterministically generates the same `z`s across all GPUs in case of distributed training
        self.generator = torch.Generator().manual_seed(2147483647)

        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, hessian_power=hessian_power)
        super().__init__(params, defaults)

        for p in self.get_params():
            p.hess = 0.0
            self.state[p]["hessian step"] = 0

    def get_params(self):
        """
        Gets all parameters in all param_groups with gradients
        """

        return (p for group in self.param_groups for p in group['params'] if p.requires_grad)

    def zero_hessian(self):
        """
        Zeros out the accumalated hessian traces.
        """

        for p in self.get_params():
            if not isinstance(p.hess, float) and self.state[p]["hessian step"] % self.update_each == 0:
                p.hess.zero_()

    @torch.no_grad()
    def set_hessian(self):
        """
        Computes the Hutchinson approximation of the hessian trace and accumulates it for each trainable parameter.
        """

        params = []
        for p in filter(lambda p: p.grad is not None, self.get_params()):
            if self.state[p]["hessian step"] % self.update_each == 0:  # compute the trace only each `update_each` step
                params.append(p)
            self.state[p]["hessian step"] += 1

        if len(params) == 0:
            return

        if self.generator.device != params[0].device:  # hackish way of casting the generator to the right device
            self.generator = torch.Generator(params[0].device).manual_seed(2147483647)

        grads = [p.grad for p in params]

        for i in range(self.n_samples):
            zs = [torch.randint(0, 2, p.size(), generator=self.generator, device=p.device) * 2.0 - 1.0 for p in params]  # Rademacher distribution {-1.0, 1.0}
            h_zs = torch.autograd.grad(grads, params, grad_outputs=zs, only_inputs=True, retain_graph=i < self.n_samples - 1)
            for h_z, z, p in zip(h_zs, zs, params):
                p.hess += h_z * z / self.n_samples  # approximate the expected values of z*(H@z)

    @torch.no_grad()
    def step(self, closure=None):
        """
        Performs a single optimization step.
        Arguments:
            closure (callable, optional) -- a closure that reevaluates the model and returns the loss (default: None)
        """

        loss = None
        if closure is not None:
            loss = closure()

        self.zero_hessian()
        self.set_hessian()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None or p.hess is None:
                    continue

                if self.average_conv_kernel and p.dim() == 4:
                    p.hess = torch.abs(p.hess).mean(dim=[2, 3], keepdim=True).expand_as(p.hess).clone()

                # Perform correct stepweight decay as in AdamW
                p.mul_(1 - group['lr'] * group['weight_decay'])

                state = self.state[p]

                # State initialization
                if len(state) == 1:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p.data)  # Exponential moving average of gradient values
                    state['exp_hessian_diag_sq'] = torch.zeros_like(p.data)  # Exponential moving average of Hessian diagonal square values

                exp_avg, exp_hessian_diag_sq = state['exp_avg'], state['exp_hessian_diag_sq']
                beta1, beta2 = group['betas']
                state['step'] += 1

                # Decay the first and second moment running average coefficient
                exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1)
                exp_hessian_diag_sq.mul_(beta2).addcmul_(p.hess, p.hess, value=1 - beta2)

                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']

                k = group['hessian_power']
                denom = (exp_hessian_diag_sq / bias_correction2).pow_(k / 2).add_(group['eps'])

                # make update
                step_size = group['lr'] / bias_correction1
                p.addcdiv_(exp_avg, denom, value=-step_size)

        return loss

## OASIS

(modified from https://github.com/davda54/ada-hessian/blob/master/ada_hessian.py)

In [None]:
class OASIS(torch.optim.Optimizer):
    """
    Implements the OASIS algorithm from "Doubly Adaptive Scaled Algorithm for Machine Learning Using Second-Order Information"

    Arguments:
        params (iterable) -- iterable of parameters to optimize or dicts defining parameter groups
        lr (float, optional) -- learning rate (default: 0.1)
        betas ((float, float), optional) -- coefficients used for computing running averages of gradient and the squared hessian trace (default: (0.9, 0.999))
        eps (float, optional) -- term added to the denominator to improve numerical stability (default: 1e-8)
        weight_decay (float, optional) -- weight decay (L2 penalty) (default: 0.0)
        hessian_power (float, optional) -- exponent of the hessian trace (default: 1.0)
        update_each (int, optional) -- compute the hessian trace approximation only after *this* number of steps (to save time) (default: 1)
        n_samples (int, optional) -- how many times to sample `z` for the approximation of the hessian trace (default: 1)
    """

    def __init__(self, params, lr=0.1, betas=(0.9, 0.999), eps=1e-2, weight_decay=0.0,
                 hessian_power=1.0, update_each=1, n_samples=1, average_conv_kernel=False):
        if not 0.0 <= lr:
            raise ValueError(f"Invalid learning rate: {lr}")
        if not 0.0 <= eps:
            raise ValueError(f"Invalid epsilon value: {eps}")
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
        if not 0.0 <= hessian_power <= 1.0:
            raise ValueError(f"Invalid Hessian power value: {hessian_power}")

        self.n_samples = n_samples
        self.update_each = update_each
        self.average_conv_kernel = average_conv_kernel

        # use a separate generator that deterministically generates the same `z`s across all GPUs in case of distributed training
        self.generator = torch.Generator().manual_seed(2147483647)

        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, hessian_power=hessian_power)
        super().__init__(params, defaults)

        for p in self.get_params():
            p.hess = 0.0
            self.state[p]["hessian step"] = 0

    def get_params(self):
        """
        Gets all parameters in all param_groups with gradients
        """

        return (p for group in self.param_groups for p in group['params'] if p.requires_grad)

    def zero_hessian(self):
        """
        Zeros out the accumalated hessian traces.
        """

        for p in self.get_params():
            if not isinstance(p.hess, float) and self.state[p]["hessian step"] % self.update_each == 0:
                p.hess.zero_()

    @torch.no_grad()
    def set_hessian(self):
        """
        Computes the Hutchinson approximation of the hessian trace and accumulates it for each trainable parameter.
        """

        params = []
        for p in filter(lambda p: p.grad is not None, self.get_params()):
            if self.state[p]["hessian step"] % self.update_each == 0:  # compute the trace only each `update_each` step
                params.append(p)
            self.state[p]["hessian step"] += 1

        if len(params) == 0:
            return

        if self.generator.device != params[0].device:  # hackish way of casting the generator to the right device
            self.generator = torch.Generator(params[0].device).manual_seed(2147483647)

        grads = [p.grad for p in params]

        for i in range(self.n_samples):
            zs = [torch.randint(0, 2, p.size(), generator=self.generator, device=p.device) * 2.0 - 1.0 for p in params]  # Rademacher distribution {-1.0, 1.0}
            h_zs = torch.autograd.grad(grads, params, grad_outputs=zs, only_inputs=True, retain_graph=i < self.n_samples - 1)
            for h_z, z, p in zip(h_zs, zs, params):
                p.hess += h_z * z / self.n_samples  # approximate the expected values of z*(H@z)

    @torch.no_grad()
    def step(self, closure=None):
        """
        Performs a single optimization step.
        Arguments:
            closure (callable, optional) -- a closure that reevaluates the model and returns the loss (default: None)
        """

        loss = None
        if closure is not None:
            loss = closure()

        self.zero_hessian()
        self.set_hessian()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None or p.hess is None:
                    continue

                if self.average_conv_kernel and p.dim() == 4:
                    p.hess = torch.abs(p.hess).mean(dim=[2, 3], keepdim=True).expand_as(p.hess).clone()

                # Perform correct stepweight decay as in AdamW
                p.mul_(1 - group['lr'] * group['weight_decay'])

                state = self.state[p]

                # State initialization
                if len(state) == 1:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p.data)  # Exponential moving average of gradient values
                    state['exp_hessian_diag_sq'] = torch.zeros_like(p.data)  # Exponential moving average of Hessian diagonal square values

                exp_avg, exp_hessian_diag_sq = state['exp_avg'], state['exp_hessian_diag_sq']
                beta1, beta2 = group['betas']
                state['step'] += 1

                # Decay the first and second moment running average coefficient
                exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1)
                #exp_hessian_diag_sq.mul_(beta2).addcmul_(p.hess, p.hess, value=1 - beta2)
                exp_hessian_diag_sq.mul_(beta2).add_(p.hess, alpha=1 - beta2)

                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']

                k = group['hessian_power']
                #denom = (exp_hessian_diag_sq / bias_correction2).pow_(k / 2).add_(group['eps'])
                denom = (exp_hessian_diag_sq).abs_().clamp_(min=group['eps'])

                # make update
                step_size = group['lr'] / bias_correction1
                p.addcdiv_(exp_avg, denom, value=-step_size)

        return loss

# Experiment setup (from the paper)

## Helpers

In [None]:
def synth_func(k: int, x_k: nn.Parameter):
  if k % 101 == 1:
    return (1010 * x_k).sum()
  else:
    return (-10 * x_k).sum()

In [None]:
def get_dataloaders(dataset_name: str,
                    batch_size: int,
                    **kwargs):
    if dataset_name == 'mnist':
      dataset = MNIST
      transform = transforms.Compose(
          [transforms.ToTensor(),
          transforms.Normalize((0.1307), (0.3081))])
    elif dataset_name == 'cifar10':
      dataset = CIFAR10
      transform = transforms.Compose(
          [transforms.ToTensor(),
          transforms.Normalize((0.4914, 0.4822 ,0.4465),
                               (0.2470, 0.2435, 0.2616))])
    else:
      raise NotImplementedError("The dataset is not supported!")

    train_loader = DataLoader(
      dataset(
          './data', train=True, download=True, transform=transform
      ),
      batch_size=batch_size, shuffle=True, **kwargs
    )

    test_loader = DataLoader(
      dataset(
          './data', train=False, transform=transform
      ),
      batch_size=batch_size, shuffle=True, **kwargs
    )
    return train_loader, test_loader

In [None]:
class Experimentation:
    def __init__(self,
                 config,
                 train_loader,
                 test_loader,
                 model: nn.Module,
                 optimizer = None,
                 num_epochs = 10,
                 device = 'cpu'):
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.model = model
        self.criterion = torch.nn.CrossEntropyLoss()
        self.num_epochs = num_epochs
        self.device = device
        self.optimizer = optimizer

    def _train_one_epoch(self):
        self.model.train()
        for i, batch in enumerate(tqdm(self.train_loader)):
            self.optimizer.zero_grad()
            x, y_true = batch
            x = x.to(self.device)
            y_true = y_true.to(self.device)

            y_pred = self.model(x)
            loss = self.criterion(y_pred, y_true)
            loss.backward(create_graph=True)
            self.optimizer.step()
            wandb.log({"train_loss": loss.item()})

    @torch.no_grad
    def _validate(self, loader):
        total_loss = 0
        self.model.eval()
        for i, batch in tqdm(enumerate(loader)):
            x, y_true = batch
            x = x.to(self.device)
            y_true = y_true.to(self.device)

            y_pred = self.model(x)
            loss = self.criterion(y_pred, y_true)
            total_loss += loss.item()
        total_loss /= len(loader)
        return total_loss

    def experiment(self):
        for epoch in range(self.num_epochs):
            print(f"Epoch: {epoch}")
            self._train_one_epoch()
            test_loss = self._validate(self.test_loader)
            wandb.log({"test_loss": test_loss})

## Models

In [None]:
class LogisticRegression(nn.Module):
    def __init__(self, input_dim, num_classes):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        x = self.flatten(x)
        x = self.linear(x)
        return x

In [None]:
class FullyConnectedNetwork(nn.Module):
    def __init__(self,
                 input_dim,
                 num_classes,
                 hidden_dim=100):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear1 = nn.Linear(input_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = self.flatten(x)
        x = F.relu(self.linear1(x))
        x = self.linear2(x)
        return x

In [None]:
class CifarNet(nn.Module):
    def __init__(self,
                 input_dim,
                 num_classes,
                 hidden_channels=64):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels=3,
                               out_channels=hidden_channels,
                               kernel_size=3,
                               padding=1,
                               bias=False)
        self.pool1 = nn.MaxPool2d(kernel_size=2)
        self.bn1 = nn.BatchNorm2d(hidden_channels)

        self.conv2 = nn.Conv2d(in_channels=hidden_channels,
                               out_channels=hidden_channels,
                               kernel_size=3,
                               padding=1,
                               bias=False)
        self.pool2 = nn.MaxPool2d(kernel_size=2)
        self.bn2 = nn.BatchNorm2d(hidden_channels)

        self.flatten = nn.Flatten()

        self.mlp = nn.Sequential(
            nn.Linear((input_dim // 4)**2 * hidden_channels, 384, bias=False),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(384, 192, bias=False),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(192, num_classes, bias=False)
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool1(x)
        x = F.relu(x)
        x = self.bn1(x)

        x = self.conv2(x)
        x = self.pool2(x)
        x = F.relu(x)
        x = self.bn2(x)

        x = self.flatten(x)
        x = self.mlp(x)
        return x

## Grid search

In [None]:
def get_sweep_config(name: str) -> Dict:
    """Defines the config for grid search.

    Args:
        name: The name of a sweep/experiment.

    Returns:
        A sweep config.
    """
    sweep_config = {
        'method': 'random',
        'name': name
    }

    ######################
    metric = {
        'name': 'test_loss',
        'goal': 'minimize'
    }
    sweep_config['metric'] = metric

    ######################
    parameters_dict = {
        'lr': {
            'values': [1e-2, 1e-3, 3e-4, 1e-4]
        },
        'beta2': {
            'values': [0.99, 0.999]
        }
    }
    sweep_config['parameters'] = parameters_dict

    return sweep_config

In [None]:
def get_optimizer(optimizer_name: str,
                  model,
                  config: Dict) -> Optimizer:
    """Constructs an optimizer.

    Args:
        optimizer_name: The name of an optimizer.
        model: An instance of the model.
        config: The dictionary with optimizer hyperparameters.

    Returns:
        An instance of optimizer
    """
    if optimizer_name == "Adam":
        optimizer = Adam(model.parameters(),
                         lr=config.lr,
                         betas=(0.9, config.beta2))
    elif optimizer_name == "AMSGrad":
        optimizer = Adam(model.parameters(),
                         lr=config.lr,
                         betas=(0.9, config.beta2),
                         amsgrad=True)
    elif optimizer_name == "OASIS":
        optimizer = OASIS(model.parameters(),
                          lr=config.lr,
                          betas=(0.9, config.beta2))
    else:
        raise NotImplementedError("This optimizer is not supported!")
    return optimizer

In [None]:
def get_model(model_name: str,
              input_dim) -> torch.nn.Module:
    """Constructs a model.

    Args:
        model_name: "FullyConnectedNetwork" or "CifarNet".
        input_dim: The input dimension.

    Returns:
        An instance of the model.
    """
    if model_name == "FullyConnectedNetwork":
        model = FullyConnectedNetwork(input_dim, 10).to(device)
    elif model_name == "CifarNet":
        model = CifarNet(32, 10).to(device)
    else:
        raise NotImplementedError("This model is not supported!")
    return model

In [None]:
def run_grid_search(dataset_name: str,
                    model_name: str,
                    optimizer_name: str,
                    config=None):
    with wandb.init(config=config):
        config = wandb.config
        if dataset_name == "mnist":
            input_dim = 1 * 28**2
        else:
            input_dim = 3 * 32**2
        model = get_model(model_name, input_dim)
        optimizer = get_optimizer(optimizer_name, model, config)
        train_loader, test_loader = get_dataloaders(dataset_name=dataset_name,
                                                    batch_size=BATCH_SIZE)
        baseline_experiment = Experimentation(config,
                                              train_loader,
                                              test_loader,
                                              model,
                                              optimizer=optimizer,
                                              device=device)
        baseline_experiment.experiment()

# MNIST

## Fully Connected Network

### Run Adam

In [None]:
wandb.login()

In [None]:
sweep_config = get_sweep_config(name="mnist-fcn-adam")

In [None]:
sweep_id = wandb.sweep(sweep_config, entity="beyondadam", project="adaptive_methods")

In [None]:
dataset_name = "mnist"
model_name = "FullyConnectedNetwork"
optimizer_name = "Adam"

In [None]:
wandb.agent(sweep_id,
            partial(run_grid_search,
                    dataset_name=dataset_name,
                    model_name=model_name,
                    optimizer_name=optimizer_name),
            count=5)

### Run AMSGrad

In [None]:
wandb.login()

In [None]:
sweep_config = get_sweep_config(name="mnist-fcn-amsgrad")

In [None]:
sweep_id = wandb.sweep(sweep_config, entity="beyondadam", project="adaptive_methods")

In [None]:
dataset_name = "mnist"
model_name = "FullyConnectedNetwork"
optimizer_name = "AMSGrad"

In [None]:
wandb.agent(sweep_id,
            partial(run_grid_search,
                    dataset_name=dataset_name,
                    model_name=model_name,
                    optimizer_name=optimizer_name),
            count=5)

### Run OASIS

In [None]:
wandb.login()

In [None]:
sweep_config = get_sweep_config(name="mnist-fcn-oasis")

In [None]:
sweep_id = wandb.sweep(sweep_config, entity="beyondadam", project="adaptive_methods")

In [None]:
dataset_name = "mnist"
model_name = "FullyConnectedNetwork"
optimizer_name = "OASIS"

In [None]:
wandb.agent(sweep_id,
            partial(run_grid_search,
                    dataset_name=dataset_name,
                    model_name=model_name,
                    optimizer_name=optimizer_name),
            count=5)

# CIFAR10

## CifarNet

### Run Adam

In [None]:
wandb.login()

In [None]:
sweep_config = get_sweep_config(name="cifar-cifarnet-adam")

In [None]:
sweep_id = wandb.sweep(sweep_config, entity="beyondadam", project="adaptive_methods")

In [None]:
dataset_name = "cifar10"
model_name = "CifarNet"
optimizer_name = "Adam"

In [None]:
wandb.agent(sweep_id,
            partial(run_grid_search,
                    dataset_name=dataset_name,
                    model_name=model_name,
                    optimizer_name=optimizer_name),
            count=5)

### Run AMSGrad

In [None]:
wandb.login()

In [None]:
sweep_config = get_sweep_config(name="cifar-cifarnet-amsgrad")

In [None]:
sweep_id = wandb.sweep(sweep_config, entity="beyondadam", project="adaptive_methods")

In [None]:
dataset_name = "cifar10"
model_name = "CifarNet"
optimizer_name = "AMSGrad"

In [None]:
wandb.agent(sweep_id,
            partial(run_grid_search,
                    dataset_name=dataset_name,
                    model_name=model_name,
                    optimizer_name=optimizer_name),
            count=5)

### Run OASIS

In [None]:
wandb.login()

In [None]:
sweep_config = get_sweep_config(name="cifar-cifarnet-oasis")

In [None]:
sweep_id = wandb.sweep(sweep_config, entity="beyondadam", project="adaptive_methods")

In [None]:
dataset_name = "cifar10"
model_name = "CifarNet"
optimizer_name = "OASIS"

In [None]:
wandb.agent(sweep_id,
            partial(run_grid_search,
                    dataset_name=dataset_name,
                    model_name=model_name,
                    optimizer_name=optimizer_name),
            count=5)