Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix primitive decoder when evaluating Promise #1432

Merged
merged 1 commit into from Feb 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 10 additions & 12 deletions flytekit/core/promise.py
Expand Up @@ -69,7 +69,6 @@ def extract_value(
val_type: type,
flyte_literal_type: _type_models.LiteralType,
) -> _literal_models.Literal:

if isinstance(input_val, list):
lt = flyte_literal_type
python_type = val_type
Expand Down Expand Up @@ -142,17 +141,16 @@ def extract_value(


def get_primitive_val(prim: Primitive) -> Any:
if prim.integer:
return prim.integer
if prim.datetime:
return prim.datetime
if prim.boolean:
return prim.boolean
if prim.duration:
return prim.duration
if prim.string_value:
return prim.string_value
return prim.float_value
for value in [
prim.integer,
prim.float_value,
prim.string_value,
prim.boolean,
prim.datetime,
prim.duration,
]:
if value is not None:
return value


class ConjunctionOps(Enum):
Expand Down
16 changes: 16 additions & 0 deletions tests/flytekit/unit/core/test_conditions.py
Expand Up @@ -71,6 +71,22 @@ def multiplier_2(my_input: float) -> float:
multiplier_2(my_input=10.0)


def test_condition_else_int():
@workflow
def multiplier_3(my_input: int) -> float:
return (
conditional("fractions")
.if_((my_input >= 0) & (my_input < 1.0))
.then(double(n=my_input))
.elif_((my_input > 1.0) & (my_input < 10.0))
.then(square(n=my_input))
.else_()
.fail("The input must be between 0 and 10")
)

assert multiplier_3(my_input=0) == 0


def test_condition_sub_workflows():
@task
def sum_div_sub(a: int, b: int) -> typing.NamedTuple("Outputs", sum=int, div=int, sub=int):
Expand Down