In [1]:
import torch
from torch import nn

In [2]:
# Check that MPS is available
if not torch.backends.mps.is_available():
    if not torch.backends.mps.is_built():
        print("MPS not available because the current PyTorch install was not "
              "built with MPS enabled.")
    else:
        print("MPS not available because the current MacOS version is not 12.3+ "
              "and/or you do not have an MPS-enabled device on this machine.")

else:
    mps_device = torch.device("mps")

    # Create a Tensor directly on the mps device
    x = torch.ones(5, device=mps_device)
    # Or
    x = torch.ones(5, device="mps")

    # Any operation happens on the GPU
    y = x * 2

    # Move your model to mps just like any other device
    model = nn.Linear(5, 5)
    model.to(mps_device)

    # Now every call runs on the GPU
    pred = model(x)

In [3]:
pred

tensor([-0.4626,  0.7140,  0.2347, -0.4631,  0.8079], device='mps:0',
       grad_fn=<LinearBackward0>)