# Tutorial 04: Advanced Solver Configuration

Learn how to configure MFG solvers with custom parameters.

## Learning Objectives

By the end of this tutorial, you will understand:
- How to use MFGSolverConfig for fine-grained control
- How to compare different solver parameters
- How to tune convergence settings
- The SolverFactory API for advanced use cases

Note: Particle methods are available in the research repository (MFG-Research)
for specialized applications. This tutorial covers the standard grid-based
solver configuration available in the core package.

**Time estimate**: 15 minutes

## Step 1: Import Dependencies

In [None]:
import matplotlib.pyplot as plt
import numpy as np

from mfg_pde import MFGProblem
from mfg_pde.config import MFGSolverConfig, PicardConfig
from mfg_pde.core import MFGComponents
from mfg_pde.core.hamiltonian import QuadraticControlCost, SeparableHamiltonian
from mfg_pde.factory import SolverFactory
from mfg_pde.geometry import TensorProductGrid
from mfg_pde.geometry.boundary import no_flux_bc

## Step 2: Create Test Problem

We'll use a simple Linear-Quadratic problem to compare different solver configurations.

In [None]:
print("=" * 70)
print("TUTORIAL 04: Advanced Solver Configuration")
print("=" * 70)
print()

# Create grid and components
grid = TensorProductGrid(
    bounds=[(0.0, 1.0)],
    Nx_points=[50],
    boundary_conditions=no_flux_bc(dimension=1),
)

hamiltonian = SeparableHamiltonian(
    control_cost=QuadraticControlCost(control_cost=1.0),
    coupling=lambda m: 0.3 * m,
    coupling_dm=lambda m: 0.3,
)

components = MFGComponents(
    hamiltonian=hamiltonian,
    m_initial=lambda x: np.exp(-50 * (x - 0.5) ** 2),
    u_terminal=lambda x: (x - 0.5) ** 2,
)

# Create problem
problem = MFGProblem(
    geometry=grid,
    T=1.0,
    Nt=50,
    diffusion=0.15,
    components=components,
)

print("Problem configuration:")
print("  Domain: [0, 1]")
print(f"  Grid points: {problem.geometry.get_grid_shape()[0]}")
print(f"  Time steps: {problem.Nt}")
print(f"  Diffusion: {problem.sigma}")
print()

## Step 3: Solve with Default Settings

First, we'll solve using the default configuration to establish a baseline.

In [None]:
print("Solving with DEFAULT settings...")
print("-" * 70)

result_default = problem.solve(verbose=True)

print()
print(f"Default: Converged in {result_default.iterations} iterations")
print(f"  Final error: {result_default.max_error:.6e}")
print()

## Step 4: Custom Configuration with Tighter Tolerance

Now we'll use MFGSolverConfig for tighter convergence criteria.

In [None]:
print("=" * 70)
print("CUSTOM CONFIGURATION")
print("=" * 70)
print()

# Configuration with tighter tolerance
config_tight = MFGSolverConfig(
    picard=PicardConfig(
        max_iterations=50,  # More iterations allowed
        tolerance=1e-6,  # Tighter convergence tolerance
    )
)

print(f"Config: max_iterations={config_tight.picard.max_iterations}, tol={config_tight.picard.tolerance}")
print()

# Create solver with custom config using SolverFactory
solver_tight = SolverFactory.create_solver(problem, config=config_tight)
result_tight = solver_tight.solve(verbose=True)

print()
print(f"Tight tolerance: Converged in {result_tight.iterations} iterations")
print(f"  Final error: {result_tight.max_error:.6e}")
print()

## Step 5: Looser Tolerance for Speed

Configuration with looser tolerance for faster (but less accurate) results.

In [None]:
print("=" * 70)
print("TOLERANCE VS SPEED TRADE-OFF")
print("=" * 70)
print()

# Configuration with looser tolerance for faster results
config_fast = MFGSolverConfig(
    picard=PicardConfig(
        max_iterations=20,
        tolerance=1e-3,  # Looser tolerance
    )
)

solver_fast = SolverFactory.create_solver(problem, config=config_fast)
result_fast = solver_fast.solve(verbose=True)

print()
print(f"Fast (loose tol): Converged in {result_fast.iterations} iterations")
print(f"  Final error: {result_fast.max_error:.6e}")
print()

## Step 6: Compare Solutions

Let's compare the solutions from different configurations.

In [None]:
print("=" * 70)
print("SOLUTION COMPARISON")
print("=" * 70)
print()

# Calculate differences between solutions
diff_tight_default = np.linalg.norm(result_tight.M - result_default.M) / np.linalg.norm(result_default.M)
diff_fast_default = np.linalg.norm(result_fast.M - result_default.M) / np.linalg.norm(result_default.M)

print("Relative L2 differences from default:")
print(f"  Tight tolerance vs Default: {diff_tight_default:.6e}")
print(f"  Fast (loose) vs Default:    {diff_fast_default:.6e}")
print()

# Mass conservation check
dx = problem.geometry.get_grid_spacing()[0]
mass_default = np.sum(result_default.M[-1, :]) * dx
mass_tight = np.sum(result_tight.M[-1, :]) * dx
mass_fast = np.sum(result_fast.M[-1, :]) * dx

print("Final mass (should be ~1.0):")
print(f"  Default:         {mass_default:.6f}")
print(f"  Tight tolerance: {mass_tight:.6f}")
print(f"  Fast:            {mass_fast:.6f}")
print()

# Visualization
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

x = np.linspace(0, 1, problem.geometry.get_grid_shape()[0])

# Plot 1: Final density comparison
axes[0].plot(x, result_default.M[-1, :], "b-", linewidth=2, label="Default")
axes[0].plot(x, result_tight.M[-1, :], "g--", linewidth=2, label="Tight", alpha=0.8)
axes[0].plot(x, result_fast.M[-1, :], "r:", linewidth=2, label="Fast", alpha=0.8)
axes[0].set_xlabel("x")
axes[0].set_ylabel("m(T, x)")
axes[0].set_title("Final Density Comparison")
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Plot 2: Value function comparison
axes[1].plot(x, result_default.U[-1, :], "b-", linewidth=2, label="Default")
axes[1].plot(x, result_tight.U[-1, :], "g--", linewidth=2, label="Tight", alpha=0.8)
axes[1].plot(x, result_fast.U[-1, :], "r:", linewidth=2, label="Fast", alpha=0.8)
axes[1].set_xlabel("x")
axes[1].set_ylabel("u(T, x)")
axes[1].set_title("Terminal Value Function")
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Plot 3: Convergence comparison
if hasattr(result_default, "error_history_M") and result_default.error_history_M:
    axes[2].semilogy(result_default.error_history_M, "b-", label="Default", linewidth=2)
if hasattr(result_tight, "error_history_M") and result_tight.error_history_M:
    axes[2].semilogy(result_tight.error_history_M, "g--", label="Tight", linewidth=2)
if hasattr(result_fast, "error_history_M") and result_fast.error_history_M:
    axes[2].semilogy(result_fast.error_history_M, "r:", label="Fast", linewidth=2)
axes[2].set_xlabel("Iteration")
axes[2].set_ylabel("Error (log scale)")
axes[2].set_title("Convergence History")
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Summary

### What You Learned

1. How to use MFGSolverConfig for custom parameters
2. How to use SolverFactory for advanced control
3. The tolerance vs speed trade-off
4. How to compare solutions with different settings

### Tolerance Selection Guide

| Use Case                    | Recommended Tolerance |
|:----------------------------|:----------------------|
| Quick prototyping           | 1e-3 to 1e-4          |
| Standard research           | 1e-4 to 1e-5          |
| Publication quality         | 1e-6 to 1e-8          |
| Convergence studies         | Vary systematically   |

### Key Takeaway

Configuration choices affect both speed and accuracy.
Start simple, then tune based on your specific needs.

### Next Steps

Proceed to **Tutorial 05: Problem Variations** to learn about parameter studies.