In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import math

In [2]:
#cumulative reward = sum of r_timestep * gamma**(timestep - first timestep) 

In [3]:
dtype = torch.float
x = torch.linspace(-math.pi, math.pi, 2000, dtype=dtype)
y = torch.sin(x)

In [4]:
a = torch.randn((), dtype=dtype)
b = torch.randn((), dtype=dtype)
c = torch.randn((), dtype=dtype)
d = torch.randn((), dtype=dtype)

In [5]:
lr = 1e-6
for t in range(2000):
    y_pred = a + b*x + c*x**2 + d*x**3
    loss = (y_pred - y).pow(2).sum().item()
    grad_y_pred = 2.0*(y_pred - y) #derivative of individual squared diff
    grad_a = grad_y_pred.sum()
    grad_b = (grad_y_pred * x).sum()
    grad_c = (grad_y_pred * x**2).sum()
    grad_d = (grad_y_pred * x**3).sum()
    
    a -= lr*grad_a
    b -= lr*grad_b
    c -= lr*grad_c
    d -= lr*grad_d

In [6]:
print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')
print(y_pred)

Result: y = 0.009927630424499512 + 0.8412808775901794 x + -0.0017126821912825108 x^2 + -0.09113134443759918 x^3
tensor([ 0.1757,  0.1699,  0.1641,  ..., -0.1779, -0.1838, -0.1896])


In [7]:
a = torch.randn((), dtype=dtype, requires_grad=True)
b = torch.randn((), dtype=dtype, requires_grad=True)
c = torch.randn((), dtype=dtype, requires_grad=True)
d = torch.randn((), dtype=dtype, requires_grad=True)

In [8]:
lr = 1e-6
for t in range(2000):
    y_pred = a + b*x + c*x**2 + d*x**3
    loss = (y_pred - y).pow(2).sum()
    if t % 100 == 99:
        print(t, loss.item())
    
    loss.backward()
    
    with torch.no_grad():
        a -= lr*a.grad
        b -= lr*b.grad
        c -= lr*c.grad
        d -= lr*d.grad
        
        a.grad = None
        b.grad = None
        c.grad = None
        d.grad = None

99 2652.847900390625
199 1788.601806640625
299 1207.9345703125
399 817.4680786132812
499 554.6722412109375
599 377.6434020996094
699 258.27947998046875
799 177.7197265625
899 123.29637145996094
999 86.49241638183594
1099 61.578304290771484
1199 44.6949348449707
1299 33.24156951904297
1399 25.463294982910156
1499 20.17507553100586
1599 16.57571792602539
1699 14.123103141784668
1799 12.449989318847656
1899 11.307310104370117
1999 10.52601432800293


In [9]:
print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')

Result: y = 0.03260684758424759 + 0.8299441337585449 x + -0.005625222343951464 x^2 + -0.08951878547668457 x^3


In [10]:
class ApproxModel(torch.autograd.Function):
    @staticmethod
    def forward(context, input):
        context.save_for_backward(input)
        return 0.5*(5*input**3 - 3*input)
    
    @staticmethod
    def backward(context, grad_output):
        input, = context.saved_tensors
        return grad_output * 1.5 * (5*input**2 - 1)

In [11]:
a = torch.full((), 0.0, dtype=dtype, requires_grad=True)
b = torch.full((), -1.0, dtype=dtype, requires_grad=True)
c = torch.full((), 0.0, dtype=dtype, requires_grad=True)
d = torch.full((), 0.3, dtype=dtype, requires_grad=True)

In [12]:
lr = 5e-6
for t in range(2000):
    model = ApproxModel().apply
    y_pred = a + b*model(c + d*x)
    loss = (y_pred - y).pow(2).sum()
    if t % 100 == 99:
        print(t, loss.item())
    
    loss.backward()
    
    with torch.no_grad():
        a -= lr*a.grad
        b -= lr*b.grad
        c -= lr*c.grad
        d -= lr*d.grad
        
        a.grad = None
        b.grad = None
        c.grad = None
        d.grad = None

99 209.95834350585938
199 144.66018676757812
299 100.70249938964844
399 71.03519439697266
499 50.97850799560547
599 37.403133392333984
699 28.206867218017578
799 21.973188400268555
899 17.7457275390625
999 14.877889633178711
1099 12.931766510009766
1199 11.610918045043945
1299 10.714258193969727
1399 10.10548210144043
1499 9.692106246948242
1599 9.411375999450684
1699 9.220745086669922
1799 9.091285705566406
1899 9.003361701965332
1999 8.943639755249023


In [13]:
print(f'Result: y = {a.item()} + {b.item()} * P3({c.item()} + {d.item()} x)')

Result: y = -6.8844756562214116e-09 + -2.208526849746704 * P3(1.5037101563919464e-09 + 0.2554861009120941 x)
