Skip to content

Commit

Permalink
Mockgen lead trail (GalacticDynamics#342)
Browse files Browse the repository at this point in the history
  • Loading branch information
nstarman committed Jun 11, 2024
1 parent a4ef997 commit a153b15
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 43 deletions.
43 changes: 16 additions & 27 deletions src/galax/dynamics/_dynamics/mockstream/df/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,12 @@
from galax.dynamics._dynamics.orbit import Orbit
from galax.potential import AbstractPotentialBase

Wif: TypeAlias = tuple[gt.LengthVec3, gt.LengthVec3, gt.SpeedVec3, gt.SpeedVec3]
Carry: TypeAlias = tuple[
int, PRNGKeyArray, gt.LengthVec3, gt.LengthVec3, gt.SpeedVec3, gt.SpeedVec3
]
Carry: TypeAlias = tuple[gt.LengthVec3, gt.SpeedVec3, gt.LengthVec3, gt.SpeedVec3]


class AbstractStreamDF(eqx.Module, strict=True): # type: ignore[call-arg, misc]
"""Abstract base class of Stream Distribution Functions."""

lead: bool = eqx.field(default=True, static=True)
trail: bool = eqx.field(default=True, static=True)

def __check_init__(self) -> None:
if not self.lead and not self.trail:
msg = "You must generate either leading or trailing tails (or both!)"
raise ValueError(msg)

@partial(jax.jit)
def sample(
self,
Expand Down Expand Up @@ -101,22 +90,22 @@ def sample(

# Scan over the release times to generate the stream particle initial
# conditions at each release time.
def scan_fn(carry: Carry, t: gt.FloatQScalar) -> tuple[Carry, Wif]:
i = carry[0]
rng, subrng = jr.split(carry[1], 2)
out = self._sample(subrng, pot, x[i], v[i], mprog(t), t)
return (i + 1, rng, *out), out
def scan_fn(_: Carry, inputs: tuple[int, PRNGKeyArray]) -> tuple[Carry, Carry]:
i, key = inputs
out = self._sample(key, pot, x[i], v[i], mprog(ts[i]), ts[i])
return out, out

# TODO: use ``jax.vmap`` instead of ``jax.lax.scan`` for GPU usage
# TODO: use ``jax.vmap`` instead of ``jax.lax.scan``?
init_carry = (
0,
rng,
xp.zeros_like(x[0]),
xp.zeros_like(x[0]),
xp.zeros_like(v[0]),
xp.zeros_like(x[0]),
xp.zeros_like(v[0]),
)
x_lead, x_trail, v_lead, v_trail = jax.lax.scan(scan_fn, init_carry, ts)[1]
subkeys = jr.split(rng, len(ts))
x_lead, v_lead, x_trail, v_trail = jax.lax.scan(
scan_fn, init_carry, (xp.arange(len(ts)), subkeys)
)[1]

mock_lead = MockStreamArm(
q=x_lead.to_units(pot.units["length"]),
Expand Down Expand Up @@ -144,7 +133,7 @@ def _sample(
prog_mass: gt.FloatQScalar,
t: gt.FloatQScalar,
) -> tuple[
gt.LengthBatchVec3, gt.LengthBatchVec3, gt.SpeedBatchVec3, gt.SpeedBatchVec3
gt.LengthBatchVec3, gt.SpeedBatchVec3, gt.LengthBatchVec3, gt.SpeedBatchVec3
]:
"""Generate stream particle initial conditions.
Expand All @@ -165,9 +154,9 @@ def _sample(
Returns
-------
x_lead, x_trail : Quantity[float, (*shape, 3), "length"]
Positions of the leading and trailing tails.
v_lead, v_trail : Quantity[float, (*shape, 3), "speed"]
Velocities of the leading and trailing tails.
x_lead, v_lead: Quantity[float, (*shape, 3), "length" | "speed"]
Position and velocity of the leading arm.
x_trail, v_trail : Quantity[float, (*shape, 3), "length" | "speed"]
Position and velocity of the trailing arm.
"""
...
4 changes: 2 additions & 2 deletions src/galax/dynamics/_dynamics/mockstream/df/_fardal.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _sample(
v: gt.SpeedVec3,
prog_mass: gt.FloatQScalar,
t: gt.FloatQScalar,
) -> tuple[gt.LengthVec3, gt.LengthVec3, gt.SpeedVec3, gt.SpeedVec3]:
) -> tuple[gt.LengthVec3, gt.SpeedVec3, gt.LengthVec3, gt.SpeedVec3]:
"""Generate stream particle initial conditions."""
# Random number generation
rng1, rng2, rng3, rng4 = jr.split(rng, 4)
Expand Down Expand Up @@ -90,7 +90,7 @@ def _sample(
x_lead = x - r_tidal * (kr_samp * r_hat - kz_samp * z_hat)
v_lead = v - v_circ * (kvphi_samp * phi_hat - kvz_samp * z_hat)

return x_lead, x_trail, v_lead, v_trail
return x_lead, v_lead, x_trail, v_trail


#####################################################################
Expand Down
26 changes: 12 additions & 14 deletions src/galax/dynamics/_dynamics/mockstream/mockstream_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,19 +244,17 @@ def run(
t = xp.ones_like(ts) * ts.value[-1] # TODO: ensure this time is correct

comps = {}
if self.df.lead:
comps["lead"] = MockStreamArm(
q=Quantity(lead_arm_w[:, 0:3], self.units["length"]),
p=Quantity(lead_arm_w[:, 3:6], self.units["speed"]),
t=t,
release_time=mock0["lead"].release_time,
)
if self.df.trail:
comps["trail"] = MockStreamArm(
q=Quantity(trail_arm_w[:, 0:3], self.units["length"]),
p=Quantity(trail_arm_w[:, 3:6], self.units["speed"]),
t=t,
release_time=mock0["trail"].release_time,
)
comps["lead"] = MockStreamArm(
q=Quantity(lead_arm_w[:, 0:3], self.units["length"]),
p=Quantity(lead_arm_w[:, 3:6], self.units["speed"]),
t=t,
release_time=mock0["lead"].release_time,
)
comps["trail"] = MockStreamArm(
q=Quantity(trail_arm_w[:, 0:3], self.units["length"]),
p=Quantity(trail_arm_w[:, 3:6], self.units["speed"]),
t=t,
release_time=mock0["trail"].release_time,
)

return MockStream(comps), prog_o[-1]

0 comments on commit a153b15

Please sign in to comment.