In [1]:
from IPython.display import Image

In [2]:
import torch

### tf32

In [3]:
Image(url='https://api.ibos.cn/v4/weapparticle/accesswximg?aid=86367&url=aHR0cHM6Ly9tbWJpei5xcGljLmNuL21tYml6X3BuZy9yY1Z1ZmlhVDAwVm5ZNTVQTnRnQ2dERVc0UzA5RVJKNERWSUtTRkZyRTZZOXVVUEFqQlJRTkpzc1owMjVjQUR2aFZvU2N5WVRJWnp3bFJnQnNKcmQ3bmcvNjQwP3d4X2ZtdD1wbmcmYW1w;from=appmsg', width=400)

- bf16 的问题范围大，但精度小；
- fp16 精度大，但范围小；
- https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices
    - TF32 tensor cores are designed to achieve better performance on matmul and convolutions on torch.float32 tensors by rounding input data to have **10 bits of mantissa**, and **accumulating results with FP32 precision**, maintaining FP32 dynamic range.

In [4]:
torch.randn(5).dtype

torch.float32

In [5]:
# Enable TF32 on CUDA
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

In [6]:
torch.randn(5, dtype='tf32').dtype

TypeError: randn() received an invalid combination of arguments - got (int, dtype=str), but expected one of:
 * (tuple of ints size, *, torch.Generator generator, tuple of names names, torch.dtype dtype = None, torch.layout layout = None, torch.device device = None, bool pin_memory = False, bool requires_grad = False)
 * (tuple of ints size, *, torch.Generator generator, Tensor out = None, torch.dtype dtype = None, torch.layout layout = None, torch.device device = None, bool pin_memory = False, bool requires_grad = False)
 * (tuple of ints size, *, Tensor out = None, torch.dtype dtype = None, torch.layout layout = None, torch.device device = None, bool pin_memory = False, bool requires_grad = False)
 * (tuple of ints size, *, tuple of names names, torch.dtype dtype = None, torch.layout layout = None, torch.device device = None, bool pin_memory = False, bool requires_grad = False)


In [7]:
a_full = torch.randn(10240, 10240*5, dtype=torch.double, device='cuda')
b_full = torch.randn(10240*5, 10240, dtype=torch.double, device='cuda')
ab_full = a_full @ b_full
mean = ab_full.abs().mean()  # 80.7277

In [8]:
a = a_full.float()
b = b_full.float()

In [9]:
# Do matmul at TF32 mode.
torch.backends.cuda.matmul.allow_tf32 = True
ab_tf32 = a @ b  # takes 0.016s on GA100
error = (ab_tf32 - ab_full).abs().max()  # 0.1747
relative_error = error / mean  # 0.0022
relative_error

tensor(0.0024, device='cuda:0', dtype=torch.float64)

In [10]:
# Do matmul with TF32 disabled.
torch.backends.cuda.matmul.allow_tf32 = False
ab_fp32 = a @ b  # takes 0.11s on GA100
error = (ab_fp32 - ab_full).abs().max()  # 0.0031
relative_error = error / mean  # 0.000039
relative_error

tensor(8.8293e-05, device='cuda:0', dtype=torch.float64)

### Inference data type