diff --git a/compute/accelerator/benchmarks/mamf-finder.py b/compute/accelerator/benchmarks/mamf-finder.py index 2556dacc..e1d09d01 100755 --- a/compute/accelerator/benchmarks/mamf-finder.py +++ b/compute/accelerator/benchmarks/mamf-finder.py @@ -28,6 +28,7 @@ import sys import time import torch +from packaging import version # important: when changing how the benchmark measures things bump up its version, so that the old # reports could be differentiated from the new ones @@ -224,24 +225,25 @@ def func_wrapper(*args, **kwargs): return decorator total_iterations = num_iterations + num_warmup_iterations - if dtype == torch.float8_e4m3fn: + + # fp8 requires special handling depending on the vendor: + # float8_e4m3fn for nvidia, float8_e4m3fnuz for amd + fp8_dtypes = [torch.float8_e4m3fn, torch.float8_e4m3fnuz] + if dtype in fp8_dtypes: + # torch._scaled_mm is different before pt-2.5 + if version.parse(torch.__version__) < version.parse("2.5"): + raise ValueError("float8 dtypes require torch>=2.5") + A = torch.randn(m, k, dtype=torch.float32, device=device).contiguous() B = torch.randn(n, k, dtype=torch.float32, device=device).contiguous().t() scale = torch.tensor([1.0]).to(device) + A = A.to(dtype) + B = B.to(dtype) - A = A.to(torch.float8_e4m3fn) - B = B.to(torch.float8_e4m3fn) - - # some torch versions require the scale arg, some don't so discover which is required - try: - C = torch._scaled_mm(A, B) - @time_it(total_iterations) - def time_iterations(): - C = torch._scaled_mm(A, B) - except: - @time_it(total_iterations) - def time_iterations(): - C = torch._scaled_mm(A, B, scale, scale) + # Simplified call for PyTorch 2.5+ + @time_it(total_iterations) + def time_iterations(): + C = torch._scaled_mm(A, B, scale, scale) else: A = torch.randn(m, k, dtype=dtype, device=device).contiguous()