In [1]:
import torch
import torch.nn as nn
from d2l import torch as d2l
import numpy as np

In [2]:
class Sgd:  #这里没有除于batch_size 因为我们在损失函数里面除了batch_size（平均损失）
    def __init__(self, lr):
        self.lr = lr
    
    def update(self, params):
        """小批量随机梯度下降。"""
        with torch.no_grad():
            for key in params.keys():
                # param -= lr * param.grad
                params[key] -= self.lr * params[key].grad
                # 梯度清零
                params[key].grad.zero_()

class Adam:
    """Adam (http://arxiv.org/abs/1412.6980v8)"""
    def __init__(self, lr=0.001, beta1=0.9, beta2=0.999):
        self.lr = lr
        self.beta1 = beta1
        self.beta2 = beta2
        self.iter = 0
        self.m = None
        self.v = None
        
    def update(self, params):
        if self.m is None:
            self.m, self.v = {}, {}
            for key, val in params.items():
                self.m[key] = torch.zeros_like(val)
                self.v[key] = torch.zeros_like(val)
        
        self.iter += 1
        lr_t  = self.lr * torch.sqrt(torch.tensor(1.0) - self.beta2**self.iter) / (1.0 - self.beta1**self.iter)         
            
        with torch.no_grad():    
            for key in params.keys():
                self.m[key] += (1 - self.beta1) * (params[key].grad - self.m[key])
                self.v[key] += (1 - self.beta2) * (params[key].grad**2 - self.v[key])

                params[key] -= lr_t * self.m[key] / (torch.sqrt(self.v[key]) + 1e-7)
                params[key].grad.zero_()
        


In [3]:
def squared_loss(preds, labels): 
    """均方损失。"""
    return (preds - labels.reshape(preds.shape)) ** 2

In [4]:
# 测试一下上面的优化算法 下面是简单的线性回归
tmp_X = torch.normal(0, 0.01, size=(2, 5))
tmp_y = torch.ones(2)
tmp_w = nn.Parameter(torch.normal(0, 0.01, size=(5, 1)))
tmp_b = nn.Parameter(torch.zeros(1))
tmp_params = {'w': tmp_w, 'b': tmp_b}
tmp_num_epochs = 20
tmp_net = lambda X: torch.matmul(X, tmp_w) + tmp_b
tmp_loss = squared_loss
origin_w = nn.Parameter(tmp_w.data.clone())
origin_b = nn.Parameter(tmp_b.data.clone())
origin_params = {'w': origin_w, 'b': origin_b}
origin_net = lambda X: torch.matmul(X, origin_w) + origin_b

In [5]:
tmp_optimizer = Sgd(0.1)
for epoch in range(tmp_num_epochs):
    for X, y in d2l.load_array((tmp_X, tmp_y), 2):
        l = tmp_loss(tmp_net(X), y).mean()
        l.backward()
        tmp_optimizer.update(tmp_params)
        print('epoch %d, loss: %f' % (epoch + 1, l.item()))

epoch 1, loss: 0.999857
epoch 2, loss: 0.639847
epoch 3, loss: 0.409463
epoch 4, loss: 0.262031
epoch 5, loss: 0.167684


In [6]:
tmp_optimizer = Adam(0.1)
for epoch in range(tmp_num_epochs):
    for X, y in d2l.load_array((tmp_X, tmp_y), 2):
        l = tmp_loss(origin_net(X), y).mean()
        l.backward()
        tmp_optimizer.update(origin_params)
        print('epoch %d, loss: %f' % (epoch + 1, l.item()))

epoch 1, loss: 0.999857
epoch 2, loss: 0.804444
epoch 3, loss: 0.630984
epoch 4, loss: 0.479730
epoch 5, loss: 0.350718


```python
torch.sqrt(1.0 - 0.999**1)
```

会报错`TypeError: sqrt(): argument 'input' (position 1) must be Tensor, not float`

应该改为

```python
torch.sqrt(torch.tensor(1.0) - 0.999**1)
```