In [8]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import math

In [9]:
#cumulative reward = sum of r_timestep * gamma**(timestep - first timestep) 

In [10]:
dtype = torch.float
x = torch.linspace(-math.pi, math.pi, 2000, dtype=dtype)
y = torch.sin(x)

In [11]:
a = torch.randn((), dtype=dtype)
b = torch.randn((), dtype=dtype)
c = torch.randn((), dtype=dtype)
d = torch.randn((), dtype=dtype)

In [12]:
lr = 1e-6
for t in range(2000):
    y_pred = a + b*x + c*x**2 + d*x**3
    loss = (y_pred - y).pow(2).sum().item()
    grad_y_pred = 2.0*(y_pred - y) #derivative of individual squared diff
    grad_a = grad_y_pred.sum()
    grad_b = (grad_y_pred * x).sum()
    grad_c = (grad_y_pred * x**2).sum()
    grad_d = (grad_y_pred * x**3).sum()
    
    a -= lr*grad_a
    b -= lr*grad_b
    c -= lr*grad_c
    d -= lr*grad_d

In [13]:
print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')
print(y_pred)

Result: y = -0.00876771192997694 + 0.8424534797668457 x + 0.0015125765930861235 x^2 + -0.09129814058542252 x^3
tensor([ 0.1903,  0.1844,  0.1786,  ..., -0.1664, -0.1722, -0.1780])


In [14]:
a = torch.randn((), dtype=dtype, requires_grad=True)
b = torch.randn((), dtype=dtype, requires_grad=True)
c = torch.randn((), dtype=dtype, requires_grad=True)
d = torch.randn((), dtype=dtype, requires_grad=True)

In [15]:
lr = 1e-6
for t in range(2000):
    y_pred = a + b*x + c*x**2 + d*x**3
    loss = (y_pred - y).pow(2).sum()
    if t % 100 == 99:
        print(t, loss.item())
    
    loss.backward()
    
    with torch.no_grad():
        a -= lr*a.grad
        b -= lr*b.grad
        c -= lr*c.grad
        d -= lr*d.grad
        
        a.grad = None
        b.grad = None
        c.grad = None
        d.grad = None

99 146.06045532226562
199 100.51876068115234
299 70.12959289550781
399 49.83741760253906
499 36.27960968017578
599 27.215787887573242
699 21.15247917175293
799 17.093643188476562
899 14.374722480773926
999 12.55205249786377
1099 11.329307556152344
1199 10.508342742919922
1299 9.95671558380127
1399 9.585732460021973
1499 9.336024284362793
1599 9.167787551879883
1699 9.054356575012207
1799 8.977773666381836
1899 8.92603588104248
1999 8.891046524047852


In [16]:
print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')

Result: y = -0.00567973405122757 + 0.8502136468887329 x + 0.0009798483224585652 x^2 + -0.09240195155143738 x^3


In [17]:
class ApproxModel(torch.autograd.Function):
    @staticmethod
    def forward(context, input):
        context.save_for_backward(input)
        return 0.5*(5*input**3 - 3*input)
    
    @staticmethod
    def backward(context, grad_output):
        input, = context.saved_tensors
        return grad_output * 1.5 * (5*input**2 - 1)

In [18]:
a = torch.full((), 0.0, dtype=dtype, requires_grad=True)
b = torch.full((), -1.0, dtype=dtype, requires_grad=True)
c = torch.full((), 0.0, dtype=dtype, requires_grad=True)
d = torch.full((), 0.3, dtype=dtype, requires_grad=True)

In [19]:
lr = 5e-6
for t in range(2000):
    model = ApproxModel().apply
    y_pred = a + b*model(c + d*x)
    loss = (y_pred - y).pow(2).sum()
    if t % 100 == 99:
        print(t, loss.item())
    
    loss.backward()
    
    with torch.no_grad():
        a -= lr*a.grad
        b -= lr*b.grad
        c -= lr*c.grad
        d -= lr*d.grad
        
        a.grad = None
        b.grad = None
        c.grad = None
        d.grad = None

99 209.95834350585938
199 144.66018676757812
299 100.70249938964844
399 71.03519439697266
499 50.97850799560547
599 37.403133392333984
699 28.206867218017578
799 21.97318458557129
899 17.745729446411133
999 14.877889633178711
1099 12.93176555633545
1199 11.610918045043945
1299 10.714258193969727
1399 10.10548210144043
1499 9.692106246948242
1599 9.411375999450684
1699 9.220745086669922
1799 9.091286659240723
1899 9.003362655639648
1999 8.943641662597656


In [22]:
print(f'Result: y = {a.item()} + {b.item()} * P3({c.item()} + {d.item()} x)')

Result: y = -2.9753338681715036e-10 + -2.208526849746704 * P3(-1.1693186696692948e-10 + 0.2554861009120941 x)
