Skip to content

Commit

Permalink
Improve passing callable function as a parameter (GalacticDynamics#33)
Browse files Browse the repository at this point in the history
* Make UserParameter func static
* handle unit consistency check.

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Dec 10, 2023
1 parent a1a5704 commit f1155f4
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 12 deletions.
16 changes: 13 additions & 3 deletions src/galdynamix/potential/_potential/param/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class ConstantParameter(AbstractParameter):
# TODO: link this shape to the return shape from __call__
value: FloatArrayAnyShape = eqx.field(converter=converter_float_array)

# This is a workaround since vectorized methods don't support kwargs.
@partial_jit()
@vectorize_method(signature="()->()")
def _call_helper(self, _: FloatOrIntScalar) -> ArrayAnyShape:
Expand Down Expand Up @@ -93,6 +94,7 @@ def __call__(

#####################################################################
# User-defined Parameter
# For passing a function as a parameter.


@runtime_checkable
Expand All @@ -118,11 +120,19 @@ def __call__(self, t: FloatScalar, **kwargs: Any) -> ArrayAnyShape:


class UserParameter(AbstractParameter):
"""User-defined Parameter."""
"""User-defined Parameter.
Parameters
----------
func : Callable[[Array[float, ()] | float | int], Array[float, (*shape,)]]
The function to use to compute the parameter value.
unit : Unit, keyword-only
The output unit of the parameter.
"""

# TODO: unit handling
func: ParameterCallable
func: ParameterCallable = eqx.field(static=True)

@partial_jit()
def __call__(self, t: FloatScalar, **kwargs: Any) -> ArrayAnyShape:
def __call__(self, t: FloatOrIntScalar, **kwargs: Any) -> FloatArrayAnyShape:
return self.func(t, **kwargs)
80 changes: 71 additions & 9 deletions src/galdynamix/potential/_potential/param/field.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
__all__ = ["ParameterField"]

from dataclasses import KW_ONLY, dataclass, field, is_dataclass
from typing import Any, cast, overload
from typing import Annotated, Any, cast, get_args, get_origin, get_type_hints, overload

import astropy.units as u
import jax.numpy as xp

from galdynamix.potential._potential.core import AbstractPotential
from galdynamix.typing import Unit

from .core import AbstractParameter, ConstantParameter, ParameterCallable, UserParameter

Expand All @@ -27,7 +28,7 @@ class ParameterField:

name: str = field(init=False)
_: KW_ONLY
dimensions: u.PhysicalType # TODO: add a converter_argument
dimensions: u.PhysicalType
equivalencies: u.Equivalency | tuple[u.Equivalency, ...] | None = None

def __post_init__(self) -> None:
Expand Down Expand Up @@ -76,20 +77,34 @@ def __get__( # TODO: use `Self` when beartype is happy

# -----------------------------

def _check_unit(self, potential: AbstractPotential, unit: Unit) -> None:
"""Check that the given unit is compatible with the parameter's."""
if not unit.is_equivalent(
potential.units[self.dimensions],
equivalencies=self.equivalencies,
):
msg = (
"Parameter function must return a value "
f"with units equivalent to {self.dimensions}"
)
raise ValueError(msg)

def __set__(
self,
potential: AbstractPotential,
value: AbstractParameter | ParameterCallable | Any,
) -> None:
# Convert
if isinstance(value, AbstractParameter):
# TODO: use the dimensions & equivalencies info to check the parameters.
# TODO: use the units on the `potential` to convert the parameter value.
pass
# TODO: this doesn't handle the correct output unit, a. la.
# potential.units[self.dimensions]
self._check_unit(potential, value.unit) # Check the unit is compatible
elif callable(value):
# TODO: use the dimensions & equivalencies info to check the parameters.
# TODO: use the units on the `potential` to convert the parameter value.
value = UserParameter(func=value)
# TODO: this only gets the existing unit, it doesn't handle the
# correct output unit, a. la. potential.units[self.dimensions]
unit = _get_unit_from_return_annotation(value)
self._check_unit(potential, unit) # Check the unit is compatible
value = UserParameter(func=value, unit=unit)
else:
# TODO: the issue here is that ``units`` hasn't necessarily been set
# on the potential yet. What is needed is to possibly bail out
Expand All @@ -100,8 +115,55 @@ def __set__(
unit = potential.units[self.dimensions]
if isinstance(value, u.Quantity):
value = value.to_value(unit, equivalencies=self.equivalencies)

value = ConstantParameter(xp.asarray(value), unit=unit)

# Set
potential.__dict__[self.name] = value


def _get_unit_from_return_annotation(func: ParameterCallable) -> Unit:
"""Get the unit from the return annotation of a Parameter function.
Parameters
----------
func : Callable[[Array[float, ()] | float | int], Array[float, (*shape,)]]
The function to use to compute the parameter value.
Returns
-------
Unit
The unit from the return annotation of the function.
"""
# Get the return annotation
type_hints = get_type_hints(func, include_extras=True)
if "return" not in type_hints:
msg = "Parameter function must have a return annotation"
raise TypeError(msg)

# Check that the return annotation might contain a unit
return_annotation = type_hints["return"]
return_origin = get_origin(return_annotation)
if return_origin is not Annotated:
msg = "Parameter function return annotation must be annotated"
raise TypeError(msg)

# Get the unit from the return annotation
return_args = get_args(return_annotation)
has_unit = False
for arg in return_args[1:]:
# Try to convert the argument to a unit
try:
unit = u.Unit(arg)
except ValueError:
continue
# Only one unit annotation is allowed
if has_unit:
msg = "function has more than one unit annotation"
raise ValueError(msg)
has_unit = True

if not has_unit:
msg = "function did not have a valid unit annotation"
raise ValueError(msg)

return unit

0 comments on commit f1155f4

Please sign in to comment.