In [3]:
import math
import torch
import torchvision
import torch.nn as nn
print("PyTorch version:", torch.__version__)
print("Torchvision version:", torchvision.__version__)
print(torch.backends.mps.is_available())
print(torch.backends.mps.is_built())

PyTorch version: 2.1.0
Torchvision version: 0.16.0
True
True


In [4]:
dtype = torch.float
device = torch.device("mps")

# Create random input and output data
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
y = torch.sin(x)

# Randomly initialize weights
a = torch.randn((), device=device, dtype=dtype)
b = torch.randn((), device=device, dtype=dtype)
c = torch.randn((), device=device, dtype=dtype)
d = torch.randn((), device=device, dtype=dtype)

learning_rate = 1e-6
for t in range(2000):
    # Forward pass: compute predicted y
    y_pred = a + b * x + c * x ** 2 + d * x ** 3

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum().item()
    if t % 100 == 99:
        print(t, loss)

# Backprop to compute gradients of a, b, c, d with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_a = grad_y_pred.sum()
    grad_b = (grad_y_pred * x).sum()
    grad_c = (grad_y_pred * x ** 2).sum()
    grad_d = (grad_y_pred * x ** 3).sum()

    # Update weights using gradient descent
    a -= learning_rate * grad_a
    b -= learning_rate * grad_b
    c -= learning_rate * grad_c
    d -= learning_rate * grad_d


print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')

99 439.94549560546875
199 294.1949768066406
299 197.73196411132812
399 133.88507080078125
499 91.62335968017578
599 63.64701843261719
699 45.12591552734375
799 32.86334991455078
899 24.743946075439453
999 19.367324829101562
1099 15.806583404541016
1199 13.44822883605957
1299 11.886041641235352
1399 10.851102828979492
1499 10.165414810180664
1599 9.711034774780273
1699 9.409889221191406
1799 9.210271835327148
1899 9.077938079833984
1999 8.990174293518066
Result: y = -0.003160446882247925 + 0.8691930770874023 x + 0.0005452316836453974 x^2 + -0.09510160982608795 x^3
