In [1]:
import torch as t
from torch import nn, optim
import numpy as np

from typing import Callable, Iterable, Tuple

import utils

In [2]:
def rosenbrocks_banana(x: t.Tensor, y: t.Tensor, a=1, b=100) -> t.Tensor:
    return (a - x) ** 2 + b * (y - x**2) ** 2 + 1

x_range = [-2, 2]
y_range = [-1, 3]
fig = utils.plot_fn(rosenbrocks_banana, x_range, y_range, log_scale=True, show_min=True)

In [3]:
def opt_fn_with_sgd(fn: Callable, xy: t.Tensor, lr=0.001, momentum=0.98, n_iters: int = 100):
    '''
    Optimize the a given function starting from the specified point.

    xy: shape (2,). The (x, y) starting point.
    n_iters: number of steps.

    Return: (n_iters, 2). The (x,y) BEFORE each step. So out[0] is the starting point.
    '''
    assert xy.requires_grad
    xys = t.zeros((n_iters, 2))
    optimizer = optim.SGD([xy], lr=lr, momentum=momentum)

    for i in range(n_iters):
        xys[i] = xy.detach()
        out = fn(xy[0], xy[1])
        out.backward()
        optimizer.step()
        optimizer.zero_grad()
    return xys

In [4]:
xy = t.tensor([-1.5, 2.5], requires_grad=True)
xys = opt_fn_with_sgd(rosenbrocks_banana, xy)

In [5]:
xy = t.tensor([-1.5, 2.5], requires_grad=True)
x_range = [-2, 2]
y_range = [-1, 3]

fig = utils.plot_optimization_sgd(opt_fn_with_sgd, rosenbrocks_banana, xy, x_range, y_range, lr=0.001, momentum=0.98, show_min=True)

# fig.show()

In [6]:
class SGD:
    params: list

    def __init__(self, params: Iterable[t.nn.parameter.Parameter], lr: float, momentum: float, weight_decay: float):
        '''Implements SGD with momentum.

        Like the PyTorch version, but assume nesterov=False, maximize=False, and dampening=0
            https://pytorch.org/docs/stable/generated/torch.optim.SGD.html#torch.optim.SGD
        '''
        self.params = list(params)
        self.lr = lr
        self.momentum = momentum
        self.weight_decay = weight_decay

        self.prev_grads = [t.zeros_like(p) for p in self.params]

    def zero_grad(self) -> None:
        for param in self.params:
            param.grad = None

    @t.inference_mode()
    def step(self) -> None:
        for i in range(len(self.params)):
            p = self.params[i]
            new_grad = p.grad
            prev_grad = self.prev_grads[i]

            if self.weight_decay != 0:
                new_grad = new_grad + self.weight_decay * p

            if self.momentum != 0:
                new_grad = self.momentum * prev_grad + new_grad

            p.sub_(self.lr * new_grad)
            self.prev_grads[i] = new_grad

    def __repr__(self) -> str:
        # Should return something reasonable here, e.g. "SGD(lr=lr, ...)"
        params_to_print = ["lr", "momentum", "weight_decay"]
        params_string = ', '.join(f'{p}={getattr(self, p)}' for p in params_to_print)
        return f'SGD({params_string})'

utils.test_sgd(SGD)


Testing configuration:  {'lr': 0.1, 'momentum': 0.0, 'weight_decay': 0.0}

Testing configuration:  {'lr': 0.1, 'momentum': 0.7, 'weight_decay': 0.0}

Testing configuration:  {'lr': 0.1, 'momentum': 0.5, 'weight_decay': 0.0}

Testing configuration:  {'lr': 0.1, 'momentum': 0.5, 'weight_decay': 0.05}

Testing configuration:  {'lr': 0.2, 'momentum': 0.8, 'weight_decay': 0.05}


In [7]:
class RMSprop:
    def __init__(
        self,
        params: Iterable[t.nn.parameter.Parameter],
        lr: float,
        alpha: float,
        weight_decay: float,
        momentum: float,
        eps: float = 1e-8,
    ):
        '''Implements RMSprop.

        Like the PyTorch version, but assumes centered=False
            https://pytorch.org/docs/stable/generated/torch.optim.RMSprop.html#torch.optim.RMSprop
        '''
        self.params = list(params)
        self.lr = lr
        self.alpha = alpha
        self.weight_decay = weight_decay
        self.momentum = momentum
        self.eps = eps

        self.vs = [t.zeros_like(p) for p in self.params]
        self.bs = [t.zeros_like(p) for p in self.params]

    def zero_grad(self) -> None:
        for param in self.params:
            param.grad = None

    @t.inference_mode()
    def step(self) -> None:
        lmbda = self.weight_decay
        mu = self.momentum

        for i in range(len(self.params)):
            p = self.params[i]
            new_grad = p.grad

            if lmbda != 0:
                new_grad = new_grad + lmbda * p

            v = self.alpha * self.vs[i] + (1 - self.alpha) * new_grad ** 2
            self.vs[i] = v

            if mu > 0:
                b = mu * self.bs[i] + new_grad / (t.sqrt(v) + self.eps)
                p.sub_(self.lr * b)
                self.bs[i] = b
            else:
                p.sub_(self.lr * new_grad / (t.sqrt(v) + self.eps))

    def __repr__(self) -> str:
        params_to_print = ["lr", "alpha", "weight_decay", "momentum", "eps"]
        params_string = ', '.join(f'{p}={getattr(self, p)}' for p in params_to_print)
        return f'RMSprop({params_string})'


utils.test_rmsprop(RMSprop)


Testing configuration:  {'lr': 0.1, 'alpha': 0.9, 'eps': 0.001, 'weight_decay': 0.0, 'momentum': 0.0}

Testing configuration:  {'lr': 0.1, 'alpha': 0.95, 'eps': 0.0001, 'weight_decay': 0.05, 'momentum': 0.0}

Testing configuration:  {'lr': 0.1, 'alpha': 0.95, 'eps': 0.0001, 'weight_decay': 0.05, 'momentum': 0.5}

Testing configuration:  {'lr': 0.1, 'alpha': 0.95, 'eps': 0.0001, 'weight_decay': 0.05, 'momentum': 0.0}


In [8]:
class Adam:
    def __init__(
        self,
        params: Iterable[t.nn.parameter.Parameter],
        lr: float,
        betas: Tuple[float, float],
        weight_decay: float,
        eps: float = 1e-8,
    ):
        '''Implements Adam.

        Like the PyTorch version, but assumes amsgrad=False and maximize=False
            https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam
        '''
        self.params = list(params)
        self.lr = lr
        self.betas = betas
        self.weight_decay = weight_decay
        self.eps = eps

        self.t = 0
        self.ms = [t.zeros_like(p) for p in self.params]
        self.vs = [t.zeros_like(p) for p in self.params]

    def zero_grad(self) -> None:
        for param in self.params:
            param.grad = None

    @t.inference_mode()
    def step(self) -> None:
        self.t += 1
        b1 = self.betas[0]
        b2 = self.betas[1]
        lmbda = self.weight_decay
        for i in range(len(self.params)):
            p = self.params[i]
            g = p.grad

            if lmbda != 0:
                g = g + lmbda * p
            
            m = b1 * self.ms[i] + (1 - b1) * g
            v = b2 * self.vs[i] + (1 - b2) * g ** 2
            self.ms[i] = m
            self.vs[i] = v

            mhat = m / (1 - b1 ** self.t)
            vhat = v / (1 - b2 ** self.t)

            p.sub_((self.lr * mhat) / (t.sqrt(vhat) + self.eps))

    def __repr__(self) -> str:
        params_to_print = ["lr", "betas", "eps", "weight_decay"]
        params_string = ', '.join(f'{p}={getattr(self, p)}' for p in params_to_print)
        return f'Adam({params_string})'


utils.test_adam(Adam)


Testing configuration:  {'lr': 0.1, 'betas': (0.8, 0.95), 'eps': 0.001, 'weight_decay': 0.0}

Testing configuration:  {'lr': 0.1, 'betas': (0.8, 0.9), 'eps': 0.001, 'weight_decay': 0.05}

Testing configuration:  {'lr': 0.2, 'betas': (0.9, 0.95), 'eps': 0.01, 'weight_decay': 0.08}


In [9]:
def opt_fn(fn: Callable, xy: t.Tensor, optimizer_class, optimizer_kwargs, n_iters: int = 100):
    '''Optimize the a given function starting from the specified point.

    optimizer_class: one of the optimizers you've defined, either SGD, RMSprop, or Adam
    optimzer_kwargs: keyword arguments passed to your optimiser (e.g. lr and weight_decay)
    '''
    assert xy.requires_grad
    xys = t.zeros((n_iters, 2))
    optimizer = optimizer_class([xy], **optimizer_kwargs)

    for i in range(n_iters):
        xys[i] = xy.detach()
        out = fn(xy[0], xy[1])
        out.backward()
        optimizer.step()
        optimizer.zero_grad()
    return xys

In [10]:
xy = t.tensor([-1.5, 2.5], requires_grad=True)
x_range = [-2, 2]
y_range = [-1, 3]
optimizers = [
    (SGD, {'lr': 1e-3, 'weight_decay': 0.0, 'momentum': 0.98}),
    (RMSprop, {'lr': 1e-1, 'alpha': 0.2, 'eps': 0.001, 'weight_decay': 0.0, 'momentum': 0.98}),
    (Adam, {'lr': 2e-1, 'betas': (0.8, 0.8), 'eps': 0.001, 'weight_decay': 0.0}),
]

fig = utils.plot_optimization(opt_fn, rosenbrocks_banana, xy, optimizers, x_range, y_range, show_min=True)

# fig.show()