# 03 - Component Ablation Study

Objective: Quantify the contribution of each reward component and validate correlation with human judgment.

## Expected Correlations (with human scores)
- R_warp ↔ Human: ρ ~0.6-0.7
- R_epi ↔ Human: ρ ~0.5-0.6
- R_sem ↔ Human: ρ ~0.4-0.5
- R_total ↔ Human: ρ ~0.7-0.8 (best)

In [None]:
import sys
sys.path.append('..')

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy import stats
from tqdm import tqdm

from src.rewards import CompositeReward

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

In [None]:
reward_fn = CompositeReward(device=device)

## 1. Generate Samples for Human Annotation

We create 100 diverse samples for human scoring.

In [None]:
def generate_annotatable_samples(n_samples=100, seed=42):
    """Generate samples with known consistency levels for annotation."""
    torch.manual_seed(seed)
    
    samples = []
    conditions = []
    ground_truth = []  # Simulated human scores (1-5)
    
    for i in range(n_samples):
        H, W = 320, 320
        base = torch.randn(3, H, W) * 0.2 + 0.5
        base = torch.clamp(base, 0, 1)
        
        # Create samples with known quality levels
        consistency = np.linspace(0.2, 1.0, n_samples)[i]
        
        views = []
        for v in range(6):
            noise = torch.randn(3, H, W) * 0.1 * (1 - consistency)
            view = torch.clamp(base + noise, 0, 1)
            views.append(view)
        
        views = torch.stack(views)
        samples.append(views)
        conditions.append(base)
        
        # Simulated human score (correlated with consistency)
        human_score = 1 + 4 * consistency + np.random.normal(0, 0.3)
        human_score = np.clip(human_score, 1, 5)
        ground_truth.append(human_score)
    
    return torch.stack(samples), torch.stack(conditions), np.array(ground_truth)

samples, conditions, human_scores = generate_annotatable_samples(100)
print(f"Generated {len(samples)} samples for ablation study")

## 2. Compute All Reward Variants

In [None]:
results = []

for i in tqdm(range(len(samples)), desc="Computing rewards"):
    sample = samples[i:i+1].to(device)
    condition = conditions[i:i+1].to(device)
    
    with torch.no_grad():
        result = reward_fn.forward(sample, condition, step=5000)
    
    results.append({
        'sample_id': i,
        'r_warp': result['warp'].cpu().item(),
        'r_epi': result['epipolar'].cpu().item(),
        'r_sem': result['semantic'].cpu().item(),
        'r_geom': result['geometric'].cpu().item(),
        'r_total': result['total'].cpu().item(),
        'human_score': human_scores[i]
    })

df = pd.DataFrame(results)

## 3. Correlation Analysis

In [None]:
reward_cols = ['r_warp', 'r_epi', 'r_sem', 'r_geom', 'r_total']
correlations = {}

print("\nSpearman Correlations with Human Scores:")
print("="*50)

for col in reward_cols:
    rho, p_value = stats.spearmanr(df[col], df['human_score'])
    correlations[col] = {'rho': rho, 'p_value': p_value}
    status = '✅' if rho > 0.5 else '⚠️'
    print(f"{col:10s}: ρ = {rho:.4f}, p = {p_value:.6f} {status}")

In [None]:
# Visualization: Reward vs Human Score
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

for ax, col in zip(axes.flat[:-1], reward_cols):
    ax.scatter(df[col], df['human_score'], alpha=0.6, s=30)
    
    # Fit line
    z = np.polyfit(df[col], df['human_score'], 1)
    p = np.poly1d(z)
    x_line = np.linspace(df[col].min(), df[col].max(), 100)
    ax.plot(x_line, p(x_line), 'r--', alpha=0.8)
    
    rho = correlations[col]['rho']
    ax.set_xlabel(col)
    ax.set_ylabel('Human Score')
    ax.set_title(f'{col} (ρ = {rho:.3f})')
    ax.grid(True, alpha=0.3)

# Hide unused subplot
axes.flat[-1].axis('off')

plt.suptitle('Reward Components vs Human Judgment', fontsize=14)
plt.tight_layout()
plt.savefig('../results/ablation_correlations.png', dpi=150, bbox_inches='tight')
plt.show()

## 4. Ablation: Removing Components

In [None]:
# Compute ablated rewards
df['r_no_warp'] = df['r_epi'] * df['r_sem']  # Without depth warping
df['r_no_epi'] = df['r_warp'] * df['r_sem']  # Without epipolar
df['r_no_sem'] = df['r_geom']  # Without semantic (pure geometric)

ablation_cols = ['r_total', 'r_no_warp', 'r_no_epi', 'r_no_sem']

print("\nAblation Study - Correlation with Human Scores:")
print("="*50)

ablation_results = []
for col in ablation_cols:
    rho, p_value = stats.spearmanr(df[col], df['human_score'])
    print(f"{col:15s}: ρ = {rho:.4f}")
    ablation_results.append({'variant': col, 'correlation': rho})

# Bar chart
plt.figure(figsize=(10, 5))
ablation_df = pd.DataFrame(ablation_results)
colors = ['green' if x == 'r_total' else 'coral' for x in ablation_df['variant']]
plt.bar(ablation_df['variant'], ablation_df['correlation'], color=colors, edgecolor='black')
plt.ylabel('Spearman ρ with Human Score')
plt.title('Ablation Study: Impact of Removing Reward Components')
plt.axhline(y=0.65, color='red', linestyle='--', label='Target ρ > 0.65')
plt.legend()
plt.tight_layout()
plt.savefig('../results/ablation_bar_chart.png', dpi=150, bbox_inches='tight')
plt.show()

## 5. Top and Bottom Samples

In [None]:
# Top 5 by each component
print("\nTop 5 samples per reward component:")
for col in ['r_warp', 'r_epi', 'r_sem', 'r_total']:
    print(f"\n{col.upper()}:")
    top5 = df.nlargest(5, col)[['sample_id', col, 'human_score']]
    display(top5)

print("\nBottom 5 samples per reward component:")
for col in ['r_warp', 'r_epi', 'r_sem', 'r_total']:
    print(f"\n{col.upper()}:")
    bottom5 = df.nsmallest(5, col)[['sample_id', col, 'human_score']]
    display(bottom5)

## 6. Summary Table

In [None]:
summary = []
for col in reward_cols:
    rho, p = stats.spearmanr(df[col], df['human_score'])
    summary.append({
        'Component': col,
        'Spearman ρ': f'{rho:.4f}',
        'p-value': f'{p:.6f}',
        'Mean': f'{df[col].mean():.4f}',
        'Std': f'{df[col].std():.4f}',
        'Meets Target': '✅' if rho > 0.5 else '❌'
    })

summary_df = pd.DataFrame(summary)
display(summary_df)

# Save
df.to_csv('../results/ablation_study_results.csv', index=False)