Skip to content

Commit

Permalink
Add missing method to SnowflakePandasTypeHandler (#8051)
Browse files Browse the repository at this point in the history
Summary:
Make the missing method abstract to prevent future missing implementations

Test Plan:
New test case
  • Loading branch information
gibsondan committed May 25, 2022
1 parent 4e1a53f commit 10b9225
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ def k8s_extra_cmds(version: str, _) -> List[str]:
unsupported_python_versions=[
AvailablePythonVersion.V3_6,
],
env_vars=["SNOWFLAKE_ACCOUNT", "SNOWFLAKE_BUILDKITE_PASSWORD"],
),
PackageSpec("python_modules/libraries/dagster-postgres", pytest_extra_cmds=postgres_extra_cmds),
PackageSpec(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def _connect_snowflake(context: Union[InputContext, OutputContext], table_slice:
**cast(Mapping[str, str], context.resource_config),
),
context.log,
).get_connection()
).get_connection(raw_conn=False)


class SnowflakePandasTypeHandler(DbTypeHandler[DataFrame]):
Expand Down Expand Up @@ -72,3 +72,7 @@ def load_input(self, context: InputContext, table_slice: TableSlice) -> DataFram
result = read_sql(sql=SnowflakeDbClient.get_select_statement(table_slice), con=con)
result.columns = map(str.lower, result.columns)
return result

@property
def supported_types(self):
return [DataFrame]
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
import os
from unittest.mock import patch

import pandas
import pytest
from dagster_snowflake import build_snowflake_io_manager
from dagster_snowflake.snowflake_io_manager import TableSlice
from dagster_snowflake_pandas import SnowflakePandasTypeHandler
from pandas import DataFrame

from dagster import (
MetadataValue,
Out,
TableColumn,
TableSchema,
build_input_context,
build_output_context,
job,
op,
)

resource_config = {
Expand All @@ -20,6 +27,8 @@
"warehouse": "warehouse_abc",
}

IS_BUILDKITE = os.getenv("BUILDKITE") is not None


def test_handle_output():
with patch("dagster_snowflake_pandas.snowflake_pandas_type_handler._connect_snowflake"):
Expand Down Expand Up @@ -67,3 +76,43 @@ def test_load_input():
)
assert mock_read_sql.call_args_list[0][1]["sql"] == "SELECT * FROM my_db.my_schema.my_table"
assert df.equals(DataFrame([{"col1": "a", "col2": 1}]))


@op(out=Out(io_manager_key="snowflake", metadata={"schema": "snowflake_io_manager_schema"}))
def emit_pandas_df(_):
return pandas.DataFrame({"foo": ["bar", "baz"], "quux": [1, 2]})


@op
def read_pandas_df(df: pandas.DataFrame):
assert set(df.columns) == {"foo", "quux"}


snowflake_io_manager = build_snowflake_io_manager([SnowflakePandasTypeHandler()])


@job(
resource_defs={"snowflake": snowflake_io_manager},
config={
"resources": {
"snowflake": {
"config": {
"account": {"env": "SNOWFLAKE_ACCOUNT"},
"user": "BUILDKITE",
"password": {
"env": "SNOWFLAKE_BUILDKITE_PASSWORD",
},
"database": "TEST_SNOWFLAKE_IO_MANAGER",
}
}
}
},
)
def io_manager_test_pipeline():
read_pandas_df(emit_pandas_df())


@pytest.mark.skipif(not IS_BUILDKITE, reason="Requires access to the BUILDKITE snowflake DB")
def test_io_manager_with_snowflake_pandas():
res = io_manager_test_pipeline.execute_in_process()
assert res.success
2 changes: 1 addition & 1 deletion python_modules/libraries/dagster-snowflake-pandas/tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ envlist = py{38,37,36}-{unix,windows},mypy,pylint

[testenv]
usedevelop = true
passenv = CI_* COVERALLS_REPO_TOKEN BUILDKITE
passenv = CI_* COVERALLS_REPO_TOKEN BUILDKITE SNOWFLAKE_BUILDKITE_PASSWORD SNOWFLAKE_ACCOUNT
deps =
-e ../../dagster[mypy,test]
-e ../dagster-snowflake
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ def load_input(self, context: InputContext, table_slice: TableSlice) -> T:
"""Loads the contents of the given table in the given schema."""

@property
@abstractmethod
def supported_types(self) -> Sequence[Type]:
...
pass


class DbClient:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def delete_table_slice(context: OutputContext, table_slice: TableSlice) -> None:
with SnowflakeConnection(
dict(**(context.resource_config or {}), schema=table_slice.schema), context.log
).get_connection() as con:
con.execute(_get_cleanup_statement(table_slice))
con.execute_string(_get_cleanup_statement(table_slice))

@staticmethod
def get_select_statement(table_slice: TableSlice) -> str:
Expand Down

0 comments on commit 10b9225

Please sign in to comment.