Skip to content

Commit

Permalink
F20 tests
Browse files Browse the repository at this point in the history
  • Loading branch information
karllark committed Jun 11, 2019
1 parent 6aa86c4 commit af90d1d
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 91 deletions.
78 changes: 44 additions & 34 deletions dust_extinction/baseclasses.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
from __future__ import (absolute_import, print_function, division)
from __future__ import absolute_import, print_function, division

import numpy as np

from astropy.modeling import (Model, Parameter, InputParameterError)
from astropy.modeling import Model, Parameter, InputParameterError

__all__ = ['BaseExtModel', 'BaseExtAveModel',
'BaseExtRvModel', 'BaseExtRvAfAModel']
__all__ = ["BaseExtModel", "BaseExtAveModel", "BaseExtRvModel", "BaseExtRvAfAModel"]


class BaseExtModel(Model):
"""
Base Extinction Model. Do not use.
"""
inputs = ('x',)
outputs = ('axav',)

inputs = ("x",)
outputs = ("axav",)

def extinguish(self, x, Av=None, Ebv=None):
"""
Expand Down Expand Up @@ -45,22 +45,23 @@ def extinguish(self, x, Av=None, Ebv=None):

# check that av or ebv is set
if (Av is None) and (Ebv is None):
raise InputParameterError('neither Av or Ebv passed, one required')
raise InputParameterError("neither Av or Ebv passed, one required")

# if Av is not set and Ebv set, convert to Av
if Av is None:
Av = self.Rv*Ebv
Av = self.Rv * Ebv

# return fractional extinction
return np.power(10.0, -0.4*axav*Av)
return np.power(10.0, -0.4 * axav * Av)


class BaseExtAveModel(Model):
"""
Base Extinction Average. Do not use.
"""
inputs = ('x',)
outputs = ('axav',)

inputs = ("x",)
outputs = ("axav",)

def extinguish(self, x, Av=None, Ebv=None):
"""
Expand Down Expand Up @@ -92,23 +93,25 @@ def extinguish(self, x, Av=None, Ebv=None):

# check that av or ebv is set
if (Av is None) and (Ebv is None):
raise InputParameterError('neither Av or Ebv passed, one required')
raise InputParameterError("neither Av or Ebv passed, one required")

# if Av is not set and Ebv set, convert to Av
if Av is None:
Av = self.Rv*Ebv
Av = self.Rv * Ebv

# return fractional extinction
return np.power(10.0, -0.4*axav*Av)
return np.power(10.0, -0.4 * axav * Av)


class BaseExtRvModel(BaseExtModel):
"""
Base Extinction R(V)-dependent Model. Do not use.
"""
Rv = Parameter(description="R(V) = A(V)/E(B-V) = "
+ "total-to-selective extinction",
default=3.1)

Rv = Parameter(
description="R(V) = A(V)/E(B-V) = " + "total-to-selective extinction",
default=3.1,
)

@Rv.validator
def Rv(self, value):
Expand All @@ -126,22 +129,25 @@ def Rv(self, value):
Input Rv values outside of defined range
"""
if not (self.Rv_range[0] <= value <= self.Rv_range[1]):
raise InputParameterError("parameter Rv must be between "
+ str(self.Rv_range[0])
+ " and "
+ str(self.Rv_range[1]))
raise InputParameterError(
"parameter Rv must be between "
+ str(self.Rv_range[0])
+ " and "
+ str(self.Rv_range[1])
)


class BaseExtRvAfAModel(BaseExtModel):
"""
Base Extinction R(V)_A, f_A -dependent Model. Do not use.
"""

RvA = Parameter(description="R_A(V) = A(V)/E(B-V) = "
+ "total-to-selective extinction of component A",
default=3.1)
fA = Parameter(description="f_A = mixture coefficent of component A",
default=1.0)
RvA = Parameter(
description="R_A(V) = A(V)/E(B-V) = "
+ "total-to-selective extinction of component A",
default=3.1,
)
fA = Parameter(description="f_A = mixture coefficent of component A", default=1.0)

@RvA.validator
def RvA(self, value):
Expand All @@ -159,10 +165,12 @@ def RvA(self, value):
Input R_A(V) values outside of defined range
"""
if not (self.RvA_range[0] <= value <= self.RvA_range[1]):
raise InputParameterError("parameter RvA must be between "
+ str(self.RvA_range[0])
+ " and "
+ str(self.RvA_range[1]))
raise InputParameterError(
"parameter RvA must be between "
+ str(self.RvA_range[0])
+ " and "
+ str(self.RvA_range[1])
)

@fA.validator
def fA(self, value):
Expand All @@ -180,7 +188,9 @@ def fA(self, value):
Input fA values outside of defined range
"""
if not (self.fA_range[0] <= value <= self.fA_range[1]):
raise InputParameterError("parameter fA must be between "
+ str(self.fA_range[0])
+ " and "
+ str(self.fA_range[1]))
raise InputParameterError(
"parameter fA must be between "
+ str(self.fA_range[0])
+ " and "
+ str(self.fA_range[1])
)
10 changes: 4 additions & 6 deletions dust_extinction/parameter_averages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1217,16 +1217,14 @@ def evaluate(self, in_x, Rv):
"""
# convert to wavenumbers (1/micron) if x input in units
# otherwise, assume x in appropriate wavenumber units
with u.add_enabled_equivalencies(u.spectral()):
x_quant = u.Quantity(in_x, 1.0 / u.micron, dtype=np.float64)

# strip the quantity to avoid needing to add units to all the
# polynomical coefficients
x = x_quant.value
x = _get_x_in_wavenumbers(in_x)

# check that the wavenumbers are within the defined range
_test_valid_x_range(x, x_range_F20, "F20")

# just in case someone calls evaluate explicitly
Rv = np.atleast_1d(Rv)

# ensure Rv is a single element, not numpy array
Rv = Rv[0]

Expand Down
74 changes: 74 additions & 0 deletions dust_extinction/tests/test_f20.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import numpy as np
import pytest

import astropy.units as u

from ..parameter_averages import F20
from .helpers import _invalid_x_range


x_bad = [-1.0, 0.1, 12.0, 100.0]


@pytest.mark.parametrize("x_invalid", x_bad)
def test_invalid_wavenumbers(x_invalid):
_invalid_x_range(x_invalid, F20(), "F20")


@pytest.mark.parametrize("x_invalid_wavenumber", x_bad / u.micron)
def test_invalid_wavenumbers_imicron(x_invalid_wavenumber):
_invalid_x_range(x_invalid_wavenumber, F20(), "F20")


@pytest.mark.parametrize("x_invalid_micron", u.micron / x_bad)
def test_invalid_micron(x_invalid_micron):
_invalid_x_range(x_invalid_micron, F20(), "F20")


@pytest.mark.parametrize("x_invalid_angstrom", u.angstrom * 1e4 / x_bad)
def test_invalid_angstrom(x_invalid_angstrom):
_invalid_x_range(x_invalid_angstrom, F20(), "F20")


def get_axav_cor_vals():
# use x values from Fitzpatrick et al. (2000) Table 3
x = np.array([1.0, 1.5, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
cor_vals = np.array(
[-1.757, -0.629, 0.438, 2.090, 4.139, 5.704, 4.904, 5.684, 7.150]
)
tolerance = 2e-3

# convert from E(x-V)/E(B-V) to A(x)/A(V)
cor_vals = (cor_vals + 3.1) / 3.1

# add units
x = x / u.micron

return (x, cor_vals, tolerance)


def test_extinction_F20_values():
# get the correct values
x, cor_vals, tolerance = get_axav_cor_vals()

# initialize extinction model
tmodel = F20()

# test
np.testing.assert_allclose(tmodel(x), cor_vals, rtol=tolerance)


x_vals, axav_vals, tolerance = get_axav_cor_vals()
test_vals = zip(x_vals, axav_vals, np.full(len(x_vals), tolerance))


@pytest.mark.parametrize("xtest_vals", test_vals)
def test_extinction_F20_single_values(xtest_vals):
x, cor_val, tolerance = xtest_vals

# initialize extinction model
tmodel = F20()

# test
np.testing.assert_allclose(tmodel(x), cor_val, rtol=tolerance)
np.testing.assert_allclose(tmodel.evaluate(x, 3.1), cor_val, rtol=tolerance)
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ auto_use = True

[flake8]
exclude = extern,sphinx,*parsetab.py
ignore = E501, W503

[pycodestyle]
exclude = extern,sphinx,*parsetab.py
Expand Down

0 comments on commit af90d1d

Please sign in to comment.