In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class HybridHead(nn.Module):
    def __init__(self, d_model, num_heads, ssm_module, gate_init=(1.0, 1.0)):
        """
        ssm_module: 모듈 인스턴스 (예: Mamba 등)
        gate_init: (alpha_init, beta_init)
        """
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
        self.ssm = ssm_module  # 예: Mamba module
        self.gate_attn = nn.Parameter(torch.tensor(gate_init[0]))
        self.gate_ssm = nn.Parameter(torch.tensor(gate_init[1]))
        # 또는 채널별 스케일 벡터를 쓸 수도 있음 (논문에서는 채널 스케일링 언급됨)
        # self.scale_attn = nn.Parameter(torch.ones(d_model))
        # self.scale_ssm = nn.Parameter(torch.ones(d_model))

    def forward(self, x, attn_mask=None, **ssm_kwargs):
        """
        x: (batch, seq_len, d_model)
        attn_mask: attention 마스크 (예: sliding window)
        ssm_kwargs: SSM 모듈이 필요로 하는 추가 인자
        """
        # 1) Attention branch
        attn_out, _ = self.attn(x, x, x, attn_mask=attn_mask)
        # 2) SSM branch
        ssm_out = self.ssm(x, **ssm_kwargs)
        # 3) 스케일 / 게이트 조정 (채널 단위 또는 스칼라)
        out = self.gate_attn * attn_out + self.gate_ssm * ssm_out
        return out


In [2]:
class KVCachePool:
    def __init__(self, num_layers, share_pairs: bool = True):
        self.share_pairs = share_pairs
        self.caches = {}

    def _key(self, layer_idx: int):
        if not self.share_pairs:
            return layer_idx
        return layer_idx // 2  # 예: 레이어 0 & 1 공유, 2 & 3 공유, etc.

    def get(self, layer_idx: int):
        return self.caches.get(self._key(layer_idx), None)

    def set(self, layer_idx: int, kv):
        self.caches[self._key(layer_idx)] = kv


In [3]:
def build_sliding_mask(seq_len: int, window_size: int, device):
    mask = torch.full((seq_len, seq_len), float("-inf"), device=device)
    for i in range(seq_len):
        start = max(0, i - window_size + 1)
        mask[i, start : i+1] = 0
    return mask

# 사용할지/말지 옵션
class AttentionWithOption(nn.Module):
    def __init__(self, d_model, num_heads, use_swa=False, window_size=16):
        super().__init__()
        self.use_swa = use_swa
        self.window_size = window_size
        self.attn = nn.MultiheadAttention(d_model, num_heads, batch_first=True)

    def forward(self, x):
        mask = None
        if self.use_swa:
            mask = build_sliding_mask(x.size(1), self.window_size, x.device)
        out, weights = self.attn(x, x, x, attn_mask=mask, need_weights=True)
        return out, weights


In [4]:
class MetaTokenPrepend(nn.Module):
    def __init__(self, num_meta: int, d_model: int, use_meta: bool = True):
        super().__init__()
        self.use_meta = use_meta
        if use_meta:
            self.meta = nn.Parameter(torch.randn(1, num_meta, d_model))

    def forward(self, x):
        if not self.use_meta:
            return x
        B = x.size(0)
        meta = self.meta.expand(B, -1, -1)
        x2 = torch.cat([meta, x], dim=1)
        return x2


In [5]:
class HymbaBlock(nn.Module):
    def __init__(self, layer_idx, d_model, num_heads, ssm_module,
                 kv_cache_pool: KVCachePool,
                 use_kv_share=False,
                 use_swa=False, window_size=16,
                 use_meta=False, num_meta=1):
        super().__init__()
        self.layer_idx = layer_idx
        self.kv_cache_pool = kv_cache_pool
        self.use_kv_share = use_kv_share

        self.meta_block = MetaTokenPrepend(num_meta, d_model, use_meta)
        self.attn_opt = AttentionWithOption(d_model, num_heads, use_swa, window_size)
        self.hybrid = HybridHead(d_model, num_heads, ssm_module)  # 병합된 head
        self.norm = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, 4*d_model),
            nn.GELU(),
            nn.Linear(4*d_model, d_model)
        )

    def forward(self, x):
        # 입력 + meta token
        x = self.meta_block(x)

        # attention part: optional sliding mask
        attn_out, _ = self.attn_opt(x)

        # hybrid head (attention + SSM)
        # 여기선 하이브리드가 내부에서 attn 적용하므로, 단순히 전달
        out = self.hybrid(x)

        # 캐시 저장 / 재사용 (필요시)
        if self.use_kv_share and self.kv_cache_pool is not None:
            cached = self.kv_cache_pool.get(self.layer_idx)
            if cached is not None:
                out = cached
            else:
                self.kv_cache_pool.set(self.layer_idx, out)

        # residual + norm + FF
        x2 = self.norm(x + out)
        y = x2 + self.ff(x2)

        return y


In [6]:
class HymbaModel(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, num_layers,
                 ssm_constructor, use_kv_share=True,
                 swa_ratio=0.9, window_size=16, num_meta=1):
        """
        swa_ratio: 몇 %의 레이어에 SWA를 적용할지 (예: 0.9 → 90%)
        ssm_constructor: 예: lambda: Mamba(...)
        """
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.kv_cache_pool = KVCachePool(num_layers, share_pairs=use_kv_share)
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            use_swa = (i / num_layers) >= (1 - swa_ratio)  # 예: 하위 일부 레이어만 global
            self.layers.append(
                HymbaBlock(layer_idx=i,
                           d_model=d_model,
                           num_heads=num_heads,
                           ssm_module=ssm_constructor(),
                           kv_cache_pool=self.kv_cache_pool,
                           use_kv_share=use_kv_share,
                           use_swa=use_swa,
                           window_size=window_size,
                           use_meta=(i == 0),
                           num_meta=num_meta
                )
            )
        self.norm = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size)

    def forward(self, input_ids):
        x = self.embed(input_ids)
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        logits = self.head(x)
        return logits


In [4]:
!pip install torch torchvision torchaudio torchtext --index-url https://download.pytorch.org/whl/cu121

[0mLooking in indexes: https://download.pytorch.org/whl/cu121
Collecting torch
  Downloading https://download.pytorch.org/whl/cu121/torch-2.5.1%2Bcu121-cp310-cp310-linux_x86_64.whl (780.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m780.4/780.4 MB[0m [31m108.5 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting torchvision
  Downloading https://download.pytorch.org/whl/cu121/torchvision-0.20.1%2Bcu121-cp310-cp310-linux_x86_64.whl (7.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.3/7.3 MB[0m [31m122.5 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hCollecting torchaudio
  Downloading https://download.pytorch.org/whl/cu121/torchaudio-2.5.1%2Bcu121-cp310-cp310-linux_x86_64.whl (3.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.4/3.4 MB[0m [31m129.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchtext
  Downloading https://download.pytorch.org/whl/torchtext-0.17.0%2Bcpu-cp310-cp310-linux_x86_

In [None]:
!pip install mamba-ssm

[0mCollecting mamba-ssm
  Downloading mamba_ssm-2.2.5.tar.gz (113 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m113.8/113.8 KB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l/

In [1]:
from mamba_ssm import Mamba

ImportError: /usr/local/lib/python3.10/dist-packages/selective_scan_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN3c104cuda9SetDeviceEab

In [1]:
import torch
import torch.nn as nn
from mamba_ssm import Mamba

# Device 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 예제 HymbaModel 생성
vocab_size = 10000
d_model = 128
num_heads = 4
num_layers = 6
window_size = 32
num_meta = 2

def make_ssm():
    return Mamba(d_model, d_state=16, d_conv=4, expand=2).to(device)

model = HymbaModel(vocab_size, d_model, num_heads, num_layers,
                   ssm_constructor=make_ssm,
                   use_kv_share=True,
                   swa_ratio=0.8, window_size=window_size,
                   num_meta=num_meta).to(device)

# 입력을 device로 이동
x = torch.randint(0, vocab_size, (2, 20), device=device)

# Forward pass
logits = model(x)
print("logits shape:", logits.shape)


ImportError: /usr/local/lib/python3.10/dist-packages/selective_scan_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN3c1021throwNullDataPtrErrorEv

In [None]:
from datasets import load_dataset

# WikiText-2 불러오기
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")

train_text = " ".join(dataset["train"]["text"])
valid_text = " ".join(dataset["validation"]["text"])
test_text  = " ".join(dataset["test"]["text"])

print("Train sample:", train_text[:200])


In [None]:
from collections import Counter
import torch

def build_vocab(texts, min_freq=2):
    counter = Counter()
    for t in texts.split():
        counter[t] += 1
    vocab = {w: i for i, (w, c) in enumerate(counter.items()) if c >= min_freq}
    vocab["<unk>"] = len(vocab)
    return vocab

vocab = build_vocab(train_text)

def encode(text):
    return torch.tensor([vocab.get(w, vocab["<unk>"]) for w in text.split()], dtype=torch.long)

train_ids = encode(train_text)
val_ids   = encode(valid_text)
test_ids  = encode(test_text)


In [None]:
bptt = 64
batch_size = 32

def batchify(data, bsz):
    nbatch = data.size(0) // bsz
    data = data[:nbatch*bsz]
    return data.view(bsz, -1).t().contiguous()

train_data = batchify(train_ids, batch_size).to("cuda")
val_data   = batchify(val_ids, batch_size).to("cuda")

def get_batch(source, i):
    seq_len = min(bptt, source.size(0)-1-i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].reshape(-1)
    return data, target


In [None]:
d_model, num_heads, num_layers = 128, 4, 6

model = HymbaModel(len(vocab), d_model, num_heads, num_layers,
                   swa_layers_ratio=0.8,   # 하위 80% 레이어는 SWA
                   use_kv_share=True,      # KV Cache Sharing
                   num_meta=2).to("cuda")


In [None]:
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler

criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=3e-4)
scaler = GradScaler()

def evaluate(data_source):
    model.eval()
    total_loss = 0.
    with torch.no_grad():
        for i in range(0, data_source.size(0)-1, bptt):
            data, targets = get_batch(data_source, i)
            with autocast(device_type="cuda"):
                output = model(data)
                loss = criterion(output.view(-1, len(vocab)), targets)
            total_loss += loss.item() * len(data)
    return total_loss / (len(data_source)-1)

model.train()
for epoch in range(1):
    for i in range(0, train_data.size(0)-1, bptt):
        data, targets = get_batch(train_data, i)
        optimizer.zero_grad()
        with autocast(device_type="cuda"):
            output = model(data)
            loss = criterion(output.view(-1, len(vocab)), targets)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        if i % (bptt*100) == 0:
            print(f"step {i}, loss {loss.item():.4f}")

val_loss = evaluate(val_data)
print("Validation Loss:", val_loss, " | PPL:", torch.exp(torch.tensor(val_loss)))


In [None]:
def generate(model, start_text="the", max_len=30):
    model.eval()
    tokens = torch.tensor([[vocab.get(w, vocab["<unk>"]) for w in start_text.split()]], device="cuda")
    for _ in range(max_len):
        with torch.no_grad(), autocast(device_type="cuda"):
            logits = model(tokens)
            next_token = torch.argmax(logits[:, -1], dim=-1).unsqueeze(0)
            tokens = torch.cat([tokens, next_token], dim=1)
    inv_vocab = {i: w for w, i in vocab.items()}
    return " ".join(inv_vocab[t.item()] for t in tokens[0])

print("Sample generation:", generate(model, "the king"))
