# 03. Write Phase Training

Stage 1: Document → z_i 학습 (Document Reconstruction)

In [None]:
import sys
from pathlib import Path

PROJECT_ROOT = Path(".").resolve().parent
sys.path.insert(0, str(PROJECT_ROOT))

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from tqdm.notebook import tqdm
import yaml
from omegaconf import OmegaConf

## 1. Configuration

In [None]:
# Config 로드
config_path = PROJECT_ROOT / "configs" / "phase1_poc.yaml"
with open(config_path) as f:
    config = OmegaConf.create(yaml.safe_load(f))

# Write phase config
write_config = config.write_phase
print("Write Phase Configuration:")
print(OmegaConf.to_yaml(write_config))

In [None]:
# 하이퍼파라미터
BATCH_SIZE = write_config.batch_size
LEARNING_RATE = write_config.lr_z
NUM_EPOCHS = write_config.epochs
MAX_DOC_LENGTH = 512

print(f"Batch size: {BATCH_SIZE}")
print(f"Learning rate (z): {LEARNING_RATE}")
print(f"Epochs: {NUM_EPOCHS}")

## 2. 데이터 로드

In [None]:
from data.dataloader import WritePhaseDataset
import json

# 샘플 corpus (실제로는 전처리된 데이터 로드)
sample_corpus = {
    0: "France is a country in Western Europe. Paris is the capital and largest city.",
    1: "Romeo and Juliet is a tragedy written by William Shakespeare.",
    2: "The Great Wall of China is a series of fortifications made of various materials.",
    3: "Albert Einstein developed the theory of relativity, one of the two pillars of modern physics.",
    4: "The Amazon rainforest produces about 20% of the world's oxygen.",
}

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(config.model.llm_name)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Dataset
write_dataset = WritePhaseDataset(sample_corpus, tokenizer, max_length=MAX_DOC_LENGTH)
write_loader = DataLoader(write_dataset, batch_size=BATCH_SIZE, shuffle=True)

print(f"Corpus size: {len(sample_corpus)}")
print(f"Dataset size: {len(write_dataset)}")

## 3. 모델 초기화

In [None]:
from models.parametric_qa import ParametricQA

# 디바이스 설정
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# 모델 초기화 (메모리가 충분할 때만)
# model = ParametricQA(
#     llm_name=config.model.llm_name,
#     num_docs=len(sample_corpus),
#     z_dim=config.parametric_qa.z_dim,
#     m_tokens=config.parametric_qa.m_tokens,
#     selection_method=config.parametric_qa.selection_method,
#     lora_r=config.model.lora.r,
#     lora_alpha=config.model.lora.alpha,
#     use_4bit=config.model.use_4bit,
# )
# model.to(device)

## 4. Write Phase Trainer

In [None]:
class WritePhaseTrainer:
    """
    Write Phase Training
    
    목표: max log P(D_i | z_i; θ)
    - z_i가 문서 D_i를 재구성할 수 있도록 학습
    """
    
    def __init__(
        self,
        model,
        tokenizer,
        lr_z: float = 1e-3,
        lr_lora: float = 1e-4,
        llm_frozen: bool = True,
        device: str = "cuda",
    ):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.llm_frozen = llm_frozen
        
        # Optimizer 구성
        param_groups = [
            {"params": [model.doc_vectors], "lr": lr_z, "name": "z"},
            {"params": model.z_to_embedding.parameters(), "lr": lr_lora, "name": "projection"},
        ]
        
        if not llm_frozen:
            lora_params = [p for n, p in model.llm.named_parameters() if "lora" in n.lower()]
            param_groups.append({"params": lora_params, "lr": lr_lora, "name": "lora"})
        
        self.optimizer = torch.optim.AdamW(param_groups)
    
    def train_step(self, batch):
        """단일 training step"""
        self.model.train()
        
        doc_ids = batch["doc_id"].to(self.device)
        input_ids = batch["input_ids"].to(self.device)
        attention_mask = batch["attention_mask"].to(self.device)
        
        # Forward
        loss = self.model.write_phase_forward(
            doc_ids=doc_ids,
            doc_input_ids=input_ids,
            doc_attention_mask=attention_mask,
        )
        
        # Backward
        self.optimizer.zero_grad()
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
        
        self.optimizer.step()
        
        return loss.item()
    
    def train_epoch(self, dataloader, epoch: int):
        """한 epoch 학습"""
        total_loss = 0
        num_batches = 0
        
        pbar = tqdm(dataloader, desc=f"Epoch {epoch}")
        for batch in pbar:
            loss = self.train_step(batch)
            total_loss += loss
            num_batches += 1
            pbar.set_postfix({"loss": f"{loss:.4f}"})
        
        avg_loss = total_loss / max(num_batches, 1)
        return avg_loss

print("WritePhaseTrainer defined")

## 5. Training Loop

In [None]:
# Training 실행 (모델이 로드된 경우)
# trainer = WritePhaseTrainer(
#     model=model,
#     tokenizer=tokenizer,
#     lr_z=write_config.lr_z,
#     lr_lora=write_config.lr_lora,
#     llm_frozen=write_config.llm_frozen,
#     device=device,
# )

# losses = []
# for epoch in range(NUM_EPOCHS):
#     avg_loss = trainer.train_epoch(write_loader, epoch + 1)
#     losses.append(avg_loss)
#     print(f"Epoch {epoch + 1}/{NUM_EPOCHS}, Loss: {avg_loss:.4f}")

In [None]:
# Loss 시각화
import matplotlib.pyplot as plt

# losses가 있을 경우
# plt.figure(figsize=(10, 6))
# plt.plot(losses)
# plt.xlabel("Epoch")
# plt.ylabel("Loss")
# plt.title("Write Phase Training Loss")
# plt.grid(True)
# plt.show()

## 6. z_i 품질 검증

In [None]:
def evaluate_reconstruction(model, corpus, tokenizer, device, num_samples=5):
    """
    z_i가 문서를 얼마나 잘 재구성하는지 평가
    """
    model.eval()
    results = []
    
    for doc_id in list(corpus.keys())[:num_samples]:
        original = corpus[doc_id]
        
        # z_i로 생성
        with torch.no_grad():
            doc_ids = torch.tensor([doc_id]).to(device)
            generated_ids = model.generate_from_z(
                doc_ids=doc_ids,
                max_new_tokens=100,
            )
            generated = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        
        results.append({
            "doc_id": doc_id,
            "original": original[:200],
            "generated": generated[:200],
        })
    
    return results

# 평가 실행
# results = evaluate_reconstruction(model, sample_corpus, tokenizer, device)
# for r in results:
#     print(f"\n=== Doc {r['doc_id']} ===")
#     print(f"Original: {r['original']}")
#     print(f"Generated: {r['generated']}")

## 7. Checkpoint 저장

In [None]:
def save_checkpoint(model, optimizer, epoch, loss, save_path):
    """체크포인트 저장"""
    save_path = Path(save_path)
    save_path.parent.mkdir(parents=True, exist_ok=True)
    
    checkpoint = {
        "epoch": epoch,
        "loss": loss,
        "doc_vectors": model.doc_vectors.data.cpu(),
        "z_to_embedding_state_dict": model.z_to_embedding.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
    }
    
    # LoRA weights도 저장
    model.llm.save_pretrained(save_path.parent / "lora_weights")
    
    torch.save(checkpoint, save_path)
    print(f"Checkpoint saved to {save_path}")

# 저장
# save_checkpoint(
#     model, trainer.optimizer, NUM_EPOCHS, losses[-1],
#     PROJECT_ROOT / "checkpoints" / "phase1" / "write_final.pt"
# )

In [None]:
def load_checkpoint(model, checkpoint_path):
    """체크포인트 로드"""
    checkpoint = torch.load(checkpoint_path, map_location="cpu")
    
    # z vectors 로드
    model.doc_vectors.data = checkpoint["doc_vectors"].to(model.doc_vectors.device)
    
    # Projection 로드
    model.z_to_embedding.load_state_dict(checkpoint["z_to_embedding_state_dict"])
    
    print(f"Loaded checkpoint from epoch {checkpoint['epoch']}, loss: {checkpoint['loss']:.4f}")
    
    return checkpoint

# 로드
# checkpoint = load_checkpoint(model, PROJECT_ROOT / "checkpoints" / "phase1" / "write_final.pt")