Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for netcdf for RasterAsGridReader. #556

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ case class MosaicRasterGDAL(
result
}
if (spatialRef == null) {
// Avoids null-CRS rasters
raster.SetSpatialRef(MosaicGDAL.WSG84)
MosaicGDAL.WSG84
} else {
spatialRef
Expand Down Expand Up @@ -480,7 +482,9 @@ case class MosaicRasterGDAL(
tmpPath
}
}
val byteArray = FileUtils.readBytes(readPath)
// For corrupted files, return empty byte array
// We will have the reason for corruption in the last_error field
val byteArray = Try(FileUtils.readBytes(readPath)).getOrElse(Array.empty[Byte])
if (dispose) RasterCleaner.dispose(this)
if (readPath != PathUtils.getCleanPath(parentPath)) {
Files.deleteIfExists(Paths.get(readPath))
Expand Down Expand Up @@ -608,7 +612,7 @@ case class MosaicRasterGDAL(
.delete()

val outputRaster = gdal.Open(resultRasterPath, GF_Write)

for (bandIndex <- 1 to this.numBands) {
val band = this.getBand(bandIndex)
val outputBand = outputRaster.GetRasterBand(bandIndex)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL
import org.gdal.gdal.{WarpOptions, gdal}

import java.nio.file.{Files, Paths}
import scala.util.Try

/** GDALWarp is a wrapper for the GDAL Warp command. */
object GDALWarp {
Expand All @@ -29,7 +30,7 @@ object GDALWarp {
val result = gdal.Warp(outputPath, rasters.map(_.getRaster).toArray, warpOptions)
// Format will always be the same as the first raster
val errorMsg = gdal.GetLastErrorMsg
val size = Files.size(Paths.get(outputPath))
val size = Try(Files.size(Paths.get(outputPath))).getOrElse(-1L)
val createInfo = Map(
"path" -> outputPath,
"parentPath" -> rasters.head.getParentPath,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,15 @@ object RasterProject {
// Note that Null is the right value here
val authName = destCRS.GetAuthorityName(null)
val authCode = destCRS.GetAuthorityCode(null)


val srcAuthName = raster.getSpatialReference.GetAuthorityName(null)
val srcAuthCode = raster.getSpatialReference.GetAuthorityCode(null)

// There is no need to translate if the CRSs match
if (authName == srcAuthName && authCode == srcAuthCode) {
return raster
}

val result = GDALWarp.executeWarp(
resultFileName,
Seq(raster),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,20 @@ object SeparateBands {
def separate(
tile: => MosaicRasterTile
): Seq[MosaicRasterTile] = {
val raster = tile.getRaster
val raster = if (tile.getRaster.getWriteOptions.format == "Zarr") {
zarrToNetCDF(tile).getRaster
} else {
tile.getRaster
}
val tiles = for (i <- 0 until raster.numBands) yield {
val fileExtension = raster.getRasterFileExtension
val rasterPath = PathUtils.createTmpFilePath(fileExtension)
val shortDriver = raster.getDriversShortName
val outOptions = raster.getWriteOptions

val result = GDALTranslate.executeTranslate(
rasterPath,
raster,
command = s"gdal_translate -of $shortDriver -b ${i + 1}",
command = s"gdal_translate -b ${i + 1}",
writeOptions = outOptions
)

Expand All @@ -49,8 +52,36 @@ object SeparateBands {

val (_, valid) = tiles.partition(_._1)

if (tile.getRaster.getWriteOptions.format == "Zarr") dispose(raster)

for (elem <- valid) { elem._2.raster.SetSpatialRef(raster.getSpatialReference) }
valid.map(t => new MosaicRasterTile(null, t._2))

}

def zarrToNetCDF(
tile: => MosaicRasterTile
): MosaicRasterTile = {
val raster = tile.getRaster
val fileExtension = "nc"
val rasterPath = PathUtils.createTmpFilePath(fileExtension)
val outOptions = raster.getWriteOptions.copy(
format = "NetCDF"
)

val result = GDALTranslate.executeTranslate(
rasterPath,
raster,
command = s"gdal_translate",
writeOptions = outOptions
)
result.raster.SetSpatialRef(raster.getSpatialReference)
result.raster.FlushCache()

val isEmpty = result.isEmpty
if (isEmpty) dispose(result)

new MosaicRasterTile(tile.index, result)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class RasterAsGridReader(sparkSession: SparkSession) extends MosaicDataFrameRead
private val mc = MosaicContext.context()
import mc.functions._

def getNPartitions(config: Map[String, String]): Int = {
private def getNPartitions(config: Map[String, String]): Int = {
val shufflePartitions = sparkSession.conf.get("spark.sql.shuffle.partitions")
val nPartitions = config.getOrElse("nPartitions", shufflePartitions).toInt
nPartitions
Expand Down Expand Up @@ -75,7 +75,11 @@ class RasterAsGridReader(sparkSession: SparkSession) extends MosaicDataFrameRead
.agg(rst_combineavg_agg(col("tile")).alias("tile"))
.withColumn(
"grid_measures",
rasterToGridCombiner(col("tile"))
// when tessellation fails the last_error will be populated
// we should surface up the error but we cant aggregate
// so we force a null value
when(col("tile.metadata.last_error").isNotNull, lit(null))
.otherwise(rasterToGridCombiner(col("tile")))
)
.select(
"grid_measures",
Expand Down Expand Up @@ -148,13 +152,17 @@ class RasterAsGridReader(sparkSession: SparkSession) extends MosaicDataFrameRead
val readSubdataset = config("readSubdataset").toBoolean
val subdatasetName = config("subdatasetName")

if (readSubdataset) {
pathsDf
.withColumn("subdatasets", rst_subdatasets(col("tile")))
.withColumn("tile", rst_getsubdataset(col("tile"), lit(subdatasetName)))
} else {
pathsDf.select(col("tile"))
}
val resolved =
if (readSubdataset) {
pathsDf
.withColumn("subdatasets", rst_subdatasets(col("tile")))
.withColumn("tile", rst_getsubdataset(col("tile"), lit(subdatasetName)))
} else {
pathsDf.select(col("tile"))
}
resolved
.withColumn("tile", rst_separatebands(col("tile")))
.where(rst_pixelcount(col("tile")).getItem(0) > 0)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ trait SharedSparkSessionGDAL extends SharedSparkSession {

override def createSparkSession: TestSparkSession = {
val conf = sparkConf
conf.set(MOSAIC_RASTER_CHECKPOINT, FileUtils.createMosaicTempDir(prefix = "/mnt/"))
conf.set(MOSAIC_RASTER_CHECKPOINT, FileUtils.createMosaicTempDir(prefix = "/tmp/tmp"))
SparkSession.cleanupAnyExistingSession()
val session = new MosaicTestSparkSession(conf)
session.sparkContext.setLogLevel("FATAL")
Expand Down
Loading