In [1]:
import torch
from torch import Tensor
from jaxtyping import Float, Int

from cs336_basics.optimizer import ToySGD, AdamW

In [9]:
torch.manual_seed(42)
weights = torch.nn.Parameter(5 * torch.randn(2, 3))
opt = AdamW(params=[weights], lr=5e-1)

for t in range(10):
    opt.zero_grad()
    loss = (weights ** 2).mean()
    print(loss.item())
    loss.backward()
    opt.step()

6.389593601226807
4.725747585296631
3.5785844326019287
2.8173627853393555
2.2594964504241943
1.780196189880371
1.3381675481796265
0.9416506290435791
0.6101463437080383
0.3570791184902191


In [48]:
from math import sqrt
from collections.abc import Iterable
from typing import Any, Callable, Optional
import torch
from torch.optim import Optimizer
from torch import Tensor
from jaxtyping import Float, Bool, Int


class AdamW(Optimizer):
    def __init__(self, params: Iterable, lr: float = 1e-3, beta1: float = 0.9, beta2: float = 0.999, eps: float = 1e-8, lmbda: float = 1e-2) -> None:
        '''
        lr: learning rate
        lmbda: weight decay rate
        '''
        defaults = {
            'lr': lr,
            'beta1': beta1,
            'beta2': beta2,
            'eps': eps,
            'lmbda': lmbda,
        }
        super().__init__(params, defaults)

    def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
        loss = None if closure is None else closure()
        for param_group in self.param_groups:
            lr = param_group['lr']
            beta1, beta2 = param_group['beta1'], param_group['beta2']
            eps = param_group['eps']
            lmbda = param_group['lmbda']
            # p of type torch.nn.parameter.Parameter
            for p in param_group['params']:
                if p.grad is None:
                    continue

                state = self.state[p]
                t = state.get('t', 1)
                m = state.get('m', torch.zeros_like(p))
                v = state.get('v', torch.zeros_like(p))
                
                grad = p.grad.data
                m = beta1 * m + (1-beta1) * grad
                v = beta2 * v + (1-beta2) * torch.square(grad)
                lr_t = lr * sqrt(1-beta2 ** t) / (1-beta1**t)
                p.data -= lr_t * m / (torch.sqrt(v) + eps)
                p.data -= lr * lmbda * p.data

                state['t'] = t + 1
                state['m'] = m
                state['v'] = v

        return loss

In [3]:
torch.zeros_like(torch.arange(0, 6).reshape(2, 3))

tensor([[0, 0, 0],
        [0, 0, 0]])

In [6]:
t = torch.tensor([2, 3], dtype=torch.float32, requires_grad=True)
type(t.data)

torch.Tensor

In [7]:
help(torch.square)

Help on built-in function square in module torch:

square(...)
    square(input, *, out=None) -> Tensor

    Returns a new tensor with the square of the elements of :attr:`input`.

    Args:
        input (Tensor): the input tensor.

    Keyword args:
        out (Tensor, optional): the output tensor.

    Example::

        >>> a = torch.randn(4)
        >>> a
        tensor([-2.0755,  1.0226,  0.0831,  0.4806])
        >>> torch.square(a)
        tensor([ 4.3077,  1.0457,  0.0069,  0.2310])



In [11]:
torch.sqrt(torch.tensor(4))

tensor(2.)