# Model Evaluation & Analysis
## Đánh giá chi tiết các VLM models

In [None]:
# Setup
import sys
sys.path.append('../')

import torch
import yaml
import json
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

from src.models.model_registry import build_model
from src.data.wad_dataset import build_dataset
from src.evaluation.evaluator import VLMEvaluator
from src.utils.visualization import plot_model_comparison

sns.set_style('whitegrid')
%matplotlib inline

## 1. Load Trained Model

In [None]:
# Load config
config_path = '../configs/llava_config.yaml'

with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

print("Config loaded:")
print(f"  Model: {config['model']['name']}")
print(f"  Output dir: {config['training']['output_dir']}")

In [None]:
# Build model
print("Loading model...")
vlm = build_model(config)

# Load checkpoint (if exists)
checkpoint_path = Path(config['training']['output_dir']) / 'final_model'

if checkpoint_path.exists():
    from peft import PeftModel
    vlm.model = PeftModel.from_pretrained(vlm.model, str(checkpoint_path))
    print(f"✓ Loaded checkpoint from {checkpoint_path}")
else:
    print(" No checkpoint found, using base model")

vlm.model.eval()
print("✓ Model ready for evaluation")

## 2. Load Evaluation Dataset

In [None]:
# Build dataset
print("Loading evaluation dataset...")
train_dataset, eval_dataset = build_dataset(config, vlm.processor, vlm.tokenizer)

print(f"✓ Evaluation dataset: {len(eval_dataset)} samples")

## 3. Run Evaluation

In [None]:
# Create evaluator
evaluator = VLMEvaluator(
    model=vlm.model,
    tokenizer=vlm.tokenizer,
    processor=vlm.processor,
    config=config
)

# Run evaluation (takes time!)
results = evaluator.evaluate(eval_dataset)

print("\nEvaluation Results:")
for metric, score in results.items():
    print(f"  {metric}: {score:.2f}%")

## 4. Visualize Results

In [None]:
# Plot metrics
metrics_df = pd.DataFrame([results]).T
metrics_df.columns = ['Score (%)']

plt.figure(figsize=(10, 6))
metrics_df.plot(kind='barh', legend=False)
plt.xlabel('Score (%)', fontsize=12)
plt.ylabel('Metric', fontsize=12)
plt.title('Model Evaluation Metrics', fontsize=14, fontweight='bold')
plt.xlim(0, 100)
plt.grid(axis='x', alpha=0.3)
plt.tight_layout()
plt.savefig('../experiments/results/evaluation_metrics.png', dpi=300, bbox_inches='tight')
plt.show()

## 5. Qualitative Analysis (Sample Predictions)

In [None]:
# Generate predictions for first 5 samples
num_samples = 5

print("Generating sample predictions...\n")

for idx in range(num_samples):
    sample = eval_dataset[idx]
    
    # Prepare input
    inputs = {
        'input_ids': sample['input_ids'].unsqueeze(0).to(config['hardware']['device']),
        'attention_mask': sample['attention_mask'].unsqueeze(0).to(config['hardware']['device']),
        'pixel_values': sample['pixel_values'].unsqueeze(0).to(config['hardware']['device'])
    }
    
    if 'image_sizes' in sample:
        inputs['image_sizes'] = [tuple(sample['image_sizes'].tolist())]
    
    if 'image_grid_thw' in sample:
        grid = sample['image_grid_thw']
        if grid.dim() == 1:
            grid = grid.unsqueeze(0)
        inputs['image_grid_thw'] = grid.unsqueeze(0).to(config['hardware']['device'])
    
    # Generate
    with torch.no_grad():
        outputs = vlm.model.generate(
            **inputs,
            max_new_tokens=128,
            do_sample=False
        )
    
    # Decode
    pred_text = vlm.tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract ground truth
    labels = sample['labels']
    gt_tokens = labels[labels != -100]
    gt_text = vlm.tokenizer.decode(gt_tokens, skip_special_tokens=True)
    
    print(f"{'='*80}")
    print(f"Sample {idx+1}")
    print(f"{'='*80}")
    print(f"Ground Truth:\n{gt_text}\n")
    print(f"Prediction:\n{pred_text}\n")

## 6. Compare Multiple Models (if available)

In [None]:
# Load comparison results (if exists)
comparison_file = '../experiments/comparison.json'

if Path(comparison_file).exists():
    with open(comparison_file, 'r') as f:
        all_results = json.load(f)
    
    # Plot comparison
    plot_model_comparison(all_results, '../experiments/results/model_comparison.png')
    
    print("\nModel Comparison:")
    df = pd.DataFrame(all_results).T
    print(df.to_string())
else:
    print(" No comparison results found. Run: python run_experiments.py --configs configs/*.yaml")

## 7. Error Analysis

In [None]:
# Analyze errors by field
field_accuracies = {k: v for k, v in results.items() if 'accuracy' in k}

sorted_fields = sorted(field_accuracies.items(), key=lambda x: x[1])

print("Field Accuracy (lowest to highest):")
for field, acc in sorted_fields:
    print(f"  {field}: {acc:.2f}%")
    
print("\n Focus improvement on lowest accuracy fields!")