In [21]:
from test_model import TestModel
import torch

torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

model = TestModel()
x = torch.randn(16, 2)

In [123]:
out_fp32 = model(x)

In [124]:
def print_param_dtype(model):
    for name, param in model.named_parameters():
        print(f"{name} is loaded in {param.dtype}")

In [125]:
print(f'{out.shape}\n')

print_param_dtype(model)
print(out_fp32)

torch.Size([16, 8])

linear1.weight is loaded in torch.float32
linear1.bias is loaded in torch.float32
bn1.weight is loaded in torch.float32
bn1.bias is loaded in torch.float32
linear2.weight is loaded in torch.float32
linear2.bias is loaded in torch.float32
bn2.weight is loaded in torch.float32
bn2.bias is loaded in torch.float32
tensor([[5.6570e-01, 2.6057e-01, 6.0959e-02, 0.0000e+00, 0.0000e+00, 6.4186e-01,
         1.0861e+00, 7.2191e-01],
        [0.0000e+00, 0.0000e+00, 3.4649e-01, 1.0093e+00, 6.2260e-01, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00, 3.4649e-01, 1.0093e+00, 6.2260e-01, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [9.5711e-01, 1.4652e+00, 0.0000e+00, 0.0000e+00, 8.1230e-01, 9.1862e-01,
         0.0000e+00, 0.0000e+00],
        [1.9353e+00, 1.6048e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 2.0986e+00,
         2.4531e+00, 1.3174e+00],
        [6.1562e-01, 0.0000e+00, 9.0805e-01, 0.0000e+00, 0.0000e+00, 6.0938e-01,
         1.

### Cast test model into float16 (Half).
- Outputs of the model are random, and change every time, even with a set seed. On such small values, the low-precision attribute of float16 bakes in rounding errors through multiple layers; in this case, batch norm, ReLU (removes negative values), linear. This generates different model outputs every run, in contrast to float32. float16 is also not supposed on some CPU kernels -> not an ideal data type in many cases.

In [126]:
model_fp16 = TestModel().to(torch.float16) # or TestModel().half()
print_param_dtype(model_fp16)
out_fp16 = model_fp16(x.to(torch.float16))
print(out_fp16)

linear1.weight is loaded in torch.float16
linear1.bias is loaded in torch.float16
bn1.weight is loaded in torch.float16
bn1.bias is loaded in torch.float16
linear2.weight is loaded in torch.float16
linear2.bias is loaded in torch.float16
bn2.weight is loaded in torch.float16
bn2.bias is loaded in torch.float16
tensor([[0.0000, 0.6367, 0.5112, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [1.5693, 0.0000, 0.0391, 0.8013, 0.9004, 0.8857, 0.0000, 0.9736],
        [1.9453, 0.0000, 0.1361, 0.9854, 0.9248, 0.9170, 0.0000, 1.2471],
        [0.0000, 0.0000, 0.0000, 0.8779, 0.0000, 0.0000, 0.7202, 0.8691],
        [0.4875, 0.1903, 0.9209, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3130, 1.5957, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3521, 0.6440, 0.0000, 0.4270, 0.4973, 1.0674, 0.0000],
        [0.0000, 0.3826, 0.2947, 0.0000, 0.5752, 0.6201, 1.1846, 0.0000],
        [0.5059, 0.7412, 0.0000, 0.0000, 0.4177, 0.2469, 0.2820, 0.0000],
        [0.0000, 0.440

### Cast test model into bfloat16.
- bfloat16 is stable, unlike float16, and achieves the same outputs after each run with a set seed.

In [127]:
from copy import deepcopy
model_bf16 = deepcopy(model)
model_bf16 = model_bf16.to(torch.bfloat16)
print_param_dtype(model_bf16)
out_bf16 = model_bf16(x.to(torch.bfloat16))
print(out_bf16)

linear1.weight is loaded in torch.bfloat16
linear1.bias is loaded in torch.bfloat16
bn1.weight is loaded in torch.bfloat16
bn1.bias is loaded in torch.bfloat16
linear2.weight is loaded in torch.bfloat16
linear2.bias is loaded in torch.bfloat16
bn2.weight is loaded in torch.bfloat16
bn2.bias is loaded in torch.bfloat16
tensor([[0.5703, 0.2412, 0.0532, 0.0000, 0.0000, 0.6211, 1.0859, 0.7109],
        [0.0000, 0.0000, 0.3457, 1.0000, 0.6211, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.3457, 1.0000, 0.6211, 0.0000, 0.0000, 0.0000],
        [0.9531, 1.4609, 0.0000, 0.0000, 0.8125, 0.9141, 0.0000, 0.0000],
        [1.9375, 1.6172, 0.0000, 0.0000, 0.0000, 2.0938, 2.4531, 1.3125],
        [0.6328, 0.0000, 0.9062, 0.0000, 0.0000, 0.6055, 1.5078, 1.8203],
        [0.0000, 0.0000, 0.8516, 0.0000, 0.0000, 0.0000, 0.2178, 0.9492],
        [0.0000, 0.0000, 0.7578, 0.0000, 0.0000, 0.0000, 0.0000, 0.5469],
        [0.0000, 0.0000, 0.3320, 0.9961, 0.6211, 0.0000, 0.0000, 0.0000],
        [0.507

### Compute output differences between fp32 and bf16

In [129]:
mean_diff = torch.abs(out_bf16 - out_fp32).mean().item()
max_diff = torch.abs(out_bf16 - out_fp32).max().item()

print(f"Mean diff: {mean_diff} | Max diff: {max_diff}")

Mean diff: 0.0030533424578607082 | Max diff: 0.029530048370361328


### Output logits hold similar values, small differences between the full-precision model and the bf16 model.
- Does not lead to a huge performance degradation, even on large models.