Skip to content

Commit

Permalink
Type annotations in dagster-graphql (#10005)
Browse files Browse the repository at this point in the history
  • Loading branch information
smackesey committed Oct 27, 2022
1 parent 67313e8 commit cf84896
Show file tree
Hide file tree
Showing 20 changed files with 365 additions and 231 deletions.
27 changes: 19 additions & 8 deletions python_modules/dagster-graphql/dagster_graphql/cli.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Mapping, Optional, cast
from urllib.parse import urljoin, urlparse

import click
import requests
from graphql import graphql
from graphql.execution import ExecutionResult

import dagster._check as check
import dagster._seven as seven
Expand All @@ -26,7 +28,11 @@ def create_dagster_graphql_cli():
return ui


def execute_query(workspace_process_context, query, variables=None):
def execute_query(
workspace_process_context: WorkspaceProcessContext,
query: str,
variables: Optional[Mapping[str, object]] = None,
):
check.inst_param(
workspace_process_context, "workspace_process_context", WorkspaceProcessContext
)
Expand All @@ -37,11 +43,14 @@ def execute_query(workspace_process_context, query, variables=None):

context = workspace_process_context.create_request_context()

result = graphql(
request_string=query,
schema=create_schema(),
context_value=context,
variable_values=variables,
result = cast(
ExecutionResult,
graphql(
request_string=query,
schema=create_schema(),
context_value=context,
variable_values=variables,
),
)

result_dict = result.to_dict()
Expand All @@ -54,8 +63,10 @@ def execute_query(workspace_process_context, query, variables=None):
# in the 'stack_trace' property of each error to ease debugging

if "errors" in result_dict:
check.invariant(len(result_dict["errors"]) == len(result.errors))
for python_error, error_dict in zip(result.errors, result_dict["errors"]):
result_dict_errors = check.list_elem(result_dict, "errors", of_type=Exception)
result_errors = check.is_list(result.errors, of_type=Exception)
check.invariant(len(result_dict_errors) == len(result_errors)) #
for python_error, error_dict in zip(result_errors, result_dict_errors):
if hasattr(python_error, "original_error") and python_error.original_error:
error_dict["stack_trace"] = get_stack_trace_array(python_error.original_error)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from math import isnan
from typing import Any, Iterator, Sequence, cast, no_type_check

from dagster_graphql.schema.table import GrapheneTable, GrapheneTableSchema

Expand Down Expand Up @@ -28,7 +29,7 @@
MIN_INT = -2147483648


def iterate_metadata_entries(metadata_entries):
def iterate_metadata_entries(metadata_entries: Sequence[MetadataEntry]) -> Iterator[Any]:
from ..schema.metadata import (
GrapheneAssetMetadataEntry,
GrapheneBoolMetadataEntry,
Expand All @@ -45,7 +46,7 @@ def iterate_metadata_entries(metadata_entries):
GrapheneUrlMetadataEntry,
)

check.list_param(metadata_entries, "metadata_entries", of_type=MetadataEntry)
check.sequence_param(metadata_entries, "metadata_entries", of_type=MetadataEntry)
for metadata_entry in metadata_entries:
if isinstance(metadata_entry.entry_data, PathMetadataValue):
yield GraphenePathMetadataEntry(
Expand Down Expand Up @@ -88,7 +89,7 @@ def iterate_metadata_entries(metadata_entries):
float_val = metadata_entry.entry_data.value

# coerce NaN to null
if isnan(float_val):
if float_val is not None and isnan(float_val):
float_val = None

yield GrapheneFloatMetadataEntry(
Expand All @@ -99,7 +100,7 @@ def iterate_metadata_entries(metadata_entries):
elif isinstance(metadata_entry.entry_data, IntMetadataValue):
# coerce > 32 bit ints to null
int_val = None
if MIN_INT <= metadata_entry.entry_data.value <= MAX_INT:
if MIN_INT <= cast(int, metadata_entry.entry_data.value) <= MAX_INT:
int_val = metadata_entry.entry_data.value

yield GrapheneIntMetadataEntry(
Expand Down Expand Up @@ -155,11 +156,14 @@ def iterate_metadata_entries(metadata_entries):
)


def _to_metadata_entries(metadata_entries):
def _to_metadata_entries(metadata_entries: Sequence[MetadataEntry]) -> Sequence[Any]:
return list(iterate_metadata_entries(metadata_entries) or [])


def from_dagster_event_record(event_record, pipeline_name):
# We don't typecheck this due to the excessive number of type errors resulting from the
# non-type-checker legible relationship between `event_type` and the class of `event_specific_data`.
@no_type_check
def from_dagster_event_record(event_record: EventLogEntry, pipeline_name: str) -> Any:
from ..schema.errors import GraphenePythonError
from ..schema.logs.events import (
GrapheneAlertFailureEvent,
Expand Down Expand Up @@ -206,7 +210,7 @@ def from_dagster_event_record(event_record, pipeline_name):
check.param_invariant(event_record.is_dagster_event, "event_record")
check.str_param(pipeline_name, "pipeline_name")

dagster_event = event_record.dagster_event
dagster_event = check.not_none(event_record.dagster_event)
basic_params = construct_basic_params(event_record)
if dagster_event.event_type == DagsterEventType.STEP_START:
return GrapheneExecutionStepStartEvent(**basic_params)
Expand All @@ -223,7 +227,7 @@ def from_dagster_event_record(event_record, pipeline_name):
elif dagster_event.event_type == DagsterEventType.STEP_SUCCESS:
return GrapheneExecutionStepSuccessEvent(**basic_params)
elif dagster_event.event_type == DagsterEventType.STEP_INPUT:
input_data = dagster_event.event_specific_data
input_data = check.not_none(dagster_event.event_specific_data)
return GrapheneExecutionStepInputEvent(
input_name=input_data.input_name,
type_check=input_data.type_check_data,
Expand Down Expand Up @@ -322,7 +326,7 @@ def from_dagster_event_record(event_record, pipeline_name):
output_name=dagster_event.event_specific_data.output_name,
manager_key=dagster_event.event_specific_data.manager_key,
metadataEntries=_to_metadata_entries(
dagster_event.event_specific_data.metadata_entries
dagster_event.event_specific_data.metadata_entries # type: ignore
),
**basic_params,
)
Expand All @@ -333,7 +337,7 @@ def from_dagster_event_record(event_record, pipeline_name):
upstream_output_name=dagster_event.event_specific_data.upstream_output_name,
upstream_step_key=dagster_event.event_specific_data.upstream_step_key,
metadataEntries=_to_metadata_entries(
dagster_event.event_specific_data.metadata_entries
dagster_event.event_specific_data.metadata_entries # type: ignore
),
**basic_params,
)
Expand Down Expand Up @@ -410,7 +414,7 @@ def from_dagster_event_record(event_record, pipeline_name):
)


def from_event_record(event_record, pipeline_name):
def from_event_record(event_record: EventLogEntry, pipeline_name: str) -> Any:
from ..schema.logs.events import GrapheneLogMessageEvent

check.inst_param(event_record, "event_record", EventLogEntry)
Expand All @@ -422,20 +426,21 @@ def from_event_record(event_record, pipeline_name):
return GrapheneLogMessageEvent(**construct_basic_params(event_record))


def construct_basic_params(event_record):
def construct_basic_params(event_record: EventLogEntry) -> Any:
from ..schema.logs.log_level import GrapheneLogLevel

check.inst_param(event_record, "event_record", EventLogEntry)
dagster_event = event_record.dagster_event
return {
"runId": event_record.run_id,
"message": event_record.message,
"timestamp": int(event_record.timestamp * 1000),
"level": GrapheneLogLevel.from_level(event_record.level),
"eventType": event_record.dagster_event.event_type
if (event_record.dagster_event and event_record.dagster_event.event_type)
"eventType": dagster_event.event_type
if (dagster_event and dagster_event.event_type)
else None,
"stepKey": event_record.step_key,
"solidHandleID": event_record.dagster_event.solid_handle.to_string()
if event_record.is_dagster_event and event_record.dagster_event.solid_handle
"solidHandleID": event_record.dagster_event.solid_handle.to_string() # type: ignore
if dagster_event and dagster_event.solid_handle
else None,
}
Original file line number Diff line number Diff line change
@@ -1,17 +1,33 @@
from __future__ import annotations

import sys
from typing import TYPE_CHECKING, Mapping, Optional, Sequence, Union

from graphene import ResolveInfo

import dagster._check as check
from dagster._config import validate_config_from_snap
from dagster._core.execution.plan.state import KnownExecutionState
from dagster._core.host_representation import ExternalPipeline, PipelineSelector, RepositorySelector
from dagster._core.workspace.context import BaseWorkspaceRequestContext
from dagster._core.host_representation.external import ExternalExecutionPlan
from dagster._core.workspace.context import BaseWorkspaceRequestContext, WorkspaceRequestContext
from dagster._utils.error import serializable_error_info_from_exc_info

from .utils import UserFacingGraphQLError, capture_error

if TYPE_CHECKING:
from dagster_graphql.schema.errors import GrapheneRepositoryNotFoundError
from dagster_graphql.schema.external import (
GrapheneRepository,
GrapheneRepositoryConnection,
GrapheneWorkspace,
)
from dagster_graphql.schema.util import HasContext


def get_full_external_job_or_raise(graphene_info, selector):
def get_full_external_job_or_raise(
graphene_info: HasContext, selector: PipelineSelector
) -> ExternalPipeline:
from ..schema.errors import GraphenePipelineNotFoundError

check.inst_param(graphene_info, "graphene_info", ResolveInfo)
Expand All @@ -23,10 +39,11 @@ def get_full_external_job_or_raise(graphene_info, selector):
return graphene_info.context.get_full_external_job(selector)


def get_external_pipeline_or_raise(graphene_info, selector):
def get_external_pipeline_or_raise(
graphene_info: HasContext, selector: PipelineSelector
) -> ExternalPipeline:
check.inst_param(graphene_info, "graphene_info", ResolveInfo)
check.inst_param(selector, "selector", PipelineSelector)

full_pipeline = get_full_external_job_or_raise(graphene_info, selector)

if selector.solid_selection is None and selector.asset_selection is None:
Expand All @@ -35,7 +52,9 @@ def get_external_pipeline_or_raise(graphene_info, selector):
return get_subset_external_pipeline(graphene_info.context, selector)


def get_subset_external_pipeline(context, selector):
def get_subset_external_pipeline(
context: WorkspaceRequestContext, selector: PipelineSelector
) -> ExternalPipeline:
from ..schema.errors import GrapheneInvalidSubsetError
from ..schema.pipelines.pipeline import GraphenePipeline

Expand All @@ -62,7 +81,9 @@ def get_subset_external_pipeline(context, selector):
return external_pipeline


def ensure_valid_config(external_pipeline, mode, run_config):
def ensure_valid_config(
external_pipeline: ExternalPipeline, mode: Optional[str], run_config: object
) -> object:
from ..schema.pipelines.config import GrapheneRunConfigValidationInvalid

check.inst_param(external_pipeline, "external_pipeline", ExternalPipeline)
Expand All @@ -71,7 +92,7 @@ def ensure_valid_config(external_pipeline, mode, run_config):

validated_config = validate_config_from_snap(
config_schema_snapshot=external_pipeline.config_schema_snapshot,
config_type_key=external_pipeline.root_config_key_for_mode(mode),
config_type_key=check.not_none(external_pipeline.root_config_key_for_mode(mode)),
config_value=run_config,
)

Expand All @@ -87,25 +108,25 @@ def ensure_valid_config(external_pipeline, mode, run_config):


def get_external_execution_plan_or_raise(
graphene_info,
external_pipeline,
mode,
run_config,
step_keys_to_execute,
known_state,
):
graphene_info: HasContext,
external_pipeline: ExternalPipeline,
mode: Optional[str],
run_config: Mapping[str, object],
step_keys_to_execute: Sequence[str],
known_state: KnownExecutionState,
) -> ExternalExecutionPlan:

return graphene_info.context.get_external_execution_plan(
external_pipeline=external_pipeline,
run_config=run_config,
mode=mode,
mode=check.not_none(mode),
step_keys_to_execute=step_keys_to_execute,
known_state=known_state,
)


@capture_error
def fetch_repositories(graphene_info):
def fetch_repositories(graphene_info: HasContext) -> GrapheneRepositoryConnection:
from ..schema.external import GrapheneRepository, GrapheneRepositoryConnection

check.inst_param(graphene_info, "graphene_info", ResolveInfo)
Expand All @@ -123,7 +144,9 @@ def fetch_repositories(graphene_info):


@capture_error
def fetch_repository(graphene_info, repository_selector):
def fetch_repository(
graphene_info: HasContext, repository_selector: RepositorySelector
) -> Union[GrapheneRepository, GrapheneRepositoryNotFoundError]:
from ..schema.errors import GrapheneRepositoryNotFoundError
from ..schema.external import GrapheneRepository

Expand All @@ -145,7 +168,7 @@ def fetch_repository(graphene_info, repository_selector):


@capture_error
def fetch_workspace(workspace_request_context):
def fetch_workspace(workspace_request_context: WorkspaceRequestContext) -> GrapheneWorkspace:
from ..schema.external import GrapheneWorkspace, GrapheneWorkspaceLocationEntry

check.inst_param(
Expand Down

0 comments on commit cf84896

Please sign in to comment.