Skip to content

Commit

Permalink
Fix invocation on ops that use generic dynamic outputs (#8133)
Browse files Browse the repository at this point in the history
* Add invocation tests for generic outputs

* Fix dynamic output testing
  • Loading branch information
dpeng817 committed May 31, 2022
1 parent dc922f9 commit a3cf6c9
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import inspect
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from typing import TYPE_CHECKING, Any, List, Optional, Union, cast

import dagster._check as check
from dagster.core.errors import (
Expand All @@ -9,6 +9,7 @@
)

from .events import (
DEFAULT_OUTPUT,
AssetMaterialization,
AssetObservation,
DynamicOutput,
Expand Down Expand Up @@ -355,7 +356,33 @@ def _type_check_output(

op_label = context.describe_op()

if isinstance(output, (Output, DynamicOutput)):
if isinstance(output, list) and all([isinstance(inner, DynamicOutput) for inner in output]):
dagster_type = output_def.dagster_type
output_list = cast(List[DynamicOutput], output)
for dyn_output in output_list:
if (
not dyn_output.output_name == DEFAULT_OUTPUT
and dyn_output.output_name != output_def.name
):
raise DagsterInvariantViolationError(
f"Received dynamic output with name '{dyn_output.output_name}' that does not exist."
)
type_check = do_type_check(
context.for_type(dagster_type), dagster_type, dyn_output.value
)
if not type_check.success:
raise DagsterTypeCheckDidNotPass(
description=(
f'Type check failed for {op_label} dynamic output "{dyn_output.output_name}" with mapping key "{dyn_output.mapping_key}" - '
f'expected type "{dagster_type.display_name}". '
f"Description: {type_check.description}"
),
metadata_entries=type_check.metadata_entries,
dagster_type=dagster_type,
)
context.observe_output(output_def.name, dyn_output.mapping_key)
return output
elif isinstance(output, (Output, DynamicOutput)):
dagster_type = output_def.dagster_type
type_check = do_type_check(context.for_type(dagster_type), dagster_type, output.value)
if not type_check.success:
Expand All @@ -370,7 +397,7 @@ def _type_check_output(
)

context.observe_output(
output.output_name, output.mapping_key if isinstance(output, DynamicOutput) else None
output_def.name, output.mapping_key if isinstance(output, DynamicOutput) else None
)
return output
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,10 @@ def the_op() -> Output[str]:
assert result.success
assert result.output_for_node("the_op") == "foo"

result = the_op()
assert isinstance(result, Output)
assert result.value == "foo"

@op
def the_op_bad_type_match() -> Output[int]:
return Output("foo")
Expand All @@ -822,6 +826,13 @@ def the_op_bad_type_match() -> Output[int]:
):
execute_op_in_graph(the_op_bad_type_match)

with pytest.raises(
DagsterTypeCheckDidNotPass,
match='Type check failed for op "the_op_bad_type_match" output "result" - expected type '
'"Int". Description: Value "foo" of python type "str" must be a int.',
):
the_op_bad_type_match()


def test_output_generic_correct_inner_type():
@op
Expand All @@ -831,13 +842,19 @@ def the_op_not_using_output() -> Output[int]:
result = execute_op_in_graph(the_op_not_using_output)
assert result.success

assert the_op_not_using_output() == 42

@op
def the_op_annotation_not_using_output() -> int:
return Output(42)

result = execute_op_in_graph(the_op_annotation_not_using_output)
assert result.success

result = the_op_annotation_not_using_output()
assert isinstance(result, Output)
assert result.value == 42


def test_output_generic_type_mismatches():
@op
Expand All @@ -850,6 +867,12 @@ def the_op_annotation_type_mismatch() -> int:
):
execute_op_in_graph(the_op_annotation_type_mismatch)

with pytest.raises(
DagsterTypeCheckDidNotPass,
match='Type check failed for op "the_op_annotation_type_mismatch" output "result" - expected type "Int". Description: Value "foo" of python type "str" must be a int.',
):
the_op_annotation_type_mismatch()

@op
def the_op_output_annotation_type_mismatch() -> Output[int]:
return "foo"
Expand All @@ -860,6 +883,12 @@ def the_op_output_annotation_type_mismatch() -> Output[int]:
):
execute_op_in_graph(the_op_output_annotation_type_mismatch)

with pytest.raises(
DagsterTypeCheckDidNotPass,
match='Type check failed for op "the_op_output_annotation_type_mismatch" output "result" - expected type "Int". Description: Value "foo" of python type "str" must be a int.',
):
the_op_output_annotation_type_mismatch()


def test_generic_output_tuple_op():
@op(out={"out1": Out(), "out2": Out()})
Expand All @@ -869,6 +898,12 @@ def the_op() -> Tuple[Output[str], Output[int]]:
result = execute_op_in_graph(the_op)
assert result.success

result1, result2 = the_op()
assert isinstance(result1, Output)
assert result1.value == "foo"
assert isinstance(result2, Output)
assert result2.value == 5

@op(out={"out1": Out(), "out2": Out()})
def the_op_bad_type_match() -> Tuple[Output[str], Output[int]]:
return (Output("foo"), Output("foo"))
Expand All @@ -880,6 +915,12 @@ def the_op_bad_type_match() -> Tuple[Output[str], Output[int]]:
):
execute_op_in_graph(the_op_bad_type_match)

with pytest.raises(
DagsterTypeCheckDidNotPass,
match='Type check failed for op "the_op_bad_type_match" output "result" - expected type "Int". Description: Value "foo" of python type "str" must be a int.',
):
the_op_bad_type_match()


def test_generic_output_tuple_complex_types():
@op(out={"out1": Out(), "out2": Out()})
Expand All @@ -889,6 +930,13 @@ def the_op() -> Tuple[Output[List[str]], Output[Dict[str, str]]]:
result = execute_op_in_graph(the_op)
assert result.success

result1, result2 = the_op()
assert isinstance(result1, Output)
assert isinstance(result2, Output)

assert result1.value == ["foo"]
assert result2.value == {"foo": "bar"}


def test_generic_output_name_mismatch():
@op(out={"out1": Out(), "out2": Out()})
Expand All @@ -901,6 +949,12 @@ def the_op() -> Tuple[Output[int], Output[str]]:
):
execute_op_in_graph(the_op)

with pytest.raises(
DagsterTypeCheckDidNotPass,
match='Type check failed for op "the_op" output "out2" - expected type "Int". Description: Value "foo" of python type "str" must be a int.',
):
the_op()


def test_generic_dynamic_output():
@op
Expand All @@ -911,6 +965,12 @@ def basic() -> List[DynamicOutput[int]]:
assert result.success
assert result.output_for_node("basic") == {"1": 1, "2": 2}

result = basic()
assert len(result) == 2
out1, out2 = result
assert out1.value == 1
assert out2.value == 2


def test_generic_dynamic_output_type_mismatch():
@op
Expand All @@ -923,6 +983,12 @@ def basic() -> List[DynamicOutput[int]]:
):
execute_op_in_graph(basic)

with pytest.raises(
DagsterTypeCheckDidNotPass,
match='Type check failed for op "basic" dynamic output "result" with mapping key "2" - expected type "Int". Description: Value "2" of python type "str" must be a int.',
):
basic()


def test_generic_dynamic_output_mix_with_regular():
@op(out={"regular": Out(), "dynamic": DynamicOut()})
Expand All @@ -941,6 +1007,14 @@ def basic() -> Tuple[Output[int], List[DynamicOutput[str]]]:
assert result.output_for_node("basic", "regular") == 5
assert result.output_for_node("basic", "dynamic") == {"1": "foo", "2": "bar"}

non_dynamic, dynamic = basic()
assert isinstance(non_dynamic, Output)
assert non_dynamic.value == 5
assert isinstance(dynamic, list)
d_out1, d_out2 = dynamic
assert d_out1.value == "foo"
assert d_out2.value == "bar"


def test_generic_dynamic_output_mix_with_regular_type_mismatch():
@op(out={"regular": Out(), "dynamic": DynamicOut()})
Expand All @@ -959,6 +1033,12 @@ def basic() -> Tuple[Output[int], List[DynamicOutput[str]]]:
):
execute_op_in_graph(basic)

with pytest.raises(
DagsterTypeCheckDidNotPass,
match='Type check failed for op "basic" dynamic output "result" with mapping key "2" - expected type "String". Description: Value "5" of python type "int" must be a string.',
):
basic()


def test_generic_dynamic_output_name_not_provided():
@op
Expand All @@ -971,6 +1051,12 @@ def basic() -> List[DynamicOutput[int]]:
):
execute_op_in_graph(basic)

with pytest.raises(
DagsterInvariantViolationError,
match="Received dynamic output with name 'blah' that does not exist.",
):
basic()


def test_generic_dynamic_output_name_mismatch():
@op(out={"the_name": DynamicOut()})
Expand All @@ -983,6 +1069,12 @@ def basic() -> List[DynamicOutput[int]]:
):
execute_op_in_graph(basic)

with pytest.raises(
DagsterInvariantViolationError,
match="Received dynamic output with name 'bad_name' that does not exist.",
):
basic()


def test_generic_dynamic_output_bare_list():
@op
Expand All @@ -993,6 +1085,10 @@ def basic() -> List[DynamicOutput]:
assert result.success
assert result.output_for_node("basic") == {"1": 4}

result = basic()
assert isinstance(result, list)
assert result[0].value == 4


def test_generic_dynamic_output_bare():

Expand Down Expand Up @@ -1029,6 +1125,9 @@ def basic() -> List[DynamicOutput]:
):
result.output_for_node("basic")

result = basic()
assert isinstance(result, list)

# This behavior isn't exactly correct - we should be erroring when a
# required dynamic output yields no outputs.
# https://github.com/dagster-io/dagster/issues/5948#issuecomment-997037163
Expand All @@ -1039,6 +1138,9 @@ def basic_yield():
result = execute_op_in_graph(basic_yield)
assert result.success

# Ensure that invocation behavior matches
basic_yield()


def test_generic_dynamic_output_empty_with_type():
@op
Expand All @@ -1058,6 +1160,9 @@ def basic_yield():
result = execute_op_in_graph(basic_yield)
assert result.success

# Ensure that invocation behavior matches
basic()


def test_generic_dynamic_multiple_outputs_empty():
@op(out={"out1": Out(), "out2": DynamicOut()})
Expand All @@ -1072,3 +1177,7 @@ def basic() -> Tuple[Output, List[DynamicOutput]]:
match="No outputs found for output 'out2' from node 'basic'.",
):
result.output_for_node("basic", "out2")

out1, out2 = basic()
assert isinstance(out1, Output)
assert isinstance(out2, list)

0 comments on commit a3cf6c9

Please sign in to comment.