# 02 - Reward Distribution Analysis

Objective: Understand reward statistics on a validation set of multi-view samples.

## Expected Distributions
- R_warp: Mean ~0.7, Std ~0.15
- R_epi: Mean ~0.8, Std ~0.12
- R_sem: Mean ~0.85, Std ~0.10

In [None]:
import sys
sys.path.append('..')

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

from src.rewards import CompositeReward

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

In [None]:
# Initialize composite reward
reward_fn = CompositeReward(device=device)

## 1. Generate Validation Set

Generate 500 synthetic multi-view samples with varying quality levels.

In [None]:
def generate_varied_samples(n_samples=500, seed=42):
    """
    Generate samples with varying degrees of consistency.
    """
    torch.manual_seed(seed)
    
    samples = []
    conditions = []
    
    for i in range(n_samples):
        H, W = 320, 320
        
        # Create base pattern
        base = torch.randn(3, H, W) * 0.2 + 0.5
        base = torch.clamp(base, 0, 1)
        
        # Varying consistency level (0 = inconsistent, 1 = fully consistent)
        consistency = np.random.uniform(0.3, 1.0)
        
        views = []
        for v in range(6):
            # Mix of base and random pattern
            random_pattern = torch.randn(3, H, W) * 0.2 + 0.5
            random_pattern = torch.clamp(random_pattern, 0, 1)
            
            view = consistency * base + (1 - consistency) * random_pattern
            view = torch.clamp(view, 0, 1)
            views.append(view)
        
        views = torch.stack(views)
        samples.append(views)
        conditions.append(base)
    
    return torch.stack(samples), torch.stack(conditions)

print("Generating 500 validation samples...")
samples, conditions = generate_varied_samples(500)
print(f"Samples shape: {samples.shape}")

## 2. Compute Rewards

In [None]:
results = []

for i in tqdm(range(len(samples)), desc="Computing rewards"):
    sample = samples[i:i+1].to(device)
    condition = conditions[i:i+1].to(device)
    
    with torch.no_grad():
        result = reward_fn.forward(sample, condition, step=5000)
    
    results.append({
        'sample_id': i,
        'r_warp': result['warp'].cpu().item(),
        'r_epi': result['epipolar'].cpu().item(),
        'r_sem': result['semantic'].cpu().item(),
        'r_total': result['total'].cpu().item(),
        'r_geom': result['geometric'].cpu().item()
    })

df = pd.DataFrame(results)
print("\nReward statistics:")
display(df.describe())

## 3. Distribution Histograms

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

columns = ['r_warp', 'r_epi', 'r_sem', 'r_total']
titles = ['R_warp (Depth Warping)', 'R_epi (Epipolar)', 'R_sem (Semantic)', 'R_total (Combined)']

for ax, col, title in zip(axes.flat, columns, titles):
    ax.hist(df[col], bins=30, edgecolor='black', alpha=0.7)
    ax.axvline(df[col].mean(), color='red', linestyle='--', label=f'Mean: {df[col].mean():.3f}')
    ax.axvline(df[col].median(), color='green', linestyle='--', label=f'Median: {df[col].median():.3f}')
    ax.set_title(title)
    ax.set_xlabel('Reward Value')
    ax.set_ylabel('Frequency')
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.suptitle('Reward Distributions (500 Samples)', fontsize=14)
plt.tight_layout()
plt.savefig('../results/reward_distributions.png', dpi=150, bbox_inches='tight')
plt.show()

## 4. Correlation Analysis

In [None]:
# Compute correlation matrix
reward_cols = ['r_warp', 'r_epi', 'r_sem', 'r_total', 'r_geom']
corr_matrix = df[reward_cols].corr()

# Plot heatmap
plt.figure(figsize=(8, 6))
sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', center=0, 
            vmin=-1, vmax=1, fmt='.3f')
plt.title('Reward Component Correlations')
plt.tight_layout()
plt.savefig('../results/reward_correlations.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nCorrelation Matrix:")
display(corr_matrix)

## 5. Scatter Plots

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# R_warp vs R_epi
axes[0].scatter(df['r_warp'], df['r_epi'], alpha=0.5, s=10)
axes[0].set_xlabel('R_warp')
axes[0].set_ylabel('R_epi')
axes[0].set_title('Geometric Components')

# R_geom vs R_sem
axes[1].scatter(df['r_geom'], df['r_sem'], alpha=0.5, s=10)
axes[1].set_xlabel('R_geom')
axes[1].set_ylabel('R_sem')
axes[1].set_title('Geometric vs Semantic')

# R_total distribution by R_sem quantile
df['sem_quantile'] = pd.qcut(df['r_sem'], q=4, labels=['Q1', 'Q2', 'Q3', 'Q4'])
for q in ['Q1', 'Q2', 'Q3', 'Q4']:
    subset = df[df['sem_quantile'] == q]
    axes[2].hist(subset['r_total'], bins=20, alpha=0.5, label=f'{q}')
axes[2].set_xlabel('R_total')
axes[2].set_ylabel('Frequency')
axes[2].set_title('R_total by R_sem Quartile')
axes[2].legend()

plt.tight_layout()
plt.savefig('../results/reward_scatter_plots.png', dpi=150, bbox_inches='tight')
plt.show()

## 6. Outlier Analysis

In [None]:
# Identify outliers (extreme low rewards)
low_total = df[df['r_total'] < df['r_total'].quantile(0.05)]
print(f"\nSamples with very low R_total (bottom 5%): {len(low_total)}")
display(low_total.head(10))

# High performers
high_total = df[df['r_total'] > df['r_total'].quantile(0.95)]
print(f"\nSamples with very high R_total (top 5%): {len(high_total)}")
display(high_total.head(10))

## 7. Save Results

In [None]:
# Save to CSV
df.to_csv('../results/reward_statistics.csv', index=False)
print("Saved reward statistics to ../results/reward_statistics.csv")