In [1]:
import torch
torch.__version__

'2.2.2+cu121'

## FP32 vs. TF32 

- Starting in PyTorch 1.7, there is a new flag called `allow_tf32`.
    - This flag defaults to True in PyTorch 1.7 to PyTorch 1.11, and False in PyTorch 1.12 and later. 

In [2]:
torch.backends.cuda.matmul.allow_tf32, torch.backends.cudnn.allow_tf32

(False, True)

In [3]:
a_full = torch.randn(10240, 10240, dtype=torch.double, device='cuda')
b_full = torch.randn(10240, 10240, dtype=torch.double, device='cuda')

In [4]:
a_full.dtype

torch.float64

In [5]:
ab_full = a_full @ b_full
mean = ab_full.abs().mean() 
mean

tensor(80.7286, device='cuda:0', dtype=torch.float64)

In [6]:
a = a_full.float()
b = b_full.float()

### tf32 disabled

In [7]:
# Do matmul with TF32 disabled.
torch.backends.cuda.matmul.allow_tf32 = False
ab_fp32 = a @ b 

In [8]:
error = (ab_fp32 - ab_full).abs().max()  # 0.0031
relative_error = error / mean  # 0.000039
error, relative_error

(tensor(0.0028, device='cuda:0', dtype=torch.float64),
 tensor(3.4570e-05, device='cuda:0', dtype=torch.float64))

### tf32 abled

In [9]:
torch.backends.cuda.matmul.allow_tf32 = True
ab_tf32 = a @ b 

In [10]:
error = (ab_tf32 - ab_full).abs().max() 
relative_error = error / mean 
error, relative_error

(tensor(0.1793, device='cuda:0', dtype=torch.float64),
 tensor(0.0022, device='cuda:0', dtype=torch.float64))