Skip to content

Commit

Permalink
Merge pull request #91 from adl1995/HipsDrawResult
Browse files Browse the repository at this point in the history
Introduce class HipsDrawResult in simple.py
  • Loading branch information
cdeil committed Jul 26, 2017
2 parents a357718 + 821f335 commit 1a2c763
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 21 deletions.
24 changes: 18 additions & 6 deletions docs/getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,22 +37,34 @@ To make a sky image with the `hips` package, follow the following three steps:


3. Call the `~hips.make_sky_image` function to fetch the HiPS data
and draw it, returning the sky image pixel data as a Numpy array::
and draw it, returning an object of `~hips.HipsDrawResult`::

from hips import make_sky_image

data = make_sky_image(geometry, hips_survey, 'fits')
result = make_sky_image(geometry, hips_survey, 'fits')


That's it. Go ahead and try it out for your favourite sky region and survey.

Now you can then save the sky image to local disk e.g. FITS file format::

from astropy.io import fits
hdu = fits.PrimaryHDU(data=data, header=geometry.fits_header)
hdu.writeto('my_image.fits')
result.write_image('my_image.fits')

or plot and analyse the sky image however you like.
The ``result`` object also contains other useful information, such as::

result.image

will return a NumPy array containing pixel data, you can also get the WCS information using::

result.geometry

If you want, you could also print out information about the ``result``::

print(result)

or plot and analyse the sky image using::

result.plot()

If you execute the example above, you will get this sky image which was plotted using `astropy.visualization.wcsaxes`

Expand Down
6 changes: 3 additions & 3 deletions docs/plot_fits.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
width=2000, height=1000, fov="3 deg",
coordsys='galactic', projection='AIT',
)
image = make_sky_image(geometry=geometry, hips_survey=hips_survey, tile_format='fits')
result = make_sky_image(geometry=geometry, hips_survey=hips_survey, tile_format='fits')

# Draw the sky image
import matplotlib.pyplot as plt
from astropy.visualization.mpl_normalize import simple_norm
ax = plt.subplot(projection=geometry.wcs)
norm = simple_norm(image, 'sqrt', min_percent=1, max_percent=99)
ax.imshow(image, origin='lower', norm=norm, cmap='gray')
norm = simple_norm(result.image, 'sqrt', min_percent=1, max_percent=99)
ax.imshow(result.image, origin='lower', norm=norm, cmap='gray')
4 changes: 2 additions & 2 deletions docs/plot_jpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
width=2000, height=1000, fov="3 deg",
coordsys='galactic', projection='AIT',
)
image = make_sky_image(geometry=geometry, hips_survey=hips_survey, tile_format='jpg')
result = make_sky_image(geometry=geometry, hips_survey=hips_survey, tile_format='jpg')

# Draw the sky image
import matplotlib.pyplot as plt
ax = plt.subplot(projection=geometry.wcs)
ax.imshow(image, origin='lower')
ax.imshow(result.image, origin='lower')
76 changes: 74 additions & 2 deletions hips/draw/simple.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""HiPS tile drawing -- simple method."""
import numpy as np
from PIL import Image
from astropy.io import fits
from typing import List, Tuple
from astropy.wcs.utils import proj_plane_pixel_scales
from skimage.transform import ProjectiveTransform, warp
Expand Down Expand Up @@ -122,6 +124,11 @@ def tiles(self) -> List[HipsTile]:

return self._tiles

@property
def result(self) -> 'HipsDrawResult':
"""Return an object of `~hips.HipsDrawResult` class."""
return HipsDrawResult(self.image, self.geometry, self.tile_format, self.tiles)

def warp_image(self, tile: HipsTile) -> np.ndarray:
"""Warp a HiPS tile and a sky image."""
return warp(
Expand Down Expand Up @@ -169,6 +176,71 @@ def plot_mpl_hips_tile_grid(self) -> None:
ax.imshow(self.image, origin='lower')


class HipsDrawResult:
"""Container class for reporting information related with fetching / drawing of HiPS tiles.
Parameters
----------
image: `~numpy.ndarray`
Container for HiPS tile data
geometry : `~hips.utils.WCSGeometry`
An object of WCSGeometry
tile_format : {'fits', 'jpg', 'png'}
Format of HiPS tile
tiles: List[HipsTile]
"""

def __init__(self, image: np.ndarray, geometry: WCSGeometry, tile_format: str, tiles: List[HipsTile]) -> None:
self.image = image
self.geometry = geometry
self.tile_format = tile_format
self.tiles = tiles

def __str__(self):
return (
'HiPS draw result:\n'
f'Sky image: shape={self.image.shape}, dtype={self.image.dtype}\n'
f'WCS geometry: {self.geometry}\n'
)

def __repr__(self):
return (
'HipsDrawResult('
f'width={self.image.shape[0]}, '
f'height={self.image.shape[1]}, '
f'channels={self.image.ndim}, '
f'dtype={self.image.dtype}, '
f'format={self.tile_format}'
')'
)

def write_image(self, filename: str) -> None:
"""Write image to file.
Parameters
----------
filename : str
Filename
"""
if self.tile_format == 'fits':
hdu = fits.PrimaryHDU(data=self.image, header=self.geometry.fits_header)
hdu.writeto(filename)
else:
image = Image.fromarray(self.image)
image.save(filename)

def plot(self) -> None:
"""Plot the all sky image using `astropy.visualization.wcsaxes` and showing the HEALPix grid."""
import matplotlib.pyplot as plt
for tile in self.tiles:
corners = tile.meta.skycoord_corners.transform_to(self.geometry.celestial_frame)
ax = plt.subplot(projection=self.geometry.wcs)
opts = dict(color='red', lw=1, )
ax.plot(corners.data.lon.deg, corners.data.lat.deg,
transform=ax.get_transform('world'), **opts)
ax.imshow(self.image, origin='lower')


def measure_tile_shape(corners: tuple) -> Tuple[List[float]]:
"""Compute length of tile edges and diagonals."""
x, y = corners
Expand Down Expand Up @@ -251,7 +323,7 @@ def plot_mpl_single_tile(geometry: WCSGeometry, tile: HipsTile, image: np.ndarra
ax.imshow(image, origin='lower')


def make_sky_image(geometry: WCSGeometry, hips_survey: HipsSurveyProperties, tile_format: str) -> np.ndarray:
def make_sky_image(geometry: WCSGeometry, hips_survey: HipsSurveyProperties, tile_format: str) -> 'HipsDrawResult':
"""Make sky image: fetch tiles and draw.
The example for this can be found on the :ref:`gs` page.
Expand All @@ -274,4 +346,4 @@ def make_sky_image(geometry: WCSGeometry, hips_survey: HipsSurveyProperties, til
painter = SimpleTilePainter(geometry, hips_survey, tile_format)
painter.run()

return painter.image
return painter.result
21 changes: 13 additions & 8 deletions hips/draw/tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
data_2=2296,
data_sum=8756493140,
dtype='>i2',
repr='HipsDrawResult(width=1000, height=2000, channels=2, dtype=>i2, format=fits)'
),
dict(
file_format='jpg',
Expand All @@ -28,6 +29,7 @@
data_2=[137, 116, 114],
data_sum=828908873,
dtype='uint8',
repr='HipsDrawResult(width=1000, height=2000, channels=3, dtype=uint8, format=jpg)'
),
dict(
file_format='png',
Expand All @@ -37,22 +39,25 @@
data_2=[227, 217, 205, 255],
data_sum=1635622838,
dtype='uint8',
repr='HipsDrawResult(width=1000, height=2000, channels=3, dtype=uint8, format=png)'
),
]


@remote_data
@pytest.mark.parametrize('pars', make_sky_image_pars)
def test_make_sky_image(pars):
def test_make_sky_image(tmpdir, pars):
hips_survey = HipsSurveyProperties.fetch(url=pars['url'])
geometry = make_test_wcs_geometry()
image = make_sky_image(geometry=geometry, hips_survey=hips_survey, tile_format=pars['file_format'])
assert image.shape == pars['shape']
assert image.dtype == pars['dtype']
assert_allclose(np.sum(image), pars['data_sum'])
assert_allclose(image[200, 994], pars['data_1'])
assert_allclose(image[200, 995], pars['data_2'])

result = make_sky_image(geometry=geometry, hips_survey=hips_survey, tile_format=pars['file_format'])
assert result.image.shape == pars['shape']
assert result.image.dtype == pars['dtype']
assert repr(result) == pars['repr']
assert_allclose(np.sum(result.image), pars['data_sum'])
assert_allclose(result.image[200, 994], pars['data_1'])
assert_allclose(result.image[200, 995], pars['data_2'])
result.write_image(str(tmpdir / 'test.' + pars['file_format']))
result.plot()

@remote_data
class TestSimpleTilePainter:
Expand Down
7 changes: 7 additions & 0 deletions hips/utils/wcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@ def __init__(self, wcs: WCS, width: int, height: int) -> None:
self.wcs = wcs
self.shape = Shape(width=width, height=height)

def __str__(self):
return (
'WCSGeometry data:\n'
f'WCS: {self.wcs}\n'
f'Shape: {self.shape}\n'
)

@property
def center_pix(self) -> Tuple[float, float]:
"""Image center in pixel coordinates (tuple of x, y)."""
Expand Down

0 comments on commit 1a2c763

Please sign in to comment.