# Graph-Aware Remasking vs WINO Baseline - GSM8K Benchmark

이 노트북은 **Graph-Aware Historical Remasking** 디코더와 **WINO** baseline을 GSM8K 벤치마크에서 비교합니다.

## 목적
- WINO (baseline): Confidence 기반 remasking
- Graph-Remask (experimental): Attention 기반 responsibility + confidence remasking
- 평가 지표: Accuracy, Forward passes, Remask count

In [None]:
import os
import sys
import torch
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoTokenizer

# Add current directory to path
current_dir = os.getcwd()
if current_dir not in sys.path:
    sys.path.append(current_dir)

# Import local modules
from modeling_llada import LLaDAModelLM
from configuration_llada import LLaDAConfig
from decoding import decoding_wino, decoding_graph_remask
import benchmark_utils

print("Modules loaded successfully.")

## 1. Load Model

In [None]:
LOCAL_MODEL_PATH = "../Grok-1-LLaDA-8B"
HF_MODEL_ID = "GSAI-ML/LLaDA-8B-Base"

model_path = HF_MODEL_ID
if os.path.exists(LOCAL_MODEL_PATH):
    model_path = LOCAL_MODEL_PATH
    print(f"Using local model: {model_path}")
else:
    print(f"Using HuggingFace model: {model_path}")

config = LLaDAConfig.from_pretrained(model_path)
model = LLaDAModelLM.from_pretrained(model_path, config=config, torch_dtype="auto")

if torch.cuda.is_available():
    model.cuda()
model.eval()

tokenizer = AutoTokenizer.from_pretrained(model_path)
print("Model loaded successfully.")

## 2. Load GSM8K Dataset

In [None]:
# Configuration
N_SAMPLES = 50  # Number of GSM8K samples to test
GEN_LENGTH = 256  # As requested
BLOCK_LENGTH = 256

print(f"Loading GSM8K dataset (N={N_SAMPLES})...")
gsm8k_data = benchmark_utils.load_gsm8k(n_samples=N_SAMPLES)
print(f"Loaded {len(gsm8k_data)} samples.")

# Show example
if gsm8k_data:
    print("\nExample question:")
    print(gsm8k_data[0]['question'])
    print(f"\nGround truth: {gsm8k_data[0]['ground_truth']}")

## 3. Run Benchmark Comparison

In [None]:
results = []
mask_id = 126336

for idx, item in enumerate(gsm8k_data):
    question = item['question']
    ground_truth = item['ground_truth']
    
    # Encode prompt
    prompt_tokens = tokenizer.encode(question, return_tensors='pt').to(model.device)
    
    print(f"\n[{idx+1}/{len(gsm8k_data)}] Processing...")
    
    # === WINO Baseline ===
    try:
        output_wino, steps_wino = decoding_wino(
            model=model,
            prompt=prompt_tokens,
            gen_length=GEN_LENGTH,
            block_length=BLOCK_LENGTH,
            temperature=0.0,
            mask_id=mask_id,
            threshold=0.6,
            threshold_back=0.9
        )
        text_wino = tokenizer.decode(output_wino[0], skip_special_tokens=True)
        correct_wino = benchmark_utils.check_correctness(text_wino, ground_truth, "Math")
        
        print(f"  WINO: {steps_wino} steps, Correct: {correct_wino}")
    except Exception as e:
        print(f"  WINO failed: {e}")
        steps_wino = -1
        correct_wino = False
        text_wino = ""
    
    # === Graph-Aware Remasking ===
    try:
        output_graph, stats_graph = decoding_graph_remask(
            model=model,
            prompt=prompt_tokens,
            gen_length=GEN_LENGTH,
            block_length=BLOCK_LENGTH,
            temperature=0.0,
            mask_id=mask_id,
            threshold_forward=0.6,
            threshold_back=0.9,
            resp_threshold=0.3,
            gamma_decay=0.95,
            use_attention_layers=[-1],
            top_k_attention=10,
            max_remask_ratio=0.3
        )
        text_graph = tokenizer.decode(output_graph[0], skip_special_tokens=True)
        correct_graph = benchmark_utils.check_correctness(text_graph, ground_truth, "Math")
        steps_graph = stats_graph['total_steps']
        remasks_graph = stats_graph['total_remasks']
        
        print(f"  Graph: {steps_graph} steps, {remasks_graph} remasks, Correct: {correct_graph}")
    except Exception as e:
        print(f"  Graph-Remask failed: {e}")
        steps_graph = -1
        remasks_graph = -1
        correct_graph = False
        text_graph = ""
    
    # Store results
    results.append({
        'question': question,
        'ground_truth': ground_truth,
        'wino_correct': 1 if correct_wino else 0,
        'wino_steps': steps_wino,
        'graph_correct': 1 if correct_graph else 0,
        'graph_steps': steps_graph,
        'graph_remasks': remasks_graph,
        'wino_output': text_wino,
        'graph_output': text_graph
    })

# Create DataFrame
df_results = pd.DataFrame(results)
print("\n=== Benchmark Complete ===")

## 4. Analyze Results

In [None]:
# Summary statistics
print("\n=== Summary Statistics ===")
print(f"\nWINO Baseline:")
print(f"  Accuracy: {df_results['wino_correct'].mean():.2%}")
print(f"  Avg Steps: {df_results['wino_steps'].mean():.1f}")

print(f"\nGraph-Aware Remasking:")
print(f"  Accuracy: {df_results['graph_correct'].mean():.2%}")
print(f"  Avg Steps: {df_results['graph_steps'].mean():.1f}")
print(f"  Avg Remasks: {df_results['graph_remasks'].mean():.1f}")

print(f"\nImprovement:")
acc_diff = df_results['graph_correct'].mean() - df_results['wino_correct'].mean()
print(f"  Accuracy Δ: {acc_diff:+.2%}")
step_diff = df_results['graph_steps'].mean() - df_results['wino_steps'].mean()
print(f"  Steps Δ: {step_diff:+.1f}")

## 5. Visualization

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Accuracy comparison
ax1 = axes[0]
accuracies = [df_results['wino_correct'].mean(), df_results['graph_correct'].mean()]
ax1.bar(['WINO', 'Graph-Remask'], accuracies, color=['#3498db', '#e74c3c'])
ax1.set_ylabel('Accuracy')
ax1.set_title('GSM8K Accuracy Comparison')
ax1.set_ylim([0, 1])
for i, v in enumerate(accuracies):
    ax1.text(i, v + 0.02, f"{v:.2%}", ha='center', fontweight='bold')

# Steps comparison
ax2 = axes[1]
steps = [df_results['wino_steps'].mean(), df_results['graph_steps'].mean()]
ax2.bar(['WINO', 'Graph-Remask'], steps, color=['#3498db', '#e74c3c'])
ax2.set_ylabel('Average Steps')
ax2.set_title('Decoding Steps Comparison')
for i, v in enumerate(steps):
    ax2.text(i, v + 1, f"{v:.1f}", ha='center', fontweight='bold')

# Remask count
ax3 = axes[2]
ax3.hist(df_results['graph_remasks'], bins=20, color='#e74c3c', alpha=0.7, edgecolor='black')
ax3.set_xlabel('Remask Count')
ax3.set_ylabel('Frequency')
ax3.set_title('Graph-Remask: Remask Distribution')
ax3.axvline(df_results['graph_remasks'].mean(), color='red', linestyle='--', linewidth=2, label=f"Mean: {df_results['graph_remasks'].mean():.1f}")
ax3.legend()

plt.tight_layout()
plt.show()

## 6. Failure Case Analysis

In [None]:
# Find cases where WINO failed but Graph succeeded
graph_wins = df_results[(df_results['wino_correct'] == 0) & (df_results['graph_correct'] == 1)]
print(f"\n=== Graph-Remask Wins (WINO failed, Graph succeeded): {len(graph_wins)} cases ===")
if len(graph_wins) > 0:
    for idx, row in graph_wins.head(3).iterrows():
        print(f"\n--- Case {idx+1} ---")
        print(f"Question: {row['question'][:100]}...")
        print(f"Ground Truth: {row['ground_truth']}")
        print(f"WINO Output: {benchmark_utils.extract_number(row['wino_output'])}")
        print(f"Graph Output: {benchmark_utils.extract_number(row['graph_output'])}")

# Find cases where WINO succeeded but Graph failed
wino_wins = df_results[(df_results['wino_correct'] == 1) & (df_results['graph_correct'] == 0)]
print(f"\n=== WINO Wins (Graph failed, WINO succeeded): {len(wino_wins)} cases ===")
if len(wino_wins) > 0:
    for idx, row in wino_wins.head(3).iterrows():
        print(f"\n--- Case {idx+1} ---")
        print(f"Question: {row['question'][:100]}...")
        print(f"Ground Truth: {row['ground_truth']}")
        print(f"WINO Output: {benchmark_utils.extract_number(row['wino_output'])}")
        print(f"Graph Output: {benchmark_utils.extract_number(row['graph_output'])}")

## 7. Save Results

In [None]:
# Save to CSV
output_file = "graph_remask_vs_wino_gsm8k_results.csv"
df_results.to_csv(output_file, index=False)
print(f"\nResults saved to {output_file}")