Skip to content

Commit

Permalink
[db io managers] add default_load_type (#12356)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamiedemaria committed Feb 17, 2023
1 parent 9a3a8e2 commit c39e007
Show file tree
Hide file tree
Showing 8 changed files with 121 additions and 17 deletions.
18 changes: 16 additions & 2 deletions python_modules/dagster/dagster/_core/storage/db_io_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(
database: str,
schema: Optional[str] = None,
io_manager_name: Optional[str] = None,
default_load_type: Optional[Type] = None,
):
self._handlers_by_type: Dict[Optional[Type], DbTypeHandler] = {}
self._io_manager_name = io_manager_name or self.__class__.__name__
Expand All @@ -114,6 +115,14 @@ def __init__(
self._db_client = db_client
self._database = database
self._schema = schema
if (
default_load_type is None
and len(type_handlers) == 1
and len(type_handlers[0].supported_types) == 1
):
self._default_load_type = type_handlers[0].supported_types[0]
else:
self._default_load_type = default_load_type

def handle_output(self, context: OutputContext, obj: object) -> None:
table_slice = self._get_table_slice(context, context)
Expand Down Expand Up @@ -147,12 +156,17 @@ def handle_output(self, context: OutputContext, obj: object) -> None:

def load_input(self, context: InputContext) -> object:
obj_type = context.dagster_type.typing_type
self._check_supported_type(obj_type)
if obj_type is Any and self._default_load_type is not None:
load_type = self._default_load_type
else:
load_type = obj_type

self._check_supported_type(load_type)

table_slice = self._get_table_slice(context, cast(OutputContext, context.upstream_output))

with self._db_client.connect(context, table_slice) as conn:
return self._handlers_by_type[obj_type].load_input(context, table_slice, conn)
return self._handlers_by_type[load_type].load_input(context, table_slice, conn)

def _get_partition_value(
self, partition_def: PartitionsDefinition, partition_key: str
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from unittest.mock import MagicMock

import pytest
from dagster import AssetKey, InputContext, OutputContext, build_output_context
from dagster import AssetKey, InputContext, OutputContext, asset, build_output_context
from dagster._check import CheckError
from dagster._core.definitions.partition import StaticPartitionsDefinition
from dagster._core.definitions.time_window_partitions import DailyPartitionsDefinition, TimeWindow
Expand Down Expand Up @@ -436,3 +436,66 @@ def test_non_supported_type():
CheckError, match="DbIOManager does not have a handler for type '<class 'str'>'"
):
manager.handle_output(output_context, "a_string")


def test_default_load_type():
handler = IntHandler()
db_client = MagicMock(spec=DbClient, get_select_statement=MagicMock(return_value=""))
manager = DbIOManager(
type_handlers=[handler],
database=resource_config["database"],
db_client=db_client,
default_load_type=int,
)
asset_key = AssetKey(["schema1", "table1"])
output_context = build_output_context(asset_key=asset_key, resource_config=resource_config)

@asset
def asset1():
...

input_context = MagicMock(
upstream_output=output_context,
resource_config=resource_config,
dagster_type=asset1.op.outs["result"].dagster_type,
asset_key=asset_key,
has_asset_partitions=False,
metadata=None,
)

manager.handle_output(output_context, 1)
assert len(handler.handle_output_calls) == 1

assert manager.load_input(input_context) == 7

assert len(handler.handle_input_calls) == 1

assert handler.handle_input_calls[0][1] == TableSlice(
database="database_abc", schema="schema1", table="table1", partition_dimensions=[]
)


def test_default_load_type_determination():
int_handler = IntHandler()
string_handler = StringHandler()
db_client = MagicMock(spec=DbClient, get_select_statement=MagicMock(return_value=""))

manager = DbIOManager(
type_handlers=[int_handler], database=resource_config["database"], db_client=db_client
)
assert manager._default_load_type == int

manager = DbIOManager(
type_handlers=[int_handler, string_handler],
database=resource_config["database"],
db_client=db_client,
)
assert manager._default_load_type is None

manager = DbIOManager(
type_handlers=[int_handler, string_handler],
database=resource_config["database"],
db_client=db_client,
default_load_type=int,
)
assert manager._default_load_type == int
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,13 @@ def supported_types(self):
return [pd.DataFrame]


duckdb_pandas_io_manager = build_duckdb_io_manager([DuckDBPandasTypeHandler()])
duckdb_pandas_io_manager = build_duckdb_io_manager(
[DuckDBPandasTypeHandler()], default_load_type=pd.DataFrame
)
duckdb_pandas_io_manager.__doc__ = """
An IO manager definition that reads inputs from and writes pandas dataframes to DuckDB.
An IO manager definition that reads inputs from and writes Pandas DataFrames to DuckDB. When
using the duckdb_pandas_io_manager, any inputs and outputs without type annotations will be loaded
as Pandas DataFrames.
Returns:
IOManagerDefinition
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,13 @@ def supported_types(self):
return [pyspark.sql.DataFrame]


duckdb_pyspark_io_manager = build_duckdb_io_manager([DuckDBPySparkTypeHandler()])
duckdb_pyspark_io_manager = build_duckdb_io_manager(
[DuckDBPySparkTypeHandler()], default_load_type=pyspark.sql.DataFrame
)
duckdb_pyspark_io_manager.__doc__ = """
An IO manager definition that reads inputs from and writes PySpark DataFrames to DuckDB.
An IO manager definition that reads inputs from and writes PySpark DataFrames to DuckDB. When
using the duckdb_pyspark_io_manager, any inputs and outputs without type annotations will be loaded
as PySpark DataFrames.
Returns:
IOManagerDefinition
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from contextlib import contextmanager
from typing import Sequence, cast
from typing import Optional, Sequence, Type, cast

import duckdb
from dagster import Field, IOManagerDefinition, OutputContext, StringSource, io_manager
Expand All @@ -16,13 +16,17 @@
DUCKDB_DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S"


def build_duckdb_io_manager(type_handlers: Sequence[DbTypeHandler]) -> IOManagerDefinition:
def build_duckdb_io_manager(
type_handlers: Sequence[DbTypeHandler], default_load_type: Optional[Type] = None
) -> IOManagerDefinition:
"""
Builds an IO manager definition that reads inputs from and writes outputs to DuckDB.
Args:
type_handlers (Sequence[DbTypeHandler]): Each handler defines how to translate between
DuckDB tables and an in-memory type - e.g. a Pandas DataFrame.
DuckDB tables and an in-memory type - e.g. a Pandas DataFrame. If only
one DbTypeHandler is provided, it will be used as teh default_load_type.
default_load_type (Type): When an input has no type annotation, load it as this type.
Returns:
IOManagerDefinition
Expand Down Expand Up @@ -97,6 +101,7 @@ def duckdb_io_manager(init_context):
io_manager_name="DuckDBIOManager",
database=init_context.resource_config["database"],
schema=init_context.resource_config.get("schema"),
default_load_type=default_load_type,
)

return duckdb_io_manager
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,14 @@ def supported_types(self):
return [pd.DataFrame]


snowflake_pandas_io_manager = build_snowflake_io_manager([SnowflakePandasTypeHandler()])
snowflake_pandas_io_manager = build_snowflake_io_manager(
[SnowflakePandasTypeHandler()], default_load_type=pd.DataFrame
)
snowflake_pandas_io_manager.__doc__ = """
An IO manager definition that reads inputs from and writes pandas dataframes to Snowflake.
An IO manager definition that reads inputs from and writes Pandas DataFrames to Snowflake. When
using the snowflake_pandas_io_manager, any inputs and outputs without type annotations will be loaded
as Pandas DataFrames.
Returns:
IOManagerDefinition
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,13 @@ def supported_types(self):
return [DataFrame]


snowflake_pyspark_io_manager = build_snowflake_io_manager([SnowflakePySparkTypeHandler()])
snowflake_pyspark_io_manager = build_snowflake_io_manager(
[SnowflakePySparkTypeHandler()], default_load_type=DataFrame
)
snowflake_pyspark_io_manager.__doc__ = """
An IO manager definition that reads inputs from and writes PySpark DataFrames to Snowflake.
An IO manager definition that reads inputs from and writes PySpark DataFrames to Snowflake. When
using the snowflake_pyspark_io_manager, any inputs and outputs without type annotations will be loaded
as PySpark DataFrames.
Returns:
IOManagerDefinition
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from contextlib import contextmanager
from typing import Mapping, Sequence, cast
from typing import Mapping, Optional, Sequence, Type, cast

from dagster import Field, IOManagerDefinition, OutputContext, StringSource, io_manager
from dagster._core.definitions.time_window_partitions import TimeWindow
Expand All @@ -17,13 +17,17 @@
SNOWFLAKE_DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S"


def build_snowflake_io_manager(type_handlers: Sequence[DbTypeHandler]) -> IOManagerDefinition:
def build_snowflake_io_manager(
type_handlers: Sequence[DbTypeHandler], default_load_type: Optional[Type] = None
) -> IOManagerDefinition:
"""
Builds an IO manager definition that reads inputs from and writes outputs to Snowflake.
Args:
type_handlers (Sequence[DbTypeHandler]): Each handler defines how to translate between
slices of Snowflake tables and an in-memory type - e.g. a Pandas DataFrame.
slices of Snowflake tables and an in-memory type - e.g. a Pandas DataFrame. If only
one DbTypeHandler is provided, it will be used as teh default_load_type.
default_load_type (Type): When an input has no type annotation, load it as this type.
Returns:
IOManagerDefinition
Expand Down Expand Up @@ -134,6 +138,7 @@ def snowflake_io_manager(init_context):
io_manager_name="SnowflakeIOManager",
database=init_context.resource_config["database"],
schema=init_context.resource_config.get("schema"),
default_load_type=default_load_type,
)

return snowflake_io_manager
Expand Down

0 comments on commit c39e007

Please sign in to comment.