# Optimal Transport on 2D Grid with Obstacles

This notebook demonstrates the Flow Sinkhorn algorithm on a **2D regular grid** with **cost modulation** by Gaussian bumps.

We will:
1. Create a 30×30 square grid graph
2. Modulate edge costs with two Gaussian bumps (obstacles)
3. Place source at top-left and sink at bottom-right
4. Compute exact optimal transport using linear programming
5. Compute approximate transport using Flow Sinkhorn
6. Visualize how the flow avoids high-cost regions (obstacles)

This example illustrates **obstacle avoidance** in optimal transport.

In [None]:
# Install Flow Sinkhorn from GitHub (for Colab)
try:
    import google.colab
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

if IN_COLAB:
    print('Running on Colab - installing flowsinkhorn from GitHub...')
    !pip install -q git+https://github.com/gpeyre/flow-sinkhorn.git
    print('Installation complete!')
else:
    print('Running locally - using local flowsinkhorn')

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
from matplotlib.collections import LineCollection
import sys
import time
import warnings


# Import Flow Sinkhorn toolbox
from flowsinkhorn import sinkhorn_w1_sparse, solve_w1_exact
import sparse

warnings.filterwarnings('ignore')

%matplotlib inline
plt.rcParams['figure.figsize'] = (12, 10)

## 1. Create 2D Square Grid

We create a regular grid graph:
- Grid size: 30×30 = 900 vertices
- Each vertex connected to 4 neighbors (up, down, left, right)
- Vertices at position (i, j) with i, j ∈ {0, 1, ..., 29}

In [None]:
def create_grid_graph(n):
    """
    Create a 2D square grid graph.
    
    Parameters
    ----------
    n : int
        Grid size (n×n vertices)
    
    Returns
    -------
    positions : ndarray of shape (n*n, 2)
        2D coordinates of vertices
    A : ndarray of shape (n*n, n*n)
        Adjacency matrix (binary)
    edges : list of tuples
        List of (i, j) edge pairs
    """
    # Create vertex positions
    positions = np.zeros((n * n, 2))
    for i in range(n):
        for j in range(n):
            idx = i * n + j
            positions[idx] = [j, i]  # (x, y) coordinates
    
    # Create adjacency matrix
    N = n * n
    A = np.zeros((N, N))
    edges = []
    
    for i in range(n):
        for j in range(n):
            idx = i * n + j
            
            # Connect to right neighbor
            if j < n - 1:
                idx_right = i * n + (j + 1)
                A[idx, idx_right] = 1
                A[idx_right, idx] = 1
                edges.append((idx, idx_right))
            
            # Connect to bottom neighbor
            if i < n - 1:
                idx_bottom = (i + 1) * n + j
                A[idx, idx_bottom] = 1
                A[idx_bottom, idx] = 1
                edges.append((idx, idx_bottom))
    
    return positions, A, edges

# Create grid
grid_size = 30
positions, A, edges = create_grid_graph(grid_size)
n_vertices = len(positions)
n_edges = len(edges)

print(f"Grid created:")
print(f"  - Grid size: {grid_size}×{grid_size}")
print(f"  - {n_vertices} vertices")
print(f"  - {n_edges} edges")
print(f"  - Average degree: {A.sum() / n_vertices:.1f}")

## 2. Add Gaussian Bumps as Obstacles

We create two Gaussian bumps in the middle of the grid to modulate edge costs.

For each edge, we compute:
$$
W_{ij} = 1 + \alpha \cdot \left( g_1(\text{midpoint}) + g_2(\text{midpoint}) \right)
$$

where $g_k$ are Gaussian functions centered at specific locations.

This creates **high-cost regions** that the optimal transport will tend to avoid.

In [None]:
def gaussian_bump(positions, center, sigma):
    """
    Compute Gaussian bump at each position.
    
    Parameters
    ----------
    positions : ndarray of shape (n, 2)
        Position coordinates
    center : tuple (x, y)
        Center of Gaussian
    sigma : float
        Standard deviation
    
    Returns
    -------
    values : ndarray of shape (n,)
        Gaussian values at each position
    """
    center = np.array(center)
    distances_sq = np.sum((positions - center[None, :])**2, axis=1)
    return np.exp(-distances_sq / (2 * sigma**2))

def compute_modulated_costs(positions, A, edges, centers, sigmas, alpha=10.0):
    """
    Compute edge costs modulated by Gaussian bumps.
    
    Parameters
    ----------
    positions : ndarray of shape (n, 2)
        Vertex positions
    A : ndarray of shape (n, n)
        Adjacency matrix
    edges : list of tuples
        Edge list
    centers : list of tuples
        Centers of Gaussian bumps
    sigmas : list of floats
        Standard deviations of bumps
    alpha : float
        Amplitude of cost modulation
    
    Returns
    -------
    W : ndarray of shape (n, n)
        Cost matrix
    cost_field : ndarray of shape (n,)
        Cost field at vertices (for visualization)
    """
    n = len(positions)
    
    # Compute total Gaussian field at vertices
    cost_field = np.zeros(n)
    for center, sigma in zip(centers, sigmas):
        cost_field += gaussian_bump(positions, center, sigma)
    
    # Compute edge costs based on midpoint values
    W = np.zeros((n, n))
    for i, j in edges:
        # Midpoint of edge
        midpoint = (positions[i] + positions[j]) / 2
        
        # Compute Gaussian at midpoint
        bump_value = 0
        for center, sigma in zip(centers, sigmas):
            bump_value += gaussian_bump(midpoint[None, :], center, sigma)[0]
        
        # Base cost (Euclidean distance) + bump modulation
        base_cost = np.linalg.norm(positions[i] - positions[j])
        cost = base_cost * (1 + alpha * bump_value)
        
        W[i, j] = cost
        W[j, i] = cost
    
    # Set large value for non-edges
    W[A == 0] = 1e9
    np.fill_diagonal(W, 0)
    
    return W, cost_field

# Define two Gaussian bumps in the middle
centers = [
    (grid_size * 0.4, grid_size * 0.5),  # Left-center bump
    (grid_size * 0.6, grid_size * 0.5),  # Right-center bump
]
sigmas = [3.0, 3.0]  # Standard deviations
alpha = 10.0  # Amplitude of cost modulation

W, cost_field = compute_modulated_costs(positions, A, edges, centers, sigmas, alpha)

print(f"\nCost modulation:")
print(f"  - Number of bumps: {len(centers)}")
print(f"  - Bump centers: {centers}")
print(f"  - Sigmas: {sigmas}")
print(f"  - Amplitude α: {alpha}")
print(f"  - Min edge cost: {W[A > 0].min():.4f}")
print(f"  - Max edge cost: {W[A > 0].max():.4f}")
print(f"  - Mean edge cost: {W[A > 0].mean():.4f}")

## 3. Visualize Grid and Cost Field

Display the grid with the cost field (Gaussian bumps) as a heatmap.

In [None]:
def plot_grid_with_costs(positions, cost_field, grid_size, title="Grid with Cost Field"):
    """
    Plot the grid with cost field as a heatmap.
    
    Parameters
    ----------
    positions : ndarray of shape (n, 2)
        Vertex positions
    cost_field : ndarray of shape (n,)
        Cost at each vertex
    grid_size : int
        Size of the grid
    title : str
        Plot title
    """
    # Reshape cost field to 2D grid
    cost_grid = cost_field.reshape(grid_size, grid_size)
    
    fig, ax = plt.subplots(figsize=(12, 11))
    
    # Plot cost field as heatmap
    im = ax.imshow(cost_grid, cmap='YlOrRd', origin='upper', 
                   extent=[0, grid_size, grid_size, 0], alpha=0.8)
    
    # Add colorbar
    cbar = plt.colorbar(im, ax=ax, label='Cost Field')
    
    # Plot grid edges (sample for visibility)
    for i in range(0, grid_size, 3):
        ax.axhline(y=i, color='gray', linewidth=0.3, alpha=0.3)
        ax.axvline(x=i, color='gray', linewidth=0.3, alpha=0.3)
    
    ax.set_xlim(0, grid_size)
    ax.set_ylim(grid_size, 0)
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_title(title)
    ax.set_aspect('equal')
    
    plt.tight_layout()
    plt.show()

plot_grid_with_costs(positions, cost_field, grid_size, 
                     title="30×30 Grid with Gaussian Obstacles (High Cost = Red)")

## 4. Define Source and Sink

- **Source**: Top-left corner (vertex 0)
- **Sink**: Bottom-right corner (vertex n²-1)

The optimal transport will find the best path from top-left to bottom-right, avoiding the high-cost regions.

In [None]:
def create_corner_sources_sinks(grid_size):
    """
    Create source at top-left and sink at bottom-right.
    
    Parameters
    ----------
    grid_size : int
        Size of the grid
    
    Returns
    -------
    source_idx : int
        Index of source vertex
    sink_idx : int
        Index of sink vertex
    z : ndarray of shape (n,)
        Source/sink vector
    """
    n = grid_size * grid_size
    
    # Top-left corner
    source_idx = 0
    
    # Bottom-right corner
    sink_idx = n - 1
    
    # Create source/sink vector
    z = np.zeros(n)
    z[source_idx] = 1.0
    z[sink_idx] = -1.0
    
    return source_idx, sink_idx, z

source_idx, sink_idx, z = create_corner_sources_sinks(grid_size)

print(f"Source and sink:")
print(f"  - Source index: {source_idx} at position {positions[source_idx]}")
print(f"  - Sink index: {sink_idx} at position {positions[sink_idx]}")
print(f"  - Euclidean distance: {np.linalg.norm(positions[source_idx] - positions[sink_idx]):.2f}")

# Visualize source and sink
cost_grid = cost_field.reshape(grid_size, grid_size)
fig, ax = plt.subplots(figsize=(12, 11))
im = ax.imshow(cost_grid, cmap='YlOrRd', origin='upper', 
               extent=[0, grid_size, grid_size, 0], alpha=0.8)
plt.colorbar(im, ax=ax, label='Cost Field')

# Plot source and sink
ax.scatter(positions[source_idx, 0], positions[source_idx, 1], 
          s=500, c='blue', marker='s', edgecolors='black', linewidths=3,
          label='Source (top-left)', zorder=10)
ax.scatter(positions[sink_idx, 0], positions[sink_idx, 1], 
          s=500, c='green', marker='s', edgecolors='black', linewidths=3,
          label='Sink (bottom-right)', zorder=10)

ax.set_xlim(0, grid_size)
ax.set_ylim(grid_size, 0)
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_title('Grid with Source and Sink')
ax.legend(fontsize=12, loc='upper right')
ax.set_aspect('equal')
plt.tight_layout()
plt.show()

## 5. Exact Optimal Transport (Linear Programming)

Compute the exact optimal transport using CVXPY.

**Note**: For a 900-vertex graph, this may take some time (~10-60 seconds depending on the solver).

In [None]:
print("Computing exact optimal transport...")
print("(This may take 10-60 seconds for a 900-vertex graph)\n")

start = time.time()
F_exact, obj_exact, status = solve_w1_exact(W, z, verbose=False)
time_exact = time.time() - start

print(f"Exact solver:")
print(f"  - Status: {status}")
print(f"  - Optimal cost: {obj_exact:.6f}")
print(f"  - Time: {time_exact:.2f}s")
print(f"  - Non-zero flows: {np.sum(F_exact > 1e-6)}")

## 6. Visualize Exact Flow

Display the exact optimal transport flow on the grid.

In [None]:
def plot_grid_with_flow(positions, cost_field, F, z, grid_size, threshold=1e-6,
                        title="Grid with Flow", flow_color='purple', 
                        flow_width_scale=5):
    """
    Plot grid with flow arrows.
    
    Parameters
    ----------
    positions : ndarray of shape (n, 2)
        Vertex positions
    cost_field : ndarray of shape (n,)
        Cost at each vertex
    F : ndarray of shape (n, n)
        Flow matrix
    z : ndarray of shape (n,)
        Source/sink vector
    grid_size : int
        Size of grid
    threshold : float
        Minimum flow to display
    title : str
        Plot title
    flow_color : str
        Color for flow edges
    flow_width_scale : float
        Scale for flow line width
    """
    cost_grid = cost_field.reshape(grid_size, grid_size)
    
    fig, ax = plt.subplots(figsize=(14, 13))
    
    # Plot cost field
    im = ax.imshow(cost_grid, cmap='YlOrRd', origin='upper',
                   extent=[0, grid_size, grid_size, 0], alpha=0.6)
    plt.colorbar(im, ax=ax, label='Cost Field')
    
    # Extract flow edges
    flow_segments = []
    flow_weights = []
    max_flow = F.max()
    
    for i in range(len(positions)):
        for j in range(i+1, len(positions)):
            flow_ij = max(F[i, j], F[j, i])
            if flow_ij > threshold:
                flow_segments.append([positions[i], positions[j]])
                flow_weights.append(flow_ij)
    
    # Plot flow as line collection
    if len(flow_segments) > 0:
        flow_weights = np.array(flow_weights)
        normalized_weights = flow_weights / max_flow
        
        lc = LineCollection(flow_segments, 
                           linewidths=normalized_weights * flow_width_scale,
                           colors=flow_color, alpha=0.8, zorder=5)
        ax.add_collection(lc)
    
    # Plot source and sink
    sources = np.where(z > 0)[0]
    sinks = np.where(z < 0)[0]
    
    if len(sources) > 0:
        ax.scatter(positions[sources, 0], positions[sources, 1],
                  s=400, c='blue', marker='s', edgecolors='black', 
                  linewidths=3, label='Source', zorder=10)
    if len(sinks) > 0:
        ax.scatter(positions[sinks, 0], positions[sinks, 1],
                  s=400, c='green', marker='s', edgecolors='black',
                  linewidths=3, label='Sink', zorder=10)
    
    ax.set_xlim(0, grid_size)
    ax.set_ylim(grid_size, 0)
    ax.set_xlabel('X', fontsize=12)
    ax.set_ylabel('Y', fontsize=12)
    ax.set_title(f"{title}\n({len(flow_segments)} edges with flow > {threshold:.2e})",
                fontsize=14)
    ax.legend(fontsize=12, loc='upper right')
    ax.set_aspect('equal')
    
    plt.tight_layout()
    plt.show()

# Visualize exact flow
threshold_exact = F_exact.max() / 100
plot_grid_with_flow(positions, cost_field, F_exact, z, grid_size,
                    threshold=threshold_exact,
                    title="Exact Optimal Transport Flow (CVXPY)",
                    flow_color='purple',
                    flow_width_scale=8)

## 7. Flow Sinkhorn with Large Regularization

Use a large entropic regularization parameter for faster convergence.

In [None]:
# Create sparse cost matrix
Ws = sparse.COO(A.nonzero(), W[A.nonzero()], shape=W.shape, fill_value=1e9)

# Large regularization
epsilon_large = 0.5
niter = 2000

print(f"Computing Sinkhorn flow with large regularization (ε = {epsilon_large})...")
start = time.time()
F_sinkhorn_large, err_large, h_large = sinkhorn_w1_sparse(Ws, z, epsilon=epsilon_large, niter=niter)
time_sinkhorn_large = time.time() - start

# Convert to dense
F_sinkhorn_large_dense = F_sinkhorn_large.todense()

# Compute cost
cost_large = np.sum(F_sinkhorn_large_dense * W)

print(f"\nSinkhorn (large ε):")
print(f"  - Final error: {err_large[-1]:.2e}")
print(f"  - Cost: {cost_large:.6f}")
print(f"  - Relative cost error: {abs(cost_large - obj_exact) / obj_exact * 100:.2f}%")
print(f"  - Time: {time_sinkhorn_large:.2f}s")
print(f"  - Speedup vs exact: {time_exact / time_sinkhorn_large:.1f}x")

In [None]:
# Plot convergence
plt.figure(figsize=(10, 5))
plt.plot(np.log10(err_large))
plt.xlabel('Iteration', fontsize=12)
plt.ylabel('log10(Error)', fontsize=12)
plt.title(f'Sinkhorn Convergence (ε = {epsilon_large})', fontsize=14)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
# Visualize flow
threshold_large = F_sinkhorn_large_dense.max() / 50
plot_grid_with_flow(positions, cost_field, F_sinkhorn_large_dense, z, grid_size,
                    threshold=threshold_large,
                    title=f"Sinkhorn Flow (ε = {epsilon_large}, large regularization)",
                    flow_color='orange',
                    flow_width_scale=6)

## 8. Flow Sinkhorn with Small Regularization

Use a smaller ε to get closer to the exact solution.

In [None]:
# Small regularization
epsilon_small = 0.05
niter = 5000

print(f"Computing Sinkhorn flow with small regularization (ε = {epsilon_small})...")
start = time.time()
F_sinkhorn_small, err_small, h_small = sinkhorn_w1_sparse(Ws, z, epsilon=epsilon_small, niter=niter)
time_sinkhorn_small = time.time() - start

# Convert to dense
F_sinkhorn_small_dense = F_sinkhorn_small.todense()

# Compute cost
cost_small = np.sum(F_sinkhorn_small_dense * W)

print(f"\nSinkhorn (small ε):")
print(f"  - Final error: {err_small[-1]:.2e}")
print(f"  - Cost: {cost_small:.6f}")
print(f"  - Relative cost error: {abs(cost_small - obj_exact) / obj_exact * 100:.2f}%")
print(f"  - Time: {time_sinkhorn_small:.2f}s")
print(f"  - Speedup vs exact: {time_exact / time_sinkhorn_small:.1f}x")

In [None]:
# Plot convergence comparison
plt.figure(figsize=(12, 5))
plt.plot(np.log10(err_large), label=f'ε = {epsilon_large} (large)', linewidth=2)
plt.plot(np.log10(err_small), label=f'ε = {epsilon_small} (small)', linewidth=2)
plt.xlabel('Iteration', fontsize=12)
plt.ylabel('log10(Error)', fontsize=12)
plt.title('Sinkhorn Convergence Comparison', fontsize=14)
plt.legend(fontsize=12)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
# Visualize flow
threshold_small = F_sinkhorn_small_dense.max() / 50
plot_grid_with_flow(positions, cost_field, F_sinkhorn_small_dense, z, grid_size,
                    threshold=threshold_small,
                    title=f"Sinkhorn Flow (ε = {epsilon_small}, small regularization)",
                    flow_color='cyan',
                    flow_width_scale=6)

## 9. Comparison Summary

Compare all methods and visualize the obstacle avoidance behavior.

In [None]:
# Summary table
print("\n" + "="*80)
print("COMPARISON SUMMARY")
print("="*80)
print(f"{'Method':<30} {'Cost':>12} {'Rel. Error':>12} {'Time (s)':>10} {'Speedup':>8}")
print("-"*80)
print(f"{'Exact (CVXPY)':<30} {obj_exact:>12.6f} {0.0:>11.2f}% {time_exact:>10.2f} {'1.0x':>8}")
print(f"{'Sinkhorn (ε=' + str(epsilon_large) + ')':<30} {cost_large:>12.6f} "
      f"{abs(cost_large - obj_exact) / obj_exact * 100:>11.2f}% "
      f"{time_sinkhorn_large:>10.2f} {time_exact / time_sinkhorn_large:>7.1f}x")
print(f"{'Sinkhorn (ε=' + str(epsilon_small) + ')':<30} {cost_small:>12.6f} "
      f"{abs(cost_small - obj_exact) / obj_exact * 100:>11.2f}% "
      f"{time_sinkhorn_small:>10.2f} {time_exact / time_sinkhorn_small:>7.1f}x")
print("="*80)

# Bar plots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

methods = ['Exact', f'Sinkhorn\n(ε={epsilon_large})', f'Sinkhorn\n(ε={epsilon_small})']
costs = [obj_exact, cost_large, cost_small]
times = [time_exact, time_sinkhorn_large, time_sinkhorn_small]
colors = ['purple', 'orange', 'cyan']

# Cost comparison
ax1.bar(methods, costs, color=colors, alpha=0.7, edgecolor='black', linewidth=2)
ax1.axhline(y=obj_exact, color='purple', linestyle='--', linewidth=2, label='Exact cost')
ax1.set_ylabel('Transport Cost', fontsize=12)
ax1.set_title('Cost Comparison', fontsize=14)
ax1.legend(fontsize=11)
ax1.grid(True, alpha=0.3, axis='y')

# Time comparison
ax2.bar(methods, times, color=colors, alpha=0.7, edgecolor='black', linewidth=2)
ax2.set_ylabel('Computation Time (s)', fontsize=12)
ax2.set_title('Runtime Comparison', fontsize=14)
ax2.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

## 10. Side-by-Side Flow Comparison

Visualize all three solutions together to see how they handle obstacle avoidance.

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(20, 7))

flows = [F_exact, F_sinkhorn_large_dense, F_sinkhorn_small_dense]
titles = [
    'Exact Flow (CVXPY)',
    f'Sinkhorn ε={epsilon_large}',
    f'Sinkhorn ε={epsilon_small}'
]
flow_colors = ['purple', 'orange', 'cyan']
thresholds = [threshold_exact, threshold_large, threshold_small]

for ax, F, title, color, threshold in zip(axes, flows, titles, flow_colors, thresholds):
    cost_grid = cost_field.reshape(grid_size, grid_size)
    
    # Plot cost field
    im = ax.imshow(cost_grid, cmap='YlOrRd', origin='upper',
                   extent=[0, grid_size, grid_size, 0], alpha=0.6)
    
    # Extract and plot flow
    flow_segments = []
    flow_weights = []
    max_flow = F.max()
    
    for i in range(len(positions)):
        for j in range(i+1, len(positions)):
            flow_ij = max(F[i, j], F[j, i])
            if flow_ij > threshold:
                flow_segments.append([positions[i], positions[j]])
                flow_weights.append(flow_ij)
    
    if len(flow_segments) > 0:
        flow_weights = np.array(flow_weights)
        normalized_weights = flow_weights / max_flow
        lc = LineCollection(flow_segments,
                           linewidths=normalized_weights * 6,
                           colors=color, alpha=0.8, zorder=5)
        ax.add_collection(lc)
    
    # Plot source and sink
    ax.scatter(positions[source_idx, 0], positions[source_idx, 1],
              s=300, c='blue', marker='s', edgecolors='black', linewidths=2, zorder=10)
    ax.scatter(positions[sink_idx, 0], positions[sink_idx, 1],
              s=300, c='green', marker='s', edgecolors='black', linewidths=2, zorder=10)
    
    ax.set_xlim(0, grid_size)
    ax.set_ylim(grid_size, 0)
    ax.set_xlabel('X', fontsize=11)
    ax.set_ylabel('Y', fontsize=11)
    ax.set_title(title, fontsize=13, fontweight='bold')
    ax.set_aspect('equal')

plt.tight_layout()
plt.show()