# Set up new Torch GPU env
## New conda env
`conda create -n torch-gpu python=3.8`
`conda activate torch-gpu`

## Install PyTorch nightly
`conda install pytorch torchvision torchaudio -c pytorch-nightly`
`conda install -n torch-gpu ipykernel --update-deps --force-reinstall`


In [1]:
# test GPU is available
import torch
import math
# this ensures that the current MacOS version is at least 12.3+
print(torch.backends.mps.is_available())
# this ensures that the current current PyTorch installation was built with MPS activated.
print(torch.backends.mps.is_built())

True
True


In [2]:
dtype = torch.float
device = torch.device("mps")

# Create random input and output data
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
y = torch.sin(x)

# Randomly initialize weights
a = torch.randn((), device=device, dtype=dtype)
b = torch.randn((), device=device, dtype=dtype)
c = torch.randn((), device=device, dtype=dtype)
d = torch.randn((), device=device, dtype=dtype)

learning_rate = 1e-6
for t in range(2000):
    # Forward pass: compute predicted y
    y_pred = a + b * x + c * x ** 2 + d * x ** 3

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum().item()
    if t % 100 == 99:
        print(t, loss)

# Backprop to compute gradients of a, b, c, d with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_a = grad_y_pred.sum()
    grad_b = (grad_y_pred * x).sum()
    grad_c = (grad_y_pred * x ** 2).sum()
    grad_d = (grad_y_pred * x ** 3).sum()

    # Update weights using gradient descent
    a -= learning_rate * grad_a
    b -= learning_rate * grad_b
    c -= learning_rate * grad_c
    d -= learning_rate * grad_d


print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')

99 2227.208251953125
199 1532.0771484375
299 1055.947998046875
399 729.450439453125
499 505.30401611328125
599 351.24993896484375
699 245.2506561279297
799 172.23532104492188
899 121.88504028320312
999 87.12703704833984
1099 63.10728073120117
1199 46.49107360839844
1299 34.984832763671875
1399 27.009231567382812
1499 21.475540161132812
1599 17.63253402709961
1699 14.961262702941895
1799 13.102811813354492
1899 11.808754920959473
1999 10.906957626342773
Result: y = 0.04366664960980415 + 0.8758867383003235 x + -0.00753322197124362 x^2 + -0.09605374187231064 x^3
