In [5]:
import jax
import jax.numpy as jnp
import tree_math

In [92]:
@tree_math.struct 
class HeapState:
    heap: jnp.ndarray
    current_index: int
    max_index: int


def create_heap(max_size=1000):
    return HeapState(heap=jnp.full(max_size, jnp.inf), current_index=0, max_index=max_size)

@jax.jit
def parent(i):
    return (i - 1) // 2

@jax.jit
def left_child(i):
    return 2 * i + 1

@jax.jit
def right_child(i):
    return 2 * i + 2

@jax.jit
def swap(heap, i, j):
    temp = heap[i]
    heap = heap.at[i].set(heap[j])
    heap = heap.at[j].set(temp)
    return heap

@jax.jit
def insert(heap_state, value):
    heap = heap_state.heap
    size = heap_state.current_index
    
    heap = heap.at[size].set(value)
    
    def cond_fun(carry):
        i, heap = carry
        return i > 0 
    
    def body_fun(carry):
        i, heap = carry
        parent_idx = parent(i)
        heap, did_swap = jax.lax.cond(
            heap[i] < heap[parent_idx],
            lambda: (swap(heap, i, parent_idx), True),
            lambda: (heap, False)
        )
        return jax.lax.cond(
            did_swap,
            lambda: (parent_idx, heap),
            lambda: (0, heap)  # Force termination if no swap occurred
        )
    
    _, heap = jax.lax.while_loop(cond_fun, body_fun, (size, heap))
    
    return HeapState(heap, size + 1, heap.shape)


@jax.jit
def extract_min(heap_state):
    heap = heap_state.heap
    size = heap_state.current_index
    
    min_val = heap[0]
    heap = heap.at[0].set(heap[size - 1])
    heap = heap.at[size - 1].set(jnp.inf)
    size = size - 1
    
    def cond_fun(carry):
        i, _ = carry
        return i < size
    
    def body_fun(carry):
        i, heap = carry
        left = left_child(i)
        right = right_child(i)
        
        min_index = jax.lax.cond(
            (left < size) & (heap[left] < heap[i]),
            lambda: left,
            lambda: i
        )
        min_index = jax.lax.cond(
            (right < size) & (heap[right] < heap[min_index]),
            lambda: right,
            lambda: min_index
        )
        
        heap, did_swap = jax.lax.cond(
            min_index != i,
            lambda: (swap(heap, i, min_index), True),
            lambda: (heap, False)
        )
        
        return jax.lax.cond(
            did_swap,
            lambda: (min_index, heap),
            lambda: (size, heap)  # Force termination if no swap occurred
        )
    
    _, heap = jax.lax.while_loop(cond_fun, body_fun, (0, heap))
    
    return HeapState(heap, size, heap.shape), min_val

In [100]:
max_size = 100
heap = create_heap(max_size=max_size)

heap = insert(heap, 5.3)
heap = insert(heap, 3.1)
heap = insert(heap, 1.1)
heap = insert(heap, 10.2)
heap = insert(heap, 2.2)
heap = insert(heap, 100.2)
heap = insert(heap, .2)
heap = insert(heap, .1)



heap, min_val = extract_min(heap)
print(min_val)
heap, min_val = extract_min(heap)
print(min_val)
heap, min_val = extract_min(heap)
print(min_val)
heap, min_val = extract_min(heap)
print(min_val)
heap, min_val = extract_min(heap)
print(min_val)
heap, min_val = extract_min(heap)
print(min_val)

0.1
0.2
1.1
2.2
3.1
5.3


In [102]:
import jraph

In [None]:
@tree_math.struct
class NodeState:
    mass : jnp.ndarray = None
    position : jnp.ndarray = None
    momentum : jnp.ndarray = None
    kinetic_energy : jnp.ndarray = None

@tree_math.struct
class EdgeState:
    spring_constant : jnp.ndarray = None
    rest_length : jnp.ndarray = None
    hookes_potential : jnp.ndarray = None


@tree_math.struct
class GlobalState:
    hamiltonian : jnp.ndarray = None


def hookes_hamiltonian_from_graph_fn(graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
    
    def update_edge_fn(edges, senders, receivers, globals_):
        del globals_
        distance = jnp.linalg.norm(senders.position - receivers.position)
        hookes_potential_per_edge = 0.5 * edges.spring_constant * distance ** 2
        return EdgeState(
            spring_constant=edges.spring_constant,
            rest_length=edges.rest_length,
            hookes_potential=hookes_potential_per_edge
        )


    def update_node_fn(nodes, sent_edges, received_edges, globals_):
        del sent_edges, received_edges, globals_
        momentum_norm = jnp.linalg.norm(nodes.momentum)
        kinetic_energy_per_node = momentum_norm ** 2 / (2 * nodes.mass)
        return NodeState(
            mass=nodes.mass,
            position=nodes.position,
            momentum=nodes.momentum,
            kinetic_energy=kinetic_energy_per_node
        )
    
    def update_global_fn(nodes, edges, globals_):
        del globals_
        # At this point we will receive node and edge features aggregated (summed)
        # for all nodes and edges in each graph.
        hamiltonian_per_graph = nodes.kinetic_energy + edges.hookes_potential
        return GlobalState(hamiltonian=hamiltonian_per_graph)

    gn = jraph.GraphNetwork(
        update_edge_fn=update_edge_fn,
        update_node_fn=update_node_fn,
        update_global_fn=update_global_fn)

    return gn(graph)

In [None]:
import numpy as np
from typing import Tuple, Callable


def get_random_uniform_norm2d_vectors(
    min_norm: float, max_norm: float, num_particles: int) -> np.ndarray:
  """Returns 2-d vectors with random norms."""
  norm = np.random.uniform(min_norm, max_norm, [num_particles, 1])
  angle = np.random.uniform(0, 2*np.pi, [num_particles])
  return norm * np.stack([np.cos(angle), np.sin(angle)], axis=-1)


def get_fully_connected_senders_and_receivers(
    num_particles: int, self_edges: bool = False,
    ) -> Tuple[np.ndarray, np.ndarray]:
  """Returns senders and receivers for fully connected particles."""
  particle_indices = np.arange(num_particles)
  senders, receivers = np.meshgrid(particle_indices, particle_indices)
  senders, receivers = senders.flatten(), receivers.flatten()
  if not self_edges:
    mask = senders != receivers
    senders, receivers = senders[mask], receivers[mask]
  return senders, receivers



def build_hookes_particle_state_graph(num_particles: int) -> jraph.GraphsTuple:
  """Generates a graph representing a Hooke's system in a random state."""

  mass = np.random.uniform(0, 5, [num_particles])
  velocity = get_random_uniform_norm2d_vectors(0, 0.1, num_particles)
  position = get_random_uniform_norm2d_vectors(0, 1, num_particles)
  momentum = velocity * np.expand_dims(mass, axis=-1)
  # Remove average momentum, so center of mass does not move.
  momentum = momentum - momentum.mean(0, keepdims=True)

  # Connect all particles to all particles.
  particle_indices = np.arange(num_particles)
  senders, receivers = np.meshgrid(particle_indices, particle_indices)
  senders, receivers = senders.flatten(), receivers.flatten()

  # Generate a symmetric random matrix of spring constants.
  # Generate random elements stringly in the lower triangular part.
  spring_constants = np.random.uniform(
      1e-2, 1e-1, [num_particles, num_particles])
  spring_constants = np.tril(
      spring_constants) + np.tril(spring_constants, -1).T
  spring_constants = spring_constants.flatten()

  # Remove interactions of particles to themselves.
  mask = senders != receivers
  senders, receivers = senders[mask], receivers[mask]
  spring_constants = spring_constants[mask]
  num_interactions = receivers.shape[0]

  return jraph.GraphsTuple(
      n_node=np.asarray([num_particles]),
      n_edge=np.asarray([num_interactions]),
      nodes=NodeState(
          mass=mass,  # Scalar mass for each particle.
          position=position,  # 2d position for each particle.
          momentum=momentum,  # 2d momentum for each particle.
      ),
      edges=EdgeState(
         spring_constant=spring_constants,
      ),
      globals={},
      senders=senders,
      receivers=receivers)