diff --git a/src/galdynamix/potential/_potential/core.py b/src/galdynamix/potential/_potential/core.py index f870c9f6..4c89e3e5 100644 --- a/src/galdynamix/potential/_potential/core.py +++ b/src/galdynamix/potential/_potential/core.py @@ -15,9 +15,7 @@ class AbstractPotential(AbstractPotentialBase): _: KW_ONLY - units: UnitSystem = eqx.field( - default=None, converter=converter_to_usys, static=True - ) + units: UnitSystem = eqx.field(converter=converter_to_usys, static=True) _G: float = eqx.field(init=False, static=True, repr=False, converter=float) def __post_init__(self) -> None: diff --git a/src/galdynamix/potential/_potential/param/field.py b/src/galdynamix/potential/_potential/param/field.py index cb390854..560def80 100644 --- a/src/galdynamix/potential/_potential/param/field.py +++ b/src/galdynamix/potential/_potential/param/field.py @@ -79,6 +79,12 @@ def __get__( # TODO: use `Self` when beartype is happy def _check_unit(self, potential: AbstractPotential, unit: Unit) -> None: """Check that the given unit is compatible with the parameter's.""" + # When the potential is being constructed, the units may not have been + # set yet, so we don't check the unit. + if not hasattr(potential, "units"): + return + + # Check the unit is compatible if not unit.is_equivalent( potential.units[self.dimensions], equivalencies=self.equivalencies, @@ -98,7 +104,8 @@ def __set__( if isinstance(value, AbstractParameter): # TODO: this doesn't handle the correct output unit, a. la. # potential.units[self.dimensions] - self._check_unit(potential, value.unit) # Check the unit is compatible + # Check the unit is compatible + self._check_unit(potential, value.unit) elif callable(value): # TODO: this only gets the existing unit, it doesn't handle the # correct output unit, a. la. potential.units[self.dimensions]