# Differentiable Fluid Dynamics on TPUs

This workflow serves as a "conversion kit" for [Computational Fluid Dynamics](https://en.wikipedia.org/wiki/Computational_fluid_dynamics) researchers. It demonstrates that you can keep their Python/NumPy syntax while unlocking Supercomputer-grade fluid dynamics an [Tensor Processing Units](https://en.wikipedia.org/wiki/Tensor_Processing_Unit). This is a complete, cell-by-cell workflow.

## The Setup: Physics & Parameters

We start by simulating [Von Kármán Vortex Shedding](https://en.wikipedia.org/wiki/K%C3%A1rm%C3%A1n_vortex_street) (flow past a cylinder); first using the typical workflow, then the [JAX](https://docs.jax.dev/) equivalent to run on any attached TPU. The resolution on the grid are course enough for you a to test on a CPU and can be adjusted to compare the performance once you have access to a TPU. Not that very fine grids will likely crash a CPU or suffer extremely long running times. 

Upto this point is a side-by-side performance comparison of the computational power of JAX-on-TPU vs numpy-on-CPU for equivalent modeling.

We next look at some advantages that are unique to JAX. By leveraging the built-in differentiation in JAX we're able to "reverse" the simulation and home in on a cross-sectional profile that reduces turbulence (what CFD modelers are most interested in). This essentially automates your iterations through different cross sections.

## Physics Parameters:

- Grid: 400 $\times$ 100 (High enough to see eddies, small enough to run quickly for the demo)
- Reynolds Number: ~80 (The "Unsteady" regime where vortices spontaneously shed)
- Lattice: D2Q9 (Standard 9-velocity model).

In [None]:
# Install JAX for TPU

import os
import sys

# Check if we are in a TPU environment and install the correct version
try:
    import jax
    print("JAX is already installed.")
except ImportError:
    # This is the specific command for Cloud TPU VMs
    print("Installing JAX for TPU...")
    !pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html


In [None]:
import jax.numpy as jnp
import numpy as np
import time
import matplotlib.pyplot as plt

# Verify we are on TPU
print(f"Target Device: {jax.devices()[0]}")

## The Physics Constants (Shared)

We define the physics once so the comparison is mathematically exact.

**Note: The grid is deliberately coarse so you can begin your experiment on a CPU. Experiment much higher values for NX and NY**

In [None]:
# --- Domain Constants ---
NX = 400   # Width
NY = 100   # Height
R = NY // 9  # Cylinder Radius
CX, CY = NX // 4, NY // 2  # Cylinder coordinates

# --- Physics Constants ---
# We tune these for a Reynolds Number ~80 to get vortex shedding
U0 = 0.05             # Inflow velocity (Mach < 0.1 for stability)
rho0 = 1.0            # Rest density
Re = 80.0             # Reynolds number
# Viscosity derived from Re: nu = U * D / Re
nu = U0 * (2 * R) / Re
# Relaxation time (tau) derived from viscosity: nu = (tau - 0.5)/3
tau = 3.0 * nu + 0.5 

print(f"Simulating Reynolds Number: {Re}")
print(f"Computed Relaxation Time (tau): {tau:.4f}")

# --- D2Q9 Lattice Constants ---
# The 9 directions: Center, E, N, W, S, NE, NW, SW, SE
w = np.array([4/9, 1/9, 1/9, 1/9, 1/9, 1/36, 1/36, 1/36, 1/36])
c_x = np.array([0, 1, 0, -1, 0, 1, -1, -1, 1])
c_y = np.array([0, 0, 1, 0, -1, 1, 1, -1, -1])

# Create the Cylinder Mask (Boolean array)
Y, X = np.meshgrid(np.arange(NY), np.arange(NX), indexing='ij')
cylinder_mask = (X - CX)**2 + (Y - CY)**2 < R**2

## The "Legacy" Approach (NumPy/CPU)

The  current approach to prototyping relying on standard `numpy` operations. Note the explicit loop for streaming.

In [None]:
def init_numpy():
    return np.ones((9, NY, NX)) * w[:, None, None]

def step_numpy(f):
    # 1. Macroscopic variables
    rho = np.sum(f, axis=0)
    # Dense linear algebra (slow on CPU)
    ux = np.sum(f * c_x[:, None, None], axis=0) / rho
    uy = np.sum(f * c_y[:, None, None], axis=0) / rho
    
    # 2. Collision (BGK approximation)
    # Calculate Equilibrium
    u_sq = ux**2 + uy**2
    for i in range(9):
        cu = c_x[i]*ux + c_y[i]*uy
        f_eq = rho * w[i] * (1 + 3*cu + 4.5*cu**2 - 1.5*u_sq)
        f[i] = f[i] - (f[i] - f_eq) / tau

    # 3. Streaming (The Bottleneck)
    # Rolling arrays in memory destroys CPU cache locality
    for i in range(9):
        f[i] = np.roll(f[i], shift=(c_x[i], c_y[i]), axis=(1, 0))

    # 4. Boundary Conditions
    # Rigid Cylinder (Bounce-back)
    # Invert directions: e.g., North(2) becomes South(4)
    inverse_idxs = [0, 3, 4, 1, 2, 7, 8, 5, 6]
    for i in range(9):
        # Wherever the mask is True, reflect the particle
        f[i][cylinder_mask] = f[inverse_idxs[i]][cylinder_mask]
        
    # Inflow (Left side fixed velocity)
    # (Simplified for brevity: force equilibrium at x=0)
    col0_rho = np.sum(f[:, :, 0], axis=0)
    col0_u_sq = U0**2
    for i in range(9):
        cu = c_x[i]*U0
        f_eq_0 = col0_rho * w[i] * (1 + 3*cu + 4.5*cu**2 - 1.5*col0_u_sq)
        f[i, :, 0] = f_eq_0

    return f

## The Alternative Approach (JAX / TPU)

The logic is identical, but we use jax.jit. This compiles the entire physics loop into a single fused XLA kernel. The roll becomes a tensor shift in High Bandwidth Memory

In [None]:


# Move constants to JAX Device (Immutable)
J_w = jnp.array(w)
J_cx = jnp.array(c_x)
J_cy = jnp.array(c_y)
J_mask = jnp.array(cylinder_mask)
J_inv = jnp.array([0, 3, 4, 1, 2, 7, 8, 5, 6]) # Inversion mapping

def init_jax():
    return jnp.ones((9, NY, NX)) * J_w[:, None, None]

@jax.jit  # <--- The Magic: Compiles Python to TPU Machine Code
def step_jax(f):
    # 1. Macroscopic variables
    rho = jnp.sum(f, axis=0)
    ux = jnp.sum(f * J_cx[:, None, None], axis=0) / rho
    uy = jnp.sum(f * J_cy[:, None, None], axis=0) / rho

    # 2. Collision (Vectorized)
    u_sq = ux**2 + uy**2
    # Einops-style calculation for all 9 directions at once
    cu = (f * 0) # Placeholder for dot product broadcasting
    # A fully vectorized equilibrium calculation
    # (We expand dims to broadcast (9,) against (NY, NX))
    cu = J_cx[:, None, None] * ux + J_cy[:, None, None] * uy
    f_eq = rho * J_w[:, None, None] * (1 + 3*cu + 4.5*cu**2 - 1.5*u_sq)
    
    f_out = f - (f - f_eq) / tau

    # 3. Streaming
    # On TPU, jnp.roll is a hardware-optimized shift
    for i in range(9):
        f_out = f_out.at[i].set(jnp.roll(f_out[i], shift=(J_cx[i], J_cy[i]), axis=(1, 0)))

    # 4. Boundary Conditions (Bounce-back)
    # Where mask is True, replace f_out with the inverted direction from BEFORE stream
    # (Standard LBM bounceback logic)
    bounced = f_out[J_inv]
    f_out = jnp.where(J_mask, bounced, f_out)
    
    # Inflow Condition
    # (Re-enforce Equilibrium at Left Wall)
    rho_0 = jnp.sum(f_out[:, :, 0], axis=0)
    cu_0 = J_cx[:, None] * U0
    f_eq_0 = rho_0 * J_w[:, None] * (1 + 3*cu_0 + 4.5*cu_0**2 - 1.5*U0**2)
    f_out = f_out.at[:, :, 0].set(f_eq_0)

    return f_out

# The Benchmark
We run 100 iterations of both. Note: The first JAX run includes compilation time, so we ignore it (warmup).

In [None]:
# @title 5. The Race: CPU vs TPU

ITERATIONS = 500

print(f"Running {ITERATIONS} steps of Fluid Dynamics...")

# --- 1. Run NumPy ---
f_state = init_numpy()
start_cpu = time.time()
for _ in range(ITERATIONS):
    f_state = step_numpy(f_state)
end_cpu = time.time()
cpu_time = end_cpu - start_cpu
print(f"NumPy (CPU): {cpu_time:.2f} seconds | {(ITERATIONS/cpu_time):.2f} FPS")

# --- 2. Run JAX ---
f_jax = init_jax()
# WARMUP (Compiling)
print("Compiling JAX Kernel (Warmup)...")
_ = step_jax(f_jax).block_until_ready()

# THE REAL RUN
start_tpu = time.time()
# We use jax.lax.fori_loop for true 'simulation inside the chip'
# But for a fair python-loop comparison, we'll loop in python
for _ in range(ITERATIONS):
    f_jax = step_jax(f_jax)
# Force synchronization to measure actual time
f_jax.block_until_ready()
end_tpu = time.time()
tpu_time = end_tpu - start_tpu

print(f"JAX (TPU)  : {tpu_time:.2f} seconds | {(ITERATIONS/tpu_time):.2f} FPS")
print(f"Speedup    : {cpu_time/tpu_time:.1f}x FASTER")

Even on CPU the speed-up from JAX should be evident.

## Visualization
Speed means nothing if the math is wrong. We calculate the Curl (Vorticity) of the velocity field to visualize the eddies.

In [None]:
# @title 6. Visualize the Vorticity (The "Pretty Picture")

def get_curl(f_final):
    # Convert JAX array to Numpy for plotting
    f = np.array(f_final)
    rho = np.sum(f, axis=0)
    ux = np.sum(f * c_x[:, None, None], axis=0) / rho
    uy = np.sum(f * c_y[:, None, None], axis=0) / rho
    
    # Compute Curl (du_y/dx - du_x/dy) using gradients
    dy_ux, dx_ux = np.gradient(ux)
    dy_uy, dx_uy = np.gradient(uy)
    curl = dx_uy - dy_ux
    return curl

vorticity = get_curl(f_jax)

plt.figure(figsize=(20, 6), dpi=200)
plt.imshow(vorticity, cmap='RdBu', vmin=-0.02, vmax=0.02, origin='lower')
# Overlay Cylinder
circle = plt.Circle((CX, CY), R, color='black', fill=True)
plt.gca().add_patch(circle)

plt.title(f"Von Kármán Vortex Street (Re={Re})")
plt.axis('off')
plt.show()

## The Code Difference: 

Note that step_jax is 95% identical to step_numpy. You don't need to learn a new language.

`jnp.roll` is the streaming step. On a CPU, this is a slow memory copy. On a TPU, the interconnects move this data instantly. You should see a significat speedup even without resorting to a TPU.

The Output: The visualization shows the alternating positive (Red) and negative (Blue) vortices trailing the cylinder. Run the cell below to view the animation

In [None]:
# Generate Animation
from matplotlib import animation
from IPython.display import HTML

# 1. COMPILE A "BATCH" STEP
# To make rendering fast, we run 40 physics steps per video frame.
# We JIT compile this loop so the TPU screams through the math.
STEPS_PER_FRAME = 40

@jax.jit
def evolve_batch(f):
    def body_fun(i, f_val):
        return step_jax(f_val)
    # This runs the loop inside the TPU, not in Python
    return jax.lax.fori_loop(0, STEPS_PER_FRAME, body_fun, f)

# 2. SETUP THE PLOT
fig, ax = plt.subplots(figsize=(20, 6), dpi=200)

# Initial Frame
f_anim = init_jax()
# Run a quick warmup to get past the initial "still" water
for _ in range(5):
    f_anim = evolve_batch(f_anim)

# Compute initial curl for the color scale
initial_curl = get_curl(f_anim)
im = ax.imshow(initial_curl, cmap='RdBu', vmin=-0.02, vmax=0.02, 
               origin='lower', interpolation='spline36')

# Add the cylinder visual
circle = plt.Circle((CX, CY), R, color='black', fill=True)
ax.add_patch(circle)
ax.axis('off')
title = ax.set_title("Time Step: 0")

# 3. ANIMATION LOOP
def animate(frame_num):
    global f_anim
    # Run the physics on TPU
    f_anim = evolve_batch(f_anim)
    f_anim.block_until_ready() # Wait for TPU to finish
    
    # Bring result to CPU for plotting
    curl = get_curl(f_anim)
    
    # Update the image
    im.set_data(curl)
    title.set_text(f"Time Step: {frame_num * STEPS_PER_FRAME}")
    return [im, title]

# Create animation (100 frames * 40 steps = 4000 simulation steps)
print("Rendering video... (This might take 30-60 seconds)")
anim = animation.FuncAnimation(fig, animate, frames=100, interval=50, blit=True)

# Display as interactive HTML5 video
HTML(anim.to_jshtml())

## Download mp4

In [None]:
import os
import matplotlib.animation as animation
from IPython.display import FileLink

# 1. ROBUST INSTALLATION
# Try Conda first (preferred for your environment), then apt-get
if os.system("which ffmpeg") != 0:
    print("⚠️ FFmpeg not found. Attempting install via Conda...")
    # This works better in Vertex AI / JupyterLab environments
    result = os.system("conda install -c conda-forge ffmpeg -y")
    
    if result != 0:
        print("Conda install failed. Trying system apt-get...")
        !sudo apt-get update && sudo apt-get install ffmpeg -y > /dev/null

# Verify installation
if os.system("which ffmpeg") != 0:
    raise RuntimeError("❌ FFmpeg could not be installed. Please install it manually to save MP4s.")
else:
    print("✅ FFmpeg is installed and ready.")

# 2. SAVE THE VIDEO
output_filename = "tpu_fluid_simulation.mp4"
print(f"Rendering {output_filename}...")

# We explicitly create the writer object.
# This prevents the 'silent fallback' to Pillow that caused your error.
FFwriter = animation.FFMpegWriter(
    fps=30, 
    extra_args=['-vcodec', 'libx264']
)

try:
    anim.save(output_filename, writer=FFwriter)
    print(f"✅ Video saved successfully: {output_filename}")
except FileNotFoundError:
    print("❌ Error: Matplotlib still cannot find the 'ffmpeg' binary.")
    print("Try restarting the kernel to refresh system paths.")

# 3. TRIGGER DOWNLOAD
try:
    from google.colab import files
    files.download(output_filename)
except ImportError:
    print("Click below to download:")
    display(FileLink(output_filename))

# The "Fun" Stuff: Topology Optmization

In traditional CFD, optimizing a shape (e.g., "what shape of wing minimizes drag?") is challenging. You usually have to write a completely separate "Adjoint Solver" or just guess-and-check 1,000 times.

In JAX, because the simulation is differentiable, we can get the answer automatically.

We are going to ask the TPU:

"Here is the flow. Tell me exactly which pixels on the screen I should turn into 'solid wall' to stop the turbulence."


## The Setup: "Soft Solver"

To make the physics differentiable, we have to make the solid cylinder "soft." Instead of a hard True/False mask, we treat the obstacle like a porous fog (values between 0.0 and 1.0). This allows JAX to calculate gradients through the object.

Note:

In [None]:
# 1. DIFFERENTIABLE SOLVER
@jax.jit
def step_differentiable(f, continuous_mask):
    # Standard Macros
    rho = jnp.sum(f, axis=0)
    ux = jnp.sum(f * J_cx[:, None, None], axis=0) / rho
    uy = jnp.sum(f * J_cy[:, None, None], axis=0) / rho

    # Standard Collision
    u_sq = ux**2 + uy**2
    cu = J_cx[:, None, None] * ux + J_cy[:, None, None] * uy
    f_eq = rho * J_w[:, None, None] * (1 + 3*cu + 4.5*cu**2 - 1.5*u_sq)
    f_post_collision = f - (f - f_eq) / tau

    # Standard Streaming
    f_streamed = f_post_collision
    for i in range(9):
        f_streamed = f_streamed.at[i].set(
            jnp.roll(f_streamed[i], shift=(J_cx[i], J_cy[i]), axis=(1, 0))
        )

    # SOFT INTERACTION
    # Use continuous_mask (0.0 to 1.0) to blend bounce-back vs stream
    f_bounced = f_streamed[J_inv]
    f_out = continuous_mask * f_bounced + (1.0 - continuous_mask) * f_streamed

    # Inflow BC
    rho_0 = jnp.sum(f_out[:, :, 0], axis=0)
    cu_0 = J_cx[:, None] * U0
    f_eq_0 = rho_0 * J_w[:, None] * (1 + 3*cu_0 + 4.5*cu_0**2 - 1.5*U0**2)
    f_out = f_out.at[:, :, 0].set(f_eq_0)
    
    return f_out


def init_moving_jax():
    # Calculate equilibrium for velocity U0
    rho_start = 1.0
    cu = J_cx[:, None, None] * U0
    u_sq = U0**2
    f_single_pixel = rho_start * J_w[:, None, None] * (1 + 3*cu + 4.5*cu**2 - 1.5*u_sq)
    
    # EXPLICITLY expand to full grid (9, NY, NX)
    # This prevents the Shape Mismatch error in lax.scan
    return jnp.broadcast_to(f_single_pixel, (9, NY, NX))

# 2. LOSS FUNCTION
def simulation_loss(mask_proposal):
    # Start with MOVING fluid
    f = init_moving_jax()
    
    # Run loop
    def loop_body(f_curr, _):
        f_next = step_differentiable(f_curr, mask_proposal)
        return f_next, None
    
    # Run 400 steps
    f_final, _ = jax.lax.scan(loop_body, f, jnp.arange(400)) 
    
    # Calculate Turbulence (Curl squared)
    rho = jnp.sum(f_final, axis=0)
    ux = jnp.sum(f_final * J_cx[:, None, None], axis=0) / rho
    uy = jnp.sum(f_final * J_cy[:, None, None], axis=0) / rho
    
    dy_ux = jnp.gradient(ux, axis=0)
    dx_uy = jnp.gradient(uy, axis=1)
    curl = dx_uy - dy_ux
    
    return jnp.sum(curl**2)

# 3. CALCULATE GRADIENTS
print("Computing Adjoint Gradient (Simulating backwards)...")

# Convert boolean cylinder to float
current_geometry = jnp.array(cylinder_mask, dtype=jnp.float32)

# Get gradients
grad_fn = jax.grad(simulation_loss)
sensitivity_map = grad_fn(current_geometry)

# Force computation
sensitivity_map_np = np.array(sensitivity_map)
print(f"Gradient stats: Min={sensitivity_map_np.min():.2e}, Max={sensitivity_map_np.max():.2e}")

# 4. VISUALIZATION
plt.figure(figsize=(18, 6), dpi=150)

# Robust limit calculation
abs_max = np.max(np.abs(sensitivity_map_np))
limit = abs_max * 0.8 if abs_max > 1e-9 else 1.0 

im = plt.imshow(sensitivity_map_np, cmap='seismic', vmin=-limit, vmax=limit, origin='lower')

# Outline the cylinder
plt.contour(cylinder_mask, levels=[0.5], colors='black', linewidths=2)

plt.title("The 'Adjoint' Map: Red = Adds Turbulence / Blue = Reduces Turbulence", fontsize=14)
plt.colorbar(im, label="Sensitivity (dLoss/dGeometry)")
plt.show()

**Download the image**

In [None]:
# @title 10. Save & Download Adjoint Map Image
from IPython.display import FileLink

# 1. Save the current figure
# bbox_inches='tight' removes the white margins
output_image = "adjoint_sensitivity_map.png"
print(f"Saving {output_image}...")
plt.savefig(output_image, dpi=300, bbox_inches='tight')

# 2. Trigger Download
try:
    # Google Colab specific download trigger
    from google.colab import files
    files.download(output_image)
except ImportError:
    # For local Jupyter/JupyterLab, generate a clickable link
    print(f"✅ Image saved locally as '{output_image}'")
    display(FileLink(output_image))