# Tutorial 03: 2D Geometry

Learn how to work with multi-dimensional MFG problems using the `GridBasedMFGProblem` class.

## Learning Objectives

- Understand `GridBasedMFGProblem` for nD problems
- Work with 2D spatial domains
- Handle multi-dimensional Hamiltonians and gradients
- Visualize 2D density evolution

## Problem Setup

We'll solve a 2D target attraction problem:
- **Domain**: $[0, 10] \times [0, 10]$ (2D spatial grid)
- **Target**: Agents want to reach $(x^*, y^*) = (5, 5)$
- **Initial distribution**: Scattered around the domain
- **Congestion**: Agents avoid crowded regions

**Hamiltonian**:
$$H(\mathbf{x}, m, \mathbf{p}, t) = \frac{1}{2}|\mathbf{p}|^2 + \kappa \cdot m \cdot |\mathbf{p}|^2$$

where $\mathbf{p} = (p_x, p_y)$ is the 2D momentum.

**Time estimate**: 20 minutes

## Step 1: Import Dependencies

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import cm

from mfg_pde import solve_mfg
from mfg_pde.core.highdim_mfg_problem import GridBasedMFGProblem

## Step 2: Define 2D Problem Class

The `GridBasedMFGProblem` provides infrastructure for nD problems. Key differences from 1D:
- `domain_bounds`: Tuple of `(xmin, xmax, ymin, ymax, ...)`
- `hamiltonian(x, m, p, t)`: Now `x` and `p` are 2D arrays
- Grid coordinates: Use `geometry.grid` for spatial information

In [None]:
class TargetAttraction2D(GridBasedMFGProblem):
    """2D target attraction with congestion."""

    def __init__(self, target_location=(5.0, 5.0), congestion_weight=1.0):
        # Initialize with 2D domain [0,10] × [0,10]
        super().__init__(
            domain_bounds=(0.0, 10.0, 0.0, 10.0),  # (xmin, xmax, ymin, ymax)
            grid_resolution=30,  # 30×30 grid
            time_domain=(4.0, 40),  # T=4, Nt=40
            diffusion_coeff=0.2,  # σ = 0.2
        )
        self.target = np.array(target_location)
        self.congestion_weight = congestion_weight

    def initial_density(self, x):
        """
        Gaussian blobs at four corners.

        Args:
            x: (N, 2) array of positions
        """
        corners = np.array(
            [
                [2.0, 2.0],
                [8.0, 2.0],
                [2.0, 8.0],
                [8.0, 8.0],
            ]
        )

        density = np.zeros(len(x))
        for corner in corners:
            dist_sq = np.sum((x - corner) ** 2, axis=1)
            density += np.exp(-2.0 * dist_sq)

        # Normalize: ∫∫ m dx dy = 1
        dx, dy = self.geometry.grid.spacing
        dV = dx * dy
        return density / (np.sum(density) * dV + 1e-10)

    def terminal_cost(self, x):
        """
        Quadratic cost: distance to target.

        Args:
            x: (N, 2) array of positions
        """
        dist_sq = np.sum((x - self.target) ** 2, axis=1)
        return 5.0 * dist_sq

    def running_cost(self, x, t):
        """Small running cost to encourage fast movement."""
        return 0.1 * np.ones(x.shape[0])

    def hamiltonian(self, x, m, p, t):
        """
        H = (1/2)|p|² + κ·m·|p|²

        Args:
            x: (N, 2) positions
            m: (N,) densities
            p: (N, 2) momenta [px, py]
            t: scalar time
        """
        # Compute |p|² = px² + py²
        p_squared = np.sum(p**2, axis=1) if p.ndim > 1 else np.sum(p**2)

        # Standard control cost
        h = 0.5 * p_squared

        # Congestion cost
        h += self.congestion_weight * m * p_squared

        return h


print("2D problem class defined!")

## Step 3: Create and Inspect Problem

In [None]:
problem = TargetAttraction2D(target_location=(5.0, 5.0), congestion_weight=1.0)

print("Problem Configuration:")
print(f"  Spatial dimension: {problem.geometry.grid.dimension}D")
print(f"  Domain: {problem.geometry.grid.bounds}")
print(f"  Grid resolution: {problem.geometry.grid.resolution}")
print(f"  Total grid points: {problem.geometry.grid.total_points()}")
print(f"  Time horizon: T = {problem.T}")
print(f"  Time steps: Nt = {problem.nt}")
print(f"  Diffusion: σ = {problem.geometry.grid.spacing}")

## Step 4: Visualize Initial Density

In [None]:
# Get grid points and initial density
grid_points = problem.geometry.grid.get_all_points()
m0 = problem.initial_density(grid_points)

# Reshape to 2D grid for plotting
nx, ny = problem.geometry.grid.resolution
m0_grid = m0.reshape((nx, ny))

# Create meshgrid for plotting
x_coords = np.linspace(0, 10, nx)
y_coords = np.linspace(0, 10, ny)
X, Y = np.meshgrid(x_coords, y_coords, indexing="ij")

# Plot
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection="3d")

surf = ax.plot_surface(X, Y, m0_grid, cmap=cm.viridis, alpha=0.8)
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_zlabel("Density m(0, x, y)")
ax.set_title("Initial Density (4 corner blobs)")
fig.colorbar(surf, shrink=0.5, aspect=5)

plt.show()

## Step 5: Solve 2D MFG System

In [None]:
print("Solving 2D MFG system...\n")
result = solve_mfg(problem, verbose=True)

print("\nSolution completed!")
print(f"  Converged: {result.converged}")
print(f"  Iterations: {result.iterations}")
print(f"  Final residual: {result.residual:.6e}")

## Step 6: Visualize Density Evolution

In [None]:
# Select 4 time snapshots
time_indices = [0, result.M.shape[0] // 3, 2 * result.M.shape[0] // 3, -1]
time_labels = ["t = 0 (initial)", "t = T/3", "t = 2T/3", "t = T (final)"]

fig, axes = plt.subplots(2, 2, figsize=(14, 12))
axes = axes.ravel()

for idx, (t_idx, t_label) in enumerate(zip(time_indices, time_labels, strict=False)):
    # Get density at this time
    m_t = result.M[t_idx].reshape((nx, ny))

    # Contour plot
    contour = axes[idx].contourf(X, Y, m_t, levels=20, cmap="viridis")
    axes[idx].plot(5.0, 5.0, "r*", markersize=20, label="Target")
    axes[idx].set_xlabel("x")
    axes[idx].set_ylabel("y")
    axes[idx].set_title(t_label)
    axes[idx].legend()
    axes[idx].set_aspect("equal")
    plt.colorbar(contour, ax=axes[idx])

plt.tight_layout()
plt.show()

## Step 7: Check Mass Conservation

Verify that mass is conserved: $\int\int m(t,x,y) dx dy = 1$ for all $t$.

In [None]:
# Compute total mass at each timestep
dx, dy = problem.geometry.grid.spacing
dV = dx * dy

masses = []
for t_idx in range(result.M.shape[0]):
    mass_t = np.sum(result.M[t_idx]) * dV
    masses.append(mass_t)

masses = np.array(masses)

# Plot mass conservation
plt.figure(figsize=(10, 5))
plt.plot(problem.tSpace, masses, "b-", linewidth=2, label="Total mass")
plt.axhline(y=1.0, color="r", linestyle="--", label="Target mass = 1.0")
plt.xlabel("Time t")
plt.ylabel("Total mass $\\int\\int m(t,x,y) dx dy$")
plt.title("Mass Conservation Check")
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

print("\nMass conservation:")
print(f"  Initial mass: {masses[0]:.6f}")
print(f"  Final mass:   {masses[-1]:.6f}")
print(f"  Mass loss:    {abs(masses[-1] - masses[0]):.6e}")
print(f"  Relative error: {abs(masses[-1] - 1.0):.2%}")

## Step 8: Analyze Optimal Paths

Compute velocity field from Hamiltonian gradient: $\mathbf{v} = -\nabla_{\mathbf{p}} H$.

In [None]:
# Velocity field at t=T/2
t_mid = result.M.shape[0] // 2

# For standard Hamiltonian H = (1/2)|p|², we have v = -p = ∇u
# In 2D: vx = ∂u/∂x, vy = ∂u/∂y
# We can approximate gradients from U

u_mid = result.U[t_mid].reshape((nx, ny))
vy, vx = np.gradient(u_mid, problem.geometry.grid.spacing[1], problem.geometry.grid.spacing[0])

# Plot velocity field
plt.figure(figsize=(10, 8))
plt.contourf(X, Y, result.M[t_mid].reshape((nx, ny)), levels=15, cmap="viridis", alpha=0.6)

# Quiver plot (subsample for clarity)
skip = 3
plt.quiver(
    X[::skip, ::skip], Y[::skip, ::skip], vx[::skip, ::skip], vy[::skip, ::skip], color="white", scale=50, width=0.003
)

plt.plot(5.0, 5.0, "r*", markersize=20, label="Target")
plt.xlabel("x")
plt.ylabel("y")
plt.title("Velocity Field (white arrows) and Density (color) at t=T/2")
plt.colorbar(label="Density")
plt.legend()
plt.axis("equal")
plt.show()

## Summary

### What You Learned

1. How to use `GridBasedMFGProblem` for multi-dimensional problems
2. How to work with 2D spatial domains and grids
3. How to handle multi-dimensional Hamiltonians
4. How to visualize 2D density evolution
5. How to verify mass conservation in 2D
6. How to analyze velocity fields

### Key Differences from 1D

| Aspect | 1D | 2D |
|--------|----|----|  
| Domain | `(xmin, xmax)` | `(xmin, xmax, ymin, ymax)` |
| Position | Scalar `x` | Array `(x, y)` shape `(N, 2)` |
| Momentum | Scalar `p` | Array `(px, py)` shape `(N, 2)` |
| Hamiltonian | `H(x, p, m, t)` | `H(x, p, m, t)` with `|p|² = px² + py²` |
| Mass integral | `∫ m dx` | `∫∫ m dx dy` |

### Extension to Higher Dimensions

The same pattern extends to 3D, 4D, and beyond:
- 3D: `domain_bounds=(xmin, xmax, ymin, ymax, zmin, zmax)`
- 4D: Add two more bounds, etc.
- Momentum becomes `(N, D)` array for D dimensions

### Next Steps

Proceed to **Tutorial 04: Particle Methods** to learn about alternative numerical approaches.