Skip to content

Commit

Permalink
Mock length method (GalacticDynamics#90)
Browse files Browse the repository at this point in the history
* Add len to PhaseSpacePosition
* cleanup mockstream

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Jan 25, 2024
1 parent 88a12f6 commit ac6b6a9
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
6 changes: 6 additions & 0 deletions src/galax/dynamics/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ def qp(self) -> BatchVec6:
p = xp.broadcast_to(self.p, batch_shape + component_shapes[1:2])
return xp.concat((q, p), axis=-1)

# ==========================================================================

def __len__(self) -> int:
"""Return the number of particles."""
return self.shape[0]

# ==========================================================================
# Dynamical quantities

Expand Down
7 changes: 3 additions & 4 deletions src/galax/dynamics/mockstream/_mockstream_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,7 @@ def _run_vmap( # TODO: output shape depends on the input shape
Better for GPU usage.
"""
qp0_lead = mock0_lead.qp
qp0_trail = mock0_trail.qp
t_f = ts[-1] + 0.01
t_f = ts[-1] + 0.01 # TODO: not have the bump in the final time.

# TODO: make this a separated method
@jax.jit # type: ignore[misc]
Expand All @@ -119,9 +117,10 @@ def single_particle_integrate(
).qp[-1]
return qp_lead, qp_trail

qp0_lead = mock0_lead.qp
particle_ids = xp.arange(len(qp0_lead))
integrator = jax.vmap(single_particle_integrate, in_axes=(0, 0, 0))
lead_arm_qp, trail_arm_qp = integrator(particle_ids, qp0_lead, qp0_trail)
lead_arm_qp, trail_arm_qp = integrator(particle_ids, qp0_lead, mock0_trail.qp)
return lead_arm_qp, trail_arm_qp

@partial_jit(static_argnames=("seed_num", "vmapped"))
Expand Down

0 comments on commit ac6b6a9

Please sign in to comment.