# DirectRandomResizedCrop vs RandomResizedCrop: Distribution Comparison

Compare crop parameter distributions between:
- **Rejection sampling** (torchvision-compatible, 10-attempt loop)
- **Direct/analytic** (single-pass, no loop, guaranteed valid)

We generate ~50k crop params from each method across several image sizes and parameter settings, then compare marginal distributions.

In [None]:
import math
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats

from slipstream.decoders.numba_decoder import (
    _generate_random_crop_params_batch,
    _generate_direct_random_crop_params_batch,
)

In [None]:
def generate_params(func, width, height, n_samples, scale, ratio, seed=42):
    """Generate n_samples crop params for a single image size."""
    widths = np.full(n_samples, width, dtype=np.int32)
    heights = np.full(n_samples, height, dtype=np.int32)
    log_ratio_min = math.log(ratio[0])
    log_ratio_max = math.log(ratio[1])
    params = func(widths, heights, scale[0], scale[1], log_ratio_min, log_ratio_max, seed)
    # params: [N, 4] = (x, y, crop_w, crop_h)
    return {
        'x': params[:, 0].astype(float),
        'y': params[:, 1].astype(float),
        'crop_w': params[:, 2].astype(float),
        'crop_h': params[:, 3].astype(float),
        'aspect_ratio': params[:, 2].astype(float) / np.maximum(params[:, 3].astype(float), 1),
        'scale': (params[:, 2].astype(float) * params[:, 3].astype(float)) / (width * height),
    }

In [None]:
def compare_distributions(width, height, scale, ratio, n_samples=50000, title_suffix=""):
    """Generate and compare distributions from both methods."""
    rej = generate_params(_generate_random_crop_params_batch, width, height, n_samples, scale, ratio, seed=42)
    direct = generate_params(_generate_direct_random_crop_params_batch, width, height, n_samples, scale, ratio, seed=123)

    keys = ['crop_w', 'crop_h', 'aspect_ratio', 'scale', 'x', 'y']
    labels = ['Crop Width', 'Crop Height', 'Aspect Ratio (w/h)', 'Scale (area fraction)', 'X Position', 'Y Position']

    fig, axes = plt.subplots(2, 3, figsize=(14, 8))
    fig.suptitle(f'Image {width}×{height}, scale={scale}, ratio={ratio}{title_suffix}', fontsize=13)

    for ax, key, label in zip(axes.flat, keys, labels):
        r_vals = rej[key]
        d_vals = direct[key]

        # Common bins
        lo = min(r_vals.min(), d_vals.min())
        hi = max(r_vals.max(), d_vals.max())
        bins = np.linspace(lo, hi, 60)

        ax.hist(r_vals, bins=bins, alpha=0.5, density=True, label='Rejection')
        ax.hist(d_vals, bins=bins, alpha=0.5, density=True, label='Direct')
        ax.set_title(label, fontsize=10)
        ax.legend(fontsize=8)

        # KS test
        ks_stat, ks_p = stats.ks_2samp(r_vals, d_vals)
        ax.text(0.98, 0.95, f'KS={ks_stat:.3f}\np={ks_p:.2e}',
                transform=ax.transAxes, fontsize=7, ha='right', va='top',
                bbox=dict(boxstyle='round,pad=0.3', facecolor='wheat', alpha=0.5))

    plt.tight_layout()
    plt.show()

    # Print summary stats
    print(f"\n{'Metric':<20} {'Rejection mean':>15} {'Direct mean':>15} {'Rej std':>10} {'Dir std':>10}")
    print('-' * 70)
    for key, label in zip(keys, labels):
        r_vals = rej[key]
        d_vals = direct[key]
        print(f"{label:<20} {r_vals.mean():>15.3f} {d_vals.mean():>15.3f} {r_vals.std():>10.3f} {d_vals.std():>10.3f}")

## 1. Default parameters, square image (256×256)

In [None]:
compare_distributions(256, 256, scale=(0.08, 1.0), ratio=(3/4, 4/3))

## 2. Extreme aspect ratio range

In [None]:
compare_distributions(256, 256, scale=(0.08, 1.0), ratio=(0.5, 2.0))

## 3. Non-square image (512×256)

In [None]:
compare_distributions(512, 256, scale=(0.08, 1.0), ratio=(3/4, 4/3))

## 4. Small scale range

In [None]:
compare_distributions(256, 256, scale=(0.01, 0.1), ratio=(3/4, 4/3))

## 5. Fallback rate comparison

Count how often each method falls back to center crop (both methods should be near zero for typical ImageNet params).

In [None]:
def count_center_crop_fallbacks(params_dict, width, height):
    """Count samples that hit the center-crop fallback."""
    crop_w = params_dict['crop_w']
    crop_h = params_dict['crop_h']
    x = params_dict['x']
    y = params_dict['y']
    min_dim = min(width, height)
    # Fallback produces: crop_w == crop_h == min(w,h), centered
    is_fallback = (
        (crop_w == min_dim) & (crop_h == min_dim) &
        (x == (width - min_dim) // 2) & (y == (height - min_dim) // 2)
    )
    return is_fallback.sum()

configs = [
    (256, 256, (0.08, 1.0), (3/4, 4/3), 'Default 256×256'),
    (512, 256, (0.08, 1.0), (3/4, 4/3), 'Non-square 512×256'),
    (256, 256, (0.08, 1.0), (0.5, 2.0), 'Wide ratio 256×256'),
    (256, 256, (0.01, 0.1), (3/4, 4/3), 'Small scale 256×256'),
]

n = 50000
print(f"{'Config':<25} {'Rejection fallbacks':>20} {'Direct fallbacks':>20}")
print('-' * 65)
for w, h, scale, ratio, name in configs:
    rej = generate_params(_generate_random_crop_params_batch, w, h, n, scale, ratio, seed=42)
    direct = generate_params(_generate_direct_random_crop_params_batch, w, h, n, scale, ratio, seed=123)
    rej_fb = count_center_crop_fallbacks(rej, w, h)
    dir_fb = count_center_crop_fallbacks(direct, w, h)
    print(f"{name:<25} {rej_fb:>15d} ({rej_fb/n*100:.2f}%) {dir_fb:>11d} ({dir_fb/n*100:.2f}%)")