Skip to content

Commit

Permalink
RasterizeWKT (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
caspervdw committed Jan 30, 2020
1 parent 9f559da commit 81a90e8
Show file tree
Hide file tree
Showing 3 changed files with 212 additions and 19 deletions.
2 changes: 1 addition & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Changelog of dask-geomodeling
2.2.1 (unreleased)
------------------

- Nothing changed yet.
- Implemented raster.RasterizeWKT


2.2.0 (2019-12-20)
Expand Down
151 changes: 134 additions & 17 deletions dask_geomodeling/raster/misc.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,31 @@
"""
Module containing miscellaneous raster blocks.
"""
from osgeo import ogr
import numpy as np
from geopandas import GeoSeries

import shapely
from shapely.geometry import box
from shapely.errors import WKTReadingError
from shapely.wkt import loads as load_wkt

from dask import config
from dask_geomodeling.geometry import GeometryBlock
from dask_geomodeling.utils import (
get_uint_dtype,
get_dtype_max,
get_index,
rasterize_geoseries,
)
from dask_geomodeling import utils

from .base import RasterBlock, BaseSingle


__all__ = ["Clip", "Classify", "Reclassify", "Mask", "MaskBelow", "Step", "Rasterize"]
__all__ = [
"Clip",
"Classify",
"Reclassify",
"Mask",
"MaskBelow",
"Step",
"Rasterize",
"RasterizeWKT",
]


class Clip(BaseSingle):
Expand Down Expand Up @@ -134,7 +142,9 @@ def process(data, value):
if data is None or "values" not in data:
return data

index = get_index(values=data["values"], no_data_value=data["no_data_value"])
index = utils.get_index(
values=data["values"], no_data_value=data["no_data_value"]
)

fillvalue = 1 if value == 0 else 0
dtype = "float32" if isinstance(value, float) else "uint8"
Expand Down Expand Up @@ -284,20 +294,20 @@ def right(self):
def dtype(self):
# with 254 bin edges, we have 255 bins, and we need 256 possible values
# to include no_data
return get_uint_dtype(len(self.bins) + 2)
return utils.get_uint_dtype(len(self.bins) + 2)

@property
def fillvalue(self):
return get_dtype_max(self.dtype)
return utils.get_dtype_max(self.dtype)

@staticmethod
def process(data, bins, right):
if data is None or "values" not in data:
return data

values = data["values"]
dtype = get_uint_dtype(len(bins) + 2)
fillvalue = get_dtype_max(dtype)
dtype = utils.get_uint_dtype(len(bins) + 2)
fillvalue = utils.get_dtype_max(dtype)

result_values = np.digitize(values, bins, right).astype(dtype)
result_values[values == data["no_data_value"]] = fillvalue
Expand Down Expand Up @@ -366,7 +376,7 @@ def dtype(self):

@property
def fillvalue(self):
return get_dtype_max(self.dtype)
return utils.get_dtype_max(self.dtype)

def get_sources_and_requests(self, **request):
process_kwargs = {
Expand Down Expand Up @@ -474,7 +484,7 @@ def dtype(self):

@property
def fillvalue(self):
return None if self.dtype == np.bool else get_dtype_max(self.dtype)
return None if self.dtype == np.bool else utils.get_dtype_max(self.dtype)

@property
def period(self):
Expand Down Expand Up @@ -528,7 +538,7 @@ def get_sources_and_requests(self, **request):

geom_request = {
"mode": "intersects",
"geometry": shapely.geometry.box(*request["bbox"]),
"geometry": box(*request["bbox"]),
"projection": request["projection"],
"min_size": min_size,
"limit": limit,
Expand Down Expand Up @@ -579,7 +589,7 @@ def process(data, process_kwargs):
values = np.full((1, height, width), no_data_value, dtype=dtype)
return {"values": values, "no_data_value": no_data_value}

result = rasterize_geoseries(
result = utils.rasterize_geoseries(
geoseries=f["geometry"] if "geometry" in f else None,
values=values,
bbox=process_kwargs["bbox"],
Expand All @@ -598,3 +608,110 @@ def process(data, process_kwargs):
cast_values[values == result["no_data_value"]] = no_data_value

return {"values": cast_values, "no_data_value": no_data_value}


class RasterizeWKT(RasterBlock):
"""Converts a single geometry to a raster mask
Args:
wkt (string): the WKT representation of a geometry
projection (string): the projection of the geometry
Returns:
RasterBlock with True for cells that are inside the geometry.
"""

def __init__(self, wkt, projection):
if not isinstance(wkt, str):
raise TypeError("'{}' object is not allowed".format(type(wkt)))
if not isinstance(projection, str):
raise TypeError("'{}' object is not allowed".format(type(projection)))
try:
load_wkt(wkt)
except WKTReadingError:
raise ValueError("The provided geometry is not a valid WKT")
try:
utils.get_sr(projection)
except TypeError:
raise ValueError("The provided projection is not a valid WKT")
super().__init__(wkt, projection)

@property
def wkt(self):
return self.args[0]

@property
def projection(self):
return self.args[1]

@property
def dtype(self):
return np.dtype("bool")

@property
def fillvalue(self):
return None

@property
def period(self):
return (self.DEFAULT_ORIGIN,) * 2

@property
def extent(self):
return tuple(
utils.shapely_transform(load_wkt(self.wkt), self.projection, "EPSG:4326").bounds
)

@property
def timedelta(self):
return None

@property
def geometry(self):
return ogr.CreateGeometryFromWkt(self.wkt, utils.get_sr(self.projection))

@property
def geo_transform(self):
return None

def get_sources_and_requests(self, **request):
# first handle the 'time' and 'meta' requests
mode = request["mode"]
if mode == "time":
data = self.period[-1]
elif mode == "meta":
data = None
elif mode == "vals":
data = {"wkt": self.wkt, "projection": self.projection}
else:
raise ValueError("Unknown mode '{}'".format(mode))
return [(data, None), (request, None)]

@staticmethod
def process(data, request):
mode = request["mode"]
if mode == "time":
return {"time": [data]}
elif mode == "meta":
return {"meta": [None]}
# load the geometry and transform it into the requested projection
geometry = load_wkt(data["wkt"])
if data["projection"] != request["projection"]:
geometry = utils.shapely_transform(
geometry, data["projection"], request["projection"]
)
# take a shortcut when the geometry does not intersect the bbox
if not geometry.intersects(box(*request["bbox"])):
return {
"values": np.full(
(1, request["height"], request["width"]), False, dtype=np.bool
),
"no_data_value": None,
}
return utils.rasterize_geoseries(
geoseries=GeoSeries([geometry]) if not geometry.is_empty else None,
bbox=request["bbox"],
projection=request["projection"],
height=request["height"],
width=request["width"],
)
78 changes: 77 additions & 1 deletion dask_geomodeling/tests/test_raster_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import numpy as np
import pytest
from numpy.testing import assert_equal
from numpy.testing import assert_equal, assert_almost_equal
from shapely.geometry import box

from dask_geomodeling import raster
from dask_geomodeling.utils import shapely_transform, get_sr
from dask_geomodeling.raster.sources import MemorySource


Expand Down Expand Up @@ -71,6 +73,22 @@ def vals_request():
}


@pytest.fixture
def point_request():
bands = 3
time_first = datetime(2000, 1, 1)
time_delta = timedelta(hours=1)
yield {
"mode": "vals",
"start": time_first,
"stop": time_first + bands * time_delta,
"width": 1,
"height": 1,
"bbox": (135001, 455999, 135001, 455999),
"projection": "EPSG:28992",
}


@pytest.fixture
def vals_request_none():
bands = 3
Expand Down Expand Up @@ -277,3 +295,61 @@ def test_reclassify_time_request(source, vals_request, expected_time):
view = raster.Reclassify(store=source, data=[[7, 1000]])
vals_request["mode"] = "time"
assert view.get_data(**vals_request)["time"] == expected_time


@pytest.mark.parametrize("projection", ["EPSG:28992", "EPSG:4326", "EPSG:3857"])
def test_rasterize_wkt_vals(vals_request, projection):
# vals_request has width=4, height=6 and cell size of 0.5
# we place a rectangle of 2 x 3 with corner at x=1, y=2
view = raster.RasterizeWKT(
shapely_transform(
box(135000.5, 455998, 135001.5, 455999.5), "EPSG:28992", projection
).wkt,
projection,
)
vals_request["start"] = vals_request["stop"] = None
actual = view.get_data(**vals_request)
assert actual["values"][0].astype(int).tolist() == [
[0, 0, 0, 0],
[0, 1, 1, 0],
[0, 1, 1, 0],
[0, 1, 1, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
]


def test_rasterize_wkt_vals_no_intersection(vals_request):
view = raster.RasterizeWKT(box(135004, 455995, 135004.5, 455996).wkt, "EPSG:28992")
vals_request["start"] = vals_request["stop"] = None
actual = view.get_data(**vals_request)
assert ~actual["values"].any()


@pytest.mark.parametrize(
"bbox,expected",
[
[(135000.5, 455998, 135001.5, 455999.5), True],
[(135000.5, 455998, 135000.9, 455998.9), False],
],
)
def test_rasterize_wkt_point(point_request, bbox, expected):
view = raster.RasterizeWKT(box(*bbox).wkt, "EPSG:28992")
point_request["start"] = point_request["stop"] = None
actual = view.get_data(**point_request)
assert actual["values"].tolist() == [[[expected]]]


def test_rasterize_wkt_attrs():
geom = box(135004, 455995, 135004.5, 455996)
view = raster.RasterizeWKT(geom.wkt, "EPSG:28992")
assert view.projection == "EPSG:28992"
assert_almost_equal(view.geometry.GetEnvelope(), [135004, 135004.5, 455995, 455996])
assert view.geometry.GetSpatialReference().IsSame(get_sr("EPSG:28992"))
assert view.dtype == np.bool
assert view.fillvalue is None
assert_almost_equal(
view.extent, shapely_transform(geom, "EPSG:28992", "EPSG:4326").bounds
)
assert view.timedelta is None
assert view.period == (datetime(1970, 1, 1), datetime(1970, 1, 1))

0 comments on commit 81a90e8

Please sign in to comment.