Skip to content

Commit

Permalink
Fix s3 client kwargs (#3316)
Browse files Browse the repository at this point in the history
* Fix s3 client kwargs

* Add type annotation

* Fix UT

* Fix UT

Co-authored-by: 刘宝 <po.lb@antgroup.com>
  • Loading branch information
fyrestone and 刘宝 committed Jan 10, 2023
1 parent 996ce47 commit aa1b261
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 3 deletions.
6 changes: 6 additions & 0 deletions mars/lib/filesystem/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,9 @@ def parse_from_path(uri: str):
if parsed_uri.password:
options["password"] = parsed_uri.password
return options

@classmethod
def get_storage_options(cls, storage_options: Dict, uri: str) -> Dict:
options = cls.parse_from_path(uri)
storage_options.update(options)
return storage_options
5 changes: 3 additions & 2 deletions mars/lib/filesystem/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ def get_fs(path: path_type, storage_options: Dict = None) -> FileSystem:
# local file systems are singletons.
return file_system_type.get_instance()
else:
options = file_system_type.parse_from_path(path)
storage_options.update(options)
storage_options = file_system_type.get_storage_options(
storage_options, path
)
return file_system_type(**storage_options)
elif scheme in _scheme_to_dependencies: # pragma: no cover
dependencies = ", ".join(_scheme_to_dependencies[scheme])
Expand Down
14 changes: 13 additions & 1 deletion mars/lib/filesystem/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Dict

"""
An example to read csv from s3
Expand All @@ -26,8 +27,9 @@
>>> "endpoint_url": "http://192.168.1.12:9000",
>>> "aws_access_key_id": "<s3 access id>",
>>> "aws_secret_access_key": "<s3 access key>",
>>> "aws_session_token": "<s3 session token>",
>>> }})
>>> # Export environment vars AWS_ENDPOINT_URL / AWS_ACCESS_KEY_ID / AWS_SECRET_ACCESS_KEY.
>>> # Export environment vars AWS_ENDPOINT_URL / AWS_ACCESS_KEY_ID / AWS_SECRET_ACCESS_KEY / AWS_SESSION_TOKEN.
>>> mdf = md.read_csv("s3://bucket/example.csv", index_col=0)
>>> r = mdf.head(1000).execute()
>>> print(r)
Expand Down Expand Up @@ -62,6 +64,16 @@ def parse_from_path(uri: str):
client_kwargs = {k: v for k, v in client_kwargs.items() if v is not None}
return {"client_kwargs": client_kwargs}

@classmethod
def get_storage_options(cls, storage_options: Dict, uri: str) -> Dict:
options = cls.parse_from_path(uri)
for k, v in storage_options.items():
if k == "client_kwargs":
options["client_kwargs"].update(v)
else:
options[k] = v
return options

register_filesystem("s3", S3FileSystem)
else:
S3FileSystem = None
98 changes: 98 additions & 0 deletions mars/lib/filesystem/tests/test_s3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright 1999-2021 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

import pytest

from ....dataframe import read_csv
from ..core import register_filesystem
from ..s3 import S3FileSystem


class KwArgsException(Exception):
def __init__(self, kwargs):
self.kwargs = kwargs


if S3FileSystem is not None:

class TestS3FileSystem(S3FileSystem):
def __init__(self, **kwargs):
super().__init__(**kwargs)
raise KwArgsException(kwargs)

else:
TestS3FileSystem = None


@pytest.mark.skipif(S3FileSystem is None, reason="S3 is not supported")
def test_client_kwargs():
register_filesystem("s3", TestS3FileSystem)

test_kwargs = {
"endpoint_url": "http://192.168.1.12:9000",
"aws_access_key_id": "test_id",
"aws_secret_access_key": "test_key",
"aws_session_token": "test_session_token",
}

def _assert_true():
# Pass endpoint_url / aws_access_key_id / aws_secret_access_key / aws_session_token to read_csv.
with pytest.raises(KwArgsException) as e:
read_csv(
"s3://bucket/example.csv",
index_col=0,
storage_options={"client_kwargs": test_kwargs},
)
assert e.value.kwargs == {
"client_kwargs": {
"endpoint_url": "http://192.168.1.12:9000",
"aws_access_key_id": "test_id",
"aws_secret_access_key": "test_key",
"aws_session_token": "test_session_token",
}
}

_assert_true()

test_env = {
"AWS_ENDPOINT_URL": "a",
"AWS_ACCESS_KEY_ID": "b",
"AWS_SECRET_ACCESS_KEY": "c",
"AWS_SESSION_TOKEN": "d",
}
for k, v in test_env.items():
os.environ[k] = v

try:
_assert_true()

for k, v in test_kwargs.items():
with pytest.raises(KwArgsException) as e:
read_csv(
"s3://bucket/example.csv",
index_col=0,
storage_options={"client_kwargs": {k: v}},
)
expect = {
"endpoint_url": "a",
"aws_access_key_id": "b",
"aws_secret_access_key": "c",
"aws_session_token": "d",
}
expect[k] = v
assert e.value.kwargs == {"client_kwargs": expect}
finally:
for k, v in test_env.items():
os.environ.pop(k, None)

0 comments on commit aa1b261

Please sign in to comment.