Skip to content

Commit

Permalink
[dagster-snowflake-pyspark] fix bug loading partitions (#12472)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamiedemaria authored and clairelin135 committed Feb 22, 2023
1 parent c4f158d commit 9fa10cb
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pandas
import pytest
from dagster import (
AssetIn,
DailyPartitionsDefinition,
DynamicPartitionsDefinition,
IOManagerDefinition,
Expand All @@ -21,6 +22,7 @@
asset,
build_input_context,
build_output_context,
fs_io_manager,
instance_for_test,
job,
materialize,
Expand Down Expand Up @@ -256,9 +258,10 @@ def test_time_window_partitioned_asset():
db_name="TEST_SNOWFLAKE_IO_MANAGER",
column_str="TIME TIMESTAMP_NTZ(9), A string, B int",
) as table_name:
partitions_def = DailyPartitionsDefinition(start_date="2022-01-01")

@asset(
partitions_def=DailyPartitionsDefinition(start_date="2022-01-01"),
partitions_def=partitions_def,
metadata={"partition_expr": "time"},
config_schema={"value": str},
key_prefix="SNOWFLAKE_IO_MANAGER_SCHEMA",
Expand All @@ -276,6 +279,16 @@ def daily_partitioned(context) -> DataFrame:
}
)

@asset(
partitions_def=partitions_def,
key_prefix="SNOWFLAKE_IO_MANAGER_SCHEMA",
ins={"df": AssetIn(["SNOWFLAKE_IO_MANAGER_SCHEMA", table_name])},
io_manager_key="fs_io",
)
def downstream_partitioned(df) -> None:
# assert that we only get the columns created in daily_partitioned
assert len(df.index) == 3

asset_full_name = f"SNOWFLAKE_IO_MANAGER_SCHEMA__{table_name}"
snowflake_table_path = f"SNOWFLAKE_IO_MANAGER_SCHEMA.{table_name}"

Expand All @@ -288,10 +301,10 @@ def daily_partitioned(context) -> DataFrame:
)

snowflake_io_manager = snowflake_pandas_io_manager.configured(snowflake_config)
resource_defs = {"io_manager": snowflake_io_manager}
resource_defs = {"io_manager": snowflake_io_manager, "fs_io": fs_io_manager}

materialize(
[daily_partitioned],
[daily_partitioned, downstream_partitioned],
partition_key="2022-01-01",
resources=resource_defs,
run_config={"ops": {asset_full_name: {"config": {"value": "1"}}}},
Expand All @@ -303,7 +316,7 @@ def daily_partitioned(context) -> DataFrame:
assert out_df["A"].tolist() == ["1", "1", "1"]

materialize(
[daily_partitioned],
[daily_partitioned, downstream_partitioned],
partition_key="2022-01-02",
resources=resource_defs,
run_config={"ops": {asset_full_name: {"config": {"value": "2"}}}},
Expand All @@ -315,7 +328,7 @@ def daily_partitioned(context) -> DataFrame:
assert sorted(out_df["A"].tolist()) == ["1", "1", "1", "2", "2", "2"]

materialize(
[daily_partitioned],
[daily_partitioned, downstream_partitioned],
partition_key="2022-01-01",
resources=resource_defs,
run_config={"ops": {asset_full_name: {"config": {"value": "3"}}}},
Expand All @@ -334,9 +347,10 @@ def test_static_partitioned_asset():
db_name="TEST_SNOWFLAKE_IO_MANAGER",
column_str=" COLOR string, A string, B int",
) as table_name:
partitions_def = StaticPartitionsDefinition(["red", "yellow", "blue"])

@asset(
partitions_def=StaticPartitionsDefinition(["red", "yellow", "blue"]),
partitions_def=partitions_def,
key_prefix=["SNOWFLAKE_IO_MANAGER_SCHEMA"],
metadata={"partition_expr": "color"},
config_schema={"value": str},
Expand All @@ -353,6 +367,16 @@ def static_partitioned(context) -> DataFrame:
}
)

@asset(
partitions_def=partitions_def,
key_prefix="SNOWFLAKE_IO_MANAGER_SCHEMA",
ins={"df": AssetIn(["SNOWFLAKE_IO_MANAGER_SCHEMA", table_name])},
io_manager_key="fs_io",
)
def downstream_partitioned(df) -> None:
# assert that we only get the columns created in static_partitioned
assert len(df.index) == 3

asset_full_name = f"SNOWFLAKE_IO_MANAGER_SCHEMA__{table_name}"
snowflake_table_path = f"SNOWFLAKE_IO_MANAGER_SCHEMA.{table_name}"

Expand All @@ -365,9 +389,9 @@ def static_partitioned(context) -> DataFrame:
)

snowflake_io_manager = snowflake_pandas_io_manager.configured(snowflake_config)
resource_defs = {"io_manager": snowflake_io_manager}
resource_defs = {"io_manager": snowflake_io_manager, "fs_io": fs_io_manager}
materialize(
[static_partitioned],
[static_partitioned, downstream_partitioned],
partition_key="red",
resources=resource_defs,
run_config={"ops": {asset_full_name: {"config": {"value": "1"}}}},
Expand All @@ -379,7 +403,7 @@ def static_partitioned(context) -> DataFrame:
assert out_df["A"].tolist() == ["1", "1", "1"]

materialize(
[static_partitioned],
[static_partitioned, downstream_partitioned],
partition_key="blue",
resources=resource_defs,
run_config={"ops": {asset_full_name: {"config": {"value": "2"}}}},
Expand All @@ -391,7 +415,7 @@ def static_partitioned(context) -> DataFrame:
assert sorted(out_df["A"].tolist()) == ["1", "1", "1", "2", "2", "2"]

materialize(
[static_partitioned],
[static_partitioned, downstream_partitioned],
partition_key="red",
resources=resource_defs,
run_config={"ops": {asset_full_name: {"config": {"value": "3"}}}},
Expand All @@ -410,14 +434,15 @@ def test_multi_partitioned_asset():
db_name="TEST_SNOWFLAKE_IO_MANAGER",
column_str=" COLOR string, TIME TIMESTAMP_NTZ(9), A string",
) as table_name:
partitions_def = MultiPartitionsDefinition(
{
"time": DailyPartitionsDefinition(start_date="2022-01-01"),
"color": StaticPartitionsDefinition(["red", "yellow", "blue"]),
}
)

@asset(
partitions_def=MultiPartitionsDefinition(
{
"time": DailyPartitionsDefinition(start_date="2022-01-01"),
"color": StaticPartitionsDefinition(["red", "yellow", "blue"]),
}
),
partitions_def=partitions_def,
key_prefix=["SNOWFLAKE_IO_MANAGER_SCHEMA"],
metadata={"partition_expr": {"time": "CAST(time as TIMESTAMP)", "color": "color"}},
config_schema={"value": str},
Expand All @@ -434,6 +459,16 @@ def multi_partitioned(context) -> DataFrame:
}
)

@asset(
partitions_def=partitions_def,
key_prefix="SNOWFLAKE_IO_MANAGER_SCHEMA",
ins={"df": AssetIn(["SNOWFLAKE_IO_MANAGER_SCHEMA", table_name])},
io_manager_key="fs_io",
)
def downstream_partitioned(df) -> None:
# assert that we only get the columns created in multi_partitioned
assert len(df.index) == 3

asset_full_name = f"SNOWFLAKE_IO_MANAGER_SCHEMA__{table_name}"
snowflake_table_path = f"SNOWFLAKE_IO_MANAGER_SCHEMA.{table_name}"

Expand All @@ -446,10 +481,10 @@ def multi_partitioned(context) -> DataFrame:
)

snowflake_io_manager = snowflake_pandas_io_manager.configured(snowflake_config)
resource_defs = {"io_manager": snowflake_io_manager}
resource_defs = {"io_manager": snowflake_io_manager, "fs_io": fs_io_manager}

materialize(
[multi_partitioned],
[multi_partitioned, downstream_partitioned],
partition_key=MultiPartitionKey({"time": "2022-01-01", "color": "red"}),
resources=resource_defs,
run_config={"ops": {asset_full_name: {"config": {"value": "1"}}}},
Expand All @@ -461,7 +496,7 @@ def multi_partitioned(context) -> DataFrame:
assert out_df["A"].tolist() == ["1", "1", "1"]

materialize(
[multi_partitioned],
[multi_partitioned, downstream_partitioned],
partition_key=MultiPartitionKey({"time": "2022-01-01", "color": "blue"}),
resources=resource_defs,
run_config={"ops": {asset_full_name: {"config": {"value": "2"}}}},
Expand All @@ -473,7 +508,7 @@ def multi_partitioned(context) -> DataFrame:
assert sorted(out_df["A"].tolist()) == ["1", "1", "1", "2", "2", "2"]

materialize(
[multi_partitioned],
[multi_partitioned, downstream_partitioned],
partition_key=MultiPartitionKey({"time": "2022-01-02", "color": "red"}),
resources=resource_defs,
run_config={"ops": {asset_full_name: {"config": {"value": "3"}}}},
Expand All @@ -485,7 +520,7 @@ def multi_partitioned(context) -> DataFrame:
assert sorted(out_df["A"].tolist()) == ["1", "1", "1", "2", "2", "2", "3", "3", "3"]

materialize(
[multi_partitioned],
[multi_partitioned, downstream_partitioned],
partition_key=MultiPartitionKey({"time": "2022-01-01", "color": "red"}),
resources=resource_defs,
run_config={"ops": {asset_full_name: {"config": {"value": "4"}}}},
Expand Down Expand Up @@ -523,6 +558,16 @@ def dynamic_partitioned(context) -> DataFrame:
}
)

@asset(
partitions_def=dynamic_fruits,
key_prefix="SNOWFLAKE_IO_MANAGER_SCHEMA",
ins={"df": AssetIn(["SNOWFLAKE_IO_MANAGER_SCHEMA", table_name])},
io_manager_key="fs_io",
)
def downstream_partitioned(df) -> None:
# assert that we only get the columns created in dynamic_partitioned
assert len(df.index) == 3

asset_full_name = f"SNOWFLAKE_IO_MANAGER_SCHEMA__{table_name}"
snowflake_table_path = f"SNOWFLAKE_IO_MANAGER_SCHEMA.{table_name}"

Expand All @@ -535,13 +580,13 @@ def dynamic_partitioned(context) -> DataFrame:
)

snowflake_io_manager = snowflake_pandas_io_manager.configured(snowflake_config)
resource_defs = {"io_manager": snowflake_io_manager}
resource_defs = {"io_manager": snowflake_io_manager, "fs_io": fs_io_manager}

with instance_for_test() as instance:
dynamic_fruits.add_partitions(["apple"], instance)

materialize(
[dynamic_partitioned],
[dynamic_partitioned, downstream_partitioned],
partition_key="apple",
resources=resource_defs,
instance=instance,
Expand All @@ -556,7 +601,7 @@ def dynamic_partitioned(context) -> DataFrame:
dynamic_fruits.add_partitions(["orange"], instance)

materialize(
[dynamic_partitioned],
[dynamic_partitioned, downstream_partitioned],
partition_key="orange",
resources=resource_defs,
instance=instance,
Expand All @@ -569,7 +614,7 @@ def dynamic_partitioned(context) -> DataFrame:
assert sorted(out_df["A"].tolist()) == ["1", "1", "1", "2", "2", "2"]

materialize(
[dynamic_partitioned],
[dynamic_partitioned, downstream_partitioned],
partition_key="apple",
resources=resource_defs,
instance=instance,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dagster._core.definitions.metadata import RawMetadataValue
from dagster._core.storage.db_io_manager import DbTypeHandler, TableSlice
from dagster_snowflake import build_snowflake_io_manager
from dagster_snowflake.snowflake_io_manager import SnowflakeDbClient
from pyspark.sql import DataFrame, SparkSession

SNOWFLAKE_CONNECTOR = "net.snowflake.spark.snowflake"
Expand All @@ -23,7 +24,6 @@ def _get_snowflake_options(config, table_slice: TableSlice) -> Mapping[str, str]
"sfDatabase": config["database"],
"sfSchema": table_slice.schema,
"sfWarehouse": config["warehouse"],
"dbtable": table_slice.table,
}

return conf
Expand Down Expand Up @@ -69,9 +69,9 @@ def handle_output(

with_uppercase_cols = obj.toDF(*[c.upper() for c in obj.columns])

with_uppercase_cols.write.format(SNOWFLAKE_CONNECTOR).options(**options).mode(
"append"
).save()
with_uppercase_cols.write.format(SNOWFLAKE_CONNECTOR).options(**options).option(
"dbtable", table_slice.table
).mode("append").save()

return {
"dataframe_columns": MetadataValue.table_schema(
Expand All @@ -88,7 +88,12 @@ def load_input(self, context: InputContext, table_slice: TableSlice, _) -> DataF
options = _get_snowflake_options(context.resource_config, table_slice)

spark = SparkSession.builder.getOrCreate()
df = spark.read.format(SNOWFLAKE_CONNECTOR).options(**options).load()
df = (
spark.read.format(SNOWFLAKE_CONNECTOR)
.options(**options)
.option("query", SnowflakeDbClient.get_select_statement(table_slice))
.load()
)

return df.toDF(*[c.lower() for c in df.columns])

Expand Down

0 comments on commit 9fa10cb

Please sign in to comment.