In [1]:
import torch
import math

dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU

# Create random input and output data
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
y = torch.sin(x)

# Randomly initialize weights
a = torch.randn((1,), device=device, dtype=dtype)
b = torch.randn((1,), device=device, dtype=dtype)
c = torch.randn((1), device=device, dtype=dtype)
d = torch.randn((), device=device, dtype=dtype)

learning_rate = 1e-6
for t in range(20000):
    # Forward pass: compute predicted y
    y_pred = a + b * x + c * x ** 2 + d * x ** 3
    
    # Compute and print loss
    loss = ((y_pred - y) ** 2).sum().item()
    if t % 100 == 99:
        print(t, loss)

    # Backprop to compute gradients of a, b, c, d with respect to loss
    # dLoss/dW = dLoss/dY_pred * dY_pred/dW
    grad_y_pred = 2.0 * (y_pred - y) # dLoss/dY_pred
    grad_a = grad_y_pred.sum() # dLoss/dY_pred * dY_pred/dW
    grad_b = (grad_y_pred * x).sum() # dLoss/dY_pred * dY_pred/dW
    grad_c = (grad_y_pred * x ** 2).sum() # dLoss/dY_pred * dY_pred/dW
    grad_d = (grad_y_pred * x ** 3).sum() # dLoss/dY_pred * dY_pred/dW

    # Update weights using gradient descent
    # w - n.dLoss/dw
    a -= learning_rate * grad_a 
    b -= learning_rate * grad_b
    c -= learning_rate * grad_c
    d -= learning_rate * grad_d

print(loss)
print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')

99 21.031464183317993
199 16.919556233133967
299 14.193048335191524
399 12.38483083106658
499 11.185386485711899
599 10.389592806544261
699 9.861491121500475
799 9.510951306688039
899 9.278213863585364
999 9.123648926100412
1099 9.02097064651436
1199 8.952740472901862
1299 8.90738680227428
1399 8.877229351263463
1499 8.857169321523983
1599 8.843820830449044
1699 8.834934842268016
1799 8.829017020058462
1899 8.825074165347537
1999 8.82244593915707
2099 8.820693154366063
2199 8.819523602498617
2299 8.818742789767978
2399 8.818221207783376
2499 8.81787258267941
2599 8.817639415484363
2699 8.817483366292874
2799 8.817378857306942
2899 8.817308815607173
2999 8.81726183878418
3099 8.817230307071554
3199 8.817209125340995
3299 8.817194884415699
3399 8.817185301654998
3499 8.817178847620266
3599 8.817174496787093
3699 8.817171560990747
3799 8.817169578079803
3899 8.817168237430502
3999 8.817167330086152
4099 8.817166715357196
4199 8.81716629843155
4299 8.817166015354303
4399 8.817165822943334
