Skip to content

Commit

Permalink
Add py.typed to dagster and all extension libs (#7561)
Browse files Browse the repository at this point in the history
  • Loading branch information
smackesey committed May 2, 2022
1 parent dbbeadf commit 7b54c30
Show file tree
Hide file tree
Showing 118 changed files with 308 additions and 188 deletions.
3 changes: 1 addition & 2 deletions examples/hacker_news/hacker_news/ops/id_range_for_time.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from datetime import datetime, timezone
from typing import Tuple

from dagster import Out, Output, check, op
from dagster import Out, Output, Tuple, check, op


def binary_search_nearest_left(get_value, start, end, min_target):
Expand Down
22 changes: 11 additions & 11 deletions examples/hacker_news/hacker_news/resources/snowflake_io_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from snowflake.sqlalchemy import URL # pylint: disable=no-name-in-module,import-error
from sqlalchemy import create_engine

import dagster.check as check
from dagster import AssetKey, IOManager, InputContext, MetadataEntry, OutputContext, io_manager


Expand Down Expand Up @@ -75,12 +76,11 @@ def __init__(self, config):
self._config = config

def handle_output(self, context: OutputContext, obj: Union[PandasDataFrame, SparkDataFrame]):
schema, table = context.metadata["table"].split(".")
metadata = check.not_none(context.metadata)
schema, table = metadata["table"].split(".")

partition_bounds = (
context.resources.partition_bounds
if context.metadata.get("partitioned") is True
else None
context.resources.partition_bounds if metadata.get("partitioned") is True else None
)
with connect_snowflake(config=self._config, schema=schema) as con:
con.execute(self._get_cleanup_statement(table, schema, partition_bounds))
Expand All @@ -95,9 +95,7 @@ def handle_output(self, context: OutputContext, obj: Union[PandasDataFrame, Spar
)

yield MetadataEntry.text(
self._get_select_statement(
table, schema, context.metadata.get("columns"), partition_bounds
),
self._get_select_statement(table, schema, metadata.get("columns"), partition_bounds),
"Query",
)

Expand Down Expand Up @@ -147,10 +145,10 @@ def _get_cleanup_statement(
def load_input(self, context: InputContext) -> PandasDataFrame:
if context.upstream_output is not None:
# loading from an upstream output
metadata = context.upstream_output.metadata
metadata = check.not_none(context.upstream_output.metadata)
else:
# loading as a root input
metadata = context.metadata
metadata = check.not_none(context.metadata)

schema, table = metadata["table"].split(".")
with connect_snowflake(config=self._config) as con:
Expand Down Expand Up @@ -188,10 +186,12 @@ def _partition_where_clause(self, partition_bounds: Mapping[str, str]) -> str:
return f"""WHERE TO_TIMESTAMP(time::INT) BETWEEN '{partition_bounds["start"]}' AND '{partition_bounds["end"]}'"""

def get_output_asset_key(self, context: OutputContext) -> AssetKey:
return AssetKey(["snowflake", *context.metadata["table"].split(".")])
metadata = check.not_none(context.metadata)
return AssetKey(["snowflake", *metadata["table"].split(".")])

def get_output_asset_partitions(self, context: OutputContext):
if context.metadata.get("partitioned") is True:
metadata = check.not_none(context.metadata)
if metadata.get("partitioned") is True:
return [context.resources.partition_bounds["start"]]
else:
return None
Expand Down
3 changes: 2 additions & 1 deletion examples/hacker_news/hacker_news_tests/test_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from hacker_news.jobs.hacker_news_api_download import download_prod_job, download_staging_job

import dagster.check as check
from dagster import Partition
from dagster.core.definitions import JobDefinition
from dagster.core.execution.api import create_execution_plan
Expand All @@ -13,7 +14,7 @@ def assert_partitioned_schedule_builds(
start: datetime,
end: datetime,
):
partition_set = job_def.get_partition_set_def()
partition_set = check.not_none(job_def.get_partition_set_def())
run_config = partition_set.run_config_for_partition(Partition((start, end)))
create_execution_plan(job_def, run_config=run_config)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def load_input(self, context) -> Union[pyspark.sql.DataFrame, str]:
)

def _get_path(self, context: OutputContext):
key = context.asset_key.path[-1]
key = context.asset_key.path[-1] # type: ignore

if context.has_asset_partitions:
start, end = context.asset_partitions_time_window
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from snowflake.sqlalchemy import URL # pylint: disable=no-name-in-module,import-error
from sqlalchemy import create_engine

from dagster import IOManager, InputContext, MetadataEntry, OutputContext, io_manager
from dagster import IOManager, InputContext, MetadataEntry, OutputContext, check, io_manager

SNOWFLAKE_DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S"
DB_SCHEMA = "hackernews"
Expand Down Expand Up @@ -73,7 +73,7 @@ def __init__(self, config):
self._config = config

def handle_output(self, context: OutputContext, obj: Union[PandasDataFrame, SparkDataFrame]):
schema, table = DB_SCHEMA, context.asset_key.path[-1]
schema, table = DB_SCHEMA, context.asset_key.path[-1] # type: ignore

time_window = context.asset_partitions_time_window if context.has_asset_partitions else None
with connect_snowflake(config=self._config, schema=schema) as con:
Expand Down Expand Up @@ -151,7 +151,8 @@ def _get_cleanup_statement(
return f"DELETE FROM {schema}.{table}"

def load_input(self, context: InputContext) -> PandasDataFrame:
asset_key = context.upstream_output.asset_key
upstream_output = check.not_none(context.upstream_output)
asset_key = check.not_none(upstream_output.asset_key)

schema, table = DB_SCHEMA, asset_key.path[-1]
with connect_snowflake(config=self._config) as con:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""
# pylint: disable=print-call
import random
from typing import Any, Dict

import numpy as np
import pandas as pd
Expand All @@ -20,13 +21,24 @@
N_ORDERS = 10000


def _safe_request(
client: AirbyteResource, endpoint: str, data: Dict[str, object]
) -> Dict[str, Any]:
response = client.make_request(endpoint, data)
assert response, "Request returned null response"
return response


def _create_ab_source(client: AirbyteResource) -> str:
workspace_id = client.make_request("/workspaces/list", data={})["workspaces"][0]["workspaceId"]
workspace_id = _safe_request(client, "/workspaces/list", data={})["workspaces"][0][
"workspaceId"
]

# get latest available Postgres source definition
source_defs = client.make_request(
"/source_definitions/list_latest", data={"workspaceId": workspace_id}
)
assert source_defs
postgres_definitions = [
sd for sd in source_defs["sourceDefinitions"] if sd["name"] == "Postgres"
]
Expand All @@ -35,7 +47,8 @@ def _create_ab_source(client: AirbyteResource) -> str:
source_definition_id = postgres_definitions[0]["sourceDefinitionId"]

# create Postgres source
source_id = client.make_request(
source_id = _safe_request(
client,
"/sources/create",
data={
"sourceDefinitionId": source_definition_id,
Expand All @@ -49,11 +62,13 @@ def _create_ab_source(client: AirbyteResource) -> str:


def _create_ab_destination(client: AirbyteResource) -> str:
workspace_id = client.make_request("/workspaces/list", data={})["workspaces"][0]["workspaceId"]
workspace_id = _safe_request(client, "/workspaces/list", data={})["workspaces"][0][
"workspaceId"
]

# get the latest available Postgres destination definition
destination_defs = client.make_request(
"/destination_definitions/list_latest", data={"workspaceId": workspace_id}
destination_defs = _safe_request(
client, "/destination_definitions/list_latest", data={"workspaceId": workspace_id}
)
postgres_definitions = [
dd for dd in destination_defs["destinationDefinitions"] if dd["name"] == "Postgres"
Expand All @@ -63,7 +78,8 @@ def _create_ab_destination(client: AirbyteResource) -> str:
destination_definition_id = postgres_definitions[0]["destinationDefinitionId"]

# create Postgres destination
destination_id = client.make_request(
destination_id = _safe_request(
client,
"/destinations/create",
data={
"destinationDefinitionId": destination_definition_id,
Expand All @@ -81,12 +97,13 @@ def setup_airbyte():
source_id = _create_ab_source(client)
destination_id = _create_ab_destination(client)

source_catalog = client.make_request("/sources/discover_schema", data={"sourceId": source_id})[
"catalog"
]
source_catalog = _safe_request(
client, "/sources/discover_schema", data={"sourceId": source_id}
)["catalog"]

# create a connection between the new source and destination
connection_id = client.make_request(
connection_id = _safe_request(
client,
"/connections/create",
data={
"name": "Example Connection",
Expand Down
2 changes: 1 addition & 1 deletion examples/nyt-feed/nyt_feed/nyt_feed_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def handle_output(self, context: OutputContext, obj: pd.DataFrame):
obj.to_csv(self._get_path(context), index=False)

def load_input(self, context: InputContext) -> pd.DataFrame:
return pd.read_csv(self._get_path(context.upstream_output))
return pd.read_csv(self._get_path(context.upstream_output)) # type: ignore


@io_manager(config_schema={"base_dir": str})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
}


def assert_run_success(client, run_id: int):
def assert_run_success(client, run_id):
start_time = time.time()
while True:
if time.time() - start_time > MAX_TIMEOUT_SECONDS:
Expand Down
1 change: 1 addition & 0 deletions python_modules/dagit/MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
include README.rst
include LICENSE
include dagit/py.typed
recursive-include dagit/graphql-playground *
recursive-include dagit/webapp/build *
recursive-include dagit/templates/assets *
Empty file.
15 changes: 10 additions & 5 deletions python_modules/dagit/dagit/webserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import io
import uuid
from os import path
from typing import List
from typing import Generic, List, TypeVar

import nbformat
from dagster_graphql import __version__ as dagster_graphql_version
Expand All @@ -29,7 +29,7 @@
from dagster import check
from dagster.core.debug import DebugRunPayload
from dagster.core.storage.compute_log_manager import ComputeIOType
from dagster.core.workspace.context import WorkspaceProcessContext, WorkspaceRequestContext
from dagster.core.workspace.context import BaseWorkspaceRequestContext, IWorkspaceProcessContext
from dagster.seven import json
from dagster.utils import Counter, traced_counter

Expand All @@ -48,9 +48,14 @@
"/robots.txt",
]

T_IWorkspaceProcessContext = TypeVar("T_IWorkspaceProcessContext", bound=IWorkspaceProcessContext)

class DagitWebserver(GraphQLServer):
def __init__(self, process_context: WorkspaceProcessContext, app_path_prefix: str = ""):

class DagitWebserver(GraphQLServer, Generic[T_IWorkspaceProcessContext]):

_process_context: T_IWorkspaceProcessContext

def __init__(self, process_context: T_IWorkspaceProcessContext, app_path_prefix: str = ""):
self._process_context = process_context
super().__init__(app_path_prefix)

Expand All @@ -63,7 +68,7 @@ def build_graphql_middleware(self) -> list:
def relative_path(self, rel: str) -> str:
return path.join(path.dirname(__file__), rel)

def make_request_context(self, conn: HTTPConnection) -> WorkspaceRequestContext:
def make_request_context(self, conn: HTTPConnection) -> BaseWorkspaceRequestContext:
return self._process_context.create_request_context(conn)

def build_middleware(self) -> List[Middleware]:
Expand Down
1 change: 1 addition & 0 deletions python_modules/dagster-graphql/MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
include LICENSE
include dagster_graphql/py.typed
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections import defaultdict
from enum import Enum
from functools import lru_cache
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
from typing import Any, Dict, Iterable, List, Mapping, Optional, Set, Tuple

from dagster import DagsterInstance, check
from dagster.core.definitions.events import AssetKey
Expand Down Expand Up @@ -278,13 +278,16 @@ def __init__(self, instance: DagsterInstance, asset_keys: Iterable[AssetKey]):
self._instance = instance
self._asset_keys: List[AssetKey] = list(asset_keys)
self._fetched = False
self._materializations: Dict[AssetKey, EventLogEntry] = {}
self._materializations: Mapping[AssetKey, Optional[EventLogEntry]] = {}

def get_latest_materialization_for_asset_key(self, asset_key: AssetKey) -> EventLogEntry:
def get_latest_materialization_for_asset_key(
self, asset_key: AssetKey
) -> Optional[EventLogEntry]:
if asset_key not in self._asset_keys:
check.failed(
f"Asset key {asset_key} not recognized for this loader. Expected one of: {self._asset_keys}"
)

if self._materializations.get(asset_key) is None:
self._fetch()
return self._materializations.get(asset_key)
Expand Down Expand Up @@ -394,7 +397,7 @@ def _build_cross_repo_deps(

def get_sink_asset(self, asset_key: AssetKey) -> ExternalAssetNode:
sink_assets, _ = self._build_cross_repo_deps()
return sink_assets.get(asset_key)
return sink_assets[asset_key]

def get_cross_repo_dependent_assets(
self, repository_location_name: str, repository_name: str, asset_key: AssetKey
Expand Down
Empty file.

0 comments on commit 7b54c30

Please sign in to comment.