# PyTorch FP8 (fused) matmul tutorial

In [13]:
import numpy as np
import torch

# Local GPU device
torch.device(0), torch.cuda.get_device_name(0)

(device(type='cuda', index=0), 'NVIDIA H100 PCIe')

### `_scaled_mm` FP8 matmul wrapper

PyTorch `_scaled_mm` defintion: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cuda/Blas.cpp#L1176C1-L1176C16

`cublasLtMatmul` not supported `E5M2 @ E5M2` matmuls: https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp8#cublasltmatmul 

TorchAO is using `_scaled_mm` function for FP8 integration: https://github.com/pytorch/ao/blob/main/torchao/float8/float8_python_api.py

In [31]:
M, N, K = 128, 64, 256

a = torch.randn((M, K), dtype=torch.float16, device='cuda')
# Transpose as cuBLASLt requires column major on `rhs`
b = torch.randn((N, K), dtype=torch.float16, device='cuda').t()

# FP8 inputs & scales
# a_fp8 = a.to(torch.float8_e4m3fn)
# b_fp8 = b.to(torch.float8_e4m3fn)

a_fp8 = a.to(torch.float8_e5m2)
b_fp8 = b.to(torch.float8_e5m2)

a_scale = torch.ones((), dtype=torch.float32, device='cuda')
b_scale = torch.ones((), dtype=torch.float32, device='cuda')

# FP8 matmul
out = torch._scaled_mm(a_fp8, b_fp8, 
                       out_dtype=torch.float16,
                       scale_a=a_scale,
                       scale_b=b_scale,
                       use_fast_accum=True,
                       bias=None,
                       scale_result=None)

RuntimeError: Multiplication of two Float8_e5m2 matrices is not supported

In [28]:
out.shape, out.dtype

(torch.Size([128, 64]), torch.float16)