Skip to content

Commit

Permalink
Merge pull request #661 from locationtech-labs/feature/rasterio-windows
Browse files Browse the repository at this point in the history
Add RasterIO tile reading capability
  • Loading branch information
Jacob Bouffard committed Jun 4, 2018
2 parents d03d95f + 7bfc952 commit b366868
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 1 deletion.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,6 @@ prof/
\#*#
*~
.#*

# Visual Studio Code
.vscode
2 changes: 2 additions & 0 deletions geopyspark/geotrellis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,7 @@ def __str__(self):
from . import constants
from . import converters
from . import geotiff
from . import rasterio
from . import histogram
from . import layer
from . import neighborhood
Expand All @@ -821,6 +822,7 @@ def __str__(self):
__all__ += ['cost_distance']
__all__ += ['euclidean_distance']
__all__ += ['geotiff']
__all__ += ['rasterio']
__all__ += ['hillshade']
__all__ += histogram.__all__
__all__ += layer.__all__
Expand Down
106 changes: 106 additions & 0 deletions geopyspark/geotrellis/rasterio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import os
import math

import geopyspark as gps
from geopyspark.geotrellis.constants import DEFAULT_MAX_TILE_SIZE

try:
import rasterio
except ImportError:
raise ImportError("rasterio must be installed in order to use the features in the geopyspark.geotrellis.rasterio package")


__all__ = ['get']

# On driver
_GDAL_DATA = os.environ.get("GDAL_DATA")

def crs_to_proj4(crs):
"""Converts a ``rasterio.crsCRS`` to a proj4 str using osgeo library.
Args:
crs (``rasterio.crs.CRS``): The target ``CRS`` to be converted to a proj4 str.
Returns:
Proj4 str of the ``CRS``.
"""

try:
from osgeo import osr
except ImportError:
raise ImportError("osgeo must be installed in order to use the crs_to_proj4 function")

srs = osr.SpatialReference()
srs.ImportFromWkt(crs.wkt)
proj4 = srs.ExportToProj4()
return proj4

def _read_windows(uri, xcols, ycols, bands, crs_to_proj4):

if ("GDAL_DATA" not in os.environ) and (_GDAL_DATA != None):
os.environ["GDAL_DATA"] = _GDAL_DATA

with rasterio.open(uri) as dataset:
bounds = dataset.bounds
height = dataset.height
width = dataset.width
proj4 = crs_to_proj4(dataset.get_crs())
nodata = dataset.nodata
tile_cols = (int)(math.ceil(width/xcols)) * xcols
tile_rows = (int)(math.ceil(height/ycols)) * ycols
windows = [((x, min(width-1, x + xcols)), (y, min(height-1, y + ycols)))
for x in range(0, tile_cols, xcols)
for y in range(0, tile_rows, ycols)]

for window in windows:
((row_start, row_stop), (col_start, col_stop)) = window

left = bounds.left + (bounds.right - bounds.left)*(float(col_start)/width)
right = bounds.left + (bounds.right - bounds.left)*(float(col_stop)/ width)
bottom = bounds.top + (bounds.bottom - bounds.top)*(float(row_stop)/height)
top = bounds.top + (bounds.bottom - bounds.top)*(float(row_start)/height)
extent = gps.Extent(left, bottom, right, top)
projected_extent = gps.ProjectedExtent(extent=extent, proj4=proj4)

data = dataset.read(bands, window=window)
tile = gps.Tile.from_numpy_array(data, no_data_value=nodata)
yield (projected_extent, tile)

def get(data_source,
xcols=DEFAULT_MAX_TILE_SIZE,
ycols=DEFAULT_MAX_TILE_SIZE,
bands=None,
crs_to_proj4=crs_to_proj4):
"""Creates an ``RDD`` of windows represented as the key value pair: ``(ProjectedExtent, Tile)``
from URIs using rasterio.
Args:
data_source (str or [str] or RDD): The source of the data to be windowed.
Can either be URI or list of URIs which point to where the source data can be found;
or it can be an ``RDD`` that contains the URIs.
xcols (int, optional): The desired tile width. If the size is smaller than
the width of the read in tile, then that tile will be broken into smaller sections
of the given size. Defaults to :const:`~geopyspark.geotrellis.constants.DEFAULT_MAX_TILE_SIZE`.
ycols (int, optional): The desired tile height. If the size is smaller than
the height of the read in tile, then that tile will be broken into smaller sections
of the given size. Defaults to :const:`~geopyspark.geotrellis.constants.DEFAULT_MAX_TILE_SIZE`.
bands ([int], opitonal): The bands from which windows should be produced given as a list
of ``int``\s. Defaults to ``None`` which causes all bands to be read.
crs_to_proj4 (``rasterio.crs.CRS`` => str, optional) A funtion that takes a :class:`rasterio.crs.CRS`
and returns a Proj4 string. Default is :func:`geopyspark.geotrellis.rasterio.crs_to_proj4`.
Returns:
RDD
"""

pysc = gps.get_spark_context()

if isinstance(data_source, (list, str)):
if isinstance(data_source, str):
data_source = [data_source]

return pysc.\
parallelize(data_source, len(data_source)).\
flatMap(lambda ds: _read_windows(ds, xcols, ycols, bands, crs_to_proj4))
else:
return data_source.flatMap(lambda ds: _read_windows(ds, xcols, ycols, bands, crs_to_proj4))
36 changes: 36 additions & 0 deletions geopyspark/tests/geotrellis/io_tests/rasterio_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import unittest
import os
import pytest
import rasterio

from geopyspark.tests.base_test_class import BaseTestClass
from geopyspark.tests.python_test_utils import file_path


class CatalogTest(BaseTestClass):
uri = file_path("srtm_52_11.tif")

@pytest.fixture(autouse=True)
def tearDown(self):
yield
BaseTestClass.pysc._gateway.close()

@pytest.mark.skipif('TRAVIS' in os.environ,
reason="Cannot resolve depency issues in Travis for the time being")
def test_tiles(self):
import geopyspark as gps
from geopyspark.geotrellis import rasterio
tiles = rasterio._read_windows(self.uri, xcols=256, ycols=256, bands=None, crs_to_proj4=lambda n: '+proj=longlat +datum=WGS84 +no_defs ')
self.assertEqual(len(list(tiles)), 144)

@pytest.mark.skipif('TRAVIS' in os.environ,
reason="Cannot resolve depency issues in Travis for the time being")
def test_layer(self):
import geopyspark as gps
from geopyspark.geotrellis import rasterio
rdd0 = gps.rasterio.get(self.uri)
rdd1 = gps.RasterLayer.from_numpy_rdd(gps.LayerType.SPATIAL, rdd0)
self.assertEqual(rdd1.count(), 144)

if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
pytest>=3.0.6
numpy>=1.8
shapely>=1.6b3
rasterio>=1.0a7
rasterio==1.0a7
setuptools
protobuf>=3.3.0
pytz
Expand Down

0 comments on commit b366868

Please sign in to comment.