# TORCH MPS TEST

In [4]:
import torch
import math

print(f'MPS avaiable: {torch.backends.mps.is_available()}, build with: {torch.backends.mps.is_built()}')

device = 'mps' if torch.backends.mps.is_available() else \
  'cuda' if torch.backends.cuda.is_available() else \
  'cpu'

MPS avaiable: True, build with: True


In [7]:
dtype = torch.float
device = torch.device("cpu")

# Create random input and output data
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
y = torch.sin(x)

# Randomly initialize weights
a = torch.randn((), device=device, dtype=dtype)
b = torch.randn((), device=device, dtype=dtype)
c = torch.randn((), device=device, dtype=dtype)
d = torch.randn((), device=device, dtype=dtype)

learning_rate = 1e-6
for t in range(2000):
    # Forward pass: compute predicted y
    y_pred = a + b * x + c * x ** 2 + d * x ** 3

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum().item()
    if t % 100 == 99:
        print(t, loss)

# Backprop to compute gradients of a, b, c, d with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_a = grad_y_pred.sum()
    grad_b = (grad_y_pred * x).sum()
    grad_c = (grad_y_pred * x ** 2).sum()
    grad_d = (grad_y_pred * x ** 3).sum()

    # Update weights using gradient descent
    a -= learning_rate * grad_a
    b -= learning_rate * grad_b
    c -= learning_rate * grad_c
    d -= learning_rate * grad_d


print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')

99 87.57077026367188
199 64.15534973144531
299 47.71859359741211
399 36.174774169921875
499 28.06369972229004
599 22.362226486206055
699 18.352909088134766
799 15.532451629638672
899 13.547626495361328
999 12.150397300720215
1099 11.16648006439209
1199 10.473418235778809
1299 9.985079765319824
1399 9.640910148620605
1499 9.398283958435059
1599 9.227201461791992
1699 9.106539726257324
1799 9.021421432495117
1899 8.961362838745117
1999 8.918980598449707
Result: y = -0.01050795242190361 + 0.8549855351448059 x + 0.0018127952935174108 x^2 + -0.09308070689439774 x^3
