Skip to content

Commit

Permalink
Merge pull request #537 from jbouffard/bug-fix/stitch
Browse files Browse the repository at this point in the history
stitch and saveStitched Now Work With MultibandTiles
  • Loading branch information
Jacob Bouffard committed Nov 8, 2017
2 parents 0c98889 + 114ca3a commit ecfdf36
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -172,14 +172,8 @@ class SpatialTiledRasterLayer(
def mask(geometries: Seq[MultiPolygon]): TiledRasterLayer[SpatialKey] =
SpatialTiledRasterLayer(zoomLevel, Mask(rdd, geometries, Mask.Options.DEFAULT))

def stitch: Array[Byte] = {
val contextRDD = ContextRDD(
rdd.mapValues({ v => v.band(0) }),
rdd.metadata
)

PythonTranslator.toPython[Tile, ProtoTile](contextRDD.stitch.tile)
}
def stitch: Array[Byte] =
PythonTranslator.toPython[MultibandTile, ProtoMultibandTile](ContextRDD(rdd, rdd.metadata).stitch.tile)

def saveStitched(path: String): Unit =
saveStitched(path, None, None)
Expand All @@ -199,12 +193,9 @@ class SpatialTiledRasterLayer(
cropBounds: Option[java.util.Map[String, Double]],
cropDimensions: Option[ArrayList[Int]]
): Unit = {
val contextRDD = ContextRDD(
rdd.map({ case (k, v) => (k, v.band(0)) }),
rdd.metadata
)
val contextRDD = ContextRDD(rdd, rdd.metadata)

val stitched: Raster[Tile] = contextRDD.stitch()
val stitched: Raster[MultibandTile] = contextRDD.stitch()

val adjusted = {
val cropped =
Expand Down
2 changes: 1 addition & 1 deletion geopyspark/geotrellis/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1329,7 +1329,7 @@ def stitch(self):
raise ValueError("Only TiledRasterLayers with a layer_type of Spatial can use stitch()")

value = self.srdd.stitch()
ser = ProtoBufSerializer.create_value_serializer("Tile")
ser = ProtoBufSerializer.create_value_serializer("MultibandTile")
return ser.loads(value)[0]

def save_stitched(self, path, crop_bounds=None, crop_dimensions=None):
Expand Down
10 changes: 5 additions & 5 deletions geopyspark/tests/tiled_layer_tests/stitch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ class StitchTest(BaseTestClass):
[1.0, 1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0, 0.0]]])

layer = [(SpatialKey(0, 0), Tile(cells, 'FLOAT', -1.0)),
(SpatialKey(1, 0), Tile(cells, 'FLOAT', -1.0,)),
(SpatialKey(0, 1), Tile(cells, 'FLOAT', -1.0,)),
(SpatialKey(1, 1), Tile(cells, 'FLOAT', -1.0,))]
layer = [(SpatialKey(0, 0), Tile(np.array([cells, cells]), 'FLOAT', -1.0)),
(SpatialKey(1, 0), Tile(np.array([cells, cells]), 'FLOAT', -1.0,)),
(SpatialKey(0, 1), Tile(np.array([cells, cells]), 'FLOAT', -1.0,)),
(SpatialKey(1, 1), Tile(np.array([cells, cells]), 'FLOAT', -1.0,))]
rdd = BaseTestClass.pysc.parallelize(layer)

extent = {'xmin': 0.0, 'ymin': 0.0, 'xmax': 33.0, 'ymax': 33.0}
Expand All @@ -45,7 +45,7 @@ def tearDown(self):

def test_stitch(self):
result = self.raster_rdd.stitch()
self.assertTrue(result.cells.shape == (1, 10, 10))
self.assertTrue(result.cells.shape == (2, 10, 10))


if __name__ == "__main__":
Expand Down

0 comments on commit ecfdf36

Please sign in to comment.