Skip to content

Commit

Permalink
add AWS credential overrides for S3 stac (#613)
Browse files Browse the repository at this point in the history
* add AWS credential overrides for S3 stac

* catch warnings
  • Loading branch information
vincentsarago committed Jun 1, 2023
1 parent 495e264 commit ef87347
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 37 deletions.
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

- handle internal and user provided `nodata` values in `rio_tiler.io.XarrayReader` to create mask

- add `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `AWS_SESSION_TOKEN`, `AWS_PROFILE` and `AWS_REGION` environnement overrides for `rio_tiler.io.stac.aws_get_object` function

**breaking changes**

- remove support for non-binary mask values (e.g non-binary alpha bands, ref: [rasterio/rasterio#1721](https://github.com/rasterio/rasterio/issues/1721#issuecomment-586547617))
Expand Down
28 changes: 25 additions & 3 deletions rio_tiler/io/stac.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,41 @@ def aws_get_object(
assert boto3_session is not None, "'boto3' must be installed to use s3:// urls"

if not client:
session = boto3_session()
if profile_name := os.environ.get("AWS_PROFILE", None):
session = boto3_session(profile_name=profile_name)

else:
access_key = os.environ.get("AWS_ACCESS_KEY_ID", None)
secret_access_key = os.environ.get("AWS_SECRET_ACCESS_KEY", None)
access_token = os.environ.get("AWS_SESSION_TOKEN", None)

# AWS_REGION is GDAL specific. Later overloaded by standard AWS_DEFAULT_REGION
region_name = os.environ.get(
"AWS_DEFAULT_REGION", os.environ.get("AWS_REGION", None)
)

session = boto3_session(
aws_access_key_id=access_key,
aws_secret_access_key=secret_access_key,
aws_session_token=access_token,
region_name=region_name or None,
)

# AWS_S3_ENDPOINT and AWS_HTTPS are GDAL config options of vsis3 driver
# https://gdal.org/user/virtual_file_systems.html#vsis3-aws-s3-files
endpoint_url = os.environ.get("AWS_S3_ENDPOINT", None)
if endpoint_url is not None:
if endpoint_url:
use_https = os.environ.get("AWS_HTTPS", "YES")
if use_https.upper() in ["YES", "TRUE", "ON"]:
endpoint_url = "https://" + endpoint_url

else:
endpoint_url = "http://" + endpoint_url

client = session.client("s3", endpoint_url=endpoint_url)
client = session.client(
"s3",
endpoint_url=endpoint_url or "s3.amazonaws.com",
)

params = {"Bucket": bucket, "Key": key}
if request_pays or os.environ.get("AWS_REQUEST_PAYER", "").lower() == "requester":
Expand Down
13 changes: 11 additions & 2 deletions rio_tiler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from rasterio.crs import CRS
from rasterio.dtypes import _gdal_typename
from rasterio.enums import ColorInterp, MaskFlags, Resampling
from rasterio.errors import NotGeoreferencedWarning
from rasterio.features import is_valid_geom
from rasterio.io import DatasetReader, DatasetWriter, MemoryFile
from rasterio.rio.helpers import coords
Expand Down Expand Up @@ -448,7 +449,11 @@ def render(
output_profile.update(creation_options)

with warnings.catch_warnings():
warnings.simplefilter("ignore", rasterio.errors.NotGeoreferencedWarning)
warnings.filterwarnings(
"ignore",
category=NotGeoreferencedWarning,
module="rasterio",
)
with MemoryFile() as memfile:
with memfile.open(**output_profile) as dst:
dst.write(data, indexes=list(range(1, count + 1)))
Expand Down Expand Up @@ -611,7 +616,11 @@ def resize_array(

datasetname = _array_gdal_name(data)
with warnings.catch_warnings():
warnings.simplefilter("ignore", rasterio.errors.NotGeoreferencedWarning)
warnings.filterwarnings(
"ignore",
category=NotGeoreferencedWarning,
module="rasterio",
)
with rasterio.open(datasetname, "r+") as src:
# if a 2D array is passed, using indexes=1 makes sure we return an 2D array
indexes = 1 if len(data.shape) == 2 else None
Expand Down
5 changes: 4 additions & 1 deletion tests/test_expression.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""test rio_tiler.expression functions."""

import warnings

import numpy
import pytest

Expand Down Expand Up @@ -48,7 +50,8 @@ def test_parse_cast(expr, expected):
)
def test_get_blocks(expr, expected):
"""test get_expression_blocks."""
with pytest.warns(None):
with warnings.catch_warnings():
warnings.simplefilter("error")
assert get_expression_blocks(expr) == expected


Expand Down
9 changes: 3 additions & 6 deletions tests/test_io_image.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Tests ImageReader."""

import os
import warnings

import numpy
import pytest
Expand All @@ -18,14 +17,12 @@

def test_non_geo_image():
"""Test ImageReader usage with Non-Geo Images."""
with pytest.warns() as w:
with pytest.warns((NotGeoreferencedWarning)):
with ImageReader(NO_GEO) as src:
assert src.minzoom == 0
assert src.maxzoom == 3
assert len(w) == 1
assert issubclass(w[0].category, NotGeoreferencedWarning)

with warnings.catch_warnings():
with pytest.warns((NotGeoreferencedWarning)):
with ImageReader(NO_GEO) as src:
assert list(src.tms.xy_bounds(0, 0, 3)) == [0, 256, 256, 0]
assert list(src.tms.xy_bounds(0, 0, 2)) == [0, 512, 512, 0]
Expand Down Expand Up @@ -92,7 +89,7 @@ def test_non_geo_image():
im = src.feature(poly)
assert im.data.shape == (3, 1100, 1100)

with warnings.catch_warnings():
with pytest.warns((NotGeoreferencedWarning)):
with ImageReader(NO_GEO_PORTRAIT) as src:
img = src.tile(5, 2, 3)
assert not img.mask.all()
Expand Down
53 changes: 29 additions & 24 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Test rio_tiler.models."""

import warnings
from io import BytesIO

import numpy
Expand All @@ -18,16 +19,16 @@ def test_imageData_AutoRescaling():
ImageData(numpy.zeros((1, 256, 256), dtype="float32")).render(img_format="PNG")
assert len(w.list) == 1

with pytest.warns(None) as w:
with warnings.catch_warnings():
warnings.simplefilter("error")
ImageData(numpy.zeros((1, 256, 256), dtype="uint8")).render(img_format="PNG")
assert len(w.list) == 0

with pytest.warns(InvalidDatatypeWarning) as w:
ImageData(numpy.zeros((1, 256, 256), dtype="int8")).render(img_format="PNG")

with pytest.warns(None) as w:
with warnings.catch_warnings():
warnings.simplefilter("error")
ImageData(numpy.zeros((1, 256, 256), dtype="uint16")).render(img_format="GTiff")
assert len(w.list) == 0

with pytest.warns(InvalidDatatypeWarning) as w:
ImageData(numpy.zeros((1, 256, 256), dtype="uint16")).render(img_format="jpeg")
Expand All @@ -41,12 +42,12 @@ def test_imageData_AutoRescaling():
)

# Make sure that we do not rescale uint16 data when there is a colormap
with pytest.warns(None) as w:
with warnings.catch_warnings():
warnings.simplefilter("error")
cm = {1: (0, 0, 0, 255), 1000: (255, 255, 255, 255)}
ImageData(numpy.zeros((1, 256, 256), dtype="uint16")).render(
img_format="JPEG", colormap=cm
)
assert len(w.list) == 0


@pytest.mark.parametrize(
Expand All @@ -55,7 +56,8 @@ def test_imageData_AutoRescaling():
)
def test_imageData_AutoRescalingAllTypes(dtype):
"""Test ImageData auto rescaling."""
with pytest.warns(None):
with warnings.catch_warnings():
warnings.simplefilter("ignore") # Some InvalidDatatypeWarning will be emitted
ImageData(numpy.zeros((1, 256, 256), dtype=dtype)).render(img_format="PNG")
ImageData(numpy.zeros((1, 256, 256), dtype=dtype)).render(img_format="JPEG")
ImageData(numpy.zeros((3, 256, 256), dtype=dtype)).render(img_format="WEBP")
Expand All @@ -69,7 +71,7 @@ def test_16bit_PNG():
mask = numpy.zeros((1, 256, 256), dtype="bool")
mask[0:10, 0:10] = True

with pytest.warns(None):
with warnings.catch_warnings():
arr = numpy.ma.MaskedArray(numpy.zeros((1, 256, 256), dtype="uint16"))
arr.mask = mask.copy()
img = ImageData(arr).render(img_format="PNG")
Expand All @@ -83,7 +85,7 @@ def test_16bit_PNG():
assert (arr[0:10, 0:10] == 0).all()
assert (arr[11:, 11:] == 65535).all()

with pytest.warns(None):
with warnings.catch_warnings():
arr = numpy.ma.MaskedArray(numpy.zeros((3, 256, 256), dtype="uint16"))
arr.mask = mask.copy()
img = ImageData(arr).render(img_format="PNG")
Expand Down Expand Up @@ -113,11 +115,11 @@ def test_merge_with_diffsize():
assert img.width == 256
assert img.height == 256

with pytest.warns(None) as w:
with warnings.catch_warnings():
warnings.simplefilter("error")
img1 = ImageData(numpy.zeros((1, 256, 256)))
img2 = ImageData(numpy.zeros((1, 256, 256)))
img = ImageData.create_from_list([img1, img2])
assert len(w) == 0


def test_apply_expression():
Expand Down Expand Up @@ -159,19 +161,22 @@ def test_dataset_statistics():
data[0, 0:10, 0:10] = 0
data[0, 10:11, 10:11] = 1

img = ImageData(data, dataset_statistics=[(0, 1)]).render(img_format="PNG")
with MemoryFile(img) as mem:
with mem.open() as dst:
arr = dst.read(indexes=1)
assert arr.min() == 0
assert arr.max() == 255

img = ImageData(data).render(img_format="PNG")
with MemoryFile(img) as mem:
with mem.open() as dst:
arr = dst.read(indexes=1)
assert not arr.min() == 0
assert not arr.max() == 255
with pytest.warns(InvalidDatatypeWarning):
img = ImageData(data, dataset_statistics=[(0, 1)]).render(img_format="PNG")
with MemoryFile(img) as mem:
with mem.open() as dst:
arr = dst.read(indexes=1)
assert arr.min() == 0
assert arr.max() == 255

with pytest.warns(InvalidDatatypeWarning):
img = ImageData(data).render(img_format="PNG")

with MemoryFile(img) as mem:
with mem.open() as dst:
arr = dst.read(indexes=1)
assert not arr.min() == 0
assert not arr.max() == 255


def test_resize():
Expand Down
5 changes: 4 additions & 1 deletion tests/test_mosaic.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,9 @@ class aClass(object):
# test with preview
# NOTE: We need to have fix output width and height because each preview could have different size
# Also because the 2 assets cover different bbox, getting the preview merged together doesn't make real sense
(t, m), _ = mosaic.mosaic_reader(assets, _read_preview, width=256, height=256)
(t, m), _ = mosaic.mosaic_reader(
assets, _read_preview, width=256, height=256, max_size=None
)
assert t.shape == (3, 256, 256)
assert m.shape == (256, 256)
assert t.dtype == "uint16"
Expand Down Expand Up @@ -445,6 +447,7 @@ def test_mosaic_tiler_with_imageDataClass():
_read_preview,
width=256,
height=256,
max_size=None,
pixel_selection=defaults.LowestMethod(),
)
assert img.data.shape == (3, 256, 256)
Expand Down

0 comments on commit ef87347

Please sign in to comment.