In [1]:
import torch
import math

In [4]:
print(torch.backends.mps.is_available())

True


In [5]:
print(torch.backends.mps.is_built())

True


In [6]:
dtype = torch.float
device = torch.device("mps")

# Create random input and output data
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
y = torch.sin(x)

# Randomly initialize weights
a = torch.randn((), device=device, dtype=dtype)
b = torch.randn((), device=device, dtype=dtype)
c = torch.randn((), device=device, dtype=dtype)
d = torch.randn((), device=device, dtype=dtype)

learning_rate = 1e-6
for t in range(2000):
    # Forward pass: compute predicted y
    y_pred = a + b * x + c * x ** 2 + d * x ** 3

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum().item()
    if t % 100 == 99:
        print(t, loss)

# Backprop to compute gradients of a, b, c, d with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_a = grad_y_pred.sum()
    grad_b = (grad_y_pred * x).sum()
    grad_c = (grad_y_pred * x ** 2).sum()
    grad_d = (grad_y_pred * x ** 3).sum()

    # Update weights using gradient descent
    a -= learning_rate * grad_a
    b -= learning_rate * grad_b
    c -= learning_rate * grad_c
    d -= learning_rate * grad_d


print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')

99 456.59283447265625
199 305.18231201171875
299 204.98226928710938
399 138.66831970214844
499 94.7783203125
599 65.72753143310547
699 46.49774169921875
799 33.76776885986328
899 25.340036392211914
999 19.76006507873535
1099 16.06531524658203
1199 13.61861801147461
1299 11.998224258422852
1399 10.924973487854004
1499 10.214006423950195
1599 9.742998123168945
1699 9.430902481079102
1799 9.224080085754395
1899 9.087005615234375
1999 8.996147155761719
Result: y = -0.003042224096134305 + 0.8440393805503845 x + 0.0005248354282230139 x^2 + -0.09152372926473618 x^3
