# [TensorFloat-32 (TF32) on Ampere (and later) devices](https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices)

In [1]:
import torch

Two float64 10240x10240 tensors are 2 * 104.857.600 neurons

## Float 32 Matrix multiplication using CPU

In [2]:
a = torch.randn(10240, 10240, device='cpu')
b = torch.randn(10240, 10240, device='cpu')

In [3]:
ab = a @ b
mean = ab.abs().mean()  # 80.7277
mean

tensor(80.7383)

## Float 64 Matrix multiplication using CPU

In [4]:
a_full = torch.randn(10240, 10240, dtype=torch.double, device='cpu')
b_full = torch.randn(10240, 10240, dtype=torch.double, device='cpu')

In [5]:
ab_full = a_full @ b_full
mean = ab_full.abs().mean()  # 80.7277
mean

tensor(80.7250, dtype=torch.float64)

## Float 64 Matrix multiplication using GPU

In [6]:
a_full = torch.randn(10240, 10240, dtype=torch.double, device='cuda')
b_full = torch.randn(10240, 10240, dtype=torch.double, device='cuda')

In [7]:
ab_full = a_full @ b_full
mean = ab_full.abs().mean()  # 80.7277
mean

tensor(80.7264, device='cuda:0', dtype=torch.float64)

## Float 32 Matrix multiplication using GPU at TF32 mode

In [8]:
a = a_full.float()
b = b_full.float()

In [9]:
torch.backends.cuda.matmul.allow_tf32 = True
ab_tf32 = a @ b  # takes 0.016s on GA100
ab_tf32.abs().mean()

tensor(80.7242, device='cuda:0')

In [10]:
error = (ab_tf32 - ab_full).abs().max()  # 0.1747
error

tensor(0.1758, device='cuda:0', dtype=torch.float64)

In [11]:
relative_error = error / mean  # 0.0022
relative_error

tensor(0.0022, device='cuda:0', dtype=torch.float64)

## Float 32 Matrix multiplication using GPU with TF32 disabled

In [12]:
torch.backends.cuda.matmul.allow_tf32 = False
ab_fp32 = a @ b  # takes 0.11s on GA100
ab_fp32.abs().mean()

tensor(80.7264, device='cuda:0')

In [13]:
error = (ab_fp32 - ab_full).abs().max()  # 0.0031
error

tensor(0.0029, device='cuda:0', dtype=torch.float64)

In [14]:
relative_error = error / mean  # 0.000039
relative_error

tensor(3.6156e-05, device='cuda:0', dtype=torch.float64)