In [None]:
from collections.abc import Callable, Iterable
from typing import Optional
import torch
import math

class SGD(torch.optim.Optimizer):
    def __init__(self, params, lr=1e-3):
        # if lr < 0:
        #     raise ValueError(f"Invalid learning rate: {lr}")
        defaults = {"lr": lr}
        super().__init__(params, defaults)
    def step(self, closure: Optional[Callable] = None):
        loss = None if closure is None else closure()
        for group in self.param_groups:
            lr = group["lr"] 
            for p in group["params"]:
                if p.grad is None:
                    continue
                state = self.state[p] 
                t = state.get("t", 0) 
                grad = p.grad.data 
                p.data -= lr / math.sqrt(t + 1) * grad 
                state["t"] = t + 1 
        return loss

In [2]:
weights = torch.nn.Parameter(5 * torch.randn((10, 10)))
opt = SGD([weights], lr=1)
for t in range(100):
    opt.zero_grad() # Reset the gradients for all learnable parameters.
    loss = (weights**2).mean() # Compute a scalar loss value.
    print(loss.cpu().item())
    loss.backward() # Run backward pass, which computes gradients.
    opt.step() # Run optimizer step.

29.163795471191406
28.0089111328125
27.222299575805664
26.597261428833008
26.067970275878906
25.603740692138672
25.187339782714844
24.8079833984375
24.458383560180664
24.133359909057617
23.829057693481445
23.542539596557617
23.271474838256836
23.014019012451172
22.768644332885742
22.534099578857422
22.3093204498291
22.093412399291992
21.885604858398438
21.685230255126953
21.49170684814453
21.304519653320312
21.12322235107422
20.947410583496094
20.776723861694336
20.610841751098633
20.449474334716797
20.2923583984375
20.139251708984375
19.989938735961914
19.844219207763672
19.7019100189209
19.562843322753906
19.426862716674805
19.293821334838867
19.163593292236328
19.036048889160156
18.911075592041016
18.788562774658203
18.668413162231445
18.55052947998047
18.434825897216797
18.32122039794922
18.209630966186523
18.099987030029297
17.992223739624023
17.886266708374023
17.782060623168945
17.679546356201172
17.578662872314453
17.47936248779297
17.3815975189209
17.28531265258789
17.19047164

In [6]:
weights = torch.nn.Parameter(5 * torch.randn((10, 10)))
opt = SGD([weights], lr=1e1)
for t in range(10):
    opt.zero_grad() # Reset the gradients for all learnable parameters.
    loss = (weights**2).mean() # Compute a scalar loss value.
    print(loss.cpu().item())
    loss.backward() # Run backward pass, which computes gradients.
    opt.step() # Run optimizer step.

28.669509887695312
18.348485946655273
13.525718688964844
10.58243179321289
8.571769714355469
7.106978893280029
5.9937944412231445
5.121867656707764
4.423135757446289
3.8530426025390625


In [7]:
weights = torch.nn.Parameter(5 * torch.randn((10, 10)))
opt = SGD([weights], lr=1e2)
for t in range(10):
    opt.zero_grad() # Reset the gradients for all learnable parameters.
    loss = (weights**2).mean() # Compute a scalar loss value.
    print(loss.cpu().item())
    loss.backward() # Run backward pass, which computes gradients.
    opt.step() # Run optimizer step.

24.986722946166992
24.98672103881836
4.28704309463501
0.10259860754013062
9.912209835852357e-17
1.1047772562469953e-18
3.7201740422407775e-20
2.2161331833032014e-21
1.9011422860258998e-22
2.1123801775646166e-23


In [8]:
weights = torch.nn.Parameter(5 * torch.randn((10, 10)))
opt = SGD([weights], lr=1e3)
for t in range(10):
    opt.zero_grad() # Reset the gradients for all learnable parameters.
    loss = (weights**2).mean() # Compute a scalar loss value.
    print(loss.cpu().item())
    loss.backward() # Run backward pass, which computes gradients.
    opt.step() # Run optimizer step.

28.053892135620117
10127.4541015625
1749170.75
194576576.0
15760702464.0
994680700928.0
51063654711296.0
2196977106288640.0
8.097588746964173e+16
2.6002254278007194e+18
