# Differentiable Simulations

In the previous notebook (`mlp.ipynb`), we introduced Automatic Differentiation (AD) as the core engine that enables the training of neural networks by computing gradients of a loss function with respect to network parameters. 

However, the power of AD extends far beyond just neural networks. It allows us to make **entire physical simulations differentiable**. This paradigm, often called **Differentiable Simulation** or **Differentiable Physics**, involves implementing a simulator (e.g., a PDE solver) in a framework that supports AD, such as PyTorch or JAX. By doing so, we can automatically compute the gradient of a final quantity (like a measurement or a loss function) with respect to any initial parameter of the simulation.

This notebook demonstrates this powerful concept. We will:
1. Briefly recall how gradients are computed in PyTorch.
2. Introduce the JAX framework for high-performance differentiable programming.
3. Build a differentiable simulator for the 1D acoustic wave equation.
4. Use this simulator to solve a challenging inverse problem: Full Waveform Inversion (FWI).

## A Quick Reminder: Gradients in PyTorch

As we saw previously, frameworks like PyTorch keep track of all operations on tensors. When we call `.backward()` on a final scalar output (like a loss), PyTorch uses reverse-mode AD (backpropagation) to compute the gradient of that output with respect to the inputs that have `requires_grad=True`.

In [5]:
import torch

# Define variables that require gradients
x1 = torch.tensor(2.0, requires_grad=True)
x2 = torch.tensor(3.0, requires_grad=True)

# Define a simple function
y = x1**2 + x2

# Compute gradients using reverse mode AD
y.backward()

# Access the computed gradients
print(f"dy/dx1: {x1.grad.item()}")
print(f"dy/dx2: {x2.grad.item()}")

dy/dx1: 4.0
dy/dx2: 1.0


This same principle applies not just to a single function, but to a whole sequence of operations, such as the time-stepping loop in a physics simulator.

## JAX: High-Performance Differentiable Programming

For our differentiable simulation, we will use **JAX**, a library from Google for high-performance numerical computing. JAX combines a NumPy-like API with a powerful set of transformations:

- **`grad()`**: Automatic differentiation, just like in PyTorch.
- **`jit()`**: Just-in-time (JIT) compilation to accelerate Python code (especially loops) on CPUs, GPUs, and TPUs.
- **`vmap()`**: Automatic vectorization of functions.

This combination makes JAX exceptionally well-suited for writing fast and differentiable physics simulators.

In [4]:
import jax.numpy as jnp
from jax import grad, jit
import optax
import matplotlib.pyplot as plt

def f(x):
    return jnp.sin(x)

# grad() transformation
df = grad(f)
print(f"Gradient of sin(x) at x=1.0 is: {df(1.0)}")

# jit() transformation
@jit
def fast_f(x):
    return jnp.sin(x)

print(f"JIT-compiled sin(x) at x=1.0 is: {fast_f(1.0)}")

XlaRuntimeError: UNIMPLEMENTED: default_memory_space is not supported.

## Case Study: Differentiable 1D Wave Simulation

We will now build a differentiable simulator for the 1D acoustic wave equation. This equation describes how waves (like sound or seismic waves) propagate through a medium.

$$ \frac{\partial^2 u}{\partial t^2} = c^2 \frac{\partial^2 u}{\partial x^2} $$

Where:
- $ u(x, t) $ is the wave's displacement or pressure at position $x$ and time $t$.
- $ c $ is the wave speed in the medium, which can vary with position, $c(x)$.

### The Forward Problem: Simulation
The forward problem is to simulate the behavior of $u(x,t)$ given an initial state and the wave speed profile $c(x)$. We will solve this using a finite difference method. By rearranging the central difference approximation, we can find the wave's state at the next timestep based on its two previous states:

$$u_i^{n+1} = c_i^2 \frac{\Delta t^2}{\Delta x^2} (u_{i+1}^n - 2u_i^n + u_{i-1}^n) + 2u_i^n - u_i^{n-1} $$

We can implement this time-stepping loop in JAX. Using `@jit`, this loop will be compiled for high performance.

In [4]:
# Set up an n-point uniform mesh
n = 1000
dx = 1.0/(n-1)
x0 = jnp.linspace(0.0, 1.0, n)

@jit
def wave_propagation(params):
    """Simulates 1D wave propagation using a finite difference scheme."""
    c = params # c can be a scalar or a vector for a spatially varying profile
    dt = 5e-4
    C = c * dt / dx
    C2 = C**2

    # Set up initial conditions (a Gaussian pulse)
    u0 = jnp.exp(-(5 * (x0 - 0.5))**2)
    u1 = jnp.exp(-(5 * (x0 - 0.5 - c * dt))**2)
    u2 = jnp.zeros(n)

    def step(i, carry):
        u0, u1, _ = carry
        # Get neighbors using jnp.roll for periodic boundaries, then fix for Dirichlet
        u1p = jnp.roll(u1, 1)
        u1p = u1p.at[0].set(0)
        u1n = jnp.roll(u1, -1)
        u1n = u1n.at[n - 1].set(0)
        
        # Central difference update rule
        u2 = 2 * u1 - u0 + C2 * (u1p - 2 * u1 + u1n)
        u0, u1 = u1, u2
        return (u0, u1, u2)

    # Run the simulation loop
    u0, u1, u2 = lax.fori_loop(0, 5000, step, (u0, u1, u2))
    return u2

# --- Run a forward simulation ---
ctarget = 1.0 # Constant velocity profile
target_wave = wave_propagation(ctarget)

plt.figure(figsize=(12, 6))
plt.plot(x0, target_wave, 'b-', label='Final Wave State (u2)')
plt.title(f'Forward Simulation with c = {ctarget}')
plt.xlabel('Position x')
plt.ylabel('Displacement u')
plt.legend()
plt.grid(True, alpha=0.5)
plt.show()

XlaRuntimeError: UNIMPLEMENTED: default_memory_space is not supported.

### The Inverse Problem: Full Waveform Inversion (FWI)

Now for the exciting part. The inverse problem asks: **If we measure the final wave `target_wave`, can we determine the velocity profile `c` that produced it?**

This is a classic and difficult problem in geophysics. With a differentiable simulator, we can solve it using gradient descent.

1.  **Define a Loss Function**: We need a way to measure the difference between our simulation's output and the observed data. The L2 norm (mean squared error) is a standard choice.
    $$L(c) = || \text{wave_propagation}(c) - \text{target_wave} ||^2$$
2.  **Compute the Gradient**: Because our `wave_propagation` function is written in JAX, we can get the gradient of the loss with respect to the parameters `c` for free: `grad(L)(c)`.
3.  **Optimize**: We start with an initial guess for `c` and iteratively update it by moving in the direction of the negative gradient, using an optimizer like Adam.

Let's try to recover the constant velocity profile `ctarget = 1.0` starting from a wrong guess.

In [None]:
# Define the loss function
@jit
def compute_loss(c):
    u2_simulated = wave_propagation(c)
    return jnp.linalg.norm(u2_simulated - target_wave)

# Define the gradient of the loss function
loss_grad_fn = jit(grad(compute_loss))

# --- Setup the optimization ---
learning_rate = 1e-3
optimizer = optax.adam(learning_rate)

# Initial guess for the velocity (wrong value)
params_c = 0.85 
opt_state = optimizer.init(params_c)

print(f"Target velocity c: {ctarget}")
print(f"Initial guess for c: {params_c}")

# --- Run the optimization loop ---
for i in range(1001):
    grads = loss_grad_fn(params_c)
    updates, opt_state = optimizer.update(grads, opt_state)
    params_c = optax.apply_updates(params_c, updates)
    if i % 200 == 0:
        loss = compute_loss(params_c)
        print(f"Iteration {i}, Loss: {loss:.6f}, Current c: {params_c:.6f}")

print(f"\nFinal recovered velocity c: {params_c:.6f}")

# --- Visualize the results ---
recovered_wave = wave_propagation(params_c)

plt.figure(figsize=(12, 6))
plt.plot(x0, target_wave, 'k-', label='Target Wave', linewidth=2)
plt.plot(x0, recovered_wave, 'r--', label='Recovered Wave', linewidth=2)
plt.title('FWI Result: Target vs. Recovered Wave')
plt.xlabel('Position x')
plt.ylabel('Displacement u')
plt.legend()
plt.grid(True, alpha=0.5)
plt.show()

### Advanced Case: Recovering a Spatially-Varying Profile

The true power of this method becomes apparent when we try to recover a more complex, spatially-varying velocity profile. Let's define a target velocity `c(x)` that changes linearly across the domain.

Amazingly, **no change to our `compute_loss` or gradient descent logic is needed**. The `wave_propagation` function and the AD framework handle the fact that `c` is now a vector of parameters instead of a single scalar.

In [None]:
# New target: a linear velocity profile
ctarget_linear = jnp.linspace(0.9, 1.1, n)
target_wave_linear = wave_propagation(ctarget_linear)

# The loss function remains the same, but now computes the loss against the new target
@jit
def compute_loss_linear(c_vector):
    u2_simulated = wave_propagation(c_vector)
    return jnp.linalg.norm(u2_simulated - target_wave_linear)

loss_grad_fn_linear = jit(grad(compute_loss_linear))

# --- Setup and run optimization for the vector case ---
learning_rate = 1e-3
optimizer = optax.adam(learning_rate)

# Initial guess: a constant, incorrect velocity profile
params_c_vector = jnp.ones(n) * 0.85
opt_state = optimizer.init(params_c_vector)

print("Starting optimization to find the linear velocity profile...")
for i in range(2001):
    grads = loss_grad_fn_linear(params_c_vector)
    updates, opt_state = optimizer.update(grads, opt_state)
    params_c_vector = optax.apply_updates(params_c_vector, updates)
    if i % 500 == 0:
        loss = compute_loss_linear(params_c_vector)
        print(f"Iteration {i}, Loss: {loss:.4f}")

print("\nOptimization finished!")

# --- Visualize the results for the linear profile ---
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

# Plot the recovered velocity profile
ax1.plot(x0, ctarget_linear, 'k-', label='Target Profile', linewidth=2)
ax1.plot(x0, params_c_vector, 'r--', label='Recovered Profile', linewidth=2)
ax1.set_title('Velocity Profile Recovery (c(x))')
ax1.set_xlabel('Position x')
ax1.set_ylabel('Wave Speed c')
ax1.legend()
ax1.grid(True, alpha=0.5)

# Plot the resulting waves
recovered_wave_linear = wave_propagation(params_c_vector)
ax2.plot(x0, target_wave_linear, 'k-', label='Target Wave', linewidth=2)
ax2.plot(x0, recovered_wave_linear, 'r--', label='Recovered Wave', linewidth=2)
ax2.set_title('Final Wave Comparison')
ax2.set_xlabel('Position x')
ax2.set_ylabel('Displacement u')
ax2.legend()
ax2.grid(True, alpha=0.5)

plt.tight_layout()
plt.show()

## Conclusion

This notebook demonstrated the core idea of differentiable simulation.

By implementing a physics simulator within a framework that supports automatic differentiation (like JAX), we can efficiently solve complex, gradient-based inverse problems. We simply define a loss function that measures the difference between our simulation's output and some observed data, and then use an optimizer to minimize this loss by adjusting the physical parameters of the simulation.

This powerful technique is a cornerstone of modern Scientific Machine Learning (SciML), enabling the fusion of traditional scientific models with data-driven methods.