Permalink
Browse files

Merge pull request #1685 from adonath/fix_issue_#1646

Add quantity support to MapCoord and MapCoord based methods
  • Loading branch information...
adonath committed Aug 9, 2018
2 parents fb23f44 + a02e0dc commit a58e939dea4a44a116722f6fb92e3bd3fcd44de9
Showing with 98 additions and 73 deletions.
  1. +0 −1 gammapy/maps/base.py
  2. +37 −38 gammapy/maps/geom.py
  3. +6 −2 gammapy/maps/tests/test_geom.py
  4. +32 −0 gammapy/maps/tests/test_wcsnd.py
  5. +22 −31 gammapy/maps/wcs.py
  6. +1 −1 gammapy/maps/wcsnd.py
View
@@ -621,7 +621,6 @@ def get_image_by_coord(self, coords, copy=True):
idx = []
for axis, value in zip(self.geom.axes, coords.values()):
value = u.Quantity(value).to(axis.unit).value
idx.append(axis.coord_to_idx(value))
return self.get_image_by_idx(idx, copy=copy)
View
@@ -211,7 +211,7 @@ def pix_tuple_to_idx(pix, copy=False):
"""
idx = []
for p in pix:
p = np.array(p, copy=copy, ndmin=1)
p = np.array(p, ndmin=1)
if np.issubdtype(p.dtype, np.integer):
idx += [p]
else:
@@ -281,7 +281,7 @@ def fn(t):
interp_fn = interp1d(
fn(edges),
np.arange(len(edges)).astype(float),
np.arange(len(edges), dtype=float),
fill_value='extrapolate',
)
@@ -311,7 +311,7 @@ def fn1(t):
raise ValueError('Invalid interp: {}'.format(interp))
interp_fn = interp1d(
np.arange(len(edges)).astype(float),
np.arange(len(edges), dtype=float),
fn0(edges),
fill_value='extrapolate',
)
@@ -373,10 +373,7 @@ def __init__(self, nodes, interp='lin', name='',
else:
raise ValueError('Invalid node type: {}'.format(node_type))
pix = np.arange(nbin, dtype=float)
self._center = self.pix_to_coord(pix)
pix = np.arange(nbin + 1, dtype=float) - 0.5
self._bin_edges = self.pix_to_coord(pix)
self._nbin = nbin
def __eq__(self, other):
if isinstance(other, self.__class__):
@@ -403,17 +400,19 @@ def name(self, val):
@property
def edges(self):
"""Return array of bin edges."""
return self._bin_edges
pix = np.arange(self.nbin + 1, dtype=float) - 0.5
return self.pix_to_coord(pix)
@property
def center(self):
"""Return array of bin centers."""
return self._center
pix = np.arange(self.nbin, dtype=float)
return self.pix_to_coord(pix)
@property
def nbin(self):
"""Return number of bins."""
return len(self._bin_edges) - 1
return self._nbin
@property
def node_type(self):
@@ -527,7 +526,7 @@ def pix_to_coord(self, pix):
Array of axis coordinate values.
"""
pix = pix - self._pix_offset
return pix_to_coord(self._nodes, pix, interp=self._interp)
return pix_to_coord(self._nodes, pix, interp=self._interp)
def coord_to_pix(self, coord):
"""Transform from axis to pixel coordinates.
@@ -542,6 +541,7 @@ def coord_to_pix(self, coord):
pix : `~numpy.ndarray`
Array of pixel coordinate values.
"""
coord = u.Quantity(coord, self.unit).value
pix = coord_to_pix(self._nodes, coord, interp=self._interp)
return np.array(pix + self._pix_offset, ndmin=1)
@@ -562,6 +562,7 @@ def coord_to_idx(self, coord, clip=False):
idx : `~numpy.ndarray`
Array of bin indices.
"""
coord = u.Quantity(coord, self.unit).value
return coord_to_idx(self.edges, coord, clip)
def coord_to_idx_interp(self, coord):
@@ -572,6 +573,7 @@ def coord_to_idx_interp(self, coord):
coord : `~numpy.ndarray`
Array of axis coordinate values.
"""
coord = u.Quantity(coord, self.unit).value
return (coord_to_idx(self.center[:-1], coord, clip=True),
coord_to_idx(self.center[:-1], coord, clip=True) + 1,)
@@ -625,25 +627,25 @@ class MapCoord(object):
coordsys : {'CEL', 'GAL', None}
Spatial coordinate system. If None then the coordinate system
will be set to the native coordinate system of the geometry.
copy : bool
Make copies of the input arrays?
If False then this object will store views.
match_by_name : bool
Match coordinates to axes by name?
If false coordinates will be matched by index.
"""
def __init__(self, data, coordsys=None, copy=False, match_by_name=True):
def __init__(self, data, coordsys=None, match_by_name=True):
if 'lon' not in data or 'lat' not in data:
raise ValueError("data dictionary must contain axes named 'lon' and 'lat'.")
self._data = OrderedDict([
(k, np.array(v, ndmin=1, copy=copy))
if issubclass(data['lon'].__class__, u.Quantity) or issubclass(data['lat'].__class__, u.Quantity):
raise ValueError('No quantities supported.')
data = OrderedDict([
(k, np.atleast_1d(np.asanyarray(v)))
for k, v in data.items()
])
vals = np.broadcast_arrays(*self._data.values())
self._data = OrderedDict(zip(self._data.keys(), vals))
vals = np.broadcast_arrays(*data.values(), subok=True)
self._data = OrderedDict(zip(data.keys(), vals))
self._coordsys = coordsys
self._match_by_name = match_by_name
@@ -698,7 +700,7 @@ def skycoord(self):
frame=coordsys_to_frame(self.coordsys))
@classmethod
def _from_lonlat(cls, coords, coordsys=None, copy=False):
def _from_lonlat(cls, coords, coordsys=None):
"""Create a `~MapCoord` from a tuple of coordinate vectors.
The first two elements of the tuple should be longitude and latitude in degrees.
@@ -721,11 +723,11 @@ def _from_lonlat(cls, coords, coordsys=None, copy=False):
else:
raise ValueError('Unrecognized input type.')
return cls(coords_dict, coordsys=coordsys, copy=copy,
return cls(coords_dict, coordsys=coordsys,
match_by_name=False)
@classmethod
def _from_skycoord(cls, coords, coordsys=None, copy=False):
def _from_skycoord(cls, coords, coordsys=None):
"""Create from vector of `~astropy.coordinates.SkyCoord`.
Parameters
@@ -741,10 +743,10 @@ def _from_skycoord(cls, coords, coordsys=None, copy=False):
skycoord = coords[0]
if skycoord.frame.name in ['icrs', 'fk5']:
coords = (skycoord.ra.deg, skycoord.dec.deg) + coords[1:]
coords = cls._from_lonlat(coords, coordsys='CEL', copy=copy)
coords = cls._from_lonlat(coords, coordsys='CEL')
elif skycoord.frame.name in ['galactic']:
coords = (skycoord.l.deg, skycoord.b.deg) + coords[1:]
coords = cls._from_lonlat(coords, coordsys='GAL', copy=copy)
coords = cls._from_lonlat(coords, coordsys='GAL')
else:
raise ValueError(
'Unrecognized coordinate frame: {}'.format(skycoord.frame.name))
@@ -755,20 +757,20 @@ def _from_skycoord(cls, coords, coordsys=None, copy=False):
return coords.to_coordsys(coordsys)
@classmethod
def _from_tuple(cls, coords, coordsys=None, copy=False):
def _from_tuple(cls, coords, coordsys=None):
"""Create from tuple of coordinate vectors."""
if isinstance(coords[0], (list, np.ndarray)) or np.isscalar(coords[0]):
return cls._from_lonlat(coords, coordsys=coordsys, copy=copy)
return cls._from_lonlat(coords, coordsys=coordsys)
elif isinstance(coords[0], SkyCoord):
return cls._from_skycoord(coords, coordsys=coordsys, copy=copy)
return cls._from_skycoord(coords, coordsys=coordsys)
else:
raise TypeError('Type not supported: {}'.format(type(coords)))
@classmethod
def _from_dict(cls, coords, coordsys=None, copy=False):
def _from_dict(cls, coords, coordsys=None):
"""Create from a dictionary of coordinate vectors."""
if 'lon' in coords and 'lat' in coords:
return cls(coords, coordsys=coordsys, copy=copy)
return cls(coords, coordsys=coordsys)
elif 'skycoord' in coords:
coords_dict = OrderedDict()
lon, lat, frame = skycoord_to_lonlat(
@@ -779,13 +781,13 @@ def _from_dict(cls, coords, coordsys=None, copy=False):
if k == 'skycoord':
continue
coords_dict[k] = v
return cls(coords_dict, coordsys=coordsys, copy=copy)
return cls(coords_dict, coordsys=coordsys)
else:
raise ValueError("Dictionary must contain axes named 'lon'/'lat'"
"or 'skycoord'.")
@classmethod
def create(cls, data, coordsys=None, copy=False):
def create(cls, data, coordsys=None):
"""Create a new `~MapCoord` object.
This method can be used to create either unnamed (with tuple input)
@@ -796,12 +798,9 @@ def create(cls, data, coordsys=None, copy=False):
data : `tuple`, `dict`, `~MapCoord` or `~astropy.coordinates.SkyCoord`
Object containing coordinate arrays.
coordsys : {'CEL', 'GAL', None}, optional
Set the coordinate system for longitude and latitude. If
Set the coordinate system for longitude and latitude. If
None longitude and latitude will be assumed to be in
the coordinate system native to a given map geometry.
copy : bool
Make copies of the input coordinate arrays. If False this
object will store views.
Examples
--------
@@ -824,11 +823,11 @@ def create(cls, data, coordsys=None, copy=False):
else:
return data.to_coordsys(coordsys)
elif isinstance(data, dict):
return cls._from_dict(data, coordsys=coordsys, copy=copy)
return cls._from_dict(data, coordsys=coordsys)
elif isinstance(data, (list, tuple)):
return cls._from_tuple(data, coordsys=coordsys, copy=copy)
return cls._from_tuple(data, coordsys=coordsys)
elif isinstance(data, SkyCoord):
return cls._from_skycoord((data,), coordsys=coordsys, copy=copy)
return cls._from_skycoord((data,), coordsys=coordsys)
else:
raise TypeError('Unsupported input type: {}'.format(type(data)))
@@ -4,6 +4,7 @@
from collections import OrderedDict
import numpy as np
from numpy.testing import assert_allclose
from astropy import units as u
from astropy.coordinates import SkyCoord
from ..geom import MapAxis, MapCoord
@@ -175,8 +176,7 @@ def test_mapcoords_create():
assert coords.ndim == 3
# 3D OrderedDict w/ vectors
coords = MapCoord.create(OrderedDict([('energy', energy),
('lat', lat), ('lon', lon)]))
coords = MapCoord.create(dict(energy=energy, lat=lat, lon=lon))
assert_allclose(coords.lon, lon)
assert_allclose(coords.lat, lat)
assert_allclose(coords['energy'], energy)
@@ -185,6 +185,10 @@ def test_mapcoords_create():
assert_allclose(coords[2], lon)
assert coords.ndim == 3
# Quantities
coords = MapCoord.create(dict(energy=energy * u.TeV, lat=lat, lon=lon))
assert coords['energy'].unit == 'TeV'
def test_mapcoords_to_coordsys():
lon, lat = np.array([0.0, 1.0]), np.array([2.0, 3.0])
@@ -192,6 +192,22 @@ def test_wcsndmap_set_get_by_coord(npix, binsz, coordsys, proj, skydir, axes):
assert_allclose(coords[0], m.get_by_coord(map_coords))
def test_set_get_by_coord_quantities():
ax = MapAxis(np.logspace(0., 3., 3), interp='log', name='energy', unit='TeV')
geom = WcsGeom.create(binsz=0.1, npix=(3, 4), axes=[ax])
m = WcsNDMap(geom)
coords_dict = {
'lon': 0,
'lat': 0,
'energy': 1000 * u.GeV
}
m.set_by_coord(coords_dict, 42)
coords_dict['energy'] = 1 * u.TeV
assert_allclose(42, m.get_by_coord(coords_dict))
@pytest.mark.parametrize(('npix', 'binsz', 'coordsys', 'proj', 'skydir', 'axes'),
wcs_test_geoms)
def test_wcsndmap_fill_by_coord(npix, binsz, coordsys, proj, skydir, axes):
@@ -247,6 +263,22 @@ def test_wcsndmap_interp_by_coord(npix, binsz, coordsys, proj, skydir, axes):
assert_allclose(coords[1], m.interp_by_coord(coords, interp='cubic'))
def test_interp_by_coord_quantities():
ax = MapAxis(np.logspace(0., 3., 3), interp='log', name='energy', unit='TeV')
geom = WcsGeom.create(binsz=0.1, npix=(3, 3), axes=[ax])
m = WcsNDMap(geom)
coords_dict = {
'lon': 0,
'lat': 0,
'energy': 1000 * u.GeV
}
m.set_by_coord(coords_dict, 42)
coords_dict['energy'] = 1 * u.TeV
assert_allclose(42, m.interp_by_coord(coords_dict, interp='nearest'))
def test_wcsndmap_interp_by_coord_fill_value():
# Introduced in https://github.com/gammapy/gammapy/pull/1559/files
m = Map.create(npix=(20, 10))
Oops, something went wrong.

0 comments on commit a58e939

Please sign in to comment.