Skip to content

Commit

Permalink
Merge pull request #2317 from pomadchin/fix/s3if-conditions
Browse files Browse the repository at this point in the history
Fix S3GeoTiffRDD behavior with some options.
  • Loading branch information
echeipesh committed Aug 7, 2017
2 parents 4a83f22 + 035bd03 commit 8e9de3d
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 10 deletions.
Expand Up @@ -195,5 +195,17 @@ class S3GeoTiffRDDSpec

MockS3Client.lastListObjectsRequest.get.getDelimiter should be ("/")
}

it("should read with num partitions and window size options set") {
val key = "geoTiff/all-ones.tif"
val testGeoTiffPath = "spark/src/test/resources/all-ones.tif"
val geoTiffBytes = Files.readAllBytes(Paths.get(testGeoTiffPath))
mockClient.putObject(bucket, key, geoTiffBytes)

val source =
S3GeoTiffRDD.spatial(bucket, key, S3GeoTiffRDD.Options(maxTileSize = 512, numPartitions = 32, getS3Client = () => new MockS3Client))

source.count.toInt should be > 0
}
}
}
17 changes: 14 additions & 3 deletions s3/src/main/scala/geotrellis/spark/io/s3/S3GeoTiffRDD.scala
Expand Up @@ -89,20 +89,31 @@ object S3GeoTiffRDD extends LazyLogging {

/**
* Create Configuration for [[S3InputFormat]] based on parameters and options.
* Important: won't pass partitionBytes into hadoop configuration if numPartition options is set.
*
* @param bucket Name of the bucket on S3 where the files are kept.
* @param prefix Prefix of all of the keys on S3 that are to be read in.
* @param options An instance of [[Options]] that contains any user defined or default settings.
*/
private def configuration(bucket: String, prefix: String, options: S3GeoTiffRDD.Options)(implicit sc: SparkContext): Configuration = {
if(options.numPartitions.isDefined && options.partitionBytes.isDefined)
logger.warn("Both numPartitions and partitionBytes options are set. " +
"Only numPartitions would be passed into hadoop configuration.")

val conf = sc.hadoopConfiguration
S3InputFormat.setBucket(conf, bucket)
S3InputFormat.setPrefix(conf, prefix)
S3InputFormat.setExtensions(conf, options.tiffExtensions)
S3InputFormat.setCreateS3Client(conf, options.getS3Client)
options.numPartitions.foreach{ n => S3InputFormat.setPartitionCount(conf, n) }
options.partitionBytes.foreach{ n => S3InputFormat.setPartitionBytes(conf, n) }
options.delimiter.foreach { n => S3InputFormat.setDelimiter(conf, n) }
options.numPartitions
.fold(S3InputFormat.removePartitionCount(conf)) { n =>
S3InputFormat.setPartitionCount(conf, n)
S3InputFormat.removePartitionBytes(conf)
}
if(options.numPartitions.isEmpty)
options.partitionBytes
.fold(S3InputFormat.removePartitionBytes(conf))(S3InputFormat.setPartitionBytes(conf, _))
options.delimiter.fold(S3InputFormat.removeDelimiter(conf))(S3InputFormat.setDelimiter(conf, _))
conf
}

Expand Down
47 changes: 40 additions & 7 deletions s3/src/main/scala/geotrellis/spark/io/s3/S3InputFormat.scala
Expand Up @@ -218,6 +218,9 @@ object S3InputFormat {
case None => S3Client.DEFAULT
}

def removeCreateS3Client(conf: Configuration): Unit =
conf.unset(CREATE_S3CLIENT)

/** Set S3N url to use, may include AWS Id and Key */
def setUrl(job: Job, url: String): Unit =
setUrl(job.getConfiguration, url)
Expand All @@ -233,18 +236,31 @@ object S3InputFormat {
conf.set(PREFIX, prefix)
}

def removeUrl(conf: Configuration): Unit = {
conf.unset(AWS_ID)
conf.unset(AWS_KEY)
removeBucket(conf)
removePrefix(conf)
}

def setBucket(job: Job, bucket: String): Unit =
setBucket(job.getConfiguration, bucket)

def setBucket(conf: Configuration, bucket: String): Unit =
conf.set(BUCKET, bucket)

def removeBucket(conf: Configuration): Unit =
conf.unset(BUCKET)

def setPrefix(job: Job, prefix: String): Unit =
setPrefix(job.getConfiguration, prefix)

def setPrefix(conf: Configuration, prefix: String): Unit =
conf.set(PREFIX, prefix)

def removePrefix(conf: Configuration): Unit =
conf.unset(PREFIX)

/** Set desired partition count */
def setPartitionCount(job: Job, limit: Int): Unit =
setPartitionCount(job.getConfiguration, limit)
Expand All @@ -253,12 +269,19 @@ object S3InputFormat {
def setPartitionCount(conf: Configuration, limit: Int): Unit =
conf.set(PARTITION_COUNT, limit.toString)

/** Removes partition count */
def removePartitionCount(conf: Configuration): Unit =
conf.unset(PARTITION_COUNT)

def setRegion(job: Job, region: String): Unit =
setRegion(job.getConfiguration, region)

def setRegion(conf: Configuration, region: String): Unit =
conf.set(REGION, region)

def removeRegion(conf: Configuration): Unit =
conf.unset(REGION)

/** Force anonymous access, bypass all key discovery */
def setAnonymous(job: Job): Unit =
setAnonymous(job.getConfiguration)
Expand All @@ -267,6 +290,9 @@ object S3InputFormat {
def setAnonymous(conf: Configuration): Unit =
conf.set(ANONYMOUS, "true")

def removeAnonymous(conf: Configuration): Unit =
conf.unset(ANONYMOUS)

/** Set desired partition size in bytes, at least one item per partition will be assigned */
def setPartitionBytes(job: Job, bytes: Long): Unit =
setPartitionBytes(job.getConfiguration, bytes)
Expand All @@ -275,16 +301,26 @@ object S3InputFormat {
def setPartitionBytes(conf: Configuration, bytes: Long): Unit =
conf.set(PARTITION_BYTES, bytes.toString)

/** Removes partition size in bytes */
def removePartitionBytes(conf: Configuration): Unit =
conf.unset(PARTITION_BYTES)

def setChunkSize(job: Job, chunkSize: Int): Unit =
setChunkSize(job.getConfiguration, chunkSize)

def setChunkSize(conf: Configuration, chunkSize: Int): Unit =
conf.set(CHUNK_SIZE, chunkSize.toString)

def removeChunkSize(conf: Configuration): Unit =
conf.unset(CHUNK_SIZE)

/** Set valid key extensions filter */
def setExtensions(conf: Configuration, extensions: Seq[String]): Unit =
conf.set(EXTENSIONS, extensions.mkString(","))

def removeExtensions(conf: Configuration): Unit =
conf.unset(EXTENSIONS)

/** Set delimiter for S3 object listing requests */
def setDelimiter(job: Job, delimiter: String): Unit =
setDelimiter(job.getConfiguration, delimiter)
Expand All @@ -296,12 +332,9 @@ object S3InputFormat {
def getDelimiter(job: JobContext): Option[String] =
getDelimiter(job.getConfiguration)

def getDelimiter(conf: Configuration): Option[String] = {
val d = conf.get(DELIMITER)
if(d != null)
Some(d)
else
None
}
def getDelimiter(conf: Configuration): Option[String] =
Option(conf.get(DELIMITER))

def removeDelimiter(conf: Configuration): Unit =
conf.unset(DELIMITER)
}

0 comments on commit 8e9de3d

Please sign in to comment.