# 02 Memory-Based Strategy Selection

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/imewei/NLSQ/blob/main/examples/notebooks/08_workflow_system/02_workflow_tiers.ipynb)

Features demonstrated:
- Understanding the three strategies: standard, chunked, streaming
- MemoryBudget computation for dataset sizing
- MemoryBudgetSelector for automatic strategy selection
- Memory usage comparison across strategies

Run this example:
    python examples/scripts/08_workflow_system/02_workflow_tiers.py

In [1]:
# @title Install NLSQ (run once in Colab)
import sys

if 'google.colab' in sys.modules:
    print("Running in Google Colab - installing NLSQ...")
    !pip install -q nlsq
    print("NLSQ installed successfully!")
else:
    print("Not running in Colab - assuming NLSQ is already installed")

Not running in Colab - assuming NLSQ is already installed


In [2]:
from pathlib import Path

import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

from nlsq import fit
from nlsq.core.workflow import MemoryBudget, MemoryBudgetSelector
from nlsq.streaming.large_dataset import MemoryEstimator

FIG_DIR = Path.cwd() / "figures"
FIG_DIR.mkdir(parents=True, exist_ok=True)

In [3]:
def exponential_decay(x, a, b, c):
    """Exponential decay: y = a * exp(-b * x) + c"""
    return a * jnp.exp(-b * x) + c


def main():
    print("=" * 70)
    print("Memory-Based Strategy Selection")
    print("=" * 70)
    print()

    np.random.seed(42)

    # =========================================================================
    # 1. Overview of Strategies
    # =========================================================================
    print("1. Strategy Overview")
    print("-" * 50)

    strategy_info = {
        "standard": {
            "description": "Full in-memory computation",
            "memory": "O(N) - loads all data into memory",
            "best_for": "Small to medium datasets that fit in RAM",
        },
        "chunked": {
            "description": "Memory-managed chunk processing",
            "memory": "O(chunk_size) - processes data in chunks",
            "best_for": "Large datasets with Jacobian memory pressure",
        },
        "streaming": {
            "description": "Mini-batch gradient descent",
            "memory": "O(batch_size) - constant memory usage",
            "best_for": "Very large datasets (100M+ points)",
        },
    }

    for strategy, info in strategy_info.items():
        print(f"\n  {strategy.upper()}:")
        print(f"    Description: {info['description']}")
        print(f"    Memory: {info['memory']}")
        print(f"    Best for: {info['best_for']}")

    # =========================================================================
    # 2. MemoryBudget Computation
    # =========================================================================
    print()
    print()
    print("2. MemoryBudget Computation")
    print("-" * 70)

    dataset_configs = [
        (100_000, 5, "100K"),
        (1_000_000, 5, "1M"),
        (10_000_000, 5, "10M"),
        (50_000_000, 5, "50M"),
        (100_000_000, 5, "100M"),
    ]

    print(f"{'Dataset':<10} {'Data GB':<12} {'Jacobian GB':<15} {'Peak GB':<12} {'Fits?':<8}")
    print("-" * 70)

    for n_points, n_params, label in dataset_configs:
        budget = MemoryBudget.compute(
            n_points=n_points, n_params=n_params, safety_factor=0.75
        )
        fits = "Yes" if budget.fits_in_memory else "No"
        print(
            f"{label:<10} {budget.data_gb:<12.4f} {budget.jacobian_gb:<15.4f} "
            f"{budget.peak_gb:<12.4f} {fits:<8}"
        )

    # =========================================================================
    # 3. Decision Tree
    # =========================================================================
    print()
    print("3. Strategy Selection Decision Tree:")
    print("-" * 50)
    print()
    print("  ┌─────────────────────────────────────────────────┐")
    print("  │       Compute MemoryBudget                      │")
    print("  │  (data_gb, jacobian_gb, peak_gb, threshold_gb)  │")
    print("  └─────────────────────┬───────────────────────────┘")
    print("                        │")
    print("                        ▼")
    print("             ┌──────────────────────┐")
    print("             │ data_gb > threshold? │")
    print("             └──────────┬───────────┘")
    print("                 Yes │      │ No")
    print("                     │      │")
    print("                     ▼      ▼")
    print("           ┌─────────────┐ ┌──────────────────────┐")
    print("           │  STREAMING  │ │ peak_gb > threshold? │")
    print("           │  Strategy   │ └──────────┬───────────┘")
    print("           └─────────────┘      Yes │      │ No")
    print("                                    │      │")
    print("                                    ▼      ▼")
    print("                          ┌─────────────┐ ┌─────────────┐")
    print("                          │   CHUNKED   │ │  STANDARD   │")
    print("                          │  Strategy   │ │  Strategy   │")
    print("                          └─────────────┘ └─────────────┘")

    # =========================================================================
    # 4. MemoryBudgetSelector Usage
    # =========================================================================
    print()
    print()
    print("4. MemoryBudgetSelector Usage")
    print("-" * 60)

    available_memory = MemoryEstimator.get_available_memory_gb()
    selector = MemoryBudgetSelector(safety_factor=0.75)

    print(f"\n  Available memory: {available_memory:.1f} GB")
    print(f"  Threshold (75%): {available_memory * 0.75:.1f} GB")
    print()

    test_sizes = [10_000, 100_000, 1_000_000, 10_000_000, 100_000_000]
    n_params = 5

    print(f"  {'Dataset Size':<15} {'Strategy':<15} {'Config Type':<20}")
    print("  " + "-" * 50)

    for n_points in test_sizes:
        strategy, config = selector.select(n_points=n_points, n_params=n_params)
        config_type = type(config).__name__ if config else "None"
        size_str = f"{n_points:,}"
        print(f"  {size_str:<15} {strategy:<15} {config_type:<20}")

    # =========================================================================
    # 5. Test Fit with Automatic Selection
    # =========================================================================
    print()
    print()
    print("5. Test Fit with Automatic Selection")
    print("-" * 50)

    n_samples = 1000
    x_data = np.linspace(0, 5, n_samples)
    true_a, true_b, true_c = 3.0, 1.2, 0.5
    y_true = true_a * np.exp(-true_b * x_data) + true_c
    y_data = y_true + 0.1 * np.random.randn(n_samples)

    print(f"  Test dataset: {n_samples} points")
    print(f"  True parameters: a={true_a}, b={true_b}, c={true_c}")

    popt, _ = fit(
        exponential_decay,
        x_data,
        y_data,
        p0=[1.0, 1.0, 0.0],
        workflow="auto",
    )
    print(f"  Fitted: a={popt[0]:.4f}, b={popt[1]:.4f}, c={popt[2]:.4f}")

    # =========================================================================
    # 6. Strategy Boundaries Visualization
    # =========================================================================
    print()
    print("6. Saving strategy boundaries visualization...")

    fig, ax = plt.subplots(figsize=(12, 8))

    dataset_sizes = np.logspace(4, 9, 100)  # 10K to 1B
    memory_limits = np.linspace(4, 128, 50)

    n_params = 5
    strategy_map = np.zeros((len(memory_limits), len(dataset_sizes)))

    for i, mem_limit in enumerate(memory_limits):
        for j, n_points in enumerate(dataset_sizes):
            strategy, _ = selector.select(
                n_points=int(n_points), n_params=n_params, memory_limit_gb=mem_limit
            )
            if strategy == "streaming":
                strategy_map[i, j] = 2
            elif strategy == "chunked":
                strategy_map[i, j] = 1
            else:
                strategy_map[i, j] = 0

    cmap = plt.cm.RdYlGn_r
    im = ax.imshow(
        strategy_map,
        aspect="auto",
        origin="lower",
        cmap=cmap,
        extent=[4, 9, 4, 128],
    )

    ax.set_xlabel("Dataset Size (log10)")
    ax.set_ylabel("Memory Limit (GB)")
    ax.set_title("Strategy Selection Boundaries (5 parameters)")

    cbar = plt.colorbar(im, ax=ax, ticks=[0, 1, 2])
    cbar.ax.set_yticklabels(["Standard", "Chunked", "Streaming"])

    ax.axhline(y=available_memory, color="white", linestyle="--", linewidth=2)
    ax.text(
        9.05,
        available_memory,
        f"Current: {available_memory:.0f} GB",
        color="white",
        va="center",
    )

    plt.tight_layout()
    plt.savefig(FIG_DIR / "02_strategy_boundaries.png", dpi=300, bbox_inches="tight")
    plt.close()
    print(f"  Saved: {FIG_DIR / '02_strategy_boundaries.png'}")

    # =========================================================================
    # Summary
    # =========================================================================
    print()
    print("=" * 70)
    print("Summary")
    print("=" * 70)
    print()
    print("Strategies:")
    print("  standard:  Full in-memory computation")
    print("  chunked:   Memory-managed chunk processing")
    print("  streaming: Mini-batch gradient descent")
    print()
    print("Decision tree (in order):")
    print("  1. data_gb > threshold -> STREAMING")
    print("  2. peak_gb > threshold -> CHUNKED")
    print("  3. else -> STANDARD")
    print()
    print("Key APIs:")
    print("  MemoryBudget.compute(n_points, n_params) - Compute memory requirements")
    print("  MemoryBudgetSelector().select(...)      - Get optimal strategy")
    print("  fit(model, x, y, workflow='auto')       - Automatic selection in fit()")
    print()
    print(f"Current system memory: {available_memory:.1f} GB")

In [4]:
if __name__ == "__main__":
    main()

INFO:nlsq.curve_fit:Starting curve fit n_params=3 | n_data_points=1000 | method=trf | solver=auto | batch_size=None | has_bounds=False | dynamic_sizing=False


INFO:nlsq.least_squares:Starting least squares optimization method=trf | n_params=3 | loss=linear | ftol=1.0000e-08 | xtol=1.0000e-08 | gtol=1.0000e-08


Memory-Based Strategy Selection

1. Strategy Overview
--------------------------------------------------

  STANDARD:
    Description: Full in-memory computation
    Memory: O(N) - loads all data into memory
    Best for: Small to medium datasets that fit in RAM

  CHUNKED:
    Description: Memory-managed chunk processing
    Memory: O(chunk_size) - processes data in chunks
    Best for: Large datasets with Jacobian memory pressure

  STREAMING:
    Description: Mini-batch gradient descent
    Memory: O(batch_size) - constant memory usage
    Best for: Very large datasets (100M+ points)


2. MemoryBudget Computation
----------------------------------------------------------------------
Dataset    Data GB      Jacobian GB     Peak GB      Fits?   
----------------------------------------------------------------------
100K       0.0015       0.0037          0.1063       Yes     
1M         0.0149       0.0373          0.1633       Yes     
10M        0.1490       0.3725          0.7333  

INFO:nlsq.optimizer.trf:Starting TRF optimization (no bounds) n_params=3 | n_residuals=1000 | max_nfev=None


PERFORMANCE:nlsq.optimizer.trf:Iteration iter=0 | cost=433.2687779247449 | grad_norm=802.7427 | nfev=1


PERFORMANCE:nlsq.optimizer.trf:Iteration iter=1 | cost=15.486553975599099 | grad_norm=28.3633 | step=2.8284271247461903 | nfev=2


PERFORMANCE:nlsq.optimizer.trf:Iteration iter=2 | cost=4.9908953514845145 | grad_norm=5.2477 | step=2.8284271247461903 | nfev=3


PERFORMANCE:nlsq.optimizer.trf:Iteration iter=3 | cost=4.783146907512251 | grad_norm=0.6351 | step=2.8284271247461903 | nfev=4


PERFORMANCE:nlsq.optimizer.trf:Iteration iter=4 | cost=4.782854992102272 | grad_norm=8.8347e-04 | step=2.8284271247461903 | nfev=5


PERFORMANCE:nlsq.least_squares:Timer: optimization elapsed=1.189131s


INFO:nlsq.least_squares:Convergence reason=`ftol` termination condition is satisfied. | iterations=5 | final_cost=4.7829 | elapsed=1.189s | final_gradient_norm=1.2022e-05


PERFORMANCE:nlsq.curve_fit:Timer: curve_fit elapsed=1.940183s




  Fitted: a=2.9824, b=1.1950, c=0.5028

6. Saving strategy boundaries visualization...


  Saved: /home/wei/Documents/GitHub/NLSQ/examples/notebooks/08_workflow_system/figures/02_strategy_boundaries.png

Summary

Strategies:
  standard:  Full in-memory computation
  chunked:   Memory-managed chunk processing
  streaming: Mini-batch gradient descent

Decision tree (in order):
  1. data_gb > threshold -> STREAMING
  2. peak_gb > threshold -> CHUNKED
  3. else -> STANDARD

Key APIs:
  MemoryBudget.compute(n_points, n_params) - Compute memory requirements
  MemoryBudgetSelector().select(...)      - Get optimal strategy
  fit(model, x, y, workflow='auto')       - Automatic selection in fit()

Current system memory: 33.3 GB
