## Understand torch.float32

My original notebook where I tried to figure this out is is down below under "Original playing around to understand torch.float32"

I realized a few months later I was still a little confused and so took the time to understand it a bit better, this time "cheating" by looking things up. Putting this new code here becuase it's a much cleaner reference for the future. These functions also make it easy to look at the bytes/bits of other types.

In [2]:
import torch

### scalars to bytes and bits

In [4]:
def scalar_tensor(scalar, torch_type):
    return torch.tensor(scalar, dtype=torch_type)

In [5]:
def get_byte_vals(scalar_tensor):
    # return bytes from most to least significant
    # (at least on my mac, not sure endian order is computer dependent)
    return(list(reversed(scalar_tensor.untyped_storage())))

In [6]:
def bytes_to_bit_string(byte_vals):
    return ' '.join([format(byte_val, '08b') for byte_val in byte_vals])

In [7]:
def print_scalar(scalar, torch_type):
    print(f"{scalar} as {torch_type} -> {bytes_to_bit_string(get_byte_vals(scalar_tensor(scalar, torch_type)))}")

### bytes and bits to scalars

In [9]:
def get_scalar_tensor_from_bytes(byte_vals, torch_type):
    # form a scalar tensor from a list of bytes from most to least significant
    return_tensor = torch.tensor(0, dtype=torch_type)
    n = len(byte_vals)
    assert n == torch_type.itemsize
    for i in range(n):
        return_tensor.untyped_storage()[i] = byte_vals[n-1-i] # order needs to be reversed (at least on my mac)
    return return_tensor

### get float components from float32

In [10]:
def get_sign(byte_vals):
    # example:
    # 10111101 01001100 11001100 11001101
    # ^
    return 1 if byte_vals[0] & 0b10000000 == 0 else -1

In [11]:
def get_exponent(byte_vals):
    # example:
    # 10111101 01001100 11001100 11001101
    #  ^^^^^^^ ^
    return (((byte_vals[0] & 0b01111111) << 1) | (byte_vals[1] >> 7)) - 127  # in a real implementation 0 and 255 have special meanings

In [12]:
def get_mantissa(byte_vals):
    # example:
    # 10111101 01001100 11001100 11001101 
    #           ^^^^^^^ ^^^^^^^^ ^^^^^^^^ ^ (last ^ indicates implied 0) ???
    return ((byte_vals[1] & 0b01111111) << 17) | (byte_vals[2] << 9) | (byte_vals[3] << 1)

In [13]:
byte_vals = get_byte_vals(scalar_tensor(1, torch.float32))
print(bytes_to_bit_string(byte_vals))
sign, exponent, mantissa = get_sign(byte_vals), get_exponent(byte_vals), get_mantissa(byte_vals)
print(sign, exponent, mantissa)
sign * 2 ** exponent * (1 + mantissa / 0b1000000000000000000000000)

00111111 10000000 00000000 00000000
1 0 0


1.0

### calculate float from components

In [14]:
def get_float_value(sign, exponent, mantissa):
    if exponent == -127 and mantissa == 0:
        return 0
    elif exponent == 128:
        if mantissa == 0:
            return sign * float('inf')
        else:
            return float('nan')
    return sign * 2 ** exponent * (1 + mantissa / 0b1000000000000000000000000)

### examples

In [15]:
for n in (0, 1, -1, 2, 3, 3.5, 3.0001, 0.5, 0.00005, 0.25):
    byte_vals = get_byte_vals(scalar_tensor(n, torch.float32))
    sign, exponent, mantissa = get_sign(byte_vals), get_exponent(byte_vals), get_mantissa(byte_vals)
    print(f"{n} -> {bytes_to_bit_string(byte_vals)} -> s={sign}, e={exponent}, m={mantissa} -> {get_float_value(sign, exponent, mantissa)}")

0 -> 00000000 00000000 00000000 00000000 -> s=1, e=-127, m=0 -> 0
1 -> 00111111 10000000 00000000 00000000 -> s=1, e=0, m=0 -> 1.0
-1 -> 10111111 10000000 00000000 00000000 -> s=-1, e=0, m=0 -> -1.0
2 -> 01000000 00000000 00000000 00000000 -> s=1, e=1, m=0 -> 2.0
3 -> 01000000 01000000 00000000 00000000 -> s=1, e=1, m=8388608 -> 3.0
3.5 -> 01000000 01100000 00000000 00000000 -> s=1, e=1, m=12582912 -> 3.5
3.0001 -> 01000000 01000000 00000001 10100011 -> s=1, e=1, m=8389446 -> 3.0000998973846436
0.5 -> 00111111 00000000 00000000 00000000 -> s=1, e=-1, m=0 -> 0.5
5e-05 -> 00111000 01010001 10110111 00010111 -> s=1, e=-15, m=10710574 -> 4.999999873689376e-05
0.25 -> 00111110 10000000 00000000 00000000 -> s=1, e=-2, m=0 -> 0.25


### examples that make it easy to see the extremes

In [18]:
big =                  get_scalar_tensor_from_bytes([0b01111111, 0b00000000, 0b00000000, 0b00000001], torch.float32).item()
even_bigger =          get_scalar_tensor_from_bytes([0b01111111, 0b01111111, 0b11111111, 0b11111111], torch.float32).item()
infinity =             get_scalar_tensor_from_bytes([0b01111111, 0b10000000, 0b00000000, 0b00000000], torch.float32).item()
nan =                  get_scalar_tensor_from_bytes([0b01111111, 0b10000000, 0b00000000, 0b00000001], torch.float32).item()
small =                get_scalar_tensor_from_bytes([0b00000000, 0b10000000, 0b00000000, 0b00000001], torch.float32).item()
smaller =              get_scalar_tensor_from_bytes([0b00000000, 0b10000000, 0b00000000, 0b00000000], torch.float32).item()
even_smaller =         get_scalar_tensor_from_bytes([0b00000000, 0b00000000, 0b00000000, 0b00000001], torch.float32).item()
zero =                 get_scalar_tensor_from_bytes([0b00000000, 0b00000000, 0b00000000, 0b00000000], torch.float32).item()
negative_even_bigger = get_scalar_tensor_from_bytes([0b11111111, 0b01111111, 0b11111111, 0b11111111], torch.float32).item()
negative_infinity =    get_scalar_tensor_from_bytes([0b11111111, 0b10000000, 0b00000000, 0b00000000], torch.float32).item()

In [19]:
for n in (big, even_bigger, infinity, nan, small, smaller, even_smaller, zero, negative_even_bigger, negative_infinity):
    byte_vals = get_byte_vals(scalar_tensor(n, torch.float32))
    sign, exponent, mantissa = get_sign(byte_vals), get_exponent(byte_vals), get_mantissa(byte_vals)
    print(f"{n} -> {bytes_to_bit_string(byte_vals)} -> s={sign}, e={exponent}, m={mantissa} -> {get_float_value(sign, exponent, mantissa)}")

1.7014120374287884e+38 -> 01111111 00000000 00000000 00000001 -> s=1, e=127, m=2 -> 1.7014120374287884e+38
3.4028234663852886e+38 -> 01111111 01111111 11111111 11111111 -> s=1, e=127, m=16777214 -> 3.4028234663852886e+38
inf -> 01111111 10000000 00000000 00000000 -> s=1, e=128, m=0 -> inf
nan -> 01111111 11000000 00000000 00000001 -> s=1, e=128, m=8388610 -> nan
1.175494490952134e-38 -> 00000000 10000000 00000000 00000001 -> s=1, e=-126, m=2 -> 1.175494490952134e-38
1.1754943508222875e-38 -> 00000000 10000000 00000000 00000000 -> s=1, e=-126, m=0 -> 1.1754943508222875e-38
1.401298464324817e-45 -> 00000000 00000000 00000000 00000001 -> s=1, e=-127, m=2 -> 5.87747245476067e-39
0.0 -> 00000000 00000000 00000000 00000000 -> s=1, e=-127, m=0 -> 0
-3.4028234663852886e+38 -> 11111111 01111111 11111111 11111111 -> s=-1, e=127, m=16777214 -> -3.4028234663852886e+38
-inf -> 11111111 10000000 00000000 00000000 -> s=-1, e=128, m=0 -> -inf


## Original playing around to understand torch.float32

I thought it would be a 5 minute side trek to understand how `torch.float32`s work. Turned out to take much longer. I got confused and didn't want to give up and look it up. I moved the exploration to this notebook from `understand-memory-needed.ipynb` so that notebook wouldn't get too confusing.

In [1]:
import torch

In [2]:
torch.float32, torch.float16, torch.bfloat16

(torch.float32, torch.float16, torch.bfloat16)

I've noticed he makes all the non-integer tensors either `torch.float32` or `torch.bfloat16`. There also seems to be casting (and autocasting?) going on and there was the compliation warning in the prior challenge: "UserWarning: Quadro RTX 4000 does not support bfloat16 compilation natively, skipping (for example, if a GPU doesnt' support bfloat16 compilation, is it better not to use bfloat16 at all?)" This makes me think bfloat16 is something newer and important to understand.

Let me guess before I look it up.

It seems clear why if you can "get away" with 16 bits instead of 32 bits per paramaeter that's better because you can store twice as many parameters for the same memory. I also imagine it's faster (by what factor?) to multiply (for example) two dtype=16 bit tensors than two 32 bit tensors assuming the tensors have the same dimension. There are half the bits to input and compute.

There must be a way to figure out how many significant (binary) digits you need. Like are there some models with some params where 32 is not enough and you want 64? Let's come back to that.

The question is what makes bfloat16 different than float16? I doubt it's just the same thing.

I imagine the way floating point works is 1 bit is used for the sign, X bits are used for all the significant digits (not sure what you call that, like the 123 in 1.23 x 10^-2), and Y bits are used for the "-2" part. That "-2" part is like a signed integer, so for example if 10 bits went to that, it would allow from -512 to 512. So for float32, 1 + X + Y = 32, and I imagine X is much bigger than Y. Does that mean it's possible to get an error when trying to create a float that is too large?

In [3]:
torch.tensor(1.23, dtype=torch.float32)

tensor(1.2300)

In [4]:
torch.tensor(1.23e10, dtype=torch.float32)

tensor(1.2300e+10)

In [5]:
torch.tensor(1.23e32, dtype=torch.float32)

tensor(1.2300e+32)

In [6]:
torch.tensor(1.23e38, dtype=torch.float32)

tensor(1.2300e+38)

In [7]:
torch.tensor(1.23e39, dtype=torch.float32)

tensor(inf)

So after a certain point it stores it as inf? But why 39? What happens when it gets too small? Does it store it as 0?

In [8]:
torch.tensor(1.23e-38, dtype=torch.float32)

tensor(1.2300e-38)

In [9]:
torch.tensor(1.23e-39, dtype=torch.float32)

tensor(1.2300e-39)

In [10]:
torch.tensor(1.23e-45, dtype=torch.float32)

tensor(1.4013e-45)

In [11]:
torch.tensor(1.23e-46, dtype=torch.float32)

tensor(0.)

Yes, but not symmetrical. But also ~38 + ~45 = ~83 which isn't close to a power of 2. Hmm.

In [12]:
torch.tensor(1.2345678e-45, dtype=torch.float32)

tensor(1.4013e-45)

That's odd.

In [13]:
torch.tensor(1.2345678e-44, dtype=torch.float32)

tensor(1.2612e-44)

In [14]:
torch.tensor(1.2345678e-43, dtype=torch.float32)

tensor(1.2331e-43)

Let's look at the actual bytes. Maybe that will help.

In [15]:
torch.tensor(1, dtype=torch.float32).untyped_storage()

 0
 0
 128
 63
[torch.storage.UntypedStorage(device=cpu) of size 4]

In [16]:
torch.tensor(1e10, dtype=torch.float32).untyped_storage()

 249
 2
 21
 80
[torch.storage.UntypedStorage(device=cpu) of size 4]

In [17]:
def bits(scalar_tensor):
    return [format(byte_val, '08b') for byte_val in scalar_tensor.untyped_storage()]

In [18]:
bits(torch.tensor(1, dtype=torch.float32))

['00000000', '00000000', '10000000', '00111111']

In [19]:
bits(torch.tensor(1e0, dtype=torch.float32)) # hope this is the same!

['00000000', '00000000', '10000000', '00111111']

In [20]:
bits(torch.tensor(2, dtype=torch.float32))

['00000000', '00000000', '00000000', '01000000']

In [21]:
bits(torch.tensor(3, dtype=torch.float32))

['00000000', '00000000', '01000000', '01000000']

This seems weird. How about 0?

In [22]:
bits(torch.tensor(0, dtype=torch.float32))

['00000000', '00000000', '00000000', '00000000']

Ah, wait, I'm using decimal scientific notation, but must be binary. Let's write out 0 - 8 decimal as "binary" scientific notation.

Since e means 10^ as in 1.23e2, let me use E to mean 2^
```
 0 is 0     0E0
 1 is 1     1E0
 2 is 10    1E1
 3 is 11    1.1E1
 4 is 100   1E10
 5 is 101   1.01E10
 6 is 110   1.10E10
 7 is 111   1.11E10
 8 is 1000  1E11
```

In [23]:
for i in range(12):
    print(f"{i:02d} -> {bits(torch.tensor(i, dtype=torch.float32))}")

00 -> ['00000000', '00000000', '00000000', '00000000']
01 -> ['00000000', '00000000', '10000000', '00111111']
02 -> ['00000000', '00000000', '00000000', '01000000']
03 -> ['00000000', '00000000', '01000000', '01000000']
04 -> ['00000000', '00000000', '10000000', '01000000']
05 -> ['00000000', '00000000', '10100000', '01000000']
06 -> ['00000000', '00000000', '11000000', '01000000']
07 -> ['00000000', '00000000', '11100000', '01000000']
08 -> ['00000000', '00000000', '00000000', '01000001']
09 -> ['00000000', '00000000', '00010000', '01000001']
10 -> ['00000000', '00000000', '00100000', '01000001']
11 -> ['00000000', '00000000', '00110000', '01000001']


Not sure what's going on yet, but seeing a bit of binary counting that "might" involve the final digit of the last byte makes me wonder if we're even displaying the bytes in the right order, there could be some endian thing.

In [24]:
def bits_reversed(scalar_tensor):
    return list(reversed([format(byte_val, '08b') for byte_val in scalar_tensor.untyped_storage()]))

In [25]:
for i in [-2,-1,0,1,2]:
    print(f"{i:04d} -> {bits_reversed(torch.tensor(i, dtype=torch.float32))}")

-002 -> ['11000000', '00000000', '00000000', '00000000']
-001 -> ['10111111', '10000000', '00000000', '00000000']
0000 -> ['00000000', '00000000', '00000000', '00000000']
0001 -> ['00111111', '10000000', '00000000', '00000000']
0002 -> ['01000000', '00000000', '00000000', '00000000']


That first bit looks like it could be the sign bit, is that a sign this reversed order is the better way to look at things?

In [26]:
for n in [-5234,-2,-1.23,-1,0,1,1.23,2,5234]:
    print(f"{n:12.4f} -> {bits_reversed(torch.tensor(n, dtype=torch.float32))}")

  -5234.0000 -> ['11000101', '10100011', '10010000', '00000000']
     -2.0000 -> ['11000000', '00000000', '00000000', '00000000']
     -1.2300 -> ['10111111', '10011101', '01110000', '10100100']
     -1.0000 -> ['10111111', '10000000', '00000000', '00000000']
      0.0000 -> ['00000000', '00000000', '00000000', '00000000']
      1.0000 -> ['00111111', '10000000', '00000000', '00000000']
      1.2300 -> ['00111111', '10011101', '01110000', '10100100']
      2.0000 -> ['01000000', '00000000', '00000000', '00000000']
   5234.0000 -> ['01000101', '10100011', '10010000', '00000000']


ok, so if the first bit is the sign bit, what is the next one? It's 0 for 0 and 1 and 1 for integers 2 to at least 1000. Let's find some numbers where it's 0.

In [27]:
for n in [0, 1, 1.5, 1.55, 1.99, 1.9999, 2, 2.001, 2.5, 3, 3.5, 3.56, 1_000, 12_000]:
    print(f"{n:12.4f} -> {bits_reversed(torch.tensor(n, dtype=torch.float32))}")

      0.0000 -> ['00000000', '00000000', '00000000', '00000000']
      1.0000 -> ['00111111', '10000000', '00000000', '00000000']
      1.5000 -> ['00111111', '11000000', '00000000', '00000000']
      1.5500 -> ['00111111', '11000110', '01100110', '01100110']
      1.9900 -> ['00111111', '11111110', '10111000', '01010010']
      1.9999 -> ['00111111', '11111111', '11111100', '10111001']
      2.0000 -> ['01000000', '00000000', '00000000', '00000000']
      2.0010 -> ['01000000', '00000000', '00010000', '01100010']
      2.5000 -> ['01000000', '00100000', '00000000', '00000000']
      3.0000 -> ['01000000', '01000000', '00000000', '00000000']
      3.5000 -> ['01000000', '01100000', '00000000', '00000000']
      3.5600 -> ['01000000', '01100011', '11010111', '00001010']
   1000.0000 -> ['01000100', '01111010', '00000000', '00000000']
  12000.0000 -> ['01000110', '00111011', '10000000', '00000000']


Let's say bit 2 is 0 only when -2 < n < 2 ... why would that be?

Does anything jump out about that 3rd bit?

In [28]:
for n in [0, 1e-20, 1e-19, 1.1e-19, 2e-19, 1e-18, 1e-17, 1e-10, 1e-5, 1e-2, 1e-1]:
    print(f"{n:20.20f} -> {bits_reversed(torch.tensor(n, dtype=torch.float32))}")

0.00000000000000000000 -> ['00000000', '00000000', '00000000', '00000000']
0.00000000000000000001 -> ['00011110', '00111100', '11100101', '00001000']
0.00000000000000000010 -> ['00011111', '11101100', '00011110', '01001010']
0.00000000000000000011 -> ['00100000', '00000001', '11011101', '01110110']
0.00000000000000000020 -> ['00100000', '01101100', '00011110', '01001010']
0.00000000000000000100 -> ['00100001', '10010011', '10010010', '11101111']
0.00000000000000001000 -> ['00100011', '00111000', '01110111', '10101010']
0.00000000010000000000 -> ['00101110', '11011011', '11100110', '11111111']
0.00001000000000000000 -> ['00110111', '00100111', '11000101', '10101100']
0.01000000000000000021 -> ['00111100', '00100011', '11010111', '00001010']
0.10000000000000000555 -> ['00111101', '11001100', '11001100', '11001101']


I'm wasting time and am going to just look this up if I can't figure it out soon. Let me see if anything jumps out if I go the other way around and go from bytes to number.

In [29]:
def get_torch32_value(byte_array):
    temp = torch.tensor(0, dtype=torch.float32)
    assert len(byte_array) == 4
    for i in range(4):
        temp.untyped_storage()[i] = byte_array[3-i] # stick with the reverse order
    return temp.item()

Let's first try keeping the first bit (sign bit?) at 0 and the 2nd, 3rd, and 4th bytes 0.

In [30]:
for b in range(0,128):
    n = get_torch32_value([b, 0, 0, 0])
    print(f"{n:30.30f} -> {bits_reversed(torch.tensor(n, dtype=torch.float32))}")

0.000000000000000000000000000000 -> ['00000000', '00000000', '00000000', '00000000']
0.000000000000000000000000000000 -> ['00000001', '00000000', '00000000', '00000000']
0.000000000000000000000000000000 -> ['00000010', '00000000', '00000000', '00000000']
0.000000000000000000000000000000 -> ['00000011', '00000000', '00000000', '00000000']
0.000000000000000000000000000000 -> ['00000100', '00000000', '00000000', '00000000']
0.000000000000000000000000000000 -> ['00000101', '00000000', '00000000', '00000000']
0.000000000000000000000000000000 -> ['00000110', '00000000', '00000000', '00000000']
0.000000000000000000000000000000 -> ['00000111', '00000000', '00000000', '00000000']
0.000000000000000000000000000000 -> ['00001000', '00000000', '00000000', '00000000']
0.000000000000000000000000000000 -> ['00001001', '00000000', '00000000', '00000000']
0.000000000000000000000000000000 -> ['00001010', '00000000', '00000000', '00000000']
0.000000000000000000000000000000 -> ['00001011', '00000000', '000

from 2 on, it looks like if x is the number from bits 3 to 8, then the value is 2^(2x+1)

and if 111111 is -1,  111110 is -2, 111101 is -3, etc. then it for 0.5, 0.125, etc.

So actually let x = (value of bytes 2...8) - 64
The value of the float is 2^(2x+1)

for example 1000010 is 66
66 - 64 = 2
2^5 = 32

for example 0111100 is 60
60 - 64 = -4
2^(-7) = 0.0078125

Now let's look at 10000000 in the 2nd byte with a few different first bytes and 0 in third and 4th.

In [31]:
for b in range(60,68):
    n = get_torch32_value([b, 128, 0, 0])
    print(f"{n:30.20f} -> {bits_reversed(torch.tensor(n, dtype=torch.float32))} {b-64}")

        0.01562500000000000000 -> ['00111100', '10000000', '00000000', '00000000'] -4
        0.06250000000000000000 -> ['00111101', '10000000', '00000000', '00000000'] -3
        0.25000000000000000000 -> ['00111110', '10000000', '00000000', '00000000'] -2
        1.00000000000000000000 -> ['00111111', '10000000', '00000000', '00000000'] -1
        4.00000000000000000000 -> ['01000000', '10000000', '00000000', '00000000'] 0
       16.00000000000000000000 -> ['01000001', '10000000', '00000000', '00000000'] 1
       64.00000000000000000000 -> ['01000010', '10000000', '00000000', '00000000'] 2
      256.00000000000000000000 -> ['01000011', '10000000', '00000000', '00000000'] 3


So for these, letting x be the same as above (value of bytes 2...8 - 64), then this will work:

2^(2x+2)

but why would 128 in the second byte tell us to do that?

In [32]:
for b in range(62,66):
    n = get_torch32_value([b, 64, 0, 0])
    print(f"{n:20.15f} -> {bits_reversed(torch.tensor(n, dtype=torch.float32))} {b-64}")

   0.187500000000000 -> ['00111110', '01000000', '00000000', '00000000'] -2
   0.750000000000000 -> ['00111111', '01000000', '00000000', '00000000'] -1
   3.000000000000000 -> ['01000000', '01000000', '00000000', '00000000'] 0
  12.000000000000000 -> ['01000001', '01000000', '00000000', '00000000'] 1


In [33]:
for b in range(0,128):
    n = get_torch32_value([64, b, 0, 0])
    print(f"{n:20.15f} -> {bits_reversed(torch.tensor(n, dtype=torch.float32))} {b}")

   2.000000000000000 -> ['01000000', '00000000', '00000000', '00000000'] 0
   2.015625000000000 -> ['01000000', '00000001', '00000000', '00000000'] 1
   2.031250000000000 -> ['01000000', '00000010', '00000000', '00000000'] 2
   2.046875000000000 -> ['01000000', '00000011', '00000000', '00000000'] 3
   2.062500000000000 -> ['01000000', '00000100', '00000000', '00000000'] 4
   2.078125000000000 -> ['01000000', '00000101', '00000000', '00000000'] 5
   2.093750000000000 -> ['01000000', '00000110', '00000000', '00000000'] 6
   2.109375000000000 -> ['01000000', '00000111', '00000000', '00000000'] 7
   2.125000000000000 -> ['01000000', '00001000', '00000000', '00000000'] 8
   2.140625000000000 -> ['01000000', '00001001', '00000000', '00000000'] 9
   2.156250000000000 -> ['01000000', '00001010', '00000000', '00000000'] 10
   2.171875000000000 -> ['01000000', '00001011', '00000000', '00000000'] 11
   2.187500000000000 -> ['01000000', '00001100', '00000000', '00000000'] 12
   2.203125000000000 -

Ah! So looks like bit 1 is sign, bits 2 - 8 are "magnitude", and everything else is "significant digits." But how exactly?

In [34]:
for b in range(0,10):
    n = get_torch32_value([64, 0, 0, b])
    print(f"{n:20.15f} -> {bits_reversed(torch.tensor(n, dtype=torch.float32))} {b}")

   2.000000000000000 -> ['01000000', '00000000', '00000000', '00000000'] 0
   2.000000238418579 -> ['01000000', '00000000', '00000000', '00000001'] 1
   2.000000476837158 -> ['01000000', '00000000', '00000000', '00000010'] 2
   2.000000715255737 -> ['01000000', '00000000', '00000000', '00000011'] 3
   2.000000953674316 -> ['01000000', '00000000', '00000000', '00000100'] 4
   2.000001192092896 -> ['01000000', '00000000', '00000000', '00000101'] 5
   2.000001430511475 -> ['01000000', '00000000', '00000000', '00000110'] 6
   2.000001668930054 -> ['01000000', '00000000', '00000000', '00000111'] 7
   2.000001907348633 -> ['01000000', '00000000', '00000000', '00001000'] 8
   2.000002145767212 -> ['01000000', '00000000', '00000000', '00001001'] 9


In [35]:
format(2 ** -21,'20.15f') # last byte 00000010 

'   0.000000476837158'

In [36]:
format(2 ** -21 + 2 ** -22,'20.15f') # last byte 00000011

'   0.000000715255737'

In [37]:
format(0b11 / 2 ** 22, '20.15f') # so it's basically DD.DD...--22 total--...DD where D is a binary digit  ?

'   0.000000715255737'

In [38]:
for b in range(63,67):
    n = get_torch32_value([64, b, 0, 0])
    print(f"{n:20.15f} -> {bits_reversed(torch.tensor(n, dtype=torch.float32))} {b}")

   2.984375000000000 -> ['01000000', '00111111', '00000000', '00000000'] 63
   3.000000000000000 -> ['01000000', '01000000', '00000000', '00000000'] 64
   3.015625000000000 -> ['01000000', '01000001', '00000000', '00000000'] 65
   3.031250000000000 -> ['01000000', '01000010', '00000000', '00000000'] 66


In [39]:
0b001111110000000000000000 / 2 ** 22

0.984375

In [40]:
0b010000000000000000000000 / 2 ** 22

1.0

In [41]:
0b010000010000000000000000 / 2 ** 22

1.015625

In [42]:
0b010000100000000000000000 / 2 ** 22

1.03125

In [43]:
for b in range(128,133):
    n = get_torch32_value([64, b, 0, 0])
    print(f"{n:20.15f} -> {bits_reversed(torch.tensor(n, dtype=torch.float32))} {b}")

   4.000000000000000 -> ['01000000', '10000000', '00000000', '00000000'] 128
   4.031250000000000 -> ['01000000', '10000001', '00000000', '00000000'] 129
   4.062500000000000 -> ['01000000', '10000010', '00000000', '00000000'] 130
   4.093750000000000 -> ['01000000', '10000011', '00000000', '00000000'] 131
   4.125000000000000 -> ['01000000', '10000100', '00000000', '00000000'] 132


In [44]:
0b100000000000000000000000 / 2 ** 22

2.0

In [45]:
0b100000010000000000000000 / 2 ** 22 # what -- why is that wrong?

2.015625

In [46]:
0b100000100000000000000000 / 2 ** 22

2.03125

In [47]:
0b100000110000000000000000 / 2 ** 22

2.046875

In [48]:
0b100001000000000000000000 / 2 ** 22

2.0625

They aren't right, why?

In [49]:
for b in range(63,66):
    for b2 in range(128,130):
        for b3 in range(0,2):
            n = get_torch32_value([b, b2, b3, 0])
            print(f"{n:30.20f} -> {bits_reversed(torch.tensor(n, dtype=torch.float32))}")

        1.00000000000000000000 -> ['00111111', '10000000', '00000000', '00000000']
        1.00003051757812500000 -> ['00111111', '10000000', '00000001', '00000000']
        1.00781250000000000000 -> ['00111111', '10000001', '00000000', '00000000']
        1.00784301757812500000 -> ['00111111', '10000001', '00000001', '00000000']
        4.00000000000000000000 -> ['01000000', '10000000', '00000000', '00000000']
        4.00012207031250000000 -> ['01000000', '10000000', '00000001', '00000000']
        4.03125000000000000000 -> ['01000000', '10000001', '00000000', '00000000']
        4.03137207031250000000 -> ['01000000', '10000001', '00000001', '00000000']
       16.00000000000000000000 -> ['01000001', '10000000', '00000000', '00000000']
       16.00048828125000000000 -> ['01000001', '10000000', '00000001', '00000000']
       16.12500000000000000000 -> ['01000001', '10000001', '00000000', '00000000']
       16.12548828125000000000 -> ['01000001', '10000001', '00000001', '00000000']


In [50]:
0b100000000000000100000000 / 2 ** 23

1.000030517578125

In [51]:
(0b100000000000000100000000 / 2 ** 23) * 4

4.0001220703125

In [52]:
# which is the same as
(0b100000000000000100000000 / 2 ** 21) 

4.0001220703125

In [53]:
(0b100000000000000100000000 / 2 ** 23) * 16

16.00048828125

In [54]:
# which is the same as
(0b100000000000000100000000 / 2 ** 19)

16.00048828125

So why can't it just be bytes 2 - 4 are the "number" and bits 2 - 8 of byte 1 says where to place the "decimal" point like normal scientific notation in binary. But it can't be that, becuase 2 and 3 couldn't work:

In [55]:
for n in [2, 3]:
    print(f"{n:30.20f} -> {bits_reversed(torch.tensor(n, dtype=torch.float32))}")

        2.00000000000000000000 -> ['01000000', '00000000', '00000000', '00000000']
        3.00000000000000000000 -> ['01000000', '01000000', '00000000', '00000000']


In [56]:
for b in range(64,67):
    for b2 in range(0,256):
        for b3 in [0]:
            n = get_torch32_value([b, b2, b3, 0])
            print(f"{n:30.20f} -> {bits_reversed(torch.tensor(n, dtype=torch.float32))}")

        2.00000000000000000000 -> ['01000000', '00000000', '00000000', '00000000']
        2.01562500000000000000 -> ['01000000', '00000001', '00000000', '00000000']
        2.03125000000000000000 -> ['01000000', '00000010', '00000000', '00000000']
        2.04687500000000000000 -> ['01000000', '00000011', '00000000', '00000000']
        2.06250000000000000000 -> ['01000000', '00000100', '00000000', '00000000']
        2.07812500000000000000 -> ['01000000', '00000101', '00000000', '00000000']
        2.09375000000000000000 -> ['01000000', '00000110', '00000000', '00000000']
        2.10937500000000000000 -> ['01000000', '00000111', '00000000', '00000000']
        2.12500000000000000000 -> ['01000000', '00001000', '00000000', '00000000']
        2.14062500000000000000 -> ['01000000', '00001001', '00000000', '00000000']
        2.15625000000000000000 -> ['01000000', '00001010', '00000000', '00000000']
        2.17187500000000000000 -> ['01000000', '00001011', '00000000', '00000000']
    

Looking above, from 2 to 4 we step by .015625
               from 4 to 8 we step by .03125
               from 8 to 16 we step by .0625
               ...
            
So it's like bytes 2,3,4 equally divide the powers of 2

ugh! The first bit of byte 2 goes with byte 1, that makes much more sense and we won't get the weird powers stepping by 2 thing

This means...

bits 2 to 9 gives us the power of 2, call it x (x = (bits 2 to 9) - 127)

(bits 10 to 32 / 2 ** 23) * 2^x gives us the rest (becuase we want that fraction of the range e.g 32 -> 64 is 32 * fraction)

so 2^x + (bits 10 to 32 / 2 ** 23) * 2^x

In [57]:
# 126.50000000000000000000 -> ['01000010', '11111101', '00000000', '00000000']
x = 0b10000101 - 127
2 ** x + (0b11111010000000000000000 / 2 ** 23) * 2 ** x

126.5

In [58]:
# 6.46875000000000000000 -> ['01000000', '11001111', '00000000', '00000000']
x = 0b10000001 - 127
2 ** x + (0b10011110000000000000000 / 2 ** 23) * 2 ** x

6.46875

In [59]:
# 0.01562500000000000000 -> ['00111100', '10000000', '00000000', '00000000']
x = 0b001111001 - 127
2 ** x + (0b00000000000000000000000 / 2 ** 23) * 2 ** x

0.015625

yes, finally! try one more where I make up arbitrary bits:

In [60]:
n = get_torch32_value([0b01101011, 0b10001011, 0b00011100, 0b01010011])
n

3.3634889253542385e+26

In [61]:
x = 0b11010111 - 127
2 ** x + (0b00010110001110001010011 / 2 ** 23) * 2 ** x

3.3634889253542385e+26

In [62]:
# 1.00000000000000000000 -> ['00111111', '10000000', '00000000', '00000000']
x = 0b01111111 - 127
2 ** x + (0b00000000000000000000000 / 2 ** 23) * 2 ** x

1.0

Now 1.0 makes sense. It's just 2^0. How could I have missed that that first bit of byte 2 went with the bits to the left.

In [63]:
# 2.00000000000000000000 -> ['01000000', '00000000', '00000000', '00000000']
x = 0b10000000 - 127
2 ** x + (0b00000000000000000000000 / 2 ** 23) * 2 ** x

2.0

Is this binary scientific notation? Doesn't seem like it. Think more about that later. But maybe similar. I'm sure there are lots of reasons why this is a good representation for doing operations.

But you can start to see why it might not be ideal for certain type of model parameters. Say for other reasons we want to keep the absolute value of our parameters under 2. We're wasting a bit. Let's say absolute numbers less than some size are also not useful. We're wasting more bits. We might prefer to have more "room" within a smaller absolute range.

For example, if we were doing this in decimal scientific notation and could store 5 digits, we might prefer D.DDD x 10^D to D.DD x 10^DD.

btw, this "weird" thing from the top of the notebook now I think makes sense: 

In [64]:
torch.tensor(1.2345678e-45, dtype=torch.float32)

tensor(1.4013e-45)

It needs to round to the closest number it can store.

In [65]:
n = 1.2345678e-45
print(f"{n:52.52f} -> {bits_reversed(torch.tensor(n, dtype=torch.float32))}")

0.0000000000000000000000000000000000000000000012345678 -> ['00000000', '00000000', '00000000', '00000001']


we should see it with bigger small numbers too:

In [66]:
torch.tensor(1.2345678e-44, dtype=torch.float32)

tensor(1.2612e-44)

We should see the same thing at the top end.

In [67]:
torch.tensor(1.2345678e44, dtype=torch.float32)

tensor(inf)

In [68]:
torch.tensor(1.2345678e40, dtype=torch.float32)

tensor(inf)

In [69]:
torch.tensor(1.2345678e36, dtype=torch.float32)

tensor(1.2346e+36)

In [70]:
torch.tensor(1.2345678e38, dtype=torch.float32)

tensor(1.2346e+38)

In [71]:
torch.tensor(1.2345678e39, dtype=torch.float32)

tensor(inf)

In [72]:
n = 1.2345678e38
print(f"{n:52.5f} -> {bits_reversed(torch.tensor(n, dtype=torch.float32))}")

       123456779999999998802545639784961802240.00000 -> ['01111110', '10111001', '11000001', '11010010']


In [73]:
format(2 ** (0b11111101 - 127), '.5e')

'8.50706e+37'

In [74]:
format(2 ** (0b11111111 - 127), '.5e')

'3.40282e+38'

In [75]:
torch.tensor(3.40282e+38, dtype=torch.float32)

tensor(3.4028e+38)

In [76]:
n = 3.40282e+38
print(f"{n:52.5f} -> {bits_reversed(torch.tensor(n, dtype=torch.float32))}")

       340282000000000014192072600942972764160.00000 -> ['01111111', '01111111', '11111111', '11101110']


In [77]:
torch.tensor(3.40282e+38, dtype=torch.float32)

tensor(3.4028e+38)

In [78]:
get_torch32_value([0b01111111, 0b10000000, 0b00000000, 0b00000000])

inf

In [79]:
get_torch32_value([0b01111111, 0b01111111, 0b11111111, 0b11111111])

3.4028234663852886e+38

^ more to understand later