Skip to content

Commit

Permalink
Merge pull request #24 from karllark/setup_fitable
Browse files Browse the repository at this point in the history
Setup FM90 fittable
  • Loading branch information
karllark committed Sep 18, 2017
2 parents 6750d93 + 46496bb commit eb0e8ae
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 51 deletions.
57 changes: 14 additions & 43 deletions ah_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,24 @@
use_setuptools()


# typing as a dependency for 1.6.1+ Sphinx causes issues when imported after
# initializing submodule with ah_boostrap.py
# See discussion and references in
# https://github.com/astropy/astropy-helpers/issues/302

try:
import typing # noqa
except ImportError:
pass


# Note: The following import is required as a workaround to
# https://github.com/astropy/astropy-helpers/issues/89; if we don't import this
# module now, it will get cleaned up after `run_setup` is called, but that will
# later cause the TemporaryDirectory class defined in it to stop working when
# used later on by setuptools
try:
import setuptools.py31compat
import setuptools.py31compat # noqa
except ImportError:
pass

Expand Down Expand Up @@ -702,7 +713,7 @@ def _update_submodule(self, submodule, status):
if self.offline:
cmd.append('--no-fetch')
elif status == 'U':
raise _AHBoostrapSystemExit(
raise _AHBootstrapSystemExit(
'Error: Submodule {0} contains unresolved merge conflicts. '
'Please complete or abandon any changes in the submodule so that '
'it is in a usable state, then try again.'.format(submodule))
Expand Down Expand Up @@ -763,7 +774,7 @@ def run_cmd(cmd):
msg = 'Command not found: `{0}`'.format(' '.join(cmd))
raise _CommandNotFound(msg, cmd)
else:
raise _AHBoostrapSystemExit(
raise _AHBootstrapSystemExit(
'An unexpected error occurred when running the '
'`{0}` command:\n{1}'.format(' '.join(cmd), str(e)))

Expand Down Expand Up @@ -878,46 +889,6 @@ def __init__(self, *args):
super(_AHBootstrapSystemExit, self).__init__(msg, *args[1:])


if sys.version_info[:2] < (2, 7):
# In Python 2.6 the distutils log does not log warnings, errors, etc. to
# stderr so we have to wrap it to ensure consistency at least in this
# module
import distutils

class log(object):
def __getattr__(self, attr):
return getattr(distutils.log, attr)

def warn(self, msg, *args):
self._log_to_stderr(distutils.log.WARN, msg, *args)

def error(self, msg):
self._log_to_stderr(distutils.log.ERROR, msg, *args)

def fatal(self, msg):
self._log_to_stderr(distutils.log.FATAL, msg, *args)

def log(self, level, msg, *args):
if level in (distutils.log.WARN, distutils.log.ERROR,
distutils.log.FATAL):
self._log_to_stderr(level, msg, *args)
else:
distutils.log.log(level, msg, *args)

def _log_to_stderr(self, level, msg, *args):
# This is the only truly 'public' way to get the current threshold
# of the log
current_threshold = distutils.log.set_threshold(distutils.log.WARN)
distutils.log.set_threshold(current_threshold)
if level >= current_threshold:
if args:
msg = msg % args
sys.stderr.write('%s\n' % msg)
sys.stderr.flush()

log = log()


BOOTSTRAPPER = _Bootstrapper.main()


Expand Down
93 changes: 87 additions & 6 deletions dust_extinction/dust_extinction.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
from scipy import interpolate

import astropy.units as u
from astropy.modeling import Model, Parameter, InputParameterError
from astropy.modeling import (Model, Fittable1DModel,
Parameter, InputParameterError)

__all__ = ['BaseExtModel','BaseExtRvModel',
__all__ = ['BaseExtModel','BaseExtRvModel', 'BaseExtAve',
'CCM89', 'FM90', 'F99',
'G03_SMCBar', 'G03_LMCAvg', 'G03_LMC2',
'G16']
Expand Down Expand Up @@ -211,6 +212,52 @@ def extinguish(self, x, Av=None, Ebv=None):
# return fractional extinction
return np.power(10.0,-0.4*axav*Av)

class BaseExtAve(Model):
"""
Base Extinction Average. Do not use.
"""
inputs = ('x',)
outputs = ('axav',)

def extinguish(self, x, Av=None, Ebv=None):
"""
Calculate the extinction as a fraction
Parameters
----------
x: float
expects either x in units of wavelengths or frequency
or assumes wavelengths in wavenumbers [1/micron]
internally wavenumbers are used
Av: float
A(V) value of dust column
Av or Ebv must be set
Ebv: float
E(B-V) value of dust column
Av or Ebv must be set
Returns
-------
frac_ext: np array (float)
fractional extinction as a function of x
"""
# get the extinction curve
axav = self(x)

# check that av or ebv is set
if (Av is None) and (Ebv is None):
raise InputParameterError('neither Av or Ebv passed, one required')

# if Av is not set and Ebv set, convert to Av
if Av is None:
Av = self.Rv*Ebv

# return fractional extinction
return np.power(10.0,-0.4*axav*Av)

class BaseExtRvModel(BaseExtModel):
"""
Base Extinction R(V)-dependent Model. Do not use.
Expand Down Expand Up @@ -371,7 +418,7 @@ def evaluate(in_x, Rv):
# return A(x)/A(V)
return a + b/Rv

class FM90(Model):
class FM90(Fittable1DModel):
"""
FM90 extinction model calculation
Expand Down Expand Up @@ -506,6 +553,40 @@ def evaluate(in_x, C1, C2, C3, C4, xo, gamma):
# return E(x-V)/E(B-V)
return exvebv

@staticmethod
def fit_deriv(in_x, C1, C2, C3, C4, xo, gamma):
"""
Derivatives of the FM90 function with respect to the parameters
"""
x = in_x

# useful quantitites
x2 = x**2
xo2 = xo**2
g2 = gamma**2
x2mxo2_2 = (x2 - xo2)**2
denom = (x2mxo2_2 - x2*g2)**2

# derivatives
d_C1 = np.full((len(x)),1.)
d_C2 = x

d_C3 = (x2/(x2mxo2_2 + x2*g2))

d_xo = (4.*C2*x2*xo*(x2 - xo2))/denom

d_gamma = (2.*C2*(x2**2)*gamma)/denom

d_C4 = np.zeros((len(x)))
fuv_indxs = np.where(x >= 5.9)
if len(fuv_indxs) > 0:
y = x[fuv_indxs] - 5.9
d_C4[fuv_indxs] = (0.5392*(y**2) + 0.05644*(y**3))

return [d_C1, d_C2, d_C3, d_C4, d_xo, d_gamma]

#fit_deriv = None

class F99(BaseExtRvModel):
"""
F99 extinction model calculation
Expand Down Expand Up @@ -616,7 +697,7 @@ def evaluate(self, in_x, Rv):
optnir_axav_x, optnir_axebv_y/Rv,
self.x_range, 'F99')

class G03_SMCBar(BaseExtModel):
class G03_SMCBar(BaseExtAve):
"""
G03 SMCBar Average Extinction Curve
Expand Down Expand Up @@ -728,7 +809,7 @@ def evaluate(self, in_x):
optnir_axav_x, optnir_axav_y,
self.x_range, 'G03')

class G03_LMCAvg(BaseExtModel):
class G03_LMCAvg(BaseExtAve):
"""
G03 LMCAvg Average Extinction Curve
Expand Down Expand Up @@ -839,7 +920,7 @@ def evaluate(self, in_x):
self.x_range, 'G03')


class G03_LMC2(BaseExtModel):
class G03_LMC2(BaseExtAve):
"""
G03 LMC2 Average Extinction Curve
Expand Down
26 changes: 25 additions & 1 deletion dust_extinction/tests/test_fm90.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

import astropy.units as u
from astropy.modeling import InputParameterError
from astropy.modeling.fitting import LevMarLSQFitter

from ..dust_extinction import FM90
from ..dust_extinction import FM90, G03_LMCAvg

x_bad = [-1.0, 0.2, 3.0, 11.0, 100.]
@pytest.mark.parametrize("x_invalid", x_bad)
Expand Down Expand Up @@ -78,3 +79,26 @@ def test_extinction_FM90_values():

# test
np.testing.assert_allclose(tmodel(x), cor_vals)


def test_FM90_fitting():

# get an observed extinction curve to fit
g03_model = G03_LMCAvg()

x = g03_model.obsdata_x
# convert to E(x-V)/E(B0V)
y = (g03_model.obsdata_axav - 1.0)*g03_model.Rv
# only fit the UV portion (FM90 only valid in UV)
gindxs, = np.where(x > 3.125)

fm90_init = FM90()
fit = LevMarLSQFitter()
g03_fit = fit(fm90_init, x[gindxs], y[gindxs])
fit_vals = [g03_fit.C1.value, g03_fit.C2.value, g03_fit.C3.value,
g03_fit.C4.value, g03_fit.xo.value, g03_fit.gamma.value]

good_vals = np.array([-0.958016797002, 1.0109751831, 2.96430606652,
0.313137860902, 4.59996300532, 0.99000982258])

np.testing.assert_allclose(good_vals, fit_vals)
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ doctest_plus = enabled
[ah_bootstrap]
auto_use = True

[pycodestyle]
[pep8]
#[pycodestyle]
# E101 - mix of tabs and spaces
# W191 - use of tabs
# W291 - trailing whitespace
Expand Down

0 comments on commit eb0e8ae

Please sign in to comment.