Skip to content

Commit

Permalink
Merge pull request #362 from s22s/fix/360
Browse files Browse the repository at this point in the history
Fixed handling of aggregate extent and image size on geotiff write.
  • Loading branch information
metasim committed Sep 23, 2019
2 parents 8bac84a + 1b7d35f commit e313542
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 52 deletions.
Expand Up @@ -48,7 +48,7 @@ class RasterRefIT extends TestEnvironment {
rf_crs($"red"), rf_extent($"red"), rf_tile($"red"), rf_tile($"green"), rf_tile($"blue"))
.toDF

val raster = TileRasterizerAggregate(df, redScene.crs, None, None)
val raster = TileRasterizerAggregate.collect(df, redScene.crs, None, None)

forEvery(raster.tile.statisticsDouble) { stats =>
stats should be ('defined)
Expand Down
Expand Up @@ -77,6 +77,10 @@ class ProjectedLayerMetadataAggregate(destCRS: CRS, destDims: TileDimensions) ex
import org.locationtech.rasterframes.encoders.CatalystSerializer._
val buf = buffer.to[BufferRecord]

if (buf.isEmpty) {
throw new IllegalArgumentException("Can not collect metadata from empty data frame.")
}

val re = RasterExtent(buf.extent, buf.cellSize)
val layout = LayoutDefinition(re, destDims.cols, destDims.rows)

Expand Down Expand Up @@ -152,6 +156,8 @@ object ProjectedLayerMetadataAggregate {
buffer(i) = encoded(i)
}
}

def isEmpty: Boolean = extent == null || cellType == null || cellSize == null
}

private[expressions]
Expand Down
Expand Up @@ -119,7 +119,7 @@ object TileRasterizerAggregate {
new TileRasterizerAggregate(prd)(crsCol, extentCol, tileCol).as(nodeName).as[Raster[Tile]]
}

def apply(df: DataFrame, destCRS: CRS, destExtent: Option[Extent], rasterDims: Option[TileDimensions]): ProjectedRaster[MultibandTile] = {
def collect(df: DataFrame, destCRS: CRS, destExtent: Option[Extent], rasterDims: Option[TileDimensions]): ProjectedRaster[MultibandTile] = {
val tileCols = WithDataFrameMethods(df).tileColumns
require(tileCols.nonEmpty, "need at least one tile column")
// Select the anchoring Tile, Extent and CRS columns
Expand Down
Expand Up @@ -34,6 +34,7 @@ import org.apache.spark.sql.functions.{asc, udf => sparkUdf}
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
import org.locationtech.geomesa.curve.Z2SFC
import org.locationtech.rasterframes.StandardColumns
import org.locationtech.rasterframes.encoders.serialized_literal

/**
* RasterFrameLayer extension methods associated with adding spatially descriptive columns.
Expand Down Expand Up @@ -71,6 +72,15 @@ trait RFSpatialColumnMethods extends MethodExtensions[RasterFrameLayer] with Sta
val key2Extent = sparkUdf(keyCol2Extent)
self.withColumn(colName, key2Extent(self.spatialKeyColumn)).certify
}
/**
* Append a column containing the CRS of the layer.
*
* @param colName name of column to append. Defaults to "crs"
* @return updated RasterFrameLayer
*/
def withCRS(colName: String = CRS_COLUMN.columnName): RasterFrameLayer = {
self.withColumn(colName, serialized_literal(self.crs)).certify
}

/**
* Append a column containing the bounds of the row's spatial key.
Expand Down
Expand Up @@ -67,27 +67,21 @@ class GeoTiffDataSource

require(tileCols.nonEmpty, "Could not find any tile columns.")

val raster = if (df.isAlreadyLayer) {
val layer = df.certify
val tlm = layer.tileLayerMetadata.merge

// If no desired image size is given, write at full size.
val TileDimensions(cols, rows) = parameters.rasterDimensions
.getOrElse {
val actualSize = tlm.layout.toRasterExtent().gridBoundsFor(tlm.extent)
TileDimensions(actualSize.width, actualSize.height)
}

// Should we really play traffic cop here?
if (cols.toDouble * rows * 64.0 > Runtime.getRuntime.totalMemory() * 0.5)
logger.warn(
s"You've asked for the construction of a very large image ($cols x $rows), destined for ${path}. Out of memory error likely.")

layer.toMultibandRaster(tileCols, cols.toInt, rows.toInt)
} else {
require(parameters.crs.nonEmpty, "A destination CRS must be provided")
TileRasterizerAggregate(df, parameters.crs.get, None, parameters.rasterDimensions)
}
val destCRS = parameters.crs.orElse(df.asLayerSafely.map(_.crs)).getOrElse(
throw new IllegalArgumentException("A destination CRS must be provided")
)

val input = df.asLayerSafely.map(layer =>
(layer.crsColumns.isEmpty, layer.extentColumns.isEmpty) match {
case (true, true) => layer.withExtent().withCRS()
case (true, false) => layer.withCRS()
case (false, true) => layer.withExtent()
case _ => layer
}).getOrElse(df)

val raster = TileRasterizerAggregate.collect(input, destCRS, None, parameters.rasterDimensions)

val tags = Tags(
RFBuildInfo.toMap.filter(_._1.toLowerCase() == "version").mapValues(_.toString),
Expand Down
Expand Up @@ -20,14 +20,16 @@
*/
package org.locationtech.rasterframes.datasource.geotiff

import java.nio.file.Paths
import java.nio.file.{Path, Paths}

import geotrellis.proj4._
import geotrellis.raster.CellType
import geotrellis.raster.io.geotiff.{MultibandGeoTiff, SinglebandGeoTiff}
import geotrellis.vector.Extent
import org.locationtech.rasterframes._
import org.apache.spark.sql.functions._
import org.locationtech.rasterframes.TestEnvironment
import org.locationtech.rasterframes.datasource.raster._

/**
* @since 1/14/18
Expand Down Expand Up @@ -89,6 +91,15 @@ class GeoTiffDataSourceSpec

describe("GeoTiff writing") {

def checkTiff(file: Path, cols: Int, rows: Int, extent: Extent, cellType: Option[CellType] = None) = {
val outputTif = SinglebandGeoTiff(file.toString)
outputTif.tile.dimensions should be ((cols, rows))
outputTif.extent should be (extent)
cellType.foreach(ct =>
outputTif.cellType should be (ct)
)
}

it("should write GeoTIFF RF to parquet") {
val rf = spark.read.format("geotiff").load(cogPath.toASCIIString).asLayer
assert(write(rf))
Expand All @@ -104,6 +115,9 @@ class GeoTiffDataSourceSpec
noException shouldBe thrownBy {
rf.write.format("geotiff").save(out.toString)
}
val extent = rf.tileLayerMetadata.merge.extent

checkTiff(out, 1028, 989, extent)
}

it("should write unstructured raster") {
Expand All @@ -116,10 +130,10 @@ class GeoTiffDataSourceSpec

val crs = df.select(rf_crs($"proj_raster")).first()

val out = Paths.get("target", "unstructured.tif").toString
val out = Paths.get("target", "unstructured.tif")

noException shouldBe thrownBy {
df.write.geotiff.withCRS(crs).save(out)
df.write.geotiff.withCRS(crs).save(out.toString)
}

val (inCols, inRows) = {
Expand All @@ -129,11 +143,7 @@ class GeoTiffDataSourceSpec
inCols should be (774)
inRows should be (500) //from gdalinfo

val outputTif = SinglebandGeoTiff(out)
outputTif.imageData.cols should be (inCols)
outputTif.imageData.rows should be (inRows)

// TODO check datatype, extent.
checkTiff(out, inCols, inRows, Extent(431902.5, 4313647.5, 443512.5, 4321147.5))
}

it("should round trip unstructured raster from COG"){
Expand Down Expand Up @@ -163,30 +173,26 @@ class GeoTiffDataSourceSpec

dfExtent shouldBe resourceExtent

val out = Paths.get("target", "unstructured_cog.tif").toString
val out = Paths.get("target", "unstructured_cog.tif")

noException shouldBe thrownBy {
df.write.geotiff.withCRS(crs).save(out)
df.write.geotiff.withCRS(crs).save(out.toString)
}

val (inCols, inRows, inExtent, inCellType) = {
val tif = readSingleband("LC08_B7_Memphis_COG.tiff")
val id = tif.imageData
(id.cols, id.rows, tif.extent, tif.cellType)
}
inCols should be (963)
inRows should be (754) //from gdalinfo
inCols should be (resourceCols)
inRows should be (resourceRows) //from gdalinfo
inExtent should be (resourceExtent)

val outputTif = SinglebandGeoTiff(out)
outputTif.imageData.cols should be (inCols)
outputTif.imageData.rows should be (inRows)
outputTif.extent should be (resourceExtent)
outputTif.cellType should be (inCellType)
checkTiff(out, inCols, inRows, resourceExtent, Some(inCellType))
}

it("should write GeoTIFF without layer") {
import org.locationtech.rasterframes.datasource.raster._

val pr = col("proj_raster_b0")
val rf = spark.read.raster.withBandIndexes(0, 1, 2).load(rgbCogSamplePath.toASCIIString)

Expand Down Expand Up @@ -217,6 +223,42 @@ class GeoTiffDataSourceSpec
.save(out.toString)
}
}

checkTiff(out, 128, 128,
Extent(-76.52586750038186, 36.85907177863949, -76.17461216980891, 37.1303690755922))
}

it("should produce the correct subregion from layer") {
import spark.implicits._
val rf = SinglebandGeoTiff(TestData.singlebandCogPath.getPath)
.projectedRaster.toLayer(128, 128).withExtent()

val out = Paths.get("target", "example3-geotiff.tif")
logger.info(s"Writing to $out")

val bitOfLayer = rf.filter($"spatial_key.col" === 0 && $"spatial_key.row" === 0)
val expectedExtent = bitOfLayer.select($"extent".as[Extent]).first()
bitOfLayer.write.geotiff.save(out.toString)

checkTiff(out, 128, 128, expectedExtent)
}

it("should produce the correct subregion without layer") {
import spark.implicits._

val rf = spark.read.raster
.withTileDimensions(128, 128)
.load(TestData.singlebandCogPath.toASCIIString)

val out = Paths.get("target", "example3-geotiff.tif")
logger.info(s"Writing to $out")

val bitOfLayer = rf.filter(st_intersects(st_makePoint(754245, 3893385), rf_geometry($"proj_raster")))
val expectedExtent = bitOfLayer.select(rf_extent($"proj_raster")).first()
val crs = bitOfLayer.select(rf_crs($"proj_raster")).first()
bitOfLayer.write.geotiff.withCRS(crs).save(out.toString)

checkTiff(out, 128, 128, expectedExtent)
}

def s(band: Int): String =
Expand Down
1 change: 1 addition & 0 deletions docs/src/main/paradox/release-notes.md
Expand Up @@ -4,6 +4,7 @@

### 0.8.2

* Fixed handling of aggregate extent and image size on GeoTIFF writing. ([#362](https://github.com/locationtech/rasterframes/issues/362))
* Fixed issue with `RasterSourceDataSource` swallowing exceptions. ([#267](https://github.com/locationtech/rasterframes/issues/267))
* Fixed SparkML memory pressure issue caused by unnecessary reevaluation, overallocation, and primitive boxing. ([#343](https://github.com/locationtech/rasterframes/issues/343))
* Fixed Parquet serialization issue with `RasterRef`s ([#338](https://github.com/locationtech/rasterframes/issues/338))
Expand Down
Expand Up @@ -108,12 +108,13 @@ class L8CatalogRelationTest extends TestEnvironment {
stats.mean shouldBe > (10000.0)
}

ignore("should construct an RGB composite") {
val aoi = Extent(31.115, 29.963, 31.148, 29.99)
it("should construct an RGB composite") {
val aoiLL = Extent(31.115, 29.963, 31.148, 29.99)

val scene = catalog
.where(
to_date($"acquisition_date") === to_date(lit("2019-07-03")) &&
st_intersects(st_geometry($"bounds_wgs84"), geomLit(aoi.jtsGeom))
st_intersects(st_geometry($"bounds_wgs84"), geomLit(aoiLL.jtsGeom))
)
.orderBy("cloud_cover_pct")
.limit(1)
Expand All @@ -122,19 +123,13 @@ class L8CatalogRelationTest extends TestEnvironment {
.fromCatalog(scene, "B4", "B3", "B2")
.withTileDimensions(256, 256)
.load()
.where(st_contains(rf_geometry($"B4"), st_reproject(geomLit(aoi.jtsGeom), lit("EPSG:4326"), rf_crs($"B4"))))

.limit(1)

noException should be thrownBy {
val raster = TileRasterizerAggregate(df, LatLng, Some(aoi), None)
println(raster)
val raster = TileRasterizerAggregate.collect(df, LatLng, Some(aoiLL), None)
raster.tile.bandCount should be (3)
raster.extent.area > 0
}

// import geotrellis.raster.io.geotiff.{GeoTiffOptions, MultibandGeoTiff, Tiled}
// import geotrellis.raster.io.geotiff.compression.{DeflateCompression}
// import geotrellis.raster.io.geotiff.tags.codes.ColorSpace
// val tiffOptions = GeoTiffOptions(Tiled, DeflateCompression, ColorSpace.RGB)
// MultibandGeoTiff(raster, raster.crs, tiffOptions).write("target/composite.tif")
}
}
}

0 comments on commit e313542

Please sign in to comment.