Skip to content

Commit

Permalink
Improve error message for missing return (#2551)
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <pingsutw@apache.org>
  • Loading branch information
pingsutw committed Jul 15, 2024
1 parent 9bd0a10 commit 7a6fe0a
Show file tree
Hide file tree
Showing 18 changed files with 142 additions and 79 deletions.
Empty file added flytekit/_ast/__init__.py
Empty file.
25 changes: 25 additions & 0 deletions flytekit/_ast/parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import ast
import inspect
import typing


def get_function_param_location(func: typing.Callable, param_name: str) -> (int, int):
"""
Get the line and column number of the parameter in the source code of the function definition.
"""
# Get source code of the function
source_lines, start_line = inspect.getsourcelines(func)
source_code = "".join(source_lines)

# Parse the source code into an AST
module = ast.parse(source_code)

# Traverse the AST to find the function definition
for node in ast.walk(module):
if isinstance(node, ast.FunctionDef) and node.name == func.__name__:
for i, arg in enumerate(node.args.args):
if arg.arg == param_name:
# Calculate the line and column number of the parameter
line_number = start_line + node.lineno - 1
column_offset = arg.col_offset
return line_number, column_offset
27 changes: 12 additions & 15 deletions flytekit/clis/sdk_in_container/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

from flytekit.core.constants import SOURCE_CODE
from flytekit.exceptions.base import FlyteException
from flytekit.exceptions.user import FlyteInvalidInputException
from flytekit.exceptions.user import FlyteCompilationException, FlyteInvalidInputException
from flytekit.exceptions.utils import annotate_exception_with_code
from flytekit.loggers import get_level_from_cli_verbosity, logger

project_option = click.Option(
Expand Down Expand Up @@ -130,12 +131,14 @@ def pretty_print_traceback(e: Exception, verbosity: int = 1):
else:
raise ValueError(f"Verbosity level must be between 0 and 2. Got {verbosity}")

if hasattr(e, SOURCE_CODE):
# TODO: Use other way to check if the background is light or dark
theme = "emacs" if "LIGHT_BACKGROUND" in os.environ else "monokai"
syntax = Syntax(getattr(e, SOURCE_CODE), "python", theme=theme, background_color="default")
panel = Panel(syntax, border_style="red", title=type(e).__name__, title_align="left")
console.print(panel, no_wrap=False)
if isinstance(e, FlyteCompilationException):
e = annotate_exception_with_code(e, e.fn, e.param_name)
if hasattr(e, SOURCE_CODE):
# TODO: Use other way to check if the background is light or dark
theme = "emacs" if "LIGHT_BACKGROUND" in os.environ else "monokai"
syntax = Syntax(getattr(e, SOURCE_CODE), "python", theme=theme, background_color="default")
panel = Panel(syntax, border_style="red", title=e._ERROR_CODE, title_align="left")
console.print(panel, no_wrap=False)


def pretty_print_exception(e: Exception, verbosity: int = 1):
Expand All @@ -161,20 +164,14 @@ def pretty_print_exception(e: Exception, verbosity: int = 1):
pretty_print_grpc_error(cause)
else:
pretty_print_traceback(e, verbosity)
else:
pretty_print_traceback(e, verbosity)
return

if isinstance(e, grpc.RpcError):
pretty_print_grpc_error(e)
return

if isinstance(e, AssertionError):
click.secho(f"Assertion Error: {e}", fg="red")
return

if isinstance(e, ValueError):
click.secho(f"Value Error: {e}", fg="red")
return

pretty_print_traceback(e, verbosity)


Expand Down
24 changes: 20 additions & 4 deletions flytekit/core/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,16 @@

from flytekit.core import context_manager
from flytekit.core.artifact import Artifact, ArtifactIDSpecification, ArtifactQuery
from flytekit.core.context_manager import FlyteContextManager
from flytekit.core.docstring import Docstring
from flytekit.core.sentinel import DYNAMIC_INPUT_BINDING
from flytekit.core.type_engine import TypeEngine, UnionTransformer
from flytekit.exceptions.user import FlyteValidationException
from flytekit.exceptions.utils import annotate_exception_with_code
from flytekit.core.utils import has_return_statement
from flytekit.exceptions.user import (
FlyteMissingReturnValueException,
FlyteMissingTypeException,
FlyteValidationException,
)
from flytekit.loggers import developer_logger, logger
from flytekit.models import interface as _interface_models
from flytekit.models.literals import Literal, Scalar, Void
Expand Down Expand Up @@ -375,15 +380,26 @@ def transform_function_to_interface(fn: typing.Callable, docstring: Optional[Doc
signature = inspect.signature(fn)
return_annotation = type_hints.get("return", None)

ctx = FlyteContextManager.current_context()
# Only check if the task/workflow has a return statement at compile time locally.
if (
ctx.execution_state
and ctx.execution_state.mode is None
and return_annotation
and type(None) not in get_args(return_annotation)
and return_annotation is not type(None)
and has_return_statement(fn) is False
):
raise FlyteMissingReturnValueException(fn=fn)

outputs = extract_return_annotation(return_annotation)
for k, v in outputs.items():
outputs[k] = v # type: ignore
inputs: Dict[str, Tuple[Type, Any]] = OrderedDict()
for k, v in signature.parameters.items(): # type: ignore
annotation = type_hints.get(k, None)
if annotation is None:
err_msg = f"'{k}' has no type. Please add a type annotation to the input parameter."
raise annotate_exception_with_code(TypeError(err_msg), fn, k)
raise FlyteMissingTypeException(fn=fn, param_name=k)
default = v.default if v.default is not inspect.Parameter.empty else None
# Inputs with default values are currently ignored, we may want to look into that in the future
inputs[k] = (annotation, default) # type: ignore
Expand Down
12 changes: 12 additions & 0 deletions flytekit/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import datetime
import inspect
import os
import shutil
import tempfile
import time
import typing
from abc import ABC, abstractmethod
from functools import wraps
from hashlib import sha224 as _sha224
Expand Down Expand Up @@ -381,3 +383,13 @@ def get_extra_config(self):
Get the config of the decorator.
"""
pass


def has_return_statement(func: typing.Callable) -> bool:
source_lines = inspect.getsourcelines(func)[0]
for line in source_lines:
if "return" in line.strip():
return True
if "yield" in line.strip():
return True
return False
22 changes: 22 additions & 0 deletions flytekit/exceptions/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,25 @@ def __init__(self, request: typing.Any):

class FlytePromiseAttributeResolveException(FlyteAssertion):
_ERROR_CODE = "USER:PromiseAttributeResolveError"


class FlyteCompilationException(FlyteUserException):
_ERROR_CODE = "USER:CompileError"

def __init__(self, fn: typing.Callable, param_name: typing.Optional[str] = None):
self.fn = fn
self.param_name = param_name


class FlyteMissingTypeException(FlyteCompilationException):
_ERROR_CODE = "USER:MissingTypeError"

def __str__(self):
return f"'{self.param_name}' has no type. Please add a type annotation to the input parameter."


class FlyteMissingReturnValueException(FlyteCompilationException):
_ERROR_CODE = "USER:MissingReturnValueError"

def __str__(self):
return f"{self.fn.__name__} function must return a value. Please add a return statement at the end of the function."
34 changes: 9 additions & 25 deletions flytekit/exceptions/utils.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,28 @@
import ast
import inspect
import typing

from flytekit._ast.parser import get_function_param_location
from flytekit.core.constants import SOURCE_CODE
from flytekit.exceptions.user import FlyteUserException


def get_function_param_location(func: typing.Callable, param_name: str) -> (int, int):
"""
Get the line and column number of the parameter in the source code of the function definition.
"""
# Get source code of the function
source_lines, start_line = inspect.getsourcelines(func)
source_code = "".join(source_lines)

# Parse the source code into an AST
module = ast.parse(source_code)

# Traverse the AST to find the function definition
for node in ast.walk(module):
if isinstance(node, ast.FunctionDef) and node.name == func.__name__:
for i, arg in enumerate(node.args.args):
if arg.arg == param_name:
# Calculate the line and column number of the parameter
line_number = start_line + node.lineno - 1
column_offset = arg.col_offset
return line_number, column_offset


def get_source_code_from_fn(fn: typing.Callable, param_name: str) -> (str, int):
def get_source_code_from_fn(fn: typing.Callable, param_name: typing.Optional[str] = None) -> (str, int):
"""
Get the source code of the function and the column offset of the parameter defined in the input signature.
"""
lines, start_line = inspect.getsourcelines(fn)
if param_name is None:
return "".join(f"{start_line + i} {lines[i]}" for i in range(len(lines))), 0

target_line_no, column_offset = get_function_param_location(fn, param_name)
line_index = target_line_no - start_line
source_code = "".join(f"{start_line + i} {lines[i]}" for i in range(line_index + 1))
return source_code, column_offset


def annotate_exception_with_code(exception: Exception, fn: typing.Callable, param_name: str) -> Exception:
def annotate_exception_with_code(
exception: FlyteUserException, fn: typing.Callable, param_name: typing.Optional[str] = None
) -> FlyteUserException:
"""
Annotate the exception with the source code, and will be printed in the rich panel.
@param exception: The exception to be annotated.
Expand Down
10 changes: 6 additions & 4 deletions tests/flytekit/integration/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,14 +390,16 @@ def test_fetch_not_exist_launch_plan(register):


def test_execute_reference_task(register):
nt = typing.NamedTuple("OutputsBC", [("t1_int_output", int), ("c", str)])

@reference_task(
project=PROJECT,
domain=DOMAIN,
name="basic.basic_workflow.t1",
version=VERSION,
)
def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str):
...
def t1(a: int) -> nt:
return nt(t1_int_output=a + 2, c="world")

remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN)
execution = remote.execute(
Expand All @@ -424,7 +426,7 @@ def test_execute_reference_workflow(register):
version=VERSION,
)
def my_wf(a: int, b: str) -> (int, str):
...
return a + 2, b + "world"

remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN)
execution = remote.execute(
Expand All @@ -451,7 +453,7 @@ def test_execute_reference_launchplan(register):
version=VERSION,
)
def my_wf(a: int, b: str) -> (int, str):
...
return 3, "world"

remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN)
execution = remote.execute(
Expand Down
2 changes: 2 additions & 0 deletions tests/flytekit/unit/bin/test_python_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def test_dispatch_execute_user_error_non_recov(mock_write_to_file, mock_upload_d
def t1(a: int) -> str:
# Should be interpreted as a non-recoverable user error
raise ValueError(f"some exception {a}")
return "hello"

ctx = context_manager.FlyteContext.current_context()
with context_manager.FlyteContextManager.with_context(
Expand Down Expand Up @@ -242,6 +243,7 @@ def t1(a: int) -> str:
def my_subwf(a: int) -> typing.List[str]:
# This also tests the dynamic/compile path
raise user_exceptions.FlyteRecoverableException(f"recoverable {a}")
return ["1", "2"]

ctx = context_manager.FlyteContext.current_context()
with context_manager.FlyteContextManager.with_context(
Expand Down
4 changes: 2 additions & 2 deletions tests/flytekit/unit/core/flyte_functools/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ def test_unwrapped_task():
error = completed_process.stderr
error_str = ""
for line in error.strip().split("\n"):
if line.startswith("TypeError"):
if line.startswith("FlyteMissingTypeException"):
error_str += line
assert error_str != ""
assert error_str.startswith("TypeError: 'args' has no type. Please add a type annotation to the input")
assert "'args' has no type. Please add a type annotation" in error_str


@pytest.mark.parametrize("script", ["nested_function.py", "nested_wrapped_function.py"])
Expand Down
2 changes: 1 addition & 1 deletion tests/flytekit/unit/core/test_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def test_basic_option_a3():

@task
def t3(b_value: str) -> Annotated[pd.DataFrame, a3]:
...
return pd.DataFrame({"a": [1, 2, 3], "b": [b_value, b_value, b_value]})

entities = OrderedDict()
t3_s = get_serializable(entities, serialization_settings, t3)
Expand Down
4 changes: 2 additions & 2 deletions tests/flytekit/unit/core/test_imperative.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def ref_t1(
dataframe: pd.DataFrame,
imputation_method: str = "median",
) -> pd.DataFrame:
...
return dataframe

@reference_task(
project="flytesnacks",
Expand All @@ -340,7 +340,7 @@ def ref_t2(
split_mask: int,
num_features: int,
) -> pd.DataFrame:
...
return dataframe

wb = ImperativeWorkflow(name="core.feature_engineering.workflow.fe_wf")
wb.add_workflow_input("sqlite_archive", FlyteFile[typing.TypeVar("sqlite")])
Expand Down
Loading

0 comments on commit 7a6fe0a

Please sign in to comment.