Skip to content

Commit

Permalink
Fix primitive decoder when evaluating Promise (#1432)
Browse files Browse the repository at this point in the history
Signed-off-by: Samhita Alla <aallasamhita@gmail.com>
  • Loading branch information
samhita-alla committed Feb 6, 2023
1 parent d006df6 commit a5c9970
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 12 deletions.
22 changes: 10 additions & 12 deletions flytekit/core/promise.py
Expand Up @@ -69,7 +69,6 @@ def extract_value(
val_type: type,
flyte_literal_type: _type_models.LiteralType,
) -> _literal_models.Literal:

if isinstance(input_val, list):
lt = flyte_literal_type
python_type = val_type
Expand Down Expand Up @@ -142,17 +141,16 @@ def extract_value(


def get_primitive_val(prim: Primitive) -> Any:
if prim.integer:
return prim.integer
if prim.datetime:
return prim.datetime
if prim.boolean:
return prim.boolean
if prim.duration:
return prim.duration
if prim.string_value:
return prim.string_value
return prim.float_value
for value in [
prim.integer,
prim.float_value,
prim.string_value,
prim.boolean,
prim.datetime,
prim.duration,
]:
if value is not None:
return value


class ConjunctionOps(Enum):
Expand Down
16 changes: 16 additions & 0 deletions tests/flytekit/unit/core/test_conditions.py
Expand Up @@ -71,6 +71,22 @@ def multiplier_2(my_input: float) -> float:
multiplier_2(my_input=10.0)


def test_condition_else_int():
@workflow
def multiplier_3(my_input: int) -> float:
return (
conditional("fractions")
.if_((my_input >= 0) & (my_input < 1.0))
.then(double(n=my_input))
.elif_((my_input > 1.0) & (my_input < 10.0))
.then(square(n=my_input))
.else_()
.fail("The input must be between 0 and 10")
)

assert multiplier_3(my_input=0) == 0


def test_condition_sub_workflows():
@task
def sum_div_sub(a: int, b: int) -> typing.NamedTuple("Outputs", sum=int, div=int, sub=int):
Expand Down

0 comments on commit a5c9970

Please sign in to comment.