# Before SFT training

## DIAGRAM
**(picture is a bit small. you may have to open the picture separately)**
<br/><br/>
<img src="./sft_diagram.png" width="1200"/>

that's going to be our overall plan. <br/>
so, similar as pretraining. we're going to splice the whole thing pilot -> 8 chunks -> tail (remaining) <br/>
for training, we will be adding SFT with LoRA and EMA <br/>
**(TL;DR)** <br/>
in the diagram, i've ommited some of the features to fit into the picture

# Before SFT training
we'll be updating 15000 times with <br/>
- 5% for pilot (15000 * 0.05 = 750)
- 8 chunks ((15000 - 750) // 8 = 1781)
- remaining 2 for tail (15000 - (750 + 8 * 1781) = 2)

# In training loop
before we talk about **LoRA** and **EMA**, let's talk about fine-tuning <br/>


# Problem
fine-tuning a full transformer with billions of parameters is
- heavy: requires huge GPU memory
- slow: every weight gets updated
- risky: can overwrite useful pretrained knowledge


# **LoRA (Low-Rank Adaptation)**
## idea
- instead of updating the entire weight matrix W, LoRA freezes the original W and ony learns a low-rank update <br/>

let's say we have a large weight matrix W, where the shape is d x k and contains real number <br/>
then, LoRA introduces two smaller trainable matrices: <br/>
- W = A ⋅ B <br/>

where shape of A is d x r <br/>
and shape of B is r x k <br/>
where r is a small rank (like 4, 8, 16, ...) <br/>
(quick check, shape of A ⋅ B is (d x r) ⋅ (r x k), where matrices multiplecation results d x k since r cancels out each other) <br/>
so the effective weight during training is: <br/>

### Formula

**W' = W + α ⋅ (A ⋅ B)** <br/>

intuitively, <br/>
- W stays frozen (keeps pretrained knowledge) <br/>
- A and B are tiny and trainable
- α is a scaling factor (--lora-alpha in train_sft.py)

Consequently
- huge parameter savings since instead of trraining the full d x k matrix, we train only r x (d + k)
- memory efficient since gradient/optimizer states only for LoRA layers
- pluggable since we can add LoRA to specific modules (in our case qkv, out_proj, mlp.fc, mlp.proj)
- composable since multiple LoRA adapters can be swapped in/out on top of the same base model

in a part of my execution and train_sft.py: <br/>
```
    --lora --lora-r 16 --lora-alpha 32 `
    --lora-targets "attn.qkv,attn.out_proj,mlp.fc,mlp.proj" `
```
```
    lora_cfg = LoraConfig(
        r=args.lora_r,
        lora_alpha=args.lora_alpha,
        lora_dropout=args.lora_drop,
        target_modules=targets,
        bias="none",
        task_type=TaskType.CAUSAL_LM,
    )
```

that is:
- `--lora-r 16`: rank of the low-rank decomposition
- `--lora-alpha 32`: scaling factor for the LoRA updates
- `--lora-targets "attn.qkv,attn.out_proj,mlp.fc,mlp.proj"`: applying LoRA only to attention projections and MLP layers (the big matrices)
- `bias="none"`: keep that W matrix frozen

So, the base model stays mostly frozen, and only a tiny fraction (the LoRA matrices) is fine-tuned <br>
that's shown in console as `trainable params: 2,359,296 || all params: 133,888,512 || trainable%: 1.7621` when we run the code <br/>
as you see, we're training small number of params: `trainable params: 2,359,296`

### Why I used it? 
the biggest reason was training speed went better <br/>
with out LoRA & EMA (we'll talk about this later) it used to be 180s/update<br/>
and with LoRA only, it went down to 60s/update

# LoRA in train_sft.py
## ARGS
```
ap.add_argument("--lora", action="store_true")
ap.add_argument("--lora-r", type=int, default=16)
ap.add_argument("--lora-alpha", type=int, default=32)
ap.add_argument("--lora-drop", type=float, default=0.05)
ap.add_argument("--lora-targets", type=str, default="attn.qkv,attn.out_proj,mlp.fc,mlp.proj")
ap.add_argument("--lora-merge", action="store_true")
ap.add_argument("--load-lora-from", type=str, default="")
```

so, yeah. other tags:
- `--lora`: makes sure that we're using LoRA
- `--lora-merge`: in W' = W + α ⋅ (A ⋅ B), it's that `+` sign, so that we end up with single .pt
- `--load-lora-from`: only load LoRA adapter weights on top of base

## Base model Wrapper
```
base_model = DummyModel(cfg).to(device)

if args.load_lora_from:
    model = PeftModel.from_pretrained(base_model, args.load_lora_from, is_trainable=True).to(device)
else:
    targets = [t.strip() for t in args.lora_targets.split(",") if t.strip()]
    lora_cfg = LoraConfig(
        r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=args.lora_drop,
        target_modules=targets, bias="none", task_type=TaskType.CAUSAL_LM
    )
    model = peft_get_peft_model(base_model, lora_cfg).to(device)
```
either load an existing LoRA adapter or wrap with a new LoRA config

## Optimizer
```
if args.lora:
    trainable = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.AdamW(trainable, lr=BASE_LR, betas=(0.9, 0.95), eps=1e-8, fused=True or False)
```
if `--lora` is set, optimizer collects trainable parameters (which are the LoRA adapters) and builds AdamW just for them

## Save
```
base_sd = base_model.state_dict()
if isinstance(model, PeftModel):
    base_sd = _strip_lora_from_base_sd(base_sd)

payload = {"model": base_sd, ...}
if isinstance(model, PeftModel):
    payload["lora_state"] = get_peft_model_state_dict(model)
torch.save(payload, path)
```
base weights goes under `model` but LoRA adapters are stored in `lora_state`

## Load
```
load_state_dict_safely(model, {"model": raw_sd})
if isinstance(model, PeftModel) and "lora_state" in ckpt:
    set_peft_model_state_dict(model, ckpt["lora_state"]) 
```
restores base first, then restore `lora_state`


## Sanity check
```
total = sum(p.numel() for n,p in model.named_parameters() if ".lora_" in n)
...
probe_loss.backward()
any_grad = any((p.grad is not None and p.grad.abs().sum().item() > 0)
               for n,p in model.named_parameters() if ".lora_" in n)
print("LoRA grads flowing:", any_grad)
```
after resuming, it runs a tiny forward/backward to ensure LoRA parameters receive gradients

## EMA can be LoRA-only
```
ema = EMA(model, beta=(args.ema ** args.ema_every), trainable_only=args.ema_trainable_only)
```
the EMA class supports trainable-only averaging (so, it EMA-averages LoRA adapters, not the frozen base) when `--ema-trainable-only` is used

## Saving LoRA adapters as standalone directories
```
best_dir = CKPT_DIR / "lora_best"; model.save_pretrained(str(best_dir))
best_ema_dir = CKPT_DIR / "lora_best_ema"; model.save_pretrained(str(best_ema_dir))
final_dir = CKPT_DIR / "lora_final"; model.save_pretrained(str(final_dir))
```

on new best/latest/final, it writes adapters folder so that we can reuse it in the future

# **EMA (Exponential Moving Average)**

## What's EMA? 
- EMA keeps a smoothed copy of the model's weight while training instead of just relying on the noisy parameters being updated each step

following is the example of with/without EMA smoothing <br/>

<img src="./ema_example1.png" width="600" />

### Formula 
**θ_EMA ​= β ⋅ θ_EMA ​+ (1 − β) ⋅ θ**
- θ: current trainable weights (in our case, LoRA params)
- θ_EMA = shadow weights (EMA copy)
- β: smoothing factor; usually close to 1 (i've used `--ema 0.9999`)

## Advantages
- stability: raw weights can oscillate update-to-update, which EMA smooths that out
- better evaluation: model evaluated with EMA weights often have lower val_loss and perplexity
- regularization effect: prevents overfitting spikes in late training
- safety net: if the live weights diverge temprarily, the EMA copy drifts more slowly

### Why I used it?
EMA actually slows down the training <br/>
but it outputs better results <br/>

during the development, before implementing EMA, it used to be 30s/upd <br/>
after implementing EMA, it is 50s/upd <br/>

the quality of output: <br/>
i don't know if this is caused by implementing EMA <br/>
but i did small fine tuning (around 8k, validation loss ≈6.7) <br/>
and for the question `"Hello, my name is Eric. Who are you?"` <br/>
answers (trimmed): <br/>
without EMA (forgot validation loss): `", , , paradox   .              "` <br/>
with EMA(≈6.7): `",          ,    helpful, me.         is name"` <br/>

# EMA in train_sft.py

## class EMA
```
class EMA:
    def __init__(self, model, beta=0.9999, trainable_only=False):
        self.beta = beta
        base = _unwrap_model_for_state_dict(model)
        self.params = [(n, p) for n, p in base.named_parameters()
                       if p.dtype.is_floating_point and (p.requires_grad if trainable_only else True)]
        with torch.no_grad():
            self.shadow = {n: p.detach().clone() for n, p in self.params}
            self.back = {}

    @torch.no_grad()
    def update(self, model):
        for n, p in self.params:
            self.shadow[n].mul_(self.beta).add_(p.data, alpha=1 - self.beta)

    @torch.no_grad()
    def swap_in(self, model):
        self.back = {n: p.detach().clone() for n, p in self.params}
        for n, p in self.params:
            p.data.copy_(self.shadow[n])

    @torch.no_grad()
    def swap_out(self, model):
        for n, p in self.params:
            p.data.copy_(self.back[n])
        self.back = {}
```

tracks a shadow copy of parameters and updates thhem with exponential formula each step

## ARGS
```
    ap.add_argument("--ema", type=float, default=0.9999)
    ap.add_argument("--ema-trainable-only", action="store_true")
    ap.add_argument("--ema-every", type=int, default=1)
```

- `--ema 0.9999`: set smoothing factor β close to 1
- `--ema-trainable-only`: restrics EMA to LoRA params
- `--ema-every`: update every k steps, so effective β per step is β^k

## Construction
```
ema = EMA(
    model,
    beta=(args.ema ** args.ema_every),   # smoothing factor (compounded if updating every k steps)
    trainable_only=args.ema_trainable_only
) if args.ema and args.ema > 0 else None
```

## Baseline evaluation uses EMA weights
```
if ema: ema.swap_in(model)
val0, ppl0 = evaluate(model, valid_loader, device, amp_dtype, max_batches=5)
if ema: ema.swap_out(model)
```
right after restart/start, we temporarily evaluate with EMA weights

## During training
```
if ema and (update % args.ema_every == 0):
    ema.update(model)
```

in `def train_for_updates(num_updates, tag):` after `optimizer.step()` <br/>
update EMA on schedule

## Quick validation using EMA
```
if args.quick_every > 0 and update % args.quick_every == 0:
    if ema: ema.swap_in(model)
    vq, pq = evaluate(model, valid_loader, device, amp_dtype, max_batches=10)
    if ema: ema.swap_out(model)
    print(f"[quick] upd {update}  val={vq:.4f}  ppl={pq:.2f}  (EMA)")
```

every `--quick-every` updates, it: swap in to EMA -> evaluate -> swap out from EMA

## Alert path (TPS drop/loss spike) uses EMA
```
if ema: ema.swap_in(model)
val_loss, val_ppl = evaluate(model, valid_loader, device, amp_dtype, max_batches=None)
if ema: ema.swap_out(model)
```

When triggering an immediate eval/checkpoint, same thing.
it swap in to EMA -> evaluate -> swap out from EMA

## Full validation & checkpoint uses EMA
```
if ema: ema.swap_in(model)
val_loss, val_ppl = evaluate(model, valid_loader, device, amp_dtype, max_batches=None)
if ema: ema.swap_out(model)
```

every `--eval-every` updates, same thing as above

# Saving special EMA checkpoints / adapters
```
if ema:
    ema.swap_in(model)
    save_checkpoint(CKPT_DIR, manifest, MANIFEST_PATH, model, optimizer, scheduler, scaler,
                    update, val_loss, tag="best_ema")
    if hasattr(model, "save_pretrained"):
        best_ema_dir = CKPT_DIR / "lora_best_ema"
        model.save_pretrained(str(best_ema_dir))
        print(f"💾 saved LoRA (EMA) adapter → {best_ema_dir}")
    ema.swap_out(model)
```

when we hit new best, it does the same thing

# END
those are major updates i've done from train.py to train_sft.py <br/>
i've done some tweaks to decrease the training time <br/>
checkout `2_SFT_tweeks.ipyng`