Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 16 additions & 14 deletions compute/accelerator/benchmarks/mamf-finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import sys
import time
import torch
from packaging import version

# important: when changing how the benchmark measures things bump up its version, so that the old
# reports could be differentiated from the new ones
Expand Down Expand Up @@ -224,24 +225,25 @@ def func_wrapper(*args, **kwargs):
return decorator

total_iterations = num_iterations + num_warmup_iterations
if dtype == torch.float8_e4m3fn:

# fp8 requires special handling depending on the vendor:
# float8_e4m3fn for nvidia, float8_e4m3fnuz for amd
fp8_dtypes = [torch.float8_e4m3fn, torch.float8_e4m3fnuz]
if dtype in fp8_dtypes:
# torch._scaled_mm is different before pt-2.5
if version.parse(torch.__version__) < version.parse("2.5"):
raise ValueError("float8 dtypes require torch>=2.5")

A = torch.randn(m, k, dtype=torch.float32, device=device).contiguous()
B = torch.randn(n, k, dtype=torch.float32, device=device).contiguous().t()
scale = torch.tensor([1.0]).to(device)
A = A.to(dtype)
B = B.to(dtype)

A = A.to(torch.float8_e4m3fn)
B = B.to(torch.float8_e4m3fn)

# some torch versions require the scale arg, some don't so discover which is required
try:
C = torch._scaled_mm(A, B)
@time_it(total_iterations)
def time_iterations():
C = torch._scaled_mm(A, B)
except:
@time_it(total_iterations)
def time_iterations():
C = torch._scaled_mm(A, B, scale, scale)
# Simplified call for PyTorch 2.5+
@time_it(total_iterations)
def time_iterations():
C = torch._scaled_mm(A, B, scale, scale)

else:
A = torch.randn(m, k, dtype=dtype, device=device).contiguous()
Expand Down