Skip to content

Commit

Permalink
TypeAnnotation (#759)
Browse files Browse the repository at this point in the history
* feat:  support for annotated simple + list

Signed-off-by: Kenny Workman <kennyworkman@sbcglobal.net>

* feat: addition of annotation att to

Signed-off-by: Kenny Workman <kennyworkman@sbcglobal.net>

* feat: core  obj

Signed-off-by: Kenny Workman <kennyworkman@sbcglobal.net>

* feat:  proto model

Signed-off-by: Kenny Workman <kennyworkman@sbcglobal.net>

* feat: testing suite

Signed-off-by: Kenny Workman <kennyworkman@sbcglobal.net>

* fix: more stable typing introspection

Signed-off-by: Kenny Workman <kennyworkman@sbcglobal.net>

* fix: strip legacy

Signed-off-by: Kenny Workman <kennyworkman@sbcglobal.net>

* fix: explicitly allow only one annotation

Signed-off-by: Kenny Workman <kennyworkman@sbcglobal.net>

* feat: direct type transformer tests

Signed-off-by: Kenny Workman <kennyworkman@sbcglobal.net>

* fix: there and back test

Signed-off-by: Kenny Workman <kennyworkman@sbcglobal.net>

* fix: typing_extensions for get_origin

Signed-off-by: Kenny Workman <kennyworkman@sbcglobal.net>

* fix: more semantic list generic unwrap

Signed-off-by: Kenny Workman <kennyworkman@sbcglobal.net>

* fix: tmp requirements file with custom idl

Signed-off-by: Kenny Workman <kennyworkman@sbcglobal.net>

* fix: nits

Signed-off-by: Kenny Workman <kennyworkman@sbcglobal.net>

* feat: semantic error for unsupported complex literals

Signed-off-by: Kenny Workman <kennyworkman@sbcglobal.net>

* fix: but

Signed-off-by: Kenny Workman <kennyworkman@sbcglobal.net>

* feat: more tests ;)

Signed-off-by: Kenny Workman <kennyworkman@sbcglobal.net>

* fix: imports

Signed-off-by: Kenny Workman <kennyworkman@sbcglobal.net>

* fix: complex annotations

Signed-off-by: Kenny Workman <kennyworkman@sbcglobal.net>

* fix: temp requirements files for unit tests

Signed-off-by: Kenny Workman <kennyworkman@sbcglobal.net>

* fix: lint bug

Signed-off-by: Kenny Workman <kennyworkman@sbcglobal.net>

* fix: tmp setup.py

Signed-off-by: Kenny Workman <kennyworkman@sbcglobal.net>

* fix: use typing_extensions

Signed-off-by: Kenny Workman <kennyworkman@sbcglobal.net>

* fix: typing_extensions for annotated

Signed-off-by: Kenny Workman <kennyworkman@sbcglobal.net>

* fix: typing_ext

Signed-off-by: Kenny Workman <kennyworkman@sbcglobal.net>

* fix: plugin tmp requirements

Signed-off-by: Kenny Workman <kennyworkman@sbcglobal.net>

* fix: bump requirements

Signed-off-by: Kenny Workman <kennyworkman@sbcglobal.net>

* fix: doc requirements

Signed-off-by: Kenny Workman <kennyworkman@sbcglobal.net>

* fix: whitespace

Signed-off-by: Kenny Workman <kennyworkman@sbcglobal.net>

* fix: bump flytekit

Signed-off-by: Kenny Workman <kennyworkman@sbcglobal.net>

* fix: numpy version

Signed-off-by: Kenny Workman <kennyworkman@sbcglobal.net>

* fix: lint

Signed-off-by: Kenny Workman <kennyworkman@sbcglobal.net>

* fix: pandas version

Signed-off-by: Kenny Workman <kennyworkman@sbcglobal.net>

* fix: bump requirements

Signed-off-by: Kenny Workman <kennyworkman@sbcglobal.net>

* fix: test import

Signed-off-by: Kenny Workman <kennyworkman@sbcglobal.net>

* fix: flake8 lint

Signed-off-by: Kenny Workman <kennyworkman@sbcglobal.net>

* fix: merge

Signed-off-by: Kenny Workman <kennyworkman@sbcglobal.net>

* fix: requirements

Signed-off-by: Kenny Workman <kennyworkman@sbcglobal.net>

* fix: requirements

Signed-off-by: Kenny Workman <kennyworkman@sbcglobal.net>

* fix: lint

Signed-off-by: Kenny Workman <kennyworkman@sbcglobal.net>

* fix: papermill req

Signed-off-by: Kenny Workman <kennyworkman@sbcglobal.net>

* fix: req

Signed-off-by: Kenny Workman <kennyworkman@sbcglobal.net>
  • Loading branch information
kennyworkman committed Feb 8, 2022
1 parent a3a9684 commit d249cb2
Show file tree
Hide file tree
Showing 28 changed files with 300 additions and 139 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Expand Up @@ -3,7 +3,7 @@
*.pyt
*.pytc
*.egg-info
.*.swp
.*.sw*
.DS_Store
venv/
.venv/
Expand Down
1 change: 0 additions & 1 deletion dev-requirements.txt
Expand Up @@ -32,7 +32,6 @@ certifi==2021.10.8
# requests
cffi==1.15.0
# via
# -c requirements.txt
# bcrypt
# cryptography
# pynacl
Expand Down
30 changes: 30 additions & 0 deletions flytekit/core/annotation.py
@@ -0,0 +1,30 @@
from typing import Any, Dict


class FlyteAnnotation:
"""A core object to add arbitrary annotations to flyte types.
This metadata is ingested as a python dictionary and will be serialized
into fields on the flyteidl type literals. This data is not accessible at
runtime but rather can be retrieved from flyteadmin for custom presentation
of typed parameters.
Flytekit expects to receive a maximum of one `FlyteAnnotation` object
within each typehint.
For a task definition:
.. code-block:: python
@task
def x(a: typing.Annotated[int, FlyteAnnotation({"foo": {"bar": 1}})]):
return
"""

def __init__(self, data: Dict[str, Any]):
self._data = data

@property
def data(self):
return self._data
74 changes: 67 additions & 7 deletions flytekit/core/type_engine.py
Expand Up @@ -25,12 +25,14 @@
from marshmallow_enum import EnumField, LoadDumpOptions
from marshmallow_jsonschema import JSONSchema

from flytekit.core.annotation import FlyteAnnotation
from flytekit.core.context_manager import FlyteContext
from flytekit.core.type_helpers import load_type_from_tag
from flytekit.exceptions import user as user_exceptions
from flytekit.loggers import logger
from flytekit.models import interface as _interface_models
from flytekit.models import types as _type_models
from flytekit.models.annotation import TypeAnnotation as TypeAnnotationModel
from flytekit.models.core import types as _core_types
from flytekit.models.literals import (
Blob,
Expand Down Expand Up @@ -235,6 +237,12 @@ def get_literal_type(self, t: Type[T]) -> LiteralType:
Extracts the Literal type definition for a Dataclass and returns a type Struct.
If possible also extracts the JSONSchema for the dataclass.
"""
if get_origin(t) is Annotated:
raise ValueError(
"Flytekit does not currently have support for FlyteAnnotations applied to Dataclass."
f"Type {t} cannot be parsed."
)

if not issubclass(t, DataClassJsonMixin):
raise AssertionError(
f"Dataclass {t} should be decorated with @dataclass_json to be " f"serialized correctly"
Expand Down Expand Up @@ -551,6 +559,7 @@ def get_transformer(cls, python_type: Type) -> TypeTransformer[T]:
TODO lets make this deterministic by using an ordered dict
"""

# Step 1
if get_origin(python_type) is Annotated:
python_type = get_args(python_type)[0]
Expand All @@ -560,8 +569,14 @@ def get_transformer(cls, python_type: Type) -> TypeTransformer[T]:

# Step 2
if hasattr(python_type, "__origin__"):
# Handling of annotated generics, eg:
# Annotated[typing.List[int], 'foo']
if get_origin(python_type) is Annotated:
return cls.get_transformer(get_args(python_type)[0])

if python_type.__origin__ in cls._REGISTRY:
return cls._REGISTRY[python_type.__origin__]

raise ValueError(f"Generic Type {python_type.__origin__} not supported currently in Flytekit.")

# Step 3
Expand Down Expand Up @@ -590,7 +605,30 @@ def to_literal_type(cls, python_type: Type) -> LiteralType:
Converts a python type into a flyte specific ``LiteralType``
"""
transformer = cls.get_transformer(python_type)
return transformer.get_literal_type(python_type)
res = transformer.get_literal_type(python_type)
data = None
if get_origin(python_type) is Annotated:
for x in get_args(python_type)[1:]:
if not isinstance(x, FlyteAnnotation):
continue
if data is not None:
raise ValueError(
f"More than one FlyteAnnotation used within {python_type} typehint. Flytekit requires a max of one."
)
data = x.data
if data is not None:
idl_type_annotation = TypeAnnotationModel(annotations=data)
return LiteralType(
simple=res.simple,
schema=res.schema,
collection_type=res.collection_type,
map_value_type=res.map_value_type,
blob=res.blob,
enum_type=res.enum_type,
metadata=res.metadata,
annotation=idl_type_annotation,
)
return res

@classmethod
def to_literal(cls, ctx: FlyteContext, python_val: typing.Any, python_type: Type, expected: LiteralType) -> Literal:
Expand Down Expand Up @@ -718,9 +756,16 @@ def get_sub_type(t: Type[T]) -> Type[T]:
"""
Return the generic Type T of the List
"""
if hasattr(t, "__origin__") and t.__origin__ is list: # type: ignore
if hasattr(t, "__args__"):
return t.__args__[0] # type: ignore

if hasattr(t, "__origin__"):
# Handle annotation on list generic, eg:
# Annotated[typing.List[int], 'foo']
if get_origin(t) is Annotated:
return ListTransformer.get_sub_type(get_args(t)[0])

if t.__origin__ is list and hasattr(t, "__args__"):
return t.__args__[0]

raise ValueError("Only generic univariate typing.List[T] type is supported.")

def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]:
Expand Down Expand Up @@ -763,9 +808,17 @@ def get_dict_types(t: Optional[Type[dict]]) -> typing.Tuple[Optional[type], Opti
"""
Return the generic Type T of the Dict
"""
if hasattr(t, "__origin__") and t.__origin__ is dict: # type: ignore
if hasattr(t, "__args__"):
return t.__args__ # type: ignore
_origin = get_origin(t)
_args = get_args(t)
if _origin is not None:
if _origin is Annotated:
raise ValueError(
f"Flytekit does not currently have support \
for FlyteAnnotations applied to dicts. {t} cannot be \
parsed."
)
if _origin is dict and _args is not None:
return _args
return None, None

@staticmethod
Expand Down Expand Up @@ -913,6 +966,13 @@ def __init__(self):
super().__init__(name="DefaultEnumTransformer", t=enum.Enum)

def get_literal_type(self, t: Type[T]) -> LiteralType:
if get_origin(t) is Annotated:
raise ValueError(
f"Flytekit does not currently have support \
for FlyteAnnotations applied to enums. {t} cannot be \
parsed."
)

values = [v.value for v in t] # type: ignore
if not isinstance(values[0], str):
raise AssertionError("Only EnumTypes with value of string are supported")
Expand Down
43 changes: 43 additions & 0 deletions flytekit/models/annotation.py
@@ -0,0 +1,43 @@
import json as _json
from typing import Any, Dict

from flyteidl.core import types_pb2 as _types_pb2
from google.protobuf import json_format as _json_format
from google.protobuf import struct_pb2 as _struct


class TypeAnnotation:
"""Python class representation of the flyteidl TypeAnnotation message."""

def __init__(self, annotations: Dict[str, Any]):
self._annotations = annotations

@property
def annotations(self) -> Dict[str, Any]:
"""
:rtype: dict[str, Any]
"""
return self._annotations

def to_flyte_idl(self) -> _types_pb2.TypeAnnotation:
"""
:rtype: flyteidl.core.types_pb2.TypeAnnotation
"""

if self._annotations is not None:
annotations = _json_format.Parse(_json.dumps(self.annotations), _struct.Struct())
else:
annotations = None

return _types_pb2.TypeAnnotation(
annotations=annotations,
)

@classmethod
def from_flyte_idl(cls, proto):
"""
:param flyteidl.core.types_pb2.TypeAnnotation proto:
:rtype: TypeAnnotation
"""

return cls(annotations=_json_format.MessageToDict(proto.annotations))
20 changes: 20 additions & 0 deletions flytekit/models/types.py
Expand Up @@ -6,6 +6,7 @@
from google.protobuf import struct_pb2 as _struct

from flytekit.models import common as _common
from flytekit.models.annotation import TypeAnnotation as TypeAnnotationModel
from flytekit.models.core import types as _core_types


Expand Down Expand Up @@ -195,6 +196,7 @@ def __init__(
enum_type=None,
structured_dataset_type=None,
metadata=None,
annotation=None,
):
"""
This is a oneof message, only one of the kwargs may be set, representing one of the Flyte types.
Expand All @@ -208,6 +210,8 @@ def __init__(
:param flytekit.models.core.types.EnumType enum_type: For enum objects, describes an enum
:param flytekit.models.core.types.StructuredDatasetType structured_dataset_type: structured dataset
:param dict[Text, T] metadata: Additional data describing the type
:param flytekit.models.annotation.FlyteAnnotation annotation: Additional data
describing the type _intended to be saturated by the client_
"""
self._simple = simple
self._schema = schema
Expand All @@ -217,6 +221,7 @@ def __init__(
self._enum_type = enum_type
self._structured_dataset_type = structured_dataset_type
self._metadata = metadata
self._annotation = annotation

@property
def simple(self) -> SimpleType:
Expand Down Expand Up @@ -259,18 +264,31 @@ def metadata(self):
"""
return self._metadata

@property
def annotation(self) -> TypeAnnotationModel:
"""
:rtype: flytekit.models.annotation.TypeAnnotation
"""
return self._annotation

@metadata.setter
def metadata(self, value):
self._metadata = value

@annotation.setter
def annotation(self, value):
self.annotation = value

def to_flyte_idl(self):
"""
:rtype: flyteidl.core.types_pb2.LiteralType
"""

if self.metadata is not None:
metadata = _json_format.Parse(_json.dumps(self.metadata), _struct.Struct())
else:
metadata = None

t = _types_pb2.LiteralType(
simple=self.simple if self.simple is not None else None,
schema=self.schema.to_flyte_idl() if self.schema is not None else None,
Expand All @@ -282,6 +300,7 @@ def to_flyte_idl(self):
if self.structured_dataset_type
else None,
metadata=metadata,
annotation=self.annotation.to_flyte_idl() if self.annotation else None,
)
return t

Expand All @@ -308,6 +327,7 @@ def from_flyte_idl(cls, proto):
if proto.HasField("structured_dataset_type")
else None,
metadata=_json_format.MessageToDict(proto.metadata) or None,
annotation=TypeAnnotationModel.from_flyte_idl(proto.annotation) if proto.HasField("annotation") else None,
)


Expand Down
8 changes: 0 additions & 8 deletions plugins/flytekit-aws-athena/requirements.txt
Expand Up @@ -12,8 +12,6 @@ binaryornot==0.4.4
# via cookiecutter
certifi==2021.10.8
# via requests
cffi==1.15.0
# via cryptography
chardet==4.0.0
# via binaryornot
charset-normalizer==2.0.11
Expand Down Expand Up @@ -54,10 +52,6 @@ idna==3.3
# via requests
importlib-metadata==4.10.1
# via keyring
jeepney==0.7.1
# via
# keyring
# secretstorage
jinja2==3.0.3
# via
# cookiecutter
Expand Down Expand Up @@ -126,8 +120,6 @@ responses==0.18.0
# via flytekit
retry==0.9.2
# via flytekit
secretstorage==3.3.1
# via keyring
six==1.16.0
# via
# cookiecutter
Expand Down
6 changes: 0 additions & 6 deletions plugins/flytekit-aws-sagemaker/requirements.txt
Expand Up @@ -73,10 +73,6 @@ importlib-metadata==4.10.1
# via keyring
inotify-simple==1.2.1
# via sagemaker-training
jeepney==0.7.1
# via
# keyring
# secretstorage
jinja2==3.0.3
# via
# cookiecutter
Expand Down Expand Up @@ -167,8 +163,6 @@ sagemaker-training==3.9.2
# via flytekitplugins-awssagemaker
scipy==1.8.0
# via sagemaker-training
secretstorage==3.3.1
# via keyring
six==1.16.0
# via
# bcrypt
Expand Down
8 changes: 0 additions & 8 deletions plugins/flytekit-data-fsspec/requirements.txt
Expand Up @@ -12,8 +12,6 @@ binaryornot==0.4.4
# via cookiecutter
certifi==2021.10.8
# via requests
cffi==1.15.0
# via cryptography
chardet==4.0.0
# via binaryornot
charset-normalizer==2.0.11
Expand Down Expand Up @@ -56,10 +54,6 @@ idna==3.3
# via requests
importlib-metadata==4.10.1
# via keyring
jeepney==0.7.1
# via
# keyring
# secretstorage
jinja2==3.0.3
# via
# cookiecutter
Expand Down Expand Up @@ -128,8 +122,6 @@ responses==0.18.0
# via flytekit
retry==0.9.2
# via flytekit
secretstorage==3.3.1
# via keyring
six==1.16.0
# via
# cookiecutter
Expand Down

0 comments on commit d249cb2

Please sign in to comment.