# 02. Model Architecture Exploration

Parametric QA 모델 구조 이해 및 테스트

In [None]:
import sys
from pathlib import Path

PROJECT_ROOT = Path(".").resolve().parent
sys.path.insert(0, str(PROJECT_ROOT))

import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model

## 1. Document Vector (z_i) 개념

In [None]:
# Parametric QA의 핵심: Document → Learnable Vector z_i

# 하이퍼파라미터
NUM_DOCS = 1000      # 문서 수
M_TOKENS = 4         # z_i가 차지하는 virtual token 수
Z_DIM = 256          # z_i의 차원

# Document Vectors: [num_docs, m_tokens, z_dim]
doc_vectors = nn.Parameter(torch.randn(NUM_DOCS, M_TOKENS, Z_DIM) * 0.02)

print(f"Document vectors shape: {doc_vectors.shape}")
print(f"Total parameters: {doc_vectors.numel():,}")
print(f"Storage per doc: {M_TOKENS * Z_DIM * 4} bytes (float32)")

In [None]:
# z_i 인덱싱
doc_indices = torch.tensor([0, 5, 10])  # 3개 문서 선택
selected_z = doc_vectors[doc_indices]   # [3, m_tokens, z_dim]

print(f"Selected z shape: {selected_z.shape}")

## 2. Query Encoder

In [None]:
# E5-base-v2를 query encoder로 사용
QUERY_ENCODER_NAME = "intfloat/e5-base-v2"

# 로드 (실제로는 시간이 걸릴 수 있음)
# query_encoder = AutoModel.from_pretrained(QUERY_ENCODER_NAME)
# query_tokenizer = AutoTokenizer.from_pretrained(QUERY_ENCODER_NAME)

print(f"Query encoder: {QUERY_ENCODER_NAME}")
print("E5 output dim: 768")

In [None]:
# Query embedding 추출 (mean pooling)
def get_query_embedding(query_encoder, input_ids, attention_mask):
    """
    Query를 embedding으로 변환
    """
    outputs = query_encoder(input_ids=input_ids, attention_mask=attention_mask)
    hidden_states = outputs.last_hidden_state  # [B, seq_len, hidden]
    
    # Mean pooling
    mask_expanded = attention_mask.unsqueeze(-1).float()
    sum_hidden = (hidden_states * mask_expanded).sum(dim=1)
    sum_mask = mask_expanded.sum(dim=1).clamp(min=1e-9)
    query_emb = sum_hidden / sum_mask  # [B, hidden]
    
    return query_emb

## 3. Selection Methods (Router)

In [None]:
from models.router import CosineSelector, LearnedRouter, AttentionSelector

# 세 가지 selection method
print("Available selection methods:")
print("1. CosineSelector: cosine similarity 기반")
print("2. LearnedRouter: 학습 가능한 MLP router")
print("3. AttentionSelector: cross-attention 기반")

In [None]:
# CosineSelector 예시
class CosineSelector(nn.Module):
    def __init__(self, query_dim: int, z_dim: int, temperature: float = 0.1):
        super().__init__()
        self.query_proj = nn.Linear(query_dim, z_dim)
        self.temperature = temperature
    
    def forward(self, query_emb, doc_vectors):
        """
        Args:
            query_emb: [B, query_dim]
            doc_vectors: [num_docs, m_tokens, z_dim]
        Returns:
            scores: [B, num_docs]
        """
        # Project query
        q = self.query_proj(query_emb)  # [B, z_dim]
        q = nn.functional.normalize(q, dim=-1)
        
        # Mean pool z_i across m_tokens
        z_mean = doc_vectors.mean(dim=1)  # [num_docs, z_dim]
        z_mean = nn.functional.normalize(z_mean, dim=-1)
        
        # Cosine similarity
        scores = torch.matmul(q, z_mean.T) / self.temperature  # [B, num_docs]
        
        return scores

# 테스트
selector = CosineSelector(query_dim=768, z_dim=Z_DIM)
dummy_query = torch.randn(2, 768)  # batch of 2 queries
scores = selector(dummy_query, doc_vectors)
print(f"Scores shape: {scores.shape}")  # [2, 1000]

## 4. LLM with QLoRA

In [None]:
# QLoRA 설정
LLM_NAME = "Qwen/Qwen3-8B"

# 4-bit quantization config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

# LoRA config
lora_config = LoraConfig(
    r=16,                      # rank
    lora_alpha=32,             # scaling
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

print("QLoRA configuration:")
print(f"  Quantization: 4-bit NF4")
print(f"  LoRA rank: {lora_config.r}")
print(f"  Target modules: {lora_config.target_modules}")

In [None]:
# LLM 로드 (메모리가 충분할 때만 실행)
# llm = AutoModelForCausalLM.from_pretrained(
#     LLM_NAME,
#     quantization_config=bnb_config,
#     device_map="auto",
#     trust_remote_code=True,
# )
# llm = get_peft_model(llm, lora_config)
# llm.print_trainable_parameters()

## 5. z → LLM Embedding Projection

In [None]:
# z_i를 LLM embedding 공간으로 projection
LLM_HIDDEN_DIM = 4096  # Qwen3-8B hidden size

z_to_embedding = nn.Sequential(
    nn.Linear(Z_DIM, LLM_HIDDEN_DIM),
    nn.GELU(),
    nn.Linear(LLM_HIDDEN_DIM, LLM_HIDDEN_DIM),
    nn.LayerNorm(LLM_HIDDEN_DIM),
)

# 테스트
z_sample = doc_vectors[0:1]  # [1, m_tokens, z_dim]
z_projected = z_to_embedding(z_sample)  # [1, m_tokens, llm_hidden_dim]
print(f"Projected z shape: {z_projected.shape}")

## 6. Full Model Architecture

In [None]:
# ParametricQA 모델 구조 개요

class ParametricQAOverview(nn.Module):
    """
    Parametric QA 모델 구조 (간소화 버전)
    
    Components:
    1. doc_vectors: [num_docs, m_tokens, z_dim] - learnable document vectors
    2. query_encoder: E5-base-v2 - query embedding
    3. selector: CosineSelector/LearnedRouter/AttentionSelector
    4. z_to_embedding: z_i → LLM embedding space
    5. llm: Qwen3-8B with QLoRA
    
    Forward pass:
    - Write phase: z_i → LLM → reconstruct D_i
    - Read phase: query → select z_i → z_i + query → LLM → answer
    """
    
    def __init__(self, num_docs, z_dim, m_tokens):
        super().__init__()
        self.doc_vectors = nn.Parameter(torch.randn(num_docs, m_tokens, z_dim) * 0.02)
        # ... other components
    
    def write_phase_forward(self, doc_ids, doc_input_ids, doc_attention_mask):
        """Write phase: z_i → reconstruct D_i"""
        z = self.doc_vectors[doc_ids]  # [B, m_tokens, z_dim]
        # Project and generate
        pass
    
    def forward(self, query_ids, doc_indices, answer_ids):
        """Read phase: query + z_selected → answer"""
        # Select documents
        z_selected = self.doc_vectors[doc_indices]  # [B, k, m_tokens, z_dim]
        # Generate answer
        pass

print("Model architecture loaded (see docstring for details)")

In [None]:
# 실제 모델 로드 (전체 모듈)
from models.parametric_qa import ParametricQA

# Config
model_config = {
    "llm_name": "Qwen/Qwen3-8B",
    "num_docs": 1000,
    "z_dim": 256,
    "m_tokens": 4,
    "selection_method": "cosine",
    "query_encoder_name": "intfloat/e5-base-v2",
    "lora_r": 16,
    "lora_alpha": 32,
    "use_4bit": True,
}

print("Model configuration:")
for k, v in model_config.items():
    print(f"  {k}: {v}")

# 실제 초기화는 GPU 메모리가 충분할 때만
# model = ParametricQA(**model_config)

## 7. Parameter Count 분석

In [None]:
def count_parameters(model):
    """모델의 파라미터 수 계산"""
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable

# 예상 파라미터 수
print("Expected parameter counts:")
print(f"  doc_vectors: {NUM_DOCS * M_TOKENS * Z_DIM:,}")
print(f"  query_encoder (E5-base): ~110M (frozen)")
print(f"  z_to_embedding: ~{2 * Z_DIM * LLM_HIDDEN_DIM:,}")
print(f"  LLM (Qwen3-8B): ~8B (mostly frozen, QLoRA ~8M trainable)")
print(f"  selector: ~{768 * Z_DIM:,}")