Skip to content

Commit

Permalink
make it possible to log a dataset without loading anything
Browse files Browse the repository at this point in the history
Signed-off-by: chenmoneygithub <chen.qian@databricks.com>
  • Loading branch information
chenmoneygithub committed Feb 16, 2024
1 parent cd90acf commit 2a8d839
Show file tree
Hide file tree
Showing 9 changed files with 119 additions and 113 deletions.
44 changes: 24 additions & 20 deletions mlflow/data/dataset.py
@@ -1,3 +1,4 @@
import hashlib
import json
from abc import abstractmethod
from typing import Any, Dict, Optional
Expand Down Expand Up @@ -30,28 +31,36 @@ def __init__(

@abstractmethod
def _compute_digest(self) -> str:
"""Computes a digest for the dataset. Called if the user doesn't supply
a digest when constructing the dataset.
"""Computes a digest for the dataset.
Called if the user doesn't supply a digest when constructing the dataset. Users can override
this method in subclasses to provide custom digest computation logic.
Returns:
A string digest for the dataset. We recommend a maximum digest length
of 10 characters with an ideal length of 8 characters.
"""
config = {
"name": self.name,
"source": self.source.to_json(),
"source_type": self.source._get_source_type(),
}
return hashlib.md5(json.dumps(config).encode("utf-8")).hexdigest()[:8]

Check failure on line 49 in mlflow/data/dataset.py

View workflow job for this annotation

GitHub Actions / lint

Probable use of insecure hash functions in `hashlib`: `md5`. See https://docs.astral.sh/ruff/rules/S324 for how to fix this error.

@abstractmethod
def _to_dict(self, base_dict: Dict[str, str]) -> Dict[str, str]:
"""
Args:
base_dict: A string dictionary of base information about the
dataset, including: name, digest, source, and source
type.
def to_dict(self) -> Dict[str, str]:
"""Create config dictionary for the dataset.
Returns:
A string dictionary containing the following fields: name,
digest, source, source type, schema (optional), profile
(optional).
Subclasses should override this method to provide a additional fields in the config dict,
e.g., schema, profile, etc.
"""
return {
"name": self.name,
"digest": self.digest,
"source": self.source.to_json(),
"source_type": self.source._get_source_type(),
}

def to_json(self) -> str:
"""
Expand All @@ -61,13 +70,8 @@ def to_json(self) -> str:
Returns:
A JSON string representation of the :py:class:`Dataset <mlflow.data.dataset.Dataset>`.
"""
base_dict = {
"name": self.name,
"digest": self.digest,
"source": self._source.to_json(),
"source_type": self._source._get_source_type(),
}
return json.dumps(self._to_dict(base_dict))

return json.dumps(self.to_dict())

@property
def name(self) -> str:
Expand Down Expand Up @@ -115,7 +119,7 @@ def schema(self) -> Optional[Any]:
def _to_mlflow_entity(self) -> DatasetEntity:
"""
Returns:
A DatasetEntity instance representing the dataset.
A `mlflow.entities.Dataset` instance representing the dataset.
"""
dataset_json = json.loads(self.to_json())
return DatasetEntity(
Expand Down
30 changes: 11 additions & 19 deletions mlflow/data/huggingface_dataset.py
Expand Up @@ -69,25 +69,17 @@ def _compute_digest(self) -> str:
)
return compute_pandas_digest(df)

def _to_dict(self, base_dict: Dict[str, str]) -> Dict[str, str]:
"""
Args:
base_dict: A string dictionary of base information about the
dataset, including: name, digest, source, and source
type.
Returns:
A string dictionary containing the following fields: name,
digest, source, source type, schema (optional), profile
(optional).
"""
return {
**base_dict,
"schema": json.dumps({"mlflow_colspec": self.schema.to_dict()})
if self.schema
else None,
"profile": json.dumps(self.profile),
}
def to_dict(self) -> Dict[str, str]:
"""Create config dictionary for the dataset."""
schema = json.dumps({"mlflow_colspec": self.schema.to_dict()}) if self.schema else None
config = super().to_dict()
config.update(
{
"schema": schema,
"profile": json.dumps(self.profile),
}
)
return config

@property
def ds(self) -> "datasets.Dataset":
Expand Down
27 changes: 11 additions & 16 deletions mlflow/data/numpy_dataset.py
Expand Up @@ -53,22 +53,17 @@ def _compute_digest(self) -> str:
"""
return compute_numpy_digest(self._features, self._targets)

def _to_dict(self, base_dict: Dict[str, str]) -> Dict[str, str]:
"""
Args:
base_dict: A string dictionary of base information about the
dataset, including: name, digest, source, and source type.
Returns:
A string dictionary containing the following fields: name,
digest, source, source type, schema (optional), profile
(optional).
"""
return {
**base_dict,
"schema": json.dumps(self.schema.to_dict()) if self.schema else None,
"profile": json.dumps(self.profile),
}
def to_dict(self) -> Dict[str, str]:
"""Create config dictionary for the dataset."""
schema = json.dumps({"mlflow_colspec": self.schema.to_dict()}) if self.schema else None
config = super().to_dict()
config.update(
{
"schema": schema,
"profile": json.dumps(self.profile),
}
)
return config

@property
def source(self) -> DatasetSource:
Expand Down
29 changes: 11 additions & 18 deletions mlflow/data/pandas_dataset.py
Expand Up @@ -70,24 +70,17 @@ def _compute_digest(self) -> str:
"""
return compute_pandas_digest(self._df)

def _to_dict(self, base_dict: Dict[str, str]) -> Dict[str, str]:
"""
Args:
base_dict: A string dictionary of base information about the
dataset, including: name, digest, source, and source type.
Returns:
A string dictionary containing the following fields: name,
digest, source, source type, schema (optional), profile
(optional).
"""
return {
**base_dict,
"schema": json.dumps({"mlflow_colspec": self.schema.to_dict()})
if self.schema
else None,
"profile": json.dumps(self.profile),
}
def to_dict(self) -> Dict[str, str]:
"""Create config dictionary for the dataset."""
schema = json.dumps({"mlflow_colspec": self.schema.to_dict()}) if self.schema else None
config = super().to_dict()
config.update(
{
"schema": schema,
"profile": json.dumps(self.profile),
}
)
return config

@property
def df(self) -> pd.DataFrame:
Expand Down
29 changes: 11 additions & 18 deletions mlflow/data/spark_dataset.py
Expand Up @@ -57,24 +57,17 @@ def _compute_digest(self) -> str:
# and deterministic than hashing DataFrame records
return compute_spark_df_digest(self._df)

def _to_dict(self, base_dict: Dict[str, str]) -> Dict[str, str]:
"""
Args:
base_dict: A string dictionary of base information about the
dataset, including: name, digest, source, and source type.
Returns:
A string dictionary containing the following fields: name,
digest, source, source type, schema (optional), profile
(optional).
"""
return {
**base_dict,
"schema": json.dumps({"mlflow_colspec": self.schema.to_dict()})
if self.schema
else None,
"profile": json.dumps(self.profile),
}
def to_dict(self) -> Dict[str, str]:
"""Create config dictionary for the dataset."""
schema = json.dumps({"mlflow_colspec": self.schema.to_dict()}) if self.schema else None
config = super().to_dict()
config.update(
{
"schema": schema,
"profile": json.dumps(self.profile),
}
)
return config

@property
def df(self):
Expand Down
27 changes: 11 additions & 16 deletions mlflow/data/tensorflow_dataset.py
Expand Up @@ -88,22 +88,17 @@ def _compute_digest(self) -> str:
else compute_tensor_digest(self._features, self._targets)
)

def _to_dict(self, base_dict: Dict[str, str]) -> Dict[str, str]:
"""
Args:
base_dict: A string dictionary of base information about the
dataset, including: name, digest, source, and source type.
Returns:
A string dictionary containing the following fields: name,
digest, source, source type, schema (optional), profile
(optional).
"""
return {
**base_dict,
"schema": json.dumps(self.schema.to_dict()) if self.schema else None,
"profile": json.dumps(self.profile),
}
def to_dict(self) -> Dict[str, str]:
"""Create config dictionary for the dataset."""
schema = json.dumps(self.schema.to_dict()) if self.schema else None
config = super().to_dict()
config.update(
{
"schema": schema,
"profile": json.dumps(self.profile),
}
)
return config

@property
def data(self):
Expand Down
15 changes: 15 additions & 0 deletions tests/data/test_dataset.py
@@ -1,5 +1,7 @@
import json

from mlflow.data.dataset import Dataset
from mlflow.data.http_dataset_source import HTTPDatasetSource
from mlflow.types.schema import Schema

from tests.resources.data.dataset import SampleDataset
Expand Down Expand Up @@ -40,3 +42,16 @@ def test_expected_name_is_used():

dataset_with_name = SampleDataset(data_list=[1, 2, 3], source=source, name="testname")
assert dataset_with_name.name == "testname"


def test_create_dataset_from_only_source():
source_uri = "test:/my/test/uri"
source = HTTPDatasetSource(url=source_uri)
dataset = Dataset(source=source)

json_str = dataset.to_json()
parsed_json = json.loads(json_str)

assert parsed_json["digest"] != None

Check failure on line 55 in tests/data/test_dataset.py

View workflow job for this annotation

GitHub Actions / lint

Comparison to `None` should be `cond is not None`. See https://docs.astral.sh/ruff/rules/E711 for how to fix this error.
assert parsed_json["source"] == '{"url": "test:/my/test/uri"}'
assert parsed_json["source_type"] == "http"
15 changes: 9 additions & 6 deletions tests/resources/data/dataset.py
Expand Up @@ -34,7 +34,7 @@ def _compute_digest(self) -> str:
hash_md5.update(hash_part)
return base64.b64encode(hash_md5.digest()).decode("ascii")

def _to_dict(self, base_dict: Dict[str, str]) -> Dict[str, str]:
def to_dict(self) -> Dict[str, str]:
"""
Args:
base_dict: A string dictionary of base information about the
Expand All @@ -46,11 +46,14 @@ def _to_dict(self, base_dict: Dict[str, str]) -> Dict[str, str]:
digest, source, source type, schema (optional), profile
(optional).
"""
return {
**base_dict,
"schema": json.dumps({"mlflow_colspec": self.schema.to_dict()}),
"profile": json.dumps(self.profile),
}
config = super().to_dict()
config.update(
{
"schema": json.dumps({"mlflow_colspec": self.schema.to_dict()}),
"profile": json.dumps(self.profile),
}
)
return config

@property
def data_list(self) -> List[int]:
Expand Down
16 changes: 16 additions & 0 deletions tests/tracking/fluent/test_fluent.py
Expand Up @@ -15,6 +15,7 @@
import mlflow.tracking.context.registry
import mlflow.tracking.fluent
from mlflow import MlflowClient
from mlflow.data.http_dataset_source import HTTPDatasetSource
from mlflow.data.pandas_dataset import from_pandas
from mlflow.entities import (
LifecycleStage,
Expand Down Expand Up @@ -1273,6 +1274,21 @@ def test_log_input(tmp_path):
assert dataset_inputs[0].tags[0].value == "train"


def test_log_input_metadata_only():
source_uri = "test:/my/test/uri"
source = HTTPDatasetSource(url=source_uri)
dataset = mlflow.data.Dataset(source=source)

with start_run() as run:
mlflow.log_input(dataset, "train")
dataset_inputs = MlflowClient().get_run(run.info.run_id).inputs.dataset_inputs
assert len(dataset_inputs) == 1
assert dataset_inputs[0].dataset.name == "dataset"
assert dataset_inputs[0].dataset.digest != None

Check failure on line 1287 in tests/tracking/fluent/test_fluent.py

View workflow job for this annotation

GitHub Actions / lint

Comparison to `None` should be `cond is not None`. See https://docs.astral.sh/ruff/rules/E711 for how to fix this error.
assert dataset_inputs[0].dataset.source_type == "http"
assert json.loads(dataset_inputs[0].dataset.source) == {"url": source_uri}


def test_get_parent_run():
with mlflow.start_run() as parent:
mlflow.log_param("a", 1)
Expand Down

0 comments on commit 2a8d839

Please sign in to comment.