
# Hymba (Hybrid Attention + Mamba) — 최신 토크나이저 · 자동 데이터 다운로드 · **ORPO(SFT+DPO 계열)** 동시학습
이 노트북은 기존 구현을 다음과 같이 보강합니다.

1) **최신 토크나이저**: 🤗 `tokenizers`의 **Unigram**(SentencePiece 계열)으로 **직접 학습**하여 사용  
2) **데이터 자동 다운로드**: 경로가 없으면 🤗 `datasets`에서 자동 로드(토치텍스트 미사용)  
3) **ORPO**: SFT 손실 + **상대 로그 오즈**(log-odds) 항을 결합한 **단일 단계** 선호 최적화 구현 및 해설  


In [1]:
from __future__ import annotations
import math, time, typing as T, os, warnings
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, random_split
from torch.nn.attention import sdpa_kernel, SDPBackend

# env / warnings
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
warnings.filterwarnings("ignore")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True


In [2]:
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import Unigram
from tokenizers.trainers import UnigramTrainer
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.normalizers import NFKC, Lowercase, Sequence as NormSeq

def get_corpus(hf_spec:str="karpathy/tiny_shakespeare") -> str:
    ds = load_dataset(hf_spec)
    col = "text" if "text" in ds["train"].column_names else ds["train"].column_names[0]
    return "\n\n".join(ds["train"][col])

def train_unigram(text:str, vocab_size:int=8000, unk:str="<|unk|>"):
    tk = Tokenizer(Unigram())
    tk.normalizer = NormSeq([NFKC(), Lowercase()])
    tk.pre_tokenizer = Whitespace()
    trainer = UnigramTrainer(vocab_size=vocab_size, special_tokens=[unk], unk_token=unk)
    tk.train_from_iterator([text], trainer=trainer)

    class Wrap:
        def __init__(self, tk): self.tk=tk
        def encode(self, s): return self.tk.encode(s).ids
        def decode(self, ids): return self.tk.decode(ids)
        @property
        def vocab_size(self): return self.tk.get_vocab_size()
    return Wrap(tk)

def make_stream_dataset(tok, text:str, seq_len:int=512) -> TensorDataset:
    import numpy as np
    ids = np.array(tok.encode(text), dtype=np.int64)
    if ids.size < seq_len+1: raise RuntimeError("Text too short")
    x = ids[:-1]; y = ids[1:]
    n = (len(y)//seq_len)*seq_len
    X = torch.tensor(x[:n].reshape(-1, seq_len))
    Y = torch.tensor(y[:n].reshape(-1, seq_len))
    return TensorDataset(X,Y)

def build_dataloaders(tok, text:str, seq_len:int=512, bs:int=32, workers:int=0, pin:bool=True):
    ds_full = make_stream_dataset(tok, text, seq_len)
    tr_len = int(0.95*len(ds_full)); va_len = len(ds_full)-tr_len
    tr, va = random_split(ds_full, [tr_len, va_len])
    train_dl = DataLoader(tr, batch_size=bs, shuffle=True, drop_last=True, num_workers=workers, pin_memory=pin)
    val_dl   = DataLoader(va, batch_size=bs, shuffle=False, num_workers=workers, pin_memory=pin)
    return train_dl, val_dl


In [3]:
class RMSNorm(nn.Module):
    def __init__(self, d:int, eps:float=1e-6):
        super().__init__(); self.eps=eps; self.w=nn.Parameter(torch.ones(d))
    def forward(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True)+self.eps) * self.w

class SwiGLU(nn.Module):
    def __init__(self, d:int, mult:float=4.0, p:float=0.0):
        super().__init__()
        h=int(d*mult)
        self.w1=nn.Linear(d,h,bias=False); self.w2=nn.Linear(d,h,bias=False); self.w3=nn.Linear(h,d,bias=False)
        self.drop=nn.Dropout(p)
    def forward(self, x): return self.w3(self.drop(F.silu(self.w1(x))*self.w2(x)))

class RotaryEmbedding(nn.Module):
    def __init__(self, dim:int, base:float=10000.0):
        super().__init__()
        inv = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("_inv", inv, persistent=False)
        self.register_buffer("_cos", None, persistent=False)
        self.register_buffer("_sin", None, persistent=False)
    def _build(self, L:int, device, dtype):
        if self._cos is not None and self._cos.size(0) >= L: return
        t = torch.arange(L, device=device, dtype=self._inv.dtype)
        freqs = torch.einsum("i,j->ij", t, self._inv)
        self._cos = torch.cos(freqs).to(dtype); self._sin = torch.sin(freqs).to(dtype)
    def apply(self, x:torch.Tensor, pos:torch.Tensor):
        # x: (..., Dh) with Dh even
        self._build(int(pos.max().item())+1, x.device, x.dtype)
        cos = self._cos.index_select(0, pos)[None,None,:,:]
        sin = self._sin.index_select(0, pos)[None,None,:,:]
        x1, x2 = x[..., ::2], x[..., 1::2]
        o1 = x1*cos - x2*sin; o2 = x1*sin + x2*cos
        return torch.stack([o1,o2], dim=-1).flatten(-2)


In [4]:
def _scaled_dot_attn(q, k, v, mask_2d: torch.Tensor | None, p: float, training: bool, is_causal: bool):
    on_cuda = q.is_cuda
    dtype_ok = q.dtype in (torch.float16, torch.bfloat16)
    want_flash = on_cuda and dtype_ok and (mask_2d is None) and is_causal
    if want_flash:
        backends = [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
    else:
        backends = [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
    with sdpa_kernel(backends):
        return F.scaled_dot_product_attention(
            q, k, v,
            attn_mask=mask_2d,
            dropout_p=p if training else 0.0,
            is_causal=is_causal and (mask_2d is None)
        )


In [5]:
class AttnLayer(nn.Module):
    """
    GQA + RoPE
    - 캐시에는 '비회전 K/V'만 저장
    - 매 호출 시 RoPE는 조회용으로만 적용
    - local=True → 슬라이딩윈도우 SWA (window tokens)
    """
    def __init__(self, d:int, n_heads:int, n_kv:int, local:bool=False, window:int=256, dropout:float=0.0):
        super().__init__()
        assert d % n_heads == 0, "d must be divisible by n_heads"
        assert n_heads % n_kv == 0, "n_heads must be divisible by n_kv for GQA"
        self.H = n_heads; self.KV = n_kv; self.Dh = d // n_heads
        self.q = nn.Linear(d, n_heads*self.Dh, bias=False)
        self.k = nn.Linear(d, n_kv*self.Dh, bias=False)
        self.v = nn.Linear(d, n_kv*self.Dh, bias=False)
        self.o = nn.Linear(n_heads*self.Dh, d, bias=False)
        self.rope = RotaryEmbedding(self.Dh)
        self.drop = nn.Dropout(dropout)
        self.local = local; self.window = window
        self.rep = self.H // self.KV

    def _local_slice(self, k, v):
        if not self.local: return k, v
        Tk = k.size(2)
        w = min(self.window, Tk)
        return k[:, :, Tk-w:Tk, :], v[:, :, Tk-w:Tk, :]

    def _apply_rope_for_ranges(self, q, k_cat, Tc:int, T:int):
        # q at positions [Tc-T, ..., Tc-1], k over [0, ..., Tc-1]
        pos_q = torch.arange(Tc - T, Tc, device=q.device)
        pos_k = torch.arange(0, Tc, device=k_cat.device)
        q = self.rope.apply(q, pos_q)
        k_rot = self.rope.apply(k_cat, pos_k)  # rotated view only
        return q, k_rot

    def forward(
        self,
        x: torch.Tensor,
        kv_cache: tuple[torch.Tensor, torch.Tensor] | None = None,
        role: str = "owner",
        global_mask: torch.Tensor | None = None,
    ):
        B, T, C = x.shape
        q = self.q(x).view(B, T, self.H, self.Dh).transpose(1, 2)  # (B,H,T,Dh)

        if role == "follower":
            assert kv_cache is not None and kv_cache[0] is not None, "Follower requires owner's KV cache"
            k_owner, v_owner = kv_cache                   # unrotated (B,KV,Tc,Dh)
            Tc = k_owner.size(2)
            # RoPE apply for q and k (rotated view only)
            q, k_rot = self._apply_rope_for_ranges(q, k_owner, Tc=Tc, T=T)
            v_full = v_owner
            # KV→H replicate
            k_full = k_rot.repeat_interleave(self.rep, dim=1)  # (B,H,Tc,Dh)
            v_full = v_full.repeat_interleave(self.rep, dim=1) # (B,H,Tc,Dh)
            # local slice
            k_full, v_full = self._local_slice(k_full, v_full)
            Tk = k_full.size(2)
            # SDPA
            out = _scaled_dot_attn(
                q.reshape(B*self.H, T, self.Dh),
                k_full.reshape(B*self.H, Tk, self.Dh),
                v_full.reshape(B*self.H, Tk, self.Dh),
                mask_2d=None, p=float(self.drop.p), training=self.training, is_causal=True
            )
            out = out.view(B, self.H, T, self.Dh).transpose(1, 2).reshape(B, T, self.H * self.Dh)
            return self.o(out), None

        # owner path: project new K/V (unrotated), concat to unrotated cache
        k_new = self.k(x).view(B, T, self.KV, self.Dh).transpose(1, 2)  # (B,KV,T,Dh)
        v_new = self.v(x).view(B, T, self.KV, self.Dh).transpose(1, 2)
        if kv_cache is not None and kv_cache[0] is not None and kv_cache[0].numel() > 0:
            k_prev, v_prev = kv_cache  # unrotated
            k_cat = torch.cat([k_prev, k_new], dim=2)
            v_cat = torch.cat([v_prev, v_new], dim=2)
        else:
            k_cat, v_cat = k_new, v_new

        Tc = k_cat.size(2)
        # RoPE rotated view for attention only
        q, k_rot = self._apply_rope_for_ranges(q, k_cat, Tc=Tc, T=T)

        # KV→H replicate
        k_full = k_rot.repeat_interleave(self.rep, dim=1)
        v_full = v_cat.repeat_interleave(self.rep, dim=1)
        # local slice
        k_full, v_full = self._local_slice(k_full, v_full)
        Tk = k_full.size(2)

        out = _scaled_dot_attn(
            q.reshape(B*self.H, T, self.Dh),
            k_full.reshape(B*self.H, Tk, self.Dh),
            v_full.reshape(B*self.H, Tk, self.Dh),
            mask_2d=None, p=float(self.drop.p), training=self.training, is_causal=True
        )
        out = out.view(B, self.H, T, self.Dh).transpose(1, 2).reshape(B, T, self.H * self.Dh)
        out = self.o(out)

        # return unrotated cache
        return out, (k_cat, v_cat)


In [6]:
class Block(nn.Module):
    """하나의 어텐션(+FFN) 블록. local=True면 SWA, False면 Global."""
    def __init__(self, d:int, n_heads:int, n_kv:int, local:bool, window:int, dropout:float):
        super().__init__()
        self.pre = RMSNorm(d)
        self.attn = AttnLayer(d, n_heads, n_kv, local=local, window=window, dropout=dropout)
        self.post = RMSNorm(d)
        self.ffn = SwiGLU(d, mult=4.0, p=dropout)
    def forward(self, x, kv_cache=None, global_mask=None, training=True, role:str="owner"):
        h = self.pre(x)
        a, new_cache = self.attn(h, kv_cache=kv_cache if not training else None, role=role, global_mask=global_mask)
        x = x + a
        x = x + self.ffn(self.post(x))
        return x, new_cache


In [7]:
@dataclass
class ModelCfg:
    vocab_size: int
    d_model: int = 384
    n_layers: int = 12
    n_heads: int = 6
    n_kv_heads: int = 2
    dropout: float = 0.0
    seq_len: int = 512
    swa_layers: T.Tuple[int,...] = (1,2,3,4,5,7,8,9,10)
    swa_window: int = 256
    num_meta_tokens: int = 0

class HymbaV2(nn.Module):
    def __init__(self, cfg:ModelCfg):
        super().__init__(); self.cfg=cfg
        self.embed = nn.Embedding(cfg.vocab_size, cfg.d_model)

        self.meta = None
        if cfg.num_meta_tokens > 0:
            self.meta = nn.Parameter(torch.randn(1, cfg.num_meta_tokens, cfg.d_model) * 0.02)

        self.blocks = nn.ModuleList()
        self.swa_layers = set(cfg.swa_layers)
        for li in range(cfg.n_layers):
            is_local = (li in self.swa_layers)
            self.blocks.append(Block(cfg.d_model, cfg.n_heads, cfg.n_kv_heads,
                                     local=is_local, window=cfg.swa_window, dropout=cfg.dropout))
        self.norm = RMSNorm(cfg.d_model)
        self.head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)

        # KV-share 그룹/owner 도출
        self.owner = list(range(cfg.n_layers))
        self.kv_group_id = [0]*cfg.n_layers
        swa = self.swa_layers
        gid = -1; i=0; N=cfg.n_layers
        while i < N:
            if i in swa:
                j=i
                while j<N and (j in swa): j+=1
                k=i
                while k<j:
                    if k+1<j:
                        gid += 1
                        self.kv_group_id[k]=gid; self.kv_group_id[k+1]=gid
                        self.owner[k]=k; self.owner[k+1]=k
                        k+=2
                    else:
                        gid += 1
                        self.kv_group_id[k]=gid; self.owner[k]=k
                        k+=1
                i=j
            else:
                gid += 1
                self.kv_group_id[i]=gid; self.owner[i]=i
                i+=1

    def forward(self, input_ids:torch.LongTensor, targets:torch.LongTensor|None=None):
        """
        Meta tokens:
          - 입력 임베딩 앞에 [num_meta_tokens]개를 prepend.
          - loss는 항상 '다음 토큰 예측' 규칙으로 정렬.
        """
        B, T = input_ids.shape
        x = self.embed(input_ids)                         # (B,T,D)
        M = 0
        if self.meta is not None:
            M = self.meta.size(1)
            x = torch.cat([self.meta.expand(B, -1, -1), x], dim=1)  # (B,M+T,D)

        h = x
        for li, blk in enumerate(self.blocks):
            h,_ = blk(h, kv_cache=None, global_mask=None, training=True, role="owner")
        h = self.norm(h)
        logits = self.head(h)                             # (B,M+T,V)

        out = {"logits": logits}
        if targets is not None:
            if M > 0:
                # logits positions [M .. M+T-2] predict targets [1 .. T-1]
                logits_for_loss = logits[:, M: M+T-1, :]
                targets_for_loss = targets[:, 1:]
            else:
                logits_for_loss = logits[:, :-1, :]
                targets_for_loss = targets[:, 1:]
            loss = F.cross_entropy(
                logits_for_loss.reshape(-1, logits.size(-1)),
                targets_for_loss.reshape(-1)
            )
            out["loss"] = loss
        return out

    # ------ Generate helper ------
    def _owner_map_for(self, kv_share:bool):
        return self.owner if kv_share else list(range(len(self.blocks)))

    def _forward_blocks_once(self, h, owners, kv, global_mask=None):
        for li, blk in enumerate(self.blocks):
            owner = owners[li]
            role = "owner" if li == owner else "follower"
            h, kv_out = blk(h, kv_cache=kv[owner], global_mask=global_mask, training=False, role=role)
            if li == owner:
                kv[owner] = kv_out
        return h, kv

    def _forward_blocks_full_recompute(self, ids, global_mask=None):
        h = self.embed(ids)
        for blk in self.blocks:
            h,_ = blk(h, kv_cache=None, global_mask=global_mask, training=False, role="owner")
        return h

    @torch.no_grad()
    def generate(self, input_ids:torch.LongTensor, max_new_tokens:int=64,
                 temperature:float=1.0, top_k:int=0, eos_token_id:int|None=None,
                 use_kv_cache:bool=True, kv_share:bool=True):
        device = next(self.parameters()).device
        self.eval()
        ids = input_ids.to(device)

        if use_kv_cache:
            owners = self._owner_map_for(kv_share)
            kv = [None]*len(self.blocks)
            h = self.embed(ids)
            h, kv = self._forward_blocks_once(h, owners, kv, global_mask=None)
        else:
            h = self._forward_blocks_full_recompute(ids, global_mask=None)

        for _ in range(max_new_tokens):
            if use_kv_cache:
                x_step = self.embed(ids[:, -1:])
                h = x_step
                h, kv = self._forward_blocks_once(h, owners, kv, global_mask=None)
                h = self.norm(h); logits = self.head(h)[:, -1, :]
            else:
                h = self._forward_blocks_full_recompute(ids, global_mask=None)
                h = self.norm(h); logits = self.head(h)[:, -1, :]

            if temperature <= 0:
                next_id = torch.argmax(logits, dim=-1, keepdim=True)
            else:
                logits = logits / temperature
                if top_k and top_k < logits.size(-1):
                    topk_vals, topk_idx = torch.topk(logits, top_k, dim=-1)
                    mask = torch.full_like(logits, float("-inf"))
                    mask.scatter_(1, topk_idx, topk_vals)
                    logits = mask
                probs = F.softmax(logits, dim=-1)
                next_id = torch.multinomial(probs, num_samples=1)

            ids = torch.cat([ids, next_id], dim=1)
            if eos_token_id is not None and bool((next_id == eos_token_id).all()):
                break
        return ids

    # ------ Utils ------
    def layer_table(self):
        import pandas as pd
        rows=[]
        for i,_ in enumerate(self.blocks):
            rows.append({
                "layer": i,
                "attn": "LOCAL(SWA)" if i in self.swa_layers else "GLOBAL",
                "kv_owner": self.owner[i],
                "kv_share_group": self.kv_group_id[i],
            })
        return pd.DataFrame(rows)

    def estimate_kv_cache_mb(self, seq_len:int, dtype=torch.float16):
        Dh = self.cfg.d_model // self.cfg.n_heads
        KV = max(1, self.cfg.n_kv_heads)
        bytes_per = torch.finfo(dtype).bits // 8
        owners = len(set(self.owner))
        per_owner = 2 * KV * seq_len * Dh * bytes_per
        return round(per_owner * owners / (1024**2), 3)


In [8]:
@dataclass
class TrainCfg:
    seq_len:int=512
    batch_size:int=32
    steps:int=600
    lr:float=6e-4
    warmup:int=100
    amp:bool=True
    wd:float=0.1
    grad_clip:float=1.0

def adamw_param_groups(model:nn.Module, wd:float):
    decay, no_decay = [], []
    for n,p in model.named_parameters():
        if not p.requires_grad: continue
        if p.dim() >= 2 and "norm" not in n.lower():
            decay.append(p)
        else:
            no_decay.append(p)
    return [{"params": decay, "weight_decay": wd},
            {"params": no_decay, "weight_decay": 0.0}]

def train_loop(model:HymbaV2, train_dl, val_dl, tcfg:TrainCfg, device:str="cuda"):
    import itertools, math
    from transformers import get_cosine_schedule_with_warmup
    from torch.amp import GradScaler, autocast

    torch.manual_seed(1337)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(1337)

    model.to(device).train()
    pg = adamw_param_groups(model, wd=tcfg.wd)
    opt = torch.optim.AdamW(pg, lr=tcfg.lr, betas=(0.9,0.95), eps=1e-8,
                            fused=torch.cuda.is_available())
    sch = get_cosine_schedule_with_warmup(opt, tcfg.warmup, tcfg.steps)
    scaler = GradScaler(device="cuda" if (device=="cuda" and torch.cuda.is_available()) else "cpu",
                        enabled=tcfg.amp)

    it = itertools.cycle(train_dl)
    step=0; tok_count=0; train_nll=0.0; train_tok=0
    t0=time.time()

    while step < tcfg.steps:
        xb,yb = next(it)
        xb,yb = xb.to(device,non_blocking=True), yb.to(device,non_blocking=True)
        with autocast(device_type=("cuda" if (device=="cuda" and torch.cuda.is_available()) else "cpu"),
                      enabled=tcfg.amp):
            out = model(xb, targets=yb)
            loss = out["loss"]

        train_nll += float(loss.detach())*xb.numel(); train_tok += xb.numel()
        scaler.scale(loss).backward()
        scaler.unscale_(opt)
        if tcfg.grad_clip>0: nn.utils.clip_grad_norm_(model.parameters(), tcfg.grad_clip)
        scaler.step(opt); scaler.update(); opt.zero_grad(set_to_none=True); sch.step()
        step += 1; tok_count += xb.numel()

        if step==1 or step%50==0:
            lr_now = opt.param_groups[0]["lr"]
            print(f"[{step:5d}] loss={loss.item():.3f} lr={lr_now:.2e}")

    elapsed = time.time()-t0
    tps = int(tok_count/max(1e-9, elapsed))
    train_loss = train_nll/max(1,train_tok)

    # validation
    model.eval(); val_nll=0.0; val_tok=0
    with torch.no_grad(), torch.amp.autocast("cuda", enabled=tcfg.amp and (device=="cuda")):
        for xb,yb in val_dl:
            xb,yb = xb.to(device), yb.to(device)
            out = model(xb, targets=yb)
            val_nll += float(out["loss"].detach())*xb.numel(); val_tok += xb.numel()
    val_loss = val_nll/max(1,val_tok); ppl = math.exp(val_loss)
    return {"train_loss": float(train_loss), "val_loss": float(val_loss), "ppl": float(ppl), "tps": tps}


In [9]:
def build_everything(seq_len:int=512, bs:int=32, vocab_size:int=8000):
    text = get_corpus("karpathy/tiny_shakespeare")
    tok = train_unigram(text, vocab_size=vocab_size)
    train_dl, val_dl = build_dataloaders(tok, text, seq_len=seq_len, bs=bs)

    cfg = ModelCfg(vocab_size=tok.vocab_size, seq_len=seq_len)
    model = HymbaV2(cfg)
    return model, tok, train_dl, val_dl


In [11]:
# (E) 실행
model, tok, train_dl, val_dl = build_everything(seq_len=512, bs=32, vocab_size=8000)
tcfg = TrainCfg(steps=1000, warmup=100, amp=True)
stats = train_loop(model, train_dl, val_dl, tcfg, device="cuda" if torch.cuda.is_available() else "cpu")
stats




[    1] loss=8.795 lr=6.00e-06
[   50] loss=5.852 lr=3.00e-04
[  100] loss=4.739 lr=6.00e-04
[  150] loss=2.822 lr=5.95e-04
[  200] loss=1.680 lr=5.82e-04
[  250] loss=0.397 lr=5.60e-04
[  300] loss=0.048 lr=5.30e-04
[  350] loss=0.009 lr=4.93e-04
[  400] loss=0.012 lr=4.50e-04
[  450] loss=0.003 lr=4.03e-04
[  500] loss=0.003 lr=3.52e-04
[  550] loss=0.002 lr=3.00e-04
[  600] loss=0.001 lr=2.48e-04
[  650] loss=0.000 lr=1.97e-04
[  700] loss=0.000 lr=1.50e-04
[  750] loss=0.000 lr=1.07e-04
[  800] loss=0.000 lr=7.02e-05
[  850] loss=0.000 lr=4.02e-05
[  900] loss=0.000 lr=1.81e-05
[  950] loss=0.000 lr=4.56e-06
[ 1000] loss=0.000 lr=0.00e+00


{'train_loss': 0.9886284915425604,
 'val_loss': 0.4139203131198883,
 'ppl': 1.5127365768230914,
 'tps': 176810}


## 9) **ORPO (SFT + 로그 오즈 선호항)** 구현 및 해설
- **핵심**: SFT의 NLL 손실에 **선호 대비 항**을 추가  
- ORPO 논문은 **참조 모델 없이**, `L = L_SFT + β · L_ratio`를 제안  
- 여기서 `L_ratio = -log σ( Δ )`, `Δ ≈ (avg log p_θ(chosen|x) - avg log p_θ(rejected|x))`  
  - 실무적으로 TRL 구현은 **길이 정규화된 토큰 평균 로그확률 차이**에 로지스틱을 적용합니다.  


In [12]:
# ===============================
# ORPO: HF 데이터셋 로드 → pairs 생성
# ===============================
# !pip install -q datasets
from datasets import load_dataset

# 사용 가능한 예:
#   "Anthropic/hh-rlhf"
#   "Dahoas/rm-static"
#   "HuggingFaceH4/ultrafeedback_binarized"
DATASET = "Anthropic/hh-rlhf"

ds = load_dataset(DATASET)

def make_pairs_from_hf(ds, dataset_name:str):
    """
    반환 형식: [(prompt, chosen, rejected), ...]
    """
    pairs = []
    split = "train" if "train" in ds else list(ds.keys())[0]
    for ex in ds[split]:
        if "Anthropic/hh-rlhf" in dataset_name:
            prompt = ""
            pos = ex["chosen"]
            neg = ex["rejected"]
        elif "Dahoas/rm-static" in dataset_name:
            prompt = ex.get("prompt","")
            pos = ex.get("chosen", ex.get("response",""))
            neg = ex.get("rejected","")
        elif "ultrafeedback_binarized" in dataset_name:
            prompt = ex.get("prompt","")
            pos = ex.get("chosen","")
            neg = ex.get("rejected","")
        else:
            continue
        if pos and neg and pos != neg:
            pairs.append((prompt, pos, neg))
    if not pairs:
        raise RuntimeError("pairs 생성 실패: 컬럼 확인 필요")
    return pairs

pairs = make_pairs_from_hf(ds, DATASET)
len(pairs)


README.md: 0.00B [00:00, ?B/s]

harmless-base/train.jsonl.gz:   0%|          | 0.00/13.2M [00:00<?, ?B/s]

helpful-base/train.jsonl.gz:   0%|          | 0.00/16.2M [00:00<?, ?B/s]

helpful-online/train.jsonl.gz:   0%|          | 0.00/20.1M [00:00<?, ?B/s]

helpful-rejection-sampled/train.jsonl.gz:   0%|          | 0.00/25.7M [00:00<?, ?B/s]

harmless-base/test.jsonl.gz:   0%|          | 0.00/743k [00:00<?, ?B/s]

helpful-base/test.jsonl.gz:   0%|          | 0.00/875k [00:00<?, ?B/s]

helpful-online/test.jsonl.gz:   0%|          | 0.00/1.05M [00:00<?, ?B/s]

helpful-rejection-sampled/test.jsonl.gz:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/160800 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/8552 [00:00<?, ? examples/s]

160060

In [13]:
# ==========================================
# ORPO 텍스트 → 토큰 텐서 리스트 생성기
# ==========================================
import torch
from typing import List, Tuple

# 전제: tok, seq_len 가 존재해야 함

def _chunks(ids: List[int], need:int, seq_len:int) -> List[List[int]]:
    out=[]
    L=len(ids)
    for s in range(0, L - need + 1, seq_len):
        out.append(ids[s:s+need])
    return out

def build_orpo_tensors(
    pairs: List[Tuple[str,str,str]],
    tok,
    seq_len: int
):
    """
    반환:
      pos_ids_list, pos_tgts_list, neg_ids_list, neg_tgts_list
      각 원소 텐서 크기: (seq_len,)
    """
    pos_ids_list=[]; pos_tgts_list=[]
    neg_ids_list=[]; neg_tgts_list=[]
    need = seq_len + 1  # 입력 T, 타깃 T 정렬용

    for prompt, pos, neg in pairs:
        text_pos = (prompt + "\n" + pos) if prompt else pos
        text_neg = (prompt + "\n" + neg) if prompt else neg

        pos_ids = tok.encode(text_pos)
        neg_ids = tok.encode(text_neg)

        pos_chunks = _chunks(pos_ids, need, seq_len)
        neg_chunks = _chunks(neg_ids, need, seq_len)

        m = min(len(pos_chunks), len(neg_chunks))
        for i in range(m):
            p = pos_chunks[i]; n = neg_chunks[i]

            p_x = torch.tensor(p[:-1], dtype=torch.long)  # (T,)
            p_y = torch.tensor(p[1:],  dtype=torch.long)  # (T,)
            n_x = torch.tensor(n[:-1], dtype=torch.long)
            n_y = torch.tensor(n[1:],  dtype=torch.long)

            pos_ids_list.append(p_x); pos_tgts_list.append(p_y)
            neg_ids_list.append(n_x); neg_tgts_list.append(n_y)

    if not (len(pos_ids_list)==len(neg_ids_list)==len(pos_tgts_list)==len(neg_tgts_list)):
        raise RuntimeError("ORPO 리스트 길이 불일치")
    if len(pos_ids_list)==0:
        raise RuntimeError("샘플 0개. seq_len을 줄이거나 텍스트 길이를 늘릴 것")

    return pos_ids_list, pos_tgts_list, neg_ids_list, neg_tgts_list


In [14]:
# =====================
# ORPO 기본 유틸/로스
# =====================
import torch
import torch.nn.functional as F
from torch import nn
from dataclasses import dataclass

@torch.no_grad()
def _meta_tokens(model) -> int:
    return 0 if getattr(model, "meta", None) is None else int(model.meta.size(1))

def _seq_logprob(model, input_ids:torch.LongTensor, targets:torch.LongTensor) -> tuple[torch.Tensor, torch.Tensor]:
    """
    returns:
      token_logp_sum: (B,)
      token_count:    (B,)
    모델의 loss 정렬과 동일하게 처리:
      - meta M>0: logits[:, M:M+T-1] vs targets[:, 1:]
      - meta M=0: logits[:, :-1]     vs targets[:, 1:]
    """
    out = model(input_ids, targets=None)
    logits = out["logits"]  # (B, M+T, V)
    B, L, V = logits.shape
    M = _meta_tokens(model)

    if M > 0:
        logit_slice = logits[:, M:L-1, :]    # (B, T-1, V)
        gold_tokens = targets[:, 1:]         # (B, T-1)
    else:
        logit_slice = logits[:, :L-1, :]
        gold_tokens = targets[:, 1:]

    logp = F.log_softmax(logit_slice, dim=-1)                         # (B, T-1, V)
    token_logp = torch.gather(logp, 2, gold_tokens.unsqueeze(-1)).squeeze(-1)  # (B, T-1)
    token_logp_sum = token_logp.sum(dim=1)                             # (B,)
    token_count = torch.full((B,), token_logp.size(1), device=token_logp.device, dtype=torch.long)
    return token_logp_sum, token_count

@dataclass
class ORPOCfg:
    beta: float = 1.0
    sft_weight: float = 1.0
    amp: bool = True

def orpo_loss(model:nn.Module,
              pos_ids:torch.LongTensor, pos_tgts:torch.LongTensor,
              neg_ids:torch.LongTensor, neg_tgts:torch.LongTensor,
              cfg:ORPOCfg) -> dict:
    """
    L = sft_weight * CE(y+|x)  +  mean( -log σ[ beta * (log p(y+|x) - log p(y-|x)) ] )
    반환: {"loss", "sft_loss", "orpo_term", "margin_mean"}
    """
    sft_out = model(pos_ids, targets=pos_tgts)
    sft_loss = sft_out["loss"]

    pos_lp_sum, pos_cnt = _seq_logprob(model, pos_ids, pos_tgts)
    neg_lp_sum, neg_cnt = _seq_logprob(model, neg_ids, neg_tgts)

    pos_lp = pos_lp_sum / pos_cnt.clamp_min(1)
    neg_lp = neg_lp_sum / neg_cnt.clamp_min(1)

    margin = pos_lp - neg_lp                       # (B,)
    orpo_term = -F.logsigmoid(cfg.beta * margin)   # (B,)
    loss = cfg.sft_weight * sft_loss + orpo_term.mean()

    return {
        "loss": loss,
        "sft_loss": sft_loss.detach(),
        "orpo_term": orpo_term.mean().detach(),
        "margin_mean": margin.mean().detach(),
    }


In [15]:
# =============================
# ORPO용 Pair 데이터셋/Collate
# =============================
from torch.utils.data import Dataset, DataLoader

class PairDataset(Dataset):
    """
    각 아이템: (pos_ids, pos_tgts, neg_ids, neg_tgts)
    각 텐서 크기: (T,)
    """
    def __init__(self, pos_ids, pos_tgts, neg_ids, neg_tgts):
        assert len(pos_ids)==len(neg_ids)==len(pos_tgts)==len(neg_tgts)
        self.pos_ids = pos_ids
        self.pos_tgts = pos_tgts
        self.neg_ids = neg_ids
        self.neg_tgts = neg_tgts
    def __len__(self): return len(self.pos_ids)
    def __getitem__(self, i):
        return (self.pos_ids[i], self.pos_tgts[i],
                self.neg_ids[i], self.neg_tgts[i])

def collate_pairs(batch):
    pos_ids  = torch.stack([b[0] for b in batch], dim=0)
    pos_tgts = torch.stack([b[1] for b in batch], dim=0)
    neg_ids  = torch.stack([b[2] for b in batch], dim=0)
    neg_tgts = torch.stack([b[3] for b in batch], dim=0)
    return pos_ids, pos_tgts, neg_ids, neg_tgts


In [16]:
# =============================
# ORPO용 Pair 데이터셋/Collate
# =============================
from torch.utils.data import Dataset, DataLoader

class PairDataset(Dataset):
    """
    각 아이템: (pos_ids, pos_tgts, neg_ids, neg_tgts)
    각 텐서 크기: (T,)
    """
    def __init__(self, pos_ids, pos_tgts, neg_ids, neg_tgts):
        assert len(pos_ids)==len(neg_ids)==len(pos_tgts)==len(neg_tgts)
        self.pos_ids = pos_ids
        self.pos_tgts = pos_tgts
        self.neg_ids = neg_ids
        self.neg_tgts = neg_tgts
    def __len__(self): return len(self.pos_ids)
    def __getitem__(self, i):
        return (self.pos_ids[i], self.pos_tgts[i],
                self.neg_ids[i], self.neg_tgts[i])

def collate_pairs(batch):
    pos_ids  = torch.stack([b[0] for b in batch], dim=0)
    pos_tgts = torch.stack([b[1] for b in batch], dim=0)
    neg_ids  = torch.stack([b[2] for b in batch], dim=0)
    neg_tgts = torch.stack([b[3] for b in batch], dim=0)
    return pos_ids, pos_tgts, neg_ids, neg_tgts


In [17]:
# ================
# ORPO 학습 루프
# ================
from torch.amp import autocast, GradScaler
from dataclasses import dataclass

@dataclass
class ORPOTrainCfg:
    steps:int = 1000
    batch_size:int = 16
    lr:float = 2e-5
    warmup:int = 100
    wd:float = 0.0
    grad_clip:float = 1.0
    beta:float = 1.0
    sft_weight:float = 1.0
    amp:bool = True

def _adamw_groups(model:nn.Module, wd:float):
    decay, no_decay = [], []
    for n,p in model.named_parameters():
        if not p.requires_grad: continue
        if p.dim() >= 2 and "norm" not in n.lower():
            decay.append(p)
        else:
            no_decay.append(p)
    return [{"params": decay, "weight_decay": wd},
            {"params": no_decay, "weight_decay": 0.0}]

def train_orpo(model:nn.Module, pair_loader:DataLoader, val_loader:DataLoader|None,
               cfg:ORPOTrainCfg, device:str="cuda"):
    import itertools, time, math
    from transformers import get_cosine_schedule_with_warmup

    torch.manual_seed(1337)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(1337)

    model.to(device).train()
    pg = _adamw_groups(model, wd=cfg.wd)
    opt = torch.optim.AdamW(pg, lr=cfg.lr, betas=(0.9,0.95), eps=1e-8,
                            fused=torch.cuda.is_available())
    sch = get_cosine_schedule_with_warmup(opt, cfg.warmup, cfg.steps)
    scaler = GradScaler(device="cuda" if (device=="cuda" and torch.cuda.is_available()) else "cpu",
                        enabled=cfg.amp)

    it = itertools.cycle(pair_loader)
    step=0; t0=time.time()
    log = []

    while step < cfg.steps:
        pos_ids, pos_tgts, neg_ids, neg_tgts = next(it)
        pos_ids = pos_ids.to(device, non_blocking=True)
        pos_tgts = pos_tgts.to(device, non_blocking=True)
        neg_ids = neg_ids.to(device, non_blocking=True)
        neg_tgts = neg_tgts.to(device, non_blocking=True)

        with autocast(device_type=("cuda" if (device=="cuda" and torch.cuda.is_available()) else "cpu"),
                      enabled=cfg.amp):
            loss_dict = orpo_loss(
                model,
                pos_ids, pos_tgts,
                neg_ids, neg_tgts,
                ORPOCfg(beta=cfg.beta, sft_weight=cfg.sft_weight, amp=cfg.amp),
            )
            loss = loss_dict["loss"]

        scaler.scale(loss).backward()
        scaler.unscale_(opt)
        if cfg.grad_clip>0: nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
        scaler.step(opt); scaler.update(); opt.zero_grad(set_to_none=True); sch.step()
        step += 1

        if step==1 or step%50==0:
            print(f"[{step:5d}] loss={loss.item():.4f} sft={loss_dict['sft_loss']:.4f} "
                  f"orpo={loss_dict['orpo_term']:.4f} margin={loss_dict['margin_mean']:.4f}")
        if step%50==0:
            log.append({k: float(v) for k,v in loss_dict.items()})

    elapsed = time.time()-t0
    metrics = {"train_time_s": elapsed}
    if val_loader is not None:
        model.eval()
        wins=0; total=0; nll_sum=0.0; tok_sum=0
        with torch.no_grad(), torch.amp.autocast("cuda", enabled=cfg.amp and (device=="cuda")):
            for pos_ids, pos_tgts, neg_ids, neg_tgts in val_loader:
                pos_ids = pos_ids.to(device); pos_tgts = pos_tgts.to(device)
                neg_ids = neg_ids.to(device); neg_tgts = neg_tgts.to(device)
                pos_lp_sum, pos_cnt = _seq_logprob(model, pos_ids, pos_tgts)
                neg_lp_sum, neg_cnt = _seq_logprob(model, neg_ids, neg_tgts)
                wins += int((pos_lp_sum/pos_cnt - neg_lp_sum/neg_cnt).sum().item() > 0)
                total += pos_ids.size(0)

                out = model(pos_ids, targets=pos_tgts)
                nll_sum += float(out["loss"]) * pos_ids.numel()
                tok_sum += pos_ids.numel()
        if total>0: metrics["win_rate"] = wins/total
        if tok_sum>0:
            val_loss = nll_sum/tok_sum
            metrics["val_ppl_pos"] = math.exp(val_loss)
        model.train()
    return metrics, log


In [19]:
# =================
# 사용 예시 (전체)
# =================
# 1) HF pairs → 텐서 리스트
seq_len = 512
pos_ids_list, pos_tgts_list, neg_ids_list, neg_tgts_list = build_orpo_tensors(pairs, tok, seq_len)

# 2) DataLoader
pair_ds = PairDataset(pos_ids_list, pos_tgts_list, neg_ids_list, neg_tgts_list)
pair_dl = DataLoader(pair_ds, batch_size=16, shuffle=True, drop_last=True, collate_fn=collate_pairs)
val_dl  = DataLoader(pair_ds, batch_size=16, shuffle=False, drop_last=False, collate_fn=collate_pairs)

# 3) 학습
orpo_cfg = ORPOTrainCfg(steps=1000, batch_size=16, lr=2e-5, warmup=100, beta=1.0, sft_weight=1.0, amp=True)
metrics, log = train_orpo(model, pair_dl, val_dl, orpo_cfg, device="cuda" if torch.cuda.is_available() else "cpu")
metrics


[    1] loss=1.7929 sft=1.0949 orpo=0.6980 margin=-0.0096
[   50] loss=1.4319 sft=0.7454 orpo=0.6864 margin=0.0137
[  100] loss=1.1596 sft=0.4718 orpo=0.6878 margin=0.0110
[  150] loss=1.0718 sft=0.3793 orpo=0.6925 margin=0.0012
[  200] loss=1.0267 sft=0.3352 orpo=0.6915 margin=0.0033
[  250] loss=0.9722 sft=0.2789 orpo=0.6932 margin=-0.0001
[  300] loss=0.9404 sft=0.2466 orpo=0.6939 margin=-0.0014
[  350] loss=0.9272 sft=0.2329 orpo=0.6943 margin=-0.0022
[  400] loss=0.9191 sft=0.2273 orpo=0.6918 margin=0.0027
[  450] loss=0.8798 sft=0.1853 orpo=0.6945 margin=-0.0027
[  500] loss=0.8676 sft=0.1726 orpo=0.6950 margin=-0.0037
[  550] loss=0.8548 sft=0.1599 orpo=0.6949 margin=-0.0035
[  600] loss=0.8440 sft=0.1504 orpo=0.6936 margin=-0.0009
[  650] loss=0.8272 sft=0.1301 orpo=0.6971 margin=-0.0078
[  700] loss=0.8155 sft=0.1250 orpo=0.6905 margin=0.0053
[  750] loss=0.8474 sft=0.1537 orpo=0.6937 margin=-0.0010
[  800] loss=0.8161 sft=0.1213 orpo=0.6948 margin=-0.0033
[  850] loss=0.8145 

{'train_time_s': 163.42869114875793,
 'win_rate': 0.03599120952025542,
 'val_ppl_pos': 1.1234024763550756}