Skip to content

Commit

Permalink
Implement staging clone table in base class
Browse files Browse the repository at this point in the history
  • Loading branch information
steinitzu committed Jan 26, 2024
1 parent 051e9d3 commit 64b7e2e
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 93 deletions.
2 changes: 2 additions & 0 deletions dlt/common/destination/capabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class DestinationCapabilitiesContext(ContainerInjectableContext):
timestamp_precision: int = 6
max_rows_per_insert: Optional[int] = None
supports_multiple_statements: bool = True
supports_clone_table: bool = False
"""Destination supports CREATE TABLE ... CLONE ... statements"""

# do not allow to create default value, destination caps must be always explicitly inserted into container
can_create_default: ClassVar[bool] = False
Expand Down
1 change: 1 addition & 0 deletions dlt/destinations/impl/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@ def capabilities() -> DestinationCapabilitiesContext:
caps.max_text_data_type_length = 10 * 1024 * 1024
caps.is_max_text_data_type_length_in_bytes = True
caps.supports_ddl_transactions = False
caps.supports_clone_table = True

return caps
40 changes: 5 additions & 35 deletions dlt/destinations/impl/bigquery/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from dlt.destinations.impl.bigquery.configuration import BigQueryClientConfiguration
from dlt.destinations.impl.bigquery.sql_client import BigQuerySqlClient, BQ_TERMINAL_REASONS
from dlt.destinations.job_client_impl import SqlJobClientWithStaging
from dlt.destinations.sql_jobs import SqlMergeJob, SqlJobParams
from dlt.destinations.job_impl import NewReferenceJob
from dlt.destinations.sql_client import SqlClientBase
from dlt.destinations.sql_jobs import SqlMergeJob, SqlStagingCopyJob, SqlJobParams
Expand Down Expand Up @@ -149,28 +150,6 @@ def gen_key_table_clauses(
return sql


class BigqueryStagingCopyJob(SqlStagingCopyJob):
@classmethod
def generate_sql(
cls,
table_chain: Sequence[TTableSchema],
sql_client: SqlClientBase[Any],
params: Optional[SqlJobParams] = None,
) -> List[str]:
sql: List[str] = []
for table in table_chain:
with sql_client.with_staging_dataset(staging=True):
staging_table_name = sql_client.make_qualified_table_name(table["name"])
table_name = sql_client.make_qualified_table_name(table["name"])
sql.extend(
(
f"DROP TABLE IF EXISTS {table_name};",
f"CREATE TABLE {table_name} CLONE {staging_table_name};",
)
)
return sql


class BigQueryClient(SqlJobClientWithStaging, SupportsStagingDestination):
capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities()

Expand All @@ -190,13 +169,6 @@ def __init__(self, schema: Schema, config: BigQueryClientConfiguration) -> None:
def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]:
return [BigQueryMergeJob.from_table_chain(table_chain, self.sql_client)]

def _create_replace_followup_jobs(
self, table_chain: Sequence[TTableSchema]
) -> List[NewLoadJob]:
if self.config.replace_strategy == "staging-optimized":
return [BigqueryStagingCopyJob.from_table_chain(table_chain, self.sql_client)]
return super()._create_replace_followup_jobs(table_chain)

def restore_file_load(self, file_path: str) -> LoadJob:
"""Returns a completed SqlLoadJob or restored BigQueryLoadJob
Expand Down Expand Up @@ -280,9 +252,9 @@ def _get_table_update_sql(
elif (c := partition_list[0])["data_type"] == "date":
sql[0] = f"{sql[0]}\nPARTITION BY {self.capabilities.escape_identifier(c['name'])}"
elif (c := partition_list[0])["data_type"] == "timestamp":
sql[0] = (
f"{sql[0]}\nPARTITION BY DATE({self.capabilities.escape_identifier(c['name'])})"
)
sql[
0
] = f"{sql[0]}\nPARTITION BY DATE({self.capabilities.escape_identifier(c['name'])})"
# Automatic partitioning of an INT64 type requires us to be prescriptive - we treat the column as a UNIX timestamp.
# This is due to the bounds requirement of GENERATE_ARRAY function for partitioning.
# The 10,000 partitions limit makes it infeasible to cover the entire `bigint` range.
Expand All @@ -300,9 +272,7 @@ def _get_table_update_sql(

def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str:
name = self.capabilities.escape_identifier(c["name"])
return (
f"{name} {self.type_mapper.to_db_type(c, table_format)} {self._gen_not_null(c.get('nullable', True))}"
)
return f"{name} {self.type_mapper.to_db_type(c, table_format)} {self._gen_not_null(c.get('nullable', True))}"

def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns]:
schema_table: TTableSchemaColumns = {}
Expand Down
1 change: 1 addition & 0 deletions dlt/destinations/impl/databricks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,5 @@ def capabilities() -> DestinationCapabilitiesContext:
# caps.supports_transactions = False
caps.alter_add_multi_column = True
caps.supports_multiple_statements = False
caps.supports_clone_table = True
return caps
28 changes: 1 addition & 27 deletions dlt/destinations/impl/databricks/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from dlt.destinations.impl.databricks import capabilities
from dlt.destinations.impl.databricks.configuration import DatabricksClientConfiguration
from dlt.destinations.impl.databricks.sql_client import DatabricksSqlClient
from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlMergeJob, SqlJobParams
from dlt.destinations.sql_jobs import SqlMergeJob, SqlJobParams
from dlt.destinations.job_impl import NewReferenceJob
from dlt.destinations.sql_client import SqlClientBase
from dlt.destinations.type_mapping import TypeMapper
Expand Down Expand Up @@ -195,25 +195,6 @@ def exception(self) -> str:
raise NotImplementedError()


class DatabricksStagingCopyJob(SqlStagingCopyJob):
@classmethod
def generate_sql(
cls,
table_chain: Sequence[TTableSchema],
sql_client: SqlClientBase[Any],
params: Optional[SqlJobParams] = None,
) -> List[str]:
sql: List[str] = []
for table in table_chain:
with sql_client.with_staging_dataset(staging=True):
staging_table_name = sql_client.make_qualified_table_name(table["name"])
table_name = sql_client.make_qualified_table_name(table["name"])
sql.append(f"DROP TABLE IF EXISTS {table_name};")
# recreate destination table with data cloned from staging table
sql.append(f"CREATE TABLE {table_name} CLONE {staging_table_name};")
return sql


class DatabricksMergeJob(SqlMergeJob):
@classmethod
def _to_temp_table(cls, select_sql: str, temp_table_name: str) -> str:
Expand Down Expand Up @@ -272,13 +253,6 @@ def _make_add_column_sql(
# Override because databricks requires multiple columns in a single ADD COLUMN clause
return ["ADD COLUMN\n" + ",\n".join(self._get_column_def_sql(c) for c in new_columns)]

def _create_replace_followup_jobs(
self, table_chain: Sequence[TTableSchema]
) -> List[NewLoadJob]:
if self.config.replace_strategy == "staging-optimized":
return [DatabricksStagingCopyJob.from_table_chain(table_chain, self.sql_client)]
return super()._create_replace_followup_jobs(table_chain)

def _get_table_update_sql(
self,
table_name: str,
Expand Down
1 change: 1 addition & 0 deletions dlt/destinations/impl/snowflake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ def capabilities() -> DestinationCapabilitiesContext:
caps.is_max_text_data_type_length_in_bytes = True
caps.supports_ddl_transactions = True
caps.alter_add_multi_column = True
caps.supports_clone_table = True
return caps
34 changes: 5 additions & 29 deletions dlt/destinations/impl/snowflake/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from dlt.destinations.impl.snowflake import capabilities
from dlt.destinations.impl.snowflake.configuration import SnowflakeClientConfiguration
from dlt.destinations.impl.snowflake.sql_client import SnowflakeSqlClient
from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlJobParams
from dlt.destinations.sql_jobs import SqlJobParams
from dlt.destinations.impl.snowflake.sql_client import SnowflakeSqlClient
from dlt.destinations.job_impl import NewReferenceJob
from dlt.destinations.sql_client import SqlClientBase
Expand Down Expand Up @@ -175,13 +175,15 @@ def __init__(
f'PUT file://{file_path} @{stage_name}/"{load_id}" OVERWRITE = TRUE,'
" AUTO_COMPRESS = FALSE"
)
client.execute_sql(f"""COPY INTO {qualified_table_name}
client.execute_sql(
f"""COPY INTO {qualified_table_name}
{from_clause}
{files_clause}
{credentials_clause}
FILE_FORMAT = {source_format}
MATCH_BY_COLUMN_NAME='CASE_INSENSITIVE'
""")
"""
)
if stage_file_path and not keep_staged_files:
client.execute_sql(f"REMOVE {stage_file_path}")

Expand All @@ -192,25 +194,6 @@ def exception(self) -> str:
raise NotImplementedError()


class SnowflakeStagingCopyJob(SqlStagingCopyJob):
@classmethod
def generate_sql(
cls,
table_chain: Sequence[TTableSchema],
sql_client: SqlClientBase[Any],
params: Optional[SqlJobParams] = None,
) -> List[str]:
sql: List[str] = []
for table in table_chain:
with sql_client.with_staging_dataset(staging=True):
staging_table_name = sql_client.make_qualified_table_name(table["name"])
table_name = sql_client.make_qualified_table_name(table["name"])
sql.append(f"DROP TABLE IF EXISTS {table_name};")
# recreate destination table with data cloned from staging table
sql.append(f"CREATE TABLE {table_name} CLONE {staging_table_name};")
return sql


class SnowflakeClient(SqlJobClientWithStaging, SupportsStagingDestination):
capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities()

Expand Down Expand Up @@ -250,13 +233,6 @@ def _make_add_column_sql(
+ ",\n".join(self._get_column_def_sql(c, table_format) for c in new_columns)
]

def _create_replace_followup_jobs(
self, table_chain: Sequence[TTableSchema]
) -> List[NewLoadJob]:
if self.config.replace_strategy == "staging-optimized":
return [SnowflakeStagingCopyJob.from_table_chain(table_chain, self.sql_client)]
return super()._create_replace_followup_jobs(table_chain)

def _get_table_update_sql(
self,
table_name: str,
Expand Down
32 changes: 30 additions & 2 deletions dlt/destinations/sql_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,28 @@ class SqlStagingCopyJob(SqlBaseJob):
failed_text: str = "Tried to generate a staging copy sql job for the following tables:"

@classmethod
def generate_sql(
def _generate_clone_sql(
cls,
table_chain: Sequence[TTableSchema],
sql_client: SqlClientBase[Any],
params: Optional[SqlJobParams] = None,
) -> List[str]:
"""Drop and clone the table for supported destinations"""
sql: List[str] = []
for table in table_chain:
with sql_client.with_staging_dataset(staging=True):
staging_table_name = sql_client.make_qualified_table_name(table["name"])
table_name = sql_client.make_qualified_table_name(table["name"])
sql.append(f"DROP TABLE IF EXISTS {table_name};")
# recreate destination table with data cloned from staging table
sql.append(f"CREATE TABLE {table_name} CLONE {staging_table_name};")
return sql

@classmethod
def _generate_insert_sql(
cls,
table_chain: Sequence[TTableSchema],
sql_client: SqlClientBase[Any],
params: SqlJobParams = None,
) -> List[str]:
sql: List[str] = []
for table in table_chain:
Expand All @@ -98,6 +115,17 @@ def generate_sql(
)
return sql

@classmethod
def generate_sql(
cls,
table_chain: Sequence[TTableSchema],
sql_client: SqlClientBase[Any],
params: SqlJobParams = None,
) -> List[str]:
if params["replace"] and sql_client.capabilities.supports_clone_table:
return cls._generate_clone_sql(table_chain, sql_client)
return cls._generate_insert_sql(table_chain, sql_client, params)


class SqlMergeJob(SqlBaseJob):
"""Generates a list of sql statements that merge the data from staging dataset into destination dataset."""
Expand Down

0 comments on commit 64b7e2e

Please sign in to comment.