
# Hymba (Hybrid Attention + Mamba) — 최신 토크나이저 · 자동 데이터 다운로드 · **ORPO(SFT+DPO 계열)** 동시학습
이 노트북은 기존 구현을 다음과 같이 보강합니다.

1) **최신 토크나이저**: 🤗 `tokenizers`의 **Unigram**(SentencePiece 계열)으로 **직접 학습**하여 사용  
2) **데이터 자동 다운로드**: 경로가 없으면 🤗 `datasets`에서 자동 로드(토치텍스트 미사용)  
3) **ORPO**: SFT 손실 + **상대 로그 오즈**(log-odds) 항을 결합한 **단일 단계** 선호 최적화 구현 및 해설  



## 0) 의존성 설치 (파이썬 코드로 실행)
- `tokenizers`의 Unigram 모델, `datasets`, `transformers`, `mamba-ssm`(가능 시) 등 설치


In [None]:
import sys, subprocess, pkg_resources

def pip_install(pkgs):
    for p in pkgs:
        try:
            pkg_resources.get_distribution(p.split('==')[0].split('>=')[0])
        except Exception:
            subprocess.check_call([sys.executable, "-m", "pip", "install", p])

pip_install([
    "torch>=2.2.0",
    "datasets>=2.18.0",
    "transformers>=4.41.0",
    "tokenizers>=0.15.2",
    "accelerate>=0.32.0",
    "mamba-ssm>=2.2.2",
    "matplotlib>=3.8.0"
])

import os, math, time, dataclasses, typing as T
import numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, TensorDataset
from datasets import load_dataset, Dataset as HFDataset
from transformers import get_cosine_schedule_with_warmup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)



## 1) 설정
- 토크나이저는 **unigram**으로 기본 설정  
- 데이터셋 지정이 **로컬 경로**면 해당 파일/폴더에서 읽고, 없으면 **HuggingFace에서 자동 다운로드**


In [None]:
@dataclasses.dataclass
class CFG:
    # 데이터
    sft_dataset: str = "karpathy/tiny_shakespeare"  # 로컬 경로 or HF hub 이름
    pref_dataset: str = "trl-lib/hh-rlhf-helpful-base"  # ORPO용 (prompt, chosen, rejected)
    seq_len: int = 512
    batch_size: int = 16
    epochs: int = 1
    max_steps: int = 150
    lr: float = 3e-4
    weight_decay: float = 0.1
    warmup_steps: int = 40
    grad_accum: int = 1
    grad_clip: float = 1.0
    amp: bool = True

    # 토크나이저
    tokenizer_type: str = "unigram"   # ["unigram", "gpt2_fallback"]
    vocab_size: int = 32000
    bos_token: str = "<|bos|>"
    eos_token: str = "<|eos|>"
    pad_token: str = "<|pad|>"

    # 모델
    d_model: int = 512
    n_layers: int = 12
    n_heads: int = 8
    n_kv_heads: int = 4
    attn_ratio: float = 0.5
    swa_window: int = 256
    swa_layers: T.Optional[T.List[int]] = None  # None이면 (첫/가운데/마지막 제외 전체 SWA)
    num_meta_tokens: int = 4
    kv_share: bool = True
    ffn_mult: float = 4.0
    dropout: float = 0.0
    max_position: int = 65536
    return_attn: bool = False

cfg = CFG()
cfg



## 2) 데이터 준비 — 로컬이 없으면 HF에서 자동 로드
- SFT: 순수 텍스트(예: Tiny Shakespeare) → 토큰 청크
- ORPO: **선호 데이터**(prompt, chosen, rejected) → 토큰 쌍


In [None]:
def read_local_text(path:str)->str:
    if not os.path.exists(path):
        return None
    if os.path.isdir(path):
        texts=[]
        for root,_,files in os.walk(path):
            for f in files:
                if f.endswith(".txt"):
                    with open(os.path.join(root,f),"r",encoding="utf-8",errors="ignore") as fh:
                        texts.append(fh.read())
        return "\n\n".join(texts) if texts else None
    else:
        with open(path,"r",encoding="utf-8",errors="ignore") as fh:
            return fh.read()

def get_sft_corpus(spec:str)->str:
    # 로컬 경로 텍스트 우선
    text = read_local_text(spec)
    if text is not None:
        return text
    # 없으면 HF datasets에서 다운로드
    ds = load_dataset(spec)
    # text/ content/ document 등 가능성 처리
    keys = ["text","content","document","raw","data"]
    for k in keys:
        if k in ds["train"].column_names:
            return "\n\n".join(ds["train"][k])
    # 마지막 수단: train split 통째로 문자열화
    return "\n\n".join(map(str, ds["train"][:1000]))

def load_pref_dataset(name:str, split:str="train"):
    # HH-RLHF helpful-base 또는 UltraFeedback binarized 등 (prompt, chosen, rejected) 지원
    ds = load_dataset(name, split=split)
    cols = ds.column_names
    def pick(x, keys):
        for k in keys:
            if k in x and x[k] is not None:
                return x[k]
        return ""
    out = {"prompt":[],"chosen":[],"rejected":[]}
    for ex in ds:
        # 대화형/문자열 형식 모두 처리
        prompt = pick(ex, ["prompt","question","input","instruction"])
        chosen = pick(ex, ["chosen","chosen_response","chosen_text"])
        rejected= pick(ex, ["rejected","rejected_response","rejected_text"])
        # 대화형 형식인 경우(list of dict: role/content)
        if isinstance(prompt, list):
            prompt = "\n".join([f"{m.get('role','user').capitalize()}: {m.get('content','')}".strip() for m in prompt])
        if isinstance(chosen, list):
            chosen = "\n".join([m.get("content","") for m in chosen])
        if isinstance(rejected, list):
            rejected = "\n".join([m.get("content","") for m in rejected])
        if prompt and chosen and rejected:
            out["prompt"].append(prompt)
            out["chosen"].append(chosen)
            out["rejected"].append(rejected)
    return out



## 3) 최신 토크나이저 — **Unigram(SentencePiece 계열)** 직접 학습
- 데이터에서 직접 학습 → 도메인 적합 & 최신 서브워드 품질  
- 특수 토큰 **BOS/EOS/PAD** 포함, 후처리 템플릿으로 자동 부착


In [None]:
from tokenizers import Tokenizer
from tokenizers.models import Unigram
from tokenizers.trainers import UnigramTrainer
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.normalizers import NFKC, Lowercase, Sequence as NormSeq
from tokenizers.processors import TemplateProcessing

def train_unigram_tokenizer(corpus_text:str, vocab_size:int, bos:str, eos:str, pad:str):
    # 간단한 iterator로 메모리 사용 절감
    def batch_iter(text, bs=1000000):
        for i in range(0, len(text), bs):
            yield text[i:i+bs]
    tok = Tokenizer(Unigram())
    tok.normalizer = NormSeq([NFKC(), Lowercase()])
    tok.pre_tokenizer = Whitespace()
    trainer = UnigramTrainer(vocab_size=vocab_size, special_tokens=[pad,bos,eos])
    tok.train_from_iterator(batch_iter(corpus_text), trainer=trainer)
    # 특수 토큰 ID
    pad_id = tok.token_to_id(pad)
    bos_id = tok.token_to_id(bos)
    eos_id = tok.token_to_id(eos)
    tok.post_processor = TemplateProcessing(
        single=f"{bos} $A {eos}",
        pair=f"{bos} $A {eos} $B:1 {eos}:1",
        special_tokens=[(bos, bos_id), (eos, eos_id)]
    )
    class ModernTok:
        def __init__(self, tk, bos_id, eos_id, pad_id):
            self.tk=tk; self.bos_token_id=bos_id; self.eos_token_id=eos_id; self.pad_token_id=pad_id
        def encode(self, s): return self.tk.encode(s).ids
        def decode(self, ids): return self.tk.decode(ids)
        @property
        def vocab_size(self): return self.tk.get_vocab_size()
    return ModernTok(tok, bos_id, eos_id, pad_id)

# 코퍼스 수집 → 토크나이저 학습
corpus = get_sft_corpus(cfg.sft_dataset)
modern_tok = train_unigram_tokenizer(corpus, cfg.vocab_size, cfg.bos_token, cfg.eos_token, cfg.pad_token)
VOCAB = modern_tok.vocab_size; EOS_ID = modern_tok.eos_token_id
print("vocab:", VOCAB, "eos:", EOS_ID)



## 4) SFT 데이터 → (X, Y) 청크
- 입력 `X`는 토큰 시퀀스, 타깃 `Y`는 **다음 토큰**(shifted)  
- `seq_len` 단위로 자르고 마지막 토큰은 EOS로 채움


In [None]:
def chunk_sft_tokens(text:str, seq_len:int, tok=modern_tok, eos_id=EOS_ID):
    ids = modern_tok.encode(text)
    n = (len(ids)//seq_len)*seq_len
    ids = ids[:n]
    arr = np.array(ids, dtype=np.int64).reshape(-1, seq_len)
    y = np.copy(arr); y[:, :-1] = arr[:, 1:]; y[:, -1] = eos_id
    X = torch.tensor(arr); Y = torch.tensor(y)
    return TensorDataset(X, Y)

sft_ds = chunk_sft_tokens(corpus, cfg.seq_len, modern_tok, EOS_ID)
sft_dl = DataLoader(sft_ds, batch_size=cfg.batch_size, shuffle=True, drop_last=True)
len(sft_ds), cfg.seq_len



## 5) 모델: Hymba (하이브리드 블록: Global/SWA 어텐션 + Mamba) + 메타 토큰 + 인접 KV 공유


In [None]:
# --- 핵심 모듈 (RoPE, RMSNorm, FFN, GQA(SWA/Global), Mamba, Block, Model) ---
class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position=131072, base=10000.0):
        super().__init__()
        inv = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        t = torch.arange(max_position, dtype=torch.float32)
        freqs = torch.einsum("i,j->ij", t, inv)
        self.register_buffer("cos_cached", torch.cos(freqs), persistent=False)
        self.register_buffer("sin_cached", torch.sin(freqs), persistent=False)
    def forward(self, x, positions):
        cos = self.cos_cached[positions][:, None, None, :]
        sin = self.sin_cached[positions][:, None, None, :]
        x1 = x[..., ::2]; x2 = x[..., 1::2]
        return torch.stack([x1*cos - x2*sin, x1*sin + x2*cos], dim=-1).flatten(-2)

class RMSNorm(nn.Module):
    def __init__(self, d, eps=1e-6): super().__init__(); self.eps=eps; self.w=nn.Parameter(torch.ones(d))
    def forward(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.w

class SwiGLU_FFN(nn.Module):
    def __init__(self, d, mult=4.0, dropout=0.0):
        super().__init__()
        h=int(d*mult); self.w1=nn.Linear(d,h, bias=False); self.w2=nn.Linear(d,h, bias=False)
        self.w3=nn.Linear(h,d, bias=False); self.drop=nn.Dropout(dropout)
    def forward(self, x): return self.w3(self.drop(F.silu(self.w1(x)) * self.w2(x)))

def band_mask(T, w, device, dtype):
    i = torch.arange(T, device=device); j = torch.arange(T, device=device)
    m = (j[None,:] <= i[:,None]) & (j[None,:] >= (i[:,None]-w+1))
    M = torch.zeros((T,T), device=device, dtype=dtype)
    return M.masked_fill(~m, float("-inf"))

class GQA(nn.Module):
    def __init__(self, d_model, H, KV, rope=None, use_swa=False, swa_window=4096, dropout=0.0, ret_attn=False):
        super().__init__(); assert d_model%H==0
        self.H=H; self.KV=KV; self.rep=H//KV; self.Dh=d_model//H
        self.q=nn.Linear(d_model, H*self.Dh, bias=False)
        self.k=nn.Linear(d_model, KV*self.Dh, bias=False)
        self.v=nn.Linear(d_model, KV*self.Dh, bias=False)
        self.o=nn.Linear(H*self.Dh, d_model, bias=False)
        self.drop=nn.Dropout(dropout); self.rope=rope; self.use_swa=use_swa; self.swa_window=swa_window; self.ret=ret_attn
    def forward(self, x, kv_cache=None, global_mask=None, need_weights=False):
        B,T,C=x.shape
        q=self.q(x).view(B,T,self.H,self.Dh).transpose(1,2)
        k=self.k(x).view(B,T,self.KV,self.Dh).transpose(1,2)
        v=self.v(x).view(B,T,self.KV,self.Dh).transpose(1,2)
        if kv_cache is not None and kv_cache[0] is not None and kv_cache[0].numel()>0:
            pk,pv=kv_cache; k=torch.cat([pk,k],2); v=torch.cat([pv,v],2)
        if self.rope is not None:
            pos_q=torch.arange(k.size(2)-T, k.size(2), device=x.device)
            pos_k=torch.arange(0, k.size(2), device=x.device)
            q=self.rope(q, pos_q); k_exp=k.repeat_interleave(self.rep,1); k_exp=self.rope(k_exp, pos_k)
        else:
            k_exp=k.repeat_interleave(self.rep,1)
        v_exp=v.repeat_interleave(self.rep,1)
        Tc=k_exp.size(2)
        if self.use_swa:
            M = band_mask(Tc, self.swa_window, x.device, q.dtype)[:, -T:]
        else:
            i = torch.arange(Tc, device=x.device); j = torch.arange(Tc, device=x.device)
            causal = (j[None,:] <= i[:,None]); causal = causal[:, -T:]
            M = torch.zeros((T,Tc), device=x.device, dtype=q.dtype).masked_fill(~causal, float("-inf"))
        if global_mask is not None:
            full = torch.zeros_like(M)
            M = torch.where(global_mask[:, :, None], full[None,:,:], M[None,:,:]).squeeze(0)
        q_=q.reshape(B*self.H,T,self.Dh); k_=k_exp.reshape(B*self.H,Tc,self.Dh); v_=v_exp.reshape(B*self.H,Tc,self.Dh)
        M_=M.unsqueeze(0).expand(B*self.H, -1, -1)
        out = F.scaled_dot_product_attention(q_, k_, v_, attn_mask=M_, is_causal=False,
                                             dropout_p=float(self.drop.p) if self.training else 0.0)
        attn=None
        out=out.view(B,self.H,T,self.Dh).transpose(1,2).reshape(B,T,self.H*self.Dh)
        out=self.o(out); new_cache=(k.detach(), v.detach())
        return out, new_cache, attn



from mamba_ssm import Mamba as _Mamba

class MambaLayer(nn.Module):
    def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
        super().__init__(); self.inner=_Mamba(d_model=d_model, d_state=d_state, d_conv=d_conv, expand=expand) if _Mamba else nn.Identity()
    def forward(self, x): return self.inner(x)

class HymbaBlock(nn.Module):
    def __init__(self, d_model, H, KV, attn_dim, use_swa, swa_window, rope, mamba_dim, ffn_mult=4.0, dropout=0.0, ret_attn=False):
        super().__init__()
        self.pre=RMSNorm(d_model); self.to_a=nn.Linear(d_model, attn_dim, bias=False) if attn_dim>0 else None
        self.to_m=nn.Linear(d_model, mamba_dim, bias=False) if mamba_dim>0 else None
        self.attn=GQA(attn_dim, H, KV, rope, use_swa, swa_window, dropout, ret_attn) if attn_dim>0 else None
        self.mamba=MambaLayer(mamba_dim) if mamba_dim>0 else None
        self.from_paths=nn.Linear(attn_dim+mamba_dim, d_model, bias=False)
        self.ffn=SwiGLU_FFN(d_model, mult=ffn_mult, dropout=dropout); self.drop=nn.Dropout(dropout)
    def forward(self, x, kv_cache=None, global_mask=None, need_attn=False):
        h=self.pre(x); outs=[]; new_cache=None
        if self.attn is not None:
            a, new_cache, _ = self.attn(self.to_a(h), kv_cache=kv_cache, global_mask=global_mask, need_weights=need_attn); outs.append(a)
        if self.mamba is not None:
            m=self.mamba(self.to_m(h)); outs.append(m)
        y=self.from_paths(torch.cat(outs, -1) if len(outs)>1 else outs[0])
        x=x+self.drop(y); x=x+self.drop(self.ffn(self.pre(x)))
        return x, new_cache, None

class HymbaForCausalLM(nn.Module):
    def __init__(self, vocab_size, d_model=512, n_layers=12, n_heads=8, n_kv_heads=4, attn_ratio=0.5,
                 swa_layers=None, swa_window=256, num_meta_tokens=4, kv_share=True,
                 max_position=65536, ffn_mult=4.0, dropout=0.0, return_attn=False):
        super().__init__()
        self.vocab_size=vocab_size; self.d_model=d_model; self.n_layers=n_layers
        self.num_meta_tokens=num_meta_tokens; self.kv_share=kv_share; self.ret=return_attn
        self.tok=nn.Embedding(vocab_size, d_model)
        attn_dim=int(d_model*attn_ratio); mamba_dim=d_model-attn_dim
        self.rope=RotaryEmbedding(dim=(attn_dim//n_heads), max_position=max_position)
        if swa_layers is None:
            mid=n_layers//2; swa_layers=[i for i in range(n_layers) if i not in (0,mid,n_layers-1)]
        self.swa=set(swa_layers)
        self.layers=nn.ModuleList([
            HymbaBlock(d_model, n_heads, n_kv_heads, attn_dim, (i in self.swa), swa_window, self.rope, mamba_dim, ffn_mult, dropout, return_attn)
            for i in range(n_layers)
        ])
        self.norm=RMSNorm(d_model); self.head=nn.Linear(d_model, vocab_size, bias=False)
        self.meta = nn.Parameter(torch.randn(1, num_meta_tokens, d_model)*0.02) if num_meta_tokens>0 else None
        # 인접 레이어 캐시 공유: (0→1), (2→3), ...
        self.owner=list(range(n_layers))
        for a in range(0,n_layers-1,2): self.owner[a+1]=a

    def forward(self, idx, targets=None, kv_caches=None, return_attn=False):
        B,T=idx.shape; x=self.tok(idx); meta_add=0
        if self.meta is not None and T>1:
            x=torch.cat([self.meta.expand(B,-1,-1), x], 1); meta_add=self.num_meta_tokens
        gmask=None
        if meta_add>0:
            gm=torch.zeros((B,x.size(1)), dtype=torch.bool, device=x.device); gm[:,:meta_add]=True; gmask=gm[:, -x.size(1):]
        new=[None]*self.n_layers if kv_caches is not None else None
        h=x
        for li,L in enumerate(self.layers):
            kv = kv_caches[self.owner[li]] if kv_caches is not None else None
            h, kv_out, _ = L(h, kv_cache=kv, global_mask=gmask, need_attn=(self.ret or return_attn))
            if new is not None and li==self.owner[li]: new[self.owner[li]]=kv_out
        h=self.norm(h); logits=self.head(h); loss=None
        if targets is not None:
            s=meta_add; lf=logits[:, s:s+targets.size(1), :]
            loss=F.cross_entropy(lf.reshape(-1, lf.size(-1)), targets.reshape(-1))
        return {"logits":logits,"loss":loss,"kv_caches":new}

    def estimate_kv_cache_bytes(self, seq_len:int, dtype=torch.float16):
        blk=self.layers[0]; Dh=blk.attn.Dh if blk.attn else 0; KV=blk.attn.KV if blk.attn else 0
        b= torch.finfo(dtype).bits//8; per = 2*KV*seq_len*Dh*b; owners=len(set(self.owner));
        return per*owners



## 6) 학습 유틸
- **Cosine 스케줄러 + Warmup**, **AMP**, **Gradient Clipping**  
- 간단 **Perplexity** 평가


In [None]:
def build_model(cfg:CFG, *, swa_layers=None, use_meta=True, kv_share=True):
    return HymbaForCausalLM(
        vocab_size=VOCAB, d_model=cfg.d_model, n_layers=cfg.n_layers, n_heads=cfg.n_heads,
        n_kv_heads=cfg.n_kv_heads, attn_ratio=cfg.attn_ratio, swa_layers=swa_layers,
        swa_window=cfg.swa_window, num_meta_tokens=(cfg.num_meta_tokens if use_meta else 0),
        kv_share=kv_share, max_position=cfg.max_position, ffn_mult=cfg.ffn_mult,
        dropout=cfg.dropout, return_attn=cfg.return_attn
    ).to(device)

def evaluate_ppl(model:nn.Module, data_loader:DataLoader, max_batches:int=20):
    model.eval(); nll=0.0; n_tok=0
    with torch.no_grad(), torch.cuda.amp.autocast(enabled=cfg.amp):
        for i,(xb,yb) in enumerate(data_loader):
            xb, yb = xb.to(device), yb.to(device)
            out = model(xb, targets=yb)
            loss = out["loss"]
            n = xb.numel()
            nll += loss.item()*n; n_tok += n
            if i+1>=max_batches: break
    ppl = math.exp(nll/max(1,n_tok))
    model.train()
    return ppl

def train_sft(model:nn.Module, data_loader:DataLoader, cfg:CFG):
    opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay, betas=(0.9,0.95), eps=1e-8)
    total_steps = cfg.max_steps
    sch = get_cosine_schedule_with_warmup(opt, num_warmup_steps=cfg.warmup_steps, num_training_steps=total_steps)
    scaler = torch.cuda.amp.GradScaler(enabled=cfg.amp)
    step=0; t0=time.time(); tokens=0
    for epoch in range(cfg.epochs):
        for xb,yb in data_loader:
            xb, yb = xb.to(device), yb.to(device)
            with torch.cuda.amp.autocast(enabled=cfg.amp):
                out = model(xb, targets=yb); loss = out["loss"]
            scaler.scale(loss).backward()
            if ((step+1)%cfg.grad_accum)==0:
                scaler.unscale_(opt)
                if cfg.grad_clip>0: nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
                scaler.step(opt); scaler.update(); opt.zero_grad(set_to_none=True); sch.step()
            step+=1; tokens+=xb.numel()
            if step%20==0:
                tps = tokens/max(1e-9, time.time()-t0)
                print(f"[SFT] step {step}/{total_steps} loss {loss.item():.3f}  ~{tps:.1f} tok/s")
            if step>=total_steps: break
        if step>=total_steps: break
    return {"steps":step, "tokens":tokens, "time_sec":time.time()-t0}



## 7) SFT 베이스라인 (Global-only, 메타 off, KV 공유 off)


In [None]:
swa_layers_none = []
baseline = build_model(cfg, swa_layers=swa_layers_none, use_meta=False, kv_share=False)
baseline_summary = train_sft(baseline, sft_dl, cfg)
baseline_ppl = evaluate_ppl(baseline, sft_dl, max_batches=20)
print("Baseline PPL:", baseline_ppl); baseline_summary



## 8) Hybrid: Global+SWA(+메타 on), 인접 **KV 공유**
- 첫/가운데/마지막은 Global, 그 외는 SWA


In [None]:
mid = cfg.n_layers//2
default_swa = [i for i in range(cfg.n_layers) if i not in (0, mid, cfg.n_layers-1)]
hybrid_meta = build_model(cfg, swa_layers=default_swa, use_meta=True, kv_share=True)
hybrid_meta_sum = train_sft(hybrid_meta, sft_dl, cfg)
hybrid_meta_ppl = evaluate_ppl(hybrid_meta, sft_dl)
print("KV bytes estimate (no-share vs share) with seq_len):")
tmp_noshare = build_model(cfg, swa_layers=default_swa, use_meta=True, kv_share=False)
bytes_noshare = tmp_noshare.estimate_kv_cache_bytes(cfg.seq_len)
bytes_share   = hybrid_meta.estimate_kv_cache_bytes(cfg.seq_len)
print("no-share:", round(bytes_noshare/2**20,2), "MB  | share:", round(bytes_share/2**20,2), "MB")
hybrid_meta_ppl, hybrid_meta_sum



## 9) **ORPO (SFT + 로그 오즈 선호항)** 구현 및 해설
- **핵심**: SFT의 NLL 손실에 **선호 대비 항**을 추가  
- ORPO 논문은 **참조 모델 없이**, `L = L_SFT + β · L_ratio`를 제안  
- 여기서 `L_ratio = -log σ( Δ )`, `Δ ≈ (avg log p_θ(chosen|x) - avg log p_θ(rejected|x))`  
  - 실무적으로 TRL 구현은 **길이 정규화된 토큰 평균 로그확률 차이**에 로지스틱을 적용합니다.  


In [None]:
def mean_logprob_for_response(model, prompt_ids, resp_ids):
    # prompt + response를 이어서 forward → response 구간의 평균 log P 토큰 계산
    B = prompt_ids.size(0)
    x = torch.cat([prompt_ids, resp_ids], dim=1)
    with torch.no_grad():
        out = model(x.to(device))
        logits = out["logits"][:, :-1, :]  # predict next
    target = x[:, 1:].to(device)
    # response token 위치 마스크
    resp_mask = torch.zeros_like(target, dtype=torch.bool)
    resp_mask[:, prompt_ids.size(1)-1 : prompt_ids.size(1)-1 + resp_ids.size(1)] = True
    logp = F.log_softmax(logits, dim=-1).gather(-1, target.unsqueeze(-1)).squeeze(-1)
    resp_logp = logp.masked_select(resp_mask.to(device)).view(B, -1)
    return resp_logp.mean(dim=1)  # (B,)

def encode_prompt_and_resp(batch_prompts, batch_resps, max_len=256):
    # 간단한 포맷: "User: ...\nAssistant: ..."
    prompts = [f"User: {p}\nAssistant:" for p in batch_prompts]
    inputs  = [modern_tok.encode(t) for t in prompts]
    resps   = [modern_tok.encode(r) for r in batch_resps]
    # 길이 제한 및 텐서화(+패딩)
    def pad_to(arrs, L):
        t = torch.full((len(arrs), L), modern_tok.pad_token_id, dtype=torch.long)
        for i,a in enumerate(arrs): t[i, :min(L,len(a))]=torch.tensor(a[:L])
        return t
    P = pad_to(inputs, max_len)
    R = pad_to(resps, max_len//2)
    return P, R

class PrefDataset(Dataset):
    def __init__(self, pref_dict, max_prompt_len=256):
        self.d=pref_dict; self.max_prompt_len=max_prompt_len
    def __len__(self): return len(self.d["prompt"])
    def __getitem__(self, i):
        return self.d["prompt"][i], self.d["chosen"][i], self.d["rejected"][i]

def collate_pref(batch, max_prompt_len=256):
    prompts, chosens, rejecteds = zip(*batch)
    P, C = encode_prompt_and_resp(prompts, chosens, max_len=max_prompt_len)
    _, R = encode_prompt_and_resp(prompts, rejecteds, max_len=max_prompt_len)
    return P, C, R

pref_raw = load_pref_dataset(cfg.pref_dataset, split="train[:2000]")
pref_dl = DataLoader(PrefDataset(pref_raw), batch_size=cfg.batch_size, shuffle=True, drop_last=True, collate_fn=collate_pref)
len(pref_raw["prompt"])


In [None]:
def train_orpo(model:nn.Module, pref_loader:DataLoader, cfg:CFG, beta:float=0.1, lam_or:float=0.5):
    opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay, betas=(0.9,0.95), eps=1e-8)
    sch = get_cosine_schedule_with_warmup(opt, num_warmup_steps=cfg.warmup_steps, num_training_steps=cfg.max_steps)
    scaler = torch.cuda.amp.GradScaler(enabled=cfg.amp)
    step=0; t0=time.time()
    for ep in range(cfg.epochs):
        for P,C,R in pref_loader:
            P, C, R = P.to(device), C.to(device), R.to(device)
            with torch.cuda.amp.autocast(enabled=cfg.amp):
                # 1) SFT - chosen에 대한 NLL (prompt+chosen 합쳐서 target=shifted)
                x = torch.cat([P, C], dim=1)
                targets = x[:, 1:].contiguous()
                out = model(x[:, :-1], targets=None)  # logits for next
                logp = F.log_softmax(out["logits"], dim=-1)
                tgt_logp = logp.gather(-1, targets.unsqueeze(-1)).squeeze(-1)
                # chosen 부분만 평균 NLL
                mask = torch.zeros_like(targets, dtype=torch.bool); mp = P.size(1)-1; mc = C.size(1)
                mask[:, mp:mp+mc] = True
                sft_loss = -(tgt_logp.masked_select(mask.to(device))).mean()

                # 2) ORPO ratio: chosen vs rejected 평균 로그확률 차이를 로지스틱 분류
                logp_ch = mean_logprob_for_response(model, P, C)  # (B,)
                logp_rj = mean_logprob_for_response(model, P, R)  # (B,)
                # 안정화: 극단값 clamp (bf16에서 NaN 방지)
                diff = torch.clamp(logp_ch - logp_rj, -50.0, 50.0)
                ratio_loss = -F.logsigmoid(beta * diff).mean()

                loss = sft_loss + lam_or * ratio_loss

            scaler.scale(loss).backward()
            if ((step+1)%cfg.grad_accum)==0:
                scaler.unscale_(opt)
                if cfg.grad_clip>0: nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
                scaler.step(opt); scaler.update(); opt.zero_grad(set_to_none=True); sch.step()
            step+=1
            if step%20==0:
                print(f"[ORPO] step {step}/{cfg.max_steps}  L_sft {sft_loss.item():.3f}  L_or {ratio_loss.item():.3f}  -> L {loss.item():.3f}")
            if step>=cfg.max_steps: break
        if step>=cfg.max_steps: break
    return {"steps":step, "time_sec":time.time()-t0}



### ORPO 실행 (하이브리드 모델 기반)
- 베이스로 **Global+SWA(+메타 on, KV 공유 on)** 모델을 사용


In [None]:
orpo_model = build_model(cfg, swa_layers=default_swa, use_meta=True, kv_share=True)
orpo_summary = train_orpo(orpo_model, pref_dl, cfg, beta=0.1, lam_or=0.5)
# SFT perplexity로 대략 확인(선호 데이터는 ppl 지표가 직접적이지 않음)
orpo_ppl = evaluate_ppl(orpo_model, sft_dl, max_batches=20)
orpo_ppl, orpo_summary



## 10) 추가 Ablation
- **메타 토큰 off**  
- **KV 공유 off** (메모리/속도 비교)


In [None]:
# 메타 토큰 OFF
hybrid_no_meta = build_model(cfg, swa_layers=default_swa, use_meta=False, kv_share=True)
hnm_sum = train_sft(hybrid_no_meta, sft_dl, cfg)
hnm_ppl = evaluate_ppl(hybrid_no_meta, sft_dl)
print("Hybrid(meta OFF) PPL:", hnm_ppl); hnm_sum


In [None]:
# KV 공유 OFF
hybrid_noshare = build_model(cfg, swa_layers=default_swa, use_meta=True, kv_share=False)
hns_sum = train_sft(hybrid_noshare, sft_dl, cfg)
hns_ppl = evaluate_ppl(hybrid_noshare, sft_dl)
b_share = hybrid_meta.estimate_kv_cache_bytes(cfg.seq_len)
b_noshare = hybrid_noshare.estimate_kv_cache_bytes(cfg.seq_len)
print("KV bytes  share:", round(b_share/2**20,2),"MB  |  no-share:", round(b_noshare/2**20,2),"MB")
hns_ppl, hns_sum
