# Differentiable Simulation of a Deformable Flapping Body

This notebook demonstrates a full Fluid-Structure Interaction (FSI) simulation of a deformable body using the JAX-IB framework.

**Key Differences from the Original Rigid Body Model:**
1.  **Dynamic vs. Kinematic:** The particle's motion is no longer prescribed by a mathematical function. Instead, it is *dynamically calculated* at each time step based on physical forces (internal elasticity, fluid pressure, viscosity, etc.).
2.  **Deformable Physics:** The particle is modeled as an elastic body with mass, stiffness, and surface tension, allowing it to deform in response to fluid forces. This is based on the penalty Immersed Boundary Method from Sustiel & Grier.
3.  **Stateful Simulation:** The core data structures have been rewritten to be *stateful*, holding the current positions and velocities of the particle markers. The entire simulation state is a JAX PyTree that is evolved over time.

This first cell handles the initial setup: cloning the repository from GitHub and installing it as a Python package in the Colab environment.

In [None]:
# --- SETUP AND INSTALLATION ---

# 1. (Optional) Clean up any old repository versions to ensure a fresh start
!rm -rf Hashim

# 2. Clone the latest version of the code from your GitHub repository
!git clone https://github.com/nurmaton/Hashim.git

# 3. Change directory into the cloned repository
%cd Hashim

# 4. Verify that the project files (like pyproject.toml and the jax_ib folder) are present
!ls

# 5. Install the project as a Python package in editable mode.
# The '-e' flag is for development; it means any changes you make to the source .py files
# will be immediately available without needing to reinstall.
# pip will automatically read the dependencies from your pyproject.toml file.
!pip install -e '.[data]'

### Imports and Compatibility Fixes

This cell imports all the necessary modules from JAX and the `jax_ib` library.

It also includes a **monkey-patch** for `jax.random.KeyArray`. This is a crucial compatibility fix. The original `jax_ib` library was built on an older version of JAX where random keys had a specific type. In modern JAX, keys are just regular arrays (`jnp.ndarray`). This line manually adds the old name back, pointing it to the new type, which allows the library code to run without modification.

In [2]:
import jax
import jax.numpy as jnp

# --- COMPATIBILITY FIX ---
# Monkey-patch the missing 'KeyArray' attribute that the older jax_ib library expects.
# In modern JAX, a random key is just a regular JAX array.
jax.random.KeyArray = jnp.ndarray
# --- END FIX ---

# --- Core JAX-IB and JAX-CFD Imports ---
import jax_ib
from jax_ib.base import particle_class as pc
import jax_ib.base as ib
import jax_cfd.base as cfd # Using utilities from jax_cfd
import jax_ib.MD as MD
from jax import random
from jax_md import space, quantity
from jax_ib.base import grids
from jax_ib.base import boundaries
from jax_ib.base import advection, finite_differences

# --- NEW: Import the modules specific to the deformable body physics ---
from jax_ib.base import IBM_Force, convolution_functions, particle_motion

## Flow Problem Setup

This cell sets up the **Eulerian** part of the simulation: the fluid properties, the grid, boundary conditions, and the initial state of the fluid.

In [3]:
#-- Flow conditions --
density = 1.0
viscosity = 0.05
dt = 5e-4
inner_steps = 1000 # Number of solver steps per animation frame
outer_steps = 20   # Total number of animation frames to generate

# --- NEW: Physical Parameters for the Deformable Particle ---
# These parameters define the physical nature of the immersed body.
# They were not needed in the old kinematic model.
stiffness = 5e4   # Spring stiffness (Kp from the paper). This is a crucial parameter to tune.
total_mass = 1   # Total mass of the immersed object.
sigma = 1.0        # Surface tension coefficient. Set to 0.0 to turn this force off.
gravity = 9.81     # Gravitational acceleration, for simulating sedimentation.

#-- Domain and Grid Setup --
domain = ((0, 15.), (0, 15.0))
size = (600, 600)
grid = grids.Grid(size, domain=domain)

#-- Boundary Conditions --
# NOTE: For this simulation, we use simple, static periodic boundaries.
# The old model had complex functions for time-varying walls, which are not needed here.
velocity_bc = (boundaries.periodic_boundary_conditions(ndim=2),
               boundaries.periodic_boundary_conditions(ndim=2))

#-- Convection Scheme --
# Use a simple and stable upwind scheme for advection.
def convect(v):
  return tuple(advection.advect_upwind(u, v, dt) for u in v)

#-- Initial velocity profile (Fluid at Rest) --
vx_fn = lambda x, y: jnp.zeros_like(x + y)
vy_fn = lambda x, y: jnp.zeros_like(x + y)

velocity_fns = (vx_fn, vy_fn)
v0 = tuple(grid.eval_on_mesh(v_fn, offset) for v_fn, offset in zip(velocity_fns, grid.cell_faces))

# Wrap the raw velocity arrays in GridVariable objects, which attach the boundary conditions.
v0 = tuple(
      grids.GridVariable(u, bc) for u, bc in zip(v0, velocity_bc))

#-- Initial Pressure Profile --
# Start with zero pressure everywhere.
pressure0 = grids.GridVariable(grids.GridArray(jnp.zeros_like(v0[0].data), grid.cell_center, grid), boundaries.get_pressure_bc_from_velocity(v0))

## Immersed Boundary Setup

This cell sets up the **Lagrangian** part of the simulation: the deformable particle.

**This is the biggest conceptual change from the rigid body model.** Instead of defining mathematical functions for motion (`kinematics`), we now define the particle's initial physical state: its shape, position, mass, stiffness, etc. This state is then stored in the `particle` object, which will be updated by the physics solver at every time step.

In [None]:
#-- Immersed objects --

# --- 1. Define a function to generate the particle's shape ---
# This function returns the coordinates of an ellipse with even point spacing.
def foil_XY_ELLIPSE(geometry_param, theta_dummy):
    A = geometry_param[0]  # Semi-major axis
    B = geometry_param[0]  # Semi-minor axis
    ntheta = 150 # Number of Lagrangian marker points on the boundary

    angle = jnp.linspace(0, 2 * jnp.pi, ntheta, endpoint=False)
    x = A * jnp.cos(angle)
    y = B * jnp.sin(angle)
    return x, y

# --- 2. NEW INITIALIZATION LOGIC FOR THE DEFORMABLE PARTICLE ---
particle_geometry_param = jnp.array([[0.5, 0.1]]) # Semi-axes [A, B] of the ellipse
particle_center_position = jnp.array([[domain[0][1]*0.75, domain[1][1]/2],]) # Initial [x, y] center
Shape_fn = foil_XY_ELLIPSE

# Generate the initial shape coordinates in the particle's local reference frame (centered at origin)
Grid_p = pc.Grid1d(shape=150, domain=(0, 2*jnp.pi)) # A dummy grid for the shape function
xp0_body, yp0_body = Shape_fn(particle_geometry_param[0], Grid_p.mesh())

# Calculate the initial positions of the marker points in the global simulation frame.
# In the penalty method, the fluid-interacting markers (X) and mass-carrying markers (Y)
# start at the same location.
initial_xp = xp0_body + particle_center_position[0, 0]
initial_yp = yp0_body + particle_center_position[0, 1]

# The particle starts from rest, so initial velocities for the mass markers are zero.
initial_Vm_x = jnp.zeros_like(initial_xp)
initial_Vm_y = jnp.zeros_like(initial_xp)

# Calculate the mass of each individual marker point.
num_markers = len(initial_xp)
mass_per_marker = total_mass / num_markers

# --- 3. Create the final particle object with its full dynamic state ---
# This object holds all the stateful information that will be updated by the solver.
deformable_particle = pc.particle(
    xp=initial_xp, yp=initial_yp, # Current fluid-interacting marker positions (X in paper)
    Ym_x=initial_xp.copy(), Ym_y=initial_yp.copy(), # Mass-carrying marker positions (Y in paper)
    Vm_x=initial_Vm_x, Vm_y=initial_Vm_y, # Mass-carrying marker velocities
    mass_per_marker=mass_per_marker,
    stiffness=stiffness, # The penalty spring constant, Kp
    sigma=sigma, # The surface tension coefficient
    particle_center=particle_center_position, # Static initial center
    geometry_param=particle_geometry_param, # Static shape parameters
    Grid=Grid_p, # Static grid object
    shape=Shape_fn # Static shape function
)

# This container holds the list of all particles in the simulation (in this case, just one).
# It is a JAX PyTree, which is crucial for JIT compilation.
particles_container = pc.particle_lista(particles=[deformable_particle])

## Simulation Setup

This cell assembles the complete time-stepping function. It takes all the individual physics components we've defined (fluid convection, IBM force, particle motion) and combines them into a single `step_fn` that can advance the entire simulation state by one time step.

In [None]:
#-- Setup Pytree Variable to be looped over during integration steps --

Intermediate_calcs = [0] # Placeholder for any post-processing calculations
Step_counter = 0
MD_state = [0] # Placeholder for Brownian motion state (not used here)

# The `All_Variables` object is the master PyTree that contains the ENTIRE simulation state.
# This is the object that will be passed into and out of the main JAX loop.
all_variables = pc.All_Variables(particles_container, v0, pressure0, Intermediate_calcs, Step_counter, MD_state)

#-- IB force coupling functions for the Deformable Model --
def internal_post_processing(all_variables, dt):
    # This function can be used for any calculations you want to perform inside the time-step loop.
    return all_variables

# This is the discrete delta function kernel used for all spreading/interpolation operations.
discrete_delta = lambda dist, center, width: convolution_functions.delta_approx_logistjax(dist, center, width)

# --- NEW, UPDATED IBM FORCING AND PARTICLE UPDATE FUNCTIONS ---
# These lambda functions adapt our specific physics modules to the generic API required by the main solver.

# The IBM forcing function now calculates the physical penalty/tension force from the particle's state.
# The `v` argument passed by the solver is the entire `all_variables` pytree.
IBM_forcing = lambda v, dt: IBM_Force.calc_IBM_force_NEW_MULTIPLE(v, discrete_delta, dt)

# The particle update function now calls our new massive, deformable particle stepper.
# It solves the equations of motion for the particle markers.
Update_position = lambda v, dt: particle_motion.update_massive_deformable_particle(v, dt, gravity_g=gravity)

# Note: We don't use a general forcing (e.g. gravity) term for the fluid here.

# --- The main solver call now uses the new dynamic functions ---
# This assembles all the components into a single function that performs `inner_steps` of the simulation.
step_fn = cfd.funcutils.repeated(
    ib.equations.semi_implicit_navier_stokes_timeBC(
        density=density,
        viscosity=viscosity,
        dt=dt,
        grid=grid,
        convect=convect,
        pressure_solve= ib.pressure.solve_fast_diag,
        forcing=None, # No general forcing on the fluid
        time_stepper= ib.time_stepping.forward_euler_updated,
        IBM_forcing = IBM_forcing,          # Using the new penalty/tension force function
        Updating_Position = Update_position,  # Using the new deformable update function
        Drag_fn = internal_post_processing,
        ),
    steps=inner_steps)

# --- EXECUTE THE SIMULATION USING THE ROLLOUT FUNCTION ---
# The rollout function compiles the entire loop into a single, highly efficient operation using jax.lax.scan.
# This is often the most performant method, but does not allow for a progress bar.
rollout_fn = cfd.funcutils.trajectory(
        step_fn, outer_steps, start_with_input=True)

# Run the simulation and move the results from the device (GPU/TPU) back to the host (CPU) memory.
final_result, trajectory = jax.device_get(rollout_fn(all_variables))

print("Simulation complete.")

## Visualization

The following cells visualize the simulation results using three different methods. Each provides a unique perspective on the dynamics of the deformable body. All visualization cells include progress bars for long rendering times.

*   **Static Vorticity Plots:** Generates a series of individual snapshot images of the vorticity field with a view that automatically pans to keep the particle centered.
*   **Vorticity Animation:** Creates a full MP4 video of the vorticity field, also using a "tracking camera" that follows the particle.
*   **Velocity Animation:** Creates a second animation of the velocity field, but uses an "intelligent framing" technique to set a fixed window that contains the entire motion, making the particle appear to move across the screen.

### Static Vorticity Plots with Tracking

This cell generates a series of individual snapshot images, one for each specified frame of the simulation. Each subplot shows the vorticity field and the particle's deformed shape at that instant.

The view for each subplot automatically **pans to keep the moving particle centered in the frame**. This is achieved by recalculating the `xlim` and `ylim` for each iteration of the loop based on the particle's current center of mass.

In [None]:
import matplotlib.pyplot as plt
import seaborn
import numpy as np
# NEW: Import the correct progress bar utility
from tqdm.notebook import tqdm

# Grid mesh for plotting
X,Y = grid.mesh(grid.cell_center)

def calc_vorticity(trajectory, idx):
    """Calculates the vorticity at a specific time step from the trajectory."""
    # --- FIX 1: Create a simple periodic BC for post-processing ---
    # We no longer need the old, non-existent variables like vx_bc or bc_fns.
    vel_bc = (boundaries.periodic_boundary_conditions(ndim=2),
              boundaries.periodic_boundary_conditions(ndim=2))

    # Reconstruct the GridVariable for this specific time step's velocity data
    velocity = (grids.GridVariable(grids.GridArray(trajectory.velocity[0].data[idx], trajectory.velocity[0].offset, grid), vel_bc[0]),
                grids.GridVariable(grids.GridArray(trajectory.velocity[1].data[idx], trajectory.velocity[1].offset, grid), vel_bc[1]))

    # Calculate vorticity using finite differences
    return finite_differences.central_difference(velocity[1], axis=0).data - finite_differences.central_difference(velocity[0], axis=1).data

# --- DEFINE THE FRAMES TO PLOT ---
# This is the single source of truth for the plotting logic.
frames_to_plot = np.arange(0, outer_steps, 2)
num_plots = len(frames_to_plot)
# --- END ---

# --- NEW: TOGGLE FOR MASS MARKER VISIBILITY ---
# Set this to True to see the mass markers or False to hide them.
show_mass_markers = True
# --- END ---

# --- Main Static Plotting Loop ---
# The number of rows is now directly derived from the number of frames we want to plot.
fig,ax = plt.subplots(figsize=(12, 6 * num_plots), nrows=num_plots)

# Ensure 'ax' is an array even if nrows is 1 to avoid indexing errors.
if num_plots == 1:
    ax = [ax]

# This controls the size of the "camera" window around the particle.
plot_window_width = 3.5
plot_window_height = 3.5

# --- Use `tqdm` to wrap the iterable for a progress bar ---
for counter, idx in enumerate(tqdm(frames_to_plot, desc="Generating Static Plots")):
    # Access the particle's state at the current time step 'idx'.
    particle_state_at_idx = trajectory.particles.particles[0]

    # Calculate the center of the particle at the current time step.
    center_x = np.mean(particle_state_at_idx.Ym_x[idx])
    center_y = np.mean(particle_state_at_idx.Ym_y[idx])

    # --- ADDED FIX: AVOID CRASHING ON EMPTY FRAMES ---
    if np.isnan(center_x) or np.isnan(center_y):
        print(f"Warning: Skipping frame {idx} due to missing particle data.")
        continue

    # Set the x and y limits based on the particle's center and the desired window size.
    ax[counter].set_xlim([center_x - plot_window_width / 2, center_x + plot_window_width / 2])
    ax[counter].set_ylim([center_y - plot_window_height / 2, center_y + plot_window_height / 2])

    # Plot the vorticity field contours using the correct function.
    ax[counter].contour(X, Y, calc_vorticity(trajectory, idx), cmap=seaborn.color_palette("vlag", as_cmap=True), levels=np.linspace(-10, 10, 10))

    # Get the current, deformed coordinates of the fluid-interacting markers (X markers).
    xp = particle_state_at_idx.xp[idx]
    yp = particle_state_at_idx.yp[idx]

    # Plot the deformed particle shape (the fluid boundary).
    ax[counter].plot(xp, yp, 'k-', linewidth=2.0, label='Fluid Markers (xp, yp)')

    # --- PLOT MASS MARKERS IF TOGGLE IS TRUE ---
    if show_mass_markers:
        Ym_x = particle_state_at_idx.Ym_x[idx]
        Ym_y = particle_state_at_idx.Ym_y[idx]
        ax[counter].plot(Ym_x, Ym_y, 'r--', linewidth=1.0, label='Mass Markers (Ym)')
        ax[counter].legend()
    # --- END ---

    ax[counter].set_aspect('equal', adjustable='box')

plt.tight_layout()
plt.show()

### Vorticity Animation with Tracking Camera

This cell creates a full MP4 video and an interactive HTML animation of the simulation, visualizing the **vorticity field**.

Similar to the static plots above, this animation uses a **"tracking camera"** that follows the particle as it moves and deforms. The particle will appear relatively stationary in the center of the frame, while the surrounding fluid appears to flow past it. This is useful for closely observing the particle's deformation and the vortex shedding at its boundary.

In [None]:
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML
import seaborn
import numpy as np
# NEW: Import the progress bar utility for notebooks
from tqdm.notebook import tqdm

# --- 1. SETUP THE FIGURE AND HELPER FUNCTIONS ---
# Create the figure and axis object that will be used as the canvas for our animation.
fig, ax = plt.subplots(figsize=(10, 8))
# Set a clean white background for both the figure and the plotting area.
fig.patch.set_facecolor('white')
ax.set_facecolor('white')

# Create a mesh of coordinates for the fluid grid, which is needed for contour plots.
X, Y = grid.mesh(grid.cell_center)

# This helper function calculates the vorticity field for a single frame of the simulation.
def calc_vorticity_for_frame(trajectory, frame_index):
    """
    Calculates the vorticity at a specific time step from the trajectory.
    NOTE: The trajectory object stores raw data arrays. To perform calculations
    that respect boundary conditions, we must first reconstruct the full
    GridVariable objects.
    """
    # Reconstruct the simple periodic BC object for post-processing.
    vel_bc = (boundaries.periodic_boundary_conditions(ndim=2),
              boundaries.periodic_boundary_conditions(ndim=2))

    # Re-wrap the raw velocity data from the specified trajectory frame into GridVariable objects.
    velocity_at_t = (
        grids.GridVariable(grids.GridArray(trajectory.velocity[0].data[frame_index], trajectory.velocity[0].offset, grid), vel_bc[0]),
        grids.GridVariable(grids.GridArray(trajectory.velocity[1].data[frame_index], trajectory.velocity[1].offset, grid), vel_bc[1])
    )

    # Calculate vorticity using the formula ω = dv/dx - du/dy, approximated with finite differences.
    vorticity = (finite_differences.central_difference(velocity_at_t[1], axis=0).data -
                 finite_differences.central_difference(velocity_at_t[0], axis=1).data)
    return vorticity

# --- 2. DEFINE THE ANIMATION FUNCTION (CALLED FOR EACH FRAME) ---

def animate(i):
    """
    This function is the core of the animation. Matplotlib's FuncAnimation will
    call this function repeatedly, once for each frame, with `i` as the frame index.
    """
    ax.clear() # Clear the contents of the previous frame.
    ax.set_facecolor('white')

    # Get the particle's state for the current frame 'i'.
    particle_at_i = trajectory.particles.particles[0]

    # --- ADDED FIX: HANDLE FRAMES WITH MISSING DATA ---
    # Attempt to calculate the particle's center.
    center_x = np.mean(particle_at_i.Ym_x[i])
    center_y = np.mean(particle_at_i.Ym_y[i])

    # Check if the calculation resulted in NaN (due to empty coordinate arrays).
    if np.isnan(center_x) or np.isnan(center_y):
        print(f"Warning: Skipping animation frame {i} due to missing particle data.")
        # Return early, leaving a blank frame. This prevents the animation from crashing.
        return
    # --- END FIX ---

    # If data is valid, proceed with plotting.
    vorticity = calc_vorticity_for_frame(trajectory, i)

    # Plot the fluid vorticity as a contour plot.
    ax.contour(X, Y, vorticity, cmap=seaborn.color_palette("vlag", as_cmap=True), levels=np.linspace(-10, 10, 20))

    # Get particle shape coordinates.
    xp = particle_at_i.xp[i]
    yp = particle_at_i.yp[i]

    # Dynamically set plot limits to track the particle.
    ax.set_xlim([center_x - plot_window_width / 2, center_x + plot_window_width / 2])
    ax.set_ylim([center_y - plot_window_height / 2, center_y + plot_window_height / 2])

    # Draw the opaque filled shape of the particle (fluid-marker boundary).
    ax.fill(xp, yp, 'steelblue', alpha=1.0, zorder=2)
    # Draw the outline of the particle for a clean look.
    ax.plot(xp, yp, color='lightsteelblue', linewidth=2.5, zorder=3)

    # Optionally plot the internal mass markers.
    if show_mass_markers:
        Ym_x = particle_at_i.Ym_x[i]
        Ym_y = particle_at_i.Ym_y[i]
        ax.plot(Ym_x, Ym_y, 'r--', linewidth=1.5, zorder=4, label='Mass Markers')
        if i == 0:
             ax.legend()

    # Set titles and aspect ratio for a clean plot.
    current_t = i * inner_steps * dt
    ax.set_title(f"Deformable Particle Simulation | Time: {current_t:.4f} s")
    ax.set_aspect('equal', adjustable='box')


# --- 3. CONFIGURE AND CREATE THE ANIMATION ---
print("Creating animation... This may take a moment.")

# --- CONTROLS ---
animation_interval = 200      # milliseconds between frames in the interactive display.
video_fps = 10                # frames per second for the saved MP4 file.
plot_window_width = 3.5       # The width of the plotting window around the particle.
plot_window_height = 3.5      # The height of the plotting window around the particle.
show_mass_markers = True      # Set to True to see the internal mass skeleton, or False to hide it.

# Get the total number of frames from the trajectory data.
num_frames = trajectory.particles.particles[0].xp.shape[0]

# Create the main animation object.
anim = animation.FuncAnimation(fig, animate, frames=num_frames, interval=animation_interval, blit=False)

# --- NEW: ADD PROGRESS BAR TO THE SAVING PROCESS ---
# Create a tqdm progress bar instance.
with tqdm(total=num_frames, desc="Rendering Animation") as pbar:
    # Define the callback function that matplotlib will call for each frame.
    def progress_update(current_frame, total_frames):
        pbar.update(1) # Increment the progress bar by one step.

    # Set up the writer and save the animation to an MP4 file.
    # Crucially, we pass our new function to the `progress_callback` argument.
    writer = animation.FFMpegWriter(fps=video_fps)
    anim.save('deformable_particle.mp4', writer=writer, progress_callback=progress_update)
# --- END NEW ---

print("Animation saved as 'deformable_particle.mp4'")

# --- 4. DISPLAY THE ANIMATION IN THE NOTEBOOK ---
plt.close() # Close the static plot figure to prevent it from displaying twice.
HTML(anim.to_jshtml()) # Convert the animation to JS/HTML for embedding in the notebook output.

### Vorticity Animation with Intelligent Framing

This final cell creates another vorticity animation, but uses a different and more robust visualization technique: **"intelligent framing"**. It first analyzes the particle's entire path to determine the total range of motion. It then creates a single, **fixed window** that is guaranteed to contain all the action.

The result is a smoother 'panning' effect where the **particle moves across a static set of axes**, rather than the axes moving with the particle. This gives a better sense of the particle's overall travel through the domain.

In [None]:
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML
import seaborn
import numpy as np
from tqdm.notebook import tqdm

# --- 1. SETUP THE FIGURE AND HELPER FUNCTIONS ---
fig, ax = plt.subplots(figsize=(10, 8))
fig.patch.set_facecolor('white')
ax.set_facecolor('white')

X, Y = grid.mesh(grid.cell_center)

def calc_vorticity_for_frame(trajectory, frame_index):
    """Calculates the vorticity at a specific time step from the trajectory."""
    vel_bc = (boundaries.periodic_boundary_conditions(ndim=2),
              boundaries.periodic_boundary_conditions(ndim=2))
    velocity_at_t = (
        grids.GridVariable(grids.GridArray(trajectory.velocity[0].data[frame_index], trajectory.velocity[0].offset, grid), vel_bc[0]),
        grids.GridVariable(grids.GridArray(trajectory.velocity[1].data[frame_index], trajectory.velocity[1].offset, grid), vel_bc[1])
    )
    vorticity = (finite_differences.central_difference(velocity_at_t[1], axis=0).data -
                 finite_differences.central_difference(velocity_at_t[0], axis=1).data)
    return vorticity

# --- 2. ROBUSTLY DETERMINE PLOT LIMITS FROM TRAJECTORY ---

# --- DEFINITIVE FIX: CLEAN DATA BEFORE CALCULATING BOUNDS ---
try:
    # Get the full history of the particle's mass marker positions.
    all_Ym_x = trajectory.particles.particles[0].Ym_x
    all_Ym_y = trajectory.particles.particles[0].Ym_y

    # First, create a single, flat array of all coordinates, skipping empty frames.
    all_valid_x = np.concatenate([arr for arr in all_Ym_x if arr.size > 0])
    all_valid_y = np.concatenate([arr for arr in all_Ym_y if arr.size > 0])

    # NEXT, CRUCIALLY, FILTER OUT ANY NaN OR Inf VALUES.
    all_finite_x = all_valid_x[np.isfinite(all_valid_x)]
    all_finite_y = all_valid_y[np.isfinite(all_valid_y)]

    # If no finite data points exist after cleaning, raise an error to use the fallback.
    if all_finite_x.size == 0 or all_finite_y.size == 0:
        raise ValueError("No finite particle data found in the entire trajectory.")

    # Now it is completely safe to find the min and max coordinates.
    min_x, max_x = np.min(all_finite_x), np.max(all_finite_x)
    min_y, max_y = np.min(all_finite_y), np.max(all_finite_y)

    # Define a padding factor to add some space around the particle's path.
    padding = 1.5

    # Set the final, fixed plot limits.
    x_lims = [min_x - padding, max_x + padding]
    y_lims = [min_y - padding, max_y + padding]

except (ValueError, IndexError) as e:
    # This block runs if the data is empty, contains only NaNs, or has other issues.
    print(f"Warning: Could not determine animation bounds automatically ({e}). Using default limits.")
    x_lims = [0, 16] # Fallback to safe default limits.
    y_lims = [0, 16]
# --- END FIX ---


# --- 3. DEFINE THE ANIMATION FUNCTION (STILL CONTAINS A FIX) ---

def animate(i):
    """
    This function draws a single frame of the animation.
    """
    ax.clear()
    ax.set_facecolor('white')
    vorticity = calc_vorticity_for_frame(trajectory, i)

    # Plot the fluid vorticity as a contour plot.
    ax.contour(X, Y, vorticity, cmap=seaborn.color_palette("vlag", as_cmap=True), levels=np.linspace(-10, 10, 20))

    # Get the particle's state for the current frame 'i'.
    particle_at_i = trajectory.particles.particles[0]

    # This per-frame fix is still necessary to prevent drawing errors.
    xp = particle_at_i.xp[i]
    yp = particle_at_i.yp[i]

    # Only attempt to draw the particle if its coordinate arrays are not empty.
    if xp.size > 0 and yp.size > 0:
        ax.fill(xp, yp, 'steelblue', alpha=1.0, zorder=2)
        ax.plot(xp, yp, color='lightsteelblue', linewidth=2.5, zorder=3)

    if show_mass_markers:
        Ym_x = particle_at_i.Ym_x[i]
        Ym_y = particle_at_i.Ym_y[i]
        if Ym_x.size > 0 and Ym_y.size > 0:
            ax.plot(Ym_x, Ym_y, 'r--', linewidth=1.5, zorder=4, label='Mass Markers')
            if i == 0:
                 ax.legend()

    current_t = i * inner_steps * dt
    ax.set_title(f"Deformable Particle Simulation | Time: {current_t:.4f} s")

    # Use our safely calculated (or default) fixed axis limits for every frame.
    ax.set_xlim(x_lims)
    ax.set_ylim(y_lims)
    ax.set_aspect('equal', adjustable='box')


# --- 4. CONFIGURE AND CREATE THE ANIMATION ---
print("Creating animation... This may take a moment.")

# --- CONTROLS ---
animation_interval = 200
video_fps = 10
show_mass_markers = True

num_frames = trajectory.particles.particles[0].xp.shape[0]
anim = animation.FuncAnimation(fig, animate, frames=num_frames, interval=animation_interval, blit=False)

with tqdm(total=num_frames, desc="Rendering Animation") as pbar:
    def progress_update(current_frame, total_frames):
        pbar.update(1)
    writer = animation.FFMpegWriter(fps=video_fps)
    anim.save('deformable_particle.mp4', writer=writer, progress_callback=progress_update)

print("Animation saved as 'deformable_particle.mp4'")

# --- 5. DISPLAY THE ANIMATION IN THE NOTEBOOK ---
plt.close()
HTML(anim.to_jshtml())

### Bonus: Velocity Field Animation with Tracking Camera

This cell creates an alternative animation showing the **velocity field**—fluid speed as a colormap and direction with quiver arrows—instead of vorticity. This gives a different and useful perspective on the flow dynamics.

This version uses the **"tracking camera"** visualization technique. The axes limits are updated at every frame to keep the particle centered, providing a close-up view of the fluid velocity patterns immediately surrounding the moving and deforming body.

In [None]:
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML
import seaborn
import numpy as np
from jax_ib.base import interpolation # Need interpolation for this plot
# NEW: Import the progress bar utility
from tqdm.notebook import tqdm

# --- 1. SETUP THE FIGURE AND HELPER FUNCTIONS ---
# Create the figure and axis object that will be used for the animation.
fig, ax = plt.subplots(figsize=(12, 8))
fig.patch.set_facecolor('white')
ax.set_facecolor('white')

# Create a mesh of coordinates for the fluid grid centers.
X, Y = grid.mesh(grid.cell_center)

# This helper function is needed because velocity is defined on a staggered grid
# (at the faces of cells), but for a quiver plot, we want all vectors to originate
# from the cell centers. This function performs that interpolation.
def get_velocity_at_centers(trajectory, frame_index):
    """Reconstructs and interpolates the velocity field to cell centers for a given frame."""
    # Reconstruct the simple periodic BC object for post-processing.
    vel_bc = (boundaries.periodic_boundary_conditions(ndim=2),
              boundaries.periodic_boundary_conditions(ndim=2))

    # Re-wrap the raw, staggered velocity data from the trajectory into GridVariable objects.
    u_staggered = grids.GridVariable(grids.GridArray(trajectory.velocity[0].data[frame_index], trajectory.velocity[0].offset, grid), vel_bc[0])
    v_staggered = grids.GridVariable(grids.GridArray(trajectory.velocity[1].data[frame_index], trajectory.velocity[1].offset, grid), vel_bc[1])

    # Perform linear interpolation to move the u and v components to the cell centers.
    u_centered = interpolation.linear(u_staggered, grid.cell_center)
    v_centered = interpolation.linear(v_staggered, grid.cell_center)

    return u_centered.data, v_centered.data

# --- 2. DEFINE THE ANIMATION FUNCTION (CALLED FOR EACH FRAME) ---

def animate_velocity(i):
    """
    This function is called for each frame of the animation to draw the velocity field.
    """
    ax.clear()
    ax.set_facecolor('white')

    # Get the particle's state for the current frame 'i'.
    particle_at_i = trajectory.particles.particles[0]

    # --- ADDED FIX: HANDLE FRAMES WITH MISSING DATA ---
    # Attempt to calculate the particle's center to pan the camera.
    center_x = np.mean(particle_at_i.Ym_x[i])
    center_y = np.mean(particle_at_i.Ym_y[i])

    # Check if the calculation resulted in NaN (due to empty coordinate arrays).
    if np.isnan(center_x) or np.isnan(center_y):
        print(f"Warning: Skipping velocity animation frame {i} due to missing particle data.")
        # Return early, leaving a blank frame. This prevents the animation from crashing.
        return
    # --- END FIX ---

    # If data is valid, proceed with plotting.
    u_data, v_data = get_velocity_at_centers(trajectory, i)
    speed = np.sqrt(u_data**2 + v_data**2)

    # Plot the fluid speed as a background color map (pcolormesh).
    ax.pcolormesh(X, Y, speed, cmap='Blues', shading='gouraud')

    # Overlay quiver arrows to show velocity direction.
    skip = 30
    ax.quiver(X[::skip, ::skip], Y[::skip, ::skip],
              u_data[::skip, ::skip], v_data[::skip, ::skip],
              color='black', scale=25)

    # Plot the particle's shape using the saved trajectory data.
    xp = particle_at_i.xp[i]
    yp = particle_at_i.yp[i]

    # Set the plot limits dynamically using the now-validated center coordinates.
    ax.set_xlim([center_x - plot_window_width / 2, center_x + plot_window_width / 2])
    ax.set_ylim([center_y - plot_window_height / 2, center_y + plot_window_height / 2])

    ax.fill(xp, yp, 'steelblue', alpha=1.0, zorder=2)
    ax.plot(xp, yp, color='lightsteelblue', linewidth=2.5, zorder=3)

    # Optionally plot the mass markers.
    if show_mass_markers:
        Ym_x = particle_at_i.Ym_x[i]
        Ym_y = particle_at_i.Ym_y[i]
        ax.plot(Ym_x, Ym_y, 'r--', linewidth=1.5, zorder=4, label='Mass Markers')
        if i == 0:
             ax.legend()

    current_t = i * inner_steps * dt
    ax.set_title(f"Deformable Particle Simulation (Velocity Field) | Time: {current_t:.4f} s")
    ax.set_aspect('equal', adjustable='box')


# --- 3. CONFIGURE AND CREATE THE ANIMATION ---
print("Creating velocity animation...")

# --- CONTROLS ---
animation_interval = 200
video_fps = 10
plot_window_width = 5.0
plot_window_height = 4.0
show_mass_markers = False

num_frames = trajectory.particles.particles[0].xp.shape[0]
anim_vel = animation.FuncAnimation(fig, animate_velocity, frames=num_frames, interval=animation_interval, blit=False)

# --- ADD PROGRESS BAR TO THE SAVING PROCESS ---
with tqdm(total=num_frames, desc="Rendering Velocity Animation") as pbar:
    def progress_update(current_frame, total_frames):
        pbar.update(1)
    writer = animation.FFMpegWriter(fps=video_fps)
    anim_vel.save('deformable_particle_velocity.mp4', writer=writer, progress_callback=progress_update)

print("Animation saved as 'deformable_particle_velocity.mp4'")

# --- 4. DISPLAY THE ANIMATION IN THE NOTEBOOK ---
plt.close()
HTML(anim_vel.to_jshtml())

### Bonus: Velocity Animation with Intelligent Framing

This final cell creates another velocity animation, but uses the more robust **"intelligent framing"** visualization technique. It first analyzes the particle's entire path to determine the total range of motion and then creates a single, **fixed window** that is guaranteed to contain all the action.

The result is a smoother 'panning' effect where the **particle moves across a static set of axes**, rather than the axes moving with the particle. This gives a better sense of the particle's overall travel through the domain.

In [None]:
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML
import seaborn
import numpy as np
from jax_ib.base import interpolation # Need interpolation for this plot
from tqdm.notebook import tqdm # Import the progress bar utility

# --- 1. SETUP THE FIGURE AND HELPER FUNCTIONS ---
fig, ax = plt.subplots(figsize=(12, 8))
fig.patch.set_facecolor('white')
ax.set_facecolor('white')

X, Y = grid.mesh(grid.cell_center)

def get_velocity_at_centers(trajectory, frame_index):
    """Reconstructs and interpolates the velocity field to cell centers for a given frame."""
    vel_bc = (boundaries.periodic_boundary_conditions(ndim=2),
              boundaries.periodic_boundary_conditions(ndim=2))
    u_staggered = grids.GridVariable(grids.GridArray(trajectory.velocity[0].data[frame_index], trajectory.velocity[0].offset, grid), vel_bc[0])
    v_staggered = grids.GridVariable(grids.GridArray(trajectory.velocity[1].data[frame_index], trajectory.velocity[1].offset, grid), vel_bc[1])
    u_centered = interpolation.linear(u_staggered, grid.cell_center)
    v_centered = interpolation.linear(v_staggered, grid.cell_center)
    return u_centered.data, v_centered.data

# --- 2. ROBUSTLY DETERMINE PLOT LIMITS FROM TRAJECTORY ---

# --- DEFINITIVE FIX: CLEAN DATA BEFORE CALCULATING BOUNDS ---
try:
    # Get the full history of the particle's mass marker positions.
    all_Ym_x = trajectory.particles.particles[0].Ym_x
    all_Ym_y = trajectory.particles.particles[0].Ym_y

    # Step 1: Create a single, flat array of all coordinates, skipping empty frames.
    all_valid_x = np.concatenate([arr for arr in all_Ym_x if arr.size > 0])
    all_valid_y = np.concatenate([arr for arr in all_Ym_y if arr.size > 0])

    # Step 2: Filter out any non-finite values (NaN or Inf) from the flattened arrays.
    all_finite_x = all_valid_x[np.isfinite(all_valid_x)]
    all_finite_y = all_valid_y[np.isfinite(all_valid_y)]

    # Step 3: Check if any valid, finite data remains. If not, raise an error.
    if all_finite_x.size == 0 or all_finite_y.size == 0:
        raise ValueError("No finite particle data found in the entire trajectory.")

    # Now it is completely safe to find the min and max coordinates.
    min_x, max_x = np.min(all_finite_x), np.max(all_finite_x)
    min_y, max_y = np.min(all_finite_y), np.max(all_finite_y)

    # Define a padding factor for a wider view.
    padding = 2.5

    # Set the final, fixed plot limits.
    x_lims = [min_x - padding, max_x + padding]
    y_lims = [min_y - padding, max_y + padding]

except (ValueError, IndexError) as e:
    # This block runs if the data is empty, contains only NaNs, or has other issues.
    print(f"Warning: Could not determine animation bounds automatically ({e}). Using default limits.")
    x_lims = [0, 16] # Fallback to safe default limits.
    y_lims = [0, 16]
# --- END FIX ---


# --- 3. DEFINE THE ANIMATION FUNCTION (CALLED FOR EACH FRAME) ---

def animate_velocity(i):
    """This function is called for each frame of the animation to draw the velocity field."""
    ax.clear()
    ax.set_facecolor('white')
    u_data, v_data = get_velocity_at_centers(trajectory, i)
    speed = np.sqrt(u_data**2 + v_data**2)

    # Plot the fluid speed as a background color map.
    ax.pcolormesh(X, Y, speed, cmap='Blues', shading='gouraud')

    # Overlay quiver arrows to show velocity direction.
    skip = 30
    ax.quiver(X[::skip, ::skip], Y[::skip, ::skip],
              u_data[::skip, ::skip], v_data[::skip, ::skip],
              color='black', scale=25)

    # Plot the particle's shape using the saved trajectory data.
    particle_at_i = trajectory.particles.particles[0]

    # --- PER-FRAME FIX: SKIP DRAWING THE PARTICLE IF ITS DATA IS MISSING ---
    xp = particle_at_i.xp[i]
    yp = particle_at_i.yp[i]

    # Only attempt to draw the particle if its coordinate arrays are not empty.
    if xp.size > 0 and yp.size > 0:
        ax.fill(xp, yp, 'steelblue', alpha=1.0, zorder=2)
        ax.plot(xp, yp, color='lightsteelblue', linewidth=2.5, zorder=3)

    # Optionally plot the internal mass markers, also with a safety check.
    if show_mass_markers:
        Ym_x = particle_at_i.Ym_x[i]
        Ym_y = particle_at_i.Ym_y[i]
        if Ym_x.size > 0 and Ym_y.size > 0:
            ax.plot(Ym_x, Ym_y, 'r--', linewidth=1.5, zorder=4, label='Mass Markers')
            if i == 0:
                 ax.legend()
    # --- END FIX ---

    current_t = i * inner_steps * dt
    ax.set_title(f"Deformable Particle Simulation (Velocity Field) | Time: {current_t:.4f} s")

    # Use our safely calculated (or default) fixed axis limits for every frame.
    ax.set_xlim(x_lims)
    ax.set_ylim(y_lims)
    ax.set_aspect('equal', adjustable='box')


# --- 4. CONFIGURE AND CREATE THE ANIMATION ---
print("Creating velocity animation...")

# --- CONTROLS ---
animation_interval = 200
video_fps = 10
show_mass_markers = False

num_frames = trajectory.particles.particles[0].xp.shape[0]
anim_vel = animation.FuncAnimation(fig, animate_velocity, frames=num_frames, interval=animation_interval, blit=False)

with tqdm(total=num_frames, desc="Rendering Velocity Animation") as pbar:
    def progress_update(current_frame, total_frames):
        pbar.update(1)
    writer = animation.FFMpegWriter(fps=video_fps)
    anim_vel.save('deformable_particle_velocity.mp4', writer=writer, progress_callback=progress_update)

print("Animation saved as 'deformable_particle_velocity.mp4'")

# --- 5. DISPLAY THE ANIMATION IN THE NOTEBOOK ---
plt.close()
HTML(anim_vel.to_jshtml())