Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
* small code cleanup

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Jan 24, 2024
1 parent 6ee100d commit 8e08e62
Show file tree
Hide file tree
Showing 24 changed files with 188 additions and 172 deletions.
7 changes: 2 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,12 @@ python_version = "3.11"
warn_unused_configs = true
strict = true
show_error_codes = true
enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"]
warn_unreachable = true
disallow_untyped_defs = false
disallow_incomplete_defs = false
disallow_untyped_defs = true
disallow_incomplete_defs = true

[[tool.mypy.overrides]]
module = "galax.*"
disallow_untyped_defs = true
disallow_incomplete_defs = true
disable_error_code = ["name-defined"] # <- jaxtyping

[[tool.mypy.overrides]]
Expand Down
11 changes: 6 additions & 5 deletions src/galax/dynamics/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from typing import TYPE_CHECKING, final

import equinox as eqx
import jax.numpy as xp
import jax.experimental.array_api as xp
import jax.numpy as jnp
from jaxtyping import Array, Float

from galax.typing import BatchFloatScalar, BatchVec3, BatchVec6, BatchVec7
Expand Down Expand Up @@ -54,7 +55,7 @@ def qp(self) -> BatchVec6:
batch_shape, component_shapes = self._shape_tuple
q = xp.broadcast_to(self.q, batch_shape + component_shapes[0:1])
p = xp.broadcast_to(self.p, batch_shape + component_shapes[1:2])
return xp.concatenate((q, p), axis=-1)
return xp.concat((q, p), axis=-1)

# ==========================================================================
# Dynamical quantities
Expand Down Expand Up @@ -85,7 +86,7 @@ def _shape_tuple(self) -> tuple[tuple[int, ...], tuple[int, int, int]]:
qbatch, qshape = batched_shape(self.q, expect_ndim=1)
pbatch, pshape = batched_shape(self.p, expect_ndim=1)
tbatch, _ = batched_shape(self.t, expect_ndim=0)
batch_shape: tuple[int, ...] = xp.broadcast_shapes(qbatch, pbatch, tbatch)
batch_shape: tuple[int, ...] = jnp.broadcast_shapes(qbatch, pbatch, tbatch)
array_shape: tuple[int, int, int] = qshape + pshape + (1,)
return batch_shape, array_shape

Expand All @@ -102,7 +103,7 @@ def w(self) -> BatchVec7:
t = xp.broadcast_to(
atleast_batched(self.t), batch_shape + component_shapes[2:3]
)
return xp.concatenate((q, p, t), axis=-1)
return xp.concat((q, p, t), axis=-1)

@property
@partial_jit()
Expand Down Expand Up @@ -132,7 +133,7 @@ def angular_momentum(self) -> BatchVec3:
Array([0. , 0. , 6.28318531], dtype=float64)
"""
# TODO: when q, p are not Cartesian.
return xp.cross(self.q, self.p)
return jnp.cross(self.q, self.p)

# ==========================================================================
# Dynamical quantities
Expand Down
7 changes: 4 additions & 3 deletions src/galax/dynamics/mockstream/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
__all__ = ["MockStream"]

import equinox as eqx
import jax.numpy as xp
import jax.experimental.array_api as xp
import jax.numpy as jnp
from jaxtyping import Array, Float

from galax.dynamics._core import AbstractPhaseSpacePositionBase
Expand Down Expand Up @@ -39,7 +40,7 @@ def _shape_tuple(self) -> tuple[tuple[int, ...], tuple[int, int, int]]:
qbatch, qshape = batched_shape(self.q, expect_ndim=1)
pbatch, pshape = batched_shape(self.p, expect_ndim=1)
tbatch, _ = batched_shape(self.release_time, expect_ndim=0)
batch_shape = xp.broadcast_shapes(qbatch, pbatch, tbatch)
batch_shape = jnp.broadcast_shapes(qbatch, pbatch, tbatch)
return batch_shape, qshape + pshape + (1,)

@property
Expand All @@ -52,4 +53,4 @@ def w(self) -> BatchVec7:
t = xp.broadcast_to(
atleast_batched(self.release_time), batch_shape + component_shapes[2:3]
)
return xp.concatenate((q, p, t), axis=-1)
return xp.concat((q, p, t), axis=-1)
25 changes: 12 additions & 13 deletions src/galax/dynamics/mockstream/_df/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import equinox as eqx
import jax
import jax.numpy as xp
import jax.experimental.array_api as xp

from galax.dynamics._orbit import Orbit
from galax.dynamics.mockstream._core import MockStream
Expand Down Expand Up @@ -62,27 +62,26 @@ def sample(
mock_lead, mock_trail : MockStream
Positions and velocities of the leading and trailing tails.
"""
# Progenitor positions and times. The orbit times are the release times
# for the mock stream.
prog_qps = prog_orbit.qp
ts = prog_orbit.t

# Scan over the release times to generate the stream particle initial
# conditions at each release time.
def scan_fn(carry: Carry, t: FloatScalar) -> tuple[Carry, Wif]:
i = carry[0]
output = self._sample(
potential,
prog_qps[i],
prog_mass,
t,
i=i,
seed_num=seed_num,
out = self._sample(
potential, prog_qps[i], prog_mass, t, i=i, seed_num=seed_num
)
return (i + 1, *output), tuple(output)
return (i + 1, *out), out

init_carry = (
0,
xp.array([0.0, 0.0, 0.0]),
xp.array([0.0, 0.0, 0.0]),
xp.array([0.0, 0.0, 0.0]),
xp.array([0.0, 0.0, 0.0]),
xp.asarray([0.0, 0.0, 0.0]),
xp.asarray([0.0, 0.0, 0.0]),
xp.asarray([0.0, 0.0, 0.0]),
xp.asarray([0.0, 0.0, 0.0]),
)
x_lead, x_trail, v_lead, v_trail = jax.lax.scan(scan_fn, init_carry, ts[1:])[1]

Expand Down
63 changes: 31 additions & 32 deletions src/galax/dynamics/mockstream/_df/fardal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
__all__ = ["FardalStreamDF"]


import jax
import jax.numpy as xp
import jax.experimental.array_api as xp
import jax.numpy as jnp
from jax import grad, random

from galax.potential._potential.base import AbstractPotentialBase
from galax.typing import (
Expand Down Expand Up @@ -42,34 +43,32 @@ def _sample(
"""Generate stream particle initial conditions."""
# Random number generation
# TODO: change random key handling... need to do all of the sampling up front...
key_master = jax.random.PRNGKey(seed_num)
random_ints = jax.random.randint(
key=key_master, shape=(4,), minval=0, maxval=1000
)
keya = jax.random.PRNGKey(i * random_ints[0])
keyb = jax.random.PRNGKey(i * random_ints[1])
keyc = jax.random.PRNGKey(i * random_ints[2])
keyd = jax.random.PRNGKey(i * random_ints[3])
key_master = random.PRNGKey(seed_num)
random_ints = random.randint(key=key_master, shape=(4,), minval=0, maxval=1000)
keya = random.PRNGKey(i * random_ints[0])
keyb = random.PRNGKey(i * random_ints[1])
keyc = random.PRNGKey(i * random_ints[2])
keyd = random.PRNGKey(i * random_ints[3])

# ---------------------------

x, v = qp[0:3], qp[3:6]

omega_val = orbital_angular_velocity_mag(x, v)

r = xp.linalg.norm(x)
r = xp.linalg.vector_norm(x)
r_hat = x / r
r_tidal = tidal_radius(potential, x, v, prog_mass, t)
rel_v = omega_val * r_tidal # relative velocity

# circlar_velocity
v_circ = rel_v

L_vec = xp.cross(x, v)
z_hat = L_vec / xp.linalg.norm(L_vec)
L_vec = jnp.cross(x, v)
z_hat = L_vec / xp.linalg.vector_norm(L_vec)

phi_vec = v - xp.sum(v * r_hat) * r_hat
phi_hat = phi_vec / xp.linalg.norm(phi_vec)
phi_hat = phi_vec / xp.linalg.vector_norm(phi_vec)

kr_bar = 2.0
kvphi_bar = 0.3
Expand All @@ -82,12 +81,12 @@ def _sample(
sigma_kz = 0.5
sigma_kvz = 0.5

kr_samp = kr_bar + jax.random.normal(keya, shape=(1,)) * sigma_kr
kr_samp = kr_bar + random.normal(keya, shape=(1,)) * sigma_kr
kvphi_samp = kr_samp * (
kvphi_bar + jax.random.normal(keyb, shape=(1,)) * sigma_kvphi
kvphi_bar + random.normal(keyb, shape=(1,)) * sigma_kvphi
)
kz_samp = kz_bar + jax.random.normal(keyc, shape=(1,)) * sigma_kz
kvz_samp = kvz_bar + jax.random.normal(keyd, shape=(1,)) * sigma_kvz
kz_samp = kz_bar + random.normal(keyc, shape=(1,)) * sigma_kz
kvz_samp = kvz_bar + random.normal(keyd, shape=(1,)) * sigma_kvz

# Trailing arm
x_trail = (
Expand Down Expand Up @@ -134,7 +133,7 @@ def dphidr(potential: AbstractPotentialBase, x: Vec3, t: FloatScalar) -> Vec3:
Array:
Derivative of potential in [1/Myr]
"""
r_hat = x / xp.linalg.norm(x)
r_hat = x / xp.linalg.vector_norm(x)
return xp.sum(potential.gradient(x, t) * r_hat)


Expand Down Expand Up @@ -165,12 +164,12 @@ def d2phidr2(
>>> from galax.potential import NFWPotential
>>> from galax.units import galactic
>>> pot = NFWPotential(m=1e12, r_s=20.0, units=galactic)
>>> d2phidr2(pot, xp.array([8.0, 0.0, 0.0]), t=0)
>>> d2phidr2(pot, xp.asarray([8.0, 0.0, 0.0]), t=0)
Array(-0.00017469, dtype=float64)
"""
r_hat = x / xp.linalg.norm(x)
r_hat = x / xp.linalg.vector_norm(x)
dphi_dr_func = lambda x: xp.sum(potential.gradient(x, t) * r_hat) # noqa: E731
return xp.sum(jax.grad(dphi_dr_func)(x) * r_hat)
return xp.sum(grad(dphi_dr_func)(x) * r_hat)


@partial_jit()
Expand All @@ -191,13 +190,13 @@ def orbital_angular_velocity(x: Vec3, v: Vec3, /) -> Vec3:
Examples:
--------
>>> x = xp.array([8.0, 0.0, 0.0])
>>> v = xp.array([8.0, 0.0, 0.0])
>>> x = xp.asarray([8.0, 0.0, 0.0])
>>> v = xp.asarray([8.0, 0.0, 0.0])
>>> orbital_angular_velocity(x, v)
Array([0., 0., 0.], dtype=float64)
"""
r = xp.linalg.norm(x)
return xp.cross(x, v) / r**2
r = xp.linalg.vector_norm(x)
return jnp.cross(x, v) / r**2


@partial_jit()
Expand All @@ -218,12 +217,12 @@ def orbital_angular_velocity_mag(x: Vec3, v: Vec3, /) -> FloatScalar:
Examples:
--------
>>> x = xp.array([8.0, 0.0, 0.0])
>>> v = xp.array([8.0, 0.0, 0.0])
>>> x = xp.asarray([8.0, 0.0, 0.0])
>>> v = xp.asarray([8.0, 0.0, 0.0])
>>> orbital_angular_velocity_mag(x, v)
Array(0., dtype=float64)
"""
return xp.linalg.norm(orbital_angular_velocity(x, v))
return xp.linalg.vector_norm(orbital_angular_velocity(x, v))


@partial_jit()
Expand Down Expand Up @@ -260,8 +259,8 @@ def tidal_radius(
>>> from galax.potential import NFWPotential
>>> from galax.units import galactic
>>> pot = NFWPotential(m=1e12, r_s=20.0, units=galactic)
>>> x=xp.array([8.0, 0.0, 0.0])
>>> v=xp.array([8.0, 0.0, 0.0])
>>> x=xp.asarray([8.0, 0.0, 0.0])
>>> v=xp.asarray([8.0, 0.0, 0.0])
>>> tidal_radius(pot, x, v, prog_mass=1e4, t=0)
Array(0.06362136, dtype=float64)
"""
Expand Down Expand Up @@ -296,7 +295,7 @@ def lagrange_points(
Time.
"""
r_t = tidal_radius(potential, x, v, prog_mass, t)
r_hat = x / xp.linalg.norm(x)
r_hat = x / xp.linalg.vector_norm(x)
L_1 = x - r_hat * r_t # close
L_2 = x + r_hat * r_t # far
return L_1, L_2
10 changes: 6 additions & 4 deletions src/galax/dynamics/mockstream/_mockstream_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

import equinox as eqx
import jax
import jax.numpy as xp
import jax.experimental.array_api as xp
import jax.numpy as jnp
from jax.lib.xla_bridge import get_backend

from galax.dynamics._orbit import Orbit
Expand Down Expand Up @@ -71,7 +72,7 @@ def _run_scan( # TODO: output shape depends on the input shape

def scan_fn(carry: Carry, _: IntScalar) -> tuple[Carry, tuple[VecN, VecN]]:
i, qp0_lead_i, qp0_trail_i = carry
qp0_lead_trail = xp.vstack([qp0_lead_i, qp0_trail_i])
qp0_lead_trail = jnp.vstack([qp0_lead_i, qp0_trail_i]) # TODO: xp.stack
tstep = xp.asarray([ts[i], ts[-1]])

def integ_ics(ics: Vec6) -> VecN:
Expand Down Expand Up @@ -167,12 +168,13 @@ def run(
# Parse vmapped
use_vmap = get_backend().platform == "gpu" if vmapped is None else vmapped

# Integrate the progenitor orbit to the stripping times
# Integrate the progenitor orbit, evaluating at the stripping times
prog_o = self.potential.integrate_orbit(
prog_w0, ts, integrator=self.progenitor_integrator
)

# Generate stream initial conditions along the integrated progenitor orbit
# Generate stream initial conditions along the integrated progenitor
# orbit. The release times are stripping times.
mock0_lead, mock0_trail = self.df.sample(
self.potential, prog_o, prog_mass, seed_num=seed_num
)
Expand Down
4 changes: 2 additions & 2 deletions src/galax/integrate/_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Any

import equinox as eqx
import jax.numpy as xp
import jax.experimental.array_api as xp
from diffrax import (
AbstractSolver,
AbstractStepSizeController,
Expand Down Expand Up @@ -60,4 +60,4 @@ def __call__(
**self.diffeq_kw,
)
ts = solution.ts[:, None] if solution.ts.ndim == 1 else solution.ts
return xp.concatenate((solution.ys, ts), axis=1)
return xp.concat((solution.ys, ts), axis=1)
7 changes: 4 additions & 3 deletions src/galax/potential/_potential/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

import astropy.units as u
import equinox as eqx
import jax.numpy as xp
import jax.experimental.array_api as xp
import jax.numpy as jnp
from astropy.constants import G as _G # pylint: disable=no-name-in-module
from jax import grad, hessian, jacfwd
from jaxtyping import Array, Float
Expand Down Expand Up @@ -171,7 +172,7 @@ def gradient(self, q: BatchVec3, /, t: BatchableFloatOrIntScalarLike) -> BatchVe
def _density(self, q: Vec3, /, t: FloatOrIntScalar) -> FloatScalar:
"""See ``density``."""
# Note: trace(jacobian(gradient)) is faster than trace(hessian(energy))
lap = xp.trace(jacfwd(self.gradient)(q, t))
lap = jnp.trace(jacfwd(self.gradient)(q, t))
return lap / (4 * xp.pi * self._G)

def density(
Expand Down Expand Up @@ -290,7 +291,7 @@ def _integrator_F(
self, t: FloatScalar, qp: Vec6, args: tuple[Any, ...] # pylint: disable=W0613
) -> Vec6:
"""Return the derivative of the phase-space position."""
return xp.hstack([qp[3:6], self.acceleration(qp[0:3], t)]) # v, a
return jnp.hstack([qp[3:6], self.acceleration(qp[0:3], t)]) # v, a

@partial_jit(static_argnames=("integrator",))
def integrate_orbit(
Expand Down
Loading

0 comments on commit 8e08e62

Please sign in to comment.