Skip to content

Commit

Permalink
Resolve #1520 and #1521
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean Gillies committed Oct 25, 2018
1 parent 98126f2 commit 8ae79a4
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 49 deletions.
5 changes: 5 additions & 0 deletions CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ Changes

Bug fixes:

- Delegate test of the environment for existing session credentials to the
session class to generalize credentialization of GDAL to cloud providers
other than AWS (#1520). The env.hascreds function is no longer used in
Rasterio and has been marked as deprecated.
- Switch to use of botocore Credentials.get_frozen_credentials (#1521).
- Numpy masked arrays with the normal Numpy mask sense (True == invalid) are
now supported as input for feature.shapes(). The mask keyword argument of the
function keeps to the GDAL sense of masks (nonzero == invalid) and the
Expand Down
34 changes: 12 additions & 22 deletions rasterio/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,11 @@
import threading
import warnings

import rasterio
from rasterio._env import (
GDALEnv, del_gdal_config, get_gdal_config, set_gdal_config)
from rasterio._env import GDALEnv, get_gdal_config, set_gdal_config
from rasterio.compat import string_types, getargspec
from rasterio.dtypes import check_dtype
from rasterio.errors import (
EnvError, GDALVersionError, RasterioDeprecationWarning)
from rasterio.path import parse_path, UnparsedPath, ParsedPath
from rasterio.session import Session, AWSSession, DummySession
from rasterio.transform import guard_transform


class ThreadEnv(threading.local):
Expand Down Expand Up @@ -226,16 +221,6 @@ def from_defaults(cls, *args, **kwargs):
options.update(**kwargs)
return Env(*args, **options)

@property
def is_credentialized(self):
"""Test for existence of cloud credentials
Returns
-------
bool
"""
return hascreds()

def credentialize(self):
"""Get credentials and configure GDAL
Expand All @@ -247,7 +232,7 @@ def credentialize(self):
None
"""
if hascreds():
if self.session.hascreds(getenv()):
pass
else:
cred_opts = self.session.get_credential_options()
Expand Down Expand Up @@ -338,6 +323,7 @@ def setenv(**options):


def hascreds():
warnings.warn("Please use Env.session.hascreds() instead", RasterioDeprecationWarning)
return local._env is not None and all(key in local._env.get_config_options() for key in ['AWS_ACCESS_KEY_ID', 'AWS_SECRET_ACCESS_KEY'])


Expand Down Expand Up @@ -397,12 +383,16 @@ def wrapper(*args, **kwds):
else:
env_ctor = Env.from_defaults

if hascreds():
session = DummySession()
elif isinstance(args[0], str):
session = Session.from_path(args[0])
if isinstance(args[0], str):
session_cls = Session.cls_from_path(args[0])

if local._env and session_cls.hascreds(getenv()):
session_cls = DummySession

session = session_cls()

else:
session = Session.from_path(None)
session = DummySession()

with env_ctor(session=session):
return f(*args, **kwds)
Expand Down
94 changes: 70 additions & 24 deletions rasterio/session.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Abstraction for sessions in various clouds."""


from rasterio.path import parse_path, UnparsedPath, ParsedPath
from rasterio.path import parse_path, UnparsedPath


class Session(object):
Expand All @@ -18,6 +18,10 @@ class Session(object):
"""

@classmethod
def hascreds(cls, config):
return NotImplementedError

def get_credential_options(self):
"""Get credentials as GDAL configuration options
Expand Down Expand Up @@ -50,40 +54,73 @@ def from_foreign_session(session, cls=None):
return cls(session)

@staticmethod
def from_path(path, *args, **kwargs):
"""Create a session object suited to the data at `path`.
def cls_from_path(path):
"""Find the session class suited to the data at `path`.
Parameters
----------
path : str
A dataset path or identifier.
args : sequence
Positional arguments for the foreign session constructor.
kwargs : dict
Keyword arguments for the foreign session constructor.
Returns
-------
Session
class
"""
if not path:
return DummySession()
return DummySession

path = parse_path(path)

if isinstance(path, UnparsedPath) or path.is_local:
return DummySession()
return DummySession

elif path.scheme == "s3" or "amazonaws.com" in path.path:
return AWSSession(*args, **kwargs)
return AWSSession

# This factory can be extended to other cloud providers here.
# elif path.scheme == "cumulonimbus": # for example.
# return CumulonimbusSession(*args, **kwargs)

else:
return DummySession()
return DummySession

@staticmethod
def from_path(path, *args, **kwargs):
"""Create a session object suited to the data at `path`.
Parameters
----------
path : str
A dataset path or identifier.
args : sequence
Positional arguments for the foreign session constructor.
kwargs : dict
Keyword arguments for the foreign session constructor.
Returns
-------
Session
"""
return Session.cls_from_path(path)(*args, **kwargs)
# if not path:
# return DummySession()
#
# path = parse_path(path)
#
# if isinstance(path, UnparsedPath) or path.is_local:
# return DummySession()
#
# elif path.scheme == "s3" or "amazonaws.com" in path.path:
# return AWSSession(*args, **kwargs)
#
# # This factory can be extended to other cloud providers here.
# # elif path.scheme == "cumulonimbus": # for example.
# # return CumulonimbusSession(*args, **kwargs)
#
# else:
# return DummySession()


class DummySession(Session):
Expand All @@ -100,6 +137,10 @@ def __init__(self, *args, **kwargs):
self._session = None
self.credentials = {}

@classmethod
def hascreds(cls, config):
return True

def get_credential_options(self):
"""Get credentials as GDAL configuration options
Expand Down Expand Up @@ -155,24 +196,29 @@ def __init__(

self.requester_pays = requester_pays
self.unsigned = aws_unsigned
self._creds = self._session._session.get_credentials()

@classmethod
def hascreds(cls, config):
return 'AWS_ACCESS_KEY_ID' in config and 'AWS_SECRET_ACCESS_KEY' in config

@property
def credentials(self):
"""The session credentials as a dict"""
creds = {}
if self._creds:
if self._creds.access_key: # pragma: no branch
creds['aws_access_key_id'] = self._creds.access_key
if self._creds.secret_key: # pragma: no branch
creds['aws_secret_access_key'] = self._creds.secret_key
if self._creds.token:
creds['aws_session_token'] = self._creds.token
res = {}
creds = self._session._session.get_credentials()
if creds:
creds_set = creds.get_frozen_credentials()
if creds_set.access_key: # pragma: no branch
res['aws_access_key_id'] = creds_set.access_key
if creds_set.secret_key: # pragma: no branch
res['aws_secret_access_key'] = creds_set.secret_key
if creds_set.token:
res['aws_session_token'] = creds_set.token
if self._session.region_name:
creds['aws_region'] = self._session.region_name
res['aws_region'] = self._session.region_name
if self.requester_pays:
creds['aws_request_payer'] = 'requester'
return creds
res['aws_request_payer'] = 'requester'
return res

def get_credential_options(self):
"""Get credentials as GDAL configuration options
Expand Down
6 changes: 3 additions & 3 deletions tests/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,9 @@ def test_aws_session(gdalenv):
aws_access_key_id='id', aws_secret_access_key='key',
aws_session_token='token', region_name='null-island-1')
with rasterio.env.Env(session=aws_session) as s:
assert s.session._creds.access_key == 'id'
assert s.session._creds.secret_key == 'key'
assert s.session._creds.token == 'token'
assert s.session._session.get_credentials().get_frozen_credentials().access_key == 'id'
assert s.session._session.get_credentials().get_frozen_credentials().secret_key == 'key'
assert s.session._session.get_credentials().get_frozen_credentials().token == 'token'
assert s.session._session.region_name == 'null-island-1'


Expand Down

0 comments on commit 8ae79a4

Please sign in to comment.