## Stage-Based Compute Reuse Experiment

### Setup
We ran two fine-tuning trials on **Mistral-7B (LoRA)** using a stage-based execution model.

- **Trial A:** 50 iters @ LR=1e-5 → 50 iters @ LR=5e-5
- **Trial B:** 50 iters @ LR=1e-5 → 50 iters @ LR=3e-5

The StageTreeRunner cached checkpoints for stage prefixes and reused them across trials.

In [1]:
import importlib, stage_runner
importlib.reload(stage_runner)

from stage_runner import StageTreeRunner, StaticCfg, Stage, Trial

static = StaticCfg(
    model_path="mlx-community/Mistral-7B-Instruct-v0.2-4bit",
    lora_layers=16,
    val_batches=-1,
    steps_per_eval=10,
    dataset_id="youtube-comments-v1",
    tokenizer_id="mistral-tokenizer-v0.2",
)
runner = StageTreeRunner(static, cache_dir="./stage_cache")

# Define trials as stage sequences (warmup -> fine)
# Trial A: 50 iters @ 1e-5 then 50 @ 5e-5

trial_A = Trial("A_lr1e-5_then_5e-5", [Stage(50,1e-5), Stage(50,5e-5)])
trial_B = Trial("B_lr1e-5_then_3e-5", [Stage(50,1e-5), Stage(50,3e-5)])

runner.run_trial(trial_A)  # builds both stages
runner.run_trial(trial_B)  # reuses stage 1 from A, trains only stage 2

[build] A_lr1e-5_then_5e-5 stage 1/2 (iters=50, lr=1e-05)
>> python /Users/sanjeeb/Coding/HSSL/qlora-mlx/stage-tree/scripts/lora.py --model mlx-community/Mistral-7B-Instruct-v0.2-4bit --train --iters 50 --steps-per-eval 10 --val-batches -1 --learning-rate 1e-05 --lora-layers 16 --adapter-file ./stage_cache/233ac6dad23f2688/adapters.npz
Loading pretrained model

Fetching 7 files:   0%|          | 0/7 [00:00<?, ?it/s]
Fetching 7 files: 100%|██████████| 7/7 [00:00<00:00, 93503.59it/s]
Total parameters 1243.189M
Trainable parameters 0.852M
Loading datasets
Training
Iter 1: Val loss 4.250, Val took 15.061s
Iter 10: Train loss 4.034, It/sec 0.113, Tokens/sec 93.701
Iter 10: Val loss 3.076, Val took 13.949s
Iter 20: Train loss 2.774, It/sec 0.099, Tokens/sec 80.333
Iter 20: Val loss 2.292, Val took 14.754s
Iter 30: Train loss 1.778, It/sec 0.097, Tokens/sec 77.261
Iter 30: Val loss 1.633, Val took 15.656s
Iter 40: Train loss 1.363, It/sec 0.104, Tokens/sec 84.485
Iter 40: Val loss 1.507, Val 

'./stage_cache/aa5051b0d4f20a46/adapters.npz'

## Tabular Summary

| Trial | Stage | Iters | LR   | Cache   | Runtime (approx) | Notes                                  |
| ----- | ----- | ----- | ---- | ------- | ---------------- | -------------------------------------- |
| A     | 1     | 50    | 1e-5 | Build   | \~710s           | Fresh training from scratch            |
| A     | 2     | 50    | 5e-5 | Build   | \~690s           | Resumed from A-Stage-1                 |
| B     | 1     | 50    | 1e-5 | **Hit** | \~0s             | Reused A-Stage-1 checkpoint            |
| B     | 2     | 50    | 3e-5 | Build   | \~690s           | Resumed from shared prefix (A-Stage-1) |

### Reuse metrics:
- **Total without reuse (naïve):** 4 stages × 50 iters = 200 iterations worth of compute.
- With stage reuse:
    - A1 trained (50 iters).
    - A2 trained (50 iters).
    - B1 skipped (reused A1).
    - B2 trained (50 iters).
        - → Only 150 iters executed.
- **Saved:** 50 iterations (~25%).
- **Speedup:** 200 / 150 ≈ 1.33×.


### Stage-Tree Visualization
```
                ┌─── Stage2: lr=5e-5 (A2)
Trial A ── Stage1: lr=1e-5 (A1)
                └─── Stage2: lr=3e-5 (B2)
```

- Stage1 (lr=1e-5, 50 iters) is a shared prefix.
    - Trial A uses it → builds fresh.
    - Trial B reuses it → cache hit.
- Stage2 branches differ:
    - Trial A continues with lr=5e-5.
    - Trial B continues with lr=3e-5.
        - Both had to be trained independently.
So the tree has one common trunk (A1/B1) and two diverging branches (A2, B2).

In [4]:
import pandas as pd

# Load as strings first
raw = pd.read_csv("stage_log.csv", dtype=str, keep_default_na=False)

if raw.empty:
    print("⚠️ stage_log.csv is empty — no runs recorded yet.")
else:
    # Drop rows that are clearly header duplications
    mask_header_dup = (raw["trial"] == "trial") if "trial" in raw else []
    if not raw.empty and len(raw) > 0:
        try:
            mask_header_dup |= (raw.columns.tolist() == raw.iloc[0].tolist())
        except IndexError:
            pass
    raw = raw[~mask_header_dup].copy()

    # Convert numeric columns
    for col in ["stage_idx", "iters", "lr", "runtime", "cache_hit", "cache_miss"]:
        if col in raw:
            raw[col] = pd.to_numeric(raw[col], errors="coerce").fillna(0)

    print("=== Cleaned log ===")
    print(raw)

    # Compute reuse metrics
    iters_built = raw[raw["cache_hit"] == 0]["iters"].sum()
    iters_cached = raw[raw["cache_hit"] == 1]["iters"].sum()

    iters_no_reuse = iters_built + iters_cached
    iters_with_reuse = iters_built

    saved_iters = iters_no_reuse - iters_with_reuse
    pct_saved = (saved_iters / iters_no_reuse) * 100 if iters_no_reuse > 0 else 0.0
    speedup = iters_no_reuse / iters_with_reuse if iters_with_reuse > 0 else 1.0

    print("\n=== Reuse summary ===")
    print(f"Iterations without reuse: {iters_no_reuse}")
    print(f"Iterations with reuse:    {iters_with_reuse}")
    print(f"Saved iterations:         {saved_iters}  ({pct_saved:.2f}%)")
    print(f"Speedup:                  {speedup:.2f}×")


⚠️ stage_log.csv is empty — no runs recorded yet.
