In [1]:
from collections.abc import Callable, Iterable
import torch
import math

class SGD(torch.optim.Optimizer):
    def __init__(self, params, lr=1e-3):
        if lr < 0:
            raise ValueError(f"Invalid learning rate: {lr}")
        defaults = {"lr": lr}
        super().__init__(params, defaults)

    def step(self, closure: Callable | None = None):
        loss = None if closure is None else closure()
        for group in self.param_groups:
            lr = group["lr"] # Get the learning rate
            for p in group["params"]:
                if p.grad is None:
                    continue
                state = self.state[p] # Get state associated with p.
                t = state.get("t", 0) # Get iteration number from the state, or initial value.
                grad = p.grad.data # Get the gradient of loss with respect to p.
                p.data -= lr / math.sqrt(t + 1) * grad # Update weight tensor in-place.
                state["t"] = t + 1 # Increment iteration number.
        return loss

In [2]:
weights = 5 * torch.randn((10, 10))
def run_training_loop(weights: torch.nn.Parameter, lr: float, num_iter:int=10):
    opt = SGD([weights], lr=lr)
    for t in range(num_iter):
        opt.zero_grad() # Reset the gradients for all learnable parameters.
        loss = (weights**2).mean() # Compute a scalar loss value.
        print(f"LR: {lr}, Iter {t}: {loss.cpu().item()}")
        loss.backward() # Run backward pass, which computes gradients.
        opt.step() # Run optimizer step

In [3]:
run_training_loop(torch.nn.Parameter(weights.clone()), 1)

LR: 1, Iter 0: 18.745744705200195
LR: 1, Iter 1: 18.0034122467041
LR: 1, Iter 2: 17.497798919677734
LR: 1, Iter 3: 17.096038818359375
LR: 1, Iter 4: 16.755826950073242
LR: 1, Iter 5: 16.457427978515625
LR: 1, Iter 6: 16.18977928161621
LR: 1, Iter 7: 15.945937156677246
LR: 1, Iter 8: 15.72122573852539
LR: 1, Iter 9: 15.512308120727539


In [4]:
run_training_loop(torch.nn.Parameter(weights.clone()), 10)

LR: 10, Iter 0: 18.745744705200195
LR: 10, Iter 1: 11.997278213500977
LR: 10, Iter 2: 8.843878746032715
LR: 10, Iter 3: 6.919390678405762
LR: 10, Iter 4: 5.60470724105835
LR: 10, Iter 5: 4.64694356918335
LR: 10, Iter 6: 3.919081211090088
LR: 10, Iter 7: 3.348966360092163
LR: 10, Iter 8: 2.8920960426330566
LR: 10, Iter 9: 2.5193369388580322


In [5]:
run_training_loop(torch.nn.Parameter(weights.clone()), 100)

LR: 100, Iter 0: 18.745744705200195
LR: 100, Iter 1: 18.745742797851562
LR: 100, Iter 2: 3.2162604331970215
LR: 100, Iter 3: 0.0769723504781723
LR: 100, Iter 4: 1.3659212279264465e-16
LR: 100, Iter 5: 1.5224041005151913e-18
LR: 100, Iter 6: 5.1264699408732886e-20
LR: 100, Iter 7: 3.053873475311743e-21
LR: 100, Iter 8: 2.6198100954648563e-22
LR: 100, Iter 9: 2.9109002112535165e-23


In [6]:
run_training_loop(torch.nn.Parameter(weights.clone()), 1000)

LR: 1000, Iter 0: 18.745744705200195
LR: 1000, Iter 1: 6767.212890625
LR: 1000, Iter 2: 1168804.125
LR: 1000, Iter 3: 130016984.0
LR: 1000, Iter 4: 10531376128.0
LR: 1000, Iter 5: 664650317824.0
LR: 1000, Iter 6: 34120975515648.0
LR: 1000, Iter 7: 1468030290755584.0
LR: 1000, Iter 8: 5.410846614644326e+16
LR: 1000, Iter 9: 1.7374829088279101e+18


Observation of above results: the higher the learning rate, the faster the rate of decrease in loss. 
However, if the learning rate is too high, we will overshoot the minima and end up oscillating, which is what we observe with a lr of 1000.