Skip to content

Commit

Permalink
refactor: units in fardal (GalacticDynamics#236)
Browse files Browse the repository at this point in the history
* refactor: units in fardal

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Mar 27, 2024
1 parent aed8a6b commit 72d5efc
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 139 deletions.
74 changes: 51 additions & 23 deletions src/galax/dynamics/_dynamics/mockstream/df/_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""galax: Galactic Dynamix in Jax."""

__all__ = ["AbstractStreamDF", "ProgenitorMassCallable", "ConstantMassProtenitor"]
__all__ = ["AbstractStreamDF"]

import abc
from functools import partial
Expand All @@ -9,7 +9,9 @@
import equinox as eqx
import jax
import quax.examples.prng as jr
from plum import convert

import coordinax as cx
import quaxed.array_api as xp
from unxt import Quantity

Expand All @@ -19,8 +21,20 @@
from galax.dynamics._dynamics.orbit import Orbit
from galax.potential._potential.base import AbstractPotentialBase

Wif: TypeAlias = tuple[gt.Vec3, gt.Vec3, gt.Vec3, gt.Vec3]
Carry: TypeAlias = tuple[int, jr.PRNG, gt.Vec3, gt.Vec3, gt.Vec3, gt.Vec3]
Wif: TypeAlias = tuple[
gt.LengthVec3,
gt.LengthVec3,
gt.SpeedVec3,
gt.SpeedVec3,
]
Carry: TypeAlias = tuple[
int,
jr.PRNG,
gt.LengthVec3,
gt.LengthVec3,
gt.SpeedVec3,
gt.SpeedVec3,
]


class AbstractStreamDF(eqx.Module, strict=True): # type: ignore[call-arg, misc]
Expand Down Expand Up @@ -66,8 +80,9 @@ def sample(
"""
# Progenitor positions and times. The orbit times are used as the
# release times for the mock stream.
prog_w = prog_orbit.w(units=pot.units) # TODO: keep as PSP
x, v = prog_w[..., 0:3], prog_w[..., 3:6]
prog_orbit = prog_orbit.represent_as(cx.Cartesian3DVector)
x = convert(prog_orbit.q, Quantity)
v = convert(prog_orbit.p, Quantity)
ts = prog_orbit.t

mprog: ProgenitorMassCallable = (
Expand All @@ -85,20 +100,27 @@ def scan_fn(carry: Carry, t: gt.FloatQScalar) -> tuple[Carry, Wif]:
return (i + 1, rng, *out), out

# TODO: use ``jax.vmap`` instead of ``jax.lax.scan`` for GPU usage
init_carry = (0, rng, xp.zeros(3), xp.zeros(3), xp.zeros(3), xp.zeros(3))
init_carry = (
0,
rng,
xp.zeros_like(x[0]),
xp.zeros_like(x[0]),
xp.zeros_like(v[0]),
xp.zeros_like(v[0]),
)
x_lead, x_trail, v_lead, v_trail = jax.lax.scan(scan_fn, init_carry, ts)[1]

mock_lead = MockStream(
q=Quantity(x_lead, pot.units["length"]),
p=Quantity(v_lead, pot.units["speed"]),
t=ts,
release_time=ts,
q=x_lead.to(pot.units["length"]),
p=v_lead.to(pot.units["speed"]),
t=ts.to(pot.units["time"]),
release_time=ts.to(pot.units["time"]),
)
mock_trail = MockStream(
q=Quantity(x_trail, pot.units["length"]),
p=Quantity(v_trail, pot.units["speed"]),
t=ts,
release_time=ts,
q=x_trail.to(pot.units["length"]),
p=v_trail.to(pot.units["speed"]),
t=ts.to(pot.units["time"]),
release_time=ts.to(pot.units["time"]),
)

return mock_lead, mock_trail
Expand All @@ -109,11 +131,13 @@ def _sample(
self,
rng: jr.PRNG,
pot: AbstractPotentialBase,
x: gt.Vec3,
v: gt.Vec3,
x: gt.LengthVec3,
v: gt.SpeedVec3,
prog_mass: gt.FloatQScalar,
t: gt.FloatQScalar,
) -> tuple[gt.BatchVec3, gt.BatchVec3, gt.BatchVec3, gt.BatchVec3]:
) -> tuple[
gt.LengthBatchVec3, gt.LengthBatchVec3, gt.SpeedBatchVec3, gt.SpeedBatchVec3
]:
"""Generate stream particle initial conditions.
Parameters
Expand All @@ -122,16 +146,20 @@ def _sample(
Pseudo-random number generator.
pot : AbstractPotentialBase
The potential of the host galaxy.
w : Array
6d position (x, y, z) [kpc], (v_x, v_y, v_z) [kpc/Myr]
prog_mass : Numeric
Mass of the progenitor in [Msol]
x : Quantity[float, (3,), "length"]
3d position (x, y, z)
v : Quantity[float, (3,), "speed"]
3d velocity (v_x, v_y, v_z)
prog_mass : Quantity[float, (), "mass"]
Mass of the progenitor.
t : Quantity[float, (), "time"]
The release time of the stream particles.
Returns
-------
x_lead, x_trail, v_lead, v_trail : Array
Positions and velocities of the leading and trailing tails.
x_lead, x_trail : Quantity[float, (*shape, 3), "length"]
Positions of the leading and trailing tails.
v_lead, v_trail : Quantity[float, (*shape, 3), "speed"]
Velocities of the leading and trailing tails.
"""
...
Loading

0 comments on commit 72d5efc

Please sign in to comment.