# float16 vs bfloat16 numerical properties comparsion

This a short notebook to help understand `fp16` vs `bfloat16` in particular when converting a model trained
in `bfloat16` to mixed precision - it should be possible to look at the numbers to know which ranges
are safe and which need to be scaled/avoided.

I needed to do that in the context of trying to understand why bfloat16 t5/mt5 models that were pretrained in bfloat16 had a lot of `nan`/`inf` problems when finetuned in mixed precision.

In [1]:
import torch

This is the main function, that tries to do very simply math in `bfloat16` and the converting the numbers to `float16` and showing the discrepancies.

In [2]:
def find_mismatch(start, incr):
    bf16 = torch.tensor(start, dtype=torch.bfloat16)
    print(f"\nfp32 start={start:.2e} using increment={incr}")
    print(f"{'bfloat16':>18} {'float16':>18} {'diff':>8}")
    c = 0
    tries = 0
    while c < 8:
        fp16 = bf16.to(torch.float16)
        if not (fp16 == bf16):
            print(f"{bf16:.16f} {fp16:.16f} {torch.sub(fp16.to(dtype=torch.float32), bf16):+.2e}")
            c += 1
        bf16 += incr
        tries += 1
        if tries >= 1e5:
            print(f"gave up finding mismatch after {tries} steps")
            return

Large numbers range

float16: ±65,504

## underflow for fp16

when numbers become 0.0

In [3]:
find_mismatch(1e-08, 1e-09)


fp32 start=1.00e-08 using increment=1e-09
          bfloat16            float16     diff
0.0000000100117177 0.0000000000000000 -1.00e-08
0.0000000110012479 0.0000000000000000 -1.10e-08
0.0000000119907781 0.0000000000000000 -1.20e-08
0.0000000129803084 0.0000000000000000 -1.30e-08
0.0000000139698386 0.0000000000000000 -1.40e-08
0.0000000150175765 0.0000000000000000 -1.50e-08
0.0000000160653144 0.0000000000000000 -1.61e-08
0.0000000171130523 0.0000000000000000 -1.71e-08


## subnormal range for fp16

starting from 5.96e-8 

usually expensive and very low precision

In [4]:
# very limited range for fp16
find_mismatch(1e-07, 1e-08)


fp32 start=1.00e-07 using increment=1e-08
          bfloat16            float16     diff
0.0000001001171768 0.0000001192092896 +1.91e-08
0.0000001098960638 0.0000001192092896 +9.31e-09
0.0000001201406121 0.0000001192092896 -9.31e-10
0.0000001303851604 0.0000001192092896 -1.12e-08
0.0000001406297088 0.0000001192092896 -2.14e-08
0.0000001508742571 0.0000001788139343 +2.79e-08
0.0000001611188054 0.0000001788139343 +1.77e-08
0.0000001713633537 0.0000001788139343 +7.45e-09


In [5]:
# things starting to improve slightly for fp16
find_mismatch(1e-06, 1e-07)


fp32 start=1.00e-06 using increment=1e-07
          bfloat16            float16     diff
0.0000009983778000 0.0000010132789612 +1.49e-08
0.0000010952353477 0.0000010728836060 -2.24e-08
0.0000012889504433 0.0000013113021851 +2.24e-08
0.0000013858079910 0.0000013709068298 -1.49e-08
0.0000014826655388 0.0000014901161194 +7.45e-09
0.0000015795230865 0.0000015497207642 -2.98e-08
0.0000016763806343 0.0000016689300537 -7.45e-09
0.0000017732381821 0.0000017881393433 +1.49e-08


## normal numbers start
min positive normal fp16: 6.104e-05 (np.finfo(np.float16).tiny)

these ranges match much better and thus will not easily find a mismatch if at all

In [6]:
find_mismatch(1e-05, 1e-06)
find_mismatch(1e-04, 1e-06)
find_mismatch(1e-03, 1e-04)
find_mismatch(1e-02, 1e-03)
find_mismatch(1e-01, 1e-02)
find_mismatch(1e1, 1e-06)
find_mismatch(1e1, 1e1)
find_mismatch(1e4, 1)


fp32 start=1.00e-05 using increment=1e-06
          bfloat16            float16     diff
gave up finding mismatch after 100000 steps

fp32 start=1.00e-04 using increment=1e-06
          bfloat16            float16     diff
gave up finding mismatch after 100000 steps

fp32 start=1.00e-03 using increment=0.0001
          bfloat16            float16     diff
gave up finding mismatch after 100000 steps

fp32 start=1.00e-02 using increment=0.001
          bfloat16            float16     diff
gave up finding mismatch after 100000 steps

fp32 start=1.00e-01 using increment=0.01
          bfloat16            float16     diff
gave up finding mismatch after 100000 steps

fp32 start=1.00e+01 using increment=1e-06
          bfloat16            float16     diff
gave up finding mismatch after 100000 steps

fp32 start=1.00e+01 using increment=10.0
          bfloat16            float16     diff
gave up finding mismatch after 100000 steps

fp32 start=1.00e+04 using increment=1
          bfloat16      

In [7]:
# hitting max range for fp16
find_mismatch(5e4, 1e3)

print("")


fp32 start=5.00e+04 using increment=1000.0
          bfloat16            float16     diff
66048.0000000000000000 inf +inf
67072.0000000000000000 inf +inf
68096.0000000000000000 inf +inf
69120.0000000000000000 inf +inf
70144.0000000000000000 inf +inf
71168.0000000000000000 inf +inf
72192.0000000000000000 inf +inf
73216.0000000000000000 inf +inf



In [8]:
# --- roundoff ---
# fp16 4.88e-4
# bf16 3.91e-3

## Big numbers

`bfloat16` numbers have a terrible range for numbers > 1 but fp16 matches those exactly
e.g. one can't represent 283 in bf16

```
python -c "import torch; print( torch.tensor(283, dtype=torch.bfloat16) )"
tensor(284., dtype=torch.bfloat16)
```

In [9]:
start = 280
fp32 = torch.tensor(start, dtype=torch.float32)
for i in range(3):
    bf16 = fp32.to(torch.bfloat16)
    bf16d = bf16
    while bf16 == bf16d:
        fp32 += 1
        bf16d = fp32.to(torch.bfloat16)
    print(f"{bf16d:.16f}")
# 282
# 284
# 286

282.0000000000000000
284.0000000000000000
286.0000000000000000


## Summation

a very narrow dynamic range means that for largish numbers NN trained in bfloat16 **expects** bad
precision and when the precision is suddenly higher unexpected outcomes happen:

In [10]:
# small sum
print(torch.tensor(282, dtype=torch.bfloat16)+1) # 284
print(torch.tensor(282, dtype=torch.float16)+1)  # 283

# sum several of these
print(torch.tensor(283, dtype=torch.bfloat16)*10) # 2848
print(torch.tensor(283, dtype=torch.float16)*10)  # 2830

tensor(284., dtype=torch.bfloat16)
tensor(283., dtype=torch.float16)
tensor(2848., dtype=torch.bfloat16)
tensor(2830., dtype=torch.float16)


# disabling subnormal numbers in pytorch
    

In [11]:
    
torch.set_flush_denormal(True)
torch.tensor([1e-39], dtype=torch.float32)
torch.set_flush_denormal(False)
torch.tensor([1e-39], dtype=torch.float32)

# broken for fp16
torch.set_flush_denormal(True)
torch.tensor([1e-6], dtype=torch.float16)
torch.set_flush_denormal(False)
torch.tensor([1e-6], dtype=torch.float16)

torch.set_flush_denormal(True)
torch.tensor([1e-39], dtype=torch.bfloat16)
torch.set_flush_denormal(False)
torch.tensor([1e-39], dtype=torch.bfloat16)

True

tensor([0.])

True

tensor([1.0000e-39])

True

tensor([1.0133e-06], dtype=torch.float16)

True

tensor([1.0133e-06], dtype=torch.float16)

True

tensor([0.], dtype=torch.bfloat16)

True

tensor([1.0102e-39], dtype=torch.bfloat16)