Skip to content

Commit

Permalink
[Datastore] When necessary, save the bucket as part of the data sourc…
Browse files Browse the repository at this point in the history
…e profile and use the default path if the path is not specified (#5478)
  • Loading branch information
alxtkr77 committed May 5, 2024
1 parent dcbd139 commit b0dbde7
Show file tree
Hide file tree
Showing 7 changed files with 173 additions and 53 deletions.
7 changes: 7 additions & 0 deletions docs/store/datastore.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ The equivalent to this parameter in environment authentication is "AZURE_STORAGE
Credential authentication:
- `credential` — TokenCredential or SAS token. The credentials with which to authenticate.
This variable is sensitive information and is kept confidential.
- `bucket` — A string representing the bucket. When specified, it is automatically prepended to the object path, and thus, it should not be manually included in the target path by the user.
This parameter will become mandatory starting with version 1.9.

## Databricks file system
### DBFS credentials and parameters
Expand Down Expand Up @@ -229,6 +231,8 @@ The equivalent to this parameter in environment authentication is "GOOGLE_APPLIC
- `gcp_credentials` — A JSON in a string format representing the authentication parameters required by GCS API.
For privacy reasons, it's tagged as a private attribute, and its default value is `None`.
The equivalent to this parameter in environment authentication is "GCP_CREDENTIALS".
- `bucket` — A string representing the bucket. When specified, it is automatically prepended to the object path, and thus, it should not be manually included in the target path by the user.
This parameter will become mandatory starting with version 1.9.

The code prioritizes `gcp_credentials` over `credentials_path`.

Expand Down Expand Up @@ -308,6 +312,9 @@ ParquetTarget(path="ds://profile-name/aws_bucket/path/to/parquet.pq")
- `assume_role_arn` — A string representing the Amazon Resource Name (ARN) of the role to assume when interacting with the S3 service. This can be useful for granting temporary permissions. By default, it is set to `None`. The equivalent to this parameter in environment authentication is env["MLRUN_AWS_ROLE_ARN"]
- `access_key_id` — A string representing the access key used for authentication to the S3 service. It's one of the credentials parts when you're not using anonymous access or IAM roles. For privacy reasons, it's tagged as a private attribute, and its default value is `None`. The equivalent to this parameter in environment authentication is env["AWS_ACCESS_KEY_ID"].
- `secret_key` — A string representing the secret key, which pairs with the access key, used for authentication to the S3 service. It's the second part of the credentials when not using anonymous access or IAM roles. It's also tagged as private for privacy and security reasons. The default value is `None`. The equivalent to this parameter in environment authentication is env["AWS_SECRET_ACCESS_KEY"].
- `bucket` — A string representing the bucket. When specified, it is automatically prepended to the object path, and thus, it should not be manually included in the target path by the user.
This parameter will become mandatory starting with version 1.9.


## V3IO

Expand Down
2 changes: 1 addition & 1 deletion mlrun/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@
"nosql": "v3io:///projects/{project}/FeatureStore/{name}/nosql",
# "authority" is optional and generalizes [userinfo "@"] host [":" port]
"redisnosql": "redis://{authority}/projects/{project}/FeatureStore/{name}/nosql",
"dsnosql": "ds://{ds_profile_name}/projects/{project}/FeatureStore/{name}/nosql",
"dsnosql": "ds://{ds_profile_name}/projects/{project}/FeatureStore/{name}/{kind}",
},
"default_targets": "parquet,nosql",
"default_job_image": "mlrun/mlrun",
Expand Down
53 changes: 50 additions & 3 deletions mlrun/datastore/datastore_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,17 @@ class DatastoreProfileS3(DatastoreProfile):
assume_role_arn: typing.Optional[str] = None
access_key_id: typing.Optional[str] = None
secret_key: typing.Optional[str] = None
bucket: typing.Optional[str] = None

@pydantic.validator("bucket")
def check_bucket(cls, v):
if not v:
warnings.warn(
"The 'bucket' attribute will be mandatory starting from version 1.9",
FutureWarning,
stacklevel=2,
)
return v

def secrets(self) -> dict:
res = {}
Expand All @@ -203,7 +214,13 @@ def secrets(self) -> dict:
return res

def url(self, subpath):
return f"s3:/{subpath}"
# TODO: There is an inconsistency with DatastoreProfileGCS. In DatastoreProfileGCS,
# we assume that the subpath can begin without a '/' character,
# while here we assume it always starts with one.
if self.bucket:
return f"s3://{self.bucket}{subpath}"
else:
return f"s3:/{subpath}"


class DatastoreProfileRedis(DatastoreProfile):
Expand Down Expand Up @@ -272,6 +289,17 @@ class DatastoreProfileGCS(DatastoreProfile):
_private_attributes = ("gcp_credentials",)
credentials_path: typing.Optional[str] = None # path to file.
gcp_credentials: typing.Optional[typing.Union[str, dict]] = None
bucket: typing.Optional[str] = None

@pydantic.validator("bucket")
def check_bucket(cls, v):
if not v:
warnings.warn(
"The 'bucket' attribute will be mandatory starting from version 1.9",
FutureWarning,
stacklevel=2,
)
return v

@pydantic.validator("gcp_credentials", pre=True, always=True)
def convert_dict_to_json(cls, v):
Expand All @@ -280,10 +308,15 @@ def convert_dict_to_json(cls, v):
return v

def url(self, subpath) -> str:
# TODO: but there's something wrong with the subpath being assumed to not start with a slash here,
# but the opposite assumption is made in S3.
if subpath.startswith("/"):
# in gcs the path after schema is starts with bucket, wherefore it should not start with "/".
subpath = subpath[1:]
return f"gcs://{subpath}"
if self.bucket:
return f"gcs://{self.bucket}/{subpath}"
else:
return f"gcs://{subpath}"

def secrets(self) -> dict:
res = {}
Expand Down Expand Up @@ -311,12 +344,26 @@ class DatastoreProfileAzureBlob(DatastoreProfile):
client_secret: typing.Optional[str] = None
sas_token: typing.Optional[str] = None
credential: typing.Optional[str] = None
bucket: typing.Optional[str] = None

@pydantic.validator("bucket")
def check_bucket(cls, v):
if not v:
warnings.warn(
"The 'bucket' attribute will be mandatory starting from version 1.9",
FutureWarning,
stacklevel=2,
)
return v

def url(self, subpath) -> str:
if subpath.startswith("/"):
# in azure the path after schema is starts with bucket, wherefore it should not start with "/".
subpath = subpath[1:]
return f"az://{subpath}"
if self.bucket:
return f"az://{self.bucket}/{subpath}"
else:
return f"az://{subpath}"

def secrets(self) -> dict:
res = {}
Expand Down
76 changes: 41 additions & 35 deletions mlrun/datastore/targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,29 @@ def get_target_templated_path(self):
def _target_path_object(self):
"""return the actual/computed target path"""
is_single_file = hasattr(self, "is_single_file") and self.is_single_file()

if self._resource and self.path:
parsed_url = urlparse(self.path)
# When the URL consists only from scheme and endpoint and no path,
# make a default path for DS and redis targets.
# Also ignore KafkaTarget when it uses the ds scheme (no default path for KafkaTarget)
if (
not isinstance(self, KafkaTarget)
and parsed_url.scheme in ["ds", "redis", "rediss"]
and (not parsed_url.path or parsed_url.path == "/")
):
return TargetPathObject(
_get_target_path(
self,
self._resource,
self.run_id is not None,
netloc=parsed_url.netloc,
scheme=parsed_url.scheme,
),
self.run_id,
is_single_file,
)

return self.get_path() or (
TargetPathObject(
_get_target_path(self, self._resource, self.run_id is not None),
Expand Down Expand Up @@ -1411,39 +1434,6 @@ class RedisNoSqlTarget(NoSqlBaseTarget):
support_spark = True
writer_step_name = "RedisNoSqlTarget"

@property
def _target_path_object(self):
url = self.path or mlrun.mlconf.redis.url
if self._resource and url:
parsed_url = urlparse(url)
if not parsed_url.path or parsed_url.path == "/":
kind_prefix = (
"sets"
if self._resource.kind
== mlrun.common.schemas.ObjectKind.feature_set
else "vectors"
)
kind = self.kind
name = self._resource.metadata.name
project = (
self._resource.metadata.project or mlrun.mlconf.default_project
)
data_prefix = get_default_prefix_for_target(kind).format(
ds_profile_name=parsed_url.netloc,
authority=parsed_url.netloc,
project=project,
kind=kind,
name=name,
)
if url.startswith("rediss://"):
data_prefix = data_prefix.replace("redis://", "rediss://", 1)
if not self.run_id:
version = self._resource.metadata.tag or "latest"
name = f"{name}-{version}"
url = f"{data_prefix}/{kind_prefix}/{name}"
return TargetPathObject(url, self.run_id, False)
return super()._target_path_object

# Fetch server url from the RedisNoSqlTarget::__init__() 'path' parameter.
# If not set fetch it from 'mlrun.mlconf.redis.url' (MLRUN_REDIS__URL environment variable).
# Then look for username and password at REDIS_xxx secrets
Expand Down Expand Up @@ -2201,7 +2191,7 @@ def _raise_sqlalchemy_import_error(exc):
}


def _get_target_path(driver, resource, run_id_mode=False):
def _get_target_path(driver, resource, run_id_mode=False, netloc=None, scheme=""):
"""return the default target path given the resource and target kind"""
kind = driver.kind
suffix = driver.suffix
Expand All @@ -2218,11 +2208,27 @@ def _get_target_path(driver, resource, run_id_mode=False):
)
name = resource.metadata.name
project = resource.metadata.project or mlrun.mlconf.default_project
data_prefix = get_default_prefix_for_target(kind).format(

default_kind_name = kind
if scheme == "ds":
# "dsnosql" is not an actual target like Parquet or Redis; rather, it serves
# as a placeholder that can be used in any specified target
default_kind_name = "dsnosql"
if scheme == "redis" or scheme == "rediss":
default_kind_name = TargetTypes.redisnosql

netloc = netloc or ""
data_prefix = get_default_prefix_for_target(default_kind_name).format(
ds_profile_name=netloc, # In case of ds profile, set its the name
authority=netloc, # In case of redis, replace {authority} with netloc
project=project,
kind=kind,
name=name,
)

if scheme == "rediss":
data_prefix = data_prefix.replace("redis://", "rediss://", 1)

# todo: handle ver tag changes, may need to copy files?
if not run_id_mode:
version = resource.metadata.tag
Expand Down
76 changes: 66 additions & 10 deletions tests/system/datastore/test_aws_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
register_temporary_client_datastore_profile,
)
from mlrun.datastore.sources import ParquetSource
from mlrun.datastore.targets import ParquetTarget
from mlrun.datastore.targets import ParquetTarget, get_default_prefix_for_target
from tests.system.base import TestMLRunSystem

test_environment = TestMLRunSystem._get_env_from_file()
Expand All @@ -48,7 +48,6 @@
not test_environment.get("AWS_BUCKET_NAME"),
reason="AWS_BUCKET_NAME is not set",
)
@pytest.mark.parametrize("use_datastore_profile", [True, False])
class TestAwsS3(TestMLRunSystem):
project_name = "s3-system-test"

Expand Down Expand Up @@ -78,8 +77,14 @@ def setup_method(self, method):
"s3": self._make_target_names(
"s3://", self._bucket_name, object_dir, object_file
),
"ds": self._make_target_names(
"ds://s3ds_profile/",
"ds_with_bucket": self._make_target_names(
"ds://s3ds_profile_with_bucket",
"", # no bucket, since it is part of the ds profile
object_dir,
object_file,
),
"ds_no_bucket": self._make_target_names(
"ds://s3ds_profile_no_bucket/",
self._bucket_name,
object_dir,
object_file,
Expand All @@ -88,7 +93,15 @@ def setup_method(self, method):

mlrun.get_or_create_project(self.project_name)
profile = DatastoreProfileS3(
name="s3ds_profile",
name="s3ds_profile_with_bucket",
access_key_id=self._access_key_id,
secret_key=self._secret_access_key,
bucket=self._bucket_name,
)
register_temporary_client_datastore_profile(profile)

profile = DatastoreProfileS3(
name="s3ds_profile_no_bucket",
access_key_id=self._access_key_id,
secret_key=self._secret_access_key,
)
Expand All @@ -105,21 +118,20 @@ def custom_teardown(self):
s3_fs.rm(file)
s3_fs.rm(full_path)

def test_ingest_with_parquet_source(self, use_datastore_profile):
@pytest.mark.parametrize("url_type", ["s3", "ds_with_bucket", "ds_no_bucket"])
def test_ingest_with_parquet_source(self, url_type):
# create source
s3_fs = fsspec.filesystem(
"s3", key=self._access_key_id, secret=self._secret_access_key
)
param = self.s3["ds"] if use_datastore_profile else self.s3["s3"]
param = self.s3[url_type]
print(f"Using URL {param['parquet_url']}\n")
data = {"Column1": [1, 2, 3], "Column2": ["A", "B", "C"]}
df = pd.DataFrame(data)
source_path = param["parquet_url"]
with tempfile.NamedTemporaryFile(mode="w+", delete=True) as temp_file:
df.to_parquet(temp_file.name)
path_only = source_path.replace("ds://s3ds_profile/", "").replace(
"s3://", ""
)
path_only = self.s3["s3"]["parquet_url"]
s3_fs.put_file(temp_file.name, path_only)
parquet_source = ParquetSource(name="test", path=source_path)

Expand All @@ -140,3 +152,47 @@ def test_ingest_with_parquet_source(self, use_datastore_profile):
assert_frame_equal(
df.sort_index(axis=1), result.sort_index(axis=1), check_like=True
)

def test_ingest_ds_default_target(self):
s3_fs = fsspec.filesystem(
"s3", key=self._access_key_id, secret=self._secret_access_key
)
param = self.s3["ds_with_bucket"]
print(f"Using URL {param['parquet_url']}\n")
data = {"Column1": [1, 2, 3], "Column2": ["A", "B", "C"]}
df = pd.DataFrame(data)
source_path = param["parquet_url"]
with tempfile.NamedTemporaryFile(mode="w+", delete=True) as temp_file:
df.to_parquet(temp_file.name)
path_only = self.s3["s3"]["parquet_url"]
s3_fs.put_file(temp_file.name, path_only)

parquet_source = ParquetSource(name="test", path=source_path)

targets = [ParquetTarget(path="ds://s3ds_profile_with_bucket")]
fset = fstore.FeatureSet(
name="test_fs",
entities=[fstore.Entity("Column1")],
)

fset.ingest(source=parquet_source, targets=targets)

expected_default_ds_data_prefix = get_default_prefix_for_target(
"dsnosql"
).format(
ds_profile_name="s3ds_profile_with_bucket",
project=fset.metadata.project,
kind=targets[0].kind,
name=fset.metadata.name,
)

assert fset.get_target_path().startswith(expected_default_ds_data_prefix)

result = ParquetSource(path=fset.get_target_path()).to_dataframe(
columns=("Column1", "Column2")
)
result.reset_index(inplace=True, drop=False)

assert_frame_equal(
df.sort_index(axis=1), result.sort_index(axis=1), check_like=True
)
6 changes: 4 additions & 2 deletions tests/system/datastore/test_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def teardown_class(cls):
def setup_before_each_test(self, use_datastore_profile):
self._object_dir = self.test_dir + "/" + f"target_directory_{uuid.uuid4()}"
self._bucket_path = (
f"ds://{self.profile_name}/{self._bucket_name}"
f"ds://{self.profile_name}"
if use_datastore_profile
else "az://" + self._bucket_name
)
Expand All @@ -87,7 +87,9 @@ def setup_before_each_test(self, use_datastore_profile):
logger.info(f"Object URL template: {self._target_url_template}")
if use_datastore_profile:
kwargs = {"connection_string": self.connection_string}
profile = DatastoreProfileAzureBlob(name=self.profile_name, **kwargs)
profile = DatastoreProfileAzureBlob(
name=self.profile_name, bucket=self._bucket_name, **kwargs
)
register_temporary_client_datastore_profile(profile)
os.environ.pop("AZURE_STORAGE_CONNECTION_STRING", None)
else:
Expand Down

0 comments on commit b0dbde7

Please sign in to comment.