Skip to content

Commit

Permalink
Use drop schema in init_client (TODO: error)
Browse files Browse the repository at this point in the history
  • Loading branch information
steinitzu committed Apr 6, 2024
1 parent 187d1b0 commit a52b45f
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 62 deletions.
16 changes: 14 additions & 2 deletions dlt/extract/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,10 +394,22 @@ def extract(
state_paths="*" if self.refresh == "drop_dataset" else [],
)
_state.update(new_state)
drop_schema = source.schema.clone()
if drop_info["tables"]:
drop_tables = {
key: table
for key, table in source.schema.tables.items()
if table["name"] in drop_info["tables"]
or table["name"] in drop_schema.dlt_table_names()
}

drop_schema.tables.clear()
drop_schema.tables.update(drop_tables)
load_package.state["drop_schema"] = drop_schema.to_dict()
source.schema.tables.clear()
source.schema.tables.update(new_schema.tables)
dropped_tables = load_package.state.setdefault("dropped_tables", [])
dropped_tables.extend(drop_info["tables"])
# dropped_tables = load_package.state.setdefault("dropped_tables", [])
# dropped_tables.extend(drop_info["tables"])

# reset resource states, the `extracted` list contains all the explicit resources and all their parents
for resource in source.resources.extracted.values():
Expand Down
122 changes: 64 additions & 58 deletions dlt/load/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
config: LoaderConfiguration = config.value,
initial_client_config: DestinationClientConfiguration = config.value,
initial_staging_client_config: DestinationClientConfiguration = config.value,
refresh: Optional[TRefreshMode] = None,
) -> None:
self.config = config
self.collector = collector
Expand All @@ -79,7 +80,7 @@ def __init__(
self.pool = NullExecutor()
self.load_storage: LoadStorage = self.create_storage(is_storage_owner)
self._loaded_packages: List[LoadPackageInfo] = []
self._refreshed_tables: Set[str] = set()
self.refresh = refresh
super().__init__()

def create_storage(self, is_storage_owner: bool) -> LoadStorage:
Expand Down Expand Up @@ -344,71 +345,75 @@ def complete_package(self, load_id: str, schema: Schema, aborted: bool = False)
f"All jobs completed, archiving package {load_id} with aborted set to {aborted}"
)

def _refresh(self, dropped_tables: Sequence[str], schema: Schema) -> Tuple[Set[str], Set[str]]:
"""When using refresh mode, drop tables if possible.
Returns a set of tables for main destination and staging destination
that could not be dropped and should be truncated instead
"""
# Exclude tables already dropped in the same load
drop_tables = set(dropped_tables) - self._refreshed_tables
if not drop_tables:
return set(), set()
# Clone schema and remove tables from it
dropped_schema = deepcopy(schema)
for table_name in drop_tables:
# pop not del: The table may not actually be in the schema if it's not being loaded again
dropped_schema.tables.pop(table_name, None)
dropped_schema._bump_version()
trunc_dest: Set[str] = set()
trunc_staging: Set[str] = set()
# Drop from destination and replace stored schema so tables will be re-created before load
with self.get_destination_client(dropped_schema) as job_client:
# TODO: SupportsSql mixin
if hasattr(job_client, "drop_tables"):
job_client.drop_tables(*drop_tables, replace_schema=True)
else:
# Tables need to be truncated instead of dropped
trunc_dest = drop_tables

if self.staging_destination:
with self.get_staging_destination_client(dropped_schema) as staging_client:
if hasattr(staging_client, "drop_tables"):
staging_client.drop_tables(*drop_tables, replace_schema=True)
else:
trunc_staging = drop_tables
self._refreshed_tables.update(drop_tables) # Don't drop table again in same load
return trunc_dest, trunc_staging
# def _refresh(self, dropped_tables: Sequence[str], schema: Schema) -> Tuple[Set[str], Set[str]]:
# """When using refresh mode, drop tables if possible.
# Returns a set of tables for main destination and staging destination
# that could not be dropped and should be truncated instead
# """
# # Exclude tables already dropped in the same load
# drop_tables = set(dropped_tables) - self._refreshed_tables
# if not drop_tables:
# return set(), set()
# # Clone schema and remove tables from it
# dropped_schema = deepcopy(schema)
# for table_name in drop_tables:
# # pop not del: The table may not actually be in the schema if it's not being loaded again
# dropped_schema.tables.pop(table_name, None)
# dropped_schema._bump_version()
# trunc_dest: Set[str] = set()
# trunc_staging: Set[str] = set()
# # Drop from destination and replace stored schema so tables will be re-created before load
# with self.get_destination_client(dropped_schema) as job_client:
# # TODO: SupportsSql mixin
# if hasattr(job_client, "drop_tables"):
# job_client.drop_tables(*drop_tables, replace_schema=True)
# else:
# # Tables need to be truncated instead of dropped
# trunc_dest = drop_tables

# if self.staging_destination:
# with self.get_staging_destination_client(dropped_schema) as staging_client:
# if hasattr(staging_client, "drop_tables"):
# staging_client.drop_tables(*drop_tables, replace_schema=True)
# else:
# trunc_staging = drop_tables
# self._refreshed_tables.update(drop_tables) # Don't drop table again in same load
# return trunc_dest, trunc_staging

def load_single_package(self, load_id: str, schema: Schema) -> None:
new_jobs = self.get_new_jobs_info(load_id)

dropped_tables = current_load_package()["state"].get("dropped_tables", [])
# dropped_tables = current_load_package()["state"].get("dropped_tables", [])
# Drop tables before loading if refresh mode is set
truncate_dest, truncate_staging = self._refresh(dropped_tables, schema)
# truncate_dest, truncate_staging = self._refresh(dropped_tables, schema)
drop_schema_dict = current_load_package()["state"].get("drop_schema")
drop_schema = Schema.from_dict(drop_schema_dict) if drop_schema_dict else None
init_schema = drop_schema if drop_schema else schema

# initialize analytical storage ie. create dataset required by passed schema
with self.get_destination_client(schema) as job_client:
with self.get_destination_client(init_schema) as job_client:
if (expected_update := self.load_storage.begin_schema_update(load_id)) is not None:
# init job client
def should_truncate(table: TTableSchema) -> bool:
# When destination doesn't support dropping refreshed tables (i.e. not SQL based) they should be truncated
return (
job_client.should_truncate_table_before_load(table)
or table["name"] in truncate_dest
)
# def should_truncate(table: TTableSchema) -> bool:
# # When destination doesn't support dropping refreshed tables (i.e. not SQL based) they should be truncated
# return (
# job_client.should_truncate_table_before_load(table)
# or table["name"] in truncate_dest
# )

applied_update = init_client(
job_client,
schema,
init_schema,
new_jobs,
expected_update,
# job_client.should_truncate_table_before_load,
should_truncate,
job_client.should_truncate_table_before_load,
# should_truncate,
(
job_client.should_load_data_to_staging_dataset
if isinstance(job_client, WithStagingDataset)
else None
),
refresh=self.refresh,
)

# init staging client
Expand All @@ -418,23 +423,24 @@ def should_truncate(table: TTableSchema) -> bool:
" implement SupportsStagingDestination"
)

def should_truncate_staging(table: TTableSchema) -> bool:
return (
job_client.should_truncate_table_before_load_on_staging_destination(
table
)
or table["name"] in truncate_staging
)
# def should_truncate_staging(table: TTableSchema) -> bool:
# return (
# job_client.should_truncate_table_before_load_on_staging_destination(
# table
# )
# or table["name"] in truncate_staging
# )

with self.get_staging_destination_client(schema) as staging_client:
with self.get_staging_destination_client(init_schema) as staging_client:
init_client(
staging_client,
schema,
init_schema,
new_jobs,
expected_update,
# job_client.should_truncate_table_before_load_on_staging_destination,
should_truncate_staging,
job_client.should_truncate_table_before_load_on_staging_destination,
# should_truncate_staging,
job_client.should_load_data_to_staging_dataset_on_staging_destination,
refresh=self.refresh,
)

self.load_storage.commit_schema_update(load_id, applied_update)
Expand Down
26 changes: 24 additions & 2 deletions dlt/load/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Set, Iterable, Callable
from typing import List, Set, Iterable, Callable, Optional

from dlt.common import logger
from dlt.common.storages.load_package import LoadJobInfo, PackageStorage
Expand All @@ -15,6 +15,7 @@
JobClientBase,
WithStagingDataset,
)
from dlt.common.pipeline import TRefreshMode


def get_completed_table_chain(
Expand Down Expand Up @@ -66,6 +67,7 @@ def init_client(
expected_update: TSchemaTables,
truncate_filter: Callable[[TTableSchema], bool],
load_staging_filter: Callable[[TTableSchema], bool],
refresh: Optional[TRefreshMode] = None,
) -> TSchemaTables:
"""Initializes destination storage including staging dataset if supported
Expand All @@ -84,6 +86,8 @@ def init_client(
"""
# get dlt/internal tables
dlt_tables = set(schema.dlt_table_names())

all_tables = set(schema.tables.keys())
# tables without data (TODO: normalizer removes such jobs, write tests and remove the line below)
tables_no_data = set(
table["name"] for table in schema.data_tables() if not has_table_seen_data(table)
Expand All @@ -92,12 +96,23 @@ def init_client(
tables_with_jobs = set(job.table_name for job in new_jobs) - tables_no_data

# get tables to truncate by extending tables with jobs with all their child tables
if refresh == "drop_data":
truncate_filter = lambda t: True
truncate_tables = set(
_extend_tables_with_table_chain(schema, tables_with_jobs, tables_with_jobs, truncate_filter)
)

if refresh in ("drop_dataset", "drop_tables"):
drop_tables = all_tables - dlt_tables - tables_no_data
else:
drop_tables = set()

applied_update = _init_dataset_and_update_schema(
job_client, expected_update, tables_with_jobs | dlt_tables, truncate_tables
job_client,
expected_update,
tables_with_jobs | dlt_tables,
truncate_tables,
drop_tables=drop_tables,
)

# update the staging dataset if client supports this
Expand Down Expand Up @@ -128,6 +143,7 @@ def _init_dataset_and_update_schema(
update_tables: Iterable[str],
truncate_tables: Iterable[str] = None,
staging_info: bool = False,
drop_tables: Optional[Iterable[str]] = None,
) -> TSchemaTables:
staging_text = "for staging dataset" if staging_info else ""
logger.info(
Expand All @@ -146,6 +162,12 @@ def _init_dataset_and_update_schema(
f"Client for {job_client.config.destination_type} will truncate tables {staging_text}"
)
job_client.initialize_storage(truncate_tables=truncate_tables)
if drop_tables:
if hasattr(job_client, "drop_tables"):
logger.info(
f"Client for {job_client.config.destination_type} will drop tables {staging_text}"
)
job_client.drop_tables(*drop_tables)
return applied_update


Expand Down
16 changes: 16 additions & 0 deletions dlt/pipeline/drop.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,22 @@ def drop_resources(
drop_all: bool = False,
state_only: bool = False,
) -> Tuple[Schema, TPipelineState, _DropInfo]:
"""Generate a new schema and pipeline state with the requested resources removed.
Args:
schema: The schema to modify.
state: The pipeline state to modify.
resources: Resource name(s) or regex pattern(s) matching resource names to drop.
If empty, no resources will be dropped unless `drop_all` is True.
state_paths: JSON path(s) relative to the source state to drop.
drop_all: If True, all resources will be dropped (supeseeds `resources`).
state_only: If True, only modify the pipeline state, not schema
Returns:
A 3 part tuple containing the new schema, the new pipeline state, and a dictionary
containing information about the drop operation.
"""

if isinstance(resources, str):
resources = [resources]
resources = list(resources)
Expand Down
2 changes: 2 additions & 0 deletions dlt/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,7 @@ def load(
config=load_config,
initial_client_config=client.config,
initial_staging_client_config=staging_client.config if staging_client else None,
refresh=self.refresh if not self.first_run else None,
)
try:
with signals.delayed_signals():
Expand All @@ -548,6 +549,7 @@ def load(
self.first_run = False
return info
except Exception as l_ex:
raise
step_info = self._get_step_info(load_step)
raise PipelineStepFailed(
self, "load", load_step.current_load_id, l_ex, step_info
Expand Down

0 comments on commit a52b45f

Please sign in to comment.