## Understand types

On my mac I'm pretty sure that everything other than the inputs and targets are float32. On the GPU we're sometimes working with bfloat16 but I'm unsure with the autocasting if the interim tensors / activations are float32 or bfloat16.

To figure this out without going crazy I'm hacking in some debug prints into `my_gpt.py` and `my_base_train.py` which I'll show here but won't commit. I'll then run a tiny training and summarize what I learn from the print statements.

### patch with the debug prints

In [36]:
!git diff ../my_nanochat/my_nanochat/my_gpt.py ../my_nanochat/scripts/my_base_train.py 

[1mdiff --git a/my_nanochat/my_nanochat/my_gpt.py b/my_nanochat/my_nanochat/my_gpt.py[m
[1mindex 1cfb8bf..6b26ab8 100644[m
[1m--- a/my_nanochat/my_nanochat/my_gpt.py[m
[1m+++ b/my_nanochat/my_nanochat/my_gpt.py[m
[36m@@ -116,8 +116,10 @@[m [mclass Block(nn.Module):[m
         self.mlp = MLP(config)[m
 [m
     def forward(self, x, cos_sin, kv_cache):[m
[32m+[m[32m        print(f"input to transformer block layer {self.attn.layer_idx} type is {x.dtype}")[m
         x = x + self.attn(norm(x), cos_sin, kv_cache)[m
         x = x + self.mlp(norm(x))[m
[32m+[m[32m        print(f"output of transformer block layer {self.attn.layer_idx} type is {x.dtype}")[m
         return x[m
 [m
 class GPT(nn.Module):[m
[36m@@ -233,18 +235,26 @@[m [mclass GPT(nn.Module):[m
         T0 = 0 if kv_cache is None else kv_cache.get_pos()[m
         cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T][m
 [m
[32m+[m[32m        print(f"input to model type is {idx.dtype}")[m
     

### first on my mac

In [7]:
import os
os.environ["PYTHONPATH"] = "../my_nanochat"

In [31]:
!python -m scripts.my_base_train \
    --depth=4 \
    --max_seq_len=128 \
    --device_batch_size=1 \
    --num_iterations=1 \
    --total_batch_size=128 \
    --eval_tokens=128 \
    --core_metric_every=0

overriding depth = 4
overriding max_seq_len = 128
overriding device_batch_size = 1
overriding num_iterations = 1
overriding total_batch_size = 128
overriding eval_tokens = 128
overriding core_metric_every = 0
user_config: {'run': 'dummy', 'device_type': '', 'depth': 4, 'max_seq_len': 128, 'num_iterations': 1, 'target_param_data_ratio': 20, 'device_batch_size': 1, 'total_batch_size': 128, 'embedding_lr': 0.2, 'unembedding_lr': 0.004, 'weight_decay': 0.0, 'matrix_lr': 0.02, 'grad_clip': 1.0, 'warmup_ratio': 0.0, 'warmdown_ratio': 0.2, 'final_lr_frac': 0.0, 'eval_every': 250, 'eval_tokens': 128, 'core_metric_every': 0, 'core_metric_max_per_task': 500, 'sample_every': 2000, 'model_tag': ''}
Autodetected device type: mps
This process is ddp_rank: 0, ddp_local_rank: 0, ddp_world_size: 1
Vocab size: 65,536
num_layers: 4
model_dim: 256
num_heads: 2
num_kv_heads: 2
Tokens / micro-batch / rank: 1 x 128 = 128
Tokens / micro-batch: 128
Total batch size 128 => gradient accumulation steps: 1
GPT(
  

#### what I see (on my mac / MPS)

- all parameters are float32
- precomputed cos and sin are bfloat16
- model input (token ids) is int32 during training but int64 during sampling (noticed this earlier and don't understand it but can't imagine it matters and int32 is more than enough for our vocab size)
- model targets (token ids) is int64
- all other interim tensors appear to be float32 (output of wte, input to each layer, output of lm_head, loss)

All is exactly as expected except I forgot that cos and sin are bfloat16 even on MPS.

### now on GPU machine

(In addition to the patch above, I added three more debug lines to spot check the type of gradients. You can see this in the output.)

In [8]:
!python -m scripts.my_base_train \
    --depth=4 \
    --max_seq_len=128 \
    --device_batch_size=1 \
    --num_iterations=1 \
    --total_batch_size=128 \
    --eval_tokens=128 \
    --core_metric_every=0

overriding depth = 4
overriding max_seq_len = 128
overriding device_batch_size = 1
overriding num_iterations = 1
overriding total_batch_size = 128
overriding eval_tokens = 128
overriding core_metric_every = 0
user_config: {'run': 'dummy', 'device_type': '', 'depth': 4, 'max_seq_len': 128, 'num_iterations': 1, 'target_param_data_ratio': 20, 'device_batch_size': 1, 'total_batch_size': 128, 'embedding_lr': 0.2, 'unembedding_lr': 0.004, 'weight_decay': 0.0, 'matrix_lr': 0.02, 'grad_clip': 1.0, 'warmup_ratio': 0.0, 'warmdown_ratio': 0.2, 'final_lr_frac': 0.0, 'eval_every': 250, 'eval_tokens': 128, 'core_metric_every': 0, 'core_metric_max_per_task': 500, 'sample_every': 2000, 'model_tag': ''}
Autodetected device type: cuda
  _C._set_float32_matmul_precision(precision)
This process is ddp_rank: 0, ddp_local_rank: 0, ddp_world_size: 1
Vocab size: 65,536
num_layers: 4
model_dim: 256
num_heads: 2
num_kv_heads: 2
Tokens / micro-batch / rank: 1 x 128 = 128
Tokens / micro-batch: 128
Total batch siz

#### what I see (on GPU machine)

- wte is bfloat16 (as expected)
- all other parameters are float32
- precomputed cos and sin are bfloat16
- model input (token ids) is int32 during training but int64 during sampling (same as on mac, noticed this earlier)
- model targets (token ids) is int64
- output of wte is bfloat16 and **we stay in bfloat16** through all transformer blocks and lm_head
- before F.cross_entropy we convert back to float32
- param grads match the type of the params (bfloat16 for wte, float32 for others)

All is exactly as expected except I forgot that cos and sin are bfloat16 even on MPS.

### simple example of...

how if input is bfloat16 but parameters are float32, the output is bfloat16

In [76]:
import torch

In [80]:
m = torch.randn((2,2), dtype=torch.float32, device="cuda"); m

tensor([[ 1.3034, -0.2928],
        [ 0.0120,  2.1256]], device='cuda:0')

In [81]:
x = torch.randn((2), dtype=torch.bfloat16, device="cuda"); x

tensor([-0.8398,  0.6016], device='cuda:0', dtype=torch.bfloat16)

In [82]:
x @ m

RuntimeError: expected mat1 and mat2 to have the same dtype, but got: c10::BFloat16 != float

In [83]:
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)

In [85]:
with autocast_ctx:
    y = x @ m

In [86]:
y

tensor([-1.0859,  1.5234], device='cuda:0', dtype=torch.bfloat16)

### see difference in result operating in bfloat16 vs float32

In [91]:
y

tensor([-1.0859,  1.5234], device='cuda:0', dtype=torch.bfloat16)

In [92]:
(x.float() @ m)

tensor([-1.0875,  1.5245], device='cuda:0')

In [87]:
y[0].item()

-1.0859375

In [90]:
(x.float() @ m)[0].item()

-1.0874789953231812