# 01 - Sanity Check: Reward Function Validation

Objective: Verify that reward components correctly differentiate between "good" (geometrically consistent) and "bad" (Janus problem, texture drift) multi-view samples.

## Expected Results
- R_warp: Good > Bad (p < 0.01)
- R_epi: Good > Bad (p < 0.01)
- R_sem: Good > Bad (p < 0.05)
- R_total: Good > Bad (p < 0.001)

In [None]:
# Install dependencies if needed
# !pip install torch torchvision transformers kornia scipy matplotlib

In [None]:
import sys
sys.path.append('..')

import torch
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
from tqdm import tqdm

from src.rewards import (
    DepthWarpingReward,
    EpipolarReward,
    SemanticReward,
    CompositeReward
)

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 1. Initialize Reward Functions

In [None]:
# Initialize all reward components
print("Loading reward models...")

depth_reward = DepthWarpingReward(device=device)
epipolar_reward = EpipolarReward(device=device)
semantic_reward = SemanticReward(device=device)
composite_reward = CompositeReward(device=device)

print("All models loaded successfully!")

## 2. Generate Synthetic Test Data

For sanity checking, we create:
- **Good samples**: Consistent multi-view (same object appearance across views)
- **Bad samples**: Inconsistent views (simulating Janus problem with random variations)

In [None]:
def generate_consistent_samples(n_samples=25, seed=42):
    """
    Generate 'good' samples with consistent appearance across views.
    Creates a base image and applies minor viewpoint-appropriate variations.
    """
    torch.manual_seed(seed)
    
    samples = []
    conditions = []
    
    for _ in range(n_samples):
        # Create a base "object" pattern
        H, W = 320, 320
        base = torch.randn(3, H, W) * 0.2 + 0.5
        base = torch.clamp(base, 0, 1)
        
        # Add a central object-like structure
        y, x = torch.meshgrid(torch.linspace(-1, 1, H), torch.linspace(-1, 1, W), indexing='ij')
        r = torch.sqrt(x**2 + y**2)
        mask = (r < 0.5).float()
        
        for c in range(3):
            base[c] = base[c] * (1 - mask * 0.5) + mask * (0.3 + 0.4 * c / 2)
        
        # Create 6 views with small consistent variations
        views = []
        for v in range(6):
            # Small intensity shift (simulating lighting change)
            shift = 0.02 * (v - 2.5)
            view = torch.clamp(base + shift, 0, 1)
            views.append(view)
        
        views = torch.stack(views)  # (6, 3, H, W)
        samples.append(views)
        conditions.append(base)
    
    return torch.stack(samples), torch.stack(conditions)


def generate_inconsistent_samples(n_samples=25, seed=123):
    """
    Generate 'bad' samples with Janus-like inconsistencies.
    Each view has different random appearance.
    """
    torch.manual_seed(seed)
    
    samples = []
    conditions = []
    
    for _ in range(n_samples):
        H, W = 320, 320
        condition = torch.randn(3, H, W) * 0.2 + 0.5
        condition = torch.clamp(condition, 0, 1)
        
        # Create 6 views with LARGE inconsistent variations
        views = []
        for v in range(6):
            # Random texture per view (Janus problem simulation)
            random_texture = torch.randn(3, H, W) * 0.3 + 0.5
            view = torch.clamp(random_texture, 0, 1)
            views.append(view)
        
        views = torch.stack(views)
        samples.append(views)
        conditions.append(condition)
    
    return torch.stack(samples), torch.stack(conditions)

In [None]:
# Generate test data
print("Generating synthetic test data...")

good_samples, good_conditions = generate_consistent_samples(25)
bad_samples, bad_conditions = generate_inconsistent_samples(25)

print(f"Good samples shape: {good_samples.shape}")
print(f"Bad samples shape: {bad_samples.shape}")

## 3. Compute Rewards for All Samples

In [None]:
def compute_rewards_batch(samples, conditions, composite):
    """
    Compute all reward components for a batch of samples.
    """
    results = {
        'warp': [],
        'epipolar': [],
        'semantic': [],
        'total': []
    }
    
    for i in tqdm(range(len(samples)), desc="Computing rewards"):
        sample = samples[i:i+1].to(device)
        condition = conditions[i:i+1].to(device)
        
        with torch.no_grad():
            result = composite.forward(sample, condition, step=5000)
        
        results['warp'].append(result['warp'].cpu().item())
        results['epipolar'].append(result['epipolar'].cpu().item())
        results['semantic'].append(result['semantic'].cpu().item())
        results['total'].append(result['total'].cpu().item())
    
    return {k: np.array(v) for k, v in results.items()}

In [None]:
# Compute rewards
print("Computing rewards for good samples...")
good_rewards = compute_rewards_batch(good_samples, good_conditions, composite_reward)

print("\nComputing rewards for bad samples...")
bad_rewards = compute_rewards_batch(bad_samples, bad_conditions, composite_reward)

## 4. Statistical Analysis

In [None]:
def perform_statistical_tests(good, bad):
    """
    Perform t-tests and compute effect sizes.
    """
    results = {}
    
    for key in ['warp', 'epipolar', 'semantic', 'total']:
        # Two-sample t-test
        t_stat, p_value = stats.ttest_ind(good[key], bad[key])
        
        # Cohen's d effect size
        pooled_std = np.sqrt(((len(good[key])-1)*np.std(good[key])**2 + 
                              (len(bad[key])-1)*np.std(bad[key])**2) / 
                             (len(good[key]) + len(bad[key]) - 2))
        cohens_d = (np.mean(good[key]) - np.mean(bad[key])) / pooled_std if pooled_std > 0 else 0
        
        results[key] = {
            'good_mean': np.mean(good[key]),
            'good_std': np.std(good[key]),
            'bad_mean': np.mean(bad[key]),
            'bad_std': np.std(bad[key]),
            't_stat': t_stat,
            'p_value': p_value,
            'cohens_d': cohens_d,
            'significant': p_value < 0.05
        }
    
    return results

In [None]:
# Run statistical tests
stats_results = perform_statistical_tests(good_rewards, bad_rewards)

# Display results
print("\n" + "="*80)
print("SANITY CHECK RESULTS")
print("="*80)

for key, result in stats_results.items():
    status = "✅ PASS" if result['significant'] and result['cohens_d'] > 0 else "❌ FAIL"
    print(f"\n{key.upper()} Reward:")
    print(f"  Good: {result['good_mean']:.4f} ± {result['good_std']:.4f}")
    print(f"  Bad:  {result['bad_mean']:.4f} ± {result['bad_std']:.4f}")
    print(f"  p-value: {result['p_value']:.6f}")
    print(f"  Cohen's d: {result['cohens_d']:.3f}")
    print(f"  Status: {status}")

## 5. Visualization: Box Plots

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

reward_names = ['warp', 'epipolar', 'semantic', 'total']
titles = ['R_warp (Depth Warping)', 'R_epi (Epipolar)', 'R_sem (Semantic)', 'R_total (Combined)']

for ax, key, title in zip(axes.flat, reward_names, titles):
    data = [good_rewards[key], bad_rewards[key]]
    bp = ax.boxplot(data, labels=['Good', 'Bad'], patch_artist=True)
    
    # Color boxes
    bp['boxes'][0].set_facecolor('lightgreen')
    bp['boxes'][1].set_facecolor('lightcoral')
    
    ax.set_title(title)
    ax.set_ylabel('Reward Value')
    ax.grid(True, alpha=0.3)
    
    # Add p-value annotation
    p = stats_results[key]['p_value']
    ax.annotate(f'p = {p:.4f}', xy=(0.95, 0.95), xycoords='axes fraction',
                ha='right', va='top', fontsize=10,
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

plt.suptitle('Sanity Check: Reward Distribution for Good vs Bad Samples', fontsize=14)
plt.tight_layout()
plt.savefig('../results/sanity_check_boxplots.png', dpi=150, bbox_inches='tight')
plt.show()

## 6. Summary Table

In [None]:
import pandas as pd

# Create summary table
summary_data = []
for key in reward_names:
    r = stats_results[key]
    summary_data.append({
        'Reward': key.capitalize(),
        'Good (mean ± std)': f"{r['good_mean']:.4f} ± {r['good_std']:.4f}",
        'Bad (mean ± std)': f"{r['bad_mean']:.4f} ± {r['bad_std']:.4f}",
        't-statistic': f"{r['t_stat']:.3f}",
        'p-value': f"{r['p_value']:.6f}",
        "Cohen's d": f"{r['cohens_d']:.3f}",
        'Verdict': '✅' if r['significant'] and r['cohens_d'] > 0.8 else '⚠️' if r['significant'] else '❌'
    })

summary_df = pd.DataFrame(summary_data)
print("\nSummary Table:")
display(summary_df)

## 7. Conclusions

**Criteria for proceeding to GRPO training:**
- ✅ All reward components show statistically significant difference (p < 0.01)
- ✅ Effect size (Cohen's d) > 0.8 for at least R_total

If all criteria are met, the reward functions are validated for GRPO training.