diff --git a/gammapy/spectrum/models.py b/gammapy/spectrum/models.py index 6731f6845e..6ddaffcc43 100644 --- a/gammapy/spectrum/models.py +++ b/gammapy/spectrum/models.py @@ -42,7 +42,10 @@ def __call__(self, energy): """Call evaluate method of derived classes""" kwargs = dict() for par in self.parameters.parameters: - kwargs[par.name] = par.quantity + quantity = par.quantity + if quantity.unit.physical_type == "energy": + quantity = quantity.to(energy.unit) + kwargs[par.name] = quantity return self.evaluate(energy, **kwargs) diff --git a/gammapy/spectrum/tests/test_models.py b/gammapy/spectrum/tests/test_models.py index a62c2e29ef..b762466a89 100644 --- a/gammapy/spectrum/tests/test_models.py +++ b/gammapy/spectrum/tests/test_models.py @@ -229,6 +229,12 @@ def test_models(spectrum): assert_quantity_allclose(val[0], spectrum["val_at_2TeV"]) +def test_model_unit(): + pwl = PowerLaw() + value = pwl(500 * u.MeV) + assert value.unit == "cm-2 s-1 TeV-1" + + @requires_dependency("matplotlib") @requires_data("gammapy-extra") def test_table_model_from_file():