Skip to content

Commit

Permalink
Check for matching time resolutions in rasters.Clip (#59)
Browse files Browse the repository at this point in the history
  • Loading branch information
caspervdw committed Aug 19, 2020
1 parent ad7bd6e commit 2435d1a
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 6 deletions.
2 changes: 1 addition & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Changelog of dask-geomodeling
2.2.11 (unreleased)
-------------------

- Nothing changed yet.
- Check for matching time resolutions in raster.Clip.


2.2.10 (2020-07-29)
Expand Down
36 changes: 36 additions & 0 deletions dask_geomodeling/raster/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class Clip(BaseSingle):
If the 'source' raster is a boolean raster, False will result in 'no data'.
Note that the input rasters are required to have the same time resolution.
Args:
store (RasterBlock): Raster whose values are clipped
source (RasterBlock): Raster that is used as the clipping mask
Expand All @@ -50,12 +52,31 @@ class Clip(BaseSingle):
def __init__(self, store, source):
if not isinstance(source, RasterBlock):
raise TypeError("'{}' object is not allowed".format(type(store)))
# timedeltas are required to be equal
if store.timedelta != source.timedelta:
raise ValueError(
"Time resolution of the clipping mask does not match that of "
"the values raster. Consider using Snap."
)
super(Clip, self).__init__(store, source)

@property
def source(self):
return self.args[1]

def get_sources_and_requests(self, **request):
start = request.get("start", None)
stop = request.get("stop", None)

if start is not None and stop is not None:
# limit request to self.period so that resulting data is aligned
period = self.period
if period is not None:
request["start"] = max(start, period[0])
request["stop"] = min(stop, period[1])

return ((source, request) for source in self.args)

@staticmethod
def process(data, source_data):
""" Mask store_data where source_data has no data """
Expand Down Expand Up @@ -115,6 +136,21 @@ def geometry(self):
return
return result

@property
def period(self):
""" Return period datetime tuple. """
periods = [x.period for x in self.args]
if any(period is None for period in periods):
return None # None precedes

# multiple periods: return the overlapping period
start = max([p[0] for p in periods])
stop = min([p[1] for p in periods])
if stop < start:
return None # no overlap
else:
return start, stop


class Mask(BaseSingle):
"""
Expand Down
56 changes: 51 additions & 5 deletions dask_geomodeling/tests/test_raster_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@

def test_clip_attrs_store_empty(source, empty_source):
# clip should propagate the (empty) extent of the store
clip = raster.Clip(empty_source, source)
clip = raster.Clip(empty_source, raster.Snap(source, empty_source))
assert clip.extent is None
assert clip.geometry is None


def test_clip_attrs_mask_empty(source, empty_source):
# clip should propagate the (empty) extent of the clipping mask
clip = raster.Clip(source, empty_source)
clip = raster.Clip(source, raster.Snap(empty_source, source))
assert clip.extent is None
assert clip.geometry is None

Expand Down Expand Up @@ -63,7 +63,7 @@ def test_clip_attrs_with_reprojection(source, empty_source):
assert clip.geometry.GetEnvelope() == source.geometry.GetEnvelope()


def test_clip_attrs_no_intersection(source, empty_source):
def test_clip_attrs_no_intersection(source):
# create a raster in that does not overlap the store
clipping_mask = MemorySource(
data=source.data,
Expand All @@ -79,13 +79,27 @@ def test_clip_attrs_no_intersection(source, empty_source):
assert clip.geometry is None


def test_clip_matching_timedelta(source):
clip = raster.Clip(source, source == 7)
assert clip.timedelta == source.timedelta


def test_clip_unequal_timedelta(source, empty_source):
# clip checks for matching timedeltas; test that here
# NB: note that `source` is temporal and `empty_source` is not
with pytest.raises(ValueError, match=".*resolution of the clipping.*"):
clip = raster.Clip(source, empty_source)
with pytest.raises(ValueError, match=".*resolution of the clipping.*"):
clip = raster.Clip(empty_source, source)


def test_clip_empty_source(source, empty_source, vals_request):
clip = raster.Clip(empty_source, source)
clip = raster.Clip(empty_source, raster.Snap(source, empty_source))
assert clip.get_data(**vals_request) is None


def test_clip_with_empty_mask(source, empty_source, vals_request):
clip = raster.Clip(source, empty_source)
clip = raster.Clip(source, raster.Snap(empty_source, source))
assert clip.get_data(**vals_request) is None


Expand Down Expand Up @@ -118,6 +132,38 @@ def test_clip_time_request(source, vals_request, expected_time):
assert clip.get_data(**vals_request)["time"] == expected_time


def test_clip_partial_temporal_overlap(source, vals_request):
# create a clipping mask in that temporally does not overlap the store
clipping_mask = MemorySource(
data=source.data,
no_data_value=source.no_data_value,
projection=source.projection,
pixel_size=source.pixel_size,
pixel_origin=source.pixel_origin,
time_first=source.time_first + source.time_delta,
time_delta=source.time_delta,
)
clip = raster.Clip(source, clipping_mask)
assert clip.period == (clipping_mask.period[0], source.period[1])
assert clip.get_data(**vals_request)["values"][:, 0, 0].tolist() == [7, 255]


def test_clip_no_temporal_overlap(source, vals_request):
# create a clipping mask in that temporally does not overlap the store
clipping_mask = MemorySource(
data=source.data,
no_data_value=source.no_data_value,
projection=source.projection,
pixel_size=source.pixel_size,
pixel_origin=source.pixel_origin,
time_first=source.time_first + 10 * source.time_delta,
time_delta=source.time_delta,
)
clip = raster.Clip(source, clipping_mask)
assert clip.period == None
assert clip.get_data(**vals_request) is None


def test_reclassify(source, vals_request):
view = raster.Reclassify(store=source, data=[[7, 1000]])
data = view.get_data(**vals_request)
Expand Down

0 comments on commit 2435d1a

Please sign in to comment.