diff --git a/dlt/common/libs/pyarrow.py b/dlt/common/libs/pyarrow.py index 28f3ddb598..6df073f97c 100644 --- a/dlt/common/libs/pyarrow.py +++ b/dlt/common/libs/pyarrow.py @@ -226,7 +226,8 @@ def should_normalize_arrow_schema( schema: pyarrow.Schema, columns: TTableSchemaColumns, naming: NamingConvention, -) -> Tuple[bool, Mapping[str, str], Dict[str, str], Dict[str, bool], TTableSchemaColumns]: + add_load_id: bool = False, +) -> Tuple[bool, Mapping[str, str], Dict[str, str], Dict[str, bool], bool, TTableSchemaColumns]: rename_mapping = get_normalized_arrow_fields_mapping(schema, naming) rev_mapping = {v: k for k, v in rename_mapping.items()} nullable_mapping = {k: v.get("nullable", True) for k, v in columns.items()} @@ -240,6 +241,16 @@ def should_normalize_arrow_schema( dlt_tables = list(map(naming.normalize_table_identifier, ("_dlt_id", "_dlt_load_id"))) + # Do we need to add a load id column? + if add_load_id and "_dlt_load_id" in columns: + try: + schema.field("_dlt_load_id") + needs_load_id = False + except KeyError: + needs_load_id = True + else: + needs_load_id = False + # remove all columns that are dlt columns but are not present in arrow schema. we do not want to add such columns # that should happen in the normalizer columns = { @@ -250,9 +261,18 @@ def should_normalize_arrow_schema( # check if nothing to rename skip_normalize = ( - list(rename_mapping.keys()) == list(rename_mapping.values()) == list(columns.keys()) - ) and not nullable_updates - return not skip_normalize, rename_mapping, rev_mapping, nullable_updates, columns + (list(rename_mapping.keys()) == list(rename_mapping.values()) == list(columns.keys())) + and not nullable_updates + and not needs_load_id + ) + return ( + not skip_normalize, + rename_mapping, + rev_mapping, + nullable_updates, + needs_load_id, + columns, + ) def normalize_py_arrow_item( @@ -260,6 +280,7 @@ def normalize_py_arrow_item( columns: TTableSchemaColumns, naming: NamingConvention, caps: DestinationCapabilitiesContext, + load_id: Optional[str] = None, ) -> TAnyArrowItem: """Normalize arrow `item` schema according to the `columns`. @@ -267,10 +288,11 @@ def normalize_py_arrow_item( 2. arrows columns will be reordered according to `columns` 3. empty columns will be inserted if they are missing, types will be generated using `caps` 4. arrow columns with different nullability than corresponding schema columns will be updated + 5. Add `_dlt_load_id` column if it is missing and `load_id` is provided """ schema = item.schema - should_normalize, rename_mapping, rev_mapping, nullable_updates, columns = ( - should_normalize_arrow_schema(schema, columns, naming) + should_normalize, rename_mapping, rev_mapping, nullable_updates, needs_load_id, columns = ( + should_normalize_arrow_schema(schema, columns, naming, load_id is not None) ) if not should_normalize: return item @@ -307,6 +329,12 @@ def normalize_py_arrow_item( new_fields.append(schema.field(idx).with_name(column_name)) new_columns.append(item.column(idx)) + if needs_load_id and load_id: + # Storage efficient type for a column with constant value + load_id_type = pyarrow.dictionary(pyarrow.int8(), pyarrow.string()) + new_fields.append(pyarrow.field("_dlt_load_id", load_id_type, nullable=False)) + new_columns.append(pyarrow.array([load_id] * item.num_rows, type=load_id_type)) + # create desired type return item.__class__.from_arrays(new_columns, schema=pyarrow.schema(new_fields)) @@ -383,6 +411,30 @@ def from_arrow_scalar(arrow_value: pyarrow.Scalar) -> Any: """Sequence of tuples: (field index, field, generating function)""" +def add_constant_column( + item: TAnyArrowItem, + name: str, + data_type: pyarrow.DataType, + value: Any = None, + nullable: bool = True, + index: int = -1, +) -> TAnyArrowItem: + """Add column with a single value to the table. + + Args: + item: Arrow table or record batch + name: The new column name + data_type: The data type of the new column + nullable: Whether the new column is nullable + value: The value to fill the new column with + index: The index at which to insert the new column. Defaults to -1 (append) + """ + field = pyarrow.field(name, pyarrow.dictionary(pyarrow.int8(), data_type), nullable=nullable) + if index == -1: + return item.append_column(field, pyarrow.array([value] * item.num_rows, type=field.type)) + return item.add_column(index, field, pyarrow.array([value] * item.num_rows, type=field.type)) + + def pq_stream_with_new_columns( parquet_file: TFileOrPath, columns: TNewColumns, row_groups_per_read: int = 1 ) -> Iterator[pyarrow.Table]: diff --git a/dlt/extract/extractors.py b/dlt/extract/extractors.py index 8f95211aa0..e3ed223837 100644 --- a/dlt/extract/extractors.py +++ b/dlt/extract/extractors.py @@ -1,6 +1,7 @@ from copy import copy from typing import Set, Dict, Any, Optional, List +from dlt.common.configuration import known_sections from dlt.common import logger from dlt.common.configuration.inject import with_config from dlt.common.configuration.specs import BaseConfiguration, configspec @@ -21,6 +22,7 @@ from dlt.extract.resource import DltResource from dlt.extract.items import TableNameMeta from dlt.extract.storage import ExtractorItemStorage +from dlt.normalize.configuration import ItemsNormalizerConfiguration try: from dlt.common.libs import pyarrow @@ -215,13 +217,26 @@ class ObjectExtractor(Extractor): class ArrowExtractor(Extractor): """Extracts arrow data items into parquet. Normalizes arrow items column names. Compares the arrow schema to actual dlt table schema to reorder the columns and to - insert missing columns (without data). + insert missing columns (without data). Adds _dlt_load_id column to the table if + `add_dlt_load_id` is set to True in normalizer config. We do things that normalizer should do here so we do not need to load and save parquet files again later. + Handles the following types: + - `pyarrow.Table` + - `pyarrow.RecordBatch` + - `pandas.DataFrame` (is converted to arrow `Table` before processing) """ + # Inject the parts of normalize configuration that are used here + @with_config( + spec=ItemsNormalizerConfiguration, sections=(known_sections.NORMALIZE, "parquet_normalizer") + ) + def __init__(self, *args: Any, add_dlt_load_id: bool = False, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.add_dlt_load_id = add_dlt_load_id + def write_items(self, resource: DltResource, items: TDataItems, meta: Any) -> None: static_table_name = self._get_static_table_name(resource, meta) items = [ @@ -294,7 +309,13 @@ def _write_item( columns = columns or self.schema.get_table_columns(table_name) # Note: `items` is always a list here due to the conversion in `write_table` items = [ - pyarrow.normalize_py_arrow_item(item, columns, self.naming, self._caps) + pyarrow.normalize_py_arrow_item( + item, + columns, + self.naming, + self._caps, + load_id=self.load_id if self.add_dlt_load_id else None, + ) for item in items ] # write items one by one @@ -318,6 +339,13 @@ def _compute_table( arrow_table["columns"] = pyarrow.py_arrow_to_table_schema_columns(item.schema) # normalize arrow table before merging arrow_table = self.schema.normalize_table_identifiers(arrow_table) + # Add load_id column + if self.add_dlt_load_id and "_dlt_load_id" not in arrow_table["columns"]: + arrow_table["columns"]["_dlt_load_id"] = { + "name": "_dlt_load_id", + "data_type": "text", + "nullable": False, + } # issue warnings when overriding computed with arrow override_warn: bool = False for col_name, column in arrow_table["columns"].items(): diff --git a/dlt/normalize/items_normalizers.py b/dlt/normalize/items_normalizers.py index eed98d7563..6678f6edee 100644 --- a/dlt/normalize/items_normalizers.py +++ b/dlt/normalize/items_normalizers.py @@ -228,37 +228,13 @@ class ArrowItemsNormalizer(ItemsNormalizer): REWRITE_ROW_GROUPS = 1 def _write_with_dlt_columns( - self, extracted_items_file: str, root_table_name: str, add_load_id: bool, add_dlt_id: bool + self, extracted_items_file: str, root_table_name: str, add_dlt_id: bool ) -> List[TSchemaUpdate]: new_columns: List[Any] = [] schema = self.schema load_id = self.load_id schema_update: TSchemaUpdate = {} - if add_load_id: - table_update = schema.update_table( - { - "name": root_table_name, - "columns": { - "_dlt_load_id": { - "name": "_dlt_load_id", - "data_type": "text", - "nullable": False, - } - }, - } - ) - table_updates = schema_update.setdefault(root_table_name, []) - table_updates.append(table_update) - load_id_type = pa.dictionary(pa.int8(), pa.string()) - new_columns.append( - ( - -1, - pa.field("_dlt_load_id", load_id_type, nullable=False), - lambda batch: pa.array([load_id] * batch.num_rows, type=load_id_type), - ) - ) - if add_dlt_id: table_update = schema.update_table( { @@ -292,9 +268,9 @@ def _write_with_dlt_columns( items_count += batch.num_rows # we may need to normalize if is_native_arrow_writer and should_normalize is None: - should_normalize, _, _, _, _ = pyarrow.should_normalize_arrow_schema( + should_normalize = pyarrow.should_normalize_arrow_schema( batch.schema, columns_schema, schema.naming - ) + )[0] if should_normalize: logger.info( f"When writing arrow table to {root_table_name} the schema requires" @@ -366,25 +342,22 @@ def __call__(self, extracted_items_file: str, root_table_name: str) -> List[TSch base_schema_update = self._fix_schema_precisions(root_table_name, arrow_schema) add_dlt_id = self.config.parquet_normalizer.add_dlt_id - add_dlt_load_id = self.config.parquet_normalizer.add_dlt_load_id # if we need to add any columns or the file format is not parquet, we can't just import files - must_rewrite = ( - add_dlt_id or add_dlt_load_id or self.item_storage.writer_spec.file_format != "parquet" - ) + must_rewrite = add_dlt_id or self.item_storage.writer_spec.file_format != "parquet" if not must_rewrite: # in rare cases normalization may be needed - must_rewrite, _, _, _, _ = pyarrow.should_normalize_arrow_schema( + must_rewrite = pyarrow.should_normalize_arrow_schema( arrow_schema, self.schema.get_table_columns(root_table_name), self.schema.naming - ) + )[0] if must_rewrite: logger.info( f"Table {root_table_name} parquet file {extracted_items_file} must be rewritten:" - f" add_dlt_id: {add_dlt_id} add_dlt_load_id: {add_dlt_load_id} destination file" + f" add_dlt_id: {add_dlt_id} destination file" f" format: {self.item_storage.writer_spec.file_format} or due to required" " normalization " ) schema_update = self._write_with_dlt_columns( - extracted_items_file, root_table_name, add_dlt_load_id, add_dlt_id + extracted_items_file, root_table_name, add_dlt_id ) return base_schema_update + schema_update diff --git a/docs/website/docs/dlt-ecosystem/verified-sources/arrow-pandas.md b/docs/website/docs/dlt-ecosystem/verified-sources/arrow-pandas.md index 426c090f94..f9ceb99a90 100644 --- a/docs/website/docs/dlt-ecosystem/verified-sources/arrow-pandas.md +++ b/docs/website/docs/dlt-ecosystem/verified-sources/arrow-pandas.md @@ -84,7 +84,10 @@ add_dlt_load_id = true add_dlt_id = true ``` -Keep in mind that enabling these incurs some performance overhead because the `parquet` file needs to be read back from disk in chunks, processed and rewritten with new columns. +Keep in mind that enabling these incurs some performance overhead: + +- `add_dlt_load_id` has minimal overhead since the column is added to arrow table in memory during `extract` stage, before parquet file is written to disk +- `add_dlt_id` adds the column during `normalize` stage after file has been extracted to disk. The file needs to be read back from disk in chunks, processed and rewritten with new columns ## Incremental loading with Arrow tables