In [1]:
from collections.abc import Callable, Iterable
import torch
import math

class SGD(torch.optim.Optimizer):
    def __init__(self, params, lr=1e-3):
        if lr < 0:
            raise ValueError(f"Invalid learning rate: {lr}")
        defaults = {"lr": lr}
        super().__init__(params, defaults)

    def step(self, closure: Callable | None = None):
        loss = None if closure is None else closure()
        for group in self.param_groups:
            lr = group["lr"] # Get the learning rate
            for p in group["params"]:
                if p.grad is None:
                    continue
                state = self.state[p] # Get state associated with p.
                t = state.get("t", 0) # Get iteration number from the state, or initial value.
                grad = p.grad.data # Get the gradient of loss with respect to p.
                p.data -= lr / math.sqrt(t + 1) * grad # Update weight tensor in-place.
                state["t"] = t + 1 # Increment iteration number.
        return loss

In [None]:
weights = 5 * torch.randn((10, 10))
def run_training_loop(weights: torch.nn.Parameter, lr: float, num_iter:int=10):
    opt = SGD([weights], lr=lr)
    for t in range(num_iter):
        opt.zero_grad() # Reset the gradients for all learnable parameters.
        loss = (weights**2).mean() # Compute a scalar loss value.
        print(f"LR: {lr}, Iter {num_iter}: {loss.cpu().item()}")
        loss.backward() # Run backward pass, which computes gradients.
        opt.step() # Run optimizer step

In [4]:
run_training_loop(torch.nn.Parameter(weights.clone()), 1)

LR: 1, Iter 10: 20.35976219177246
LR: 1, Iter 10: 19.553516387939453
LR: 1, Iter 10: 19.004369735717773
LR: 1, Iter 10: 18.568017959594727
LR: 1, Iter 10: 18.198514938354492
LR: 1, Iter 10: 17.874425888061523
LR: 1, Iter 10: 17.583728790283203
LR: 1, Iter 10: 17.318893432617188
LR: 1, Iter 10: 17.074832916259766
LR: 1, Iter 10: 16.84792709350586


In [5]:
run_training_loop(torch.nn.Parameter(weights.clone()), 10)

LR: 10, Iter 10: 20.35976219177246
LR: 10, Iter 10: 13.030248641967773
LR: 10, Iter 10: 9.605342864990234
LR: 10, Iter 10: 7.515154838562012
LR: 10, Iter 10: 6.087275505065918
LR: 10, Iter 10: 5.047048568725586
LR: 10, Iter 10: 4.256515979766846
LR: 10, Iter 10: 3.637314558029175
LR: 10, Iter 10: 3.1411068439483643
LR: 10, Iter 10: 2.736253261566162


In [6]:
run_training_loop(torch.nn.Parameter(weights.clone()), 100)

LR: 100, Iter 10: 20.35976219177246
LR: 100, Iter 10: 20.35976219177246
LR: 100, Iter 10: 3.493182420730591
LR: 100, Iter 10: 0.08359973132610321
LR: 100, Iter 10: 8.72220332208744e-17
LR: 100, Iter 10: 9.72143704710216e-19
LR: 100, Iter 10: 3.273549825837561e-20
LR: 100, Iter 10: 1.9500762746098454e-21
LR: 100, Iter 10: 1.6729012234271768e-22
LR: 100, Iter 10: 1.8587793825647003e-23


In [7]:
run_training_loop(torch.nn.Parameter(weights.clone()), 1000)

LR: 1000, Iter 10: 20.35976219177246
LR: 1000, Iter 10: 7349.87451171875
LR: 1000, Iter 10: 1269438.875
LR: 1000, Iter 10: 141211520.0
LR: 1000, Iter 10: 11438132224.0
LR: 1000, Iter 10: 721877139456.0
LR: 1000, Iter 10: 37058814935040.0
LR: 1000, Iter 10: 1594428762357760.0
LR: 1000, Iter 10: 5.876724294221824e+16
LR: 1000, Iter 10: 1.887081361391485e+18


Observation of above results: the higher the learning rate, the faster the rate of decrease in loss. 
However, if the learning rate is too high, we will overshoot the minima and end up oscillating, which is what we observe with a lr of 1000.