Skip to content

Commit

Permalink
Add load_id to arrow tables in extract step instead of normalize (#1449)
Browse files Browse the repository at this point in the history
* Add load_id to arrow tables in extract step instead of normalize

* Test arrow load id in extract

* Get normalize config without decorator

* Normalize load ID column name

* Load ID column goes last

* adds update_table column order tests

---------

Co-authored-by: Marcin Rudolf <rudolfix@rudolfix.org>
  • Loading branch information
steinitzu and rudolfix committed Jun 18, 2024
1 parent 14f06e4 commit b267c70
Show file tree
Hide file tree
Showing 6 changed files with 200 additions and 47 deletions.
76 changes: 68 additions & 8 deletions dlt/common/libs/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ def should_normalize_arrow_schema(
schema: pyarrow.Schema,
columns: TTableSchemaColumns,
naming: NamingConvention,
) -> Tuple[bool, Mapping[str, str], Dict[str, str], Dict[str, bool], TTableSchemaColumns]:
add_load_id: bool = False,
) -> Tuple[bool, Mapping[str, str], Dict[str, str], Dict[str, bool], bool, TTableSchemaColumns]:
rename_mapping = get_normalized_arrow_fields_mapping(schema, naming)
rev_mapping = {v: k for k, v in rename_mapping.items()}
nullable_mapping = {k: v.get("nullable", True) for k, v in columns.items()}
Expand All @@ -238,39 +239,62 @@ def should_normalize_arrow_schema(
if norm_name in nullable_mapping and field.nullable != nullable_mapping[norm_name]:
nullable_updates[norm_name] = nullable_mapping[norm_name]

dlt_tables = list(map(naming.normalize_table_identifier, ("_dlt_id", "_dlt_load_id")))
dlt_load_id_col = naming.normalize_table_identifier("_dlt_load_id")
dlt_id_col = naming.normalize_table_identifier("_dlt_id")
dlt_columns = {dlt_load_id_col, dlt_id_col}

# Do we need to add a load id column?
if add_load_id and dlt_load_id_col in columns:
try:
schema.field(dlt_load_id_col)
needs_load_id = False
except KeyError:
needs_load_id = True
else:
needs_load_id = False

# remove all columns that are dlt columns but are not present in arrow schema. we do not want to add such columns
# that should happen in the normalizer
columns = {
name: column
for name, column in columns.items()
if name not in dlt_tables or name in rev_mapping
if name not in dlt_columns or name in rev_mapping
}

# check if nothing to rename
skip_normalize = (
list(rename_mapping.keys()) == list(rename_mapping.values()) == list(columns.keys())
) and not nullable_updates
return not skip_normalize, rename_mapping, rev_mapping, nullable_updates, columns
(list(rename_mapping.keys()) == list(rename_mapping.values()) == list(columns.keys()))
and not nullable_updates
and not needs_load_id
)
return (
not skip_normalize,
rename_mapping,
rev_mapping,
nullable_updates,
needs_load_id,
columns,
)


def normalize_py_arrow_item(
item: TAnyArrowItem,
columns: TTableSchemaColumns,
naming: NamingConvention,
caps: DestinationCapabilitiesContext,
load_id: Optional[str] = None,
) -> TAnyArrowItem:
"""Normalize arrow `item` schema according to the `columns`.
1. arrow schema field names will be normalized according to `naming`
2. arrows columns will be reordered according to `columns`
3. empty columns will be inserted if they are missing, types will be generated using `caps`
4. arrow columns with different nullability than corresponding schema columns will be updated
5. Add `_dlt_load_id` column if it is missing and `load_id` is provided
"""
schema = item.schema
should_normalize, rename_mapping, rev_mapping, nullable_updates, columns = (
should_normalize_arrow_schema(schema, columns, naming)
should_normalize, rename_mapping, rev_mapping, nullable_updates, needs_load_id, columns = (
should_normalize_arrow_schema(schema, columns, naming, load_id is not None)
)
if not should_normalize:
return item
Expand Down Expand Up @@ -307,6 +331,18 @@ def normalize_py_arrow_item(
new_fields.append(schema.field(idx).with_name(column_name))
new_columns.append(item.column(idx))

if needs_load_id and load_id:
# Storage efficient type for a column with constant value
load_id_type = pyarrow.dictionary(pyarrow.int8(), pyarrow.string())
new_fields.append(
pyarrow.field(
naming.normalize_table_identifier("_dlt_load_id"),
load_id_type,
nullable=False,
)
)
new_columns.append(pyarrow.array([load_id] * item.num_rows, type=load_id_type))

# create desired type
return item.__class__.from_arrays(new_columns, schema=pyarrow.schema(new_fields))

Expand Down Expand Up @@ -383,6 +419,30 @@ def from_arrow_scalar(arrow_value: pyarrow.Scalar) -> Any:
"""Sequence of tuples: (field index, field, generating function)"""


def add_constant_column(
item: TAnyArrowItem,
name: str,
data_type: pyarrow.DataType,
value: Any = None,
nullable: bool = True,
index: int = -1,
) -> TAnyArrowItem:
"""Add column with a single value to the table.
Args:
item: Arrow table or record batch
name: The new column name
data_type: The data type of the new column
nullable: Whether the new column is nullable
value: The value to fill the new column with
index: The index at which to insert the new column. Defaults to -1 (append)
"""
field = pyarrow.field(name, pyarrow.dictionary(pyarrow.int8(), data_type), nullable=nullable)
if index == -1:
return item.append_column(field, pyarrow.array([value] * item.num_rows, type=field.type))
return item.add_column(index, field, pyarrow.array([value] * item.num_rows, type=field.type))


def pq_stream_with_new_columns(
parquet_file: TFileOrPath, columns: TNewColumns, row_groups_per_read: int = 1
) -> Iterator[pyarrow.Table]:
Expand Down
44 changes: 41 additions & 3 deletions dlt/extract/extractors.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from copy import copy
from typing import Set, Dict, Any, Optional, List

from dlt.common.configuration import known_sections, resolve_configuration, with_config
from dlt.common import logger
from dlt.common.configuration.inject import with_config
from dlt.common.configuration.specs import BaseConfiguration, configspec
from dlt.common.destination.capabilities import DestinationCapabilitiesContext
from dlt.common.exceptions import MissingDependencyException
Expand All @@ -21,6 +21,7 @@
from dlt.extract.resource import DltResource
from dlt.extract.items import TableNameMeta
from dlt.extract.storage import ExtractorItemStorage
from dlt.normalize.configuration import ItemsNormalizerConfiguration

try:
from dlt.common.libs import pyarrow
Expand Down Expand Up @@ -215,13 +216,29 @@ class ObjectExtractor(Extractor):
class ArrowExtractor(Extractor):
"""Extracts arrow data items into parquet. Normalizes arrow items column names.
Compares the arrow schema to actual dlt table schema to reorder the columns and to
insert missing columns (without data).
insert missing columns (without data). Adds _dlt_load_id column to the table if
`add_dlt_load_id` is set to True in normalizer config.
We do things that normalizer should do here so we do not need to load and save parquet
files again later.
Handles the following types:
- `pyarrow.Table`
- `pyarrow.RecordBatch`
- `pandas.DataFrame` (is converted to arrow `Table` before processing)
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._normalize_config = self._retrieve_normalize_config()

def _retrieve_normalize_config(self) -> ItemsNormalizerConfiguration:
"""Get normalizer settings that are used here"""
return resolve_configuration(
ItemsNormalizerConfiguration(),
sections=(known_sections.NORMALIZE, "parquet_normalizer"),
)

def write_items(self, resource: DltResource, items: TDataItems, meta: Any) -> None:
static_table_name = self._get_static_table_name(resource, meta)
items = [
Expand Down Expand Up @@ -294,7 +311,13 @@ def _write_item(
columns = columns or self.schema.get_table_columns(table_name)
# Note: `items` is always a list here due to the conversion in `write_table`
items = [
pyarrow.normalize_py_arrow_item(item, columns, self.naming, self._caps)
pyarrow.normalize_py_arrow_item(
item,
columns,
self.naming,
self._caps,
load_id=self.load_id if self._normalize_config.add_dlt_load_id else None,
)
for item in items
]
# write items one by one
Expand All @@ -316,8 +339,22 @@ def _compute_table(
else:
arrow_table = copy(computed_table)
arrow_table["columns"] = pyarrow.py_arrow_to_table_schema_columns(item.schema)

# Add load_id column if needed
dlt_load_id_col = self.naming.normalize_table_identifier("_dlt_load_id")
if (
self._normalize_config.add_dlt_load_id
and dlt_load_id_col not in arrow_table["columns"]
):
arrow_table["columns"][dlt_load_id_col] = {
"name": dlt_load_id_col,
"data_type": "text",
"nullable": False,
}

# normalize arrow table before merging
arrow_table = self.schema.normalize_table_identifiers(arrow_table)

# issue warnings when overriding computed with arrow
override_warn: bool = False
for col_name, column in arrow_table["columns"].items():
Expand All @@ -343,6 +380,7 @@ def _compute_table(
utils.merge_columns(
arrow_table["columns"], computed_table["columns"], merge_columns=True
)

return arrow_table

def _compute_and_update_table(
Expand Down
43 changes: 8 additions & 35 deletions dlt/normalize/items_normalizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,37 +228,13 @@ class ArrowItemsNormalizer(ItemsNormalizer):
REWRITE_ROW_GROUPS = 1

def _write_with_dlt_columns(
self, extracted_items_file: str, root_table_name: str, add_load_id: bool, add_dlt_id: bool
self, extracted_items_file: str, root_table_name: str, add_dlt_id: bool
) -> List[TSchemaUpdate]:
new_columns: List[Any] = []
schema = self.schema
load_id = self.load_id
schema_update: TSchemaUpdate = {}

if add_load_id:
table_update = schema.update_table(
{
"name": root_table_name,
"columns": {
"_dlt_load_id": {
"name": "_dlt_load_id",
"data_type": "text",
"nullable": False,
}
},
}
)
table_updates = schema_update.setdefault(root_table_name, [])
table_updates.append(table_update)
load_id_type = pa.dictionary(pa.int8(), pa.string())
new_columns.append(
(
-1,
pa.field("_dlt_load_id", load_id_type, nullable=False),
lambda batch: pa.array([load_id] * batch.num_rows, type=load_id_type),
)
)

if add_dlt_id:
table_update = schema.update_table(
{
Expand Down Expand Up @@ -292,9 +268,9 @@ def _write_with_dlt_columns(
items_count += batch.num_rows
# we may need to normalize
if is_native_arrow_writer and should_normalize is None:
should_normalize, _, _, _, _ = pyarrow.should_normalize_arrow_schema(
should_normalize = pyarrow.should_normalize_arrow_schema(
batch.schema, columns_schema, schema.naming
)
)[0]
if should_normalize:
logger.info(
f"When writing arrow table to {root_table_name} the schema requires"
Expand Down Expand Up @@ -366,25 +342,22 @@ def __call__(self, extracted_items_file: str, root_table_name: str) -> List[TSch
base_schema_update = self._fix_schema_precisions(root_table_name, arrow_schema)

add_dlt_id = self.config.parquet_normalizer.add_dlt_id
add_dlt_load_id = self.config.parquet_normalizer.add_dlt_load_id
# if we need to add any columns or the file format is not parquet, we can't just import files
must_rewrite = (
add_dlt_id or add_dlt_load_id or self.item_storage.writer_spec.file_format != "parquet"
)
must_rewrite = add_dlt_id or self.item_storage.writer_spec.file_format != "parquet"
if not must_rewrite:
# in rare cases normalization may be needed
must_rewrite, _, _, _, _ = pyarrow.should_normalize_arrow_schema(
must_rewrite = pyarrow.should_normalize_arrow_schema(
arrow_schema, self.schema.get_table_columns(root_table_name), self.schema.naming
)
)[0]
if must_rewrite:
logger.info(
f"Table {root_table_name} parquet file {extracted_items_file} must be rewritten:"
f" add_dlt_id: {add_dlt_id} add_dlt_load_id: {add_dlt_load_id} destination file"
f" add_dlt_id: {add_dlt_id} destination file"
f" format: {self.item_storage.writer_spec.file_format} or due to required"
" normalization "
)
schema_update = self._write_with_dlt_columns(
extracted_items_file, root_table_name, add_dlt_load_id, add_dlt_id
extracted_items_file, root_table_name, add_dlt_id
)
return base_schema_update + schema_update

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,10 @@ add_dlt_load_id = true
add_dlt_id = true
```

Keep in mind that enabling these incurs some performance overhead because the `parquet` file needs to be read back from disk in chunks, processed and rewritten with new columns.
Keep in mind that enabling these incurs some performance overhead:

- `add_dlt_load_id` has minimal overhead since the column is added to arrow table in memory during `extract` stage, before parquet file is written to disk
- `add_dlt_id` adds the column during `normalize` stage after file has been extracted to disk. The file needs to be read back from disk in chunks, processed and rewritten with new columns

## Incremental loading with Arrow tables

Expand Down
21 changes: 21 additions & 0 deletions tests/common/schema/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,3 +565,24 @@ def test_infer_on_incomplete_column(schema: Schema) -> None:
assert i_column["x-special"] == "spec" # type: ignore[typeddict-item]
assert i_column["primary_key"] is True
assert i_column["data_type"] == "text"


def test_update_table_adds_at_end(schema: Schema) -> None:
row = {"evm": Wei(1)}
_, new_table = schema.coerce_row("eth", None, row)
schema.update_table(new_table)
schema.update_table(
{
"name": new_table["name"],
"columns": {
"_dlt_load_id": {
"name": "_dlt_load_id",
"data_type": "text",
"nullable": False,
}
},
}
)
table = schema.tables["eth"]
# place new columns at the end
assert list(table["columns"].keys()) == ["evm", "_dlt_load_id"]

0 comments on commit b267c70

Please sign in to comment.