# Tournament Selection for Streaming Global Optimization

This tutorial demonstrates tournament selection for memory-efficient global optimization
on large/streaming datasets where evaluating all candidates on all data is infeasible.

**Features demonstrated:**
- `TournamentSelector` for progressive elimination
- Tournament selection for memory-efficient global optimization
- Configuration: `tournament_size`, `selection_pressure`, `memory_limit_gb`
- Streaming candidate processing
- Visualization of tournament progression

**Level: Advanced** | **Duration: 30 minutes**

In [None]:
# Configure matplotlib for inline plotting (MUST come before imports)
%matplotlib inline

In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

from nlsq import GlobalOptimizationConfig
from nlsq.global_optimization import (
    TournamentSelector,
    latin_hypercube_sample,
    scale_samples_to_bounds,
)

In [None]:
# Set random seed for reproducibility
np.random.seed(42)

## 1. The Challenge: Multi-Start on Large Datasets

Multi-start optimization evaluates many starting points to find the global optimum.
For large/streaming datasets, fully evaluating all candidates is expensive:

- **Memory:** Cannot load entire dataset at once
- **Time:** Evaluating N candidates on M data points = O(N * M)
- **Streaming:** Data arrives in batches, not all at once

**Tournament selection** solves this by progressively eliminating poor candidates
based on performance on data batches, keeping only the best for final evaluation.

## 2. Tournament Selection Algorithm

The tournament proceeds in rounds:

1. **Start:** N candidates
2. **Round 1:** Evaluate on batch 1, eliminate worst (1 - elimination_fraction) * N
3. **Round 2:** Evaluate survivors on batch 2, eliminate again
4. **...continue until top M candidates remain**

This reduces evaluation cost from O(N * total_batches) to O(N * batches_per_round * rounds).

In [None]:
# Define a multimodal model for demonstration
def multimodal_model(x, a, b, c):
    """Sinusoidal model with multiple local minima.
    
    y = a * sin(b * x) + c
    """
    return a * jnp.sin(b * x) + c

In [None]:
# True parameters for synthetic data
true_a, true_b, true_c = 2.5, 1.8, 1.0

print(f"True parameters: a={true_a}, b={true_b}, c={true_c}")

## 3. Generate Candidate Starting Points

We generate candidate starting points using Latin Hypercube Sampling (LHS)
to ensure good coverage of the parameter space.

In [None]:
# Define parameter bounds
lb = np.array([0.5, 0.5, -2.0])  # Lower bounds: a, b, c
ub = np.array([5.0, 4.0, 5.0])   # Upper bounds: a, b, c

# Number of candidate starting points
n_candidates = 20
n_params = 3

# Generate candidates using LHS
key = jax.random.PRNGKey(42)
lhs_samples = latin_hypercube_sample(n_candidates, n_params, rng_key=key)
candidates = scale_samples_to_bounds(lhs_samples, lb, ub)

print(f"Generated {n_candidates} candidates in {n_params}D parameter space")
print(f"Candidate shape: {candidates.shape}")
print(f"\nFirst 5 candidates:")
for i in range(5):
    print(f"  Candidate {i}: a={candidates[i,0]:.2f}, b={candidates[i,1]:.2f}, c={candidates[i,2]:.2f}")

## 4. Configure Tournament Selection

Tournament parameters in `GlobalOptimizationConfig`:

- `elimination_rounds`: Number of elimination rounds (default: 3)
- `elimination_fraction`: Fraction to eliminate each round (default: 0.5 = 50%)
- `batches_per_round`: Number of data batches per round (default: 50)

In [None]:
# Configure tournament parameters
config = GlobalOptimizationConfig(
    n_starts=n_candidates,
    sampler="lhs",
    elimination_rounds=3,         # 3 elimination rounds
    elimination_fraction=0.5,     # Eliminate 50% each round
    batches_per_round=10,         # Evaluate on 10 batches per round
)

print("Tournament Configuration:")
print(f"  n_starts:             {config.n_starts}")
print(f"  elimination_rounds:   {config.elimination_rounds}")
print(f"  elimination_fraction: {config.elimination_fraction}")
print(f"  batches_per_round:    {config.batches_per_round}")

# Calculate expected progression
expected_survivors = n_candidates
print(f"\nExpected tournament progression:")
print(f"  Start: {expected_survivors} candidates")
for r in range(config.elimination_rounds):
    expected_survivors = max(1, int(expected_survivors * (1 - config.elimination_fraction)))
    print(f"  After round {r+1}: {expected_survivors} survivors")

## 5. Create TournamentSelector

The `TournamentSelector` class manages the progressive elimination process.

In [None]:
# Create TournamentSelector
selector = TournamentSelector(candidates=candidates, config=config)

print(f"TournamentSelector initialized:")
print(f"  n_candidates: {selector.n_candidates}")
print(f"  n_params:     {selector.n_params}")
print(f"  n_survivors:  {selector.n_survivors}")
print(f"  current_round: {selector.current_round}")

## 6. Streaming Data Generator

In streaming scenarios, data arrives in batches. We simulate this with a generator
that yields (x_batch, y_batch) tuples.

In [None]:
def create_data_batch_generator(n_batches=100, batch_size=500, noise_level=0.3):
    """Generator that yields streaming data batches.
    
    Simulates a streaming scenario where data arrives in batches
    rather than being available all at once.
    
    Parameters
    ----------
    n_batches : int
        Total number of batches to generate
    batch_size : int
        Number of points per batch
    noise_level : float
        Standard deviation of Gaussian noise
    
    Yields
    ------
    tuple[np.ndarray, np.ndarray]
        (x_batch, y_batch) data pairs
    """
    for batch_idx in range(n_batches):
        # Generate random x values for this batch
        x_batch = np.random.uniform(0, 4 * np.pi, batch_size)
        
        # Generate y values with true parameters + noise
        y_true = true_a * np.sin(true_b * x_batch) + true_c
        noise = noise_level * np.random.randn(batch_size)
        y_batch = y_true + noise
        
        yield x_batch, y_batch


# Test the generator
test_gen = create_data_batch_generator(n_batches=3, batch_size=100)
for i, (x, y) in enumerate(test_gen):
    print(f"Batch {i}: x shape={x.shape}, y shape={y.shape}")

## 7. Run Tournament Selection

The `run_tournament` method processes data batches and progressively eliminates
poor-performing candidates.

In [None]:
# Create fresh selector and data generator
selector = TournamentSelector(candidates=candidates, config=config)

# Need enough batches for all rounds
total_batches_needed = config.elimination_rounds * config.batches_per_round + 10
data_gen = create_data_batch_generator(n_batches=total_batches_needed, batch_size=500)

print("Running tournament selection...")
print(f"Total batches available: {total_batches_needed}")
print()

# Run tournament and get top candidate
best_candidates = selector.run_tournament(
    data_batch_iterator=data_gen,
    model=multimodal_model,
    top_m=3,  # Return top 3 candidates
)

print(f"\nTournament complete!")
print(f"Top 3 candidates:")
for i, params in enumerate(best_candidates):
    print(f"  {i+1}. a={params[0]:.3f}, b={params[1]:.3f}, c={params[2]:.3f}")

print(f"\nTrue parameters: a={true_a}, b={true_b}, c={true_c}")

## 8. Tournament Diagnostics

The selector provides detailed diagnostics about the tournament process.

In [None]:
# Get tournament diagnostics
diagnostics = selector.get_diagnostics()

print("Tournament Diagnostics:")
print(f"  Initial candidates: {diagnostics['n_candidates_initial']}")
print(f"  Final survivors:    {diagnostics['n_survivors']}")
print(f"  Elimination rate:   {diagnostics['elimination_rate']:.1%}")
print(f"  Rounds completed:   {diagnostics['rounds_completed']}")
print(f"  Total batches:      {diagnostics['total_batches_evaluated']}")
print(f"  Numerical failures: {diagnostics['numerical_failures']}")

if diagnostics['mean_survivor_loss'] is not None:
    print(f"  Mean survivor loss: {diagnostics['mean_survivor_loss']:.6f}")

In [None]:
# Round-by-round history
print("\nRound History:")
print("-" * 70)
print(f"{'Round':<8} {'Before':<10} {'After':<10} {'Eliminated':<12} {'Mean Loss':<12}")
print("-" * 70)

for round_info in diagnostics['round_history']:
    print(
        f"{round_info['round']:<8} "
        f"{round_info['n_survivors_before']:<10} "
        f"{round_info['n_survivors_after']:<10} "
        f"{round_info['n_eliminated']:<12} "
        f"{round_info['mean_loss']:.6f}"
    )

## 9. Visualize Tournament Progression

Let's visualize how candidates are eliminated through the tournament.

In [None]:
# Extract round history for plotting
rounds = [0] + [r['round'] + 1 for r in diagnostics['round_history']]
survivors = [diagnostics['n_candidates_initial']] + [r['n_survivors_after'] for r in diagnostics['round_history']]
mean_losses = [None] + [r['mean_loss'] for r in diagnostics['round_history']]

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: Survivor count over rounds
ax1 = axes[0]
ax1.plot(rounds, survivors, 'bo-', linewidth=2, markersize=10)
ax1.fill_between(rounds, survivors, alpha=0.3)
ax1.set_xlabel('Tournament Round')
ax1.set_ylabel('Number of Candidates')
ax1.set_title('Tournament Elimination: Candidate Survival')
ax1.set_xticks(rounds)
ax1.grid(True, alpha=0.3)

# Add annotations
for r, s in zip(rounds, survivors):
    ax1.annotate(str(s), (r, s), textcoords="offset points", xytext=(0, 10), ha='center')

# Right: Mean loss evolution
ax2 = axes[1]
valid_rounds = [r for r, m in zip(rounds, mean_losses) if m is not None]
valid_losses = [m for m in mean_losses if m is not None]

if valid_losses:
    ax2.plot(valid_rounds, valid_losses, 'ro-', linewidth=2, markersize=10)
    ax2.set_xlabel('Tournament Round')
    ax2.set_ylabel('Mean Survivor Loss')
    ax2.set_title('Tournament Elimination: Loss Improvement')
    ax2.set_xticks(valid_rounds)
    ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('figures/04_tournament_progression.png', dpi=300, bbox_inches='tight')
plt.show()

## 10. Visualize Candidate Losses

Let's visualize the cumulative loss of each candidate and which ones survived.

In [None]:
# Get cumulative losses and survival status
cumulative_losses = selector.cumulative_losses
survival_mask = selector.survival_mask

# Sort candidates by loss for visualization
sorted_indices = np.argsort(cumulative_losses)

fig, ax = plt.subplots(figsize=(12, 6))

# Plot bars for each candidate
colors = ['green' if survival_mask[i] else 'red' for i in sorted_indices]
losses_sorted = [cumulative_losses[i] for i in sorted_indices]

# Cap infinite values for visualization
max_finite_loss = max(l for l in losses_sorted if np.isfinite(l)) * 1.5
losses_capped = [min(l, max_finite_loss) for l in losses_sorted]

bars = ax.bar(range(len(sorted_indices)), losses_capped, color=colors, alpha=0.7)

ax.set_xlabel('Candidate (sorted by loss)')
ax.set_ylabel('Cumulative Loss')
ax.set_title('Tournament Results: Cumulative Loss by Candidate')

# Add legend
from matplotlib.patches import Patch
legend_elements = [
    Patch(facecolor='green', alpha=0.7, label='Survivor'),
    Patch(facecolor='red', alpha=0.7, label='Eliminated'),
]
ax.legend(handles=legend_elements, loc='upper left')

ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('figures/04_candidate_losses.png', dpi=300, bbox_inches='tight')
plt.show()

## 11. Compare Different Elimination Strategies

Let's compare different `elimination_fraction` values to understand the tradeoffs.

In [None]:
# Compare different elimination fractions
elimination_fractions = [0.25, 0.5, 0.75]
comparison_results = {}

print("Comparing elimination fractions:")
print("=" * 70)

for elim_frac in elimination_fractions:
    # Configure with this elimination fraction
    config = GlobalOptimizationConfig(
        n_starts=n_candidates,
        elimination_rounds=3,
        elimination_fraction=elim_frac,
        batches_per_round=10,
    )
    
    # Create fresh candidates and selector
    key = jax.random.PRNGKey(42)
    lhs_samples = latin_hypercube_sample(n_candidates, n_params, rng_key=key)
    candidates = scale_samples_to_bounds(lhs_samples, lb, ub)
    
    selector = TournamentSelector(candidates=candidates, config=config)
    data_gen = create_data_batch_generator(n_batches=50, batch_size=500)
    
    best = selector.run_tournament(
        data_batch_iterator=data_gen,
        model=multimodal_model,
        top_m=1,
    )
    
    diag = selector.get_diagnostics()
    
    comparison_results[elim_frac] = {
        'best_params': best[0],
        'n_survivors': diag['n_survivors'],
        'total_batches': diag['total_batches_evaluated'],
        'round_history': diag['round_history'],
    }
    
    print(f"\nelimination_fraction = {elim_frac}:")
    print(f"  Survivors: {diag['n_survivors']}")
    print(f"  Batches evaluated: {diag['total_batches_evaluated']}")
    print(f"  Best params: a={best[0][0]:.3f}, b={best[0][1]:.3f}, c={best[0][2]:.3f}")

In [None]:
# Visualize the comparison
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

colors = ['blue', 'green', 'orange']

# Plot survivor progression for each strategy
ax1 = axes[0]
for (elim_frac, data), color in zip(comparison_results.items(), colors):
    rounds = [0] + [r['round'] + 1 for r in data['round_history']]
    survivors = [n_candidates] + [r['n_survivors_after'] for r in data['round_history']]
    ax1.plot(rounds, survivors, 'o-', color=color, linewidth=2, markersize=8,
             label=f'elim_frac={elim_frac}')

ax1.set_xlabel('Tournament Round')
ax1.set_ylabel('Survivors')
ax1.set_title('Survivor Count by Elimination Fraction')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Bar chart: Total batches evaluated
ax2 = axes[1]
fracs = list(comparison_results.keys())
batches = [comparison_results[f]['total_batches'] for f in fracs]
ax2.bar([str(f) for f in fracs], batches, color=colors)
ax2.set_xlabel('Elimination Fraction')
ax2.set_ylabel('Total Batches Evaluated')
ax2.set_title('Computational Cost')

# Bar chart: Parameter error
ax3 = axes[2]
true_params = np.array([true_a, true_b, true_c])
errors = []
for f in fracs:
    best_p = comparison_results[f]['best_params']
    error = np.linalg.norm(best_p - true_params)
    errors.append(error)

ax3.bar([str(f) for f in fracs], errors, color=colors)
ax3.set_xlabel('Elimination Fraction')
ax3.set_ylabel('Parameter Error (L2)')
ax3.set_title('Best Candidate Accuracy')

plt.tight_layout()
plt.savefig('figures/04_elimination_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

## 12. Checkpointing for Fault Tolerance

For long-running tournaments, you can save and restore the selector state.

In [None]:
# Create a selector and run partial tournament
config = GlobalOptimizationConfig(
    n_starts=n_candidates,
    elimination_rounds=3,
    elimination_fraction=0.5,
    batches_per_round=10,
)

key = jax.random.PRNGKey(42)
lhs_samples = latin_hypercube_sample(n_candidates, n_params, rng_key=key)
candidates = scale_samples_to_bounds(lhs_samples, lb, ub)

selector = TournamentSelector(candidates=candidates, config=config)

# Save checkpoint
checkpoint = selector.to_checkpoint()

print("Checkpoint contents:")
for key, value in checkpoint.items():
    if isinstance(value, np.ndarray):
        print(f"  {key}: ndarray shape={value.shape}")
    elif isinstance(value, list):
        print(f"  {key}: list length={len(value)}")
    else:
        print(f"  {key}: {value}")

In [None]:
# Restore from checkpoint
restored_selector = TournamentSelector.from_checkpoint(checkpoint, config)

print(f"Restored selector:")
print(f"  n_candidates: {restored_selector.n_candidates}")
print(f"  n_survivors:  {restored_selector.n_survivors}")
print(f"  current_round: {restored_selector.current_round}")

# Can continue tournament from where it left off
# data_gen = create_data_batch_generator(...)
# best = restored_selector.run_tournament(data_gen, model, top_m=1)

## 13. Key Takeaways

1. **Tournament selection** is memory-efficient for large/streaming datasets:
   - Evaluates candidates on data batches, not full dataset
   - Progressively eliminates poor performers
   - Reduces computational cost from O(N * M) to O(N * batches * rounds)

2. **Configuration parameters:**
   - `elimination_rounds`: More rounds = more filtering, fewer survivors
   - `elimination_fraction`: Higher = more aggressive pruning
   - `batches_per_round`: More batches = better candidate ranking

3. **Tradeoffs:**
   - Aggressive elimination (0.75): Faster, but may eliminate good candidates early
   - Conservative elimination (0.25): Slower, but more robust
   - Default (0.5): Balanced approach

4. **Checkpointing:** Use `to_checkpoint()` and `from_checkpoint()` for fault tolerance

5. **Diagnostics:** `get_diagnostics()` provides detailed tournament statistics

In [None]:
# Summary statistics
print("Summary")
print("=" * 50)
print(f"True parameters: a={true_a}, b={true_b}, c={true_c}")
print()
print("Tournament selection is ideal for:")
print("  - Large datasets that exceed memory")
print("  - Streaming data scenarios")
print("  - High-dimensional parameter spaces")
print()
print("Use GlobalOptimizationConfig with:")
print("  - elimination_rounds: 2-4 (more = more filtering)")
print("  - elimination_fraction: 0.25-0.75 (higher = faster)")
print("  - batches_per_round: 10-100 (more = better ranking)")