## Gradient Descent

In [1]:
import torch
import torch.nn.functional as F

In [2]:
target = torch.FloatTensor([[.1, .2, .3],
                            [.4, .5, .6],
                            [.7, .8, .9]])
target

tensor([[0.1000, 0.2000, 0.3000],
        [0.4000, 0.5000, 0.6000],
        [0.7000, 0.8000, 0.9000]])

In [5]:
x = torch.rand_like(target).requires_grad_(True)

## This means the final scalar will be differentiate by x.
## You can get gradient of x, after differentiation.
x

tensor([[0.5425, 0.6250, 0.6944],
        [0.2591, 0.2604, 0.1496],
        [0.1317, 0.5961, 0.9846]], requires_grad=True)

In [6]:
loss = F.mse_loss(x, target)

loss

tensor(0.1315, grad_fn=<MseLossBackward>)

In [7]:
threshold = 1e-5
learning_rate = 1.
iter_cnt = 0

while loss > threshold:
    iter_cnt += 1
    
    loss.backward() ## Calculate gradients.
    
    x = x - learning_rate * x.grad
    
    # You don't need to aware this now.
    x.detach_()
    x.requires_grad_(True)
    
    loss = F.mse_loss(x, target)
    
    print('%d-th Loss: %.4e' % (iter_cnt, loss))
    print(x)

1-th Loss: 7.9569e-02
tensor([[0.4442, 0.5305, 0.6068],
        [0.2904, 0.3136, 0.2497],
        [0.2580, 0.6414, 0.9658]], requires_grad=True)
2-th Loss: 4.8134e-02
tensor([[0.3677, 0.4571, 0.5386],
        [0.3148, 0.3550, 0.3276],
        [0.3562, 0.6767, 0.9512]], requires_grad=True)
3-th Loss: 2.9118e-02
tensor([[0.3082, 0.3999, 0.4856],
        [0.3337, 0.3872, 0.3881],
        [0.4326, 0.7041, 0.9398]], requires_grad=True)
4-th Loss: 1.7615e-02
tensor([[0.2619, 0.3555, 0.4443],
        [0.3484, 0.4123, 0.4352],
        [0.4920, 0.7254, 0.9309]], requires_grad=True)
5-th Loss: 1.0656e-02
tensor([[0.2260, 0.3210, 0.4123],
        [0.3599, 0.4318, 0.4718],
        [0.5382, 0.7420, 0.9241]], requires_grad=True)
6-th Loss: 6.4462e-03
tensor([[0.1980, 0.2941, 0.3873],
        [0.3688, 0.4469, 0.5003],
        [0.5742, 0.7549, 0.9187]], requires_grad=True)
7-th Loss: 3.8995e-03
tensor([[0.1762, 0.2732, 0.3679],
        [0.3757, 0.4587, 0.5225],
        [0.6021, 0.7649, 0.9146]], requi