In [3]:
import torch
import math
# this ensures that the current MacOS version is at least 12.3+
print(torch.backends.mps.is_available())
# this ensures that the current current PyTorch installation was built with MPS activated.
print(torch.backends.mps.is_built())
print(torch.cuda.is_available())

True
True
False


In [2]:
dtype = torch.float
device = torch.device("mps")

# Create random input and output data
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
y = torch.sin(x)

# Randomly initialize weights
a = torch.randn((), device=device, dtype=dtype)
b = torch.randn((), device=device, dtype=dtype)
c = torch.randn((), device=device, dtype=dtype)
d = torch.randn((), device=device, dtype=dtype)

learning_rate = 1e-6
for t in range(2000):
    # Forward pass: compute predicted y
    y_pred = a + b * x + c * x ** 2 + d * x ** 3

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum().item()
    if t % 100 == 99:
        print(t, loss)

# Backprop to compute gradients of a, b, c, d with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_a = grad_y_pred.sum()
    grad_b = (grad_y_pred * x).sum()
    grad_c = (grad_y_pred * x ** 2).sum()
    grad_d = (grad_y_pred * x ** 3).sum()

    # Update weights using gradient descent
    a -= learning_rate * grad_a
    b -= learning_rate * grad_b
    c -= learning_rate * grad_c
    d -= learning_rate * grad_d


print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')

99 396.64990234375
199 276.2408447265625
299 193.41580200195312
399 136.37771606445312
499 97.054931640625
599 69.916015625
699 51.16594696044922
799 38.198246002197266
899 29.220487594604492
999 22.998886108398438
1099 18.683135986328125
1199 15.686593055725098
1299 13.604110717773438
1399 12.155586242675781
1499 11.147150039672852
1599 10.444525718688965
1699 9.95458984375
1799 9.612682342529297
1899 9.373908996582031
1999 9.207040786743164
Result: y = 0.01926896534860134 + 0.8493067622184753 x + -0.0033242187928408384 x^2 + -0.09227297455072403 x^3
