Skip to content

Commit

Permalink
dagster-dbt types (#7878)
Browse files Browse the repository at this point in the history
  • Loading branch information
smackesey committed May 13, 2022
1 parent 21b6a5e commit 78d4b57
Show file tree
Hide file tree
Showing 11 changed files with 51 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
MYPY_EXCLUDES = [
"python_modules/automation",
"python_modules/libraries/dagster-databricks",
"python_modules/libraries/dagster-dbt",
"python_modules/libraries/dagster-docker",
"examples/docs_snippets",
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ def with_retry(self, num_retries):

def on_queue(self, queue):
assert BuildkiteQueue.contains(queue)

self._step["agents"]["queue"] = queue.value
agents = self._step["agents"] # type: ignore
agents["queue"] = queue.value
return self

def depends_on(self, step_keys):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)

import dagster._check as check
from dagster.config.config_schema import ConfigSchemaType
from dagster.core.decorator_utils import format_docstring_for_description
from dagster.core.definitions.config import is_callable_valid_config_arg
from dagster.core.definitions.configurable import AnonymousConfigurableDefinition
Expand Down Expand Up @@ -80,7 +81,7 @@ class ResourceDefinition(AnonymousConfigurableDefinition):
def __init__(
self,
resource_fn: Callable[["InitResourceContext"], Any],
config_schema: Optional[Union[Any, IDefinitionConfigSchema]] = None,
config_schema: Optional[Union[Any, ConfigSchemaType]] = None,
description: Optional[str] = None,
required_resource_keys: Optional[AbstractSet[str]] = None,
version: Optional[str] = None,
Expand Down Expand Up @@ -270,7 +271,7 @@ def resource(config_schema=Callable[["InitResourceContext"], Any]) -> ResourceDe

@overload
def resource(
config_schema: Optional[Union[IDefinitionConfigSchema, Dict[str, Any]]] = ...,
config_schema: Optional[ConfigSchemaType] = ...,
description: Optional[str] = ...,
required_resource_keys: Optional[AbstractSet[str]] = ...,
version: Optional[str] = ...,
Expand All @@ -279,9 +280,7 @@ def resource(


def resource(
config_schema: Optional[
Union[Callable[["InitResourceContext"], Any], IDefinitionConfigSchema, Dict[str, Any]]
] = None,
config_schema: Union[Callable[["InitResourceContext"], Any], Optional[ConfigSchemaType]] = None,
description: Optional[str] = None,
required_resource_keys: Optional[AbstractSet[str]] = None,
version: Optional[str] = None,
Expand Down Expand Up @@ -312,7 +311,7 @@ def resource(
# This case is for when decorator is used bare, without arguments.
# E.g. @resource versus @resource()
if callable(config_schema) and not is_callable_valid_config_arg(config_schema):
return _ResourceDecoratorCallable()(config_schema)
return _ResourceDecoratorCallable()(config_schema) # type: ignore

def _wrap(resource_fn: Callable[["InitResourceContext"], Any]) -> "ResourceDefinition":
return _ResourceDecoratorCallable(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from dagster import _check as check
from dagster import get_dagster_logger
from dagster.core.asset_defs import AssetsDefinition, multi_asset
from dagster.core.definitions.metadata import RawMetadataValue


def _load_manifest_for_project(
Expand Down Expand Up @@ -67,7 +68,7 @@ def _dbt_nodes_to_assets(
select: str,
selected_unique_ids: AbstractSet[str],
runtime_metadata_fn: Optional[
Callable[[SolidExecutionContext, Mapping[str, Any]], Mapping[str, Any]]
Callable[[SolidExecutionContext, Mapping[str, Any]], Mapping[str, RawMetadataValue]]
] = None,
io_manager_key: Optional[str] = None,
node_info_to_asset_key: Callable[[Mapping[str, Any]], AssetKey] = _get_node_asset_key,
Expand Down Expand Up @@ -103,7 +104,6 @@ def _dbt_nodes_to_assets(

node_name = node_info["name"]
outs[node_name] = Out(
dagster_type=None,
asset_key=node_info_to_asset_key(node_info),
description=description,
io_manager_key=io_manager_key,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def make_request(
endpoint: str,
data: Optional[Dict[str, Any]] = None,
return_text: bool = False,
) -> Dict[str, Any]:
) -> Any:
"""
Creates and sends a request to the desired dbt Cloud API endpoint.
Expand Down Expand Up @@ -181,7 +181,7 @@ def get_runs(
order_by: Optional[str] = "-id",
offset: int = 0,
limit: int = 100,
) -> List[Dict[str, any]]:
) -> List[Dict[str, object]]:
"""
Returns a list of runs from dbt Cloud. This can be optionally filtered to a specific job
using the job_definition_id. It supports pagination using offset and limit as well and
Expand Down Expand Up @@ -381,8 +381,9 @@ def poll_run(
See: https://docs.getdbt.com/dbt-cloud/api-v2#operation/getRunById for schema.
"""

if not href:
if href is None:
href = self.get_run(run_id).get("href")
assert isinstance(href, str), "Run must have an href"

poll_start = datetime.datetime.now()
while True:
Expand Down
24 changes: 14 additions & 10 deletions python_modules/libraries/dagster-dbt/dagster_dbt/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ def my_dbt_rpc_job():
"""


@op(
# NOTE: mypy fails to properly track the type of `_DEFAULT_OP_PROPS` items when they are
# double-splatted, so we type-ignore the below op declarations.


@op( # type: ignore
**_DEFAULT_OP_PROPS,
config_schema={
"yield_asset_events": Field(
Expand Down Expand Up @@ -67,7 +71,7 @@ def dbt_build_op(context):
yield Output(dbt_output)


@op(
@op( # type: ignore
**_DEFAULT_OP_PROPS,
config_schema={
"yield_materializations": Field(
Expand Down Expand Up @@ -97,37 +101,37 @@ def dbt_run_op(context):
yield Output(dbt_output)


@op(**_DEFAULT_OP_PROPS)
@op(**_DEFAULT_OP_PROPS) # type: ignore
def dbt_compile_op(context):
return context.resources.dbt.compile()


@op(**_DEFAULT_OP_PROPS)
@op(**_DEFAULT_OP_PROPS) # type: ignore
def dbt_ls_op(context):
return context.resources.dbt.ls()


@op(**_DEFAULT_OP_PROPS)
@op(**_DEFAULT_OP_PROPS) # type: ignore
def dbt_test_op(context):
return context.resources.dbt.test()


@op(**_DEFAULT_OP_PROPS)
@op(**_DEFAULT_OP_PROPS) # type: ignore
def dbt_snapshot_op(context):
return context.resources.dbt.snapshot()


@op(**_DEFAULT_OP_PROPS)
@op(**_DEFAULT_OP_PROPS) # type: ignore
def dbt_seed_op(context):
return context.resources.dbt.seed()


@op(**_DEFAULT_OP_PROPS)
@op(**_DEFAULT_OP_PROPS) # type: ignore
def dbt_docs_generate_op(context):
return context.resources.dbt.generate_docs()


for op, cmd in [
for dbt_op, cmd in [
(dbt_build_op, "build"),
(dbt_run_op, "run"),
(dbt_compile_op, "compile"),
Expand All @@ -137,4 +141,4 @@ def dbt_docs_generate_op(context):
(dbt_seed_op, "seed"),
(dbt_docs_generate_op, "docs generate"),
]:
op.__doc__ = _get_doc(op.name, cmd)
dbt_op.__doc__ = _get_doc(dbt_op.name, cmd)
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import logging
import platform
import sys
import time
Expand Down Expand Up @@ -140,8 +141,8 @@ def jsonrpc_version(self) -> str:
return self._jsonrpc_version

@property
def logger(self) -> Optional[Any]:
"""Optional[Any]: A property for injecting a logger dependency."""
def logger(self) -> logging.Logger:
"""logging.Logger: A property for injecting a logger dependency."""
return self._logger

@property
Expand Down Expand Up @@ -518,10 +519,11 @@ def __init__(
self.poll_interval = poll_interval

def _get_result(self, data: Optional[str] = None) -> DbtRpcOutput:
"""Sends a request to the dbt RPC server and continuously polls for the status of a request until the state is ``success``."""
"""Sends a request to the dbt RPC server and continuously polls for the status of a request
until the state is ``success``."""

out = super()._get_result(data)
request_token = out.result.get("request_token")
request_token: str = check.not_none(out.result.get("request_token"))

logs_start = 0

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -907,7 +907,7 @@ def create_dbt_rpc_run_sql_solid(
tags={"kind": "dbt"},
**kwargs,
)
def _dbt_rpc_run_sql(context: SolidExecutionContext, sql: String) -> DataFrame:
def _dbt_rpc_run_sql(context: SolidExecutionContext, sql: String) -> pd.DataFrame:
out = context.resources.dbt_rpc.run_sql(sql=sql, name=context.solid_config["name"])
context.log.debug(out.response.text)
raise_for_rpc_error(context, out.response)
Expand Down
4 changes: 2 additions & 2 deletions python_modules/libraries/dagster-dbt/dagster_dbt/rpc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from dagster.core.execution.context.compute import SolidExecutionContext


def fmt_rpc_logs(logs: List[Dict]) -> Dict[int, str]:
def fmt_rpc_logs(logs: List[Dict[str, str]]) -> Dict[int, str]:
d = defaultdict(list)
for log in logs:
levelname = log.get("levelname")
levelname = log["levelname"]
d[getattr(logging, levelname)].append(
f"{log.get('timestamp')} - {levelname} - {log.get('message')}"
)
Expand Down
23 changes: 13 additions & 10 deletions python_modules/libraries/dagster-dbt/dagster_dbt/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Union
from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Union, cast

import dateutil

Expand Down Expand Up @@ -29,7 +29,7 @@ def _node_result_to_metadata(node_result: Dict[str, Any]) -> Mapping[str, RawMet


def _timing_to_metadata(timings: List[Dict[str, Any]]) -> Mapping[str, RawMetadataValue]:
metadata = {}
metadata: Dict[str, RawMetadataValue] = {}
for timing in timings:
if timing["name"] == "execute":
desc = "Execution"
Expand All @@ -38,8 +38,9 @@ def _timing_to_metadata(timings: List[Dict[str, Any]]) -> Mapping[str, RawMetada
else:
continue

started_at = dateutil.parser.isoparse(timing["started_at"])
completed_at = dateutil.parser.isoparse(timing["completed_at"])
# dateutil does not properly expose its modules to static checkers
started_at = dateutil.parser.isoparse(timing["started_at"]) # type: ignore
completed_at = dateutil.parser.isoparse(timing["completed_at"]) # type: ignore
duration = completed_at - started_at
metadata.update(
{
Expand Down Expand Up @@ -131,11 +132,13 @@ def generate_events(
"""

for result in dbt_output.result["results"]:
yield from result_to_events(
result,
docs_url=dbt_output.docs_url,
node_info_to_asset_key=node_info_to_asset_key,
manifest_json=manifest_json,
yield from check.not_none(
result_to_events(
result,
docs_url=dbt_output.docs_url,
node_info_to_asset_key=node_info_to_asset_key,
manifest_json=manifest_json,
)
)


Expand Down Expand Up @@ -185,4 +188,4 @@ def my_dbt_rpc_job():
asset_key_prefix + info["unique_id"].split(".")
),
):
yield check.inst(event, AssetMaterialization)
yield check.inst(cast(AssetMaterialization, event), AssetMaterialization)
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import subprocess
import time
from typing import Set
from urllib import request
from urllib.error import URLError

Expand Down Expand Up @@ -30,7 +31,7 @@ def get_rpc_server_status():
return json.load(resp)


all_subprocs = set()
all_subprocs: Set[subprocess.Popen] = set()


def kill_all_subprocs():
Expand Down

0 comments on commit 78d4b57

Please sign in to comment.