In [1]:
from test_model import TestModel
import torch

torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

device = "cuda" if torch.cuda.is_available() else "cpu"

model = TestModel().to(device)
model.eval()
x = torch.randn(16, 2).to(device)

In [2]:
out_fp32 = model(x)

In [3]:
def print_param_dtype(model):
    for name, param in model.named_parameters():
        print(f"{name} is loaded in {param.dtype}")

In [4]:
print(f'{out_fp32.shape}\n')

print_param_dtype(model)
print(out_fp32)

torch.Size([16, 8])

linear1.weight is loaded in torch.float32
linear1.bias is loaded in torch.float32
bn1.weight is loaded in torch.float32
bn1.bias is loaded in torch.float32
linear2.weight is loaded in torch.float32
linear2.bias is loaded in torch.float32
bn2.weight is loaded in torch.float32
bn2.bias is loaded in torch.float32
tensor([[0.8541, 0.1149, 0.1141, 0.0000, 0.1887, 0.2379, 0.0000, 0.4394],
        [0.4347, 0.0000, 0.1950, 0.0000, 0.4956, 0.0000, 0.0000, 0.0000],
        [0.4341, 0.0000, 0.1766, 0.0000, 0.5147, 0.0000, 0.0000, 0.0000],
        [1.0019, 0.3521, 0.0000, 0.0000, 0.4951, 0.2923, 0.0000, 0.1099],
        [1.0963, 0.3803, 0.0000, 0.0000, 0.0000, 0.5342, 0.0000, 0.5385],
        [0.7830, 0.0401, 0.2393, 0.0000, 0.0000, 0.2328, 0.0000, 0.5426],
        [0.6002, 0.0000, 0.2864, 0.0000, 0.2534, 0.0000, 0.0000, 0.4049],
        [0.5767, 0.0000, 0.2725, 0.0000, 0.3464, 0.0000, 0.0000, 0.3579],
        [0.6742, 0.0000, 0.0204, 0.0000, 0.5273, 0.0000, 0.0000, 0.0123],
 

### Cast test model into float16 (Half).
- Casting the entire model to fp16 interacts badly with batch norm layers. 
- float16 is also not supposed on some CPU kernels

In [7]:
model_fp16 = TestModel().to(dtype = torch.float16, device = device) # or TestModel().half()
model_fp16.eval()
x = x.to(torch.float16)

In [13]:
print_param_dtype(model_fp16)
out_fp16 = model_fp16(x)
print(out_fp16)

linear1.weight is loaded in torch.float16
linear1.bias is loaded in torch.float16
bn1.weight is loaded in torch.float16
bn1.bias is loaded in torch.float16
linear2.weight is loaded in torch.float16
linear2.bias is loaded in torch.float16
bn2.weight is loaded in torch.float16
bn2.bias is loaded in torch.float16
tensor([[0.1161, 0.2583, 0.0906, 0.0000, 0.2610, 0.2627, 0.1870, 0.0000],
        [0.4341, 0.2399, 0.2913, 0.0344, 0.5327, 0.1770, 0.1666, 0.0000],
        [0.4849, 0.2382, 0.3276, 0.0719, 0.5742, 0.1587, 0.1602, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.6655, 0.0000, 0.0000, 0.0000],
        [0.1161, 0.2583, 0.0906, 0.0000, 0.2610, 0.2627, 0.1870, 0.0000],
        [0.1161, 0.2583, 0.0906, 0.0000, 0.2610, 0.2627, 0.1870, 0.0000],
        [0.1552, 0.2537, 0.1135, 0.0000, 0.2791, 0.2722, 0.2109, 0.0000],
        [0.1766, 0.2510, 0.1261, 0.0000, 0.2888, 0.2776, 0.2241, 0.0000],
        [0.0000, 0.1134, 0.0000, 0.0000, 0.4780, 0.0280, 0.0000, 0.0000],
        [0.1161, 0.258

#### Instead, cast using mixed precision:

In [18]:
with torch.amp.autocast(device):
    out_fp16_autocast = model(x)

print(out_fp16_autocast)

tensor([[0.8540, 0.1149, 0.1140, 0.0000, 0.1888, 0.2378, 0.0000, 0.4390],
        [0.4348, 0.0000, 0.1949, 0.0000, 0.4956, 0.0000, 0.0000, 0.0000],
        [0.4341, 0.0000, 0.1765, 0.0000, 0.5146, 0.0000, 0.0000, 0.0000],
        [1.0020, 0.3521, 0.0000, 0.0000, 0.4951, 0.2925, 0.0000, 0.1100],
        [1.0967, 0.3809, 0.0000, 0.0000, 0.0000, 0.5347, 0.0000, 0.5386],
        [0.7832, 0.0404, 0.2394, 0.0000, 0.0000, 0.2334, 0.0000, 0.5425],
        [0.6006, 0.0000, 0.2861, 0.0000, 0.2537, 0.0000, 0.0000, 0.4048],
        [0.5767, 0.0000, 0.2725, 0.0000, 0.3467, 0.0000, 0.0000, 0.3579],
        [0.6743, 0.0000, 0.0203, 0.0000, 0.5273, 0.0000, 0.0000, 0.0123],
        [0.7778, 0.0338, 0.2205, 0.0000, 0.0494, 0.2075, 0.0000, 0.5088],
        [0.7363, 0.0000, 0.0278, 0.0000, 0.5000, 0.0000, 0.0000, 0.1305],
        [0.6284, 0.0000, 0.0658, 0.0000, 0.5176, 0.0000, 0.0000, 0.0363],
        [0.5996, 0.0000, 0.0719, 0.0000, 0.4763, 0.0000, 0.0000, 0.0000],
        [0.4529, 0.0000, 0.3083, 0.000

- Very close to fp32, but might not be ideal for deeper models

### Cast test model into bfloat16.
- bfloat16 is a more stable, and better alternative to fp16

In [30]:
from copy import deepcopy
model_bf16 = deepcopy(model)
model_bf16 = model_bf16.to(torch.bfloat16)
print_param_dtype(model_bf16)
out_bf16 = model_bf16(x.to(torch.bfloat16))
print(out_bf16)

linear1.weight is loaded in torch.bfloat16
linear1.bias is loaded in torch.bfloat16
bn1.weight is loaded in torch.bfloat16
bn1.bias is loaded in torch.bfloat16
linear2.weight is loaded in torch.bfloat16
linear2.bias is loaded in torch.bfloat16
bn2.weight is loaded in torch.bfloat16
bn2.bias is loaded in torch.bfloat16
tensor([[0.5703, 0.2402, 0.0532, 0.0000, 0.0000, 0.6211, 1.0781, 0.7109],
        [0.0000, 0.0000, 0.3457, 1.0000, 0.6211, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.3457, 1.0000, 0.6211, 0.0000, 0.0000, 0.0000],
        [0.9531, 1.4609, 0.0000, 0.0000, 0.8086, 0.9141, 0.0000, 0.0000],
        [1.9375, 1.6094, 0.0000, 0.0000, 0.0000, 2.0938, 2.4375, 1.3203],
        [0.6328, 0.0000, 0.9062, 0.0000, 0.0000, 0.6055, 1.5000, 1.8203],
        [0.0000, 0.0000, 0.8516, 0.0000, 0.0000, 0.0000, 0.2178, 0.9492],
        [0.0000, 0.0000, 0.7578, 0.0000, 0.0000, 0.0000, 0.0000, 0.5469],
        [0.0000, 0.0000, 0.3320, 0.9961, 0.6211, 0.0000, 0.0000, 0.0000],
        [0.507

### Compute output differences between fp32 and bf16

In [31]:
mean_diff = torch.abs(out_bf16 - out_fp32).mean().item()
max_diff = torch.abs(out_bf16 - out_fp32).max().item()

print(f"Mean diff: {mean_diff} | Max diff: {max_diff}")

Mean diff: 0.00309927249327302 | Max diff: 0.029530048370361328


### Output logits hold similar values, small differences between the full-precision model and the bf16 model.
- Does not lead to a huge performance degradation, even on large models.