# Pytorch using Metal backend

See the following links:

1. <https://pytorch.org/docs/master/notes/mps.html>
1. <https://stackoverflow.com/questions/73583061/what-is-the-synchronise-function-for-mac-mps>

In [1]:
import torch

# Check that MPS is available
if not torch.backends.mps.is_available():
    if not torch.backends.mps.is_built():
        print("MPS not available because the current PyTorch install was not "
              "built with MPS enabled.")
    else:
        print("MPS not available because the current MacOS version is not 12.3+ "
              "and/or you do not have an MPS-enabled device on this machine.")

else:
    mps_device = torch.device("mps")

    # Create a Tensor directly on the mps device
    x = torch.ones(5, device=mps_device)
    # Or
    x = torch.ones(5, device="mps")

    # Any operation happens on the GPU
    y = x * 2

In [2]:
import torch
if torch.has_mps:
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print("using", device, "device")

import time

matrix_size = 32*512

x = torch.randn(matrix_size, matrix_size)
y = torch.randn(matrix_size, matrix_size)

print("***********cpu speed***************")
start = time.time()
result_cpu = torch.matmul(x,y)
print(time.time()-start)
print("verify device: ", result_cpu.device)


x_mps = x.to(device)
y_mps = y.to(device)

print("***********mps speed***************")
start = time.time()
result_mps = torch.matmul(x_mps,y_mps)
print(time.time()-start)
print("verify device: ", result_mps.device)

assert torch.all(result_mps.eq(result_mps)) == True

using mps device
***********cpu speed***************
36.39276432991028
verify device:  cpu
***********mps speed***************
0.020927906036376953
verify device:  mps:0


In [3]:
from timeit import timeit
import numpy as np

def numpy_sum(n=100_000_000):
    return np.sum(np.arange(n))

numpy_value = numpy_sum()

print('numpy sum: ', timeit(numpy_sum, number=1))

numpy sum:  0.28134358299939777


In [4]:
import torch

def torch_sum(n=100_000_000):
    x = torch.from_numpy(np.arange(n))
    x.to(torch.device("mps"))
    return torch.sum(x).item()

torch_value = torch_sum()

assert torch_value == numpy_value, f"Error: both sums are not equal, torch value is {torch_value}, while expected value is {numpy_value}, delta = {numpy_value - torch_value}"

print('Torch sum: ', timeit(torch_sum, number=1))

Torch sum:  0.21476745799918717
