Skip to content

Commit

Permalink
Adjustments/condensing of dagster.check (#6751)
Browse files Browse the repository at this point in the history
* improve type annotations
* rename matrix to two_dim_list
* condense/rename type/subclass check funcs
  • Loading branch information
smackesey committed Feb 25, 2022
1 parent 7b3dde3 commit 1a3f5c1
Show file tree
Hide file tree
Showing 15 changed files with 410 additions and 338 deletions.
627 changes: 345 additions & 282 deletions python_modules/dagster/dagster/check/__init__.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion python_modules/dagster/dagster/config/config_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def __init__(
ConfigIntInstance = Int()
ConfigStringInstance = String()

_CONFIG_MAP: Dict[check.Type, ConfigType] = {
_CONFIG_MAP: Dict[check.TypeOrTupleOfTypes, ConfigType] = {
BuiltinEnum.ANY: ConfigAnyInstance,
BuiltinEnum.BOOL: ConfigBoolInstance,
BuiltinEnum.FLOAT: ConfigFloatInstance,
Expand Down
8 changes: 4 additions & 4 deletions python_modules/dagster/dagster/core/definitions/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,10 +292,10 @@ def __new__(
elif isinstance(asset_key, str):
asset_key = AssetKey(parse_asset_key_string(asset_key))
elif isinstance(asset_key, list):
check.is_list(asset_key, of_type=str)
check.list_param(asset_key, "asset_key", of_type=str)
asset_key = AssetKey(asset_key)
else:
check.is_tuple(asset_key, of_type=str)
check.tuple_param(asset_key, "asset_key", of_type=str)
asset_key = AssetKey(asset_key)

metadata = check.opt_dict_param(metadata, "metadata", key_type=str)
Expand Down Expand Up @@ -372,10 +372,10 @@ def __new__(
elif isinstance(asset_key, str):
asset_key = AssetKey(parse_asset_key_string(asset_key))
elif isinstance(asset_key, list):
check.is_list(asset_key, of_type=str)
check.list_param(asset_key, "asset_key", of_type=str)
asset_key = AssetKey(asset_key)
else:
check.is_tuple(asset_key, of_type=str)
check.tuple_param(asset_key, "asset_key", of_type=str)
asset_key = AssetKey(asset_key)

if tags:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ class TableRecord(NamedTuple("TableRecord", [("data", Dict[str, Union[str, int,
"""

def __new__(cls, **data):
check.is_dict(
check.dict_param(
data,
"data",
value_type=(str, float, int, bool, type(None)),
desc="Record fields must be one of types: (str, float, int, bool)",
additional_message="Record fields must be one of types: (str, float, int, bool)",
)
return super(TableRecord, cls).__new__(cls, data=data)

Expand Down
2 changes: 1 addition & 1 deletion python_modules/dagster/dagster/core/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def user_code_error_boundary(error_cls, msg_fn, log_manager=None, **kwargs):
"""
check.callable_param(msg_fn, "msg_fn")
check.subclass_param(error_cls, "error_cls", DagsterUserCodeExecutionError)
check.class_param(error_cls, "error_cls", superclass=DagsterUserCodeExecutionError)

with raise_execution_interrupts():
if log_manager:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def solid_execution_error_boundary(error_cls, msg_fn, step_context, **kwargs):
from dagster.core.execution.context.system import StepExecutionContext

check.callable_param(msg_fn, "msg_fn")
check.subclass_param(error_cls, "error_cls", DagsterUserCodeExecutionError)
check.class_param(error_cls, "error_cls", superclass=DagsterUserCodeExecutionError)
check.inst_param(step_context, "step_context", StepExecutionContext)

with raise_execution_interrupts():
Expand Down
2 changes: 1 addition & 1 deletion python_modules/dagster/dagster/core/instance/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1316,7 +1316,7 @@ def report_engine_event(
from dagster.core.events import EngineEventData, DagsterEvent, DagsterEventType
from dagster.core.events.log import EventLogEntry

check.class_param(cls, "cls")
check.opt_class_param(cls, "cls")
check.str_param(message, "message")
check.opt_inst_param(pipeline_run, "pipeline_run", PipelineRun)
check.opt_str_param(run_id, "run_id")
Expand Down
2 changes: 1 addition & 1 deletion python_modules/dagster/dagster/core/types/dagster_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def __init__(
typing_type = t.Union[python_type] # type: ignore

else:
self.python_type = check.type_param(python_type, "python_type") # type: ignore
self.python_type = check.class_param(python_type, "python_type") # type: ignore
self.type_str = cast(str, python_type.__name__)
typing_type = self.python_type # type: ignore
name = check.opt_str_param(name, "name", self.type_str)
Expand Down
2 changes: 1 addition & 1 deletion python_modules/dagster/dagster/core/types/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def s3_path(self):
"""

def _with_args(bare_cls):
check.type_param(bare_cls, "bare_cls")
check.class_param(bare_cls, "bare_cls")
new_name = name if name else bare_cls.__name__

make_python_type_usable_as_dagster_type(
Expand Down
2 changes: 1 addition & 1 deletion python_modules/dagster/dagster/serdes/serdes.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def whitelist_for_serdes(
check.class_param(__cls, "__cls")
return _whitelist_for_serdes(whitelist_map=_WHITELIST_MAP)(__cls)
else: # decorator passed params
check.opt_subclass_param(serializer, "serializer", Serializer)
check.opt_class_param(serializer, "serializer", superclass=Serializer)
serializer = cast(Type[Serializer], serializer)
return _whitelist_for_serdes(
whitelist_map=_WHITELIST_MAP, serializer=serializer, storage_name=storage_name
Expand Down
2 changes: 1 addition & 1 deletion python_modules/dagster/dagster/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ def __init__(
require_object: Optional[bool] = True,
):
self.generator = check.generator(generator)
self.object_cls: Type[GeneratedContext] = check.type_param(object_cls, "object_cls")
self.object_cls: Type[GeneratedContext] = check.class_param(object_cls, "object_cls")
self.require_object = check.bool_param(require_object, "require_object")
self.object: Optional[GeneratedContext] = None
self.did_setup = False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -928,26 +928,26 @@ def test_tuple_param():
check.is_tuple((3, 4), of_shape=(int, int), of_type=int)


def test_matrix_param():
assert check.matrix_param([[1, 2], [2, 3]], "something")
def test_two_dim_list_param():
assert check.two_dim_list_param([[1, 2], [2, 3]], "something")

with pytest.raises(CheckError):
assert check.matrix_param(None, "something")
assert check.two_dim_list_param(None, "something")

with pytest.raises(CheckError):
assert check.matrix_param([1, 2, 4], "something")
assert check.two_dim_list_param([1, 2, 4], "something")

with pytest.raises(CheckError):
assert check.matrix_param([], "something")
assert check.two_dim_list_param([], "something")

with pytest.raises(CheckError):
assert check.matrix_param([[1, 2], 3], "soemthing")
assert check.two_dim_list_param([[1, 2], 3], "soemthing")

with pytest.raises(CheckError):
assert check.matrix_param([[1, 2], [3.0, 4.1]], "something", of_type=int)
assert check.two_dim_list_param([[1, 2], [3.0, 4.1]], "something", of_type=int)

with pytest.raises(CheckError):
assert check.matrix_param([[1, 2], [2, 3, 4]], "something")
assert check.two_dim_list_param([[1, 2], [2, 3, 4]], "something")


def test_opt_tuple_param():
Expand Down Expand Up @@ -992,53 +992,51 @@ def test_opt_tuple_param():
check.is_tuple((3, 4), of_shape=(int, int), of_type=int)


def test_opt_type_param():
def test_opt_class_param():
class Foo:
pass

assert check.opt_type_param(int, "foo")
assert check.opt_type_param(Foo, "foo")
assert check.opt_class_param(int, "foo")
assert check.opt_class_param(Foo, "foo")

assert check.opt_type_param(None, "foo") is None
assert check.opt_type_param(None, "foo", Foo) is Foo
assert check.opt_class_param(None, "foo") is None
assert check.opt_class_param(None, "foo", Foo) is Foo

with pytest.raises(CheckError):
check.opt_type_param(check, "foo")
check.opt_class_param(check, "foo")

with pytest.raises(CheckError):
check.opt_type_param(234, "foo")
check.opt_class_param(234, "foo")

with pytest.raises(CheckError):
check.opt_type_param("bar", "foo")
check.opt_class_param("bar", "foo")

with pytest.raises(CheckError):
check.opt_type_param(Foo(), "foo")
check.opt_class_param(Foo(), "foo")


def test_type_param():
def test_class_param():
class Bar:
pass

assert check.type_param(int, "foo")
assert check.type_param(Bar, "foo")
assert check.class_param(int, "foo")
assert check.class_param(Bar, "foo")

with pytest.raises(CheckError):
check.type_param(None, "foo")
check.class_param(None, "foo")

with pytest.raises(CheckError):
check.type_param(check, "foo")
check.class_param(check, "foo")

with pytest.raises(CheckError):
check.type_param(234, "foo")
check.class_param(234, "foo")

with pytest.raises(CheckError):
check.type_param("bar", "foo")
check.class_param("bar", "foo")

with pytest.raises(CheckError):
check.type_param(Bar(), "foo")
check.class_param(Bar(), "foo")


def test_subclass_param():
class Super:
pass

Expand All @@ -1048,22 +1046,22 @@ class Sub(Super):
class Alone:
pass

assert check.subclass_param(Sub, "foo", Super)
assert check.class_param(Sub, "foo", superclass=Super)

with pytest.raises(CheckError):
assert check.subclass_param(Alone, "foo", Super)
assert check.class_param(Alone, "foo", superclass=Super)

with pytest.raises(CheckError):
assert check.subclass_param("value", "foo", Super)
assert check.class_param("value", "foo", superclass=Super)

assert check.opt_subclass_param(Sub, "foo", Super)
assert check.opt_subclass_param(None, "foo", Super) is None
assert check.opt_class_param(Sub, "foo", superclass=Super)
assert check.opt_class_param(None, "foo", superclass=Super) is None

with pytest.raises(CheckError):
assert check.opt_subclass_param(Alone, "foo", Super)
assert check.opt_class_param(Alone, "foo", superclass=Super)

with pytest.raises(CheckError):
assert check.opt_subclass_param("value", "foo", Super)
assert check.opt_class_param("value", "foo", superclass=Super)


@contextmanager
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def basic_generator():
with pytest.raises(CheckError, match="Not a generator"):
EventGenerationManager(None, int)

with pytest.raises(CheckError, match="was supposed to be a type"):
with pytest.raises(CheckError, match="must be a class"):
EventGenerationManager(basic_generator(), None)

with pytest.raises(CheckError, match="Called `get_object` before `generate_setup_events`"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def _make_airflow_dag(
dag_description = check.opt_str_param(
dag_description, "dag_description", _make_dag_description(job_name)
)
check.subclass_param(operator, "operator", BaseOperator)
check.class_param(operator, "operator", superclass=BaseOperator)

dag_kwargs = dict(
{"default_args": DEFAULT_ARGS},
Expand Down Expand Up @@ -361,7 +361,7 @@ def make_airflow_dag_for_operator(
(airflow.models.DAG, List[airflow.models.BaseOperator]): The generated Airflow DAG, and a
list of its constituent tasks.
"""
check.subclass_param(operator, "operator", BaseOperator)
check.class_param(operator, "operator", superclass=BaseOperator)

job_name = canonicalize_backcompat_args(
new_val=job_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def execute_query(self, query, fetch_results=False, cursor_factory=None, error_c
"""
check.str_param(query, "query")
check.bool_param(fetch_results, "fetch_results")
check.opt_subclass_param(cursor_factory, "cursor_factory", psycopg2.extensions.cursor)
check.opt_class_param(
cursor_factory, "cursor_factory", superclass=psycopg2.extensions.cursor
)
check.opt_callable_param(error_callback, "error_callback")

with self._get_conn() as conn:
Expand Down Expand Up @@ -124,7 +126,9 @@ def execute_queries(
"""
check.list_param(queries, "queries", of_type=str)
check.bool_param(fetch_results, "fetch_results")
check.opt_subclass_param(cursor_factory, "cursor_factory", psycopg2.extensions.cursor)
check.opt_class_param(
cursor_factory, "cursor_factory", superclass=psycopg2.extensions.cursor
)
check.opt_callable_param(error_callback, "error_callback")

results = []
Expand Down Expand Up @@ -174,7 +178,9 @@ def _get_conn(self):

@contextmanager
def _get_cursor(self, conn, cursor_factory=None):
check.opt_subclass_param(cursor_factory, "cursor_factory", psycopg2.extensions.cursor)
check.opt_class_param(
cursor_factory, "cursor_factory", superclass=psycopg2.extensions.cursor
)

# Could be none, in which case we should respect the connection default. Otherwise
# explicitly set to true/false.
Expand Down Expand Up @@ -217,7 +223,9 @@ def execute_query(self, query, fetch_results=False, cursor_factory=None, error_c
"""
check.str_param(query, "query")
check.bool_param(fetch_results, "fetch_results")
check.opt_subclass_param(cursor_factory, "cursor_factory", psycopg2.extensions.cursor)
check.opt_class_param(
cursor_factory, "cursor_factory", superclass=psycopg2.extensions.cursor
)
check.opt_callable_param(error_callback, "error_callback")

self.log.info("Executing query '{query}'".format(query=query))
Expand Down Expand Up @@ -248,7 +256,9 @@ def execute_queries(
"""
check.list_param(queries, "queries", of_type=str)
check.bool_param(fetch_results, "fetch_results")
check.opt_subclass_param(cursor_factory, "cursor_factory", psycopg2.extensions.cursor)
check.opt_class_param(
cursor_factory, "cursor_factory", superclass=psycopg2.extensions.cursor
)
check.opt_callable_param(error_callback, "error_callback")

for query in queries:
Expand Down

0 comments on commit 1a3f5c1

Please sign in to comment.