Skip to content

Commit

Permalink
remove Hamiltonian
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Nov 5, 2023
1 parent 2bb917c commit 3405dba
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 73 deletions.
29 changes: 13 additions & 16 deletions src/galdynamix/dynamics/mockstream/_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import jax.numpy as xp
import jax.typing as jt

from galdynamix.potential._hamiltonian import Hamiltonian
from galdynamix.potential._potential.base import PotentialBase
from galdynamix.utils import jit_method


Expand All @@ -34,7 +34,7 @@ def sample(
i: int,
t: jt.Array,
*,
hamiltonian: Hamiltonian,
potential: PotentialBase,
seed_num: int,
) -> tuple[jt.Array, jt.Array, jt.Array, jt.Array]:
"""Sample the DF."""
Expand All @@ -54,7 +54,7 @@ def sample(
i: int,
t: jt.Array,
*,
hamiltonian: Hamiltonian,
potential: PotentialBase,
seed_num: int,
) -> tuple[jt.Array, jt.Array, jt.Array, jt.Array]:
"""
Expand All @@ -75,18 +75,18 @@ def sample(
keye = jax.random.PRNGKey(i * random_ints[4]) # jax.random.PRNGKey(i*17)

L_close, L_far = self._lagrange_pts(
x, v, prog_mass, t, hamiltonian=hamiltonian
x, v, prog_mass, t, potential=potential
) # each is an xyz array

omega_val = self._omega(x, v)

r = xp.linalg.norm(x)
r_hat = x / r
r_tidal = self._tidalr_mw(x, v, prog_mass, t, hamiltonian=hamiltonian)
r_tidal = self._tidalr_mw(x, v, prog_mass, t, potential=potential)
rel_v = omega_val * r_tidal # relative velocity

# circlar_velocity
dphi_dr = xp.sum(hamiltonian.potential.gradient(x, t) * r_hat)
dphi_dr = xp.sum(potential.gradient(x, t) * r_hat)
v_circ = rel_v ##xp.sqrt( r*dphi_dr )

L_vec = xp.cross(x, v)
Expand Down Expand Up @@ -151,17 +151,17 @@ def _lagrange_pts(
Msat: jt.Array,
t: jt.Array,
*,
hamiltonian: Hamiltonian,
potential: PotentialBase,
) -> tuple[jt.Array, jt.Array]:
r_tidal = self._tidalr_mw(x, v, Msat, t, hamiltonian=hamiltonian)
r_tidal = self._tidalr_mw(x, v, Msat, t, potential=potential)
r_hat = x / xp.linalg.norm(x)
L_close = x - r_hat * r_tidal
L_far = x + r_hat * r_tidal
return L_close, L_far

@jit_method()
def _d2phidr2_mw(
self, x: jt.Array, /, t: jt.Array, *, hamiltonian: Hamiltonian
self, x: jt.Array, /, t: jt.Array, *, potential: PotentialBase
) -> jt.Array:
"""
Computes the second derivative of the potential at a position x (in the simulation frame)
Expand All @@ -181,7 +181,7 @@ def _d2phidr2_mw(
"""
rad = xp.linalg.norm(x)
r_hat = x / rad
dphi_dr_func = lambda x: xp.sum(hamiltonian.potential.gradient(x, t) * r_hat) # noqa: E731
dphi_dr_func = lambda x: xp.sum(potential.gradient(x, t) * r_hat) # noqa: E731
return xp.sum(jax.grad(dphi_dr_func)(x) * r_hat)

@jit_method()
Expand Down Expand Up @@ -218,7 +218,7 @@ def _tidalr_mw(
Msat: jt.Array,
t: jt.Array,
*,
hamiltonian: Hamiltonian,
potential: PotentialBase,
) -> jt.Array:
"""Computes the tidal radius of a cluster in the potential.
Expand All @@ -238,10 +238,7 @@ def _tidalr_mw(
>>> _tidalr_mw(x=xp.array([8.0, 0.0, 0.0]), v=xp.array([8.0, 0.0, 0.0]), Msat=1e4)
"""
return (
hamiltonian.potential._G
potential._G
* Msat
/ (
self._omega(x, v) ** 2
- self._d2phidr2_mw(x, t, hamiltonian=hamiltonian)
)
/ (self._omega(x, v) ** 2 - self._d2phidr2_mw(x, t, potential=potential))
) ** (1.0 / 3.0)
21 changes: 15 additions & 6 deletions src/galdynamix/dynamics/mockstream/_mockstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import jax.numpy as xp
import jax.typing as jt

from galdynamix.potential._hamiltonian import Hamiltonian
from galdynamix.potential._potential.base import PotentialBase
from galdynamix.utils import jit_method

Expand All @@ -21,7 +20,7 @@

class MockStreamGenerator(eqx.Module):
df: BaseStreamDF
hamiltonian: Hamiltonian
potential: PotentialBase
progenitor_potential: PotentialBase | None = None

def __post_init__(self) -> None:
Expand All @@ -33,12 +32,20 @@ def __post_init__(self) -> None:
def _gen_stream_ics(
self, ts: jt.Array, prog_w0: jt.Array, prog_mass: jt.Array, *, seed_num: int
) -> jt.Array:
ws_jax = self.hamiltonian.integrate_orbit(prog_w0, t0=xp.min(ts), t1=xp.max(ts), ts=ts)
ws_jax = self.potential.integrate_orbit(
prog_w0, t0=xp.min(ts), t1=xp.max(ts), ts=ts
)

def scan_fun(carry: Any, t: Any) -> Any:
i, pos_close, pos_far, vel_close, vel_far = carry
sample_outputs = self.df.sample(
ws_jax[i, :3], ws_jax[i, 3:], prog_mass, i, t, hamiltonian=self.hamiltonian, seed_num=seed_num
ws_jax[i, :3],
ws_jax[i, 3:],
prog_mass,
i,
t,
potential=self.potential,
seed_num=seed_num,
)
return [i + 1, *sample_outputs], list(sample_outputs)

Expand Down Expand Up @@ -74,11 +81,13 @@ def scan_fun(carry: Any, particle_idx: Any) -> Any:
minval, maxval = ts[i], ts[-1]

def integrate_different_ics(ics: jt.Array) -> jt.Array:
return self.hamiltonian.integrate_orbit(ics, minval, maxval, None)[0]
return self.potential.integrate_orbit(ics, minval, maxval, None)[0]

w_particle_close, w_particle_far = jax.vmap(
integrate_different_ics, in_axes=(0,)
)(w0_lead_trail) # vmap over leading and trailing arm
)(
w0_lead_trail
) # vmap over leading and trailing arm

return [
i + 1,
Expand Down
4 changes: 1 addition & 3 deletions src/galdynamix/potential/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@

from __future__ import annotations

from . import _hamiltonian, _potential
from ._hamiltonian import *
from . import _potential
from ._potential import *

__all__: list[str] = []
__all__ += _potential.__all__
__all__ += _hamiltonian.__all__
28 changes: 0 additions & 28 deletions src/galdynamix/potential/_hamiltonian.py

This file was deleted.

24 changes: 4 additions & 20 deletions src/galdynamix/potential/_potential/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,24 +82,8 @@ def _velocity_acceleration(self, t: jt.Array, xv: jt.Array, args: Any) -> jt.Arr
def integrate_orbit(
self, w0: jt.Array, t0: jt.Array, t1: jt.Array, ts: jt.Array | None
) -> jt.Array:
# from galdynamix.integrate._builtin.diffrax import DiffraxIntegrator
# from galdynamix.potential._hamiltonian import Hamiltonian

# return Hamiltonian(self).integrate_orbit(
# w0, t0, t1, ts, Integrator=DiffraxIntegrator
# )
from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve

solution = diffeqsolve(
terms=ODETerm(self._velocity_acceleration),
solver=Dopri5(),
t0=t0,
t1=t1,
y0=w0,
dt0=None,
saveat=SaveAt(t0=False, t1=True, ts=ts, dense=False),
stepsize_controller=PIDController(rtol=1e-7, atol=1e-7),
discrete_terminating_event=None,
max_steps=None,
from galdynamix.integrate._builtin.diffrax import (
DiffraxIntegrator as Integrator,
)
return solution.ys

return Integrator(self._velocity_acceleration).run(w0, t0, t1, ts)

0 comments on commit 3405dba

Please sign in to comment.