Skip to content

Commit

Permalink
Support auto-setting AWS credentials for storage options (#78)
Browse files Browse the repository at this point in the history
  • Loading branch information
milesgranger committed Mar 21, 2024
1 parent 05f7cf1 commit ec1c90c
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 9 deletions.
19 changes: 14 additions & 5 deletions dask_deltatable/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from packaging.version import Version
from pyarrow import dataset as pa_ds

from . import utils
from .types import Filters
from .utils import get_partition_filters

if Version(pa.__version__) >= Version("10.0.0"):
filters_to_expression = pq.filters_to_expression
Expand All @@ -44,7 +44,9 @@ def _get_pq_files(dt: DeltaTable, filter: Filters = None) -> list[str]:
list[str]
List of files matching optional filter.
"""
partition_filters = get_partition_filters(dt.metadata().partition_columns, filter)
partition_filters = utils.get_partition_filters(
dt.metadata().partition_columns, filter
)
if not partition_filters:
# can't filter
return sorted(dt.file_uris())
Expand Down Expand Up @@ -94,6 +96,9 @@ def _read_from_filesystem(
"""
Reads the list of parquet files in parallel
"""
storage_options = utils.maybe_set_aws_credentials(path, storage_options) # type: ignore
delta_storage_options = utils.maybe_set_aws_credentials(path, delta_storage_options) # type: ignore

fs, fs_token, _ = get_fs_token_paths(path, storage_options=storage_options)
dt = DeltaTable(
table_uri=path, version=version, storage_options=delta_storage_options
Expand All @@ -116,12 +121,14 @@ def _read_from_filesystem(
if columns:
meta = meta[columns]

kws = dict(meta=meta, label="read-delta-table")
if not dd._dask_expr_enabled():
# Setting token not supported in dask-expr
kws["token"] = tokenize(path, fs_token, **kwargs) # type: ignore
return dd.from_map(
partial(_read_delta_partition, fs=fs, columns=columns, schema=schema, **kwargs),
pq_files,
meta=meta,
label="read-delta-table",
token=tokenize(path, fs_token, **kwargs),
**kws,
)


Expand Down Expand Up @@ -270,6 +277,8 @@ def read_deltalake(
else:
if path is None:
raise ValueError("Please Provide Delta Table path")

delta_storage_options = utils.maybe_set_aws_credentials(path, delta_storage_options) # type: ignore
resultdf = _read_from_filesystem(
path=path,
version=version,
Expand Down
70 changes: 69 additions & 1 deletion dask_deltatable/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,78 @@
from __future__ import annotations

from typing import cast
from typing import Any, cast

from .types import Filter, Filters


def get_bucket_region(path: str):
import boto3

if not path.startswith("s3://"):
raise ValueError(f"'{path}' is not an S3 path")
bucket = path.replace("s3://", "").split("/")[0]
resp = boto3.client("s3").get_bucket_location(Bucket=bucket)
# Buckets in region 'us-east-1' results in None, b/c why not.
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3/client/get_bucket_location.html#S3.Client.get_bucket_location
return resp["LocationConstraint"] or "us-east-1"


def maybe_set_aws_credentials(path: Any, options: dict[str, Any]) -> dict[str, Any]:
"""
Maybe set AWS credentials into ``options`` if existing AWS specific keys
not found in it and path is s3:// format.
Parameters
----------
path : Any
If it's a string, we'll check if it starts with 's3://' then determine bucket
region if the AWS credentials should be set.
options : dict[str, Any]
Options, any kwargs to be supplied to things like S3FileSystem or similar
that may accept AWS credentials set. A copy is made and returned if modified.
Returns
-------
dict
Either the original options if not modified, or a copied and updated options
with AWS credentials inserted.
"""

is_s3_path = getattr(path, "startswith", lambda _: False)("s3://")
if not is_s3_path:
return options

# Avoid overwriting already provided credentials
keys = ("AWS_ACCESS_KEY", "AWS_SECRET_ACCESS_KEY", "access_key", "secret_key")
if not any(k in (options or ()) for k in keys):
# defers installing boto3 upfront, xref _read_from_catalog
import boto3

session = boto3.session.Session()
credentials = session.get_credentials()
if credentials is None:
return options
region = get_bucket_region(path)

options = (options or {}).copy()
options.update(
# Capitalized is used in delta specific API and lowercase is for S3FileSystem
dict(
# TODO: w/o this, we need to configure a LockClient which seems to require dynamodb.
AWS_S3_ALLOW_UNSAFE_RENAME="true",
AWS_SECRET_ACCESS_KEY=credentials.secret_key,
AWS_ACCESS_KEY_ID=credentials.access_key,
AWS_SESSION_TOKEN=credentials.token,
AWS_REGION=region,
secret_key=credentials.secret_key,
access_key=credentials.access_key,
token=credentials.token,
region=region,
)
)
return options


def get_partition_filters(
partition_columns: list[str], filters: Filters
) -> list[list[Filter]] | None:
Expand Down
13 changes: 12 additions & 1 deletion dask_deltatable/write.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,15 @@
from dask.dataframe.core import Scalar
from dask.highlevelgraph import HighLevelGraph
from deltalake import DeltaTable

try:
from deltalake.writer import MAX_SUPPORTED_WRITER_VERSION # type: ignore
except ImportError:
from deltalake.writer import (
MAX_SUPPORTED_PYARROW_WRITER_VERSION as MAX_SUPPORTED_WRITER_VERSION,
)

from deltalake.writer import (
MAX_SUPPORTED_WRITER_VERSION,
PYARROW_MAJOR_VERSION,
AddAction,
DeltaJSONEncoder,
Expand All @@ -30,6 +37,7 @@
)
from toolz.itertoolz import pluck

from . import utils
from ._schema import pyarrow_to_deltalake, validate_compatible


Expand Down Expand Up @@ -123,6 +131,7 @@ def to_deltalake(
-------
dask.Scalar
"""
storage_options = utils.maybe_set_aws_credentials(table_or_uri, storage_options) # type: ignore
table, table_uri = try_get_table_and_table_uri(table_or_uri, storage_options)

# We need to write against the latest table version
Expand All @@ -136,6 +145,7 @@ def to_deltalake(
storage_options = table._storage_options or {}
storage_options.update(storage_options or {})

storage_options = utils.maybe_set_aws_credentials(table_uri, storage_options)
filesystem = pa_fs.PyFileSystem(DeltaStorageHandler(table_uri, storage_options))

if isinstance(partition_by, str):
Expand Down Expand Up @@ -253,6 +263,7 @@ def _commit(
schema = validate_compatible(schemas)
assert schema
if table is None:
storage_options = utils.maybe_set_aws_credentials(table_uri, storage_options)
write_deltalake_pyarrow(
table_uri,
schema,
Expand Down
14 changes: 13 additions & 1 deletion tests/test_acceptance.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import os
import shutil
import unittest.mock as mock
from urllib.request import urlretrieve

import dask.dataframe as dd
Expand Down Expand Up @@ -42,6 +43,15 @@ def download_data():
assert os.path.exists(DATA_DIR)


@mock.patch("dask_deltatable.utils.maybe_set_aws_credentials")
def test_reader_check_aws_credentials(maybe_set_aws_credentials):
# The full functionality of maybe_set_aws_credentials tests in test_utils
# we only need to ensure it's called here when reading with a str path
maybe_set_aws_credentials.return_value = dict()
ddt.read_deltalake(f"{DATA_DIR}/all_primitive_types/delta")
maybe_set_aws_credentials.assert_called()


def test_reader_all_primitive_types():
actual_ddf = ddt.read_deltalake(f"{DATA_DIR}/all_primitive_types/delta")
expected_ddf = dd.read_parquet(
Expand All @@ -50,7 +60,9 @@ def test_reader_all_primitive_types():
# Dask and delta go through different parquet parsers which read the
# timestamp differently. This is likely a bug in arrow but the delta result
# is "more correct".
expected_ddf["timestamp"] = expected_ddf["timestamp"].astype("datetime64[us]")
expected_ddf["timestamp"] = (
expected_ddf["timestamp"].astype("datetime64[us]").dt.tz_localize("UTC")
)
assert_eq(actual_ddf, expected_ddf)


Expand Down
85 changes: 84 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from __future__ import annotations

import pathlib
import unittest.mock as mock

import pytest

from dask_deltatable.utils import get_partition_filters
from dask_deltatable.utils import (
get_bucket_region,
get_partition_filters,
maybe_set_aws_credentials,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -31,3 +38,79 @@ def test_partition_filters(cols, filters, expected):
# make sure it works with additional level of wrapping
res = get_partition_filters(cols, filters)
assert res == expected


@mock.patch("dask_deltatable.utils.get_bucket_region")
@pytest.mark.parametrize(
"options",
(
None,
dict(),
dict(AWS_ACCESS_KEY_ID="foo", AWS_SECRET_ACCESS_KEY="bar"),
dict(access_key="foo", secret_key="bar"),
),
)
@pytest.mark.parametrize("path", ("s3://path", "/another/path", pathlib.Path(".")))
def test_maybe_set_aws_credentials(
mocked_get_bucket_region,
options,
path,
):
pytest.importorskip("boto3")

mocked_get_bucket_region.return_value = "foo-region"

mock_creds = mock.MagicMock()
type(mock_creds).token = mock.PropertyMock(return_value="token")
type(mock_creds).access_key = mock.PropertyMock(return_value="access-key")
type(mock_creds).secret_key = mock.PropertyMock(return_value="secret-key")

def mock_get_credentials():
return mock_creds

with mock.patch("boto3.session.Session") as mocked_session:
session = mocked_session.return_value
session.get_credentials.side_effect = mock_get_credentials

opts = maybe_set_aws_credentials(path, options)

if options and not any(k in options for k in ("AWS_ACCESS_KEY_ID", "access_key")):
assert opts["AWS_ACCESS_KEY_ID"] == "access-key"
assert opts["AWS_SECRET_ACCESS_KEY"] == "secret-key"
assert opts["AWS_SESSION_TOKEN"] == "token"
assert opts["AWS_REGION"] == "foo-region"

assert opts["access_key"] == "access-key"
assert opts["secret_key"] == "secret-key"
assert opts["token"] == "token"
assert opts["region"] == "foo-region"

# Did not alter input options if credentials were supplied by user
elif options:
assert options == opts


@pytest.mark.parametrize("location", (None, "region-foo"))
@pytest.mark.parametrize(
"path,bucket",
(("s3://foo/bar", "foo"), ("s3://fizzbuzz", "fizzbuzz"), ("/not/s3", None)),
)
def test_get_bucket_region(location, path, bucket):
pytest.importorskip("boto3")

with mock.patch("boto3.client") as mock_client:
mock_client = mock_client.return_value
mock_client.get_bucket_location.return_value = {"LocationConstraint": location}

if not path.startswith("s3://"):
with pytest.raises(ValueError, match="is not an S3 path"):
get_bucket_region(path)
return

region = get_bucket_region(path)

# AWS returns None if bucket located in us-east-1...
location = location if location else "us-east-1"
assert region == location

mock_client.get_bucket_location.assert_has_calls([mock.call(Bucket=bucket)])
13 changes: 13 additions & 0 deletions tests/test_write.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import os
import unittest.mock as mock

import dask.dataframe as dd
import pandas as pd
Expand Down Expand Up @@ -61,6 +62,18 @@ def test_roundtrip(tmpdir, with_index, freq, partition_freq):
assert_eq(ddf_read, ddf_dask)


@mock.patch("dask_deltatable.utils.maybe_set_aws_credentials")
def test_writer_check_aws_credentials(maybe_set_aws_credentials, tmpdir):
# The full functionality of maybe_set_aws_credentials tests in test_utils
# we only need to ensure it's called here when writing with a str path
maybe_set_aws_credentials.return_value = dict()

df = pd.DataFrame({"col1": range(10)})
ddf = dd.from_pandas(df, npartitions=2)
to_deltalake(str(tmpdir), ddf)
maybe_set_aws_credentials.assert_called()


@pytest.mark.parametrize("unit", ["s", "ms", "us", "ns"])
def test_datetime(tmpdir, unit):
"""Ensure we can write datetime with different resolutions,
Expand Down

0 comments on commit ec1c90c

Please sign in to comment.