In [None]:
import numpy as np
import plotly.graph_objects as go
import ipywidgets as widgets
from IPython.display import display

import torch, time, psutil, os
import torch.nn as nn

### KV Cache Sharing

In [4]:
def kv_cache_tradeoff(max_layers=24, seq_len=1024, hidden_dim=1024):
    layers = np.arange(2, max_layers+1, 2)
    
    # 단순화된 메모리 모델 (단위: MB)
    token_size = hidden_dim * 2  # Key + Value
    bytes_per_token = token_size * 4  # float32
    base_mem = seq_len * bytes_per_token / (1024**2)
    
    mem_no_share = layers * base_mem
    mem_pair_share = (layers/2) * base_mem
    
    # Latency 모델 (임의 스케일링: ms 단위)
    base_latency = 0.05  # 1 cache access = 0.05ms 가정
    latency_no_share = layers * base_latency
    latency_pair_share = (layers/2) * base_latency
    
    # Plot
    fig = go.Figure()
    fig.add_trace(go.Scatter(
        x=layers, y=mem_no_share, mode="lines+markers",
        name="Memory No Sharing", line=dict(color="red", width=3)
    ))
    fig.add_trace(go.Scatter(
        x=layers, y=mem_pair_share, mode="lines+markers",
        name="Memory Pair Sharing", line=dict(color="green", width=3, dash="dot")
    ))
    fig.add_trace(go.Scatter(
        x=layers, y=latency_no_share*100, mode="lines+markers",
        name="Latency No Sharing (scaled)", line=dict(color="orange", width=3)
    ))
    fig.add_trace(go.Scatter(
        x=layers, y=latency_pair_share*100, mode="lines+markers",
        name="Latency Pair Sharing (scaled)", line=dict(color="blue", width=3, dash="dot")
    ))
    
    fig.update_layout(
        title="KV Cache Trade-off: Memory vs Latency",
        xaxis_title="Number of Layers",
        yaxis_title="Relative Units (Memory MB / Latency*100)",
        template="plotly_white"
    )
    fig.show()

widgets.interact(
    kv_cache_tradeoff,
    max_layers=(4, 48, 2),
    seq_len=(512, 8192, 512),
    hidden_dim=(256, 2048, 256)
);


interactive(children=(IntSlider(value=24, description='max_layers', max=48, min=4, step=2), IntSlider(value=10…

In [7]:
class KVCache:
    """KV 캐시 관리: 2개 레이어가 하나의 캐시를 공유"""
    def __init__(self, num_layers, share_pairs=True):
        self.share_pairs = share_pairs
        self.caches = {}
        self.num_layers = num_layers

    def get_key(self, layer_idx):
        if not self.share_pairs:
            return f"layer_{layer_idx}"
        # 두 레이어를 묶어 같은 캐시 키 반환
        pair_idx = layer_idx // 2
        return f"pair_{pair_idx}"

    def get(self, layer_idx):
        return self.caches.get(self.get_key(layer_idx))

    def set(self, layer_idx, kv):
        self.caches[self.get_key(layer_idx)] = kv

class KVLayer(nn.Module):
    def __init__(self, d_model, layer_idx, kv_cache: KVCache):
        super().__init__()
        self.fc = nn.Linear(d_model, d_model)
        self.layer_idx = layer_idx
        self.kv_cache = kv_cache

    def forward(self, x):
        cached = self.kv_cache.get(self.layer_idx)
        if cached is not None:
            return cached
        out = self.fc(x)
        self.kv_cache.set(self.layer_idx, out)
        return out

# 사용 예시
d_model = 64
layers = 6
kv_cache = KVCache(num_layers=layers, share_pairs=True)
x = torch.randn(1, 10, d_model)

model = nn.ModuleList([KVLayer(d_model, i, kv_cache) for i in range(layers)])
for layer in model:
    out = layer(x)


### SWA(Sliding Window Attention)

In [4]:
def swa_attention_map(seq_len=16, window_size=4, highlight_row=8):
    attn = np.zeros((seq_len, seq_len))
    for i in range(seq_len):
        for j in range(seq_len):
            if j <= i and j >= i - window_size + 1:
                attn[i,j] = 1
    fig = go.Figure(data=go.Heatmap(
        z=attn,
        x=[f"K{j}" for j in range(seq_len)],
        y=[f"Q{i}" for i in range(seq_len)],
        colorscale=[[0,"#f3f4f6"],[1,"#22c55e"]],
        showscale=False,
        hovertemplate="Query=%{y}, Key=%{x}, Active=%{z}<extra></extra>"
    ))
    # 강조 row에 라인
    fig.add_shape(type="rect", x0=-0.5, y0=highlight_row-0.5, x1=seq_len-0.5, y1=highlight_row+0.5,
                  line=dict(color="blue", width=2))
    fig.update_layout(
        title=f"SWA Attention Map (Window={window_size}, Highlight Q{highlight_row})",
        xaxis=dict(side="top")
    )
    fig.show()

widgets.interact(swa_attention_map, seq_len=(8,32,1), window_size=(1,16,1), highlight_row=(1,32,1));


interactive(children=(IntSlider(value=16, description='seq_len', max=32, min=8), IntSlider(value=4, descriptio…

In [4]:
# PyTorch 코드로 mask 비교
def build_mask(n,w,dev):
    m=torch.full((n,n),float("-inf"),device=dev)
    for i in range(n):
        m[i,max(0,i-w):i+1]=0
    return m

print("Full mask:", build_mask(8,8,"cpu"))
print("Window mask (w=3):", build_mask(8,3,"cpu"))


Full mask: tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])
Window mask (w=3): tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [-inf, 0., 0., 0., 0., -inf, -inf, -inf],
        [-inf, -inf, 0., 0., 0., 0., -inf, -inf],
        [-inf, -inf, -inf, 0., 0., 0., 0., -inf],
        [-inf, -inf, -inf, -inf, 0., 0., 0., 0.]])


### Meta Token

In [None]:
def meta_attention_map(num_tokens=12, use_meta=True):
    size = num_tokens + (1 if use_meta else 0)
    attn = np.zeros((size, size))
    for i in range(size):
        for j in range(size):
            if j <= i:  # causal
                if use_meta and j == 0:
                    attn[i,j] = 0.9
                elif abs(i-j) <= 1:
                    attn[i,j] = 0.6
                else:
                    attn[i,j] = 0.2
    
    fig = go.Figure(data=go.Heatmap(
        z=attn,
        x=[f"K{j}" if not(use_meta and j==0) else "META" for j in range(size)],
        y=[f"Q{i}" if not(use_meta and i==0) else "META" for i in range(size)],
        colorscale="Blues",
        hoverongaps=False,
        colorbar=dict(title="Attention")
    ))
    fig.update_layout(
        title="Meta Token Heatmap (허브 역할)" if use_meta else "Standard Attention Heatmap",
        xaxis=dict(side="top")
    )
    fig.show()

widgets.interact(meta_attention_map, num_tokens=(4,16,1), use_meta=[False, True]);


interactive(children=(IntSlider(value=12, description='num_tokens', max=16, min=4), Dropdown(description='use_…

In [6]:
# PyTorch 코드 예시 (Meta Token 적용/비적용)
class MetaBlock(nn.Module):
    def __init__(self,d,heads=2,use_meta=True):
        super().__init__(); self.attn=nn.MultiheadAttention(d,heads,batch_first=True); self.use_meta=use_meta
    def forward(self,x):
        if self.use_meta:
            meta=torch.zeros(x.size(0),1,x.size(-1))
            x=torch.cat([meta,x],dim=1)
        out,w=self.attn(x,x,x,need_weights=True,average_attn_weights=False)
        return out,w

x=torch.randn(1,16,8)
for use_meta in [False,True]:
    blk=MetaBlock(8,use_meta=use_meta)
    _,w=blk(x)
    print("With Meta" if use_meta else "Without Meta", "weights shape:",w.shape)


Without Meta weights shape: torch.Size([1, 2, 16, 16])
With Meta weights shape: torch.Size([1, 2, 17, 17])
