In [4]:
import torch
import math
# this ensures that the current MacOS version is at least 12.3+
print(torch.backends.mps.is_available())
# this ensures that the current current PyTorch installation was built with MPS activated.
print(torch.backends.mps.is_built())

dtype = torch.float
device = torch.device("mps")

# Create random input and output data
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
y = torch.sin(x)

# Randomly initialize weights
a = torch.randn((), device=device, dtype=dtype)
b = torch.randn((), device=device, dtype=dtype)
c = torch.randn((), device=device, dtype=dtype)
d = torch.randn((), device=device, dtype=dtype)

learning_rate = 1e-6
for t in range(2000):
    # Forward pass: compute predicted y
    y_pred = a + b * x + c * x ** 2 + d * x ** 3

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum().item()
    if t % 100 == 99:
        print(t, loss)

# Backprop to compute gradients of a, b, c, d with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_a = grad_y_pred.sum()
    grad_b = (grad_y_pred * x).sum()
    grad_c = (grad_y_pred * x ** 2).sum()
    grad_d = (grad_y_pred * x ** 3).sum()

    # Update weights using gradient descent
    a -= learning_rate * grad_a
    b -= learning_rate * grad_b
    c -= learning_rate * grad_c
    d -= learning_rate * grad_d


print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')

True
True
99 1965.019775390625
199 1382.242919921875
299 973.533447265625
399 686.7455444335938
499 485.4105224609375
599 344.00006103515625
699 244.635498046875
799 174.78558349609375
899 125.66407775878906
999 91.10664367675781
1099 66.78667449951172
1199 49.66564178466797
1299 37.608795166015625
1399 29.115638732910156
1499 23.13118553161621
1599 18.913305282592773
1699 15.939769744873047
1799 13.842973709106445
1899 12.364091873168945
1999 11.320807456970215
Result: y = -0.05199597030878067 + 0.8474884033203125 x + 0.008970167487859726 x^2 + -0.09201432764530182 x^3


In [6]:
import torch
from unet import UNet
model = UNet(in_channels=1,
            out_channels=1,
            n_blocks=5,
            start_filters=32,
            activation='relu',
            normalization='batch',
            conv_mode='same',
            dim=3)

# Create a random dataset 
x = torch.randn(size=(1, 1, 128, 96, 128), dtype=torch.float32)

with torch.no_grad():
    out = model(x)

print(f'Out: {out.shape}')
print(f'In: {x.shape}')

Out: torch.Size([1, 1, 128, 96, 128])
In: torch.Size([1, 1, 128, 96, 128])
