diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 8966564b28..f1203c7fc7 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1663,6 +1663,18 @@ def _register_default_type_transformers(): ) ) + TypeEngine.register( + SimpleTransformer( + "date", + _datetime.date, + _type_models.LiteralType(simple=_type_models.SimpleType.DATETIME), + lambda x: Literal( + scalar=Scalar(primitive=Primitive(datetime=_datetime.datetime.combine(x, _datetime.time.min))) + ), # convert datetime to date + lambda x: x.scalar.primitive.datetime.date(), # get date from datetime + ) + ) + TypeEngine.register( SimpleTransformer( "none", diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 7332d01631..411ffa85dc 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -90,6 +90,7 @@ def test_type_resolution(): assert type(TypeEngine.get_transformer(dict)) == DictTransformer assert type(TypeEngine.get_transformer(int)) == SimpleTransformer + assert type(TypeEngine.get_transformer(datetime.date)) == SimpleTransformer assert type(TypeEngine.get_transformer(os.PathLike)) == FlyteFilePathTransformer assert type(TypeEngine.get_transformer(FlytePickle)) == FlytePickleTransformer @@ -323,6 +324,7 @@ def recursive_assert(lit: LiteralType, expected: LiteralType, expected_depth: in recursive_assert(d.get_literal_type(typing.Dict[str, int]), LiteralType(simple=SimpleType.INTEGER)) recursive_assert(d.get_literal_type(typing.Dict[str, datetime.datetime]), LiteralType(simple=SimpleType.DATETIME)) recursive_assert(d.get_literal_type(typing.Dict[str, datetime.timedelta]), LiteralType(simple=SimpleType.DURATION)) + recursive_assert(d.get_literal_type(typing.Dict[str, datetime.date]), LiteralType(simple=SimpleType.DATETIME)) recursive_assert(d.get_literal_type(typing.Dict[str, dict]), LiteralType(simple=SimpleType.STRUCT)) recursive_assert( d.get_literal_type(typing.Dict[str, typing.Dict[str, str]]),