In [8]:
import torch
from torch.profiler import profile, record_function, ProfilerActivity
from torch.nn.functional import scaled_dot_product_attention as sdpa

# for MP
if torch.mps.is_available():
    has_sdpa = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
    print(f'Hash SDPA: {has_sdpa}')
    device = 'mps'

    # create pseudo-Q,K,V for Attention
    B, H, T, D = 2, 4, 128, 64
    dtype = torch.bfloat16
    q = torch.randn(B, H, T, D, device=device, dtype=dtype)
    k = torch.randn_like(q)
    v = torch.randn_like(q)

    # create profiling context
    activities = [ProfilerActivity.CPU]
    with profile(activities=activities, record_shapes=True) as prof:
        with record_function('sdpa_test'):
            y = sdpa(q, k, v, is_causal=True)

    # look specifically for scaled_dot_product_attention kernels
    for evt in prof.key_averages():
        if '_scaled_dot_product' in evt.key:
            if '_scaled_dot_product_attention_math' in evt.key:
                print(f'SDPA is implemented by {evt.key} within the SDPBackend.MATH backend.\n'
                      'It is a C++ wrapper around MPS/Metal kernels and does not provide a flash attention implementation.')
                continue
            print(evt.key)

Hash SDPA: True
SDPA is implemented by aten::_scaled_dot_product_attention_math_for_mps within the SDPBackend.MATH backend.
It is a C++ wrapper around MPS/Metal kernels and does not provide a flash attention implementation.
