Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend custom destination #1107

Merged
merged 18 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions dlt/common/destination/capabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class DestinationCapabilitiesContext(ContainerInjectableContext):
insert_values_writer_type: str = "default"
supports_multiple_statements: bool = True
supports_clone_table: bool = False
max_table_nesting: Optional[int] = None # destination can overwrite max table nesting
"""Destination supports CREATE TABLE ... CLONE ... statements"""

# do not allow to create default value, destination caps must be always explicitly inserted into container
Expand Down
21 changes: 21 additions & 0 deletions dlt/common/destination/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,27 @@ def create_followup_jobs(self, final_state: TLoadJobState) -> List[NewLoadJob]:
return []


class DoNothingJob(LoadJob):
"""The most lazy class of dlt"""

def __init__(self, file_path: str) -> None:
super().__init__(FileStorage.get_file_name_from_file_path(file_path))

def state(self) -> TLoadJobState:
# this job is always done
return "completed"

def exception(self) -> str:
# this part of code should be never reached
raise NotImplementedError()


class DoNothingFollowupJob(DoNothingJob, FollowupJob):
"""The second most lazy class of dlt"""

pass


class JobClientBase(ABC):
capabilities: ClassVar[DestinationCapabilitiesContext] = None

Expand Down
14 changes: 12 additions & 2 deletions dlt/common/normalizers/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dlt.common.configuration.specs import BaseConfiguration
from dlt.common.destination import DestinationCapabilitiesContext
from dlt.common.normalizers.typing import TJSONNormalizer
from dlt.common.typing import StrAny
from dlt.common.typing import DictStrAny


@configspec
Expand All @@ -14,14 +14,24 @@ class NormalizersConfiguration(BaseConfiguration):
__section__: str = "schema"

naming: Optional[str] = None
json_normalizer: Optional[StrAny] = None
json_normalizer: Optional[DictStrAny] = None
destination_capabilities: Optional[DestinationCapabilitiesContext] = None # injectable

def on_resolved(self) -> None:
# get naming from capabilities if not present
if self.naming is None:
if self.destination_capabilities:
self.naming = self.destination_capabilities.naming_convention
# if max_table_nesting is set, we need to set the max_table_nesting in the json_normalizer
if (
self.destination_capabilities
and self.destination_capabilities.max_table_nesting is not None
):
self.json_normalizer = self.json_normalizer or {}
self.json_normalizer.setdefault("config", {})
self.json_normalizer["config"][
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the best we can do now. if we have more normalizers with incompatible configs then we'll need to look for something better

"max_nesting"
] = self.destination_capabilities.max_table_nesting

if TYPE_CHECKING:

Expand Down
8 changes: 5 additions & 3 deletions dlt/common/normalizers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,11 @@ def import_normalizers(
"""
# add defaults to normalizer_config
normalizers_config["names"] = names = normalizers_config["names"] or "snake_case"
normalizers_config["json"] = item_normalizer = normalizers_config["json"] or {
"module": "dlt.common.normalizers.json.relational"
}
# set default json normalizer module
normalizers_config["json"] = item_normalizer = normalizers_config.get("json") or {}
if "module" not in item_normalizer:
item_normalizer["module"] = "dlt.common.normalizers.json.relational"

try:
if "." in names:
# TODO: bump schema engine version and migrate schema. also change the name in TNormalizersConfig from names to naming
Expand Down
4 changes: 4 additions & 0 deletions dlt/destinations/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def destination(
batch_size: int = 10,
name: str = None,
naming_convention: str = "direct",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good! please add this to our docs: that the default settings are such that data comes to sink without changing identifiers, un-nested and with dlt identifiers removed. and that it is good to push stuff to queues and REST APIs

skip_dlt_columns_and_tables: bool = True,
max_table_nesting: int = 0,
spec: Type[GenericDestinationClientConfiguration] = GenericDestinationClientConfiguration,
) -> Callable[
[Callable[Concatenate[Union[TDataItems, str], TTableSchema, TDestinationCallableParams], Any]],
Expand All @@ -49,6 +51,8 @@ def wrapper(
batch_size=batch_size,
destination_name=name,
naming_convention=naming_convention,
skip_dlt_columns_and_tables=skip_dlt_columns_and_tables,
max_table_nesting=max_table_nesting,
**kwargs, # type: ignore
)

Expand Down
23 changes: 1 addition & 22 deletions dlt/destinations/impl/athena/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from dlt.common.schema.typing import TTableSchema, TColumnType, TWriteDisposition, TTableFormat
from dlt.common.schema.utils import table_schema_has_type, get_table_format
from dlt.common.destination import DestinationCapabilitiesContext
from dlt.common.destination.reference import LoadJob, FollowupJob
from dlt.common.destination.reference import LoadJob, DoNothingFollowupJob, DoNothingJob
from dlt.common.destination.reference import TLoadJobState, NewLoadJob, SupportsStagingDestination
from dlt.common.storages import FileStorage
from dlt.common.data_writers.escape import escape_bigquery_identifier
Expand Down Expand Up @@ -149,27 +149,6 @@ def __init__(self) -> None:
DLTAthenaFormatter._INSTANCE = self


class DoNothingJob(LoadJob):
"""The most lazy class of dlt"""

def __init__(self, file_path: str) -> None:
super().__init__(FileStorage.get_file_name_from_file_path(file_path))

def state(self) -> TLoadJobState:
# this job is always done
return "completed"

def exception(self) -> str:
# this part of code should be never reached
raise NotImplementedError()


class DoNothingFollowupJob(DoNothingJob, FollowupJob):
"""The second most lazy class of dlt"""

pass


class AthenaSQLClient(SqlClientBase[Connection]):
capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities()
dbapi: ClassVar[DBApi] = pyathena
Expand Down
3 changes: 3 additions & 0 deletions dlt/destinations/impl/destination/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from typing import Optional
from dlt.common.destination import DestinationCapabilitiesContext
from dlt.common.data_writers import TLoaderFileFormat


def capabilities(
preferred_loader_file_format: TLoaderFileFormat = "puae-jsonl",
naming_convention: str = "direct",
max_table_nesting: Optional[int] = 0,
) -> DestinationCapabilitiesContext:
caps = DestinationCapabilitiesContext.generic_capabilities(preferred_loader_file_format)
caps.supported_loader_file_formats = ["puae-jsonl", "parquet"]
caps.supports_ddl_transactions = False
caps.supports_transactions = False
caps.naming_convention = naming_convention
caps.max_table_nesting = max_table_nesting
return caps
2 changes: 2 additions & 0 deletions dlt/destinations/impl/destination/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ class GenericDestinationClientConfiguration(DestinationClientConfiguration):
destination_callable: Optional[Union[str, TDestinationCallable]] = None # noqa: A003
loader_file_format: TLoaderFileFormat = "puae-jsonl"
batch_size: int = 10
skip_dlt_columns_and_tables: bool = True
max_table_nesting: int = 0

if TYPE_CHECKING:

Expand Down
30 changes: 28 additions & 2 deletions dlt/destinations/impl/destination/destination.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from abc import ABC, abstractmethod
from types import TracebackType
from typing import ClassVar, Dict, Optional, Type, Iterable, Iterable, cast, Dict
from typing import ClassVar, Dict, Optional, Type, Iterable, Iterable, cast, Dict, List
from copy import deepcopy

from dlt.common.destination.reference import LoadJob
from dlt.destinations.job_impl import EmptyLoadJob
from dlt.common.typing import TDataItems, AnyFun
from dlt.common import json
Expand All @@ -18,6 +20,7 @@
from dlt.common.destination.reference import (
TLoadJobState,
LoadJob,
DoNothingJob,
JobClientBase,
)

Expand All @@ -37,6 +40,7 @@ def __init__(
schema: Schema,
destination_state: Dict[str, int],
destination_callable: TDestinationCallable,
skipped_columns: List[str],
) -> None:
super().__init__(FileStorage.get_file_name_from_file_path(file_path))
self._file_path = file_path
Expand All @@ -47,6 +51,7 @@ def __init__(
self._callable = destination_callable
self._state: TLoadJobState = "running"
self._storage_id = f"{self._parsed_file_name.table_name}.{self._parsed_file_name.file_id}"
self.skipped_columns = skipped_columns
try:
if self._config.batch_size == 0:
# on batch size zero we only call the callable with the filename
Expand Down Expand Up @@ -93,9 +98,14 @@ def run(self, start_index: int) -> Iterable[TDataItems]:
start_index % self._config.batch_size
) == 0, "Batch size was changed during processing of one load package"

# on record batches we cannot drop columns, we need to
# select the ones we want to keep
keep_columns = list(self._table["columns"].keys())
start_batch = start_index / self._config.batch_size
with pyarrow.parquet.ParquetFile(self._file_path) as reader:
for record_batch in reader.iter_batches(batch_size=self._config.batch_size):
for record_batch in reader.iter_batches(
batch_size=self._config.batch_size, columns=keep_columns
rudolfix marked this conversation as resolved.
Show resolved Hide resolved
):
if start_batch > 0:
start_batch -= 1
continue
Expand All @@ -115,6 +125,9 @@ def run(self, start_index: int) -> Iterable[TDataItems]:
if start_index > 0:
start_index -= 1
continue
# skip internal columns
for column in self.skipped_columns:
item.pop(column, None)
current_batch.append(item)
if len(current_batch) == self._config.batch_size:
yield current_batch
Expand Down Expand Up @@ -150,6 +163,17 @@ def update_stored_schema(
return super().update_stored_schema(only_tables, expected_update)

def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob:
# skip internal tables and remove columns from schema if so configured
skipped_columns: List[str] = []
if self.config.skip_dlt_columns_and_tables:
if table["name"].startswith(self.schema._dlt_tables_prefix):
return DoNothingJob(file_path)
table = deepcopy(table)
for column in list(table["columns"].keys()):
if column.startswith(self.schema._dlt_tables_prefix):
table["columns"].pop(column)
skipped_columns.append(column)

# save our state in destination name scope
load_state = destination_state()
if file_path.endswith("parquet"):
Expand All @@ -160,6 +184,7 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) ->
self.schema,
load_state,
self.destination_callable,
skipped_columns,
)
if file_path.endswith("jsonl"):
return DestinationJsonlLoadJob(
Expand All @@ -169,6 +194,7 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) ->
self.schema,
load_state,
self.destination_callable,
skipped_columns,
)
return None

Expand Down
5 changes: 3 additions & 2 deletions dlt/destinations/impl/destination/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ class DestinationInfo(t.NamedTuple):
class destination(Destination[GenericDestinationClientConfiguration, "DestinationClient"]):
def capabilities(self) -> DestinationCapabilitiesContext:
return capabilities(
self.config_params.get("loader_file_format", "puae-jsonl"),
self.config_params.get("naming_convention", "direct"),
preferred_loader_file_format=self.config_params.get("loader_file_format", "puae-jsonl"),
naming_convention=self.config_params.get("naming_convention", "direct"),
max_table_nesting=self.config_params.get("max_table_nesting", None),
)

@property
Expand Down
1 change: 1 addition & 0 deletions dlt/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,7 @@ def normalize(

# make sure destination capabilities are available
self._get_destination_capabilities()

# create default normalize config
normalize_config = NormalizeConfiguration(
workers=workers,
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# you can just paste services.json as credentials
[destination.bigquery.credentials]
client_email = ""
private_key = ""
project_id = ""
token_uri = ""
refresh_token = ""
client_id = ""
client_secret = ""
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import dlt
import pandas as pd
import pyarrow as pa
from google.cloud import bigquery

from dlt.common.configuration.specs import GcpServiceAccountCredentials

# constants
OWID_DISASTERS_URL = (
"https://raw.githubusercontent.com/owid/owid-datasets/master/datasets/"
"Natural%20disasters%20from%201900%20to%202019%20-%20EMDAT%20(2020)/"
"Natural%20disasters%20from%201900%20to%202019%20-%20EMDAT%20(2020).csv"
)
# this table needs to be manually created in your gc account
# format: "your-project.your_dataset.your_table"
BIGQUERY_TABLE_ID = "chat-analytics-rasa-ci.ci_streaming_insert.natural-disasters"

# dlt sources
@dlt.resource(name="natural_disasters")
def resource(url: str):
# load pyarrow table with pandas
table = pa.Table.from_pandas(pd.read_csv(url))
# we add a list type column to demontrate bigquery lists
table = table.append_column(
"tags",
pa.array(
[["disasters", "earthquakes", "floods", "tsunamis"]] * len(table),
pa.list_(pa.string()),
),
)
# we add a struct type column to demonstrate bigquery structs
table = table.append_column(
"meta",
pa.array(
[{"loaded_by": "dlt"}] * len(table),
pa.struct([("loaded_by", pa.string())]),
),
)
yield table

# dlt biquery custom destination
# we can use the dlt provided credentials class
# to retrieve the gcp credentials from the secrets
@dlt.destination(name="bigquery", loader_file_format="parquet", batch_size=0)
def bigquery_insert(
items, table, credentials: GcpServiceAccountCredentials = dlt.secrets.value
) -> None:
client = bigquery.Client(
credentials.project_id, credentials.to_native_credentials(), location="US"
)
job_config = bigquery.LoadJobConfig(
autodetect=True,
source_format=bigquery.SourceFormat.PARQUET,
schema_update_options=bigquery.SchemaUpdateOption.ALLOW_FIELD_ADDITION,
)
# since we have set the batch_size to 0, we get a filepath and can load the file directly
with open(items, "rb") as f:
load_job = client.load_table_from_file(f, BIGQUERY_TABLE_ID, job_config=job_config)
load_job.result() # Waits for the job to complete.

if __name__ == "__main__":
# run the pipeline and print load results
pipeline = dlt.pipeline(
pipeline_name="csv_to_bigquery_insert",
destination=bigquery_insert,
dataset_name="mydata",
full_refresh=True,
)
load_info = pipeline.run(resource(url=OWID_DISASTERS_URL))

print(load_info)