Skip to content

Commit

Permalink
Merge pull request #656 from chasegarner/add-aws-credentials-options-…
Browse files Browse the repository at this point in the history
…to-GeoTiffRDD

Add ability to provide S3 credentials directly to geotiff.get()
  • Loading branch information
Jacob Bouffard committed Jun 14, 2018
2 parents 95a9d7c + 341b9bc commit 9191461
Show file tree
Hide file tree
Showing 7 changed files with 287 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ object Constants {
final val TEMPORALPROJECTEDEXTENT = "TemporalProjectedExtent"

final val S3 = "s3"
final val S3A = "s3a"
final val S3N = "s3n"

final val FLOAT = "float"
final val ZOOM = "zoom"
Expand Down Expand Up @@ -64,4 +66,7 @@ object Constants {
final val HASH = "HashPartitioner"
final val SPATIAL = "SpatialPartitioner"
final val SPACETIME = "SpaceTimePartitioner"

final val DEFAULT = "default"
final val MOCK = "mock"
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import scala.collection.JavaConversions._
import java.net.URI
import scala.reflect._

import com.amazonaws.auth.BasicAWSCredentials

import org.apache.spark._
import org.apache.hadoop.fs.Path

Expand Down Expand Up @@ -50,6 +52,8 @@ object GeoTiffRDD {
def default = S3GeoTiffRDD.Options.DEFAULT

def setValues(
conf: SparkConf,
scheme: String,
intMap: Map[String, Int],
stringMap: Map[String, String],
partitionBytes: Option[Long]
Expand All @@ -60,17 +64,19 @@ object GeoTiffRDD {
else
None

val getS3Client: () => S3Client =
stringMap.get("s3_client") match {
case Some(client) =>
if (client == "default")
default.getS3Client
else if (client == "mock")
() => new MockS3Client()
else
throw new Error(s"Could not find the given S3Client, $client")
case None => default.getS3Client
val getS3Client: () => S3Client = {
val accessKey = conf.get(s"spark.hadoop.fs.${scheme}.access.key", "")
val secretKey = conf.get(s"spark.hadoop.fs.${scheme}.secret.key", "")

stringMap.getOrElse("s3_client", DEFAULT) match {
case DEFAULT if accessKey.isEmpty => default.getS3Client
case DEFAULT =>
() => AmazonS3Client(new BasicAWSCredentials(accessKey, secretKey), S3Client.defaultConfiguration)
case MOCK =>
() => new MockS3Client()
case client: String => throw new Error(s"Could not find the given S3Client, $client")
}
}

S3GeoTiffRDD.Options(
crs = crs,
Expand All @@ -96,16 +102,17 @@ object GeoTiffRDD {
val uris = paths.map{ path => new URI(path) }
val (stringMap, intMap) = GeoTrellisUtils.convertToScalaMap(options)
val bytes = Some(partitionBytes.toLong)
lazy val conf = sc.getConf

uris
.map { uri =>
uri.getScheme match {
case S3 =>
case (S3 | S3A | S3N) =>
getS3GeoTiffRDD(
sc,
keyType,
uri,
S3GeoTiffRDDOptions.setValues(intMap, stringMap, bytes)
S3GeoTiffRDDOptions.setValues(conf, uri.getScheme, intMap, stringMap, bytes)
)
case _ =>
getHadoopGeoTiffRDD(
Expand Down
3 changes: 3 additions & 0 deletions geopyspark/geotrellis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,7 @@ def __str__(self):
from . import histogram
from . import layer
from . import neighborhood
from . import s3
from . import tms

from .catalog import *
Expand All @@ -812,6 +813,7 @@ def __str__(self):
from .layer import *
from .neighborhood import *
from .rasterize import *
from .s3 import *
from .tms import *
from .union import *
from .combine_bands import *
Expand All @@ -828,6 +830,7 @@ def __str__(self):
__all__ += layer.__all__
__all__ += neighborhood.__all__
__all__ += ['rasterize', 'rasterize_features']
__all__ += s3.__all__
__all__ += tms.__all__
__all__ += ['union']
__all__ += ['combine_bands']
51 changes: 37 additions & 14 deletions geopyspark/geotrellis/geotiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
DEFAULT_GEOTIFF_TIME_FORMAT,
DEFAULT_S3_CLIENT)
from geopyspark.geotrellis.layer import RasterLayer
from geopyspark.geotrellis.s3 import Credentials, is_s3_uri, set_s3_credentials


__all__ = ['get']
Expand All @@ -24,7 +25,8 @@ def get(layer_type,
time_tag=DEFAULT_GEOTIFF_TIME_TAG,
time_format=DEFAULT_GEOTIFF_TIME_FORMAT,
delimiter=None,
s3_client=DEFAULT_S3_CLIENT):
s3_client=DEFAULT_S3_CLIENT,
s3_credentials=None):
"""Creates a ``RasterLayer`` from GeoTiffs that are located on the local file system, ``HDFS``,
or ``S3``.
Expand Down Expand Up @@ -68,17 +70,22 @@ def get(layer_type,
Note:
This parameter will only be used when reading from S3.
s3_client (str, optional): Which ``S3Cleint`` to use when reading
s3_client (str, optional): Which ``S3Client`` to use when reading
GeoTiffs from S3. There are currently two options: ``default`` and
``mock``. Defaults to :const:`~geopyspark.geotrellis.constants.DEFAULT_S3_CLIENT`.
Note:
``mock`` should only be used in unit tests and debugging.
s3_credentials(:class:`~geopyspark.geotrellis.s3.Credentials`, optional): Alternative Amazon S3
credentials to use when accessing the tile(s).
Returns:
:class:`~geopyspark.geotrellis.layer.RasterLayer`
"""
Raises:
RuntimeError: ``s3_credentials`` were specified but the specified ``uri`` was not S3-based.
"""
inputs = {k:v for k, v in locals().items() if v is not None}

pysc = get_spark_context()
Expand All @@ -87,17 +94,33 @@ def get(layer_type,
key = LayerType(inputs.pop('layer_type'))._key_name(False)
partition_bytes = str(inputs.pop('partition_bytes'))

if isinstance(uri, list):
srdd = geotiff_rdd.get(pysc._jsc.sc(),
key,
inputs.pop('uri'),
inputs,
partition_bytes)
uri = inputs.pop('uri')
uris = (uri if isinstance(uri, list) else [uri])

try:
s3_credentials = inputs.pop('s3_credentials')
except KeyError:
s3_credentials = None
else:
srdd = geotiff_rdd.get(pysc._jsc.sc(),
key,
[inputs.pop('uri')],
inputs,
partition_bytes)
_validate_s3_credentials(uri, s3_credentials)

uri_type = uri.split(":")[0]

with set_s3_credentials(s3_credentials, uri_type):
srdd = geotiff_rdd.get(
pysc._jsc.sc(),
key,
uris,
inputs,
partition_bytes
)

return RasterLayer(layer_type, srdd)


def _validate_s3_credentials(uri, credentials):
if credentials and not is_s3_uri(uri):
raise RuntimeError(
'S3 credentials were provided to geotiff.get, but an AWS S3 URI '
'was not specified (URI: {})'.format(uri)
)
98 changes: 98 additions & 0 deletions geopyspark/geotrellis/s3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""Utilities relating to Amazon S3"""

from collections import namedtuple
from contextlib import contextmanager

from geopyspark import get_spark_context


__all__ = ['Credentials']


S3A_FS_CONSTANT = 'org.apache.hadoop.fs.s3a.S3AFileSystem'
S3A_IMPL_PATH = 'spark.hadoop.fs.s3a.impl'

_S3_ACCESS_KEY_PATH_TEMPLATE = 'spark.hadoop.fs.{prefix}.access.key'
_S3_SECRET_KEY_PATH_TEMPLATE = 'spark.hadoop.fs.{prefix}.secret.key'
_S3_URI_PREFIXES = frozenset(['s3', 's3a', 's3n'])


Credentials = namedtuple(
'Credentials',
['access_key', 'secret_key']
)
Credentials.__doc__ = (
"""Credentials for Amazon S3 buckets.
Attributes:
access_key (str): The access key for the S3 bucket.
secret_key (str): The secret key for the S3 bucket.
"""
)

@contextmanager
def set_s3_credentials(credentials, uri_type):
"""Temporarily updates the session's Amazon S3 credentials for the
duration of the context.
Args:
credentials (Credentials): The access and secret keys used to access
Amazon S3 resources.
uri_type (str): The URI type. 's3', 's3a', or 's3n'.
"""
if credentials:
if uri_type not in _S3_URI_PREFIXES:
raise RuntimeError(
'Cannot set S3 credentials for unrecognized URI type '
'{}'.format(uri_type)
)
configuration = get_spark_context()._conf
with _set_s3_credentials(credentials, configuration, uri_type):
yield
else:
yield


@contextmanager
def _set_s3_credentials(credentials, configuration, uri_type):
access_key_path = _S3_ACCESS_KEY_PATH_TEMPLATE.format(prefix=uri_type)
secret_key_path = _S3_SECRET_KEY_PATH_TEMPLATE.format(prefix=uri_type)

old_s3_access_key = configuration.get(access_key_path)
old_s3_secret_key = configuration.get(secret_key_path)

try:
configuration.set(access_key_path, credentials.access_key)
configuration.set(secret_key_path, credentials.secret_key)
if uri_type == 's3a':
with _set_s3a_path(configuration):
yield
else:
yield
finally:
configuration.set(access_key_path, old_s3_access_key)
configuration.set(secret_key_path, old_s3_secret_key)


@contextmanager
def _set_s3a_path(configuration):
old_s3a_impl = configuration.get(S3A_IMPL_PATH)
try:
configuration.set(S3A_IMPL_PATH, S3A_FS_CONSTANT)
yield
finally:
configuration.set(S3A_IMPL_PATH, old_s3a_impl)


def is_s3_uri(uri):
"""Determines if the specified ``uri`` points to Amazon S3.
Args:
uri (str): The Universal Resource Identifier to examine.
Returns:
bool: True if ``uri`` is an Amazon S3 URI.
"""
return any(
uri.startswith(prefix + '://') for prefix in _S3_URI_PREFIXES
)
20 changes: 20 additions & 0 deletions geopyspark/tests/geotrellis/geotiff_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Miscellaneous geotiff tests"""

import pytest

from geopyspark.geotrellis.constants import LayerType
from geopyspark.geotrellis.geotiff import get
from geopyspark.geotrellis.s3 import Credentials
from geopyspark.tests.python_test_utils import file_path


def test_s3_uri_type_and_credential_mismatch():
local_path = file_path("all-ones.tif")

with pytest.raises(RuntimeError):
get(
LayerType.SPATIAL,
local_path,
max_tile_size=256,
s3_credentials=Credentials('123', 'abc')
)

0 comments on commit 9191461

Please sign in to comment.