-
Notifications
You must be signed in to change notification settings - Fork 59
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #661 from locationtech-labs/feature/rasterio-windows
Add RasterIO tile reading capability
- Loading branch information
Showing
5 changed files
with
148 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -109,3 +109,6 @@ prof/ | |
\#*# | ||
*~ | ||
.#* | ||
|
||
# Visual Studio Code | ||
.vscode |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import os | ||
import math | ||
|
||
import geopyspark as gps | ||
from geopyspark.geotrellis.constants import DEFAULT_MAX_TILE_SIZE | ||
|
||
try: | ||
import rasterio | ||
except ImportError: | ||
raise ImportError("rasterio must be installed in order to use the features in the geopyspark.geotrellis.rasterio package") | ||
|
||
|
||
__all__ = ['get'] | ||
|
||
# On driver | ||
_GDAL_DATA = os.environ.get("GDAL_DATA") | ||
|
||
def crs_to_proj4(crs): | ||
"""Converts a ``rasterio.crsCRS`` to a proj4 str using osgeo library. | ||
Args: | ||
crs (``rasterio.crs.CRS``): The target ``CRS`` to be converted to a proj4 str. | ||
Returns: | ||
Proj4 str of the ``CRS``. | ||
""" | ||
|
||
try: | ||
from osgeo import osr | ||
except ImportError: | ||
raise ImportError("osgeo must be installed in order to use the crs_to_proj4 function") | ||
|
||
srs = osr.SpatialReference() | ||
srs.ImportFromWkt(crs.wkt) | ||
proj4 = srs.ExportToProj4() | ||
return proj4 | ||
|
||
def _read_windows(uri, xcols, ycols, bands, crs_to_proj4): | ||
|
||
if ("GDAL_DATA" not in os.environ) and (_GDAL_DATA != None): | ||
os.environ["GDAL_DATA"] = _GDAL_DATA | ||
|
||
with rasterio.open(uri) as dataset: | ||
bounds = dataset.bounds | ||
height = dataset.height | ||
width = dataset.width | ||
proj4 = crs_to_proj4(dataset.get_crs()) | ||
nodata = dataset.nodata | ||
tile_cols = (int)(math.ceil(width/xcols)) * xcols | ||
tile_rows = (int)(math.ceil(height/ycols)) * ycols | ||
windows = [((x, min(width-1, x + xcols)), (y, min(height-1, y + ycols))) | ||
for x in range(0, tile_cols, xcols) | ||
for y in range(0, tile_rows, ycols)] | ||
|
||
for window in windows: | ||
((row_start, row_stop), (col_start, col_stop)) = window | ||
|
||
left = bounds.left + (bounds.right - bounds.left)*(float(col_start)/width) | ||
right = bounds.left + (bounds.right - bounds.left)*(float(col_stop)/ width) | ||
bottom = bounds.top + (bounds.bottom - bounds.top)*(float(row_stop)/height) | ||
top = bounds.top + (bounds.bottom - bounds.top)*(float(row_start)/height) | ||
extent = gps.Extent(left, bottom, right, top) | ||
projected_extent = gps.ProjectedExtent(extent=extent, proj4=proj4) | ||
|
||
data = dataset.read(bands, window=window) | ||
tile = gps.Tile.from_numpy_array(data, no_data_value=nodata) | ||
yield (projected_extent, tile) | ||
|
||
def get(data_source, | ||
xcols=DEFAULT_MAX_TILE_SIZE, | ||
ycols=DEFAULT_MAX_TILE_SIZE, | ||
bands=None, | ||
crs_to_proj4=crs_to_proj4): | ||
"""Creates an ``RDD`` of windows represented as the key value pair: ``(ProjectedExtent, Tile)`` | ||
from URIs using rasterio. | ||
Args: | ||
data_source (str or [str] or RDD): The source of the data to be windowed. | ||
Can either be URI or list of URIs which point to where the source data can be found; | ||
or it can be an ``RDD`` that contains the URIs. | ||
xcols (int, optional): The desired tile width. If the size is smaller than | ||
the width of the read in tile, then that tile will be broken into smaller sections | ||
of the given size. Defaults to :const:`~geopyspark.geotrellis.constants.DEFAULT_MAX_TILE_SIZE`. | ||
ycols (int, optional): The desired tile height. If the size is smaller than | ||
the height of the read in tile, then that tile will be broken into smaller sections | ||
of the given size. Defaults to :const:`~geopyspark.geotrellis.constants.DEFAULT_MAX_TILE_SIZE`. | ||
bands ([int], opitonal): The bands from which windows should be produced given as a list | ||
of ``int``\s. Defaults to ``None`` which causes all bands to be read. | ||
crs_to_proj4 (``rasterio.crs.CRS`` => str, optional) A funtion that takes a :class:`rasterio.crs.CRS` | ||
and returns a Proj4 string. Default is :func:`geopyspark.geotrellis.rasterio.crs_to_proj4`. | ||
Returns: | ||
RDD | ||
""" | ||
|
||
pysc = gps.get_spark_context() | ||
|
||
if isinstance(data_source, (list, str)): | ||
if isinstance(data_source, str): | ||
data_source = [data_source] | ||
|
||
return pysc.\ | ||
parallelize(data_source, len(data_source)).\ | ||
flatMap(lambda ds: _read_windows(ds, xcols, ycols, bands, crs_to_proj4)) | ||
else: | ||
return data_source.flatMap(lambda ds: _read_windows(ds, xcols, ycols, bands, crs_to_proj4)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import unittest | ||
import os | ||
import pytest | ||
import rasterio | ||
|
||
from geopyspark.tests.base_test_class import BaseTestClass | ||
from geopyspark.tests.python_test_utils import file_path | ||
|
||
|
||
class CatalogTest(BaseTestClass): | ||
uri = file_path("srtm_52_11.tif") | ||
|
||
@pytest.fixture(autouse=True) | ||
def tearDown(self): | ||
yield | ||
BaseTestClass.pysc._gateway.close() | ||
|
||
@pytest.mark.skipif('TRAVIS' in os.environ, | ||
reason="Cannot resolve depency issues in Travis for the time being") | ||
def test_tiles(self): | ||
import geopyspark as gps | ||
from geopyspark.geotrellis import rasterio | ||
tiles = rasterio._read_windows(self.uri, xcols=256, ycols=256, bands=None, crs_to_proj4=lambda n: '+proj=longlat +datum=WGS84 +no_defs ') | ||
self.assertEqual(len(list(tiles)), 144) | ||
|
||
@pytest.mark.skipif('TRAVIS' in os.environ, | ||
reason="Cannot resolve depency issues in Travis for the time being") | ||
def test_layer(self): | ||
import geopyspark as gps | ||
from geopyspark.geotrellis import rasterio | ||
rdd0 = gps.rasterio.get(self.uri) | ||
rdd1 = gps.RasterLayer.from_numpy_rdd(gps.LayerType.SPATIAL, rdd0) | ||
self.assertEqual(rdd1.count(), 144) | ||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
pytest>=3.0.6 | ||
numpy>=1.8 | ||
shapely>=1.6b3 | ||
rasterio>=1.0a7 | ||
rasterio==1.0a7 | ||
setuptools | ||
protobuf>=3.3.0 | ||
pytz | ||
|