Skip to content

Commit

Permalink
Merge pull request #503 from jbouffard/feature/merge
Browse files Browse the repository at this point in the history
RasterLayer.merge Method
  • Loading branch information
Jacob Bouffard committed Sep 25, 2017
2 parents 722db90 + bea52c4 commit 6c005a8
Show file tree
Hide file tree
Showing 6 changed files with 231 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,10 @@ abstract class RasterLayer[K](implicit ev0: ClassTag[K], ev1: Component[K, Proje
protected def reproject(targetCRS: String, layoutType: LayoutType, resampleMethod: ResampleMethod): TiledRasterLayer[_]
protected def reproject(targetCRS: String, layoutDefinition: LayoutDefinition, resampleMethod: ResampleMethod): TiledRasterLayer[_]
protected def withRDD(result: RDD[(K, MultibandTile)]): RasterLayer[K]

def merge(numPartitions: Integer): RasterLayer[K] =
numPartitions match {
case i: Integer => withRDD(rdd.merge(Some(new HashPartitioner(i))))
case null => withRDD(rdd.merge())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,12 @@ class TemporalRasterLayer(val rdd: RDD[(TemporalProjectedExtent, MultibandTile)]
rdd
.filter { case (key, _) => key.instant == instant }
.map { x => (x._1.projectedExtent, x._2) }
.merge()

ProjectedRasterLayer(spatialRDD)
}

def toSpatialLayer(): ProjectedRasterLayer =
ProjectedRasterLayer(rdd.map { x => (x._1.projectedExtent, x._2) }.merge())
ProjectedRasterLayer(rdd.map { x => (x._1.projectedExtent, x._2) })

def collectKeys(): java.util.ArrayList[Array[Byte]] =
PythonTranslator.toPython[TemporalProjectedExtent, ProtoTemporalProjectedExtent](rdd.keys.collect)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,6 @@ class TemporalTiledRasterLayer(
rdd
.filter { case (key, _) => key.instant == instant }
.map { x => (x._1.spatialKey, x._2) }
.merge()

val (minKey, maxKey) = (spatialRDD.keys.min(), spatialRDD.keys.max())

Expand All @@ -377,7 +376,8 @@ class TemporalTiledRasterLayer(
}

def toSpatialLayer(): SpatialTiledRasterLayer = {
val spatialRDD = rdd.map { x => (x._1.spatialKey, x._2) }.merge()
val spatialRDD = rdd.map { x => (x._1.spatialKey, x._2) }

val bounds = rdd.metadata.bounds.get
val spatialMetadata =
rdd.metadata.copy(bounds = Bounds(bounds.minKey.spatialKey, bounds.maxKey.spatialKey))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,19 @@ abstract class TiledRasterLayer[K: SpatialComponent: JsonFormat: ClassTag: Bound
withRDD(result.mapValues { tiles => MultibandTile(tiles) } )
}

def merge(numPartitions: Integer): TiledRasterLayer[K] =
numPartitions match {
case i: Integer => withRDD(
ContextRDD(
rdd
.asInstanceOf[RDD[(K, MultibandTile)]]
.merge(Some(new HashPartitioner(i))),
rdd.metadata
)
)
case null => withRDD(ContextRDD(rdd.asInstanceOf[RDD[(K, MultibandTile)]].merge(), rdd.metadata))
}

def isFloatingPointLayer(): Boolean = rdd.metadata.cellType.isFloatingPoint

protected def withRDD(result: RDD[(K, MultibandTile)]): TiledRasterLayer[K]
Expand Down
52 changes: 52 additions & 0 deletions geopyspark/geotrellis/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,32 @@ def collect_keys(self):
else:
return [temporal_projected_extent_decoder(key) for key in self.srdd.collectKeys()]

def merge(self, num_partitions=None):
"""Merges the ``Tile`` of each ``K`` together to produce a single ``Tile``.
This method will reduce each value by its key within the layer to produce a single
``(K, V)`` for every ``K``. In order to achieve this, each ``Tile`` that shares a
``K`` is merged together to form a single ``Tile``. This is done by replacing
one ``Tile``\'s cells with another's. Not all cells, if any, may be replaced, however.
The following steps are taken to determine if a cell's value should be replaced:
1. If the cell contains a ``NoData`` value, then it will be replaced.
2. If no ``NoData`` value is set, then a cell with a value of 0 will be replaced.
3. If neither of the above are true, then the cell retain its value.
Args:
num_partitions (int, optional): The number of partitions that the resulting
layer should be partitioned with. If ``None``, then the ``num_partitions``
will the number of partitions the layer curretly has.
Returns:
:class:`~geopyspark.geotrellis.layer.RasterLayer`
"""

result = self.srdd.merge(num_partitions)

return RasterLayer(self.layer_type, result)

def collect_metadata(self, layout=LocalLayout()):
"""Iterate over the RDD records and generates layer metadata desribing the contained
rasters.
Expand Down Expand Up @@ -863,6 +889,32 @@ def collect_keys(self):
else:
return [space_time_key_decoder(key) for key in self.srdd.collectKeys()]

def merge(self, num_partitions=None):
"""Merges the ``Tile`` of each ``K`` together to produce a single ``Tile``.
This method will reduce each value by its key within the layer to produce a single
``(K, V)`` for every ``K``. In order to achieve this, each ``Tile`` that shares a
``K`` is merged together to form a single ``Tile``. This is done by replacing
one ``Tile``\'s cells with another's. Not all cells, if any, may be replaced, however.
The following steps are taken to determine if a cell's value should be replaced:
1. If the cell contains a ``NoData`` value, then it will be replaced.
2. If no ``NoData`` value is set, then a cell with a value of 0 will be replaced.
3. If neither of the above are true, then the cell retain its value.
Args:
num_partitions (int, optional): The number of partitions that the resulting
layer should be partitioned with. If ``None``, then the ``num_partitions``
will the number of partitions the layer curretly has.
Returns:
:class:`~geopyspark.geotrellis.layer.TiledRasterLayer`
"""

result = self.srdd.merge(num_partitions)

return TiledRasterLayer(self.layer_type, result)

def bands(self, band):
"""Select a subsection of bands from the ``Tile``\s within the layer.
Expand Down
157 changes: 157 additions & 0 deletions geopyspark/tests/merge_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import datetime
import unittest
import pytest
import numpy as np

from geopyspark.geotrellis import (ProjectedExtent,
Extent,
TemporalProjectedExtent,
SpatialKey,
SpaceTimeKey,
Tile,
Bounds,
TileLayout,
LayoutDefinition,
Metadata)
from geopyspark.tests.base_test_class import BaseTestClass
from geopyspark.geotrellis.layer import RasterLayer, TiledRasterLayer
from geopyspark.geotrellis.constants import LayerType


class MergeTest(BaseTestClass):
arr_1 = np.zeros((1, 4, 4))
arr_2 = np.ones((1, 4, 4))

tile_1 = Tile.from_numpy_array(arr_1)
tile_2 = Tile.from_numpy_array(arr_2)

crs = 4326
time = datetime.datetime.strptime("2016-08-24T09:00:00Z", '%Y-%m-%dT%H:%M:%SZ')

extents = [
Extent(0.0, 0.0, 4.0, 4.0),
Extent(0.0, 4.0, 4.0, 8.0),
]

extent = Extent(0.0, 0.0, 8.0, 8.0)
layout = TileLayout(2, 2, 5, 5)

ct = 'float32ud-1.0'
md_proj = '+proj=longlat +datum=WGS84 +no_defs '
ld = LayoutDefinition(extent, layout)

@pytest.fixture(autouse=True)
def tearDown(self):
yield
BaseTestClass.pysc._gateway.close()

def test_projected_extent(self):
pes = [
ProjectedExtent(extent=self.extents[0], epsg=self.crs),
ProjectedExtent(extent=self.extents[1], epsg=self.crs),
]

pe_layer = [
(pes[0], self.tile_1),
(pes[0], self.tile_2),
(pes[1], self.tile_1),
(pes[1], self.tile_2)
]

rdd = self.pysc.parallelize(pe_layer)
layer = RasterLayer.from_numpy_rdd(LayerType.SPATIAL, rdd)

actual = layer.merge()

self.assertEqual(actual.srdd.rdd().count(), 2)

for k, v in actual.to_numpy_rdd().collect():
self.assertTrue((v.cells == self.arr_2).all())

def test_temporal_projected_extent(self):
pes = [
TemporalProjectedExtent(extent=self.extents[0], epsg=self.crs, instant=self.time),
TemporalProjectedExtent(extent=self.extents[1], epsg=self.crs, instant=self.time),
]

pe_layer = [
(pes[0], self.tile_1),
(pes[1], self.tile_1),
(pes[0], self.tile_2),
(pes[1], self.tile_2)
]

rdd = self.pysc.parallelize(pe_layer)
layer = RasterLayer.from_numpy_rdd(LayerType.SPACETIME, rdd)

actual = layer.merge()

self.assertEqual(actual.srdd.rdd().count(), 2)

for k, v in actual.to_numpy_rdd().collect():
self.assertTrue((v.cells == self.arr_2).all())

def test_spatial_keys(self):
keys = [
SpatialKey(0, 0),
SpatialKey(0, 1)
]

key_layer = [
(keys[0], self.tile_1),
(keys[1], self.tile_1),
(keys[0], self.tile_2),
(keys[1], self.tile_2)
]

bounds = Bounds(keys[0], keys[1])

md = Metadata(bounds=bounds,
crs=self.md_proj,
cell_type=self.ct,
extent=self.extent,
layout_definition=self.ld)

rdd = self.pysc.parallelize(key_layer)
layer = TiledRasterLayer.from_numpy_rdd(LayerType.SPATIAL, rdd, md)

actual = layer.merge()

self.assertEqual(actual.srdd.rdd().count(), 2)

for k, v in actual.to_numpy_rdd().collect():
self.assertTrue((v.cells == self.arr_2).all())

def test_space_time_keys(self):
temp_keys = [
SpaceTimeKey(0, 0, instant=self.time),
SpaceTimeKey(0, 1, instant=self.time)
]

temp_key_layer = [
(temp_keys[0], self.tile_2),
(temp_keys[1], self.tile_2),
(temp_keys[0], self.tile_2),
(temp_keys[1], self.tile_2)
]

temp_bounds = Bounds(temp_keys[0], temp_keys[1])

temp_md = Metadata(bounds=temp_bounds,
crs=self.md_proj,
cell_type=self.ct,
extent=self.extent,
layout_definition=self.ld)

rdd = self.pysc.parallelize(temp_key_layer)
layer = TiledRasterLayer.from_numpy_rdd(LayerType.SPACETIME, rdd, temp_md)

actual = layer.merge()

self.assertEqual(actual.srdd.rdd().count(), 2)

for k, v in actual.to_numpy_rdd().collect():
self.assertTrue((v.cells == self.arr_2).all())

if __name__ == "__main__":
unittest.main()

0 comments on commit 6c005a8

Please sign in to comment.