diff --git a/dlt/common/destination/capabilities.py b/dlt/common/destination/capabilities.py index 10d09d52b3..b891d4b31f 100644 --- a/dlt/common/destination/capabilities.py +++ b/dlt/common/destination/capabilities.py @@ -53,6 +53,8 @@ class DestinationCapabilitiesContext(ContainerInjectableContext): timestamp_precision: int = 6 max_rows_per_insert: Optional[int] = None supports_multiple_statements: bool = True + supports_clone_table: bool = False + """Destination supports CREATE TABLE ... CLONE ... statements""" # do not allow to create default value, destination caps must be always explicitly inserted into container can_create_default: ClassVar[bool] = False diff --git a/dlt/destinations/impl/bigquery/__init__.py b/dlt/destinations/impl/bigquery/__init__.py index 1304bd72bb..6d1491817a 100644 --- a/dlt/destinations/impl/bigquery/__init__.py +++ b/dlt/destinations/impl/bigquery/__init__.py @@ -20,5 +20,6 @@ def capabilities() -> DestinationCapabilitiesContext: caps.max_text_data_type_length = 10 * 1024 * 1024 caps.is_max_text_data_type_length_in_bytes = True caps.supports_ddl_transactions = False + caps.supports_clone_table = True return caps diff --git a/dlt/destinations/impl/bigquery/bigquery.py b/dlt/destinations/impl/bigquery/bigquery.py index 254184b96d..1058b1d2c9 100644 --- a/dlt/destinations/impl/bigquery/bigquery.py +++ b/dlt/destinations/impl/bigquery/bigquery.py @@ -30,6 +30,7 @@ from dlt.destinations.impl.bigquery.configuration import BigQueryClientConfiguration from dlt.destinations.impl.bigquery.sql_client import BigQuerySqlClient, BQ_TERMINAL_REASONS from dlt.destinations.job_client_impl import SqlJobClientWithStaging +from dlt.destinations.sql_jobs import SqlMergeJob, SqlJobParams from dlt.destinations.job_impl import NewReferenceJob from dlt.destinations.sql_client import SqlClientBase from dlt.destinations.sql_jobs import SqlMergeJob, SqlStagingCopyJob, SqlJobParams @@ -149,28 +150,6 @@ def gen_key_table_clauses( return sql -class BigqueryStagingCopyJob(SqlStagingCopyJob): - @classmethod - def generate_sql( - cls, - table_chain: Sequence[TTableSchema], - sql_client: SqlClientBase[Any], - params: Optional[SqlJobParams] = None, - ) -> List[str]: - sql: List[str] = [] - for table in table_chain: - with sql_client.with_staging_dataset(staging=True): - staging_table_name = sql_client.make_qualified_table_name(table["name"]) - table_name = sql_client.make_qualified_table_name(table["name"]) - sql.extend( - ( - f"DROP TABLE IF EXISTS {table_name};", - f"CREATE TABLE {table_name} CLONE {staging_table_name};", - ) - ) - return sql - - class BigQueryClient(SqlJobClientWithStaging, SupportsStagingDestination): capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() @@ -190,13 +169,6 @@ def __init__(self, schema: Schema, config: BigQueryClientConfiguration) -> None: def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: return [BigQueryMergeJob.from_table_chain(table_chain, self.sql_client)] - def _create_replace_followup_jobs( - self, table_chain: Sequence[TTableSchema] - ) -> List[NewLoadJob]: - if self.config.replace_strategy == "staging-optimized": - return [BigqueryStagingCopyJob.from_table_chain(table_chain, self.sql_client)] - return super()._create_replace_followup_jobs(table_chain) - def restore_file_load(self, file_path: str) -> LoadJob: """Returns a completed SqlLoadJob or restored BigQueryLoadJob @@ -280,9 +252,9 @@ def _get_table_update_sql( elif (c := partition_list[0])["data_type"] == "date": sql[0] = f"{sql[0]}\nPARTITION BY {self.capabilities.escape_identifier(c['name'])}" elif (c := partition_list[0])["data_type"] == "timestamp": - sql[0] = ( - f"{sql[0]}\nPARTITION BY DATE({self.capabilities.escape_identifier(c['name'])})" - ) + sql[ + 0 + ] = f"{sql[0]}\nPARTITION BY DATE({self.capabilities.escape_identifier(c['name'])})" # Automatic partitioning of an INT64 type requires us to be prescriptive - we treat the column as a UNIX timestamp. # This is due to the bounds requirement of GENERATE_ARRAY function for partitioning. # The 10,000 partitions limit makes it infeasible to cover the entire `bigint` range. @@ -300,9 +272,7 @@ def _get_table_update_sql( def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: name = self.capabilities.escape_identifier(c["name"]) - return ( - f"{name} {self.type_mapper.to_db_type(c, table_format)} {self._gen_not_null(c.get('nullable', True))}" - ) + return f"{name} {self.type_mapper.to_db_type(c, table_format)} {self._gen_not_null(c.get('nullable', True))}" def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns]: schema_table: TTableSchemaColumns = {} diff --git a/dlt/destinations/impl/databricks/__init__.py b/dlt/destinations/impl/databricks/__init__.py index 97836a8ce2..b2e79279d6 100644 --- a/dlt/destinations/impl/databricks/__init__.py +++ b/dlt/destinations/impl/databricks/__init__.py @@ -26,4 +26,5 @@ def capabilities() -> DestinationCapabilitiesContext: # caps.supports_transactions = False caps.alter_add_multi_column = True caps.supports_multiple_statements = False + caps.supports_clone_table = True return caps diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index 4141ff33df..384daf82b0 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -28,7 +28,7 @@ from dlt.destinations.impl.databricks import capabilities from dlt.destinations.impl.databricks.configuration import DatabricksClientConfiguration from dlt.destinations.impl.databricks.sql_client import DatabricksSqlClient -from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlMergeJob, SqlJobParams +from dlt.destinations.sql_jobs import SqlMergeJob, SqlJobParams from dlt.destinations.job_impl import NewReferenceJob from dlt.destinations.sql_client import SqlClientBase from dlt.destinations.type_mapping import TypeMapper @@ -195,25 +195,6 @@ def exception(self) -> str: raise NotImplementedError() -class DatabricksStagingCopyJob(SqlStagingCopyJob): - @classmethod - def generate_sql( - cls, - table_chain: Sequence[TTableSchema], - sql_client: SqlClientBase[Any], - params: Optional[SqlJobParams] = None, - ) -> List[str]: - sql: List[str] = [] - for table in table_chain: - with sql_client.with_staging_dataset(staging=True): - staging_table_name = sql_client.make_qualified_table_name(table["name"]) - table_name = sql_client.make_qualified_table_name(table["name"]) - sql.append(f"DROP TABLE IF EXISTS {table_name};") - # recreate destination table with data cloned from staging table - sql.append(f"CREATE TABLE {table_name} CLONE {staging_table_name};") - return sql - - class DatabricksMergeJob(SqlMergeJob): @classmethod def _to_temp_table(cls, select_sql: str, temp_table_name: str) -> str: @@ -272,13 +253,6 @@ def _make_add_column_sql( # Override because databricks requires multiple columns in a single ADD COLUMN clause return ["ADD COLUMN\n" + ",\n".join(self._get_column_def_sql(c) for c in new_columns)] - def _create_replace_followup_jobs( - self, table_chain: Sequence[TTableSchema] - ) -> List[NewLoadJob]: - if self.config.replace_strategy == "staging-optimized": - return [DatabricksStagingCopyJob.from_table_chain(table_chain, self.sql_client)] - return super()._create_replace_followup_jobs(table_chain) - def _get_table_update_sql( self, table_name: str, diff --git a/dlt/destinations/impl/snowflake/__init__.py b/dlt/destinations/impl/snowflake/__init__.py index d6bebd3fdd..dde4d5a382 100644 --- a/dlt/destinations/impl/snowflake/__init__.py +++ b/dlt/destinations/impl/snowflake/__init__.py @@ -21,4 +21,5 @@ def capabilities() -> DestinationCapabilitiesContext: caps.is_max_text_data_type_length_in_bytes = True caps.supports_ddl_transactions = True caps.alter_add_multi_column = True + caps.supports_clone_table = True return caps diff --git a/dlt/destinations/impl/snowflake/snowflake.py b/dlt/destinations/impl/snowflake/snowflake.py index 67df78c138..fb51ab9d36 100644 --- a/dlt/destinations/impl/snowflake/snowflake.py +++ b/dlt/destinations/impl/snowflake/snowflake.py @@ -27,7 +27,7 @@ from dlt.destinations.impl.snowflake import capabilities from dlt.destinations.impl.snowflake.configuration import SnowflakeClientConfiguration from dlt.destinations.impl.snowflake.sql_client import SnowflakeSqlClient -from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlJobParams +from dlt.destinations.sql_jobs import SqlJobParams from dlt.destinations.impl.snowflake.sql_client import SnowflakeSqlClient from dlt.destinations.job_impl import NewReferenceJob from dlt.destinations.sql_client import SqlClientBase @@ -175,13 +175,15 @@ def __init__( f'PUT file://{file_path} @{stage_name}/"{load_id}" OVERWRITE = TRUE,' " AUTO_COMPRESS = FALSE" ) - client.execute_sql(f"""COPY INTO {qualified_table_name} + client.execute_sql( + f"""COPY INTO {qualified_table_name} {from_clause} {files_clause} {credentials_clause} FILE_FORMAT = {source_format} MATCH_BY_COLUMN_NAME='CASE_INSENSITIVE' - """) + """ + ) if stage_file_path and not keep_staged_files: client.execute_sql(f"REMOVE {stage_file_path}") @@ -192,25 +194,6 @@ def exception(self) -> str: raise NotImplementedError() -class SnowflakeStagingCopyJob(SqlStagingCopyJob): - @classmethod - def generate_sql( - cls, - table_chain: Sequence[TTableSchema], - sql_client: SqlClientBase[Any], - params: Optional[SqlJobParams] = None, - ) -> List[str]: - sql: List[str] = [] - for table in table_chain: - with sql_client.with_staging_dataset(staging=True): - staging_table_name = sql_client.make_qualified_table_name(table["name"]) - table_name = sql_client.make_qualified_table_name(table["name"]) - sql.append(f"DROP TABLE IF EXISTS {table_name};") - # recreate destination table with data cloned from staging table - sql.append(f"CREATE TABLE {table_name} CLONE {staging_table_name};") - return sql - - class SnowflakeClient(SqlJobClientWithStaging, SupportsStagingDestination): capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() @@ -250,13 +233,6 @@ def _make_add_column_sql( + ",\n".join(self._get_column_def_sql(c, table_format) for c in new_columns) ] - def _create_replace_followup_jobs( - self, table_chain: Sequence[TTableSchema] - ) -> List[NewLoadJob]: - if self.config.replace_strategy == "staging-optimized": - return [SnowflakeStagingCopyJob.from_table_chain(table_chain, self.sql_client)] - return super()._create_replace_followup_jobs(table_chain) - def _get_table_update_sql( self, table_name: str, diff --git a/dlt/destinations/sql_jobs.py b/dlt/destinations/sql_jobs.py index 899947313d..d0911d0bea 100644 --- a/dlt/destinations/sql_jobs.py +++ b/dlt/destinations/sql_jobs.py @@ -74,11 +74,28 @@ class SqlStagingCopyJob(SqlBaseJob): failed_text: str = "Tried to generate a staging copy sql job for the following tables:" @classmethod - def generate_sql( + def _generate_clone_sql( cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any], - params: Optional[SqlJobParams] = None, + ) -> List[str]: + """Drop and clone the table for supported destinations""" + sql: List[str] = [] + for table in table_chain: + with sql_client.with_staging_dataset(staging=True): + staging_table_name = sql_client.make_qualified_table_name(table["name"]) + table_name = sql_client.make_qualified_table_name(table["name"]) + sql.append(f"DROP TABLE IF EXISTS {table_name};") + # recreate destination table with data cloned from staging table + sql.append(f"CREATE TABLE {table_name} CLONE {staging_table_name};") + return sql + + @classmethod + def _generate_insert_sql( + cls, + table_chain: Sequence[TTableSchema], + sql_client: SqlClientBase[Any], + params: SqlJobParams = None, ) -> List[str]: sql: List[str] = [] for table in table_chain: @@ -98,6 +115,17 @@ def generate_sql( ) return sql + @classmethod + def generate_sql( + cls, + table_chain: Sequence[TTableSchema], + sql_client: SqlClientBase[Any], + params: SqlJobParams = None, + ) -> List[str]: + if params["replace"] and sql_client.capabilities.supports_clone_table: + return cls._generate_clone_sql(table_chain, sql_client) + return cls._generate_insert_sql(table_chain, sql_client, params) + class SqlMergeJob(SqlBaseJob): """Generates a list of sql statements that merge the data from staging dataset into destination dataset."""