TODO:

1. Parallelize Sim
2. Hook Sim up to a controller

In [21]:
from typing import List, Sequence, Tuple
import numpy as np
import jax
import jax.numpy as jnp

from drawing import draw_lines


def cross_product2d(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
  if a.shape[-2] == 0 or b.shape[-2] == 0:
    return jnp.zeros((*a.shape[:-2], 0), dtype=a.dtype)
  return a[..., 0] * b[..., 1] - a[..., 1] * b[..., 0]


@jax.jit(static_argnums=(0, 10))
def step_fn(num_segments, origins, velocities, angles, angular_velocities, connection_points1, connection_points2, connections, collision_segments, collision_points, batch_size, control):
    masses = 1
    moments_of_inertia = 0.1
    connection_restoring_force_factor = 0.05
    connection_restoring_force_limit = 0.00005
    angular_damping_factor = 0.1
    ground_friction_coefficient = 0.4

    # gravity
    velocities = velocities.at[:, :, 1].add(-0.000001)
    
    # Compute rotation matrices
    rotation_matrices = jnp.zeros((batch_size, num_segments, 2, 2))
    rotation_matrices = rotation_matrices.at[:, :, 0, 0].set(jnp.cos(angles))
    rotation_matrices = rotation_matrices.at[:, :, 1, 0].set(jnp.sin(angles))
    rotation_matrices = rotation_matrices.at[:, :, 0, 1].set(-jnp.sin(angles))
    rotation_matrices = rotation_matrices.at[:, :, 1, 1].set(jnp.cos(angles))
    
    connection_points1_rot = jnp.matmul(rotation_matrices[:, connections[:, 0]], connection_points1[:, :, None]).squeeze(-1)
    connection_points2_rot = jnp.matmul(rotation_matrices[:, connections[:, 1]], connection_points2[:, :, None]).squeeze(-1)
    connection_discrepancy = (
      (connection_points2_rot + origins[:, connections[:, 1]])
      - (connection_points1_rot + origins[:, connections[:, 0]])
    )
    # connection restoring forces
    force_vectors1 = connection_restoring_force_factor * jnp.linalg.norm(connection_discrepancy, axis=-1, keepdims=True) * connection_discrepancy
    force_vectors1 = jnp.clip(force_vectors1, min=-connection_restoring_force_limit, max=connection_restoring_force_limit)
    force_vectors2 = -force_vectors1

    total_linear_forces = jnp.zeros_like(origins)
    total_linear_forces = total_linear_forces.at[:, connections[:, 0]].add(force_vectors1)
    total_linear_forces = total_linear_forces.at[:, connections[:, 1]].add(force_vectors2)

    total_angular_forces = jnp.zeros_like(angular_velocities)
    total_angular_forces = total_angular_forces.at[:, connections[:, 0]].add(cross_product2d(connection_points1_rot, force_vectors1))
    total_angular_forces = total_angular_forces.at[:, connections[:, 1]].add(cross_product2d(connection_points2_rot, force_vectors2))

    # angular damping
    total_angular_forces -= angular_damping_factor * jnp.abs(angular_velocities) * angular_velocities
    total_angular_forces = total_angular_forces.at[:, connections[:, 0]].add(-control).at[:, connections[:, 0]].add(control)

    r = jnp.matmul(rotation_matrices[:, collision_segments], collision_points[:, :, None]).squeeze(-1)
    
    # v = vo + w x r = [fx / m, fy / m] + (rx * fy - ry * fx) / I * [-ry, rx]
    # va = vo + w x r = [vox, voy] + w * [-ry, rx]
    # vay = voy + w * rx
    # vx = fx / m - (rx * fy - ry * fx) / I * ry
    # vy = fy / m + (rx * fy - ry * fx) / I * rx
    # vy = fy / m + rx^2 * fy / I
    # vy = fy * (1 / m + rx^2 / I)
    # fy = -vay / (1 / m + rx^2 / I)
    # --- # fx = min(vx * m, fy * .2)
    # vo = f / m, w = r x f / I = (rx * fy - ry * fx) / I

    
    # if fx <= 0.2 * fy:
    # -vax = fx / m - (rx * fy - ry * fx) / I * ry
    # -vay = fy / m + (rx * fy - ry * fx) / I * rx
    # 
    # (fx / m + vax) / ry = (rx * fy - ry * fx) / I
    # -(vay + fy / m) / rx = (rx * fy - ry * fx) / I
    # (fx / m + vax) * rx = -(vay + fy / m) * ry
    # fx * rx = -(vay * m + fy) * ry - vax * rx * m
    # fx * rx = -fy * ry - m * (vay * ry - vax * rx)

    # -vay = fy / m + (rx * fy * rx - ry * fx * rx) / I
    # -vay = fy / m + (rx * fy * rx - ry * (-fy * ry - m * (vay * ry - vax * rx))) / I
    # -vay = fy / m + (rx * fy * rx + fy * ry^2 + ry * m * (vay * ry - vax * rx)) / I
    # -vay = fy / m + (rx * fy * rx + fy * ry^2) / I  + ry * m * (vay * ry - vax * rx) / I
    # -vay = fy / m + fy * (rx^2 + ry^2) / I  + ry * m * (vay * ry - vax * rx) / I
    # -vay = fy * (1 / m + (rx^2 + ry^2) / I) + ry * m * (vay * ry - vax * rx) / I
    # fy = -(vay + ry * m * (vay * ry - vax * rx) / I) / (1 / m + (rx^2 + ry^2) / I)

    # if fx > 0.2 * fy:
    # fx = abs(0.2 * fy) * -sign(vax)
    # -vay = fy / m + (rx * fy - ry * fx) / I * rx
    # -vay = fy * (1 / m + rx^2 / I + rx * ry * 0.2 * sign(vax) / I)
    # fy = -vay / (1 / m + rx^2 / I + rx * ry * 0.2 * sign(vax) / I)
    distance_below_ground = jnp.clip(-1.5 - (r + origins[:, collision_segments])[..., 1], min=0)
    vp = jnp.stack((
      velocities[:, collision_segments, 0] - angular_velocities[:, collision_segments] * r[..., 1],
      velocities[:, collision_segments, 1] + angular_velocities[:, collision_segments] * r[..., 0]
    ), axis=-1)
    collision_mask = (distance_below_ground > 0) & (vp[..., 1] < 0)

    # Process all collision points with mask (static shape approach)
    r_all = r.reshape(-1, 2)
    vp_all = vp.reshape(-1, 2)
    distance_below_ground_all = distance_below_ground.flatten()
    collision_mask_all = collision_mask.flatten()
    
    # Create batch indices for scatter operations
    batch_indices = jnp.arange(batch_size)[:, None].repeat(len(collision_segments), axis=1)
    batch_indices_all = batch_indices.flatten()
    collision_segments_all = jnp.tile(collision_segments[None, :], (batch_size, 1)).flatten()

    # Compute forces for all points, will multiply by mask later
    m_vay_ry_vax_rx = masses * (vp_all[:, 1] * r_all[:, 1] - vp_all[:, 0] * r_all[:, 0])
    fy = -(vp_all[:, 1] + r_all[:, 1] * m_vay_ry_vax_rx / moments_of_inertia) / (1 / masses + jnp.square(r_all).sum(-1) / moments_of_inertia)
    fx_rx = -fy * r_all[:, 1] - m_vay_ry_vax_rx

    fy_friction_capped = -vp_all[:, 1] / (1 / masses + jnp.square(r_all[:, 0]) / moments_of_inertia + vp_all.prod(-1) * ground_friction_coefficient * jnp.sign(vp_all[:, 0]) / moments_of_inertia)
    fx_friction_capped = -jnp.sign(vp_all[:, 0]) * ground_friction_coefficient * jnp.abs(fy)
    friction_capped = jnp.abs(fx_rx) > jnp.abs(r_all[:, 0] * ground_friction_coefficient * fy)
    # fy = -vp[:, 1] / (1 / masses + np.square(r[:, 0]) / moments_of_inertia)
    
    # print(fy.shape)
    # Use where to select between friction-capped and non-friction-capped forces
    fx = jnp.where(fx_rx == 0, 0, fx_rx / r_all[:, 0])
    collision_forces_x = jnp.where(friction_capped, fx_friction_capped, fx)
    collision_forces_y = jnp.where(friction_capped, fy_friction_capped, fy)
    collision_forces = jnp.stack([collision_forces_x, collision_forces_y], axis=-1)
    
    # Apply collision mask - zero out forces for non-colliding points
    collision_forces = collision_forces * collision_mask_all[:, None]
    distance_below_ground_all = distance_below_ground_all * collision_mask_all
    # collision_forces = np.stack((np.zeros_like(fy), fy), axis=-1)
    # ground_restoring_force = np.clip(0.001 * np.stack((np.zeros(collision_points.shape[0]), distance_below_ground), axis=-1), max=0.0001)
    # print(collision_forces)
    # collision_forces = np.clip(collision_forces, -0.005, 0.005)
    
    # Use flattened batch dimension for scatter operations
    flat_segment_indices = batch_indices_all * num_segments + collision_segments_all
    total_linear_forces_flat = total_linear_forces.reshape(-1, 2)
    total_linear_forces_flat = total_linear_forces_flat.at[flat_segment_indices[:, None].repeat(2, axis=1), jnp.arange(2)].add(collision_forces)
    total_linear_forces = total_linear_forces_flat.reshape(batch_size, num_segments, 2)
    
    total_angular_forces_flat = total_angular_forces.reshape(-1)
    total_angular_forces_flat = total_angular_forces_flat.at[flat_segment_indices].add(cross_product2d(r_all, collision_forces))
    total_angular_forces = total_angular_forces_flat.reshape(batch_size, num_segments)
    
    # Update ground penetration
    origins_y_flat = origins[..., 1].reshape(-1)
    segment_indices_2d = batch_indices_all * num_segments + collision_segments_all
    origins_y_flat = origins_y_flat.at[segment_indices_2d].add(distance_below_ground_all)
    origins = origins.at[..., 1].set(origins_y_flat.reshape(batch_size, num_segments))


    velocities = velocities + total_linear_forces / masses
    angular_velocities = angular_velocities + total_angular_forces / moments_of_inertia
    origins = origins + velocities
    angles = angles + angular_velocities
    
    return origins, velocities, angles, angular_velocities


class Sim:
  class Data:
    def __init__(self) -> None:
      pass

    @property
    def pos(self) -> jnp.ndarray:
      pass

    @property
    def theta(self) -> jnp.ndarray:
      pass

    @property
    def v(self) -> jnp.ndarray:
      pass

    @property
    def w(self) -> jnp.ndarray:
      pass

    @property
    def rotation_matrices(self) -> jnp.ndarray:
      pass
    
  def __init__(
    self,
    positions: np.ndarray,
    segments: List[Tuple[int, int, int]],  # center of mass, render point 1, render point 2
    connections: List[Tuple[int, int, int]],
    collision_points: List[Tuple[int, int]],  # segment, point
    batch_size: int = 1,
    device: str = 'cuda' if jax.devices()[0].platform == 'gpu' else 'cpu'
  ):
    self.device = device
    self.batch_size = batch_size
    
    positions = jnp.array(positions, dtype=jnp.float32)
    origins = jnp.stack([positions[origin] for origin, _, __ in segments])
    self.origins = jnp.tile(origins[None, :, :], (batch_size, 1, 1))
    self.velocities = jnp.zeros((batch_size, len(segments), 2))
    self.angles = jnp.zeros((batch_size, len(segments)))
    self.angular_velocities = jnp.zeros((batch_size, len(segments)))
    
    connections = jnp.array(connections, dtype=jnp.int32)
    self.connections = connections[:, :2]
    self.connection_points1 = positions[connections[:, 2]] - origins[connections[:, 0]]
    self.connection_points2 = positions[connections[:, 2]] - origins[connections[:, 1]]
    
    self.render_points1 = positions[jnp.array([pos1 for _, pos1, __ in segments])] - origins
    self.render_points2 = positions[jnp.array([pos2 for _, __, pos2 in segments])] - origins
    
    self.collision_segments = jnp.array([segment for segment, _ in collision_points], dtype=jnp.int32)
    self.collision_points = positions[jnp.array([point for _, point in collision_points])] - origins[self.collision_segments]

  def step(self, control: jax.Array | None = None):
    num_segments = self.origins.shape[1]
    if control is None:
      control = jnp.zeros(self.connections.shape[0])
    elif len(control.shape) == 2:
      control = control.reshape((self.batch_size, *control.shape))
    self.origins, self.velocities, self.angles, self.angular_velocities = step_fn(
        num_segments, self.origins, self.velocities, self.angles, self.angular_velocities,
        self.connection_points1, self.connection_points2, self.connections,
        self.collision_segments, self.collision_points, self.batch_size, control
    )

  def render(self, batch_idx: int, origin: Tuple[float, float], resolution: Tuple[int, int], pixels_per_unit: float) -> np.ndarray:
    rotation_matrices = self._rotation_matrices()[batch_idx]
    points1 = jnp.matmul(rotation_matrices, self.render_points1[:, :, None]).squeeze(-1) + self.origins[batch_idx]
    points2 = jnp.matmul(rotation_matrices, self.render_points2[:, :, None]).squeeze(-1) + self.origins[batch_idx]
    
    points1 = np.array(points1)
    points2 = np.array(points2)

    def cvt(points: np.ndarray):
      points = (points - origin) * pixels_per_unit
      points[:, 1] *= -1
      points = np.flip(points, axis=-1)
      points += np.array(resolution) / 2
      return points

    canvases = np.zeros((1, *resolution))
    draw_lines(canvases, cvt(np.array([[-5.0, -1.5]])), cvt(np.array([[5.0, -1.5]])), sample_count=20, width=0.1 * pixels_per_unit, width_pass_count=int(0.2 * pixels_per_unit))
    for start, end in zip(cvt(points1), cvt(points2)):
      draw_lines(canvases, start[None, :], end[None, :], sample_count=100, width=0.1 * pixels_per_unit, width_pass_count=int(0.2 * pixels_per_unit))
    return canvases[0]

  def _rotation_matrices(self) -> jnp.ndarray:
    rotation_matrices = jnp.zeros((self.batch_size, self.origins.shape[1], 2, 2))
    rotation_matrices = rotation_matrices.at[:, :, 0, 0].set(jnp.cos(self.angles))
    rotation_matrices = rotation_matrices.at[:, :, 1, 0].set(jnp.sin(self.angles))
    rotation_matrices = rotation_matrices.at[:, :, 0, 1].set(-jnp.sin(self.angles))
    rotation_matrices = rotation_matrices.at[:, :, 1, 1].set(jnp.cos(self.angles))
    return rotation_matrices

    
sim = Sim(
  positions=np.array([[-2.0, 0.0], [-1.0, 0.0], [0.0, 0.0], [1.0, 0.0], [2.0, 0.0]]),
  segments=[(1, 0, 2), (3, 2, 4)],
  connections=[(0, 1, 2)],
  collision_points=[(0, 0), (1, 4)]
)
print(sim.connection_points1, sim.connection_points2)

[[1. 0.]] [[-1.  0.]]


In [53]:
# sim.origins[0] = np.array([-0.9, 1.0])
import matplotlib.pyplot as plt
print('origins:', sim.origins)
print('angles:', sim.angles)

sim.step()
plt.imshow(sim.render(0, (0, 0), (500, 500), 50))

origins: [[[-1.  0.]
  [ 1.  0.]]]
angles: [[0. 0.]]


<matplotlib.image.AxesImage at 0x2c2ca103e00>

In [52]:
sim = Sim(
  positions=np.array([[-2.0, 0.0], [-1.0, 0.0], [0.0, 0.0], [1.0, 0.0], [2.0, 0.0]]),
  segments=[(1, 0, 2), (3, 2, 4)],
  connections=[(0, 1, 2)],
  collision_points=[(0, 0), (1, 4)]
)
sim.angular_velocities = sim.angular_velocities.at[0].set(0.01)
sim.angular_velocities = sim.angular_velocities.at[1].set(-0.01)
# sim.velocities[0] = np.array([0.01, 0.01])

In [None]:
%matplotlib qt

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.animation import FuncAnimation

sim = Sim(
  positions=np.array([[-0.5, -1.0], [-0.5, 0.0], [-0.5, 1.0], [0.0, 1.0], [0.5, 1.0], [0.5, 0.0], [0.5, -1.0]]),
  segments=[(1, 0, 2), (3, 2, 4), (5, 4, 6)],
  connections=[(0, 1, 2), (1, 2, 4)],
  collision_points=[(0, 0), (2, 6), (1, 2), (1, 4)],
  batch_size=4096*4,
  device='cuda'
)
sim.velocities = sim.velocities.at[:, 1, 0].set(0.0001)
# sim = Sim(
#   positions=np.array([[-2.0, 0.0], [-1.0, 0.0], [0.0, 0.0], [1.0, 0.0], [2.0, 0.0]]),
#   segments=[(1, 0, 2), (3, 2, 4)],
#   connections=[(0, 1, 2)],
#   collision_points=[(0, 0), (1, 4)]
# )
# sim.angular_velocities[0] = 0.01
# sim.angular_velocities[1] = -0.01
# sim.velocities[0] = np.array([0.01, 0.01])

# Create a sample image (replace with your image loading)
fig, ax = plt.subplots()
im = ax.imshow(sim.render(0, (0, 0), (500, 500), 50))
ax.axis('off')

space_pressed = False
interval_ms = 50  # Update interval in milliseconds

def on_key_press(event):
    global space_pressed
    if event.key == ' ':
        space_pressed = True

def on_key_release(event):
    global space_pressed
    if event.key == ' ':
        space_pressed = False

def update(frame):
    if space_pressed:
        # TODO: Add your image change logic here
        for _ in range(100):
          sim.step()
        im.set_data(sim.render(0, (0, 0), (500, 500), 50))
        # Example: im.set_data(new_image)
        fig.canvas.draw_idle()

fig.canvas.mpl_connect('key_press_event', on_key_press)
fig.canvas.mpl_connect('key_release_event', on_key_release)

ani = FuncAnimation(fig, update, interval=interval_ms, blit=False, cache_frame_data=False)

plt.show()

In [172]:
print(sim.origins)

[[nan nan]
 [nan nan]
 [nan nan]]


In [8]:
for _ in range(10):
  # sim.velocities[:, :, 1] -= 0.0001
  sim.velocities = sim.velocities.at[:, :, 1].add(-0.000001)
  print(sim.origins)
  sim.step()
plt.imshow(sim.render(0, (0, 0), (100, 100), 10))

[[[-0.4990816  -0.00504933]
  [ 0.00816317  0.9949486 ]
  [ 0.50091815 -0.00504933]]]
[[[-0.49906757 -0.00515028]
  [ 0.0082351   0.99484754]
  [ 0.50093216 -0.00515028]]]
[[[-0.49905375 -0.00525223]
  [ 0.00830742  0.99474543]
  [ 0.500946   -0.00525223]]]
[[[-0.49904013 -0.00535517]
  [ 0.00838017  0.9946423 ]
  [ 0.50095963 -0.00535517]]]
[[[-0.49902675 -0.00545911]
  [ 0.00845338  0.9945382 ]
  [ 0.50097305 -0.00545911]]]
[[[-0.4990136  -0.00556405]
  [ 0.00852709  0.99443305]
  [ 0.5009862  -0.00556405]]]
[[[-0.49900073 -0.00566998]
  [ 0.00860131  0.9943269 ]
  [ 0.5009991  -0.00566998]]]
[[[-0.49898812 -0.00577691]
  [ 0.00867608  0.9942197 ]
  [ 0.5010117  -0.00577691]]]
[[[-0.4989758  -0.00588483]
  [ 0.00875144  0.99411154]
  [ 0.50102407 -0.00588483]]]
[[[-0.4989638  -0.00599374]
  [ 0.00882741  0.99400234]
  [ 0.5010361  -0.00599374]]]


<matplotlib.image.AxesImage at 0x2c2b70ff590>

In [36]:
from typing import Callable, TypeVar
import ipywidgets as widgets
from PIL import Image
import io
import threading
import time

T = TypeVar('T')
def simple_viewer(reset: Callable[[], T], advance: Callable[[T], T], render: Callable[[T], np.ndarray], fps=10):
    """
    Ultra-simple image viewer for Jupyter notebooks.
    
    Args:
        images: List of (h, w, 3) numpy arrays
        fps: Frames per second when playing
        windowed: If True, opens in a separate Qt window instead of inline
    
    Returns:
        The viewer widget (displayed automatically in Jupyter)
        or None if windowed=True
    """
    
    # If windowed mode requested, use Qt
    # if windowed:
    #     viewer = ImageViewer(mode='standalone')
    #     viewer.set_images(images)
    #     viewer.fps = fps
    #     viewer.show()
    #     return None
    
    # State
    state = {'index': 0, 'playing': threading.Event(), 'state': reset(), 'thread': None}
    
    def to_png(img):
        if img.max() <= 1.0:
            img = (img * 255).astype(np.uint8)
        pil_img = Image.fromarray(img)
        buf = io.BytesIO()
        pil_img.save(buf, format='PNG')
        return buf.getvalue()
    
    # Widgets
    img_widget = widgets.Image(value=to_png(render(state['state'])), format='png', width=600)
    label = widgets.Label(value=f'Timestep 0')
    reset_btn = widgets.Button(description='Reset', button_style='primary')
    play_btn = widgets.Button(description='▶ Play', button_style='success')
    next_btn = widgets.Button(description='Next', button_style='primary')
    
    def on_reset(b):
      state['playing'].clear()
      if state['thread'] is not None:
        state['thread'].join()
      play_btn.description = '▶ Play'
      play_btn.button_style = 'success'
      state['state'] = reset()
      state['index'] = 0
      img_widget.value = to_png(render(state['state']))
      label.value = f"Timestep {state['index']}"

    def update():
        state['state'] = advance(state['state'])
        img_widget.value = to_png(render(state['state']))
        label.value = f"Timestep {state['index']}"
    
    def on_next(b):
        state['index'] += 1
        update()
    
    def on_play(b):
        flag = state['playing']
        playing = flag.is_set()
        play_btn.description = '⏸ Pause' if not playing else '▶ Play'
        play_btn.button_style = 'warning' if not playing else 'success'
        if playing:
          flag.clear()
        else:
          flag.set()
          state['thread'] = threading.Thread(target=play_loop, daemon=True, args=(flag,))
          state['thread'].start()
        # if playing:
        #   threading.Thread(target=play_loop, daemon=True, args=(flag,)).start()
    
    def play_loop(flag):
        while flag.is_set():
            start_time = time.time()
            on_next(None)
            elapsed_time = time.time() - start_time
            # time.sleep(max(0, 1.0 / fps - elapsed_time))
    
    reset_btn.on_click(on_reset)
    next_btn.on_click(on_next)
    play_btn.on_click(on_play)
    
    container = widgets.VBox([
        widgets.HBox([reset_btn, play_btn, next_btn, label]),
        img_widget
    ])
    
    # display(container)
    return container

def create_sim():
  sim = Sim(
    positions=np.array([[-0.5, -1.0], [-0.5, 0.0], [-0.5, 1.0], [0.0, 1.0], [0.5, 1.0], [0.5, 0.0], [0.5, -1.0]]),
    segments=[(1, 0, 2), (3, 2, 4), (5, 4, 6)],
    connections=[(0, 1, 2), (1, 2, 4)],
    collision_points=[(0, 0), (2, 6), (1, 2), (1, 4)],
    batch_size=16386,
    device='cuda'
  )
  sim.velocities = sim.velocities.at[:, 1, 0].set(0.0001)
  return sim
def step_sim(sim: Sim):
    for _ in range(100):
      sim.step()
    return sim

simple_viewer(
  reset=create_sim,
  advance=step_sim,
  render=lambda sim: sim.render(0, (0, 0), (500, 500), 50)
)

VBox(children=(HBox(children=(Button(button_style='primary', description='Reset', style=ButtonStyle()), Button…