Skip to content

Commit

Permalink
Moved wrapper classes into .wrappers sub-module.
Browse files Browse the repository at this point in the history
Also move the array_index-related methods to the base classes since they can be
effectively be the same for all implementations
  • Loading branch information
astrofrog authored and Cadair committed May 4, 2020
1 parent b0b5d7e commit ec55022
Show file tree
Hide file tree
Showing 14 changed files with 473 additions and 321 deletions.
6 changes: 0 additions & 6 deletions astropy/visualization/wcsaxes/tests/test_wcsapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,16 +381,10 @@ def pixel_to_world_values(self, *pixel_arrays):
pixel_arrays = (list(pixel_arrays) * 3)[:-1] # make list have 5 elements
return [np.asarray(pix) * scale for pix, scale in zip(pixel_arrays, [10, 0.2, 0.4, 0.39, 2])]

def array_index_to_world_values(self, *index_arrays):
return self.pixel_to_world_values(index_arrays[::-1])[::-1]

def world_to_pixel_values(self, *world_arrays):
world_arrays = world_arrays[:2] # make list have 2 elements
return [np.asarray(world) / scale for world, scale in zip(world_arrays, [10, 0.2])]

def world_to_array_index_values(self, *world_arrays):
return np.round(self.world_to_array_index_values(world_arrays[::-1])[::-1]).astype(int)

@property
def world_axis_object_components(self):
return [('freq', 0, 'value'),
Expand Down
10 changes: 5 additions & 5 deletions astropy/wcs/wcsapi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .low_level_api import *
from .high_level_api import *
from .high_level_wcs_wrapper import *
from .utils import *
from .sliced_low_level_wcs import *
from .low_level_api import * # noqa
from .high_level_api import * # noqa
from .high_level_wcs_wrapper import * # noqa
from .utils import * # noqa
from .wrappers import * # noqa
175 changes: 175 additions & 0 deletions astropy/wcs/wcsapi/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import pytest
import numpy as np

from astropy.coordinates import SkyCoord
from astropy.units import Quantity
from astropy.wcs import WCS
from astropy.wcs.wcsapi import BaseLowLevelWCS


@pytest.fixture
def spectral_1d_fitswcs():
wcs = WCS(naxis=1)
wcs.wcs.ctype = 'FREQ',
wcs.wcs.cunit = 'Hz',
wcs.wcs.cdelt = 3.e9,
wcs.wcs.crval = 4.e9,
wcs.wcs.crpix = 11.,
wcs.wcs.cname = 'Frequency',
return wcs


@pytest.fixture
def time_1d_fitswcs():
wcs = WCS(naxis=1)
wcs.wcs.ctype = 'TIME',
wcs.wcs.mjdref = 30042,
wcs.wcs.crval = 3.,
wcs.wcs.crpix = 11.,
wcs.wcs.cname = 'Frequency',
return wcs


@pytest.fixture
def celestial_2d_fitswcs():
wcs = WCS(naxis=2)
wcs.wcs.ctype = 'RA---CAR', 'DEC--CAR'
wcs.wcs.cunit = 'deg', 'deg'
wcs.wcs.cdelt = -2., 2.
wcs.wcs.crval = 4., 0.
wcs.wcs.crpix = 6., 7.
wcs.wcs.cname = 'Right Ascension', 'Declination'
wcs.pixel_shape = (6, 7)
wcs.pixel_bounds = [(-1, 5), (1, 7)]
return wcs


@pytest.fixture
def spectral_cube_3d_fitswcs():
wcs = WCS(naxis=3)
wcs.wcs.ctype = 'RA---CAR', 'DEC--CAR', 'FREQ'
wcs.wcs.cunit = 'deg', 'deg', 'Hz'
wcs.wcs.cdelt = -2., 2., 3.e9
wcs.wcs.crval = 4., 0., 4.e9
wcs.wcs.crpix = 6., 7., 11.
wcs.wcs.cname = 'Right Ascension', 'Declination', 'Frequency'
wcs.pixel_shape = (6, 7, 3)
wcs.pixel_bounds = [(-1, 5), (1, 7), (1, 2.5)]
return wcs



class Spectral1DLowLevelWCS(BaseLowLevelWCS):

@property
def pixel_n_dim(self):
return 1

@property
def world_n_dim(self):
return 1

@property
def world_axis_physical_types(self):
return 'em.freq',

@property
def world_axis_units(self):
return 'Hz',

@property
def world_axis_names(self):
return 'Frequency',

_pixel_shape = None

@property
def pixel_shape(self):
return self._pixel_shape

@pixel_shape.setter
def pixel_shape(self, value):
self._pixel_shape = value

_pixel_bounds = None

@property
def pixel_bounds(self):
return self._pixel_bounds

@pixel_bounds.setter
def pixel_bounds(self, value):
self._pixel_bounds = value

def pixel_to_world_values(self, pixel_array):
return np.asarray(pixel_array - 10) * 3e9 + 4e9

def world_to_pixel_values(self, world_array):
return np.asarray(world_array - 4e9) / 3e9 + 10

@property
def world_axis_object_components(self):
return ('test', 0, 'value'),

@property
def world_axis_object_classes(self):
return {'test': (Quantity, (), {'unit': 'Hz'})}



@pytest.fixture
def spectral_1d_ape14_wcs():
return Spectral1DLowLevelWCS()


class Celestial2DLowLevelWCS(BaseLowLevelWCS):

@property
def pixel_n_dim(self):
return 2

@property
def world_n_dim(self):
return 2

@property
def world_axis_physical_types(self):
return 'pos.eq.ra', 'pos.eq.dec'

@property
def world_axis_units(self):
return 'deg', 'deg'

@property
def world_axis_names(self):
return 'Right Ascension', 'Declination'

@property
def pixel_shape(self):
return (6, 7)

@property
def pixel_bounds(self):
return (-1, 5), (1, 7)

def pixel_to_world_values(self, px, py):
return (-(np.asarray(px) - 5.) * 2 + 4.,
(np.asarray(py) - 6.) * 2)

def world_to_pixel_values(self, wx, wy):
return (-(np.asarray(wx) - 4.) / 2 + 5.,
np.asarray(wy) / 2 + 6.)

@property
def world_axis_object_components(self):
return [('test', 0, 'spherical.lon.degree'),
('test', 1, 'spherical.lat.degree')]

@property
def world_axis_object_classes(self):
return {'test': (SkyCoord, (), {'unit': 'deg'})}


@pytest.fixture
def celestial_2d_ape14_wcs():
return Celestial2DLowLevelWCS()
23 changes: 5 additions & 18 deletions astropy/wcs/wcsapi/fitswcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from .low_level_api import BaseLowLevelWCS
from .high_level_api import HighLevelWCSMixin
from .sliced_low_level_wcs import SlicedLowLevelWCS
from .wrappers import SlicedLowLevelWCS

__all__ = ['custom_ctype_to_ucd_mapping', 'SlicedFITSWCS', 'FITSWCSAPIMixin']

Expand Down Expand Up @@ -200,21 +200,17 @@ def world_n_dim(self):

@property
def array_shape(self):
if self._naxis == [0, 0]:
if self.pixel_shape is None:
return None
else:
return tuple(self._naxis[::-1])
return self.pixel_shape[::-1]

@array_shape.setter
def array_shape(self, value):
if value is None:
self._naxis = [0, 0]
self.pixel_shape = None
else:
if len(value) != self.naxis:
raise ValueError("The number of data axes, "
"{}, does not equal the "
"shape {}.".format(self.naxis, len(value)))
self._naxis = list(value)[::-1]
self.pixel_shape = value[::-1]

@property
def pixel_shape(self):
Expand Down Expand Up @@ -317,19 +313,10 @@ def pixel_to_world_values(self, *pixel_arrays):
world = self.all_pix2world(*pixel_arrays, 0)
return world[0] if self.world_n_dim == 1 else tuple(world)

def array_index_to_world_values(self, *indices):
world = self.all_pix2world(*indices[::-1], 0)
return world[0] if self.world_n_dim == 1 else tuple(world)

def world_to_pixel_values(self, *world_arrays):
pixel = self.all_world2pix(*world_arrays, 0)
return pixel[0] if self.pixel_n_dim == 1 else tuple(pixel)

def world_to_array_index_values(self, *world_arrays):
pixel_arrays = self.all_world2pix(*world_arrays, 0)[::-1]
array_indices = tuple(np.asarray(np.floor(pixel + 0.5), dtype=np.int_) for pixel in pixel_arrays)
return array_indices[0] if self.pixel_n_dim == 1 else array_indices

@property
def world_axis_object_components(self):
return self._get_components_and_classes()[0]
Expand Down
16 changes: 5 additions & 11 deletions astropy/wcs/wcsapi/high_level_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def pixel_to_world(self, *pixel_arrays):
indexing and ordering conventions.
"""

@abc.abstractmethod
def array_index_to_world(self, *index_arrays):
"""
Convert array indices to world coordinates (represented by Astropy
Expand All @@ -64,6 +63,7 @@ def array_index_to_world(self, *index_arrays):
`~astropy.wcs.wcsapi.BaseLowLevelWCS.array_index_to_world_values` for
pixel indexing and ordering conventions.
"""
return self.pixel_to_world(*index_arrays[::-1])

@abc.abstractmethod
def world_to_pixel(self, *world_objects):
Expand All @@ -78,7 +78,6 @@ def world_to_pixel(self, *world_objects):
indexing and ordering conventions.
"""

@abc.abstractmethod
def world_to_array_index(self, *world_objects):
"""
Convert world coordinates (represented by Astropy objects) to array
Expand All @@ -91,6 +90,10 @@ def world_to_array_index(self, *world_objects):
pixel indexing and ordering conventions. The indices should be returned
as rounded integers.
"""
if self.pixel_n_dim == 1:
return np.round(self.world_to_pixel(*world_objects)).astype(int)
else:
return tuple(np.round(self.world_to_pixel(*world_objects)[::-1]).astype(int).tolist())


class HighLevelWCSMixin(BaseHighLevelWCS):
Expand Down Expand Up @@ -255,12 +258,3 @@ def pixel_to_world(self, *pixel_arrays):
return result[0]
else:
return result

def array_index_to_world(self, *index_arrays):
return self.pixel_to_world(*index_arrays[::-1])

def world_to_array_index(self, *world_objects):
if self.pixel_n_dim == 1:
return np.round(self.world_to_pixel(*world_objects)).astype(int)
else:
return tuple(np.round(self.world_to_pixel(*world_objects)[::-1]).astype(int).tolist())
15 changes: 12 additions & 3 deletions astropy/wcs/wcsapi/low_level_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def pixel_to_world_values(self, *pixel_arrays):
arrays is returned.
"""

@abc.abstractmethod
def array_index_to_world_values(self, *index_arrays):
"""
Convert array indices to world coordinates.
Expand All @@ -88,6 +87,7 @@ def array_index_to_world_values(self, *index_arrays):
method returns a single scalar or array, otherwise a tuple of scalars or
arrays is returned.
"""
return self.pixel_to_world_values(*index_arrays[::-1])

@abc.abstractmethod
def world_to_pixel_values(self, *world_arrays):
Expand All @@ -108,7 +108,6 @@ def world_to_pixel_values(self, *world_arrays):
arrays is returned.
"""

@abc.abstractmethod
def world_to_array_index_values(self, *world_arrays):
"""
Convert world coordinates to array indices.
Expand All @@ -123,6 +122,13 @@ def world_to_array_index_values(self, *world_arrays):
method returns a single scalar or array, otherwise a tuple of scalars or
arrays is returned.
"""
pixel_arrays = self.world_to_pixel_values(*world_arrays)
if self.pixel_n_dim == 1:
pixel_arrays = (pixel_arrays,)
else:
pixel_arrays = pixel_arrays[::-1]
array_indices = tuple(np.asarray(np.floor(pixel + 0.5), dtype=np.int_) for pixel in pixel_arrays)
return array_indices[0] if self.pixel_n_dim == 1 else array_indices

@property
@abc.abstractmethod
Expand Down Expand Up @@ -229,7 +235,10 @@ def array_shape(self):
objects. This is an optional property, and it should return `None`
if a shape is not known or relevant.
"""
return None
if self.pixel_shape is None:
return None
else:
return self.pixel_shape[::-1]

@property
def pixel_shape(self):
Expand Down

0 comments on commit ec55022

Please sign in to comment.