# 08. Analysis and Visualization

실험 결과 분석 및 시각화

In [None]:
import sys
from pathlib import Path

PROJECT_ROOT = Path(".").resolve().parent
sys.path.insert(0, str(PROJECT_ROOT))

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import json

## 1. z 벡터 시각화

In [None]:
def visualize_z_vectors(doc_vectors, labels=None, method="tsne", perplexity=30):
    """
    z_i 벡터 2D 시각화
    
    Args:
        doc_vectors: [num_docs, m_tokens, z_dim]
        labels: 문서 카테고리 (optional)
        method: 'tsne' or 'pca'
    """
    # Mean pool across m_tokens
    z_mean = doc_vectors.mean(dim=1).cpu().numpy()  # [num_docs, z_dim]
    
    # Dimensionality reduction
    if method == "tsne":
        reducer = TSNE(n_components=2, perplexity=perplexity, random_state=42)
    else:
        reducer = PCA(n_components=2)
    
    z_2d = reducer.fit_transform(z_mean)
    
    # Plot
    plt.figure(figsize=(12, 10))
    
    if labels is not None:
        unique_labels = list(set(labels))
        colors = plt.cm.tab10(np.linspace(0, 1, len(unique_labels)))
        
        for i, label in enumerate(unique_labels):
            mask = [l == label for l in labels]
            plt.scatter(
                z_2d[mask, 0], z_2d[mask, 1],
                c=[colors[i]], label=label, alpha=0.6, s=50
            )
        plt.legend()
    else:
        plt.scatter(z_2d[:, 0], z_2d[:, 1], alpha=0.6, s=50)
    
    plt.xlabel("Dimension 1")
    plt.ylabel("Dimension 2")
    plt.title(f"Document Vectors ({method.upper()})")
    plt.grid(True, alpha=0.3)
    
    return z_2d

# 예시 데이터
# doc_vectors = model.doc_vectors.data
# z_2d = visualize_z_vectors(doc_vectors, method="tsne")
# plt.show()

In [None]:
# 샘플 시각화 (랜덤 데이터)
np.random.seed(42)
sample_vectors = torch.randn(100, 4, 256)
sample_labels = ["Category A"] * 30 + ["Category B"] * 40 + ["Category C"] * 30

z_2d = visualize_z_vectors(sample_vectors, sample_labels, method="pca")
plt.show()

## 2. Selection 분포 분석

In [None]:
def analyze_selection_distribution(selection_scores, gold_doc_ids):
    """
    Selection score 분포 분석
    
    Args:
        selection_scores: [num_queries, num_docs]
        gold_doc_ids: list of gold document indices per query
    """
    gold_scores = []
    non_gold_scores = []
    
    for i, scores in enumerate(selection_scores):
        gold_ids = gold_doc_ids[i]
        for j, score in enumerate(scores):
            if j in gold_ids:
                gold_scores.append(score)
            else:
                non_gold_scores.append(score)
    
    # 분포 시각화
    plt.figure(figsize=(10, 6))
    
    plt.hist(non_gold_scores, bins=50, alpha=0.5, label="Non-gold", color="blue", density=True)
    plt.hist(gold_scores, bins=50, alpha=0.5, label="Gold", color="red", density=True)
    
    plt.xlabel("Selection Score")
    plt.ylabel("Density")
    plt.title("Selection Score Distribution: Gold vs Non-gold Documents")
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # 통계
    print(f"Gold scores - Mean: {np.mean(gold_scores):.4f}, Std: {np.std(gold_scores):.4f}")
    print(f"Non-gold scores - Mean: {np.mean(non_gold_scores):.4f}, Std: {np.std(non_gold_scores):.4f}")
    
    return gold_scores, non_gold_scores

# 예시
sample_scores = np.random.randn(50, 100)
sample_gold = [[np.random.randint(0, 100)] for _ in range(50)]
# analyze_selection_distribution(sample_scores, sample_gold)
# plt.show()

## 3. Error Analysis

In [None]:
from analysis.error_analysis import analyze_errors, analyze_multihop_errors

# Error categories
error_categories = {
    "selection_failure": "Gold doc이 top-k에 없음",
    "generation_failure": "Selection은 맞지만 답이 틀림",
    "both_failure": "Selection과 Generation 모두 실패",
    "success": "정답",
}

for cat, desc in error_categories.items():
    print(f"{cat}: {desc}")

In [None]:
# 예시 error 분포
error_stats = {
    "success": 45.2,
    "selection_failure": 25.3,
    "generation_failure": 18.5,
    "both_failure": 11.0,
}

# Pie chart
plt.figure(figsize=(10, 8))
colors = ["#2ecc71", "#e74c3c", "#3498db", "#9b59b6"]
explode = (0.05, 0, 0, 0)

plt.pie(
    error_stats.values(),
    labels=error_stats.keys(),
    autopct="%1.1f%%",
    colors=colors,
    explode=explode,
    shadow=True,
)
plt.title("Error Distribution")
plt.show()

## 4. Multi-hop Error Analysis

In [None]:
# Multi-hop 전용 error 분류
multihop_errors = {
    "bridge_not_found": 35.2,
    "bridge_not_recognized": 22.1,
    "propagation_error": 15.3,
    "success": 27.4,
}

# Bar chart
plt.figure(figsize=(10, 6))
colors = ["#e74c3c", "#f39c12", "#9b59b6", "#2ecc71"]
bars = plt.bar(multihop_errors.keys(), multihop_errors.values(), color=colors)

plt.ylabel("Percentage (%)")
plt.title("Multi-hop Error Analysis (HotpotQA)")
plt.grid(axis="y", alpha=0.3)

# 값 표시
for bar, val in zip(bars, multihop_errors.values()):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
             f"{val}%", ha="center", va="bottom")

plt.tight_layout()
plt.show()

## 5. Attention Heatmap

In [None]:
def plot_attention_heatmap(attention_weights, query_tokens, doc_ids):
    """
    Query-Document attention heatmap
    
    Args:
        attention_weights: [query_len, num_docs]
        query_tokens: list of query tokens
        doc_ids: list of document IDs
    """
    plt.figure(figsize=(12, 8))
    
    sns.heatmap(
        attention_weights,
        xticklabels=doc_ids,
        yticklabels=query_tokens,
        cmap="YlOrRd",
        annot=True,
        fmt=".2f",
    )
    
    plt.xlabel("Document ID")
    plt.ylabel("Query Token")
    plt.title("Query-Document Attention Weights")
    plt.tight_layout()

# 예시
sample_attn = np.random.rand(8, 5)
sample_attn = sample_attn / sample_attn.sum(axis=1, keepdims=True)  # normalize
query_tokens = ["What", "is", "the", "capital", "of", "France", "?", "[PAD]"]
doc_ids = ["Doc0", "Doc1", "Doc2", "Doc3", "Doc4"]

plot_attention_heatmap(sample_attn, query_tokens, doc_ids)
plt.show()

## 6. Training Curves

In [None]:
def plot_training_curves(history: dict):
    """
    학습 곡선 시각화
    
    Args:
        history: {"loss": [...], "em": [...], "f1": [...]}
    """
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Loss
    ax1 = axes[0]
    if "train_loss" in history:
        ax1.plot(history["train_loss"], label="Train", color="#3498db")
    if "val_loss" in history:
        ax1.plot(history["val_loss"], label="Validation", color="#e74c3c")
    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("Loss")
    ax1.set_title("Training Loss")
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Metrics
    ax2 = axes[1]
    if "em" in history:
        ax2.plot(history["em"], label="EM", color="#2ecc71")
    if "f1" in history:
        ax2.plot(history["f1"], label="F1", color="#9b59b6")
    ax2.set_xlabel("Epoch")
    ax2.set_ylabel("Score (%)")
    ax2.set_title("Evaluation Metrics")
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()

# 예시 데이터
sample_history = {
    "train_loss": [2.5, 1.8, 1.4, 1.1, 0.9, 0.75, 0.65, 0.58, 0.52, 0.48],
    "val_loss": [2.6, 1.9, 1.5, 1.2, 1.0, 0.85, 0.78, 0.72, 0.68, 0.66],
    "em": [15, 25, 32, 38, 42, 44, 45.5, 46.2, 46.5, 46.8],
    "f1": [22, 32, 40, 46, 50, 52, 53.5, 54.2, 54.5, 54.8],
}

plot_training_curves(sample_history)
plt.show()

## 7. Corpus Size Scaling

In [None]:
# Corpus size vs Performance/Efficiency
scaling_data = {
    "corpus_size": [1000, 5000, 10000, 50000, 100000],
    "EM": [42.5, 44.2, 45.8, 46.5, 46.8],
    "selection_latency_ms": [2.1, 3.5, 5.2, 15.8, 32.4],
    "storage_mb": [4, 20, 40, 200, 400],
}

df_scaling = pd.DataFrame(scaling_data)

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# EM
ax1 = axes[0]
ax1.semilogx(df_scaling["corpus_size"], df_scaling["EM"], "o-", color="#3498db")
ax1.set_xlabel("Corpus Size")
ax1.set_ylabel("EM (%)")
ax1.set_title("Performance Scaling")
ax1.grid(True, alpha=0.3)

# Latency
ax2 = axes[1]
ax2.semilogx(df_scaling["corpus_size"], df_scaling["selection_latency_ms"], "o-", color="#e74c3c")
ax2.set_xlabel("Corpus Size")
ax2.set_ylabel("Selection Latency (ms)")
ax2.set_title("Latency Scaling")
ax2.grid(True, alpha=0.3)

# Storage
ax3 = axes[2]
ax3.loglog(df_scaling["corpus_size"], df_scaling["storage_mb"], "o-", color="#2ecc71")
ax3.set_xlabel("Corpus Size")
ax3.set_ylabel("Storage (MB)")
ax3.set_title("Storage Scaling")
ax3.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 8. 결과 저장

In [None]:
def save_analysis_results(results: dict, figures_dir: Path):
    """
    분석 결과 저장
    """
    figures_dir.mkdir(parents=True, exist_ok=True)
    
    # JSON으로 수치 결과 저장
    with open(figures_dir / "analysis_results.json", "w") as f:
        json.dump(results, f, indent=2)
    
    print(f"Results saved to {figures_dir}")

# 저장
# results = {
#     "error_distribution": error_stats,
#     "multihop_errors": multihop_errors,
#     "scaling": scaling_data,
# }
# save_analysis_results(results, PROJECT_ROOT / "results" / "analysis")