Skip to content

Commit

Permalink
Fixes for asset graphql tests to make them easier to generalize and c…
Browse files Browse the repository at this point in the history
…all with other graphql context fixtures (#8199)
  • Loading branch information
gibsondan committed Jun 9, 2022
1 parent b6cf5d2 commit d421a1b
Show file tree
Hide file tree
Showing 12 changed files with 127 additions and 92 deletions.
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
# pylint: disable=missing-graphene-docstring
import json
import sys

import graphene
from graphene.types.generic import GenericScalar

import dagster._check as check
from dagster.utils.error import serializable_error_info_from_exc_info

from ..implementation.fetch_runs import get_runs, get_runs_count
from ..implementation.utils import UserFacingGraphQLError
from .errors import (
GrapheneInvalidPipelineRunsFilterError,
GraphenePythonError,
Expand Down Expand Up @@ -150,7 +153,9 @@ def parse_run_config_input(run_config, raise_on_error: bool):
return json.loads(run_config)
except json.JSONDecodeError:
if raise_on_error:
raise
raise UserFacingGraphQLError(
GraphenePythonError(serializable_error_info_from_exc_info(sys.exc_info()))
)
# Pass the config through as a string so that it will return a useful validation error
return run_config
return run_config
Expand Down
3 changes: 2 additions & 1 deletion python_modules/dagster-graphql/dagster_graphql/test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import dagster._check as check
from dagster.core.instance import DagsterInstance
from dagster.core.test_utils import wait_for_runs_to_finish
from dagster.core.workspace import WorkspaceProcessContext
from dagster.core.workspace.load_target import PythonFileTarget

Expand Down Expand Up @@ -40,7 +41,7 @@ def execute_dagster_graphql(context, query, variables=None):

def execute_dagster_graphql_and_finish_runs(context, query, variables=None):
result = execute_dagster_graphql(context, query, variables)
context.instance.run_launcher.join()
wait_for_runs_to_finish(context.instance, timeout=30)
return result


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def _sqlite_asset_instance():

class EnvironmentManagers:
@staticmethod
def managed_grpc(target=None):
def managed_grpc(target=None, location_name="test"):
@contextmanager
def _mgr_fn(instance, read_only):
"""Goes out of process via grpc"""
Expand All @@ -340,14 +340,14 @@ def _mgr_fn(instance, read_only):
python_file=loadable_target_origin.python_file,
attribute=loadable_target_origin.attribute,
working_directory=loadable_target_origin.working_directory,
location_name="test",
location_name=location_name,
)
if loadable_target_origin.python_file
else ModuleTarget(
module_name=loadable_target_origin.module_name,
attribute=loadable_target_origin.attribute,
working_directory=loadable_target_origin.working_directory,
location_name="test",
location_name=location_name,
)
),
version="",
Expand All @@ -358,7 +358,7 @@ def _mgr_fn(instance, read_only):
return MarkedManager(_mgr_fn, [Marks.managed_grpc_env])

@staticmethod
def deployed_grpc(target=None):
def deployed_grpc(target=None, location_name="test"):
@contextmanager
def _mgr_fn(instance, read_only):
server_process = GrpcServerProcess(
Expand All @@ -374,7 +374,7 @@ def _mgr_fn(instance, read_only):
port=api_client.port,
socket=api_client.socket,
host=api_client.host,
location_name="test",
location_name=location_name,
),
version="",
read_only=read_only,
Expand Down Expand Up @@ -543,10 +543,10 @@ def sqlite_with_queued_run_coordinator_managed_grpc_env():
)

@staticmethod
def sqlite_with_default_run_launcher_managed_grpc_env(target=None):
def sqlite_with_default_run_launcher_managed_grpc_env(target=None, location_name="test"):
return GraphQLContextVariant(
InstanceManagers.sqlite_instance_with_default_run_launcher(),
EnvironmentManagers.managed_grpc(target),
EnvironmentManagers.managed_grpc(target, location_name),
test_id="sqlite_with_default_run_launcher_managed_grpc_env",
)

Expand All @@ -560,26 +560,26 @@ def sqlite_read_only_with_default_run_launcher_managed_grpc_env():
)

@staticmethod
def sqlite_with_default_run_launcher_deployed_grpc_env(target=None):
def sqlite_with_default_run_launcher_deployed_grpc_env(target=None, location_name="test"):
return GraphQLContextVariant(
InstanceManagers.sqlite_instance_with_default_run_launcher(),
EnvironmentManagers.deployed_grpc(target),
EnvironmentManagers.deployed_grpc(target, location_name),
test_id="sqlite_with_default_run_launcher_deployed_grpc_env",
)

@staticmethod
def postgres_with_default_run_launcher_managed_grpc_env(target=None):
def postgres_with_default_run_launcher_managed_grpc_env(target=None, location_name="test"):
return GraphQLContextVariant(
InstanceManagers.postgres_instance_with_default_run_launcher(),
EnvironmentManagers.managed_grpc(target),
EnvironmentManagers.managed_grpc(target, location_name),
test_id="postgres_with_default_run_launcher_managed_grpc_env",
)

@staticmethod
def postgres_with_default_run_launcher_deployed_grpc_env(target=None):
def postgres_with_default_run_launcher_deployed_grpc_env(target=None, location_name="test"):
return GraphQLContextVariant(
InstanceManagers.postgres_instance_with_default_run_launcher(),
EnvironmentManagers.deployed_grpc(target),
EnvironmentManagers.deployed_grpc(target, location_name),
test_id="postgres_with_default_run_launcher_deployed_grpc_env",
)

Expand Down Expand Up @@ -700,12 +700,20 @@ def all_variants():
]

@staticmethod
def all_executing_variants(target=None):
def all_executing_variants(target=None, location_name="test"):
return [
GraphQLContextVariant.sqlite_with_default_run_launcher_managed_grpc_env(target),
GraphQLContextVariant.sqlite_with_default_run_launcher_deployed_grpc_env(target),
GraphQLContextVariant.postgres_with_default_run_launcher_managed_grpc_env(target),
GraphQLContextVariant.postgres_with_default_run_launcher_deployed_grpc_env(target),
GraphQLContextVariant.sqlite_with_default_run_launcher_managed_grpc_env(
target, location_name
),
GraphQLContextVariant.sqlite_with_default_run_launcher_deployed_grpc_env(
target, location_name
),
GraphQLContextVariant.postgres_with_default_run_launcher_managed_grpc_env(
target, location_name
),
GraphQLContextVariant.postgres_with_default_run_launcher_deployed_grpc_env(
target, location_name
),
]

@staticmethod
Expand Down Expand Up @@ -862,5 +870,7 @@ def execute(self, gql_query, variable_values=None):
python_file=file_relative_path(__file__, "cross_repo_asset_deps.py"),
)
AllRepositoryGraphQLContextTestMatrix = make_graphql_context_test_suite(
context_variants=GraphQLContextVariant.all_executing_variants(target=all_repos_loadable_target)
context_variants=GraphQLContextVariant.all_executing_variants(
target=all_repos_loadable_target, location_name="cross_asset_repos"
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
infer_repository_selector,
)

from dagster import AssetKey, DagsterEventType
from dagster import AssetKey, DagsterEventType, PipelineRunStatus
from dagster.core.test_utils import poll_for_finished_run
from dagster.utils import safe_tempfile_path

# from .graphql_context_test_suite import GraphQLContextVariant, make_graphql_context_test_suite
Expand Down Expand Up @@ -303,8 +304,9 @@ def _create_run(
variables={"executionParams": {"selector": selector, "mode": mode, "stepKeys": step_keys}},
)
assert result.data["launchPipelineExecution"]["__typename"] == "LaunchRunSuccess"
graphql_context.instance.run_launcher.join()
return result.data["launchPipelineExecution"]["run"]["runId"]
run_id = result.data["launchPipelineExecution"]["run"]["runId"]
poll_for_finished_run(graphql_context.instance, run_id)
return run_id


def _get_sorted_materialization_events(graphql_context, run_id):
Expand Down Expand Up @@ -369,8 +371,9 @@ def test_get_asset_key_lineage(self, graphql_context, snapshot):
variables={"executionParams": {"selector": selector, "mode": "default"}},
)
assert result.data["launchPipelineExecution"]["__typename"] == "LaunchRunSuccess"
run_id = result.data["launchPipelineExecution"]["run"]["runId"]

graphql_context.instance.run_launcher.join()
poll_for_finished_run(graphql_context.instance, run_id)

result = execute_dagster_graphql(
graphql_context,
Expand All @@ -388,8 +391,8 @@ def test_get_partitioned_asset_key_lineage(self, graphql_context, snapshot):
variables={"executionParams": {"selector": selector, "mode": "default"}},
)
assert result.data["launchPipelineExecution"]["__typename"] == "LaunchRunSuccess"

graphql_context.instance.run_launcher.join()
run_id = result.data["launchPipelineExecution"]["run"]["runId"]
poll_for_finished_run(graphql_context.instance, run_id)

result = execute_dagster_graphql(
graphql_context,
Expand Down Expand Up @@ -828,6 +831,7 @@ def test_asset_selection_in_run(self, graphql_context):
run_id = _create_run(graphql_context, "foo_job", asset_selection=[{"path": ["bar"]}])
run = graphql_context.instance.get_run_by_id(run_id)
assert run.is_finished
assert run.status == PipelineRunStatus.SUCCESS
assert run.asset_selection == {AssetKey("bar")}

def test_execute_pipeline_subset(self, graphql_context):
Expand Down Expand Up @@ -912,8 +916,9 @@ def test_reexecute_subset(self, graphql_context):
},
},
)
graphql_context.instance.run_launcher.join()
run_id = result.data["launchPipelineReexecution"]["run"]["runId"]
poll_for_finished_run(graphql_context.instance, run_id)

run = graphql_context.instance.get_run_by_id(run_id)
assert run.is_finished
events = _get_sorted_materialization_events(graphql_context, run_id)
Expand All @@ -925,7 +930,6 @@ def test_reexecute_subset(self, graphql_context):
class TestPersistentInstanceAssetInProgress(ExecutingGraphQLContextTestMatrix):
def test_asset_in_progress(self, graphql_context):
selector = infer_pipeline_selector(graphql_context, "hanging_job")
run_id = "foo"

with safe_tempfile_path() as path:
result = execute_dagster_graphql(
Expand All @@ -938,14 +942,15 @@ def test_asset_in_progress(self, graphql_context):
"runConfigData": {
"resources": {"hanging_asset_resource": {"config": {"file": path}}}
},
"executionMetadata": {"runId": run_id},
}
},
)

assert not result.errors
assert result.data

run_id = result.data["launchPipelineExecution"]["run"]["runId"]

# ensure the execution has happened
while not os.path.exists(path):
time.sleep(0.1)
Expand All @@ -972,23 +977,22 @@ def test_asset_in_progress(self, graphql_context):
assert len(assets_live_info) == 3

assert assets_live_info[0]["assetKey"]["path"] == ["first_asset"]
assert assets_live_info[0]["latestMaterialization"]["runId"] == "foo"
assert assets_live_info[0]["latestMaterialization"]["runId"] == run_id
assert assets_live_info[0]["unstartedRunIds"] == []
assert assets_live_info[0]["inProgressRunIds"] == []

assert assets_live_info[1]["assetKey"]["path"] == ["hanging_asset"]
assert assets_live_info[1]["latestMaterialization"] == None
assert assets_live_info[1]["unstartedRunIds"] == []
assert assets_live_info[1]["inProgressRunIds"] == ["foo"]
assert assets_live_info[1]["inProgressRunIds"] == [run_id]

assert assets_live_info[2]["assetKey"]["path"] == ["never_runs_asset"]
assert assets_live_info[2]["latestMaterialization"] == None
assert assets_live_info[2]["unstartedRunIds"] == ["foo"]
assert assets_live_info[2]["unstartedRunIds"] == [run_id]
assert assets_live_info[2]["inProgressRunIds"] == []

def test_graph_asset_in_progress(self, graphql_context):
selector = infer_pipeline_selector(graphql_context, "hanging_graph_asset_job")
run_id = "foo"

with safe_tempfile_path() as path:
result = execute_dagster_graphql(
Expand All @@ -1001,14 +1005,15 @@ def test_graph_asset_in_progress(self, graphql_context):
"runConfigData": {
"resources": {"hanging_asset_resource": {"config": {"file": path}}}
},
"executionMetadata": {"runId": run_id},
}
},
)

assert not result.errors
assert result.data

run_id = result.data["launchPipelineExecution"]["run"]["runId"]

# ensure the execution has happened
while not os.path.exists(path):
time.sleep(0.1)
Expand Down Expand Up @@ -1036,17 +1041,17 @@ def test_graph_asset_in_progress(self, graphql_context):
assert assets_live_info[1]["assetKey"]["path"] == ["hanging_graph"]
assert assets_live_info[1]["latestMaterialization"] == None
assert assets_live_info[1]["unstartedRunIds"] == []
assert assets_live_info[1]["inProgressRunIds"] == ["foo"]
assert assets_live_info[1]["inProgressRunIds"] == [run_id]

assert assets_live_info[0]["assetKey"]["path"] == ["downstream_asset"]
assert assets_live_info[0]["latestMaterialization"] == None
assert assets_live_info[0]["unstartedRunIds"] == ["foo"]
assert assets_live_info[0]["unstartedRunIds"] == [run_id]
assert assets_live_info[0]["inProgressRunIds"] == []


class TestCrossRepoAssetDependedBy(AllRepositoryGraphQLContextTestMatrix):
def test_cross_repo_assets(self, graphql_context):
repository_location = graphql_context.get_repository_location("test")
repository_location = graphql_context.get_repository_location("cross_asset_repos")
repository = repository_location.get_repository("upstream_assets_repository")

selector = {
Expand Down

0 comments on commit d421a1b

Please sign in to comment.