Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add upsert merge strategy #1294

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dlt/common/destination/capabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class DestinationCapabilitiesContext(ContainerInjectableContext):
supports_multiple_statements: bool = True
supports_clone_table: bool = False
"""Destination supports CREATE TABLE ... CLONE ... statements"""
supports_temp_table: bool = True
max_table_nesting: Optional[int] = None # destination can overwrite max table nesting

# do not allow to create default value, destination caps must be always explicitly inserted into container
Expand Down Expand Up @@ -84,6 +85,7 @@ def generic_capabilities(
caps.supports_ddl_transactions = True
caps.supports_transactions = True
caps.supports_multiple_statements = True
caps.supports_temp_table = True
return caps


Expand Down
62 changes: 41 additions & 21 deletions dlt/common/normalizers/json/relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@
from dlt.common.typing import DictStrAny, DictStrStr, TDataItem, StrAny
from dlt.common.schema import Schema
from dlt.common.schema.typing import (
TTableSchema,
TLoaderMergeStrategy,
TColumnSchema,
TColumnName,
TSimpleRegex,
DLT_NAME_PREFIX,
)
from dlt.common.schema.utils import column_name_validator, get_validity_column_names
from dlt.common.schema.utils import (
column_name_validator,
get_validity_column_names,
get_columns_names_with_prop,
)
from dlt.common.schema.exceptions import ColumnNameConflictException
from dlt.common.utils import digest128, update_dict_nested
from dlt.common.normalizers.json import (
Expand Down Expand Up @@ -136,14 +140,16 @@ def norm_row_dicts(dict_row: StrAny, __r_lvl: int, path: Tuple[str, ...] = ()) -
return cast(TDataItemRow, out_rec_row), out_rec_list

@staticmethod
def get_row_hash(row: Dict[str, Any]) -> str:
def get_row_hash(row: Dict[str, Any], subset: Optional[List[str]] = None) -> str:
"""Returns hash of row.

Hash includes column names and values and is ordered by column name.
Excludes dlt system columns.
Can be used as deterministic row identifier.
"""
row_filtered = {k: v for k, v in row.items() if not k.startswith(DLT_NAME_PREFIX)}
if subset is not None:
row_filtered = {k: v for k, v in row.items() if k in subset}
row_str = json.dumps(row_filtered, sort_keys=True)
return digest128(row_str, DLT_ID_LENGTH_BYTES)

Expand Down Expand Up @@ -240,28 +246,34 @@ def _normalize_row(
parent_row_id: Optional[str] = None,
pos: Optional[int] = None,
_r_lvl: int = 0,
row_hash: bool = False,
row_id_type: str = None,
) -> TNormalizedRowIterator:
schema = self.schema
table = schema.naming.shorten_fragments(*parent_path, *ident_path)
# compute row hash and set as row id
if row_hash:
table_name = schema.naming.shorten_fragments(*parent_path, *ident_path)
if row_id_type == "key_hash":
primary_key = self._get_primary_key(schema, table_name)
key_hash = self.get_row_hash(dict_row, subset=primary_key) # type: ignore[arg-type]
dict_row["_dlt_id"] = key_hash
elif row_id_type == "row_hash":
row_id = self.get_row_hash(dict_row) # type: ignore[arg-type]
dict_row["_dlt_id"] = row_id
# flatten current row and extract all lists to recur into
flattened_row, lists = self._flatten(table, dict_row, _r_lvl)
flattened_row, lists = self._flatten(table_name, dict_row, _r_lvl)
# always extend row
DataItemNormalizer._extend_row(extend, flattened_row)
# infer record hash or leave existing primary key if present
row_id = flattened_row.get("_dlt_id", None)
if not row_id:
row_id = self._add_row_id(table, flattened_row, parent_row_id, pos, _r_lvl)
row_id = self._add_row_id(table_name, flattened_row, parent_row_id, pos, _r_lvl)

# find fields to propagate to child tables in config
extend.update(self._get_propagated_values(table, flattened_row, _r_lvl))
extend.update(self._get_propagated_values(table_name, flattened_row, _r_lvl))

# yield parent table first
should_descend = yield (table, schema.naming.shorten_fragments(*parent_path)), flattened_row
should_descend = (
yield (table_name, schema.naming.shorten_fragments(*parent_path)),
flattened_row,
)
if should_descend is False:
return

Expand Down Expand Up @@ -320,18 +332,21 @@ def normalize_data_item(
row = cast(TDataItemRowRoot, item)
# identify load id if loaded data must be processed after loading incrementally
row["_dlt_load_id"] = load_id
# determine if row hash should be used as dlt id
row_hash = False
if self._is_scd2_table(self.schema, table_name):
row_hash = self._dlt_id_is_row_hash(self.schema, table_name)
# determine type of dlt id
row_id_type = None
if self._get_merge_strategy(self.schema, table_name) == "upsert":
row_id_type = "key_hash"
elif self._get_merge_strategy(self.schema, table_name) == "scd2":
if self._dlt_id_is_row_hash(self.schema, table_name):
row_id_type = "row_hash"
self._validate_validity_column_names(
self._get_validity_column_names(self.schema, table_name), item
)
yield from self._normalize_row(
cast(TDataItemRowChild, row),
{},
(self.schema.naming.normalize_table_identifier(table_name),),
row_hash=row_hash,
row_id_type=row_id_type,
)

@classmethod
Expand Down Expand Up @@ -368,11 +383,16 @@ def _validate_normalizer_config(schema: Schema, config: RelationalNormalizerConf

@staticmethod
@lru_cache(maxsize=None)
def _is_scd2_table(schema: Schema, table_name: str) -> bool:
if table_name in schema.data_table_names():
if schema.get_table(table_name).get("x-merge-strategy") == "scd2":
return True
return False
def _get_merge_strategy(schema: Schema, table_name: str) -> Optional[TLoaderMergeStrategy]:
if table_name in schema.data_table_names(include_incomplete=True):
return schema.get_table(table_name).get("x-merge-strategy") # type: ignore[return-value]
return None

@staticmethod
@lru_cache(maxsize=None)
def _get_primary_key(schema: Schema, table_name: str) -> List[str]:
table = schema.get_table(table_name)
return get_columns_names_with_prop(table, "primary_key", include_incomplete=True)

@staticmethod
@lru_cache(maxsize=None)
Expand Down
4 changes: 2 additions & 2 deletions dlt/common/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,9 +553,9 @@ def data_tables(
)
]

def data_table_names(self) -> List[str]:
def data_table_names(self, include_incomplete: bool = False) -> List[str]:
"""Returns list of table table names. Excludes dlt table names."""
return [t["name"] for t in self.data_tables()]
return [t["name"] for t in self.data_tables(include_incomplete=include_incomplete)]

def dlt_tables(self) -> List[TTableSchema]:
"""Gets dlt tables"""
Expand Down
2 changes: 1 addition & 1 deletion dlt/common/schema/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ class NormalizerInfo(TypedDict, total=True):


TWriteDisposition = Literal["skip", "append", "replace", "merge"]
TLoaderMergeStrategy = Literal["delete-insert", "scd2"]
TLoaderMergeStrategy = Literal["delete-insert", "scd2", "upsert"]


WRITE_DISPOSITIONS: Set[TWriteDisposition] = set(get_args(TWriteDisposition))
Expand Down
1 change: 1 addition & 0 deletions dlt/destinations/impl/athena/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@ def capabilities() -> DestinationCapabilitiesContext:
caps.schema_supports_numeric_precision = False
caps.timestamp_precision = 3
caps.supports_truncate_command = False
caps.supports_temp_table = False
return caps
3 changes: 2 additions & 1 deletion dlt/destinations/impl/athena/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,8 @@ def _create_replace_followup_jobs(
return super()._create_replace_followup_jobs(table_chain)

def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]:
# fall back to append jobs for merge
if table_chain[0]["x-merge-strategy"] == "upsert": # type: ignore[typeddict-item]
return super()._create_merge_followup_jobs(table_chain)
return self._create_append_followup_jobs(table_chain)

def _is_iceberg_table(self, table: TTableSchema) -> bool:
Expand Down
6 changes: 6 additions & 0 deletions dlt/destinations/sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,12 @@ def make_qualified_table_name(self, table_name: str, escape: bool = True) -> str
table_name = self.capabilities.escape_identifier(table_name)
return f"{self.fully_qualified_dataset_name(escape=escape)}.{table_name}"

def get_qualified_table_names(self, table_name: str, escape: bool = True) -> Tuple[str, str]:
"""Returns qualified names for table and corresponding staging table as tuple."""
with self.with_staging_dataset(staging=True):
staging_table_name = self.make_qualified_table_name(table_name, escape)
return self.make_qualified_table_name(table_name, escape), staging_table_name

def escape_column_name(self, column_name: str, escape: bool = True) -> str:
if escape:
return self.capabilities.escape_identifier(column_name)
Expand Down
65 changes: 64 additions & 1 deletion dlt/destinations/sql_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ def generate_sql( # type: ignore[return]
merge_strategy = table_chain[0].get("x-merge-strategy", DEFAULT_MERGE_STRATEGY)
if merge_strategy == "delete-insert":
return cls.gen_merge_sql(table_chain, sql_client)
elif merge_strategy == "upsert":
return cls.gen_upsert_sql(table_chain, sql_client)
elif merge_strategy == "scd2":
return cls.gen_scd2_sql(table_chain, sql_client)

Expand Down Expand Up @@ -338,6 +340,67 @@ def _to_temp_table(cls, select_sql: str, temp_table_name: str) -> str:
"""
return f"CREATE TEMP TABLE {temp_table_name} AS {select_sql};"

@classmethod
def gen_upsert_sql(
cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any]
) -> List[str]:
sql: List[str] = []
root_table = table_chain[0]
root_table_name, staging_root_table_name = sql_client.get_qualified_table_names(
root_table["name"]
)
caps = sql_client.capabilities
escape_id = caps.escape_identifier
root_unique_key = escape_id(get_first_column_name_with_prop(root_table, "unique"))

# generate merge statement for root table
col_str = ", ".join(["{alias}" + escape_id(c) for c in root_table["columns"]])
update_str = ", ".join(
[escape_id(c) + " = " + "s." + escape_id(c) for c in root_table["columns"]]
)
sql.append(f"""
MERGE INTO {root_table_name} d USING {staging_root_table_name} s
ON d.{root_unique_key} = s.{root_unique_key}
WHEN MATCHED
THEN UPDATE SET {update_str}
WHEN NOT MATCHED
THEN INSERT ({col_str.format(alias="")}) VALUES ({col_str.format(alias="s.")});
""")

# generate statements for child tables if they exist
child_tables = table_chain[1:]
if child_tables:
if len(child_tables) > 1 and caps.supports_temp_table:
# store unique keys in temp table for efficiency
temp_table_name = cls._new_temp_table_name("delete", sql_client)
select_statement = f"SELECT {root_unique_key} FROM {staging_root_table_name}"
sql.append(cls._to_temp_table(select_statement, temp_table_name))

for table in child_tables:
unique_key = escape_id(get_first_column_name_with_prop(table, "unique"))
root_key = escape_id(get_first_column_name_with_prop(table, "root_key"))
table_name, staging_table_name = sql_client.get_qualified_table_names(table["name"])

# delete records for elements no longer in the list
sql.append(f"""
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dummy comment to create link.

DELETE FROM {table_name}
WHERE {root_key} IN (SELECT {root_unique_key} FROM {temp_table_name if caps.supports_temp_table else staging_root_table_name})
AND {unique_key} NOT IN (SELECT {unique_key} FROM {staging_table_name});
""")

# insert records for new elements in the list
col_str = ", ".join(
["{alias}" + escape_id(c) for c in list(table["columns"].keys())]
)
sql.append(f"""
MERGE INTO {table_name} d USING {staging_table_name} s
ON d.{unique_key} = s.{unique_key}
WHEN NOT MATCHED
THEN INSERT ({col_str.format(alias="")}) VALUES ({col_str.format(alias="s.")});
""")

return sql

@classmethod
def gen_merge_sql(
cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any]
Expand Down Expand Up @@ -437,7 +500,7 @@ def gen_merge_sql(
)
)

# delete from top table now that child tables have been prcessed
# delete from top table now that child tables have been processed
sql.append(
cls.gen_delete_from_sql(
root_table_name, unique_column, delete_temp_table_name, unique_column
Expand Down
8 changes: 1 addition & 7 deletions tests/load/pipeline/test_scd2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
destinations_configs,
DestinationTestConfiguration,
load_tables_to_dicts,
assert_records_as_set,
)
from tests.utils import TPythonTableFormat

Expand Down Expand Up @@ -74,13 +75,6 @@ def get_table(
)


def assert_records_as_set(actual: List[Dict[str, Any]], expected: List[Dict[str, Any]]) -> None:
"""Compares two lists of dicts regardless of order"""
actual_set = set(frozenset(dict_.items()) for dict_ in actual)
expected_set = set(frozenset(dict_.items()) for dict_ in expected)
assert actual_set == expected_set


@pytest.mark.essential
@pytest.mark.parametrize(
"destination_config,simple,validity_column_names",
Expand Down
Loading