Skip to content

Commit

Permalink
Add load_id to arrow tables in extract step instead of normalize
Browse files Browse the repository at this point in the history
  • Loading branch information
steinitzu committed Jun 7, 2024
1 parent 1c1ce7e commit 536f0f2
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 44 deletions.
64 changes: 58 additions & 6 deletions dlt/common/libs/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ def should_normalize_arrow_schema(
schema: pyarrow.Schema,
columns: TTableSchemaColumns,
naming: NamingConvention,
) -> Tuple[bool, Mapping[str, str], Dict[str, str], Dict[str, bool], TTableSchemaColumns]:
add_load_id: bool = False,
) -> Tuple[bool, Mapping[str, str], Dict[str, str], Dict[str, bool], bool, TTableSchemaColumns]:
rename_mapping = get_normalized_arrow_fields_mapping(schema, naming)
rev_mapping = {v: k for k, v in rename_mapping.items()}
nullable_mapping = {k: v.get("nullable", True) for k, v in columns.items()}
Expand All @@ -240,6 +241,16 @@ def should_normalize_arrow_schema(

dlt_tables = list(map(naming.normalize_table_identifier, ("_dlt_id", "_dlt_load_id")))

# Do we need to add a load id column?
if add_load_id and "_dlt_load_id" in columns:
try:
schema.field("_dlt_load_id")
needs_load_id = False
except KeyError:
needs_load_id = True
else:
needs_load_id = False

# remove all columns that are dlt columns but are not present in arrow schema. we do not want to add such columns
# that should happen in the normalizer
columns = {
Expand All @@ -250,27 +261,38 @@ def should_normalize_arrow_schema(

# check if nothing to rename
skip_normalize = (
list(rename_mapping.keys()) == list(rename_mapping.values()) == list(columns.keys())
) and not nullable_updates
return not skip_normalize, rename_mapping, rev_mapping, nullable_updates, columns
(list(rename_mapping.keys()) == list(rename_mapping.values()) == list(columns.keys()))
and not nullable_updates
and not needs_load_id
)
return (
not skip_normalize,
rename_mapping,
rev_mapping,
nullable_updates,
needs_load_id,
columns,
)


def normalize_py_arrow_item(
item: TAnyArrowItem,
columns: TTableSchemaColumns,
naming: NamingConvention,
caps: DestinationCapabilitiesContext,
load_id: Optional[str] = None,
) -> TAnyArrowItem:
"""Normalize arrow `item` schema according to the `columns`.
1. arrow schema field names will be normalized according to `naming`
2. arrows columns will be reordered according to `columns`
3. empty columns will be inserted if they are missing, types will be generated using `caps`
4. arrow columns with different nullability than corresponding schema columns will be updated
5. Add `_dlt_load_id` column if it is missing and `load_id` is provided
"""
schema = item.schema
should_normalize, rename_mapping, rev_mapping, nullable_updates, columns = (
should_normalize_arrow_schema(schema, columns, naming)
should_normalize, rename_mapping, rev_mapping, nullable_updates, needs_load_id, columns = (
should_normalize_arrow_schema(schema, columns, naming, load_id is not None)
)
if not should_normalize:
return item
Expand Down Expand Up @@ -307,6 +329,12 @@ def normalize_py_arrow_item(
new_fields.append(schema.field(idx).with_name(column_name))
new_columns.append(item.column(idx))

if needs_load_id and load_id:
# Storage efficient type for a column with constant value
load_id_type = pyarrow.dictionary(pyarrow.int8(), pyarrow.string())
new_fields.append(pyarrow.field("_dlt_load_id", load_id_type, nullable=False))
new_columns.append(pyarrow.array([load_id] * item.num_rows, type=load_id_type))

# create desired type
return item.__class__.from_arrays(new_columns, schema=pyarrow.schema(new_fields))

Expand Down Expand Up @@ -383,6 +411,30 @@ def from_arrow_scalar(arrow_value: pyarrow.Scalar) -> Any:
"""Sequence of tuples: (field index, field, generating function)"""


def add_constant_column(
item: TAnyArrowItem,
name: str,
data_type: pyarrow.DataType,
value: Any = None,
nullable: bool = True,
index: int = -1,
) -> TAnyArrowItem:
"""Add column with a single value to the table.
Args:
item: Arrow table or record batch
name: The new column name
data_type: The data type of the new column
nullable: Whether the new column is nullable
value: The value to fill the new column with
index: The index at which to insert the new column. Defaults to -1 (append)
"""
field = pyarrow.field(name, pyarrow.dictionary(pyarrow.int8(), data_type), nullable=nullable)
if index == -1:
return item.append_column(field, pyarrow.array([value] * item.num_rows, type=field.type))
return item.add_column(index, field, pyarrow.array([value] * item.num_rows, type=field.type))


def pq_stream_with_new_columns(
parquet_file: TFileOrPath, columns: TNewColumns, row_groups_per_read: int = 1
) -> Iterator[pyarrow.Table]:
Expand Down
32 changes: 30 additions & 2 deletions dlt/extract/extractors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from copy import copy
from typing import Set, Dict, Any, Optional, List

from dlt.common.configuration import known_sections
from dlt.common import logger
from dlt.common.configuration.inject import with_config
from dlt.common.configuration.specs import BaseConfiguration, configspec
Expand All @@ -21,6 +22,7 @@
from dlt.extract.resource import DltResource
from dlt.extract.items import TableNameMeta
from dlt.extract.storage import ExtractorItemStorage
from dlt.normalize.configuration import ItemsNormalizerConfiguration

try:
from dlt.common.libs import pyarrow
Expand Down Expand Up @@ -215,13 +217,26 @@ class ObjectExtractor(Extractor):
class ArrowExtractor(Extractor):
"""Extracts arrow data items into parquet. Normalizes arrow items column names.
Compares the arrow schema to actual dlt table schema to reorder the columns and to
insert missing columns (without data).
insert missing columns (without data). Adds _dlt_load_id column to the table if
`add_dlt_load_id` is set to True in normalizer config.
We do things that normalizer should do here so we do not need to load and save parquet
files again later.
Handles the following types:
- `pyarrow.Table`
- `pyarrow.RecordBatch`
- `pandas.DataFrame` (is converted to arrow `Table` before processing)
"""

# Inject the parts of normalize configuration that are used here
@with_config(
spec=ItemsNormalizerConfiguration, sections=(known_sections.NORMALIZE, "parquet_normalizer")
)
def __init__(self, *args: Any, add_dlt_load_id: bool = False, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.add_dlt_load_id = add_dlt_load_id

def write_items(self, resource: DltResource, items: TDataItems, meta: Any) -> None:
static_table_name = self._get_static_table_name(resource, meta)
items = [
Expand Down Expand Up @@ -294,7 +309,13 @@ def _write_item(
columns = columns or self.schema.get_table_columns(table_name)
# Note: `items` is always a list here due to the conversion in `write_table`
items = [
pyarrow.normalize_py_arrow_item(item, columns, self.naming, self._caps)
pyarrow.normalize_py_arrow_item(
item,
columns,
self.naming,
self._caps,
load_id=self.load_id if self.add_dlt_load_id else None,
)
for item in items
]
# write items one by one
Expand All @@ -318,6 +339,13 @@ def _compute_table(
arrow_table["columns"] = pyarrow.py_arrow_to_table_schema_columns(item.schema)
# normalize arrow table before merging
arrow_table = self.schema.normalize_table_identifiers(arrow_table)
# Add load_id column
if self.add_dlt_load_id and "_dlt_load_id" not in arrow_table["columns"]:
arrow_table["columns"]["_dlt_load_id"] = {
"name": "_dlt_load_id",
"data_type": "text",
"nullable": False,
}
# issue warnings when overriding computed with arrow
override_warn: bool = False
for col_name, column in arrow_table["columns"].items():
Expand Down
43 changes: 8 additions & 35 deletions dlt/normalize/items_normalizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,37 +228,13 @@ class ArrowItemsNormalizer(ItemsNormalizer):
REWRITE_ROW_GROUPS = 1

def _write_with_dlt_columns(
self, extracted_items_file: str, root_table_name: str, add_load_id: bool, add_dlt_id: bool
self, extracted_items_file: str, root_table_name: str, add_dlt_id: bool
) -> List[TSchemaUpdate]:
new_columns: List[Any] = []
schema = self.schema
load_id = self.load_id
schema_update: TSchemaUpdate = {}

if add_load_id:
table_update = schema.update_table(
{
"name": root_table_name,
"columns": {
"_dlt_load_id": {
"name": "_dlt_load_id",
"data_type": "text",
"nullable": False,
}
},
}
)
table_updates = schema_update.setdefault(root_table_name, [])
table_updates.append(table_update)
load_id_type = pa.dictionary(pa.int8(), pa.string())
new_columns.append(
(
-1,
pa.field("_dlt_load_id", load_id_type, nullable=False),
lambda batch: pa.array([load_id] * batch.num_rows, type=load_id_type),
)
)

if add_dlt_id:
table_update = schema.update_table(
{
Expand Down Expand Up @@ -292,9 +268,9 @@ def _write_with_dlt_columns(
items_count += batch.num_rows
# we may need to normalize
if is_native_arrow_writer and should_normalize is None:
should_normalize, _, _, _, _ = pyarrow.should_normalize_arrow_schema(
should_normalize = pyarrow.should_normalize_arrow_schema(
batch.schema, columns_schema, schema.naming
)
)[0]
if should_normalize:
logger.info(
f"When writing arrow table to {root_table_name} the schema requires"
Expand Down Expand Up @@ -366,25 +342,22 @@ def __call__(self, extracted_items_file: str, root_table_name: str) -> List[TSch
base_schema_update = self._fix_schema_precisions(root_table_name, arrow_schema)

add_dlt_id = self.config.parquet_normalizer.add_dlt_id
add_dlt_load_id = self.config.parquet_normalizer.add_dlt_load_id
# if we need to add any columns or the file format is not parquet, we can't just import files
must_rewrite = (
add_dlt_id or add_dlt_load_id or self.item_storage.writer_spec.file_format != "parquet"
)
must_rewrite = add_dlt_id or self.item_storage.writer_spec.file_format != "parquet"
if not must_rewrite:
# in rare cases normalization may be needed
must_rewrite, _, _, _, _ = pyarrow.should_normalize_arrow_schema(
must_rewrite = pyarrow.should_normalize_arrow_schema(
arrow_schema, self.schema.get_table_columns(root_table_name), self.schema.naming
)
)[0]
if must_rewrite:
logger.info(
f"Table {root_table_name} parquet file {extracted_items_file} must be rewritten:"
f" add_dlt_id: {add_dlt_id} add_dlt_load_id: {add_dlt_load_id} destination file"
f" add_dlt_id: {add_dlt_id} destination file"
f" format: {self.item_storage.writer_spec.file_format} or due to required"
" normalization "
)
schema_update = self._write_with_dlt_columns(
extracted_items_file, root_table_name, add_dlt_load_id, add_dlt_id
extracted_items_file, root_table_name, add_dlt_id
)
return base_schema_update + schema_update

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,10 @@ add_dlt_load_id = true
add_dlt_id = true
```

Keep in mind that enabling these incurs some performance overhead because the `parquet` file needs to be read back from disk in chunks, processed and rewritten with new columns.
Keep in mind that enabling these incurs some performance overhead:

- `add_dlt_load_id` has minimal overhead since the column is added to arrow table in memory during `extract` stage, before parquet file is written to disk
- `add_dlt_id` adds the column during `normalize` stage after file has been extracted to disk. The file needs to be read back from disk in chunks, processed and rewritten with new columns

## Incremental loading with Arrow tables

Expand Down

0 comments on commit 536f0f2

Please sign in to comment.