Skip to content

Commit

Permalink
Geoblock mask int values cannot be higher then uint8 (#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
caspervdw committed Jun 23, 2020
1 parent cfcdd50 commit 773a0e0
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 4 deletions.
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ Changelog of dask-geomodeling

- Implemented `RasterTiler`.

- Let raster.Mask accomodate int values larger than uint8.


2.2.8 (2020-06-12)
------------------
Expand Down
13 changes: 11 additions & 2 deletions dask_geomodeling/raster/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,16 @@ def fillvalue(self):

@property
def dtype(self):
return "float32" if isinstance(self.value, float) else "uint8"
return self._dtype_from_value(self.value)

@staticmethod
def _dtype_from_value(value):
if isinstance(value, float):
return np.dtype("float32")
elif value >= 0:
return utils.get_uint_dtype(value)
else:
return utils.get_int_dtype(value)

@staticmethod
def process(data, value):
Expand All @@ -156,7 +165,7 @@ def process(data, value):
)

fillvalue = 1 if value == 0 else 0
dtype = "float32" if isinstance(value, float) else "uint8"
dtype = Mask._dtype_from_value(value)

values = np.full_like(data["values"], fillvalue, dtype=dtype)
values[index] = value
Expand Down
18 changes: 18 additions & 0 deletions dask_geomodeling/tests/test_raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1663,6 +1663,24 @@ def test_mask(self):
self.assertEqual(view.get_data(**self.meta_request)["meta"], self.expected_meta)
self.assertEqual(view.get_data(**self.time_request)["time"], self.expected_time)

# the 'value' determines the dtype. 1000 becomes uint16.
view = raster.Mask(store=self.raster, value=1000)
data = view.get_data(**self.vals_request)
self.assertEqual(str(view.dtype), "uint16")
assert_equal(data["values"], 1000)

# -1000 becomes int16.
view = raster.Mask(store=self.raster, value=-1000)
data = view.get_data(**self.vals_request)
self.assertEqual(str(view.dtype), "int16")
assert_equal(data["values"], -1000)

# 3.14159 becomes float32.
view = raster.Mask(store=self.raster, value=3.14159)
data = view.get_data(**self.vals_request)
self.assertEqual(str(view.dtype), "float32")
assert_equal(data["values"], 3.14159)

def test_mask_below(self):
# filled result
view = raster.MaskBelow(store=self.raster, value=0)
Expand Down
13 changes: 13 additions & 0 deletions dask_geomodeling/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,19 @@ def test_get_dtype_min(self):
self.assertIsInstance(utils.get_dtype_min("f4"), float)
self.assertIsInstance(utils.get_dtype_min("u4"), int)

def test_get_int_dtype(self):
for dtype in ["i1", "i2", "i4", "i8"]:
hi = np.iinfo(dtype).max
lo = np.iinfo(dtype).max
self.assertEqual(utils.get_int_dtype(hi - 1), dtype)
self.assertEqual(utils.get_int_dtype(lo), dtype)

def test_get_uint_dtype(self):
self.assertRaises(ValueError, utils.get_uint_dtype, -1)
for dtype in ["u1", "u2", "u4", "u8"]:
hi = np.iinfo(dtype).max
self.assertEqual(utils.get_uint_dtype(hi - 1), dtype)

def test_get_projection(self):
projection_rd = str("EPSG:28992")
projection_wgs = str("EPSG:4326")
Expand Down
16 changes: 14 additions & 2 deletions dask_geomodeling/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,24 @@ def get_dtype_min(dtype):
return np.iinfo(d).min


def get_int_dtype(n):
"""Get the smallest int dtype that accomodates 'n' values, leaving space
for a no data value."""
for dtype in ("i1", "i2", "i4", "i8"):
if (n - 1 <= np.iinfo(dtype).max) and (n >= np.iinfo(dtype).min):
return np.dtype(dtype)
raise ValueError("Value does not fit in int dtype ({})".format(n))


def get_uint_dtype(n):
"""Get the smallest uint dtype that accomodates 'n' values"""
"""Get the smallest uint dtype that accomodates 'n' values, leaving space
for a no data value."""
if n < 0:
raise ValueError("Value does not fit in uint dtype ({})".format(n))
for dtype in ("u1", "u2", "u4", "u8"):
if n - 1 <= np.iinfo(dtype).max:
return np.dtype(dtype)
raise ValueError("Too many values for uint dtype ({})".format(n))
raise ValueError("Value does not fit in uint dtype ({})".format(n))


def get_rounded_repr(obj, significant=4, fmt="{} (rounded)"):
Expand Down

0 comments on commit 773a0e0

Please sign in to comment.