# JAX Simulation

Python is too slow. We need to use JAX to speed up our simulation.

In [None]:
import jax.numpy as jnp
import jax
import math

def kernel_function_poly6(pts_a, pts_b, h):
    """
    pts_a: [n, 2] array of particle positions
    pts_b: [m, 2] array of particle positions
    h: float, smoothing length
    Returns: [n, ] array of kernel values
    """
    # mag_squared = (pt_a[0] - pt_b[0])**2 + (pt_a[1] - pt_b[1])**2

    # if mag_squared > h**2:
    #     return 0.0

    # const_2d = 4.0 / (math.pi * h**8)
    # return const_2d * (h**2 - mag_squared)**3

    diffs = (pts_a - pts_b)
    mag_squared = jnp.sum(diffs**2, axis=-1)

    # Outside the kernel radius => kernel is zero
    mask = mag_squared > h**2
    mag_squared = jnp.where(mask, 0.0, mag_squared)
    # Inside the kernel radius => kernel is non-zero
    # const_2d = 4.0 / (math.pi * h**8)
    const_2d = 315.0 / (64.0 * math.pi * h**9)
    kernel = const_2d * (h**2 - mag_squared)**3

    # Apply mask to kernel
    ret = jnp.where(mask, 0.0, kernel)
    return ret
    
def kernel_function_gradient_spiky(pts_a, pts_b, h):
    """
    pts_a: [n, 2] array of particle positions
    pts_b: [m, 2] array of particle positions
    h: float, smoothing length
    Returns: [n, 2] array of kernel gradients
    """
    diffs = (pts_a - pts_b)
    mag_squared = jnp.sum(diffs**2, axis=-1)

    # Outside the kernel radius => gradient is zero
    mask = mag_squared > h**2
    mag_squared = jnp.where(mask, 0.0, mag_squared)
    
    # Inside the kernel radius => gradient is non-zero
    mag = jnp.sqrt(mag_squared)
    # Avoid division by zero if particles are at the same spot
    mask_mag_zero = mag == 0.0
    mag = jnp.where(mask_mag_zero, 1.0, mag)

    # 2D Spiky gradient constant
    c = -30.0 / (math.pi * h**5)

    # Multiply by ((h - r)² / r)
    factor = c * (h - mag)**2 / mag

    ret = factor[..., None] * diffs

    # Apply mask to gradient
    ret = jnp.where(mask[..., None], 0.0, ret)
    
    return ret

def kernel_function_viscosity_laplacian(pts_a, pts_b, h):
    """
    pts_a: [n, 2] array of particle positions
    pts_b: [m, 2] array of particle positions
    h: float, smoothing length
    Returns: [n, ] array of kernel viscosity laplacian values
    """
    diffs = (pts_a - pts_b)
    r2 = jnp.sum(diffs**2, axis=-1)

    # Outside the kernel radius => Laplacian is zero
    mask = r2 > h**2
    r2 = jnp.where(mask, 0.0, r2)

    # This is basically a cone of radius h and height that makes the volume = 1
    # Volume of a cone is 1/3 * pi * r^2 * h
    # We want the volume to be 1, so we set h = 3 / (pi * r^2)
    # The height of the cone is 3 / (pi * r^2)
    # The volume of the cone is 1

    mag = jnp.sqrt(r2)
    
    height_proportion = 1 - (mag / h)
    volume = 3 / (math.pi * h ** 2)
    
    ret = volume * height_proportion

    # Apply mask to laplacian
    ret = jnp.where(mask, 0.0, ret)
    
    return ret

arr = jnp.array([[x, 0] for x in jnp.linspace(-1.25, 1.25, 1000)])
test_point = jnp.array([[0, 0] for _ in range(1000)])
h = 1.0
kernel_function_poly6_jit = jax.jit(kernel_function_poly6)
kernel_function_gradient_spiky_jit = jax.jit(kernel_function_gradient_spiky)
kernel_function_viscosity_laplacian_jit = jax.jit(kernel_function_viscosity_laplacian)

kernel_values = kernel_function_poly6_jit(arr, test_point, h)
print(kernel_values.shape)
kernel_gradient_values = kernel_function_gradient_spiky_jit(arr, test_point, h)
print(kernel_gradient_values.shape)
kernel_viscosity_laplacian_values = kernel_function_viscosity_laplacian_jit(arr, test_point, h)
print(kernel_viscosity_laplacian_values.shape)

import matplotlib.pyplot as plt

fig, axs = plt.subplots(1, 3, figsize=(12, 4))
axs[0].set_title("Kernel Density Function")
axs[0].plot(arr[:, 0], kernel_values, label="Kernel Function", color="blue")
axs[1].set_title("Kernel Gradient Function")
axs[1].plot(arr[:, 0], kernel_gradient_values[:, 0], label="Kernel Gradient", color="orange")
axs[2].set_title("Kernel Viscosity Laplacian Function")
axs[2].plot(arr[:, 0], kernel_viscosity_laplacian_values, label="Kernel Viscosity Laplacian", color="green")

In [None]:
from functools import partial

def approximate_density(pts_a, pt_b, h):
    """
    pts_a: [n, 2] array of particle positions
    pts_b: [m, 2] array of particle positions
    h: float, smoothing length
    Returns: [n, ] array of density values
    """
    pts_b = jnp.expand_dims(pt_b, axis=0)
    kernel_values = kernel_function_poly6(pts_a, pts_b, h)
    return jnp.sum(kernel_values, axis=-1)

def get_particle_pressure_force(pt, pts, densities, pressures, pressure, h):
    """
    pt: [1, 2] array of particle position
    pts: [n, 2] array of particle positions
    densities: [n, ] array of particle densities
    pressures: [n, ] array of particle pressures
    pressure: float, pressure of the particle
    h: float, smoothing length
    returns: [1, 2] array of pressure force
    """
    # Compute the gradient of the kernel function
    grad_kernel = kernel_function_gradient_spiky(pt, pts, h) # [n, 2] array of gradients

    pressures = -grad_kernel * (pressure + pressures)[:, None] / (2 * densities)[:, None]
    pressure_force = jnp.sum(pressures, axis=0) # [1, 2] array of pressure force

    return pressure_force

def get_particle_viscosity_force(pt, pts, velocities, velocity, densities, h):
    """
    pt: [1, 2] array of particle position
    pts: [n, 2] array of particle positions
    velocities: [n, 2] array of particle velocities
    velocity: [1, 2] array of particle velocity
    densities: [n, ] array of particle densities
    h: float, smoothing length
    returns: [1, 2] array of viscosity force
    """
    #  VISCOSITY * viscosity_force_mult * (neighbor.velocity[0] - particle.velocity[0]) / neighbor.density

    # Compute the gradient of the kernel function
    grad_kernel = kernel_function_viscosity_laplacian(pt, pts, h) # [n, 2] array of gradients
    viscosity_forces = (VISCOSITY * grad_kernel)[:, None] * (velocities - velocity) / (2 * densities)[:, None]
    viscosity_force = jnp.sum(viscosity_forces, axis=0) # [1, 2] array of viscosity force

    return viscosity_force

    

def get_density_map(pts, h):
    resolution=200
    x = jnp.linspace(0, DOMAIN_SIZE, resolution)
    y = jnp.linspace(0, DOMAIN_SIZE, resolution)

    # Perform a vmap
    # Create a grid of points
    grid_x, grid_y = jnp.meshgrid(x, y)
    grid_points = jnp.array([grid_x.flatten(), grid_y.flatten()]).T
    # Compute density for each grid point
    density_values = jax.vmap(approximate_density, in_axes=(None, 0, None))(pts, grid_points, h)
    # Reshape the density values to match the grid shape
    density_map = density_values.reshape(resolution, resolution)
    return density_map

DOMAIN_SIZE = 60.0
KERNEL_RADIUS = 5.0

# Generate random points
import numpy as np
from tqdm import tqdm
np.random.seed(42)
num_points = 100
points = np.random.rand(num_points, 2) * DOMAIN_SIZE
velocities = np.zeros_like(points)  # Initialize velocities to zero

approximate_density_jit = jax.jit(approximate_density)
get_density_map_jit = jax.jit(get_density_map)
density_map = get_density_map_jit(points, KERNEL_RADIUS)

PRESSURE_MULTIPLIER = 1000.0
REST_DENSITY = 0.05

density_function_vmap = jax.vmap(approximate_density_jit, in_axes=(None, 0, None))
densities = density_function_vmap(points, points, KERNEL_RADIUS)


get_particle_pressure_force_jit = jax.jit(get_particle_pressure_force)
pressures = PRESSURE_MULTIPLIER * (densities - REST_DENSITY)
pressure_forces = []
for pt in tqdm(points):
    pressure_force = get_particle_pressure_force_jit(pt, points, densities, pressures, PRESSURE_MULTIPLIER, KERNEL_RADIUS)
    pressure_forces.append(pressure_force)

VISCOSITY = 0.1
get_particle_viscosity_force_jit = jax.jit(get_particle_viscosity_force)
viscosity_forces = []
for pt in tqdm(points):
    velocity = np.array([0.0, 0.0])
    viscosity_force = get_particle_viscosity_force_jit(pt, points, velocities, velocity, densities, KERNEL_RADIUS)
    viscosity_forces.append(viscosity_force)

fig,ax = plt.subplots(figsize=(8, 8))
ax.set_title("Density Map")
cbar = plt.colorbar(ax.imshow(density_map, extent=(0, DOMAIN_SIZE, 0, DOMAIN_SIZE), origin='lower', cmap='hot'))
cbar.set_label('Density')
ax.scatter(points[:, 0], points[:, 1], c='red', s=DOMAIN_SIZE, label='Particles')
ax.quiver(points[:, 0], points[:, 1], 0.001 * np.array(pressure_forces)[:, 0], 0.001 * np.array(pressure_forces)[:, 1], color='blue', scale=50, label='Pressure Forces')
ax.quiver(points[:, 0], points[:, 1], 0.001 * np.array(viscosity_forces)[:, 0], 0.001 * np.array(viscosity_forces)[:, 1], color='green', scale=50, label='Viscosity Forces')
ax.set_xlabel('X')
ax.set_ylabel('Y')

In [None]:
PRESSURE_MULTIPLIER = 1000.0
REST_DENSITY = 0.05
GRAVITY = 1.0
BOUNDS_FACTOR = -0.6
DELTA_TIME = 0.01
VISCOSITY = 0.1

def step(positions, velocities):
    # Start computations
    density_function_vmap = jax.vmap(approximate_density_jit, in_axes=(None, 0, None))
    densities = density_function_vmap(positions, positions, KERNEL_RADIUS)
    pressures = PRESSURE_MULTIPLIER * (densities - REST_DENSITY)

    # Vmap the pressure force
    get_particle_pressure_force_vmap = jax.vmap(get_particle_pressure_force_jit, in_axes=(0, None, None, None, 0, None))
    pressure_forces = get_particle_pressure_force_vmap(positions, positions, densities, pressures, pressures, KERNEL_RADIUS)

    # Vmap the viscosity force
    get_particle_viscosity_force_vmap = jax.vmap(get_particle_viscosity_force_jit, in_axes=(0, None, None, 0, None, None))
    viscosity_forces = get_particle_viscosity_force_vmap(positions, positions, velocities, velocities, densities, KERNEL_RADIUS)

    # Compute net force
    net_forces = jnp.zeros_like(positions)
    net_forces += pressure_forces
    net_forces += viscosity_forces
    net_forces += jnp.array([[0.0, -GRAVITY] for _ in range(len(positions))])
    accelerations = net_forces / densities[:, None]

    # Update velocities and positions
    velocities += accelerations * DELTA_TIME
    positions += velocities * DELTA_TIME

    # Enforce Bounds
    left_mask = positions[:, 0] < KERNEL_RADIUS
    right_mask = positions[:, 0] > (DOMAIN_SIZE - KERNEL_RADIUS)
    top_mask = positions[:, 1] > (DOMAIN_SIZE - KERNEL_RADIUS)
    bottom_mask = positions[:, 1] < KERNEL_RADIUS

    positions = jnp.where(left_mask[:, None], jnp.array([[KERNEL_RADIUS, positions[i, 1]] for i in range(len(positions))]), positions)
    velocities = jnp.where(left_mask[:, None], jnp.array([[BOUNDS_FACTOR * velocities[i, 0], velocities[i, 1]] for i in range(len(velocities))]), velocities)
    positions = jnp.where(right_mask[:, None], jnp.array([[DOMAIN_SIZE - KERNEL_RADIUS, positions[i, 1]] for i in range(len(positions))]), positions)
    velocities = jnp.where(right_mask[:, None], jnp.array([[BOUNDS_FACTOR * velocities[i, 0], velocities[i, 1]] for i in range(len(velocities))]), velocities)
    positions = jnp.where(top_mask[:, None], jnp.array([[positions[i, 0], DOMAIN_SIZE - KERNEL_RADIUS] for i in range(len(positions))]), positions)
    velocities = jnp.where(top_mask[:, None], jnp.array([[velocities[i, 0], BOUNDS_FACTOR * velocities[i, 1]] for i in range(len(velocities))]), velocities)
    positions = jnp.where(bottom_mask[:, None], jnp.array([[positions[i, 0], KERNEL_RADIUS] for i in range(len(positions))]), positions)
    velocities = jnp.where(bottom_mask[:, None], jnp.array([[velocities[i, 0], BOUNDS_FACTOR * velocities[i, 1]] for i in range(len(velocities))]), velocities)

    return positions, velocities, pressure_forces, viscosity_forces

num_points = 100
positions = np.random.rand(num_points, 2) * DOMAIN_SIZE
velocities = jnp.zeros_like(positions)

step_jit = jax.jit(step)
frame = 0
for time_step in tqdm(range(3_000)):
    positions, velocities, pressure_forces, viscosity_forces = step_jit(positions, velocities)

    if time_step % 10 == 0:
        frame += 1
        density_map = get_density_map_jit(positions, KERNEL_RADIUS)
        fig, ax = plt.subplots(2,2, figsize=(8, 8))
        fig.suptitle(f"Time {time_step*DELTA_TIME:.2f}")
        plt.tight_layout()
        

        for axis in ax.flat:
            axis.set_xlim(0, DOMAIN_SIZE)
            axis.set_ylim(0, DOMAIN_SIZE)
            axis.set_aspect('equal')
            

        # Top Left: Smoothed map
        ax[0, 0].set_title("Water Map")
        ax[0, 0].imshow(density_map, extent=(0, DOMAIN_SIZE, 0, DOMAIN_SIZE), vmin=0, vmax=REST_DENSITY, origin='lower', cmap='Blues')

        # Bottom Left: Density map
        ax[1, 0].set_title("Density Map")
        cbar = plt.colorbar(ax[1, 0].imshow(density_map, extent=(0, DOMAIN_SIZE, 0, DOMAIN_SIZE), vmin=0, vmax=REST_DENSITY*2, origin='lower', cmap='hot'))

        # Top Right: Water Droplets with velocity vectors
        ax[0, 1].set_title("Water Droplets")
        for i, pt in enumerate(positions):
            ax[0, 1].add_patch(plt.Circle(pt, 0.5, color='blue', fill=True))
            # Show velocity
            ax[0, 1].quiver(pt[0], pt[1], 0.1 * velocities[i][0], 0.1 * velocities[i][1], color='black', scale=50)

        # Bottom Right: Free Body Diagram of all forces
        ax[1, 1].set_title("Free Body Diagram")
        for i, pt in enumerate(positions):
            ax[1, 1].add_patch(plt.Circle(pt, 0.5, color='black', fill=False))
            # Show pressure force
            ax[1, 1].quiver(pt[0], pt[1], 0.1 * pressure_forces[i][0], 0.1 * pressure_forces[i][1], color='red', scale=50)
            # Show viscosity force
            ax[1, 1].quiver(pt[0], pt[1], 0.1 * viscosity_forces[i][0], 0.1 * viscosity_forces[i][1], color='green', scale=50)
            # Show gravity
            ax[1, 1].quiver(pt[0], pt[1], 0.1 * 0.0, -0.1 * GRAVITY, color='blue', scale=50)
        
        plt.savefig(f"output/frame_{frame:04d}.png")
        plt.close()
        
        
    if time_step % 1000 == 0 and 0 < i < 3000:
        # Add 100 new particles at the top
        new_particles = np.random.rand(100, 2) * DOMAIN_SIZE
        new_particles[:, 1] = new_particles[:, 1] / 2 + (DOMAIN_SIZE / 2)
        
        positions = jnp.concatenate([positions, new_particles], axis=0) 
        velocities = jnp.concatenate([velocities, jnp.zeros((100, 2))], axis=0)


In [None]:
!ffmpeg -framerate 10 -y -i output/frame_%04d.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p output/output.mp4

# Visualization with VisPy

In [None]:
from PyQt5 import QtWidgets
from vispy import scene
from vispy import app
import sys
import numpy as np
app = QtWidgets.QApplication(sys.argv)

In [None]:
import vispy


class Canvas(scene.SceneCanvas):
    
    def __init__(self):
        scene.SceneCanvas.__init__(self, keys='interactive', size=(512, 512), show=True)
        self.unfreeze()  # Unfreeze to add a new attribute

        self.positions = np.random.uniform(0, DOMAIN_SIZE, size=(250, 2))
        self.velocities = np.zeros_like(self.positions)  # Initialize velocities to zero

        self.view = self.central_widget.add_view()
        self.points = scene.visuals.Markers()
        self.points.set_data(self.positions, edge_color=None, face_color=(0, 0.8, 1, 1), size=5)
        self.view.add(self.points)
        self.view.camera = 'panzoom'
        self.view.camera.set_range((0, DOMAIN_SIZE), (0, DOMAIN_SIZE))

        # Warm up the JIT compilation
        step_jit(self.positions, self.velocities)

        self.timer = vispy.app.Timer('auto', connect=self.on_timer, start=True)

    def on_timer(self, event):

        # Update the positions and velocities
        # self.positions += 0.1
        # self.positions = self.positions % DOMAIN_SIZE  # Wrap around the domain
        self.positions, self.velocities, pressure_forces, viscosity_forces = step_jit(self.positions, self.velocities)
        self.positions = np.array(self.positions)
        # Update the points data
        self.points.set_data(self.positions, edge_color=None, face_color=(0, 0.8, 1, 1), size=5)
        
        self.update()  # Repaint the scene

canvas = Canvas()
canvas.show()
app.exec_()