Skip to content

Commit

Permalink
Merge 0380b56 into 744b38b
Browse files Browse the repository at this point in the history
  • Loading branch information
sgillies committed May 15, 2019
2 parents 744b38b + 0380b56 commit 9c74729
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 42 deletions.
116 changes: 80 additions & 36 deletions rasterio/rio/calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,55 @@

from collections import OrderedDict
from distutils.version import LooseVersion
import math

import click
import snuggs

import rasterio
from rasterio.features import sieve
from rasterio.fill import fillnodata
from rasterio.windows import Window
from rasterio.rio import options
from rasterio.rio.helpers import resolve_inout


def get_bands(inputs, d, i=None):
def _get_bands(inputs, sources, d, i=None):
"""Get a rasterio.Band object from calc's inputs"""
path = inputs[d] if d in dict(inputs) else inputs[int(d) - 1][1]
src = rasterio.open(path)
idx = d if d in dict(inputs) else int(d) - 1
src = sources[idx]
return (rasterio.band(src, i) if i else
[rasterio.band(src, j) for j in src.indexes])


def read_array(ix, subix=None, dtype=None):
def _read_array(ix, subix=None, dtype=None):
"""Change the type of a read array"""
arr = snuggs._ctx.lookup(ix, subix)
if dtype:
arr = arr.astype(dtype)
return arr


def _get_work_windows(width, height, count, itemsize, mem_limit=1):
"""Get windows for work"""
element_limit = mem_limit * 1.0e+6 / itemsize
pixel_limit = element_limit / count
work_window_size = int(math.floor(math.sqrt(pixel_limit)))
num_windows_cols = int(math.ceil(width / work_window_size))
num_windows_rows = int(math.ceil(height / work_window_size))
work_windows = []

for col in range(num_windows_cols):
col_offset = col * work_window_size
w = min(work_window_size, width - col_offset)
for row in range(num_windows_rows):
row_offset = row * work_window_size
h = min(work_window_size, height - row_offset)
work_windows.append(((row, col), Window(col_offset, row_offset, w, h)))

return work_windows


@click.command(short_help="Raster data calculator.")
@click.argument('command')
@options.files_inout_arg
Expand All @@ -40,10 +62,10 @@ def read_array(ix, subix=None, dtype=None):
@options.dtype_opt
@options.masked_opt
@options.overwrite_opt
@click.option("--mem-limit", type=int, default=64, help="Limit on size of scratch space, in MB.")
@options.creation_options
@click.pass_context
def calc(ctx, command, files, output, name, dtype, masked, overwrite,
creation_options):
def calc(ctx, command, files, output, name, dtype, masked, overwrite, mem_limit, creation_options):
"""A raster data calculator
Evaluates an expression using input datasets and writes the result
Expand Down Expand Up @@ -89,19 +111,36 @@ def calc(ctx, command, files, output, name, dtype, masked, overwrite,
with ctx.obj['env']:
output, files = resolve_inout(files=files, output=output,
overwrite=overwrite)

inputs = ([tuple(n.split('=')) for n in name] +
[(None, n) for n in files])
sources = [rasterio.open(path) for name, path in inputs]

first = sources[0]
kwargs = first.profile
kwargs.update(**creation_options)
dtype = dtype or first.meta['dtype']
kwargs['dtype'] = dtype

# Extend snuggs.
snuggs.func_map['read'] = _read_array
snuggs.func_map['band'] = lambda d, i: _get_bands(inputs, sources, d, i)
snuggs.func_map['bands'] = lambda d: _get_bands(inputs, sources, d)
snuggs.func_map['fillnodata'] = lambda *args: fillnodata(*args)
snuggs.func_map['sieve'] = lambda *args: sieve(*args)

with rasterio.open(inputs[0][1]) as first:
kwargs = first.meta
kwargs.update(**creation_options)
dtype = dtype or first.meta['dtype']
kwargs['dtype'] = dtype
dst = None

# The windows iterator is initialized with a single sample.
# The actual work windows will be added in the second
# iteration of the loop.
work_windows = [(None, Window(0, 0, 16, 16))]

for ij, window in work_windows:

ctxkwds = OrderedDict()

for i, ((name, path), src) in enumerate(zip(inputs, sources)):

ctxkwds = OrderedDict()
for i, (name, path) in enumerate(inputs):
with rasterio.open(path) as src:
# Using the class method instead of instance
# method. Latter raises
#
Expand All @@ -110,36 +149,41 @@ def calc(ctx, command, files, output, name, dtype, masked, overwrite,
#
# possibly something to do with the instance being
# a masked array.
ctxkwds[name or '_i%d' % (i + 1)] = src.read(masked=masked)
ctxkwds[name or '_i%d' % (i + 1)] = src.read(masked=masked, window=window)

# Extend snuggs.
snuggs.func_map['read'] = read_array
snuggs.func_map['band'] = lambda d, i: get_bands(inputs, d, i)
snuggs.func_map['bands'] = lambda d: get_bands(inputs, d)
snuggs.func_map['fillnodata'] = lambda *args: fillnodata(*args)
snuggs.func_map['sieve'] = lambda *args: sieve(*args)
res = snuggs.eval(command, **ctxkwds)

res = snuggs.eval(command, ctxkwds)
if (isinstance(res, np.ma.core.MaskedArray) and (
tuple(LooseVersion(np.__version__).version) < (1, 9) or
tuple(LooseVersion(np.__version__).version) > (1, 10))):
res = res.filled(kwargs['nodata'])

if (isinstance(res, np.ma.core.MaskedArray) and (
tuple(LooseVersion(np.__version__).version) < (1, 9) or
tuple(LooseVersion(np.__version__).version) > (1, 10))):
res = res.filled(kwargs['nodata'])
if len(res.shape) == 3:
results = np.ndarray.astype(res, dtype, copy=False)
else:
results = np.asanyarray(
[np.ndarray.astype(res, dtype, copy=False)])

if len(res.shape) == 3:
results = np.ndarray.astype(res, dtype, copy=False)
else:
results = np.asanyarray(
[np.ndarray.astype(res, dtype, copy=False)])
# The first iteration is only to get sample results and from them
# compute some properties of the output dataset.
if dst is None:
kwargs['count'] = results.shape[0]
dst = rasterio.open(output, 'w', **kwargs)
work_windows.extend(_get_work_windows(dst.width, dst.height, dst.count, np.dtype(dst.dtypes[0]).itemsize, mem_limit=mem_limit))

kwargs['count'] = results.shape[0]

with rasterio.open(output, 'w', **kwargs) as dst:
dst.write(results)
# In subsequent iterations we write results.
else:
dst.write(results, window=window)

except snuggs.ExpressionError as err:
click.echo("Expression Error:")
click.echo(' %s' % err.text)
click.echo(' ' + ' ' * err.offset + "^")
click.echo(err)
raise click.Abort()

finally:
if dst:
dst.close()
for src in sources:
src.close()
22 changes: 16 additions & 6 deletions tests/test_rio_calc.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
import sys
import logging

from click.testing import CliRunner
import pytest

import rasterio
from rasterio.rio.calc import _get_work_windows
from rasterio.rio.main import main_group


logging.basicConfig(stream=sys.stderr, level=logging.DEBUG)


def test_err(tmpdir):
outfile = str(tmpdir.join('out.tif'))
runner = CliRunner()
Expand Down Expand Up @@ -138,6 +134,7 @@ def test_fillnodata_map(tmpdir):
assert round(data.mean(), 1) == 58.6
assert data[0][60][60] > 0


def test_sieve_band(tmpdir):
outfile = str(tmpdir.join('out.tif'))
runner = CliRunner()
Expand Down Expand Up @@ -183,3 +180,16 @@ def test_positional_calculation_byindex(tmpdir):

with rasterio.open(outfile) as src:
assert src.read(1, window=window) == answer


@pytest.mark.parametrize('width', [10, 791, 3000])
@pytest.mark.parametrize('height', [8, 718, 4000])
@pytest.mark.parametrize('count', [1, 3, 4])
@pytest.mark.parametrize('itemsize', [1, 2, 8])
@pytest.mark.parametrize('mem_limit', [1, 16, 64, 512])
def test_get_work_windows(width, height, count, itemsize, mem_limit):
work_windows = _get_work_windows(width, height, count, itemsize, mem_limit=mem_limit)
num_windows_rows = max([i for ((i, j), w) in work_windows]) + 1
num_windows_cols = max([j for ((i, j), w) in work_windows]) + 1
assert sum((w.width for ij, w in work_windows)) == width * num_windows_rows
assert sum((w.height for ij, w in work_windows)) == height * num_windows_cols

0 comments on commit 9c74729

Please sign in to comment.