Skip to content

Commit

Permalink
Refresh modes with dropped_tables file
Browse files Browse the repository at this point in the history
  • Loading branch information
steinitzu committed Mar 7, 2024
1 parent 24181da commit 4270770
Show file tree
Hide file tree
Showing 11 changed files with 285 additions and 23 deletions.
4 changes: 4 additions & 0 deletions dlt/common/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
TypeVar,
TypedDict,
Mapping,
Literal,
)
from typing_extensions import NotRequired

Expand Down Expand Up @@ -47,6 +48,9 @@
from dlt.common.utils import RowCounts, merge_row_counts


TRefreshMode = Literal["full", "replace"]


class _StepInfo(NamedTuple):
pipeline: "SupportsPipeline"
loads_ids: List[str]
Expand Down
15 changes: 15 additions & 0 deletions dlt/common/storages/load_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ class PackageStorage:
PACKAGE_COMPLETED_FILE_NAME = ( # completed package marker file, currently only to store data with os.stat
"package_completed.json"
)
DROPPED_TABLES_FILE_NAME = "dropped_tables.json"

def __init__(self, storage: FileStorage, initial_state: TLoadPackageState) -> None:
"""Creates storage that manages load packages with root at `storage` and initial package state `initial_state`"""
Expand Down Expand Up @@ -381,6 +382,20 @@ def save_schema_updates(self, load_id: str, schema_update: TSchemaTables) -> Non
) as f:
json.dump(schema_update, f)

def save_dropped_tables(self, load_id: str, dropped_tables: Sequence[str]) -> None:
with self.storage.open_file(
os.path.join(load_id, PackageStorage.DROPPED_TABLES_FILE_NAME), mode="wb"
) as f:
json.dump(dropped_tables, f)

def load_dropped_tables(self, load_id: str) -> List[str]:
try:
return json.loads( # type: ignore[no-any-return]
self.storage.load(os.path.join(load_id, PackageStorage.DROPPED_TABLES_FILE_NAME))
)
except FileNotFoundError:
return []

#
# Get package info
#
Expand Down
2 changes: 2 additions & 0 deletions dlt/extract/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
SupportsPipeline,
WithStepInfo,
reset_resource_state,
TRefreshMode,
)
from dlt.common.runtime import signals
from dlt.common.runtime.collector import Collector, NULL_COLLECTOR
Expand Down Expand Up @@ -299,6 +300,7 @@ def _extract_single_source(
data_tables = {t["name"]: t for t in schema.data_tables(include_incomplete=False)}
tables_by_resources = utils.group_tables_by_resource(data_tables)
for resource in source.resources.selected.values():
# Truncate when write disposition is replace or refresh = 'replace'
if (
resource.write_disposition != "replace"
or resource.name in resources_with_items
Expand Down
69 changes: 65 additions & 4 deletions dlt/load/load.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import contextlib
from functools import reduce
import datetime # noqa: 251
from typing import Dict, List, Optional, Tuple, Set, Iterator, Iterable
from typing import Dict, List, Optional, Tuple, Set, Iterator, Iterable, Sequence
from concurrent.futures import Executor
import os
from copy import deepcopy

from dlt.common import sleep, logger
from dlt.common.configuration import with_config, known_sections
from dlt.common.configuration.accessors import config
from dlt.common.pipeline import LoadInfo, LoadMetrics, SupportsPipeline, WithStepInfo
from dlt.common.pipeline import LoadInfo, LoadMetrics, SupportsPipeline, WithStepInfo, TRefreshMode
from dlt.common.schema.utils import get_top_level_table
from dlt.common.schema.typing import TTableSchema
from dlt.common.storages.load_storage import LoadPackageInfo, ParsedLoadJobFileName, TJobState
from dlt.common.runners import TRunMetrics, Runnable, workermethod, NullExecutor
from dlt.common.runtime.collector import Collector, NULL_COLLECTOR
Expand Down Expand Up @@ -71,6 +73,7 @@ def __init__(
self.pool = NullExecutor()
self.load_storage: LoadStorage = self.create_storage(is_storage_owner)
self._loaded_packages: List[LoadPackageInfo] = []
self._refreshed_tables: Set[str] = set()
super().__init__()

def create_storage(self, is_storage_owner: bool) -> LoadStorage:
Expand Down Expand Up @@ -335,18 +338,66 @@ def complete_package(self, load_id: str, schema: Schema, aborted: bool = False)
f"All jobs completed, archiving package {load_id} with aborted set to {aborted}"
)

def _refresh(self, dropped_tables: Sequence[str], schema: Schema) -> Tuple[Set[str], Set[str]]:
"""When using refresh mode, drop tables if possible.
Returns a set of tables for main destination and staging destination
that could not be dropped and should be truncated instead
"""
# Exclude tables already dropped in the same load
drop_tables = set(dropped_tables) - self._refreshed_tables
if not drop_tables:
return set(), set()
# Clone schema and remove tables from it
dropped_schema = deepcopy(schema)
for table_name in drop_tables:
# pop not del: The table may not actually be in the schema if it's not being loaded again
dropped_schema.tables.pop(table_name, None)
dropped_schema.bump_version()
trunc_dest: Set[str] = set()
trunc_staging: Set[str] = set()
# Drop from destination and replace stored schema so tables will be re-created before load
with self.get_destination_client(dropped_schema) as job_client:
# TODO: SupportsSql mixin
if hasattr(job_client, "drop_tables"):
job_client.drop_tables(*drop_tables, replace_schema=True)
else:
# Tables need to be truncated instead of dropped
trunc_dest = drop_tables

if self.staging_destination:
with self.get_staging_destination_client(dropped_schema) as staging_client:
if hasattr(staging_client, "drop_tables"):
staging_client.drop_tables(*drop_tables, replace_schema=True)
else:
trunc_staging = drop_tables
self._refreshed_tables.update(drop_tables) # Don't drop table again in same load
return trunc_dest, trunc_staging

def load_single_package(self, load_id: str, schema: Schema) -> None:
new_jobs = self.get_new_jobs_info(load_id)

dropped_tables = self.load_storage.normalized_packages.load_dropped_tables(load_id)
# Drop tables before loading if refresh mode is set
truncate_dest, truncate_staging = self._refresh(dropped_tables, schema)

# initialize analytical storage ie. create dataset required by passed schema
with self.get_destination_client(schema) as job_client:
if (expected_update := self.load_storage.begin_schema_update(load_id)) is not None:
# init job client
def should_truncate(table: TTableSchema) -> bool:
# When destination doesn't support dropping refreshed tables (i.e. not SQL based) they should be truncated
return (
job_client.should_truncate_table_before_load(table)
or table["name"] in truncate_dest
)

applied_update = init_client(
job_client,
schema,
new_jobs,
expected_update,
job_client.should_truncate_table_before_load,
# job_client.should_truncate_table_before_load,
should_truncate,
(
job_client.should_load_data_to_staging_dataset
if isinstance(job_client, WithStagingDataset)
Expand All @@ -360,13 +411,23 @@ def load_single_package(self, load_id: str, schema: Schema) -> None:
f"Job client for destination {self.destination.destination_type} does not"
" implement SupportsStagingDestination"
)

def should_truncate_staging(table: TTableSchema) -> bool:
return (
job_client.should_truncate_table_before_load_on_staging_destination(
table
)
or table["name"] in truncate_staging
)

with self.get_staging_destination_client(schema) as staging_client:
init_client(
staging_client,
schema,
new_jobs,
expected_update,
job_client.should_truncate_table_before_load_on_staging_destination,
# job_client.should_truncate_table_before_load_on_staging_destination,
should_truncate_staging,
job_client.should_load_data_to_staging_dataset_on_staging_destination,
)

Expand Down
3 changes: 3 additions & 0 deletions dlt/normalize/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,9 @@ def spool_files(
self.load_storage.new_packages.save_schema_updates(
load_id, merge_schema_updates(schema_updates)
)
self.load_storage.new_packages.save_dropped_tables(
load_id, self.normalize_storage.extracted_packages.load_dropped_tables(load_id)
)
# files must be renamed and deleted together so do not attempt that when process is about to be terminated
signals.raise_if_signalled()
logger.info("Committing storage, do not kill this process")
Expand Down
3 changes: 1 addition & 2 deletions dlt/pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@
from dlt.common.configuration.container import Container
from dlt.common.configuration.inject import get_orig_args, last_config
from dlt.common.destination import TLoaderFileFormat, Destination, TDestinationReferenceArg
from dlt.common.pipeline import LoadInfo, PipelineContext, get_dlt_pipelines_dir
from dlt.common.pipeline import LoadInfo, PipelineContext, get_dlt_pipelines_dir, TRefreshMode

from dlt.pipeline.configuration import PipelineConfiguration, ensure_correct_pipeline_kwargs
from dlt.pipeline.pipeline import Pipeline
from dlt.pipeline.progress import _from_name as collector_from_name, TCollectorArg, _NULL_COLLECTOR
from dlt.pipeline.warnings import credentials_argument_deprecated, full_refresh_argument_deprecated
from dlt.pipeline.typing import TRefreshMode


@overload
Expand Down
2 changes: 1 addition & 1 deletion dlt/pipeline/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dlt.common.typing import AnyFun, TSecretValue
from dlt.common.utils import digest256
from dlt.common.data_writers import TLoaderFileFormat
from dlt.pipeline.typing import TRefreshMode
from dlt.common.pipeline import TRefreshMode


@configspec
Expand Down
30 changes: 21 additions & 9 deletions dlt/pipeline/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ def __init__(
state_paths: TAnyJsonPath = (),
drop_all: bool = False,
state_only: bool = False,
tables_only: bool = False,
extract_only: bool = False,
) -> None:
"""
Args:
Expand All @@ -109,7 +111,10 @@ def __init__(
state_paths: JSON path(s) relative to the source state to drop
drop_all: Drop all resources and tables in the schema (supersedes `resources` list)
state_only: Drop only state, not tables
extract_only: Only apply changes locally, but do not normalize and load to destination
"""
self.extract_only = extract_only
self.pipeline = pipeline
if isinstance(resources, str):
resources = [resources]
Expand All @@ -122,7 +127,7 @@ def __init__(
self.schema = pipeline.schemas[schema_name or pipeline.default_schema_name].clone()
self.schema_tables = self.schema.tables
self.drop_tables = not state_only
self.drop_state = True
self.drop_state = not tables_only
self.state_paths_to_drop = compile_paths(state_paths)

resources = set(resources)
Expand Down Expand Up @@ -178,13 +183,14 @@ def is_empty(self) -> bool:
and len(self.info["resource_states"]) == 0
)

def _drop_destination_tables(self) -> None:
def _drop_destination_tables(self, allow_schema_tables: bool = False) -> None:
table_names = [tbl["name"] for tbl in self.tables_to_drop]
for table_name in table_names:
assert table_name not in self.schema._schema_tables, (
f"You are dropping table {table_name} in {self.schema.name} but it is still present"
" in the schema"
)
if not allow_schema_tables:
for table_name in table_names:
assert table_name not in self.schema._schema_tables, (
f"You are dropping table {table_name} in {self.schema.name} but it is still"
" present in the schema"
)
with self.pipeline._sql_job_client(self.schema) as client:
client.drop_tables(*table_names, replace_schema=True)
# also delete staging but ignore if staging does not exist
Expand Down Expand Up @@ -241,6 +247,9 @@ def _drop_state_keys(self) -> None:
except ContextDefaultCannotBeCreated:
pass

def _save_local_schema(self) -> None:
self.pipeline.schemas.save_schema(self.schema)

def __call__(self) -> None:
if (
self.pipeline.has_pending_data
Expand All @@ -255,12 +264,15 @@ def __call__(self) -> None:

if self.drop_tables:
self._delete_pipeline_tables()
self._drop_destination_tables()
if not self.extract_only:
self._drop_destination_tables()
if self.drop_tables:
self.pipeline.schemas.save_schema(self.schema)
self._save_local_schema()
if self.drop_state:
self._drop_state_keys()
# Send updated state to destination
if self.extract_only:
return
self.pipeline.normalize()
try:
self.pipeline.load(raise_on_failed_jobs=True)
Expand Down
39 changes: 34 additions & 5 deletions dlt/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
StateInjectableContext,
TStepMetrics,
WithStepInfo,
TRefreshMode,
)
from dlt.common.schema import Schema
from dlt.common.utils import is_interactive
Expand All @@ -113,6 +114,7 @@
PipelineNotActive,
PipelineStepFailed,
SqlClientNotAvailable,
PipelineNeverRan,
)
from dlt.pipeline.trace import (
PipelineTrace,
Expand All @@ -124,7 +126,7 @@
end_trace_step,
end_trace,
)
from dlt.pipeline.typing import TPipelineStep, TRefreshMode
from dlt.pipeline.typing import TPipelineStep
from dlt.pipeline.state_sync import (
STATE_ENGINE_VERSION,
bump_version_if_modified,
Expand All @@ -135,7 +137,7 @@
json_decode_state,
)
from dlt.pipeline.warnings import credentials_argument_deprecated
from dlt.pipeline.helpers import drop as drop_command
from dlt.pipeline.helpers import DropCommand


def with_state_sync(may_extract_state: bool = False) -> Callable[[TFun], TFun]:
Expand Down Expand Up @@ -390,8 +392,6 @@ def extract(
schema_contract: TSchemaContract = None,
) -> ExtractInfo:
"""Extracts the `data` and prepare it for the normalization. Does not require destination or credentials to be configured. See `run` method for the arguments' description."""
if self.refresh == "full":
drop_command(self, drop_all=True, state_paths="*")

# create extract storage to which all the sources will be extracted
extract_step = Extract(
Expand All @@ -416,7 +416,36 @@ def extract(
):
if source.exhausted:
raise SourceExhausted(source.name)
self._extract_source(extract_step, source, max_parallel_items, workers)
dropped_tables = []
if not self.first_run:
if self.refresh == "full":
# Drop all tables from all resources and all source state paths
d = DropCommand(
self,
drop_all=True,
extract_only=True, # Only modify local state/schema, destination drop tables is done in load step
state_paths="*",
schema_name=source.schema.name,
)
dropped_tables = d.info["tables"]
d()
elif self.refresh == "replace":
# Drop tables from resources being currently extracted
d = DropCommand(
self,
resources=source.resources.extracted,
extract_only=True,
schema_name=source.schema.name,
)
dropped_tables = d.info["tables"]
d()
load_id = self._extract_source(
extract_step, source, max_parallel_items, workers
)
# Save the tables that were dropped locally (to be dropped from destination in load step)
extract_step.extract_storage.new_packages.save_dropped_tables(
load_id, dropped_tables
)
# extract state
if self.config.restore_from_destination:
# this will update state version hash so it will not be extracted again by with_state_sync
Expand Down
2 changes: 0 additions & 2 deletions dlt/pipeline/typing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Literal

TPipelineStep = Literal["sync", "extract", "normalize", "load"]

TRefreshMode = Literal["full", "replace"]

0 comments on commit 4270770

Please sign in to comment.