Skip to content

Commit

Permalink
Make Windows Conform To Segments, Pack Partitions
Browse files Browse the repository at this point in the history
  • Loading branch information
James McClain authored and echeipesh committed Oct 11, 2017
1 parent 9eabf32 commit 4c24f24
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 31 deletions.
28 changes: 15 additions & 13 deletions s3/src/main/scala/geotrellis/spark/io/s3/S3GeoTiffRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -132,28 +132,29 @@ object S3GeoTiffRDD extends LazyLogging {
lazy val sourceGeoTiffInfo = S3GeoTiffInfoReader(bucket, prefix, options)

(options.maxTileSize, options.partitionBytes) match {
case (None, Some(partitionBytes)) =>
val segments: RDD[(String, Array[GridBounds])] =
sourceGeoTiffInfo.segmentsByPartitionBytes(partitionBytes, windowSize)
case (_, Some(partitionBytes)) => {
val windows: RDD[(String, Array[GridBounds])] =
sourceGeoTiffInfo.windowsByBytes(partitionBytes, options.maxTileSize.getOrElse(1<<10))

segments.persist() // StorageLevel.MEMORY_ONLY by default
val segmentsCount = segments.count.toInt
windows.persist()

logger.info(s"repartition into ${segmentsCount} partitions.")
val windowCount = windows.count.toInt

logger.info(s"Repartition into ${windowCount} partitions.")

val repartition =
if(segmentsCount > segments.partitions.length) segments.repartition(segmentsCount)
else segments
if (windowCount > windows.partitions.length) windows.repartition(windowCount)
else windows

val result = repartition.flatMap { case (path, segmentBounds) =>
rr.readWindows(segmentBounds, sourceGeoTiffInfo.getGeoTiffInfo(path), options).map { case (k, v) =>
val result = repartition.flatMap { case (path, windowBounds) =>
rr.readWindows(windowBounds, sourceGeoTiffInfo.getGeoTiffInfo(path), options).map { case (k, v) =>
uriToKey(new URI(path), k) -> v
}
}

segments.unpersist()
windows.unpersist()
result

}
case (Some(_), _) =>
val objectRequestsToDimensions: RDD[(GetObjectRequest, (Int, Int))] =
sc.newAPIHadoopRDD(
Expand Down Expand Up @@ -210,7 +211,8 @@ object S3GeoTiffRDD extends LazyLogging {
val layout = sourceGeoTiffInfo.getGeoTiffInfo(s"s3://$bucket/$key").segmentLayout.tileLayout

RasterReader
.listWindows(cols, rows, options.maxTileSize, layout.tileCols, layout.tileRows)
.listWindows(cols, rows, options.maxTileSize.getOrElse(1<<10), layout.tileCols, layout.tileRows)
._3
.map((objectRequest, _))
}

Expand Down
59 changes: 57 additions & 2 deletions spark/src/main/scala/geotrellis/spark/io/GeoTiffInfoReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import geotrellis.raster.io.geotiff.reader.GeoTiffReader
import geotrellis.raster.io.geotiff.reader.GeoTiffReader.GeoTiffInfo
import geotrellis.util.LazyLogging
import geotrellis.raster.GridBounds
import geotrellis.raster._

import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -59,15 +60,69 @@ private [geotrellis] trait GeoTiffInfoReader extends LazyLogging {
}
}

def windowsByBytes(
partitionBytes: Long,
maxSize: Int
)(implicit sc: SparkContext): RDD[(String, Array[GridBounds])] = {
geoTiffInfoRdd.flatMap({ uri =>
val md = getGeoTiffInfo(uri)
val cols = md.segmentLayout.totalCols
val rows = md.segmentLayout.totalRows
val segCols = md.segmentLayout.tileLayout.tileCols
val segRows = md.segmentLayout.tileLayout.tileRows
val cellType = md.cellType
val depth = {
cellType match {
case BitCellType | ByteCellType | UByteCellType | ByteConstantNoDataCellType | ByteUserDefinedNoDataCellType(_) | UByteConstantNoDataCellType | UByteUserDefinedNoDataCellType(_) => 1
case ShortCellType | UShortCellType | ShortConstantNoDataCellType | ShortUserDefinedNoDataCellType(_) | UShortConstantNoDataCellType | UShortUserDefinedNoDataCellType(_) => 2
case IntCellType | IntConstantNoDataCellType | IntUserDefinedNoDataCellType(_) => 4
case FloatCellType | FloatConstantNoDataCellType | FloatUserDefinedNoDataCellType(_) => 4
case DoubleCellType | DoubleConstantNoDataCellType | DoubleUserDefinedNoDataCellType(_) => 8
}
}
val (tileCols, tileRows, fileWindows) =
RasterReader.listWindows(cols, rows, maxSize, segCols, segRows)
val windowBytes = tileCols * tileRows * depth

var currentBytes = 0
val currentPartition = mutable.ArrayBuffer.empty[GridBounds]
val allPartitions = mutable.ArrayBuffer.empty[Array[GridBounds]]

fileWindows.foreach({ gb =>
// Add the window to the present partition
if (currentBytes + windowBytes <= partitionBytes) {
currentPartition.append(gb)
currentBytes += windowBytes
}
// The window is small enough to fit into some partition,
// but not this partition; start a new partition.
else if ((currentBytes + windowBytes > partitionBytes) && (windowBytes < partitionBytes)) {
allPartitions.append(currentPartition.toArray)
currentPartition.clear
currentPartition.append(gb)
currentBytes = windowBytes
}
// The window is too large to fit into any partition.
else {
allPartitions.append(Array(gb))
}
})
allPartitions.append(currentPartition.toArray)
allPartitions.toArray.map({ array => (uri, array) })
})
}

/**
* Function calculates a split of segments, to minimize segments reads.
*
* Returns RDD of pairs: (URI, Array[GridBounds])
* where GridBounds are gird bounds of a particular segment,
* each segment can only be in a single partition.
* */
def segmentsByPartitionBytes(partitionBytes: Long = Long.MaxValue, maxTileSize: Option[Int] = None)
(implicit sc: SparkContext): RDD[(String, Array[GridBounds])] = {
def segmentsByPartitionBytes(
partitionBytes: Long = Long.MaxValue,
maxTileSize: Option[Int] = None
)(implicit sc: SparkContext): RDD[(String, Array[GridBounds])] = {
geoTiffInfoRdd.flatMap { uri =>
val bufferKey = uri

Expand Down
26 changes: 11 additions & 15 deletions spark/src/main/scala/geotrellis/spark/io/RasterReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -72,31 +72,27 @@ object RasterReader {
}

def listWindows(
cols: Int, rows: Int, maxTileSize: Option[Int],
cols: Int, rows: Int, maxSize: Int,
segCols: Int, segRows: Int
): Array[GridBounds] = {
): (Int, Int, Array[GridBounds]) = {
val colSize: Int = if (maxSize >= segCols) segCols; else best(maxSize, segCols)
val rowSize: Int = if (maxSize >= segRows) segRows; else best(maxSize, segRows)
val windows = listWindows(cols, rows, colSize, rowSize)

maxTileSize match {
case Some(maxSize) =>
val maxColSize: Int = if (maxSize >= segCols) segCols; else best(maxSize, segCols)
val maxRowSize: Int = if (maxSize >= segRows) segRows; else best(maxSize, segRows)
listWindows(cols, rows, maxColSize, maxRowSize)
case None =>
listWindows(cols, rows, None)
}
(colSize, rowSize, windows)
}

/** List all pixel windows that cover a grid of given size */
def listWindows(cols: Int, rows: Int, maxColSize: Int, maxRowSize: Int): Array[GridBounds] = {
def listWindows(cols: Int, rows: Int, colSize: Int, rowSize: Int): Array[GridBounds] = {
val result = scala.collection.mutable.ArrayBuffer[GridBounds]()
cfor(0)(_ < cols, _ + maxColSize) { col =>
cfor(0)(_ < rows, _ + maxRowSize) { row =>
cfor(0)(_ < cols, _ + colSize) { col =>
cfor(0)(_ < rows, _ + rowSize) { row =>
result +=
GridBounds(
col,
row,
math.min(col + maxColSize - 1, cols - 1),
math.min(row + maxRowSize - 1, rows - 1)
math.min(col + colSize - 1, cols - 1),
math.min(row + rowSize - 1, rows - 1)
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ object HadoopGeoTiffRDD {
val layout = GeoTiffReader.readGeoTiffInfo(rangeReader, false, true).segmentLayout.tileLayout

RasterReader
.listWindows(cols, rows, options.maxTileSize, layout.tileCols, layout.tileRows)
.listWindows(cols, rows, options.maxTileSize.getOrElse(1<<10), layout.tileCols, layout.tileRows)
._3
.map((objectRequest, _))
}

Expand Down

0 comments on commit 4c24f24

Please sign in to comment.