Skip to content

Commit

Permalink
Merge d8f3d9d into dc0307c
Browse files Browse the repository at this point in the history
  • Loading branch information
sgillies committed Feb 8, 2015
2 parents dc0307c + d8f3d9d commit 6efe688
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 0 deletions.
121 changes: 121 additions & 0 deletions rasterio/rio/calc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Calc command.

import logging
import math
import os.path
import re
import sys
import warnings

import click
from cligj import files_inout_arg

import rasterio
from rasterio.rio.cli import cli


@cli.command(short_help="Raster data calculator.")
@click.argument('command')
@files_inout_arg
@click.option('--dtype',
type=click.Choice(
['uint8', 'uint16', 'int16', 'float32', 'float64']),
help="Output data type.")
@click.pass_context
def calc(ctx, command, files, dtype):
"""A raster data calculator
Applies one or more commands to a set of input datasets and writes the
results to a new dataset.
Command syntax is a work in progress.
"""
import numpy as np

verbosity = (ctx.obj and ctx.obj.get('verbosity')) or 1
logger = logging.getLogger('rio')

try:
with rasterio.drivers(CPL_DEBUG=verbosity>2):
output = files[-1]
files = files[:-1]

with rasterio.open(files[0]) as first:
kwargs = first.meta
kwargs['transform'] = kwargs.pop('affine')

sources = [rasterio.open(path).read() for path in files]


# TODO: implement a real parser for calc expressions,
# perhaps using numexpr's parser as a guide, instead
# eval'ing any string.

parts = command.split(';')
if len(parts) == 1:

# Translates, eg, '{1}' to 'sources[0]'.
cmd = re.sub(
r'{(\d)}',
lambda m: 'sources[%d]' % (int(m.group(1))-1),
parts.pop())

logger.debug("Translated cmd: %r", cmd)

results = eval(cmd)

# Using the class method instead of instance method.
# Latter raises
# TypeError: astype() got an unexpected keyword argument 'copy'
# Possibly something to do with the instance being a masked
# array.
results = np.ndarray.astype(
results, dtype or 'float64', copy=False)

# Write results.
if len(results.shape) == 3:
kwargs.update(
count=results.shape[0],
dtype=results.dtype.type)
with rasterio.open(output, 'w', **kwargs) as dst:
dst.write(results)

elif len(results.shape) == 2:
kwargs.update(
count=1,
dtype=results.dtype.type)
with rasterio.open(output, 'w', **kwargs) as dst:
dst.write(results, 1)

else:
parts = list(filter(lambda p: p.strip(), parts))
kwargs['count'] = len(parts)

results = []
#with rasterio.open(output, 'w', **kwargs) as dst:

for part in parts:
cmd = re.sub(
r'{(\d)\s*,\s*(\d)}',
lambda m: 'sources[%d][%d]' % (
int(m.group(1))-1, int(m.group(2))-1),
part)

logger.debug("Translated cmd: %r", cmd)

res = eval(cmd)
res = np.ndarray.astype(
res, dtype or 'float64', copy=False)
results.append(res)

results = np.asanyarray(results)
kwargs.update(
count=results.shape[0],
dtype=results.dtype.type)
with rasterio.open(output, 'w', **kwargs) as dst:
dst.write(results)

sys.exit(0)
except Exception:
logger.exception("Failed. Exception caught")
sys.exit(1)
1 change: 1 addition & 0 deletions rasterio/rio/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python

from rasterio.rio.calc import calc
from rasterio.rio.cli import cli
from rasterio.rio.bands import stack
from rasterio.rio.features import shapes, rasterize
Expand Down
57 changes: 57 additions & 0 deletions tests/test_rio_calc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import sys
import logging

from click.testing import CliRunner

import rasterio
from rasterio.rio.calc import calc


logging.basicConfig(stream=sys.stderr, level=logging.DEBUG)


def test_multiband_calc(tmpdir):
outfile = str(tmpdir.join('out.tif'))
runner = CliRunner()
result = runner.invoke(calc, [
'0.10*{1} + 125', 'tests/data/shade.tif', outfile],
catch_exceptions=False)
assert result.exit_code == 0
with rasterio.open(outfile) as src:
assert src.meta['dtype'] == 'float64'
data = src.read()
assert data.min() == 125


def test_singleband_calc(tmpdir):
outfile = str(tmpdir.join('out.tif'))
runner = CliRunner()
result = runner.invoke(calc, [
'0.10*{1,1} + 125;', 'tests/data/shade.tif', outfile],
catch_exceptions=False)
assert result.exit_code == 0
with rasterio.open(outfile) as src:
assert src.meta['dtype'] == 'float64'
data = src.read()
assert data.min() == 125


def test_parts_calc(tmpdir):
# Producing an RGB output from the hill shade.
# Red band has bumped up values. Other bands are unchanged.
outfile = str(tmpdir.join('out.tif'))
runner = CliRunner()
result = runner.invoke(calc, [
'{1,1} + 125; {1,1}; {1,1}',
'--dtype', 'uint8',
'tests/data/shade.tif',
outfile],
catch_exceptions=False)
assert result.exit_code == 0
with rasterio.open(outfile) as src:
assert src.count == 3
assert src.meta['dtype'] == 'uint8'
data = src.read()
assert data[0].min() == 125
assert data[1].min() == 0
assert data[2].min() == 0

0 comments on commit 6efe688

Please sign in to comment.