diff --git a/dlt/common/libs/pyarrow.py b/dlt/common/libs/pyarrow.py index 28f3ddb598..8a6dc68078 100644 --- a/dlt/common/libs/pyarrow.py +++ b/dlt/common/libs/pyarrow.py @@ -226,7 +226,8 @@ def should_normalize_arrow_schema( schema: pyarrow.Schema, columns: TTableSchemaColumns, naming: NamingConvention, -) -> Tuple[bool, Mapping[str, str], Dict[str, str], Dict[str, bool], TTableSchemaColumns]: + add_load_id: bool = False, +) -> Tuple[bool, Mapping[str, str], Dict[str, str], Dict[str, bool], bool, TTableSchemaColumns]: rename_mapping = get_normalized_arrow_fields_mapping(schema, naming) rev_mapping = {v: k for k, v in rename_mapping.items()} nullable_mapping = {k: v.get("nullable", True) for k, v in columns.items()} @@ -238,21 +239,42 @@ def should_normalize_arrow_schema( if norm_name in nullable_mapping and field.nullable != nullable_mapping[norm_name]: nullable_updates[norm_name] = nullable_mapping[norm_name] - dlt_tables = list(map(naming.normalize_table_identifier, ("_dlt_id", "_dlt_load_id"))) + dlt_load_id_col = naming.normalize_table_identifier("_dlt_load_id") + dlt_id_col = naming.normalize_table_identifier("_dlt_id") + dlt_columns = {dlt_load_id_col, dlt_id_col} + + # Do we need to add a load id column? + if add_load_id and dlt_load_id_col in columns: + try: + schema.field(dlt_load_id_col) + needs_load_id = False + except KeyError: + needs_load_id = True + else: + needs_load_id = False # remove all columns that are dlt columns but are not present in arrow schema. we do not want to add such columns # that should happen in the normalizer columns = { name: column for name, column in columns.items() - if name not in dlt_tables or name in rev_mapping + if name not in dlt_columns or name in rev_mapping } # check if nothing to rename skip_normalize = ( - list(rename_mapping.keys()) == list(rename_mapping.values()) == list(columns.keys()) - ) and not nullable_updates - return not skip_normalize, rename_mapping, rev_mapping, nullable_updates, columns + (list(rename_mapping.keys()) == list(rename_mapping.values()) == list(columns.keys())) + and not nullable_updates + and not needs_load_id + ) + return ( + not skip_normalize, + rename_mapping, + rev_mapping, + nullable_updates, + needs_load_id, + columns, + ) def normalize_py_arrow_item( @@ -260,6 +282,7 @@ def normalize_py_arrow_item( columns: TTableSchemaColumns, naming: NamingConvention, caps: DestinationCapabilitiesContext, + load_id: Optional[str] = None, ) -> TAnyArrowItem: """Normalize arrow `item` schema according to the `columns`. @@ -267,10 +290,11 @@ def normalize_py_arrow_item( 2. arrows columns will be reordered according to `columns` 3. empty columns will be inserted if they are missing, types will be generated using `caps` 4. arrow columns with different nullability than corresponding schema columns will be updated + 5. Add `_dlt_load_id` column if it is missing and `load_id` is provided """ schema = item.schema - should_normalize, rename_mapping, rev_mapping, nullable_updates, columns = ( - should_normalize_arrow_schema(schema, columns, naming) + should_normalize, rename_mapping, rev_mapping, nullable_updates, needs_load_id, columns = ( + should_normalize_arrow_schema(schema, columns, naming, load_id is not None) ) if not should_normalize: return item @@ -307,6 +331,18 @@ def normalize_py_arrow_item( new_fields.append(schema.field(idx).with_name(column_name)) new_columns.append(item.column(idx)) + if needs_load_id and load_id: + # Storage efficient type for a column with constant value + load_id_type = pyarrow.dictionary(pyarrow.int8(), pyarrow.string()) + new_fields.append( + pyarrow.field( + naming.normalize_table_identifier("_dlt_load_id"), + load_id_type, + nullable=False, + ) + ) + new_columns.append(pyarrow.array([load_id] * item.num_rows, type=load_id_type)) + # create desired type return item.__class__.from_arrays(new_columns, schema=pyarrow.schema(new_fields)) @@ -383,6 +419,30 @@ def from_arrow_scalar(arrow_value: pyarrow.Scalar) -> Any: """Sequence of tuples: (field index, field, generating function)""" +def add_constant_column( + item: TAnyArrowItem, + name: str, + data_type: pyarrow.DataType, + value: Any = None, + nullable: bool = True, + index: int = -1, +) -> TAnyArrowItem: + """Add column with a single value to the table. + + Args: + item: Arrow table or record batch + name: The new column name + data_type: The data type of the new column + nullable: Whether the new column is nullable + value: The value to fill the new column with + index: The index at which to insert the new column. Defaults to -1 (append) + """ + field = pyarrow.field(name, pyarrow.dictionary(pyarrow.int8(), data_type), nullable=nullable) + if index == -1: + return item.append_column(field, pyarrow.array([value] * item.num_rows, type=field.type)) + return item.add_column(index, field, pyarrow.array([value] * item.num_rows, type=field.type)) + + def pq_stream_with_new_columns( parquet_file: TFileOrPath, columns: TNewColumns, row_groups_per_read: int = 1 ) -> Iterator[pyarrow.Table]: diff --git a/dlt/extract/extractors.py b/dlt/extract/extractors.py index 8f95211aa0..48f0d6968e 100644 --- a/dlt/extract/extractors.py +++ b/dlt/extract/extractors.py @@ -1,8 +1,8 @@ from copy import copy from typing import Set, Dict, Any, Optional, List +from dlt.common.configuration import known_sections, resolve_configuration, with_config from dlt.common import logger -from dlt.common.configuration.inject import with_config from dlt.common.configuration.specs import BaseConfiguration, configspec from dlt.common.destination.capabilities import DestinationCapabilitiesContext from dlt.common.exceptions import MissingDependencyException @@ -21,6 +21,7 @@ from dlt.extract.resource import DltResource from dlt.extract.items import TableNameMeta from dlt.extract.storage import ExtractorItemStorage +from dlt.normalize.configuration import ItemsNormalizerConfiguration try: from dlt.common.libs import pyarrow @@ -215,13 +216,29 @@ class ObjectExtractor(Extractor): class ArrowExtractor(Extractor): """Extracts arrow data items into parquet. Normalizes arrow items column names. Compares the arrow schema to actual dlt table schema to reorder the columns and to - insert missing columns (without data). + insert missing columns (without data). Adds _dlt_load_id column to the table if + `add_dlt_load_id` is set to True in normalizer config. We do things that normalizer should do here so we do not need to load and save parquet files again later. + Handles the following types: + - `pyarrow.Table` + - `pyarrow.RecordBatch` + - `pandas.DataFrame` (is converted to arrow `Table` before processing) """ + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._normalize_config = self._retrieve_normalize_config() + + def _retrieve_normalize_config(self) -> ItemsNormalizerConfiguration: + """Get normalizer settings that are used here""" + return resolve_configuration( + ItemsNormalizerConfiguration(), + sections=(known_sections.NORMALIZE, "parquet_normalizer"), + ) + def write_items(self, resource: DltResource, items: TDataItems, meta: Any) -> None: static_table_name = self._get_static_table_name(resource, meta) items = [ @@ -294,7 +311,13 @@ def _write_item( columns = columns or self.schema.get_table_columns(table_name) # Note: `items` is always a list here due to the conversion in `write_table` items = [ - pyarrow.normalize_py_arrow_item(item, columns, self.naming, self._caps) + pyarrow.normalize_py_arrow_item( + item, + columns, + self.naming, + self._caps, + load_id=self.load_id if self._normalize_config.add_dlt_load_id else None, + ) for item in items ] # write items one by one @@ -316,8 +339,22 @@ def _compute_table( else: arrow_table = copy(computed_table) arrow_table["columns"] = pyarrow.py_arrow_to_table_schema_columns(item.schema) + + # Add load_id column if needed + dlt_load_id_col = self.naming.normalize_table_identifier("_dlt_load_id") + if ( + self._normalize_config.add_dlt_load_id + and dlt_load_id_col not in arrow_table["columns"] + ): + arrow_table["columns"][dlt_load_id_col] = { + "name": dlt_load_id_col, + "data_type": "text", + "nullable": False, + } + # normalize arrow table before merging arrow_table = self.schema.normalize_table_identifiers(arrow_table) + # issue warnings when overriding computed with arrow override_warn: bool = False for col_name, column in arrow_table["columns"].items(): @@ -343,6 +380,7 @@ def _compute_table( utils.merge_columns( arrow_table["columns"], computed_table["columns"], merge_columns=True ) + return arrow_table def _compute_and_update_table( diff --git a/dlt/normalize/items_normalizers.py b/dlt/normalize/items_normalizers.py index eed98d7563..6678f6edee 100644 --- a/dlt/normalize/items_normalizers.py +++ b/dlt/normalize/items_normalizers.py @@ -228,37 +228,13 @@ class ArrowItemsNormalizer(ItemsNormalizer): REWRITE_ROW_GROUPS = 1 def _write_with_dlt_columns( - self, extracted_items_file: str, root_table_name: str, add_load_id: bool, add_dlt_id: bool + self, extracted_items_file: str, root_table_name: str, add_dlt_id: bool ) -> List[TSchemaUpdate]: new_columns: List[Any] = [] schema = self.schema load_id = self.load_id schema_update: TSchemaUpdate = {} - if add_load_id: - table_update = schema.update_table( - { - "name": root_table_name, - "columns": { - "_dlt_load_id": { - "name": "_dlt_load_id", - "data_type": "text", - "nullable": False, - } - }, - } - ) - table_updates = schema_update.setdefault(root_table_name, []) - table_updates.append(table_update) - load_id_type = pa.dictionary(pa.int8(), pa.string()) - new_columns.append( - ( - -1, - pa.field("_dlt_load_id", load_id_type, nullable=False), - lambda batch: pa.array([load_id] * batch.num_rows, type=load_id_type), - ) - ) - if add_dlt_id: table_update = schema.update_table( { @@ -292,9 +268,9 @@ def _write_with_dlt_columns( items_count += batch.num_rows # we may need to normalize if is_native_arrow_writer and should_normalize is None: - should_normalize, _, _, _, _ = pyarrow.should_normalize_arrow_schema( + should_normalize = pyarrow.should_normalize_arrow_schema( batch.schema, columns_schema, schema.naming - ) + )[0] if should_normalize: logger.info( f"When writing arrow table to {root_table_name} the schema requires" @@ -366,25 +342,22 @@ def __call__(self, extracted_items_file: str, root_table_name: str) -> List[TSch base_schema_update = self._fix_schema_precisions(root_table_name, arrow_schema) add_dlt_id = self.config.parquet_normalizer.add_dlt_id - add_dlt_load_id = self.config.parquet_normalizer.add_dlt_load_id # if we need to add any columns or the file format is not parquet, we can't just import files - must_rewrite = ( - add_dlt_id or add_dlt_load_id or self.item_storage.writer_spec.file_format != "parquet" - ) + must_rewrite = add_dlt_id or self.item_storage.writer_spec.file_format != "parquet" if not must_rewrite: # in rare cases normalization may be needed - must_rewrite, _, _, _, _ = pyarrow.should_normalize_arrow_schema( + must_rewrite = pyarrow.should_normalize_arrow_schema( arrow_schema, self.schema.get_table_columns(root_table_name), self.schema.naming - ) + )[0] if must_rewrite: logger.info( f"Table {root_table_name} parquet file {extracted_items_file} must be rewritten:" - f" add_dlt_id: {add_dlt_id} add_dlt_load_id: {add_dlt_load_id} destination file" + f" add_dlt_id: {add_dlt_id} destination file" f" format: {self.item_storage.writer_spec.file_format} or due to required" " normalization " ) schema_update = self._write_with_dlt_columns( - extracted_items_file, root_table_name, add_dlt_load_id, add_dlt_id + extracted_items_file, root_table_name, add_dlt_id ) return base_schema_update + schema_update diff --git a/docs/website/docs/dlt-ecosystem/verified-sources/arrow-pandas.md b/docs/website/docs/dlt-ecosystem/verified-sources/arrow-pandas.md index 426c090f94..f9ceb99a90 100644 --- a/docs/website/docs/dlt-ecosystem/verified-sources/arrow-pandas.md +++ b/docs/website/docs/dlt-ecosystem/verified-sources/arrow-pandas.md @@ -84,7 +84,10 @@ add_dlt_load_id = true add_dlt_id = true ``` -Keep in mind that enabling these incurs some performance overhead because the `parquet` file needs to be read back from disk in chunks, processed and rewritten with new columns. +Keep in mind that enabling these incurs some performance overhead: + +- `add_dlt_load_id` has minimal overhead since the column is added to arrow table in memory during `extract` stage, before parquet file is written to disk +- `add_dlt_id` adds the column during `normalize` stage after file has been extracted to disk. The file needs to be read back from disk in chunks, processed and rewritten with new columns ## Incremental loading with Arrow tables diff --git a/tests/common/schema/test_inference.py b/tests/common/schema/test_inference.py index da5c809827..0a40953f53 100644 --- a/tests/common/schema/test_inference.py +++ b/tests/common/schema/test_inference.py @@ -565,3 +565,24 @@ def test_infer_on_incomplete_column(schema: Schema) -> None: assert i_column["x-special"] == "spec" # type: ignore[typeddict-item] assert i_column["primary_key"] is True assert i_column["data_type"] == "text" + + +def test_update_table_adds_at_end(schema: Schema) -> None: + row = {"evm": Wei(1)} + _, new_table = schema.coerce_row("eth", None, row) + schema.update_table(new_table) + schema.update_table( + { + "name": new_table["name"], + "columns": { + "_dlt_load_id": { + "name": "_dlt_load_id", + "data_type": "text", + "nullable": False, + } + }, + } + ) + table = schema.tables["eth"] + # place new columns at the end + assert list(table["columns"].keys()) == ["evm", "_dlt_load_id"] diff --git a/tests/pipeline/test_arrow_sources.py b/tests/pipeline/test_arrow_sources.py index 667f26476b..0c03a8209d 100644 --- a/tests/pipeline/test_arrow_sources.py +++ b/tests/pipeline/test_arrow_sources.py @@ -505,3 +505,61 @@ def test_empty_arrow(item_type: TPythonTableFormat) -> None: assert len(pipeline.list_extracted_resources()) == 1 norm_info = pipeline.normalize() assert norm_info.row_counts["items"] == 0 + + +@pytest.mark.parametrize("item_type", ["pandas", "arrow-table", "arrow-batch"]) +def test_extract_adds_dlt_load_id(item_type: TPythonTableFormat) -> None: + os.environ["NORMALIZE__PARQUET_NORMALIZER__ADD_DLT_LOAD_ID"] = "True" + os.environ["DESTINATION__LOADER_FILE_FORMAT"] = "parquet" + + item, _, _ = arrow_table_all_data_types(item_type, num_rows=5432) + + @dlt.resource + def some_data(): + yield item + + pipeline: dlt.Pipeline = dlt.pipeline("arrow_" + uniq_id(), destination="duckdb") + info = pipeline.extract(some_data()) + + load_id = info.loads_ids[0] + jobs = info.load_packages[0].jobs["new_jobs"] + extracted_file = [job for job in jobs if "some_data" in job.file_path][0].file_path + + with pa.parquet.ParquetFile(extracted_file) as pq: + tbl = pq.read() + assert len(tbl) == 5432 + + # Extracted file has _dlt_load_id + assert pa.compute.all(pa.compute.equal(tbl["_dlt_load_id"], load_id)).as_py() + + # Load ID in both schema and arrow tbl should be the last column + assert tbl.schema.names[-1] == "_dlt_load_id" + cols = list(pipeline.default_schema.tables["some_data"]["columns"]) + assert cols[-1] == "_dlt_load_id" + + +def test_extract_json_normalize_parquet_adds_dlt_load_id(): + """Extract jsonl data that gets written to parquet in normalizer. Check that _dlt_load_id is added.""" + os.environ["NORMALIZE__PARQUET_NORMALIZER__ADD_DLT_LOAD_ID"] = "True" + + rows, _, _ = arrow_table_all_data_types("object", num_rows=1001) + + @dlt.resource + def some_data(): + yield rows + + pipeline: dlt.Pipeline = dlt.pipeline("arrow_" + uniq_id(), destination="duckdb") + + pipeline.extract(some_data()) + n_info = pipeline.normalize(loader_file_format="parquet") + + load_id = n_info.loads_ids[0] + jobs = n_info.load_packages[0].jobs["new_jobs"] + normalized_file = [job for job in jobs if "some_data" in job.file_path][0].file_path + + with pa.parquet.ParquetFile(normalized_file) as pq: + tbl = pq.read() + assert len(tbl) == 1001 + + # Normalized file has _dlt_load_id + assert pa.compute.all(pa.compute.equal(tbl["_dlt_load_id"], load_id)).as_py()