In [1]:
import torch
import numpy as np

In [2]:
inputs = torch.tensor([[11. , 22. , 33.],
                       [44. , 55. , 66.],
                       [77. , 25. , 68.],
                       [37. , 31. , 42.]])
inputs = inputs / torch.max(inputs)
inputs

tensor([[0.1429, 0.2857, 0.4286],
        [0.5714, 0.7143, 0.8571],
        [1.0000, 0.3247, 0.8831],
        [0.4805, 0.4026, 0.5455]])

In [3]:
targets = torch.tensor([[45. , 70.],
                        [27. , 57.],
                        [34. , 77.],
                        [41. , 68.]])
targets

tensor([[45., 70.],
        [27., 57.],
        [34., 77.],
        [41., 68.]])

In [4]:
w = torch.randn(2 , 3 , requires_grad  = True)
b = torch.randn(2 , requires_grad = True)

In [5]:
def model(inputs):
    return inputs @ w.t() + b

In [6]:
def mse(t1 , t2):
    diff = t2 - t1
    return torch.sum(diff * diff) / diff.numel()

In [7]:
preds = model(inputs)
print(preds)

tensor([[ 0.8328, -1.2842],
        [ 2.0987, -1.1811],
        [ 2.3130,  0.2657],
        [ 1.6089, -0.7006]], grad_fn=<AddBackward0>)


In [8]:
loss = mse(preds, targets)
print(loss)

tensor(3025.1182, grad_fn=<DivBackward0>)


In [9]:
loss.backward()

In [10]:
print(w)
print(w.grad)

tensor([[ 1.9516,  1.5596, -0.5574],
        [ 2.5163, -1.0288, -1.2469]], requires_grad=True)
tensor([[-17.7885, -14.1382, -22.4356],
        [-38.2940, -28.6243, -46.4146]])


In [11]:
for epochs in range(50):
    pred = model(inputs)
    pred.requires_grad_(True)
    pred.retain_grad()
    loss = mse(pred , targets)
    loss.backward()
    with torch.no_grad():
        w -= (1e-3 * w.grad)
        b -= (1e-3 * b.grad)
        w.grad.zero_()
        b.grad.zero_()

    print('Epoch : [{} / {}] --> Loss : {}'.format(epochs + 1, 50 , loss.item()))

Epoch : [1 / 50] --> Loss : 3025.1181640625
Epoch : [2 / 50] --> Loss : 3002.3427734375
Epoch : [3 / 50] --> Loss : 2991.034912109375
Epoch : [4 / 50] --> Loss : 2979.771728515625
Epoch : [5 / 50] --> Loss : 2968.55419921875
Epoch : [6 / 50] --> Loss : 2957.380859375
Epoch : [7 / 50] --> Loss : 2946.252197265625
Epoch : [8 / 50] --> Loss : 2935.168212890625
Epoch : [9 / 50] --> Loss : 2924.128662109375
Epoch : [10 / 50] --> Loss : 2913.1328125
Epoch : [11 / 50] --> Loss : 2902.18115234375
Epoch : [12 / 50] --> Loss : 2891.27294921875
Epoch : [13 / 50] --> Loss : 2880.40869140625
Epoch : [14 / 50] --> Loss : 2869.58740234375
Epoch : [15 / 50] --> Loss : 2858.809814453125
Epoch : [16 / 50] --> Loss : 2848.074951171875
Epoch : [17 / 50] --> Loss : 2837.383056640625
Epoch : [18 / 50] --> Loss : 2826.73388671875
Epoch : [19 / 50] --> Loss : 2816.127685546875
Epoch : [20 / 50] --> Loss : 2805.56298828125
Epoch : [21 / 50] --> Loss : 2795.041015625
Epoch : [22 / 50] --> Loss : 2784.5610351562