In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import math
from typing import Optional, Tuple

# mamba-ssm과 causal-conv1d는 저수준 CUDA 커널을 사용하기 위해 필요합니다.
# pip install mamba-ssm causal-conv1d
try:
    from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
    from causal_conv1d import causal_conv1d_fn
except ImportError:
    print("Warning: mamba-ssm or causal-conv1d not found. MambaBranch will not work.")
    selective_scan_fn = None
    causal_conv1d_fn = None

# -----------------------------------------------------------------------------
# 1. 아키텍처의 기본 구성 요소 (Llama 및 Mamba)
# -----------------------------------------------------------------------------

class RMSNorm(nn.Module):
    def __init__(self, hidden_size: int, eps: float = 1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

class RotaryEmbedding(nn.Module):
    def __init__(self, dim: int, max_seq_len: int, base: int = 10000, device: Optional[str] = None):
        super().__init__()
        self.dim = dim
        inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
        self.register_buffer("inv_freq", inv_freq)
        self._set_cos_sin_cache(seq_len=max_seq_len, device=device)

    def _set_cos_sin_cache(self, seq_len: int, device: Optional[str], dtype: torch.dtype = torch.float32):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

    def forward(self, x: torch.Tensor, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
        return self.cos_cached[:seq_len, ...].to(dtype=x.dtype), self.sin_cached[:seq_len, ...].to(dtype=x.dtype)

def rotate_half(x: torch.Tensor) -> torch.Tensor:
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    cos = cos.unsqueeze(0).unsqueeze(2)
    sin = sin.unsqueeze(0).unsqueeze(2)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

class AttentionBranch(nn.Module):
    def __init__(self, d_inner: int, n_heads: int, n_kv_heads: int, max_seq_len: int, window_size: int = -1):
        super().__init__()
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.n_key_value_groups = n_heads // n_kv_heads
        self.head_dim = d_inner // n_heads
        self.rotary_emb = RotaryEmbedding(self.head_dim, max_seq_len)

    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attn_mask: Optional[torch.Tensor] = None, use_cache: bool = False):
        batch_size, q_len, _ = q.shape
        
        q = q.view(batch_size, q_len, self.n_heads, self.head_dim)
        k = k.view(batch_size, q_len, self.n_kv_heads, self.head_dim)
        v = v.view(batch_size, q_len, self.n_kv_heads, self.head_dim)

        kv_seq_len = q_len
        if past_kv is not None:
            kv_seq_len += past_kv[0].shape[1]
            
        cos, sin = self.rotary_emb(v, seq_len=kv_seq_len)

        if past_kv is not None:
            cos = cos[past_kv[0].shape[1]:]
            sin = sin[past_kv[0].shape[1]:]
        
        q, k = apply_rotary_pos_emb(q, k, cos, sin)
        
        if past_kv is not None:
            past_key, past_value = past_kv
            k = torch.cat([past_key, k], dim=1)
            v = torch.cat([past_value, v], dim=1)
            
        present_kv = (k, v) if use_cache else None
        
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        
        k = repeat_kv(k, self.n_key_value_groups)
        v = repeat_kv(v, self.n_key_value_groups)
        
        attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
        return attn_output.transpose(1, 2).contiguous().view(batch_size, q_len, -1), present_kv

class MambaBranch(nn.Module):
    def __init__(self, d_inner, d_state, d_conv, dt_rank):
        super().__init__()
        if causal_conv1d_fn is None or selective_scan_fn is None:
            raise ImportError("Mamba packages not found. Please install them.")
        
        self.d_inner, self.d_state, self.d_conv, self.dt_rank = d_inner, d_state, d_conv, dt_rank
        
        self.conv1d = nn.Conv1d(in_channels=self.d_inner, out_channels=self.d_inner, kernel_size=d_conv, bias=True, groups=self.d_inner, padding=d_conv - 1)
        self.x_proj = nn.Linear(self.d_inner, self.dt_rank + 2 * self.d_state, bias=False)
        self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
        
        A = torch.arange(1, self.d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1)
        self.A_log = nn.Parameter(torch.log(A))
        self.D = nn.Parameter(torch.ones(self.d_inner))

    def forward(self, x, z):
        x_transposed = x.transpose(1, 2).contiguous()
        x_conv = causal_conv1d_fn(x_transposed, self.conv1d.weight.squeeze(1), self.conv1d.bias, activation="silu")
        
        x_dbl = self.x_proj(x_conv.transpose(1, 2))
        dt_pre, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
        
        dt = self.dt_proj(dt_pre).transpose(1, 2)
        A = -torch.exp(self.A_log.float())
        
        y = selective_scan_fn(
            x_conv, dt, A, B.transpose(1, 2), C.transpose(1, 2), self.D.float(), 
            z=z.transpose(1,2), delta_bias=self.dt_proj.bias.float(), delta_softplus=True
        )
        return y.transpose(1,2)
        
# -----------------------------------------------------------------------------
# 2. Ablation을 위한 각 단계별 블록 정의
# -----------------------------------------------------------------------------
class FeedForward(nn.Module):
    def __init__(self, d_model: int, hidden_dim: int):
        super().__init__()
        self.w1 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w3 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, d_model, bias=False)
        self.act_fn = F.silu
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.w2(self.act_fn(self.w1(x)) * self.w3(x))

class MambaOnlyBlock(nn.Module):
    """ 1단계: Mamba-only 베이스라인을 위한 블록 """
    def __init__(self, d_model: int, ffn_hidden_dim: int, mamba_params: dict):
        super().__init__()
        self.d_inner = mamba_params['expand'] * d_model
        
        self.norm = RMSNorm(d_model)
        self.in_proj = nn.Linear(d_model, 2 * self.d_inner, bias=False)
        self.mamba_branch = MambaBranch(d_inner=self.d_inner, **{k:v for k,v in mamba_params.items() if k != 'expand'})
        self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
        self.ffn = FeedForward(d_model, ffn_hidden_dim)
        self.ffn_norm = RMSNorm(d_model)

    def forward(self, x: torch.Tensor, **kwargs):
        residual = x
        x_norm = self.norm(x)
        
        xz = self.in_proj(x_norm)
        x_mamba, z_mamba = xz.chunk(2, dim=-1)
        
        mamba_out = self.mamba_branch(x_mamba, z_mamba)
        
        h = residual + self.out_proj(mamba_out)
        out = h + self.ffn(self.ffn_norm(h))
        return out, None

class HymbaBlock(nn.Module):
    """ 2-5단계: 공식 코드 로직을 따르는 HymbaBlock """
    def __init__(self, d_model: int, n_heads: int, n_kv_heads: int, max_seq_len: int, window_size: int, mamba_params: dict):
        super().__init__()
        self.d_inner = mamba_params['expand'] * d_model
        attn_head_dim = self.d_inner // n_heads
        
        latent_dim = self.d_inner + self.d_inner + (attn_head_dim * n_kv_heads * 2)
        self.in_proj = nn.Linear(d_model, latent_dim + self.d_inner, bias=True)
        self.out_proj = nn.Linear(self.d_inner, d_model, bias=True)

        self.attn_branch = AttentionBranch(d_inner=self.d_inner, n_heads=n_heads, n_kv_heads=n_kv_heads, max_seq_len=max_seq_len, window_size=window_size)
        self.mamba_branch = MambaBranch(d_inner=self.d_inner, d_state=mamba_params['d_state'], d_conv=mamba_params['d_conv'], dt_rank=mamba_params['dt_rank'])

        self.norm = RMSNorm(d_model)
        self.norm1 = RMSNorm(self.d_inner)
        self.norm2 = RMSNorm(self.d_inner)

    def forward(self, x: torch.Tensor, past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attn_mask: Optional[torch.Tensor] = None, use_cache: bool = False):
        projected = self.in_proj(self.norm(x))
        
        latent, gate = projected.tensor_split((projected.shape[-1] - self.d_inner,), dim=-1)
        
        attn_q_dim = self.d_inner
        attn_k_dim = self.attn_branch.head_dim * self.attn_branch.n_kv_heads
        
        q, k, v, mamba_x = latent.tensor_split((attn_q_dim, attn_q_dim + attn_k_dim, attn_q_dim + 2 * attn_k_dim), dim=-1)

        attn_out, present_kv = self.attn_branch(q, k, v, past_kv, attn_mask, use_cache)
        mamba_out = self.mamba_branch(mamba_x, gate)

        combined = (self.norm1(attn_out) + self.norm2(mamba_out)) / 2
        return self.out_proj(combined), present_kv

# -----------------------------------------------------------------------------
# 3. 전체 모델 및 마스크 생성
# -----------------------------------------------------------------------------

class HymbaModel(nn.Module):
    def __init__(self, config: dict):
        super().__init__()
        self.config = config
        self.embedding = nn.Embedding(config['vocab_size'], config['d_model'])
        if config['use_meta_tokens']:
            self.meta_tokens = nn.Parameter(torch.randn(1, config['n_meta_tokens'], config['d_model']))

        layers = []
        for _ in range(config['n_layers']):
            if config['use_mamba_only']:
                 layers.append(MambaOnlyBlock(
                    d_model=config['d_model'],
                    ffn_hidden_dim=config['ffn_hidden_dim'],
                    mamba_params=config['mamba_params']
                 ))
            else: # Hymba Block for stages 2-5
                layers.append(HymbaBlock(
                    d_model=config['d_model'], n_heads=config['n_heads'], 
                    n_kv_heads=config['n_kv_heads'], max_seq_len=config['max_seq_len'],
                    window_size=config['window_size'] if config['use_swa'] else -1,
                    mamba_params=config['mamba_params']
                 ))
        self.layers = nn.ModuleList(layers)
        self.norm = RMSNorm(config['d_model'])
        self.lm_head = nn.Linear(config['d_model'], config['vocab_size'], bias=False)

    def _create_attention_mask(self, q_len: int, kv_len: int, window_size: int, n_meta_tokens: int, device: str) -> Optional[torch.Tensor]:
        if q_len > 1:
             mask = torch.full((1, 1, q_len, kv_len), float("-inf"), device=device)
             mask = torch.triu(mask, diagonal=1)
             if window_size > 0:
                 sliding_mask = torch.ones(q_len, kv_len, device=device).bool()
                 sliding_mask.tril_(-1).triu_(-window_size)
                 mask.masked_fill_(~sliding_mask[None, None, ...], float("-inf"))
             if n_meta_tokens > 0:
                 mask[..., n_meta_tokens:, :n_meta_tokens] = 0
             return mask.to(torch.float32)
        return None

    def forward(self, tokens: torch.Tensor, use_cache: bool = False, return_kv_cache: bool = False):
        batch_size, seq_len = tokens.shape
        is_decoding = use_cache and tokens.shape[1] == 1
        
        h = self.embedding(tokens)
        
        current_seq_len = seq_len
        if self.config['use_meta_tokens'] and not is_decoding:
            meta_embeds = self.meta_tokens.expand(batch_size, -1, -1)
            h = torch.cat([meta_embeds, h], dim=1)
            current_seq_len += self.config['n_meta_tokens']
            
        kv_cache_list = [None] * self.config['n_layers']
        
        attn_mask = None
        if not self.config['use_mamba_only'] and not is_decoding:
            attn_mask = self._create_attention_mask(
                q_len=current_seq_len, kv_len=current_seq_len,
                window_size=self.config['window_size'] if self.config['use_swa'] else -1,
                n_meta_tokens=self.config['n_meta_tokens'] if self.config['use_meta_tokens'] else 0,
                device=h.device
            )
        
        residual = h
        for i, layer in enumerate(self.layers):
            past_kv = kv_cache_list[i] if use_cache and is_decoding else None
            if use_cache and is_decoding and i > 0 and self.config['use_shared_kv_cache']:
                 past_kv = kv_cache_list[i-1]
            
            output, present_kv = layer(h, past_kv=past_kv, attn_mask=attn_mask, use_cache=use_cache)
            h = residual + output
            residual = h

            if use_cache:
                kv_cache_list[i] = present_kv

        h = self.norm(h)
        
        if self.config['use_meta_tokens'] and not is_decoding:
            h = h[:, self.config['n_meta_tokens']:]
            
        logits = self.lm_head(h)
        
        if return_kv_cache:
            return logits, kv_cache_list
        return logits

# -----------------------------------------------------------------------------
# 4. 단계별 모델링 및 성능 측정
# -----------------------------------------------------------------------------
if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if not torch.cuda.is_available() and selective_scan_fn is not None:
        print("경고: 이 코드는 mamba-ssm의 CUDA 커널에 의존하므로 GPU 환경에서 실행해야 합니다.")

    base_config = {
        'vocab_size': 32000, 'd_model': 256, 'n_layers': 4,
        'n_heads': 8, 'n_kv_heads': 2, 'max_seq_len': 2048,
        'ffn_hidden_dim': 256 * 4, 'window_size': 256, 'n_meta_tokens': 4,
        'mamba_params': {'d_state': 16, 'd_conv': 4, 'expand': 2, 'dt_rank': math.ceil((256 * 2) / 16)},
    }

    ablation_stages = [
        ("1_Mamba_Only_Baseline",     {'use_mamba_only': True, 'use_ssm_head': False, 'use_meta_tokens': False, 'use_shared_kv_cache': False, 'use_swa': False}),
        ("2_+_Attention (Hymba)",     {'use_mamba_only': False, 'use_ssm_head': True, 'use_meta_tokens': False, 'use_shared_kv_cache': False, 'use_swa': False}),
        ("3_+_Meta_Tokens",           {'use_mamba_only': False, 'use_ssm_head': True, 'use_meta_tokens': True,  'use_shared_kv_cache': False, 'use_swa': False}),
        ("4_+_Shared_KV_Cache",       {'use_mamba_only': False, 'use_ssm_head': True, 'use_meta_tokens': True,  'use_shared_kv_cache': True,  'use_swa': False}),
        ("5_+_SWA",                   {'use_mamba_only': False, 'use_ssm_head': True, 'use_meta_tokens': True,  'use_shared_kv_cache': True,  'use_swa': True}),
    ]
    
    batch_size = 2
    seq_len = 512
    
    for name, flags in ablation_stages:
        print(f"\n{'='*20} {name} {'='*20}")
        config = base_config.copy()
        config.update(flags)
        
        try:
            model = HymbaModel(config).to(device)
            model.eval()
            
            dummy_input = torch.randint(0, config['vocab_size'], (batch_size, seq_len)).to(device)

            with torch.no_grad():
                logits, kv_cache_list = model(dummy_input, use_cache=True, return_kv_cache=True)
            
            total_cache_bytes = 0
            if not config['use_mamba_only']:
                if config['use_shared_kv_cache']:
                    for i in range(1, config['n_layers'], 2):
                         cache = kv_cache_list[i]
                         if cache is not None:
                            k, v = cache
                            total_cache_bytes += k.numel() * k.element_size() + v.numel() * v.element_size()
                else:
                    for cache in kv_cache_list:
                        if cache is not None:
                            k, v = cache
                            total_cache_bytes += k.numel() * k.element_size() + v.numel() * v.element_size()

            total_cache_mb = total_cache_bytes / (1024 * 1024)
            print(f"--- KV Cache 크기: {total_cache_mb:.2f} MB")

            warmup_iterations = 5
            measurement_iterations = 10
            
            with torch.no_grad():
                for _ in range(warmup_iterations):
                    _ = model(dummy_input, use_cache=False)
                if torch.cuda.is_available(): torch.cuda.synchronize()

                start_time = time.time()
                for _ in range(measurement_iterations):
                    _ = model(dummy_input, use_cache=False)
                if torch.cuda.is_available(): torch.cuda.synchronize()
                end_time = time.time()

            total_time = end_time - start_time
            total_tokens = batch_size * seq_len * measurement_iterations
            tokens_per_second = total_tokens / total_time
            print(f"--- 처리량: {tokens_per_second:,.2f} tokens/sec")

        except Exception as e:
            print(f"오류 발생: {e}")



--- KV Cache 크기: 0.00 MB
--- 처리량: 344,487.54 tokens/sec

--- KV Cache 크기: 4.00 MB
--- 처리량: 260,545.80 tokens/sec

--- KV Cache 크기: 4.03 MB
--- 처리량: 247,209.44 tokens/sec

--- KV Cache 크기: 2.02 MB
--- 처리량: 250,966.61 tokens/sec

--- KV Cache 크기: 2.02 MB
--- 처리량: 245,972.06 tokens/sec
