Skip to content

Commit

Permalink
Revert "Make Output generic (#7202)" (#7715)
Browse files Browse the repository at this point in the history
This reverts commit 671cea6.
  • Loading branch information
dpeng817 committed May 3, 2022
1 parent 439b4d5 commit 8566057
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 154 deletions.
10 changes: 2 additions & 8 deletions python_modules/dagster/dagster/core/definitions/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,11 @@
Any,
Callable,
Dict,
Generic,
List,
Mapping,
NamedTuple,
Optional,
Sequence,
TypeVar,
Union,
cast,
)
Expand Down Expand Up @@ -178,9 +176,6 @@ def __new__(cls, asset_key, partitions=None):
return super(AssetLineageInfo, cls).__new__(cls, asset_key=asset_key, partitions=partitions)


T = TypeVar("T")


class Output(
NamedTuple(
"_Output",
Expand All @@ -189,8 +184,7 @@ class Output(
("output_name", str),
("metadata_entries", List[Union[PartitionMetadataEntry, MetadataEntry]]),
],
),
Generic[T],
)
):
"""Event corresponding to one of a op's outputs.
Expand All @@ -216,7 +210,7 @@ class Output(

def __new__(
cls,
value: T,
value: Any,
output_name: Optional[str] = DEFAULT_OUTPUT,
metadata_entries: Optional[Sequence[Union[MetadataEntry, PartitionMetadataEntry]]] = None,
metadata: Optional[Dict[str, RawMetadataValue]] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ def def_from_pointer(
# if its a function invoke it - otherwise we are pointing to a
# artifact in module scope, likely decorator output

if seven.get_arg_names(target):
if seven.get_args(target):
raise DagsterInvariantViolationError(
"Error invoking function at {target} with no arguments. "
"Reconstructable target must be callable with no arguments".format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def _coerce_solid_compute_fn_to_iterator(fn, output_defs, context, context_arg_p


def _validate_and_coerce_solid_result_to_iterator(result, context, output_defs):
from dagster.core.definitions.events import DEFAULT_OUTPUT

if isinstance(result, (AssetMaterialization, Materialization, ExpectationResult)):
raise DagsterInvariantViolationError(
Expand Down Expand Up @@ -109,32 +108,9 @@ def _validate_and_coerce_solid_result_to_iterator(result, context, output_defs):
f"returned a tuple with {len(result)} elements"
)

for position, (output_def, element) in enumerate(zip(output_defs, result)):
# If an output object was provided directly, ensure that it matches
# with expected order from provided output definitions.
if isinstance(element, Output):
# If a name was explicitly provided on the output object, and
# that name does not match the name expected at this position,
# then throw an error.
if (
not element.output_name == DEFAULT_OUTPUT
and not element.output_name == output_def.name
):
raise DagsterInvariantViolationError(
f"Bad state: Received a tuple of outputs. An output was "
f"explicitly named '{element.output_name}', which does "
"not match the output definition specified for "
f"position {position}: '{output_def.name}'."
)
yield Output(
output_name=output_def.name,
value=element.value,
metadata_entries=element.metadata_entries,
)
else:
# If an output object was not returned, then construct one from any metadata that has been logged within the op's body.
metadata = context.get_output_metadata(output_def.name)
yield Output(output_name=output_def.name, value=element, metadata=metadata)
for output_def, element in zip(output_defs, result):
metadata = context.get_output_metadata(output_def.name)
yield Output(output_name=output_def.name, value=element, metadata=metadata)
elif result is not None:
if not output_defs:
raise DagsterInvariantViolationError(
Expand Down
14 changes: 5 additions & 9 deletions python_modules/dagster/dagster/core/types/dagster_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dagster.builtins import BuiltinEnum
from dagster.config.config_type import Array, ConfigType
from dagster.config.config_type import Noneable as ConfigNoneable
from dagster.core.definitions.events import Output, TypeCheck
from dagster.core.definitions.events import TypeCheck
from dagster.core.definitions.metadata import MetadataEntry, RawMetadataValue, normalize_metadata
from dagster.core.errors import DagsterInvalidDefinitionError, DagsterInvariantViolationError
from dagster.serdes import whitelist_for_serdes
Expand Down Expand Up @@ -227,9 +227,9 @@ def get_inner_type_for_fan_in(self) -> "DagsterType":


def _validate_type_check_fn(fn: t.Callable, name: t.Optional[str]) -> bool:
from dagster.seven import get_arg_names
from dagster.seven import get_args

args = get_arg_names(fn)
args = get_args(fn)

# py2 doesn't filter out self
if len(args) >= 1 and args[0] == "self":
Expand Down Expand Up @@ -804,7 +804,6 @@ def resolve_dagster_type(dagster_type: object) -> DagsterType:
is_supported_runtime_python_builtin,
remap_python_builtin_for_runtime,
)
from dagster.seven.typing import get_args, get_origin
from dagster.utils.typing_api import is_typing_type

from .python_dict import Dict, PythonDict
Expand All @@ -822,13 +821,10 @@ def resolve_dagster_type(dagster_type: object) -> DagsterType:
"Do not pass runtime type classes. Got {}".format(dagster_type),
)

# First, check to see if we're using Dagster's generic output type to do the type catching.
if get_origin(dagster_type) == Output:
dagster_type = get_args(dagster_type)[0]

# Then, check to see if it is part of python's typing library
# First check to see if it is part of python's typing library
if is_typing_type(dagster_type):
dagster_type = transform_typing_type(dagster_type)

if isinstance(dagster_type, DagsterType):
return dagster_type

Expand Down
2 changes: 1 addition & 1 deletion python_modules/dagster/dagster/seven/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def is_ascii(str_):
time_fn = time.perf_counter


def get_arg_names(callable_):
def get_args(callable_):
return [
parameter.name
for parameter in inspect.signature(callable_).parameters.values()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# type: ignore[return-value]
import time
from typing import Dict, Generator, List, Tuple

Expand Down Expand Up @@ -796,105 +795,3 @@ def the_graph_provides_inputs():

result = the_graph_provides_inputs.execute_in_process()
assert result.success


def test_generic_output_op():
@op
def the_op() -> Output[str]:
return Output("foo")

assert the_op.output_def_named("result").dagster_type.key == "String"

result = execute_op_in_graph(the_op)
assert result.success
assert result.output_for_node("the_op") == "foo"

@op
def the_op_bad_type_match() -> Output[int]:
return Output("foo")

with pytest.raises(
DagsterTypeCheckDidNotPass,
match='Type check failed for step output "result" - expected type '
'"Int". Description: Value "foo" of python type "str" must be a int.',
):
execute_op_in_graph(the_op_bad_type_match)


def test_output_generic_correct_inner_type():
@op
def the_op_not_using_output() -> Output[int]:
return 42

result = execute_op_in_graph(the_op_not_using_output)
assert result.success

@op
def the_op_annotation_not_using_output() -> int:
return Output(42)

result = execute_op_in_graph(the_op_annotation_not_using_output)
assert result.success


def test_output_generic_type_mismatches():
@op
def the_op_annotation_type_mismatch() -> int:
return Output("foo")

with pytest.raises(
DagsterTypeCheckDidNotPass,
match='Type check failed for step output "result" - expected type "Int". Description: Value "foo" of python type "str" must be a int.',
):
execute_op_in_graph(the_op_annotation_type_mismatch)

@op
def the_op_output_annotation_type_mismatch() -> Output[int]:
return "foo"

with pytest.raises(
DagsterTypeCheckDidNotPass,
match='Type check failed for step output "result" - expected type "Int". Description: Value "foo" of python type "str" must be a int.',
):
execute_op_in_graph(the_op_output_annotation_type_mismatch)


def test_generic_output_tuple_op():
@op(out={"out1": Out(), "out2": Out()})
def the_op() -> Tuple[Output[str], Output[int]]:
return (Output("foo"), Output(5))

result = execute_op_in_graph(the_op)
assert result.success

@op(out={"out1": Out(), "out2": Out()})
def the_op_bad_type_match() -> Tuple[Output[str], Output[int]]:
return (Output("foo"), Output("foo"))

with pytest.raises(
DagsterTypeCheckDidNotPass,
match='Type check failed for step output "out2" - expected type "Int". '
'Description: Value "foo" of python type "str" must be a int.',
):
execute_op_in_graph(the_op_bad_type_match)


def test_generic_output_tuple_complex_types():
@op(out={"out1": Out(), "out2": Out()})
def the_op() -> Tuple[Output[List[str]], Output[Dict[str, str]]]:
return (Output(["foo"]), Output({"foo": "bar"}))

result = execute_op_in_graph(the_op)
assert result.success


def test_generic_output_name_mismatch():
@op(out={"out1": Out(), "out2": Out()})
def the_op() -> Tuple[Output[int], Output[str]]:
return (Output("foo", output_name="out2"), Output(42, output_name="out1"))

with pytest.raises(
DagsterInvariantViolationError,
match="Bad state: Received a tuple of outputs. An output was explicitly named 'out2', which does not match the output definition specified for position 0: 'out1'.",
):
execute_op_in_graph(the_op)
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@ def test_tempdir():
assert not seven.temp_dir.get_system_temp_directory().startswith("/var")


def test_get_arg_names():
def test_get_args():
def foo(one, two=2, three=None): # pylint: disable=unused-argument
pass

assert len(seven.get_arg_names(foo)) == 3
assert "one" in seven.get_arg_names(foo)
assert "two" in seven.get_arg_names(foo)
assert "three" in seven.get_arg_names(foo)
assert len(seven.get_args(foo)) == 3
assert "one" in seven.get_args(foo)
assert "two" in seven.get_args(foo)
assert "three" in seven.get_args(foo)


def test_is_lambda():
Expand Down

0 comments on commit 8566057

Please sign in to comment.