Skip to content

Commit

Permalink
add NamespacedTagSet api (#22036)
Browse files Browse the repository at this point in the history
## Summary

Largely clones `NamespacedMetadataSet` into a new `NamespacedTagSet` ABC which can be used to define a set of tags which will be logically set together in code, and which have a namespace prefix.

A bit simpler, in that all values must be strings.

## Test Plan

New little unit test suite.
  • Loading branch information
benpankow committed May 24, 2024
1 parent 71b0e8e commit eebaa94
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 36 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ Pipfile.lock
.mypy_cache/

tags
!python_modules/dagster/dagster/_core/definitions/tags

.pytest_cache
.DS_Store
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@

from .metadata_value import MetadataValue, TableColumnLineage, TableSchema

T_NamespacedMetadataSet = TypeVar("T_NamespacedMetadataSet", bound="NamespacedMetadataSet")


# Python types that have a MetadataValue types that directly wraps them
DIRECTLY_WRAPPED_METADATA_TYPES = {
str,
Expand All @@ -25,39 +22,20 @@
type(None),
}

T_NamespacedMetadataSet = TypeVar("T_NamespacedMetadataSet", bound="NamespacedMetadataSet")
T_NamespacedKVSet = TypeVar("T_NamespacedKVSet", bound="NamespacedKVSet")


def is_raw_metadata_type(t: Type) -> bool:
return issubclass(t, MetadataValue) or t in DIRECTLY_WRAPPED_METADATA_TYPES


class NamespacedMetadataSet(ABC, DagsterModel):
"""Extend this class to define a set of metadata fields in the same namespace.
class NamespacedKVSet(ABC, DagsterModel):
"""Base class for defining a set of key-value pairs in the same namespace.
Supports splatting to a dictionary that can be placed inside a metadata argument along with
other dictionary-structured metadata.
.. code-block:: python
Includes shared behavior between NamespacedMetadataSet and NamespacedTagSet.
my_metadata: NamespacedMetadataSet = ...
return MaterializeResult(metadata={**my_metadata, ...})
"""

def __init__(self, *args, **kwargs):
for field_name in model_fields(self).keys():
annotation_types = self._get_accepted_types_for_field(field_name)
invalid_annotation_types = {
annotation_type
for annotation_type in annotation_types
if not is_raw_metadata_type(annotation_type)
}
if invalid_annotation_types:
check.failed(
f"Type annotation for field '{field_name}' includes invalid metadata type(s): {invalid_annotation_types}"
)
super().__init__(*args, **kwargs)

@classmethod
@abstractmethod
def namespace(cls) -> str:
Expand All @@ -84,29 +62,27 @@ def __getitem__(self, key: str) -> Any:
return getattr(self, self._strip_namespace_from_key(key))

@classmethod
def extract(
cls: Type[T_NamespacedMetadataSet], metadata: Mapping[str, Any]
) -> T_NamespacedMetadataSet:
"""Extracts entries from the provided metadata dictionary into an instance of this class.
def extract(cls: Type[T_NamespacedKVSet], values: Mapping[str, Any]) -> T_NamespacedKVSet:
"""Extracts entries from the provided dictionary into an instance of this class.
Ignores any entries in the metadata dictionary whose keys don't correspond to fields on this
Ignores any entries in the dictionary whose keys don't correspond to fields on this
class.
In general, the following should always pass:
.. code-block:: python
class MyMetadataSet(NamedspacedMetadataSet):
class MyKVSet(NamespacedKVSet):
...
metadata: MyMetadataSet = ...
assert MyMetadataSet.extract(dict(metadata)) == metadata
metadata: MyKVSet = ...
assert MyKVSet.extract(dict(metadata)) == metadata
Args:
metadata (Mapping[str, Any]): A dictionary of metadata entries.
values (Mapping[str, Any]): A dictionary of entries to extract.
"""
kwargs = {}
for namespaced_key, value in metadata.items():
for namespaced_key, value in values.items():
splits = namespaced_key.split("/")
if len(splits) == 2:
namespace, key = splits
Expand All @@ -115,6 +91,39 @@ class MyMetadataSet(NamedspacedMetadataSet):

return cls(**kwargs)

@classmethod
@abstractmethod
def _extract_value(cls, field_name: str, value: Any) -> Any:
"""Based on type annotation, potentially coerce the value to the expected type."""
...


class NamespacedMetadataSet(NamespacedKVSet):
"""Extend this class to define a set of metadata fields in the same namespace.
Supports splatting to a dictionary that can be placed inside a metadata argument along with
other dictionary-structured metadata.
.. code-block:: python
my_metadata: NamespacedMetadataSet = ...
return MaterializeResult(metadata={**my_metadata, ...})
"""

def __init__(self, *args, **kwargs) -> None:
for field_name in model_fields(self).keys():
annotation_types = self._get_accepted_types_for_field(field_name)
invalid_annotation_types = {
annotation_type
for annotation_type in annotation_types
if not is_raw_metadata_type(annotation_type)
}
if invalid_annotation_types:
check.failed(
f"Type annotation for field '{field_name}' includes invalid metadata type(s): {invalid_annotation_types}"
)
super().__init__(*args, **kwargs)

@classmethod
def _extract_value(cls, field_name: str, value: Any) -> Any:
"""Based on type annotation, potentially coerce the metadata value to its inner value.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .tag_set import NamespacedTagSet as NamespacedTagSet
45 changes: 45 additions & 0 deletions python_modules/dagster/dagster/_core/definitions/tags/tag_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import Any, cast

from typing_extensions import TypeVar, get_args

from dagster import _check as check
from dagster._core.definitions.metadata.metadata_set import NamespacedKVSet
from dagster._model.pydantic_compat_layer import model_fields
from dagster._utils.typing_api import is_closed_python_optional_type

T_NamespacedTagSet = TypeVar("T_NamespacedTagSet", bound="NamespacedTagSet")


class NamespacedTagSet(NamespacedKVSet):
"""Extend this class to define a set of tags in the same namespace.
Supports splatting to a dictionary that can be placed inside a tags argument along with
other tags.
.. code-block:: python
my_tags: NamespacedTagsSet = ...
@asset(
tags={**my_tags}
)
def my_asset():
pass
"""

def __init__(self, *args, **kwargs) -> None:
for field_name, field in model_fields(self).items():
annotation_type = field.annotation

is_optional_str = is_closed_python_optional_type(annotation_type) and str in get_args(
annotation_type
)
if not (is_optional_str or annotation_type is str):
check.failed(
f"Type annotation for field '{field_name}' is not str or Optional[str]"
)
super().__init__(*args, **kwargs)

@classmethod
def _extract_value(cls, field_name: str, value: Any) -> str:
"""Since all tag values are strings, we don't need to do any type coercion."""
return cast(str, value)
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import Optional

import pydantic
import pytest
from dagster._check import CheckError
from dagster._core.definitions.tags import NamespacedTagSet


def test_invalid_tag_set() -> None:
class MyJunkTagSet(NamespacedTagSet):
junk: int

@classmethod
def namespace(cls) -> str:
return "dagster"

with pytest.raises(
CheckError, match="Type annotation for field 'junk' is not str or Optional\\[str\\]"
):
MyJunkTagSet(junk=1)


def test_basic_tag_set_validation() -> None:
class MyValidTagSet(NamespacedTagSet):
foo: Optional[str] = None
bar: Optional[str] = None

@classmethod
def namespace(cls) -> str:
return "dagster"

with pytest.raises(pydantic.ValidationError):
MyValidTagSet(foo="lorem", bar=lambda x: x) # type: ignore

with pytest.raises(pydantic.ValidationError):
MyValidTagSet(foo="lorem", baz="ipsum") # type: ignore


def test_basic_tag_set_functionality() -> None:
class MyValidTagSet(NamespacedTagSet):
foo: Optional[str] = None
bar: Optional[str] = None

@classmethod
def namespace(cls) -> str:
return "dagster"

tag_set = MyValidTagSet(foo="lorem", bar="ipsum")
assert tag_set.foo == "lorem"
assert tag_set.bar == "ipsum"

assert tag_set.extract((dict(tag_set))) == tag_set
assert tag_set.extract({}) == MyValidTagSet(foo=None, bar=None)

assert dict(tag_set) == {"dagster/foo": "lorem", "dagster/bar": "ipsum"}

0 comments on commit eebaa94

Please sign in to comment.