In [1]:
import torch
import math

In [2]:
print(torch.backends.mps.is_available())

True


In [3]:
print(torch.backends.mps.is_built())

True


In [4]:
dtype = torch.float
device = torch.device("mps")

# Create random input and output data
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
y = torch.sin(x)

# Randomly initialize weights
a = torch.randn((), device=device, dtype=dtype)
b = torch.randn((), device=device, dtype=dtype)
c = torch.randn((), device=device, dtype=dtype)
d = torch.randn((), device=device, dtype=dtype)

learning_rate = 1e-6
for t in range(2000):
    # Forward pass: compute predicted y
    y_pred = a + b * x + c * x ** 2 + d * x ** 3

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum().item()
    if t % 100 == 99:
        print(t, loss)

# Backprop to compute gradients of a, b, c, d with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_a = grad_y_pred.sum()
    grad_b = (grad_y_pred * x).sum()
    grad_c = (grad_y_pred * x ** 2).sum()
    grad_d = (grad_y_pred * x ** 3).sum()

    # Update weights using gradient descent
    a -= learning_rate * grad_a
    b -= learning_rate * grad_b
    c -= learning_rate * grad_c
    d -= learning_rate * grad_d


print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')

99 333.0504455566406
199 235.37448120117188
299 167.2323455810547
399 119.65789794921875
499 86.4192886352539
599 63.18062210083008
699 46.92267608642578
799 35.54131317138672
899 27.56899642944336
999 21.9814510345459
1099 18.063138961791992
1199 15.313982009887695
1299 13.384149551391602
1399 12.028831481933594
1499 11.076581954956055
1599 10.407220840454102
1699 9.93653678894043
1799 9.60543441772461
1899 9.372437477111816
1999 9.208417892456055
Result: y = -0.020283635705709457 + 0.8615069389343262 x + 0.0034992636647075415 x^2 + -0.09400834888219833 x^3
