# SFT tweaks

**!!!DISCLAIMER!!!**

you may completely ignore this notebook <br/>
i'm just jotting down the problems i've faced while modifying `train_sft.py`

## Problems
1. update speed is 180s/upd
2. tps ≈500 
3. on resume, foward ≈300, backward ≈6000

## idea
1. foward ≈300 is fine. but, backward is too slow. make this 400 ~ 800
2. increase tps ≈3000

#### conclusion
if we can decrease backward speed to around 400 ~ 800 and tps ≈3000, i'm pretty sure update speed would increase <br/>
ok. let's see if we can reach around 20s/upd

# Solving backward speed ≈6000 problem
## Problem found!
yup. for some reason, SDPBackend.MATH was slowing down. <br/>
so, i completely disabled using SDPBackend.MATH, but forced to use either FLASH_ATTENTION or EFFICIENT_ATTENTION <br/>

- using fp16 + SDPBackend.FLASH_ATTENTION
    - result: backward ≈3000
- using bf16 + SDPBackend.EFFICIENT_ATTENTION
    - result: backward ≈2000

ok. stick with bf16 + SDPBackend.EFFICIENT_ATTENTION

## Conclusion
original was using fp16 + SDPBackend.MATH <br/>
now, it uses bf16 + SDPBackend.EFFICIENT_ATTENTION <br/>
(still can use fp16 + SDPBackend.FLASH_ATTENTION but i'm using the one above since it's faster)

# ENV var conflicts forcing bad kernels
## Problem
`SDPA_BACKEND`, `PYTORCH_SDP_KERNEL` and shell overrides nudged PyTorch into slow paths
## Solution
cleared env overrides in script and used: <br/>
- cuda_backend.enable_flash_sdp(False)
- cuda_backend.enable_mem_efficient_sdp(True)
- cuda_backend.enable_math_sdp(True)
## Result
yup. no longer seeing that kernel error

# No FlashAttention in the wheel
### Problem
flash was unusable. torch was not compiled with flash attention
### Solution
i've upgraded from cu121 to cu124, set prefer mem-efficient SDPA; only Flash if actually available
-> YES! now it's using bf16 + SDPBackend.EFFICIENT_ATTENTION
### Result
bf16 + SDPBackend.EFFICIENT_ATTENTION is working with no error


# AMP/scaler mismatches
### Problem
calling GradScaler logic when using bf16 caused awkward code overheads
### Solution
guard everything with scaler and scaler.is_enabled()
### Another Problem
when using fp16, the graph was drastically fluctuating...
### Solution for Another Problem
use scaler only for fp16
### Result
now, both fp16 and bf16 are working fine

# Stressed kernel
### Problem 
backward time is showing 5k
### Solution
welp. i played around with `--micro-bsz` and `--accum` <br/>
setting `--micro-bsz 8 --accum 8` did solve the problem
### result
yeah... the speed went lower to 2k with bf16 (but that's still slow...)

# State dict shape/name mismatches
### Problem
checkpoint with split `q/k/v` vs fused qkv; PEFT `.base_layer` naming
### Solution
auto-fuse `q,k,v -> qkv` on load and map to `.base_layer` keys for LoRA
### Result
now, i no longer see that missing keys stuffs 

# Resume/scheduler drift
### Problem
LR scheduler wasn't alighed when resuming midrun
### Solution
fast--forward / rebuild scheduler to match remaining steps and current LR
### Result
on resume, now i do see stable initial values

# INFO
currently, i see <br/>
- `[profile] forward ≈200ms  backward ≈450ms  (no optimizer.step)`
- `tps ≈4000` 

not THAT(?) bad but would be better if we can make it even faster

# Experiment 1
let's try `--torch-compile`, `--compile-mode`
### Result
i tried both 
- `--compile-mode over-head`
- `--compile-mode default`
- `--compile-mode max-autotune`

but they all dramatically increased backward speed (the highest one was ≈11000ms)<br/>
yeah... better not use it...

# Experiment 2
let's try adding `--workers`
### Result
even 1 worker slow downs the update speed... (80s/upd) <br/>
i'm just gonna use no workers

# Experiment 3
let's try enabling TF32 for extra throughput <br/>
additionally set `is_causal=True`

# Result
oh hey! <br/>
i did `torch.backends.cuda.matmul.allow_tf32 = True` <br/>
and `torch.set_float32_matmul_precision("high")` <br/>

it did increase the tps (tps ≈5500)