diff --git a/core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala b/core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala index be0a511e1..9f7101318 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala @@ -405,6 +405,9 @@ trait RasterFunctions { /** Cellwise inequality comparison between a tile and a scalar. */ def rf_local_unequal[T: Numeric](tileCol: Column, value: T): Column = Unequal(tileCol, value) + /** Test if each cell value is in provided array */ + def rf_local_is_in(tileCol: Column, arrayCol: Column) = IsIn(tileCol, arrayCol) + /** Return a tile with ones where the input is NoData, otherwise zero */ def rf_local_no_data(tileCol: Column): Column = Undefined(tileCol) diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/IsIn.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/IsIn.scala new file mode 100644 index 000000000..84008acbd --- /dev/null +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/localops/IsIn.scala @@ -0,0 +1,88 @@ +/* + * This software is licensed under the Apache 2 license, quoted below. + * + * Copyright 2019 Astraea, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * [http://www.apache.org/licenses/LICENSE-2.0] + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + * + * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.locationtech.rasterframes.expressions.localops + +import geotrellis.raster.Tile +import geotrellis.raster.mapalgebra.local.IfCell +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.types.{ArrayType, DataType} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, ExpressionDescription} +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.rf.TileUDT +import org.locationtech.rasterframes.encoders.CatalystSerializer._ +import org.locationtech.rasterframes.expressions.DynamicExtractors._ +import org.locationtech.rasterframes.expressions._ + +@ExpressionDescription( + usage = "_FUNC_(tile, rhs) - In each cell of `tile`, return true if the value is in rhs.", + arguments = """ + Arguments: + * tile - tile column to apply abs + * rhs - array to test against + """, + examples = """ + Examples: + > SELECT _FUNC_(tile, array(lit(33), lit(66), lit(99))); + ...""" +) +case class IsIn(left: Expression, right: Expression) extends BinaryExpression with CodegenFallback { + override val nodeName: String = "rf_local_is_in" + + override def dataType: DataType = left.dataType + + @transient private lazy val elementType: DataType = right.dataType.asInstanceOf[ArrayType].elementType + + override def checkInputDataTypes(): TypeCheckResult = + if(!tileExtractor.isDefinedAt(left.dataType)) { + TypeCheckFailure(s"Input type '${left.dataType}' does not conform to a raster type.") + } else right.dataType match { + case _: ArrayType ⇒ TypeCheckSuccess + case _ ⇒ TypeCheckFailure(s"Input type '${right.dataType}' does not conform to ArrayType.") + } + + override protected def nullSafeEval(input1: Any, input2: Any): Any = { + implicit val tileSer = TileUDT.tileSerializer + val (childTile, childCtx) = tileExtractor(left.dataType)(row(input1)) + + val arr = input2.asInstanceOf[ArrayData].toArray[AnyRef](elementType) + + childCtx match { + case Some(ctx) => ctx.toProjectRasterTile(op(childTile, arr)).toInternalRow + case None => op(childTile, arr).toInternalRow + } + + } + + protected def op(left: Tile, right: IndexedSeq[AnyRef]): Tile = { + def fn(i: Int): Boolean = right.contains(i) + IfCell(left, fn(_), 1, 0) + } + +} + +object IsIn { + def apply(left: Column, right: Column): Column = + new Column(IsIn(left.expr, right.expr)) +} diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala index f97fea6f6..d289242bc 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala @@ -86,6 +86,7 @@ package object expressions { registry.registerExpression[GreaterEqual]("rf_local_greater_equal") registry.registerExpression[Equal]("rf_local_equal") registry.registerExpression[Unequal]("rf_local_unequal") + registry.registerExpression[IsIn]("rf_local_is_in") registry.registerExpression[Undefined]("rf_local_no_data") registry.registerExpression[Defined]("rf_local_data") registry.registerExpression[Sum]("rf_tile_sum") diff --git a/core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala b/core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala index f5256a32f..b424a730f 100644 --- a/core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala +++ b/core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala @@ -972,4 +972,28 @@ class RasterFunctionsSpec extends TestEnvironment with RasterMatchers { val dResult = df.select($"ld").as[Tile].first() dResult should be (randNDPRT.localDefined()) } + + it("should check values isin"){ + checkDocs("rf_local_is_in") + + // tile is 3 by 3 with values, 1 to 9 + val df = Seq(byteArrayTile).toDF("t") + .withColumn("one", lit(1)) + .withColumn("five", lit(5)) + .withColumn("ten", lit(10)) + .withColumn("in_expect_2", rf_local_is_in($"t", array($"one", $"five"))) + .withColumn("in_expect_1", rf_local_is_in($"t", array($"ten", $"five"))) + .withColumn("in_expect_0", rf_local_is_in($"t", array($"ten"))) + + val e2Result = df.select(rf_tile_sum($"in_expect_2")).as[Double].first() + e2Result should be (2.0) + + val e1Result = df.select(rf_tile_sum($"in_expect_1")).as[Double].first() + e1Result should be (1.0) + + val e0Result = df.select($"in_expect_0").as[Tile].first() + e0Result.toArray() should contain only (0) + +// lazy val invalid = df.select(rf_local_is_in($"t", lit("foobar"))).as[Tile].first() + } } diff --git a/pyrasterframes/src/main/python/docs/reference.pymd b/docs/src/main/paradox/reference.md similarity index 97% rename from pyrasterframes/src/main/python/docs/reference.pymd rename to docs/src/main/paradox/reference.md index 061c6b9f4..728c21ff6 100644 --- a/pyrasterframes/src/main/python/docs/reference.pymd +++ b/docs/src/main/paradox/reference.md @@ -192,7 +192,7 @@ Parameters `tile_columns` and `tile_rows` are literals, not column expressions. Tile rf_array_to_tile(Array arrayCol, Int numCols, Int numRows) -Python only. Create a `tile` from a Spark SQL [Array](http://spark.apache.org/docs/2.3.2/api/python/pyspark.sql.html#pyspark.sql.types.ArrayType), filling values in row-major order. +Python only. Create a `tile` from a Spark SQL [Array][Array], filling values in row-major order. ### rf_assemble_tile @@ -383,6 +383,13 @@ Returns a `tile` column containing the element-wise equality of `tile1` and `rhs Returns a `tile` column containing the element-wise inequality of `tile1` and `rhs`. +### rf_local_is_in + + Tile rf_local_is_in(Tile tile, Array array) + Tile rf_local_is_in(Tile tile, list l) + +Returns a `tile` column with cell values of 1 where the `tile` cell value is in the provided array or list. The `array` is a Spark SQL [Array][Array]. A python `list` of numeric values can also be passed. + ### rf_round Tile rf_round(Tile tile) @@ -630,13 +637,13 @@ Python only. As with @ref:[`rf_explode_tiles`](reference.md#rf-explode-tiles), b Array rf_tile_to_array_int(Tile tile) -Convert Tile column to Spark SQL [Array](http://spark.apache.org/docs/2.3.2/api/python/pyspark.sql.html#pyspark.sql.types.ArrayType), in row-major order. Float cell types will be coerced to integral type by flooring. +Convert Tile column to Spark SQL [Array][Array], in row-major order. Float cell types will be coerced to integral type by flooring. ### rf_tile_to_array_double Array rf_tile_to_arry_double(Tile tile) -Convert tile column to Spark [Array](http://spark.apache.org/docs/2.3.2/api/python/pyspark.sql.html#pyspark.sql.types.ArrayType), in row-major order. Integral cell types will be coerced to floats. +Convert tile column to Spark [Array][Array], in row-major order. Integral cell types will be coerced to floats. ### rf_render_ascii @@ -666,3 +673,4 @@ Runs [`rf_rgb_composite`](reference.md#rf-rgb-composite) on the given tile colum [RasterFunctions]: org.locationtech.rasterframes.RasterFunctions [scaladoc]: latest/api/index.html +[Array]: http://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.types.ArrayType diff --git a/docs/src/main/paradox/release-notes.md b/docs/src/main/paradox/release-notes.md index bb02af4cf..5a9c70c5b 100644 --- a/docs/src/main/paradox/release-notes.md +++ b/docs/src/main/paradox/release-notes.md @@ -6,6 +6,7 @@ * _Breaking_ (potentially): removed `GeoTiffCollectionRelation` due to usage limitation and overlap with `RasterSourceDataSource` functionality. * Upgraded to Spark 2.4.4 + * Add `rf_local_is_in` raster function ### 0.8.3 diff --git a/pyrasterframes/src/main/python/docs/nodata-handling.pymd b/pyrasterframes/src/main/python/docs/nodata-handling.pymd index c9fffe390..df7c30804 100644 --- a/pyrasterframes/src/main/python/docs/nodata-handling.pymd +++ b/pyrasterframes/src/main/python/docs/nodata-handling.pymd @@ -105,32 +105,23 @@ Drawing on @ref:[local map algebra](local-algebra.md) techniques, we will create ```python, def_mask from pyspark.sql.functions import lit -mask_part = unmasked.withColumn('nodata', rf_local_equal('scl', lit(0))) \ - .withColumn('defect', rf_local_equal('scl', lit(1))) \ - .withColumn('cloud8', rf_local_equal('scl', lit(8))) \ - .withColumn('cloud9', rf_local_equal('scl', lit(9))) \ - .withColumn('cirrus', rf_local_equal('scl', lit(10))) - -one_mask = mask_part.withColumn('mask', rf_local_add('nodata', 'defect')) \ - .withColumn('mask', rf_local_add('mask', 'cloud8')) \ - .withColumn('mask', rf_local_add('mask', 'cloud9')) \ - .withColumn('mask', rf_local_add('mask', 'cirrus')) - -cell_types = one_mask.select(rf_cell_type('mask')).distinct() +mask = unmasked.withColumn('mask', rf_local_is_in('scl', [0, 1, 8, 9, 10])) + +cell_types = mask.select(rf_cell_type('mask')).distinct() cell_types ``` Because there is not a NoData already defined, we will choose one. In this particular example, the minimum value is greater than zero, so we can use 0 as the NoData value. ```python, pick_nd -blue_min = one_mask.agg(rf_agg_stats('blue').min.alias('blue_min')) +blue_min = mask.agg(rf_agg_stats('blue').min.alias('blue_min')) blue_min ``` We can now construct the cell type string for our blue band's cell type, designating 0 as NoData. ```python, get_ct_string -blue_ct = one_mask.select(rf_cell_type('blue')).distinct().first()[0][0] +blue_ct = mask.select(rf_cell_type('blue')).distinct().first()[0][0] masked_blue_ct = CellType(blue_ct).with_no_data_value(0) masked_blue_ct.cell_type_name ``` @@ -139,9 +130,8 @@ Now we will use the @ref:[`rf_mask_by_value`](reference.md#rf-mask-by-value) to ```python, mask_blu with_nd = rf_convert_cell_type('blue', masked_blue_ct) -masked = one_mask.withColumn('blue_masked', - rf_mask_by_value(with_nd, 'mask', lit(1))) \ - .drop('nodata', 'defect', 'cloud8', 'cloud9', 'cirrus', 'blue') +masked = mask.withColumn('blue_masked', + rf_mask_by_value(with_nd, 'mask', lit(1))) ``` We can verify that the number of NoData cells in the resulting `blue_masked` column matches the total of the boolean `mask` _tile_ to ensure our logic is correct. diff --git a/pyrasterframes/src/main/python/docs/supervised-learning.pymd b/pyrasterframes/src/main/python/docs/supervised-learning.pymd index c66697032..81a81f634 100644 --- a/pyrasterframes/src/main/python/docs/supervised-learning.pymd +++ b/pyrasterframes/src/main/python/docs/supervised-learning.pymd @@ -32,7 +32,8 @@ catalog_df = pd.DataFrame([ {b: uri_base.format(b) for b in cols} ]) -df = spark.read.raster(catalog_df, catalog_col_names=cols, tile_dimensions=(128, 128)) \ +tile_size = 256 +df = spark.read.raster(catalog_df, catalog_col_names=cols, tile_dimensions=(tile_size, tile_size)) \ .repartition(100) df = df.select( @@ -91,23 +92,12 @@ To filter only for good quality pixels, we follow roughly the same procedure as ```python, make_mask from pyspark.sql.functions import lit -mask_part = df_labeled \ - .withColumn('nodata', rf_local_equal('scl', lit(0))) \ - .withColumn('defect', rf_local_equal('scl', lit(1))) \ - .withColumn('cloud8', rf_local_equal('scl', lit(8))) \ - .withColumn('cloud9', rf_local_equal('scl', lit(9))) \ - .withColumn('cirrus', rf_local_equal('scl', lit(10))) - -df_mask_inv = mask_part \ - .withColumn('mask', rf_local_add('nodata', 'defect')) \ - .withColumn('mask', rf_local_add('mask', 'cloud8')) \ - .withColumn('mask', rf_local_add('mask', 'cloud9')) \ - .withColumn('mask', rf_local_add('mask', 'cirrus')) \ - .drop('nodata', 'defect', 'cloud8', 'cloud9', 'cirrus') - +df_labeled = df_labeled \ + .withColumn('mask', rf_local_is_in('scl', [0, 1, 8, 9, 10])) + # at this point the mask contains 0 for good cells and 1 for defect, etc # convert cell type and set value 1 to NoData -df_mask = df_mask_inv.withColumn('mask', +df_mask = df_labeled.withColumn('mask', rf_with_no_data(rf_convert_cell_type('mask', 'uint8'), 1.0) ) @@ -204,29 +194,35 @@ scored = model.transform(df_mask.drop('label')) retiled = scored \ .groupBy('extent', 'crs') \ .agg( - rf_assemble_tile('column_index', 'row_index', 'prediction', 128, 128).alias('prediction'), - rf_assemble_tile('column_index', 'row_index', 'B04', 128, 128).alias('red'), - rf_assemble_tile('column_index', 'row_index', 'B03', 128, 128).alias('grn'), - rf_assemble_tile('column_index', 'row_index', 'B02', 128, 128).alias('blu') + rf_assemble_tile('column_index', 'row_index', 'prediction', tile_size, tile_size).alias('prediction'), + rf_assemble_tile('column_index', 'row_index', 'B04', tile_size, tile_size).alias('red'), + rf_assemble_tile('column_index', 'row_index', 'B03', tile_size, tile_size).alias('grn'), + rf_assemble_tile('column_index', 'row_index', 'B02', tile_size, tile_size).alias('blu') ) retiled.printSchema() ``` Take a look at a sample of the resulting output and the corresponding area's red-green-blue composite image. +Recall the label coding: 1 is forest (purple), 2 is cropland (green) and 3 is developed areas(yellow). ```python, display_rgb sample = retiled \ - .select('prediction', rf_rgb_composite('red', 'grn', 'blu').alias('rgb')) \ + .select('prediction', 'red', 'grn', 'blu') \ .sort(-rf_tile_sum(rf_local_equal('prediction', lit(3.0)))) \ .first() -sample_rgb = sample['rgb'] -mins = np.nanmin(sample_rgb.cells, axis=(0,1)) -plt.imshow((sample_rgb.cells - mins) / (np.nanmax(sample_rgb.cells, axis=(0,1)) - mins)) -``` +sample_rgb = np.concatenate([sample['red'].cells[:, :, None], + sample['grn'].cells[ :, :, None], + sample['blu'].cells[ :, :, None]], axis=2) +# plot scaled RGB +scaling_quantiles = np.nanpercentile(sample_rgb, [3.00, 97.00], axis=(0,1)) +scaled = np.clip(sample_rgb, scaling_quantiles[0, :], scaling_quantiles[1, :]) +scaled -= scaling_quantiles[0, :] +scaled /= (scaling_quantiles[1, : ] - scaling_quantiles[0, :]) -Recall the label coding: 1 is forest (purple), 2 is cropland (green) and 3 is developed areas(yellow). +fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5)) +ax1.imshow(scaled) -```python, display_prediction -display(sample['prediction']) +# display prediction +ax2.imshow(sample['prediction'].cells) ``` diff --git a/pyrasterframes/src/main/python/pyrasterframes/rasterfunctions.py b/pyrasterframes/src/main/python/pyrasterframes/rasterfunctions.py index 9f9d5225f..6848a304c 100644 --- a/pyrasterframes/src/main/python/pyrasterframes/rasterfunctions.py +++ b/pyrasterframes/src/main/python/pyrasterframes/rasterfunctions.py @@ -260,14 +260,24 @@ def rf_local_unequal_int(tile_col, scalar): """Return a Tile with values equal 1 if the cell is not equal to a scalar, otherwise 0""" return _apply_scalar_to_tile('rf_local_unequal_int', tile_col, scalar) + def rf_local_no_data(tile_col): """Return a tile with ones where the input is NoData, otherwise zero.""" return _apply_column_function('rf_local_no_data', tile_col) + def rf_local_data(tile_col): """Return a tile with zeros where the input is NoData, otherwise one.""" return _apply_column_function('rf_local_data', tile_col) +def rf_local_is_in(tile_col, array): + """Return a tile with cell values of 1 where the `tile_col` cell is in the provided array.""" + from pyspark.sql.functions import array as sql_array, lit + if isinstance(array, list): + array = sql_array([lit(v) for v in array]) + + return _apply_column_function('rf_local_is_in', tile_col, array) + def _apply_column_function(name, *args): jfcn = RFContext.active().lookup(name) jcols = [_to_java_column(arg) for arg in args] diff --git a/pyrasterframes/src/main/python/tests/PyRasterFramesTests.py b/pyrasterframes/src/main/python/tests/PyRasterFramesTests.py index 7cda3b997..3bb2ce491 100644 --- a/pyrasterframes/src/main/python/tests/PyRasterFramesTests.py +++ b/pyrasterframes/src/main/python/tests/PyRasterFramesTests.py @@ -131,7 +131,7 @@ def test_tile_udt_serialization(self): cells[1][1] = nd a_tile = Tile(cells, ct.with_no_data_value(nd)) round_trip = udt.fromInternal(udt.toInternal(a_tile)) - self.assertEquals(a_tile, round_trip, "round-trip serialization for " + str(ct)) + self.assertEqual(a_tile, round_trip, "round-trip serialization for " + str(ct)) schema = StructType([StructField("tile", TileUDT(), False)]) df = self.spark.createDataFrame([{"tile": a_tile}], schema) diff --git a/pyrasterframes/src/main/python/tests/RasterFunctionsTests.py b/pyrasterframes/src/main/python/tests/RasterFunctionsTests.py index e81b95594..f6fabf7b0 100644 --- a/pyrasterframes/src/main/python/tests/RasterFunctionsTests.py +++ b/pyrasterframes/src/main/python/tests/RasterFunctionsTests.py @@ -344,8 +344,11 @@ def test_rf_local_data_and_no_data(self): import numpy as np from numpy.testing import assert_equal - t = Tile(np.array([[1, 3, 4], [5, 0, 3]]), CellType.uint8().with_no_data_value(5)) - #note the convert is due to issue #188 + nd = 5 + t = Tile( + np.array([[1, 3, 4], [nd, 0, 3]]), + CellType.uint8().with_no_data_value(nd)) + # note the convert is due to issue #188 df = self.spark.createDataFrame([Row(t=t)])\ .withColumn('lnd', rf_convert_cell_type(rf_local_no_data('t'), 'uint8')) \ .withColumn('ld', rf_convert_cell_type(rf_local_data('t'), 'uint8')) @@ -357,6 +360,37 @@ def test_rf_local_data_and_no_data(self): result_d = result['ld'] assert_equal(result_d.cells, np.invert(t.cells.mask)) + def test_rf_local_is_in(self): + from pyspark.sql.functions import lit, array, col + from pyspark.sql import Row + import numpy as np + from numpy.testing import assert_equal + + nd = 5 + t = Tile( + np.array([[1, 3, 4], [nd, 0, 3]]), + CellType.uint8().with_no_data_value(nd)) + # note the convert is due to issue #188 + df = self.spark.createDataFrame([Row(t=t)]) \ + .withColumn('a', array(lit(3), lit(4))) \ + .withColumn('in2', rf_convert_cell_type( + rf_local_is_in(col('t'), array(lit(0), lit(4))), + 'uint8')) \ + .withColumn('in3', rf_convert_cell_type(rf_local_is_in('t', 'a'), 'uint8')) \ + .withColumn('in4', rf_convert_cell_type( + rf_local_is_in('t', array(lit(0), lit(4), lit(3))), + 'uint8')) \ + .withColumn('in_list', rf_convert_cell_type(rf_local_is_in(col('t'), [4, 1]), 'uint8')) + + result = df.first() + self.assertEqual(result['in2'].cells.sum(), 2) + assert_equal(result['in2'].cells, np.isin(t.cells, np.array([0, 4]))) + self.assertEqual(result['in3'].cells.sum(), 3) + self.assertEqual(result['in4'].cells.sum(), 4) + self.assertEqual(result['in_list'].cells.sum(), 2, + "Tile value {} should contain two 1s as: [[1, 0, 1],[0, 0, 0]]" + .format(result['in_list'].cells)) + def test_rf_spatial_index(self): from pyspark.sql.functions import min as F_min result_one_arg = self.df.select(rf_spatial_index('tile').alias('ix')) \