In [12]:
import os
import math
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

In [2]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda device


In [9]:
# torch.cuda.device_count()
torch.cuda.get_device_name(0)

'NVIDIA GeForce RTX 4070 Ti'

In [5]:
X_train = torch.FloatTensor([0., 1., 2.])
X_train.is_cuda


False

In [6]:
X_train = X_train.to(device)
X_train.is_cuda


True

In [13]:
# device = torch.device("cpu")
device = torch.device("cuda:0") # Uncomment this to run on GPU

dtype = torch.float

In [14]:
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
y = torch.sin(x)

In [17]:
a = torch.randn((), device=device, dtype=dtype)
b = torch.randn((), device=device, dtype=dtype)
c = torch.randn((), device=device, dtype=dtype)
d = torch.randn((), device=device, dtype=dtype)

In [18]:
y_pred = a + b * x + c * x **2 + d * x ** 3

In [24]:
y_pred

tensor([-2.3295, -2.3123, -2.2952,  ..., 23.2061, 23.2713, 23.3365],
       device='cuda:0')

In [25]:
(y_pred - y).pow(2).sum()


tensor(77300.2969, device='cuda:0')

In [27]:
# to fit a third order polynomial to sine function

learning_rate = 1e-6

for t in range(2000):

    # forward pass
    y_pred = a + b * x + c * x **2 + d * x ** 3

    # compute and print loss
    loss = (y_pred - y).pow(2).sum().item()

    if t % 100 == 99:
        print(t, loss)

    # backprop to compute gradients of a, b, c, d wrt loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_a = grad_y_pred.sum()
    grad_b = (grad_y_pred * x).sum()
    grad_c = (grad_y_pred * x **2).sum()
    grad_d = (grad_y_pred * x **3).sum()

    # Update weights using gradient descent
    a -= learning_rate * grad_a
    b -= learning_rate * grad_b
    c -= learning_rate * grad_c
    d -= learning_rate * grad_d


print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')

99 3039.31982421875
199 2094.4619140625
299 1445.78466796875
399 999.9383544921875
499 693.1578979492188
599 481.8324890136719
699 336.1019287109375
799 235.49742126464844
899 165.9723358154297
999 117.87565612792969
1099 84.56925964355469
1199 61.482017517089844
1299 45.46310806274414
1399 34.338165283203125
1499 26.60502815246582
1599 21.224842071533203
1699 17.478525161743164
1799 14.867705345153809
1899 13.046825408935547
1999 11.775888442993164
Result: y = -0.05258806794881821 + 0.8352139592170715 x + 0.009072314947843552 x^2 + -0.09026837348937988 x^3
