# Hymba 수정된 구현 평가

## 수정 사항

### 1. KV Sharing RoPE 버그 수정
- **문제**: Producer가 RoPE 적용 전 K를 공유 → Consumer에서 RoPE 중복 적용
- **해결**: Producer가 RoPE 적용 후 K를 공유, Consumer는 RoPE 재적용 안함

### 2. SWA Window Size 최적화
- **변경**: `window=256` → `window=128` (seq_len=1024의 1/8)
- **이유**: 더 명확한 local attention 효과

## 예상 결과
- Hybrid-Mine PPL: 40.70 → ~37-38 (Official 수준으로 개선)

In [1]:
import sys
import os
import warnings
import gc
sys.path.append('./backbone')

os.environ['TOKENIZERS_PARALLELISM'] = 'false'
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import time
from tqdm.auto import tqdm
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
from collections import defaultdict

# 모듈 리로드 (수정된 코드 반영)
import importlib
if 'hymba' in sys.modules:
    importlib.reload(sys.modules['hymba'])
if 'hymba_official' in sys.modules:
    importlib.reload(sys.modules['hymba_official'])

# 수정된 내 구현
from hymba import Hymba, HymbaConfig, ArchType, AttentionType

# 공식 구현 스타일
from hymba_official import HymbaOfficialModel, HymbaOfficialConfig

from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer

RESULTS_DIR = './results'
os.makedirs(RESULTS_DIR, exist_ok=True)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Device: {device}')
if device == 'cuda':
    print(f'GPU: {torch.cuda.get_device_name()}')
    print(f'Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')

plt.rcParams['figure.figsize'] = (16, 10)
plt.rcParams['figure.dpi'] = 150
sns.set_style('whitegrid')

torch.manual_seed(42)
np.random.seed(42)

Device: cuda
GPU: NVIDIA A100 80GB PCIe
Memory: 85.1 GB


---

## 1. 데이터셋 준비

In [2]:
print('=' * 70)
print('WikiText-103 데이터셋 로드')
print('=' * 70)

tokenizer = AutoTokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
VOCAB_SIZE = tokenizer.vocab_size

print(f'Tokenizer: GPT-2 (vocab_size={VOCAB_SIZE})')

dataset = load_dataset('wikitext', 'wikitext-103-raw-v1')

print(f"\nTrain samples: {len(dataset['train']):,}")
print(f"Validation samples: {len(dataset['validation']):,}")

WikiText-103 데이터셋 로드
Tokenizer: GPT-2 (vocab_size=50257)

Train samples: 1,801,350
Validation samples: 3,760


In [None]:
class WikiTextDataset(Dataset):
    def __init__(self, texts: List[str], tokenizer, seq_len: int, max_tokens: int = None):
        self.seq_len = seq_len
        self.tokenizer = tokenizer
        
        print(f'텍스트 토큰화 중...')
        all_text = ' '.join([t for t in texts if t.strip()])
        tokens = tokenizer.encode(all_text, add_special_tokens=False)
        
        if max_tokens:
            tokens = tokens[:max_tokens]
        
        self.tokens = torch.tensor(tokens, dtype=torch.long)
        self.n_chunks = (len(self.tokens) - 1) // seq_len
        
        print(f'총 토큰: {len(self.tokens):,}')
        print(f'청크 수 (seq_len={seq_len}): {self.n_chunks:,}')
    
    def __len__(self):
        return self.n_chunks
    
    def __getitem__(self, idx):
        start = idx * self.seq_len
        end = start + self.seq_len + 1
        chunk = self.tokens[start:end]
        return chunk[:-1], chunk[1:]


SEQ_LEN = 1024
BATCH_SIZE = 8
MAX_TRAIN_TOKENS = 50_000_000
MAX_VAL_TOKENS = 500_000

print('\n' + '=' * 70)
print(f'데이터셋 생성 (seq_len={SEQ_LEN})')
print('=' * 70)

train_dataset = WikiTextDataset(
    dataset['train']['text'], tokenizer, SEQ_LEN, max_tokens=MAX_TRAIN_TOKENS
)
val_dataset = WikiTextDataset(
    dataset['validation']['text'], tokenizer, SEQ_LEN, max_tokens=MAX_VAL_TOKENS
)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f'\n학습: {len(train_dataset):,} 청크, {len(train_loader):,} 배치')
print(f'검증: {len(val_dataset):,} 청크')


데이터셋 생성 (seq_len=1024)
텍스트 토큰화 중...


---

## 2. 모델 설정 (수정된 버전)

In [None]:
@dataclass
class ExperimentConfig:
    name: str
    config: any
    description: str
    is_official: bool = False


# 수정된 설정:
# - SWA window: 256 → 128 (더 명확한 local attention)
SWA_WINDOW = 128  # 수정: seq_len의 1/8
NUM_META_TOKENS = 64

experiments: Dict[str, ExperimentConfig] = {
    # Hybrid - 수정된 내 구현 (RoPE 버그 수정 + window 최적화)
    'Hybrid-Mine-Fixed': ExperimentConfig(
        name='Hybrid-Mine-Fixed',
        config=HymbaConfig(
            vocab_size=VOCAB_SIZE,
            d_model=320,
            n_layers=11,
            n_heads=5,
            n_kv_heads=1,
            arch_type=ArchType.HYBRID,
            global_attn_idx=[0, 5, 10],
            use_meta_tokens=True,
            num_meta_tokens=NUM_META_TOKENS,
            swa_window=SWA_WINDOW,
            dropout=0.1,
        ),
        description='Hybrid (수정됨: RoPE 버그 수정 + window=128)',
    ),
    
    # Hybrid - 공식 스타일
    'Hybrid-Official': ExperimentConfig(
        name='Hybrid-Official',
        config=HymbaOfficialConfig(
            vocab_size=VOCAB_SIZE,
            hidden_size=320,
            num_hidden_layers=11,
            num_attention_heads=5,
            num_key_value_heads=1,
            attn_hidden_size=320,
            global_attn_idx=[0, 5, 10],
            num_memory_tokens=NUM_META_TOKENS,
            attn_window_size=SWA_WINDOW,
            mamba_expand=2,
            mamba_d_state=16,
            mamba_d_conv=4,
            intermediate_size=320 * 3,
            attention_dropout=0.1,
        ),
        description='Hybrid (공식 스타일: 단일 in_proj, avg fusion)',
        is_official=True,
    ),
}

print('=' * 90)
print('실험 모델 설정')
print('=' * 90)
print(f'{"Name":<20} {"Params":>10} {"Layers":>7} {"d_model":>8} {"Meta":>6} {"SWA":>6}')
print('-' * 90)

for name, exp in experiments.items():
    cfg = exp.config
    
    if exp.is_official:
        model = HymbaOfficialModel(cfg)
        d_model = cfg.hidden_size
        layers = cfg.num_hidden_layers
        meta = cfg.num_memory_tokens
        swa = cfg.attn_window_size
    else:
        model = Hymba(cfg)
        d_model = cfg.d_model
        layers = cfg.n_layers
        meta = cfg.num_meta_tokens if cfg.use_meta_tokens else 0
        swa = cfg.swa_window
    
    params = model.count_parameters()['total']
    print(f'{name:<20} {params/1e6:>9.2f}M {layers:>7} {d_model:>8} {meta:>6} {swa:>6}')
    del model

torch.cuda.empty_cache()
print('=' * 90)

---

## 3. 학습

In [None]:
@dataclass
class TrainConfig:
    epochs: int = 3
    lr: float = 3e-4
    min_lr: float = 3e-5
    warmup_ratio: float = 0.05
    weight_decay: float = 0.1
    grad_clip: float = 1.0
    eval_interval: int = 500
    log_interval: int = 100


def train_model(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    config: TrainConfig,
    model_name: str = '',
) -> Dict:
    model = model.to(device).train()
    
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config.lr,
        betas=(0.9, 0.95),
        weight_decay=config.weight_decay,
    )
    
    total_steps = config.epochs * len(train_loader)
    warmup_steps = int(total_steps * config.warmup_ratio)
    
    def lr_lambda(step):
        if step < warmup_steps:
            return step / warmup_steps
        progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
        cosine_decay = 0.5 * (1 + np.cos(np.pi * progress))
        return config.min_lr / config.lr + (1 - config.min_lr / config.lr) * cosine_decay
    
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    
    history = {'train_loss': [], 'val_loss': [], 'val_ppl': [], 'step': [], 'lr': []}
    step = 0
    best_val_loss = float('inf')
    t0 = time.time()
    running_loss = 0.0
    
    for epoch in range(config.epochs):
        model.train()
        pbar = tqdm(train_loader, desc=f'[{model_name}] Epoch {epoch+1}/{config.epochs}', leave=False)
        
        for xb, yb in pbar:
            xb, yb = xb.to(device), yb.to(device)
            
            out = model(xb, targets=yb)
            loss = out['loss']
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)
            scheduler.step()
            step += 1
            running_loss += loss.item()
            
            if step % config.log_interval == 0:
                avg_loss = running_loss / config.log_interval
                running_loss = 0.0
                pbar.set_postfix({'loss': f'{avg_loss:.3f}', 'lr': f'{scheduler.get_last_lr()[0]:.2e}'})
            
            if step % config.eval_interval == 0:
                model.eval()
                val_loss = 0.0
                val_count = 0
                
                with torch.no_grad():
                    for vxb, vyb in val_loader:
                        vxb, vyb = vxb.to(device), vyb.to(device)
                        vout = model(vxb, targets=vyb)
                        val_loss += vout['loss'].item() * vxb.size(0)
                        val_count += vxb.size(0)
                
                val_loss /= val_count
                val_ppl = np.exp(val_loss)
                
                history['train_loss'].append(loss.item())
                history['val_loss'].append(val_loss)
                history['val_ppl'].append(val_ppl)
                history['step'].append(step)
                history['lr'].append(scheduler.get_last_lr()[0])
                
                best_val_loss = min(best_val_loss, val_loss)
                model.train()
    
    elapsed = time.time() - t0
    tokens_processed = step * BATCH_SIZE * SEQ_LEN
    
    return {
        'best_val_loss': best_val_loss,
        'best_ppl': np.exp(best_val_loss),
        'final_ppl': history['val_ppl'][-1] if history['val_ppl'] else np.exp(best_val_loss),
        'time_min': elapsed / 60,
        'tokens_per_sec': tokens_processed / elapsed,
        'history': history,
    }

In [None]:
train_config = TrainConfig(epochs=3, lr=3e-4, eval_interval=500, log_interval=100)
results: Dict[str, Dict] = {}
trained_models: Dict[str, nn.Module] = {}

print('\n' + '=' * 80)
print('학습 시작 (수정된 버전)')
print(f'총 에폭: {train_config.epochs}')
print(f'배치당 토큰: {BATCH_SIZE * SEQ_LEN:,}')
print(f'SWA Window: {SWA_WINDOW} (seq_len의 {SWA_WINDOW/SEQ_LEN*100:.1f}%)')
print('=' * 80)

for exp_name, exp_config in experiments.items():
    print(f"\n{'='*70}")
    print(f"실험: {exp_name}")
    print(f"설명: {exp_config.description}")
    print(f"{'='*70}")
    
    cfg = exp_config.config
    
    if exp_config.is_official:
        model = HymbaOfficialModel(cfg)
    else:
        model = Hymba(cfg)
    
    params = model.count_parameters()
    print(f'Parameters: {params["total"]/1e6:.2f}M')
    
    train_result = train_model(model, train_loader, val_loader, train_config, exp_name)
    
    results[exp_name] = {
        'config': cfg,
        'params': params['total'],
        'best_ppl': train_result['best_ppl'],
        'final_ppl': train_result['final_ppl'],
        'time_min': train_result['time_min'],
        'tokens_per_sec': train_result['tokens_per_sec'],
        'history': train_result['history'],
        'is_official': exp_config.is_official,
    }
    
    trained_models[exp_name] = model.eval()
    
    print(f'Best PPL: {train_result["best_ppl"]:.2f}')
    print(f'Throughput: {train_result["tokens_per_sec"]/1000:.1f}K tokens/sec')
    print(f'Time: {train_result["time_min"]:.1f} min')

print('\n' + '=' * 80)
print('모든 학습 완료!')
print('=' * 80)

---

## 4. 결과 분석

In [None]:
df = pd.DataFrame([{
    'Model': name,
    'Type': 'Official' if r.get('is_official', False) else 'Mine',
    'Params (M)': r['params'] / 1e6,
    'Best PPL': r['best_ppl'],
    'Final PPL': r['final_ppl'],
    'Throughput (K tok/s)': r['tokens_per_sec'] / 1000,
    'Time (min)': r['time_min'],
} for name, r in results.items()])

df_sorted = df.sort_values('Best PPL')

print('\n' + '=' * 100)
print('결과 요약 (Best PPL 기준 정렬)')
print('=' * 100)
print(df_sorted.to_string(index=False))
print('=' * 100)

# 이전 결과와 비교
print('\n' + '=' * 80)
print('이전 결과와 비교')
print('=' * 80)
print(f'이전 Hybrid-Mine PPL: 40.70')
print(f'이전 Hybrid-Official PPL: 37.14')
print(f'이전 차이: 3.56')
print()

mine_ppl = results.get('Hybrid-Mine-Fixed', {}).get('best_ppl', 0)
official_ppl = results.get('Hybrid-Official', {}).get('best_ppl', 0)
if mine_ppl and official_ppl:
    diff = mine_ppl - official_ppl
    print(f'수정 후 Hybrid-Mine-Fixed PPL: {mine_ppl:.2f}')
    print(f'수정 후 Hybrid-Official PPL: {official_ppl:.2f}')
    print(f'수정 후 차이: {diff:.2f}')
    
    improvement = 40.70 - mine_ppl
    print(f'\nHybrid-Mine 개선: {improvement:.2f} PPL')

In [None]:
# 학습 곡선 시각화
colors = {
    'Hybrid-Mine-Fixed': '#45B7D1',
    'Hybrid-Official': '#6B5B95',
}

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# PPL 비교
ax = axes[0]
bars = ax.barh(df_sorted['Model'], df_sorted['Best PPL'],
               color=[colors.get(m, 'gray') for m in df_sorted['Model']])
ax.set_xlabel('Best Validation PPL (lower is better)')
ax.set_title('Model Performance Comparison', fontweight='bold', fontsize=12)
ax.invert_yaxis()
for bar, val in zip(bars, df_sorted['Best PPL']):
    ax.text(val + 0.3, bar.get_y() + bar.get_height()/2, f'{val:.2f}', va='center')

# 학습 곡선
ax = axes[1]
for name, r in results.items():
    linestyle = '--' if r.get('is_official', False) else '-'
    ax.plot(r['history']['step'], r['history']['val_ppl'],
            label=name, linewidth=2, linestyle=linestyle,
            color=colors.get(name, 'gray'))
ax.set_xlabel('Training Steps')
ax.set_ylabel('Validation PPL')
ax.set_title('Training Curves', fontweight='bold', fontsize=12)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f'{RESULTS_DIR}/training_comparison_fixed.png', dpi=300, bbox_inches='tight')
plt.show()

---

## 5. 결론

In [None]:
print('=' * 80)
print('실험 결론')
print('=' * 80)

print('''
┌─────────────────────────────────────────────────────────────────────────────┐
│                              수정 사항                                        │
├─────────────────────────────────────────────────────────────────────────────┤
│ 1. KV Sharing RoPE 버그 수정                                                 │
│    - 이전: RoPE 적용 전 K 공유 → Consumer에서 RoPE 중복 적용                   │
│    - 수정: RoPE 적용 후 K 공유 → Consumer는 RoPE 재적용 안함                   │
│                                                                               │
│ 2. SWA Window Size 최적화                                                    │
│    - 이전: window=256 (seq_len의 25%)                                         │
│    - 수정: window=128 (seq_len의 12.5%)                                       │
│    - 효과: 더 명확한 local attention 패턴                                     │
└─────────────────────────────────────────────────────────────────────────────┘
''')

mine_ppl = results.get('Hybrid-Mine-Fixed', {}).get('best_ppl', 0)
official_ppl = results.get('Hybrid-Official', {}).get('best_ppl', 0)

print('\n결과 비교:')
print(f'  이전 Hybrid-Mine PPL: 40.70')
print(f'  수정 후 Hybrid-Mine-Fixed PPL: {mine_ppl:.2f}')
print(f'  개선량: {40.70 - mine_ppl:.2f} PPL')
print(f'\n  Hybrid-Official PPL: {official_ppl:.2f}')
print(f'  차이: {mine_ppl - official_ppl:.2f} (이전: 3.56)')

In [None]:
# 결과 저장
df.to_csv(f'{RESULTS_DIR}/fixed_comparison_results.csv', index=False)
print(f'결과 저장 완료: {RESULTS_DIR}/fixed_comparison_results.csv')

# 메모리 정리
del trained_models
torch.cuda.empty_cache()
gc.collect()
print('메모리 정리 완료')