Skip to content

Commit

Permalink
#3026: add type datetime.date (#1786)
Browse files Browse the repository at this point in the history
Signed-off-by: troychiu <y.troychiu@gmail.com>
  • Loading branch information
troychiu authored and Fabio Grätz committed Aug 14, 2023
1 parent 2010be6 commit f520dc1
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 0 deletions.
12 changes: 12 additions & 0 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1663,6 +1663,18 @@ def _register_default_type_transformers():
)
)

TypeEngine.register(
SimpleTransformer(
"date",
_datetime.date,
_type_models.LiteralType(simple=_type_models.SimpleType.DATETIME),
lambda x: Literal(
scalar=Scalar(primitive=Primitive(datetime=_datetime.datetime.combine(x, _datetime.time.min)))
), # convert datetime to date
lambda x: x.scalar.primitive.datetime.date(), # get date from datetime
)
)

TypeEngine.register(
SimpleTransformer(
"none",
Expand Down
2 changes: 2 additions & 0 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def test_type_resolution():
assert type(TypeEngine.get_transformer(dict)) == DictTransformer

assert type(TypeEngine.get_transformer(int)) == SimpleTransformer
assert type(TypeEngine.get_transformer(datetime.date)) == SimpleTransformer

assert type(TypeEngine.get_transformer(os.PathLike)) == FlyteFilePathTransformer
assert type(TypeEngine.get_transformer(FlytePickle)) == FlytePickleTransformer
Expand Down Expand Up @@ -323,6 +324,7 @@ def recursive_assert(lit: LiteralType, expected: LiteralType, expected_depth: in
recursive_assert(d.get_literal_type(typing.Dict[str, int]), LiteralType(simple=SimpleType.INTEGER))
recursive_assert(d.get_literal_type(typing.Dict[str, datetime.datetime]), LiteralType(simple=SimpleType.DATETIME))
recursive_assert(d.get_literal_type(typing.Dict[str, datetime.timedelta]), LiteralType(simple=SimpleType.DURATION))
recursive_assert(d.get_literal_type(typing.Dict[str, datetime.date]), LiteralType(simple=SimpleType.DATETIME))
recursive_assert(d.get_literal_type(typing.Dict[str, dict]), LiteralType(simple=SimpleType.STRUCT))
recursive_assert(
d.get_literal_type(typing.Dict[str, typing.Dict[str, str]]),
Expand Down

0 comments on commit f520dc1

Please sign in to comment.