Skip to content

Commit

Permalink
[db io managers] connection refactor (#12258)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamiedemaria committed Feb 17, 2023
1 parent e660cd8 commit 074ae45
Show file tree
Hide file tree
Showing 11 changed files with 187 additions and 152 deletions.
34 changes: 22 additions & 12 deletions python_modules/dagster/dagster/_core/storage/db_io_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import (
Any,
Dict,
Expand Down Expand Up @@ -50,12 +51,12 @@ class TableSlice(NamedTuple):
class DbTypeHandler(ABC, Generic[T]):
@abstractmethod
def handle_output(
self, context: OutputContext, table_slice: TableSlice, obj: T
self, context: OutputContext, table_slice: TableSlice, obj: T, connection
) -> Optional[Mapping[str, RawMetadataValue]]:
"""Stores the given object at the given table in the given schema."""

@abstractmethod
def load_input(self, context: InputContext, table_slice: TableSlice) -> T:
def load_input(self, context: InputContext, table_slice: TableSlice, connection) -> T:
"""Loads the contents of the given table in the given schema."""

@property
Expand All @@ -67,7 +68,7 @@ def supported_types(self) -> Sequence[Type[object]]:
class DbClient:
@staticmethod
@abstractmethod
def delete_table_slice(context: OutputContext, table_slice: TableSlice) -> None:
def delete_table_slice(context: OutputContext, table_slice: TableSlice, connection) -> None:
...

@staticmethod
Expand All @@ -76,7 +77,13 @@ def get_select_statement(table_slice: TableSlice) -> str:
...

@staticmethod
def ensure_schema_exists(context: OutputContext, table_slice: TableSlice) -> None:
@abstractmethod
def ensure_schema_exists(context: OutputContext, table_slice: TableSlice, connection) -> None:
...

@staticmethod
@contextmanager
def connect(context: Union[OutputContext, InputContext], table_slice: TableSlice):
...


Expand Down Expand Up @@ -115,12 +122,14 @@ def handle_output(self, context: OutputContext, obj: object) -> None:
obj_type = type(obj)
self._check_supported_type(obj_type)

self._db_client.delete_table_slice(context, table_slice)
with self._db_client.connect(context, table_slice) as conn:
self._db_client.ensure_schema_exists(context, table_slice, conn)
self._db_client.delete_table_slice(context, table_slice, conn)

self._db_client.ensure_schema_exists(context, table_slice)
handler_metadata = (
self._handlers_by_type[obj_type].handle_output(context, table_slice, obj) or {}
)
handler_metadata = (
self._handlers_by_type[obj_type].handle_output(context, table_slice, obj, conn)
or {}
)
else:
check.invariant(
context.dagster_type.is_nothing,
Expand All @@ -140,9 +149,10 @@ def load_input(self, context: InputContext) -> object:
obj_type = context.dagster_type.typing_type
self._check_supported_type(obj_type)

return self._handlers_by_type[obj_type].load_input(
context, self._get_table_slice(context, cast(OutputContext, context.upstream_output))
)
table_slice = self._get_table_slice(context, cast(OutputContext, context.upstream_output))

with self._db_client.connect(context, table_slice) as conn:
return self._handlers_by_type[obj_type].load_input(context, table_slice, conn)

def _get_partition_value(
self, partition_def: PartitionsDefinition, partition_key: str
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ def __init__(self):
self.handle_input_calls = []
self.handle_output_calls = []

def handle_output(self, context: OutputContext, table_slice: TableSlice, obj: int):
def handle_output(self, context: OutputContext, table_slice: TableSlice, obj: int, connection):
self.handle_output_calls.append((context, table_slice, obj))

def load_input(self, context: InputContext, table_slice: TableSlice) -> int:
def load_input(self, context: InputContext, table_slice: TableSlice, connection) -> int:
self.handle_input_calls.append((context, table_slice))
return 7

Expand All @@ -48,10 +48,10 @@ def __init__(self):
self.handle_input_calls = []
self.handle_output_calls = []

def handle_output(self, context: OutputContext, table_slice: TableSlice, obj: str):
def handle_output(self, context: OutputContext, table_slice: TableSlice, obj: str, connection):
self.handle_output_calls.append((context, table_slice, obj))

def load_input(self, context: InputContext, table_slice: TableSlice) -> str:
def load_input(self, context: InputContext, table_slice: TableSlice, connection) -> str:
self.handle_input_calls.append((context, table_slice))
return "8"

Expand All @@ -73,7 +73,10 @@ def build_db_io_manager(type_handlers, db_client, resource_config_override=None)

def test_asset_out():
handler = IntHandler()
db_client = MagicMock(spec=DbClient, get_select_statement=MagicMock(return_value=""))
connect_mock = MagicMock()
db_client = MagicMock(
spec=DbClient, get_select_statement=MagicMock(return_value=""), connect=connect_mock
)
manager = build_db_io_manager(type_handlers=[handler], db_client=db_client)
asset_key = AssetKey(["schema1", "table1"])
output_context = build_output_context(asset_key=asset_key, resource_config=resource_config)
Expand All @@ -93,15 +96,20 @@ def test_asset_out():
database="database_abc", schema="schema1", table="table1", partition_dimensions=[]
)
assert handler.handle_output_calls[0][1:] == (table_slice, 5)
db_client.delete_table_slice.assert_called_once_with(output_context, table_slice)
db_client.delete_table_slice.assert_called_once_with(
output_context, table_slice, connect_mock().__enter__()
)

assert len(handler.handle_input_calls) == 1
assert handler.handle_input_calls[0][1] == table_slice


def test_asset_out_columns():
handler = IntHandler()
db_client = MagicMock(spec=DbClient, get_select_statement=MagicMock(return_value=""))
connect_mock = MagicMock()
db_client = MagicMock(
spec=DbClient, get_select_statement=MagicMock(return_value=""), connect=connect_mock
)
manager = build_db_io_manager(type_handlers=[handler], db_client=db_client)
asset_key = AssetKey(["schema1", "table1"])
output_context = build_output_context(asset_key=asset_key, resource_config=resource_config)
Expand All @@ -121,7 +129,9 @@ def test_asset_out_columns():
database="database_abc", schema="schema1", table="table1", partition_dimensions=[]
)
assert handler.handle_output_calls[0][1:] == (table_slice, 5)
db_client.delete_table_slice.assert_called_once_with(output_context, table_slice)
db_client.delete_table_slice.assert_called_once_with(
output_context, table_slice, connect_mock().__enter__()
)

assert len(handler.handle_input_calls) == 1
assert handler.handle_input_calls[0][1] == TableSlice(
Expand All @@ -135,7 +145,10 @@ def test_asset_out_columns():

def test_asset_out_partitioned():
handler = IntHandler()
db_client = MagicMock(spec=DbClient, get_select_statement=MagicMock(return_value=""))
connect_mock = MagicMock()
db_client = MagicMock(
spec=DbClient, get_select_statement=MagicMock(return_value=""), connect=connect_mock
)
manager = build_db_io_manager(type_handlers=[handler], db_client=db_client)
asset_key = AssetKey(["schema1", "table1"])
partitions_def = DailyPartitionsDefinition(start_date="2020-01-02")
Expand Down Expand Up @@ -176,15 +189,20 @@ def test_asset_out_partitioned():
],
)
assert handler.handle_output_calls[0][1:] == (table_slice, 5)
db_client.delete_table_slice.assert_called_once_with(output_context, table_slice)
db_client.delete_table_slice.assert_called_once_with(
output_context, table_slice, connect_mock().__enter__()
)

assert len(handler.handle_input_calls) == 1
assert handler.handle_input_calls[0][1] == table_slice


def test_asset_out_static_partitioned():
handler = IntHandler()
db_client = MagicMock(spec=DbClient, get_select_statement=MagicMock(return_value=""))
connect_mock = MagicMock()
db_client = MagicMock(
spec=DbClient, get_select_statement=MagicMock(return_value=""), connect=connect_mock
)
manager = build_db_io_manager(type_handlers=[handler], db_client=db_client)
asset_key = AssetKey(["schema1", "table1"])
partitions_def = StaticPartitionsDefinition(["red", "yellow", "blue"])
Expand Down Expand Up @@ -220,7 +238,9 @@ def test_asset_out_static_partitioned():
],
)
assert handler.handle_output_calls[0][1:] == (table_slice, 5)
db_client.delete_table_slice.assert_called_once_with(output_context, table_slice)
db_client.delete_table_slice.assert_called_once_with(
output_context, table_slice, connect_mock().__enter__()
)

assert len(handler.handle_input_calls) == 1
assert handler.handle_input_calls[0][1] == table_slice
Expand All @@ -229,7 +249,10 @@ def test_asset_out_static_partitioned():
def test_different_output_and_input_types():
int_handler = IntHandler()
str_handler = StringHandler()
db_client = MagicMock(spec=DbClient, get_select_statement=MagicMock(return_value=""))
connect_mock = MagicMock()
db_client = MagicMock(
spec=DbClient, get_select_statement=MagicMock(return_value=""), connect=connect_mock
)
manager = build_db_io_manager(type_handlers=[int_handler, str_handler], db_client=db_client)
asset_key = AssetKey(["schema1", "table1"])
output_context = build_output_context(asset_key=asset_key, resource_config=resource_config)
Expand All @@ -240,7 +263,9 @@ def test_different_output_and_input_types():
database="database_abc", schema="schema1", table="table1", partition_dimensions=[]
)
assert int_handler.handle_output_calls[0][1:] == (table_slice, 5)
db_client.delete_table_slice.assert_called_once_with(output_context, table_slice)
db_client.delete_table_slice.assert_called_once_with(
output_context, table_slice, connect_mock().__enter__()
)

input_context = MagicMock(
asset_key=asset_key,
Expand All @@ -259,7 +284,10 @@ def test_different_output_and_input_types():

def test_non_asset_out():
handler = IntHandler()
db_client = MagicMock(spec=DbClient, get_select_statement=MagicMock(return_value=""))
connect_mock = MagicMock()
db_client = MagicMock(
spec=DbClient, get_select_statement=MagicMock(return_value=""), connect=connect_mock
)
manager = build_db_io_manager(type_handlers=[handler], db_client=db_client)
output_context = build_output_context(
name="table1", metadata={"schema": "schema1"}, resource_config=resource_config
Expand All @@ -280,7 +308,9 @@ def test_non_asset_out():
database="database_abc", schema="schema1", table="table1", partition_dimensions=[]
)
assert handler.handle_output_calls[0][1:] == (table_slice, 5)
db_client.delete_table_slice.assert_called_once_with(output_context, table_slice)
db_client.delete_table_slice.assert_called_once_with(
output_context, table_slice, connect_mock().__enter__()
)

assert len(handler.handle_input_calls) == 1
assert handler.handle_input_calls[0][1] == table_slice
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pandas as pd
from dagster import InputContext, MetadataValue, OutputContext, TableColumn, TableSchema
from dagster._core.storage.db_io_manager import DbTypeHandler, TableSlice
from dagster_duckdb.io_manager import DuckDbClient, _connect_duckdb, build_duckdb_io_manager
from dagster_duckdb.io_manager import DuckDbClient, build_duckdb_io_manager


class DuckDBPandasTypeHandler(DbTypeHandler[pd.DataFrame]):
Expand Down Expand Up @@ -30,17 +30,19 @@ def my_repo():
"""

def handle_output(self, context: OutputContext, table_slice: TableSlice, obj: pd.DataFrame):
def handle_output(
self, context: OutputContext, table_slice: TableSlice, obj: pd.DataFrame, connection
):
"""Stores the pandas DataFrame in duckdb."""
conn = _connect_duckdb(context).cursor()

conn.execute(
connection.execute(
f"create table if not exists {table_slice.schema}.{table_slice.table} as select * from"
" obj;"
)
if not conn.fetchall():
if not connection.fetchall():
# table was not created, therefore already exists. Insert the data
conn.execute(f"insert into {table_slice.schema}.{table_slice.table} select * from obj")
connection.execute(
f"insert into {table_slice.schema}.{table_slice.table} select * from obj"
)

context.add_output_metadata(
{
Expand All @@ -56,10 +58,11 @@ def handle_output(self, context: OutputContext, table_slice: TableSlice, obj: pd
}
)

def load_input(self, context: InputContext, table_slice: TableSlice) -> pd.DataFrame:
def load_input(
self, context: InputContext, table_slice: TableSlice, connection
) -> pd.DataFrame:
"""Loads the input as a Pandas DataFrame."""
conn = _connect_duckdb(context).cursor()
return conn.execute(DuckDbClient.get_select_statement(table_slice)).fetchdf()
return connection.execute(DuckDbClient.get_select_statement(table_slice)).fetchdf()

@property
def supported_types(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pyspark.sql
from dagster import InputContext, MetadataValue, OutputContext, TableColumn, TableSchema
from dagster._core.storage.db_io_manager import DbTypeHandler, TableSlice
from dagster_duckdb.io_manager import DuckDbClient, _connect_duckdb, build_duckdb_io_manager
from dagster_duckdb.io_manager import DuckDbClient, build_duckdb_io_manager
from pyspark.sql import SparkSession


Expand Down Expand Up @@ -34,19 +34,21 @@ def my_repo():
"""

def handle_output(
self, context: OutputContext, table_slice: TableSlice, obj: pyspark.sql.DataFrame
self,
context: OutputContext,
table_slice: TableSlice,
obj: pyspark.sql.DataFrame,
connection,
):
"""Stores the given object at the provided filepath."""
conn = _connect_duckdb(context).cursor()

pd_df = obj.toPandas() # noqa: F841
conn.execute(
connection.execute(
f"create table if not exists {table_slice.schema}.{table_slice.table} as select * from"
" pd_df;"
)
if not conn.fetchall():
if not connection.fetchall():
# table was not created, therefore already exists. Insert the data
conn.execute(
connection.execute(
f"insert into {table_slice.schema}.{table_slice.table} select * from pd_df"
)

Expand All @@ -63,10 +65,11 @@ def handle_output(
}
)

def load_input(self, context: InputContext, table_slice: TableSlice) -> pyspark.sql.DataFrame:
def load_input(
self, context: InputContext, table_slice: TableSlice, connection
) -> pyspark.sql.DataFrame:
"""Loads the return of the query as the correct type."""
conn = _connect_duckdb(context).cursor()
pd_df = conn.execute(DuckDbClient.get_select_statement(table_slice)).fetchdf()
pd_df = connection.execute(DuckDbClient.get_select_statement(table_slice)).fetchdf()
spark = SparkSession.builder.getOrCreate()
return spark.createDataFrame(pd_df)

Expand Down

0 comments on commit 074ae45

Please sign in to comment.