In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
import time

# --- 중요 ---
# Hymba의 Mamba 구성 요소는 성능 최적화를 위해 CUDA 커널을 사용합니다.
# 이 코드는 반드시 GPU가 활성화된 환경(예: Google Colab의 GPU 런타임)에서 실행해야 합니다.
# !pip install mamba-ssm causal-conv1d einops
from mamba_ssm import Mamba

# --- 1단계: Llama 기반 Transformer 구성 요소 구현 ---

class RMSNorm(nn.Module):
    def __init__(self, d_model: int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model))

    def forward(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight

class RotaryEmbedding(nn.Module):
    def __init__(self, dim: int, max_seq_len: int, base: int = 10000):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)
        
        t = torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :])
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :])

    def forward(self, x: torch.Tensor, seq_len: int):
        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )

def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_out = (xq * cos) + (rotate_half(xq) * sin)
    xk_out = (xk * cos) + (rotate_half(xk) * sin)
    return xq_out, xk_out

def rotate_half(x: torch.Tensor):
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

class GroupedQueryAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int, n_kv_heads: int, max_seq_len: int, window_size: int = -1):
        super().__init__()
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.head_dim = d_model // n_heads
        self.window_size = window_size

        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        self.k_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(d_model, d_model, bias=False)
        
        self.rotary_emb = RotaryEmbedding(self.head_dim, max_seq_len)

    def forward(self, x: torch.Tensor, past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attn_mask: Optional[torch.Tensor] = None, use_cache: bool = False):
        batch_size, seq_len, _ = x.shape
        
        q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(batch_size, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(batch_size, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2)
        
        cos, sin = self.rotary_emb(q, seq_len)
        q, k = apply_rotary_emb(q, k, cos, sin)

        if use_cache and past_kv is not None:
            past_k, past_v = past_kv
            k = torch.cat([past_k, k], dim=2)
            v = torch.cat([past_v, v], dim=2)
        
        present_kv = (k, v) if use_cache else None
        
        if self.n_kv_heads != self.n_heads:
            n_repeats = self.n_heads // self.n_kv_heads
            k = k.repeat_interleave(n_repeats, dim=1)
            v = v.repeat_interleave(n_repeats, dim=1)
            
        attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=(attn_mask is None))
        
        output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
        return self.o_proj(output), present_kv

class FeedForward(nn.Module):
    def __init__(self, d_model: int, hidden_dim: int):
        super().__init__()
        self.w1 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, d_model, bias=False)
        self.w3 = nn.Linear(d_model, hidden_dim, bias=False)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

# --- 2단계: 최종 Hymba 레이어 및 모델 구현 ---

class TransformerBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int, n_kv_heads: int, max_seq_len: int, ffn_hidden_dim: int, **kwargs):
        super().__init__()
        attn_args = {'d_model': d_model, 'n_heads': n_heads, 'n_kv_heads': n_kv_heads, 'max_seq_len': max_seq_len}
        if 'window_size' in kwargs:
            attn_args['window_size'] = kwargs['window_size']
            
        self.attention = GroupedQueryAttention(**attn_args)
        self.feed_forward = FeedForward(d_model, ffn_hidden_dim)
        self.attention_norm = RMSNorm(d_model)
        self.ffn_norm = RMSNorm(d_model)

    def forward(self, x: torch.Tensor, past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attn_mask: Optional[torch.Tensor] = None, use_cache: bool = False):
        h, present_kv = self.attention(self.attention_norm(x), past_kv, attn_mask, use_cache=use_cache)
        h = x + h
        out = h + self.feed_forward(self.ffn_norm(h))
        return out, present_kv

class HymbaLayer(nn.Module):
    def __init__(self, d_model: int, mamba_params: dict, attn_params: dict, ffn_hidden_dim: int):
        super().__init__()
        self.norm = RMSNorm(d_model)
        self.mamba_block = Mamba(d_model=d_model, **mamba_params)
        self.attn_block = GroupedQueryAttention(**attn_params)
        self.feed_forward = FeedForward(d_model, ffn_hidden_dim)
        self.ffn_norm = RMSNorm(d_model)
        self.gate = nn.Linear(d_model * 2, d_model, bias=False)

    def forward(self, x: torch.Tensor, past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attn_mask: Optional[torch.Tensor] = None, use_cache: bool = False):
        residual = x
        x_norm = self.norm(x)
        
        mamba_out = self.mamba_block(x_norm)
        attn_out, present_kv = self.attn_block(x_norm, past_kv, attn_mask, use_cache=use_cache)
        
        combined = torch.cat([mamba_out, attn_out], dim=-1)
        gated_output = self.gate(combined)
        
        h = residual + gated_output
        out = h + self.feed_forward(self.ffn_norm(h))
        
        return out, present_kv

class HymbaModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.tok_embeddings = nn.Embedding(config['vocab_size'], config['d_model'])
        
        if config['use_meta_tokens']:
            self.meta_tokens = nn.Parameter(torch.randn(1, config['n_meta_tokens'], config['d_model']))

        layers = []
        for _ in range(config['n_layers']):
            attn_params = {
                'd_model': config['d_model'],
                'n_heads': config['n_heads'],
                'n_kv_heads': config['n_kv_heads'],
                'max_seq_len': config['max_seq_len'],
                'window_size': config['window_size'] if config['use_swa'] else -1,
            }
            if config['use_ssm_head']:
                layers.append(HymbaLayer(
                    d_model=config['d_model'],
                    mamba_params=config['mamba_params'],
                    attn_params=attn_params,
                    ffn_hidden_dim=config['ffn_hidden_dim']
                ))
            else:
                layers.append(TransformerBlock(
                    d_model=config['d_model'],
                    n_heads=config['n_heads'],
                    n_kv_heads=config['n_kv_heads'],
                    max_seq_len=config['max_seq_len'],
                    ffn_hidden_dim=config['ffn_hidden_dim'],
                    window_size=attn_params['window_size']
                ))
        self.layers = nn.ModuleList(layers)
        self.norm = RMSNorm(config['d_model'])
        self.output = nn.Linear(config['d_model'], config['vocab_size'], bias=False)

    def _create_sliding_window_mask(self, seq_len: int, device: torch.device) -> Optional[torch.Tensor]:
        if not self.config['use_swa'] or self.config['window_size'] <= 0:
            return None
        
        mask = torch.full((1, 1, seq_len, seq_len), float("-inf"), device=device)
        mask = torch.triu(mask, diagonal=1)
        
        window_mask = torch.triu(torch.ones_like(mask), diagonal=self.config['window_size'] + 1)
        mask = mask.masked_fill(window_mask.bool(), float('-inf'))
        
        return mask

    def forward(self, tokens: torch.Tensor, use_cache: bool = False):
        batch_size, seq_len = tokens.shape
        h = self.tok_embeddings(tokens)

        if self.config['use_meta_tokens']:
            meta_h = self.meta_tokens.expand(batch_size, -1, -1)
            h = torch.cat([meta_h, h], dim=1)
            seq_len += self.config['n_meta_tokens']

        attn_mask = self._create_sliding_window_mask(seq_len, tokens.device)
        
        kv_cache = [None] * self.config['n_layers']
        
        for i, layer in enumerate(self.layers):
            past_kv = kv_cache[i] if use_cache else None
            
            if use_cache and self.config['use_shared_kv_cache'] and i > 0:
                past_kv = kv_cache[i-1]
                
            h, present_kv = layer(h, past_kv=past_kv, attn_mask=attn_mask, use_cache=use_cache)
            
            if use_cache:
                kv_cache[i] = present_kv

        h = self.norm(h)

        if self.config['use_meta_tokens']:
            h = h[:, self.config['n_meta_tokens']:, :]

        return self.output(h)

# --- 3단계: Ablation Study 재현을 위한 설정 ---
if __name__ == '__main__':
    
    # 1. 실행 환경 확인 및 디바이스 설정
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print(f"사용 디바이스: {torch.cuda.get_device_name(0)}")
        print("Mamba의 CUDA 커널을 사용하여 최적화된 성능으로 실행됩니다.")
    else:
        device = torch.device("cpu")
        print("사용 디바이스: CPU")
        print("!!! 경고: Mamba의 핵심 기능은 CUDA를 필요로 합니다. 이 코드는 CPU 환경에서 실행되지 않습니다. !!!")
        print("!!! Google Colab 사용 시, [런타임] > [런타임 유형 변경]에서 GPU를 선택해주세요. !!!")
        # CPU 환경에서는 Mamba가 실행되지 않으므로, 테스트를 중단합니다.
        exit()

    base_config = {
        'vocab_size': 32000, 'd_model': 256, 'n_layers': 4,
        'n_heads': 8, 'n_kv_heads': 2, 'max_seq_len': 1024,
        'ffn_hidden_dim': 256 * 4, 'window_size': 256, 'n_meta_tokens': 4,
        'mamba_params': {'d_state': 16, 'd_conv': 4, 'expand': 2},
        'use_meta_tokens': False, 'use_shared_kv_cache': False,
        'use_swa': False, 'use_ssm_head': False,
    }

    ablation_steps = {
        "1_Transformer_Baseline": {},
        "2_Plus_Meta_Tokens": {'use_meta_tokens': True},
        "3_Plus_Shared_KV": {'use_meta_tokens': True, 'use_shared_kv_cache': True},
        "4_Plus_SWA": {'use_meta_tokens': True, 'use_shared_kv_cache': True, 'use_swa': True},
        "5_Hymba_Final": {'use_meta_tokens': True, 'use_shared_kv_cache': True, 'use_swa': True, 'use_ssm_head': True},
    }

    print("\n--- Hymba Ablation Study 단계별 모델 구현 테스트 ---")
    for name, flags in ablation_steps.items():
        print(f"\n--- 단계: {name} ---")
        config = base_config.copy()
        config.update(flags)
        
        # 2. 모델과 데이터를 설정된 디바이스로 이동
        model = HymbaModel(config).to(device)
        
        num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(f"모델 파라미터 수: {num_params:,}")

        dummy_input = torch.randint(0, config['vocab_size'], (2, 512)).to(device)
        
        try:
            with torch.no_grad():
                output_logits = model(dummy_input, use_cache=False)
                
            expected_seq_len = 512
            print(f"입력 shape: {dummy_input.shape}")
            print(f"출력 shape: {output_logits.shape}")
            assert output_logits.shape == (2, expected_seq_len, config['vocab_size'])
            print("테스트 성공!")
        except RuntimeError as e:
            print(f"!!! 테스트 실패: {e} !!!")
            print("GPU 환경에서 실행 중인지 다시 확인해주세요.")
            break


[2025-09-03 02:00:26,820] [INFO] [real_accelerator.py:219:get_accelerator] Setting ds_accelerator to cuda (auto detect)


2025-09-03 02:00:28.259547: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


사용 디바이스: NVIDIA A100 80GB PCIe
Mamba의 CUDA 커널을 사용하여 최적화된 성능으로 실행됩니다.

--- Hymba Ablation Study 단계별 모델 구현 테스트 ---

--- 단계: 1_Transformer_Baseline ---
모델 파라미터 수: 20,187,392
입력 shape: torch.Size([2, 512])
출력 shape: torch.Size([2, 512, 32000])
테스트 성공!

--- 단계: 2_Plus_Meta_Tokens ---
모델 파라미터 수: 20,188,416
입력 shape: torch.Size([2, 512])
출력 shape: torch.Size([2, 512, 32000])
테스트 성공!

--- 단계: 3_Plus_Shared_KV ---
모델 파라미터 수: 20,188,416
입력 shape: torch.Size([2, 512])
출력 shape: torch.Size([2, 512, 32000])
테스트 성공!

--- 단계: 4_Plus_SWA ---
모델 파라미터 수: 20,188,416
입력 shape: torch.Size([2, 512])
출력 shape: torch.Size([2, 512, 32000])
테스트 성공!

--- 단계: 5_Hymba_Final ---
모델 파라미터 수: 22,463,744
입력 shape: torch.Size([2, 512])
출력 shape: torch.Size([2, 512, 32000])
테스트 성공!


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
import time

# --- 중요 ---
# Hymba의 Mamba 구성 요소는 성능 최적화를 위해 CUDA 커널을 사용합니다.
# 이 코드는 반드시 GPU가 활성화된 환경(예: Google Colab의 GPU 런타임)에서 실행해야 합니다.
# !pip install mamba-ssm causal-conv1d einops
from mamba_ssm import Mamba

# --- 1단계: Llama 기반 Transformer 구성 요소 구현 ---

class RMSNorm(nn.Module):
    def __init__(self, d_model: int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model))

    def forward(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight

class RotaryEmbedding(nn.Module):
    def __init__(self, dim: int, max_seq_len: int, base: int = 10000):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)
        
        t = torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :])
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :])

    def forward(self, x: torch.Tensor, seq_len: int):
        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )

def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_out = (xq * cos) + (rotate_half(xq) * sin)
    xk_out = (xk * cos) + (rotate_half(xk) * sin)
    return xq_out, xk_out

def rotate_half(x: torch.Tensor):
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

class GroupedQueryAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int, n_kv_heads: int, max_seq_len: int, window_size: int = -1):
        super().__init__()
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.head_dim = d_model // n_heads
        self.window_size = window_size

        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        self.k_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(d_model, d_model, bias=False)
        
        self.rotary_emb = RotaryEmbedding(self.head_dim, max_seq_len)

    def forward(self, x: torch.Tensor, past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attn_mask: Optional[torch.Tensor] = None, use_cache: bool = False):
        batch_size, seq_len, _ = x.shape
        
        q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(batch_size, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(batch_size, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2)
        
        cos, sin = self.rotary_emb(q, seq_len)
        q, k = apply_rotary_emb(q, k, cos, sin)

        if use_cache and past_kv is not None:
            past_k, past_v = past_kv
            k = torch.cat([past_k, k], dim=2)
            v = torch.cat([past_v, v], dim=2)
        
        present_kv = (k, v)
        
        if self.n_kv_heads != self.n_heads:
            n_repeats = self.n_heads // self.n_kv_heads
            k = k.repeat_interleave(n_repeats, dim=1)
            v = v.repeat_interleave(n_repeats, dim=1)
            
        attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=(attn_mask is None))
        
        output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
        return self.o_proj(output), present_kv

class FeedForward(nn.Module):
    def __init__(self, d_model: int, hidden_dim: int):
        super().__init__()
        self.w1 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, d_model, bias=False)
        self.w3 = nn.Linear(d_model, hidden_dim, bias=False)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

class TransformerBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int, n_kv_heads: int, max_seq_len: int, ffn_hidden_dim: int, **kwargs):
        super().__init__()
        attn_args = {'d_model': d_model, 'n_heads': n_heads, 'n_kv_heads': n_kv_heads, 'max_seq_len': max_seq_len}
        if 'window_size' in kwargs:
            attn_args['window_size'] = kwargs['window_size']
            
        self.attention = GroupedQueryAttention(**attn_args)
        self.feed_forward = FeedForward(d_model, ffn_hidden_dim)
        self.attention_norm = RMSNorm(d_model)
        self.ffn_norm = RMSNorm(d_model)

    def forward(self, x: torch.Tensor, past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attn_mask: Optional[torch.Tensor] = None, use_cache: bool = False):
        h, present_kv = self.attention(self.attention_norm(x), past_kv, attn_mask, use_cache=use_cache)
        h = x + h
        out = h + self.feed_forward(self.ffn_norm(h))
        return out, present_kv

class HymbaLayer(nn.Module):
    def __init__(self, d_model: int, mamba_params: dict, attn_params: dict, ffn_hidden_dim: int):
        super().__init__()
        self.norm = RMSNorm(d_model)
        self.mamba_block = Mamba(d_model=d_model, **mamba_params)
        self.attn_block = GroupedQueryAttention(**attn_params)
        self.feed_forward = FeedForward(d_model, ffn_hidden_dim)
        self.ffn_norm = RMSNorm(d_model)
        self.gate = nn.Linear(d_model * 2, d_model, bias=False)

    def forward(self, x: torch.Tensor, past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attn_mask: Optional[torch.Tensor] = None, use_cache: bool = False):
        residual = x
        x_norm = self.norm(x)
        
        mamba_out = self.mamba_block(x_norm)
        attn_out, present_kv = self.attn_block(x_norm, past_kv, attn_mask, use_cache=use_cache)
        
        combined = torch.cat([mamba_out, attn_out], dim=-1)
        gated_output = self.gate(combined)
        
        h = residual + gated_output
        out = h + self.feed_forward(self.ffn_norm(h))
        
        return out, present_kv

class HymbaModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.tok_embeddings = nn.Embedding(config['vocab_size'], config['d_model'])
        
        if config['use_meta_tokens']:
            self.meta_tokens = nn.Parameter(torch.randn(1, config['n_meta_tokens'], config['d_model']))

        layers = []
        for _ in range(config['n_layers']):
            attn_params = {
                'd_model': config['d_model'],
                'n_heads': config['n_heads'],
                'n_kv_heads': config['n_kv_heads'],
                'max_seq_len': config['max_seq_len'],
                'window_size': config['window_size'] if config['use_swa'] else -1,
            }
            if config['use_ssm_head']:
                layers.append(HymbaLayer(
                    d_model=config['d_model'],
                    mamba_params=config['mamba_params'],
                    attn_params=attn_params,
                    ffn_hidden_dim=config['ffn_hidden_dim']
                ))
            else:
                layers.append(TransformerBlock(
                    d_model=config['d_model'],
                    n_heads=config['n_heads'],
                    n_kv_heads=config['n_kv_heads'],
                    max_seq_len=config['max_seq_len'],
                    ffn_hidden_dim=config['ffn_hidden_dim'],
                    window_size=attn_params['window_size']
                ))
        self.layers = nn.ModuleList(layers)
        self.norm = RMSNorm(config['d_model'])
        self.output = nn.Linear(config['d_model'], config['vocab_size'], bias=False)

    def _create_sliding_window_mask(self, seq_len: int, device: torch.device) -> Optional[torch.Tensor]:
        if not self.config['use_swa'] or self.config['window_size'] <= 0:
            return None
        
        mask = torch.full((1, 1, seq_len, seq_len), float("-inf"), device=device)
        mask = torch.triu(mask, diagonal=1)
        
        window_mask = torch.triu(torch.ones_like(mask), diagonal=self.config['window_size'] + 1)
        mask = mask.masked_fill(window_mask.bool(), float('-inf'))
        
        return mask

    def forward(self, tokens: torch.Tensor, use_cache: bool = False, return_kv_cache: bool = False):
        """
        수정된 forward 메소드
        - use_cache: KV 캐시를 사용할지(추론) 안 할지(학습) 결정
        - return_kv_cache: 외부에서 캐시를 확인하기 위해 반환할지 결정
        """
        batch_size, seq_len = tokens.shape
        h = self.tok_embeddings(tokens)

        if self.config['use_meta_tokens']:
            meta_h = self.meta_tokens.expand(batch_size, -1, -1)
            h = torch.cat([meta_h, h], dim=1)
            seq_len += self.config['n_meta_tokens']

        attn_mask = self._create_sliding_window_mask(seq_len, tokens.device)
        
        kv_cache = [None] * self.config['n_layers']
        
        for i, layer in enumerate(self.layers):
            past_kv = kv_cache[i] if use_cache else None
            
            # Shared KV Cache는 추론 시에만 의미가 있음
            if use_cache and self.config['use_shared_kv_cache'] and i > 0:
                past_kv = kv_cache[i-1]
                
            h, present_kv = layer(h, past_kv=past_kv, attn_mask=attn_mask, use_cache=use_cache)
            
            # use_cache가 True일 때만 캐시를 저장
            if use_cache:
                kv_cache[i] = present_kv

        h = self.norm(h)

        if self.config['use_meta_tokens']:
            h = h[:, self.config['n_meta_tokens']:, :]

        logits = self.output(h)

        if return_kv_cache:
            return logits, kv_cache
        else:
            return logits

# --- 성능 측정 스크립트 ---
if __name__ == '__main__':
    # GPU 사용 가능 여부 확인
    if not torch.cuda.is_available():
        print("!!! 경고: 이 성능 측정 스크립트는 GPU 환경에서 실행해야 의미가 있습니다. !!!")
        print("!!! Google Colab 사용 시, [런타임] > [런타임 유형 변경]에서 GPU를 선택해주세요. !!!")
        exit()
        
    device = torch.device("cuda")
    print(f"사용 디바이스: {torch.cuda.get_device_name(0)}")

    # 기본 Llama 스타일 설정
    base_config = {
        'vocab_size': 32000, 'd_model': 256, 'n_layers': 4,
        'n_heads': 8, 'n_kv_heads': 2, 'max_seq_len': 1024,
        'ffn_hidden_dim': 256 * 4, 'window_size': 256, 'n_meta_tokens': 4,
        'mamba_params': {'d_state': 16, 'd_conv': 4, 'expand': 2},
        'use_meta_tokens': False, 'use_shared_kv_cache': False,
        'use_swa': False, 'use_ssm_head': False,
    }

    # 3단계 설정: Transformer + Meta Tokens + Shared KV
    config = base_config.copy()
    config.update({
        'use_meta_tokens': True,
        'use_shared_kv_cache': True,
    })

    model = HymbaModel(config).to(device)
    model.eval()

    batch_size = 2
    seq_len = 512
    dummy_input = torch.randint(0, config['vocab_size'], (batch_size, seq_len)).to(device)

    print("\n--- 3단계: + Shared KV Cache 모델 성능 측정 ---")

    # --- KV Cache 크기 측정 ---
    print("\n--- KV Cache 크기 측정 ---")
    with torch.no_grad():
        # use_cache와 return_kv_cache를 True로 설정하여 캐시를 채우고 반환받음
        _, kv_cache = model(dummy_input, use_cache=True, return_kv_cache=True)

    unique_tensors = {}
    for cache_tuple in kv_cache:
        if cache_tuple is not None:
            k, v = cache_tuple
            if id(k) not in unique_tensors:
                unique_tensors[id(k)] = k
            if id(v) not in unique_tensors:
                unique_tensors[id(v)] = v

    total_cache_bytes = sum(t.numel() * t.element_size() for t in unique_tensors.values())
    total_cache_mb = total_cache_bytes / (1024 * 1024)

    # 이론적 계산 (비교용)
    num_unique_layers = config['n_layers'] if not config['use_shared_kv_cache'] else (config['n_layers'] // 2 + config['n_layers'] % 2)
    expected_bytes = num_unique_layers * batch_size * config['n_kv_heads'] * (seq_len + config['n_meta_tokens']) * (config['d_model'] // config['n_heads']) * 2 * 4
    expected_mb = expected_bytes / (1024 * 1024)

    print(f"입력 shape: ({batch_size}, {seq_len})")
    print(f"실제 측정된 총 KV 캐시 크기: {total_cache_mb:.2f} MB")
    print(f"이론적인 총 KV 캐시 크기: {expected_mb:.2f} MB (공유 O)")

    # --- 처리량(Throughput) 측정 ---
    print("\n--- 처리량(Throughput) 측정 ---")
    warmup_iterations = 5
    measurement_iterations = 20

    print(f"워밍업 실행 ({warmup_iterations}회)...")
    for _ in range(warmup_iterations):
        # use_cache=True로 설정하여 추론과 유사한 상황(캐시 채우기)을 시뮬레이션
        _ = model(dummy_input, use_cache=True)
    torch.cuda.synchronize()

    print(f"성능 측정 실행 ({measurement_iterations}회)...")
    start_time = time.time()
    for _ in range(measurement_iterations):
        _ = model(dummy_input, use_cache=True)
    torch.cuda.synchronize()
    end_time = time.time()

    total_time = end_time - start_time
    total_tokens = batch_size * seq_len * measurement_iterations
    tokens_per_second = total_tokens / total_time

    print(f"총 처리 시간: {total_time:.3f} 초")
    print(f"총 처리 토큰: {total_tokens:,} 개")
    print(f"처리량: {tokens_per_second:,.2f} tokens/sec")

    # --- 지연 시간(Latency) 측정 ---
    print("\n--- 지연 시간(Latency) 측정 ---")
    latency_iterations = 50
    latencies = []
    for _ in range(latency_iterations):
        start = time.time()
        _ = model(dummy_input, use_cache=True)
        torch.cuda.synchronize()
        end = time.time()
        latencies.append(end - start)
    avg_latency = sum(latencies) / latency_iterations
    print(f"평균 지연 시간: {avg_latency * 1000:.2f} ms")


사용 디바이스: NVIDIA A100 80GB PCIe

--- 3단계: + Shared KV Cache 모델 성능 측정 ---

--- KV Cache 크기 측정 ---
입력 shape: (2, 512)
실제 측정된 총 KV 캐시 크기: 5.04 MB
이론적인 총 KV 캐시 크기: 1.01 MB (공유 O)

--- 처리량(Throughput) 측정 ---
워밍업 실행 (5회)...
성능 측정 실행 (20회)...
총 처리 시간: 0.070 초
총 처리 토큰: 20,480 개
처리량: 292,229.31 tokens/sec

--- 지연 시간(Latency) 측정 ---
평균 지연 시간: 4.13 ms
