Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Nov 5, 2023
1 parent 3405dba commit 2b3804f
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 115 deletions.
23 changes: 10 additions & 13 deletions src/galdynamix/dynamics/mockstream/_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ def __post_init__(self) -> None:
@abc.abstractmethod
def sample(
self,
potential: PotentialBase,
x: jt.Array,
v: jt.Array,
prog_mass: jt.Array,
i: int,
t: jt.Array,
*,
potential: PotentialBase,
seed_num: int,
) -> tuple[jt.Array, jt.Array, jt.Array, jt.Array]:
"""Sample the DF."""
Expand All @@ -48,13 +48,13 @@ class FardalStreamDF(BaseStreamDF):
@jit_method(static_argnames=("seed_num",))
def sample(
self,
potential: PotentialBase,
x: jt.Array,
v: jt.Array,
prog_mass: jt.Array,
i: int,
t: jt.Array,
*,
potential: PotentialBase,
seed_num: int,
) -> tuple[jt.Array, jt.Array, jt.Array, jt.Array]:
"""
Expand All @@ -74,15 +74,14 @@ def sample(
keyd = jax.random.PRNGKey(i * random_ints[3]) # jax.random.PRNGKey(i*3)
keye = jax.random.PRNGKey(i * random_ints[4]) # jax.random.PRNGKey(i*17)

L_close, L_far = self._lagrange_pts(
x, v, prog_mass, t, potential=potential
) # each is an xyz array
# each is an xyz array
L_close, L_far = self._lagrange_pts(potential, x, v, prog_mass, t)

omega_val = self._omega(x, v)

r = xp.linalg.norm(x)
r_hat = x / r
r_tidal = self._tidalr_mw(x, v, prog_mass, t, potential=potential)
r_tidal = self._tidalr_mw(potential, x, v, prog_mass, t)
rel_v = omega_val * r_tidal # relative velocity

# circlar_velocity
Expand Down Expand Up @@ -146,22 +145,21 @@ def sample(
@jit_method()
def _lagrange_pts(
self,
potential: PotentialBase,
x: jt.Array,
v: jt.Array,
Msat: jt.Array,
t: jt.Array,
*,
potential: PotentialBase,
) -> tuple[jt.Array, jt.Array]:
r_tidal = self._tidalr_mw(x, v, Msat, t, potential=potential)
r_tidal = self._tidalr_mw(potential, x, v, Msat, t)
r_hat = x / xp.linalg.norm(x)
L_close = x - r_hat * r_tidal
L_far = x + r_hat * r_tidal
return L_close, L_far

@jit_method()
def _d2phidr2_mw(
self, x: jt.Array, /, t: jt.Array, *, potential: PotentialBase
self, potential: PotentialBase, x: jt.Array, /, t: jt.Array
) -> jt.Array:
"""
Computes the second derivative of the potential at a position x (in the simulation frame)
Expand Down Expand Up @@ -212,13 +210,12 @@ def _omega(self, x: jt.Array, v: jt.Array) -> jt.Array:
@jit_method()
def _tidalr_mw(
self,
potential: PotentialBase,
x: jt.Array,
v: jt.Array,
/,
Msat: jt.Array,
t: jt.Array,
*,
potential: PotentialBase,
) -> jt.Array:
"""Computes the tidal radius of a cluster in the potential.
Expand All @@ -240,5 +237,5 @@ def _tidalr_mw(
return (
potential._G
* Msat
/ (self._omega(x, v) ** 2 - self._d2phidr2_mw(x, t, potential=potential))
/ (self._omega(x, v) ** 2 - self._d2phidr2_mw(potential, x, t))
) ** (1.0 / 3.0)
2 changes: 1 addition & 1 deletion src/galdynamix/dynamics/mockstream/_mockstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ def _gen_stream_ics(
def scan_fun(carry: Any, t: Any) -> Any:
i, pos_close, pos_far, vel_close, vel_far = carry
sample_outputs = self.df.sample(
self.potential,
ws_jax[i, :3],
ws_jax[i, 3:],
prog_mass,
i,
t,
potential=self.potential,
seed_num=seed_num,
)
return [i + 1, *sample_outputs], list(sample_outputs)
Expand Down
File renamed without changes.
10 changes: 0 additions & 10 deletions src/galdynamix/integrate/_builtin/__init__.py

This file was deleted.

81 changes: 0 additions & 81 deletions src/galdynamix/integrate/_builtin/leapfrog.py

This file was deleted.

13 changes: 3 additions & 10 deletions src/galdynamix/potential/_potential/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def energy(self, q: jt.Array, /, t: jt.Array) -> jt.Array:

###########################################################################
# Core methods that use the above implemented functions
#

@jit_method()
def gradient(self, q: jt.Array, /, t: jt.Array) -> jt.Array:
Expand All @@ -68,15 +67,9 @@ def acceleration(self, q: jt.Array, /, t: jt.Array) -> jt.Array:

###########################################################################

# @jit_method()
# def _jacobian_force_mw(self, q: jt.Array, /, t: jt.Array) -> jt.Array:
# return jax.jacfwd(self.gradient)(q, t)

@jit_method()
def _velocity_acceleration(self, t: jt.Array, xv: jt.Array, args: Any) -> jt.Array:
x, v = xv[:3], xv[3:]
acceleration = -self.gradient(x, t)
return xp.hstack([v, acceleration])
def _vel_acc(self, t: jt.Array, xv: jt.Array, args: Any) -> jt.Array:
return xp.hstack([xv[3:], -self.gradient(xv[:3], t)])

@jit_method()
def integrate_orbit(
Expand All @@ -86,4 +79,4 @@ def integrate_orbit(
DiffraxIntegrator as Integrator,
)

return Integrator(self._velocity_acceleration).run(w0, t0, t1, ts)
return Integrator(self._vel_acc).run(w0, t0, t1, ts)

0 comments on commit 2b3804f

Please sign in to comment.