
# Hymba Ablation Notebook (Table-Style Replication)
첨부 표의 **논리 흐름(1→13)**에 맞춰, 하나의 **완성된 Hymba 모델**을 정의하고 설정만 바꿔가며
작은 데이터셋에 **직접 학습·평가**할 수 있도록 구성했습니다.

**핵심 특징**
- 현대적 **Unigram(SentencePiece 계열)** 토크나이저를 **직접 학습**
- 데이터셋 경로가 없으면 자동으로 🤗 `datasets`에서 다운로드 (torchtext 불사용)
- 모델: 하이브리드(Attention+Mamba), **SWA**, **Cross-layer KV sharing**, **Meta Tokens**, **GQA on/off**, **Fusion(Mean/Concat)**
- 훈련: **AdamW + Cosine with Warmup**, **AMP**, **Grad Clip**
- 결과: **pandas DataFrame**으로 **Throughput(tokens/s)**, **PPL**, **토큰 정확도**, **KV‑cache(MB)**, **Attn:Mamba 비율**을 보기 좋게 요약


## 0) 의존성 설치 (파이썬 코드로 실행)

In [None]:
import sys, subprocess, pkg_resources
def pip_install(pkgs):
    for p in pkgs:
        try:
            pkg_resources.get_distribution(p.split('==')[0].split('>=')[0])
        except Exception:
            subprocess.check_call([sys.executable, "-m", "pip", "install", p])
pip_install([
    "torch>=2.2.0",
    "datasets>=2.18.0",
    "tokenizers>=0.15.2",
    "accelerate>=0.32.0",
    "mamba-ssm>=2.2.2",
    "pandas>=2.1.0",
    "matplotlib>=3.8.0"
])
import os, math, time, dataclasses, typing as T, numpy as np, pandas as pd
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, random_split
from datasets import load_dataset
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device



## 1) 실험 설정
- A100 80GB 기준 기본값. 데모 목적으로 `max_steps`는 작게 두었습니다(필요 시 키우세요).


In [None]:
@dataclasses.dataclass
class CFG:
    dataset: str = "karpathy/tiny_shakespeare"  # 로컬 경로 or HF hub 이름
    seq_len: int = 512
    batch_size: int = 16
    epochs: int = 1
    max_steps: int = 150
    lr: float = 3e-4
    weight_decay: float = 0.1
    warmup_steps: int = 40
    grad_accum: int = 1
    grad_clip: float = 1.0
    amp: bool = True
    # tokenizer
    vocab_size: int = 32000
    bos_token: str = "<|bos|>"
    eos_token: str = "<|eos|>"
    pad_token: str = "<|pad|>"
    # model base
    d_model: int = 512
    n_layers: int = 12
    n_heads: int = 8
    n_kv_heads: int = 4
    # SWA
    swa_window: int = 256
    # train/val split
    val_ratio: float = 0.05

cfg = CFG(); cfg



## 2) 데이터 & 토크나이저
- 경로가 없으면 🤗 `datasets`에서 자동 다운로드
- 최신 **Unigram** 토크나이저를 **직접 학습**(BOS/EOS/PAD 포함)


In [None]:
from tokenizers import Tokenizer
from tokenizers.models import Unigram
from tokenizers.trainers import UnigramTrainer
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.normalizers import NFKC, Lowercase, Sequence as NormSeq
from tokenizers.processors import TemplateProcessing

def read_local_text(path: str):
    if not os.path.exists(path):
        return None
    if os.path.isdir(path):
        texts=[]
        for root,_,files in os.walk(path):
            for f in files:
                if f.endswith(".txt"):
                    with open(os.path.join(root,f),"r",encoding="utf-8",errors="ignore") as fh:
                        texts.append(fh.read())
        return "\n\n".join(texts) if texts else None
    else:
        with open(path,"r",encoding="utf-8",errors="ignore") as fh:
            return fh.read()

def get_corpus(spec:str)->str:
    text = read_local_text(spec)
    if text is not None:
        return text
    ds = load_dataset(spec)
    # heuristic column choice
    for k in ["text","content","document","raw","data"]:
        if k in ds["train"].column_names:
            return "\n\n".join(ds["train"][k])
    return "\n\n".join(map(str, ds["train"][:1000]))

def train_unigram_tokenizer(corpus_text:str, vocab_size:int, bos:str, eos:str, pad:str):
    def batch_iter(text, bs=1000000):
        for i in range(0, len(text), bs):
            yield text[i:i+bs]
    tok = Tokenizer(Unigram())
    tok.normalizer = NormSeq([NFKC(), Lowercase()])
    tok.pre_tokenizer = Whitespace()
    trainer = UnigramTrainer(vocab_size=vocab_size, special_tokens=[pad,bos,eos])
    tok.train_from_iterator(batch_iter(corpus_text), trainer=trainer)
    pad_id = tok.token_to_id(pad); bos_id = tok.token_to_id(bos); eos_id = tok.token_to_id(eos)
    tok.post_processor = TemplateProcessing(
        single=f"{bos} $A {eos}",
        pair=f"{bos} $A {eos} $B:1 {eos}:1",
        special_tokens=[(bos, bos_id), (eos, eos_id)]
    )
    class ModernTok:
        def __init__(self, tk, bos_id, eos_id, pad_id):
            self.tk=tk; self.bos_token_id=bos_id; self.eos_token_id=eos_id; self.pad_token_id=pad_id
        def encode(self, s): return self.tk.encode(s).ids
        def decode(self, ids): return self.tk.decode(ids)
        @property
        def vocab_size(self): return self.tk.get_vocab_size()
    return ModernTok(tok, bos_id, eos_id, pad_id)

corpus = get_corpus(cfg.dataset)
tok = train_unigram_tokenizer(corpus, cfg.vocab_size, cfg.bos_token, cfg.eos_token, cfg.pad_token)
VOCAB = tok.vocab_size; EOS_ID = tok.eos_token_id
VOCAB, EOS_ID



### 2-1) 토큰화 & 청크 → Train/Val Split


In [None]:
def chunk_tokens(text:str, seq_len:int, tok=tok, eos_id=EOS_ID):
    ids = tok.encode(text)
    n = (len(ids)//seq_len)*seq_len; ids = ids[:n]
    arr = np.array(ids, dtype=np.int64).reshape(-1, seq_len)
    y = np.copy(arr); y[:, :-1] = arr[:, 1:]; y[:, -1] = eos_id
    X = torch.tensor(arr); Y = torch.tensor(y)
    return TensorDataset(X, Y)

full_ds = chunk_tokens(corpus, cfg.seq_len, tok, EOS_ID)
val_len = max(1, int(len(full_ds)*cfg.val_ratio))
train_len = len(full_ds) - val_len
train_ds, val_ds = random_split(full_ds, [train_len, val_len])
train_dl = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, drop_last=True)
val_dl   = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False, drop_last=False)
len(train_ds), len(val_ds)



## 3) 모델: 하나의 구현으로 모든 설정을 토글
- **Attn:Mamba 차원 비율**(Param Ratio)로 경로 분할
- **GQA on/off** (`n_kv_heads < n_heads`면 GQA)
- **SWA 레이어 집합**(전부/일부/없음), **Cross-layer KV sharing**
- **Fusion**: `mean`(기본) vs `concat`
- **Meta Tokens**: on/off


In [None]:
class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position=131072, base=10000.0):
        super().__init__()
        inv = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        t = torch.arange(max_position, dtype=torch.float32)
        freqs = torch.einsum("i,j->ij", t, inv)
        self.register_buffer("cos_cached", torch.cos(freqs), persistent=False)
        self.register_buffer("sin_cached", torch.sin(freqs), persistent=False)
    def forward(self, x, positions):
        cos = self.cos_cached[positions][:, None, None, :]
        sin = self.sin_cached[positions][:, None, None, :]
        x1 = x[..., ::2]; x2 = x[..., 1::2]
        return torch.stack([x1*cos - x2*sin, x1*sin + x2*cos], dim=-1).flatten(-2)

class RMSNorm(nn.Module):
    def __init__(self, d, eps=1e-6): super().__init__(); self.eps=eps; self.w=nn.Parameter(torch.ones(d))
    def forward(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.w

class SwiGLU_FFN(nn.Module):
    def __init__(self, d, mult=4.0, dropout=0.0):
        super().__init__()
        h=int(d*mult); self.w1=nn.Linear(d,h, bias=False); self.w2=nn.Linear(d,h, bias=False)
        self.w3=nn.Linear(h,d, bias=False); self.drop=nn.Dropout(dropout)
    def forward(self, x): return self.w3(self.drop(F.silu(self.w1(x)) * self.w2(x)))

def band_mask(T, w, device, dtype):
    i = torch.arange(T, device=device); j = torch.arange(T, device=device)
    m = (j[None,:] <= i[:,None]) & (j[None,:] >= (i[:,None]-w+1))
    M = torch.zeros((T,T), device=device, dtype=dtype)
    return M.masked_fill(~m, float("-inf"))

class GQA(nn.Module):
    def __init__(self, d_model, H, KV, rope=None, use_swa=False, swa_window=4096, dropout=0.0):
        super().__init__(); assert (H==0) or (d_model%max(1,H)==0)
        self.enabled = (d_model>0 and H>0)
        self.H=H; self.KV=KV if KV>0 else H
        self.rep=max(1,self.H//self.KV); self.Dh = (d_model//max(1,self.H)) if self.enabled else 1
        self.q=nn.Linear(d_model, self.H*self.Dh, bias=False) if self.enabled else None
        self.k=nn.Linear(d_model, self.KV*self.Dh, bias=False) if self.enabled else None
        self.v=nn.Linear(d_model, self.KV*self.Dh, bias=False) if self.enabled else None
        self.o=nn.Linear(self.H*self.Dh, d_model, bias=False) if self.enabled else None
        self.drop=nn.Dropout(dropout); self.rope=rope; self.use_swa=use_swa; self.swa_window=swa_window
    def forward(self, x, kv_cache=None, global_mask=None):
        if not self.enabled: 
            return torch.zeros_like(x), None
        B,T,C=x.shape
        q=self.q(x).view(B,T,self.H,self.Dh).transpose(1,2)
        k=self.k(x).view(B,T,self.KV,self.Dh).transpose(1,2)
        v=self.v(x).view(B,T,self.KV,self.Dh).transpose(1,2)
        if kv_cache is not None and kv_cache[0] is not None and kv_cache[0].numel()>0:
            pk,pv=kv_cache; k=torch.cat([pk,k],2); v=torch.cat([pv,v],2)
        if self.rope is not None:
            pos_q=torch.arange(k.size(2)-T, k.size(2), device=x.device)
            pos_k=torch.arange(0, k.size(2), device=x.device)
            q=self.rope(q, pos_q); k_exp=k.repeat_interleave(self.rep,1); k_exp=self.rope(k_exp, pos_k)
        else:
            k_exp=k.repeat_interleave(self.rep,1)
        v_exp=v.repeat_interleave(self.rep,1)
        Tc=k_exp.size(2)
        if self.use_swa:
            M = band_mask(Tc, self.swa_window, x.device, q.dtype)[:, -T:]
        else:
            i = torch.arange(Tc, device=x.device); j = torch.arange(Tc, device=x.device)
            causal = (j[None,:] <= i[:,None]); causal = causal[:, -T:]
            M = torch.zeros((T,Tc), device=x.device, dtype=q.dtype).masked_fill(~causal, float("-inf"))
        if global_mask is not None:
            full = torch.zeros_like(M)
            M = torch.where(global_mask[:, :, None], full[None,:,:], M[None,:,:]).squeeze(0)
        q_=q.reshape(B*self.H,T,self.Dh); k_=k_exp.reshape(B*self.H,Tc,self.Dh); v_=v_exp.reshape(B*self.H,Tc,self.Dh)
        M_=M.unsqueeze(0).expand(B*self.H, -1, -1)
        out = F.scaled_dot_product_attention(q_, k_, v_, attn_mask=M_, is_causal=False,
                                             dropout_p=float(self.drop.p) if self.training else 0.0)
        out=out.view(B,self.H,T,self.Dh).transpose(1,2).reshape(B,T,self.H*self.Dh)
        out=self.o(out)
        new_cache=(k.detach(), v.detach())
        return out, new_cache

try:
    from mamba_ssm import Mamba as _Mamba
except Exception:
    _Mamba=None
class MambaLayer(nn.Module):
    def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
        super().__init__(); self.enabled=(d_model>0); 
        self.inner = (_Mamba(d_model=d_model, d_state=d_state, d_conv=d_conv, expand=expand) if (_Mamba and self.enabled) else nn.Identity())
    def forward(self, x): return self.inner(x)

class HymbaBlock(nn.Module):
    def __init__(self, d_model, attn_dim, attn_heads, kv_heads, use_swa, swa_window, rope,
                 mamba_dim, ffn_mult=4.0, dropout=0.0, fusion:str="mean"):
        super().__init__()
        self.attn_dim=attn_dim; self.mamba_dim=mamba_dim; self.fusion=fusion
        self.pre=RMSNorm(d_model)
        self.to_a=nn.Linear(d_model, attn_dim, bias=False) if attn_dim>0 else None
        self.to_m=nn.Linear(d_model, mamba_dim, bias=False) if mamba_dim>0 else None
        self.attn=GQA(attn_dim, attn_heads, kv_heads, rope=rope, use_swa=use_swa, swa_window=swa_window, dropout=dropout) if attn_dim>0 else None
        self.mamba=MambaLayer(mamba_dim) if mamba_dim>0 else None
        if fusion=="concat":
            self.mix = nn.Linear(attn_dim+mamba_dim, d_model, bias=False)
            self.mix2= None
        else:  # mean
            self.mix = nn.Linear(attn_dim, d_model, bias=False) if attn_dim>0 else None
            self.mix2= nn.Linear(mamba_dim, d_model, bias=False) if mamba_dim>0 else None
        self.ffn=SwiGLU_FFN(d_model, mult=ffn_mult, dropout=dropout); self.drop=nn.Dropout(dropout)

    def forward(self, x, kv_cache=None, global_mask=None):
        h=self.pre(x); outs=[]; new_cache=None
        if self.attn is not None:
            a_in=self.to_a(h); a, new_cache = self.attn(a_in, kv_cache=kv_cache, global_mask=global_mask)
            outs.append(a)
        if self.mamba is not None:
            m_in=self.to_m(h); m=self.mamba(m_in); outs.append(m)
        if len(outs)==2:
            if self.fusion=="concat":
                y=self.mix(torch.cat(outs, -1))
            else:
                y = 0
                if self.attn is not None: y = y + self.mix(outs[0])
                if self.mamba is not None: y = y + self.mix2(outs[1])
                y = y / ((1 if self.attn is not None else 0) + (1 if self.mamba is not None else 0))
        else:
            if self.fusion=="concat":
                y=self.mix(outs[0])
            else:
                y = self.mix(outs[0]) if (self.attn is not None) else self.mix2(outs[0])
        x = x + self.drop(y)
        x = x + self.drop(self.ffn(self.pre(x)))
        return x, new_cache

class Hymba(nn.Module):
    def __init__(self, vocab_size:int, d_model:int, n_layers:int, n_heads:int, n_kv_heads:int,
                 attn_dim:int, swa_layers:set, swa_window:int, mamba_dim:int, fusion:str="mean",
                 num_meta_tokens:int=0, kv_share:bool=True, max_position:int=65536, ffn_mult:float=4.0, dropout:float=0.0):
        super().__init__()
        self.vocab_size=vocab_size; self.d_model=d_model; self.n_layers=n_layers
        self.num_meta_tokens=num_meta_tokens; self.kv_share=kv_share
        self.tok=nn.Embedding(vocab_size, d_model)
        self.rope=RotaryEmbedding(dim=(attn_dim//max(1,n_heads)), max_position=max_position) if attn_dim>0 else None
        self.swa_layers=swa_layers
        self.blocks=nn.ModuleList([
            HymbaBlock(d_model, attn_dim, n_heads, n_kv_heads, (i in swa_layers), swa_window, self.rope, mamba_dim,
                      ffn_mult=ffn_mult, dropout=dropout, fusion=fusion)
            for i in range(n_layers)
        ])
        self.norm=RMSNorm(d_model); self.head=nn.Linear(d_model, vocab_size, bias=False)
        self.meta = nn.Parameter(torch.randn(1, num_meta_tokens, d_model)*0.02) if num_meta_tokens>0 else None
        # KV owner mapping for cross-layer sharing
        self.owner=list(range(n_layers))
        if kv_share:
            for a in range(0,n_layers-1,2): self.owner[a+1]=a

    def forward(self, idx, targets=None, kv_caches=None):
        B,T=idx.shape; x=self.tok(idx); meta_add=0
        if self.meta is not None and T>1:
            x=torch.cat([self.meta.expand(B,-1,-1), x], 1); meta_add=self.num_meta_tokens
        gmask=None
        if meta_add>0:
            gm=torch.zeros((B,x.size(1)), dtype=torch.bool, device=x.device); gm[:,:meta_add]=True; gmask=gm[:, -x.size(1):]
        new=[None]*self.n_layers if kv_caches is not None else None
        h=x
        for li,blk in enumerate(self.blocks):
            kv_in = kv_caches[self.owner[li]] if (kv_caches is not None) else None
            h, kv_out = blk(h, kv_cache=kv_in, global_mask=gmask)
            if new is not None and li==self.owner[li]: new[self.owner[li]]=kv_out
        h=self.norm(h); logits=self.head(h)
        loss=None
        if targets is not None:
            s=meta_add; lf=logits[:, s:s+targets.size(1), :]
            loss = F.cross_entropy(lf.reshape(-1, lf.size(-1)), targets.reshape(-1))
        return {"logits":logits,"loss":loss,"kv_caches":new}

    def estimate_kv_cache_bytes(self, seq_len:int, attn_dim:int, n_kv_heads:int, n_heads:int, dtype=torch.float16):
        if attn_dim<=0 or n_heads<=0: return 0
        Dh = (attn_dim//n_heads)
        KV = n_kv_heads if n_kv_heads>0 else n_heads
        bytes_per = torch.finfo(dtype).bits//8
        per_owner = 2 * KV * seq_len * Dh * bytes_per  # K & V
        owners = len(set(self.owner))
        return per_owner * owners

def dims_from_ratio(d_model:int, ratio_attn:float, ratio_mamba:float):
    if ratio_attn==0 and ratio_mamba>0:
        return 0, d_model
    if ratio_mamba==0 and ratio_attn>0:
        return d_model, 0
    frac = ratio_attn / (ratio_attn + ratio_mamba)
    attn_dim = int(round(d_model * frac))
    attn_dim = max(0, min(d_model, attn_dim))
    mamba_dim = d_model - attn_dim
    return attn_dim, mamba_dim



## 4) 훈련/평가 유틸(AMP, Cosine+Warmup, 토큰 정확도, Throughput)


In [None]:
from transformers import get_cosine_schedule_with_warmup

def build_model_variant(ratio_attn, ratio_mamba, n_heads, n_kv_heads, swa_mode:str, fusion:str,
                        use_meta:bool, kv_share:bool):
    attn_dim, mamba_dim = dims_from_ratio(cfg.d_model, ratio_attn, ratio_mamba)
    # SWA layer set
    if swa_mode=="none":
        swa_layers=set()
    elif swa_mode=="all":
        swa_layers=set(range(cfg.n_layers))
    elif swa_mode=="default":  # first/middle/last are global; others SWA
        mid=cfg.n_layers//2
        swa_layers=set([i for i in range(cfg.n_layers) if i not in (0,mid,cfg.n_layers-1)])
    else:
        swa_layers=set()
    model = Hymba(
        vocab_size=VOCAB, d_model=cfg.d_model, n_layers=cfg.n_layers, n_heads=n_heads, n_kv_heads=n_kv_heads,
        attn_dim=attn_dim, swa_layers=swa_layers, swa_window=cfg.swa_window, mamba_dim=mamba_dim, fusion=fusion,
        num_meta_tokens=(cfg.num_meta_tokens if use_meta else 0), kv_share=kv_share, max_position=cfg.max_position,
        ffn_mult=cfg.ffn_mult if hasattr(cfg, "ffn_mult") else 4.0, dropout=0.0
    ).to(device)
    return model, attn_dim, mamba_dim

def tokens_accuracy(logits, targets):
    pred = logits.argmax(-1)
    correct = (pred == targets).float()
    return correct.mean().item()

def train_one(model:nn.Module, train_dl, val_dl, attn_dim, n_kv_heads, n_heads):
    opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay, betas=(0.9,0.95), eps=1e-8)
    total_steps = cfg.max_steps
    sch = get_cosine_schedule_with_warmup(opt, num_warmup_steps=cfg.warmup_steps, num_training_steps=total_steps)
    scaler = torch.cuda.amp.GradScaler(enabled=cfg.amp)
    step=0; t0=time.time(); tokens=0
    for ep in range(cfg.epochs):
        for xb,yb in train_dl:
            xb, yb = xb.to(device), yb.to(device)
            with torch.cuda.amp.autocast(enabled=cfg.amp):
                out = model(xb, targets=yb); loss = out["loss"]
            scaler.scale(loss).backward()
            if ((step+1)%cfg.grad_accum)==0:
                scaler.unscale_(opt)
                if cfg.grad_clip>0: nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
                scaler.step(opt); scaler.update(); opt.zero_grad(set_to_none=True); sch.step()
            step+=1; tokens+=xb.numel()
            if step>=total_steps: break
        if step>=total_steps: break
    elapsed = time.time()-t0
    tps = tokens/max(1e-9, elapsed)

    # Evaluate
    model.eval(); nll=0.0; n_tok=0; accs=[]; 
    with torch.no_grad(), torch.cuda.amp.autocast(enabled=cfg.amp):
        for xb,yb in val_dl:
            xb, yb = xb.to(device), yb.to(device)
            out = model(xb, targets=yb)
            loss = out["loss"]
            n = xb.numel(); nll += loss.item()*n; n_tok += n
            accs.append(tokens_accuracy(out["logits"][:, :yb.size(1), :], yb))
    ppl = math.exp(nll/max(1,n_tok)); acc = float(np.mean(accs))
    cache_mb = model.estimate_kv_cache_bytes(cfg.seq_len, attn_dim, n_kv_heads, n_heads)/2**20
    return {"ppl":ppl, "acc":acc, "tps":tps, "cache_mb":cache_mb, "steps":step}



## 5) 실험 매트릭스(1→13)
- **Param Ratio (Attn:Mamba)**는 표 수치를 반영(예: `1:8.48` → 비율=1/8.48)
- **6)**은 **Attention-Only** (Mamba 0), **10)**은 **6의 GQA 버전**
- **7,8**은 SWA 배치(전 레이어 / 첫·중·마 제외)
- **9**는 8)+**Cross-layer KV sharing**
- **11**은 9)+**Fusion=Concat**
- **12**는 1)+**Meta Tokens**
- **13**은 9)+**Meta Tokens**


In [None]:
exps = [
    # (id, desc, (attn, mamba) ratio, n_heads, n_kv_heads, swa_mode, fusion, meta, kv_share)
    (1, "Mamba Heads Only",               (0.0, 1.0), cfg.n_heads, cfg.n_kv_heads, "none",    "mean",   False, False),
    (2, "Mamba + 4 Attn Heads",           (1.0, 8.48), cfg.n_heads, cfg.n_heads,   "none",    "mean",   False, False),  # no GQA
    (3, "Mamba + 8 Attn Heads",           (1.0, 4.24), cfg.n_heads, cfg.n_heads,   "none",    "mean",   False, False),  # no GQA
    (4, "Mamba + 16 Attn Heads",          (1.0, 2.12), cfg.n_heads, cfg.n_heads,   "none",    "mean",   False, False),  # no GQA
    (5, "4) + GQA",                       (1.0, 3.64), cfg.n_heads, max(1,cfg.n_kv_heads//1), "none",    "mean",   False, False),
    (6, "Attn Heads Only (Llama)",        (1.0, 0.0), cfg.n_heads, cfg.n_heads,    "none",    "mean",   False, False),  # attn-only
    (7, "5) + All SWA's",                 (1.0, 3.64), cfg.n_heads, max(1,cfg.n_kv_heads//1), "all",     "mean",   False, False),
    (8, "5) + SWA's + Full Attn",         (1.0, 3.64), cfg.n_heads, max(1,cfg.n_kv_heads//1), "default", "mean",   False, False),
    (9, "8) + Cross-layer KV sharing",    (1.0, 5.23), cfg.n_heads, max(1,cfg.n_kv_heads//1), "default", "mean",   False, True),
    (10,"6) + Same KV compression",       (1.0, 0.0), cfg.n_heads, max(1,cfg.n_kv_heads//1), "none",    "mean",   False, False),
    (11,"9) Replace Mean by Concat",      (1.0, 5.23), cfg.n_heads, max(1,cfg.n_kv_heads//1), "default", "concat", False, True),
    (12,"1) + Meta Tokens",               (0.0, 1.0), cfg.n_heads, cfg.n_kv_heads, "none",    "mean",   True,  False),
    (13,"9) + Meta Tokens",               (1.0, 5.23), cfg.n_heads, max(1,cfg.n_kv_heads//1), "default", "mean",   True,  True),
]
results = []
for (eid, desc, ratio, n_heads, n_kv_heads, swa_mode, fusion, meta, kv_share) in exps:
    print(f"=== Exp {eid}: {desc} ===")
    model, attn_dim, mamba_dim = build_model_variant(ratio[0], ratio[1], n_heads, n_kv_heads, swa_mode, fusion, meta, kv_share)
    out = train_one(model, train_dl, val_dl, attn_dim, n_kv_heads, n_heads)
    attn_share = attn_dim / max(1, (attn_dim+mamba_dim))
    results.append({
        "ID": eid,
        "Configuration": desc,
        "ParamRatio(Attn:Mamba)": f"{ratio[0]}:{ratio[1]}",
        "AttnShare(%)": round(100*attn_share, 2),
        "PPL(↓)": round(out["ppl"], 3),
        "Acc(↑)": round(out["acc"], 4),
        "Throughput(tokens/s)↑": int(out["tps"]),
        "Cache(MB)↓": round(out["cache_mb"], 2),
        "Steps": out["steps"],
        "Fusion": fusion,
        "SWA": swa_mode,
        "Meta": meta,
        "KVshare": kv_share,
    })

df = pd.DataFrame(results).sort_values("ID").reset_index(drop=True)
df



## 6) 결과 요약
- 표 형태로 정리하고, 구성 요소별 차이를 빠르게 비교합니다.


In [None]:
import pandas as pd
from IPython.display import display
pd.set_option("display.max_colwidth", 120)
display(df)
