Skip to content

Commit

Permalink
unify tag/metadata set baseclass
Browse files Browse the repository at this point in the history
  • Loading branch information
benpankow committed May 22, 2024
1 parent 973f3ac commit f076dd7
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 102 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@

from .metadata_value import MetadataValue, TableColumnLineage, TableSchema

T_NamespacedMetadataSet = TypeVar("T_NamespacedMetadataSet", bound="NamespacedMetadataSet")


# Python types that have a MetadataValue types that directly wraps them
DIRECTLY_WRAPPED_METADATA_TYPES = {
str,
Expand All @@ -25,39 +22,20 @@
type(None),
}

T_NamespacedMetadataSet = TypeVar("T_NamespacedMetadataSet", bound="NamespacedMetadataSet")
T_NamespacedKVSet = TypeVar("T_NamespacedKVSet", bound="NamespacedKVSet")


def is_raw_metadata_type(t: Type) -> bool:
return issubclass(t, MetadataValue) or t in DIRECTLY_WRAPPED_METADATA_TYPES


class NamespacedMetadataSet(ABC, DagsterModel):
"""Extend this class to define a set of metadata fields in the same namespace.
class NamespacedKVSet(ABC, DagsterModel):
"""Base class for defining a set of key-value pairs in the same namespace.
Supports splatting to a dictionary that can be placed inside a metadata argument along with
other dictionary-structured metadata.
.. code-block:: python
Includes shared behavior between NamespacedMetadataSet and NamespacedTagSet.
my_metadata: NamespacedMetadataSet = ...
return MaterializeResult(metadata={**my_metadata, ...})
"""

def __init__(self, *args, **kwargs):
for field_name in model_fields(self).keys():
annotation_types = self._get_accepted_types_for_field(field_name)
invalid_annotation_types = {
annotation_type
for annotation_type in annotation_types
if not is_raw_metadata_type(annotation_type)
}
if invalid_annotation_types:
check.failed(
f"Type annotation for field '{field_name}' includes invalid metadata type(s): {invalid_annotation_types}"
)
super().__init__(*args, **kwargs)

@classmethod
@abstractmethod
def namespace(cls) -> str:
Expand All @@ -84,9 +62,7 @@ def __getitem__(self, key: str) -> Any:
return getattr(self, self._strip_namespace_from_key(key))

@classmethod
def extract(
cls: Type[T_NamespacedMetadataSet], metadata: Mapping[str, Any]
) -> T_NamespacedMetadataSet:
def extract(cls: Type[T_NamespacedKVSet], metadata: Mapping[str, Any]) -> T_NamespacedKVSet:
"""Extracts entries from the provided metadata dictionary into an instance of this class.
Ignores any entries in the metadata dictionary whose keys don't correspond to fields on this
Expand Down Expand Up @@ -115,6 +91,39 @@ class MyMetadataSet(NamedspacedMetadataSet):

return cls(**kwargs)

@classmethod
@abstractmethod
def _extract_value(cls, field_name: str, value: Any) -> Any:
"""Based on type annotation, potentially coerce the value to the expected type."""
...


class NamespacedMetadataSet(NamespacedKVSet):
"""Extend this class to define a set of metadata fields in the same namespace.
Supports splatting to a dictionary that can be placed inside a metadata argument along with
other dictionary-structured metadata.
.. code-block:: python
my_metadata: NamespacedMetadataSet = ...
return MaterializeResult(metadata={**my_metadata, ...})
"""

def __init__(self, *args, **kwargs) -> None:
for field_name in model_fields(self).keys():
annotation_types = self._get_accepted_types_for_field(field_name)
invalid_annotation_types = {
annotation_type
for annotation_type in annotation_types
if not is_raw_metadata_type(annotation_type)
}
if invalid_annotation_types:
check.failed(
f"Type annotation for field '{field_name}' includes invalid metadata type(s): {invalid_annotation_types}"
)
super().__init__(*args, **kwargs)

@classmethod
def _extract_value(cls, field_name: str, value: Any) -> Any:
"""Based on type annotation, potentially coerce the metadata value to its inner value.
Expand Down
87 changes: 14 additions & 73 deletions python_modules/dagster/dagster/_core/definitions/tags/tag_set.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
from abc import ABC, abstractmethod
from functools import lru_cache
from typing import AbstractSet, Any, Mapping, Type
from typing import Any

from typing_extensions import TypeVar
from typing_extensions import TypeVar, get_args

from dagster import _check as check
from dagster._model import DagsterModel
from dagster._core.definitions.metadata.metadata_set import NamespacedKVSet
from dagster._model.pydantic_compat_layer import model_fields
from dagster._utils.typing_api import flatten_unions
from dagster._utils.typing_api import is_closed_python_optional_type

T_NamespacedTagSet = TypeVar("T_NamespacedTagSet", bound="NamespacedTagSet")


class NamespacedTagSet(ABC, DagsterModel):
class NamespacedTagSet(NamespacedKVSet):
"""Extend this class to define a set of tags fields in the same namespace.
Supports splatting to a dictionary that can be placed inside a tags argument along with
Expand All @@ -29,76 +27,19 @@ def my_asset():
"""

def __init__(self, *args, **kwargs) -> None:
for field_name in model_fields(self).keys():
annotation_types = self._get_accepted_types_for_field(field_name)
invalid_annotation_types = {
for field_name, field in model_fields(self).items():
annotation_type = field.annotation

is_optional_str = is_closed_python_optional_type(annotation_type) and str in get_args(
annotation_type
for annotation_type in annotation_types
if annotation_type not in (str, type(None))
}
if invalid_annotation_types:
)
if not (is_optional_str or annotation_type is str):
check.failed(
f"Type annotation for field '{field_name}' is not str or Optional[str]"
)
super().__init__(*args, **kwargs)

@classmethod
@abstractmethod
def namespace(cls) -> str:
raise NotImplementedError()

@classmethod
def _namespaced_key(cls, key: str) -> str:
return f"{cls.namespace()}/{key}"

@staticmethod
def _strip_namespace_from_key(key: str) -> str:
return key.split("/", 1)[1]

def keys(self) -> AbstractSet[str]:
return {
self._namespaced_key(key)
for key in model_fields(self).keys()
# getattr returns the pydantic property on the subclass
if getattr(self, key) is not None
}

def __getitem__(self, key: str) -> Any:
# getattr returns the pydantic property on the subclass
return getattr(self, self._strip_namespace_from_key(key))

@classmethod
def extract(cls: Type[T_NamespacedTagSet], tags: Mapping[str, str]) -> T_NamespacedTagSet:
"""Extracts entries from the provided tags dictionary into an instance of this class.
Ignores any entries in the tags dictionary whose keys don't correspond to fields on this
class.
In general, the following should always pass:
.. code-block:: python
class MyTagSet(NamespacedTagSet):
...
tags: MyTagSet = ...
assert MyTagSet.extract(dict(metadata)) == metadata
Args:
tags (Mapping[str, str]): A dictionary of tags.
"""
kwargs = {}
for namespaced_key, value in tags.items():
splits = namespaced_key.split("/")
if len(splits) == 2:
namespace, key = splits
if namespace == cls.namespace() and key in model_fields(cls):
kwargs[key] = value

return cls(**kwargs)

@classmethod
@lru_cache(maxsize=None) # this avoids wastefully recomputing this once per instance
def _get_accepted_types_for_field(cls, field_name: str) -> AbstractSet[Type]:
annotation = model_fields(cls)[field_name].annotation
return flatten_unions(annotation)
def _extract_value(cls, field_name: str, value: Any) -> Any:
"""Since all tag values are strings, we don't need to do any type coercion."""
return value

0 comments on commit f076dd7

Please sign in to comment.