# Hymba Refactor — Clean Training & Ablation

In [1]:
import os, warnings, logging
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")   # fork 후 deadlock 경고 제거
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "3")         # TF 로그 억제
os.environ.setdefault("HF_DATASETS_DISABLE_TF_WARNING", "1")

'1'

In [None]:
import math, time
import torch
import pandas as pd

from torch.utils.data import DataLoader, TensorDataset, random_split
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import Unigram
from tokenizers.trainers import UnigramTrainer
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.normalizers import NFKC, Lowercase, Sequence as NormSeq
from tokenizers.processors import TemplateProcessing
from backbone.hymba_v1 import HymbaRef, ModelCfg

In [3]:
USE_AMP_DEVICE = "cuda"  # or "cpu" if no GPU

In [4]:
# 로깅/워닝 억제
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", message="pkg_resources is deprecated as an API", category=UserWarning)
logging.getLogger().setLevel(logging.ERROR)   # JIT 커널 불러오기 경고 포함 대부분 억제

# DataLoader는 num_workers=0 권장(토크나이저 포크 경고 방지)
DATALOADER_NUM_WORKERS = 0
PIN_MEMORY = True

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(42); 
if torch.cuda.is_available(): torch.cuda.manual_seed_all(42)

In [6]:
# 1) Unigram tokenizer (compact)
def get_corpus(spec:str):
    try:
        ds = load_dataset(spec)
        for k in ["text","content","document","raw","data"]:
            if k in ds["train"].column_names:
                return "\n\n".join(ds["train"][k])
        return "\n\n".join(map(str, ds["train"][:1000]))
    except Exception:
        with open(spec, "r", encoding="utf-8", errors="ignore") as f:
            return f.read()

In [7]:
def train_unigram(text, vocab_size=16000,
                  bos="<|bos|>", eos="<|eos|>", pad="<|pad|>", unk="<|unk|>"):
    # 1) 모델/전처리
    tok = Tokenizer(Unigram())
    tok.normalizer = NormSeq([NFKC(), Lowercase()])
    tok.pre_tokenizer = Whitespace()

    # 2) 트레이너: UNK 반드시 지정 + 특수토큰 등록
    trainer = UnigramTrainer(
        vocab_size=vocab_size,
        special_tokens=[pad, bos, eos, unk],
        unk_token=unk
    )

    # 3) 한 번만 학습 (iterator로 스트리밍)
    def it():
        bs = 1_000_000
        for i in range(0, len(text), bs):
            yield text[i:i+bs]
    tok.train_from_iterator(it(), trainer=trainer)

    # 4) ID 조회 및 후처리
    pad_id = tok.token_to_id(pad)
    bos_id = tok.token_to_id(bos)
    eos_id = tok.token_to_id(eos)
    unk_id = tok.token_to_id(unk)

    tok.post_processor = TemplateProcessing(
        single=f"{bos} $A {eos}",
        pair=f"{bos} $A {eos} $B:1 {eos}:1",
        special_tokens=[(bos, bos_id), (eos, eos_id)],
    )

    class ModernTok:
        def __init__(self, tk, bos_id, eos_id, pad_id, unk_id):
            self.tk = tk
            self.bos_token_id = bos_id
            self.eos_token_id = eos_id
            self.pad_token_id = pad_id
            self.unk_token_id = unk_id
        def encode(self, s): return self.tk.encode(s).ids
        def decode(self, ids): return self.tk.decode(ids)
        @property
        def vocab_size(self): return self.tk.get_vocab_size()

    return ModernTok(tok, bos_id, eos_id, pad_id, unk_id)

In [8]:
corpus = get_corpus("karpathy/tiny_shakespeare")
tok = train_unigram(corpus, vocab_size=8000)
VOCAB = tok.vocab_size; EOS = tok.eos_token_id
print("vocab:", VOCAB)



vocab: 5590


In [9]:
# 2) chunk & loaders
def chunk(text, L=512):
    ids = tok.encode(text)
    n = (len(ids)//L)*L; ids = ids[:n]
    import numpy as np
    arr = np.array(ids, dtype=np.int64).reshape(-1, L)
    y = np.copy(arr); y[:, :-1] = arr[:, 1:]; y[:, -1] = EOS
    X=torch.tensor(arr); Y=torch.tensor(y)
    return TensorDataset(X,Y)

In [10]:
ds_full = chunk(corpus, 512)
train_len = int(0.95*len(ds_full)); val_len = len(ds_full)-train_len
train_ds, val_ds = random_split(ds_full, [train_len, val_len])

train_dl = DataLoader(train_ds, batch_size=64, shuffle=True,
                      num_workers=DATALOADER_NUM_WORKERS, pin_memory=PIN_MEMORY, drop_last=True)
val_dl   = DataLoader(val_ds, batch_size=64, shuffle=False,
                      num_workers=DATALOADER_NUM_WORKERS, pin_memory=PIN_MEMORY)

In [11]:
# 3) training utilities
def run(cfg:ModelCfg, title:str, steps:int=150, lr=3e-4, wd=0.1, warmup=40, amp=True):
    model = HymbaRef(cfg).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd, betas=(0.9,0.95), eps=1e-8)
    from transformers import get_cosine_schedule_with_warmup
    sch = get_cosine_schedule_with_warmup(opt, warmup, steps)
    scaler = torch.amp.GradScaler(device=USE_AMP_DEVICE, enabled=amp)

    step=0; t0=time.time(); tok_count=0; lossm=0.0; cnt=0
    model.train()
    for xb,yb in train_dl:
        xb, yb = xb.to(device), yb.to(device)
        with torch.amp.autocast(device_type=USE_AMP_DEVICE, enabled=True):
            out = model(xb, targets=yb); loss = out["loss"]
        lossm += loss.item(); cnt += 1
        scaler.scale(loss).backward()
        scaler.unscale_(opt)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(opt); scaler.update(); opt.zero_grad(set_to_none=True); sch.step()
        step += 1; tok_count += xb.numel()
        if step>=steps: break
    train_loss = lossm/max(1,cnt)
    elapsed = time.time()-t0
    tps = tok_count/max(1e-9, elapsed)

    # eval with consistent averaging
    model.eval(); nll=0.0; n_tok=0
    with torch.no_grad(), torch.amp.autocast(device_type=USE_AMP_DEVICE, enabled=True):
        for xb,yb in val_dl:
            xb, yb = xb.to(device), yb.to(device)
            out = model(xb, targets=yb); loss = out["loss"]
            nll += loss.item()*xb.numel(); n_tok += xb.numel()
    val_loss = nll/max(1,n_tok); ppl = math.exp(val_loss)

    # KV cache estimate (inference)
    cache_mb = model.estimate_kv_cache_mb(seq_len=512)
    return dict(title=title, train_loss=train_loss, val_loss=val_loss, ppl=ppl, time=elapsed, mem=cache_mb)

In [12]:
# 4) experiment grid (cleaned & minimal)
base = dict(vocab_size=VOCAB, d_model=512, n_layers=12, n_heads=8, n_kv_heads=4,
            attn_dim=256, mamba_dim=256, swa_layers=None, swa_window=256,
            num_meta_tokens=4, meta_dropout=0.1, kv_share=True, fusion="mean")

exps = [
    ("BASE: partial-SWA + KVshare + Meta + Hybrid",
     dict(attn_dim=256, mamba_dim=256, swa_layers=None, kv_share=True, num_meta_tokens=4, fusion="mean")),
    ("NO-KVSHARE",
     dict(attn_dim=256, mamba_dim=256, swa_layers=None, kv_share=False, num_meta_tokens=4, fusion="mean")),
    ("NO-META",
     dict(attn_dim=256, mamba_dim=256, swa_layers=None, kv_share=True, num_meta_tokens=0, fusion="mean")),
    ("WEAK-SWA(window=256)",
     dict(attn_dim=256, mamba_dim=256, swa_layers=None, kv_share=True, num_meta_tokens=4, fusion="mean")),
]

In [None]:
rows = []
for title, override in exps:
    cfg = ModelCfg(**{**base, **override})
    rows.append(run(cfg, title=title, steps=1000, lr=3e-4, warmup=40, amp=True))

df = pd.DataFrame(rows)
df


Unnamed: 0,title,train_loss,val_loss,ppl,time,mem
0,BASE: partial-SWA + KVshare + Meta + Hybrid,8.67523,8.379238,4355.689187,5.489295,1.5
1,NO-KVSHARE,8.679866,8.387218,4390.588165,3.706423,3.0
2,NO-META,8.688317,8.402307,4457.335998,3.978421,1.5
3,WEAK-SWA(window=256),8.706653,8.410793,4495.325259,3.710657,1.5


In [14]:
rows = []
for title, override in exps:
    cfg = ModelCfg(**{**base, **override})
    rows.append(run(cfg, title=title, steps=10000, lr=3e-4, warmup=40, amp=True))

df = pd.DataFrame(rows)
df


Unnamed: 0,title,train_loss,val_loss,ppl,time,mem
0,BASE: partial-SWA + KVshare + Meta + Hybrid,8.685378,8.392913,4415.661203,3.713562,1.5
1,NO-KVSHARE,8.694437,8.392723,4414.823273,3.709999,3.0
2,NO-META,8.695388,8.388641,4396.8399,3.05888,1.5
3,WEAK-SWA(window=256),8.71985,8.41023,4492.792311,3.711063,1.5
