Skip to content

Commit

Permalink
Added CollectNeighbors
Browse files Browse the repository at this point in the history
  • Loading branch information
lossyrob committed Dec 1, 2016
1 parent 641fb14 commit 4e03093
Show file tree
Hide file tree
Showing 7 changed files with 368 additions and 80 deletions.
134 changes: 63 additions & 71 deletions spark/src/main/scala/geotrellis/spark/buffer/BufferTiles.scala
Expand Up @@ -29,19 +29,11 @@ import scala.reflect.ClassTag
import scala.collection.mutable.ArrayBuffer

object BufferTiles {
sealed trait Direction

case object Center extends Direction
case object Top extends Direction
case object TopRight extends Direction
case object Right extends Direction
case object BottomRight extends Direction
case object Bottom extends Direction
case object BottomLeft extends Direction
case object Left extends Direction
case object TopLeft extends Direction

def collectWithNeighbors[K: SpatialComponent, V <: CellGrid: (? => CropMethods[V])](

/** Collects tile neighbors by slicing the neighboring tiles to the given
* buffer size
*/
def collectWithTileNeighbors[K: SpatialComponent, V <: CellGrid: (? => CropMethods[V])](
key: K,
tile: V,
includeKey: SpatialKey => Boolean,
Expand All @@ -60,33 +52,33 @@ object BufferTiles {

val part: V =
direction match {
case Center => tile
case Right => tile.crop(0, 0, bufferSizes.right - 1, rows - 1, Crop.Options(force = true))
case Left => tile.crop(cols - bufferSizes.left, 0, cols - 1, rows - 1, Crop.Options(force = true))
case Top => tile.crop(0, rows - bufferSizes.top, cols - 1, rows - 1, Crop.Options(force = true))
case Bottom => tile.crop(0, 0, cols - 1, bufferSizes.bottom - 1, Crop.Options(force = true))
case TopLeft => tile.crop(cols - bufferSizes.left, rows - bufferSizes.top, cols - 1, rows - 1, Crop.Options(force = true))
case TopRight => tile.crop(0, rows - bufferSizes.top, bufferSizes.right - 1, rows - 1, Crop.Options(force = true))
case BottomLeft => tile.crop(cols - bufferSizes.left, 0, cols - 1, bufferSizes.bottom - 1, Crop.Options(force = true))
case BottomRight => tile.crop(0, 0, bufferSizes.right - 1, bufferSizes.bottom - 1, Crop.Options(force = true))
case CenterDirection => tile
case RightDirection => tile.crop(0, 0, bufferSizes.right - 1, rows - 1, Crop.Options(force = true))
case LeftDirection => tile.crop(cols - bufferSizes.left, 0, cols - 1, rows - 1, Crop.Options(force = true))
case TopDirection => tile.crop(0, rows - bufferSizes.top, cols - 1, rows - 1, Crop.Options(force = true))
case BottomDirection => tile.crop(0, 0, cols - 1, bufferSizes.bottom - 1, Crop.Options(force = true))
case TopLeftDirection => tile.crop(cols - bufferSizes.left, rows - bufferSizes.top, cols - 1, rows - 1, Crop.Options(force = true))
case TopRightDirection => tile.crop(0, rows - bufferSizes.top, bufferSizes.right - 1, rows - 1, Crop.Options(force = true))
case BottomLeftDirection => tile.crop(cols - bufferSizes.left, 0, cols - 1, bufferSizes.bottom - 1, Crop.Options(force = true))
case BottomRightDirection => tile.crop(0, 0, bufferSizes.right - 1, bufferSizes.bottom - 1, Crop.Options(force = true))
}

parts += ( (key.setComponent(spatialKey), (direction, part)) )
}
}

// ex: A tile that contributes to the top (tile above it) will give up it's top slice, which will be placed at the bottom of the target focal window
addSlice(SpatialKey(col,row), Center)
addSlice(SpatialKey(col,row), CenterDirection)

addSlice(SpatialKey(col-1, row), Right)
addSlice(SpatialKey(col+1, row), Left)
addSlice(SpatialKey(col, row-1), Bottom)
addSlice(SpatialKey(col, row+1), Top)
addSlice(SpatialKey(col-1, row), RightDirection)
addSlice(SpatialKey(col+1, row), LeftDirection)
addSlice(SpatialKey(col, row-1), BottomDirection)
addSlice(SpatialKey(col, row+1), TopDirection)

addSlice(SpatialKey(col-1, row-1), BottomRight)
addSlice(SpatialKey(col+1, row-1), BottomLeft)
addSlice(SpatialKey(col+1, row+1), TopLeft)
addSlice(SpatialKey(col-1, row+1), TopRight)
addSlice(SpatialKey(col-1, row-1), BottomRightDirection)
addSlice(SpatialKey(col+1, row-1), BottomLeftDirection)
addSlice(SpatialKey(col+1, row+1), TopLeftDirection)
addSlice(SpatialKey(col-1, row+1), TopRightDirection)

parts
}
Expand All @@ -97,20 +89,20 @@ object BufferTiles {
](rdd: RDD[(K, Iterable[(Direction, V)])]): RDD[(K, BufferedTile[V])] = {
val r = rdd
.flatMapValues { neighbors =>
neighbors.find( _._1 == Center) map { case (_, centerTile) =>
neighbors.find( _._1 == CenterDirection) map { case (_, centerTile) =>

val bufferSizes =
neighbors.foldLeft(BufferSizes(0, 0, 0, 0)) { (acc, tup) =>
val (direction, slice) = tup
direction match {
case Left => acc.copy(left = slice.cols)
case Right => acc.copy(right = slice.cols)
case Top => acc.copy(top = slice.rows)
case Bottom => acc.copy(bottom = slice.rows)
case BottomRight => acc.copy(bottom = slice.rows, right = slice.cols)
case BottomLeft => acc.copy(bottom = slice.rows, left = slice.cols)
case TopRight => acc.copy(top = slice.rows, right = slice.cols)
case TopLeft => acc.copy(top = slice.rows, left = slice.cols)
case LeftDirection => acc.copy(left = slice.cols)
case RightDirection => acc.copy(right = slice.cols)
case TopDirection => acc.copy(top = slice.rows)
case BottomDirection => acc.copy(bottom = slice.rows)
case BottomRightDirection => acc.copy(bottom = slice.rows, right = slice.cols)
case BottomLeftDirection => acc.copy(bottom = slice.rows, left = slice.cols)
case TopRightDirection => acc.copy(top = slice.rows, right = slice.cols)
case TopLeftDirection => acc.copy(top = slice.rows, left = slice.cols)
case _ => acc
}
}
Expand All @@ -119,15 +111,15 @@ object BufferTiles {
neighbors.map { case (direction, slice) =>
val (updateColMin, updateRowMin) =
direction match {
case Center => (bufferSizes.left, bufferSizes.top)
case Left => (0, bufferSizes.top)
case Right => (bufferSizes.left + centerTile.cols, bufferSizes.top)
case Top => (bufferSizes.left, 0)
case Bottom => (bufferSizes.left, bufferSizes.top + centerTile.rows)
case TopLeft => (0, 0)
case TopRight => (bufferSizes.left + centerTile.cols, 0)
case BottomLeft => (0, bufferSizes.top + centerTile.rows)
case BottomRight => (bufferSizes.left + centerTile.cols, bufferSizes.top + centerTile.rows)
case CenterDirection => (bufferSizes.left, bufferSizes.top)
case LeftDirection => (0, bufferSizes.top)
case RightDirection => (bufferSizes.left + centerTile.cols, bufferSizes.top)
case TopDirection => (bufferSizes.left, 0)
case BottomDirection => (bufferSizes.left, bufferSizes.top + centerTile.rows)
case TopLeftDirection => (0, 0)
case TopRightDirection => (bufferSizes.left + centerTile.cols, 0)
case BottomLeftDirection => (0, bufferSizes.top + centerTile.rows)
case BottomRightDirection => (bufferSizes.left + centerTile.cols, bufferSizes.top + centerTile.rows)
}

(slice, (updateColMin, updateRowMin))
Expand All @@ -150,20 +142,20 @@ object BufferTiles {
](seq: Seq[(K, Seq[(Direction, V)])]): Seq[(K, BufferedTile[V])] = {
seq
.flatMap { case (key, neighbors) =>
val opt = neighbors.find(_._1 == Center).map { case (_, centerTile) =>
val opt = neighbors.find(_._1 == CenterDirection).map { case (_, centerTile) =>

val bufferSizes =
neighbors.foldLeft(BufferSizes(0, 0, 0, 0)) { (acc, tup) =>
val (direction, slice) = tup
direction match {
case Left => acc.copy(left = slice.cols)
case Right => acc.copy(right = slice.cols)
case Top => acc.copy(top = slice.rows)
case Bottom => acc.copy(bottom = slice.rows)
case BottomRight => acc.copy(bottom = slice.rows, right = slice.cols)
case BottomLeft => acc.copy(bottom = slice.rows, left = slice.cols)
case TopRight => acc.copy(top = slice.rows, right = slice.cols)
case TopLeft => acc.copy(top = slice.rows, left = slice.cols)
case LeftDirection => acc.copy(left = slice.cols)
case RightDirection => acc.copy(right = slice.cols)
case TopDirection => acc.copy(top = slice.rows)
case BottomDirection => acc.copy(bottom = slice.rows)
case BottomRightDirection => acc.copy(bottom = slice.rows, right = slice.cols)
case BottomLeftDirection => acc.copy(bottom = slice.rows, left = slice.cols)
case TopRightDirection => acc.copy(top = slice.rows, right = slice.cols)
case TopLeftDirection => acc.copy(top = slice.rows, left = slice.cols)
case _ => acc
}
}
Expand All @@ -172,15 +164,15 @@ object BufferTiles {
neighbors.map { case (direction, slice) =>
val (updateColMin, updateRowMin) =
direction match {
case Center => (bufferSizes.left, bufferSizes.top)
case Left => (0, bufferSizes.top)
case Right => (bufferSizes.left + centerTile.cols, bufferSizes.top)
case Top => (bufferSizes.left, 0)
case Bottom => (bufferSizes.left, bufferSizes.top + centerTile.rows)
case TopLeft => (0, 0)
case TopRight => (bufferSizes.left + centerTile.cols, 0)
case BottomLeft => (0, bufferSizes.top + centerTile.rows)
case BottomRight => (bufferSizes.left + centerTile.cols, bufferSizes.top + centerTile.rows)
case CenterDirection => (bufferSizes.left, bufferSizes.top)
case LeftDirection => (0, bufferSizes.top)
case RightDirection => (bufferSizes.left + centerTile.cols, bufferSizes.top)
case TopDirection => (bufferSizes.left, 0)
case BottomDirection => (bufferSizes.left, bufferSizes.top + centerTile.rows)
case TopLeftDirection => (0, 0)
case TopRightDirection => (bufferSizes.left + centerTile.cols, 0)
case BottomLeftDirection => (0, bufferSizes.top + centerTile.rows)
case BottomRightDirection => (bufferSizes.left + centerTile.cols, bufferSizes.top + centerTile.rows)
}

(slice, (updateColMin, updateRowMin))
Expand Down Expand Up @@ -254,7 +246,7 @@ object BufferTiles {
val tilesAndSlivers =
rdd
.flatMap { case (key, tile) =>
collectWithNeighbors(key, tile, { key => layerBounds.contains(key.col, key.row) }, { key => bufferSizes })
collectWithTileNeighbors(key, tile, { key => layerBounds.contains(key.col, key.row) }, { key => bufferSizes })
}

val grouped =
Expand Down Expand Up @@ -357,7 +349,7 @@ object BufferTiles {
rdd
.join(surroundingBufferSizes)
.flatMap { case (key, (tile, bufferSizesMap)) =>
collectWithNeighbors(key, tile, bufferSizesMap.contains _, bufferSizesMap)
collectWithTileNeighbors(key, tile, bufferSizesMap.contains _, bufferSizesMap)
}

val grouped =
Expand Down Expand Up @@ -409,7 +401,7 @@ object BufferTiles {

val grouped: Seq[(K, Seq[(Direction, V)])] =
seq.zip(surroundingBufferSizes).flatMap { case ((key, tile), (k2, bufferSizesMap)) =>
collectWithNeighbors(key, tile, bufferSizesMap.contains _, bufferSizesMap)
collectWithTileNeighbors(key, tile, bufferSizesMap.contains _, bufferSizesMap)
}.groupBy(_._1).mapValues(_.map(_._2)).toSeq

bufferWithNeighbors(grouped)
Expand Down Expand Up @@ -437,7 +429,7 @@ object BufferTiles {
val grouped: Seq[(K, Seq[(Direction, V)])] =
seq
.flatMap { case (key, tile) =>
collectWithNeighbors(key, tile, { key => layerBounds.contains(key.col, key.row) }, { key => bufferSizes })
collectWithTileNeighbors(key, tile, { key => layerBounds.contains(key.col, key.row) }, { key => bufferSizes })
}.groupBy(_._1).mapValues { _.map(_._2) }.toSeq

bufferWithNeighbors(grouped)
Expand Down
@@ -0,0 +1,72 @@
/*
* Copyright 2016 Azavea
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package geotrellis.spark.buffer

import geotrellis.spark._
import geotrellis.raster._
import geotrellis.raster.crop._
import geotrellis.raster.stitch._
import geotrellis.util._

import org.apache.spark.rdd._
import org.apache.spark.storage.StorageLevel

import scala.reflect.ClassTag
import scala.collection.mutable.ArrayBuffer

object CollectNeighbors {

/** Collects tile neighbors by slicing the neighboring tiles to the given
* buffer size
*/
def apply[K: SpatialComponent: ClassTag, V](rdd: RDD[(K, V)]): RDD[(K, Map[Direction, (K, V)])] = {
val neighbored: RDD[(K, (Direction, (K, V)))] =
rdd
.flatMap { case (key, value) =>
val SpatialKey(col, row) = key

Seq(
(key, (CenterDirection, (key, value))),

(key.setComponent(SpatialKey(col-1, row)), (RightDirection, (key, value))),
(key.setComponent(SpatialKey(col+1, row)), (LeftDirection, (key, value))),
(key.setComponent(SpatialKey(col, row-1)), (BottomDirection, (key, value))),
(key.setComponent(SpatialKey(col, row+1)), (TopDirection, (key, value))),

(key.setComponent(SpatialKey(col-1, row-1)), (BottomRightDirection, (key, value))),
(key.setComponent(SpatialKey(col+1, row-1)), (BottomLeftDirection, (key, value))),
(key.setComponent(SpatialKey(col-1, row+1)), (TopRightDirection, (key, value))),
(key.setComponent(SpatialKey(col+1, row+1)), (TopLeftDirection, (key, value)))
)
}

val grouped: RDD[(K, Iterable[(Direction, (K, V))])] =
rdd.partitioner match {
case Some(partitioner) => neighbored.groupByKey(partitioner)
case None => neighbored.groupByKey
}

grouped
.filter { case (_, values) =>
values.find {
case (CenterDirection, _) => true
case _ => false
}.isDefined
}
.mapValues(_.toMap)
}
}
@@ -0,0 +1,30 @@
/*
* Copyright 2016 Azavea
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package geotrellis.spark.buffer

import geotrellis.spark._
import geotrellis.util.MethodExtensions

import org.apache.spark.rdd.RDD

import scala.reflect.ClassTag

abstract class CollectNeighborsMethods[K: SpatialComponent: ClassTag, V](val self: RDD[(K, V)])
extends MethodExtensions[RDD[(K, V)]] {
def collectNeighbors(): RDD[(K, Map[Direction, (K, V)])] =
CollectNeighbors(self)
}
29 changes: 29 additions & 0 deletions spark/src/main/scala/geotrellis/spark/buffer/Direction.scala
@@ -0,0 +1,29 @@
/*
* Copyright 2016 Azavea
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package geotrellis.spark.buffer

sealed trait Direction

case object CenterDirection extends Direction
case object TopDirection extends Direction
case object TopRightDirection extends Direction
case object RightDirection extends Direction
case object BottomRightDirection extends Direction
case object BottomDirection extends Direction
case object BottomLeftDirection extends Direction
case object LeftDirection extends Direction
case object TopLeftDirection extends Direction
5 changes: 5 additions & 0 deletions spark/src/main/scala/geotrellis/spark/buffer/Implicits.scala
Expand Up @@ -28,6 +28,11 @@ import scala.reflect.ClassTag
object Implicits extends Implicits

trait Implicits {
implicit class withCollectNeighborsMethodsWrapper[
K: SpatialComponent: ClassTag,
V
](self: RDD[(K, V)]) extends CollectNeighborsMethods[K, V](self)

implicit class withBufferTilesMethodsWrapper[
K: SpatialComponent: ClassTag,
V <: CellGrid: Stitcher: ClassTag: (? => CropMethods[V])
Expand Down
Expand Up @@ -140,15 +140,15 @@ class KryoRegistrator extends SparkKryoRegistrator {
kryo.register(classOf[geotrellis.proj4.CRS])

// UnmodifiableCollectionsSerializer.registerSerializers(kryo)
kryo.register(geotrellis.spark.buffer.BufferTiles.Center.getClass)
kryo.register(geotrellis.spark.buffer.BufferTiles.Top.getClass)
kryo.register(geotrellis.spark.buffer.BufferTiles.Bottom.getClass)
kryo.register(geotrellis.spark.buffer.BufferTiles.Left.getClass)
kryo.register(geotrellis.spark.buffer.BufferTiles.Right.getClass)
kryo.register(geotrellis.spark.buffer.BufferTiles.TopLeft.getClass)
kryo.register(geotrellis.spark.buffer.BufferTiles.TopRight.getClass)
kryo.register(geotrellis.spark.buffer.BufferTiles.BottomLeft.getClass)
kryo.register(geotrellis.spark.buffer.BufferTiles.BottomRight.getClass)
kryo.register(geotrellis.spark.buffer.CenterDirection.getClass)
kryo.register(geotrellis.spark.buffer.TopDirection.getClass)
kryo.register(geotrellis.spark.buffer.BottomDirection.getClass)
kryo.register(geotrellis.spark.buffer.LeftDirection.getClass)
kryo.register(geotrellis.spark.buffer.RightDirection.getClass)
kryo.register(geotrellis.spark.buffer.TopLeftDirection.getClass)
kryo.register(geotrellis.spark.buffer.TopRightDirection.getClass)
kryo.register(geotrellis.spark.buffer.BottomLeftDirection.getClass)
kryo.register(geotrellis.spark.buffer.BottomRightDirection.getClass)

/* Exhaustive Registration */
kryo.register(classOf[Array[Double]])
Expand Down

0 comments on commit 4e03093

Please sign in to comment.