In [1]:
import torch
import copy

# MPS backend
- New backend is now available named MPS (Metal Perfomance Shaders)
- Can be used as drop in replacment for the 'cuda' device
- *Caveats* not all the ops are supported

In [4]:
device = 'mps' if torch.backends.mps.is_available() else 'cpu'
print(device)

mps


In [35]:
import torch.utils.benchmark as benchmark
import torchvision.models as models

# Small helper which will be used for profiling inference
def run_in_inf_mode(model, tensor):
    with torch.inference_mode():
        model(tensor)

In [30]:
# This utility allow for benchmarking pytorch code
# https://pytorch.org/tutorials/recipes/recipes/benchmark.html

# Other Comparisons
# https://sebastianraschka.com/blog/2022/pytorch-m1-gpu.html

import torch.utils.benchmark as benchmark
import torchvision.models as models

In [36]:
# Profile VGG
# Import a network we know works with mps backend
vgg_weights = models.VGG16_BN_Weights.DEFAULT
vgg_preprocessor = vgg_weights.transforms()

vgg = models.vgg.vgg16_bn(weights=vgg_weights)
vgg_cpu = copy.deepcopy(vgg).eval()
vgg_mps = vgg.to(device).eval()

# Input for benchmarking:
x_cpu = vgg_preprocessor(torch.randint(0, 256, size=(64, 3, 224, 224)))
x_mps = x_cpu.clone().to(device)


print("Vgg".center(100,"*"))
t0 = benchmark.Timer(
    stmt='run_in_inf_mode(vgg_cpu, x_cpu)',
    globals={'run_in_inf_mode':run_in_inf_mode,'vgg_cpu':vgg_cpu, 'x_cpu': x_cpu},
    label='Vgg on CPU',
    num_threads=torch.get_num_threads())

t1 = benchmark.Timer(
    stmt='run_in_inf_mode(vgg_mps, x_mps)',
    globals={'run_in_inf_mode':run_in_inf_mode,'vgg_mps':vgg_mps, 'x_mps':x_mps},
    label='Vgg on MPS',
    num_threads=torch.get_num_threads())

m0 = t0.blocked_autorange(min_run_time=10)
m1 = t1.blocked_autorange(min_run_time=10)

print(m0)
print(m1)

************************************************Vgg*************************************************
<torch.utils.benchmark.utils.common.Measurement object at 0x12fcf2b30>
Vgg on CPU
  Median: 3.96 s
  3 measurements, 1 runs per measurement, 10 threads
<torch.utils.benchmark.utils.common.Measurement object at 0x29085ba00>
Vgg on MPS
  Median: 510.24 ms
  IQR:    3.37 ms (509.61 to 512.98)
  18 measurements, 1 runs per measurement, 10 threads


In [38]:
# Profile Efficient Net
# Import a network we know works with mps backend
eff_weights = models.EfficientNet_B0_Weights.DEFAULT
eff_preprocessor = eff_weights.transforms()

efficientnet_b0 = models.efficientnet_b0(weights=eff_weights)
effic_cpu = copy.deepcopy(efficientnet_b0).eval()
effic_mps = efficientnet_b0.to(device).eval()

# Input for benchmarking:
x_cpu = eff_preprocessor(torch.randint(0, 256, size=(64, 3, 224, 224)))
x_mps = x_cpu.clone().to(device)


print("EfficientNet".center(100,"*"))
t0 = benchmark.Timer(
    stmt='run_in_inf_mode(effic_cpu, x_cpu)',
    globals={'run_in_inf_mode':run_in_inf_mode,'effic_cpu':effic_cpu, 'x_cpu': x_cpu},
    label='EfficientNet on CPU',
    num_threads=torch.get_num_threads())

t1 = benchmark.Timer(
    stmt='run_in_inf_mode(effic_mps, x_mps)',
    globals={'run_in_inf_mode':run_in_inf_mode,'effic_mps':effic_mps, 'x_mps':x_mps},
    label='EfficientNet on MPS',
    num_threads=torch.get_num_threads())

m0 = t0.blocked_autorange(min_run_time=10)
m1 = t1.blocked_autorange(min_run_time=10)

print(m0)
print(m1)

********************************************EfficientNet********************************************
<torch.utils.benchmark.utils.common.Measurement object at 0x1483da200>
EfficientNet on CPU
  Median: 4.00 s
  3 measurements, 1 runs per measurement, 10 threads
<torch.utils.benchmark.utils.common.Measurement object at 0x156d602e0>
EfficientNet on MPS
  Median: 291.97 ms
  IQR:    3.31 ms (291.36 to 294.67)
  35 measurements, 1 runs per measurement, 10 threads


In [47]:
# Autograd works for these ops:
vgg_weights = models.VGG16_BN_Weights.DEFAULT
vgg_preprocessor = vgg_weights.transforms()

vgg = models.vgg.vgg16_bn(weights=vgg_weights)
vgg_cpu = copy.deepcopy(vgg)
vgg_mps = vgg.to(device)

# Input for benchmarking:
x_cpu = vgg_preprocessor(torch.randint(0, 256, size=(64, 3, 224, 224)))
x_mps = x_cpu.clone().to(device)

In [50]:
%%time
vgg_mps(x_mps).sum().backward()

CPU times: user 36.2 ms, sys: 2.44 s, total: 2.48 s
Wall time: 3.52 s


In [51]:
%%time
vgg_cpu(x_cpu).sum().backward()

CPU times: user 40.9 s, sys: 7.61 s, total: 48.5 s
Wall time: 15.2 s


In [44]:
# Not all operators are avialble though:
# See this issue for tracking: https://github.com/pytorch/pytorch/issues/77764
torch.equal(torch.randn(1,3,device=device),torch.randn(1,3,device=device))

NotImplementedError: The operator 'aten::equal' is not current implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

In [45]:
# Profile Resnet50 Breaks on MPS
res50_weights = models.ResNet50_Weights.IMAGENET1K_V2
preprocessor = res50_weights.transforms()
res50 = models.resnet50(weights=weights)
res_cpu = copy.deepcopy(res50).eval()
res_mps = res50.to(device).eval()

# Input for benchmarking:
x_cpu = res50(torch.randint(0, 256, size=(64, 3, 224, 224)))
x_mps = x_cpu.clone().to(device)

print("ResNet50".center(100,"*"))
t0 = benchmark.Timer(
    stmt='run_in_inf_mode(effic_cpu, x_cpu)',
    globals={'run_in_inf_mode':run_in_inf_mode,'res_cpu':res_cpu, 'x_cpu': x_cpu},
    label='ResNet50 on CPU',
    num_threads=torch.get_num_threads())

t1 = benchmark.Timer(
    stmt='run_in_inf_mode(effic_mps, x_mps)',
    globals={'run_in_inf_mode':run_in_inf_mode,'res_mps':res_mps, 'x_mps':x_mps},
    label='ResNet50 on MPS',
    num_threads=torch.get_num_threads())

m0 = t0.blocked_autorange(min_run_time=10)
m1 = t1.blocked_autorange(min_run_time=10)

print(m0)
print(m1)

NotImplementedError: The operator 'aten::_slow_conv2d_forward' is not current implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.