In [None]:
import math
from collections.abc import Callable

import torch


class SGD(torch.optim.Optimizer):
    def __init__(self, params, lr=1e-3):
        if lr < 0:
            raise ValueError(f"Invalid learning rate: {lr}")
        defaults = {"lr": lr}
        super().__init__(params, defaults)

    def step(self, closure: Callable | None = None):
        loss = None if closure is None else closure()
        for group in self.param_groups:
            lr = group["lr"]  # Get the learning rate.
            for p in group["params"]:
                if p.grad is None:
                    continue

                state = self.state[p]  # Get state associated with p.
                t = state.get("t", 0)  # Get iteration number from the state, or initial value.
                grad = p.grad.data  # Get the gradient of loss with respect to p.
                p.data -= lr / math.sqrt(t + 1) * grad  # Update weight tensor in-place.
                state["t"] = t + 1  # Increment iteration number.

        return loss

In [None]:
import copy

LEARNING_RATES = (1, 1e1, 1e2, 1e3)
WEIGHTS = torch.nn.Parameter(5 * torch.randn((10, 10)))

results = {}

for lr in LEARNING_RATES:
    weights = copy.deepcopy(WEIGHTS)
    opt = SGD([weights], lr=lr)
    results[lr] = []
    
    for t in range(100):
        opt.zero_grad() # Reset the gradients for all learnable parameters.
        loss = (weights**2).mean() # Compute a scalar loss value.
        results[lr].append(loss.cpu().item())
        loss.backward() # Run backward pass, which computes gradients.
        opt.step() # Run optimizer step.

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(8, 5))
plt.yscale("log")
for lr, losses in results.items():
    plt.plot(range(len(losses)), losses, label=f"lr={lr:g}")

plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.title("Loss vs Iteration for different learning rates")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
results[1000]