Skip to content

Commit

Permalink
refactor(databricks): add types to databricks.py (#12364)
Browse files Browse the repository at this point in the history
### Summary & Motivation
Wrangle some of the complexity into types.

### How I Tested These Changes
pyright
  • Loading branch information
rexledesma committed Feb 16, 2023
1 parent b592861 commit fbd6a8f
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 77 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import base64
import logging
import time
from typing import Any, Mapping, Optional

import dagster
import dagster._check as check
Expand All @@ -12,9 +14,9 @@
import dagster_databricks

from .types import (
DATABRICKS_RUN_TERMINATED_STATES,
DatabricksRunLifeCycleState,
DatabricksRunResultState,
DatabricksRunState,
)

# wait at most 24 hours by default for run execution
Expand All @@ -28,7 +30,7 @@ class DatabricksError(Exception):
class DatabricksClient:
"""A thin wrapper over the Databricks REST API."""

def __init__(self, host, token, workspace_id=None):
def __init__(self, host: str, token: str, workspace_id: Optional[str] = None):
self.host = host
self.workspace_id = workspace_id

Expand Down Expand Up @@ -77,37 +79,38 @@ def op1(context):
"""
return self._api_client

def submit_run(self, *args, **kwargs):
def submit_run(self, *args, **kwargs) -> int:
"""Submit a run directly to the 'Runs Submit' API."""
return self.client.jobs.submit_run(*args, **kwargs)["run_id"] # pylint: disable=no-member
return self.client.jobs.submit_run(*args, **kwargs)["run_id"]

def read_file(self, dbfs_path, block_size=1024**2):
def read_file(self, dbfs_path: str, block_size: int = 1024**2) -> bytes:
"""Read a file from DBFS to a **byte string**."""
if dbfs_path.startswith("dbfs://"):
dbfs_path = dbfs_path[7:]

data = b""
bytes_read = 0
jdoc = self.client.dbfs.read(path=dbfs_path, length=block_size) # pylint: disable=no-member

jdoc = self.client.dbfs.read(path=dbfs_path, length=block_size)
data += base64.b64decode(jdoc["data"])
while jdoc["bytes_read"] == block_size:
bytes_read += jdoc["bytes_read"]
jdoc = self.client.dbfs.read( # pylint: disable=no-member
path=dbfs_path, offset=bytes_read, length=block_size
)
jdoc = self.client.dbfs.read(path=dbfs_path, offset=bytes_read, length=block_size)
data += base64.b64decode(jdoc["data"])

return data

def put_file(self, file_obj, dbfs_path, overwrite=False, block_size=1024**2):
def put_file(
self, file_obj, dbfs_path: str, overwrite: bool = False, block_size: int = 1024**2
) -> None:
"""Upload an arbitrary large file to DBFS.
This doesn't use the DBFS `Put` API because that endpoint is limited to 1MB.
"""
if dbfs_path.startswith("dbfs://"):
dbfs_path = dbfs_path[7:]
create_response = self.client.dbfs.create( # pylint: disable=no-member
path=dbfs_path, overwrite=overwrite
)

create_response = self.client.dbfs.create(path=dbfs_path, overwrite=overwrite)
handle = create_response["handle"]

block = file_obj.read(block_size)
Expand All @@ -118,52 +121,39 @@ def put_file(self, file_obj, dbfs_path, overwrite=False, block_size=1024**2):

self.client.dbfs.close(handle=handle) # pylint: disable=no-member

def get_run(self, databricks_run_id):
return self.client.jobs.get_run(databricks_run_id) # pylint: disable=no-member
def get_run(self, databricks_run_id: int) -> Mapping[str, Any]:
return self.client.jobs.get_run(databricks_run_id)

def get_run_state(self, databricks_run_id):
"""Get the state of a run by Databricks run ID (_not_ dagster run ID).
def get_run_state(self, databricks_run_id: int) -> "DatabricksRunState":
"""Get the state of a run by Databricks run ID.
Return a `DatabricksRunState` object. Note that the `result_state`
attribute may be `None` if the run hasn't yet terminated.
"""
run = self.get_run(databricks_run_id)
state = run["state"]
result_state = state.get("result_state")
if result_state:
result_state = DatabricksRunResultState(result_state)
result_state = (
DatabricksRunResultState(state.get("result_state"))
if state.get("result_state")
else None
)

return DatabricksRunState(
life_cycle_state=DatabricksRunLifeCycleState(state["life_cycle_state"]),
result_state=result_state,
state_message=state["state_message"],
)


class DatabricksRunState:
"""Represents the state of a Databricks job run."""

def __init__(self, life_cycle_state, result_state, state_message):
self.life_cycle_state = life_cycle_state
self.result_state = result_state
self.state_message = state_message

def has_terminated(self):
"""Has the job terminated?"""
return self.life_cycle_state in DATABRICKS_RUN_TERMINATED_STATES

def is_successful(self):
"""Was the job successful?"""
return self.result_state == DatabricksRunResultState.Success

def __repr__(self):
return str(self.__dict__)


class DatabricksJobRunner:
"""Submits jobs created using Dagster config to Databricks, and monitors their progress."""

def __init__(
self, host, token, poll_interval_sec=5, max_wait_time_sec=DEFAULT_RUN_MAX_WAIT_TIME_SEC
self,
host: str,
token: str,
poll_interval_sec: float = 5,
max_wait_time_sec: int = DEFAULT_RUN_MAX_WAIT_TIME_SEC,
):
"""Args:
Expand All @@ -175,14 +165,14 @@ def __init__(
self.poll_interval_sec = check.numeric_param(poll_interval_sec, "poll_interval_sec")
self.max_wait_time_sec = check.int_param(max_wait_time_sec, "max_wait_time_sec")

self._client = DatabricksClient(host=self.host, token=self.token)
self._client: DatabricksClient = DatabricksClient(host=self.host, token=self.token)

@property
def client(self):
def client(self) -> DatabricksClient:
"""Return the underlying `DatabricksClient` object."""
return self._client

def submit_run(self, run_config, task):
def submit_run(self, run_config: Mapping[str, Any], task: Mapping[str, Any]) -> int:
"""Submit a new run using the 'Runs submit' API."""
existing_cluster_id = run_config["cluster"].get("existing")

Expand Down Expand Up @@ -264,7 +254,7 @@ def submit_run(self, run_config, task):
)
return self.client.submit_run(**config)

def retrieve_logs_for_run_id(self, log, databricks_run_id):
def retrieve_logs_for_run_id(self, log: logging.Logger, databricks_run_id: int):
"""Retrieve the stdout and stderr logs for a run."""
api_client = self.client.client
run = api_client.jobs.get_run(databricks_run_id) # pylint: disable=no-member
Expand All @@ -290,8 +280,14 @@ def retrieve_logs_for_run_id(self, log, databricks_run_id):
return stdout, stderr

def wait_for_dbfs_logs(
self, log, prefix, cluster_id, filename, waiter_delay=10, waiter_max_attempts=10
):
self,
log: logging.Logger,
prefix,
cluster_id,
filename,
waiter_delay: int = 10,
waiter_max_attempts: int = 10,
) -> Optional[str]:
"""Attempt up to `waiter_max_attempts` attempts to get logs from DBFS."""
path = "/".join([prefix, cluster_id, "driver", filename])
log.info("Retrieving logs from {}".format(path))
Expand All @@ -305,7 +301,9 @@ def wait_for_dbfs_logs(
time.sleep(waiter_delay)
log.warn("Could not retrieve cluster logs!")

def wait_for_run_to_complete(self, log, databricks_run_id, verbose_logs=True):
def wait_for_run_to_complete(
self, log: logging.Logger, databricks_run_id: int, verbose_logs: bool = True
):
return wait_for_run_to_complete(
self.client,
log,
Expand All @@ -317,8 +315,8 @@ def wait_for_run_to_complete(self, log, databricks_run_id, verbose_logs=True):


def poll_run_state(
client,
log,
client: DatabricksClient,
log: logging.Logger,
start_poll_time: float,
databricks_run_id: int,
max_wait_time_sec: float,
Expand Down Expand Up @@ -350,8 +348,13 @@ def poll_run_state(


def wait_for_run_to_complete(
client, log, databricks_run_id, poll_interval_sec, max_wait_time_sec, verbose_logs=True
):
client: DatabricksClient,
log: logging.Logger,
databricks_run_id: int,
poll_interval_sec: float,
max_wait_time_sec: int,
verbose_logs: bool = True,
) -> None:
"""Wait for a Databricks run to complete."""
check.int_param(databricks_run_id, "databricks_run_id")
log.info("Waiting for Databricks run %s to complete..." % databricks_run_id)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,10 @@
"""Types returned by the Databricks API.
"""
from enum import Enum
from typing import NamedTuple, Optional

from enum import Enum as PyEnum


class DatabricksRunResultState(PyEnum):
"""The result state of the run.
If life_cycle_state = TERMINATED: if the run had a task, the result is guaranteed to be
available, and it indicates the result of the task.
If life_cycle_state = PENDING, RUNNING, or SKIPPED, the result state is not available.
If life_cycle_state = TERMINATING or life_cycle_state = INTERNAL_ERROR: the result state
is available if the run had a task and managed to start it.
Once available, the result state never changes.
See https://docs.databricks.com/dev-tools/api/latest/jobs.html#runresultstate.
class DatabricksRunResultState(str, Enum):
"""
See https://docs.databricks.com/dev-tools/api/2.0/jobs.html#runresultstate.
"""

Success = "SUCCESS"
Expand All @@ -24,10 +13,9 @@ class DatabricksRunResultState(PyEnum):
Canceled = "CANCELED"


class DatabricksRunLifeCycleState(PyEnum):
"""The life cycle state of a run.
See https://docs.databricks.com/dev-tools/api/latest/jobs.html#runlifecyclestate.
class DatabricksRunLifeCycleState(str, Enum):
"""
See https://docs.databricks.com/dev-tools/api/2.0/jobs.html#jobsrunlifecyclestate.
"""

Pending = "PENDING"
Expand All @@ -43,3 +31,19 @@ class DatabricksRunLifeCycleState(PyEnum):
DatabricksRunLifeCycleState.Terminated,
DatabricksRunLifeCycleState.InternalError,
]


class DatabricksRunState(NamedTuple):
"""Represents the state of a Databricks job run."""

life_cycle_state: "DatabricksRunLifeCycleState"
result_state: Optional["DatabricksRunResultState"]
state_message: str

def has_terminated(self) -> bool:
"""Has the job terminated?"""
return self.life_cycle_state in DATABRICKS_RUN_TERMINATED_STATES

def is_successful(self) -> bool:
"""Was the job successful?"""
return self.result_state == DatabricksRunResultState.Success
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
import dagster_pyspark
import pytest
from dagster._utils.test import create_test_pipeline_execution_context
from dagster_databricks.databricks import DatabricksError, DatabricksJobRunner, DatabricksRunState
from dagster_databricks.types import DatabricksRunLifeCycleState, DatabricksRunResultState
from dagster_databricks.databricks import DatabricksError, DatabricksJobRunner
from dagster_databricks.types import (
DatabricksRunLifeCycleState,
DatabricksRunResultState,
DatabricksRunState,
)

HOST = "https://uksouth.azuredatabricks.net"
TOKEN = "super-secret-token"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
import pytest
from dagster import job
from dagster_databricks import create_databricks_job_op, databricks_client
from dagster_databricks.databricks import DatabricksRunState
from dagster_databricks.ops import create_ui_url
from dagster_databricks.types import DatabricksRunLifeCycleState, DatabricksRunResultState
from dagster_databricks.types import (
DatabricksRunLifeCycleState,
DatabricksRunResultState,
DatabricksRunState,
)


@pytest.mark.parametrize("job_creator", [create_databricks_job_op])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
DatabricksRunResultState,
databricks_pyspark_step_launcher,
)
from dagster_databricks.databricks import DatabricksRunState
from dagster_databricks.types import DatabricksRunState
from dagster_pyspark import DataFrame, pyspark_resource
from pyspark.sql import Row
from pyspark.sql.types import IntegerType, StringType, StructField, StructType
Expand Down

0 comments on commit fbd6a8f

Please sign in to comment.