Skip to content

Commit

Permalink
make parameters time dependent (#2)
Browse files Browse the repository at this point in the history
* make parameters time dependent
* nox tests

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Nov 7, 2023
1 parent 51fcbb4 commit 3956150
Show file tree
Hide file tree
Showing 23 changed files with 828 additions and 523 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ repos:
hooks:
- id: mypy
files: src|tests
args: []
args: ["--enable-incomplete-feature=Unpack"]
additional_dependencies:
- pytest

Expand Down
20 changes: 10 additions & 10 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

DIR = Path(__file__).parent.resolve()

nox.options.sessions = ["lint", "pylint", "tests"]
nox.options.sessions = ["lint", "tests"]


@nox.session
Expand All @@ -22,15 +22,15 @@ def lint(session: nox.Session) -> None:
)


@nox.session
def pylint(session: nox.Session) -> None:
"""
Run PyLint.
"""
# This needs to be installed into the package environment, and is slower
# than a pre-commit check
session.install(".", "pylint")
session.run("pylint", "galdynamix", *session.posargs)
# @nox.session
# def pylint(session: nox.Session) -> None:
# """
# Run PyLint.
# """
# # This needs to be installed into the package environment, and is slower
# # than a pre-commit check
# session.install(".", "pylint")
# session.run("pylint", "galdynamix", *session.posargs)


@nox.session
Expand Down
56 changes: 26 additions & 30 deletions src/galdynamix/dynamics/mockstream/_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
import jax.numpy as xp
import jax.typing as jt

from galdynamix.potential._potential.base import PotentialBase
from galdynamix.utils import jit_method
from galdynamix.potential._potential.base import AbstractPotentialBase
from galdynamix.utils import partial_jit


class BaseStreamDF(eqx.Module):
class BaseStreamDF(eqx.Module): # type: ignore[misc]
lead: bool = eqx.field(default=True, static=True)
trail: bool = eqx.field(default=True, static=True)

Expand All @@ -28,7 +28,7 @@ def __post_init__(self) -> None:
@abc.abstractmethod
def sample(
self,
potential: PotentialBase,
potential: AbstractPotentialBase,
x: jt.Array,
v: jt.Array,
prog_mass: jt.Array,
Expand All @@ -45,10 +45,10 @@ def sample(


class FardalStreamDF(BaseStreamDF):
@jit_method(static_argnames=("seed_num",))
@partial_jit(static_argnames=("seed_num",))
def sample(
self,
potential: PotentialBase,
potential: AbstractPotentialBase,
x: jt.Array,
v: jt.Array,
prog_mass: jt.Array,
Expand All @@ -74,9 +74,6 @@ def sample(
keyd = jax.random.PRNGKey(i * random_ints[3]) # jax.random.PRNGKey(i*3)
keye = jax.random.PRNGKey(i * random_ints[4]) # jax.random.PRNGKey(i*17)

# each is an xyz array
L_close, L_far = self._lagrange_pts(potential, x, v, prog_mass, t)

omega_val = self._omega(x, v)

r = xp.linalg.norm(x)
Expand All @@ -85,7 +82,6 @@ def sample(
rel_v = omega_val * r_tidal # relative velocity

# circlar_velocity
dphi_dr = xp.sum(potential.gradient(x, t) * r_hat)
v_circ = rel_v ##xp.sqrt( r*dphi_dr )

L_vec = xp.cross(x, v)
Expand Down Expand Up @@ -142,31 +138,32 @@ def sample(

return pos_lead, pos_trail, v_lead, v_trail

@jit_method()
@partial_jit()
def _lagrange_pts(
self,
potential: PotentialBase,
potential: AbstractPotentialBase,
x: jt.Array,
v: jt.Array,
Msat: jt.Array,
prog_mass: jt.Array,
t: jt.Array,
) -> tuple[jt.Array, jt.Array]:
r_tidal = self._tidalr_mw(potential, x, v, Msat, t)
r_tidal = self._tidalr_mw(potential, x, v, prog_mass, t)
r_hat = x / xp.linalg.norm(x)
L_close = x - r_hat * r_tidal
L_far = x + r_hat * r_tidal
return L_close, L_far

@jit_method()
@partial_jit()
def _d2phidr2_mw(
self, potential: PotentialBase, x: jt.Array, /, t: jt.Array
self, potential: AbstractPotentialBase, q: jt.Array, t: jt.Array
) -> jt.Array:
"""
Computes the second derivative of the potential at a position x (in the simulation frame)
Parameters
----------
x: 3d position (x, y, z) in [kpc]
x: Array
3d position (x, y, z) in [kpc]
Returns
-------
Expand All @@ -177,13 +174,12 @@ def _d2phidr2_mw(
--------
>>> _d2phidr2_mw(x=xp.array([8.0, 0.0, 0.0]))
"""
rad = xp.linalg.norm(x)
r_hat = x / rad
r_hat = q / xp.linalg.norm(q)
dphi_dr_func = lambda x: xp.sum(potential.gradient(x, t) * r_hat) # noqa: E731
return xp.sum(jax.grad(dphi_dr_func)(x) * r_hat)
return xp.sum(jax.grad(dphi_dr_func)(q) * r_hat)

@jit_method()
def _omega(self, x: jt.Array, v: jt.Array) -> jt.Array:
@partial_jit()
def _omega(self, q: jt.Array, v: jt.Array) -> jt.Array:
"""
Computes the magnitude of the angular momentum in the simulation frame
Expand All @@ -203,18 +199,18 @@ def _omega(self, x: jt.Array, v: jt.Array) -> jt.Array:
--------
>>> _omega(x=xp.array([8.0, 0.0, 0.0]), v=xp.array([8.0, 0.0, 0.0]))
"""
rad = xp.sqrt(x[0] ** 2 + x[1] ** 2 + x[2] ** 2)
omega_vec = xp.cross(x, v) / (rad**2)
r = xp.sqrt(q[0] ** 2 + q[1] ** 2 + q[2] ** 2) # TODO: use norm
omega_vec = xp.cross(q, v) / r**2
return xp.linalg.norm(omega_vec)

@jit_method()
@partial_jit()
def _tidalr_mw(
self,
potential: PotentialBase,
potential: AbstractPotentialBase,
x: jt.Array,
v: jt.Array,
/,
Msat: jt.Array,
prog_mass: jt.Array,
t: jt.Array,
) -> jt.Array:
"""Computes the tidal radius of a cluster in the potential.
Expand All @@ -223,7 +219,7 @@ def _tidalr_mw(
----------
x: 3d position (x, y, z) in [kpc]
v: 3d velocity (v_x, v_y, v_z) in [kpc/Myr]
Msat: Cluster mass in [Msol]
prog_mass: Cluster mass in [Msol]
Returns
-------
Expand All @@ -232,10 +228,10 @@ def _tidalr_mw(
Examples
--------
>>> _tidalr_mw(x=xp.array([8.0, 0.0, 0.0]), v=xp.array([8.0, 0.0, 0.0]), Msat=1e4)
>>> _tidalr_mw(x=xp.array([8.0, 0.0, 0.0]), v=xp.array([8.0, 0.0, 0.0]), prog_mass=1e4)
"""
return (
potential._G
* Msat
* prog_mass
/ (self._omega(x, v) ** 2 - self._d2phidr2_mw(potential, x, t))
) ** (1.0 / 3.0)
Loading

0 comments on commit 3956150

Please sign in to comment.