# Hymba v2 — Flash/Efficient SDPA, SWA(global+local), KV-share, Meta tokens

In [1]:
# 0) 준비
import torch, math, time, pandas as pd
from backbone.hymba_v2 import HymbaV2, ModelCfg, TrainCfg, build_everything, train_loop

device = "cuda" if torch.cuda.is_available() else "cpu"

steps = 100
batch_size = 128

model, tok, train_dl, val_dl = build_everything(seq_len=512, bs=batch_size, vocab_size=6000)
display(model.layer_table())
print("est_cache_mb@512:", model.estimate_kv_cache_mb(512))

# 1) 짧게 학습

tcfg = TrainCfg(seq_len=512, batch_size=batch_size, steps=steps, lr=6e-4, warmup=int(steps*0.1), amp=True, grad_clip=1.0)
stats = train_loop(model, train_dl, val_dl, tcfg, device=device)
stats






Unnamed: 0,layer,attn,kv_owner,kv_share_group
0,0,GLOBAL,0,0
1,1,LOCAL(SWA),1,1
2,2,LOCAL(SWA),1,1
3,3,LOCAL(SWA),3,2
4,4,LOCAL(SWA),3,2
5,5,LOCAL(SWA),5,3
6,6,GLOBAL,6,4
7,7,LOCAL(SWA),7,5
8,8,LOCAL(SWA),7,5
9,9,LOCAL(SWA),9,6


est_cache_mb@512: 2.0


2025-10-02 01:09:36.092931: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


[    1] loss=8.813 lr=6.00e-05
[   50] loss=5.440 lr=3.52e-04
[  100] loss=4.412 lr=0.00e+00


{'train_loss': 5.527135457992554,
 'val_loss': 4.630681991577148,
 'ppl': 102.58400168446353,
 'tps': 196059}

In [2]:
# 2) 평가/벤치 유틸
from contextlib import nullcontext

def peak_gpu_mem_mb():
    if device == "cuda":
        torch.cuda.synchronize()
        m = torch.cuda.max_memory_allocated()/(1024**2)
        torch.cuda.reset_peak_memory_stats()
        return round(m,2)
    return 0.0

@torch.no_grad()
def evaluate_ppl(model, val_dl, amp=True):
    model.eval()
    nll=0.0; tok=0
    ctx = (torch.amp.autocast("cuda") if (amp and device=="cuda") else nullcontext())
    with ctx:
        for xb,yb in val_dl:
            xb,yb = xb.to(device), yb.to(device)
            out = model(xb, targets=yb)
            nll += out["loss"].item()*xb.numel()
            tok += xb.numel()
    return math.exp(nll/max(1,tok))
@torch.no_grad()
def bench_generate(model, prompt_len=512, gen_len=256, use_kv_cache=True, kv_share=True, warmup=1, repeat=2):
    model.eval()
    device = next(model.parameters()).device
    vocab = model.cfg.vocab_size
    torch.manual_seed(0)
    prompt = torch.randint(0, vocab, (1, prompt_len), device=device)

    # warmup
    for _ in range(warmup):
        _ = model.generate(prompt, max_new_tokens=16, use_kv_cache=use_kv_cache, kv_share=kv_share)

    if device.type == "cuda":
        torch.cuda.reset_peak_memory_stats()

    import time
    times = []
    for _ in range(repeat):
        if device.type == "cuda": torch.cuda.synchronize()
        t0 = time.time()
        _ = model.generate(prompt, max_new_tokens=gen_len, use_kv_cache=use_kv_cache, kv_share=kv_share)
        if device.type == "cuda": torch.cuda.synchronize()
        times.append(time.time() - t0)

    sec = sum(times) / len(times)
    tps = int((prompt_len + gen_len) / sec)
    mem = 0.0
    if device.type == "cuda":
        mem = torch.cuda.max_memory_allocated() / (1024**2)

    # <<< 표준화된 키명 >>>
    return {
        "gen_latency_s": round(sec, 3),
        "gen_tps": tps,
        "gen_peak_mb": round(mem, 2),
    }



In [3]:
# 3) KV-cache 비교 표 (NoCache / KV / KV+Share)
rows=[]

b1 = bench_generate(model, prompt_len=512, gen_len=256, use_kv_cache=False, kv_share=False)
p1 = evaluate_ppl(model, val_dl, amp=True)
rows.append({"title":"No Cache (recompute)", "ppl":round(p1,3), **b1, "est_cache_mb":"-"})

b2 = bench_generate(model, prompt_len=512, gen_len=256, use_kv_cache=True, kv_share=False)
p2 = evaluate_ppl(model, val_dl, amp=True)
rows.append({"title":"KV Cache (no share)", "ppl":round(p2,3), **b2, "est_cache_mb": model.estimate_kv_cache_mb(512)})

b3 = bench_generate(model, prompt_len=512, gen_len=256, use_kv_cache=True, kv_share=True)
p3 = evaluate_ppl(model, val_dl, amp=True)
rows.append({"title":"KV Cache + SWA Share", "ppl":round(p3,3), **b3, "est_cache_mb": model.estimate_kv_cache_mb(512)})

pd.DataFrame(rows)


Unnamed: 0,title,ppl,gen_latency_s,gen_tps,gen_peak_mb,est_cache_mb
0,No Cache (recompute),102.561,3.794,202,181.38,-
1,KV Cache (no share),102.561,3.591,213,160.99,2.0
2,KV Cache + SWA Share,102.561,3.306,232,158.99,2.0


In [4]:
# (C) 편의: 새 모델 만들기(변이)
def new_model_from(base_cfg:ModelCfg, **kw) -> HymbaV2:
    cfg = ModelCfg(**{**base_cfg.__dict__, **kw})
    m = HymbaV2(cfg).to(device)
    return m

# 동일 학습 레시피(짧게 돌려 비교)
train_recipe = TrainCfg(seq_len=512, batch_size=batch_size, steps=steps, lr=6e-4, warmup=int(steps*0.1), amp=True, grad_clip=1.0)

In [5]:
# (D) bench_train: 단계별로 컴포넌트를 하나씩 추가하며 학습/측정
def bench_train(tok, train_dl, val_dl, base_cfg:ModelCfg, recipe:TrainCfg):
    rows = []

    stages = [
        # title, cfg overrides, gen flags (nocache/kv/kv+share 측정)
        ("0) Global-only (no SWA, no Meta)",  {"swa_layers": (), "num_meta_tokens": 0},  {"kv_share": False}),
        ("1) + SWA (local windows)",          {"swa_layers": base_cfg.swa_layers, "num_meta_tokens": 0}, {"kv_share": False}),
        ("2) + KV-Share (SWA cross-layer)",   {"swa_layers": base_cfg.swa_layers, "num_meta_tokens": 0}, {"kv_share": True}),
        ("3) + MetaTokens (learnable M=4)",   {"swa_layers": base_cfg.swa_layers, "num_meta_tokens": 4}, {"kv_share": True}),
    ]

    for title, cfg_over, gen_flags in stages:
        model = new_model_from(base_cfg, **cfg_over)
        print(f"\n=== {title} ===")
        display(model.layer_table())

        stats = train_loop(model, train_dl, val_dl, recipe, device=device)
        ppl = evaluate_ppl(model, val_dl, amp=True)

        # 생성 속도/메모리 비교: NoCache vs KV(no-share) vs KV(share or not)
        b_nc = bench_generate(model, prompt_len=512, gen_len=256, use_kv_cache=False, kv_share=False)
        b_kv = bench_generate(model, prompt_len=512, gen_len=256, use_kv_cache=True,  kv_share=False)
        b_sh = bench_generate(model, prompt_len=512, gen_len=256, use_kv_cache=True,  kv_share=gen_flags["kv_share"])

        rows.append({
            "title": title,
            "train_loss": round(stats["train_loss"],4),
            "val_loss": round(stats["val_loss"],4),
            "ppl": round(ppl,3),
            # gen benches
            "gen_tps_nocache": b_nc["gen_tps"],
            "gen_tps_kv": b_kv["gen_tps"],
            "gen_tps_share": b_sh["gen_tps"],
            "gen_mb_nocache": b_nc["gen_peak_mb"],
            "gen_mb_kv": b_kv["gen_peak_mb"],
            "gen_mb_share": b_sh["gen_peak_mb"],
            # 추정 캐시 사용량(owners 기준)
            "est_cache_mb@512": model.estimate_kv_cache_mb(512),
        })
    return pd.DataFrame(rows)


In [7]:
# (E) 실행
base_cfg = model.cfg
base_cfg.n_layers = 15
train_recipe.steps = 500

df = bench_train(tok, train_dl, val_dl, base_cfg, train_recipe)
df


=== 0) Global-only (no SWA, no Meta) ===


Unnamed: 0,layer,attn,kv_owner,kv_share_group
0,0,GLOBAL,0,0
1,1,GLOBAL,1,1
2,2,GLOBAL,2,2
3,3,GLOBAL,3,3
4,4,GLOBAL,4,4
5,5,GLOBAL,5,5
6,6,GLOBAL,6,6
7,7,GLOBAL,7,7
8,8,GLOBAL,8,8
9,9,GLOBAL,9,9


[    1] loss=8.803 lr=6.00e-05
[   50] loss=5.440 lr=5.90e-04
[  100] loss=4.657 lr=5.51e-04
[  150] loss=3.288 lr=4.87e-04
[  200] loss=1.708 lr=4.04e-04
[  250] loss=0.772 lr=3.10e-04
[  300] loss=0.216 lr=2.15e-04
[  350] loss=0.039 lr=1.28e-04
[  400] loss=0.016 lr=5.96e-05
[  450] loss=0.011 lr=1.53e-05
[  500] loss=0.011 lr=0.00e+00

=== 1) + SWA (local windows) ===


Unnamed: 0,layer,attn,kv_owner,kv_share_group
0,0,GLOBAL,0,0
1,1,LOCAL(SWA),1,1
2,2,LOCAL(SWA),1,1
3,3,LOCAL(SWA),3,2
4,4,LOCAL(SWA),3,2
5,5,LOCAL(SWA),5,3
6,6,GLOBAL,6,4
7,7,LOCAL(SWA),7,5
8,8,LOCAL(SWA),7,5
9,9,LOCAL(SWA),9,6


[    1] loss=8.795 lr=6.00e-05
[   50] loss=5.393 lr=5.90e-04
[  100] loss=2.891 lr=5.51e-04
[  150] loss=1.472 lr=4.87e-04
[  200] loss=0.275 lr=4.04e-04
[  250] loss=0.018 lr=3.10e-04
[  300] loss=0.004 lr=2.15e-04
[  350] loss=0.001 lr=1.28e-04
[  400] loss=0.001 lr=5.96e-05
[  450] loss=0.000 lr=1.53e-05
[  500] loss=0.000 lr=0.00e+00

=== 2) + KV-Share (SWA cross-layer) ===


Unnamed: 0,layer,attn,kv_owner,kv_share_group
0,0,GLOBAL,0,0
1,1,LOCAL(SWA),1,1
2,2,LOCAL(SWA),1,1
3,3,LOCAL(SWA),3,2
4,4,LOCAL(SWA),3,2
5,5,LOCAL(SWA),5,3
6,6,GLOBAL,6,4
7,7,LOCAL(SWA),7,5
8,8,LOCAL(SWA),7,5
9,9,LOCAL(SWA),9,6


[    1] loss=8.795 lr=6.00e-05
[   50] loss=5.393 lr=5.90e-04
[  100] loss=2.891 lr=5.51e-04
[  150] loss=1.472 lr=4.87e-04
[  200] loss=0.275 lr=4.04e-04
[  250] loss=0.018 lr=3.10e-04
[  300] loss=0.004 lr=2.15e-04
[  350] loss=0.001 lr=1.28e-04
[  400] loss=0.001 lr=5.96e-05
[  450] loss=0.000 lr=1.53e-05
[  500] loss=0.000 lr=0.00e+00

=== 3) + MetaTokens (learnable M=4) ===


Unnamed: 0,layer,attn,kv_owner,kv_share_group
0,0,GLOBAL,0,0
1,1,LOCAL(SWA),1,1
2,2,LOCAL(SWA),1,1
3,3,LOCAL(SWA),3,2
4,4,LOCAL(SWA),3,2
5,5,LOCAL(SWA),5,3
6,6,GLOBAL,6,4
7,7,LOCAL(SWA),7,5
8,8,LOCAL(SWA),7,5
9,9,LOCAL(SWA),9,6


[    1] loss=8.802 lr=6.00e-05
[   50] loss=5.386 lr=5.90e-04
[  100] loss=3.215 lr=5.51e-04
[  150] loss=1.576 lr=4.87e-04
[  200] loss=0.368 lr=4.04e-04
[  250] loss=0.028 lr=3.10e-04
[  300] loss=0.005 lr=2.15e-04
[  350] loss=0.006 lr=1.28e-04
[  400] loss=0.002 lr=5.96e-05
[  450] loss=0.001 lr=1.53e-05
[  500] loss=0.001 lr=0.00e+00


Unnamed: 0,title,train_loss,val_loss,ppl,gen_tps_nocache,gen_tps_kv,gen_tps_share,gen_mb_nocache,gen_mb_kv,gen_mb_share,est_cache_mb@512
0,"0) Global-only (no SWA, no Meta)",1.9633,10.0192,22453.574,166,174,174,25185.07,25168.82,25168.82,3.75
1,1) + SWA (local windows),1.3444,5.1227,167.795,163,171,172,25184.55,25169.2,25169.2,2.75
2,2) + KV-Share (SWA cross-layer),1.3444,5.1227,167.795,158,168,180,25185.07,25168.88,25166.88,2.75
3,3) + MetaTokens (learnable M=4),1.4362,5.107,165.167,163,171,184,25184.61,25168.57,25166.57,2.75
