Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ dependencies = [
"citation-compass>=0.0.3",
"dust_extinction",
"extinction>=0.4.7", # Needed by sncosmo
"jax",
"matplotlib",
"mocpy",
"nested-pandas",
Expand All @@ -43,6 +42,7 @@ dev = [
"bilby", # Used for Bayesian inference
"fsspec[http]", # Read OpSim file from the web with utils/make_opsim_shorten.py
"jupyter", # Clears output from Jupyter notebooks
"jax",
"lsdb",
"pre-commit", # Used to run checks before finalizing a git commit
"pytest",
Expand All @@ -58,6 +58,7 @@ dev = [
# (include the single quotes)
all = [
"bilby", # Used for Bayesian inference
"jax",
"lsdb", # Used to read and write LSDB catalogs
"pzflow",
"sncosmo",
Expand Down
17 changes: 16 additions & 1 deletion src/lightcurvelynx/astro_utils/salt2_color_law.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
https://github.com/sncosmo/sncosmo/blob/v2.10.1/sncosmo/salt2utils.pyx
"""

import jax.numpy as jnp
import numpy as np
from citation_compass import CiteClass

Expand Down Expand Up @@ -48,6 +47,14 @@ class SALT2ColorLaw(CiteClass):
"""

def __init__(self, wave_min, wave_max, coeffs):
try:
import jax.numpy as jnp
except ImportError as err:
raise ImportError(
"JAX is required to use the SALT2ColorLaw class, please "
"install with `pip install jax` or `conda install conda-forge::jax`"
) from err

# Create the internal coefficient array. The new first entry is 1.0 minus the
# sum of the given entries. The first six given entries are then listed.
coeffs = np.array(coeffs)
Expand Down Expand Up @@ -112,6 +119,14 @@ def apply(self, wavelengths):
wavelengths : array
The wavelengths in angstroms.
"""
try:
import jax.numpy as jnp
except ImportError as err:
raise ImportError(
"JAX is required to use the SALT2ColorLaw class, please "
"install with `pip install jax` or `conda install conda-forge::jax`"
) from err

num_waves = len(wavelengths)
shifted_wave = (jnp.asarray(wavelengths) - _SALT2CL_B) * _WAVESCALE

Expand Down
69 changes: 41 additions & 28 deletions src/lightcurvelynx/math_nodes/basic_math_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,6 @@

import ast

# Disable unused import because we need all of these imported
# so they can be used during evaluation of the node.
import math # noqa: F401

import jax.numpy as jnp # noqa: F401
import numpy as np # noqa: F401

from lightcurvelynx.base_models import FunctionNode


Expand All @@ -36,14 +29,15 @@ class BasicMathNode(FunctionNode):
expression : str
The expression to evaluate.
backend : str
The math libary to use. Must be one of: math, numpy, or jax.
The math libary to use. This is auto-converted to one of (math, np, or jnp)
depending on the input parameter.

Parameters
----------
expression : str
The expression to evaluate.
backend : str
The math libary to use. Must be one of: math, numpy, or jax.
The math libary to use. Must be one of: math, numpy, np, jax, or jnp.
node_label : str, optional
An identifier (or name) for the current node.
**kwargs : dict, optional
Expand Down Expand Up @@ -114,9 +108,32 @@ class BasicMathNode(FunctionNode):
}

def __init__(self, expression, backend="numpy", node_label=None, **kwargs):
if backend not in ["jax", "math", "numpy"]:
if backend == "jax" or backend == "jnp":
try:
import jax.numpy as jnp
except ImportError as err:
raise ImportError(
"JAX is required to use the BasicMathNode with backend='jax', please "
"install with `pip install jax` or `conda install conda-forge::jax`"
) from err

self.backend = "jnp"
self.backend_lib = jnp
self.to_array = jnp.asarray
elif backend == "numpy" or backend == "np":
import numpy as np

self.backend = "np"
self.backend_lib = np
self.to_array = np.asarray
elif backend == "math":
import math

self.backend = "math"
self.backend_lib = math
self.to_array = lambda x: x # No conversion
else:
raise ValueError(f"Unsupported math backend {backend}")
self.backend = backend

# Check the expression is pure math and translate it into the correct backend.
self.expression = expression
Expand All @@ -125,20 +142,26 @@ def __init__(self, expression, backend="numpy", node_label=None, **kwargs):
# Create a function from the expression. Note the expression has
# already been sanitized and validated via _prepare().
def eval_func(**kwargs):
params = self.prepare_params(**kwargs)
params = self._prepare_params(**kwargs)
params[self.backend] = self.backend_lib

try:
return eval(self.expression, globals(), params)
except Exception as problem:
# Provide more detailed logging, including the expression and parameters
# used, when we encounter a math error like divide by zero.
new_message = f"Error during math operation '{self.expression}' with args={kwargs}"
new_message = (
f"Error during math operation '{self.expression}' with args={kwargs}. "
f"Original error: {problem}"
)
raise type(problem)(new_message) from problem

super().__init__(eval_func, node_label=node_label, **kwargs)

def eval(self, **kwargs):
"""Evaluate the expression."""
params = self.prepare_params(**kwargs)
params = self._prepare_params(**kwargs)
params[self.backend] = self.backend_lib
return eval(self.expression, globals(), params)

@staticmethod
Expand All @@ -152,7 +175,7 @@ def list_functions():
"""
return list(BasicMathNode._math_map.keys())

def prepare_params(self, **kwargs):
def _prepare_params(self, **kwargs):
"""Convert all of the incoming parameters into the correct type,
such as numpy arrays.

Expand All @@ -168,12 +191,7 @@ def prepare_params(self, **kwargs):
"""
params = {}
for name, value in kwargs.items():
if self.backend == "numpy":
params[name] = np.array(value)
elif self.backend == "jax":
params[name] = jnp.array(value)
else:
params[name] = value
params[name] = self.to_array(value)
return params

def _prepare(self, **kwargs):
Expand All @@ -186,11 +204,6 @@ def _prepare(self, **kwargs):
**kwargs : dict, optional
Any additional keyword arguments, including the variable
assignments.

Returns
-------
tree : ast.*
The root node of the parsed syntax tree.
"""
tree = ast.parse(self.expression)

Expand All @@ -211,9 +224,9 @@ def _prepare(self, **kwargs):
# This is a math function or constant. Overwrite
if self.backend == "math":
node.id = self._math_map[node.id][0]
elif self.backend == "numpy":
elif self.backend == "numpy" or self.backend == "np":
node.id = self._math_map[node.id][1]
elif self.backend == "jax":
elif self.backend == "jax" or self.backend == "jnp":
node.id = self._math_map[node.id][2]
else:
raise ValueError(
Expand Down
3 changes: 2 additions & 1 deletion src/lightcurvelynx/models/salt2_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from lightcurvelynx.astro_utils.salt2_color_law import SALT2ColorLaw
from lightcurvelynx.astro_utils.unit_utils import flam_to_fnu
from lightcurvelynx.models.physical_model import SEDModel
from lightcurvelynx.utils.bicubic_interp import BicubicInterpolator


class SALT2JaxModel(SEDModel, CiteClass):
Expand Down Expand Up @@ -100,6 +99,8 @@ def __init__(
self.add_parameter("c", c, **kwargs)

# Load the data files.
from lightcurvelynx.utils.bicubic_interp import BicubicInterpolator

model_path = Path(model_dir)
self._m0_model = BicubicInterpolator.from_grid_file(
model_path / m0_filename,
Expand Down
10 changes: 8 additions & 2 deletions src/lightcurvelynx/utils/bicubic_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,14 @@
https://github.com/sncosmo/sncosmo/blob/v2.10.1/sncosmo/salt2utils.pyx
"""

import jax.numpy as jnp
from jax import jit, vmap
try:
import jax.numpy as jnp
from jax import jit, vmap
except ImportError as err:
raise ImportError(
"JAX is required to use the BicubicInterpolator class, please "
"install with `pip install jax` or `conda install conda-forge::jax`"
) from err

from lightcurvelynx.utils.io_utils import read_grid_data

Expand Down