diff --git a/dlt/common/pipeline.py b/dlt/common/pipeline.py index df221ec703..0ae4148846 100644 --- a/dlt/common/pipeline.py +++ b/dlt/common/pipeline.py @@ -19,6 +19,7 @@ TypeVar, TypedDict, Mapping, + Literal, ) from typing_extensions import NotRequired @@ -47,6 +48,9 @@ from dlt.common.utils import RowCounts, merge_row_counts +TRefreshMode = Literal["full", "replace"] + + class _StepInfo(NamedTuple): pipeline: "SupportsPipeline" loads_ids: List[str] diff --git a/dlt/common/storages/load_package.py b/dlt/common/storages/load_package.py index 63409aa878..c53d576aaa 100644 --- a/dlt/common/storages/load_package.py +++ b/dlt/common/storages/load_package.py @@ -201,6 +201,7 @@ class PackageStorage: PACKAGE_COMPLETED_FILE_NAME = ( # completed package marker file, currently only to store data with os.stat "package_completed.json" ) + DROPPED_TABLES_FILE_NAME = "dropped_tables.json" def __init__(self, storage: FileStorage, initial_state: TLoadPackageState) -> None: """Creates storage that manages load packages with root at `storage` and initial package state `initial_state`""" @@ -381,6 +382,20 @@ def save_schema_updates(self, load_id: str, schema_update: TSchemaTables) -> Non ) as f: json.dump(schema_update, f) + def save_dropped_tables(self, load_id: str, dropped_tables: Sequence[str]) -> None: + with self.storage.open_file( + os.path.join(load_id, PackageStorage.DROPPED_TABLES_FILE_NAME), mode="wb" + ) as f: + json.dump(dropped_tables, f) + + def load_dropped_tables(self, load_id: str) -> List[str]: + try: + return json.loads( # type: ignore[no-any-return] + self.storage.load(os.path.join(load_id, PackageStorage.DROPPED_TABLES_FILE_NAME)) + ) + except FileNotFoundError: + return [] + # # Get package info # diff --git a/dlt/extract/extract.py b/dlt/extract/extract.py index 2ff813a2de..76bfaece0d 100644 --- a/dlt/extract/extract.py +++ b/dlt/extract/extract.py @@ -17,6 +17,7 @@ SupportsPipeline, WithStepInfo, reset_resource_state, + TRefreshMode, ) from dlt.common.runtime import signals from dlt.common.runtime.collector import Collector, NULL_COLLECTOR @@ -299,6 +300,7 @@ def _extract_single_source( data_tables = {t["name"]: t for t in schema.data_tables(include_incomplete=False)} tables_by_resources = utils.group_tables_by_resource(data_tables) for resource in source.resources.selected.values(): + # Truncate when write disposition is replace or refresh = 'replace' if ( resource.write_disposition != "replace" or resource.name in resources_with_items diff --git a/dlt/load/load.py b/dlt/load/load.py index 050e7bce67..a09eb4aead 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -1,15 +1,17 @@ import contextlib from functools import reduce import datetime # noqa: 251 -from typing import Dict, List, Optional, Tuple, Set, Iterator, Iterable +from typing import Dict, List, Optional, Tuple, Set, Iterator, Iterable, Sequence from concurrent.futures import Executor import os +from copy import deepcopy from dlt.common import sleep, logger from dlt.common.configuration import with_config, known_sections from dlt.common.configuration.accessors import config -from dlt.common.pipeline import LoadInfo, LoadMetrics, SupportsPipeline, WithStepInfo +from dlt.common.pipeline import LoadInfo, LoadMetrics, SupportsPipeline, WithStepInfo, TRefreshMode from dlt.common.schema.utils import get_top_level_table +from dlt.common.schema.typing import TTableSchema from dlt.common.storages.load_storage import LoadPackageInfo, ParsedLoadJobFileName, TJobState from dlt.common.runners import TRunMetrics, Runnable, workermethod, NullExecutor from dlt.common.runtime.collector import Collector, NULL_COLLECTOR @@ -71,6 +73,7 @@ def __init__( self.pool = NullExecutor() self.load_storage: LoadStorage = self.create_storage(is_storage_owner) self._loaded_packages: List[LoadPackageInfo] = [] + self._refreshed_tables: Set[str] = set() super().__init__() def create_storage(self, is_storage_owner: bool) -> LoadStorage: @@ -335,18 +338,66 @@ def complete_package(self, load_id: str, schema: Schema, aborted: bool = False) f"All jobs completed, archiving package {load_id} with aborted set to {aborted}" ) + def _refresh(self, dropped_tables: Sequence[str], schema: Schema) -> Tuple[Set[str], Set[str]]: + """When using refresh mode, drop tables if possible. + Returns a set of tables for main destination and staging destination + that could not be dropped and should be truncated instead + """ + # Exclude tables already dropped in the same load + drop_tables = set(dropped_tables) - self._refreshed_tables + if not drop_tables: + return set(), set() + # Clone schema and remove tables from it + dropped_schema = deepcopy(schema) + for table_name in drop_tables: + # pop not del: The table may not actually be in the schema if it's not being loaded again + dropped_schema.tables.pop(table_name, None) + dropped_schema.bump_version() + trunc_dest: Set[str] = set() + trunc_staging: Set[str] = set() + # Drop from destination and replace stored schema so tables will be re-created before load + with self.get_destination_client(dropped_schema) as job_client: + # TODO: SupportsSql mixin + if hasattr(job_client, "drop_tables"): + job_client.drop_tables(*drop_tables, replace_schema=True) + else: + # Tables need to be truncated instead of dropped + trunc_dest = drop_tables + + if self.staging_destination: + with self.get_staging_destination_client(dropped_schema) as staging_client: + if hasattr(staging_client, "drop_tables"): + staging_client.drop_tables(*drop_tables, replace_schema=True) + else: + trunc_staging = drop_tables + self._refreshed_tables.update(drop_tables) # Don't drop table again in same load + return trunc_dest, trunc_staging + def load_single_package(self, load_id: str, schema: Schema) -> None: new_jobs = self.get_new_jobs_info(load_id) + + dropped_tables = self.load_storage.normalized_packages.load_dropped_tables(load_id) + # Drop tables before loading if refresh mode is set + truncate_dest, truncate_staging = self._refresh(dropped_tables, schema) + # initialize analytical storage ie. create dataset required by passed schema with self.get_destination_client(schema) as job_client: if (expected_update := self.load_storage.begin_schema_update(load_id)) is not None: # init job client + def should_truncate(table: TTableSchema) -> bool: + # When destination doesn't support dropping refreshed tables (i.e. not SQL based) they should be truncated + return ( + job_client.should_truncate_table_before_load(table) + or table["name"] in truncate_dest + ) + applied_update = init_client( job_client, schema, new_jobs, expected_update, - job_client.should_truncate_table_before_load, + # job_client.should_truncate_table_before_load, + should_truncate, ( job_client.should_load_data_to_staging_dataset if isinstance(job_client, WithStagingDataset) @@ -360,13 +411,23 @@ def load_single_package(self, load_id: str, schema: Schema) -> None: f"Job client for destination {self.destination.destination_type} does not" " implement SupportsStagingDestination" ) + + def should_truncate_staging(table: TTableSchema) -> bool: + return ( + job_client.should_truncate_table_before_load_on_staging_destination( + table + ) + or table["name"] in truncate_staging + ) + with self.get_staging_destination_client(schema) as staging_client: init_client( staging_client, schema, new_jobs, expected_update, - job_client.should_truncate_table_before_load_on_staging_destination, + # job_client.should_truncate_table_before_load_on_staging_destination, + should_truncate_staging, job_client.should_load_data_to_staging_dataset_on_staging_destination, ) diff --git a/dlt/normalize/normalize.py b/dlt/normalize/normalize.py index c5762af680..b2965b860f 100644 --- a/dlt/normalize/normalize.py +++ b/dlt/normalize/normalize.py @@ -320,6 +320,9 @@ def spool_files( self.load_storage.new_packages.save_schema_updates( load_id, merge_schema_updates(schema_updates) ) + self.load_storage.new_packages.save_dropped_tables( + load_id, self.normalize_storage.extracted_packages.load_dropped_tables(load_id) + ) # files must be renamed and deleted together so do not attempt that when process is about to be terminated signals.raise_if_signalled() logger.info("Committing storage, do not kill this process") diff --git a/dlt/pipeline/__init__.py b/dlt/pipeline/__init__.py index 0281ce31ef..16314c0ec5 100644 --- a/dlt/pipeline/__init__.py +++ b/dlt/pipeline/__init__.py @@ -8,13 +8,12 @@ from dlt.common.configuration.container import Container from dlt.common.configuration.inject import get_orig_args, last_config from dlt.common.destination import TLoaderFileFormat, Destination, TDestinationReferenceArg -from dlt.common.pipeline import LoadInfo, PipelineContext, get_dlt_pipelines_dir +from dlt.common.pipeline import LoadInfo, PipelineContext, get_dlt_pipelines_dir, TRefreshMode from dlt.pipeline.configuration import PipelineConfiguration, ensure_correct_pipeline_kwargs from dlt.pipeline.pipeline import Pipeline from dlt.pipeline.progress import _from_name as collector_from_name, TCollectorArg, _NULL_COLLECTOR from dlt.pipeline.warnings import credentials_argument_deprecated, full_refresh_argument_deprecated -from dlt.pipeline.typing import TRefreshMode @overload diff --git a/dlt/pipeline/configuration.py b/dlt/pipeline/configuration.py index 361f627dfc..0e79e462f3 100644 --- a/dlt/pipeline/configuration.py +++ b/dlt/pipeline/configuration.py @@ -5,7 +5,7 @@ from dlt.common.typing import AnyFun, TSecretValue from dlt.common.utils import digest256 from dlt.common.data_writers import TLoaderFileFormat -from dlt.pipeline.typing import TRefreshMode +from dlt.common.pipeline import TRefreshMode @configspec diff --git a/dlt/pipeline/helpers.py b/dlt/pipeline/helpers.py index e5c2e25cf8..5ee25808d8 100644 --- a/dlt/pipeline/helpers.py +++ b/dlt/pipeline/helpers.py @@ -100,6 +100,8 @@ def __init__( state_paths: TAnyJsonPath = (), drop_all: bool = False, state_only: bool = False, + tables_only: bool = False, + extract_only: bool = False, ) -> None: """ Args: @@ -109,7 +111,10 @@ def __init__( state_paths: JSON path(s) relative to the source state to drop drop_all: Drop all resources and tables in the schema (supersedes `resources` list) state_only: Drop only state, not tables + extract_only: Only apply changes locally, but do not normalize and load to destination + """ + self.extract_only = extract_only self.pipeline = pipeline if isinstance(resources, str): resources = [resources] @@ -122,7 +127,7 @@ def __init__( self.schema = pipeline.schemas[schema_name or pipeline.default_schema_name].clone() self.schema_tables = self.schema.tables self.drop_tables = not state_only - self.drop_state = True + self.drop_state = not tables_only self.state_paths_to_drop = compile_paths(state_paths) resources = set(resources) @@ -178,13 +183,14 @@ def is_empty(self) -> bool: and len(self.info["resource_states"]) == 0 ) - def _drop_destination_tables(self) -> None: + def _drop_destination_tables(self, allow_schema_tables: bool = False) -> None: table_names = [tbl["name"] for tbl in self.tables_to_drop] - for table_name in table_names: - assert table_name not in self.schema._schema_tables, ( - f"You are dropping table {table_name} in {self.schema.name} but it is still present" - " in the schema" - ) + if not allow_schema_tables: + for table_name in table_names: + assert table_name not in self.schema._schema_tables, ( + f"You are dropping table {table_name} in {self.schema.name} but it is still" + " present in the schema" + ) with self.pipeline._sql_job_client(self.schema) as client: client.drop_tables(*table_names, replace_schema=True) # also delete staging but ignore if staging does not exist @@ -241,6 +247,9 @@ def _drop_state_keys(self) -> None: except ContextDefaultCannotBeCreated: pass + def _save_local_schema(self) -> None: + self.pipeline.schemas.save_schema(self.schema) + def __call__(self) -> None: if ( self.pipeline.has_pending_data @@ -255,12 +264,15 @@ def __call__(self) -> None: if self.drop_tables: self._delete_pipeline_tables() - self._drop_destination_tables() + if not self.extract_only: + self._drop_destination_tables() if self.drop_tables: - self.pipeline.schemas.save_schema(self.schema) + self._save_local_schema() if self.drop_state: self._drop_state_keys() # Send updated state to destination + if self.extract_only: + return self.pipeline.normalize() try: self.pipeline.load(raise_on_failed_jobs=True) diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 87cdb33727..db59cb4ad8 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -88,6 +88,7 @@ StateInjectableContext, TStepMetrics, WithStepInfo, + TRefreshMode, ) from dlt.common.schema import Schema from dlt.common.utils import is_interactive @@ -113,6 +114,7 @@ PipelineNotActive, PipelineStepFailed, SqlClientNotAvailable, + PipelineNeverRan, ) from dlt.pipeline.trace import ( PipelineTrace, @@ -124,7 +126,7 @@ end_trace_step, end_trace, ) -from dlt.pipeline.typing import TPipelineStep, TRefreshMode +from dlt.pipeline.typing import TPipelineStep from dlt.pipeline.state_sync import ( STATE_ENGINE_VERSION, bump_version_if_modified, @@ -135,7 +137,7 @@ json_decode_state, ) from dlt.pipeline.warnings import credentials_argument_deprecated -from dlt.pipeline.helpers import drop as drop_command +from dlt.pipeline.helpers import DropCommand def with_state_sync(may_extract_state: bool = False) -> Callable[[TFun], TFun]: @@ -390,8 +392,6 @@ def extract( schema_contract: TSchemaContract = None, ) -> ExtractInfo: """Extracts the `data` and prepare it for the normalization. Does not require destination or credentials to be configured. See `run` method for the arguments' description.""" - if self.refresh == "full": - drop_command(self, drop_all=True, state_paths="*") # create extract storage to which all the sources will be extracted extract_step = Extract( @@ -416,7 +416,36 @@ def extract( ): if source.exhausted: raise SourceExhausted(source.name) - self._extract_source(extract_step, source, max_parallel_items, workers) + dropped_tables = [] + if not self.first_run: + if self.refresh == "full": + # Drop all tables from all resources and all source state paths + d = DropCommand( + self, + drop_all=True, + extract_only=True, # Only modify local state/schema, destination drop tables is done in load step + state_paths="*", + schema_name=source.schema.name, + ) + dropped_tables = d.info["tables"] + d() + elif self.refresh == "replace": + # Drop tables from resources being currently extracted + d = DropCommand( + self, + resources=source.resources.extracted, + extract_only=True, + schema_name=source.schema.name, + ) + dropped_tables = d.info["tables"] + d() + load_id = self._extract_source( + extract_step, source, max_parallel_items, workers + ) + # Save the tables that were dropped locally (to be dropped from destination in load step) + extract_step.extract_storage.new_packages.save_dropped_tables( + load_id, dropped_tables + ) # extract state if self.config.restore_from_destination: # this will update state version hash so it will not be extracted again by with_state_sync diff --git a/dlt/pipeline/typing.py b/dlt/pipeline/typing.py index ec0eca4685..f0192a504d 100644 --- a/dlt/pipeline/typing.py +++ b/dlt/pipeline/typing.py @@ -1,5 +1,3 @@ from typing import Literal TPipelineStep = Literal["sync", "extract", "normalize", "load"] - -TRefreshMode = Literal["full", "replace"] diff --git a/tests/pipeline/test_refresh_modes.py b/tests/pipeline/test_refresh_modes.py new file mode 100644 index 0000000000..11f0289723 --- /dev/null +++ b/tests/pipeline/test_refresh_modes.py @@ -0,0 +1,139 @@ +import pytest + +import dlt +from dlt.common.pipeline import resource_state +from dlt.destinations.exceptions import DatabaseUndefinedRelation + +from tests.utils import clean_test_storage, preserve_environ +from tests.pipeline.utils import assert_load_info + + +def test_refresh_full(): + first_run = True + + @dlt.source + def my_source(): + @dlt.resource + def some_data_1(): + # Set some source and resource state + if first_run: + dlt.state()["source_key_1"] = "source_value_1" + resource_state("some_data_1")["resource_key_1"] = "resource_value_1" + resource_state("some_data_1")["resource_key_2"] = "resource_value_2" + else: + # State is cleared for all resources on second run + assert not dlt.state()["resources"] + assert "source_key_1" not in dlt.state() + assert "source_key_2" not in dlt.state() + assert "source_key_3" not in dlt.state() + yield {"id": 1, "name": "John"} + yield {"id": 2, "name": "Jane"} + + @dlt.resource + def some_data_2(): + if first_run: + dlt.state()["source_key_2"] = "source_value_2" + resource_state("some_data_2")["resource_key_3"] = "resource_value_3" + resource_state("some_data_2")["resource_key_4"] = "resource_value_4" + yield {"id": 3, "name": "Joe"} + yield {"id": 4, "name": "Jill"} + + @dlt.resource + def some_data_3(): + if first_run: + dlt.state()["source_key_3"] = "source_value_3" + resource_state("some_data_3")["resource_key_5"] = "resource_value_5" + yield {"id": 5, "name": "Jack"} + yield {"id": 6, "name": "Jill"} + + return [some_data_1, some_data_2, some_data_3] + + # First run pipeline with load to destination so tables are created + pipeline = dlt.pipeline( + "refresh_full_test", destination="duckdb", refresh="full", dataset_name="refresh_full_test" + ) + + info = pipeline.run(my_source()) + assert_load_info(info) + + # Second run of pipeline with only selected resources + first_run = False + info = pipeline.run(my_source().with_resources("some_data_1", "some_data_2")) + + # Confirm resource tables not selected on second run got wiped + with pytest.raises(DatabaseUndefinedRelation): + with pipeline.sql_client() as client: + result = client.execute_sql("SELECT * FROM some_data_3") + + with pipeline.sql_client() as client: + result = client.execute_sql("SELECT id FROM some_data_1 ORDER BY id") + assert result == [(1,), (2,)] + + +def test_refresh_replace(): + first_run = True + + @dlt.source + def my_source(): + @dlt.resource + def some_data_1(): + # Set some source and resource state + state = dlt.state() + if first_run: + dlt.state()["source_key_1"] = "source_value_1" + resource_state("some_data_1")["resource_key_1"] = "resource_value_1" + resource_state("some_data_1")["resource_key_2"] = "resource_value_2" + else: + # State is cleared for all resources on second run + assert "source_key_1" in dlt.state() + assert "source_key_2" in dlt.state() + assert "source_key_3" in dlt.state() + # Resource 3 is not wiped + assert dlt.state()["resources"] == { + "some_data_3": {"resource_key_5": "resource_value_5"} + } + yield {"id": 1, "name": "John"} + yield {"id": 2, "name": "Jane"} + + @dlt.resource + def some_data_2(): + if first_run: + dlt.state()["source_key_2"] = "source_value_2" + resource_state("some_data_2")["resource_key_3"] = "resource_value_3" + resource_state("some_data_2")["resource_key_4"] = "resource_value_4" + yield {"id": 3, "name": "Joe"} + yield {"id": 4, "name": "Jill"} + + @dlt.resource + def some_data_3(): + if first_run: + dlt.state()["source_key_3"] = "source_value_3" + resource_state("some_data_3")["resource_key_5"] = "resource_value_5" + yield {"id": 5, "name": "Jack"} + yield {"id": 6, "name": "Jill"} + + return [some_data_1, some_data_2, some_data_3] + + # First run pipeline with load to destination so tables are created + pipeline = dlt.pipeline( + "refresh_full_test", + destination="duckdb", + refresh="replace", + dataset_name="refresh_full_test", + ) + + info = pipeline.run(my_source()) + assert_load_info(info) + + # Second run of pipeline with only selected resources + first_run = False + info = pipeline.run(my_source().with_resources("some_data_1", "some_data_2")) + + # Confirm resource tables not selected on second run got wiped + with pipeline.sql_client() as client: + result = client.execute_sql("SELECT id FROM some_data_3 ORDER BY id") + assert result == [(5,), (6,)] + + with pipeline.sql_client() as client: + result = client.execute_sql("SELECT id FROM some_data_1 ORDER BY id") + assert result == [(1,), (2,)]