# Differentiable Simulation of a Deformable Flapping Body

This notebook demonstrates a full Fluid-Structure Interaction (FSI) simulation of a deformable body using the JAX-IB framework. 

**Key Differences from the Original Rigid Body Model:**
1.  **Dynamic vs. Kinematic:** The particle's motion is no longer prescribed by a mathematical function. Instead, it is *dynamically calculated* at each time step based on physical forces (internal elasticity, fluid pressure, viscosity, etc.).
2.  **Deformable Physics:** The particle is modeled as an elastic body with mass, stiffness, and surface tension, allowing it to deform in response to fluid forces. This is based on the penalty Immersed Boundary Method from Sustiel & Grier.
3.  **Stateful Simulation:** The core data structures have been rewritten to be *stateful*, holding the current positions and velocities of the particle markers. The entire simulation state is a JAX PyTree that is evolved over time.

This first cell handles the initial setup: cloning the repository from GitHub and installing it as a Python package in the Colab environment.

In [1]:
# --- SETUP AND INSTALLATION ---

# 1. (Optional) Clean up any old repository versions to ensure a fresh start
!rm -rf Hashim

# 2. Clone the latest version of the code from your GitHub repository
!git clone https://github.com/nurmaton/Hashim.git

# 3. Change directory into the cloned repository
%cd Hashim

# 4. Verify that the project files (like pyproject.toml and the jax_ib folder) are present
!ls

# 5. Install the project as a Python package in editable mode.
# The '-e' flag is for development; it means any changes you make to the source .py files
# will be immediately available without needing to reinstall.
# pip will automatically read the dependencies from your pyproject.toml file.
!pip install -e '.[data]'

### Imports and Compatibility Fixes

This cell imports all the necessary modules from JAX and the `jax_ib` library. 

It also includes a **monkey-patch** for `jax.random.KeyArray`. This is a crucial compatibility fix. The original `jax_ib` library was built on an older version of JAX where random keys had a specific type. In modern JAX, keys are just regular arrays (`jnp.ndarray`). This line manually adds the old name back, pointing it to the new type, which allows the library code to run without modification.

In [2]:
import jax
import jax.numpy as jnp

# --- COMPATIBILITY FIX ---
# Monkey-patch the missing 'KeyArray' attribute that the older jax_ib library expects.
# In modern JAX, a random key is just a regular JAX array.
jax.random.KeyArray = jnp.ndarray
# --- END FIX ---

# --- Core JAX-IB and JAX-CFD Imports ---
import jax_ib
from jax_ib.base import particle_class as pc
import jax_ib.base as ib
import jax_cfd.base as cfd # Using utilities from jax_cfd
import jax_ib.MD as MD
from jax import random
from jax_md import space, quantity
from jax_ib.base import grids
from jax_ib.base import boundaries
from jax_ib.base import advection, finite_differences

# --- NEW: Import the modules specific to the deformable body physics ---
from jax_ib.base import IBM_Force, convolution_functions, particle_motion

## Flow Problem Setup

This cell sets up the **Eulerian** part of the simulation: the fluid properties, the grid, boundary conditions, and the initial state of the fluid.

In [3]:
#-- Flow conditions --
density = 1.0
viscosity = 0.05
dt = 5e-4
inner_steps = 1000 # Number of solver steps per animation frame
outer_steps = 20   # Total number of animation frames to generate

# --- NEW: Physical Parameters for the Deformable Particle ---
# These parameters define the physical nature of the immersed body.
# They were not needed in the old kinematic model.
stiffness = 5e4   # Spring stiffness (Kp from the paper). This is a crucial parameter to tune.
total_mass = 1.0   # Total mass of the immersed object.
sigma = 1.0        # Surface tension coefficient. Set to 0.0 to turn this force off.
gravity = 9.81     # Gravitational acceleration, for simulating sedimentation.

#-- Domain and Grid Setup --
domain = ((0, 15.), (0, 15.0))
size = (600, 600)
grid = grids.Grid(size, domain=domain)

#-- Boundary Conditions --
# NOTE: For this simulation, we use simple, static periodic boundaries.
# The old model had complex functions for time-varying walls, which are not needed here.
velocity_bc = (boundaries.periodic_boundary_conditions(ndim=2),
               boundaries.periodic_boundary_conditions(ndim=2))

#-- Convection Scheme --
# Use a simple and stable upwind scheme for advection.
def convect(v):
  return tuple(advection.advect_upwind(u, v, dt) for u in v)

#-- Initial velocity profile (Fluid at Rest) --
vx_fn = lambda x, y: jnp.zeros_like(x + y)
vy_fn = lambda x, y: jnp.zeros_like(x + y)

velocity_fns = (vx_fn, vy_fn)
v0 = tuple(grid.eval_on_mesh(v_fn, offset) for v_fn, offset in zip(velocity_fns, grid.cell_faces))

# Wrap the raw velocity arrays in GridVariable objects, which attach the boundary conditions.
v0 = tuple(
      grids.GridVariable(u, bc) for u, bc in zip(v0, velocity_bc))

#-- Initial Pressure Profile --
# Start with zero pressure everywhere.
pressure0 = grids.GridVariable(grids.GridArray(jnp.zeros_like(v0[0].data), grid.cell_center, grid), boundaries.get_pressure_bc_from_velocity(v0))

## Immersed Boundary Setup

This cell sets up the **Lagrangian** part of the simulation: the deformable particle.

**This is the biggest conceptual change from the rigid body model.** Instead of defining mathematical functions for motion (`kinematics`), we now define the particle's initial physical state: its shape, position, mass, stiffness, etc. This state is then stored in the `particle` object, which will be updated by the physics solver at every time step.

In [4]:
#-- Immersed objects --

# --- 1. Define a function to generate the particle's shape ---
# This function returns the coordinates of an ellipse with even point spacing.
def foil_XY_ELLIPSE(geometry_param, theta_dummy):
    A = geometry_param[0]  # Semi-major axis
    B = geometry_param[1]  # Semi-minor axis
    ntheta = 150 # Number of Lagrangian marker points on the boundary

    angle = jnp.linspace(0, 2 * jnp.pi, ntheta, endpoint=False)
    x = A * jnp.cos(angle)
    y = B * jnp.sin(angle)
    return x, y

# --- 2. NEW INITIALIZATION LOGIC FOR THE DEFORMABLE PARTICLE ---
particle_geometry_param = jnp.array([[0.5, 0.1]]) # Semi-axes [A, B] of the ellipse
particle_center_position = jnp.array([[domain[0][1]*0.75, domain[1][1]/2],]) # Initial [x, y] center
Shape_fn = foil_XY_ELLIPSE

# Generate the initial shape coordinates in the particle's local reference frame (centered at origin)
Grid_p = pc.Grid1d(shape=150, domain=(0, 2*jnp.pi)) # A dummy grid for the shape function
xp0_body, yp0_body = Shape_fn(particle_geometry_param[0], Grid_p.mesh())

# Calculate the initial positions of the marker points in the global simulation frame.
# In the penalty method, the fluid-interacting markers (X) and mass-carrying markers (Y)
# start at the same location.
initial_xp = xp0_body + particle_center_position[0, 0]
initial_yp = yp0_body + particle_center_position[0, 1]

# The particle starts from rest, so initial velocities for the mass markers are zero.
initial_Vm_x = jnp.zeros_like(initial_xp)
initial_Vm_y = jnp.zeros_like(initial_xp)

# Calculate the mass of each individual marker point.
num_markers = len(initial_xp)
mass_per_marker = total_mass / num_markers

# --- 3. Create the final particle object with its full dynamic state ---
# This object holds all the stateful information that will be updated by the solver.
deformable_particle = pc.particle(
    xp=initial_xp, yp=initial_yp, # Current fluid-interacting marker positions (X in paper)
    Ym_x=initial_xp.copy(), Ym_y=initial_yp.copy(), # Mass-carrying marker positions (Y in paper)
    Vm_x=initial_Vm_x, Vm_y=initial_Vm_y, # Mass-carrying marker velocities
    mass_per_marker=mass_per_marker,
    stiffness=stiffness, # The penalty spring constant, Kp
    sigma=sigma, # The surface tension coefficient
    particle_center=particle_center_position, # Static initial center
    geometry_param=particle_geometry_param, # Static shape parameters
    Grid=Grid_p, # Static grid object
    shape=Shape_fn # Static shape function
)

# This container holds the list of all particles in the simulation (in this case, just one).
# It is a JAX PyTree, which is crucial for JIT compilation.
particles_container = pc.particle_lista(particles=[deformable_particle])

## Simulation Setup

This cell assembles the complete time-stepping function. It takes all the individual physics components we've defined (fluid convection, IBM force, particle motion) and combines them into a single `step_fn` that can advance the entire simulation state by one time step. 

In [5]:
#-- Setup Pytree Variable to be looped over during integration steps --

Intermediate_calcs = [0] # Placeholder for any post-processing calculations
Step_counter = 0
MD_state = [0] # Placeholder for Brownian motion state (not used here)

# The `All_Variables` object is the master PyTree that contains the ENTIRE simulation state.
# This is the object that will be passed into and out of the main JAX loop.
all_variables = pc.All_Variables(particles_container, v0, pressure0, Intermediate_calcs, Step_counter, MD_state)

#-- IB force coupling functions for the Deformable Model --
def internal_post_processing(all_variables, dt):
    # This function can be used for any calculations you want to perform inside the time-step loop.
    return all_variables

# This is the discrete delta function kernel used for all spreading/interpolation operations.
discrete_delta = lambda dist, center, width: convolution_functions.delta_approx_logistjax(dist, center, width)

# --- NEW, UPDATED IBM FORCING AND PARTICLE UPDATE FUNCTIONS ---
# These lambda functions adapt our specific physics modules to the generic API required by the main solver.

# The IBM forcing function now calculates the physical penalty/tension force from the particle's state.
# The `v` argument passed by the solver is the entire `all_variables` pytree.
IBM_forcing = lambda v, dt: IBM_Force.calc_IBM_force_NEW_MULTIPLE(v, discrete_delta, dt)

# The particle update function now calls our new massive, deformable particle stepper.
# It solves the equations of motion for the particle markers.
Update_position = lambda v, dt: particle_motion.update_massive_deformable_particle(v, dt, gravity_g=gravity)


# Note: The gravity force is applied directly to the particle in `particle_motion.py`,
# so we don't need a separate general forcing term for the fluid here.

# --- The main solver call now uses the new dynamic functions ---
# This assembles all the components into a single function that performs `inner_steps` of the simulation.
step_fn = cfd.funcutils.repeated(
    ib.equations.semi_implicit_navier_stokes_timeBC(
        density=density,
        viscosity=viscosity,
        dt=dt,
        grid=grid,
        convect=convect,
        pressure_solve= ib.pressure.solve_fast_diag,
        forcing=None, # No general forcing on the fluid
        time_stepper= ib.time_stepping.forward_euler_updated,
        IBM_forcing = IBM_forcing,          # Using the new penalty/tension force function
        Updating_Position = Update_position,  # Using the new deformable update function
        Drag_fn = internal_post_processing,
        ),
    steps=inner_steps)

# This `rollout` function wraps `step_fn` to run it for `outer_steps` and collect the full trajectory.
rollout_fn = cfd.funcutils.trajectory(
        step_fn, outer_steps, start_with_input=True)

# --- EXECUTE THE SIMULATION ---
# This is the main JAX call. `rollout_fn` is JIT-compiled for performance,
# and then executed. `jax.device_get` moves the final results from the device (GPU/TPU) back to the host (CPU).
final_result, trajectory = jax.device_get(rollout_fn(all_variables))

## Visualization

The following cells visualize the simulation results. The first block shows how to create static plots of the vorticity field at different time steps. The second block generates a full MP4 video of the simulation, which is a much better way to see the dynamics.

In [6]:
import matplotlib.pyplot as plt
import seaborn
import numpy as np

# Grid mesh for plotting
X,Y = grid.mesh(grid.cell_center)

def calc_vorticity(trajectory, idx):
    """
    Calculates the vorticity at a specific time step from the trajectory.
    NOTE: This is now updated for static periodic boundary conditions.
    """
    # Reconstruct the simple periodic BC object for post-processing.
    vel_bc = (boundaries.periodic_boundary_conditions(ndim=2),
              boundaries.periodic_boundary_conditions(ndim=2))

    # Reconstruct the GridVariable for this specific time step's velocity data.
    velocity = (grids.GridVariable(grids.GridArray(trajectory.velocity[0].data[idx], trajectory.velocity[0].offset, trajectory.velocity[0].grid), vel_bc[0]),
                grids.GridVariable(grids.GridArray(trajectory.velocity[1].data[idx], trajectory.velocity[1].offset, trajectory.velocity[1].grid), vel_bc[1]))

    # Calculate vorticity using finite differences.
    return finite_differences.central_difference(velocity[1], axis=0).data - finite_differences.central_difference(velocity[0], axis=1).data

# --- Main Static Plotting Loop ---
fig,ax = plt.subplots(figsize=(12, 26), nrows=max(1, outer_steps // 2))

# Ensure 'ax' is an array even if nrows is 1 to avoid indexing errors.
if outer_steps // 2 == 1:
    ax = [ax]

counter = 0
# Plot every second frame from the trajectory.
for idx in range(0, outer_steps, 2):
    # Plot the vorticity field contours.
    ax[counter].contour(X, Y, calc_vorticity(trajectory, idx), cmap=seaborn.color_palette("vlag", as_cmap=True), levels=np.linspace(-10, 10, 10))

    # --- KEY CHANGE IN VISUALIZATION ---
    # The particle's shape is no longer recalculated from kinematic functions.
    # Instead, we get the DEFORMED particle shape directly from the trajectory data.

    # Access the particle's state at the current time step 'idx'.
    particle_state_at_idx = trajectory.particles.particles[0]

    # Get the current, deformed coordinates of the fluid-interacting markers (X markers).
    xp = particle_state_at_idx.xp[idx]
    yp = particle_state_at_idx.yp[idx]

    # Plot the deformed particle shape.
    ax[counter].set_xlim([9.5, 12.7])
    ax[counter].set_ylim([5.5, 8.7])
    ax[counter].plot(xp, yp, 'k-', linewidth=2.0)

    # To see both the fluid and mass markers, you can uncomment this block:
    # Ym_x = particle_state_at_idx.Ym_x[idx]
    # Ym_y = particle_state_at_idx.Ym_y[idx]
    # ax[counter].plot(Ym_x, Ym_y, 'r--', linewidth=1.0) # Mass markers in dashed red

    ax[counter].set_aspect('equal', adjustable='box')

    counter += 1

plt.tight_layout() # Adjusts subplot params for a tight layout.

plt.show()

In [7]:
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML
import seaborn
import numpy as np

# --- 1. SETUP THE FIGURE AND HELPER FUNCTIONS ---
fig, ax = plt.subplots(figsize=(10, 8))
fig.patch.set_facecolor('white')
ax.set_facecolor('white')

# Re-use the vorticity calculation from the static plot cell.
def calc_vorticity_for_frame(trajectory, frame_index):
    vel_bc = (boundaries.periodic_boundary_conditions(ndim=2),
              boundaries.periodic_boundary_conditions(ndim=2))
    velocity_at_t = (
        grids.GridVariable(grids.GridArray(trajectory.velocity[0].data[frame_index], trajectory.velocity[0].offset, grid), vel_bc[0]),
        grids.GridVariable(grids.GridArray(trajectory.velocity[1].data[frame_index], trajectory.velocity[1].offset, grid), vel_bc[1])
    )
    vorticity = (finite_differences.central_difference(velocity_at_t[1], axis=0).data -
                 finite_differences.central_difference(velocity_at_t[0], axis=1).data)
    return vorticity

# --- 2. DEFINE THE ANIMATION FUNCTION (CALLED FOR EACH FRAME) ---

def animate(i):
    ax.clear() # Clear the previous frame
    ax.set_facecolor('white')
    vorticity = calc_vorticity_for_frame(trajectory, i)

    # Plot the fluid vorticity as a contour plot
    ax.contour(X, Y, vorticity, cmap=seaborn.color_palette("vlag", as_cmap=True), levels=np.linspace(-10, 10, 20))

    # Get the particle's state for the current frame 'i'
    particle_at_i = trajectory.particles.particles[0]
    xp = particle_at_i.xp[i]
    yp = particle_at_i.yp[i]

    # Draw the opaque filled shape of the particle (fluid-marker boundary)
    ax.fill(xp, yp, 'steelblue', alpha=1.0, zorder=2)
    # Draw the outline of the particle
    ax.plot(xp, yp, color='lightsteelblue', linewidth=2.5, zorder=3)

    # --- NEW: Optionally plot the internal mass markers ---
    if show_mass_markers:
        Ym_x = particle_at_i.Ym_x[i]
        Ym_y = particle_at_i.Ym_y[i]
        ax.plot(Ym_x, Ym_y, 'r--', linewidth=1.5, zorder=4, label='Mass Markers')
        if i == 0: # Add a legend only on the first frame for clarity
             ax.legend()

    # Set titles and limits for a clean plot
    current_t = i * inner_steps * dt
    ax.set_title(f"Deformable Particle Simulation | Time: {current_t:.4f} s")
    ax.set_xlim(x_lims)
    ax.set_ylim(y_lims)
    ax.set_aspect('equal', adjustable='box')


# --- 3. CONFIGURE AND CREATE THE ANIMATION ---
print("Creating animation... This may take a moment.")

# --- CONTROLS ---
animation_interval = 200 # milliseconds between frames in the display
video_fps = 10          # frames per second for the saved MP4 file
x_lims = [9.5, 13]      # Set the x-axis zoom
y_lims = [5.5, 8.7]      # Set the y-axis zoom
show_mass_markers = True # Set to True to see the internal mass skeleton, or False to hide it

# Get the total number of frames from the trajectory data
num_frames = trajectory.particles.particles[0].xp.shape[0]

# Create the animation object
anim = animation.FuncAnimation(fig, animate, frames=num_frames, interval=animation_interval, blit=False)

# Save the animation to an MP4 file
writer = animation.FFMpegWriter(fps=20)
anim.save('deformable_particle.mp4', writer=writer)

print("Animation saved as 'deformable_particle.mp4'")

# --- 4. DISPLAY THE ANIMATION IN THE NOTEBOOK ---
plt.close() # Close the static plot figure to prevent it from displaying twice
HTML(anim.to_jshtml()) # Convert the animation to JS/HTML for embedding

### Bonus: Velocity Field Animation

This final cell creates an alternative animation showing the fluid velocity magnitude and direction (using quiver arrows) instead of vorticity. This can give a different and useful perspective on the flow dynamics.

In [8]:
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML
import seaborn
import numpy as np
from jax_ib.base import interpolation # Need interpolation for this plot

# --- 1. SETUP THE FIGURE AND HELPER FUNCTIONS ---
fig, ax = plt.subplots(figsize=(12, 8))
fig.patch.set_facecolor('white')
ax.set_facecolor('white')

X, Y = grid.mesh(grid.cell_center)

# This helper function is needed because velocity is on a staggered grid,
# but for a quiver plot, we want all vectors to originate from cell centers.
def get_velocity_at_centers(trajectory, frame_index):
    """Reconstructs and interpolates the velocity field to cell centers for a given frame."""
    vel_bc = (boundaries.periodic_boundary_conditions(ndim=2),
              boundaries.periodic_boundary_conditions(ndim=2))
    
    u_staggered = grids.GridVariable(grids.GridArray(trajectory.velocity[0].data[frame_index], trajectory.velocity[0].offset, grid), vel_bc[0])
    v_staggered = grids.GridVariable(grids.GridArray(trajectory.velocity[1].data[frame_index], trajectory.velocity[1].offset, grid), vel_bc[1])

    # Interpolate the x and y velocity components to the cell center locations.
    u_centered = interpolation.linear(u_staggered, grid.cell_center)
    v_centered = interpolation.linear(v_staggered, grid.cell_center)

    return u_centered.data, v_centered.data

# --- 2. DEFINE THE ANIMATION FUNCTION (CALLED FOR EACH FRAME) ---

def animate_velocity(i):
    ax.clear()
    ax.set_facecolor('white')
    u_data, v_data = get_velocity_at_centers(trajectory, i)
    speed = np.sqrt(u_data**2 + v_data**2)

    # Plot the fluid speed as a colormap
    ax.pcolormesh(X, Y, speed, cmap='Blues', shading='gouraud')
    
    # Overlay quiver arrows to show velocity direction. `skip` plots fewer arrows for clarity.
    skip = 30
    ax.quiver(X[::skip, ::skip], Y[::skip, ::skip],
              u_data[::skip, ::skip], v_data[::skip, ::skip],
              color='black', scale=25)

    # Plot the particle shape (same logic as before)
    particle_at_i = trajectory.particles.particles[0]
    xp = particle_at_i.xp[i]
    yp = particle_at_i.yp[i]

    ax.fill(xp, yp, 'steelblue', alpha=1.0, zorder=2)
    ax.plot(xp, yp, color='lightsteelblue', linewidth=2.5, zorder=3)

    # Optionally plot the mass markers
    if show_mass_markers:
        Ym_x = particle_at_i.Ym_x[i]
        Ym_y = particle_at_i.Ym_y[i]
        ax.plot(Ym_x, Ym_y, 'r--', linewidth=1.5, zorder=4, label='Mass Markers')
        if i == 0:
             ax.legend()

    current_t = i * inner_steps * dt
    ax.set_title(f"Deformable Particle Simulation (Velocity Field) | Time: {current_t:.4f} s")
    ax.set_xlim(x_lims)
    ax.set_ylim(y_lims)
    ax.set_aspect('equal', adjustable='box')

# --- 3. CONFIGURE AND CREATE THE ANIMATION ---
print("Creating velocity animation...")

# --- CONTROLS ---
animation_interval = 200
video_fps = 10
x_lims = [9.0, 15] # Wider view for velocity field
y_lims = [5.0, 9.0]
show_mass_markers = False # Turn off for a cleaner velocity plot

num_frames = trajectory.particles.particles[0].xp.shape[0]
anim_vel = animation.FuncAnimation(fig, animate_velocity, frames=num_frames, interval=animation_interval, blit=False)

# Save the animation to a new MP4 file
writer = animation.FFMpegWriter(fps=20)
anim_vel.save('deformable_particle_velocity.mp4', writer=writer)

print("Animation saved as 'deformable_particle_velocity.mp4'")

# --- 4. DISPLAY THE ANIMATION IN THE NOTEBOOK ---
plt.close()
HTML(anim_vel.to_jshtml())