In [5]:
import torch
import math

class LegendrePolynomial3(torch.autograd.Function):
  @staticmethod
  def forward(ctx, input):
    ctx.save_for_backward(input)
    return 0.5 * (5 * input ** 3 - 3 * input)

  @staticmethod
  def backward(ctx, grad_output):
    input, = ctx.saved_tensors
    return grad_output * 1.5 * (5 * input ** 2 - 1)

dtype = torch.float
device = torch.accelerator.current_accelerator() if torch.accelerator.is_available() else torch.device('cpu')

x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
y = torch.sin(x)

a = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True)
b = torch.full((), -1.0, device=device, dtype=dtype, requires_grad=True)
c = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True)
d = torch.full((), .3, device=device,  dtype=dtype, requires_grad=True)

learning_rate = 5e-6
for t in range(2000):
  P3 = LegendrePolynomial3.apply
  y_pred = a + b * P3(c + d * x)
  loss = (y_pred - y).pow(2).sum()
  if t % 100 == 99:
    print(t, loss.item())

loss.backward()

with torch.no_grad():
  a -= learning_rate * a.grad
  b -= learning_rate * b.grad
  c -= learning_rate * c.grad
  d -= learning_rate * d.grad

  a.grad = None
  b.grad = None
  c.grad = None
  d.grad = None

print(f'Result: y = {a.item()} + {b.item()} * P3({c.item()} + {d.item()} x)')


99 461.902587890625
199 461.902587890625
299 461.902587890625
399 461.902587890625
499 461.902587890625
599 461.902587890625
699 461.902587890625
799 461.902587890625
899 461.902587890625
999 461.902587890625
1099 461.902587890625
1199 461.902587890625
1299 461.902587890625
1399 461.902587890625
1499 461.902587890625
1599 461.902587890625
1699 461.902587890625
1799 461.902587890625
1899 461.902587890625
1999 461.902587890625
Result: y = 1.4305114384716155e-11 + -1.0016018152236938 * P3(0.0 + 0.2659043073654175 x)
