From 5d20c6044044fcdee51926dbffc039e3c5f7cd17 Mon Sep 17 00:00:00 2001 From: Dave Date: Thu, 24 Aug 2023 16:49:40 +0200 Subject: [PATCH] format codebase --- dlt/__init__.py | 17 +- dlt/cli/_dlt.py | 446 ++++++++++--- dlt/cli/config_toml_writer.py | 41 +- dlt/cli/deploy_command.py | 314 +++++++--- dlt/cli/deploy_command_helpers.py | 160 +++-- dlt/cli/echo.py | 4 +- dlt/cli/init_command.py | 287 ++++++--- dlt/cli/pipeline_command.py | 136 +++- dlt/cli/pipeline_files.py | 96 +-- dlt/cli/requirements.py | 4 +- dlt/cli/source_detection.py | 54 +- dlt/cli/telemetry_command.py | 11 +- dlt/cli/utils.py | 26 +- dlt/common/__init__.py | 8 +- dlt/common/arithmetics.py | 16 +- dlt/common/configuration/__init__.py | 18 +- dlt/common/configuration/accessors.py | 17 +- dlt/common/configuration/container.py | 6 +- dlt/common/configuration/exceptions.py | 77 ++- dlt/common/configuration/inject.py | 41 +- dlt/common/configuration/paths.py | 11 +- .../configuration/providers/__init__.py | 16 +- dlt/common/configuration/providers/airflow.py | 3 +- dlt/common/configuration/providers/context.py | 7 +- .../configuration/providers/dictionary.py | 11 +- dlt/common/configuration/providers/environ.py | 8 +- .../configuration/providers/google_secrets.py | 28 +- .../configuration/providers/provider.py | 7 +- dlt/common/configuration/providers/toml.py | 64 +- dlt/common/configuration/resolve.py | 189 ++++-- dlt/common/configuration/specs/__init__.py | 28 +- .../configuration/specs/api_credentials.py | 6 +- .../configuration/specs/aws_credentials.py | 21 +- .../configuration/specs/base_configuration.py | 76 ++- .../specs/config_providers_context.py | 53 +- .../specs/config_section_context.py | 35 +- .../specs/connection_string_credentials.py | 19 +- dlt/common/configuration/specs/exceptions.py | 33 +- .../configuration/specs/gcp_credentials.py | 60 +- .../configuration/specs/known_sections.py | 2 +- .../configuration/specs/run_configuration.py | 16 +- dlt/common/configuration/utils.py | 56 +- dlt/common/data_types/__init__.py | 2 +- dlt/common/data_types/type_helpers.py | 17 +- dlt/common/data_types/typing.py | 5 +- dlt/common/data_writers/__init__.py | 8 +- dlt/common/data_writers/buffered.py | 44 +- dlt/common/data_writers/escape.py | 15 +- dlt/common/data_writers/exceptions.py | 9 +- dlt/common/data_writers/writers.py | 87 ++- dlt/common/destination/__init__.py | 2 +- dlt/common/destination/capabilities.py | 17 +- dlt/common/destination/reference.py | 130 ++-- dlt/common/exceptions.py | 83 ++- dlt/common/git.py | 50 +- dlt/common/json/__init__.py | 61 +- dlt/common/json/_orjson.py | 17 +- dlt/common/json/_simplejson.py | 31 +- dlt/common/jsonpath.py | 8 +- dlt/common/libs/pyarrow.py | 8 +- dlt/common/normalizers/__init__.py | 2 +- dlt/common/normalizers/configuration.py | 3 +- dlt/common/normalizers/exceptions.py | 5 +- dlt/common/normalizers/json/__init__.py | 11 +- dlt/common/normalizers/json/relational.py | 82 ++- dlt/common/normalizers/naming/__init__.py | 3 +- dlt/common/normalizers/naming/direct.py | 2 +- dlt/common/normalizers/naming/duck_case.py | 5 +- dlt/common/normalizers/naming/exceptions.py | 6 +- dlt/common/normalizers/naming/naming.py | 35 +- dlt/common/normalizers/naming/snake_case.py | 13 +- dlt/common/normalizers/typing.py | 2 +- dlt/common/normalizers/utils.py | 32 +- dlt/common/pendulum.py | 3 +- dlt/common/pipeline.py | 118 ++-- dlt/common/reflection/function_visitor.py | 1 + dlt/common/reflection/spec.py | 25 +- dlt/common/reflection/utils.py | 19 +- dlt/common/runners/configuration.py | 9 +- dlt/common/runners/pool_runner.py | 23 +- dlt/common/runners/runnable.py | 6 +- dlt/common/runners/stdout.py | 16 +- dlt/common/runners/synth_pickle.py | 9 +- dlt/common/runners/venv.py | 17 +- dlt/common/runtime/__init__.py | 2 +- dlt/common/runtime/collector.py | 109 +++- dlt/common/runtime/exec_info.py | 21 +- dlt/common/runtime/logger.py | 23 +- dlt/common/runtime/prometheus.py | 6 +- dlt/common/runtime/segment.py | 51 +- dlt/common/runtime/sentry.py | 23 +- dlt/common/runtime/signals.py | 4 +- dlt/common/runtime/slack.py | 12 +- dlt/common/runtime/telemetry.py | 19 +- dlt/common/schema/__init__.py | 13 +- dlt/common/schema/detections.py | 3 +- dlt/common/schema/exceptions.py | 52 +- dlt/common/schema/schema.py | 200 ++++-- dlt/common/schema/typing.py | 64 +- dlt/common/schema/utils.py | 281 +++++---- dlt/common/source.py | 3 +- dlt/common/storages/__init__.py | 15 +- dlt/common/storages/configuration.py | 26 +- dlt/common/storages/data_item_storage.py | 22 +- dlt/common/storages/exceptions.py | 57 +- dlt/common/storages/file_storage.py | 54 +- dlt/common/storages/live_schema_storage.py | 9 +- dlt/common/storages/load_storage.py | 217 +++++-- dlt/common/storages/normalize_storage.py | 22 +- dlt/common/storages/schema_storage.py | 63 +- dlt/common/storages/transactional_file.py | 23 +- dlt/common/storages/versioned_storage.py | 23 +- dlt/common/time.py | 26 +- dlt/common/typing.py | 51 +- dlt/common/utils.py | 82 ++- dlt/common/validation.py | 81 ++- dlt/common/wei.py | 15 +- dlt/destinations/athena/__init__.py | 28 +- dlt/destinations/athena/athena.py | 128 ++-- dlt/destinations/athena/configuration.py | 4 +- dlt/destinations/bigquery/__init__.py | 25 +- dlt/destinations/bigquery/bigquery.py | 148 +++-- dlt/destinations/bigquery/configuration.py | 16 +- dlt/destinations/bigquery/sql_client.py | 89 ++- dlt/destinations/duckdb/__init__.py | 23 +- dlt/destinations/duckdb/configuration.py | 15 +- dlt/destinations/duckdb/duck.py | 43 +- dlt/destinations/duckdb/sql_client.py | 50 +- dlt/destinations/dummy/__init__.py | 19 +- dlt/destinations/dummy/configuration.py | 6 +- dlt/destinations/dummy/dummy.py | 46 +- dlt/destinations/exceptions.py | 60 +- dlt/destinations/filesystem/__init__.py | 26 +- dlt/destinations/filesystem/configuration.py | 27 +- dlt/destinations/filesystem/filesystem.py | 96 ++- .../filesystem/filesystem_client.py | 27 +- dlt/destinations/insert_job_client.py | 15 +- dlt/destinations/job_client_impl.py | 197 ++++-- dlt/destinations/job_impl.py | 14 +- dlt/destinations/motherduck/__init__.py | 27 +- dlt/destinations/motherduck/configuration.py | 12 +- dlt/destinations/motherduck/motherduck.py | 10 +- dlt/destinations/motherduck/sql_client.py | 33 +- dlt/destinations/path_utils.py | 42 +- dlt/destinations/postgres/__init__.py | 23 +- dlt/destinations/postgres/configuration.py | 8 +- dlt/destinations/postgres/postgres.py | 55 +- dlt/destinations/postgres/sql_client.py | 58 +- dlt/destinations/redshift/__init__.py | 21 +- dlt/destinations/redshift/configuration.py | 5 +- dlt/destinations/redshift/redshift.py | 119 ++-- dlt/destinations/snowflake/__init__.py | 24 +- dlt/destinations/snowflake/configuration.py | 58 +- dlt/destinations/snowflake/snowflake.py | 114 ++-- dlt/destinations/snowflake/sql_client.py | 40 +- dlt/destinations/sql_client.py | 58 +- dlt/destinations/sql_jobs.py | 150 ++++- dlt/destinations/typing.py | 10 +- dlt/destinations/weaviate/__init__.py | 12 +- dlt/destinations/weaviate/configuration.py | 18 +- dlt/destinations/weaviate/exceptions.py | 2 +- dlt/destinations/weaviate/naming.py | 11 +- dlt/destinations/weaviate/weaviate_adapter.py | 10 +- dlt/destinations/weaviate/weaviate_client.py | 93 +-- dlt/extract/decorators.py | 135 ++-- dlt/extract/exceptions.py | 208 ++++-- dlt/extract/extract.py | 42 +- dlt/extract/incremental.py | 205 ++++-- dlt/extract/pipe.py | 194 ++++-- dlt/extract/schema.py | 75 ++- dlt/extract/source.py | 253 ++++++-- dlt/extract/typing.py | 20 +- dlt/extract/utils.py | 8 +- dlt/helpers/airflow_helper.py | 75 ++- dlt/helpers/dbt/__init__.py | 35 +- dlt/helpers/dbt/configuration.py | 6 +- dlt/helpers/dbt/dbt_utils.py | 57 +- dlt/helpers/dbt/exceptions.py | 6 +- dlt/helpers/dbt/runner.py | 102 ++- dlt/helpers/pandas_helper.py | 11 +- dlt/helpers/streamlit_helper.py | 68 +- dlt/load/configuration.py | 5 +- dlt/load/exceptions.py | 26 +- dlt/load/load.py | 266 ++++++-- dlt/normalize/__init__.py | 2 +- dlt/normalize/configuration.py | 9 +- dlt/normalize/normalize.py | 165 +++-- dlt/pipeline/__init__.py | 43 +- dlt/pipeline/configuration.py | 4 +- dlt/pipeline/current.py | 5 +- dlt/pipeline/dbt.py | 36 +- dlt/pipeline/exceptions.py | 63 +- dlt/pipeline/helpers.py | 84 ++- dlt/pipeline/mark.py | 2 +- dlt/pipeline/pipeline.py | 552 +++++++++++----- dlt/pipeline/progress.py | 11 +- dlt/pipeline/state_sync.py | 68 +- dlt/pipeline/trace.py | 72 ++- dlt/pipeline/track.py | 28 +- dlt/reflection/names.py | 6 +- dlt/reflection/script_inspector.py | 49 +- dlt/reflection/script_visitor.py | 15 +- dlt/sources/__init__.py | 2 +- dlt/sources/credentials.py | 13 +- dlt/sources/helpers/requests/__init__.py | 17 +- dlt/sources/helpers/requests/retry.py | 68 +- dlt/sources/helpers/requests/session.py | 21 +- dlt/sources/helpers/requests/typing.py | 2 +- dlt/sources/helpers/transform.py | 4 + dlt/version.py | 5 +- docs/examples/_helpers.py | 11 +- docs/examples/chess/chess.py | 22 +- docs/examples/chess/chess_dbt.py | 2 +- docs/examples/credentials/explicit.py | 15 +- docs/examples/dbt_run_jaffle.py | 13 +- docs/examples/discord_iterator.py | 1 - docs/examples/google_sheets.py | 5 +- docs/examples/quickstart.py | 39 +- docs/examples/rasa_example.py | 15 +- docs/examples/read_table.py | 14 +- docs/examples/restore_pipeline.py | 2 +- docs/examples/singer_tap_example.py | 17 +- docs/examples/singer_tap_jsonl_example.py | 12 +- docs/examples/sources/google_sheets.py | 48 +- docs/examples/sources/jsonl.py | 11 +- docs/examples/sources/rasa/__init__.py | 2 +- docs/examples/sources/rasa/rasa.py | 10 +- docs/examples/sources/singer_tap.py | 37 +- docs/examples/sources/sql_query.py | 55 +- docs/examples/sources/stdout.py | 1 - docs/snippets/conftest.py | 4 +- docs/snippets/intro_snippet.py | 14 +- docs/snippets/intro_snippet_test.py | 4 +- docs/snippets/utils.py | 9 +- tests/cases.py | 166 ++--- .../cases/deploy_pipeline/debug_pipeline.py | 13 +- tests/cli/common/test_cli_invoke.py | 84 ++- tests/cli/common/test_telemetry_command.py | 45 +- tests/cli/conftest.py | 2 +- tests/cli/test_config_toml_writer.py | 62 +- tests/cli/test_deploy_command.py | 106 +++- tests/cli/test_init_command.py | 183 ++++-- tests/cli/test_pipeline_command.py | 36 +- tests/cli/utils.py | 22 +- tests/common/cases/modules/uniq_mod_121.py | 3 + tests/common/configuration/test_accessors.py | 108 +++- .../configuration/test_configuration.py | 438 +++++++++---- tests/common/configuration/test_container.py | 17 +- .../common/configuration/test_credentials.py | 56 +- .../configuration/test_environ_provider.py | 45 +- tests/common/configuration/test_inject.py | 66 +- tests/common/configuration/test_providers.py | 1 + tests/common/configuration/test_sections.py | 69 +- tests/common/configuration/test_spec_union.py | 84 ++- .../configuration/test_toml_provider.py | 158 +++-- tests/common/configuration/utils.py | 55 +- .../common/normalizers/custom_normalizers.py | 7 +- .../normalizers/test_import_normalizers.py | 29 +- .../normalizers/test_json_relational.py | 406 ++++++------ tests/common/normalizers/test_naming.py | 109 ++-- .../normalizers/test_naming_duck_case.py | 14 +- .../normalizers/test_naming_snake_case.py | 10 +- tests/common/reflection/test_reflect_spec.py | 174 ++++- tests/common/runners/test_pipes.py | 41 +- tests/common/runners/test_runnable.py | 15 +- tests/common/runners/test_runners.py | 31 +- tests/common/runners/test_venv.py | 14 +- tests/common/runners/utils.py | 19 +- tests/common/runtime/test_collector.py | 5 +- tests/common/runtime/test_logging.py | 43 +- tests/common/runtime/test_signals.py | 7 +- tests/common/runtime/test_telemetry.py | 31 +- tests/common/runtime/utils.py | 1 + tests/common/schema/test_coercion.py | 100 ++- tests/common/schema/test_detections.py | 25 +- tests/common/schema/test_filtering.py | 53 +- tests/common/schema/test_inference.py | 77 ++- tests/common/schema/test_schema.py | 307 ++++++--- tests/common/schema/test_versioning.py | 3 +- tests/common/scripts/args.py | 2 +- tests/common/scripts/counter.py | 3 +- tests/common/scripts/cwd.py | 2 +- tests/common/scripts/long_lines.py | 2 +- tests/common/scripts/long_lines_fails.py | 2 +- tests/common/scripts/no_stdout_exception.py | 2 +- .../scripts/no_stdout_no_stderr_with_fail.py | 2 +- tests/common/scripts/raising_counter.py | 3 +- .../common/scripts/stdout_encode_exception.py | 4 +- tests/common/scripts/stdout_encode_result.py | 1 + tests/common/storages/test_file_storage.py | 45 +- tests/common/storages/test_loader_storage.py | 63 +- .../common/storages/test_normalize_storage.py | 18 +- tests/common/storages/test_schema_storage.py | 80 ++- .../storages/test_transactional_file.py | 13 +- .../common/storages/test_versioned_storage.py | 11 +- tests/common/test_arithmetics.py | 6 +- .../test_data_writers/test_buffered_writer.py | 51 +- .../test_data_writers/test_data_writers.py | 61 +- .../test_data_writers/test_parquet_writer.py | 84 ++- tests/common/test_destination.py | 74 ++- tests/common/test_git.py | 28 +- tests/common/test_json.py | 44 +- tests/common/test_pipeline_state.py | 8 +- tests/common/test_time.py | 18 +- tests/common/test_typing.py | 38 +- tests/common/test_utils.py | 89 ++- tests/common/test_validation.py | 43 +- tests/common/test_version.py | 3 +- tests/common/test_wei.py | 25 +- tests/common/utils.py | 35 +- tests/conftest.py | 68 +- tests/destinations/test_path_utils.py | 32 +- tests/extract/cases/eth_source/source.py | 2 + .../section_source/external_resources.py | 20 +- .../cases/section_source/named_module.py | 1 + tests/extract/conftest.py | 8 +- tests/extract/test_decorators.py | 143 +++-- tests/extract/test_extract.py | 16 +- tests/extract/test_extract_pipe.py | 108 ++-- tests/extract/test_incremental.py | 592 +++++++++++------- tests/extract/test_sources.py | 165 +++-- tests/extract/utils.py | 23 +- tests/helpers/airflow_tests/conftest.py | 2 +- .../airflow_tests/test_airflow_provider.py | 92 ++- .../airflow_tests/test_airflow_wrapper.py | 383 +++++++---- .../test_join_airflow_scheduler.py | 169 +++-- tests/helpers/airflow_tests/utils.py | 11 +- .../helpers/dbt_tests/local/test_dbt_utils.py | 79 ++- .../local/test_runner_destinations.py | 91 ++- tests/helpers/dbt_tests/local/utils.py | 24 +- .../dbt_tests/test_runner_dbt_versions.py | 158 +++-- tests/helpers/dbt_tests/utils.py | 33 +- .../providers/test_google_secrets_provider.py | 72 ++- tests/load/bigquery/test_bigquery_client.py | 132 ++-- .../bigquery/test_bigquery_table_builder.py | 16 +- tests/load/cases/fake_destination.py | 2 +- tests/load/conftest.py | 10 +- tests/load/duckdb/test_duckdb_client.py | 44 +- .../load/duckdb/test_duckdb_table_builder.py | 20 +- tests/load/duckdb/test_motherduck_client.py | 15 +- tests/load/filesystem/test_aws_credentials.py | 34 +- .../load/filesystem/test_filesystem_client.py | 99 +-- tests/load/filesystem/utils.py | 20 +- tests/load/pipeline/conftest.py | 9 +- tests/load/pipeline/test_athena.py | 119 ++-- tests/load/pipeline/test_dbt_helper.py | 71 ++- tests/load/pipeline/test_drop.py | 175 ++++-- .../load/pipeline/test_filesystem_pipeline.py | 61 +- tests/load/pipeline/test_merge_disposition.py | 206 ++++-- tests/load/pipeline/test_pipelines.py | 368 +++++++---- .../load/pipeline/test_replace_disposition.py | 225 ++++--- tests/load/pipeline/test_restore_state.py | 196 ++++-- tests/load/pipeline/test_stage_loading.py | 139 +++- tests/load/pipeline/utils.py | 77 ++- tests/load/postgres/test_postgres_client.py | 65 +- .../postgres/test_postgres_table_builder.py | 26 +- tests/load/redshift/test_redshift_client.py | 44 +- .../redshift/test_redshift_table_builder.py | 37 +- .../snowflake/test_snowflake_configuration.py | 68 +- .../snowflake/test_snowflake_table_builder.py | 18 +- tests/load/test_dummy_client.py | 218 ++++--- tests/load/test_insert_job_client.py | 149 +++-- tests/load/test_job_client.py | 297 ++++++--- tests/load/test_sql_client.py | 258 ++++++-- tests/load/utils.py | 294 ++++++--- tests/load/weaviate/test_naming.py | 7 +- tests/load/weaviate/test_pipeline.py | 43 +- tests/load/weaviate/test_weaviate_client.py | 42 +- tests/load/weaviate/utils.py | 7 +- tests/normalize/mock_rasa_json_normalizer.py | 17 +- tests/normalize/test_normalize.py | 292 ++++++--- tests/normalize/utils.py | 9 +- .../cases/github_pipeline/github_extract.py | 8 +- .../cases/github_pipeline/github_pipeline.py | 18 +- tests/pipeline/conftest.py | 10 +- tests/pipeline/test_dlt_versions.py | 78 ++- tests/pipeline/test_pipeline.py | 241 ++++--- .../test_pipeline_file_format_resolver.py | 18 +- tests/pipeline/test_pipeline_state.py | 146 +++-- tests/pipeline/test_pipeline_trace.py | 132 ++-- tests/pipeline/utils.py | 9 +- tests/reflection/module_cases/__init__.py | 4 +- tests/reflection/module_cases/all_imports.py | 2 +- .../module_cases/dlt_import_exception.py | 1 - .../module_cases/executes_resource.py | 3 +- .../reflection/module_cases/import_as_type.py | 2 + tests/reflection/module_cases/no_pkg.py | 2 +- tests/reflection/module_cases/raises.py | 3 +- .../module_cases/stripe_analytics/__init__.py | 2 +- .../stripe_analytics/stripe_analytics.py | 2 +- .../module_cases/stripe_analytics_pipeline.py | 4 +- tests/reflection/test_script_inspector.py | 18 +- tests/sources/helpers/test_requests.py | 122 ++-- tests/tools/clean_redshift.py | 5 +- tests/tools/create_storages.py | 10 +- tests/utils.py | 42 +- 396 files changed, 14979 insertions(+), 7186 deletions(-) diff --git a/dlt/__init__.py b/dlt/__init__.py index c0be967ce3..58a0504535 100644 --- a/dlt/__init__.py +++ b/dlt/__init__.py @@ -21,17 +21,20 @@ For more detailed info, see https://dlthub.com/docs/walkthroughs """ -from dlt.version import __version__ +from dlt import sources from dlt.common.configuration.accessors import config, secrets -from dlt.common.typing import TSecretValue as _TSecretValue from dlt.common.configuration.specs import CredentialsConfiguration as _CredentialsConfiguration from dlt.common.pipeline import source_state as state from dlt.common.schema import Schema - -from dlt import sources -from dlt.extract.decorators import source, resource, transformer, defer -from dlt.pipeline import pipeline as _pipeline, run, attach, Pipeline, dbt, current as _current, mark as _mark -from dlt.pipeline import progress +from dlt.common.typing import TSecretValue as _TSecretValue +from dlt.extract.decorators import defer, resource, source, transformer +from dlt.pipeline import Pipeline, attach +from dlt.pipeline import current as _current +from dlt.pipeline import dbt +from dlt.pipeline import mark as _mark +from dlt.pipeline import pipeline as _pipeline +from dlt.pipeline import progress, run +from dlt.version import __version__ pipeline = _pipeline current = _current diff --git a/dlt/cli/_dlt.py b/dlt/cli/_dlt.py index f719c30de0..b79997fa97 100644 --- a/dlt/cli/_dlt.py +++ b/dlt/cli/_dlt.py @@ -1,32 +1,52 @@ -from typing import Any, Sequence, Optional -import yaml -import os import argparse +import os +from typing import Any, Optional, Sequence + import click +import yaml -from dlt.version import __version__ +import dlt.cli.echo as fmt +from dlt.cli import utils +from dlt.cli.init_command import ( + DEFAULT_VERIFIED_SOURCES_REPO, + DLT_INIT_DOCS_URL, + init_command, + list_verified_sources_command, +) +from dlt.cli.pipeline_command import DLT_PIPELINE_COMMAND_DOCS_URL, pipeline_command +from dlt.cli.telemetry_command import ( + DLT_TELEMETRY_DOCS_URL, + change_telemetry_status_command, + telemetry_status_command, +) from dlt.common import json +from dlt.common.runners import Venv from dlt.common.schema import Schema from dlt.common.typing import DictStrAny -from dlt.common.runners import Venv - -import dlt.cli.echo as fmt -from dlt.cli import utils from dlt.pipeline.exceptions import CannotRestorePipelineException - -from dlt.cli.init_command import init_command, list_verified_sources_command, DLT_INIT_DOCS_URL, DEFAULT_VERIFIED_SOURCES_REPO -from dlt.cli.pipeline_command import pipeline_command, DLT_PIPELINE_COMMAND_DOCS_URL -from dlt.cli.telemetry_command import DLT_TELEMETRY_DOCS_URL, change_telemetry_status_command, telemetry_status_command +from dlt.version import __version__ try: from dlt.cli import deploy_command - from dlt.cli.deploy_command import PipelineWasNotRun, DLT_DEPLOY_DOCS_URL, DeploymentMethods, COMMAND_DEPLOY_REPO_LOCATION, SecretFormats + from dlt.cli.deploy_command import ( + COMMAND_DEPLOY_REPO_LOCATION, + DLT_DEPLOY_DOCS_URL, + DeploymentMethods, + PipelineWasNotRun, + SecretFormats, + ) except ModuleNotFoundError: pass @utils.track_command("init", False, "source_name", "destination_name") -def init_command_wrapper(source_name: str, destination_name: str, use_generic_template: bool, repo_location: str, branch: str) -> int: +def init_command_wrapper( + source_name: str, + destination_name: str, + use_generic_template: bool, + repo_location: str, + branch: str, +) -> int: try: init_command(source_name, destination_name, use_generic_template, repo_location, branch) except Exception as ex: @@ -48,7 +68,12 @@ def list_verified_sources_command_wrapper(repo_location: str, branch: str) -> in @utils.track_command("deploy", False, "deployment_method") -def deploy_command_wrapper(pipeline_script_path: str, deployment_method: str, repo_location: str, branch: Optional[str] = None, **kwargs: Any +def deploy_command_wrapper( + pipeline_script_path: str, + deployment_method: str, + repo_location: str, + branch: Optional[str] = None, + **kwargs: Any, ) -> int: try: utils.ensure_git_command("deploy") @@ -57,36 +82,42 @@ def deploy_command_wrapper(pipeline_script_path: str, deployment_method: str, re return -1 from git import InvalidGitRepositoryError, NoSuchPathError + try: deploy_command.deploy_command( pipeline_script_path=pipeline_script_path, deployment_method=deployment_method, repo_location=repo_location, branch=branch, - **kwargs + **kwargs, ) except (CannotRestorePipelineException, PipelineWasNotRun) as ex: click.secho(str(ex), err=True, fg="red") - fmt.note("You must run the pipeline locally successfully at least once in order to deploy it.") + fmt.note( + "You must run the pipeline locally successfully at least once in order to deploy it." + ) fmt.note("Please refer to %s for further assistance" % fmt.bold(DLT_DEPLOY_DOCS_URL)) return -2 except InvalidGitRepositoryError: click.secho( "No git repository found for pipeline script %s." % fmt.bold(pipeline_script_path), err=True, - fg="red" + fg="red", ) fmt.note("If you do not have a repository yet, you can do either of:") - fmt.note("- Run the following command to initialize new repository: %s" % fmt.bold("git init")) - fmt.note("- Add your local code to Github as described here: %s" % fmt.bold("https://docs.github.com/en/get-started/importing-your-projects-to-github/importing-source-code-to-github/adding-locally-hosted-code-to-github")) + fmt.note( + "- Run the following command to initialize new repository: %s" % fmt.bold("git init") + ) + fmt.note( + "- Add your local code to Github as described here: %s" + % fmt.bold( + "https://docs.github.com/en/get-started/importing-your-projects-to-github/importing-source-code-to-github/adding-locally-hosted-code-to-github" + ) + ) fmt.note("Please refer to %s for further assistance" % fmt.bold(DLT_DEPLOY_DOCS_URL)) return -3 except NoSuchPathError as path_ex: - click.secho( - "The pipeline script does not exist\n%s" % str(path_ex), - err=True, - fg="red" - ) + click.secho("The pipeline script does not exist\n%s" % str(path_ex), err=True, fg="red") return -4 except Exception as ex: click.secho(str(ex), err=True, fg="red") @@ -98,14 +129,17 @@ def deploy_command_wrapper(pipeline_script_path: str, deployment_method: str, re @utils.track_command("pipeline", True, "operation") def pipeline_command_wrapper( - operation: str, pipeline_name: str, pipelines_dir: str, verbosity: int, **command_kwargs: Any + operation: str, pipeline_name: str, pipelines_dir: str, verbosity: int, **command_kwargs: Any ) -> int: try: pipeline_command(operation, pipeline_name, pipelines_dir, verbosity, **command_kwargs) return 0 except CannotRestorePipelineException as ex: click.secho(str(ex), err=True, fg="red") - click.secho("Try command %s to restore the pipeline state from destination" % fmt.bold(f"dlt pipeline {pipeline_name} sync")) + click.secho( + "Try command %s to restore the pipeline state from destination" + % fmt.bold(f"dlt pipeline {pipeline_name} sync") + ) return 1 except Exception as ex: click.secho(str(ex), err=True, fg="red") @@ -152,21 +186,31 @@ def telemetry_change_status_command_wrapper(enabled: bool) -> int: ACTION_EXECUTED = False + def print_help(parser: argparse.ArgumentParser) -> None: if not ACTION_EXECUTED: parser.print_help() class TelemetryAction(argparse.Action): - def __init__(self, option_strings: Sequence[str], dest: Any = argparse.SUPPRESS, default: Any = argparse.SUPPRESS, help: str = None) -> None: # noqa + def __init__( + self, + option_strings: Sequence[str], + dest: Any = argparse.SUPPRESS, + default: Any = argparse.SUPPRESS, + help: str = None, + ) -> None: # noqa super(TelemetryAction, self).__init__( - option_strings=option_strings, - dest=dest, - default=default, - nargs=0, - help=help + option_strings=option_strings, dest=dest, default=default, nargs=0, help=help ) - def __call__(self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str = None) -> None: + + def __call__( + self, + parser: argparse.ArgumentParser, + namespace: argparse.Namespace, + values: Any, + option_string: str = None, + ) -> None: global ACTION_EXECUTED ACTION_EXECUTED = True @@ -174,114 +218,315 @@ def __call__(self, parser: argparse.ArgumentParser, namespace: argparse.Namespac class NonInteractiveAction(argparse.Action): - def __init__(self, option_strings: Sequence[str], dest: Any = argparse.SUPPRESS, default: Any = argparse.SUPPRESS, help: str = None) -> None: # noqa + def __init__( + self, + option_strings: Sequence[str], + dest: Any = argparse.SUPPRESS, + default: Any = argparse.SUPPRESS, + help: str = None, + ) -> None: # noqa super(NonInteractiveAction, self).__init__( - option_strings=option_strings, - dest=dest, - default=default, - nargs=0, - help=help + option_strings=option_strings, dest=dest, default=default, nargs=0, help=help ) - def __call__(self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str = None) -> None: + + def __call__( + self, + parser: argparse.ArgumentParser, + namespace: argparse.Namespace, + values: Any, + option_string: str = None, + ) -> None: fmt.ALWAYS_CHOOSE_DEFAULT = True def main() -> int: - parser = argparse.ArgumentParser(description="Creates, adds, inspects and deploys dlt pipelines.", formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--version', action="version", version='%(prog)s {version}'.format(version=__version__)) - parser.add_argument('--disable-telemetry', action=TelemetryAction, help="Disables telemetry before command is executed") - parser.add_argument('--enable-telemetry', action=TelemetryAction, help="Enables telemetry before command is executed") - parser.add_argument('--non-interactive', action=NonInteractiveAction, help="Non interactive mode. Default choices are automatically made for confirmations and prompts.") + parser = argparse.ArgumentParser( + description="Creates, adds, inspects and deploys dlt pipelines.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--version", action="version", version="%(prog)s {version}".format(version=__version__) + ) + parser.add_argument( + "--disable-telemetry", + action=TelemetryAction, + help="Disables telemetry before command is executed", + ) + parser.add_argument( + "--enable-telemetry", + action=TelemetryAction, + help="Enables telemetry before command is executed", + ) + parser.add_argument( + "--non-interactive", + action=NonInteractiveAction, + help=( + "Non interactive mode. Default choices are automatically made for confirmations and" + " prompts." + ), + ) subparsers = parser.add_subparsers(dest="command") - init_cmd = subparsers.add_parser("init", help="Creates a pipeline project in the current folder by adding existing verified source or creating a new one from template.") - init_cmd.add_argument("--list-verified-sources", "-l", default=False, action="store_true", help="List available verified sources") - init_cmd.add_argument("source", nargs='?', help="Name of data source for which to create a pipeline. Adds existing verified source or creates a new pipeline template if verified source for your data source is not yet implemented.") - init_cmd.add_argument("destination", nargs='?', help="Name of a destination ie. bigquery or redshift") - init_cmd.add_argument("--location", default=DEFAULT_VERIFIED_SOURCES_REPO, help="Advanced. Uses a specific url or local path to verified sources repository.") - init_cmd.add_argument("--branch", default=None, help="Advanced. Uses specific branch of the init repository to fetch the template.") - init_cmd.add_argument("--generic", default=False, action="store_true", help="When present uses a generic template with all the dlt loading code present will be used. Otherwise a debug template is used that can be immediately run to get familiar with the dlt sources.") + init_cmd = subparsers.add_parser( + "init", + help=( + "Creates a pipeline project in the current folder by adding existing verified source or" + " creating a new one from template." + ), + ) + init_cmd.add_argument( + "--list-verified-sources", + "-l", + default=False, + action="store_true", + help="List available verified sources", + ) + init_cmd.add_argument( + "source", + nargs="?", + help=( + "Name of data source for which to create a pipeline. Adds existing verified source or" + " creates a new pipeline template if verified source for your data source is not yet" + " implemented." + ), + ) + init_cmd.add_argument( + "destination", nargs="?", help="Name of a destination ie. bigquery or redshift" + ) + init_cmd.add_argument( + "--location", + default=DEFAULT_VERIFIED_SOURCES_REPO, + help="Advanced. Uses a specific url or local path to verified sources repository.", + ) + init_cmd.add_argument( + "--branch", + default=None, + help="Advanced. Uses specific branch of the init repository to fetch the template.", + ) + init_cmd.add_argument( + "--generic", + default=False, + action="store_true", + help=( + "When present uses a generic template with all the dlt loading code present will be" + " used. Otherwise a debug template is used that can be immediately run to get familiar" + " with the dlt sources." + ), + ) # deploy command requires additional dependencies try: # make sure the name is defined _ = deploy_command - deploy_comm = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, add_help=False) - deploy_comm.add_argument("--location", default=COMMAND_DEPLOY_REPO_LOCATION, help="Advanced. Uses a specific url or local path to pipelines repository.") - deploy_comm.add_argument("--branch", help="Advanced. Uses specific branch of the deploy repository to fetch the template.") + deploy_comm = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, add_help=False + ) + deploy_comm.add_argument( + "--location", + default=COMMAND_DEPLOY_REPO_LOCATION, + help="Advanced. Uses a specific url or local path to pipelines repository.", + ) + deploy_comm.add_argument( + "--branch", + help="Advanced. Uses specific branch of the deploy repository to fetch the template.", + ) - deploy_cmd = subparsers.add_parser("deploy", help="Creates a deployment package for a selected pipeline script") - deploy_cmd.add_argument("pipeline_script_path", metavar="pipeline-script-path", help="Path to a pipeline script") + deploy_cmd = subparsers.add_parser( + "deploy", help="Creates a deployment package for a selected pipeline script" + ) + deploy_cmd.add_argument( + "pipeline_script_path", metavar="pipeline-script-path", help="Path to a pipeline script" + ) deploy_sub_parsers = deploy_cmd.add_subparsers(dest="deployment_method") # deploy github actions - deploy_github_cmd = deploy_sub_parsers.add_parser(DeploymentMethods.github_actions.value, help="Deploys the pipeline to Github Actions", parents=[deploy_comm]) - deploy_github_cmd.add_argument("--schedule", required=True, help="A schedule with which to run the pipeline, in cron format. Example: '*/30 * * * *' will run the pipeline every 30 minutes. Remember to enclose the scheduler expression in quotation marks!") - deploy_github_cmd.add_argument("--run-manually", default=True, action="store_true", help="Allows the pipeline to be run manually form Github Actions UI.") - deploy_github_cmd.add_argument("--run-on-push", default=False, action="store_true", help="Runs the pipeline with every push to the repository.") + deploy_github_cmd = deploy_sub_parsers.add_parser( + DeploymentMethods.github_actions.value, + help="Deploys the pipeline to Github Actions", + parents=[deploy_comm], + ) + deploy_github_cmd.add_argument( + "--schedule", + required=True, + help=( + "A schedule with which to run the pipeline, in cron format. Example: '*/30 * * * *'" + " will run the pipeline every 30 minutes. Remember to enclose the scheduler" + " expression in quotation marks!" + ), + ) + deploy_github_cmd.add_argument( + "--run-manually", + default=True, + action="store_true", + help="Allows the pipeline to be run manually form Github Actions UI.", + ) + deploy_github_cmd.add_argument( + "--run-on-push", + default=False, + action="store_true", + help="Runs the pipeline with every push to the repository.", + ) # deploy airflow composer - deploy_airflow_cmd = deploy_sub_parsers.add_parser(DeploymentMethods.airflow_composer.value, help="Deploys the pipeline to Airflow", parents=[deploy_comm]) - deploy_airflow_cmd.add_argument("--secrets-format", default=SecretFormats.toml.value, choices=[v.value for v in SecretFormats], required=False, help="Format of the secrets") + deploy_airflow_cmd = deploy_sub_parsers.add_parser( + DeploymentMethods.airflow_composer.value, + help="Deploys the pipeline to Airflow", + parents=[deploy_comm], + ) + deploy_airflow_cmd.add_argument( + "--secrets-format", + default=SecretFormats.toml.value, + choices=[v.value for v in SecretFormats], + required=False, + help="Format of the secrets", + ) except NameError: # create placeholder command - deploy_cmd = subparsers.add_parser("deploy", help='Install additional dependencies with pip install "dlt[cli]" to create deployment packages', add_help=False) + deploy_cmd = subparsers.add_parser( + "deploy", + help=( + 'Install additional dependencies with pip install "dlt[cli]" to create deployment' + " packages" + ), + add_help=False, + ) deploy_cmd.add_argument("--help", "-h", nargs="?", const=True) - deploy_cmd.add_argument("pipeline_script_path", metavar="pipeline-script-path", nargs=argparse.REMAINDER) + deploy_cmd.add_argument( + "pipeline_script_path", metavar="pipeline-script-path", nargs=argparse.REMAINDER + ) schema = subparsers.add_parser("schema", help="Shows, converts and upgrades schemas") - schema.add_argument("file", help="Schema file name, in yaml or json format, will autodetect based on extension") - schema.add_argument("--format", choices=["json", "yaml"], default="yaml", help="Display schema in this format") - schema.add_argument("--remove-defaults", action="store_true", help="Does not show default hint values") + schema.add_argument( + "file", help="Schema file name, in yaml or json format, will autodetect based on extension" + ) + schema.add_argument( + "--format", choices=["json", "yaml"], default="yaml", help="Display schema in this format" + ) + schema.add_argument( + "--remove-defaults", action="store_true", help="Does not show default hint values" + ) - pipe_cmd = subparsers.add_parser("pipeline", help="Operations on pipelines that were ran locally") - pipe_cmd.add_argument("--list-pipelines", "-l", default=False, action="store_true", help="List local pipelines") - pipe_cmd.add_argument("pipeline_name", nargs='?', help="Pipeline name") + pipe_cmd = subparsers.add_parser( + "pipeline", help="Operations on pipelines that were ran locally" + ) + pipe_cmd.add_argument( + "--list-pipelines", "-l", default=False, action="store_true", help="List local pipelines" + ) + pipe_cmd.add_argument("pipeline_name", nargs="?", help="Pipeline name") pipe_cmd.add_argument("--pipelines-dir", help="Pipelines working directory", default=None) - pipe_cmd.add_argument("--verbose", "-v", action='count', default=0, help="Provides more information for certain commands.", dest="verbosity") + pipe_cmd.add_argument( + "--verbose", + "-v", + action="count", + default=0, + help="Provides more information for certain commands.", + dest="verbosity", + ) # pipe_cmd.add_argument("--dataset-name", help="Dataset name used to sync destination when local pipeline state is missing.") # pipe_cmd.add_argument("--destination", help="Destination name used to sync when local pipeline state is missing.") pipeline_subparsers = pipe_cmd.add_subparsers(dest="operation", required=False) pipe_cmd_sync_parent = argparse.ArgumentParser(add_help=False) - pipe_cmd_sync_parent.add_argument("--destination", help="Sync from this destination when local pipeline state is missing.") - pipe_cmd_sync_parent.add_argument("--dataset-name", help="Dataset name to sync from when local pipeline state is missing.") + pipe_cmd_sync_parent.add_argument( + "--destination", help="Sync from this destination when local pipeline state is missing." + ) + pipe_cmd_sync_parent.add_argument( + "--dataset-name", help="Dataset name to sync from when local pipeline state is missing." + ) - pipeline_subparsers.add_parser("info", help="Displays state of the pipeline, use -v or -vv for more info") - pipeline_subparsers.add_parser("show", help="Generates and launches Streamlit app with the loading status and dataset explorer") - pipeline_subparsers.add_parser("failed-jobs", help="Displays information on all the failed loads in all completed packages, failed jobs and associated error messages") + pipeline_subparsers.add_parser( + "info", help="Displays state of the pipeline, use -v or -vv for more info" + ) + pipeline_subparsers.add_parser( + "show", + help="Generates and launches Streamlit app with the loading status and dataset explorer", + ) + pipeline_subparsers.add_parser( + "failed-jobs", + help=( + "Displays information on all the failed loads in all completed packages, failed jobs" + " and associated error messages" + ), + ) pipeline_subparsers.add_parser( "sync", - help="Drops the local state of the pipeline and resets all the schemas and restores it from destination. The destination state, data and schemas are left intact.", - parents=[pipe_cmd_sync_parent] + help=( + "Drops the local state of the pipeline and resets all the schemas and restores it from" + " destination. The destination state, data and schemas are left intact." + ), + parents=[pipe_cmd_sync_parent], + ) + pipeline_subparsers.add_parser( + "trace", help="Displays last run trace, use -v or -vv for more info" ) - pipeline_subparsers.add_parser("trace", help="Displays last run trace, use -v or -vv for more info") pipe_cmd_schema = pipeline_subparsers.add_parser("schema", help="Displays default schema") - pipe_cmd_schema.add_argument("--format", choices=["json", "yaml"], default="yaml", help="Display schema in this format") - pipe_cmd_schema.add_argument("--remove-defaults", action="store_true", help="Does not show default hint values") + pipe_cmd_schema.add_argument( + "--format", choices=["json", "yaml"], default="yaml", help="Display schema in this format" + ) + pipe_cmd_schema.add_argument( + "--remove-defaults", action="store_true", help="Does not show default hint values" + ) pipe_cmd_drop = pipeline_subparsers.add_parser( "drop", help="Selectively drop tables and reset state", parents=[pipe_cmd_sync_parent], - epilog=f"See {DLT_PIPELINE_COMMAND_DOCS_URL}#selectively-drop-tables-and-reset-state for more info" + epilog=( + f"See {DLT_PIPELINE_COMMAND_DOCS_URL}#selectively-drop-tables-and-reset-state for more" + " info" + ), + ) + pipe_cmd_drop.add_argument( + "resources", + nargs="*", + help=( + "One or more resources to drop. Can be exact resource name(s) or regex pattern(s)." + " Regex patterns must start with re:" + ), + ) + pipe_cmd_drop.add_argument( + "--drop-all", + action="store_true", + default=False, + help="Drop all resources found in schema. Supersedes [resources] argument.", + ) + pipe_cmd_drop.add_argument( + "--state-paths", nargs="*", help="State keys or json paths to drop", default=() + ) + pipe_cmd_drop.add_argument( + "--schema", + help="Schema name to drop from (if other than default schema).", + dest="schema_name", + ) + pipe_cmd_drop.add_argument( + "--state-only", + action="store_true", + help="Only wipe state for matching resources without dropping tables.", + default=False, ) - pipe_cmd_drop.add_argument("resources", nargs="*", help="One or more resources to drop. Can be exact resource name(s) or regex pattern(s). Regex patterns must start with re:") - pipe_cmd_drop.add_argument("--drop-all", action="store_true", default=False, help="Drop all resources found in schema. Supersedes [resources] argument.") - pipe_cmd_drop.add_argument("--state-paths", nargs="*", help="State keys or json paths to drop", default=()) - pipe_cmd_drop.add_argument("--schema", help="Schema name to drop from (if other than default schema).", dest="schema_name") - pipe_cmd_drop.add_argument("--state-only", action="store_true", help="Only wipe state for matching resources without dropping tables.", default=False) - pipe_cmd_package = pipeline_subparsers.add_parser("load-package", help="Displays information on load package, use -v or -vv for more info") - pipe_cmd_package.add_argument("load_id", metavar="load-id", nargs='?', help="Load id of completed or normalized package. Defaults to the most recent package.") + pipe_cmd_package = pipeline_subparsers.add_parser( + "load-package", help="Displays information on load package, use -v or -vv for more info" + ) + pipe_cmd_package.add_argument( + "load_id", + metavar="load-id", + nargs="?", + help="Load id of completed or normalized package. Defaults to the most recent package.", + ) subparsers.add_parser("telemetry", help="Shows telemetry status") args = parser.parse_args() if Venv.is_virtual_env() and not Venv.is_venv_activated(): - fmt.warning("You are running dlt installed in the global environment, however you have virtual environment activated. The dlt command will not see dependencies from virtual environment. You should uninstall the dlt from global environment and install it in the current virtual environment instead.") + fmt.warning( + "You are running dlt installed in the global environment, however you have virtual" + " environment activated. The dlt command will not see dependencies from virtual" + " environment. You should uninstall the dlt from global environment and install it in" + " the current virtual environment instead." + ) if args.command == "schema": return schema_command_wrapper(args.file, args.format, args.remove_defaults) @@ -290,7 +535,7 @@ def main() -> int: return pipeline_command_wrapper("list", "-", args.pipelines_dir, args.verbosity) else: command_kwargs = dict(args._get_kwargs()) - command_kwargs['operation'] = args.operation or "info" + command_kwargs["operation"] = args.operation or "info" del command_kwargs["command"] del command_kwargs["list_pipelines"] return pipeline_command_wrapper(**command_kwargs) @@ -302,7 +547,9 @@ def main() -> int: init_cmd.print_usage() return -1 else: - return init_command_wrapper(args.source, args.destination, args.generic, args.location, args.branch) + return init_command_wrapper( + args.source, args.destination, args.generic, args.location, args.branch + ) elif args.command == "deploy": try: deploy_args = vars(args) @@ -311,12 +558,17 @@ def main() -> int: deployment_method=deploy_args.pop("deployment_method"), repo_location=deploy_args.pop("location"), branch=deploy_args.pop("branch"), - **deploy_args + **deploy_args, ) except (NameError, KeyError): - fmt.warning("Please install additional command line dependencies to use deploy command:") + fmt.warning( + "Please install additional command line dependencies to use deploy command:" + ) fmt.secho('pip install "dlt[cli]"', bold=True) - fmt.echo("We ask you to install those dependencies separately to keep our core library small and make it work everywhere.") + fmt.echo( + "We ask you to install those dependencies separately to keep our core library small" + " and make it work everywhere." + ) return -1 elif args.command == "telemetry": return telemetry_status_command_wrapper() diff --git a/dlt/cli/config_toml_writer.py b/dlt/cli/config_toml_writer.py index ca2e74fd15..57c04106c9 100644 --- a/dlt/cli/config_toml_writer.py +++ b/dlt/cli/config_toml_writer.py @@ -1,11 +1,16 @@ -from typing import Any, NamedTuple, Tuple, Iterable +from collections.abc import Sequence as C_Sequence +from typing import Any, Iterable, NamedTuple, Tuple + import tomlkit -from tomlkit.items import Table as TOMLTable from tomlkit.container import Container as TOMLContainer -from collections.abc import Sequence as C_Sequence +from tomlkit.items import Table as TOMLTable from dlt.common import pendulum -from dlt.common.configuration.specs import BaseConfiguration, is_base_configuration_inner_hint, extract_inner_hint +from dlt.common.configuration.specs import ( + BaseConfiguration, + extract_inner_hint, + is_base_configuration_inner_hint, +) from dlt.common.data_types import py_type_to_sc_type from dlt.common.typing import AnyType, is_final_type, is_optional_type @@ -53,13 +58,15 @@ def write_value( hint: AnyType, overwrite_existing: bool, default_value: Any = None, - is_default_of_interest: bool = False + is_default_of_interest: bool = False, ) -> None: # skip if table contains the name already if name in toml_table and not overwrite_existing: return # do not dump final and optional fields if they are not of special interest - if (is_final_type(hint) or is_optional_type(hint) or default_value is not None) and not is_default_of_interest: + if ( + is_final_type(hint) or is_optional_type(hint) or default_value is not None + ) and not is_default_of_interest: return # get the inner hint to generate cool examples hint = extract_inner_hint(hint) @@ -84,10 +91,19 @@ def write_spec(toml_table: TOMLTable, config: BaseConfiguration, overwrite_exist default_value = getattr(config, name, None) # check if field is of particular interest and should be included if it has default is_default_of_interest = name in config.__config_gen_annotations__ - write_value(toml_table, name, hint, overwrite_existing, default_value=default_value, is_default_of_interest=is_default_of_interest) + write_value( + toml_table, + name, + hint, + overwrite_existing, + default_value=default_value, + is_default_of_interest=is_default_of_interest, + ) -def write_values(toml: TOMLContainer, values: Iterable[WritableConfigValue], overwrite_existing: bool) -> None: +def write_values( + toml: TOMLContainer, values: Iterable[WritableConfigValue], overwrite_existing: bool +) -> None: for value in values: toml_table: TOMLTable = toml # type: ignore for section in value.sections: @@ -98,4 +114,11 @@ def write_values(toml: TOMLContainer, values: Iterable[WritableConfigValue], ove else: toml_table = toml_table[section] # type: ignore - write_value(toml_table, value.name, value.hint, overwrite_existing, default_value=value.default_value, is_default_of_interest=True) + write_value( + toml_table, + value.name, + value.hint, + overwrite_existing, + default_value=value.default_value, + is_default_of_interest=True, + ) diff --git a/dlt/cli/deploy_command.py b/dlt/cli/deploy_command.py index 7634f173b3..b4dab9cb61 100644 --- a/dlt/cli/deploy_command.py +++ b/dlt/cli/deploy_command.py @@ -1,26 +1,34 @@ import os -from typing import Optional, Any, Type -import yaml from enum import Enum from importlib.metadata import version as pkg_version +from typing import Any, Optional, Type -from dlt.common.configuration.providers import SECRETS_TOML, SECRETS_TOML_KEY, StringTomlProvider +import yaml + +from dlt.cli import echo as fmt +from dlt.cli import utils +from dlt.cli.deploy_command_helpers import ( + BaseDeployment, + PipelineWasNotRun, + ask_files_overwrite, + generate_pip_freeze, + get_schedule_description, + github_origin_to_url, + serialize_templated_yaml, + wrap_template_str, +) from dlt.common.configuration.paths import make_dlt_settings_path +from dlt.common.configuration.providers import SECRETS_TOML, SECRETS_TOML_KEY, StringTomlProvider from dlt.common.configuration.utils import serialize_value +from dlt.common.destination.reference import DestinationReference from dlt.common.git import is_dirty - -from dlt.cli import utils -from dlt.cli import echo as fmt -from dlt.cli.deploy_command_helpers import (PipelineWasNotRun, BaseDeployment, ask_files_overwrite, generate_pip_freeze, github_origin_to_url, serialize_templated_yaml, - wrap_template_str, get_schedule_description) - from dlt.version import DLT_PKG_NAME -from dlt.common.destination.reference import DestinationReference - REQUIREMENTS_GITHUB_ACTION = "requirements_github_action.txt" DLT_DEPLOY_DOCS_URL = "https://dlthub.com/docs/walkthroughs/deploy-a-pipeline" -DLT_AIRFLOW_GCP_DOCS_URL = "https://dlthub.com/docs/walkthroughs/deploy-a-pipeline/deploy-with-airflow-composer" +DLT_AIRFLOW_GCP_DOCS_URL = ( + "https://dlthub.com/docs/walkthroughs/deploy-a-pipeline/deploy-with-airflow-composer" +) AIRFLOW_GETTING_STARTED = "https://airflow.apache.org/docs/apache-airflow/stable/start.html" AIRFLOW_DAG_TEMPLATE_SCRIPT = "dag_template.py" AIRFLOW_CLOUDBUILD_YAML = "cloudbuild.yaml" @@ -38,9 +46,13 @@ class SecretFormats(Enum): toml = "toml" -def deploy_command(pipeline_script_path: str, deployment_method: str, repo_location: str, branch: Optional[str] = None, **kwargs: Any +def deploy_command( + pipeline_script_path: str, + deployment_method: str, + repo_location: str, + branch: Optional[str] = None, + **kwargs: Any, ) -> None: - # get current repo local folder deployment_class: Type[BaseDeployment] = None if deployment_method == DeploymentMethods.github_actions.value: @@ -48,10 +60,15 @@ def deploy_command(pipeline_script_path: str, deployment_method: str, repo_locat elif deployment_method == DeploymentMethods.airflow_composer.value: deployment_class = AirflowDeployment else: - raise ValueError(f"Deployment method '{deployment_method}' is not supported. Only {', '.join([m.value for m in DeploymentMethods])} are available.'") + raise ValueError( + f"Deployment method '{deployment_method}' is not supported. Only" + f" {', '.join([m.value for m in DeploymentMethods])} are available.'" + ) # command no longer needed kwargs.pop("command", None) - deployment_class(pipeline_script_path=pipeline_script_path, location=repo_location, branch=branch, **kwargs).run_deployment() + deployment_class( + pipeline_script_path=pipeline_script_path, location=repo_location, branch=branch, **kwargs + ).run_deployment() class GithubActionDeployment(BaseDeployment): @@ -77,22 +94,25 @@ def _generate_workflow(self, *args: Optional[Any]) -> None: if self.schedule_description is None: # TODO: move that check to _dlt and some intelligent help message on missing arg raise ValueError( - f"Setting 'schedule' for '{self.deployment_method}' is required! Use deploy command as 'dlt deploy chess.py {self.deployment_method} --schedule \"*/30 * * * *\"'." + f"Setting 'schedule' for '{self.deployment_method}' is required! Use deploy command" + f" as 'dlt deploy chess.py {self.deployment_method} --schedule \"*/30 * * * *\"'." ) workflow = self._create_new_workflow() serialized_workflow = serialize_templated_yaml(workflow) serialized_workflow_name = f"run_{self.state['pipeline_name']}_workflow.yml" - self.artifacts['serialized_workflow'] = serialized_workflow - self.artifacts['serialized_workflow_name'] = serialized_workflow_name + self.artifacts["serialized_workflow"] = serialized_workflow + self.artifacts["serialized_workflow_name"] = serialized_workflow_name # pip freeze special requirements file - with self.template_storage.open_file(os.path.join(self.deployment_method, "requirements_blacklist.txt")) as f: + with self.template_storage.open_file( + os.path.join(self.deployment_method, "requirements_blacklist.txt") + ) as f: requirements_blacklist = f.readlines() requirements_txt = generate_pip_freeze(requirements_blacklist, REQUIREMENTS_GITHUB_ACTION) requirements_txt_name = REQUIREMENTS_GITHUB_ACTION # if repo_storage.has_file(utils.REQUIREMENTS_TXT): - self.artifacts['requirements_txt'] = requirements_txt - self.artifacts['requirements_txt_name'] = requirements_txt_name + self.artifacts["requirements_txt"] = requirements_txt + self.artifacts["requirements_txt_name"] = requirements_txt_name def _make_modification(self) -> None: if not self.repo_storage.has_folder(utils.GITHUB_WORKFLOWS_DIR): @@ -100,15 +120,21 @@ def _make_modification(self) -> None: self.repo_storage.save( os.path.join(utils.GITHUB_WORKFLOWS_DIR, self.artifacts["serialized_workflow_name"]), - self.artifacts["serialized_workflow"] + self.artifacts["serialized_workflow"], + ) + self.repo_storage.save( + self.artifacts["requirements_txt_name"], self.artifacts["requirements_txt"] ) - self.repo_storage.save(self.artifacts["requirements_txt_name"], self.artifacts["requirements_txt"]) def _create_new_workflow(self) -> Any: - with self.template_storage.open_file(os.path.join(self.deployment_method, "run_pipeline_workflow.yml")) as f: + with self.template_storage.open_file( + os.path.join(self.deployment_method, "run_pipeline_workflow.yml") + ) as f: workflow = yaml.safe_load(f) # customize the workflow - workflow["name"] = f"Run {self.state['pipeline_name']} pipeline from {self.pipeline_script_path}" + workflow["name"] = ( + f"Run {self.state['pipeline_name']} pipeline from {self.pipeline_script_path}" + ) if self.run_on_push is False: del workflow["on"]["push"] if self.run_manually is False: @@ -137,51 +163,98 @@ def _create_new_workflow(self) -> Any: return workflow def _echo_instructions(self, *args: Optional[Any]) -> None: - fmt.echo("Your %s deployment for pipeline %s in script %s is ready!" % ( - fmt.bold(self.deployment_method), fmt.bold(self.state["pipeline_name"]), fmt.bold(self.pipeline_script_path) - )) + fmt.echo( + "Your %s deployment for pipeline %s in script %s is ready!" + % ( + fmt.bold(self.deployment_method), + fmt.bold(self.state["pipeline_name"]), + fmt.bold(self.pipeline_script_path), + ) + ) # It contains all relevant configurations and references to credentials that are needed to run the pipeline - fmt.echo("* A github workflow file %s was created in %s." % ( - fmt.bold(self.artifacts["serialized_workflow_name"]), fmt.bold(utils.GITHUB_WORKFLOWS_DIR) - )) - fmt.echo("* The schedule with which the pipeline is run is: %s.%s%s" % ( - fmt.bold(self.schedule_description), - " You can also run the pipeline manually." if self.run_manually else "", - " Pipeline will also run on each push to the repository." if self.run_on_push else "", - )) fmt.echo( - "* The dependencies that will be used to run the pipeline are stored in %s. If you change add more dependencies, remember to refresh your deployment by running the same 'deploy' command again." % fmt.bold( - self.artifacts['requirements_txt_name']) + "* A github workflow file %s was created in %s." + % ( + fmt.bold(self.artifacts["serialized_workflow_name"]), + fmt.bold(utils.GITHUB_WORKFLOWS_DIR), + ) + ) + fmt.echo( + "* The schedule with which the pipeline is run is: %s.%s%s" + % ( + fmt.bold(self.schedule_description), + " You can also run the pipeline manually." if self.run_manually else "", + ( + " Pipeline will also run on each push to the repository." + if self.run_on_push + else "" + ), + ) + ) + fmt.echo( + "* The dependencies that will be used to run the pipeline are stored in %s. If you" + " change add more dependencies, remember to refresh your deployment by running the same" + " 'deploy' command again." + % fmt.bold(self.artifacts["requirements_txt_name"]) ) fmt.echo() if len(self.secret_envs) == 0 and len(self.envs) == 0: fmt.echo("1. Your pipeline does not seem to need any secrets.") else: - fmt.echo("You should now add the secrets to github repository secrets, commit and push the pipeline files to github.") - fmt.echo("1. Add the following secret values (typically stored in %s): \n%s\nin %s" % ( - fmt.bold(make_dlt_settings_path(SECRETS_TOML)), - fmt.bold("\n".join(self.env_prov.get_key_name(s_v.key, *s_v.sections) for s_v in self.secret_envs)), - fmt.bold(github_origin_to_url(self.origin, "/settings/secrets/actions")) - )) + fmt.echo( + "You should now add the secrets to github repository secrets, commit and push the" + " pipeline files to github." + ) + fmt.echo( + "1. Add the following secret values (typically stored in %s): \n%s\nin %s" + % ( + fmt.bold(make_dlt_settings_path(SECRETS_TOML)), + fmt.bold( + "\n".join( + self.env_prov.get_key_name(s_v.key, *s_v.sections) + for s_v in self.secret_envs + ) + ), + fmt.bold(github_origin_to_url(self.origin, "/settings/secrets/actions")), + ) + ) fmt.echo() self._echo_secrets() - fmt.echo("2. Add stage deployment files to commit. Use your Git UI or the following command") - new_req_path = self.repo_storage.from_relative_path_to_wd(self.artifacts['requirements_txt_name']) - new_workflow_path = self.repo_storage.from_relative_path_to_wd(os.path.join(utils.GITHUB_WORKFLOWS_DIR, self.artifacts['serialized_workflow_name'])) - fmt.echo(fmt.bold( - f"git add {new_req_path} {new_workflow_path}")) + fmt.echo( + "2. Add stage deployment files to commit. Use your Git UI or the following command" + ) + new_req_path = self.repo_storage.from_relative_path_to_wd( + self.artifacts["requirements_txt_name"] + ) + new_workflow_path = self.repo_storage.from_relative_path_to_wd( + os.path.join(utils.GITHUB_WORKFLOWS_DIR, self.artifacts["serialized_workflow_name"]) + ) + fmt.echo(fmt.bold(f"git add {new_req_path} {new_workflow_path}")) fmt.echo() fmt.echo("3. Commit the files above. Use your Git UI or the following command") - fmt.echo(fmt.bold(f"git commit -m 'run {self.state['pipeline_name']} pipeline with github action'")) + fmt.echo( + fmt.bold( + f"git commit -m 'run {self.state['pipeline_name']} pipeline with github action'" + ) + ) if is_dirty(self.repo): - fmt.warning("You have modified files in your repository. Do not forget to push changes to your pipeline script as well!") + fmt.warning( + "You have modified files in your repository. Do not forget to push changes to your" + " pipeline script as well!" + ) fmt.echo() fmt.echo("4. Push changes to github. Use your Git UI or the following command") fmt.echo(fmt.bold("git push origin")) fmt.echo() fmt.echo("5. Your pipeline should be running! You can monitor it here:") - fmt.echo(fmt.bold(github_origin_to_url(self.origin, f"/actions/workflows/{self.artifacts['serialized_workflow_name']}"))) + fmt.echo( + fmt.bold( + github_origin_to_url( + self.origin, f"/actions/workflows/{self.artifacts['serialized_workflow_name']}" + ) + ) + ) class AirflowDeployment(BaseDeployment): @@ -206,11 +279,15 @@ def _generate_workflow(self, *args: Optional[Any]) -> None: dag_script_name = f"dag_{self.state['pipeline_name']}.py" self.artifacts["dag_script_name"] = dag_script_name - cloudbuild_file = self.template_storage.load(os.path.join(self.deployment_method, AIRFLOW_CLOUDBUILD_YAML)) + cloudbuild_file = self.template_storage.load( + os.path.join(self.deployment_method, AIRFLOW_CLOUDBUILD_YAML) + ) self.artifacts["cloudbuild_file"] = cloudbuild_file # TODO: rewrite dag file to at least set the schedule - dag_file = self.template_storage.load(os.path.join(self.deployment_method, AIRFLOW_DAG_TEMPLATE_SCRIPT)) + dag_file = self.template_storage.load( + os.path.join(self.deployment_method, AIRFLOW_DAG_TEMPLATE_SCRIPT) + ) self.artifacts["dag_file"] = dag_file # ask user if to overwrite the files @@ -227,61 +304,92 @@ def _make_modification(self) -> None: # save cloudbuild.yaml only if not exist to allow to run the deploy command for many different pipelines dest_cloud_build = os.path.join(utils.AIRFLOW_BUILD_FOLDER, AIRFLOW_CLOUDBUILD_YAML) if not self.repo_storage.has_file(dest_cloud_build): - self.repo_storage.save( - dest_cloud_build, - self.artifacts["cloudbuild_file"] - ) + self.repo_storage.save(dest_cloud_build, self.artifacts["cloudbuild_file"]) else: - fmt.warning(f"{AIRFLOW_CLOUDBUILD_YAML} already created. Delete the file and run the deploy command again to re-create.") + fmt.warning( + f"{AIRFLOW_CLOUDBUILD_YAML} already created. Delete the file and run the deploy" + " command again to re-create." + ) dest_dag_script = os.path.join(utils.AIRFLOW_DAGS_FOLDER, self.artifacts["dag_script_name"]) - self.repo_storage.save( - dest_dag_script, - self.artifacts["dag_file"] - ) - + self.repo_storage.save(dest_dag_script, self.artifacts["dag_file"]) def _echo_instructions(self, *args: Optional[Any]) -> None: - fmt.echo("Your %s deployment for pipeline %s is ready!" % ( - fmt.bold(self.deployment_method), fmt.bold(self.state["pipeline_name"]), - )) - fmt.echo("* The airflow %s file was created in %s." % ( - fmt.bold(AIRFLOW_CLOUDBUILD_YAML), fmt.bold(utils.AIRFLOW_BUILD_FOLDER) - )) - fmt.echo("* The %s script was created in %s." % ( - fmt.bold(self.artifacts["dag_script_name"]), fmt.bold(utils.AIRFLOW_DAGS_FOLDER) - )) + fmt.echo( + "Your %s deployment for pipeline %s is ready!" + % ( + fmt.bold(self.deployment_method), + fmt.bold(self.state["pipeline_name"]), + ) + ) + fmt.echo( + "* The airflow %s file was created in %s." + % (fmt.bold(AIRFLOW_CLOUDBUILD_YAML), fmt.bold(utils.AIRFLOW_BUILD_FOLDER)) + ) + fmt.echo( + "* The %s script was created in %s." + % (fmt.bold(self.artifacts["dag_script_name"]), fmt.bold(utils.AIRFLOW_DAGS_FOLDER)) + ) fmt.echo() fmt.echo("You must prepare your DAG first:") - fmt.echo("1. Import your sources in %s, configure the DAG ans tasks as needed." % (fmt.bold(self.artifacts["dag_script_name"]))) - fmt.echo("2. Test the DAG with Airflow locally .\nSee Airflow getting started: %s" % (fmt.bold(AIRFLOW_GETTING_STARTED))) + fmt.echo( + "1. Import your sources in %s, configure the DAG ans tasks as needed." + % (fmt.bold(self.artifacts["dag_script_name"])) + ) + fmt.echo( + "2. Test the DAG with Airflow locally .\nSee Airflow getting started: %s" + % (fmt.bold(AIRFLOW_GETTING_STARTED)) + ) fmt.echo() - fmt.echo("If you are planning run the pipeline with Google Cloud Composer, follow the next instructions:\n") - fmt.echo("1. Read this doc and set up the Environment: %s" % ( - fmt.bold(DLT_AIRFLOW_GCP_DOCS_URL) - )) - fmt.echo("2. Set _BUCKET_NAME up in %s/%s file. " % ( - fmt.bold(utils.AIRFLOW_BUILD_FOLDER), fmt.bold(AIRFLOW_CLOUDBUILD_YAML), - )) + fmt.echo( + "If you are planning run the pipeline with Google Cloud Composer, follow the next" + " instructions:\n" + ) + fmt.echo( + "1. Read this doc and set up the Environment: %s" % (fmt.bold(DLT_AIRFLOW_GCP_DOCS_URL)) + ) + fmt.echo( + "2. Set _BUCKET_NAME up in %s/%s file. " + % ( + fmt.bold(utils.AIRFLOW_BUILD_FOLDER), + fmt.bold(AIRFLOW_CLOUDBUILD_YAML), + ) + ) if len(self.secret_envs) == 0 and len(self.envs) == 0: fmt.echo("3. Your pipeline does not seem to need any secrets.") else: if self.secrets_format == SecretFormats.env.value: - fmt.echo("3. Add the following secret values (typically stored in %s): \n%s\n%s\nin ENVIRONMENT VARIABLES using Google Composer UI" % ( - fmt.bold(make_dlt_settings_path(SECRETS_TOML)), - fmt.bold("\n".join(self.env_prov.get_key_name(s_v.key, *s_v.sections) for s_v in self.secret_envs)), - fmt.bold("\n".join(self.env_prov.get_key_name(v.key, *v.sections) for v in self.envs)), - )) + fmt.echo( + "3. Add the following secret values (typically stored in %s): \n%s\n%s\nin" + " ENVIRONMENT VARIABLES using Google Composer UI" + % ( + fmt.bold(make_dlt_settings_path(SECRETS_TOML)), + fmt.bold( + "\n".join( + self.env_prov.get_key_name(s_v.key, *s_v.sections) + for s_v in self.secret_envs + ) + ), + fmt.bold( + "\n".join( + self.env_prov.get_key_name(v.key, *v.sections) for v in self.envs + ) + ), + ) + ) fmt.echo() # if fmt.confirm("Do you want to list the environment variables in the format suitable for Airflow?", default=True): self._echo_secrets() self._echo_envs() elif self.secrets_format == SecretFormats.toml.value: # build toml - fmt.echo(f"3. Add the following toml-string in the Google Composer UI as the {SECRETS_TOML_KEY} variable.") + fmt.echo( + "3. Add the following toml-string in the Google Composer UI as the" + f" {SECRETS_TOML_KEY} variable." + ) fmt.echo() toml_provider = StringTomlProvider("") for s_v in self.secret_envs: @@ -294,18 +402,34 @@ def _echo_instructions(self, *args: Optional[Any]) -> None: fmt.echo("4. Add dlt package below using Google Composer UI.") fmt.echo(fmt.bold(self.artifacts["requirements_txt"])) - fmt.note("You may need to add more packages ie. when your source requires additional dependencies") + fmt.note( + "You may need to add more packages ie. when your source requires additional" + " dependencies" + ) fmt.echo("5. Commit and push the pipeline files to github:") - fmt.echo("a. Add stage deployment files to commit. Use your Git UI or the following command") + fmt.echo( + "a. Add stage deployment files to commit. Use your Git UI or the following command" + ) - dag_script_path = self.repo_storage.from_relative_path_to_wd(os.path.join(utils.AIRFLOW_DAGS_FOLDER, self.artifacts["dag_script_name"])) - cloudbuild_path = self.repo_storage.from_relative_path_to_wd(os.path.join(utils.AIRFLOW_BUILD_FOLDER, AIRFLOW_CLOUDBUILD_YAML)) + dag_script_path = self.repo_storage.from_relative_path_to_wd( + os.path.join(utils.AIRFLOW_DAGS_FOLDER, self.artifacts["dag_script_name"]) + ) + cloudbuild_path = self.repo_storage.from_relative_path_to_wd( + os.path.join(utils.AIRFLOW_BUILD_FOLDER, AIRFLOW_CLOUDBUILD_YAML) + ) fmt.echo(fmt.bold(f"git add {dag_script_path} {cloudbuild_path}")) fmt.echo("b. Commit the files above. Use your Git UI or the following command") - fmt.echo(fmt.bold(f"git commit -m 'initiate {self.state['pipeline_name']} pipeline with Airflow'")) + fmt.echo( + fmt.bold( + f"git commit -m 'initiate {self.state['pipeline_name']} pipeline with Airflow'" + ) + ) if is_dirty(self.repo): - fmt.warning("You have modified files in your repository. Do not forget to push changes to your pipeline script as well!") + fmt.warning( + "You have modified files in your repository. Do not forget to push changes to your" + " pipeline script as well!" + ) fmt.echo("c. Push changes to github. Use your Git UI or the following command") fmt.echo(fmt.bold("git push origin")) fmt.echo("6. You should see your pipeline in Airflow.") diff --git a/dlt/cli/deploy_command_helpers.py b/dlt/cli/deploy_command_helpers.py index 81852f3ce1..d5fef310d9 100644 --- a/dlt/cli/deploy_command_helpers.py +++ b/dlt/cli/deploy_command_helpers.py @@ -1,37 +1,36 @@ -import re import abc import os -import yaml -from yaml import Dumper +import re from itertools import chain -from typing import List, Optional, Sequence, Tuple, Any, Dict -from astunparse import unparse +from typing import Any, Dict, List, Optional, Sequence, Tuple + +import cron_descriptor + # optional dependencies import pipdeptree -import cron_descriptor +import yaml +from astunparse import unparse +from yaml import Dumper import dlt - +from dlt.cli import echo as fmt +from dlt.cli import utils +from dlt.cli.exceptions import CliCommandException from dlt.common import git from dlt.common.configuration.exceptions import LookupTrace from dlt.common.configuration.providers import ConfigTomlProvider, EnvironProvider -from dlt.common.git import get_origin, get_repo, Repo from dlt.common.configuration.specs.run_configuration import get_default_pipeline_name -from dlt.common.typing import StrAny -from dlt.common.reflection.utils import evaluate_node_literal +from dlt.common.git import Repo, get_origin, get_repo from dlt.common.pipeline import LoadInfo, TPipelineState, get_dlt_repos_dir +from dlt.common.reflection.utils import evaluate_node_literal from dlt.common.storages import FileStorage +from dlt.common.typing import StrAny from dlt.common.utils import set_working_dir - from dlt.pipeline.pipeline import Pipeline from dlt.pipeline.trace import PipelineTrace from dlt.reflection import names as n from dlt.reflection.script_visitor import PipelineScriptVisitor -from dlt.cli import utils -from dlt.cli import echo as fmt -from dlt.cli.exceptions import CliCommandException - GITHUB_URL = "https://github.com/" @@ -77,20 +76,36 @@ def _prepare_deployment(self) -> None: # make sure the repo has origin self.origin = self._get_origin() # convert to path relative to repo - self.repo_pipeline_script_path = self.repo_storage.from_wd_to_relative_path(self.pipeline_script_path) + self.repo_pipeline_script_path = self.repo_storage.from_wd_to_relative_path( + self.pipeline_script_path + ) # load a pipeline script and extract full_refresh and pipelines_dir args self.pipeline_script = self.repo_storage.load(self.repo_pipeline_script_path) - fmt.echo("Looking up the deployment template scripts in %s...\n" % fmt.bold(self.repo_location)) - self.template_storage = git.get_fresh_repo_files(self.repo_location, get_dlt_repos_dir(), branch=self.branch) + fmt.echo( + "Looking up the deployment template scripts in %s...\n" % fmt.bold(self.repo_location) + ) + self.template_storage = git.get_fresh_repo_files( + self.repo_location, get_dlt_repos_dir(), branch=self.branch + ) self.working_directory = os.path.split(self.pipeline_script_path)[0] def _get_origin(self) -> str: try: origin = get_origin(self.repo) if "github.com" not in origin: - raise CliCommandException("deploy", f"Your current repository origin is not set to github but to {origin}.\nYou must change it to be able to run the pipelines with github actions: https://docs.github.com/en/get-started/getting-started-with-git/managing-remote-repositories") + raise CliCommandException( + "deploy", + f"Your current repository origin is not set to github but to {origin}.\nYou" + " must change it to be able to run the pipelines with github actions:" + " https://docs.github.com/en/get-started/getting-started-with-git/managing-remote-repositories", + ) except ValueError: - raise CliCommandException("deploy", "Your current repository has no origin set. Please set it up to be able to run the pipelines with github actions: https://docs.github.com/en/get-started/importing-your-projects-to-github/importing-source-code-to-github/adding-locally-hosted-code-to-github") + raise CliCommandException( + "deploy", + "Your current repository has no origin set. Please set it up to be able to run the" + " pipelines with github actions:" + " https://docs.github.com/en/get-started/importing-your-projects-to-github/importing-source-code-to-github/adding-locally-hosted-code-to-github", + ) return origin @@ -104,14 +119,18 @@ def run_deployment(self) -> None: pipeline_name: str = None pipelines_dir: str = None - uniq_possible_pipelines = {t[0]:t for t in possible_pipelines} + uniq_possible_pipelines = {t[0]: t for t in possible_pipelines} if len(uniq_possible_pipelines) == 1: pipeline_name, pipelines_dir = possible_pipelines[0] elif len(uniq_possible_pipelines) > 1: choices = list(uniq_possible_pipelines.keys()) - choices_str = "".join([str(i+1) for i in range(len(choices))]) + choices_str = "".join([str(i + 1) for i in range(len(choices))]) choices_selection = [f"{idx+1}-{name}" for idx, name in enumerate(choices)] - sel = fmt.prompt("Several pipelines found in script, please select one: " + ", ".join(choices_selection), choices=choices_str) + sel = fmt.prompt( + "Several pipelines found in script, please select one: " + + ", ".join(choices_selection), + choices=choices_str, + ) pipeline_name, pipelines_dir = uniq_possible_pipelines[choices[int(sel) - 1]] if pipelines_dir: @@ -126,11 +145,17 @@ def run_deployment(self) -> None: self.pipeline_name = dlt.config.get("pipeline_name") if not self.pipeline_name: self.pipeline_name = get_default_pipeline_name(self.pipeline_script_path) - fmt.warning(f"Using default pipeline name {self.pipeline_name}. The pipeline name is not passed as argument to dlt.pipeline nor configured via config provides ie. config.toml") + fmt.warning( + f"Using default pipeline name {self.pipeline_name}. The pipeline name" + " is not passed as argument to dlt.pipeline nor configured via config" + " provides ie. config.toml" + ) # fmt.echo("Generating deployment for pipeline %s" % fmt.bold(self.pipeline_name)) # attach to pipeline name, get state and trace - pipeline = dlt.attach(pipeline_name=self.pipeline_name, pipelines_dir=self.pipelines_dir) + pipeline = dlt.attach( + pipeline_name=self.pipeline_name, pipelines_dir=self.pipelines_dir + ) self.state, trace = get_state_and_trace(pipeline) self._update_envs(trace) @@ -148,12 +173,26 @@ def _update_envs(self, trace: PipelineTrace) -> None: for resolved_value in trace.resolved_config_values: if resolved_value.is_secret_hint: # generate special forms for all secrets - self.secret_envs.append(LookupTrace(self.env_prov.name, tuple(resolved_value.sections), resolved_value.key, resolved_value.value)) + self.secret_envs.append( + LookupTrace( + self.env_prov.name, + tuple(resolved_value.sections), + resolved_value.key, + resolved_value.value, + ) + ) # fmt.echo(f"{resolved_value.key}:{resolved_value.value}{type(resolved_value.value)} in {resolved_value.sections} is SECRET") else: # move all config values that are not in config.toml into env if resolved_value.provider_name != self.config_prov.name: - self.envs.append(LookupTrace(self.env_prov.name, tuple(resolved_value.sections), resolved_value.key, resolved_value.value)) + self.envs.append( + LookupTrace( + self.env_prov.name, + tuple(resolved_value.sections), + resolved_value.key, + resolved_value.value, + ) + ) # fmt.echo(f"{resolved_value.key} in {resolved_value.sections} moved to CONFIG") def _echo_secrets(self) -> None: @@ -189,12 +228,20 @@ def get_state_and_trace(pipeline: Pipeline) -> Tuple[TPipelineState, PipelineTra # trace must exist and end with a successful loading step trace = pipeline.last_trace if trace is None or len(trace.steps) == 0: - raise PipelineWasNotRun("Pipeline run trace could not be found. Please run the pipeline at least once locally.") + raise PipelineWasNotRun( + "Pipeline run trace could not be found. Please run the pipeline at least once locally." + ) last_step = trace.steps[-1] if last_step.step_exception is not None: - raise PipelineWasNotRun(f"The last pipeline run ended with error. Please make sure that pipeline runs correctly before deployment.\n{last_step.step_exception}") + raise PipelineWasNotRun( + "The last pipeline run ended with error. Please make sure that pipeline runs correctly" + f" before deployment.\n{last_step.step_exception}" + ) if not isinstance(last_step.step_info, LoadInfo): - raise PipelineWasNotRun("The last pipeline run did not reach the load step. Please run the pipeline locally until it loads data into destination.") + raise PipelineWasNotRun( + "The last pipeline run did not reach the load step. Please run the pipeline locally" + " until it loads data into destination." + ) return pipeline.state, trace @@ -202,7 +249,10 @@ def get_state_and_trace(pipeline: Pipeline) -> Tuple[TPipelineState, PipelineTra def get_visitors(pipeline_script: str, pipeline_script_path: str) -> PipelineScriptVisitor: visitor = utils.parse_init_script("deploy", pipeline_script, pipeline_script_path) if n.RUN not in visitor.known_calls: - raise CliCommandException("deploy", f"The pipeline script {pipeline_script_path} does not seem to run the pipeline.") + raise CliCommandException( + "deploy", + f"The pipeline script {pipeline_script_path} does not seem to run the pipeline.", + ) return visitor @@ -215,22 +265,40 @@ def parse_pipeline_info(visitor: PipelineScriptVisitor) -> List[Tuple[str, Optio if f_r_node: f_r_value = evaluate_node_literal(f_r_node) if f_r_value is None: - fmt.warning(f"The value of `full_refresh` in call to `dlt.pipeline` cannot be determined from {unparse(f_r_node).strip()}. We assume that you know what you are doing :)") + fmt.warning( + "The value of `full_refresh` in call to `dlt.pipeline` cannot be" + f" determined from {unparse(f_r_node).strip()}. We assume that you know" + " what you are doing :)" + ) if f_r_value is True: - if fmt.confirm("The value of 'full_refresh' is set to True. Do you want to abort to set it to False?", default=True): + if fmt.confirm( + "The value of 'full_refresh' is set to True. Do you want to abort to set it" + " to False?", + default=True, + ): raise CliCommandException("deploy", "Please set the full_refresh to False") p_d_node = call_args.arguments.get("pipelines_dir") if p_d_node: pipelines_dir = evaluate_node_literal(p_d_node) if pipelines_dir is None: - raise CliCommandException("deploy", f"The value of 'pipelines_dir' argument in call to `dlt_pipeline` cannot be determined from {unparse(p_d_node).strip()}. Pipeline working dir will be found. Pass it directly with --pipelines-dir option.") + raise CliCommandException( + "deploy", + "The value of 'pipelines_dir' argument in call to `dlt_pipeline` cannot be" + f" determined from {unparse(p_d_node).strip()}. Pipeline working dir will" + " be found. Pass it directly with --pipelines-dir option.", + ) p_n_node = call_args.arguments.get("pipeline_name") if p_n_node: pipeline_name = evaluate_node_literal(p_n_node) if pipeline_name is None: - raise CliCommandException("deploy", f"The value of 'pipeline_name' argument in call to `dlt_pipeline` cannot be determined from {unparse(p_d_node).strip()}. Pipeline working dir will be found. Pass it directly with --pipeline-name option.") + raise CliCommandException( + "deploy", + "The value of 'pipeline_name' argument in call to `dlt_pipeline` cannot be" + f" determined from {unparse(p_d_node).strip()}. Pipeline working dir will" + " be found. Pass it directly with --pipeline-name option.", + ) pipelines.append((pipeline_name, pipelines_dir)) return pipelines @@ -240,8 +308,8 @@ def str_representer(dumper: yaml.Dumper, data: str) -> yaml.ScalarNode: # format multiline strings as blocks with the exception of placeholders # that will be expanded as yaml if len(data.splitlines()) > 1 and "{{ toYaml" not in data: # check for multiline string - return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|') - return dumper.represent_scalar('tag:yaml.org,2002:str', data) + return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|") + return dumper.represent_scalar("tag:yaml.org,2002:str", data) def wrap_template_str(s: str) -> str: @@ -253,17 +321,14 @@ def serialize_templated_yaml(tree: StrAny) -> str: try: yaml.add_representer(str, str_representer) # pretty serialize yaml - serialized: str = yaml.dump(tree, allow_unicode=True, default_flow_style=False, sort_keys=False) + serialized: str = yaml.dump( + tree, allow_unicode=True, default_flow_style=False, sort_keys=False + ) # removes apostrophes around the template - serialized = re.sub(r"'([\s\n]*?\${{.+?}})'", - r"\1", - serialized, - flags=re.DOTALL) + serialized = re.sub(r"'([\s\n]*?\${{.+?}})'", r"\1", serialized, flags=re.DOTALL) # print(serialized) # fix the new lines in templates ending }} - serialized = re.sub(r"(\${{.+)\n.+(}})", - r"\1 \2", - serialized) + serialized = re.sub(r"(\${{.+)\n.+(}})", r"\1 \2", serialized) return serialized finally: yaml.add_representer(str, old_representer) @@ -292,7 +357,10 @@ def generate_pip_freeze(requirements_blacklist: List[str], requirements_file_nam conflicts = pipdeptree.conflicting_deps(tree) cycles = pipdeptree.cyclic_deps(tree) if conflicts: - fmt.warning(f"Unable to create dependencies for the github action. Please edit {requirements_file_name} yourself") + fmt.warning( + "Unable to create dependencies for the github action. Please edit" + f" {requirements_file_name} yourself" + ) pipdeptree.render_conflicts_text(conflicts) pipdeptree.render_cycles_text(cycles) fmt.echo() diff --git a/dlt/cli/echo.py b/dlt/cli/echo.py index 41c9fc1f7f..94b24d0d9a 100644 --- a/dlt/cli/echo.py +++ b/dlt/cli/echo.py @@ -1,7 +1,7 @@ import contextlib from typing import Any, Iterable, Iterator, Optional -import click +import click ALWAYS_CHOOSE_DEFAULT = False ALWAYS_CHOOSE_VALUE: Any = None @@ -20,7 +20,6 @@ def always_choose(always_choose_default: bool, always_choose_value: Any) -> Iter ALWAYS_CHOOSE_VALUE = _always_choose_value - echo = click.echo secho = click.secho style = click.style @@ -65,5 +64,6 @@ def prompt(text: str, choices: Iterable[str], default: Optional[Any] = None) -> click_choices = click.Choice(choices) return click.prompt(text, type=click_choices, default=default) + def text_input(text: str) -> str: return click.prompt(text) # type: ignore[no-any-return] diff --git a/dlt/cli/init_command.py b/dlt/cli/init_command.py index c246ac87de..685a41b2df 100644 --- a/dlt/cli/init_command.py +++ b/dlt/cli/init_command.py @@ -1,32 +1,40 @@ -import os import ast +import os import shutil +from importlib.metadata import version as pkg_version from types import ModuleType from typing import Dict, List, Sequence, Tuple -from importlib.metadata import version as pkg_version +import dlt.reflection.names as n +from dlt.cli import echo as fmt +from dlt.cli import pipeline_files as files_ops +from dlt.cli import source_detection, utils +from dlt.cli.config_toml_writer import WritableConfigValue, write_values +from dlt.cli.exceptions import CliCommandException +from dlt.cli.pipeline_files import ( + TVerifiedSourceFileEntry, + TVerifiedSourceFileIndex, + VerifiedSourceFiles, +) +from dlt.cli.requirements import SourceRequirements from dlt.common import git from dlt.common.configuration.paths import get_dlt_settings_dir, make_dlt_settings_path +from dlt.common.configuration.providers import ( + CONFIG_TOML, + SECRETS_TOML, + ConfigTomlProvider, + SecretsTomlProvider, +) from dlt.common.configuration.specs import known_sections -from dlt.common.configuration.providers import CONFIG_TOML, SECRETS_TOML, ConfigTomlProvider, SecretsTomlProvider -from dlt.common.pipeline import get_dlt_repos_dir -from dlt.common.source import _SOURCES -from dlt.version import DLT_PKG_NAME, __version__ from dlt.common.destination import DestinationReference +from dlt.common.pipeline import get_dlt_repos_dir from dlt.common.reflection.utils import rewrite_python_script -from dlt.common.schema.utils import is_valid_schema_name from dlt.common.schema.exceptions import InvalidSchemaName +from dlt.common.schema.utils import is_valid_schema_name +from dlt.common.source import _SOURCES from dlt.common.storages.file_storage import FileStorage - -import dlt.reflection.names as n from dlt.reflection.script_inspector import inspect_pipeline_script, load_script_module - -from dlt.cli import echo as fmt, pipeline_files as files_ops, source_detection -from dlt.cli import utils -from dlt.cli.config_toml_writer import WritableConfigValue, write_values -from dlt.cli.pipeline_files import VerifiedSourceFiles, TVerifiedSourceFileEntry, TVerifiedSourceFileIndex -from dlt.cli.exceptions import CliCommandException -from dlt.cli.requirements import SourceRequirements +from dlt.version import DLT_PKG_NAME, __version__ DLT_INIT_DOCS_URL = "https://dlthub.com/docs/reference/command-line-interface#dlt-init" DEFAULT_VERIFIED_SOURCES_REPO = "https://github.com/dlt-hub/verified-sources.git" @@ -34,7 +42,9 @@ SOURCES_MODULE_NAME = "sources" -def _get_template_files(command_module: ModuleType, use_generic_template: bool) -> Tuple[str, List[str]]: +def _get_template_files( + command_module: ModuleType, use_generic_template: bool +) -> Tuple[str, List[str]]: template_files: List[str] = command_module.TEMPLATE_FILES pipeline_script: str = command_module.PIPELINE_SCRIPT if use_generic_template: @@ -48,22 +58,41 @@ def _select_source_files( remote_modified: Dict[str, TVerifiedSourceFileEntry], remote_deleted: Dict[str, TVerifiedSourceFileEntry], conflict_modified: Sequence[str], - conflict_deleted: Sequence[str] + conflict_deleted: Sequence[str], ) -> Tuple[str, Dict[str, TVerifiedSourceFileEntry], Dict[str, TVerifiedSourceFileEntry]]: # some files were changed and cannot be updated (or are created without index) - fmt.echo("Existing files for %s source were changed and cannot be automatically updated" % fmt.bold(source_name)) + fmt.echo( + "Existing files for %s source were changed and cannot be automatically updated" + % fmt.bold(source_name) + ) if conflict_modified: - fmt.echo("Following files are MODIFIED locally and CONFLICT with incoming changes: %s" % fmt.bold(", ".join(conflict_modified))) + fmt.echo( + "Following files are MODIFIED locally and CONFLICT with incoming changes: %s" + % fmt.bold(", ".join(conflict_modified)) + ) if conflict_deleted: - fmt.echo("Following files are DELETED locally and CONFLICT with incoming changes: %s" % fmt.bold(", ".join(conflict_deleted))) + fmt.echo( + "Following files are DELETED locally and CONFLICT with incoming changes: %s" + % fmt.bold(", ".join(conflict_deleted)) + ) can_update_files = set(remote_modified.keys()) - set(conflict_modified) can_delete_files = set(remote_deleted.keys()) - set(conflict_deleted) if len(can_update_files) > 0 or len(can_delete_files) > 0: if len(can_update_files) > 0: - fmt.echo("Following files can be automatically UPDATED: %s" % fmt.bold(", ".join(can_update_files))) + fmt.echo( + "Following files can be automatically UPDATED: %s" + % fmt.bold(", ".join(can_update_files)) + ) if len(can_delete_files) > 0: - fmt.echo("Following files can be automatically DELETED: %s" % fmt.bold(", ".join(can_delete_files))) - prompt = "Should incoming changes be Skipped, Applied (local changes will be lost) or Merged (%s UPDATED | %s DELETED | all local changes remain)?" % (fmt.bold(",".join(can_update_files)), fmt.bold(",".join(can_delete_files))) + fmt.echo( + "Following files can be automatically DELETED: %s" + % fmt.bold(", ".join(can_delete_files)) + ) + prompt = ( + "Should incoming changes be Skipped, Applied (local changes will be lost) or Merged (%s" + " UPDATED | %s DELETED | all local changes remain)?" + % (fmt.bold(",".join(can_update_files)), fmt.bold(",".join(can_delete_files))) + ) choices = "sam" else: prompt = "Should incoming changes be Skipped or Applied?" @@ -78,8 +107,8 @@ def _select_source_files( elif resolution == "m": # update what we can fmt.echo("Merging the incoming changes. No files with local changes were modified.") - remote_modified = {n:e for n, e in remote_modified.items() if n in can_update_files} - remote_deleted = {n:e for n, e in remote_deleted.items() if n in can_delete_files} + remote_modified = {n: e for n, e in remote_modified.items() if n in can_update_files} + remote_deleted = {n: e for n, e in remote_deleted.items() if n in can_delete_files} else: # fully overwrite, leave all files to be copied fmt.echo("Applying all incoming changes to local files.") @@ -96,7 +125,9 @@ def _get_dependency_system(dest_storage: FileStorage) -> str: return None -def _list_verified_sources(repo_location: str, branch: str = None) -> Dict[str, VerifiedSourceFiles]: +def _list_verified_sources( + repo_location: str, branch: str = None +) -> Dict[str, VerifiedSourceFiles]: clone_storage = git.get_fresh_repo_files(repo_location, get_dlt_repos_dir(), branch=branch) sources_storage = FileStorage(clone_storage.make_full_path(SOURCES_MODULE_NAME)) @@ -110,41 +141,73 @@ def _list_verified_sources(repo_location: str, branch: str = None) -> Dict[str, return sources -def _welcome_message(source_name: str, destination_name: str, source_files: VerifiedSourceFiles, dependency_system: str, is_new_source: bool) -> None: +def _welcome_message( + source_name: str, + destination_name: str, + source_files: VerifiedSourceFiles, + dependency_system: str, + is_new_source: bool, +) -> None: fmt.echo() if source_files.is_template: fmt.echo("Your new pipeline %s is ready to be customized!" % fmt.bold(source_name)) - fmt.echo("* Review and change how dlt loads your data in %s" % fmt.bold(source_files.dest_pipeline_script)) + fmt.echo( + "* Review and change how dlt loads your data in %s" + % fmt.bold(source_files.dest_pipeline_script) + ) else: if is_new_source: fmt.echo("Verified source %s was added to your project!" % fmt.bold(source_name)) - fmt.echo("* See the usage examples and code snippets to copy from %s" % fmt.bold(source_files.dest_pipeline_script)) + fmt.echo( + "* See the usage examples and code snippets to copy from %s" + % fmt.bold(source_files.dest_pipeline_script) + ) else: - fmt.echo("Verified source %s was updated to the newest version!" % fmt.bold(source_name)) + fmt.echo( + "Verified source %s was updated to the newest version!" % fmt.bold(source_name) + ) if is_new_source: - fmt.echo("* Add credentials for %s and other secrets in %s" % (fmt.bold(destination_name), fmt.bold(make_dlt_settings_path(SECRETS_TOML)))) + fmt.echo( + "* Add credentials for %s and other secrets in %s" + % (fmt.bold(destination_name), fmt.bold(make_dlt_settings_path(SECRETS_TOML))) + ) if dependency_system: fmt.echo("* Add the required dependencies to %s:" % fmt.bold(dependency_system)) compiled_requirements = source_files.requirements.compiled() for dep in compiled_requirements: fmt.echo(" " + fmt.bold(dep)) - fmt.echo(" If the dlt dependency is already added, make sure you install the extra for %s to it" % fmt.bold(destination_name)) + fmt.echo( + " If the dlt dependency is already added, make sure you install the extra for %s to it" + % fmt.bold(destination_name) + ) if dependency_system == utils.REQUIREMENTS_TXT: qs = "' '" - fmt.echo(" To install with pip: %s" % fmt.bold(f"pip3 install '{qs.join(compiled_requirements)}'")) + fmt.echo( + " To install with pip: %s" + % fmt.bold(f"pip3 install '{qs.join(compiled_requirements)}'") + ) elif dependency_system == utils.PYPROJECT_TOML: fmt.echo(" If you are using poetry you may issue the following command:") fmt.echo(fmt.bold(" poetry add %s -E %s" % (DLT_PKG_NAME, destination_name))) fmt.echo() else: - fmt.echo("* %s was created. Install it with:\npip3 install -r %s" % (fmt.bold(utils.REQUIREMENTS_TXT), utils.REQUIREMENTS_TXT)) + fmt.echo( + "* %s was created. Install it with:\npip3 install -r %s" + % (fmt.bold(utils.REQUIREMENTS_TXT), utils.REQUIREMENTS_TXT) + ) if is_new_source: - fmt.echo("* Read %s for more information" % fmt.bold("https://dlthub.com/docs/walkthroughs/create-a-pipeline")) + fmt.echo( + "* Read %s for more information" + % fmt.bold("https://dlthub.com/docs/walkthroughs/create-a-pipeline") + ) else: - fmt.echo("* Read %s for more information" % fmt.bold("https://dlthub.com/docs/walkthroughs/add-a-verified-source")) + fmt.echo( + "* Read %s for more information" + % fmt.bold("https://dlthub.com/docs/walkthroughs/add-a-verified-source") + ) def list_verified_sources_command(repo_location: str, branch: str = None) -> None: @@ -158,7 +221,13 @@ def list_verified_sources_command(repo_location: str, branch: str = None) -> Non fmt.echo(msg) -def init_command(source_name: str, destination_name: str, use_generic_template: bool, repo_location: str, branch: str = None) -> None: +def init_command( + source_name: str, + destination_name: str, + use_generic_template: bool, + repo_location: str, + branch: str = None, +) -> None: # try to import the destination and get config spec destination_reference = DestinationReference.from_name(destination_name) destination_spec = destination_reference.spec() @@ -192,76 +261,115 @@ def init_command(source_name: str, destination_name: str, use_generic_template: source_files = files_ops.get_verified_source_files(sources_storage, source_name) # get file index from remote verified source files being copied remote_index = files_ops.get_remote_source_index( - source_files.storage.storage_path, source_files.files, source_files.requirements.dlt_version_constraint() + source_files.storage.storage_path, + source_files.files, + source_files.requirements.dlt_version_constraint(), ) # diff local and remote index to get modified and deleted files - remote_new, remote_modified, remote_deleted = files_ops.gen_index_diff(local_index, remote_index) + remote_new, remote_modified, remote_deleted = files_ops.gen_index_diff( + local_index, remote_index + ) # find files that are modified locally - conflict_modified, conflict_deleted = files_ops.find_conflict_files(local_index, remote_new, remote_modified, remote_deleted, dest_storage) + conflict_modified, conflict_deleted = files_ops.find_conflict_files( + local_index, remote_new, remote_modified, remote_deleted, dest_storage + ) # add new to modified remote_modified.update(remote_new) if conflict_modified or conflict_deleted: # select source files that can be copied/updated _, remote_modified, remote_deleted = _select_source_files( - source_name, - remote_modified, - remote_deleted, - conflict_modified, - conflict_deleted + source_name, remote_modified, remote_deleted, conflict_modified, conflict_deleted ) if not remote_deleted and not remote_modified: fmt.echo("No files to update, exiting") return if remote_index["is_dirty"]: - fmt.warning(f"The verified sources repository is dirty. {source_name} source files may not update correctly in the future.") + fmt.warning( + f"The verified sources repository is dirty. {source_name} source files may not" + " update correctly in the future." + ) # add template files source_files.files.extend(template_files) else: - if not is_valid_schema_name(source_name): raise InvalidSchemaName(source_name) dest_pipeline_script = source_name + ".py" - source_files = VerifiedSourceFiles(True, init_storage, pipeline_script, dest_pipeline_script, template_files, SourceRequirements([]), "") + source_files = VerifiedSourceFiles( + True, + init_storage, + pipeline_script, + dest_pipeline_script, + template_files, + SourceRequirements([]), + "", + ) if dest_storage.has_file(dest_pipeline_script): fmt.warning("Pipeline script %s already exist, exiting" % dest_pipeline_script) return # add .dlt/*.toml files to be copied - source_files.files.extend([make_dlt_settings_path(CONFIG_TOML), make_dlt_settings_path(SECRETS_TOML)]) + source_files.files.extend( + [make_dlt_settings_path(CONFIG_TOML), make_dlt_settings_path(SECRETS_TOML)] + ) # add dlt extras line to requirements source_files.requirements.update_dlt_extras(destination_name) # Check compatibility with installed dlt if not source_files.requirements.is_installed_dlt_compatible(): - msg = f"This pipeline requires a newer version of dlt than your installed version ({source_files.requirements.current_dlt_version()}). " \ - f"Pipeline requires '{source_files.requirements.dlt_requirement_base}'" + msg = ( + "This pipeline requires a newer version of dlt than your installed version" + f" ({source_files.requirements.current_dlt_version()}). Pipeline requires" + f" '{source_files.requirements.dlt_requirement_base}'" + ) fmt.warning(msg) - if not fmt.confirm("Would you like to continue anyway? (you can update dlt after this step)", default=True): - fmt.echo(f'You can update dlt with: pip3 install -U "{source_files.requirements.dlt_requirement_base}"') + if not fmt.confirm( + "Would you like to continue anyway? (you can update dlt after this step)", default=True + ): + fmt.echo( + "You can update dlt with: pip3 install -U" + f' "{source_files.requirements.dlt_requirement_base}"' + ) return # read module source and parse it - visitor = utils.parse_init_script("init", source_files.storage.load(source_files.pipeline_script), source_files.pipeline_script) + visitor = utils.parse_init_script( + "init", + source_files.storage.load(source_files.pipeline_script), + source_files.pipeline_script, + ) if visitor.is_destination_imported: - raise CliCommandException("init", f"The pipeline script {source_files.pipeline_script} import a destination from dlt.destinations. You should specify destinations by name when calling dlt.pipeline or dlt.run in init scripts.") + raise CliCommandException( + "init", + f"The pipeline script {source_files.pipeline_script} import a destination from" + " dlt.destinations. You should specify destinations by name when calling dlt.pipeline" + " or dlt.run in init scripts.", + ) if n.PIPELINE not in visitor.known_calls: - raise CliCommandException("init", f"The pipeline script {source_files.pipeline_script} does not seem to initialize pipeline with dlt.pipeline. Please initialize pipeline explicitly in init scripts.") + raise CliCommandException( + "init", + f"The pipeline script {source_files.pipeline_script} does not seem to initialize" + " pipeline with dlt.pipeline. Please initialize pipeline explicitly in init scripts.", + ) # find all arguments in all calls to replace transformed_nodes = source_detection.find_call_arguments_to_replace( visitor, - [("destination", destination_name), ("pipeline_name", source_name), ("dataset_name", source_name + "_data")], - source_files.pipeline_script + [ + ("destination", destination_name), + ("pipeline_name", source_name), + ("dataset_name", source_name + "_data"), + ], + source_files.pipeline_script, ) # inspect the script inspect_pipeline_script( source_files.storage.storage_path, source_files.storage.to_relative_path(source_files.pipeline_script), - ignore_missing_imports=True + ignore_missing_imports=True, ) # detect all the required secrets and configs that should go into tomls files @@ -269,32 +377,57 @@ def init_command(source_name: str, destination_name: str, use_generic_template: # replace destination, pipeline_name and dataset_name in templates transformed_nodes = source_detection.find_call_arguments_to_replace( visitor, - [("destination", destination_name), ("pipeline_name", source_name), ("dataset_name", source_name + "_data")], - source_files.pipeline_script + [ + ("destination", destination_name), + ("pipeline_name", source_name), + ("dataset_name", source_name + "_data"), + ], + source_files.pipeline_script, ) # template sources are always in module starting with "pipeline" # for templates, place config and secrets into top level section - required_secrets, required_config, checked_sources = source_detection.detect_source_configs(_SOURCES, "pipeline", ()) + required_secrets, required_config, checked_sources = source_detection.detect_source_configs( + _SOURCES, "pipeline", () + ) # template has a strict rules where sources are placed for source_q_name, source_config in checked_sources.items(): if source_q_name not in visitor.known_sources_resources: - raise CliCommandException("init", f"The pipeline script {source_files.pipeline_script} imports a source/resource {source_config.f.__name__} from module {source_config.module.__name__}. In init scripts you must declare all sources and resources in single file.") + raise CliCommandException( + "init", + f"The pipeline script {source_files.pipeline_script} imports a source/resource" + f" {source_config.f.__name__} from module {source_config.module.__name__}. In" + " init scripts you must declare all sources and resources in single file.", + ) # rename sources and resources - transformed_nodes.extend(source_detection.find_source_calls_to_replace(visitor, source_name)) + transformed_nodes.extend( + source_detection.find_source_calls_to_replace(visitor, source_name) + ) else: # replace only destination for existing pipelines - transformed_nodes = source_detection.find_call_arguments_to_replace(visitor, [("destination", destination_name)], source_files.pipeline_script) + transformed_nodes = source_detection.find_call_arguments_to_replace( + visitor, [("destination", destination_name)], source_files.pipeline_script + ) # pipeline sources are in module with name starting from {pipeline_name} # for verified pipelines place in the specific source section - required_secrets, required_config, checked_sources = source_detection.detect_source_configs(_SOURCES, source_name, (known_sections.SOURCES, source_name)) + required_secrets, required_config, checked_sources = source_detection.detect_source_configs( + _SOURCES, source_name, (known_sections.SOURCES, source_name) + ) if len(checked_sources) == 0: - raise CliCommandException("init", f"The pipeline script {source_files.pipeline_script} is not creating or importing any sources or resources") + raise CliCommandException( + "init", + f"The pipeline script {source_files.pipeline_script} is not creating or importing any" + " sources or resources", + ) # add destination spec to required secrets - required_secrets["destinations:" + destination_name] = WritableConfigValue(destination_name, destination_spec, None, ("destination",)) + required_secrets["destinations:" + destination_name] = WritableConfigValue( + destination_name, destination_spec, None, ("destination",) + ) # add the global telemetry to required config - required_config["runtime.dlthub_telemetry"] = WritableConfigValue("dlthub_telemetry", bool, utils.get_telemetry_status(), ("runtime", )) + required_config["runtime.dlthub_telemetry"] = WritableConfigValue( + "dlthub_telemetry", bool, utils.get_telemetry_status(), ("runtime",) + ) # modify the script script_lines = rewrite_python_script(visitor.source_lines, transformed_nodes) @@ -305,9 +438,15 @@ def init_command(source_name: str, destination_name: str, use_generic_template: # ask for confirmation if is_new_source: if source_files.is_template: - fmt.echo("A verified source %s was not found. Using a template to create a new source and pipeline with name %s." % (fmt.bold(source_name), fmt.bold(source_name))) + fmt.echo( + "A verified source %s was not found. Using a template to create a new source and" + " pipeline with name %s." % (fmt.bold(source_name), fmt.bold(source_name)) + ) else: - fmt.echo("Cloning and configuring a verified source %s (%s)" % (fmt.bold(source_name), source_files.doc)) + fmt.echo( + "Cloning and configuring a verified source %s (%s)" + % (fmt.bold(source_name), source_files.doc) + ) if use_generic_template: fmt.warning("--generic parameter is meaningless if verified source is found") if not fmt.confirm("Do you want to proceed?", default=True): @@ -339,7 +478,9 @@ def init_command(source_name: str, destination_name: str, use_generic_template: for file_name in remote_deleted: if dest_storage.has_file(file_name): dest_storage.delete(file_name) - files_ops.save_verified_source_local_index(source_name, remote_index, remote_modified, remote_deleted) + files_ops.save_verified_source_local_index( + source_name, remote_index, remote_modified, remote_deleted + ) # create script if not dest_storage.has_file(source_files.dest_pipeline_script): dest_storage.save(source_files.dest_pipeline_script, dest_script_source) diff --git a/dlt/cli/pipeline_command.py b/dlt/cli/pipeline_command.py index 52a9c8ffdc..fb90054d3f 100644 --- a/dlt/cli/pipeline_command.py +++ b/dlt/cli/pipeline_command.py @@ -1,25 +1,33 @@ -import yaml from typing import Any + +import yaml + import dlt +from dlt.cli import echo as fmt from dlt.cli.exceptions import CliCommandException - from dlt.common import json -from dlt.common.pipeline import resource_state, get_dlt_pipelines_dir, TSourceState from dlt.common.destination.reference import TDestinationReferenceArg +from dlt.common.pipeline import TSourceState, get_dlt_pipelines_dir, resource_state from dlt.common.runners import Venv from dlt.common.runners.stdout import iter_stdout from dlt.common.schema.utils import group_tables_by_resource, remove_defaults from dlt.common.storages.file_storage import FileStorage from dlt.common.typing import DictStrAny -from dlt.pipeline.helpers import DropCommand from dlt.pipeline.exceptions import CannotRestorePipelineException - -from dlt.cli import echo as fmt +from dlt.pipeline.helpers import DropCommand DLT_PIPELINE_COMMAND_DOCS_URL = "https://dlthub.com/docs/reference/command-line-interface" -def pipeline_command(operation: str, pipeline_name: str, pipelines_dir: str, verbosity: int, dataset_name: str = None, destination: TDestinationReferenceArg = None, **command_kwargs: Any) -> None: +def pipeline_command( + operation: str, + pipeline_name: str, + pipelines_dir: str, + verbosity: int, + dataset_name: str = None, + destination: TDestinationReferenceArg = None, + **command_kwargs: Any, +) -> None: if operation == "list": pipelines_dir = pipelines_dir or get_dlt_pipelines_dir() storage = FileStorage(pipelines_dir) @@ -38,16 +46,26 @@ def pipeline_command(operation: str, pipeline_name: str, pipelines_dir: str, ver if operation not in {"sync", "drop"}: raise fmt.warning(str(e)) - if not fmt.confirm("Do you want to attempt to restore the pipeline state from destination?", default=False): + if not fmt.confirm( + "Do you want to attempt to restore the pipeline state from destination?", default=False + ): return - destination = destination or fmt.text_input(f"Enter destination name for pipeline {fmt.bold(pipeline_name)}") - dataset_name = dataset_name or fmt.text_input(f"Enter dataset name for pipeline {fmt.bold(pipeline_name)}") - p = dlt.pipeline(pipeline_name, pipelines_dir, destination=destination, dataset_name=dataset_name) + destination = destination or fmt.text_input( + f"Enter destination name for pipeline {fmt.bold(pipeline_name)}" + ) + dataset_name = dataset_name or fmt.text_input( + f"Enter dataset name for pipeline {fmt.bold(pipeline_name)}" + ) + p = dlt.pipeline( + pipeline_name, pipelines_dir, destination=destination, dataset_name=dataset_name + ) p.sync_destination() if p.first_run: # remote state was not found p._wipe_working_folder() - fmt.error(f"Pipeline {pipeline_name} was not found in dataset {dataset_name} in {destination}") + fmt.error( + f"Pipeline {pipeline_name} was not found in dataset {dataset_name} in {destination}" + ) return if operation == "sync": return # No need to sync again @@ -60,7 +78,9 @@ def pipeline_command(operation: str, pipeline_name: str, pipelines_dir: str, ver with signals.delayed_signals(): venv = Venv.restore_current() - for line in iter_stdout(venv, "streamlit", "run", streamlit_helper.__file__, pipeline_name): + for line in iter_stdout( + venv, "streamlit", "run", streamlit_helper.__file__, pipeline_name + ): fmt.echo(line) if operation == "info": @@ -88,32 +108,52 @@ def pipeline_command(operation: str, pipeline_name: str, pipelines_dir: str, ver fmt.warning("This pipeline does not have a default schema") else: is_single_schema = len(p.schema_names) == 1 - for schema_name in p.schema_names: + for schema_name in p.schema_names: fmt.echo("Resources in schema: %s" % fmt.bold(schema_name)) schema = p.schemas[schema_name] data_tables = {t["name"]: t for t in schema.data_tables()} for resource_name, tables in group_tables_by_resource(data_tables).items(): res_state_slots = 0 if sources_state: - source_state = next(iter(sources_state.items()))[1] if is_single_schema else sources_state.get(schema_name) + source_state = ( + next(iter(sources_state.items()))[1] + if is_single_schema + else sources_state.get(schema_name) + ) if source_state: resource_state_ = resource_state(resource_name, source_state) res_state_slots = len(resource_state_) - fmt.echo("%s with %s table(s) and %s resource state slot(s)" % (fmt.bold(resource_name), fmt.bold(str(len(tables))), fmt.bold(str(res_state_slots)))) + fmt.echo( + "%s with %s table(s) and %s resource state slot(s)" + % ( + fmt.bold(resource_name), + fmt.bold(str(len(tables))), + fmt.bold(str(res_state_slots)), + ) + ) fmt.echo() fmt.echo("Working dir content:") extracted_files = p.list_extracted_resources() if extracted_files: - fmt.echo("Has %s extracted files ready to be normalized" % fmt.bold(str(len(extracted_files)))) + fmt.echo( + "Has %s extracted files ready to be normalized" + % fmt.bold(str(len(extracted_files))) + ) norm_packages = p.list_normalized_load_packages() if norm_packages: - fmt.echo("Has %s load packages ready to be loaded with following load ids:" % fmt.bold(str(len(norm_packages)))) + fmt.echo( + "Has %s load packages ready to be loaded with following load ids:" + % fmt.bold(str(len(norm_packages))) + ) for load_id in norm_packages: fmt.echo(load_id) fmt.echo() loaded_packages = p.list_completed_load_packages() if loaded_packages: - fmt.echo("Has %s completed load packages with following load ids:" % fmt.bold(str(len(loaded_packages)))) + fmt.echo( + "Has %s completed load packages with following load ids:" + % fmt.bold(str(len(loaded_packages))) + ) for load_id in loaded_packages: fmt.echo(load_id) fmt.echo() @@ -121,7 +161,10 @@ def pipeline_command(operation: str, pipeline_name: str, pipelines_dir: str, ver if trace is None or len(trace.steps) == 0: fmt.echo("Pipeline does not have last run trace.") else: - fmt.echo("Pipeline has last run trace. Use 'dlt pipeline %s trace' to inspect " % pipeline_name) + fmt.echo( + "Pipeline has last run trace. Use 'dlt pipeline %s trace' to inspect " + % pipeline_name + ) if operation == "trace": trace = p.last_trace @@ -138,7 +181,13 @@ def pipeline_command(operation: str, pipeline_name: str, pipelines_dir: str, ver failed_jobs = p.list_failed_jobs_in_package(load_id) if failed_jobs: for failed_job in p.list_failed_jobs_in_package(load_id): - fmt.echo("JOB: %s(%s)" % (fmt.bold(failed_job.job_file_info.job_id()), fmt.bold(failed_job.job_file_info.table_name))) + fmt.echo( + "JOB: %s(%s)" + % ( + fmt.bold(failed_job.job_file_info.job_id()), + fmt.bold(failed_job.job_file_info.table_name), + ) + ) fmt.echo("JOB file type: %s" % fmt.bold(failed_job.job_file_info.file_format)) fmt.echo("JOB file path: %s" % fmt.bold(failed_job.file_path)) if verbosity > 0: @@ -148,26 +197,33 @@ def pipeline_command(operation: str, pipeline_name: str, pipelines_dir: str, ver else: fmt.echo("No failed jobs found") - if operation == "sync": - if fmt.confirm("About to drop the local state of the pipeline and reset all the schemas. The destination state, data and schemas are left intact. Proceed?", default=False): + if fmt.confirm( + "About to drop the local state of the pipeline and reset all the schemas. The" + " destination state, data and schemas are left intact. Proceed?", + default=False, + ): fmt.echo("Dropping local state") p = p.drop() fmt.echo("Restoring from destination") p.sync_destination() if operation == "load-package": - load_id = command_kwargs.get('load_id') + load_id = command_kwargs.get("load_id") if not load_id: packages = sorted(p.list_normalized_load_packages()) if not packages: packages = sorted(p.list_completed_load_packages()) if not packages: - raise CliCommandException("pipeline", "There are no load packages for that pipeline") + raise CliCommandException( + "pipeline", "There are no load packages for that pipeline" + ) load_id = packages[-1] package_info = p.get_load_package_info(load_id) - fmt.echo("Package %s found in %s" % (fmt.bold(load_id), fmt.bold(package_info.package_path))) + fmt.echo( + "Package %s found in %s" % (fmt.bold(load_id), fmt.bold(package_info.package_path)) + ) fmt.echo(package_info.asstr(verbosity)) if len(package_info.schema_update) > 0: if verbosity == 0: @@ -175,7 +231,9 @@ def pipeline_command(operation: str, pipeline_name: str, pipelines_dir: str, ver else: tables = remove_defaults({"tables": package_info.schema_update}) # type: ignore fmt.echo(fmt.bold("Schema update:")) - fmt.echo(yaml.dump(tables, allow_unicode=True, default_flow_style=False, sort_keys=False)) + fmt.echo( + yaml.dump(tables, allow_unicode=True, default_flow_style=False, sort_keys=False) + ) if operation == "schema": if not p.default_schema_name: @@ -188,7 +246,10 @@ def pipeline_command(operation: str, pipeline_name: str, pipelines_dir: str, ver if operation == "drop": drop = DropCommand(p, **command_kwargs) if drop.is_empty: - fmt.echo("Could not select any resources to drop and no resource/source state to reset. Use the command below to inspect the pipeline:") + fmt.echo( + "Could not select any resources to drop and no resource/source state to reset. Use" + " the command below to inspect the pipeline:" + ) fmt.echo(f"dlt pipeline -v {p.pipeline_name} info") if len(drop.info["warnings"]): fmt.echo("Additional warnings are available") @@ -196,12 +257,23 @@ def pipeline_command(operation: str, pipeline_name: str, pipelines_dir: str, ver fmt.warning(warning) return - fmt.echo("About to drop the following data in dataset %s in destination %s:" % (fmt.bold(drop.info["dataset_name"]), fmt.bold(p.destination.__name__))) + fmt.echo( + "About to drop the following data in dataset %s in destination %s:" + % (fmt.bold(drop.info["dataset_name"]), fmt.bold(p.destination.__name__)) + ) fmt.echo("%s: %s" % (fmt.style("Selected schema", fg="green"), drop.info["schema_name"])) - fmt.echo("%s: %s" % (fmt.style("Selected resource(s)", fg="green"), drop.info["resource_names"])) + fmt.echo( + "%s: %s" % (fmt.style("Selected resource(s)", fg="green"), drop.info["resource_names"]) + ) fmt.echo("%s: %s" % (fmt.style("Table(s) to drop", fg="green"), drop.info["tables"])) - fmt.echo("%s: %s" % (fmt.style("Resource(s) state to reset", fg="green"), drop.info["resource_states"])) - fmt.echo("%s: %s" % (fmt.style("Source state path(s) to reset", fg="green"), drop.info["state_paths"])) + fmt.echo( + "%s: %s" + % (fmt.style("Resource(s) state to reset", fg="green"), drop.info["resource_states"]) + ) + fmt.echo( + "%s: %s" + % (fmt.style("Source state path(s) to reset", fg="green"), drop.info["state_paths"]) + ) # for k, v in drop.info.items(): # fmt.echo("%s: %s" % (fmt.style(k, fg="green"), v)) for warning in drop.info["warnings"]: diff --git a/dlt/cli/pipeline_files.py b/dlt/cli/pipeline_files.py index acd3a95e80..73aa967627 100644 --- a/dlt/cli/pipeline_files.py +++ b/dlt/cli/pipeline_files.py @@ -1,21 +1,19 @@ import fnmatch import hashlib import os -import yaml import posixpath from pathlib import Path -from typing import Dict, NamedTuple, Sequence, Tuple, TypedDict, List -from dlt.cli.exceptions import VerifiedSourceRepoError +from typing import Dict, List, NamedTuple, Sequence, Tuple, TypedDict -from dlt.common import git -from dlt.common.configuration.paths import make_dlt_settings_path -from dlt.common.storages import FileStorage - -from dlt.common.reflection.utils import get_module_docstring +import yaml from dlt.cli import utils +from dlt.cli.exceptions import VerifiedSourceRepoError from dlt.cli.requirements import SourceRequirements - +from dlt.common import git +from dlt.common.configuration.paths import make_dlt_settings_path +from dlt.common.reflection.utils import get_module_docstring +from dlt.common.storages import FileStorage SOURCES_INIT_INFO_ENGINE_VERSION = 1 SOURCES_INIT_INFO_FILE = ".sources" @@ -65,17 +63,14 @@ def _load_dot_sources() -> TVerifiedSourcesFileIndex: raise FileNotFoundError(SOURCES_INIT_INFO_FILE) return index except FileNotFoundError: - return { - "engine_version": SOURCES_INIT_INFO_ENGINE_VERSION, - "sources": {} - } + return {"engine_version": SOURCES_INIT_INFO_ENGINE_VERSION, "sources": {}} def _merge_remote_index( local_index: TVerifiedSourceFileIndex, remote_index: TVerifiedSourceFileIndex, remote_modified: Dict[str, TVerifiedSourceFileEntry], - remote_deleted: Dict[str, TVerifiedSourceFileEntry] + remote_deleted: Dict[str, TVerifiedSourceFileEntry], ) -> TVerifiedSourceFileIndex: # update all modified files local_index["files"].update(remote_modified) @@ -92,13 +87,15 @@ def _merge_remote_index( def load_verified_sources_local_index(source_name: str) -> TVerifiedSourceFileIndex: - return _load_dot_sources()["sources"].get(source_name, { - "is_dirty": False, - "last_commit_sha": None, - "last_commit_timestamp": None, - "files": {}, - "dlt_version_constraint": ">=0.1.0" - } + return _load_dot_sources()["sources"].get( + source_name, + { + "is_dirty": False, + "last_commit_sha": None, + "last_commit_timestamp": None, + "files": {}, + "dlt_version_constraint": ">=0.1.0", + }, ) @@ -106,17 +103,17 @@ def save_verified_source_local_index( source_name: str, remote_index: TVerifiedSourceFileIndex, remote_modified: Dict[str, TVerifiedSourceFileEntry], - remote_deleted: Dict[str, TVerifiedSourceFileEntry] + remote_deleted: Dict[str, TVerifiedSourceFileEntry], ) -> None: - all_sources = _load_dot_sources() local_index = all_sources["sources"].setdefault(source_name, remote_index) _merge_remote_index(local_index, remote_index, remote_modified, remote_deleted) _save_dot_sources(all_sources) -def get_remote_source_index(repo_path: str, files: Sequence[str], dlt_version_constraint: str) -> TVerifiedSourceFileIndex: - +def get_remote_source_index( + repo_path: str, files: Sequence[str], dlt_version_constraint: str +) -> TVerifiedSourceFileIndex: with git.get_repo(repo_path) as repo: tree = repo.tree() commit_sha = repo.head.commit.hexsha @@ -136,7 +133,7 @@ def get_remote_source_index(repo_path: str, files: Sequence[str], dlt_version_co files_sha[file] = { "commit_sha": commit_sha, "git_sha": blob_sha3, - "sha3_256": hashlib.sha3_256(file_blob).hexdigest() + "sha3_256": hashlib.sha3_256(file_blob).hexdigest(), } return { @@ -144,26 +141,37 @@ def get_remote_source_index(repo_path: str, files: Sequence[str], dlt_version_co "last_commit_sha": commit_sha, "last_commit_timestamp": repo.head.commit.committed_datetime.isoformat(), "files": files_sha, - "dlt_version_constraint": dlt_version_constraint + "dlt_version_constraint": dlt_version_constraint, } def get_verified_source_names(sources_storage: FileStorage) -> List[str]: candidates: List[str] = [] - for name in [n for n in sources_storage.list_folder_dirs(".", to_root=False) if not any(fnmatch.fnmatch(n, ignore) for ignore in IGNORE_SOURCES)]: + for name in [ + n + for n in sources_storage.list_folder_dirs(".", to_root=False) + if not any(fnmatch.fnmatch(n, ignore) for ignore in IGNORE_SOURCES) + ]: # must contain at least one valid python script if any(f.endswith(".py") for f in sources_storage.list_folder_files(name, to_root=False)): candidates.append(name) return candidates -def get_verified_source_files(sources_storage: FileStorage, source_name: str) -> VerifiedSourceFiles: +def get_verified_source_files( + sources_storage: FileStorage, source_name: str +) -> VerifiedSourceFiles: if not sources_storage.has_folder(source_name): - raise VerifiedSourceRepoError(f"Verified source {source_name} could not be found in the repository", source_name) + raise VerifiedSourceRepoError( + f"Verified source {source_name} could not be found in the repository", source_name + ) # find example script example_script = f"{source_name}_pipeline.py" if not sources_storage.has_file(example_script): - raise VerifiedSourceRepoError(f"Pipeline example script {example_script} could not be found in the repository", source_name) + raise VerifiedSourceRepoError( + f"Pipeline example script {example_script} could not be found in the repository", + source_name, + ) # get all files recursively files: List[str] = [] for root, subdirs, _files in os.walk(sources_storage.make_full_path(source_name)): @@ -172,9 +180,15 @@ def get_verified_source_files(sources_storage: FileStorage, source_name: str) -> if any(fnmatch.fnmatch(subdir, ignore) for ignore in IGNORE_FILES): subdirs.remove(subdir) rel_root = sources_storage.to_relative_path(root) - files.extend([os.path.join(rel_root, file) for file in _files if all(not fnmatch.fnmatch(file, ignore) for ignore in IGNORE_FILES)]) + files.extend( + [ + os.path.join(rel_root, file) + for file in _files + if all(not fnmatch.fnmatch(file, ignore) for ignore in IGNORE_FILES) + ] + ) # read the docs - init_py = os.path.join(source_name, utils.MODULE_INIT) + init_py = os.path.join(source_name, utils.MODULE_INIT) docstring: str = "" if sources_storage.has_file(init_py): docstring = get_module_docstring(sources_storage.load(init_py)) @@ -187,14 +201,18 @@ def get_verified_source_files(sources_storage: FileStorage, source_name: str) -> else: requirements = SourceRequirements([]) # find requirements - return VerifiedSourceFiles(False, sources_storage, example_script, example_script, files, requirements, docstring) + return VerifiedSourceFiles( + False, sources_storage, example_script, example_script, files, requirements, docstring + ) def gen_index_diff( - local_index: TVerifiedSourceFileIndex, - remote_index: TVerifiedSourceFileIndex -) -> Tuple[Dict[str, TVerifiedSourceFileEntry], Dict[str, TVerifiedSourceFileEntry], Dict[str, TVerifiedSourceFileEntry]]: - + local_index: TVerifiedSourceFileIndex, remote_index: TVerifiedSourceFileIndex +) -> Tuple[ + Dict[str, TVerifiedSourceFileEntry], + Dict[str, TVerifiedSourceFileEntry], + Dict[str, TVerifiedSourceFileEntry], +]: deleted: Dict[str, TVerifiedSourceFileEntry] = {} modified: Dict[str, TVerifiedSourceFileEntry] = {} new: Dict[str, TVerifiedSourceFileEntry] = {} @@ -223,7 +241,7 @@ def find_conflict_files( remote_new: Dict[str, TVerifiedSourceFileEntry], remote_modified: Dict[str, TVerifiedSourceFileEntry], remote_deleted: Dict[str, TVerifiedSourceFileEntry], - dest_storage: FileStorage + dest_storage: FileStorage, ) -> Tuple[List[str], List[str]]: """Use files index from .sources to identify modified files via sha3 content hash""" diff --git a/dlt/cli/requirements.py b/dlt/cli/requirements.py index 79907ae01c..cf804f95fc 100644 --- a/dlt/cli/requirements.py +++ b/dlt/cli/requirements.py @@ -1,5 +1,6 @@ -from typing import Sequence, List from importlib.metadata import version as pkg_version +from typing import List, Sequence + from packaging.requirements import Requirement from dlt.version import DLT_PKG_NAME @@ -7,6 +8,7 @@ class SourceRequirements: """Helper class to parse and manipulate entries in source's requirements.txt""" + dlt_requirement: Requirement """Final dlt requirement that may be updated with destination extras""" dlt_requirement_base: Requirement diff --git a/dlt/cli/source_detection.py b/dlt/cli/source_detection.py index 369663b82f..dec0e736b0 100644 --- a/dlt/cli/source_detection.py +++ b/dlt/cli/source_detection.py @@ -1,20 +1,22 @@ import ast import inspect +from typing import Dict, List, Set, Tuple + from astunparse import unparse -from typing import Dict, Tuple, Set, List +from dlt.cli.config_toml_writer import WritableConfigValue +from dlt.cli.exceptions import CliCommandException from dlt.common.configuration import is_secret_hint from dlt.common.configuration.specs import BaseConfiguration from dlt.common.reflection.utils import creates_func_def_name_node -from dlt.common.typing import is_optional_type from dlt.common.source import SourceInfo - -from dlt.cli.config_toml_writer import WritableConfigValue -from dlt.cli.exceptions import CliCommandException +from dlt.common.typing import is_optional_type from dlt.reflection.script_visitor import PipelineScriptVisitor -def find_call_arguments_to_replace(visitor: PipelineScriptVisitor, replace_nodes: List[Tuple[str, str]], init_script_name: str) -> List[Tuple[ast.AST, ast.AST]]: +def find_call_arguments_to_replace( + visitor: PipelineScriptVisitor, replace_nodes: List[Tuple[str, str]], init_script_name: str +) -> List[Tuple[ast.AST, ast.AST]]: # the input tuple (call argument name, replacement value) # the returned tuple (node, replacement value, node type) transformed_nodes: List[Tuple[ast.AST, ast.AST]] = [] @@ -26,7 +28,11 @@ def find_call_arguments_to_replace(visitor: PipelineScriptVisitor, replace_nodes dn_node: ast.AST = args.arguments.get(t_arg_name) if dn_node is not None: if not isinstance(dn_node, ast.Constant) or not isinstance(dn_node.value, str): - raise CliCommandException("init", f"The pipeline script {init_script_name} must pass the {t_arg_name} as string to '{arg_name}' function in line {dn_node.lineno}") + raise CliCommandException( + "init", + f"The pipeline script {init_script_name} must pass the {t_arg_name} as" + f" string to '{arg_name}' function in line {dn_node.lineno}", + ) else: transformed_nodes.append((dn_node, ast.Constant(value=t_value, kind=None))) replaced_args.add(t_arg_name) @@ -34,27 +40,39 @@ def find_call_arguments_to_replace(visitor: PipelineScriptVisitor, replace_nodes # there was at least one replacement for t_arg_name, _ in replace_nodes: if t_arg_name not in replaced_args: - raise CliCommandException("init", f"The pipeline script {init_script_name} is not explicitly passing the '{t_arg_name}' argument to 'pipeline' or 'run' function. In init script the default and configured values are not accepted.") + raise CliCommandException( + "init", + f"The pipeline script {init_script_name} is not explicitly passing the" + f" '{t_arg_name}' argument to 'pipeline' or 'run' function. In init script the" + " default and configured values are not accepted.", + ) return transformed_nodes -def find_source_calls_to_replace(visitor: PipelineScriptVisitor, pipeline_name: str) -> List[Tuple[ast.AST, ast.AST]]: +def find_source_calls_to_replace( + visitor: PipelineScriptVisitor, pipeline_name: str +) -> List[Tuple[ast.AST, ast.AST]]: transformed_nodes: List[Tuple[ast.AST, ast.AST]] = [] for source_def in visitor.known_sources_resources.values(): # append function name to be replaced - transformed_nodes.append((creates_func_def_name_node(source_def, visitor.source_lines), ast.Name(id=pipeline_name + "_" + source_def.name))) + transformed_nodes.append( + ( + creates_func_def_name_node(source_def, visitor.source_lines), + ast.Name(id=pipeline_name + "_" + source_def.name), + ) + ) for calls in visitor.known_sources_resources_calls.values(): for call in calls: - transformed_nodes.append((call.func, ast.Name(id=pipeline_name + "_" + unparse(call.func)))) + transformed_nodes.append( + (call.func, ast.Name(id=pipeline_name + "_" + unparse(call.func))) + ) return transformed_nodes def detect_source_configs( - sources: Dict[str, SourceInfo], - module_prefix: str, - section: Tuple[str, ...] + sources: Dict[str, SourceInfo], module_prefix: str, section: Tuple[str, ...] ) -> Tuple[Dict[str, WritableConfigValue], Dict[str, WritableConfigValue], Dict[str, SourceInfo]]: # all detected secrets with sections required_secrets: Dict[str, WritableConfigValue] = {} @@ -75,11 +93,15 @@ def detect_source_configs( if is_secret_hint(field_type): val_store = required_secrets # all configs that are required and do not have a default value must go to config.toml - elif not is_optional_type(field_type) and getattr(source_config, field_name) is None: + elif ( + not is_optional_type(field_type) and getattr(source_config, field_name) is None + ): val_store = required_config if val_store is not None: # we are sure that all resources come from single file so we can put them in single section - val_store[source_name + ":" + field_name] = WritableConfigValue(field_name, field_type, None, section) + val_store[source_name + ":" + field_name] = WritableConfigValue( + field_name, field_type, None, section + ) return required_secrets, required_config, checked_sources diff --git a/dlt/cli/telemetry_command.py b/dlt/cli/telemetry_command.py index 574005797a..d281796ff8 100644 --- a/dlt/cli/telemetry_command.py +++ b/dlt/cli/telemetry_command.py @@ -1,13 +1,12 @@ import os +from dlt.cli import echo as fmt +from dlt.cli.config_toml_writer import WritableConfigValue, write_values +from dlt.cli.utils import get_telemetry_status from dlt.common.configuration import resolve_configuration from dlt.common.configuration.container import Container from dlt.common.configuration.providers.toml import ConfigTomlProvider from dlt.common.configuration.specs import RunConfiguration - -from dlt.cli import echo as fmt -from dlt.cli.utils import get_telemetry_status -from dlt.cli.config_toml_writer import WritableConfigValue, write_values from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext from dlt.common.runtime.segment import get_anonymous_id @@ -24,7 +23,9 @@ def telemetry_status_command() -> None: def change_telemetry_status_command(enabled: bool) -> None: # value to write - telemetry_value = [WritableConfigValue("dlthub_telemetry", bool, enabled, (RunConfiguration.__section__, ))] + telemetry_value = [ + WritableConfigValue("dlthub_telemetry", bool, enabled, (RunConfiguration.__section__,)) + ] # write local config config = ConfigTomlProvider(add_global_config=False) if not config.is_empty: diff --git a/dlt/cli/utils.py b/dlt/cli/utils.py index 996770b40d..934c9d1816 100644 --- a/dlt/cli/utils.py +++ b/dlt/cli/utils.py @@ -3,19 +3,16 @@ import tempfile from typing import Callable +from dlt.cli.exceptions import CliCommandException from dlt.common import git -from dlt.common.reflection.utils import set_ast_parents -from dlt.common.storages import FileStorage -from dlt.common.typing import TFun from dlt.common.configuration import resolve_configuration from dlt.common.configuration.specs import RunConfiguration +from dlt.common.reflection.utils import set_ast_parents from dlt.common.runtime.telemetry import with_telemetry - +from dlt.common.storages import FileStorage +from dlt.common.typing import TFun from dlt.reflection.script_visitor import PipelineScriptVisitor -from dlt.cli.exceptions import CliCommandException - - REQUIREMENTS_TXT = "requirements.txt" PYPROJECT_TOML = "pyproject.toml" GITHUB_WORKFLOWS_DIR = os.path.join(".github", "workflows") @@ -25,14 +22,20 @@ MODULE_INIT = "__init__.py" -def parse_init_script(command: str, script_source: str, init_script_name: str) -> PipelineScriptVisitor: +def parse_init_script( + command: str, script_source: str, init_script_name: str +) -> PipelineScriptVisitor: # parse the script first tree = ast.parse(source=script_source) set_ast_parents(tree) visitor = PipelineScriptVisitor(script_source) visitor.visit_passes(tree) if len(visitor.mod_aliases) == 0: - raise CliCommandException(command, f"The pipeline script {init_script_name} does not import dlt and does not seem to run any pipelines") + raise CliCommandException( + command, + f"The pipeline script {init_script_name} does not import dlt and does not seem to run" + " any pipelines", + ) return visitor @@ -45,8 +48,9 @@ def ensure_git_command(command: str) -> None: raise raise CliCommandException( command, - "'git' command is not available. Install and setup git with the following the guide %s" % "https://docs.github.com/en/get-started/quickstart/set-up-git", - imp_ex + "'git' command is not available. Install and setup git with the following the guide %s" + % "https://docs.github.com/en/get-started/quickstart/set-up-git", + imp_ex, ) from imp_ex diff --git a/dlt/common/__init__.py b/dlt/common/__init__.py index 222cb3d5d6..3ef723ecf9 100644 --- a/dlt/common/__init__.py +++ b/dlt/common/__init__.py @@ -1,6 +1,6 @@ from dlt.common.arithmetics import Decimal # noqa: F401 -from dlt.common.wei import Wei -from dlt.common.pendulum import pendulum # noqa: F401 from dlt.common.json import json # noqa: F401, I251 -from dlt.common.runtime.signals import sleep # noqa: F401 -from dlt.common.runtime import logger # noqa: F401 \ No newline at end of file +from dlt.common.pendulum import pendulum # noqa: F401 +from dlt.common.runtime import logger # noqa: F401 +from dlt.common.runtime.signals import sleep # noqa: F401 +from dlt.common.wei import Wei diff --git a/dlt/common/arithmetics.py b/dlt/common/arithmetics.py index 5277acad4f..2d6115b2b3 100644 --- a/dlt/common/arithmetics.py +++ b/dlt/common/arithmetics.py @@ -1,8 +1,18 @@ -import decimal # noqa: I251 +import decimal # noqa: I251 from contextlib import contextmanager +from decimal import ( # noqa: I251 + ROUND_HALF_UP, + Context, + ConversionSyntax, + Decimal, + DefaultContext, + DivisionByZero, + Inexact, + InvalidOperation, + Subnormal, + localcontext, +) from typing import Iterator -from decimal import ROUND_HALF_UP, Decimal, Inexact, DivisionByZero, DefaultContext, InvalidOperation, localcontext, Context, Subnormal, ConversionSyntax # noqa: I251 - DEFAULT_NUMERIC_PRECISION = 38 DEFAULT_NUMERIC_SCALE = 9 diff --git a/dlt/common/configuration/__init__.py b/dlt/common/configuration/__init__.py index 2f22314ff4..6237eaaac1 100644 --- a/dlt/common/configuration/__init__.py +++ b/dlt/common/configuration/__init__.py @@ -1,11 +1,15 @@ -from .specs.base_configuration import configspec, is_valid_hint, is_secret_hint, resolve_type # noqa: F401 -from .specs import known_sections # noqa: F401 -from .resolve import resolve_configuration, inject_section # noqa: F401 -from .inject import with_config, last_config, get_fun_spec # noqa: F401 - from .exceptions import ( # noqa: F401 ConfigFieldMissingException, - ConfigValueCannotBeCoercedException, ConfigFileNotFoundException, - ConfigurationValueError + ConfigurationValueError, + ConfigValueCannotBeCoercedException, +) +from .inject import get_fun_spec, last_config, with_config # noqa: F401 +from .resolve import inject_section, resolve_configuration # noqa: F401 +from .specs import known_sections # noqa: F401 +from .specs.base_configuration import ( # noqa: F401 + configspec, + is_secret_hint, + is_valid_hint, + resolve_type, ) diff --git a/dlt/common/configuration/accessors.py b/dlt/common/configuration/accessors.py index cf71db7030..c5215b0b2a 100644 --- a/dlt/common/configuration/accessors.py +++ b/dlt/common/configuration/accessors.py @@ -1,22 +1,23 @@ import abc import contextlib -import tomlkit from typing import Any, ClassVar, List, Sequence, Tuple, Type, TypeVar +import tomlkit + from dlt.common.configuration.container import Container from dlt.common.configuration.exceptions import ConfigFieldMissingException, LookupTrace from dlt.common.configuration.providers.provider import ConfigProvider from dlt.common.configuration.specs import BaseConfiguration, is_base_configuration_inner_hint -from dlt.common.configuration.utils import deserialize_value, log_traces, auto_cast from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext +from dlt.common.configuration.utils import auto_cast, deserialize_value, log_traces from dlt.common.typing import AnyType, ConfigValue, TSecretValue DLT_SECRETS_VALUE = "secrets.value" DLT_CONFIG_VALUE = "config.value" TConfigAny = TypeVar("TConfigAny", bound=Any) -class _Accessor(abc.ABC): +class _Accessor(abc.ABC): def __getitem__(self, field: str) -> Any: value, traces = self._get_value(field) if value is None: @@ -100,7 +101,11 @@ def default_type(self) -> AnyType: @property def writable_provider(self) -> ConfigProvider: """find first writable provider that does not support secrets - should be config.toml""" - return next(p for p in self._get_providers_from_context() if p.is_writable and not p.supports_secrets) + return next( + p + for p in self._get_providers_from_context() + if p.is_writable and not p.supports_secrets + ) value: ClassVar[None] = ConfigValue "A placeholder that tells dlt to replace it with actual config value during the call to a source or resource decorated function." @@ -121,7 +126,9 @@ def default_type(self) -> AnyType: @property def writable_provider(self) -> ConfigProvider: """find first writable provider that supports secrets - should be secrets.toml""" - return next(p for p in self._get_providers_from_context() if p.is_writable and p.supports_secrets) + return next( + p for p in self._get_providers_from_context() if p.is_writable and p.supports_secrets + ) value: ClassVar[None] = ConfigValue "A placeholder that tells dlt to replace it with actual secret during the call to a source or resource decorated function." diff --git a/dlt/common/configuration/container.py b/dlt/common/configuration/container.py index 46d64f7a37..f5b733c77c 100644 --- a/dlt/common/configuration/container.py +++ b/dlt/common/configuration/container.py @@ -1,8 +1,11 @@ from contextlib import contextmanager from typing import Dict, Iterator, Type, TypeVar +from dlt.common.configuration.exceptions import ( + ContainerInjectableContextMangled, + ContextDefaultCannotBeCreated, +) from dlt.common.configuration.specs.base_configuration import ContainerInjectableContext -from dlt.common.configuration.exceptions import ContainerInjectableContextMangled, ContextDefaultCannotBeCreated TConfiguration = TypeVar("TConfiguration", bound=ContainerInjectableContext) @@ -60,7 +63,6 @@ def __delitem__(self, spec: Type[TConfiguration]) -> None: def __contains__(self, spec: Type[TConfiguration]) -> bool: return spec in self.contexts - @contextmanager def injectable_context(self, config: TConfiguration) -> Iterator[TConfiguration]: """A context manager that will insert `config` into the container and restore the previous value when it gets out of scope.""" diff --git a/dlt/common/configuration/exceptions.py b/dlt/common/configuration/exceptions.py index cae666dab1..e97efd05ff 100644 --- a/dlt/common/configuration/exceptions.py +++ b/dlt/common/configuration/exceptions.py @@ -1,4 +1,4 @@ -from typing import Any, Mapping, Type, Tuple, NamedTuple, Sequence +from typing import Any, Mapping, NamedTuple, Sequence, Tuple, Type from dlt.common.exceptions import DltException, TerminalException @@ -20,17 +20,22 @@ class ConfigurationValueError(ConfigurationException, ValueError): class ContainerException(DltException): """base exception for all exceptions related to injectable container""" + pass class ConfigProviderException(ConfigurationException): """base exceptions for all exceptions raised by config providers""" + pass class ConfigurationWrongTypeException(ConfigurationException): def __init__(self, _typ: type) -> None: - super().__init__(f"Invalid configuration instance type {_typ}. Configuration instances must derive from BaseConfiguration.") + super().__init__( + f"Invalid configuration instance type {_typ}. Configuration instances must derive from" + " BaseConfiguration." + ) class ConfigFieldMissingException(KeyError, ConfigurationException): @@ -43,34 +48,45 @@ def __init__(self, spec_name: str, traces: Mapping[str, Sequence[LookupTrace]]) super().__init__(spec_name) def __str__(self) -> str: - msg = f"Following fields are missing: {str(self.fields)} in configuration with spec {self.spec_name}\n" + msg = ( + f"Following fields are missing: {str(self.fields)} in configuration with spec" + f" {self.spec_name}\n" + ) for f, field_traces in self.traces.items(): msg += f'\tfor field "{f}" config providers and keys were tried in following order:\n' for tr in field_traces: - msg += f'\t\tIn {tr.provider} key {tr.key} was not found.\n' - msg += "Please refer to https://dlthub.com/docs/general-usage/credentials for more information\n" + msg += f"\t\tIn {tr.provider} key {tr.key} was not found.\n" + msg += ( + "Please refer to https://dlthub.com/docs/general-usage/credentials for more" + " information\n" + ) return msg class UnmatchedConfigHintResolversException(ConfigurationException): """Raised when using `@resolve_type` on a field that doesn't exist in the spec""" + def __init__(self, spec_name: str, field_names: Sequence[str]) -> None: self.field_names = field_names self.spec_name = spec_name - example = f">>> class {spec_name}(BaseConfiguration)\n" + "\n".join(f">>> {name}: Any" for name in field_names) + example = f">>> class {spec_name}(BaseConfiguration)\n" + "\n".join( + f">>> {name}: Any" for name in field_names + ) msg = ( - f"The config spec {spec_name} has dynamic type resolvers for fields: {field_names} " - "but these fields are not defined in the spec.\n" - "When using @resolve_type() decorator, Add the fields with 'Any' or another common type hint, example:\n" - f"\n{example}" + f"The config spec {spec_name} has dynamic type resolvers for fields: {field_names} but" + " these fields are not defined in the spec.\nWhen using @resolve_type() decorator, Add" + f" the fields with 'Any' or another common type hint, example:\n\n{example}" ) super().__init__(msg) class FinalConfigFieldException(ConfigurationException): """rises when field was annotated as final ie Final[str] and the value is modified by config provider""" + def __init__(self, spec_name: str, field: str) -> None: - super().__init__(f"Field {field} in spec {spec_name} is final but is being changed by a config provider") + super().__init__( + f"Field {field} in spec {spec_name} is final but is being changed by a config provider" + ) class ConfigValueCannotBeCoercedException(ConfigurationValueError): @@ -80,7 +96,9 @@ def __init__(self, field_name: str, field_value: Any, hint: type) -> None: self.field_name = field_name self.field_value = field_value self.hint = hint - super().__init__('Configured value for field %s cannot be coerced into type %s' % (field_name, str(hint))) + super().__init__( + "Configured value for field %s cannot be coerced into type %s" % (field_name, str(hint)) + ) # class ConfigIntegrityException(ConfigurationException): @@ -106,7 +124,9 @@ class ConfigFieldMissingTypeHintException(ConfigurationException): def __init__(self, field_name: str, spec: Type[Any]) -> None: self.field_name = field_name self.typ_ = spec - super().__init__(f"Field {field_name} on configspec {spec} does not provide required type hint") + super().__init__( + f"Field {field_name} on configspec {spec} does not provide required type hint" + ) class ConfigFieldTypeHintNotSupported(ConfigurationException): @@ -115,25 +135,39 @@ class ConfigFieldTypeHintNotSupported(ConfigurationException): def __init__(self, field_name: str, spec: Type[Any], typ_: Type[Any]) -> None: self.field_name = field_name self.typ_ = spec - super().__init__(f"Field {field_name} on configspec {spec} has hint with unsupported type {typ_}") + super().__init__( + f"Field {field_name} on configspec {spec} has hint with unsupported type {typ_}" + ) class ValueNotSecretException(ConfigurationException): def __init__(self, provider_name: str, key: str) -> None: self.provider_name = provider_name self.key = key - super().__init__(f"Provider {provider_name} cannot hold secret values but key {key} with secret value is present") + super().__init__( + f"Provider {provider_name} cannot hold secret values but key {key} with secret value is" + " present" + ) class InvalidNativeValue(ConfigurationException): - def __init__(self, spec: Type[Any], native_value_type: Type[Any], embedded_sections: Tuple[str, ...], inner_exception: Exception) -> None: + def __init__( + self, + spec: Type[Any], + native_value_type: Type[Any], + embedded_sections: Tuple[str, ...], + inner_exception: Exception, + ) -> None: self.spec = spec self.native_value_type = native_value_type self.embedded_sections = embedded_sections self.inner_exception = inner_exception inner_msg = f" {self.inner_exception}" if inner_exception is not ValueError else "" super().__init__( - f"{spec.__name__} cannot parse the configuration value provided. The value is of type {native_value_type.__name__} and comes from the {embedded_sections} section(s).{inner_msg}") + f"{spec.__name__} cannot parse the configuration value provided. The value is of type" + f" {native_value_type.__name__} and comes from the" + f" {embedded_sections} section(s).{inner_msg}" + ) class ContainerInjectableContextMangled(ContainerException): @@ -141,7 +175,10 @@ def __init__(self, spec: Type[Any], existing_config: Any, expected_config: Any) self.spec = spec self.existing_config = existing_config self.expected_config = expected_config - super().__init__(f"When restoring context {spec.__name__}, instance {expected_config} was expected, instead instance {existing_config} was found.") + super().__init__( + f"When restoring context {spec.__name__}, instance {expected_config} was expected," + f" instead instance {existing_config} was found." + ) class ContextDefaultCannotBeCreated(ContainerException, KeyError): @@ -153,4 +190,6 @@ def __init__(self, spec: Type[Any]) -> None: class DuplicateConfigProviderException(ConfigProviderException): def __init__(self, provider_name: str) -> None: self.provider_name = provider_name - super().__init__(f"Provider with name {provider_name} already present in ConfigProvidersContext") + super().__init__( + f"Provider with name {provider_name} already present in ConfigProvidersContext" + ) diff --git a/dlt/common/configuration/inject.py b/dlt/common/configuration/inject.py index d7419ee378..6714c4bb42 100644 --- a/dlt/common/configuration/inject.py +++ b/dlt/common/configuration/inject.py @@ -1,14 +1,13 @@ import inspect from functools import wraps -from typing import Callable, Dict, Type, Any, Optional, Tuple, TypeVar, overload -from inspect import Signature, Parameter +from inspect import Parameter, Signature +from typing import Any, Callable, Dict, Optional, Tuple, Type, TypeVar, overload -from dlt.common.typing import DictStrAny, StrAny, TFun, AnyFun -from dlt.common.configuration.resolve import resolve_configuration, inject_section +from dlt.common.configuration.resolve import inject_section, resolve_configuration from dlt.common.configuration.specs.base_configuration import BaseConfiguration from dlt.common.configuration.specs.config_section_context import ConfigSectionContext from dlt.common.reflection.spec import spec_from_signature - +from dlt.common.typing import AnyFun, DictStrAny, StrAny, TFun _LAST_DLT_CONFIG = "_dlt_config" _ORIGINAL_ARGS = "_dlt_orig_args" @@ -30,8 +29,8 @@ def with_config( sections: Tuple[str, ...] = (), sections_merge_style: ConfigSectionContext.TMergeFunc = ConfigSectionContext.prefer_incoming, auto_pipeline_section: bool = False, - include_defaults: bool = True -) -> TFun: + include_defaults: bool = True, +) -> TFun: ... @@ -43,8 +42,8 @@ def with_config( sections: Tuple[str, ...] = (), sections_merge_style: ConfigSectionContext.TMergeFunc = ConfigSectionContext.prefer_incoming, auto_pipeline_section: bool = False, - include_defaults: bool = True -) -> Callable[[TFun], TFun]: + include_defaults: bool = True, +) -> Callable[[TFun], TFun]: ... @@ -55,8 +54,8 @@ def with_config( sections: Tuple[str, ...] = (), sections_merge_style: ConfigSectionContext.TMergeFunc = ConfigSectionContext.prefer_incoming, auto_pipeline_section: bool = False, - include_defaults: bool = True -) -> Callable[[TFun], TFun]: + include_defaults: bool = True, +) -> Callable[[TFun], TFun]: """Injects values into decorated function arguments following the specification in `spec` or by deriving one from function's signature. The synthesized spec contains the arguments marked with `dlt.secrets.value` and `dlt.config.value` which are required to be injected at runtime. @@ -81,7 +80,9 @@ def with_config( def decorator(f: TFun) -> TFun: SPEC: Type[BaseConfiguration] = None sig: Signature = inspect.signature(f) - kwargs_arg = next((p for p in sig.parameters.values() if p.kind == Parameter.VAR_KEYWORD), None) + kwargs_arg = next( + (p for p in sig.parameters.values() if p.kind == Parameter.VAR_KEYWORD), None + ) spec_arg: Parameter = None pipeline_name_arg: Parameter = None section_context = ConfigSectionContext(sections=sections, merge_style=sections_merge_style) @@ -106,7 +107,6 @@ def decorator(f: TFun) -> TFun: pipeline_name_arg = p pipeline_name_arg_default = None if p.default == Parameter.empty else p.default - @wraps(f) def _wrap(*args: Any, **kwargs: Any) -> Any: # bind parameters to signature @@ -119,7 +119,7 @@ def _wrap(*args: Any, **kwargs: Any) -> Any: # if section derivation function was provided then call it nonlocal sections if section_f: - section_context.sections = (section_f(bound_args.arguments), ) + section_context.sections = (section_f(bound_args.arguments),) # sections may be a string if isinstance(sections, str): section_context.sections = (sections,) @@ -129,10 +129,14 @@ def _wrap(*args: Any, **kwargs: Any) -> Any: config = bound_args.arguments.get(spec_arg.name, None) # resolve SPEC, also provide section_context with pipeline_name if pipeline_name_arg: - section_context.pipeline_name = bound_args.arguments.get(pipeline_name_arg.name, pipeline_name_arg_default) + section_context.pipeline_name = bound_args.arguments.get( + pipeline_name_arg.name, pipeline_name_arg_default + ) with inject_section(section_context): # print(f"RESOLVE CONF in inject: {f.__name__}: {section_context.sections} vs {sections}") - config = resolve_configuration(config or SPEC(), explicit_value=bound_args.arguments) + config = resolve_configuration( + config or SPEC(), explicit_value=bound_args.arguments + ) resolved_params = dict(config) # overwrite or add resolved params for p in sig.parameters.values(): @@ -162,7 +166,10 @@ def _wrap(*args: Any, **kwargs: Any) -> Any: return decorator if not callable(func): - raise ValueError("First parameter to the with_config must be callable ie. by using it as function decorator") + raise ValueError( + "First parameter to the with_config must be callable ie. by using it as function" + " decorator" + ) # We're called as @with_config without parens. return decorator(func) diff --git a/dlt/common/configuration/paths.py b/dlt/common/configuration/paths.py index f773a779f8..780896ab4b 100644 --- a/dlt/common/configuration/paths.py +++ b/dlt/common/configuration/paths.py @@ -27,11 +27,11 @@ def make_dlt_settings_path(path: str) -> str: def get_dlt_data_dir() -> str: - """ Gets default directory where pipelines' data will be stored - 1. in user home directory: ~/.dlt/ - 2. if current user is root: in /var/dlt/ - 3. if current user does not have a home directory: in /tmp/dlt/ - 4. if DLT_DATA_DIR is set in env then it is used + """Gets default directory where pipelines' data will be stored + 1. in user home directory: ~/.dlt/ + 2. if current user is root: in /var/dlt/ + 3. if current user does not have a home directory: in /tmp/dlt/ + 4. if DLT_DATA_DIR is set in env then it is used """ if "DLT_DATA_DIR" in os.environ: return os.environ["DLT_DATA_DIR"] @@ -49,5 +49,6 @@ def get_dlt_data_dir() -> str: # if home directory is available use ~/.dlt/pipelines return os.path.join(home, DOT_DLT) + def _get_user_home_dir() -> str: return os.path.expanduser("~") diff --git a/dlt/common/configuration/providers/__init__.py b/dlt/common/configuration/providers/__init__.py index f79aff223e..86e6ac1527 100644 --- a/dlt/common/configuration/providers/__init__.py +++ b/dlt/common/configuration/providers/__init__.py @@ -1,6 +1,14 @@ -from .provider import ConfigProvider -from .environ import EnvironProvider +from .context import ContextProvider from .dictionary import DictionaryProvider -from .toml import SecretsTomlProvider, ConfigTomlProvider, TomlFileProvider, CONFIG_TOML, SECRETS_TOML, StringTomlProvider, SECRETS_TOML_KEY +from .environ import EnvironProvider from .google_secrets import GoogleSecretsProvider -from .context import ContextProvider \ No newline at end of file +from .provider import ConfigProvider +from .toml import ( + CONFIG_TOML, + SECRETS_TOML, + SECRETS_TOML_KEY, + ConfigTomlProvider, + SecretsTomlProvider, + StringTomlProvider, + TomlFileProvider, +) diff --git a/dlt/common/configuration/providers/airflow.py b/dlt/common/configuration/providers/airflow.py index 04b0b18be7..ca2f71e728 100644 --- a/dlt/common/configuration/providers/airflow.py +++ b/dlt/common/configuration/providers/airflow.py @@ -7,11 +7,12 @@ def __init__(self, only_secrets: bool = False, only_toml_fragments: bool = False @property def name(self) -> str: - return 'Airflow Secrets TOML Provider' + return "Airflow Secrets TOML Provider" def _look_vault(self, full_key: str, hint: type) -> str: """Get Airflow Variable with given `full_key`, return None if not found""" from airflow.models import Variable + return Variable.get(full_key, default_var=None) # type: ignore @property diff --git a/dlt/common/configuration/providers/context.py b/dlt/common/configuration/providers/context.py index 84e26923a3..c0b727dd62 100644 --- a/dlt/common/configuration/providers/context.py +++ b/dlt/common/configuration/providers/context.py @@ -1,5 +1,5 @@ import contextlib -from typing import Any, ClassVar, Optional, Type, Tuple +from typing import Any, ClassVar, Optional, Tuple, Type from dlt.common.configuration.container import Container from dlt.common.configuration.specs import ContainerInjectableContext @@ -8,7 +8,6 @@ class ContextProvider(ConfigProvider): - NAME: ClassVar[str] = "Injectable Context" def __init__(self) -> None: @@ -18,7 +17,9 @@ def __init__(self) -> None: def name(self) -> str: return ContextProvider.NAME - def get_value(self, key: str, hint: Type[Any], pipeline_name: str = None, *sections: str) -> Tuple[Optional[Any], str]: + def get_value( + self, key: str, hint: Type[Any], pipeline_name: str = None, *sections: str + ) -> Tuple[Optional[Any], str]: assert sections == () # only context is a valid hint diff --git a/dlt/common/configuration/providers/dictionary.py b/dlt/common/configuration/providers/dictionary.py index 40a51eeb72..5329233045 100644 --- a/dlt/common/configuration/providers/dictionary.py +++ b/dlt/common/configuration/providers/dictionary.py @@ -1,5 +1,5 @@ from contextlib import contextmanager -from typing import Any, ClassVar, Iterator, Optional, Type, Tuple +from typing import Any, ClassVar, Iterator, Optional, Tuple, Type from dlt.common.typing import StrAny @@ -7,7 +7,6 @@ class DictionaryProvider(ConfigProvider): - NAME: ClassVar[str] = "Dictionary Provider" def __init__(self) -> None: @@ -17,14 +16,16 @@ def __init__(self) -> None: def name(self) -> str: return self.NAME - def get_value(self, key: str, hint: Type[Any], pipeline_name: str, *sections: str) -> Tuple[Optional[Any], str]: + def get_value( + self, key: str, hint: Type[Any], pipeline_name: str, *sections: str + ) -> Tuple[Optional[Any], str]: full_path = sections + (key,) if pipeline_name: - full_path = (pipeline_name, ) + full_path + full_path = (pipeline_name,) + full_path full_key = get_key_name(key, "__", pipeline_name, *sections) node = self._values try: - for k in full_path: + for k in full_path: if not isinstance(node, dict): raise KeyError(k) node = node[k] diff --git a/dlt/common/configuration/providers/environ.py b/dlt/common/configuration/providers/environ.py index 7406a1207b..0f74f29482 100644 --- a/dlt/common/configuration/providers/environ.py +++ b/dlt/common/configuration/providers/environ.py @@ -1,6 +1,6 @@ from os import environ from os.path import isdir -from typing import Any, Optional, Type, Tuple +from typing import Any, Optional, Tuple, Type from dlt.common.typing import TSecretValue @@ -8,8 +8,8 @@ SECRET_STORAGE_PATH: str = "/run/secrets/%s" -class EnvironProvider(ConfigProvider): +class EnvironProvider(ConfigProvider): @staticmethod def get_key_name(key: str, *sections: str) -> str: return get_key_name(key, "__", *sections).upper() @@ -18,7 +18,9 @@ def get_key_name(key: str, *sections: str) -> str: def name(self) -> str: return "Environment Variables" - def get_value(self, key: str, hint: Type[Any], pipeline_name: str, *sections: str) -> Tuple[Optional[Any], str]: + def get_value( + self, key: str, hint: Type[Any], pipeline_name: str, *sections: str + ) -> Tuple[Optional[Any], str]: # apply section to the key key = self.get_key_name(key, pipeline_name, *sections) if hint is TSecretValue: diff --git a/dlt/common/configuration/providers/google_secrets.py b/dlt/common/configuration/providers/google_secrets.py index ccf891a575..757c56a002 100644 --- a/dlt/common/configuration/providers/google_secrets.py +++ b/dlt/common/configuration/providers/google_secrets.py @@ -4,12 +4,17 @@ from dlt.common.configuration.specs import GcpServiceAccountCredentials from dlt.common.exceptions import MissingDependencyException -from .toml import VaultTomlProvider from .provider import get_key_name +from .toml import VaultTomlProvider class GoogleSecretsProvider(VaultTomlProvider): - def __init__(self, credentials: GcpServiceAccountCredentials, only_secrets: bool = True, only_toml_fragments: bool = True) -> None: + def __init__( + self, + credentials: GcpServiceAccountCredentials, + only_secrets: bool = True, + only_toml_fragments: bool = True, + ) -> None: self.credentials = credentials super().__init__(only_secrets, only_toml_fragments) @@ -26,7 +31,11 @@ def _look_vault(self, full_key: str, hint: type) -> str: from googleapiclient.discovery import build from googleapiclient.errors import HttpError except ModuleNotFoundError: - raise MissingDependencyException("GoogleSecretsProvider", ["google-api-python-client"], "We need google-api-python-client to build client for secretmanager v1") + raise MissingDependencyException( + "GoogleSecretsProvider", + ["google-api-python-client"], + "We need google-api-python-client to build client for secretmanager v1", + ) from dlt.common import logger resource_name = f"projects/{self.credentials.project_id}/secrets/{full_key}/versions/latest" @@ -42,10 +51,17 @@ def _look_vault(self, full_key: str, hint: type) -> str: # logger.warning(f"{self.credentials.client_email} has roles/secretmanager.secretAccessor role but {full_key} not found in Google Secrets: {error_doc['message']}[{error_doc['status']}]") return None elif error.resp.status == 403: - logger.warning(f"{self.credentials.client_email} does not have roles/secretmanager.secretAccessor role. It also does not have read permission to {full_key} or the key is not found in Google Secrets: {error_doc['message']}[{error_doc['status']}]") + logger.warning( + f"{self.credentials.client_email} does not have" + " roles/secretmanager.secretAccessor role. It also does not have read" + f" permission to {full_key} or the key is not found in Google Secrets:" + f" {error_doc['message']}[{error_doc['status']}]" + ) return None elif error.resp.status == 400: - logger.warning(f"Unable to read {full_key} : {error_doc['message']}[{error_doc['status']}]") + logger.warning( + f"Unable to read {full_key} : {error_doc['message']}[{error_doc['status']}]" + ) return None raise @@ -68,4 +84,4 @@ def _look_vault(self, full_key: str, hint: type) -> str: # has_required_role = True # break # if not has_required_role: - # print("no secrets read access") \ No newline at end of file + # print("no secrets read access") diff --git a/dlt/common/configuration/providers/provider.py b/dlt/common/configuration/providers/provider.py index c6bfea5dc3..634ecb53e3 100644 --- a/dlt/common/configuration/providers/provider.py +++ b/dlt/common/configuration/providers/provider.py @@ -1,13 +1,14 @@ import abc -from typing import Any, Tuple, Type, Optional +from typing import Any, Optional, Tuple, Type from dlt.common.configuration.exceptions import ConfigurationException class ConfigProvider(abc.ABC): - @abc.abstractmethod - def get_value(self, key: str, hint: Type[Any], pipeline_name: str, *sections: str) -> Tuple[Optional[Any], str]: + def get_value( + self, key: str, hint: Type[Any], pipeline_name: str, *sections: str + ) -> Tuple[Optional[Any], str]: pass def set_value(self, key: str, value: Any, pipeline_name: str, *sections: str) -> None: diff --git a/dlt/common/configuration/providers/toml.py b/dlt/common/configuration/providers/toml.py index 19374187fb..3b9918a888 100644 --- a/dlt/common/configuration/providers/toml.py +++ b/dlt/common/configuration/providers/toml.py @@ -1,25 +1,25 @@ -import os import abc -import tomlkit import contextlib -from tomlkit.items import Item as TOMLItem -from tomlkit.container import Container as TOMLContainer +import os from typing import Any, Dict, Optional, Tuple, Type, Union +import tomlkit +from tomlkit.container import Container as TOMLContainer +from tomlkit.items import Item as TOMLItem + from dlt.common import pendulum -from dlt.common.configuration.paths import get_dlt_settings_dir, get_dlt_data_dir -from dlt.common.configuration.utils import auto_cast +from dlt.common.configuration.paths import get_dlt_data_dir, get_dlt_settings_dir from dlt.common.configuration.specs import known_sections from dlt.common.configuration.specs.base_configuration import is_secret_hint -from dlt.common.utils import update_dict_nested - +from dlt.common.configuration.utils import auto_cast from dlt.common.typing import AnyType +from dlt.common.utils import update_dict_nested from .provider import ConfigProvider, ConfigProviderException, get_key_name CONFIG_TOML = "config.toml" SECRETS_TOML = "secrets.toml" -SECRETS_TOML_KEY = 'dlt_secrets_toml' +SECRETS_TOML_KEY = "dlt_secrets_toml" class BaseTomlProvider(ConfigProvider): @@ -30,10 +30,12 @@ def __init__(self, toml_document: TOMLContainer) -> None: def get_key_name(key: str, *sections: str) -> str: return get_key_name(key, ".", *sections) - def get_value(self, key: str, hint: Type[Any], pipeline_name: str, *sections: str) -> Tuple[Optional[Any], str]: + def get_value( + self, key: str, hint: Type[Any], pipeline_name: str, *sections: str + ) -> Tuple[Optional[Any], str]: full_path = sections + (key,) if pipeline_name: - full_path = (pipeline_name, ) + full_path + full_path = (pipeline_name,) + full_path full_key = self.get_key_name(key, pipeline_name, *sections) node: Union[TOMLContainer, TOMLItem] = self._toml try: @@ -48,7 +50,7 @@ def get_value(self, key: str, hint: Type[Any], pipeline_name: str, *sections: st def set_value(self, key: str, value: Any, pipeline_name: str, *sections: str) -> None: if pipeline_name: - sections = (pipeline_name, ) + sections + sections = (pipeline_name,) + sections if isinstance(value, TOMLContainer): if key is None: @@ -85,7 +87,6 @@ def is_empty(self) -> bool: class StringTomlProvider(BaseTomlProvider): - def __init__(self, toml_string: str) -> None: super().__init__(StringTomlProvider.loads(toml_string)) @@ -138,12 +139,13 @@ def __init__(self, only_secrets: bool, only_toml_fragments: bool) -> None: super().__init__(tomlkit.document()) self._update_from_vault(SECRETS_TOML_KEY, None, AnyType, None, ()) - def get_value(self, key: str, hint: type, pipeline_name: str, *sections: str) -> Tuple[Optional[Any], str]: + def get_value( + self, key: str, hint: type, pipeline_name: str, *sections: str + ) -> Tuple[Optional[Any], str]: full_key = self.get_key_name(key, pipeline_name, *sections) value, _ = super().get_value(key, hint, pipeline_name, *sections) if value is None: - # only secrets hints are handled if self.only_secrets and not is_secret_hint(hint) and hint is not AnyType: return None, full_key @@ -153,7 +155,6 @@ def get_value(self, key: str, hint: type, pipeline_name: str, *sections: str) -> lookup_fk = self.get_key_name(SECRETS_TOML_KEY, pipeline_name) self._update_from_vault(lookup_fk, "", AnyType, pipeline_name, ()) - # generate auxiliary paths to get from vault for known_section in [known_sections.SOURCES, known_sections.DESTINATION]: @@ -161,7 +162,9 @@ def _look_at_idx(idx: int, full_path: Tuple[str, ...], pipeline_name: str) -> No lookup_key = full_path[idx] lookup_sections = full_path[:idx] lookup_fk = self.get_key_name(lookup_key, *lookup_sections) - self._update_from_vault(lookup_fk, lookup_key, AnyType, pipeline_name, lookup_sections) + self._update_from_vault( + lookup_fk, lookup_key, AnyType, pipeline_name, lookup_sections + ) def _lookup_paths(pipeline_name_: str, known_section_: str) -> None: with contextlib.suppress(ValueError): @@ -177,7 +180,9 @@ def _lookup_paths(pipeline_name_: str, known_section_: str) -> None: # first query the shortest paths so the longer paths can override it _lookup_paths(None, known_section) # check sources and sources. if pipeline_name: - _lookup_paths(pipeline_name, known_section) # check .sources and .sources. + _lookup_paths( + pipeline_name, known_section + ) # check .sources and .sources. value, _ = super().get_value(key, hint, pipeline_name, *sections) # skip checking the exact path if we check only toml fragments @@ -200,7 +205,9 @@ def supports_secrets(self) -> bool: def _look_vault(self, full_key: str, hint: type) -> str: pass - def _update_from_vault(self, full_key: str, key: str, hint: type, pipeline_name: str, sections: Tuple[str, ...]) -> None: + def _update_from_vault( + self, full_key: str, key: str, hint: type, pipeline_name: str, sections: Tuple[str, ...] + ) -> None: if full_key in self._vault_lookups: return # print(f"tries '{key}' {pipeline_name} | {sections} at '{full_key}'") @@ -213,8 +220,11 @@ def _update_from_vault(self, full_key: str, key: str, hint: type, pipeline_name: def is_empty(self) -> bool: return False + class TomlFileProvider(BaseTomlProvider): - def __init__(self, file_name: str, project_dir: str = None, add_global_config: bool = False) -> None: + def __init__( + self, file_name: str, project_dir: str = None, add_global_config: bool = False + ) -> None: """Creates config provider from a `toml` file The provider loads the `toml` file with specified name and from specified folder. If `add_global_config` flags is specified, @@ -233,7 +243,9 @@ def __init__(self, file_name: str, project_dir: str = None, add_global_config: b toml_document = self._read_toml_file(file_name, project_dir, add_global_config) super().__init__(toml_document) - def _read_toml_file(self, file_name: str, project_dir: str = None, add_global_config: bool = False) -> tomlkit.TOMLDocument: + def _read_toml_file( + self, file_name: str, project_dir: str = None, add_global_config: bool = False + ) -> tomlkit.TOMLDocument: self._file_name = file_name self._toml_path = os.path.join(project_dir or get_dlt_settings_dir(), file_name) self._add_global_config = add_global_config @@ -251,7 +263,9 @@ def global_config_path() -> str: return get_dlt_data_dir() def write_toml(self) -> None: - assert not self._add_global_config, "Will not write configs when `add_global_config` flag was set" + assert ( + not self._add_global_config + ), "Will not write configs when `add_global_config` flag was set" with open(self._toml_path, "w", encoding="utf-8") as f: tomlkit.dump(self._toml, f) @@ -266,7 +280,6 @@ def _read_toml(toml_path: str) -> tomlkit.TOMLDocument: class ConfigTomlProvider(TomlFileProvider): - def __init__(self, project_dir: str = None, add_global_config: bool = False) -> None: super().__init__(CONFIG_TOML, project_dir=project_dir, add_global_config=add_global_config) @@ -284,7 +297,6 @@ def is_writable(self) -> bool: class SecretsTomlProvider(TomlFileProvider): - def __init__(self, project_dir: str = None, add_global_config: bool = False) -> None: super().__init__(SECRETS_TOML, project_dir=project_dir, add_global_config=add_global_config) @@ -302,7 +314,9 @@ def is_writable(self) -> bool: class TomlProviderReadException(ConfigProviderException): - def __init__(self, provider_name: str, file_name: str, full_path: str, toml_exception: str) -> None: + def __init__( + self, provider_name: str, file_name: str, full_path: str, toml_exception: str + ) -> None: self.file_name = file_name self.full_path = full_path msg = f"A problem encountered when loading {provider_name} from {full_path}:\n" diff --git a/dlt/common/configuration/resolve.py b/dlt/common/configuration/resolve.py index 68421d7d4b..9353a77761 100644 --- a/dlt/common/configuration/resolve.py +++ b/dlt/common/configuration/resolve.py @@ -1,24 +1,51 @@ import itertools from collections.abc import Mapping as C_Mapping -from typing import Any, Dict, ContextManager, List, Optional, Sequence, Tuple, Type, TypeVar +from typing import Any, ContextManager, Dict, List, Optional, Sequence, Tuple, Type, TypeVar +from dlt.common.configuration.container import Container +from dlt.common.configuration.exceptions import ( + ConfigFieldMissingException, + ConfigurationWrongTypeException, + FinalConfigFieldException, + InvalidNativeValue, + LookupTrace, + UnmatchedConfigHintResolversException, + ValueNotSecretException, +) from dlt.common.configuration.providers.provider import ConfigProvider -from dlt.common.typing import AnyType, StrAny, TSecretValue, get_all_types_of_class_in_union, is_final_type, is_optional_type, is_union - -from dlt.common.configuration.specs.base_configuration import BaseConfiguration, CredentialsConfiguration, is_secret_hint, extract_inner_hint, is_context_inner_hint, is_base_configuration_inner_hint, is_valid_hint +from dlt.common.configuration.specs.base_configuration import ( + BaseConfiguration, + CredentialsConfiguration, + extract_inner_hint, + is_base_configuration_inner_hint, + is_context_inner_hint, + is_secret_hint, + is_valid_hint, +) +from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext from dlt.common.configuration.specs.config_section_context import ConfigSectionContext from dlt.common.configuration.specs.exceptions import NativeValueError -from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext -from dlt.common.configuration.container import Container -from dlt.common.configuration.utils import log_traces, deserialize_value -from dlt.common.configuration.exceptions import ( - FinalConfigFieldException, LookupTrace, ConfigFieldMissingException, ConfigurationWrongTypeException, - ValueNotSecretException, InvalidNativeValue, UnmatchedConfigHintResolversException) +from dlt.common.configuration.utils import deserialize_value, log_traces +from dlt.common.typing import ( + AnyType, + StrAny, + TSecretValue, + get_all_types_of_class_in_union, + is_final_type, + is_optional_type, + is_union, +) TConfiguration = TypeVar("TConfiguration", bound=BaseConfiguration) -def resolve_configuration(config: TConfiguration, *, sections: Tuple[str, ...] = (), explicit_value: Any = None, accept_partial: bool = False) -> TConfiguration: +def resolve_configuration( + config: TConfiguration, + *, + sections: Tuple[str, ...] = (), + explicit_value: Any = None, + accept_partial: bool = False +) -> TConfiguration: if not isinstance(config, BaseConfiguration): raise ConfigurationWrongTypeException(type(config)) @@ -26,7 +53,9 @@ def resolve_configuration(config: TConfiguration, *, sections: Tuple[str, ...] = # allows, for example, to store connection string or service.json in their native form in single env variable or under single vault key if config.__section__ and explicit_value is None: initial_hint = TSecretValue if isinstance(config, CredentialsConfiguration) else AnyType - explicit_value, traces = _resolve_single_value(config.__section__, initial_hint, AnyType, None, sections, ()) + explicit_value, traces = _resolve_single_value( + config.__section__, initial_hint, AnyType, None, sections, () + ) if isinstance(explicit_value, C_Mapping): # mappings cannot be used as explicit values, we want to enumerate mappings and request the fields' values one by one explicit_value = None @@ -63,7 +92,9 @@ def initialize_credentials(hint: Any, initial_value: Any) -> CredentialsConfigur return hint(initial_value) # type: ignore -def inject_section(section_context: ConfigSectionContext, merge_existing: bool = True) -> ContextManager[ConfigSectionContext]: +def inject_section( + section_context: ConfigSectionContext, merge_existing: bool = True +) -> ContextManager[ConfigSectionContext]: """Context manager that sets section specified in `section_context` to be used during configuration resolution. Optionally merges the context already in the container with the one provided Args: @@ -84,9 +115,14 @@ def inject_section(section_context: ConfigSectionContext, merge_existing: bool = return container.injectable_context(section_context) -def _maybe_parse_native_value(config: TConfiguration, explicit_value: Any, embedded_sections: Tuple[str, ...]) -> Any: + +def _maybe_parse_native_value( + config: TConfiguration, explicit_value: Any, embedded_sections: Tuple[str, ...] +) -> Any: # use initial value to resolve the whole configuration. if explicit value is a mapping it will be applied field by field later - if explicit_value and (not isinstance(explicit_value, C_Mapping) or isinstance(explicit_value, BaseConfiguration)): + if explicit_value and ( + not isinstance(explicit_value, C_Mapping) or isinstance(explicit_value, BaseConfiguration) + ): # print(f"TRYING TO PARSE NATIVE from {explicit_value}") try: config.parse_native_representation(explicit_value) @@ -99,13 +135,14 @@ def _maybe_parse_native_value(config: TConfiguration, explicit_value: Any, embed explicit_value = None return explicit_value + def _resolve_configuration( - config: TConfiguration, - explicit_sections: Tuple[str, ...], - embedded_sections: Tuple[str, ...], - explicit_value: Any, - accept_partial: bool - ) -> TConfiguration: + config: TConfiguration, + explicit_sections: Tuple[str, ...], + embedded_sections: Tuple[str, ...], + explicit_value: Any, + accept_partial: bool, +) -> TConfiguration: # do not resolve twice if config.is_resolved(): return config @@ -116,7 +153,9 @@ def _resolve_configuration( explicit_value = _maybe_parse_native_value(config, explicit_value, embedded_sections) # if native representation didn't fully resolve the config, we try to resolve field by field if not config.is_resolved(): - _resolve_config_fields(config, explicit_value, explicit_sections, embedded_sections, accept_partial) + _resolve_config_fields( + config, explicit_value, explicit_sections, embedded_sections, accept_partial + ) # full configuration was resolved config.resolve() except ConfigFieldMissingException as cm_ex: @@ -136,13 +175,12 @@ def _resolve_configuration( def _resolve_config_fields( - config: BaseConfiguration, - explicit_values: StrAny, - explicit_sections: Tuple[str, ...], - embedded_sections: Tuple[str, ...], - accept_partial: bool - ) -> None: - + config: BaseConfiguration, + explicit_values: StrAny, + explicit_sections: Tuple[str, ...], + embedded_sections: Tuple[str, ...], + accept_partial: bool, +) -> None: fields = config.get_resolvable_fields() unresolved_fields: Dict[str, Sequence[LookupTrace]] = {} @@ -169,7 +207,11 @@ def _resolve_config_fields( if is_union(hint): # print(f"HINT UNION?: {key}:{hint}") # if union contains a type of explicit value which is not a valid hint, return it as current value - if explicit_value and not is_valid_hint(type(explicit_value)) and get_all_types_of_class_in_union(hint, type(explicit_value)): + if ( + explicit_value + and not is_valid_hint(type(explicit_value)) + and get_all_types_of_class_in_union(hint, type(explicit_value)) + ): current_value, traces = explicit_value, [] else: specs_in_union = get_all_types_of_class_in_union(hint, BaseConfiguration) @@ -187,7 +229,7 @@ def _resolve_config_fields( config.__section__, explicit_sections, embedded_sections, - accept_partial + accept_partial, ) break except ConfigFieldMissingException as cfm_ex: @@ -208,7 +250,7 @@ def _resolve_config_fields( config.__section__, explicit_sections, embedded_sections, - accept_partial + accept_partial, ) # check if hint optional @@ -236,17 +278,16 @@ def _resolve_config_fields( def _resolve_config_field( - key: str, - hint: Type[Any], - default_value: Any, - explicit_value: Any, - config: BaseConfiguration, - config_sections: str, - explicit_sections: Tuple[str, ...], - embedded_sections: Tuple[str, ...], - accept_partial: bool - ) -> Tuple[Any, List[LookupTrace]]: - + key: str, + hint: Type[Any], + default_value: Any, + explicit_value: Any, + config: BaseConfiguration, + config_sections: str, + explicit_sections: Tuple[str, ...], + embedded_sections: Tuple[str, ...], + accept_partial: bool, +) -> Tuple[Any, List[LookupTrace]]: inner_hint = extract_inner_hint(hint) if explicit_value is not None: @@ -254,7 +295,9 @@ def _resolve_config_field( traces: List[LookupTrace] = [] else: # resolve key value via active providers passing the original hint ie. to preserve TSecretValue - value, traces = _resolve_single_value(key, hint, inner_hint, config_sections, explicit_sections, embedded_sections) + value, traces = _resolve_single_value( + key, hint, inner_hint, config_sections, explicit_sections, embedded_sections + ) log_traces(config, key, hint, value, default_value, traces) # contexts must be resolved as a whole if is_context_inner_hint(inner_hint): @@ -283,23 +326,44 @@ def _resolve_config_field( # only config with sections may look for initial values if embedded_config.__section__ and value is None: # config section becomes the key if the key does not start with, otherwise it keeps its original value - initial_key, initial_embedded = _apply_embedded_sections_to_config_sections(embedded_config.__section__, embedded_sections + (key,)) + initial_key, initial_embedded = _apply_embedded_sections_to_config_sections( + embedded_config.__section__, embedded_sections + (key,) + ) # it must be a secret value is config is credentials - initial_hint = TSecretValue if isinstance(embedded_config, CredentialsConfiguration) else AnyType - value, initial_traces = _resolve_single_value(initial_key, initial_hint, AnyType, None, explicit_sections, initial_embedded) + initial_hint = ( + TSecretValue + if isinstance(embedded_config, CredentialsConfiguration) + else AnyType + ) + value, initial_traces = _resolve_single_value( + initial_key, initial_hint, AnyType, None, explicit_sections, initial_embedded + ) if isinstance(value, C_Mapping): # mappings are not passed as initials value = None else: traces.extend(initial_traces) - log_traces(config, initial_key, type(embedded_config), value, default_value, initial_traces) + log_traces( + config, + initial_key, + type(embedded_config), + value, + default_value, + initial_traces, + ) # check if hint optional is_optional = is_optional_type(hint) # accept partial becomes True if type if optional so we do not fail on optional configs that do not resolve fully accept_partial = accept_partial or is_optional # create new instance and pass value from the provider as initial, add key to sections - value = _resolve_configuration(embedded_config, explicit_sections, embedded_sections + (key,), default_value if value is None else value, accept_partial) + value = _resolve_configuration( + embedded_config, + explicit_sections, + embedded_sections + (key,), + default_value if value is None else value, + accept_partial, + ) if value.is_partial() and is_optional: # do not return partially resolved optional embeds value = None @@ -314,14 +378,13 @@ def _resolve_config_field( def _resolve_single_value( - key: str, - hint: Type[Any], - inner_hint: Type[Any], - config_section: str, - explicit_sections: Tuple[str, ...], - embedded_sections: Tuple[str, ...] - ) -> Tuple[Optional[Any], List[LookupTrace]]: - + key: str, + hint: Type[Any], + inner_hint: Type[Any], + config_section: str, + explicit_sections: Tuple[str, ...], + embedded_sections: Tuple[str, ...], +) -> Tuple[Optional[Any], List[LookupTrace]]: traces: List[LookupTrace] = [] value = None @@ -338,7 +401,9 @@ def _resolve_single_value( return value, traces # resolve a field of the config - config_section, embedded_sections = _apply_embedded_sections_to_config_sections(config_section, embedded_sections) + config_section, embedded_sections = _apply_embedded_sections_to_config_sections( + config_section, embedded_sections + ) providers = providers_context.providers # get additional sections to look in from container sections_context = container[ConfigSectionContext] @@ -359,7 +424,7 @@ def look_sections(pipeline_name: str = None) -> Any: config_section, # if explicit sections are provided, ignore the injected context explicit_sections or sections_context.sections, - embedded_sections + embedded_sections, ) traces.extend(provider_traces) if value is not None: @@ -385,7 +450,7 @@ def resolve_single_provider_value( pipeline_name: str = None, config_section: str = None, explicit_sections: Tuple[str, ...] = (), - embedded_sections: Tuple[str, ...] = () + embedded_sections: Tuple[str, ...] = (), ) -> Tuple[Optional[Any], List[LookupTrace]]: traces: List[LookupTrace] = [] @@ -432,7 +497,9 @@ def resolve_single_provider_value( return value, traces -def _apply_embedded_sections_to_config_sections(config_section: str, embedded_sections: Tuple[str, ...]) -> Tuple[str, Tuple[str, ...]]: +def _apply_embedded_sections_to_config_sections( + config_section: str, embedded_sections: Tuple[str, ...] +) -> Tuple[str, Tuple[str, ...]]: # for the configurations that have __section__ (config_section) defined and are embedded in other configurations, # the innermost embedded section replaces config_section if embedded_sections: diff --git a/dlt/common/configuration/specs/__init__.py b/dlt/common/configuration/specs/__init__.py index 675b0a0bec..4989a9b709 100644 --- a/dlt/common/configuration/specs/__init__.py +++ b/dlt/common/configuration/specs/__init__.py @@ -1,12 +1,24 @@ -from .run_configuration import RunConfiguration # noqa: F401 -from .base_configuration import BaseConfiguration, CredentialsConfiguration, CredentialsWithDefault, ContainerInjectableContext, extract_inner_hint, is_base_configuration_inner_hint, configspec # noqa: F401 -from .config_section_context import ConfigSectionContext # noqa: F401 - -from .gcp_credentials import GcpServiceAccountCredentialsWithoutDefaults, GcpServiceAccountCredentials, GcpOAuthCredentialsWithoutDefaults, GcpOAuthCredentials, GcpCredentials # noqa: F401 -from .connection_string_credentials import ConnectionStringCredentials # noqa: F401 from .api_credentials import OAuth2Credentials # noqa: F401 from .aws_credentials import AwsCredentials, AwsCredentialsWithoutDefaults # noqa: F401 - +from .base_configuration import ( # noqa: F401 + BaseConfiguration, + ContainerInjectableContext, + CredentialsConfiguration, + CredentialsWithDefault, + configspec, + extract_inner_hint, + is_base_configuration_inner_hint, +) +from .config_section_context import ConfigSectionContext # noqa: F401 +from .connection_string_credentials import ConnectionStringCredentials # noqa: F401 # backward compatibility for service account credentials -from .gcp_credentials import GcpServiceAccountCredentialsWithoutDefaults as GcpClientCredentials, GcpServiceAccountCredentials as GcpClientCredentialsWithDefault # noqa: F401 +from .gcp_credentials import GcpServiceAccountCredentialsWithoutDefaults # noqa: F401 +from .gcp_credentials import ( # noqa: F401 + GcpCredentials, + GcpOAuthCredentials, + GcpOAuthCredentialsWithoutDefaults, +) +from .gcp_credentials import GcpServiceAccountCredentials +from .gcp_credentials import GcpServiceAccountCredentials as GcpClientCredentialsWithDefault +from .run_configuration import RunConfiguration # noqa: F401 diff --git a/dlt/common/configuration/specs/api_credentials.py b/dlt/common/configuration/specs/api_credentials.py index 6a06a42713..8f80025061 100644 --- a/dlt/common/configuration/specs/api_credentials.py +++ b/dlt/common/configuration/specs/api_credentials.py @@ -1,7 +1,7 @@ -from typing import ClassVar, List, Union, Optional +from typing import ClassVar, List, Optional, Union -from dlt.common.typing import TSecretValue from dlt.common.configuration.specs.base_configuration import CredentialsConfiguration, configspec +from dlt.common.typing import TSecretValue @configspec @@ -17,7 +17,6 @@ class OAuth2Credentials(CredentialsConfiguration): # add refresh_token when generating config samples __config_gen_annotations__: ClassVar[List[str]] = ["refresh_token"] - def auth(self, scopes: Union[str, List[str]] = None, redirect_url: str = None) -> None: """Authorizes the client using the available credentials @@ -44,4 +43,3 @@ def add_scopes(self, scopes: Union[str, List[str]]) -> None: self.scopes += [scopes] elif scopes: self.scopes = list(set(self.scopes + scopes)) - diff --git a/dlt/common/configuration/specs/aws_credentials.py b/dlt/common/configuration/specs/aws_credentials.py index 3839bee91b..5d8acc8b43 100644 --- a/dlt/common/configuration/specs/aws_credentials.py +++ b/dlt/common/configuration/specs/aws_credentials.py @@ -1,10 +1,14 @@ -from typing import Optional, Dict, Any +from typing import Any, Dict, Optional +from dlt import version +from dlt.common.configuration.specs import ( + CredentialsConfiguration, + CredentialsWithDefault, + configspec, +) +from dlt.common.configuration.specs.exceptions import InvalidBoto3Session from dlt.common.exceptions import MissingDependencyException from dlt.common.typing import TSecretStrValue -from dlt.common.configuration.specs import CredentialsConfiguration, CredentialsWithDefault, configspec -from dlt.common.configuration.specs.exceptions import InvalidBoto3Session -from dlt import version @configspec @@ -22,7 +26,7 @@ def to_s3fs_credentials(self) -> Dict[str, Optional[str]]: key=self.aws_access_key_id, secret=self.aws_secret_access_key, token=self.aws_session_token, - profile=self.profile_name + profile=self.profile_name, ) def to_native_representation(self) -> Dict[str, Optional[str]]: @@ -32,7 +36,6 @@ def to_native_representation(self) -> Dict[str, Optional[str]]: @configspec class AwsCredentials(AwsCredentialsWithoutDefaults, CredentialsWithDefault): - def on_partial(self) -> None: # Try get default credentials session = self._to_session() @@ -43,12 +46,15 @@ def _to_session(self) -> Any: try: import boto3 except ModuleNotFoundError: - raise MissingDependencyException(self.__class__.__name__, [f"{version.DLT_PKG_NAME}[s3]"]) + raise MissingDependencyException( + self.__class__.__name__, [f"{version.DLT_PKG_NAME}[s3]"] + ) return boto3.Session(**self.to_native_representation()) # type: ignore def _from_session(self, session: Any) -> Any: """Sets the credentials properties from boto3 `session` and return session's credentials if found""" import boto3 + assert isinstance(session, boto3.Session) # NOTE: we do not set profile name from boto3 session # we either pass it explicitly in `_to_session` so we know it is identical @@ -69,6 +75,7 @@ def parse_native_representation(self, native_value: Any) -> None: """Import external boto session""" try: import boto3 + if isinstance(native_value, boto3.Session): if self._from_session(native_value): self.__is_resolved__ = True diff --git a/dlt/common/configuration/specs/base_configuration.py b/dlt/common/configuration/specs/base_configuration.py index d4af2e8555..8513a9e542 100644 --- a/dlt/common/configuration/specs/base_configuration.py +++ b/dlt/common/configuration/specs/base_configuration.py @@ -1,20 +1,44 @@ -import copy -import inspect import contextlib +import copy import dataclasses +import inspect from collections.abc import Mapping as C_Mapping -from typing import Callable, List, Optional, Union, Any, Dict, Iterator, MutableMapping, Type, TYPE_CHECKING, get_args, get_origin, overload, ClassVar, TypeVar from functools import wraps +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Dict, + Iterator, + List, + MutableMapping, + Optional, + Type, + TypeVar, + Union, + get_args, + get_origin, + overload, +) if TYPE_CHECKING: TDtcField = dataclasses.Field[Any] else: TDtcField = dataclasses.Field -from dlt.common.typing import TAnyClass, TSecretValue, extract_inner_type, is_optional_type, is_union +from dlt.common.configuration.exceptions import ( + ConfigFieldMissingTypeHintException, + ConfigFieldTypeHintNotSupported, +) from dlt.common.data_types import py_type_to_sc_type -from dlt.common.configuration.exceptions import ConfigFieldMissingTypeHintException, ConfigFieldTypeHintNotSupported - +from dlt.common.typing import ( + TAnyClass, + TSecretValue, + extract_inner_type, + is_optional_type, + is_union, +) # forward class declaration _F_BaseConfiguration: Any = type(object) @@ -68,7 +92,7 @@ def extract_inner_hint(hint: Type[Any], preserve_new_types: bool = False) -> Typ def is_secret_hint(hint: Type[Any]) -> bool: - is_secret = False + is_secret = False if hasattr(hint, "__name__"): is_secret = hint.__name__ == "TSecretValue" if not is_secret: @@ -91,7 +115,9 @@ def configspec(cls: None = ...) -> Callable[[Type[TAnyClass]], Type[TAnyClass]]: ... -def configspec(cls: Optional[Type[Any]] = None) -> Union[Type[TAnyClass], Callable[[Type[TAnyClass]], Type[TAnyClass]]]: +def configspec( + cls: Optional[Type[Any]] = None, +) -> Union[Type[TAnyClass], Callable[[Type[TAnyClass]], Type[TAnyClass]]]: """Converts (via derivation) any decorated class to a Python dataclass that may be used as a spec to resolve configurations In comparison the Python dataclass, a spec implements full dictionary interface for its attributes, allows instance creation from ie. strings @@ -99,6 +125,7 @@ def configspec(cls: Optional[Type[Any]] = None) -> Union[Type[TAnyClass], Callab more information. """ + def wrap(cls: Type[TAnyClass]) -> Type[TAnyClass]: cls.__hint_resolvers__ = {} # type: ignore[attr-defined] is_context = issubclass(cls, _F_ContainerInjectableContext) @@ -106,8 +133,11 @@ def wrap(cls: Type[TAnyClass]) -> Type[TAnyClass]: with contextlib.suppress(NameError): if not issubclass(cls, BaseConfiguration): # keep the original module and keep defaults for fields listed in annotations - fields = {"__module__": cls.__module__, "__annotations__": getattr(cls, "__annotations__", {})} - for key in fields['__annotations__'].keys(): # type: ignore[union-attr] + fields = { + "__module__": cls.__module__, + "__annotations__": getattr(cls, "__annotations__", {}), + } + for key in fields["__annotations__"].keys(): # type: ignore[union-attr] if key in cls.__dict__: fields[key] = cls.__dict__[key] cls = type(cls.__name__, (cls, _F_BaseConfiguration), fields) @@ -129,7 +159,9 @@ def wrap(cls: Type[TAnyClass]) -> Type[TAnyClass]: except NameError: # Dealing with BaseConfiguration itself before it is defined continue - if not att_name.startswith(("__", "_abc_impl")) and not isinstance(att_value, (staticmethod, classmethod, property)): + if not att_name.startswith(("__", "_abc_impl")) and not isinstance( + att_value, (staticmethod, classmethod, property) + ): if att_name not in cls.__annotations__: raise ConfigFieldMissingTypeHintException(att_name, cls) hint = cls.__annotations__[att_name] @@ -142,8 +174,8 @@ def wrap(cls: Type[TAnyClass]) -> Type[TAnyClass]: # blocking mutable defaults def default_factory(att_value=att_value): # type: ignore[no-untyped-def] return att_value.copy() - setattr(cls, att_name, dataclasses.field(default_factory=default_factory)) + setattr(cls, att_name, dataclasses.field(default_factory=default_factory)) # We don't want to overwrite user's __init__ method # Create dataclass init only when not defined in the class @@ -168,12 +200,11 @@ def default_factory(att_value=att_value): # type: ignore[no-untyped-def] @configspec class BaseConfiguration(MutableMapping[str, Any]): - - __is_resolved__: bool = dataclasses.field(default = False, init=False, repr=False) + __is_resolved__: bool = dataclasses.field(default=False, init=False, repr=False) """True when all config fields were resolved and have a specified value type""" - __section__: str = dataclasses.field(default = None, init=False, repr=False) + __section__: str = dataclasses.field(default=None, init=False, repr=False) """Obligatory section used by config providers when searching for keys, always present in the search path""" - __exception__: Exception = dataclasses.field(default = None, init=False, repr=False) + __exception__: Exception = dataclasses.field(default=None, init=False, repr=False) """Holds the exception that prevented the full resolution""" __config_gen_annotations__: ClassVar[List[str]] = [] """Additional annotations for config generator, currently holds a list of fields of interest that have defaults""" @@ -181,7 +212,6 @@ class BaseConfiguration(MutableMapping[str, Any]): """Typing for dataclass fields""" __hint_resolvers__: ClassVar[Dict[str, Callable[["BaseConfiguration"], Type[Any]]]] = {} - def parse_native_representation(self, native_value: Any) -> None: """Initialize the configuration fields by parsing the `native_value` which should be a native representation of the configuration or credentials, for example database connection string or JSON serialized GCP service credentials file. @@ -212,7 +242,7 @@ def _get_resolvable_dataclass_fields(cls) -> Iterator[TDtcField]: # Sort dynamic type hint fields last because they depend on other values yield from sorted( (f for f in cls.__dataclass_fields__.values() if not f.name.startswith("__")), - key=lambda f: f.name in cls.__hint_resolvers__ + key=lambda f: f.name in cls.__hint_resolvers__, ) @classmethod @@ -229,7 +259,9 @@ def is_partial(self) -> bool: return False # check if all resolvable fields have value return any( - field for field, hint in self.get_resolvable_fields().items() if getattr(self, field) is None and not is_optional_type(hint) + field + for field, hint in self.get_resolvable_fields().items() + if getattr(self, field) is None and not is_optional_type(hint) ) def resolve(self) -> None: @@ -330,7 +362,7 @@ def to_native_credentials(self) -> Any: return self.to_native_representation() def __str__(self) -> str: - """Get string representation of credentials to be displayed, with all secret parts removed """ + """Get string representation of credentials to be displayed, with all secret parts removed""" return super().__str__() @@ -367,11 +399,15 @@ def add_extras(self) -> None: TSpec = TypeVar("TSpec", bound=BaseConfiguration) THintResolver = Callable[[TSpec], Type[Any]] + def resolve_type(field_name: str) -> Callable[[THintResolver[TSpec]], THintResolver[TSpec]]: def decorator(func: THintResolver[TSpec]) -> THintResolver[TSpec]: func.__hint_for_field__ = field_name # type: ignore[attr-defined] + @wraps(func) def wrapper(self: TSpec) -> Type[Any]: return func(self) + return wrapper + return decorator diff --git a/dlt/common/configuration/specs/config_providers_context.py b/dlt/common/configuration/specs/config_providers_context.py index 02a7397472..2410909833 100644 --- a/dlt/common/configuration/specs/config_providers_context.py +++ b/dlt/common/configuration/specs/config_providers_context.py @@ -1,10 +1,23 @@ import contextlib import io from typing import List + from dlt.common.configuration.exceptions import DuplicateConfigProviderException -from dlt.common.configuration.providers import ConfigProvider, EnvironProvider, ContextProvider, SecretsTomlProvider, ConfigTomlProvider, GoogleSecretsProvider +from dlt.common.configuration.providers import ( + ConfigProvider, + ConfigTomlProvider, + ContextProvider, + EnvironProvider, + GoogleSecretsProvider, + SecretsTomlProvider, +) +from dlt.common.configuration.specs import ( + BaseConfiguration, + GcpServiceAccountCredentials, + configspec, + known_sections, +) from dlt.common.configuration.specs.base_configuration import ContainerInjectableContext -from dlt.common.configuration.specs import GcpServiceAccountCredentials, BaseConfiguration, configspec, known_sections from dlt.common.runtime.exec_info import is_airflow_installed @@ -21,6 +34,7 @@ class ConfigProvidersConfiguration(BaseConfiguration): @configspec class ConfigProvidersContext(ContainerInjectableContext): """Injectable list of providers used by the configuration `resolve` module""" + providers: List[ConfigProvider] context_provider: ConfigProvider @@ -70,27 +84,36 @@ def _initial_providers() -> List[ConfigProvider]: providers = [ EnvironProvider(), SecretsTomlProvider(add_global_config=True), - ConfigTomlProvider(add_global_config=True) + ConfigTomlProvider(add_global_config=True), ] return providers def _extra_providers() -> List[ConfigProvider]: from dlt.common.configuration.resolve import resolve_configuration + providers_config = resolve_configuration(ConfigProvidersConfiguration()) extra_providers = [] if providers_config.enable_airflow_secrets: extra_providers.extend(_airflow_providers()) if providers_config.enable_google_secrets: - extra_providers.append(_google_secrets_provider(only_toml_fragments=providers_config.only_toml_fragments)) + extra_providers.append( + _google_secrets_provider(only_toml_fragments=providers_config.only_toml_fragments) + ) return extra_providers -def _google_secrets_provider(only_secrets: bool = True, only_toml_fragments: bool = True) -> ConfigProvider: +def _google_secrets_provider( + only_secrets: bool = True, only_toml_fragments: bool = True +) -> ConfigProvider: from dlt.common.configuration.resolve import resolve_configuration - c = resolve_configuration(GcpServiceAccountCredentials(), sections=(known_sections.PROVIDERS, "google_secrets")) - return GoogleSecretsProvider(c, only_secrets=only_secrets, only_toml_fragments=only_toml_fragments) + c = resolve_configuration( + GcpServiceAccountCredentials(), sections=(known_sections.PROVIDERS, "google_secrets") + ) + return GoogleSecretsProvider( + c, only_secrets=only_secrets, only_toml_fragments=only_toml_fragments + ) def _airflow_providers() -> List[ConfigProvider]: @@ -112,10 +135,13 @@ def _airflow_providers() -> List[ConfigProvider]: # hide stdio. airflow typically dumps tons of warnings and deprecations to stdout and stderr with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()): # try to get dlt secrets variable. many broken Airflow installations break here. in that case do not create - from airflow.models import Variable # noqa + from airflow.models import Variable # noqa + from dlt.common.configuration.providers.airflow import AirflowSecretsTomlProvider + # probe if Airflow variable containing all secrets is present from dlt.common.configuration.providers.toml import SECRETS_TOML_KEY + secrets_toml_var = Variable.get(SECRETS_TOML_KEY, default_var=None) # providers can be returned - mind that AirflowSecretsTomlProvider() requests the variable above immediately @@ -123,13 +149,18 @@ def _airflow_providers() -> List[ConfigProvider]: # check if we are in task context and provide more info from airflow.operators.python import get_current_context # noqa + ti = get_current_context()["ti"] # log outside of stderr/out redirect if secrets_toml_var is None: - message = f"Airflow variable '{SECRETS_TOML_KEY}' was not found. " + \ - "This Airflow variable is a recommended place to hold the content of secrets.toml." + \ - "If you do not use Airflow variables to hold dlt configuration or use variables with other names you can ignore this warning." + message = ( + f"Airflow variable '{SECRETS_TOML_KEY}' was not found. " + + "This Airflow variable is a recommended place to hold the content of" + " secrets.toml." + + "If you do not use Airflow variables to hold dlt configuration or use variables" + " with other names you can ignore this warning." + ) ti.log.warning(message) except Exception: diff --git a/dlt/common/configuration/specs/config_section_context.py b/dlt/common/configuration/specs/config_section_context.py index e251b4f01b..ea1e7d0d94 100644 --- a/dlt/common/configuration/specs/config_section_context.py +++ b/dlt/common/configuration/specs/config_section_context.py @@ -1,11 +1,11 @@ -from typing import Callable, List, Optional, Tuple, TYPE_CHECKING -from dlt.common.configuration.specs import known_sections +from typing import TYPE_CHECKING, Callable, List, Optional, Tuple +from dlt.common.configuration.specs import known_sections from dlt.common.configuration.specs.base_configuration import ContainerInjectableContext, configspec + @configspec class ConfigSectionContext(ContainerInjectableContext): - TMergeFunc = Callable[["ConfigSectionContext", "ConfigSectionContext"], None] pipeline_name: Optional[str] @@ -13,7 +13,6 @@ class ConfigSectionContext(ContainerInjectableContext): merge_style: TMergeFunc = None source_state_key: str = None - def merge(self, existing: "ConfigSectionContext") -> None: """Merges existing context into incoming using a merge style function""" merge_style_f = self.merge_style or self.prefer_incoming @@ -40,15 +39,20 @@ def prefer_incoming(incoming: "ConfigSectionContext", existing: "ConfigSectionCo @staticmethod def prefer_existing(incoming: "ConfigSectionContext", existing: "ConfigSectionContext") -> None: """Prefer existing section context when merging this context before injecting""" - incoming.pipeline_name = existing.pipeline_name or incoming.pipeline_name - incoming.sections = existing.sections or incoming.sections - incoming.source_state_key = existing.source_state_key or incoming.source_state_key + incoming.pipeline_name = existing.pipeline_name or incoming.pipeline_name + incoming.sections = existing.sections or incoming.sections + incoming.source_state_key = existing.source_state_key or incoming.source_state_key @staticmethod - def resource_merge_style(incoming: "ConfigSectionContext", existing: "ConfigSectionContext") -> None: + def resource_merge_style( + incoming: "ConfigSectionContext", existing: "ConfigSectionContext" + ) -> None: """If top level section is same and there are 3 sections it replaces second element (source module) from existing and keeps the 3rd element (name)""" incoming.pipeline_name = incoming.pipeline_name or existing.pipeline_name - if len(incoming.sections) == 3 == len(existing.sections) and incoming.sections[0] == existing.sections[0]: + if ( + len(incoming.sections) == 3 == len(existing.sections) + and incoming.sections[0] == existing.sections[0] + ): incoming.sections = (incoming.sections[0], existing.sections[1], incoming.sections[2]) incoming.source_state_key = existing.source_state_key or incoming.source_state_key else: @@ -56,9 +60,18 @@ def resource_merge_style(incoming: "ConfigSectionContext", existing: "ConfigSect incoming.source_state_key = incoming.source_state_key or existing.source_state_key def __str__(self) -> str: - return super().__str__() + f": {self.pipeline_name} {self.sections}@{self.merge_style} state['{self.source_state_key}']" + return ( + super().__str__() + + f": {self.pipeline_name} {self.sections}@{self.merge_style} state['{self.source_state_key}']" + ) if TYPE_CHECKING: # provide __init__ signature when type checking - def __init__(self, pipeline_name:str = None, sections: Tuple[str, ...] = (), merge_style: TMergeFunc = None, source_state_key: str = None) -> None: + def __init__( + self, + pipeline_name: str = None, + sections: Tuple[str, ...] = (), + merge_style: TMergeFunc = None, + source_state_key: str = None, + ) -> None: ... diff --git a/dlt/common/configuration/specs/connection_string_credentials.py b/dlt/common/configuration/specs/connection_string_credentials.py index 386535122b..3df091e9ca 100644 --- a/dlt/common/configuration/specs/connection_string_credentials.py +++ b/dlt/common/configuration/specs/connection_string_credentials.py @@ -1,9 +1,10 @@ from typing import Any, ClassVar, Dict, List, Optional + from sqlalchemy.engine import URL, make_url -from dlt.common.configuration.specs.exceptions import InvalidConnectionString -from dlt.common.typing import TSecretValue from dlt.common.configuration.specs.base_configuration import CredentialsConfiguration, configspec +from dlt.common.configuration.specs.exceptions import InvalidConnectionString +from dlt.common.typing import TSecretValue @configspec @@ -24,9 +25,7 @@ def parse_native_representation(self, native_value: Any) -> None: try: url = make_url(native_value) # update only values that are not None - self.update( - {k: v for k,v in url._asdict().items() if v is not None} - ) + self.update({k: v for k, v in url._asdict().items() if v is not None}) if self.query is not None: self.query = dict(self.query) except Exception: @@ -40,7 +39,15 @@ def to_native_representation(self) -> str: return self.to_url().render_as_string(hide_password=False) def to_url(self) -> URL: - return URL.create(self.drivername, self.username, self.password, self.host, self.port, self.database, self.query) + return URL.create( + self.drivername, + self.username, + self.password, + self.host, + self.port, + self.database, + self.query, + ) def __str__(self) -> str: return self.to_url().render_as_string(hide_password=True) diff --git a/dlt/common/configuration/specs/exceptions.py b/dlt/common/configuration/specs/exceptions.py index 054d21c78c..f1cb528be2 100644 --- a/dlt/common/configuration/specs/exceptions.py +++ b/dlt/common/configuration/specs/exceptions.py @@ -1,4 +1,5 @@ from typing import Any, Type + from dlt.common.configuration.exceptions import ConfigurationException @@ -9,7 +10,10 @@ class SpecException(ConfigurationException): class OAuth2ScopesRequired(SpecException): def __init__(self, spec: type) -> None: self.spec = spec - super().__init__("Scopes are required to retrieve refresh_token. Use 'openid' scope for a token without any permissions to resources.") + super().__init__( + "Scopes are required to retrieve refresh_token. Use 'openid' scope for a token without" + " any permissions to resources." + ) class NativeValueError(SpecException, ValueError): @@ -22,29 +26,46 @@ def __init__(self, spec: Type[Any], native_value: str, msg: str) -> None: class InvalidConnectionString(NativeValueError): def __init__(self, spec: Type[Any], native_value: str, driver: str): driver = driver or "driver" - msg = f"The expected representation for {spec.__name__} is a standard database connection string with the following format: {driver}://username:password@host:port/database." + msg = ( + f"The expected representation for {spec.__name__} is a standard database connection" + f" string with the following format: {driver}://username:password@host:port/database." + ) super().__init__(spec, native_value, msg) class InvalidGoogleNativeCredentialsType(NativeValueError): def __init__(self, spec: Type[Any], native_value: Any): - msg = f"Credentials {spec.__name__} accept a string with serialized credentials json file or an instance of Credentials object from google.* namespace. The value passed is of type {type(native_value)}" + msg = ( + f"Credentials {spec.__name__} accept a string with serialized credentials json file or" + " an instance of Credentials object from google.* namespace. The value passed is of" + f" type {type(native_value)}" + ) super().__init__(spec, native_value, msg) class InvalidGoogleServicesJson(NativeValueError): def __init__(self, spec: Type[Any], native_value: Any): - msg = f"The expected representation for {spec.__name__} is a string with serialized service account credentials, where at least 'project_id', 'private_key' and 'client_email` keys are present" + msg = ( + f"The expected representation for {spec.__name__} is a string with serialized service" + " account credentials, where at least 'project_id', 'private_key' and 'client_email`" + " keys are present" + ) super().__init__(spec, native_value, msg) class InvalidGoogleOauth2Json(NativeValueError): def __init__(self, spec: Type[Any], native_value: Any): - msg = f"The expected representation for {spec.__name__} is a string with serialized oauth2 user info and may be wrapped in 'install'/'web' node - depending of oauth2 app type." + msg = ( + f"The expected representation for {spec.__name__} is a string with serialized oauth2" + " user info and may be wrapped in 'install'/'web' node - depending of oauth2 app type." + ) super().__init__(spec, native_value, msg) class InvalidBoto3Session(NativeValueError): def __init__(self, spec: Type[Any], native_value: Any): - msg = f"The expected representation for {spec.__name__} is and instance of boto3.Session containing credentials" + msg = ( + f"The expected representation for {spec.__name__} is and instance of boto3.Session" + " containing credentials" + ) super().__init__(spec, native_value, msg) diff --git a/dlt/common/configuration/specs/gcp_credentials.py b/dlt/common/configuration/specs/gcp_credentials.py index f96c1d44f5..d8637b7741 100644 --- a/dlt/common/configuration/specs/gcp_credentials.py +++ b/dlt/common/configuration/specs/gcp_credentials.py @@ -1,13 +1,24 @@ import sys from typing import Any, Final, List, Tuple, Union + from deprecated import deprecated from dlt.common import json, pendulum from dlt.common.configuration.specs.api_credentials import OAuth2Credentials -from dlt.common.configuration.specs.exceptions import InvalidGoogleNativeCredentialsType, InvalidGoogleOauth2Json, InvalidGoogleServicesJson, NativeValueError, OAuth2ScopesRequired +from dlt.common.configuration.specs.base_configuration import ( + CredentialsConfiguration, + CredentialsWithDefault, + configspec, +) +from dlt.common.configuration.specs.exceptions import ( + InvalidGoogleNativeCredentialsType, + InvalidGoogleOauth2Json, + InvalidGoogleServicesJson, + NativeValueError, + OAuth2ScopesRequired, +) from dlt.common.exceptions import MissingDependencyException -from dlt.common.typing import DictStrAny, TSecretValue, StrAny -from dlt.common.configuration.specs.base_configuration import CredentialsConfiguration, CredentialsWithDefault, configspec +from dlt.common.typing import DictStrAny, StrAny, TSecretValue from dlt.common.utils import is_interactive @@ -18,7 +29,9 @@ class GcpCredentials(CredentialsConfiguration): project_id: str = None - location: str = "US" # DEPRECATED! and present only for backward compatibility. please set bigquery location in BigQuery configuration + location: str = ( # DEPRECATED! and present only for backward compatibility. please set bigquery location in BigQuery configuration + "US" + ) def parse_native_representation(self, native_value: Any) -> None: if not isinstance(native_value, str): @@ -49,12 +62,13 @@ def parse_native_representation(self, native_value: Any) -> None: service_dict: DictStrAny = None try: from google.oauth2.service_account import Credentials as ServiceAccountCredentials + if isinstance(native_value, ServiceAccountCredentials): # extract credentials service_dict = { "project_id": native_value.project_id, "client_email": native_value.service_account_email, - "private_key": native_value # keep native credentials in private key + "private_key": native_value, # keep native credentials in private key } self.__is_resolved__ = True except ImportError: @@ -84,6 +98,7 @@ def to_native_credentials(self) -> Any: """Returns google.oauth2.service_account.Credentials""" from google.oauth2.service_account import Credentials as ServiceAccountCredentials + if isinstance(self.private_key, ServiceAccountCredentials): # private key holds the native instance if it was passed to parse_native_representation return self.private_key @@ -105,6 +120,7 @@ def parse_native_representation(self, native_value: Any) -> None: oauth_dict: DictStrAny = None try: from google.oauth2.credentials import Credentials as GoogleOAuth2Credentials + if isinstance(native_value, GoogleOAuth2Credentials): # extract credentials, project id may not be present oauth_dict = { @@ -113,7 +129,7 @@ def parse_native_representation(self, native_value: Any) -> None: "client_secret": native_value.client_secret, "refresh_token": native_value.refresh_token, "scopes": native_value.scopes, - "token": native_value.token + "token": native_value.token, } # if token is present, we are logged in self.__is_resolved__ = native_value.token is not None @@ -141,8 +157,12 @@ def auth(self, scopes: Union[str, List[str]] = None, redirect_url: str = None) - self.add_scopes(scopes) if not self.scopes: raise OAuth2ScopesRequired(self.__class__) - assert sys.stdin.isatty() or is_interactive(), "Must have a tty or interactive mode for web flow" - self.refresh_token, self.token = self._get_refresh_token(redirect_url or "http://localhost") + assert ( + sys.stdin.isatty() or is_interactive() + ), "Must have a tty or interactive mode for web flow" + self.refresh_token, self.token = self._get_refresh_token( + redirect_url or "http://localhost" + ) else: # if scopes or redirect_url: # logger.warning("Please note that scopes and redirect_url are ignored when getting access token") @@ -164,11 +184,10 @@ def _get_access_token(self) -> TSecretValue: raise MissingDependencyException("GcpOAuthCredentials", ["requests_oauthlib"]) google = OAuth2Session(client_id=self.client_id, scope=self.scopes) - extra = { - "client_id": self.client_id, - "client_secret": self.client_secret - } - token = google.refresh_token(token_url=self.token_uri, refresh_token=self.refresh_token, **extra)["access_token"] + extra = {"client_id": self.client_id, "client_secret": self.client_secret} + token = google.refresh_token( + token_url=self.token_uri, refresh_token=self.refresh_token, **extra + )["access_token"] return TSecretValue(token) def _get_refresh_token(self, redirect_url: str) -> Tuple[TSecretValue, TSecretValue]: @@ -191,9 +210,7 @@ def to_native_credentials(self) -> Any: return credentials def _installed_dict(self, redirect_url: str = "http://localhost") -> StrAny: - installed_dict = { - self.client_type: self._info_dict() - } + installed_dict = {self.client_type: self._info_dict()} if redirect_url: installed_dict[self.client_type]["redirect_uris"] = [redirect_url] @@ -211,13 +228,13 @@ def __str__(self) -> str: @configspec class GcpDefaultCredentials(CredentialsWithDefault, GcpCredentials): - _LAST_FAILED_DEFAULT: float = 0.0 def parse_native_representation(self, native_value: Any) -> None: """Accepts google credentials as native value""" try: from google.auth.credentials import Credentials as GoogleCredentials + if isinstance(native_value, GoogleCredentials): self.project_id = self.project_id or native_value.quota_project_id self._set_default_credentials(native_value) @@ -226,11 +243,12 @@ def parse_native_representation(self, native_value: Any) -> None: return except ImportError: pass - raise NativeValueError(self.__class__, native_value, "Default Google Credentials not present") + raise NativeValueError( + self.__class__, native_value, "Default Google Credentials not present" + ) @staticmethod def _get_default_credentials(retry_timeout_s: float = 600.0) -> Tuple[Any, str]: - now = pendulum.now().timestamp() if now - GcpDefaultCredentials._LAST_FAILED_DEFAULT < retry_timeout_s: return None, None @@ -268,7 +286,9 @@ def to_native_credentials(self) -> Any: @configspec -class GcpServiceAccountCredentials(GcpDefaultCredentials, GcpServiceAccountCredentialsWithoutDefaults): +class GcpServiceAccountCredentials( + GcpDefaultCredentials, GcpServiceAccountCredentialsWithoutDefaults +): def parse_native_representation(self, native_value: Any) -> None: try: GcpDefaultCredentials.parse_native_representation(self, native_value) diff --git a/dlt/common/configuration/specs/known_sections.py b/dlt/common/configuration/specs/known_sections.py index 3663bb0d19..7107099072 100644 --- a/dlt/common/configuration/specs/known_sections.py +++ b/dlt/common/configuration/specs/known_sections.py @@ -19,5 +19,5 @@ DATA_WRITER = "data_writer" """default section holding BufferedDataWriter settings""" -DBT_PACKAGE_RUNNER = "dbt_package_runner" +DBT_PACKAGE_RUNNER = "dbt_package_runner" """dbt package runner configuration (DBTRunnerConfiguration)""" diff --git a/dlt/common/configuration/specs/run_configuration.py b/dlt/common/configuration/specs/run_configuration.py index 2ec3648dbe..6b6ca06214 100644 --- a/dlt/common/configuration/specs/run_configuration.py +++ b/dlt/common/configuration/specs/run_configuration.py @@ -1,12 +1,12 @@ import binascii from os.path import isfile, join from pathlib import Path -from typing import Any, Optional, Tuple, IO -from dlt.common.typing import TSecretStrValue +from typing import IO, Any, Optional, Tuple -from dlt.common.utils import encoding_for_mode, main_module_file_path, reveal_pseudo_secret -from dlt.common.configuration.specs.base_configuration import BaseConfiguration, configspec from dlt.common.configuration.exceptions import ConfigFileNotFoundException +from dlt.common.configuration.specs.base_configuration import BaseConfiguration, configspec +from dlt.common.typing import TSecretStrValue +from dlt.common.utils import encoding_for_mode, main_module_file_path, reveal_pseudo_secret @configspec @@ -16,7 +16,9 @@ class RunConfiguration(BaseConfiguration): slack_incoming_hook: Optional[TSecretStrValue] = None dlthub_telemetry: bool = True # enable or disable dlthub telemetry dlthub_telemetry_segment_write_key: str = "a1F2gc6cNYw2plyAt02sZouZcsRjG7TD" - log_format: str = '{asctime}|[{levelname:<21}]|{process}|{name}|{filename}|{funcName}:{lineno}|{message}' + log_format: str = ( + "{asctime}|[{levelname:<21}]|{process}|{name}|{filename}|{funcName}:{lineno}|{message}" + ) log_level: str = "WARNING" request_timeout: float = 60 """Timeout for http requests""" @@ -38,7 +40,9 @@ def on_resolved(self) -> None: # it may be obfuscated base64 value # TODO: that needs to be removed ASAP try: - self.slack_incoming_hook = TSecretStrValue(reveal_pseudo_secret(self.slack_incoming_hook, b"dlt-runtime-2022")) + self.slack_incoming_hook = TSecretStrValue( + reveal_pseudo_secret(self.slack_incoming_hook, b"dlt-runtime-2022") + ) except binascii.Error: # just keep the original value pass diff --git a/dlt/common/configuration/utils.py b/dlt/common/configuration/utils.py index 4841c8e3fa..4d22534c23 100644 --- a/dlt/common/configuration/utils.py +++ b/dlt/common/configuration/utils.py @@ -1,16 +1,20 @@ -import os import ast import contextlib -import tomlkit -from typing import Any, Dict, Mapping, NamedTuple, Optional, Tuple, Type, Sequence +import os from collections.abc import Mapping as C_Mapping +from typing import Any, Dict, Mapping, NamedTuple, Optional, Sequence, Tuple, Type + +import tomlkit from dlt.common import json -from dlt.common.typing import AnyType, TAny -from dlt.common.data_types import coerce_value, py_type_to_sc_type -from dlt.common.configuration.providers import EnvironProvider from dlt.common.configuration.exceptions import ConfigValueCannotBeCoercedException, LookupTrace -from dlt.common.configuration.specs.base_configuration import BaseConfiguration, is_base_configuration_inner_hint +from dlt.common.configuration.providers import EnvironProvider +from dlt.common.configuration.specs.base_configuration import ( + BaseConfiguration, + is_base_configuration_inner_hint, +) +from dlt.common.data_types import coerce_value, py_type_to_sc_type +from dlt.common.typing import AnyType, TAny class ResolvedValueTrace(NamedTuple): @@ -111,40 +115,56 @@ def auto_cast(value: str) -> Any: return value - -def log_traces(config: Optional[BaseConfiguration], key: str, hint: Type[Any], value: Any, default_value: Any, traces: Sequence[LookupTrace]) -> None: +def log_traces( + config: Optional[BaseConfiguration], + key: str, + hint: Type[Any], + value: Any, + default_value: Any, + traces: Sequence[LookupTrace], +) -> None: from dlt.common import logger # if logger.is_logging() and logger.log_level() == "DEBUG" and config: # logger.debug(f"Field {key} with type {hint} in {type(config).__name__} {'NOT RESOLVED' if value is None else 'RESOLVED'}") - # print(f"Field {key} with type {hint} in {type(config).__name__} {'NOT RESOLVED' if value is None else 'RESOLVED'}") - # for tr in traces: - # # print(str(tr)) - # logger.debug(str(tr)) + # print(f"Field {key} with type {hint} in {type(config).__name__} {'NOT RESOLVED' if value is None else 'RESOLVED'}") + # for tr in traces: + # # print(str(tr)) + # logger.debug(str(tr)) # store all traces with resolved values resolved_trace = next((trace for trace in traces if trace.value is not None), None) if resolved_trace is not None: path = f'{".".join(resolved_trace.sections)}.{key}' - _RESOLVED_TRACES[path] = ResolvedValueTrace(key, resolved_trace.value, default_value, hint, resolved_trace.sections, resolved_trace.provider, config) + _RESOLVED_TRACES[path] = ResolvedValueTrace( + key, + resolved_trace.value, + default_value, + hint, + resolved_trace.sections, + resolved_trace.provider, + config, + ) def get_resolved_traces() -> Dict[str, ResolvedValueTrace]: return _RESOLVED_TRACES -def add_config_to_env(config: BaseConfiguration, sections: Tuple[str, ...] = ()) -> None: +def add_config_to_env(config: BaseConfiguration, sections: Tuple[str, ...] = ()) -> None: """Writes values in configuration back into environment using the naming convention of EnvironProvider. Will descend recursively if embedded BaseConfiguration instances are found""" if config.__section__: - sections += (config.__section__, ) + sections += (config.__section__,) return add_config_dict_to_env(dict(config), sections, overwrite_keys=True) -def add_config_dict_to_env(dict_: Mapping[str, Any], sections: Tuple[str, ...] = (), overwrite_keys: bool = False) -> None: +def add_config_dict_to_env( + dict_: Mapping[str, Any], sections: Tuple[str, ...] = (), overwrite_keys: bool = False +) -> None: """Writes values in dict_ back into environment using the naming convention of EnvironProvider. Applies `sections` if specified. Does not overwrite existing keys by default""" for k, v in dict_.items(): if isinstance(v, BaseConfiguration): if not v.__section__: - embedded_sections = sections + (k, ) + embedded_sections = sections + (k,) else: embedded_sections = sections add_config_to_env(v, embedded_sections) diff --git a/dlt/common/data_types/__init__.py b/dlt/common/data_types/__init__.py index 9ad4df37fe..8a12815479 100644 --- a/dlt/common/data_types/__init__.py +++ b/dlt/common/data_types/__init__.py @@ -1,2 +1,2 @@ from dlt.common.data_types.type_helpers import coerce_value, py_type_to_sc_type -from dlt.common.data_types.typing import TDataType, DATA_TYPES \ No newline at end of file +from dlt.common.data_types.typing import DATA_TYPES, TDataType diff --git a/dlt/common/data_types/type_helpers.py b/dlt/common/data_types/type_helpers.py index d040471a46..6bcd8a6c2a 100644 --- a/dlt/common/data_types/type_helpers.py +++ b/dlt/common/data_types/type_helpers.py @@ -1,15 +1,16 @@ -import binascii import base64 +import binascii import datetime # noqa: I251 -from collections.abc import Mapping as C_Mapping, Sequence as C_Sequence -from typing import Any, Type, Literal, Union, cast +from collections.abc import Mapping as C_Mapping +from collections.abc import Sequence as C_Sequence +from typing import Any, Literal, Type, Union, cast -from dlt.common import pendulum, json, Decimal, Wei -from dlt.common.json import custom_pua_remove -from dlt.common.json._simplejson import custom_encode as json_custom_encode +from dlt.common import Decimal, Wei, json, pendulum from dlt.common.arithmetics import InvalidOperation from dlt.common.data_types.typing import TDataType -from dlt.common.time import ensure_pendulum_datetime, parse_iso_like_datetime, ensure_pendulum_date +from dlt.common.json import custom_pua_remove +from dlt.common.json._simplejson import custom_encode as json_custom_encode +from dlt.common.time import ensure_pendulum_date, ensure_pendulum_datetime, parse_iso_like_datetime from dlt.common.utils import map_nested_in_place, str2bool @@ -137,7 +138,7 @@ def coerce_value(to_type: TDataType, from_type: TDataType, value: Any) -> Any: except binascii.Error: raise ValueError(value) if from_type == "bigint": - return value.to_bytes((value.bit_length() + 7) // 8, 'little') + return value.to_bytes((value.bit_length() + 7) // 8, "little") if to_type == "bigint": if from_type in ["wei", "decimal", "double"]: diff --git a/dlt/common/data_types/typing.py b/dlt/common/data_types/typing.py index 727be6fa58..ab612d08a2 100644 --- a/dlt/common/data_types/typing.py +++ b/dlt/common/data_types/typing.py @@ -1,5 +1,6 @@ from typing import Literal, Set, get_args - -TDataType = Literal["text", "double", "bool", "timestamp", "bigint", "binary", "complex", "decimal", "wei", "date"] +TDataType = Literal[ + "text", "double", "bool", "timestamp", "bigint", "binary", "complex", "decimal", "wei", "date" +] DATA_TYPES: Set[TDataType] = set(get_args(TDataType)) diff --git a/dlt/common/data_writers/__init__.py b/dlt/common/data_writers/__init__.py index 89d4607c90..1eaca3b152 100644 --- a/dlt/common/data_writers/__init__.py +++ b/dlt/common/data_writers/__init__.py @@ -1,3 +1,7 @@ -from dlt.common.data_writers.writers import DataWriter, TLoaderFileFormat from dlt.common.data_writers.buffered import BufferedDataWriter -from dlt.common.data_writers.escape import escape_redshift_literal, escape_redshift_identifier, escape_bigquery_identifier \ No newline at end of file +from dlt.common.data_writers.escape import ( + escape_bigquery_identifier, + escape_redshift_identifier, + escape_redshift_literal, +) +from dlt.common.data_writers.writers import DataWriter, TLoaderFileFormat diff --git a/dlt/common/data_writers/buffered.py b/dlt/common/data_writers/buffered.py index 067631b935..c1f85cc9a8 100644 --- a/dlt/common/data_writers/buffered.py +++ b/dlt/common/data_writers/buffered.py @@ -1,19 +1,22 @@ import gzip -from typing import List, IO, Any, Optional, Type +from typing import IO, Any, List, Optional, Type -from dlt.common.utils import uniq_id -from dlt.common.typing import TDataItem, TDataItems +from dlt.common.configuration import configspec, known_sections, with_config +from dlt.common.configuration.specs import BaseConfiguration from dlt.common.data_writers import TLoaderFileFormat -from dlt.common.data_writers.exceptions import BufferedDataWriterClosed, DestinationCapabilitiesRequired, InvalidFileNameTemplateException +from dlt.common.data_writers.exceptions import ( + BufferedDataWriterClosed, + DestinationCapabilitiesRequired, + InvalidFileNameTemplateException, +) from dlt.common.data_writers.writers import DataWriter -from dlt.common.schema.typing import TTableSchemaColumns -from dlt.common.configuration import with_config, known_sections, configspec -from dlt.common.configuration.specs import BaseConfiguration from dlt.common.destination import DestinationCapabilitiesContext +from dlt.common.schema.typing import TTableSchemaColumns +from dlt.common.typing import TDataItem, TDataItems +from dlt.common.utils import uniq_id class BufferedDataWriter: - @configspec class BufferedDataWriterConfiguration(BaseConfiguration): buffer_max_items: int = 5000 @@ -24,7 +27,6 @@ class BufferedDataWriterConfiguration(BaseConfiguration): __section__ = known_sections.DATA_WRITER - @with_config(spec=BufferedDataWriterConfiguration) def __init__( self, @@ -50,7 +52,11 @@ def __init__( self.file_max_bytes = file_max_bytes self.file_max_items = file_max_items # the open function is either gzip.open or open - self.open = gzip.open if self._file_format_spec.supports_compression and not disable_compression else open + self.open = ( + gzip.open + if self._file_format_spec.supports_compression and not disable_compression + else open + ) self._current_columns: TTableSchemaColumns = None self._file_name: str = None @@ -67,7 +73,11 @@ def write_data_item(self, item: TDataItems, columns: TTableSchemaColumns) -> Non self._ensure_open() # rotate file if columns changed and writer does not allow for that # as the only allowed change is to add new column (no updates/deletes), we detect the change by comparing lengths - if self._writer and not self._writer.data_format().supports_schema_changes and len(columns) != len(self._current_columns): + if ( + self._writer + and not self._writer.data_format().supports_schema_changes + and len(columns) != len(self._current_columns) + ): assert len(columns) > len(self._current_columns) self._rotate_file() # until the first chunk is written we can change the columns schema freely @@ -112,7 +122,9 @@ def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb def _rotate_file(self) -> None: self._flush_and_close_file() - self._file_name = self.file_name_template % uniq_id(5) + "." + self._file_format_spec.file_extension + self._file_name = ( + self.file_name_template % uniq_id(5) + "." + self._file_format_spec.file_extension + ) def _flush_items(self, allow_empty_file: bool = False) -> None: if len(self._buffered_items) > 0 or allow_empty_file: @@ -120,10 +132,12 @@ def _flush_items(self, allow_empty_file: bool = False) -> None: if not self._writer: # create new writer and write header if self._file_format_spec.is_binary_format: - self._file = self.open(self._file_name, "wb") # type: ignore + self._file = self.open(self._file_name, "wb") # type: ignore else: - self._file = self.open(self._file_name, "wt", encoding="utf-8") # type: ignore - self._writer = DataWriter.from_file_format(self.file_format, self._file, caps=self._caps) + self._file = self.open(self._file_name, "wt", encoding="utf-8") # type: ignore + self._writer = DataWriter.from_file_format( + self.file_format, self._file, caps=self._caps + ) self._writer.write_header(self._current_columns) # write buffer if self._buffered_items: diff --git a/dlt/common/data_writers/escape.py b/dlt/common/data_writers/escape.py index c8e07ea45a..e04ccbc3d3 100644 --- a/dlt/common/data_writers/escape.py +++ b/dlt/common/data_writers/escape.py @@ -1,16 +1,19 @@ -import re import base64 -from typing import Any +import re from datetime import date, datetime # noqa: I251 +from typing import Any from dlt.common.json import json # use regex to escape characters in single pass SQL_ESCAPE_DICT = {"'": "''", "\\": "\\\\", "\n": "\\n", "\r": "\\r"} -SQL_ESCAPE_RE = re.compile("|".join([re.escape(k) for k in sorted(SQL_ESCAPE_DICT, key=len, reverse=True)]), flags=re.DOTALL) +SQL_ESCAPE_RE = re.compile( + "|".join([re.escape(k) for k in sorted(SQL_ESCAPE_DICT, key=len, reverse=True)]), + flags=re.DOTALL, +) -def _escape_extended(v: str, prefix:str = "E'") -> str: +def _escape_extended(v: str, prefix: str = "E'") -> str: return "{}{}{}".format(prefix, SQL_ESCAPE_RE.sub(lambda x: SQL_ESCAPE_DICT[x.group(0)], v), "'") @@ -25,7 +28,7 @@ def escape_redshift_literal(v: Any) -> Any: if isinstance(v, (datetime, date)): return f"'{v.isoformat()}'" if isinstance(v, (list, dict)): - return "json_parse(%s)" % _escape_extended(json.dumps(v), prefix='\'') + return "json_parse(%s)" % _escape_extended(json.dumps(v), prefix="'") return str(v) @@ -68,7 +71,7 @@ def escape_redshift_identifier(v: str) -> str: def escape_bigquery_identifier(v: str) -> str: # https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical - return "`" + v.replace("\\", "\\\\").replace("`","\\`") + "`" + return "`" + v.replace("\\", "\\\\").replace("`", "\\`") + "`" def escape_snowflake_identifier(v: str) -> str: diff --git a/dlt/common/data_writers/exceptions.py b/dlt/common/data_writers/exceptions.py index a86bd9440e..d3a073cf4e 100644 --- a/dlt/common/data_writers/exceptions.py +++ b/dlt/common/data_writers/exceptions.py @@ -9,7 +9,10 @@ class DataWriterException(DltException): class InvalidFileNameTemplateException(DataWriterException, ValueError): def __init__(self, file_name_template: str): self.file_name_template = file_name_template - super().__init__(f"Wrong file name template {file_name_template}. File name template must contain exactly one %s formatter") + super().__init__( + f"Wrong file name template {file_name_template}. File name template must contain" + " exactly one %s formatter" + ) class BufferedDataWriterClosed(DataWriterException): @@ -21,4 +24,6 @@ def __init__(self, file_name: str): class DestinationCapabilitiesRequired(DataWriterException, ValueError): def __init__(self, file_format: TLoaderFileFormat): self.file_format = file_format - super().__init__(f"Writer for {file_format} requires destination capabilities which were not provided.") + super().__init__( + f"Writer for {file_format} requires destination capabilities which were not provided." + ) diff --git a/dlt/common/data_writers/writers.py b/dlt/common/data_writers/writers.py index fdffd3dc30..2f20321e89 100644 --- a/dlt/common/data_writers/writers.py +++ b/dlt/common/data_writers/writers.py @@ -1,14 +1,14 @@ import abc - from dataclasses import dataclass -from typing import Any, Dict, Sequence, IO, Type, Optional, List, cast +from typing import IO, Any, Dict, List, Optional, Sequence, Type, cast from dlt.common import json -from dlt.common.typing import StrAny -from dlt.common.schema.typing import TTableSchemaColumns -from dlt.common.destination import TLoaderFileFormat, DestinationCapabilitiesContext -from dlt.common.configuration import with_config, known_sections, configspec +from dlt.common.configuration import configspec, known_sections, with_config from dlt.common.configuration.specs import BaseConfiguration +from dlt.common.destination import DestinationCapabilitiesContext, TLoaderFileFormat +from dlt.common.schema.typing import TTableSchemaColumns +from dlt.common.typing import StrAny + @dataclass class TFileFormatSpec: @@ -42,18 +42,21 @@ def write_all(self, columns_schema: TTableSchemaColumns, rows: Sequence[Any]) -> self.write_data(rows) self.write_footer() - @classmethod @abc.abstractmethod def data_format(cls) -> TFileFormatSpec: pass @classmethod - def from_file_format(cls, file_format: TLoaderFileFormat, f: IO[Any], caps: DestinationCapabilitiesContext = None) -> "DataWriter": + def from_file_format( + cls, file_format: TLoaderFileFormat, f: IO[Any], caps: DestinationCapabilitiesContext = None + ) -> "DataWriter": return cls.class_factory(file_format)(f, caps) @classmethod - def from_destination_capabilities(cls, caps: DestinationCapabilitiesContext, f: IO[Any]) -> "DataWriter": + def from_destination_capabilities( + cls, caps: DestinationCapabilitiesContext, f: IO[Any] + ) -> "DataWriter": return cls.class_factory(caps.preferred_loader_file_format)(f, caps) @classmethod @@ -75,7 +78,6 @@ def class_factory(file_format: TLoaderFileFormat) -> Type["DataWriter"]: class JsonlWriter(DataWriter): - def write_header(self, columns_schema: TTableSchemaColumns) -> None: pass @@ -100,7 +102,6 @@ def data_format(cls) -> TFileFormatSpec: class JsonlListPUAEncodeWriter(JsonlWriter): - def write_data(self, rows: Sequence[Any]) -> None: # skip JsonlWriter when calling super super(JsonlWriter, self).write_data(rows) @@ -121,7 +122,6 @@ def data_format(cls) -> TFileFormatSpec: class InsertValuesWriter(DataWriter): - def __init__(self, f: IO[Any], caps: DestinationCapabilitiesContext = None) -> None: super().__init__(f, caps) self._chunks_written = 0 @@ -143,7 +143,7 @@ def write_data(self, rows: Sequence[Any]) -> None: def write_row(row: StrAny) -> None: output = ["NULL"] * len(self._headers_lookup) - for n,v in row.items(): + for n, v in row.items(): output[self._headers_lookup[n]] = self._caps.escape_literal(v) self._f.write("(") self._f.write(",".join(output)) @@ -188,18 +188,19 @@ class ParquetDataWriterConfiguration(BaseConfiguration): __section__: str = known_sections.DATA_WRITER -class ParquetDataWriter(DataWriter): +class ParquetDataWriter(DataWriter): @with_config(spec=ParquetDataWriterConfiguration) - def __init__(self, - f: IO[Any], - caps: DestinationCapabilitiesContext = None, - *, - flavor: str = "spark", - version: str = "2.4", - data_page_size: int = 1024 * 1024, - timestamp_timezone: str = "UTC" - ) -> None: + def __init__( + self, + f: IO[Any], + caps: DestinationCapabilitiesContext = None, + *, + flavor: str = "spark", + version: str = "2.4", + data_page_size: int = 1024 * 1024, + timestamp_timezone: str = "UTC" + ) -> None: super().__init__(f, caps) from dlt.common.libs.pyarrow import pyarrow @@ -212,20 +213,32 @@ def __init__(self, self.timestamp_timezone = timestamp_timezone def write_header(self, columns_schema: TTableSchemaColumns) -> None: - from dlt.common.libs.pyarrow import pyarrow, get_py_arrow_datatype + from dlt.common.libs.pyarrow import get_py_arrow_datatype, pyarrow # build schema self.schema = pyarrow.schema( - [pyarrow.field( - name, - get_py_arrow_datatype(schema_item["data_type"], self._caps, self.timestamp_timezone), - nullable=schema_item["nullable"] - ) for name, schema_item in columns_schema.items()] + [ + pyarrow.field( + name, + get_py_arrow_datatype( + schema_item["data_type"], self._caps, self.timestamp_timezone + ), + nullable=schema_item["nullable"], + ) + for name, schema_item in columns_schema.items() + ] ) # find row items that are of the complex type (could be abstracted out for use in other writers?) - self.complex_indices = [i for i, field in columns_schema.items() if field["data_type"] == "complex"] - self.writer = pyarrow.parquet.ParquetWriter(self._f, self.schema, flavor=self.parquet_flavor, version=self.parquet_version, data_page_size=self.parquet_data_page_size) - + self.complex_indices = [ + i for i, field in columns_schema.items() if field["data_type"] == "complex" + ] + self.writer = pyarrow.parquet.ParquetWriter( + self._f, + self.schema, + flavor=self.parquet_flavor, + version=self.parquet_version, + data_page_size=self.parquet_data_page_size, + ) def write_data(self, rows: Sequence[Any]) -> None: super().write_data(rows) @@ -245,7 +258,13 @@ def write_footer(self) -> None: self.writer.close() self.writer = None - @classmethod def data_format(cls) -> TFileFormatSpec: - return TFileFormatSpec("parquet", "parquet", True, False, requires_destination_capabilities=True, supports_compression=False) + return TFileFormatSpec( + "parquet", + "parquet", + True, + False, + requires_destination_capabilities=True, + supports_compression=False, + ) diff --git a/dlt/common/destination/__init__.py b/dlt/common/destination/__init__.py index d4e91acdad..8ccff628a8 100644 --- a/dlt/common/destination/__init__.py +++ b/dlt/common/destination/__init__.py @@ -1,2 +1,2 @@ from dlt.common.destination.capabilities import DestinationCapabilitiesContext, TLoaderFileFormat -from dlt.common.destination.reference import DestinationReference, TDestinationReferenceArg \ No newline at end of file +from dlt.common.destination.reference import DestinationReference, TDestinationReferenceArg diff --git a/dlt/common/destination/capabilities.py b/dlt/common/destination/capabilities.py index c5bc3050de..46188ba998 100644 --- a/dlt/common/destination/capabilities.py +++ b/dlt/common/destination/capabilities.py @@ -1,12 +1,10 @@ -from typing import Any, Callable, ClassVar, List, Literal, Optional, Tuple, Set, get_args +from typing import Any, Callable, ClassVar, List, Literal, Optional, Set, Tuple, get_args -from dlt.common.configuration.utils import serialize_value +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE from dlt.common.configuration import configspec from dlt.common.configuration.specs import ContainerInjectableContext +from dlt.common.configuration.utils import serialize_value from dlt.common.utils import identity - -from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE - from dlt.common.wei import EVM_DECIMAL_PRECISION # known loader file formats @@ -18,12 +16,15 @@ # file formats used internally by dlt INTERNAL_LOADER_FILE_FORMATS: Set[TLoaderFileFormat] = {"puae-jsonl", "sql", "reference"} # file formats that may be chosen by the user -EXTERNAL_LOADER_FILE_FORMATS: Set[TLoaderFileFormat] = set(get_args(TLoaderFileFormat)) - INTERNAL_LOADER_FILE_FORMATS +EXTERNAL_LOADER_FILE_FORMATS: Set[TLoaderFileFormat] = ( + set(get_args(TLoaderFileFormat)) - INTERNAL_LOADER_FILE_FORMATS +) @configspec class DestinationCapabilitiesContext(ContainerInjectableContext): """Injectable destination capabilities required for many Pipeline stages ie. normalize""" + preferred_loader_file_format: TLoaderFileFormat supported_loader_file_formats: List[TLoaderFileFormat] preferred_staging_file_format: Optional[TLoaderFileFormat] @@ -50,7 +51,9 @@ class DestinationCapabilitiesContext(ContainerInjectableContext): can_create_default: ClassVar[bool] = False @staticmethod - def generic_capabilities(preferred_loader_file_format: TLoaderFileFormat = None) -> "DestinationCapabilitiesContext": + def generic_capabilities( + preferred_loader_file_format: TLoaderFileFormat = None, + ) -> "DestinationCapabilitiesContext": caps = DestinationCapabilitiesContext() caps.preferred_loader_file_format = preferred_loader_file_format caps.supported_loader_file_formats = ["jsonl", "insert_values", "parquet"] diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index 2408cf5882..c8c80215bb 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -1,26 +1,49 @@ from abc import ABC, abstractmethod -from importlib import import_module -from types import TracebackType, ModuleType -from typing import ClassVar, Final, Optional, Literal, Sequence, Iterable, Type, Protocol, Union, TYPE_CHECKING, cast, List, ContextManager from contextlib import contextmanager +from importlib import import_module +from types import ModuleType, TracebackType +from typing import ( + TYPE_CHECKING, + ClassVar, + ContextManager, + Final, + Iterable, + List, + Literal, + Optional, + Protocol, + Sequence, + Type, + Union, + cast, +) from dlt.common import logger -from dlt.common.exceptions import IdentifierTooLongException, InvalidDestinationReference, UnknownDestinationModule -from dlt.common.schema import Schema, TTableSchema, TSchemaTables -from dlt.common.schema.typing import TWriteDisposition -from dlt.common.schema.exceptions import InvalidDatasetName from dlt.common.configuration import configspec -from dlt.common.configuration.specs import BaseConfiguration, CredentialsConfiguration from dlt.common.configuration.accessors import config +from dlt.common.configuration.specs import ( + AwsCredentialsWithoutDefaults, + BaseConfiguration, + CredentialsConfiguration, + GcpCredentials, +) from dlt.common.destination.capabilities import DestinationCapabilitiesContext +from dlt.common.exceptions import ( + IdentifierTooLongException, + InvalidDestinationReference, + UnknownDestinationModule, +) +from dlt.common.schema import Schema, TSchemaTables, TTableSchema +from dlt.common.schema.exceptions import InvalidDatasetName +from dlt.common.schema.typing import TWriteDisposition from dlt.common.schema.utils import is_complete_column from dlt.common.storages import FileStorage from dlt.common.storages.load_storage import ParsedLoadJobFileName from dlt.common.utils import get_module_name -from dlt.common.configuration.specs import GcpCredentials, AwsCredentialsWithoutDefaults TLoaderReplaceStrategy = Literal["truncate-and-insert", "insert-from-staging", "staging-optimized"] + @configspec class DestinationClientConfiguration(BaseConfiguration): destination_name: str = None # which destination to load data to @@ -35,8 +58,12 @@ def __str__(self) -> str: return str(self.credentials) if TYPE_CHECKING: - def __init__(self, destination_name: str = None, credentials: Optional[CredentialsConfiguration] = None -) -> None: + + def __init__( + self, + destination_name: str = None, + credentials: Optional[CredentialsConfiguration] = None, + ) -> None: ... @@ -54,7 +81,7 @@ class DestinationClientDwhConfiguration(DestinationClientConfiguration): def normalize_dataset_name(self, schema: Schema) -> str: """Builds full db dataset (schema) name out of configured dataset name and schema name: {dataset_name}_{schema.name}. The resulting name is normalized. - If default schema name equals schema.name, the schema suffix is skipped. + If default schema name equals schema.name, the schema suffix is skipped. """ if not schema.name: raise ValueError("schema_name is None or empty") @@ -62,11 +89,18 @@ def normalize_dataset_name(self, schema: Schema) -> str: # if default schema is None then suffix is not added if self.default_schema_name is not None and schema.name != self.default_schema_name: # also normalize schema name. schema name is Python identifier and here convention may be different - return schema.naming.normalize_table_identifier((self.dataset_name or "") + "_" + schema.name) + return schema.naming.normalize_table_identifier( + (self.dataset_name or "") + "_" + schema.name + ) - return self.dataset_name if not self.dataset_name else schema.naming.normalize_table_identifier(self.dataset_name) + return ( + self.dataset_name + if not self.dataset_name + else schema.naming.normalize_table_identifier(self.dataset_name) + ) if TYPE_CHECKING: + def __init__( self, destination_name: str = None, @@ -76,18 +110,21 @@ def __init__( ) -> None: ... + @configspec class DestinationClientStagingConfiguration(DestinationClientDwhConfiguration): """Configuration of a staging destination, able to store files with desired `layout` at `bucket_url`. - Also supports datasets and can act as standalone destination. + Also supports datasets and can act as standalone destination. """ + as_staging: bool = False bucket_url: str = None # layout of the destination files layout: str = "{table_name}/{load_id}.{file_id}.{ext}" if TYPE_CHECKING: + def __init__( self, destination_name: str = None, @@ -96,23 +133,26 @@ def __init__( default_schema_name: Optional[str] = None, as_staging: bool = False, bucket_url: str = None, - layout: str = None + layout: str = None, ) -> None: ... + @configspec class DestinationClientDwhWithStagingConfiguration(DestinationClientDwhConfiguration): """Configuration of a destination that can take data from staging destination""" + staging_config: Optional[DestinationClientStagingConfiguration] = None """configuration of the staging, if present, injected at runtime""" if TYPE_CHECKING: + def __init__( self, destination_name: str = None, credentials: Optional[CredentialsConfiguration] = None, dataset_name: str = None, default_schema_name: Optional[str] = None, - staging_config: Optional[DestinationClientStagingConfiguration] = None + staging_config: Optional[DestinationClientStagingConfiguration] = None, ) -> None: ... @@ -123,14 +163,15 @@ def __init__( class LoadJob: """Represents a job that loads a single file - Each job starts in "running" state and ends in one of terminal states: "retry", "failed" or "completed". - Each job is uniquely identified by a file name. The file is guaranteed to exist in "running" state. In terminal state, the file may not be present. - In "running" state, the loader component periodically gets the state via `status()` method. When terminal state is reached, load job is discarded and not called again. - `exception` method is called to get error information in "failed" and "retry" states. + Each job starts in "running" state and ends in one of terminal states: "retry", "failed" or "completed". + Each job is uniquely identified by a file name. The file is guaranteed to exist in "running" state. In terminal state, the file may not be present. + In "running" state, the loader component periodically gets the state via `status()` method. When terminal state is reached, load job is discarded and not called again. + `exception` method is called to get error information in "failed" and "retry" states. - The `__init__` method is responsible to put the Job in "running" state. It may raise `LoadClientTerminalException` and `LoadClientTransientException` to - immediately transition job into "failed" or "retry" state respectively. + The `__init__` method is responsible to put the Job in "running" state. It may raise `LoadClientTerminalException` and `LoadClientTransientException` to + immediately transition job into "failed" or "retry" state respectively. """ + def __init__(self, file_name: str) -> None: """ File name is also a job id (or job id is deterministically derived) so it must be globally unique @@ -173,12 +214,12 @@ def new_file_path(self) -> str: class FollowupJob: """Adds a trait that allows to create a followup job""" + def create_followup_jobs(self, next_state: str) -> List[NewLoadJob]: return [] class JobClientBase(ABC): - capabilities: ClassVar[DestinationCapabilitiesContext] = None def __init__(self, schema: Schema, config: DestinationClientConfiguration) -> None: @@ -187,8 +228,7 @@ def __init__(self, schema: Schema, config: DestinationClientConfiguration) -> No @abstractmethod def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: - """Prepares storage to be used ie. creates database schema or file system folder. Truncates requested tables. - """ + """Prepares storage to be used ie. creates database schema or file system folder. Truncates requested tables.""" pass @abstractmethod @@ -196,7 +236,9 @@ def is_storage_initialized(self) -> bool: """Returns if storage is ready to be read/written.""" pass - def update_storage_schema(self, only_tables: Iterable[str] = None, expected_update: TSchemaTables = None) -> Optional[TSchemaTables]: + def update_storage_schema( + self, only_tables: Iterable[str] = None, expected_update: TSchemaTables = None + ) -> Optional[TSchemaTables]: """Updates storage to the current schema. Implementations should not assume that `expected_update` is the exact difference between destination state and the self.schema. This is only the case if @@ -225,7 +267,9 @@ def get_truncate_destination_table_dispositions(self) -> List[TWriteDisposition] # in the base job, all replace strategies are treated the same, see filesystem for example return ["replace"] - def create_table_chain_completed_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + def create_table_chain_completed_followup_jobs( + self, table_chain: Sequence[TTableSchema] + ) -> List[NewLoadJob]: """Creates a list of followup jobs that should be executed after a table chain is completed""" return [] @@ -239,7 +283,9 @@ def __enter__(self) -> "JobClientBase": pass @abstractmethod - def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType) -> None: + def __exit__( + self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType + ) -> None: pass def _verify_schema(self) -> None: @@ -252,17 +298,27 @@ def _verify_schema(self) -> None: for table in self.schema.data_tables(): table_name = table["name"] if len(table_name) > self.capabilities.max_identifier_length: - raise IdentifierTooLongException(self.config.destination_name, "table", table_name, self.capabilities.max_identifier_length) + raise IdentifierTooLongException( + self.config.destination_name, + "table", + table_name, + self.capabilities.max_identifier_length, + ) for column_name, column in dict(table["columns"]).items(): if len(column_name) > self.capabilities.max_column_identifier_length: raise IdentifierTooLongException( self.config.destination_name, "column", f"{table_name}.{column_name}", - self.capabilities.max_column_identifier_length + self.capabilities.max_column_identifier_length, ) if not is_complete_column(column): - logger.warning(f"A column {column_name} in table {table_name} in schema {self.schema.name} is incomplete. It was not bound to the data during normalizations stage and its data type is unknown. Did you add this column manually in code ie. as a merge key?") + logger.warning( + f"A column {column_name} in table {table_name} in schema" + f" {self.schema.name} is incomplete. It was not bound to the data during" + " normalizations stage and its data type is unknown. Did you add this" + " column manually in code ie. as a merge key?" + ) class WithStagingDataset: @@ -274,7 +330,7 @@ def get_stage_dispositions(self) -> List[TWriteDisposition]: return [] @abstractmethod - def with_staging_dataset(self)-> ContextManager["JobClientBase"]: + def with_staging_dataset(self) -> ContextManager["JobClientBase"]: """Executes job client methods on staging dataset""" return self # type: ignore @@ -289,7 +345,9 @@ class DestinationReference(Protocol): def capabilities(self) -> DestinationCapabilitiesContext: """Destination capabilities ie. supported loader file formats, identifier name lengths, naming conventions, escape function etc.""" - def client(self, schema: Schema, initial_config: DestinationClientConfiguration = config.value) -> "JobClientBase": + def client( + self, schema: Schema, initial_config: DestinationClientConfiguration = config.value + ) -> "JobClientBase": """A job client responsible for starting and resuming load jobs""" def spec(self) -> Type[DestinationClientConfiguration]: @@ -308,7 +366,9 @@ def from_name(destination: TDestinationReferenceArg) -> "DestinationReference": destination_ref = cast(DestinationReference, import_module(destination)) else: # from known location - destination_ref = cast(DestinationReference, import_module(f"dlt.destinations.{destination}")) + destination_ref = cast( + DestinationReference, import_module(f"dlt.destinations.{destination}") + ) except ImportError: if "." in destination: raise UnknownDestinationModule(destination) diff --git a/dlt/common/exceptions.py b/dlt/common/exceptions.py index f07c41f4dc..ed0a6a25b7 100644 --- a/dlt/common/exceptions.py +++ b/dlt/common/exceptions.py @@ -1,4 +1,4 @@ -from typing import Any, AnyStr, List, Sequence, Optional, Iterable +from typing import Any, AnyStr, Iterable, List, Optional, Sequence class DltException(Exception): @@ -6,10 +6,14 @@ def __reduce__(self) -> Any: """Enables exceptions with parametrized constructor to be pickled""" return type(self).__new__, (type(self), *self.args), self.__dict__ + class UnsupportedProcessStartMethodException(DltException): def __init__(self, method: str) -> None: self.method = method - super().__init__(f"Process pool supports only fork start method, {method} not supported. Switch the pool type to threading") + super().__init__( + f"Process pool supports only fork start method, {method} not supported. Switch the pool" + " type to threading" + ) class CannotInstallDependency(DltException): @@ -20,7 +24,9 @@ def __init__(self, dependency: str, interpreter: str, output: AnyStr) -> None: str_output = output.decode("utf-8") else: str_output = output - super().__init__(f"Cannot install dependency {dependency} with {interpreter} and pip:\n{str_output}\n") + super().__init__( + f"Cannot install dependency {dependency} with {interpreter} and pip:\n{str_output}\n" + ) class VenvNotFound(DltException): @@ -49,6 +55,7 @@ class TerminalValueError(ValueError, TerminalException): class SignalReceivedException(KeyboardInterrupt, TerminalException): """Raises when signal comes. Derives from `BaseException` to not be caught in regular exception handlers.""" + def __init__(self, signal_code: int) -> None: self.signal_code = signal_code super().__init__(f"Signal {signal_code} received") @@ -87,7 +94,7 @@ def _get_msg(self, appendix: str) -> str: return msg def _to_pip_install(self) -> str: - return "\n".join([f"pip install \"{d}\"" for d in self.dependencies]) + return "\n".join([f'pip install "{d}"' for d in self.dependencies]) class DestinationException(DltException): @@ -128,11 +135,13 @@ def __init__(self, destination: str) -> None: self.destination = destination super().__init__(f"Destination {destination} does not support loading via staging.") + class DestinationLoadingWithoutStagingNotSupported(DestinationTerminalException): def __init__(self, destination: str) -> None: self.destination = destination super().__init__(f"Destination {destination} does not support loading without staging.") + class DestinationNoStagingMode(DestinationTerminalException): def __init__(self, destination: str) -> None: self.destination = destination @@ -140,7 +149,9 @@ def __init__(self, destination: str) -> None: class DestinationIncompatibleLoaderFileFormatException(DestinationTerminalException): - def __init__(self, destination: str, staging: str, file_format: str, supported_formats: Iterable[str]) -> None: + def __init__( + self, destination: str, staging: str, file_format: str, supported_formats: Iterable[str] + ) -> None: self.destination = destination self.staging = staging self.file_format = file_format @@ -148,21 +159,41 @@ def __init__(self, destination: str, staging: str, file_format: str, supported_f supported_formats_str = ", ".join(supported_formats) if self.staging: if not supported_formats: - msg = f"Staging {staging} cannot be used with destination {destination} because they have no file formats in common." + msg = ( + f"Staging {staging} cannot be used with destination {destination} because they" + " have no file formats in common." + ) else: - msg = f"Unsupported file format {file_format} for destination {destination} in combination with staging destination {staging}. Supported formats: {supported_formats_str}" + msg = ( + f"Unsupported file format {file_format} for destination {destination} in" + f" combination with staging destination {staging}. Supported formats:" + f" {supported_formats_str}" + ) else: - msg = f"Unsupported file format {file_format} destination {destination}. Supported formats: {supported_formats_str}. Check the staging option in the dlt.pipeline for additional formats." + msg = ( + f"Unsupported file format {file_format} destination {destination}. Supported" + f" formats: {supported_formats_str}. Check the staging option in the dlt.pipeline" + " for additional formats." + ) super().__init__(msg) class IdentifierTooLongException(DestinationTerminalException): - def __init__(self, destination_name: str, identifier_type: str, identifier_name: str, max_identifier_length: int) -> None: + def __init__( + self, + destination_name: str, + identifier_type: str, + identifier_name: str, + max_identifier_length: int, + ) -> None: self.destination_name = destination_name self.identifier_type = identifier_type self.identifier_name = identifier_name self.max_identifier_length = max_identifier_length - super().__init__(f"The length of {identifier_type} {identifier_name} exceeds {max_identifier_length} allowed for {destination_name}") + super().__init__( + f"The length of {identifier_type} {identifier_name} exceeds" + f" {max_identifier_length} allowed for {destination_name}" + ) class DestinationHasFailedJobs(DestinationTerminalException): @@ -170,7 +201,9 @@ def __init__(self, destination_name: str, load_id: str, failed_jobs: List[Any]) self.destination_name = destination_name self.load_id = load_id self.failed_jobs = failed_jobs - super().__init__(f"Destination {destination_name} has failed jobs in load package {load_id}") + super().__init__( + f"Destination {destination_name} has failed jobs in load package {load_id}" + ) class PipelineException(DltException): @@ -183,21 +216,37 @@ def __init__(self, pipeline_name: str, msg: str) -> None: class PipelineStateNotAvailable(PipelineException): def __init__(self, source_state_key: Optional[str] = None) -> None: if source_state_key: - msg = f"The source {source_state_key} requested the access to pipeline state but no pipeline is active right now." + msg = ( + f"The source {source_state_key} requested the access to pipeline state but no" + " pipeline is active right now." + ) else: - msg = "The resource you called requested the access to pipeline state but no pipeline is active right now." - msg += " Call dlt.pipeline(...) before you call the @dlt.source or @dlt.resource decorated function." + msg = ( + "The resource you called requested the access to pipeline state but no pipeline is" + " active right now." + ) + msg += ( + " Call dlt.pipeline(...) before you call the @dlt.source or @dlt.resource decorated" + " function." + ) self.source_state_key = source_state_key super().__init__(None, msg) class ResourceNameNotAvailable(PipelineException): def __init__(self) -> None: - super().__init__(None, - "A resource state was requested but no active extract pipe context was found. Resource state may be only requested from @dlt.resource decorated function or with explicit resource name.") + super().__init__( + None, + "A resource state was requested but no active extract pipe context was found. Resource" + " state may be only requested from @dlt.resource decorated function or with explicit" + " resource name.", + ) class SourceSectionNotAvailable(PipelineException): def __init__(self) -> None: - msg = "Access to state was requested without source section active. State should be requested from within the @dlt.source and @dlt.resource decorated function." + msg = ( + "Access to state was requested without source section active. State should be requested" + " from within the @dlt.source and @dlt.resource decorated function." + ) super().__init__(None, msg) diff --git a/dlt/common/git.py b/dlt/common/git.py index 602e889a36..551291947f 100644 --- a/dlt/common/git.py +++ b/dlt/common/git.py @@ -1,13 +1,13 @@ import os import tempfile -import giturlparse -from typing import Iterator, Optional, TYPE_CHECKING from contextlib import contextmanager +from typing import TYPE_CHECKING, Iterator, Optional + +import giturlparse from dlt.common.storages import FileStorage -from dlt.common.utils import uniq_id from dlt.common.typing import Any - +from dlt.common.utils import uniq_id # NOTE: never import git module directly as it performs a check if the git command is available and raises ImportError if TYPE_CHECKING: @@ -15,6 +15,7 @@ else: Repo = Any + @contextmanager def git_custom_key_command(private_key: Optional[str]) -> Iterator[str]: if private_key: @@ -24,7 +25,9 @@ def git_custom_key_command(private_key: Optional[str]) -> Iterator[str]: try: # permissions so SSH does not complain os.chmod(key_file, 0o600) - yield 'ssh -o "StrictHostKeyChecking accept-new" -i "%s"' % key_file.replace("\\", "\\\\") + yield 'ssh -o "StrictHostKeyChecking accept-new" -i "%s"' % key_file.replace( + "\\", "\\\\" + ) finally: os.remove(key_file) else: @@ -46,6 +49,7 @@ def is_dirty(repo: Repo) -> bool: status: str = repo.git.status("--short") return len(status.strip()) > 0 + # def is_dirty(repo: Repo) -> bool: # # get branch status # status: str = repo.git.status("--short", "--branch") @@ -53,7 +57,9 @@ def is_dirty(repo: Repo) -> bool: # return len(status.splitlines()) > 1 -def ensure_remote_head(repo_path: str, branch: Optional[str] = None, with_git_command: Optional[str] = None) -> None: +def ensure_remote_head( + repo_path: str, branch: Optional[str] = None, with_git_command: Optional[str] = None +) -> None: from git import Repo, RepositoryDirtyError # update remotes and check if heads are same. ignores locally modified files @@ -70,7 +76,12 @@ def ensure_remote_head(repo_path: str, branch: Optional[str] = None, with_git_co raise RepositoryDirtyError(repo, status) -def clone_repo(repository_url: str, clone_path: str, branch: Optional[str] = None, with_git_command: Optional[str] = None) -> Repo: +def clone_repo( + repository_url: str, + clone_path: str, + branch: Optional[str] = None, + with_git_command: Optional[str] = None, +) -> Repo: from git import Repo repo = Repo.clone_from(repository_url, clone_path, env=dict(GIT_SSH_COMMAND=with_git_command)) @@ -79,7 +90,13 @@ def clone_repo(repository_url: str, clone_path: str, branch: Optional[str] = Non return repo -def force_clone_repo(repo_url: str, repo_storage: FileStorage, repo_name: str, branch: Optional[str] = None, with_git_command: Optional[str] = None) -> None: +def force_clone_repo( + repo_url: str, + repo_storage: FileStorage, + repo_name: str, + branch: Optional[str] = None, + with_git_command: Optional[str] = None, +) -> None: """Deletes the working directory repo_storage.root/repo_name and clones the `repo_url` into it. Will checkout `branch` if provided""" try: # delete repo folder @@ -89,7 +106,7 @@ def force_clone_repo(repo_url: str, repo_storage: FileStorage, repo_name: str, b repo_url, repo_storage.make_full_path(repo_name), branch=branch, - with_git_command=with_git_command + with_git_command=with_git_command, ).close() except Exception: # delete folder so we start clean next time @@ -98,7 +115,12 @@ def force_clone_repo(repo_url: str, repo_storage: FileStorage, repo_name: str, b raise -def get_fresh_repo_files(repo_location: str, working_dir: str = None, branch: Optional[str] = None, with_git_command: Optional[str] = None) -> FileStorage: +def get_fresh_repo_files( + repo_location: str, + working_dir: str = None, + branch: Optional[str] = None, + with_git_command: Optional[str] = None, +) -> FileStorage: """Returns a file storage leading to the newest repository files. If `repo_location` is url, file will be checked out into `working_dir/repo_name`""" from git import GitError @@ -113,7 +135,13 @@ def get_fresh_repo_files(repo_location: str, working_dir: str = None, branch: Op try: ensure_remote_head(repo_path, branch=branch, with_git_command=with_git_command) except GitError: - force_clone_repo(repo_location, FileStorage(working_dir, makedirs=True), repo_name, branch=branch, with_git_command=with_git_command) + force_clone_repo( + repo_location, + FileStorage(working_dir, makedirs=True), + repo_name, + branch=branch, + with_git_command=with_git_command, + ) return FileStorage(repo_path) diff --git a/dlt/common/json/__init__.py b/dlt/common/json/__init__.py index f38d95a4c1..c201291173 100644 --- a/dlt/common/json/__init__.py +++ b/dlt/common/json/__init__.py @@ -1,16 +1,16 @@ -import os import base64 import dataclasses +import os from datetime import date, datetime # noqa: I251 -from typing import Any, Callable, List, Protocol, IO, Union +from typing import IO, Any, Callable, List, Protocol, Union from uuid import UUID -from hexbytes import HexBytes +from hexbytes import HexBytes from dlt.common.arithmetics import Decimal -from dlt.common.wei import Wei -from dlt.common.utils import map_nested_in_place from dlt.common.time import parse_iso_like_datetime +from dlt.common.utils import map_nested_in_place +from dlt.common.wei import Wei class SupportsJson(Protocol): @@ -19,10 +19,10 @@ class SupportsJson(Protocol): _impl_name: str """Implementation name""" - def dump(self, obj: Any, fp: IO[bytes], sort_keys: bool = False, pretty:bool = False) -> None: + def dump(self, obj: Any, fp: IO[bytes], sort_keys: bool = False, pretty: bool = False) -> None: ... - def typed_dump(self, obj: Any, fp: IO[bytes], pretty:bool = False) -> None: + def typed_dump(self, obj: Any, fp: IO[bytes], pretty: bool = False) -> None: ... def typed_dumps(self, obj: Any, sort_keys: bool = False, pretty: bool = False) -> str: @@ -37,10 +37,10 @@ def typed_dumpb(self, obj: Any, sort_keys: bool = False, pretty: bool = False) - def typed_loadb(self, s: Union[bytes, bytearray, memoryview]) -> Any: ... - def dumps(self, obj: Any, sort_keys: bool = False, pretty:bool = False) -> str: + def dumps(self, obj: Any, sort_keys: bool = False, pretty: bool = False) -> str: ... - def dumpb(self, obj: Any, sort_keys: bool = False, pretty:bool = False) -> bytes: + def dumpb(self, obj: Any, sort_keys: bool = False, pretty: bool = False) -> bytes: ... def load(self, fp: IO[bytes]) -> Any: @@ -64,8 +64,8 @@ def custom_encode(obj: Any) -> str: # leave microseconds alone # if obj.microsecond: # r = r[:23] + r[26:] - if r.endswith('+00:00'): - r = r[:-6] + 'Z' + if r.endswith("+00:00"): + r = r[:-6] + "Z" return r elif isinstance(obj, date): return obj.isoformat() @@ -74,10 +74,10 @@ def custom_encode(obj: Any) -> str: elif isinstance(obj, HexBytes): return obj.hex() elif isinstance(obj, bytes): - return base64.b64encode(obj).decode('ascii') - elif hasattr(obj, 'asdict'): + return base64.b64encode(obj).decode("ascii") + elif hasattr(obj, "asdict"): return obj.asdict() # type: ignore - elif hasattr(obj, '_asdict'): + elif hasattr(obj, "_asdict"): return obj._asdict() # type: ignore elif dataclasses.is_dataclass(obj): return dataclasses.asdict(obj) # type: ignore @@ -85,13 +85,13 @@ def custom_encode(obj: Any) -> str: # use PUA range to encode additional types -_DECIMAL = '\uF026' -_DATETIME = '\uF027' -_DATE = '\uF028' -_UUIDT = '\uF029' -_HEXBYTES = '\uF02A' -_B64BYTES = '\uF02B' -_WEI = '\uF02C' +_DECIMAL = "\uf026" +_DATETIME = "\uf027" +_DATE = "\uf028" +_UUIDT = "\uf029" +_HEXBYTES = "\uf02a" +_B64BYTES = "\uf02b" +_WEI = "\uf02c" DECODERS: List[Callable[[Any], Any]] = [ Decimal, @@ -100,7 +100,7 @@ def custom_encode(obj: Any) -> str: UUID, HexBytes, base64.b64decode, - Wei + Wei, ] @@ -113,8 +113,8 @@ def custom_pua_encode(obj: Any) -> str: # this works both for standard datetime and pendulum elif isinstance(obj, datetime): r = obj.isoformat() - if r.endswith('+00:00'): - r = r[:-6] + 'Z' + if r.endswith("+00:00"): + r = r[:-6] + "Z" return _DATETIME + r elif isinstance(obj, date): return _DATE + obj.isoformat() @@ -123,10 +123,10 @@ def custom_pua_encode(obj: Any) -> str: elif isinstance(obj, HexBytes): return _HEXBYTES + obj.hex() elif isinstance(obj, bytes): - return _B64BYTES + base64.b64encode(obj).decode('ascii') - elif hasattr(obj, 'asdict'): + return _B64BYTES + base64.b64encode(obj).decode("ascii") + elif hasattr(obj, "asdict"): return obj.asdict() # type: ignore - elif hasattr(obj, '_asdict'): + elif hasattr(obj, "_asdict"): return obj._asdict() # type: ignore elif dataclasses.is_dataclass(obj): return dataclasses.asdict(obj) # type: ignore @@ -137,7 +137,7 @@ def custom_pua_decode(obj: Any) -> Any: if isinstance(obj, str) and len(obj) > 1: c = ord(obj[0]) - 0xF026 # decode only the PUA space defined in DECODERS - if c >=0 and c <= 6: + if c >= 0 and c <= 6: return DECODERS[c](obj[1:]) return obj @@ -155,7 +155,7 @@ def custom_pua_remove(obj: Any) -> Any: if isinstance(obj, str) and len(obj) > 1: c = ord(obj[0]) - 0xF026 # decode only the PUA space defined in DECODERS - if c >=0 and c <= 6: + if c >= 0 and c <= 6: return obj[1:] return obj @@ -164,11 +164,14 @@ def custom_pua_remove(obj: Any) -> Any: json: SupportsJson = None if os.environ.get("DLT_USE_JSON") == "simplejson": from dlt.common.json import _simplejson as _json_d + json = _json_d else: try: from dlt.common.json import _orjson as _json_or + json = _json_or except ImportError: from dlt.common.json import _simplejson as _json_simple + json = _json_simple diff --git a/dlt/common/json/_orjson.py b/dlt/common/json/_orjson.py index ada91cbb1b..9b11dd312e 100644 --- a/dlt/common/json/_orjson.py +++ b/dlt/common/json/_orjson.py @@ -1,13 +1,16 @@ from typing import IO, Any, Union + import orjson -from dlt.common.json import custom_pua_encode, custom_pua_decode_nested, custom_encode +from dlt.common.json import custom_encode, custom_pua_decode_nested, custom_pua_encode from dlt.common.typing import AnyFun _impl_name = "orjson" -def _dumps(obj: Any, sort_keys: bool, pretty:bool, default:AnyFun = custom_encode, options: int = 0) -> bytes: +def _dumps( + obj: Any, sort_keys: bool, pretty: bool, default: AnyFun = custom_encode, options: int = 0 +) -> bytes: options = options | orjson.OPT_UTC_Z | orjson.OPT_NON_STR_KEYS if pretty: options |= orjson.OPT_INDENT_2 @@ -16,11 +19,11 @@ def _dumps(obj: Any, sort_keys: bool, pretty:bool, default:AnyFun = custom_encod return orjson.dumps(obj, default=default, option=options) -def dump(obj: Any, fp: IO[bytes], sort_keys: bool = False, pretty:bool = False) -> None: +def dump(obj: Any, fp: IO[bytes], sort_keys: bool = False, pretty: bool = False) -> None: fp.write(_dumps(obj, sort_keys, pretty)) -def typed_dump(obj: Any, fp: IO[bytes], pretty:bool = False) -> None: +def typed_dump(obj: Any, fp: IO[bytes], pretty: bool = False) -> None: fp.write(typed_dumpb(obj, pretty=pretty)) @@ -29,7 +32,7 @@ def typed_dumpb(obj: Any, sort_keys: bool = False, pretty: bool = False) -> byte def typed_dumps(obj: Any, sort_keys: bool = False, pretty: bool = False) -> str: - return typed_dumpb(obj, sort_keys, pretty).decode('utf-8') + return typed_dumpb(obj, sort_keys, pretty).decode("utf-8") def typed_loads(s: str) -> Any: @@ -40,11 +43,11 @@ def typed_loadb(s: Union[bytes, bytearray, memoryview]) -> Any: return custom_pua_decode_nested(loadb(s)) -def dumps(obj: Any, sort_keys: bool = False, pretty:bool = False) -> str: +def dumps(obj: Any, sort_keys: bool = False, pretty: bool = False) -> str: return _dumps(obj, sort_keys, pretty).decode("utf-8") -def dumpb(obj: Any, sort_keys: bool = False, pretty:bool = False) -> bytes: +def dumpb(obj: Any, sort_keys: bool = False, pretty: bool = False) -> bytes: return _dumps(obj, sort_keys, pretty) diff --git a/dlt/common/json/_simplejson.py b/dlt/common/json/_simplejson.py index c670717527..6f65927736 100644 --- a/dlt/common/json/_simplejson.py +++ b/dlt/common/json/_simplejson.py @@ -1,10 +1,10 @@ import codecs +import platform from typing import IO, Any, Union import simplejson -import platform -from dlt.common.json import custom_pua_encode, custom_pua_decode_nested, custom_encode +from dlt.common.json import custom_encode, custom_pua_decode_nested, custom_pua_encode if platform.python_implementation() == "PyPy": # disable speedups on PyPy, it can be actually faster than Python C @@ -15,7 +15,7 @@ _impl_name = "simplejson" -def dump(obj: Any, fp: IO[bytes], sort_keys: bool = False, pretty:bool = False) -> None: +def dump(obj: Any, fp: IO[bytes], sort_keys: bool = False, pretty: bool = False) -> None: if pretty: indent = 2 else: @@ -28,13 +28,13 @@ def dump(obj: Any, fp: IO[bytes], sort_keys: bool = False, pretty:bool = False) default=custom_encode, encoding=None, ensure_ascii=False, - separators=(',', ':'), + separators=(",", ":"), sort_keys=sort_keys, - indent=indent + indent=indent, ) -def typed_dump(obj: Any, fp: IO[bytes], pretty:bool = False) -> None: +def typed_dump(obj: Any, fp: IO[bytes], pretty: bool = False) -> None: if pretty: indent = 2 else: @@ -47,10 +47,11 @@ def typed_dump(obj: Any, fp: IO[bytes], pretty:bool = False) -> None: default=custom_pua_encode, encoding=None, ensure_ascii=False, - separators=(',', ':'), - indent=indent + separators=(",", ":"), + indent=indent, ) + def typed_dumps(obj: Any, sort_keys: bool = False, pretty: bool = False) -> str: indent = 2 if pretty else None return simplejson.dumps( @@ -59,8 +60,8 @@ def typed_dumps(obj: Any, sort_keys: bool = False, pretty: bool = False) -> str: default=custom_pua_encode, encoding=None, ensure_ascii=False, - separators=(',', ':'), - indent=indent + separators=(",", ":"), + indent=indent, ) @@ -69,14 +70,14 @@ def typed_loads(s: str) -> Any: def typed_dumpb(obj: Any, sort_keys: bool = False, pretty: bool = False) -> bytes: - return typed_dumps(obj, sort_keys, pretty).encode('utf-8') + return typed_dumps(obj, sort_keys, pretty).encode("utf-8") def typed_loadb(s: Union[bytes, bytearray, memoryview]) -> Any: return custom_pua_decode_nested(loadb(s)) -def dumps(obj: Any, sort_keys: bool = False, pretty:bool = False) -> str: +def dumps(obj: Any, sort_keys: bool = False, pretty: bool = False) -> str: if pretty: indent = 2 else: @@ -87,13 +88,13 @@ def dumps(obj: Any, sort_keys: bool = False, pretty:bool = False) -> str: default=custom_encode, encoding=None, ensure_ascii=False, - separators=(',', ':'), + separators=(",", ":"), sort_keys=sort_keys, - indent=indent + indent=indent, ) -def dumpb(obj: Any, sort_keys: bool = False, pretty:bool = False) -> bytes: +def dumpb(obj: Any, sort_keys: bool = False, pretty: bool = False) -> bytes: return dumps(obj, sort_keys, pretty).encode("utf-8") diff --git a/dlt/common/jsonpath.py b/dlt/common/jsonpath.py index f5922d5d16..7004d662f8 100644 --- a/dlt/common/jsonpath.py +++ b/dlt/common/jsonpath.py @@ -1,10 +1,10 @@ -from typing import Iterable, Union, List, Any from itertools import chain +from typing import Any, Iterable, List, Union -from dlt.common.typing import DictStrAny - -from jsonpath_ng import parse as _parse, JSONPath +from jsonpath_ng import JSONPath +from jsonpath_ng import parse as _parse +from dlt.common.typing import DictStrAny TJsonPath = Union[str, JSONPath] # Jsonpath compiled or str TAnyJsonPath = Union[TJsonPath, Iterable[TJsonPath]] # A single or multiple jsonpaths diff --git a/dlt/common/libs/pyarrow.py b/dlt/common/libs/pyarrow.py index 35d8017070..4cd716527c 100644 --- a/dlt/common/libs/pyarrow.py +++ b/dlt/common/libs/pyarrow.py @@ -1,14 +1,16 @@ from typing import Any, Tuple -from dlt import version -from dlt.common.exceptions import MissingDependencyException +from dlt import version from dlt.common.destination.capabilities import DestinationCapabilitiesContext +from dlt.common.exceptions import MissingDependencyException try: import pyarrow import pyarrow.parquet except ModuleNotFoundError: - raise MissingDependencyException("DLT parquet Helpers", [f"{version.DLT_PKG_NAME}[parquet]"], "DLT Helpers for for parquet.") + raise MissingDependencyException( + "DLT parquet Helpers", [f"{version.DLT_PKG_NAME}[parquet]"], "DLT Helpers for for parquet." + ) def get_py_arrow_datatype(column_type: str, caps: DestinationCapabilitiesContext, tz: str) -> Any: diff --git a/dlt/common/normalizers/__init__.py b/dlt/common/normalizers/__init__.py index cfe1f5beb3..8b0cf45709 100644 --- a/dlt/common/normalizers/__init__.py +++ b/dlt/common/normalizers/__init__.py @@ -1,3 +1,3 @@ from dlt.common.normalizers.configuration import NormalizersConfiguration from dlt.common.normalizers.typing import TJSONNormalizer, TNormalizersConfig -from dlt.common.normalizers.utils import explicit_normalizers, import_normalizers \ No newline at end of file +from dlt.common.normalizers.utils import explicit_normalizers, import_normalizers diff --git a/dlt/common/normalizers/configuration.py b/dlt/common/normalizers/configuration.py index 2c13367abd..c34af046d0 100644 --- a/dlt/common/normalizers/configuration.py +++ b/dlt/common/normalizers/configuration.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from dlt.common.configuration import configspec from dlt.common.configuration.specs import BaseConfiguration @@ -24,5 +24,6 @@ def on_resolved(self) -> None: self.naming = self.destination_capabilities.naming_convention if TYPE_CHECKING: + def __init__(self, naming: str = None, json_normalizer: TJSONNormalizer = None) -> None: ... diff --git a/dlt/common/normalizers/exceptions.py b/dlt/common/normalizers/exceptions.py index b8ad4baed3..248aecc7fe 100644 --- a/dlt/common/normalizers/exceptions.py +++ b/dlt/common/normalizers/exceptions.py @@ -9,4 +9,7 @@ class InvalidJsonNormalizer(NormalizerException): def __init__(self, required_normalizer: str, present_normalizer: str) -> None: self.required_normalizer = required_normalizer self.present_normalizer = present_normalizer - super().__init__(f"Operation requires {required_normalizer} normalizer while {present_normalizer} normalizer is present") + super().__init__( + f"Operation requires {required_normalizer} normalizer while" + f" {present_normalizer} normalizer is present" + ) diff --git a/dlt/common/normalizers/json/__init__.py b/dlt/common/normalizers/json/__init__.py index 949c9cf4b3..6e5bd460d8 100644 --- a/dlt/common/normalizers/json/__init__.py +++ b/dlt/common/normalizers/json/__init__.py @@ -1,7 +1,8 @@ import abc -from typing import Any, Generic, Type, Iterator, Tuple, Callable, Protocol, TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Generic, Iterator, Protocol, Tuple, Type, TypeVar + +from dlt.common.typing import DictStrAny, StrAny, TDataItem -from dlt.common.typing import DictStrAny, TDataItem, StrAny if TYPE_CHECKING: from dlt.common.schema import Schema else: @@ -15,14 +16,16 @@ # type var for data item normalizer config TNormalizerConfig = TypeVar("TNormalizerConfig", bound=Any) -class DataItemNormalizer(abc.ABC, Generic[TNormalizerConfig]): +class DataItemNormalizer(abc.ABC, Generic[TNormalizerConfig]): @abc.abstractmethod def __init__(self, schema: Schema) -> None: pass @abc.abstractmethod - def normalize_data_item(self, item: TDataItem, load_id: str, table_name: str) -> TNormalizedRowIterator: + def normalize_data_item( + self, item: TDataItem, load_id: str, table_name: str + ) -> TNormalizedRowIterator: pass @abc.abstractmethod diff --git a/dlt/common/normalizers/json/relational.py b/dlt/common/normalizers/json/relational.py index ce23e3dd58..c065dfcbc1 100644 --- a/dlt/common/normalizers/json/relational.py +++ b/dlt/common/normalizers/json/relational.py @@ -1,19 +1,21 @@ -from typing import Dict, List, Mapping, Optional, Sequence, Tuple, cast, TypedDict, Any +from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, TypedDict, cast + from dlt.common.data_types.typing import TDataType from dlt.common.normalizers.exceptions import InvalidJsonNormalizer +from dlt.common.normalizers.json import DataItemNormalizer as DataItemNormalizerBase +from dlt.common.normalizers.json import TNormalizedRowIterator, wrap_in_dict from dlt.common.normalizers.typing import TJSONNormalizer - -from dlt.common.typing import DictStrAny, DictStrStr, TDataItem, StrAny from dlt.common.schema import Schema -from dlt.common.schema.typing import TColumnSchema, TColumnName, TSimpleRegex +from dlt.common.schema.typing import TColumnName, TColumnSchema, TSimpleRegex from dlt.common.schema.utils import column_name_validator +from dlt.common.typing import DictStrAny, DictStrStr, StrAny, TDataItem from dlt.common.utils import digest128, uniq_id_base64, update_dict_nested -from dlt.common.normalizers.json import TNormalizedRowIterator, wrap_in_dict, DataItemNormalizer as DataItemNormalizerBase from dlt.common.validation import validate_dict EMPTY_KEY_IDENTIFIER = "_empty" # replace empty keys with this DLT_ID_LENGTH_BYTES = 10 + class TDataItemRow(TypedDict, total=False): _dlt_id: str # unique id of current row @@ -62,7 +64,7 @@ def _reset(self) -> None: # for those paths the complex nested objects should be left in place def _is_complex_type(self, table_name: str, field_name: str, _r_lvl: int) -> bool: # turn everything at the recursion level into complex type - max_nesting = self.max_nesting + max_nesting = self.max_nesting schema = self.schema assert _r_lvl <= max_nesting @@ -81,14 +83,9 @@ def _is_complex_type(self, table_name: str, field_name: str, _r_lvl: int) -> boo data_type = column["data_type"] return data_type == "complex" - def _flatten( - self, - table: str, - dict_row: TDataItemRow, - _r_lvl: int + self, table: str, dict_row: TDataItemRow, _r_lvl: int ) -> Tuple[TDataItemRow, Dict[Tuple[str, ...], Sequence[Any]]]: - out_rec_row: DictStrAny = {} out_rec_list: Dict[Tuple[str, ...], Sequence[Any]] = {} schema_naming = self.schema.naming @@ -102,7 +99,9 @@ def norm_row_dicts(dict_row: StrAny, __r_lvl: int, path: Tuple[str, ...] = ()) - norm_k = EMPTY_KEY_IDENTIFIER # if norm_k != k: # print(f"{k} -> {norm_k}") - child_name = norm_k if path == () else schema_naming.shorten_fragments(*path, norm_k) + child_name = ( + norm_k if path == () else schema_naming.shorten_fragments(*path, norm_k) + ) # for lists and dicts we must check if type is possibly complex if isinstance(v, (dict, list)): if not self._is_complex_type(table, child_name, __r_lvl): @@ -129,7 +128,6 @@ def _get_child_row_hash(parent_row_id: str, child_table: str, list_idx: int) -> # and all child tables must be lists return digest128(f"{parent_row_id}_{child_table}_{list_idx}", DLT_ID_LENGTH_BYTES) - @staticmethod def _link_row(row: TDataItemRowChild, parent_row_id: str, list_idx: int) -> TDataItemRowChild: assert parent_row_id @@ -142,7 +140,9 @@ def _link_row(row: TDataItemRowChild, parent_row_id: str, list_idx: int) -> TDat def _extend_row(extend: DictStrAny, row: TDataItemRow) -> None: row.update(extend) # type: ignore - def _add_row_id(self, table: str, row: TDataItemRow, parent_row_id: str, pos: int, _r_lvl: int) -> str: + def _add_row_id( + self, table: str, row: TDataItemRow, parent_row_id: str, pos: int, _r_lvl: int + ) -> str: # row_id is always random, no matter if primary_key is present or not row_id = uniq_id_base64(DLT_ID_LENGTH_BYTES) if _r_lvl > 0: @@ -184,19 +184,22 @@ def _normalize_list( ident_path: Tuple[str, ...], parent_path: Tuple[str, ...], parent_row_id: Optional[str] = None, - _r_lvl: int = 0 + _r_lvl: int = 0, ) -> TNormalizedRowIterator: - v: TDataItemRowChild = None table = self.schema.naming.shorten_fragments(*parent_path, *ident_path) for idx, v in enumerate(seq): # yield child table row if isinstance(v, dict): - yield from self._normalize_row(v, extend, ident_path, parent_path, parent_row_id, idx, _r_lvl) + yield from self._normalize_row( + v, extend, ident_path, parent_path, parent_row_id, idx, _r_lvl + ) elif isinstance(v, list): # to normalize lists of lists, we must create a tracking intermediary table by creating a mock row - yield from self._normalize_row({"list": v}, extend, ident_path, parent_path, parent_row_id, idx, _r_lvl + 1) + yield from self._normalize_row( + {"list": v}, extend, ident_path, parent_path, parent_row_id, idx, _r_lvl + 1 + ) else: # list of simple types child_row_hash = DataItemNormalizer._get_child_row_hash(parent_row_id, table, idx) @@ -214,9 +217,8 @@ def _normalize_row( parent_path: Tuple[str, ...] = (), parent_row_id: Optional[str] = None, pos: Optional[int] = None, - _r_lvl: int = 0 + _r_lvl: int = 0, ) -> TNormalizedRowIterator: - schema = self.schema table = schema.naming.shorten_fragments(*parent_path, *ident_path) @@ -230,18 +232,22 @@ def _normalize_row( row_id = self._add_row_id(table, flattened_row, parent_row_id, pos, _r_lvl) # find fields to propagate to child tables in config - extend.update(self._get_propagated_values(table, flattened_row, _r_lvl )) + extend.update(self._get_propagated_values(table, flattened_row, _r_lvl)) # yield parent table first yield (table, schema.naming.shorten_fragments(*parent_path)), flattened_row # normalize and yield lists for list_path, list_content in lists.items(): - yield from self._normalize_list(list_content, extend, list_path, parent_path + ident_path, row_id, _r_lvl + 1) + yield from self._normalize_list( + list_content, extend, list_path, parent_path + ident_path, row_id, _r_lvl + 1 + ) def extend_schema(self) -> None: # validate config - config = cast(RelationalNormalizerConfig, self.schema._normalizers_config["json"].get("config") or {}) + config = cast( + RelationalNormalizerConfig, self.schema._normalizers_config["json"].get("config") or {} + ) DataItemNormalizer._validate_normalizer_config(self.schema, config) # quick check to see if hints are applied @@ -252,16 +258,21 @@ def extend_schema(self) -> None: self.schema.merge_hints( { "not_null": [ - TSimpleRegex("_dlt_id"), TSimpleRegex("_dlt_root_id"), TSimpleRegex("_dlt_parent_id"), - TSimpleRegex("_dlt_list_idx"), TSimpleRegex("_dlt_load_id") - ], + TSimpleRegex("_dlt_id"), + TSimpleRegex("_dlt_root_id"), + TSimpleRegex("_dlt_parent_id"), + TSimpleRegex("_dlt_list_idx"), + TSimpleRegex("_dlt_load_id"), + ], "foreign_key": [TSimpleRegex("_dlt_parent_id")], "root_key": [TSimpleRegex("_dlt_root_id")], - "unique": [TSimpleRegex("_dlt_id")] + "unique": [TSimpleRegex("_dlt_id")], } ) - def normalize_data_item(self, item: TDataItem, load_id: str, table_name: str) -> TNormalizedRowIterator: + def normalize_data_item( + self, item: TDataItem, load_id: str, table_name: str + ) -> TNormalizedRowIterator: # wrap items that are not dictionaries in dictionary, otherwise they cannot be processed by the JSON normalizer if not isinstance(item, dict): item = wrap_in_dict(item) @@ -269,7 +280,11 @@ def normalize_data_item(self, item: TDataItem, load_id: str, table_name: str) -> row = cast(TDataItemRowRoot, item) # identify load id if loaded data must be processed after loading incrementally row["_dlt_load_id"] = load_id - yield from self._normalize_row(cast(TDataItemRowChild, row), {}, (self.schema.naming.normalize_table_identifier(table_name),)) + yield from self._normalize_row( + cast(TDataItemRowChild, row), + {}, + (self.schema.naming.normalize_table_identifier(table_name),), + ) @classmethod def ensure_this_normalizer(cls, norm_config: TJSONNormalizer) -> None: @@ -296,4 +311,9 @@ def get_normalizer_config(cls, schema: Schema) -> RelationalNormalizerConfig: @staticmethod def _validate_normalizer_config(schema: Schema, config: RelationalNormalizerConfig) -> None: - validate_dict(RelationalNormalizerConfig, config, "./normalizers/json/config", validator_f=column_name_validator(schema.naming)) + validate_dict( + RelationalNormalizerConfig, + config, + "./normalizers/json/config", + validator_f=column_name_validator(schema.naming), + ) diff --git a/dlt/common/normalizers/naming/__init__.py b/dlt/common/normalizers/naming/__init__.py index 28f3ae5bd4..88ee3a20b9 100644 --- a/dlt/common/normalizers/naming/__init__.py +++ b/dlt/common/normalizers/naming/__init__.py @@ -1,2 +1 @@ -from .naming import SupportsNamingConvention, NamingConvention - +from .naming import NamingConvention, SupportsNamingConvention diff --git a/dlt/common/normalizers/naming/direct.py b/dlt/common/normalizers/naming/direct.py index 3a973106fe..09403d9e53 100644 --- a/dlt/common/normalizers/naming/direct.py +++ b/dlt/common/normalizers/naming/direct.py @@ -17,4 +17,4 @@ def make_path(self, *identifiers: Any) -> str: return self.PATH_SEPARATOR.join(filter(lambda x: x.strip(), identifiers)) def break_path(self, path: str) -> Sequence[str]: - return [ident for ident in path.split(self.PATH_SEPARATOR) if ident.strip()] \ No newline at end of file + return [ident for ident in path.split(self.PATH_SEPARATOR) if ident.strip()] diff --git a/dlt/common/normalizers/naming/duck_case.py b/dlt/common/normalizers/naming/duck_case.py index 7c59b4daa4..6a4d2026d0 100644 --- a/dlt/common/normalizers/naming/duck_case.py +++ b/dlt/common/normalizers/naming/duck_case.py @@ -5,7 +5,6 @@ class NamingConvention(BaseNamingConvention): - _RE_NON_ALPHANUMERIC = re.compile(r"[^a-zA-Z\d_+-]+") _REDUCE_ALPHABET = ("*@|", "xal") _TR_REDUCE_ALPHABET = str.maketrans(_REDUCE_ALPHABET[0], _REDUCE_ALPHABET[1]) @@ -20,7 +19,5 @@ def _normalize_identifier(identifier: str, max_length: int) -> str: # shorten identifier return NamingConvention.shorten_identifier( - NamingConvention._to_snake_case(normalized_ident), - identifier, - max_length + NamingConvention._to_snake_case(normalized_ident), identifier, max_length ) diff --git a/dlt/common/normalizers/naming/exceptions.py b/dlt/common/normalizers/naming/exceptions.py index b76362962e..572fc7e0d0 100644 --- a/dlt/common/normalizers/naming/exceptions.py +++ b/dlt/common/normalizers/naming/exceptions.py @@ -1,4 +1,3 @@ - from dlt.common.exceptions import DltException @@ -19,5 +18,8 @@ def __init__(self, naming_module: str) -> None: class InvalidNamingModule(NormalizersException): def __init__(self, naming_module: str) -> None: self.naming_module = naming_module - msg = f"Naming module {naming_module} does not implement required SupportsNamingConvention protocol" + msg = ( + f"Naming module {naming_module} does not implement required SupportsNamingConvention" + " protocol" + ) super().__init__(msg) diff --git a/dlt/common/normalizers/naming/naming.py b/dlt/common/normalizers/naming/naming.py index 80130bace6..c89b27d543 100644 --- a/dlt/common/normalizers/naming/naming.py +++ b/dlt/common/normalizers/naming/naming.py @@ -1,13 +1,12 @@ import base64 -from abc import abstractmethod, ABC -from functools import lru_cache -import math import hashlib +import math +from abc import ABC, abstractmethod +from functools import lru_cache from typing import Any, List, Protocol, Sequence, Type class NamingConvention(ABC): - _TR_TABLE = bytes.maketrans(b"/+", b"ab") _DEFAULT_COLLISION_PROB = 0.001 @@ -46,7 +45,9 @@ def normalize_path(self, path: str) -> str: def normalize_tables_path(self, path: str) -> str: """Breaks path of table identifiers, normalizes components, reconstitutes and shortens the path""" - normalized_idents = [self.normalize_table_identifier(ident) for ident in self.break_path(path)] + normalized_idents = [ + self.normalize_table_identifier(ident) for ident in self.break_path(path) + ] # shorten the whole path return self.shorten_identifier(self.make_path(*normalized_idents), path, self.max_length) @@ -59,7 +60,12 @@ def shorten_fragments(self, *normalized_idents: str) -> str: @staticmethod @lru_cache(maxsize=None) - def shorten_identifier(normalized_ident: str, identifier: str, max_length: int, collision_prob: float = _DEFAULT_COLLISION_PROB) -> str: + def shorten_identifier( + normalized_ident: str, + identifier: str, + max_length: int, + collision_prob: float = _DEFAULT_COLLISION_PROB, + ) -> str: """Shortens the `name` to `max_length` and adds a tag to it to make it unique. Tag may be placed in the middle or at the end""" if max_length and len(normalized_ident) > max_length: # use original identifier to compute tag @@ -72,9 +78,14 @@ def shorten_identifier(normalized_ident: str, identifier: str, max_length: int, def _compute_tag(identifier: str, collision_prob: float) -> str: # assume that shake_128 has perfect collision resistance 2^N/2 then collision prob is 1/resistance: prob = 1/2^N/2, solving for prob # take into account that we are case insensitive in base64 so we need ~1.5x more bits (2+1) - tl_bytes = int(((2+1)*math.log2(1/(collision_prob)) // 8) + 1) - tag = base64.b64encode(hashlib.shake_128(identifier.encode("utf-8")).digest(tl_bytes) - ).rstrip(b"=").translate(NamingConvention._TR_TABLE).lower().decode("ascii") + tl_bytes = int(((2 + 1) * math.log2(1 / (collision_prob)) // 8) + 1) + tag = ( + base64.b64encode(hashlib.shake_128(identifier.encode("utf-8")).digest(tl_bytes)) + .rstrip(b"=") + .translate(NamingConvention._TR_TABLE) + .lower() + .decode("ascii") + ) return tag @staticmethod @@ -82,7 +93,11 @@ def _trim_and_tag(identifier: str, tag: str, max_length: int) -> str: assert len(tag) <= max_length remaining_length = max_length - len(tag) remaining_overflow = remaining_length % 2 - identifier = identifier[:remaining_length // 2 + remaining_overflow] + tag + identifier[len(identifier) - remaining_length // 2:] + identifier = ( + identifier[: remaining_length // 2 + remaining_overflow] + + tag + + identifier[len(identifier) - remaining_length // 2 :] + ) assert len(identifier) == max_length return identifier diff --git a/dlt/common/normalizers/naming/snake_case.py b/dlt/common/normalizers/naming/snake_case.py index 67c9fdd30e..21a54d1ca2 100644 --- a/dlt/common/normalizers/naming/snake_case.py +++ b/dlt/common/normalizers/naming/snake_case.py @@ -1,12 +1,11 @@ import re -from typing import Any, List, Sequence from functools import lru_cache +from typing import Any, List, Sequence from dlt.common.normalizers.naming.naming import NamingConvention as BaseNamingConvention class NamingConvention(BaseNamingConvention): - _RE_UNDERSCORES = re.compile("__+") _RE_LEADING_DIGITS = re.compile(r"^\d+") # _RE_ENDING_UNDERSCORES = re.compile(r"_+$") @@ -41,16 +40,14 @@ def _normalize_identifier(identifier: str, max_length: int) -> str: # shorten identifier return NamingConvention.shorten_identifier( - NamingConvention._to_snake_case(normalized_ident), - identifier, - max_length + NamingConvention._to_snake_case(normalized_ident), identifier, max_length ) @staticmethod def _to_snake_case(identifier: str) -> str: # then convert to snake case - identifier = NamingConvention._SNAKE_CASE_BREAK_1.sub(r'\1_\2', identifier) - identifier = NamingConvention._SNAKE_CASE_BREAK_2.sub(r'\1_\2', identifier).lower() + identifier = NamingConvention._SNAKE_CASE_BREAK_1.sub(r"\1_\2", identifier) + identifier = NamingConvention._SNAKE_CASE_BREAK_2.sub(r"\1_\2", identifier).lower() # leading digits will be prefixed if NamingConvention._RE_LEADING_DIGITS.match(identifier): @@ -63,4 +60,4 @@ def _to_snake_case(identifier: str) -> str: # identifier = NamingConvention._RE_ENDING_UNDERSCORES.sub("x", identifier) # replace consecutive underscores with single one to prevent name clashes with PATH_SEPARATOR - return NamingConvention._RE_UNDERSCORES.sub("_", stripped_ident) \ No newline at end of file + return NamingConvention._RE_UNDERSCORES.sub("_", stripped_ident) diff --git a/dlt/common/normalizers/typing.py b/dlt/common/normalizers/typing.py index 93920fda1b..599426259f 100644 --- a/dlt/common/normalizers/typing.py +++ b/dlt/common/normalizers/typing.py @@ -11,4 +11,4 @@ class TJSONNormalizer(TypedDict, total=False): class TNormalizersConfig(TypedDict, total=False): names: str detections: Optional[List[str]] - json: TJSONNormalizer \ No newline at end of file + json: TJSONNormalizer diff --git a/dlt/common/normalizers/utils.py b/dlt/common/normalizers/utils.py index f8c08bd910..7533f3501a 100644 --- a/dlt/common/normalizers/utils.py +++ b/dlt/common/normalizers/utils.py @@ -1,13 +1,13 @@ from importlib import import_module -from typing import Any, Type, Tuple, cast +from typing import Any, Tuple, Type, cast import dlt from dlt.common.configuration.inject import with_config from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.normalizers.configuration import NormalizersConfiguration -from dlt.common.normalizers.json import SupportsDataItemNormalizer, DataItemNormalizer +from dlt.common.normalizers.json import DataItemNormalizer, SupportsDataItemNormalizer from dlt.common.normalizers.naming import NamingConvention, SupportsNamingConvention -from dlt.common.normalizers.naming.exceptions import UnknownNamingModule, InvalidNamingModule +from dlt.common.normalizers.naming.exceptions import InvalidNamingModule, UnknownNamingModule from dlt.common.normalizers.typing import TJSONNormalizer, TNormalizersConfig DEFAULT_NAMING_MODULE = "dlt.common.normalizers.naming.snake_case" @@ -15,8 +15,7 @@ @with_config(spec=NormalizersConfiguration) def explicit_normalizers( - naming: str = dlt.config.value , - json_normalizer: TJSONNormalizer = dlt.config.value + naming: str = dlt.config.value, json_normalizer: TJSONNormalizer = dlt.config.value ) -> TNormalizersConfig: """Gets explicitly configured normalizers - via config or destination caps. May return None as naming or normalizer""" return {"names": naming, "json": json_normalizer} @@ -25,15 +24,17 @@ def explicit_normalizers( @with_config def import_normalizers( normalizers_config: TNormalizersConfig, - destination_capabilities: DestinationCapabilitiesContext = None + destination_capabilities: DestinationCapabilitiesContext = None, ) -> Tuple[TNormalizersConfig, NamingConvention, Type[DataItemNormalizer[Any]]]: """Imports the normalizers specified in `normalizers_config` or taken from defaults. Returns the updated config and imported modules. - `destination_capabilities` are used to get max length of the identifier. + `destination_capabilities` are used to get max length of the identifier. """ # add defaults to normalizer_config normalizers_config["names"] = names = normalizers_config["names"] or "snake_case" - normalizers_config["json"] = item_normalizer = normalizers_config["json"] or {"module": "dlt.common.normalizers.json.relational"} + normalizers_config["json"] = item_normalizer = normalizers_config["json"] or { + "module": "dlt.common.normalizers.json.relational" + } try: if "." in names: # TODO: bump schema engine version and migrate schema. also change the name in TNormalizersConfig from names to naming @@ -43,16 +44,25 @@ def import_normalizers( naming_module = cast(SupportsNamingConvention, import_module(names)) else: # from known location - naming_module = cast(SupportsNamingConvention, import_module(f"dlt.common.normalizers.naming.{names}")) + naming_module = cast( + SupportsNamingConvention, import_module(f"dlt.common.normalizers.naming.{names}") + ) except ImportError: raise UnknownNamingModule(names) if not hasattr(naming_module, "NamingConvention"): raise InvalidNamingModule(names) # get max identifier length if destination_capabilities: - max_length = min(destination_capabilities.max_identifier_length, destination_capabilities.max_column_identifier_length) + max_length = min( + destination_capabilities.max_identifier_length, + destination_capabilities.max_column_identifier_length, + ) else: max_length = None json_module = cast(SupportsDataItemNormalizer, import_module(item_normalizer["module"])) - return normalizers_config, naming_module.NamingConvention(max_length), json_module.DataItemNormalizer + return ( + normalizers_config, + naming_module.NamingConvention(max_length), + json_module.DataItemNormalizer, + ) diff --git a/dlt/common/pendulum.py b/dlt/common/pendulum.py index 3d1c784488..9b6b2548de 100644 --- a/dlt/common/pendulum.py +++ b/dlt/common/pendulum.py @@ -1,8 +1,9 @@ from datetime import timedelta # noqa: I251 + import pendulum # noqa: I251 # force UTC as the local timezone to prevent local dates to be written to dbs -pendulum.set_local_timezone(pendulum.timezone('UTC')) +pendulum.set_local_timezone(pendulum.timezone("UTC")) def __utcnow() -> pendulum.DateTime: diff --git a/dlt/common/pipeline.py b/dlt/common/pipeline.py index 7a0a4aacfe..8c23835670 100644 --- a/dlt/common/pipeline.py +++ b/dlt/common/pipeline.py @@ -1,27 +1,44 @@ -import os +import contextlib import datetime # noqa: 251 +import os +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Dict, + List, + NamedTuple, + Optional, + Protocol, + Sequence, + Tuple, + TypedDict, +) + import humanize -import contextlib -from typing import Any, Callable, ClassVar, Dict, List, NamedTuple, Optional, Protocol, Sequence, TYPE_CHECKING, Tuple, TypedDict -from dlt.common import pendulum, logger -from dlt.common.configuration import configspec -from dlt.common.configuration import known_sections +from dlt.common import logger, pendulum +from dlt.common.configuration import configspec, known_sections from dlt.common.configuration.container import Container from dlt.common.configuration.exceptions import ContextDefaultCannotBeCreated -from dlt.common.configuration.specs import ContainerInjectableContext -from dlt.common.configuration.specs.config_section_context import ConfigSectionContext from dlt.common.configuration.paths import get_dlt_data_dir -from dlt.common.configuration.specs import RunConfiguration +from dlt.common.configuration.specs import ContainerInjectableContext, RunConfiguration +from dlt.common.configuration.specs.config_section_context import ConfigSectionContext +from dlt.common.data_writers.writers import TLoaderFileFormat from dlt.common.destination import DestinationReference, TDestinationReferenceArg -from dlt.common.exceptions import DestinationHasFailedJobs, PipelineStateNotAvailable, ResourceNameNotAvailable, SourceSectionNotAvailable +from dlt.common.exceptions import ( + DestinationHasFailedJobs, + PipelineStateNotAvailable, + ResourceNameNotAvailable, + SourceSectionNotAvailable, +) +from dlt.common.jsonpath import TAnyJsonPath, delete_matches from dlt.common.schema import Schema from dlt.common.schema.typing import TColumnNames, TColumnSchema, TWriteDisposition from dlt.common.source import get_current_pipe_name from dlt.common.storages.load_storage import LoadPackageInfo from dlt.common.typing import DictStrAny, REPattern -from dlt.common.jsonpath import delete_matches, TAnyJsonPath -from dlt.common.data_writers.writers import TLoaderFileFormat class ExtractDataInfo(TypedDict): @@ -46,6 +63,7 @@ def __str__(self) -> str: class NormalizeInfo(NamedTuple): """A tuple holding information on normalized data items. Returned by pipeline `normalize` method.""" + def asdict(self) -> DictStrAny: return {} @@ -58,6 +76,7 @@ def __str__(self) -> str: class LoadInfo(NamedTuple): """A tuple holding the information on recently loaded packages. Returned by pipeline `run` and `load` methods""" + pipeline: "SupportsPipeline" destination_name: str destination_displayable_credentials: str @@ -75,9 +94,7 @@ class LoadInfo(NamedTuple): def asdict(self) -> DictStrAny: """A dictionary representation of LoadInfo that can be loaded with `dlt`""" d = self._asdict() - d["pipeline"] = { - "pipeline_name": self.pipeline.pipeline_name - } + d["pipeline"] = {"pipeline_name": self.pipeline.pipeline_name} d["load_packages"] = [package.asdict() for package in self.load_packages] return d @@ -88,11 +105,20 @@ def asstr(self, verbosity: int = 0) -> str: msg += humanize.precisedelta(elapsed) else: msg += "---" - msg += f"\n{len(self.loads_ids)} load package(s) were loaded to destination {self.destination_name} and into dataset {self.dataset_name}\n" + msg += ( + f"\n{len(self.loads_ids)} load package(s) were loaded to destination" + f" {self.destination_name} and into dataset {self.dataset_name}\n" + ) if self.staging_name: - msg += f"The {self.staging_name} staging destination used {self.staging_displayable_credentials} location to stage data\n" - - msg += f"The {self.destination_name} destination used {self.destination_displayable_credentials} location to store data" + msg += ( + f"The {self.staging_name} staging destination used" + f" {self.staging_displayable_credentials} location to stage data\n" + ) + + msg += ( + f"The {self.destination_name} destination used" + f" {self.destination_displayable_credentials} location to store data" + ) for load_package in self.load_packages: cstr = load_package.state.upper() if load_package.completed_at else "NOT COMPLETED" # now enumerate all complete loads if we have any failed packages @@ -102,7 +128,9 @@ def asstr(self, verbosity: int = 0) -> str: msg += f"\nLoad package {load_package.load_id} is {cstr} and contains {jobs_str}" if verbosity > 0: for failed_job in failed_jobs: - msg += f"\n\t[{failed_job.job_file_info.job_id()}]: {failed_job.failed_message}\n" + msg += ( + f"\n\t[{failed_job.job_file_info.job_id()}]: {failed_job.failed_message}\n" + ) if verbosity > 1: msg += "\nPackage details:\n" msg += load_package.asstr() + "\n" @@ -121,11 +149,14 @@ def raise_on_failed_jobs(self) -> None: for load_package in self.load_packages: failed_jobs = load_package.jobs["failed_jobs"] if len(failed_jobs): - raise DestinationHasFailedJobs(self.destination_name, load_package.load_id, failed_jobs) + raise DestinationHasFailedJobs( + self.destination_name, load_package.load_id, failed_jobs + ) def __str__(self) -> str: return self.asstr(verbosity=1) + class TPipelineLocalState(TypedDict, total=False): first_run: bool """Indicates a first run of the pipeline, where run ends with successful loading of data""" @@ -135,6 +166,7 @@ class TPipelineLocalState(TypedDict, total=False): class TPipelineState(TypedDict, total=False): """Schema for a pipeline state that is stored within the pipeline working directory""" + pipeline_name: str dataset_name: str default_schema_name: Optional[str] @@ -157,6 +189,7 @@ class TSourceState(TPipelineState): class SupportsPipeline(Protocol): """A protocol with core pipeline operations that lets high level abstractions ie. sources to access pipeline methods and properties""" + pipeline_name: str """Name of the pipeline""" default_schema_name: str @@ -196,8 +229,8 @@ def run( columns: Sequence[TColumnSchema] = None, primary_key: TColumnNames = None, schema: Schema = None, - loader_file_format: TLoaderFileFormat = None - ) -> LoadInfo: + loader_file_format: TLoaderFileFormat = None, + ) -> LoadInfo: ... def _set_context(self, is_active: bool) -> None: @@ -218,7 +251,7 @@ def __call__( write_disposition: TWriteDisposition = None, columns: Sequence[TColumnSchema] = None, schema: Schema = None, - loader_file_format: TLoaderFileFormat = None + loader_file_format: TLoaderFileFormat = None, ) -> LoadInfo: ... @@ -265,17 +298,20 @@ class StateInjectableContext(ContainerInjectableContext): can_create_default: ClassVar[bool] = False if TYPE_CHECKING: + def __init__(self, state: TPipelineState = None) -> None: ... -def pipeline_state(container: Container, initial_default: TPipelineState = None) -> Tuple[TPipelineState, bool]: +def pipeline_state( + container: Container, initial_default: TPipelineState = None +) -> Tuple[TPipelineState, bool]: """Gets value of the state from context or active pipeline, if none found returns `initial_default` - Injected state is called "writable": it is injected by the `Pipeline` class and all the changes will be persisted. - The state coming from pipeline context or `initial_default` is called "read only" and all the changes to it will be discarded + Injected state is called "writable": it is injected by the `Pipeline` class and all the changes will be persisted. + The state coming from pipeline context or `initial_default` is called "read only" and all the changes to it will be discarded - Returns tuple (state, writable) + Returns tuple (state, writable) """ try: # get injected state if present. injected state is typically "managed" so changes will be persisted @@ -347,7 +383,9 @@ def source_state() -> DictStrAny: _last_full_state: TPipelineState = None -def _delete_source_state_keys(key: TAnyJsonPath, source_state_: Optional[DictStrAny] = None, /) -> None: +def _delete_source_state_keys( + key: TAnyJsonPath, source_state_: Optional[DictStrAny] = None, / +) -> None: """Remove one or more key from the source state. The `key` can be any number of keys and/or json paths to be removed. """ @@ -355,7 +393,9 @@ def _delete_source_state_keys(key: TAnyJsonPath, source_state_: Optional[DictStr delete_matches(key, state_) -def resource_state(resource_name: str = None, source_state_: Optional[DictStrAny] = None, /) -> DictStrAny: +def resource_state( + resource_name: str = None, source_state_: Optional[DictStrAny] = None, / +) -> DictStrAny: """Returns a dictionary with the resource-scoped state. Resource-scoped state is visible only to resource requesting the access. Dlt state is preserved across pipeline runs and may be used to implement incremental loads. Note that this function accepts the resource name as optional argument. There are rare cases when `dlt` is not able to resolve resource name due to requesting function @@ -405,10 +445,12 @@ def resource_state(resource_name: str = None, source_state_: Optional[DictStrAny resource_name = get_current_pipe_name() if not resource_name: raise ResourceNameNotAvailable() - return state_.setdefault('resources', {}).setdefault(resource_name, {}) # type: ignore + return state_.setdefault("resources", {}).setdefault(resource_name, {}) # type: ignore -def _reset_resource_state(resource_name: str, source_state_: Optional[DictStrAny] = None, /) -> None: +def _reset_resource_state( + resource_name: str, source_state_: Optional[DictStrAny] = None, / +) -> None: """Alpha version of the resource state. Resets the resource state Args: @@ -420,7 +462,9 @@ def _reset_resource_state(resource_name: str, source_state_: Optional[DictStrAny state_["resources"].pop(resource_name) -def _get_matching_resources(pattern: REPattern, source_state_: Optional[DictStrAny] = None, /) -> List[str]: +def _get_matching_resources( + pattern: REPattern, source_state_: Optional[DictStrAny] = None, / +) -> List[str]: """Get all resource names in state matching the regex pattern""" state_ = source_state() if source_state_ is None else source_state_ if "resources" not in state_: @@ -429,10 +473,10 @@ def _get_matching_resources(pattern: REPattern, source_state_: Optional[DictStrA def get_dlt_pipelines_dir() -> str: - """ Gets default directory where pipelines' data will be stored - 1. in user home directory ~/.dlt/pipelines/ - 2. if current user is root in /var/dlt/pipelines - 3. if current user does not have a home directory in /tmp/dlt/pipelines + """Gets default directory where pipelines' data will be stored + 1. in user home directory ~/.dlt/pipelines/ + 2. if current user is root in /var/dlt/pipelines + 3. if current user does not have a home directory in /tmp/dlt/pipelines """ return os.path.join(get_dlt_data_dir(), "pipelines") diff --git a/dlt/common/reflection/function_visitor.py b/dlt/common/reflection/function_visitor.py index 3b89403745..6cb6016a7f 100644 --- a/dlt/common/reflection/function_visitor.py +++ b/dlt/common/reflection/function_visitor.py @@ -2,6 +2,7 @@ from ast import NodeVisitor from typing import Any + class FunctionVisitor(NodeVisitor): def __init__(self, source: str): self.source = source diff --git a/dlt/common/reflection/spec.py b/dlt/common/reflection/spec.py index 58a75fb53e..219f4dcad9 100644 --- a/dlt/common/reflection/spec.py +++ b/dlt/common/reflection/spec.py @@ -1,13 +1,13 @@ -import re import inspect -from typing import Dict, List, Type, Any, Optional, NewType -from inspect import Signature, Parameter +import re +from inspect import Parameter, Signature +from typing import Any, Dict, List, NewType, Optional, Type -from dlt.common.typing import AnyType, AnyFun, TSecretValue -from dlt.common.configuration import configspec, is_valid_hint, is_secret_hint -from dlt.common.configuration.specs import BaseConfiguration +from dlt.common.configuration import configspec, is_secret_hint, is_valid_hint from dlt.common.configuration.accessors import DLT_CONFIG_VALUE, DLT_SECRETS_VALUE +from dlt.common.configuration.specs import BaseConfiguration from dlt.common.reflection.utils import get_func_def_node, get_literal_defaults +from dlt.common.typing import AnyFun, AnyType, TSecretValue from dlt.common.utils import get_callable_name # [^.^_]+ splits by . or _ @@ -15,7 +15,9 @@ def _get_spec_name_from_f(f: AnyFun) -> str: - func_name = get_callable_name(f, "__qualname__").replace(".", "") # func qual name contains position in the module, separated by dots + func_name = get_callable_name(f, "__qualname__").replace( + ".", "" + ) # func qual name contains position in the module, separated by dots def _first_up(s: str) -> str: return s[0].upper() + s[1:] @@ -23,7 +25,9 @@ def _first_up(s: str) -> str: return "".join(map(_first_up, _SLEEPING_CAT_SPLIT.findall(func_name))) + "Configuration" -def spec_from_signature(f: AnyFun, sig: Signature, include_defaults: bool = True) -> Type[BaseConfiguration]: +def spec_from_signature( + f: AnyFun, sig: Signature, include_defaults: bool = True +) -> Type[BaseConfiguration]: name = _get_spec_name_from_f(f) module = inspect.getmodule(f) @@ -60,7 +64,10 @@ def dlt_config_literal_to_type(arg_name: str) -> AnyType: for p in sig.parameters.values(): # skip *args and **kwargs, skip typical method params - if p.kind not in (Parameter.VAR_KEYWORD, Parameter.VAR_POSITIONAL) and p.name not in ["self", "cls"]: + if p.kind not in (Parameter.VAR_KEYWORD, Parameter.VAR_POSITIONAL) and p.name not in [ + "self", + "cls", + ]: field_type = AnyType if p.annotation == Parameter.empty else p.annotation # only valid hints and parameters with defaults are eligible if is_valid_hint(field_type) and p.default != Parameter.empty: diff --git a/dlt/common/reflection/utils.py b/dlt/common/reflection/utils.py index c9c1ad92ed..c699bb1f37 100644 --- a/dlt/common/reflection/utils.py +++ b/dlt/common/reflection/utils.py @@ -1,8 +1,9 @@ import ast import inspect -import astunparse from typing import Any, Dict, List, Optional, Sequence, Tuple +import astunparse + from dlt.common.typing import AnyFun @@ -68,12 +69,16 @@ def creates_func_def_name_node(func_def: ast.FunctionDef, source_lines: Sequence """Recreate function name as a ast.Name with known source code location""" func_name = ast.Name(func_def.name) func_name.lineno = func_name.end_lineno = func_def.lineno - func_name.col_offset = source_lines[func_name.lineno - 1].index(func_def.name) # find where function name starts + func_name.col_offset = source_lines[func_name.lineno - 1].index( + func_def.name + ) # find where function name starts func_name.end_col_offset = func_name.col_offset + len(func_def.name) return func_name -def rewrite_python_script(source_script_lines: List[str], transformed_nodes: List[Tuple[ast.AST, ast.AST]]) -> List[str]: +def rewrite_python_script( + source_script_lines: List[str], transformed_nodes: List[Tuple[ast.AST, ast.AST]] +) -> List[str]: """Replaces all the nodes present in `transformed_nodes` in the `script_lines`. The `transformed_nodes` is a tuple where the first element is must be a node with full location information created out of `script_lines`""" script_lines: List[str] = [] @@ -87,12 +92,12 @@ def rewrite_python_script(source_script_lines: List[str], transformed_nodes: Lis if last_offset >= 0: script_lines.append(source_script_lines[last_line][last_offset:]) # add all new lines from previous line to current - script_lines.extend(source_script_lines[last_line+1:node.lineno-1]) + script_lines.extend(source_script_lines[last_line + 1 : node.lineno - 1]) # add trailing characters until node in current line starts - script_lines.append(source_script_lines[node.lineno-1][:node.col_offset]) + script_lines.append(source_script_lines[node.lineno - 1][: node.col_offset]) elif last_offset >= 0: # no line change, add the characters from the end of previous node to the current - script_lines.append(source_script_lines[last_line][last_offset:node.col_offset]) + script_lines.append(source_script_lines[last_line][last_offset : node.col_offset]) # replace node value script_lines.append(astunparse.unparse(t_value).strip()) @@ -102,7 +107,7 @@ def rewrite_python_script(source_script_lines: List[str], transformed_nodes: Lis # add all that was missing if last_offset >= 0: script_lines.append(source_script_lines[last_line][last_offset:]) - script_lines.extend(source_script_lines[last_line+1:]) + script_lines.extend(source_script_lines[last_line + 1 :]) return script_lines diff --git a/dlt/common/runners/configuration.py b/dlt/common/runners/configuration.py index 3231f83807..57a4885e31 100644 --- a/dlt/common/runners/configuration.py +++ b/dlt/common/runners/configuration.py @@ -1,4 +1,4 @@ -from typing import Literal, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Literal, Optional from dlt.common.configuration import configspec from dlt.common.configuration.specs import BaseConfiguration @@ -13,9 +13,6 @@ class PoolRunnerConfiguration(BaseConfiguration): run_sleep: float = 0.1 # how long to sleep between runs with workload, seconds if TYPE_CHECKING: - def __init__( - self, - pool_type: TPoolType = None, - workers: int = None - ) -> None: + + def __init__(self, pool_type: TPoolType = None, workers: int = None) -> None: ... diff --git a/dlt/common/runners/pool_runner.py b/dlt/common/runners/pool_runner.py index 4263236008..c0e33da406 100644 --- a/dlt/common/runners/pool_runner.py +++ b/dlt/common/runners/pool_runner.py @@ -1,21 +1,24 @@ import multiprocessing +from multiprocessing.pool import Pool, ThreadPool from typing import Callable, Union, cast -from multiprocessing.pool import ThreadPool, Pool from dlt.common import logger, sleep -from dlt.common.runtime import init -from dlt.common.runners.runnable import Runnable, TPool +from dlt.common.exceptions import SignalReceivedException from dlt.common.runners.configuration import PoolRunnerConfiguration +from dlt.common.runners.runnable import Runnable, TPool from dlt.common.runners.typing import TRunMetrics -from dlt.common.runtime import signals -from dlt.common.exceptions import SignalReceivedException +from dlt.common.runtime import init, signals def create_pool(config: PoolRunnerConfiguration) -> Pool: if config.pool_type == "process": # if not fork method, provide initializer for logs and configuration if multiprocessing.get_start_method() != "fork" and init._INITIALIZED: - return Pool(processes=config.workers, initializer=init.initialize_runtime, initargs=(init._RUN_CONFIGURATION, )) + return Pool( + processes=config.workers, + initializer=init.initialize_runtime, + initargs=(init._RUN_CONFIGURATION,), + ) else: return Pool(processes=config.workers) elif config.pool_type == "thread": @@ -24,10 +27,14 @@ def create_pool(config: PoolRunnerConfiguration) -> Pool: return None -def run_pool(config: PoolRunnerConfiguration, run_f: Union[Runnable[TPool], Callable[[TPool], TRunMetrics]]) -> int: +def run_pool( + config: PoolRunnerConfiguration, run_f: Union[Runnable[TPool], Callable[[TPool], TRunMetrics]] +) -> int: # validate the run function if not isinstance(run_f, Runnable) and not callable(run_f): - raise ValueError(run_f, "Pool runner entry point must be a function f(pool: TPool) or Runnable") + raise ValueError( + run_f, "Pool runner entry point must be a function f(pool: TPool) or Runnable" + ) # start pool pool = create_pool(config) diff --git a/dlt/common/runners/runnable.py b/dlt/common/runners/runnable.py index d168c5482b..aea78d8172 100644 --- a/dlt/common/runners/runnable.py +++ b/dlt/common/runners/runnable.py @@ -1,11 +1,11 @@ from abc import ABC, abstractmethod from functools import wraps -from typing import Any, Dict, Type, TypeVar, TYPE_CHECKING, Union, Generic from multiprocessing.pool import Pool +from typing import TYPE_CHECKING, Any, Dict, Generic, Type, TypeVar, Union from weakref import WeakValueDictionary -from dlt.common.typing import TFun from dlt.common.runners.typing import TRunMetrics +from dlt.common.typing import TFun TPool = TypeVar("TPool", bound=Pool) @@ -49,6 +49,7 @@ def workermethod(f: TFun) -> TFun: Returns: TFun: wrapped worker function """ + @wraps(f) def _wrap(rid: Union[int, Runnable[TPool]], *args: Any, **kwargs: Any) -> Any: if isinstance(rid, int): @@ -94,4 +95,3 @@ def _wrap(rid: Union[int, Runnable[TPool]], *args: Any, **kwargs: Any) -> Any: # return f(config, *args, **kwargs) # return _wrap # type: ignore - diff --git a/dlt/common/runners/stdout.py b/dlt/common/runners/stdout.py index a9f4ab1438..07c15c3274 100644 --- a/dlt/common/runners/stdout.py +++ b/dlt/common/runners/stdout.py @@ -4,8 +4,8 @@ from threading import Thread from typing import Any, Generator, Iterator, List +from dlt.common.runners.synth_pickle import decode_last_obj, decode_obj, encode_obj from dlt.common.runners.venv import Venv -from dlt.common.runners.synth_pickle import decode_obj, decode_last_obj, encode_obj from dlt.common.typing import AnyFun @@ -26,14 +26,16 @@ def exec_to_stdout(f: AnyFun) -> Iterator[Any]: def iter_stdout(venv: Venv, command: str, *script_args: Any) -> Iterator[str]: # start a process in virtual environment, assume that text comes from stdout - with venv.start_command(command, *script_args, stdout=PIPE, stderr=PIPE, bufsize=1, text=True) as process: + with venv.start_command( + command, *script_args, stdout=PIPE, stderr=PIPE, bufsize=1, text=True + ) as process: exit_code: int = None line = "" stderr: List[str] = [] def _r_stderr() -> None: nonlocal stderr - for line in iter(process.stderr.readline, ''): + for line in iter(process.stderr.readline, ""): stderr.append(line) # read stderr with a thread, selectors do not work on windows @@ -41,7 +43,7 @@ def _r_stderr() -> None: t.start() # read stdout with - for line in iter(process.stdout.readline, ''): + for line in iter(process.stdout.readline, ""): if line.endswith("\n"): yield line[:-1] else: @@ -57,9 +59,11 @@ def _r_stderr() -> None: raise CalledProcessError(exit_code, command, output=line, stderr="".join(stderr)) -def iter_stdout_with_result(venv: Venv, command: str, *script_args: Any) -> Generator[str, None, Any]: +def iter_stdout_with_result( + venv: Venv, command: str, *script_args: Any +) -> Generator[str, None, Any]: """Yields stdout lines coming from remote process and returns the last result decoded with decode_obj. In case of exit code != 0 if exception is decoded - it will be raised, otherwise CalledProcessError is raised""" + it will be raised, otherwise CalledProcessError is raised""" last_result: Any = None try: for line in iter_stdout(venv, command, *script_args): diff --git a/dlt/common/runners/synth_pickle.py b/dlt/common/runners/synth_pickle.py index 420e89a74a..0de141d52f 100644 --- a/dlt/common/runners/synth_pickle.py +++ b/dlt/common/runners/synth_pickle.py @@ -1,8 +1,8 @@ -import io -import sys +import base64 import binascii +import io import pickle -import base64 +import sys from typing import Any, Sequence from dlt.common.utils import digest128b @@ -15,6 +15,7 @@ def __init__(*args: Any, **kwargs: Any) -> None: class SynthesizingUnpickler(pickle.Unpickler): """Unpickler that synthesizes missing types instead of raising""" + def find_class(self, module: str, name: str) -> Any: if module not in sys.modules: module_obj = sys.modules[__name__] @@ -24,7 +25,7 @@ def find_class(self, module: str, name: str) -> Any: return getattr(module_obj, name) except Exception: # synthesize type - t = type(name, (MissingUnpickledType, ), {"__module__": module}) + t = type(name, (MissingUnpickledType,), {"__module__": module}) setattr(module_obj, name, t) return t diff --git a/dlt/common/runners/venv.py b/dlt/common/runners/venv.py index e4e8532248..538047428d 100644 --- a/dlt/common/runners/venv.py +++ b/dlt/common/runners/venv.py @@ -1,9 +1,9 @@ -import sys import os import shutil -import venv -import types import subprocess +import sys +import types +import venv from typing import Any, List, Type from dlt.common.exceptions import CannotInstallDependency, VenvNotFound @@ -19,7 +19,7 @@ def post_setup(self, context: types.SimpleNamespace) -> None: self.context = context -class Venv(): +class Venv: """Creates and wraps the Python Virtual Environment to allow for code execution""" def __init__(self, context: types.SimpleNamespace, current: bool = False) -> None: @@ -59,6 +59,7 @@ def restore_current(cls) -> "Venv": venv = cls.restore(os.environ["VIRTUAL_ENV"], current=True) except KeyError: import sys + bin_path, _ = os.path.split(sys.executable) context = types.SimpleNamespace(bin_path=bin_path, env_exe=sys.executable) venv = cls(context, current=True) @@ -69,7 +70,9 @@ def __enter__(self) -> "Venv": raise NotImplementedError("Context manager does not work with current venv") return self - def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: types.TracebackType) -> None: + def __exit__( + self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: types.TracebackType + ) -> None: self.delete_environment() def delete_environment(self) -> None: @@ -80,7 +83,9 @@ def delete_environment(self) -> None: if self.context.env_dir and os.path.isdir(self.context.env_dir): shutil.rmtree(self.context.env_dir) - def start_command(self, entry_point: str, *script_args: Any, **popen_kwargs: Any) -> "subprocess.Popen[str]": + def start_command( + self, entry_point: str, *script_args: Any, **popen_kwargs: Any + ) -> "subprocess.Popen[str]": command = os.path.join(self.context.bin_path, entry_point) cmd = [command, *script_args] return subprocess.Popen(cmd, **popen_kwargs) diff --git a/dlt/common/runtime/__init__.py b/dlt/common/runtime/__init__.py index 65f5c11696..5272e427f7 100644 --- a/dlt/common/runtime/__init__.py +++ b/dlt/common/runtime/__init__.py @@ -1 +1 @@ -from .init import initialize_runtime \ No newline at end of file +from .init import initialize_runtime diff --git a/dlt/common/runtime/collector.py b/dlt/common/runtime/collector.py index 850f5099f9..a1dd6b8db4 100644 --- a/dlt/common/runtime/collector.py +++ b/dlt/common/runtime/collector.py @@ -1,15 +1,30 @@ +import logging import os import sys -import logging import time from abc import ABC, abstractmethod from collections import defaultdict -from typing import Any, ContextManager, Dict, Type, TYPE_CHECKING, DefaultDict, NamedTuple, Optional, Union, TextIO, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + ContextManager, + DefaultDict, + Dict, + NamedTuple, + Optional, + TextIO, + Type, + TypeVar, + Union, +) + if TYPE_CHECKING: - from tqdm import tqdm import enlighten - from enlighten import Counter as EnlCounter, StatusBar as EnlStatusBar, Manager as EnlManager from alive_progress import alive_bar + from enlighten import Counter as EnlCounter + from enlighten import Manager as EnlManager + from enlighten import StatusBar as EnlStatusBar + from tqdm import tqdm else: tqdm = EnlCounter = EnlStatusBar = EnlManager = Any @@ -20,11 +35,12 @@ class Collector(ABC): - step: str @abstractmethod - def update(self, name: str, inc: int = 1, total: int = None, message: str = None, label: str = None) -> None: + def update( + self, name: str, inc: int = 1, total: int = None, message: str = None, label: str = None + ) -> None: """Creates or updates a counter This function updates a counter `name` with a value `inc`. If counter does not exist, it is created with optional total value of `total`. @@ -65,7 +81,9 @@ def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb class NullCollector(Collector): """A default counter that does not count anything.""" - def update(self, name: str, inc: int = 1, total: int = None, message: str = None, label: str = None) -> None: + def update( + self, name: str, inc: int = 1, total: int = None, message: str = None, label: str = None + ) -> None: pass def _start(self, step: str) -> None: @@ -81,7 +99,9 @@ class DictCollector(Collector): def __init__(self) -> None: self.counters: DefaultDict[str, int] = None - def update(self, name: str, inc: int = 1, total: int = None, message: str = None, label: str = None) -> None: + def update( + self, name: str, inc: int = 1, total: int = None, message: str = None, label: str = None + ) -> None: assert not label, "labels not supported in dict collector" self.counters[name] += inc @@ -103,7 +123,13 @@ class CounterInfo(NamedTuple): start_time: float total: Optional[int] - def __init__(self, log_period: float = 1.0, logger: Union[logging.Logger, TextIO] = sys.stdout, log_level: int = logging.INFO, dump_system_stats: bool = True) -> None: + def __init__( + self, + log_period: float = 1.0, + logger: Union[logging.Logger, TextIO] = sys.stdout, + log_level: int = logging.INFO, + dump_system_stats: bool = True, + ) -> None: """ Collector writing to a `logger` every `log_period` seconds. The logger can be a Python logger instance, text stream, or None that will attach `dlt` logger @@ -123,12 +149,19 @@ def __init__(self, log_period: float = 1.0, logger: Union[logging.Logger, TextIO try: import psutil except ImportError: - self._log(logging.WARNING, "psutil dependency is not installed and mem stats will not be available. add psutil to your environment or pass dump_system_stats argument as False to disable warning.") + self._log( + logging.WARNING, + "psutil dependency is not installed and mem stats will not be available. add" + " psutil to your environment or pass dump_system_stats argument as False to" + " disable warning.", + ) dump_system_stats = False self.dump_system_stats = dump_system_stats self.last_log_time: float = None - def update(self, name: str, inc: int = 1, total: int = None, message: str = None, label: str = None) -> None: + def update( + self, name: str, inc: int = 1, total: int = None, message: str = None, label: str = None + ) -> None: counter_key = f"{name}_{label}" if label else name if counter_key not in self.counters: @@ -169,7 +202,10 @@ def dump_counters(self) -> None: items_per_second_str = f"{items_per_second:.2f}/s" message = f"[{self.messages[name]}]" if self.messages[name] is not None else "" - counter_line = f"{info.description}: {progress} {percentage} | Time: {elapsed_time_str} | Rate: {items_per_second_str} {message}" + counter_line = ( + f"{info.description}: {progress} {percentage} | Time: {elapsed_time_str} | Rate:" + f" {items_per_second_str} {message}" + ) log_lines.append(counter_line.strip()) if self.dump_system_stats: @@ -177,10 +213,13 @@ def dump_counters(self) -> None: process = psutil.Process(os.getpid()) mem_info = process.memory_info() - current_mem = mem_info.rss / (1024 ** 2) # Convert to MB + current_mem = mem_info.rss / (1024**2) # Convert to MB mem_percent = psutil.virtual_memory().percent cpu_percent = process.cpu_percent() - log_lines.append(f"Memory usage: {current_mem:.2f} MB ({mem_percent:.2f}%) | CPU usage: {cpu_percent:.2f}%") + log_lines.append( + f"Memory usage: {current_mem:.2f} MB ({mem_percent:.2f}%) | CPU usage:" + f" {cpu_percent:.2f}%" + ) log_lines.append("") log_message = "\n".join(log_lines) @@ -218,12 +257,16 @@ def __init__(self, single_bar: bool = False, **tqdm_kwargs: Any) -> None: global tqdm from tqdm import tqdm except ModuleNotFoundError: - raise MissingDependencyException("TqdmCollector", ["tqdm"], "We need tqdm to display progress bars.") + raise MissingDependencyException( + "TqdmCollector", ["tqdm"], "We need tqdm to display progress bars." + ) self.single_bar = single_bar self._bars: Dict[str, tqdm] = {} self.tqdm_kwargs = tqdm_kwargs or {} - def update(self, name: str, inc: int = 1, total: int = None, message: str = None, label: str = "") -> None: + def update( + self, name: str, inc: int = 1, total: int = None, message: str = None, label: str = "" + ) -> None: key = f"{name}_{label}" bar = self._bars.get(key) if bar is None: @@ -263,13 +306,19 @@ def __init__(self, single_bar: bool = True, **alive_kwargs: Any) -> None: from alive_progress import alive_bar except ModuleNotFoundError: - raise MissingDependencyException("AliveCollector", ["alive-progress"], "We need alive-progress to display progress bars.") + raise MissingDependencyException( + "AliveCollector", + ["alive-progress"], + "We need alive-progress to display progress bars.", + ) self.single_bar = single_bar self._bars: Dict[str, Any] = {} self._bars_contexts: Dict[str, ContextManager[Any]] = {} self.alive_kwargs = alive_kwargs or {} - def update(self, name: str, inc: int = 1, total: int = None, message: str = None, label: str = "") -> None: + def update( + self, name: str, inc: int = 1, total: int = None, message: str = None, label: str = "" + ) -> None: key = f"{name}_{label}" bar = self._bars.get(key) if bar is None: @@ -313,13 +362,21 @@ def __init__(self, single_bar: bool = False, **enlighten_kwargs: Any) -> None: global enlighten import enlighten - from enlighten import Counter as EnlCounter, StatusBar as EnlStatusBar, Manager as EnlManager + from enlighten import Counter as EnlCounter + from enlighten import Manager as EnlManager + from enlighten import StatusBar as EnlStatusBar except ModuleNotFoundError: - raise MissingDependencyException("EnlightenCollector", ["enlighten"], "We need enlighten to display progress bars with a space for log messages.") + raise MissingDependencyException( + "EnlightenCollector", + ["enlighten"], + "We need enlighten to display progress bars with a space for log messages.", + ) self.single_bar = single_bar self.enlighten_kwargs = enlighten_kwargs - def update(self, name: str, inc: int = 1, total: int = None, message: str = None, label: str = "") -> None: + def update( + self, name: str, inc: int = 1, total: int = None, message: str = None, label: str = "" + ) -> None: key = f"{name}_{label}" bar = self._bars.get(key) if bar is None: @@ -328,7 +385,9 @@ def update(self, name: str, inc: int = 1, total: int = None, message: str = Non if len(self._bars) > 0 and self.single_bar: # do not add any more counters return - bar = self._manager.counter(desc=name, total=total, leave=True, force=True, **self.enlighten_kwargs) + bar = self._manager.counter( + desc=name, total=total, leave=True, force=True, **self.enlighten_kwargs + ) bar.refresh() self._bars[key] = bar bar.update(inc) @@ -336,7 +395,9 @@ def update(self, name: str, inc: int = 1, total: int = None, message: str = Non def _start(self, step: str) -> None: self._bars = {} self._manager = enlighten.get_manager(enabled=True) - self._status = self._manager.status_bar(leave=True, justify=enlighten.Justify.CENTER, fill="=") + self._status = self._manager.status_bar( + leave=True, justify=enlighten.Justify.CENTER, fill="=" + ) self._status.update(step) def _stop(self) -> None: @@ -352,4 +413,4 @@ def _stop(self) -> None: self._status = None -NULL_COLLECTOR = NullCollector() +NULL_COLLECTOR = NullCollector() diff --git a/dlt/common/runtime/exec_info.py b/dlt/common/runtime/exec_info.py index ecb8376aa7..51f87745eb 100644 --- a/dlt/common/runtime/exec_info.py +++ b/dlt/common/runtime/exec_info.py @@ -1,13 +1,22 @@ +import contextlib import io import os -import contextlib -from dlt.common.typing import StrStr, StrAny, Literal, List +from dlt.common.typing import List, Literal, StrAny, StrStr from dlt.common.utils import filter_env_vars from dlt.version import __version__ - -TExecInfoNames = Literal["kubernetes", "docker", "codespaces", "github_actions", "airflow", "notebook", "colab","aws_lambda","gcp_cloud_function"] +TExecInfoNames = Literal[ + "kubernetes", + "docker", + "codespaces", + "github_actions", + "airflow", + "notebook", + "colab", + "aws_lambda", + "gcp_cloud_function", +] # if one of these environment variables is set, we assume to be running in CI env CI_ENVIRONMENT_TELL = [ "bamboo.buildKey", @@ -100,7 +109,7 @@ def is_running_in_airflow_task() -> bool: from airflow.operators.python import get_current_context context = get_current_context() - return context is not None and 'ti' in context + return context is not None and "ti" in context except Exception: return False @@ -163,4 +172,4 @@ def is_aws_lambda() -> bool: def is_gcp_cloud_function() -> bool: "Return True if the process is running in the serverless platform GCP Cloud Functions" - return os.environ.get("FUNCTION_NAME") is not None \ No newline at end of file + return os.environ.get("FUNCTION_NAME") is not None diff --git a/dlt/common/runtime/logger.py b/dlt/common/runtime/logger.py index 14fa238d90..5ecc4628fb 100644 --- a/dlt/common/runtime/logger.py +++ b/dlt/common/runtime/logger.py @@ -1,14 +1,15 @@ import contextlib import logging -import json_logging import traceback -from logging import LogRecord, Logger +from logging import Logger, LogRecord from typing import Any, Iterator, Protocol +import json_logging + +from dlt.common.configuration.specs import RunConfiguration from dlt.common.json import json from dlt.common.runtime.exec_info import dlt_version_info from dlt.common.typing import StrAny, StrStr -from dlt.common.configuration.specs import RunConfiguration DLT_LOGGER_NAME = "dlt" LOGGER: Logger = None @@ -21,6 +22,7 @@ def __call__(self, msg: str, *args: Any, **kwds: Any) -> None: def __getattr__(name: str) -> LogMethod: """Forwards log method calls (debug, info, error etc.) to LOGGER""" + def wrapper(msg: str, *args: Any, **kwargs: Any) -> None: if LOGGER: # skip stack frames when displaying log so the original logging frame is displayed @@ -29,6 +31,7 @@ def wrapper(msg: str, *args: Any, **kwargs: Any) -> None: # exception has one more frame stacklevel = 3 getattr(LOGGER, name)(msg, *args, **kwargs, stacklevel=stacklevel) + return wrapper @@ -51,11 +54,8 @@ def init_logging(config: RunConfiguration) -> None: version = dlt_version_info(config.pipeline_name) LOGGER = _init_logging( - DLT_LOGGER_NAME, - config.log_level, - config.log_format, - config.pipeline_name, - version) + DLT_LOGGER_NAME, config.log_level, config.log_format, config.pipeline_name, version + ) def is_logging() -> bool: @@ -86,7 +86,6 @@ def format(self, record: LogRecord) -> str: # noqa: A003 class _CustomJsonFormatter(json_logging.JSONLogFormatter): - version: StrStr = None def _format_log_object(self, record: LogRecord, request_util: Any) -> Any: @@ -96,7 +95,9 @@ def _format_log_object(self, record: LogRecord, request_util: Any) -> Any: return json_log_object -def _init_logging(logger_name: str, level: str, fmt: str, component: str, version: StrStr) -> Logger: +def _init_logging( + logger_name: str, level: str, fmt: str, component: str, version: StrStr +) -> Logger: if logger_name == "root": logging.basicConfig(level=level) handler = logging.getLogger().handlers[0] @@ -121,6 +122,6 @@ def _init_logging(logger_name: str, level: str, fmt: str, component: str, versio if logger_name == "root": json_logging.config_root_logger() else: - handler.setFormatter(_MetricsFormatter(fmt=fmt, style='{')) + handler.setFormatter(_MetricsFormatter(fmt=fmt, style="{")) return logger diff --git a/dlt/common/runtime/prometheus.py b/dlt/common/runtime/prometheus.py index 0634670a5a..f1640f0d99 100644 --- a/dlt/common/runtime/prometheus.py +++ b/dlt/common/runtime/prometheus.py @@ -1,4 +1,5 @@ from typing import Iterable + from prometheus_client import Gauge from prometheus_client.metrics import MetricWrapperBase @@ -7,7 +8,6 @@ from dlt.common.runtime.exec_info import dlt_version_info from dlt.common.typing import DictStrAny, StrAny - # def init_prometheus(config: RunConfiguration) -> None: # from prometheus_client import start_http_server, Info @@ -23,7 +23,9 @@ def get_metrics_from_prometheus(gauges: Iterable[MetricWrapperBase]) -> StrAny: name = g._name if g._is_parent(): # for gauges containing many label values, enumerate all - metrics.update(get_metrics_from_prometheus([g.labels(*label) for label in g._metrics.keys()])) + metrics.update( + get_metrics_from_prometheus([g.labels(*label) for label in g._metrics.keys()]) + ) continue # for gauges with labels: add the label to the name and enumerate samples if g._labelvalues: diff --git a/dlt/common/runtime/segment.py b/dlt/common/runtime/segment.py index b8d533cccb..67a8c95454 100644 --- a/dlt/common/runtime/segment.py +++ b/dlt/common/runtime/segment.py @@ -1,24 +1,24 @@ """dltHub telemetry using Segment""" # several code fragments come from https://github.com/RasaHQ/rasa/blob/main/rasa/telemetry.py -import os -import sys -import multiprocessing import atexit import base64 -import requests +import multiprocessing +import os import platform +import sys from concurrent.futures import ThreadPoolExecutor from typing import Literal, Optional -from dlt.common.configuration.paths import get_dlt_data_dir -from dlt.common.runtime import logger +import requests +from dlt.common.configuration.paths import get_dlt_data_dir from dlt.common.configuration.specs import RunConfiguration +from dlt.common.runtime import logger from dlt.common.runtime.exec_info import exec_info_names, in_continuous_integration from dlt.common.typing import DictStrAny, StrAny from dlt.common.utils import uniq_id -from dlt.version import __version__, DLT_PKG_NAME +from dlt.version import DLT_PKG_NAME, __version__ TEventCategory = Literal["pipeline", "command", "helper"] @@ -31,7 +31,9 @@ def init_segment(config: RunConfiguration) -> None: - assert config.dlthub_telemetry_segment_write_key, "dlthub_telemetry_segment_write_key not present in RunConfiguration" + assert ( + config.dlthub_telemetry_segment_write_key + ), "dlthub_telemetry_segment_write_key not present in RunConfiguration" # create thread pool to send telemetry to segment global _THREAD_POOL, _WRITE_KEY, _SESSION @@ -51,11 +53,7 @@ def disable_segment() -> None: _at_exit_cleanup() -def track( - event_category: TEventCategory, - event_name: str, - properties: DictStrAny -) -> None: +def track(event_category: TEventCategory, event_name: str, properties: DictStrAny) -> None: """Tracks a telemetry event. The segment event name will be created as "{event_category}_{event_name} @@ -68,10 +66,7 @@ def track( if properties is None: properties = {} - properties.update({ - "event_category": event_category, - "event_name": event_name - }) + properties.update({"event_category": event_category, "event_name": event_name}) try: _send_event(f"{event_category}_{event_name}", properties, _default_context_fields()) @@ -127,11 +122,7 @@ def get_anonymous_id() -> str: return anonymous_id -def _segment_request_payload( - event_name: str, - properties: StrAny, - context: StrAny -) -> DictStrAny: +def _segment_request_payload(event_name: str, properties: StrAny, context: StrAny) -> DictStrAny: """Compose a valid payload for the segment API. Args: @@ -167,7 +158,7 @@ def _default_context_fields() -> DictStrAny: "python": sys.version.split(" ")[0], "library": {"name": DLT_PKG_NAME, "version": __version__}, "cpu": multiprocessing.cpu_count(), - "exec_info": exec_info_names() + "exec_info": exec_info_names(), } # avoid returning the cached dict --> caller could modify the dictionary... @@ -176,11 +167,7 @@ def _default_context_fields() -> DictStrAny: return _SEGMENT_CONTEXT.copy() -def _send_event( - event_name: str, - properties: StrAny, - context: StrAny -) -> None: +def _send_event(event_name: str, properties: StrAny, context: StrAny) -> None: """Report the contents segment of an event to the /track Segment endpoint. Args: @@ -205,7 +192,9 @@ def _send_event( def _future_send() -> None: # import time # start_ts = time.time() - resp = _SESSION.post(_SEGMENT_ENDPOINT, headers=headers, json=payload, timeout=_SEGMENT_REQUEST_TIMEOUT) + resp = _SESSION.post( + _SEGMENT_ENDPOINT, headers=headers, json=payload, timeout=_SEGMENT_REQUEST_TIMEOUT + ) # print(f"SENDING TO Segment done {resp.status_code} {time.time() - start_ts} {base64.b64decode(_WRITE_KEY)}") # handle different failure cases if resp.status_code != 200: @@ -216,8 +205,6 @@ def _future_send() -> None: else: data = resp.json() if not data.get("success"): - logger.debug( - f"Segment telemetry request returned a failure. Response: {data}" - ) + logger.debug(f"Segment telemetry request returned a failure. Response: {data}") _THREAD_POOL.submit(_future_send) diff --git a/dlt/common/runtime/sentry.py b/dlt/common/runtime/sentry.py index 8bc70e46cf..25ac624b9a 100644 --- a/dlt/common/runtime/sentry.py +++ b/dlt/common/runtime/sentry.py @@ -5,14 +5,18 @@ try: import sentry_sdk - from sentry_sdk.transport import HttpTransport from sentry_sdk.integrations.logging import LoggingIntegration + from sentry_sdk.transport import HttpTransport except ModuleNotFoundError: - raise MissingDependencyException("sentry telemetry", ["sentry-sdk"], "Please install sentry-sdk if you have `sentry_dsn` set in your RuntimeConfiguration") + raise MissingDependencyException( + "sentry telemetry", + ["sentry-sdk"], + "Please install sentry-sdk if you have `sentry_dsn` set in your RuntimeConfiguration", + ) -from dlt.common.typing import DictStrAny, Any, StrAny from dlt.common.configuration.specs import RunConfiguration -from dlt.common.runtime.exec_info import dlt_version_info, kube_pod_info, github_info +from dlt.common.runtime.exec_info import dlt_version_info, github_info, kube_pod_info +from dlt.common.typing import Any, DictStrAny, StrAny def init_sentry(config: RunConfiguration) -> None: @@ -27,10 +31,10 @@ def init_sentry(config: RunConfiguration) -> None: before_send=before_send, traces_sample_rate=1.0, # disable tornado, boto3, sql alchemy etc. - auto_enabling_integrations = False, + auto_enabling_integrations=False, integrations=[_get_sentry_log_level(config)], release=release, - transport=_SentryHttpTransport + transport=_SentryHttpTransport, ) # add version tags for k, v in version.items(): @@ -58,12 +62,11 @@ def before_send(event: DictStrAny, _unused_hint: Optional[StrAny] = None) -> Opt class _SentryHttpTransport(HttpTransport): - timeout: float = 0 def _get_pool_options(self, *a: Any, **kw: Any) -> DictStrAny: rv = HttpTransport._get_pool_options(self, *a, **kw) - rv['timeout'] = self.timeout + rv["timeout"] = self.timeout return rv @@ -71,6 +74,6 @@ def _get_sentry_log_level(config: RunConfiguration) -> LoggingIntegration: log_level = logging._nameToLevel[config.log_level] event_level = logging.WARNING if log_level <= logging.WARNING else log_level return LoggingIntegration( - level=logging.INFO, # Capture info and above as breadcrumbs - event_level=event_level # Send errors as events + level=logging.INFO, # Capture info and above as breadcrumbs + event_level=event_level, # Send errors as events ) diff --git a/dlt/common/runtime/signals.py b/dlt/common/runtime/signals.py index 2a5cc75135..c835710019 100644 --- a/dlt/common/runtime/signals.py +++ b/dlt/common/runtime/signals.py @@ -1,8 +1,8 @@ -import threading import signal +import threading from contextlib import contextmanager from threading import Event -from typing import Any, TYPE_CHECKING, Iterator +from typing import TYPE_CHECKING, Any, Iterator from dlt.common.exceptions import SignalReceivedException diff --git a/dlt/common/runtime/slack.py b/dlt/common/runtime/slack.py index ce5e90b300..f279468d4a 100644 --- a/dlt/common/runtime/slack.py +++ b/dlt/common/runtime/slack.py @@ -1,16 +1,14 @@ import requests + from dlt.common import json, logger def send_slack_message(incoming_hook: str, message: str, is_markdown: bool = True) -> None: """Sends a `message` to Slack `incoming_hook`, by default formatted as markdown.""" - r = requests.post(incoming_hook, - data = json.dumps({ - "text": message, - "mrkdwn": is_markdown - } - ).encode("utf-8"), - headers={'Content-Type': 'application/json;charset=utf-8'} + r = requests.post( + incoming_hook, + data=json.dumps({"text": message, "mrkdwn": is_markdown}).encode("utf-8"), + headers={"Content-Type": "application/json;charset=utf-8"}, ) if r.status_code >= 400: logger.warning(f"Could not post the notification to slack: {r.status_code}") diff --git a/dlt/common/runtime/telemetry.py b/dlt/common/runtime/telemetry.py index 86b3355985..2883d1ba2a 100644 --- a/dlt/common/runtime/telemetry.py +++ b/dlt/common/runtime/telemetry.py @@ -1,12 +1,12 @@ -import time import contextlib import inspect +import time from typing import Any, Callable +from dlt.common.configuration import resolve_configuration from dlt.common.configuration.specs import RunConfiguration +from dlt.common.runtime.segment import TEventCategory, disable_segment, init_segment, track from dlt.common.typing import TFun -from dlt.common.configuration import resolve_configuration -from dlt.common.runtime.segment import TEventCategory, init_segment, disable_segment, track _TELEMETRY_STARTED = False @@ -21,6 +21,7 @@ def start_telemetry(config: RunConfiguration) -> None: if config.sentry_dsn: # may raise if sentry is not installed from dlt.common.runtime.sentry import init_sentry + init_sentry(config) if config.dlthub_telemetry: @@ -36,6 +37,7 @@ def stop_telemetry() -> None: try: from dlt.common.runtime.sentry import disable_sentry + disable_sentry() except ImportError: pass @@ -49,14 +51,18 @@ def is_telemetry_started() -> bool: return _TELEMETRY_STARTED -def with_telemetry(category: TEventCategory, command: str, track_before: bool, *args: str) -> Callable[[TFun], TFun]: +def with_telemetry( + category: TEventCategory, command: str, track_before: bool, *args: str +) -> Callable[[TFun], TFun]: """Adds telemetry to f: TFun and add optional f *args values to `properties` of telemetry event""" + def decorator(f: TFun) -> TFun: sig: inspect.Signature = inspect.signature(f) + def _wrap(*f_args: Any, **f_kwargs: Any) -> Any: # look for additional arguments bound_args = sig.bind(*f_args, **f_kwargs) - props = {p:bound_args.arguments[p] for p in args if p in bound_args.arguments} + props = {p: bound_args.arguments[p] for p in args if p in bound_args.arguments} start_ts = time.time() def _track(success: bool) -> None: @@ -88,4 +94,5 @@ def _track(success: bool) -> None: raise return _wrap # type: ignore - return decorator \ No newline at end of file + + return decorator diff --git a/dlt/common/schema/__init__.py b/dlt/common/schema/__init__.py index a574d9baf3..cf18606b50 100644 --- a/dlt/common/schema/__init__.py +++ b/dlt/common/schema/__init__.py @@ -1,4 +1,13 @@ -from dlt.common.schema.typing import TSchemaUpdate, TSchemaTables, TTableSchema, TStoredSchema, TTableSchemaColumns, TColumnHint, TColumnSchema, TColumnSchemaBase # noqa: F401 -from dlt.common.schema.typing import COLUMN_HINTS # noqa: F401 from dlt.common.schema.schema import Schema # noqa: F401 +from dlt.common.schema.typing import COLUMN_HINTS # noqa: F401 +from dlt.common.schema.typing import ( # noqa: F401 + TColumnHint, + TColumnSchema, + TColumnSchemaBase, + TSchemaTables, + TSchemaUpdate, + TStoredSchema, + TTableSchema, + TTableSchemaColumns, +) from dlt.common.schema.utils import add_missing_hints, verify_schema_hash # noqa: F401 diff --git a/dlt/common/schema/detections.py b/dlt/common/schema/detections.py index 574cb44c93..31e4ecc711 100644 --- a/dlt/common/schema/detections.py +++ b/dlt/common/schema/detections.py @@ -3,11 +3,10 @@ from hexbytes import HexBytes -from dlt.common import pendulum, Wei +from dlt.common import Wei, pendulum from dlt.common.data_types import TDataType from dlt.common.time import parse_iso_like_datetime - _NOW_TS: float = pendulum.now().timestamp() _FLOAT_TS_RANGE = 5 * 31536000.0 # seconds in year diff --git a/dlt/common/schema/exceptions.py b/dlt/common/schema/exceptions.py index 2245a77b61..5f63b20459 100644 --- a/dlt/common/schema/exceptions.py +++ b/dlt/common/schema/exceptions.py @@ -1,7 +1,7 @@ from typing import Any -from dlt.common.exceptions import DltException from dlt.common.data_types import TDataType +from dlt.common.exceptions import DltException class SchemaException(DltException): @@ -13,7 +13,12 @@ class InvalidSchemaName(ValueError, SchemaException): def __init__(self, name: str) -> None: self.name = name - super().__init__(f"{name} is an invalid schema/source name. The source or schema name must be a valid Python identifier ie. a snake case function name and have maximum {self.MAXIMUM_SCHEMA_NAME_LENGTH} characters. Ideally should contain only small letters, numbers and underscores.") + super().__init__( + f"{name} is an invalid schema/source name. The source or schema name must be a valid" + " Python identifier ie. a snake case function name and have maximum" + f" {self.MAXIMUM_SCHEMA_NAME_LENGTH} characters. Ideally should contain only small" + " letters, numbers and underscores." + ) # class InvalidDatasetName(ValueError, SchemaException): @@ -21,20 +26,34 @@ def __init__(self, name: str) -> None: # self.name = name # super().__init__(f"{name} is an invalid dataset name. The dataset name must conform to wide range of destinations and ideally should contain only small letters, numbers and underscores. Try {normalized_name} instead as suggested by current naming module.") + class InvalidDatasetName(ValueError, SchemaException): def __init__(self, destination_name: str) -> None: self.destination_name = destination_name - super().__init__(f"Destination {destination_name} does not accept empty datasets. Please pass the dataset name to the destination configuration ie. via dlt pipeline.") + super().__init__( + f"Destination {destination_name} does not accept empty datasets. Please pass the" + " dataset name to the destination configuration ie. via dlt pipeline." + ) class CannotCoerceColumnException(SchemaException): - def __init__(self, table_name: str, column_name: str, from_type: TDataType, to_type: TDataType, coerced_value: Any) -> None: + def __init__( + self, + table_name: str, + column_name: str, + from_type: TDataType, + to_type: TDataType, + coerced_value: Any, + ) -> None: self.table_name = table_name self.column_name = column_name self.from_type = from_type self.to_type = to_type self.coerced_value = coerced_value - super().__init__(f"Cannot coerce type in table {table_name} column {column_name} existing type {from_type} coerced type {to_type} value: {coerced_value}") + super().__init__( + f"Cannot coerce type in table {table_name} column {column_name} existing type" + f" {from_type} coerced type {to_type} value: {coerced_value}" + ) class TablePropertiesConflictException(SchemaException): @@ -43,19 +62,27 @@ def __init__(self, table_name: str, prop_name: str, val1: str, val2: str): self.prop_name = prop_name self.val1 = val1 self.val2 = val2 - super().__init__(f"Cannot merge partial tables for {table_name} due to property {prop_name}: {val1} != {val2}") + super().__init__( + f"Cannot merge partial tables for {table_name} due to property {prop_name}: {val1} !=" + f" {val2}" + ) class ParentTableNotFoundException(SchemaException): def __init__(self, table_name: str, parent_table_name: str, explanation: str = "") -> None: self.table_name = table_name self.parent_table_name = parent_table_name - super().__init__(f"Parent table {parent_table_name} for {table_name} was not found in the schema.{explanation}") + super().__init__( + f"Parent table {parent_table_name} for {table_name} was not found in the" + f" schema.{explanation}" + ) class CannotCoerceNullException(SchemaException): def __init__(self, table_name: str, column_name: str) -> None: - super().__init__(f"Cannot coerce NULL in table {table_name} column {column_name} which is not nullable") + super().__init__( + f"Cannot coerce NULL in table {table_name} column {column_name} which is not nullable" + ) class SchemaCorruptedException(SchemaException): @@ -63,9 +90,14 @@ class SchemaCorruptedException(SchemaException): class SchemaEngineNoUpgradePathException(SchemaException): - def __init__(self, schema_name: str, init_engine: int, from_engine: int, to_engine: int) -> None: + def __init__( + self, schema_name: str, init_engine: int, from_engine: int, to_engine: int + ) -> None: self.schema_name = schema_name self.init_engine = init_engine self.from_engine = from_engine self.to_engine = to_engine - super().__init__(f"No engine upgrade path in schema {schema_name} from {init_engine} to {to_engine}, stopped at {from_engine}") + super().__init__( + f"No engine upgrade path in schema {schema_name} from {init_engine} to {to_engine}," + f" stopped at {from_engine}" + ) diff --git a/dlt/common/schema/schema.py b/dlt/common/schema/schema.py index 256afdfe0d..36a9806f08 100644 --- a/dlt/common/schema/schema.py +++ b/dlt/common/schema/schema.py @@ -1,18 +1,48 @@ -import yaml from copy import copy, deepcopy -from typing import ClassVar, Dict, List, Mapping, Optional, Sequence, Tuple, Any, cast -from dlt.common import json +from typing import Any, ClassVar, Dict, List, Mapping, Optional, Sequence, Tuple, cast + +import yaml -from dlt.common.typing import DictStrAny, StrAny, REPattern, SupportsVariant, VARIANT_FIELD_FORMAT, TDataItem +from dlt.common import json +from dlt.common.data_types import TDataType, coerce_value, py_type_to_sc_type from dlt.common.normalizers import TNormalizersConfig, explicit_normalizers, import_normalizers -from dlt.common.normalizers.naming import NamingConvention from dlt.common.normalizers.json import DataItemNormalizer, TNormalizedRowIterator +from dlt.common.normalizers.naming import NamingConvention from dlt.common.schema import utils -from dlt.common.data_types import py_type_to_sc_type, coerce_value, TDataType -from dlt.common.schema.typing import (COLUMN_HINTS, SCHEMA_ENGINE_VERSION, LOADS_TABLE_NAME, VERSION_TABLE_NAME, TColumnSchemaBase, TPartialTableSchema, TSchemaSettings, TSimpleRegex, TStoredSchema, - TSchemaTables, TTableSchema, TTableSchemaColumns, TColumnSchema, TColumnProp, TColumnHint, TTypeDetections, TWriteDisposition) -from dlt.common.schema.exceptions import (CannotCoerceColumnException, CannotCoerceNullException, InvalidSchemaName, - ParentTableNotFoundException, SchemaCorruptedException) +from dlt.common.schema.exceptions import ( + CannotCoerceColumnException, + CannotCoerceNullException, + InvalidSchemaName, + ParentTableNotFoundException, + SchemaCorruptedException, +) +from dlt.common.schema.typing import ( + COLUMN_HINTS, + LOADS_TABLE_NAME, + SCHEMA_ENGINE_VERSION, + VERSION_TABLE_NAME, + TColumnHint, + TColumnProp, + TColumnSchema, + TColumnSchemaBase, + TPartialTableSchema, + TSchemaSettings, + TSchemaTables, + TSimpleRegex, + TStoredSchema, + TTableSchema, + TTableSchemaColumns, + TTypeDetections, + TWriteDisposition, +) +from dlt.common.typing import ( + VARIANT_FIELD_FORMAT, + DictStrAny, + REPattern, + StrAny, + SupportsVariant, + TDataItem, +) from dlt.common.validation import validate_dict @@ -29,7 +59,6 @@ class Schema: loads_table_name: str """Normalized name of the loads table""" - _schema_name: str _dlt_tables_prefix: str _stored_version: int # version at load/creation time @@ -37,7 +66,7 @@ class Schema: _imported_version_hash: str # version hash of recently imported schema _schema_description: str # optional schema description _schema_tables: TSchemaTables - _settings: TSchemaSettings # schema settings to hold default hints, preferred types and other settings + _settings: TSchemaSettings # schema settings to hold default hints, preferred types and other settings # list of preferred types: map regex on columns into types _compiled_preferred_types: List[Tuple[REPattern, TDataType]] @@ -72,7 +101,9 @@ def from_dict(cls, d: DictStrAny) -> "Schema": @classmethod def from_stored_schema(cls, stored_schema: TStoredSchema) -> "Schema": # create new instance from dict - self: Schema = cls(stored_schema["name"], normalizers=stored_schema.get("normalizers", None)) + self: Schema = cls( + stored_schema["name"], normalizers=stored_schema.get("normalizers", None) + ) self._from_stored_schema(stored_schema) return self @@ -88,7 +119,7 @@ def to_dict(self, remove_defaults: bool = False) -> TStoredSchema: "name": self._schema_name, "tables": self._schema_tables, "settings": self._settings, - "normalizers": self._normalizers_config + "normalizers": self._normalizers_config, } if self._imported_version_hash and not remove_defaults: stored_schema["imported_version_hash"] = self._imported_version_hash @@ -102,7 +133,9 @@ def to_dict(self, remove_defaults: bool = False) -> TStoredSchema: utils.remove_defaults(stored_schema) return stored_schema - def normalize_data_item(self, item: TDataItem, load_id: str, table_name: str) -> TNormalizedRowIterator: + def normalize_data_item( + self, item: TDataItem, load_id: str, table_name: str + ) -> TNormalizedRowIterator: return self.data_item_normalizer.normalize_data_item(item, load_id, table_name) def filter_row(self, table_name: str, row: StrAny) -> StrAny: @@ -119,7 +152,9 @@ def filter_row(self, table_name: str, row: StrAny) -> StrAny: # most of the schema do not use them return row - def _exclude(path: str, excludes: Sequence[REPattern], includes: Sequence[REPattern]) -> bool: + def _exclude( + path: str, excludes: Sequence[REPattern], includes: Sequence[REPattern] + ) -> bool: is_included = False is_excluded = any(exclude.search(path) for exclude in excludes) if is_excluded: @@ -148,7 +183,9 @@ def _exclude(path: str, excludes: Sequence[REPattern], includes: Sequence[REPatt break return row - def coerce_row(self, table_name: str, parent_table: str, row: StrAny) -> Tuple[DictStrAny, TPartialTableSchema]: + def coerce_row( + self, table_name: str, parent_table: str, row: StrAny + ) -> Tuple[DictStrAny, TPartialTableSchema]: # get existing or create a new table updated_table_partial: TPartialTableSchema = None table = self._schema_tables.get(table_name) @@ -163,7 +200,9 @@ def coerce_row(self, table_name: str, parent_table: str, row: StrAny) -> Tuple[D # just check if column is nullable if it exists self._coerce_null_value(table_columns, table_name, col_name) else: - new_col_name, new_col_def, new_v = self._coerce_non_null_value(table_columns, table_name, col_name, v) + new_col_name, new_col_def, new_v = self._coerce_non_null_value( + table_columns, table_name, col_name, v + ) new_row[new_col_name] = new_v if new_col_def: if not updated_table_partial: @@ -181,9 +220,12 @@ def update_schema(self, partial_table: TPartialTableSchema) -> TPartialTableSche if parent_table_name is not None: if self._schema_tables.get(parent_table_name) is None: raise ParentTableNotFoundException( - table_name, parent_table_name, - f" This may be due to misconfigured excludes filter that fully deletes content of the {parent_table_name}. Add includes that will preserve the parent table." - ) + table_name, + parent_table_name, + " This may be due to misconfigured excludes filter that fully deletes content" + f" of the {parent_table_name}. Add includes that will preserve the parent" + " table.", + ) table = self._schema_tables.get(table_name) if table is None: # add the whole new table to SchemaTables @@ -213,7 +255,9 @@ def filter_row_with_hint(self, table_name: str, hint_type: TColumnHint, row: Str for column_name in table: if column_name in row: hint_value = table[column_name][column_prop] - if (hint_value and column_prop != "nullable") or (column_prop == "nullable" and not hint_value): + if (hint_value and column_prop != "nullable") or ( + column_prop == "nullable" and not hint_value + ): rv_row[column_name] = row[column_name] except KeyError: for k, v in row.items(): @@ -225,7 +269,12 @@ def filter_row_with_hint(self, table_name: str, hint_type: TColumnHint, row: Str def merge_hints(self, new_hints: Mapping[TColumnHint, Sequence[TSimpleRegex]]) -> None: # validate regexes - validate_dict(TSchemaSettings, {"default_hints": new_hints}, ".", validator_f=utils.simple_regex_validator) + validate_dict( + TSchemaSettings, + {"default_hints": new_hints}, + ".", + validator_f=utils.simple_regex_validator, + ) # prepare hints to be added default_hints = self._settings.setdefault("default_hints", {}) # add `new_hints` to existing hints @@ -249,10 +298,15 @@ def normalize_table_identifiers(self, table: TTableSchema) -> TTableSchema: for c in columns.values(): c["name"] = self.naming.normalize_path(c["name"]) # re-index columns as the name changed - table["columns"] = {c["name"]:c for c in columns.values()} + table["columns"] = {c["name"]: c for c in columns.values()} return table - def get_new_table_columns(self, table_name: str, exiting_columns: TTableSchemaColumns, include_incomplete: bool = False) -> List[TColumnSchema]: + def get_new_table_columns( + self, + table_name: str, + exiting_columns: TTableSchemaColumns, + include_incomplete: bool = False, + ) -> List[TColumnSchema]: """Gets new columns to be added to `exiting_columns` to bring them up to date with `table_name` schema. Optionally includes incomplete columns (without data type)""" diff_c: List[TColumnSchema] = [] s_t = self.get_table_columns(table_name, include_incomplete=include_incomplete) @@ -264,20 +318,33 @@ def get_new_table_columns(self, table_name: str, exiting_columns: TTableSchemaCo def get_table(self, table_name: str) -> TTableSchema: return self._schema_tables[table_name] - def get_table_columns(self, table_name: str, include_incomplete: bool = False) -> TTableSchemaColumns: - """Gets columns of `table_name`. Optionally includes incomplete columns """ + def get_table_columns( + self, table_name: str, include_incomplete: bool = False + ) -> TTableSchemaColumns: + """Gets columns of `table_name`. Optionally includes incomplete columns""" if include_incomplete: return self._schema_tables[table_name]["columns"] else: - return {k:v for k, v in self._schema_tables[table_name]["columns"].items() if utils.is_complete_column(v)} + return { + k: v + for k, v in self._schema_tables[table_name]["columns"].items() + if utils.is_complete_column(v) + } def data_tables(self, include_incomplete: bool = False) -> List[TTableSchema]: """Gets list of all tables, that hold the loaded data. Excludes dlt tables. Excludes incomplete tables (ie. without columns)""" - return [t for t in self._schema_tables.values() if not t["name"].startswith(self._dlt_tables_prefix) and (len(t["columns"]) > 0 or include_incomplete)] + return [ + t + for t in self._schema_tables.values() + if not t["name"].startswith(self._dlt_tables_prefix) + and (len(t["columns"]) > 0 or include_incomplete) + ] def dlt_tables(self) -> List[TTableSchema]: """Gets dlt tables""" - return [t for t in self._schema_tables.values() if t["name"].startswith(self._dlt_tables_prefix)] + return [ + t for t in self._schema_tables.values() if t["name"].startswith(self._dlt_tables_prefix) + ] def get_preferred_type(self, col_name: str) -> Optional[TDataType]: return next((m[1] for m in self._compiled_preferred_types if m[0].search(col_name)), None) @@ -349,11 +416,13 @@ def update_normalizers(self) -> None: normalizers["json"] = normalizers["json"] or self._normalizers_config["json"] self._configure_normalizers(normalizers) - def _infer_column(self, k: str, v: Any, data_type: TDataType = None, is_variant: bool = False) -> TColumnSchema: - column_schema = TColumnSchema( + def _infer_column( + self, k: str, v: Any, data_type: TDataType = None, is_variant: bool = False + ) -> TColumnSchema: + column_schema = TColumnSchema( name=k, data_type=data_type or self._infer_column_type(v, k), - nullable=not self._infer_hint("not_null", v, k) + nullable=not self._infer_hint("not_null", v, k), ) for hint in COLUMN_HINTS: column_schema[utils.hint_to_column_prop(hint)] = self._infer_hint(hint, v, k) @@ -362,14 +431,23 @@ def _infer_column(self, k: str, v: Any, data_type: TDataType = None, is_variant: column_schema["variant"] = is_variant return column_schema - def _coerce_null_value(self, table_columns: TTableSchemaColumns, table_name: str, col_name: str) -> None: + def _coerce_null_value( + self, table_columns: TTableSchemaColumns, table_name: str, col_name: str + ) -> None: """Raises when column is explicitly not nullable""" if col_name in table_columns: existing_column = table_columns[col_name] if not existing_column.get("nullable", True): raise CannotCoerceNullException(table_name, col_name) - def _coerce_non_null_value(self, table_columns: TTableSchemaColumns, table_name: str, col_name: str, v: Any, is_variant: bool = False) -> Tuple[str, TColumnSchema, Any]: + def _coerce_non_null_value( + self, + table_columns: TTableSchemaColumns, + table_name: str, + col_name: str, + v: Any, + is_variant: bool = False, + ) -> Tuple[str, TColumnSchema, Any]: new_column: TColumnSchema = None existing_column = table_columns.get(col_name) # if column exist but is incomplete then keep it as new column @@ -378,7 +456,11 @@ def _coerce_non_null_value(self, table_columns: TTableSchemaColumns, table_name: existing_column = None # infer type or get it from existing table - col_type = existing_column["data_type"] if existing_column else self._infer_column_type(v, col_name, skip_preferred=is_variant) + col_type = ( + existing_column["data_type"] + if existing_column + else self._infer_column_type(v, col_name, skip_preferred=is_variant) + ) # get data type of value py_type = py_type_to_sc_type(type(v)) # and coerce type if inference changed the python type @@ -387,12 +469,18 @@ def _coerce_non_null_value(self, table_columns: TTableSchemaColumns, table_name: except (ValueError, SyntaxError): if is_variant: # this is final call: we cannot generate any more auto-variants - raise CannotCoerceColumnException(table_name, col_name, py_type, table_columns[col_name]["data_type"], v) + raise CannotCoerceColumnException( + table_name, col_name, py_type, table_columns[col_name]["data_type"], v + ) # otherwise we must create variant extension to the table # pass final=True so no more auto-variants can be created recursively # TODO: generate callback so DLT user can decide what to do - variant_col_name = self.naming.shorten_fragments(col_name, VARIANT_FIELD_FORMAT % py_type) - return self._coerce_non_null_value(table_columns, table_name, variant_col_name, v, is_variant=True) + variant_col_name = self.naming.shorten_fragments( + col_name, VARIANT_FIELD_FORMAT % py_type + ) + return self._coerce_non_null_value( + table_columns, table_name, variant_col_name, v, is_variant=True + ) # if coerced value is variant, then extract variant value # note: checking runtime protocols with isinstance(coerced_v, SupportsVariant): is extremely slow so we check if callable as every variant is callable @@ -400,11 +488,17 @@ def _coerce_non_null_value(self, table_columns: TTableSchemaColumns, table_name: coerced_v = coerced_v() if isinstance(coerced_v, tuple): # variant recovered so call recursively with variant column name and variant value - variant_col_name = self.naming.shorten_fragments(col_name, VARIANT_FIELD_FORMAT % coerced_v[0]) - return self._coerce_non_null_value(table_columns, table_name, variant_col_name, coerced_v[1], is_variant=True) + variant_col_name = self.naming.shorten_fragments( + col_name, VARIANT_FIELD_FORMAT % coerced_v[0] + ) + return self._coerce_non_null_value( + table_columns, table_name, variant_col_name, coerced_v[1], is_variant=True + ) if not existing_column: - inferred_column = self._infer_column(col_name, v, data_type=col_type, is_variant=is_variant) + inferred_column = self._infer_column( + col_name, v, data_type=col_type, is_variant=is_variant + ) # if there's partial new_column then merge it with inferred column if new_column: new_column = utils.merge_columns(new_column, inferred_column, merge_defaults=True) @@ -433,8 +527,12 @@ def _infer_hint(self, hint_type: TColumnHint, _: Any, col_name: str) -> bool: return False def _add_standard_tables(self) -> None: - self._schema_tables[self.version_table_name] = self.normalize_table_identifiers(utils.version_table()) - self._schema_tables[self.loads_table_name] = self.normalize_table_identifiers(utils.load_table()) + self._schema_tables[self.version_table_name] = self.normalize_table_identifiers( + utils.version_table() + ) + self._schema_tables[self.loads_table_name] = self.normalize_table_identifiers( + utils.load_table() + ) def _add_standard_hints(self) -> None: default_hints = utils.standard_hints() @@ -446,14 +544,16 @@ def _add_standard_hints(self) -> None: def _configure_normalizers(self, normalizers: TNormalizersConfig) -> None: # import desired modules - self._normalizers_config, naming_module, item_normalizer_class = import_normalizers(normalizers) + self._normalizers_config, naming_module, item_normalizer_class = import_normalizers( + normalizers + ) # print(f"{self.name}: {type(self.naming)} {type(naming_module)}") if self.naming and type(self.naming) is not type(naming_module): self.naming = naming_module for table in self._schema_tables.values(): self.normalize_table_identifiers(table) # re-index the table names - self._schema_tables = {t["name"]:t for t in self._schema_tables.values()} + self._schema_tables = {t["name"]: t for t in self._schema_tables.values()} # name normalization functions self.naming = naming_module @@ -529,9 +629,13 @@ def _compile_settings(self) -> None: for table in self._schema_tables.values(): if "filters" in table: if "excludes" in table["filters"]: - self._compiled_excludes[table["name"]] = list(map(utils.compile_simple_regex, table["filters"]["excludes"])) + self._compiled_excludes[table["name"]] = list( + map(utils.compile_simple_regex, table["filters"]["excludes"]) + ) if "includes" in table["filters"]: - self._compiled_includes[table["name"]] = list(map(utils.compile_simple_regex, table["filters"]["includes"])) + self._compiled_includes[table["name"]] = list( + map(utils.compile_simple_regex, table["filters"]["includes"]) + ) # look for auto-detections in settings and then normalizer self._type_detections = self._settings.get("detections") or self._normalizers_config.get("detections") or [] # type: ignore diff --git a/dlt/common/schema/typing.py b/dlt/common/schema/typing.py index 42a5a1771d..d4e3fed9ee 100644 --- a/dlt/common/schema/typing.py +++ b/dlt/common/schema/typing.py @@ -1,4 +1,18 @@ -from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Set, Type, TypedDict, NewType, Union, get_args +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + NewType, + Optional, + Sequence, + Set, + Type, + TypedDict, + Union, + get_args, +) from dlt.common.data_types import TDataType from dlt.common.normalizers.typing import TNormalizersConfig @@ -10,21 +24,57 @@ VERSION_TABLE_NAME = "_dlt_version" LOADS_TABLE_NAME = "_dlt_loads" -TColumnHint = Literal["not_null", "partition", "cluster", "primary_key", "foreign_key", "sort", "unique", "root_key", "merge_key"] -TColumnProp = Literal["name", "data_type", "nullable", "partition", "cluster", "primary_key", "foreign_key", "sort", "unique", "merge_key", "root_key"] +TColumnHint = Literal[ + "not_null", + "partition", + "cluster", + "primary_key", + "foreign_key", + "sort", + "unique", + "root_key", + "merge_key", +] +TColumnProp = Literal[ + "name", + "data_type", + "nullable", + "partition", + "cluster", + "primary_key", + "foreign_key", + "sort", + "unique", + "merge_key", + "root_key", +] TWriteDisposition = Literal["skip", "append", "replace", "merge"] -TTypeDetections = Literal["timestamp", "iso_timestamp", "large_integer", "hexbytes_to_text", "wei_to_double"] +TTypeDetections = Literal[ + "timestamp", "iso_timestamp", "large_integer", "hexbytes_to_text", "wei_to_double" +] TTypeDetectionFunc = Callable[[Type[Any], Any], Optional[TDataType]] TColumnNames = Union[str, Sequence[str]] """A string representing a column name or a list of""" COLUMN_PROPS: Set[TColumnProp] = set(get_args(TColumnProp)) -COLUMN_HINTS: Set[TColumnHint] = set(["partition", "cluster", "primary_key", "foreign_key", "sort", "unique", "merge_key", "root_key"]) +COLUMN_HINTS: Set[TColumnHint] = set( + [ + "partition", + "cluster", + "primary_key", + "foreign_key", + "sort", + "unique", + "merge_key", + "root_key", + ] +) WRITE_DISPOSITIONS: Set[TWriteDisposition] = set(get_args(TWriteDisposition)) class TColumnSchemaBase(TypedDict, total=False): """TypedDict that defines basic properties of a column: name, data type and nullable""" + name: Optional[str] data_type: Optional[TDataType] nullable: Optional[bool] @@ -32,6 +82,7 @@ class TColumnSchemaBase(TypedDict, total=False): class TColumnSchema(TColumnSchemaBase, total=False): """TypedDict that defines additional column hints""" + description: Optional[str] partition: Optional[bool] cluster: Optional[bool] @@ -58,6 +109,7 @@ class TRowFilters(TypedDict, total=True): class TTableSchema(TypedDict, total=False): """TypedDict that defines properties of a table""" + name: Optional[str] description: Optional[str] write_disposition: Optional[TWriteDisposition] @@ -75,6 +127,7 @@ class TPartialTableSchema(TTableSchema): TSchemaTables = Dict[str, TTableSchema] TSchemaUpdate = Dict[str, List[TPartialTableSchema]] + class TSchemaSettings(TypedDict, total=False): schema_sealed: Optional[bool] detections: Optional[List[TTypeDetections]] @@ -84,6 +137,7 @@ class TSchemaSettings(TypedDict, total=False): class TStoredSchema(TypedDict, total=False): """TypeDict defining the schema representation in storage""" + version: int version_hash: str imported_version_hash: Optional[str] diff --git a/dlt/common/schema/utils.py b/dlt/common/schema/utils.py index 546cc918e8..2c12fb3b1a 100644 --- a/dlt/common/schema/utils.py +++ b/dlt/common/schema/utils.py @@ -1,9 +1,8 @@ -import re import base64 import hashlib - +import re from copy import deepcopy -from typing import Dict, List, Sequence, Tuple, Type, Any, cast, Iterable, Optional +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type, cast from dlt.common import json from dlt.common.data_types import TDataType @@ -11,16 +10,39 @@ from dlt.common.normalizers import explicit_normalizers from dlt.common.normalizers.naming import NamingConvention from dlt.common.normalizers.naming.snake_case import NamingConvention as SnakeCase +from dlt.common.normalizers.utils import import_normalizers +from dlt.common.schema import detections +from dlt.common.schema.exceptions import ( + CannotCoerceColumnException, + InvalidSchemaName, + ParentTableNotFoundException, + SchemaEngineNoUpgradePathException, + SchemaException, + TablePropertiesConflictException, +) +from dlt.common.schema.typing import ( + LOADS_TABLE_NAME, + SCHEMA_ENGINE_VERSION, + SIMPLE_REGEX_PREFIX, + VERSION_TABLE_NAME, + TColumnHint, + TColumnName, + TColumnProp, + TColumnSchema, + TColumnSchemaBase, + TPartialTableSchema, + TSchemaTables, + TSchemaUpdate, + TSimpleRegex, + TStoredSchema, + TTableSchema, + TTableSchemaColumns, + TTypeDetectionFunc, + TTypeDetections, + TWriteDisposition, +) from dlt.common.typing import DictStrAny, REPattern from dlt.common.validation import TCustomValidator, validate_dict, validate_dict_ignoring_xkeys -from dlt.common.schema import detections -from dlt.common.schema.typing import (SCHEMA_ENGINE_VERSION, LOADS_TABLE_NAME, SIMPLE_REGEX_PREFIX, VERSION_TABLE_NAME, TColumnName, TPartialTableSchema, TSchemaTables, TSchemaUpdate, - TSimpleRegex, TStoredSchema, TTableSchema, TTableSchemaColumns, TColumnSchemaBase, TColumnSchema, TColumnProp, - TColumnHint, TTypeDetectionFunc, TTypeDetections, TWriteDisposition) -from dlt.common.schema.exceptions import (CannotCoerceColumnException, ParentTableNotFoundException, SchemaEngineNoUpgradePathException, SchemaException, - TablePropertiesConflictException, InvalidSchemaName) - -from dlt.common.normalizers.utils import import_normalizers RE_NON_ALPHANUMERIC_UNDERSCORE = re.compile(r"[^a-zA-Z\d_]") DEFAULT_WRITE_DISPOSITION: TWriteDisposition = "append" @@ -28,7 +50,11 @@ def is_valid_schema_name(name: str) -> bool: """Schema name must be a valid python identifier and have max len of 64""" - return name is not None and name.isidentifier() and len(name) <= InvalidSchemaName.MAXIMUM_SCHEMA_NAME_LENGTH + return ( + name is not None + and name.isidentifier() + and len(name) <= InvalidSchemaName.MAXIMUM_SCHEMA_NAME_LENGTH + ) def normalize_schema_name(name: str) -> str: @@ -45,8 +71,8 @@ def apply_defaults(stored_schema: TStoredSchema) -> None: if table.get("parent") is None: if table.get("write_disposition") is None: table["write_disposition"] = DEFAULT_WRITE_DISPOSITION - if table.get('resource') is None: - table['resource'] = table_name + if table.get("resource") is None: + table["resource"] = table_name # add missing hints to columns for column_name in table["columns"]: # add default hints to tables @@ -61,8 +87,8 @@ def remove_defaults(stored_schema: TStoredSchema) -> TStoredSchema: clean_tables = deepcopy(stored_schema["tables"]) for table_name, t in clean_tables.items(): del t["name"] - if t.get('resource') == table_name: - del t['resource'] + if t.get("resource") == table_name: + del t["resource"] for c in t["columns"].values(): # do not save names del c["name"] @@ -106,10 +132,12 @@ def generate_version_hash(stored_schema: TStoredSchema) -> str: # add column names to hash in order for cn in (t.get("columns") or {}).keys(): h.update(cn.encode("utf-8")) - return base64.b64encode(h.digest()).decode('ascii') + return base64.b64encode(h.digest()).decode("ascii") -def verify_schema_hash(loaded_schema_dict: DictStrAny, verifies_if_not_migrated: bool = False) -> bool: +def verify_schema_hash( + loaded_schema_dict: DictStrAny, verifies_if_not_migrated: bool = False +) -> bool: # generates content hash and compares with existing engine_version: str = loaded_schema_dict.get("engine_version") # if upgrade is needed, the hash cannot be compared @@ -125,16 +153,32 @@ def simple_regex_validator(path: str, pk: str, pv: Any, t: Any) -> bool: # custom validator on type TSimpleRegex if t is TSimpleRegex: if not isinstance(pv, str): - raise DictValidationException(f"In {path}: field {pk} value {pv} has invalid type {type(pv).__name__} while str is expected", path, pk, pv) + raise DictValidationException( + f"In {path}: field {pk} value {pv} has invalid type {type(pv).__name__} while str" + " is expected", + path, + pk, + pv, + ) if pv.startswith(SIMPLE_REGEX_PREFIX): # check if regex try: re.compile(pv[3:]) except Exception as e: - raise DictValidationException(f"In {path}: field {pk} value {pv[3:]} does not compile as regex: {str(e)}", path, pk, pv) + raise DictValidationException( + f"In {path}: field {pk} value {pv[3:]} does not compile as regex: {str(e)}", + path, + pk, + pv, + ) else: if RE_NON_ALPHANUMERIC_UNDERSCORE.match(pv): - raise DictValidationException(f"In {path}: field {pk} value {pv} looks like a regex, please prefix with re:", path, pk, pv) + raise DictValidationException( + f"In {path}: field {pk} value {pv} looks like a regex, please prefix with re:", + path, + pk, + pv, + ) # we know how to validate that type return True else: @@ -143,16 +187,25 @@ def simple_regex_validator(path: str, pk: str, pv: Any, t: Any) -> bool: def column_name_validator(naming: NamingConvention) -> TCustomValidator: - def validator(path: str, pk: str, pv: Any, t: Any) -> bool: if t is TColumnName: if not isinstance(pv, str): - raise DictValidationException(f"In {path}: field {pk} value {pv} has invalid type {type(pv).__name__} while str is expected", path, pk, pv) + raise DictValidationException( + f"In {path}: field {pk} value {pv} has invalid type {type(pv).__name__} while" + " str is expected", + path, + pk, + pv, + ) try: if naming.normalize_path(pv) != pv: - raise DictValidationException(f"In {path}: field {pk}: {pv} is not a valid column name", path, pk, pv) + raise DictValidationException( + f"In {path}: field {pk}: {pv} is not a valid column name", path, pk, pv + ) except ValueError: - raise DictValidationException(f"In {path}: field {pk}: {pv} is not a valid column name", path, pk, pv) + raise DictValidationException( + f"In {path}: field {pk}: {pv} is not a valid column name", path, pk, pv + ) return True else: return False @@ -174,7 +227,7 @@ def compile_simple_regex(r: TSimpleRegex) -> REPattern: def compile_simple_regexes(r: Iterable[TSimpleRegex]) -> REPattern: """Compile multiple patterns as one""" - pattern = '|'.join(f"({_prepare_simple_regex(p)})" for p in r) + pattern = "|".join(f"({_prepare_simple_regex(p)})" for p in r) if not pattern: # Don't create an empty pattern that matches everything raise ValueError("Cannot create a regex pattern from empty sequence") return re.compile(pattern) @@ -183,10 +236,7 @@ def compile_simple_regexes(r: Iterable[TSimpleRegex]) -> REPattern: def validate_stored_schema(stored_schema: TStoredSchema) -> None: # use lambda to verify only non extra fields validate_dict_ignoring_xkeys( - spec=TStoredSchema, - doc=stored_schema, - path=".", - validator_f=simple_regex_validator + spec=TStoredSchema, doc=stored_schema, path=".", validator_f=simple_regex_validator ) # check child parent relationships for table_name, table in stored_schema["tables"].items(): @@ -210,12 +260,8 @@ def migrate_schema(schema_dict: DictStrAny, from_engine: int, to_engine: int) -> # add default normalizers and root hash propagation current["normalizers"], _, _ = import_normalizers(explicit_normalizers()) current["normalizers"]["json"]["config"] = { - "propagation": { - "root": { - "_dlt_id": "_dlt_root_id" - } - } - } + "propagation": {"root": {"_dlt_id": "_dlt_root_id"}} + } # move settings, convert strings to simple regexes d_h: Dict[TColumnHint, List[TSimpleRegex]] = schema_dict.pop("hints", {}) for h_k, h_l in d_h.items(): @@ -252,8 +298,8 @@ def migrate_filters(group: str, filters: List[str]) -> None: # existing filter were always defined at the root table. find this table and move filters for f in filters: # skip initial ^ - root = f[1:f.find("__")] - path = f[f.find("__") + 2:] + root = f[1 : f.find("__")] + path = f[f.find("__") + 2 :] t = current["tables"].get(root) if t is None: # must add new table to hold filters @@ -284,7 +330,9 @@ def migrate_filters(group: str, filters: List[str]) -> None: schema_dict["engine_version"] = from_engine if from_engine != to_engine: - raise SchemaEngineNoUpgradePathException(schema_dict["name"], schema_dict["engine_version"], from_engine, to_engine) + raise SchemaEngineNoUpgradePathException( + schema_dict["name"], schema_dict["engine_version"], from_engine, to_engine + ) return cast(TStoredSchema, schema_dict) @@ -301,9 +349,9 @@ def add_missing_hints(column: TColumnSchemaBase) -> TColumnSchema: "primary_key": False, "foreign_key": False, "root_key": False, - "merge_key": False + "merge_key": False, }, - **column + **column, } @@ -330,7 +378,9 @@ def compare_complete_columns(a: TColumnSchema, b: TColumnSchema) -> bool: return a["data_type"] == b["data_type"] and a["name"] == b["name"] -def merge_columns(col_a: TColumnSchema, col_b: TColumnSchema, merge_defaults: bool = False) -> TColumnSchema: +def merge_columns( + col_a: TColumnSchema, col_b: TColumnSchema, merge_defaults: bool = False +) -> TColumnSchema: """Merges `col_b` into `col_a`. if `merge_defaults` is True, only hints not present in `col_a` will be set.""" # print(f"MERGE ({merge_defaults}) {col_b} into {col_a}") for n, v in col_b.items(): @@ -339,7 +389,9 @@ def merge_columns(col_a: TColumnSchema, col_b: TColumnSchema, merge_defaults: bo return col_a -def diff_tables(tab_a: TTableSchema, tab_b: TPartialTableSchema, ignore_table_name: bool = True) -> TPartialTableSchema: +def diff_tables( + tab_a: TTableSchema, tab_b: TPartialTableSchema, ignore_table_name: bool = True +) -> TPartialTableSchema: """Creates a partial table that contains properties found in `tab_b` that are not present in `tab_a` or that can be updated. Raises SchemaException if tables cannot be merged """ @@ -349,7 +401,9 @@ def diff_tables(tab_a: TTableSchema, tab_b: TPartialTableSchema, ignore_table_na # check if table properties can be merged if tab_a.get("parent") != tab_b.get("parent"): - raise TablePropertiesConflictException(table_name, "parent", tab_a.get("parent"), tab_b.get("parent")) + raise TablePropertiesConflictException( + table_name, "parent", tab_a.get("parent"), tab_b.get("parent") + ) # get new columns, changes in the column data type or other properties are not allowed tab_a_columns = tab_a["columns"] @@ -361,13 +415,18 @@ def diff_tables(tab_a: TTableSchema, tab_b: TPartialTableSchema, ignore_table_na if is_complete_column(col_a) and is_complete_column(col_b): if not compare_complete_columns(tab_a_columns[col_b_name], col_b): # attempt to update to incompatible columns - raise CannotCoerceColumnException(table_name, col_b_name, col_b["data_type"], tab_a_columns[col_b_name]["data_type"], None) + raise CannotCoerceColumnException( + table_name, + col_b_name, + col_b["data_type"], + tab_a_columns[col_b_name]["data_type"], + None, + ) # else: new_columns.append(merge_columns(col_a, col_b)) else: new_columns.append(col_b) - # return partial table containing only name and properties that differ (column, filters etc.) partial_table = new_table(table_name, columns=new_columns) partial_table["write_disposition"] = None @@ -378,7 +437,6 @@ def diff_tables(tab_a: TTableSchema, tab_b: TPartialTableSchema, ignore_table_na return partial_table - def compare_tables(tab_a: TTableSchema, tab_b: TTableSchema) -> bool: try: diff_table = diff_tables(tab_a, tab_b, ignore_table_name=False) @@ -398,8 +456,8 @@ def merge_tables(table: TTableSchema, partial_table: TPartialTableSchema) -> TPa partial_w_d = partial_table.get("write_disposition") if partial_w_d: table["write_disposition"] = partial_w_d - if table.get('parent') is None and (resource := partial_table.get('resource')): - table['resource'] = resource + if table.get("parent") is None and (resource := partial_table.get("resource")): + table["resource"] = resource return diff_table @@ -410,10 +468,16 @@ def hint_to_column_prop(h: TColumnHint) -> TColumnProp: return h -def get_columns_names_with_prop(table: TTableSchema, column_prop: TColumnProp, include_incomplete: bool = False) -> List[str]: +def get_columns_names_with_prop( + table: TTableSchema, column_prop: TColumnProp, include_incomplete: bool = False +) -> List[str]: # column_prop: TColumnProp = hint_to_column_prop(hint_type) # default = column_prop != "nullable" # default is true, only for nullable false - return [c["name"] for c in table["columns"].values() if bool(c.get(column_prop, False)) is True and (include_incomplete or is_complete_column(c))] + return [ + c["name"] + for c in table["columns"].values() + if bool(c.get(column_prop, False)) is True and (include_incomplete or is_complete_column(c)) + ] def merge_schema_updates(schema_updates: Sequence[TSchemaUpdate]) -> TSchemaTables: @@ -470,54 +534,40 @@ def _child(t: TTableSchema) -> None: return chain -def group_tables_by_resource(tables: TSchemaTables, pattern: Optional[REPattern] = None) -> Dict[str, List[TTableSchema]]: +def group_tables_by_resource( + tables: TSchemaTables, pattern: Optional[REPattern] = None +) -> Dict[str, List[TTableSchema]]: """Create a dict of resources and their associated tables and descendant tables If `pattern` is supplied, the result is filtered to only resource names matching the pattern. """ result: Dict[str, List[TTableSchema]] = {} for table in tables.values(): - resource = table.get('resource') + resource = table.get("resource") if resource and (pattern is None or pattern.match(resource)): resource_tables = result.setdefault(resource, []) - resource_tables.extend(get_child_tables(tables, table['name'])) + resource_tables.extend(get_child_tables(tables, table["name"])) return result def version_table() -> TTableSchema: # NOTE: always add new columns at the end of the table so we have identical layout # after an update of existing tables (always at the end) - table = new_table(VERSION_TABLE_NAME, columns=[ - add_missing_hints({ - "name": "version", - "data_type": "bigint", - "nullable": False, - }), - add_missing_hints({ - "name": "engine_version", - "data_type": "bigint", - "nullable": False - }), - add_missing_hints({ - "name": "inserted_at", - "data_type": "timestamp", - "nullable": False - }), - add_missing_hints({ - "name": "schema_name", - "data_type": "text", - "nullable": False - }), - add_missing_hints({ - "name": "version_hash", - "data_type": "text", - "nullable": False - }), - add_missing_hints({ - "name": "schema", - "data_type": "text", - "nullable": False - }) - ] + table = new_table( + VERSION_TABLE_NAME, + columns=[ + add_missing_hints( + { + "name": "version", + "data_type": "bigint", + "nullable": False, + } + ), + add_missing_hints({"name": "engine_version", "data_type": "bigint", "nullable": False}), + add_missing_hints({"name": "inserted_at", "data_type": "timestamp", "nullable": False}), + add_missing_hints({"name": "schema_name", "data_type": "text", "nullable": False}), + add_missing_hints({"name": "version_hash", "data_type": "text", "nullable": False}), + add_missing_hints({"name": "schema", "data_type": "text", "nullable": False}), + ], ) table["write_disposition"] = "skip" table["description"] = "Created by DLT. Tracks schema updates" @@ -527,33 +577,21 @@ def version_table() -> TTableSchema: def load_table() -> TTableSchema: # NOTE: always add new columns at the end of the table so we have identical layout # after an update of existing tables (always at the end) - table = new_table(LOADS_TABLE_NAME, columns=[ - add_missing_hints({ - "name": "load_id", - "data_type": "text", - "nullable": False - }), - add_missing_hints({ - "name": "schema_name", - "data_type": "text", - "nullable": True - }), - add_missing_hints({ - "name": "status", - "data_type": "bigint", - "nullable": False - }), - add_missing_hints({ - "name": "inserted_at", - "data_type": "timestamp", - "nullable": False - }), - add_missing_hints({ - "name": "schema_version_hash", - "data_type": "text", - "nullable": True, - }), - ] + table = new_table( + LOADS_TABLE_NAME, + columns=[ + add_missing_hints({"name": "load_id", "data_type": "text", "nullable": False}), + add_missing_hints({"name": "schema_name", "data_type": "text", "nullable": True}), + add_missing_hints({"name": "status", "data_type": "bigint", "nullable": False}), + add_missing_hints({"name": "inserted_at", "data_type": "timestamp", "nullable": False}), + add_missing_hints( + { + "name": "schema_version_hash", + "data_type": "text", + "nullable": True, + } + ), + ], ) table["write_disposition"] = "skip" table["description"] = "Created by DLT. Tracks completed loads" @@ -566,12 +604,11 @@ def new_table( write_disposition: TWriteDisposition = None, columns: Sequence[TColumnSchema] = None, validate_schema: bool = False, - resource: str = None + resource: str = None, ) -> TTableSchema: - table: TTableSchema = { "name": table_name, - "columns": {} if columns is None else {c["name"]: add_missing_hints(c) for c in columns} + "columns": {} if columns is None else {c["name"]: add_missing_hints(c) for c in columns}, } if parent_table_name: table["parent"] = parent_table_name @@ -591,11 +628,13 @@ def new_table( return table -def new_column(column_name: str, data_type: TDataType = None, nullable: bool = True, validate_schema: bool = False) -> TColumnSchema: - column = add_missing_hints({ - "name": column_name, - "nullable": nullable - }) +def new_column( + column_name: str, + data_type: TDataType = None, + nullable: bool = True, + validate_schema: bool = False, +) -> TColumnSchema: + column = add_missing_hints({"name": column_name, "nullable": nullable}) if data_type: column["data_type"] = data_type if validate_schema: diff --git a/dlt/common/source.py b/dlt/common/source.py index a75c2dd948..249d54b4c5 100644 --- a/dlt/common/source.py +++ b/dlt/common/source.py @@ -10,6 +10,7 @@ class SourceInfo(NamedTuple): """Runtime information on the source/resource""" + SPEC: Type[BaseConfiguration] f: AnyFun module: ModuleType @@ -44,4 +45,4 @@ def _get_source_for_inner_function(f: AnyFun) -> Optional[SourceInfo]: # find source function parts = get_callable_name(f, "__qualname__").split(".") parent_fun = ".".join(parts[:-2]) - return _SOURCES.get(parent_fun) \ No newline at end of file + return _SOURCES.get(parent_fun) diff --git a/dlt/common/storages/__init__.py b/dlt/common/storages/__init__.py index 7b4260e9d5..0ce5497cdf 100644 --- a/dlt/common/storages/__init__.py +++ b/dlt/common/storages/__init__.py @@ -1,8 +1,13 @@ +from .configuration import ( # noqa: F401 + LoadStorageConfiguration, + NormalizeStorageConfiguration, + SchemaStorageConfiguration, + TSchemaFileFormat, +) +from .data_item_storage import DataItemStorage # noqa: F401 from .file_storage import FileStorage # noqa: F401 -from .versioned_storage import VersionedStorage # noqa: F401 -from .schema_storage import SchemaStorage # noqa: F401 from .live_schema_storage import LiveSchemaStorage # noqa: F401 -from .normalize_storage import NormalizeStorage # noqa: F401 from .load_storage import LoadStorage # noqa: F401 -from .data_item_storage import DataItemStorage # noqa: F401 -from .configuration import LoadStorageConfiguration, NormalizeStorageConfiguration, SchemaStorageConfiguration, TSchemaFileFormat # noqa: F401 +from .normalize_storage import NormalizeStorage # noqa: F401 +from .schema_storage import SchemaStorage # noqa: F401 +from .versioned_storage import VersionedStorage # noqa: F401 diff --git a/dlt/common/storages/configuration.py b/dlt/common/storages/configuration.py index 5fe4ce4a7c..0fdaf9bfb8 100644 --- a/dlt/common/storages/configuration.py +++ b/dlt/common/storages/configuration.py @@ -12,10 +12,18 @@ class SchemaStorageConfiguration(BaseConfiguration): import_schema_path: Optional[str] = None # import schema from external location export_schema_path: Optional[str] = None # export schema to external location external_schema_format: TSchemaFileFormat = "yaml" # format in which to expect external schema - external_schema_format_remove_defaults: bool = True # remove default values when exporting schema + external_schema_format_remove_defaults: bool = ( + True # remove default values when exporting schema + ) if TYPE_CHECKING: - def __init__(self, schema_volume_path: str = None, import_schema_path: str = None, export_schema_path: str = None) -> None: + + def __init__( + self, + schema_volume_path: str = None, + import_schema_path: str = None, + export_schema_path: str = None, + ) -> None: ... @@ -24,15 +32,23 @@ class NormalizeStorageConfiguration(BaseConfiguration): normalize_volume_path: str = None # path to volume where normalized loader files will be stored if TYPE_CHECKING: + def __init__(self, normalize_volume_path: str = None) -> None: ... @configspec class LoadStorageConfiguration(BaseConfiguration): - load_volume_path: str = None # path to volume where files to be loaded to analytical storage are stored - delete_completed_jobs: bool = False # if set to true the folder with completed jobs will be deleted + load_volume_path: str = ( + None # path to volume where files to be loaded to analytical storage are stored + ) + delete_completed_jobs: bool = ( + False # if set to true the folder with completed jobs will be deleted + ) if TYPE_CHECKING: - def __init__(self, load_volume_path: str = None, delete_completed_jobs: bool = None) -> None: + + def __init__( + self, load_volume_path: str = None, delete_completed_jobs: bool = None + ) -> None: ... diff --git a/dlt/common/storages/data_item_storage.py b/dlt/common/storages/data_item_storage.py index 140549de46..809c68f422 100644 --- a/dlt/common/storages/data_item_storage.py +++ b/dlt/common/storages/data_item_storage.py @@ -1,10 +1,10 @@ -from typing import Dict, Any, List from abc import ABC, abstractmethod +from typing import Any, Dict, List from dlt.common import logger +from dlt.common.data_writers import BufferedDataWriter, TLoaderFileFormat from dlt.common.schema import TTableSchemaColumns from dlt.common.typing import TDataItems -from dlt.common.data_writers import TLoaderFileFormat, BufferedDataWriter class DataItemStorage(ABC): @@ -24,12 +24,21 @@ def get_writer(self, load_id: str, schema_name: str, table_name: str) -> Buffere self.buffered_writers[writer_id] = writer return writer - def write_data_item(self, load_id: str, schema_name: str, table_name: str, item: TDataItems, columns: TTableSchemaColumns) -> None: + def write_data_item( + self, + load_id: str, + schema_name: str, + table_name: str, + item: TDataItems, + columns: TTableSchemaColumns, + ) -> None: writer = self.get_writer(load_id, schema_name, table_name) # write item(s) writer.write_data_item(item, columns) - def write_empty_file(self, load_id: str, schema_name: str, table_name: str, columns: TTableSchemaColumns) -> None: + def write_empty_file( + self, load_id: str, schema_name: str, table_name: str, columns: TTableSchemaColumns + ) -> None: writer = self.get_writer(load_id, schema_name, table_name) writer.write_empty_file(columns) @@ -37,7 +46,10 @@ def close_writers(self, extract_id: str) -> None: # flush and close all files for name, writer in self.buffered_writers.items(): if name.startswith(extract_id): - logger.debug(f"Closing writer for {name} with file {writer._file} and actual name {writer._file_name}") + logger.debug( + f"Closing writer for {name} with file {writer._file} and actual name" + f" {writer._file_name}" + ) writer.close() def closed_files(self) -> List[str]: diff --git a/dlt/common/storages/exceptions.py b/dlt/common/storages/exceptions.py index cab149c22c..3076183f9f 100644 --- a/dlt/common/storages/exceptions.py +++ b/dlt/common/storages/exceptions.py @@ -1,8 +1,9 @@ -import semver from typing import Iterable -from dlt.common.exceptions import DltException +import semver + from dlt.common.data_writers import TLoaderFileFormat +from dlt.common.exceptions import DltException class StorageException(DltException): @@ -11,20 +12,36 @@ def __init__(self, msg: str) -> None: class NoMigrationPathException(StorageException): - def __init__(self, storage_path: str, initial_version: semver.VersionInfo, migrated_version: semver.VersionInfo, target_version: semver.VersionInfo) -> None: + def __init__( + self, + storage_path: str, + initial_version: semver.VersionInfo, + migrated_version: semver.VersionInfo, + target_version: semver.VersionInfo, + ) -> None: self.storage_path = storage_path self.initial_version = initial_version self.migrated_version = migrated_version self.target_version = target_version - super().__init__(f"Could not find migration path for {storage_path} from v {initial_version} to {target_version}, stopped at {migrated_version}") + super().__init__( + f"Could not find migration path for {storage_path} from v {initial_version} to" + f" {target_version}, stopped at {migrated_version}" + ) class WrongStorageVersionException(StorageException): - def __init__(self, storage_path: str, initial_version: semver.VersionInfo, target_version: semver.VersionInfo) -> None: + def __init__( + self, + storage_path: str, + initial_version: semver.VersionInfo, + target_version: semver.VersionInfo, + ) -> None: self.storage_path = storage_path self.initial_version = initial_version self.target_version = target_version - super().__init__(f"Expected storage {storage_path} with v {target_version} but found {initial_version}") + super().__init__( + f"Expected storage {storage_path} with v {target_version} but found {initial_version}" + ) class LoadStorageException(StorageException): @@ -32,11 +49,16 @@ class LoadStorageException(StorageException): class JobWithUnsupportedWriterException(LoadStorageException): - def __init__(self, load_id: str, expected_file_formats: Iterable[TLoaderFileFormat], wrong_job: str) -> None: + def __init__( + self, load_id: str, expected_file_formats: Iterable[TLoaderFileFormat], wrong_job: str + ) -> None: self.load_id = load_id self.expected_file_formats = expected_file_formats self.wrong_job = wrong_job - super().__init__(f"Job {wrong_job} for load id {load_id} requires loader file format that is not one of {expected_file_formats}") + super().__init__( + f"Job {wrong_job} for load id {load_id} requires loader file format that is not one of" + f" {expected_file_formats}" + ) class LoadPackageNotFound(LoadStorageException, FileNotFoundError): @@ -51,12 +73,22 @@ class SchemaStorageException(StorageException): class InStorageSchemaModified(SchemaStorageException): def __init__(self, schema_name: str, storage_path: str) -> None: - msg = f"Schema {schema_name} in {storage_path} was externally modified. This is not allowed as that would prevent correct version tracking. Use import/export capabilities of DLT to provide external changes." + msg = ( + f"Schema {schema_name} in {storage_path} was externally modified. This is not allowed" + " as that would prevent correct version tracking. Use import/export capabilities of" + " DLT to provide external changes." + ) super().__init__(msg) class SchemaNotFoundError(SchemaStorageException, FileNotFoundError, KeyError): - def __init__(self, schema_name: str, storage_path: str, import_path: str = None, import_format: str = None) -> None: + def __init__( + self, + schema_name: str, + storage_path: str, + import_path: str = None, + import_format: str = None, + ) -> None: msg = f"Schema {schema_name} in {storage_path} could not be found." if import_path: msg += f"Import from {import_path} and format {import_format} failed." @@ -65,4 +97,7 @@ def __init__(self, schema_name: str, storage_path: str, import_path: str = None, class UnexpectedSchemaName(SchemaStorageException, ValueError): def __init__(self, schema_name: str, storage_path: str, stored_name: str) -> None: - super().__init__(f"A schema file name '{schema_name}' in {storage_path} does not correspond to the name of schema in the file {stored_name}") + super().__init__( + f"A schema file name '{schema_name}' in {storage_path} does not correspond to the name" + f" of schema in the file {stored_name}" + ) diff --git a/dlt/common/storages/file_storage.py b/dlt/common/storages/file_storage.py index 046116c82a..80f51537fe 100644 --- a/dlt/common/storages/file_storage.py +++ b/dlt/common/storages/file_storage.py @@ -1,24 +1,22 @@ +import errno import gzip import os import re +import shutil import stat -import errno import tempfile -import shutil +from typing import IO, Any, List, Optional, cast + import pathvalidate -from typing import IO, Any, Optional, List, cast -from dlt.common.typing import AnyFun +from dlt.common.typing import AnyFun from dlt.common.utils import encoding_for_mode, uniq_id - FILE_COMPONENT_INVALID_CHARACTERS = re.compile(r"[.%{}]") + class FileStorage: - def __init__(self, - storage_path: str, - file_type: str = "t", - makedirs: bool = False) -> None: + def __init__(self, storage_path: str, file_type: str = "t", makedirs: bool = False) -> None: # make it absolute path self.storage_path = os.path.realpath(storage_path) # os.path.join(, '') self.file_type = file_type @@ -31,7 +29,9 @@ def save(self, relative_path: str, data: Any) -> str: @staticmethod def save_atomic(storage_path: str, relative_path: str, data: Any, file_type: str = "t") -> str: mode = "w" + file_type - with tempfile.NamedTemporaryFile(dir=storage_path, mode=mode, delete=False, encoding=encoding_for_mode(mode)) as f: + with tempfile.NamedTemporaryFile( + dir=storage_path, mode=mode, delete=False, encoding=encoding_for_mode(mode) + ) as f: tmp_path = f.name f.write(data) try: @@ -75,7 +75,9 @@ def delete(self, relative_path: str) -> None: else: raise FileNotFoundError(file_path) - def delete_folder(self, relative_path: str, recursively: bool = False, delete_ro: bool = False) -> None: + def delete_folder( + self, relative_path: str, recursively: bool = False, delete_ro: bool = False + ) -> None: folder_path = self.make_full_path(relative_path) if os.path.isdir(folder_path): if recursively: @@ -98,7 +100,9 @@ def open_file(self, relative_path: str, mode: str = "r") -> IO[Any]: def open_temp(self, delete: bool = False, mode: str = "w", file_type: str = None) -> IO[Any]: mode = mode + file_type or self.file_type - return tempfile.NamedTemporaryFile(dir=self.storage_path, mode=mode, delete=delete, encoding=encoding_for_mode(mode)) + return tempfile.NamedTemporaryFile( + dir=self.storage_path, mode=mode, delete=delete, encoding=encoding_for_mode(mode) + ) def has_file(self, relative_path: str) -> bool: return os.path.isfile(self.make_full_path(relative_path)) @@ -119,7 +123,9 @@ def list_folder_files(self, relative_path: str, to_root: bool = True) -> List[st scan_path = self.make_full_path(relative_path) if to_root: # list files in relative path, returning paths relative to storage root - return [os.path.join(relative_path, e.name) for e in os.scandir(scan_path) if e.is_file()] + return [ + os.path.join(relative_path, e.name) for e in os.scandir(scan_path) if e.is_file() + ] else: # or to the folder return [e.name for e in os.scandir(scan_path) if e.is_file()] @@ -129,7 +135,9 @@ def list_folder_dirs(self, relative_path: str, to_root: bool = True) -> List[str scan_path = self.make_full_path(relative_path) if to_root: # list folders in relative path, returning paths relative to storage root - return [os.path.join(relative_path, e.name) for e in os.scandir(scan_path) if e.is_dir()] + return [ + os.path.join(relative_path, e.name) for e in os.scandir(scan_path) if e.is_dir() + ] else: # or to the folder return [e.name for e in os.scandir(scan_path) if e.is_dir()] @@ -139,10 +147,7 @@ def create_folder(self, relative_path: str, exists_ok: bool = False) -> None: def link_hard(self, from_relative_path: str, to_relative_path: str) -> None: # note: some interesting stuff on links https://lightrun.com/answers/conan-io-conan-research-investigate-symlinks-and-hard-links - os.link( - self.make_full_path(from_relative_path), - self.make_full_path(to_relative_path) - ) + os.link(self.make_full_path(from_relative_path), self.make_full_path(to_relative_path)) def atomic_rename(self, from_relative_path: str, to_relative_path: str) -> None: """Renames a path using os.rename which is atomic on POSIX, Windows and NFS v4. @@ -153,10 +158,7 @@ def atomic_rename(self, from_relative_path: str, to_relative_path: str) -> None: 3. All buckets mapped with FUSE are not atomic """ - os.rename( - self.make_full_path(from_relative_path), - self.make_full_path(to_relative_path) - ) + os.rename(self.make_full_path(from_relative_path), self.make_full_path(to_relative_path)) def rename_tree(self, from_relative_path: str, to_relative_path: str) -> None: """Renames a tree using os.rename if possible making it atomic @@ -197,7 +199,9 @@ def rename_tree_files(self, from_relative_path: str, to_relative_path: str) -> N def atomic_import(self, external_file_path: str, to_folder: str) -> str: """Moves a file at `external_file_path` into the `to_folder` effectively importing file into storage""" - return self.to_relative_path(FileStorage.copy_atomic(external_file_path, self.make_full_path(to_folder))) + return self.to_relative_path( + FileStorage.copy_atomic(external_file_path, self.make_full_path(to_folder)) + ) # file_name = FileStorage.get_file_name_from_file_path(external_path) # os.rename(external_path, os.path.join(self.make_full_path(to_folder), file_name)) @@ -244,7 +248,9 @@ def validate_file_name_component(name: str) -> None: pathvalidate.validate_filename(name, platform="Universal") # component cannot contain "." if FILE_COMPONENT_INVALID_CHARACTERS.search(name): - raise pathvalidate.error.InvalidCharError(description="Component name cannot contain the following characters: . % { }") + raise pathvalidate.error.InvalidCharError( + description="Component name cannot contain the following characters: . % { }" + ) @staticmethod def rmtree_del_ro(action: AnyFun, name: str, exc: Any) -> Any: diff --git a/dlt/common/storages/live_schema_storage.py b/dlt/common/storages/live_schema_storage.py index c482d5e7ea..3616f91c0b 100644 --- a/dlt/common/storages/live_schema_storage.py +++ b/dlt/common/storages/live_schema_storage.py @@ -1,14 +1,15 @@ from typing import Dict -from dlt.common.schema.schema import Schema from dlt.common.configuration.accessors import config -from dlt.common.storages.schema_storage import SchemaStorage +from dlt.common.schema.schema import Schema from dlt.common.storages.configuration import SchemaStorageConfiguration +from dlt.common.storages.schema_storage import SchemaStorage class LiveSchemaStorage(SchemaStorage): - - def __init__(self, config: SchemaStorageConfiguration = config.value, makedirs: bool = False) -> None: + def __init__( + self, config: SchemaStorageConfiguration = config.value, makedirs: bool = False + ) -> None: self.live_schemas: Dict[str, Schema] = {} super().__init__(config, makedirs) diff --git a/dlt/common/storages/load_storage.py b/dlt/common/storages/load_storage.py index 95170ac46c..55872edfb0 100644 --- a/dlt/common/storages/load_storage.py +++ b/dlt/common/storages/load_storage.py @@ -1,29 +1,40 @@ import contextlib -from copy import deepcopy -import os import datetime # noqa: 251 -import humanize +import os +from copy import deepcopy from os.path import join from pathlib import Path +from typing import ( + Dict, + Iterable, + List, + Literal, + NamedTuple, + Optional, + Sequence, + Set, + cast, + get_args, +) + +import humanize from pendulum.datetime import DateTime -from typing import Dict, Iterable, List, NamedTuple, Literal, Optional, Sequence, Set, get_args, cast from dlt.common import json, pendulum from dlt.common.configuration import known_sections -from dlt.common.configuration.inject import with_config -from dlt.common.typing import DictStrAny, StrAny -from dlt.common.storages.file_storage import FileStorage -from dlt.common.data_writers import TLoaderFileFormat, DataWriter from dlt.common.configuration.accessors import config +from dlt.common.configuration.inject import with_config +from dlt.common.data_writers import DataWriter, TLoaderFileFormat from dlt.common.exceptions import TerminalValueError from dlt.common.schema import Schema, TSchemaTables, TTableSchemaColumns from dlt.common.storages.configuration import LoadStorageConfiguration -from dlt.common.storages.versioned_storage import VersionedStorage from dlt.common.storages.data_item_storage import DataItemStorage from dlt.common.storages.exceptions import JobWithUnsupportedWriterException, LoadPackageNotFound +from dlt.common.storages.file_storage import FileStorage +from dlt.common.storages.versioned_storage import VersionedStorage +from dlt.common.typing import DictStrAny, StrAny from dlt.common.utils import flatten_list_or_items - # folders to manage load jobs in a single load package TJobState = Literal["new_jobs", "failed_jobs", "started_jobs", "completed_jobs"] WORKING_FOLDERS = set(get_args(TJobState)) @@ -46,7 +57,9 @@ def parse(file_name: str) -> "ParsedLoadJobFileName": if len(parts) != 4: raise TerminalValueError(parts) - return ParsedLoadJobFileName(parts[0], parts[1], int(parts[2]), cast(TLoaderFileFormat, parts[3])) + return ParsedLoadJobFileName( + parts[0], parts[1], int(parts[2]), cast(TLoaderFileFormat, parts[3]) + ) class LoadJobInfo(NamedTuple): @@ -66,10 +79,22 @@ def asdict(self) -> DictStrAny: return d def asstr(self, verbosity: int = 0) -> str: - failed_msg = "The job FAILED TERMINALLY and cannot be restarted." if self.failed_message else "" - elapsed_msg = humanize.precisedelta(pendulum.duration(seconds=self.elapsed)) if self.elapsed else "---" - msg = f"Job: {self.job_file_info.job_id()}, table: {self.job_file_info.table_name} in {self.state}. " - msg += f"File type: {self.job_file_info.file_format}, size: {humanize.naturalsize(self.file_size, binary=True, gnu=True)}. " + failed_msg = ( + "The job FAILED TERMINALLY and cannot be restarted." if self.failed_message else "" + ) + elapsed_msg = ( + humanize.precisedelta(pendulum.duration(seconds=self.elapsed)) + if self.elapsed + else "---" + ) + msg = ( + f"Job: {self.job_file_info.job_id()}, table: {self.job_file_info.table_name} in" + f" {self.state}. " + ) + msg += ( + f"File type: {self.job_file_info.file_format}, size:" + f" {humanize.naturalsize(self.file_size, binary=True, gnu=True)}. " + ) msg += f"Started on: {self.created_at} and completed in {elapsed_msg}." if failed_msg: msg += "\nThe job FAILED TERMINALLY and cannot be restarted." @@ -112,8 +137,16 @@ def asdict(self) -> DictStrAny: return d def asstr(self, verbosity: int = 0) -> str: - completed_msg = f"The package was {self.state.upper()} at {self.completed_at}" if self.completed_at else "The package is being PROCESSED" - msg = f"The package with load id {self.load_id} for schema {self.schema_name} is in {self.state} state. It updated schema for {len(self.schema_update)} tables. {completed_msg}.\n" + completed_msg = ( + f"The package was {self.state.upper()} at {self.completed_at}" + if self.completed_at + else "The package is being PROCESSED" + ) + msg = ( + f"The package with load id {self.load_id} for schema {self.schema_name} is in" + f" {self.state} state. It updated schema for {len(self.schema_update)} tables." + f" {completed_msg}.\n" + ) msg += "Jobs details:\n" msg += "\n".join(job.asstr(verbosity) for job in flatten_list_or_items(iter(self.jobs.values()))) # type: ignore return msg @@ -123,7 +156,6 @@ def __str__(self) -> str: class LoadStorage(DataItemStorage, VersionedStorage): - STORAGE_VERSION = "1.0.0" NORMALIZED_FOLDER = "normalized" # folder within the volume where load packages are stored LOADED_FOLDER = "loaded" # folder to keep the loads that were completely processed @@ -133,10 +165,16 @@ class LoadStorage(DataItemStorage, VersionedStorage): STARTED_JOBS_FOLDER: TJobState = "started_jobs" COMPLETED_JOBS_FOLDER: TJobState = "completed_jobs" - SCHEMA_UPDATES_FILE_NAME = "schema_updates.json" # updates to the tables in schema created by normalizer - APPLIED_SCHEMA_UPDATES_FILE_NAME = "applied_" + "schema_updates.json" # updates applied to the destination + SCHEMA_UPDATES_FILE_NAME = ( # updates to the tables in schema created by normalizer + "schema_updates.json" + ) + APPLIED_SCHEMA_UPDATES_FILE_NAME = ( + "applied_" + "schema_updates.json" + ) # updates applied to the destination SCHEMA_FILE_NAME = "schema.json" # package schema - PACKAGE_COMPLETED_FILE_NAME = "package_completed.json" # completed package marker file, currently only to store data with os.stat + PACKAGE_COMPLETED_FILE_NAME = ( # completed package marker file, currently only to store data with os.stat + "package_completed.json" + ) ALL_SUPPORTED_FILE_FORMATS: Set[TLoaderFileFormat] = set(get_args(TLoaderFileFormat)) @@ -146,7 +184,7 @@ def __init__( is_owner: bool, preferred_file_format: TLoaderFileFormat, supported_file_formats: Iterable[TLoaderFileFormat], - config: LoadStorageConfiguration = config.value + config: LoadStorageConfiguration = config.value, ) -> None: if not LoadStorage.ALL_SUPPORTED_FILE_FORMATS.issuperset(supported_file_formats): raise TerminalValueError(supported_file_formats) @@ -157,7 +195,8 @@ def __init__( super().__init__( preferred_file_format, LoadStorage.STORAGE_VERSION, - is_owner, FileStorage(config.load_volume_path, "t", makedirs=is_owner) + is_owner, + FileStorage(config.load_volume_path, "t", makedirs=is_owner), ) if is_owner: self.initialize_storage() @@ -181,8 +220,19 @@ def _get_data_item_path_template(self, load_id: str, _: str, table_name: str) -> file_name = self.build_job_file_name(table_name, "%s", with_extension=False) return self.storage.make_full_path(join(load_id, LoadStorage.NEW_JOBS_FOLDER, file_name)) - def write_temp_job_file(self, load_id: str, table_name: str, table: TTableSchemaColumns, file_id: str, rows: Sequence[StrAny]) -> str: - file_name = self._get_data_item_path_template(load_id, None, table_name) % file_id + "." + self.loader_file_format + def write_temp_job_file( + self, + load_id: str, + table_name: str, + table: TTableSchemaColumns, + file_id: str, + rows: Sequence[StrAny], + ) -> str: + file_name = ( + self._get_data_item_path_template(load_id, None, table_name) % file_id + + "." + + self.loader_file_format + ) format_spec = DataWriter.data_format_from_file_format(self.loader_file_format) mode = "wb" if format_spec.is_binary_format else "w" with self.storage.open_file(file_name, mode=mode) as f: @@ -206,7 +256,9 @@ def save_temp_schema(self, schema: Schema, load_id: str) -> str: return self.storage.save(join(load_id, LoadStorage.SCHEMA_FILE_NAME), dump) def save_temp_schema_updates(self, load_id: str, schema_update: TSchemaTables) -> None: - with self.storage.open_file(join(load_id, LoadStorage.SCHEMA_UPDATES_FILE_NAME), mode="wb") as f: + with self.storage.open_file( + join(load_id, LoadStorage.SCHEMA_UPDATES_FILE_NAME), mode="wb" + ) as f: json.dump(schema_update, f) def commit_temp_load_package(self, load_id: str) -> None: @@ -223,36 +275,57 @@ def list_completed_packages(self) -> Sequence[str]: return sorted(loads) def list_new_jobs(self, load_id: str) -> Sequence[str]: - new_jobs = self.storage.list_folder_files(self._get_job_folder_path(load_id, LoadStorage.NEW_JOBS_FOLDER)) + new_jobs = self.storage.list_folder_files( + self._get_job_folder_path(load_id, LoadStorage.NEW_JOBS_FOLDER) + ) # make sure all jobs have supported writers - wrong_job = next((j for j in new_jobs if LoadStorage.parse_job_file_name(j).file_format not in self.supported_file_formats), None) + wrong_job = next( + ( + j + for j in new_jobs + if LoadStorage.parse_job_file_name(j).file_format not in self.supported_file_formats + ), + None, + ) if wrong_job is not None: raise JobWithUnsupportedWriterException(load_id, self.supported_file_formats, wrong_job) return new_jobs def list_started_jobs(self, load_id: str) -> Sequence[str]: - return self.storage.list_folder_files(self._get_job_folder_path(load_id, LoadStorage.STARTED_JOBS_FOLDER)) + return self.storage.list_folder_files( + self._get_job_folder_path(load_id, LoadStorage.STARTED_JOBS_FOLDER) + ) def list_failed_jobs(self, load_id: str) -> Sequence[str]: - return self.storage.list_folder_files(self._get_job_folder_path(load_id, LoadStorage.FAILED_JOBS_FOLDER)) + return self.storage.list_folder_files( + self._get_job_folder_path(load_id, LoadStorage.FAILED_JOBS_FOLDER) + ) def list_jobs_for_table(self, load_id: str, table_name: str) -> Sequence[LoadJobInfo]: info = self.get_load_package_info(load_id) return [job for job in flatten_list_or_items(iter(info.jobs.values())) if job.job_file_info.table_name == table_name] # type: ignore def list_completed_failed_jobs(self, load_id: str) -> Sequence[str]: - return self.storage.list_folder_files(self._get_job_folder_completed_path(load_id, LoadStorage.FAILED_JOBS_FOLDER)) + return self.storage.list_folder_files( + self._get_job_folder_completed_path(load_id, LoadStorage.FAILED_JOBS_FOLDER) + ) def list_failed_jobs_in_completed_package(self, load_id: str) -> Sequence[LoadJobInfo]: """List all failed jobs and associated error messages for a completed load package with `load_id`""" failed_jobs: List[LoadJobInfo] = [] package_path = self.get_completed_package_path(load_id) package_created_at = pendulum.from_timestamp( - os.path.getmtime(self.storage.make_full_path(join(package_path, LoadStorage.PACKAGE_COMPLETED_FILE_NAME))) + os.path.getmtime( + self.storage.make_full_path( + join(package_path, LoadStorage.PACKAGE_COMPLETED_FILE_NAME) + ) + ) ) for file in self.list_completed_failed_jobs(load_id): if not file.endswith(".exception"): - failed_jobs.append(self._read_job_file_info("failed_jobs", file, package_created_at)) + failed_jobs.append( + self._read_job_file_info("failed_jobs", file, package_created_at) + ) return failed_jobs def get_load_package_info(self, load_id: str) -> LoadPackageInfo: @@ -266,10 +339,14 @@ def get_load_package_info(self, load_id: str) -> LoadPackageInfo: package_path = self.get_completed_package_path(load_id) if not self.storage.has_folder(package_path): raise LoadPackageNotFound(load_id) - completed_file_path = self.storage.make_full_path(join(package_path, LoadStorage.PACKAGE_COMPLETED_FILE_NAME)) + completed_file_path = self.storage.make_full_path( + join(package_path, LoadStorage.PACKAGE_COMPLETED_FILE_NAME) + ) package_created_at = pendulum.from_timestamp(os.path.getmtime(completed_file_path)) package_state = self.storage.load(completed_file_path) - applied_schema_update_file = join(package_path, LoadStorage.APPLIED_SCHEMA_UPDATES_FILE_NAME) + applied_schema_update_file = join( + package_path, LoadStorage.APPLIED_SCHEMA_UPDATES_FILE_NAME + ) if self.storage.has_file(applied_schema_update_file): applied_update = json.loads(self.storage.load(applied_schema_update_file)) schema = self._load_schema(join(package_path, LoadStorage.SCHEMA_FILE_NAME)) @@ -284,7 +361,15 @@ def get_load_package_info(self, load_id: str) -> LoadPackageInfo: jobs.append(self._read_job_file_info(state, file, package_created_at)) all_jobs[state] = jobs - return LoadPackageInfo(load_id, self.storage.make_full_path(package_path), package_state, schema.name, applied_update, package_created_at, all_jobs) + return LoadPackageInfo( + load_id, + self.storage.make_full_path(package_path), + package_state, + schema.name, + applied_update, + package_created_at, + all_jobs, + ) def begin_schema_update(self, load_id: str) -> Optional[TSchemaTables]: package_path = self.get_package_path(load_id) @@ -307,32 +392,53 @@ def commit_schema_update(self, load_id: str, applied_update: TSchemaTables) -> N # save applied update self.storage.save(processed_schema_update_file, json.dumps(applied_update)) - def add_new_job(self, load_id: str, job_file_path: str, job_state: TJobState = "new_jobs") -> None: + def add_new_job( + self, load_id: str, job_file_path: str, job_state: TJobState = "new_jobs" + ) -> None: """Adds new job by moving the `job_file_path` into `new_jobs` of package `load_id`""" self.storage.atomic_import(job_file_path, self._get_job_folder_path(load_id, job_state)) def start_job(self, load_id: str, file_name: str) -> str: - return self._move_job(load_id, LoadStorage.NEW_JOBS_FOLDER, LoadStorage.STARTED_JOBS_FOLDER, file_name) + return self._move_job( + load_id, LoadStorage.NEW_JOBS_FOLDER, LoadStorage.STARTED_JOBS_FOLDER, file_name + ) def fail_job(self, load_id: str, file_name: str, failed_message: Optional[str]) -> str: # save the exception to failed jobs if failed_message: self.storage.save( - self._get_job_file_path(load_id, LoadStorage.FAILED_JOBS_FOLDER, file_name + ".exception"), - failed_message + self._get_job_file_path( + load_id, LoadStorage.FAILED_JOBS_FOLDER, file_name + ".exception" + ), + failed_message, ) # move to failed jobs - return self._move_job(load_id, LoadStorage.STARTED_JOBS_FOLDER, LoadStorage.FAILED_JOBS_FOLDER, file_name) + return self._move_job( + load_id, LoadStorage.STARTED_JOBS_FOLDER, LoadStorage.FAILED_JOBS_FOLDER, file_name + ) def retry_job(self, load_id: str, file_name: str) -> str: # when retrying job we must increase the retry count source_fn = ParsedLoadJobFileName.parse(file_name) - dest_fn = ParsedLoadJobFileName(source_fn.table_name, source_fn.file_id, source_fn.retry_count + 1, source_fn.file_format) + dest_fn = ParsedLoadJobFileName( + source_fn.table_name, + source_fn.file_id, + source_fn.retry_count + 1, + source_fn.file_format, + ) # move it directly to new file name - return self._move_job(load_id, LoadStorage.STARTED_JOBS_FOLDER, LoadStorage.NEW_JOBS_FOLDER, file_name, dest_fn.job_id()) + return self._move_job( + load_id, + LoadStorage.STARTED_JOBS_FOLDER, + LoadStorage.NEW_JOBS_FOLDER, + file_name, + dest_fn.job_id(), + ) def complete_job(self, load_id: str, file_name: str) -> str: - return self._move_job(load_id, LoadStorage.STARTED_JOBS_FOLDER, LoadStorage.COMPLETED_JOBS_FOLDER, file_name) + return self._move_job( + load_id, LoadStorage.STARTED_JOBS_FOLDER, LoadStorage.COMPLETED_JOBS_FOLDER, file_name + ) def complete_load_package(self, load_id: str, aborted: bool) -> None: load_path = self.get_package_path(load_id) @@ -341,7 +447,8 @@ def complete_load_package(self, load_id: str, aborted: bool) -> None: if self.config.delete_completed_jobs and not has_failed_jobs: self.storage.delete_folder( self._get_job_folder_path(load_id, LoadStorage.COMPLETED_JOBS_FOLDER), - recursively=True) + recursively=True, + ) # save marker file completed_state: TLoadPackageState = "aborted" if aborted else "loaded" self.storage.save(join(load_path, LoadStorage.PACKAGE_COMPLETED_FILE_NAME), completed_state) @@ -376,7 +483,14 @@ def _load_schema(self, schema_path: str) -> Schema: stored_schema: DictStrAny = json.loads(self.storage.load(schema_path)) return Schema.from_dict(stored_schema) - def _move_job(self, load_id: str, source_folder: TJobState, dest_folder: TJobState, file_name: str, new_file_name: str = None) -> str: + def _move_job( + self, + load_id: str, + source_folder: TJobState, + dest_folder: TJobState, + file_name: str, + new_file_name: str = None, + ) -> str: # ensure we move file names, not paths assert file_name == FileStorage.get_file_name_from_file_path(file_name) load_path = self.get_package_path(load_id) @@ -408,10 +522,17 @@ def _read_job_file_info(self, state: TJobState, file: str, now: DateTime = None) pendulum.from_timestamp(st.st_mtime), self.job_elapsed_time_seconds(full_path, now.timestamp() if now else None), self.parse_job_file_name(file), - failed_message + failed_message, ) - def build_job_file_name(self, table_name: str, file_id: str, retry_count: int = 0, validate_components: bool = True, with_extension: bool = True) -> str: + def build_job_file_name( + self, + table_name: str, + file_id: str, + retry_count: int = 0, + validate_components: bool = True, + with_extension: bool = True, + ) -> str: if validate_components: FileStorage.validate_file_name_component(table_name) # FileStorage.validate_file_name_component(file_id) diff --git a/dlt/common/storages/normalize_storage.py b/dlt/common/storages/normalize_storage.py index 45f541f5ec..4d711b36ed 100644 --- a/dlt/common/storages/normalize_storage.py +++ b/dlt/common/storages/normalize_storage.py @@ -1,13 +1,14 @@ -from typing import ClassVar, Sequence, NamedTuple from itertools import groupby from pathlib import Path +from typing import ClassVar, NamedTuple, Sequence -from dlt.common.configuration import with_config, known_sections +from dlt.common.configuration import known_sections, with_config from dlt.common.configuration.accessors import config -from dlt.common.storages.file_storage import FileStorage from dlt.common.storages.configuration import NormalizeStorageConfiguration +from dlt.common.storages.file_storage import FileStorage from dlt.common.storages.versioned_storage import VersionedStorage + class TParsedNormalizeFileName(NamedTuple): schema_name: str table_name: str @@ -15,13 +16,20 @@ class TParsedNormalizeFileName(NamedTuple): class NormalizeStorage(VersionedStorage): - STORAGE_VERSION: ClassVar[str] = "1.0.0" - EXTRACTED_FOLDER: ClassVar[str] = "extracted" # folder within the volume where extracted files to be normalized are stored + EXTRACTED_FOLDER: ClassVar[str] = ( + "extracted" # folder within the volume where extracted files to be normalized are stored + ) @with_config(spec=NormalizeStorageConfiguration, sections=(known_sections.NORMALIZE,)) - def __init__(self, is_owner: bool, config: NormalizeStorageConfiguration = config.value) -> None: - super().__init__(NormalizeStorage.STORAGE_VERSION, is_owner, FileStorage(config.normalize_volume_path, "t", makedirs=is_owner)) + def __init__( + self, is_owner: bool, config: NormalizeStorageConfiguration = config.value + ) -> None: + super().__init__( + NormalizeStorage.STORAGE_VERSION, + is_owner, + FileStorage(config.normalize_volume_path, "t", makedirs=is_owner), + ) self.config = config if is_owner: self.initialize_storage() diff --git a/dlt/common/storages/schema_storage.py b/dlt/common/storages/schema_storage.py index a9fee71531..0c2a5e74eb 100644 --- a/dlt/common/storages/schema_storage.py +++ b/dlt/common/storages/schema_storage.py @@ -1,24 +1,33 @@ -import yaml from typing import Iterator, List, Mapping, Tuple +import yaml + from dlt.common import json, logger from dlt.common.configuration import with_config from dlt.common.configuration.accessors import config -from dlt.common.storages.configuration import SchemaStorageConfiguration, TSchemaFileFormat, SchemaFileExtensions -from dlt.common.storages.file_storage import FileStorage from dlt.common.schema import Schema, verify_schema_hash +from dlt.common.storages.configuration import ( + SchemaFileExtensions, + SchemaStorageConfiguration, + TSchemaFileFormat, +) +from dlt.common.storages.exceptions import ( + InStorageSchemaModified, + SchemaNotFoundError, + UnexpectedSchemaName, +) +from dlt.common.storages.file_storage import FileStorage from dlt.common.typing import DictStrAny -from dlt.common.storages.exceptions import InStorageSchemaModified, SchemaNotFoundError, UnexpectedSchemaName - class SchemaStorage(Mapping[str, Schema]): - SCHEMA_FILE_NAME = "schema.%s" NAMED_SCHEMA_FILE_PATTERN = f"%s.{SCHEMA_FILE_NAME}" @with_config(spec=SchemaStorageConfiguration, sections=("schema",)) - def __init__(self, config: SchemaStorageConfiguration = config.value, makedirs: bool = False) -> None: + def __init__( + self, config: SchemaStorageConfiguration = config.value, makedirs: bool = False + ) -> None: self.config = config self.storage = FileStorage(config.schema_volume_path, makedirs=makedirs) @@ -97,7 +106,11 @@ def _maybe_import_schema(self, name: str, storage_schema: DictStrAny = None) -> # if schema was imported, overwrite storage schema rv_schema._imported_version_hash = rv_schema.version_hash self._save_schema(rv_schema) - logger.info(f"Schema {name} not present in {self.storage.storage_path} and got imported with version {rv_schema.stored_version} and imported hash {rv_schema._imported_version_hash}") + logger.info( + f"Schema {name} not present in {self.storage.storage_path} and got imported" + f" with version {rv_schema.stored_version} and imported hash" + f" {rv_schema._imported_version_hash}" + ) else: # import schema when imported schema was modified from the last import sc = Schema.from_dict(storage_schema) @@ -108,14 +121,23 @@ def _maybe_import_schema(self, name: str, storage_schema: DictStrAny = None) -> rv_schema._imported_version_hash = rv_schema.version_hash # if schema was imported, overwrite storage schema self._save_schema(rv_schema) - logger.info(f"Schema {name} was present in {self.storage.storage_path} but is overwritten with imported schema version {rv_schema.stored_version} and imported hash {rv_schema._imported_version_hash}") + logger.info( + f"Schema {name} was present in {self.storage.storage_path} but is" + f" overwritten with imported schema version {rv_schema.stored_version} and" + f" imported hash {rv_schema._imported_version_hash}" + ) else: # use storage schema as nothing changed rv_schema = sc except FileNotFoundError: # no schema to import -> skip silently and return the original if storage_schema is None: - raise SchemaNotFoundError(name, self.config.schema_volume_path, self.config.import_schema_path, self.config.external_schema_format) + raise SchemaNotFoundError( + name, + self.config.schema_volume_path, + self.config.import_schema_path, + self.config.external_schema_format, + ) rv_schema = Schema.from_dict(storage_schema) assert rv_schema is not None @@ -124,20 +146,29 @@ def _maybe_import_schema(self, name: str, storage_schema: DictStrAny = None) -> def _load_import_schema(self, name: str) -> DictStrAny: import_storage = FileStorage(self.config.import_schema_path, makedirs=False) schema_file = self._file_name_in_store(name, self.config.external_schema_format) - return self._parse_schema_str(import_storage.load(schema_file), self.config.external_schema_format) + return self._parse_schema_str( + import_storage.load(schema_file), self.config.external_schema_format + ) def _export_schema(self, schema: Schema, export_path: str) -> None: if self.config.external_schema_format == "json": - exported_schema_s = schema.to_pretty_json(remove_defaults=self.config.external_schema_format_remove_defaults) + exported_schema_s = schema.to_pretty_json( + remove_defaults=self.config.external_schema_format_remove_defaults + ) elif self.config.external_schema_format == "yaml": - exported_schema_s = schema.to_pretty_yaml(remove_defaults=self.config.external_schema_format_remove_defaults) + exported_schema_s = schema.to_pretty_yaml( + remove_defaults=self.config.external_schema_format_remove_defaults + ) else: raise ValueError(self.config.external_schema_format) export_storage = FileStorage(export_path, makedirs=True) schema_file = self._file_name_in_store(schema.name, self.config.external_schema_format) export_storage.save(schema_file, exported_schema_s) - logger.info(f"Schema {schema.name} exported to {export_path} with version {schema.stored_version} as {self.config.external_schema_format}") + logger.info( + f"Schema {schema.name} exported to {export_path} with version" + f" {schema.stored_version} as {self.config.external_schema_format}" + ) def _save_schema(self, schema: Schema) -> str: # save a schema to schema store @@ -145,7 +176,9 @@ def _save_schema(self, schema: Schema) -> str: return self.storage.save(schema_file, schema.to_pretty_json(remove_defaults=False)) @staticmethod - def load_schema_file(path: str, name: str, extensions: Tuple[TSchemaFileFormat, ...]=SchemaFileExtensions) -> Schema: + def load_schema_file( + path: str, name: str, extensions: Tuple[TSchemaFileFormat, ...] = SchemaFileExtensions + ) -> Schema: storage = FileStorage(path) for extension in extensions: file = SchemaStorage._file_name_in_store(name, extension) diff --git a/dlt/common/storages/transactional_file.py b/dlt/common/storages/transactional_file.py index f059a3a0c7..2cf6b74676 100644 --- a/dlt/common/storages/transactional_file.py +++ b/dlt/common/storages/transactional_file.py @@ -5,18 +5,19 @@ cloud storage. The lock can be used to operate on entire directories by creating a lock file that resolves to an agreed upon path across processes. """ +import posixpath import random import string import time import typing as t -from pathlib import Path -import posixpath from contextlib import contextmanager -from dlt.common.pendulum import pendulum, timedelta +from pathlib import Path from threading import Timer import fsspec +from dlt.common.pendulum import pendulum, timedelta + def lock_id(k: int = 4) -> str: """Generate a time based random id. @@ -33,6 +34,7 @@ def lock_id(k: int = 4) -> str: class Heartbeat(Timer): """A thread designed to periodically execute a fn.""" + daemon = True def run(self) -> None: @@ -72,7 +74,9 @@ def __init__(self, path: str, fs: fsspec.AbstractFileSystem) -> None: parsed_path = Path(path) if not parsed_path.is_absolute(): - raise ValueError(f"{path} is not absolute. Please pass only absolute paths to TransactionalFile") + raise ValueError( + f"{path} is not absolute. Please pass only absolute paths to TransactionalFile" + ) self.path = path if proto == "file": # standardize path separator to POSIX. fsspec always uses POSIX. Windows may use either. @@ -114,7 +118,7 @@ def _sync_locks(self) -> t.List[str]: # Purge stale locks mtime = self.extract_mtime(lock) if now - mtime > timedelta(seconds=TransactionalFile.LOCK_TTL_SECONDS): - try: # Janitors can race, so we ignore errors + try: # Janitors can race, so we ignore errors self._fs.rm_file(name) except OSError: pass @@ -122,7 +126,10 @@ def _sync_locks(self) -> t.List[str]: # The name is timestamp + random suffix and is time sortable output.append(name) if not output: - raise RuntimeError(f"When syncing locks for path {self.path} and lock {self.lock_path} no lock file was found") + raise RuntimeError( + f"When syncing locks for path {self.path} and lock {self.lock_path} no lock file" + " was found" + ) return output def read(self) -> t.Optional[bytes]: @@ -148,7 +155,9 @@ def rollback(self) -> None: elif self._fs.isfile(self.path): self._fs.rm_file(self.path) - def acquire_lock(self, blocking: bool = True, timeout: float = -1, jitter_mean: float = 0) -> bool: + def acquire_lock( + self, blocking: bool = True, timeout: float = -1, jitter_mean: float = 0 + ) -> bool: """Acquires a lock on a path. Mimics the stdlib's `threading.Lock` interface. Acquire a lock, blocking or non-blocking. diff --git a/dlt/common/storages/versioned_storage.py b/dlt/common/storages/versioned_storage.py index c61156c540..0e5f5ceb68 100644 --- a/dlt/common/storages/versioned_storage.py +++ b/dlt/common/storages/versioned_storage.py @@ -1,11 +1,10 @@ import semver -from dlt.common.storages.file_storage import FileStorage from dlt.common.storages.exceptions import NoMigrationPathException, WrongStorageVersionException +from dlt.common.storages.file_storage import FileStorage class VersionedStorage: - VERSION_FILE = ".version" def __init__(self, version: semver.VersionInfo, is_owner: bool, storage: FileStorage) -> None: @@ -16,24 +15,34 @@ def __init__(self, version: semver.VersionInfo, is_owner: bool, storage: FileSto if existing_version != version: if existing_version > version: # version cannot be downgraded - raise NoMigrationPathException(storage.storage_path, existing_version, existing_version, version) + raise NoMigrationPathException( + storage.storage_path, existing_version, existing_version, version + ) if is_owner: # only owner can migrate storage self.migrate_storage(existing_version, version) # storage should be migrated to desired version migrated_version = self._load_version() if version != migrated_version: - raise NoMigrationPathException(storage.storage_path, existing_version, migrated_version, version) + raise NoMigrationPathException( + storage.storage_path, existing_version, migrated_version, version + ) else: # we cannot use storage and we must wait for owner to upgrade it - raise WrongStorageVersionException(storage.storage_path, existing_version, version) + raise WrongStorageVersionException( + storage.storage_path, existing_version, version + ) else: if is_owner: self._save_version(version) else: - raise WrongStorageVersionException(storage.storage_path, semver.VersionInfo.parse("0.0.0"), version) + raise WrongStorageVersionException( + storage.storage_path, semver.VersionInfo.parse("0.0.0"), version + ) - def migrate_storage(self, from_version: semver.VersionInfo, to_version: semver.VersionInfo) -> None: + def migrate_storage( + self, from_version: semver.VersionInfo, to_version: semver.VersionInfo + ) -> None: # migration example: # # semver lib supports comparing both to string and other semvers # if from_version == "1.0.0" and from_version < to_version: diff --git a/dlt/common/time.py b/dlt/common/time.py index 6195716710..f35dda2a5a 100644 --- a/dlt/common/time.py +++ b/dlt/common/time.py @@ -1,22 +1,28 @@ import contextlib -from typing import Any, Optional, Union, overload # noqa import datetime # noqa: I251 +from typing import Any, Optional, Union, overload # noqa -from dlt.common.pendulum import pendulum, timedelta -from dlt.common.typing import TimedeltaSeconds, TAnyDateTime -from pendulum.parsing import parse_iso8601, _parse_common as parse_datetime_common +from pendulum.parsing import _parse_common as parse_datetime_common +from pendulum.parsing import parse_iso8601 from pendulum.tz import UTC +from dlt.common.pendulum import pendulum, timedelta +from dlt.common.typing import TAnyDateTime, TimedeltaSeconds + PAST_TIMESTAMP: float = 0.0 FUTURE_TIMESTAMP: float = 9999999999.0 DAY_DURATION_SEC: float = 24 * 60 * 60.0 -def timestamp_within(timestamp: float, min_exclusive: Optional[float], max_inclusive: Optional[float]) -> bool: +def timestamp_within( + timestamp: float, min_exclusive: Optional[float], max_inclusive: Optional[float] +) -> bool: """ check if timestamp within range uniformly treating none and range inclusiveness """ - return timestamp > (min_exclusive or PAST_TIMESTAMP) and timestamp <= (max_inclusive or FUTURE_TIMESTAMP) + return timestamp > (min_exclusive or PAST_TIMESTAMP) and timestamp <= ( + max_inclusive or FUTURE_TIMESTAMP + ) def timestamp_before(timestamp: float, max_inclusive: Optional[float]) -> bool: @@ -48,7 +54,7 @@ def parse_iso_like_datetime(value: Any) -> Union[pendulum.DateTime, pendulum.Dat dtv.minute, dtv.second, dtv.microsecond, - tz=dtv.tzinfo or UTC # type: ignore + tz=dtv.tzinfo or UTC, # type: ignore ) # no typings for pendulum return dtv # type: ignore @@ -122,5 +128,7 @@ def ensure_pendulum_datetime(value: TAnyDateTime) -> pendulum.DateTime: raise TypeError(f"Cannot coerce {value} to a pendulum.DateTime object.") -def reduce_pendulum_datetime_precision(value: pendulum.DateTime, microsecond_precision: int) -> pendulum.DateTime: - return value.set(microsecond=value.microsecond // 10**(6 - microsecond_precision) * 10**(6 - microsecond_precision)) # type: ignore +def reduce_pendulum_datetime_precision( + value: pendulum.DateTime, microsecond_precision: int +) -> pendulum.DateTime: + return value.set(microsecond=value.microsecond // 10 ** (6 - microsecond_precision) * 10 ** (6 - microsecond_precision)) # type: ignore diff --git a/dlt/common/typing.py b/dlt/common/typing.py index 86fa1635df..6571832bec 100644 --- a/dlt/common/typing.py +++ b/dlt/common/typing.py @@ -1,19 +1,44 @@ -from collections.abc import Mapping as C_Mapping, Sequence as C_Sequence -from datetime import datetime, date # noqa: I251 import inspect +from collections.abc import Mapping as C_Mapping +from collections.abc import Sequence as C_Sequence +from datetime import date, datetime # noqa: I251 from re import Pattern as _REPattern -from typing import Callable, Dict, Any, Final, Literal, List, Mapping, NewType, Optional, Tuple, Type, TypeVar, Generic, Protocol, TYPE_CHECKING, Union, runtime_checkable, get_args, get_origin -from typing_extensions import TypeAlias, ParamSpec, Concatenate - -from dlt.common.pendulum import timedelta, pendulum +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Final, + Generic, + List, + Literal, + Mapping, + NewType, + Optional, + Protocol, + Tuple, + Type, + TypeVar, + Union, + get_args, + get_origin, + runtime_checkable, +) + +from typing_extensions import Concatenate, ParamSpec, TypeAlias + +from dlt.common.pendulum import pendulum, timedelta if TYPE_CHECKING: - from _typeshed import StrOrBytesPath from typing import _TypedDict + + from _typeshed import StrOrBytesPath + REPattern = _REPattern[str] else: StrOrBytesPath = Any from typing import _TypedDictMeta as _TypedDict + REPattern = _REPattern AnyType: TypeAlias = Any @@ -45,13 +70,15 @@ TVariantRV = Tuple[str, Any] VARIANT_FIELD_FORMAT = "v_%s" + @runtime_checkable class SupportsVariant(Protocol, Generic[TVariantBase]): """Defines variant type protocol that should be recognized by normalizers - Variant types behave like TVariantBase type (ie. Decimal) but also implement the protocol below that is used to extract the variant value from it. - See `Wei` type declaration which returns Decimal or str for values greater than supported by destination warehouse. + Variant types behave like TVariantBase type (ie. Decimal) but also implement the protocol below that is used to extract the variant value from it. + See `Wei` type declaration which returns Decimal or str for values greater than supported by destination warehouse. """ + def __call__(self) -> Union[TVariantBase, TVariantRV]: ... @@ -134,7 +161,9 @@ def get_all_types_of_class_in_union(hint: Type[Any], cls: Type[TAny]) -> List[Ty return [t for t in get_args(hint) if inspect.isclass(t) and issubclass(t, cls)] -def get_generic_type_argument_from_instance(instance: Any, sample_value: Optional[Any]) -> Type[Any]: +def get_generic_type_argument_from_instance( + instance: Any, sample_value: Optional[Any] +) -> Type[Any]: """Infers type argument of a Generic class from an `instance` of that class using optional `sample_value` of the argument type Inference depends on the presence of __orig_class__ attribute in instance, if not present - sample_Value will be used @@ -151,4 +180,4 @@ def get_generic_type_argument_from_instance(instance: Any, sample_value: Optiona orig_param_type = get_args(instance.__orig_class__)[0] if orig_param_type is Any and sample_value is not None: orig_param_type = type(sample_value) - return orig_param_type # type: ignore \ No newline at end of file + return orig_param_type # type: ignore diff --git a/dlt/common/utils.py b/dlt/common/utils.py index 5df552c5e2..235bbee10d 100644 --- a/dlt/common/utils.py +++ b/dlt/common/utils.py @@ -1,20 +1,32 @@ -import os -from pathlib import Path -import sys import base64 import hashlib +import os import secrets +import sys +import zlib +from collections.abc import Mapping as C_Mapping from contextlib import contextmanager from functools import wraps from os import environ +from pathlib import Path from types import ModuleType -import zlib - -from typing import Any, ContextManager, Dict, Iterator, Optional, Sequence, Set, Tuple, TypeVar, Mapping, List, Union, Counter -from collections.abc import Mapping as C_Mapping - -from dlt.common.typing import AnyFun, StrAny, DictStrAny, StrStr, TAny, TFun - +from typing import ( + Any, + ContextManager, + Counter, + Dict, + Iterator, + List, + Mapping, + Optional, + Sequence, + Set, + Tuple, + TypeVar, + Union, +) + +from dlt.common.typing import AnyFun, DictStrAny, StrAny, StrStr, TAny, TFun T = TypeVar("T") TDict = TypeVar("TDict", bound=DictStrAny) @@ -22,7 +34,7 @@ def chunks(seq: Sequence[T], n: int) -> Iterator[Sequence[T]]: for i in range(0, len(seq), n): - yield seq[i:i + n] + yield seq[i : i + n] def uniq_id(len_: int = 16) -> str: @@ -32,34 +44,38 @@ def uniq_id(len_: int = 16) -> str: def uniq_id_base64(len_: int = 16) -> str: """Returns a base64 encoded crypto-grade string of random bytes with desired len_""" - return base64.b64encode(secrets.token_bytes(len_)).decode('ascii').rstrip("=") + return base64.b64encode(secrets.token_bytes(len_)).decode("ascii").rstrip("=") def digest128(v: str, len_: int = 15) -> str: """Returns a base64 encoded shake128 hash of str `v` with digest of length `len_` (default: 15 bytes = 20 characters length)""" - return base64.b64encode(hashlib.shake_128(v.encode("utf-8")).digest(len_)).decode('ascii').rstrip("=") + return ( + base64.b64encode(hashlib.shake_128(v.encode("utf-8")).digest(len_)) + .decode("ascii") + .rstrip("=") + ) def digest128b(v: bytes, len_: int = 15) -> str: """Returns a base64 encoded shake128 hash of bytes `v` with digest of length `len_` (default: 15 bytes = 20 characters length)""" - enc_v = base64.b64encode(hashlib.shake_128(v).digest(len_)).decode('ascii') + enc_v = base64.b64encode(hashlib.shake_128(v).digest(len_)).decode("ascii") return enc_v.rstrip("=") def digest256(v: str) -> str: digest = hashlib.sha3_256(v.encode("utf-8")).digest() - return base64.b64encode(digest).decode('ascii') + return base64.b64encode(digest).decode("ascii") def str2bool(v: str) -> bool: if isinstance(v, bool): return v - if v.lower() in ('yes', 'true', 't', 'y', '1'): + if v.lower() in ("yes", "true", "t", "y", "1"): return True - elif v.lower() in ('no', 'false', 'f', 'n', '0'): + elif v.lower() in ("no", "false", "f", "n", "0"): return False else: - raise ValueError('Boolean value expected.') + raise ValueError("Boolean value expected.") # def flatten_list_of_dicts(dicts: Sequence[StrAny]) -> StrAny: @@ -82,7 +98,7 @@ def flatten_list_of_str_or_dicts(seq: Sequence[Union[StrAny, str]]) -> StrAny: o: DictStrAny = {} for e in seq: if isinstance(e, dict): - for k,v in e.items(): + for k, v in e.items(): if k in o: raise KeyError(f"Cannot flatten with duplicate key {k}") o[k] = v @@ -163,7 +179,9 @@ def concat_strings_with_limit(strings: List[str], separator: str, limit: int) -> sep_len = len(separator) for i in range(1, len(strings)): - if current_length + len(strings[i]) + sep_len > limit: # accounts for the length of separator + if ( + current_length + len(strings[i]) + sep_len > limit + ): # accounts for the length of separator yield separator.join(strings[start:i]) start = i current_length = len(strings[i]) @@ -173,7 +191,9 @@ def concat_strings_with_limit(strings: List[str], separator: str, limit: int) -> yield separator.join(strings[start:]) -def graph_edges_to_nodes(edges: Sequence[Tuple[TAny, TAny]], directed: bool = True) -> Dict[TAny, Set[TAny]]: +def graph_edges_to_nodes( + edges: Sequence[Tuple[TAny, TAny]], directed: bool = True +) -> Dict[TAny, Set[TAny]]: """Converts a directed graph represented as a sequence of edges to a graph represented as a mapping from nodes a set of connected nodes. Isolated nodes are represented as edges to itself. If `directed` is `False`, each edge is duplicated but going in opposite direction. @@ -207,7 +227,6 @@ def dfs(node: TAny, current_component: Set[TAny]) -> None: for neighbor in undag[node]: dfs(neighbor, current_component) - for node in undag: if node not in visited: component: Set[TAny] = set() @@ -287,9 +306,10 @@ def is_interactive() -> bool: bool: True if interactive (e.g., REPL, IPython, Jupyter Notebook), False if running as a script. """ import __main__ as main + # When running as a script, the __main__ module has a __file__ attribute. # In an interactive environment, the __file__ attribute is absent. - return not hasattr(main, '__file__') + return not hasattr(main, "__file__") def dict_remove_nones_in_place(d: Dict[Any, Any]) -> Dict[Any, Any]: @@ -317,7 +337,6 @@ def custom_environ(env: StrStr) -> Iterator[None]: def with_custom_environ(f: TFun) -> TFun: - @wraps(f) def _wrap(*args: Any, **kwargs: Any) -> Any: saved_environ = os.environ.copy() @@ -390,11 +409,20 @@ def is_inner_callable(f: AnyFun) -> bool: def obfuscate_pseudo_secret(pseudo_secret: str, pseudo_key: bytes) -> str: - return base64.b64encode(bytes([_a ^ _b for _a, _b in zip(pseudo_secret.encode("utf-8"), pseudo_key*250)])).decode() + return base64.b64encode( + bytes([_a ^ _b for _a, _b in zip(pseudo_secret.encode("utf-8"), pseudo_key * 250)]) + ).decode() def reveal_pseudo_secret(obfuscated_secret: str, pseudo_key: bytes) -> str: - return bytes([_a ^ _b for _a, _b in zip(base64.b64decode(obfuscated_secret.encode("ascii"), validate=True), pseudo_key*250)]).decode("utf-8") + return bytes( + [ + _a ^ _b + for _a, _b in zip( + base64.b64decode(obfuscated_secret.encode("ascii"), validate=True), pseudo_key * 250 + ) + ] + ).decode("utf-8") def get_module_name(m: ModuleType) -> str: @@ -414,7 +442,7 @@ def derives_from_class_of_name(o: object, name: str) -> bool: def compressed_b64encode(value: bytes) -> str: """Compress and b64 encode the given bytestring""" - return base64.b64encode(zlib.compress(value, level=9)).decode('ascii') + return base64.b64encode(zlib.compress(value, level=9)).decode("ascii") def compressed_b64decode(value: str) -> bytes: diff --git a/dlt/common/validation.py b/dlt/common/validation.py index f1900c1b0e..48494720da 100644 --- a/dlt/common/validation.py +++ b/dlt/common/validation.py @@ -1,15 +1,29 @@ import functools -from typing import Callable, Any, Type, get_type_hints, get_args +from typing import Any, Callable, Type, get_args, get_type_hints from dlt.common.exceptions import DictValidationException -from dlt.common.typing import StrAny, extract_optional_type, is_literal_type, is_optional_type, is_typeddict, is_list_generic_type, is_dict_generic_type, _TypedDict - +from dlt.common.typing import ( + StrAny, + _TypedDict, + extract_optional_type, + is_dict_generic_type, + is_list_generic_type, + is_literal_type, + is_optional_type, + is_typeddict, +) TFilterFunc = Callable[[str], bool] TCustomValidator = Callable[[str, str, Any, Any], bool] -def validate_dict(spec: Type[_TypedDict], doc: StrAny, path: str, filter_f: TFilterFunc = None, validator_f: TCustomValidator = None) -> None: +def validate_dict( + spec: Type[_TypedDict], + doc: StrAny, + path: str, + filter_f: TFilterFunc = None, + validator_f: TCustomValidator = None, +) -> None: """Validate the `doc` dictionary based on the given typed dictionary specification `spec`. Args: @@ -42,11 +56,15 @@ def validate_dict(spec: Type[_TypedDict], doc: StrAny, path: str, filter_f: TFil # check missing props missing = set(required_props.keys()).difference(props.keys()) if len(missing): - raise DictValidationException(f"In {path}: following required fields are missing {missing}", path) + raise DictValidationException( + f"In {path}: following required fields are missing {missing}", path + ) # check unknown props unexpected = set(props.keys()).difference(allowed_props.keys()) if len(unexpected): - raise DictValidationException(f"In {path}: following fields are unexpected {unexpected}", path) + raise DictValidationException( + f"In {path}: following fields are unexpected {unexpected}", path + ) def verify_prop(pk: str, pv: Any, t: Any) -> None: if is_optional_type(t): @@ -55,36 +73,68 @@ def verify_prop(pk: str, pv: Any, t: Any) -> None: if is_literal_type(t): a_l = get_args(t) if pv not in a_l: - raise DictValidationException(f"In {path}: field {pk} value {pv} not in allowed {a_l}", path, pk, pv) + raise DictValidationException( + f"In {path}: field {pk} value {pv} not in allowed {a_l}", path, pk, pv + ) elif t in [int, bool, str, float]: if not isinstance(pv, t): - raise DictValidationException(f"In {path}: field {pk} value {pv} has invalid type {type(pv).__name__} while {t.__name__} is expected", path, pk, pv) + raise DictValidationException( + f"In {path}: field {pk} value {pv} has invalid type {type(pv).__name__} while" + f" {t.__name__} is expected", + path, + pk, + pv, + ) elif is_typeddict(t): if not isinstance(pv, dict): - raise DictValidationException(f"In {path}: field {pk} value {pv} has invalid type {type(pv).__name__} while dict is expected", path, pk, pv) + raise DictValidationException( + f"In {path}: field {pk} value {pv} has invalid type {type(pv).__name__} while" + " dict is expected", + path, + pk, + pv, + ) validate_dict(t, pv, path + "/" + pk, filter_f, validator_f) elif is_list_generic_type(t): if not isinstance(pv, list): - raise DictValidationException(f"In {path}: field {pk} value {pv} has invalid type {type(pv).__name__} while list is expected", path, pk, pv) + raise DictValidationException( + f"In {path}: field {pk} value {pv} has invalid type {type(pv).__name__} while" + " list is expected", + path, + pk, + pv, + ) # get list element type from generic and process each list element l_t = get_args(t)[0] for i, l_v in enumerate(pv): verify_prop(pk + f"[{i}]", l_v, l_t) elif is_dict_generic_type(t): if not isinstance(pv, dict): - raise DictValidationException(f"In {path}: field {pk} value {pv} has invalid type {type(pv).__name__} while dict is expected", path, pk, pv) + raise DictValidationException( + f"In {path}: field {pk} value {pv} has invalid type {type(pv).__name__} while" + " dict is expected", + path, + pk, + pv, + ) # get dict key and value type from generic and process each k: v of the dict _, d_v_t = get_args(t) for d_k, d_v in pv.items(): if not isinstance(d_k, str): - raise DictValidationException(f"In {path}: field {pk} key {d_k} must be a string", path, pk, d_k) + raise DictValidationException( + f"In {path}: field {pk} key {d_k} must be a string", path, pk, d_k + ) verify_prop(pk + f"[{d_k}]", d_v, d_v_t) elif t is Any: # pass everything with any type pass else: if not validator_f(path, pk, pv, t): - raise DictValidationException(f"In {path}: field {pk} has expected type {t.__name__} which lacks validator", path, pk) + raise DictValidationException( + f"In {path}: field {pk} has expected type {t.__name__} which lacks validator", + path, + pk, + ) # check allowed props for pk, pv in props.items(): @@ -92,6 +142,5 @@ def verify_prop(pk: str, pv: Any, t: Any) -> None: validate_dict_ignoring_xkeys = functools.partial( - validate_dict, - filter_f=lambda k: not k.startswith("x-") -) \ No newline at end of file + validate_dict, filter_f=lambda k: not k.startswith("x-") +) diff --git a/dlt/common/wei.py b/dlt/common/wei.py index 218e5eee3a..60a5e259ef 100644 --- a/dlt/common/wei.py +++ b/dlt/common/wei.py @@ -1,7 +1,7 @@ from typing import Union -from dlt.common.typing import TVariantRV, SupportsVariant -from dlt.common.arithmetics import default_context, decimal, Decimal +from dlt.common.arithmetics import Decimal, decimal, default_context +from dlt.common.typing import SupportsVariant, TVariantRV # default scale of EVM based blockchain WEI_SCALE = 18 @@ -11,8 +11,7 @@ WEI_SCALE_POW = 10**18 -class Wei(Decimal,SupportsVariant[Decimal]): - +class Wei(Decimal, SupportsVariant[Decimal]): ctx = default_context(decimal.getcontext().copy(), EVM_DECIMAL_PRECISION) @classmethod @@ -29,11 +28,13 @@ def from_int256(cls, value: int, decimals: int = 0) -> "Wei": def __call__(self) -> Union["Wei", TVariantRV]: # TODO: this should look into DestinationCapabilitiesContext to get maximum Decimal value. # this is BigQuery BIGDECIMAL max - if self > 578960446186580977117854925043439539266 or self < -578960446186580977117854925043439539267: - return ("str", str(self)) + if ( + self > 578960446186580977117854925043439539266 + or self < -578960446186580977117854925043439539267 + ): + return ("str", str(self)) else: return self - def __repr__(self) -> str: return f"Wei('{str(self)}')" diff --git a/dlt/destinations/athena/__init__.py b/dlt/destinations/athena/__init__.py index 531744f6e6..845557b3c6 100644 --- a/dlt/destinations/athena/__init__.py +++ b/dlt/destinations/athena/__init__.py @@ -1,20 +1,27 @@ from typing import Type -from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.configuration import with_config, known_sections +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE +from dlt.common.configuration import known_sections, with_config from dlt.common.configuration.accessors import config -from dlt.common.schema.schema import Schema from dlt.common.data_writers.escape import escape_athena_identifier -from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE +from dlt.common.destination import DestinationCapabilitiesContext +from dlt.common.destination.reference import DestinationClientConfiguration, JobClientBase +from dlt.common.schema.schema import Schema from dlt.common.wei import EVM_DECIMAL_PRECISION - from dlt.destinations.athena.configuration import AthenaClientConfiguration -from dlt.common.destination.reference import JobClientBase, DestinationClientConfiguration -@with_config(spec=AthenaClientConfiguration, sections=(known_sections.DESTINATION, "athena",)) + +@with_config( + spec=AthenaClientConfiguration, + sections=( + known_sections.DESTINATION, + "athena", + ), +) def _configure(config: AthenaClientConfiguration = config.value) -> AthenaClientConfiguration: return config + def capabilities() -> DestinationCapabilitiesContext: caps = DestinationCapabilitiesContext() # athena only supports loading from staged files on s3 for now @@ -39,11 +46,14 @@ def capabilities() -> DestinationCapabilitiesContext: return caps -def client(schema: Schema, initial_config: DestinationClientConfiguration = config.value) -> JobClientBase: +def client( + schema: Schema, initial_config: DestinationClientConfiguration = config.value +) -> JobClientBase: # import client when creating instance so capabilities and config specs can be accessed without dependencies installed from dlt.destinations.athena.athena import AthenaClient + return AthenaClient(schema, _configure(initial_config)) # type: ignore def spec() -> Type[DestinationClientConfiguration]: - return AthenaClientConfiguration \ No newline at end of file + return AthenaClientConfiguration diff --git a/dlt/destinations/athena/athena.py b/dlt/destinations/athena/athena.py index 53b41fddb0..f8bc60afa6 100644 --- a/dlt/destinations/athena/athena.py +++ b/dlt/destinations/athena/athena.py @@ -1,36 +1,58 @@ -from typing import Optional, ClassVar, Iterator, Any, AnyStr, Sequence, Tuple, List, Dict, Callable, Iterable, Type -from copy import deepcopy import re - from contextlib import contextmanager -from pendulum.datetime import DateTime, Date +from copy import deepcopy from datetime import datetime # noqa: I251 +from typing import ( + Any, + AnyStr, + Callable, + ClassVar, + Dict, + Iterable, + Iterator, + List, + Optional, + Sequence, + Tuple, + Type, +) import pyathena +from pendulum.datetime import Date, DateTime from pyathena import connect from pyathena.connection import Connection -from pyathena.error import OperationalError, DatabaseError, ProgrammingError, IntegrityError, Error -from pyathena.formatter import DefaultParameterFormatter, _DEFAULT_FORMATTERS, Formatter, _format_date +from pyathena.error import DatabaseError, Error, IntegrityError, OperationalError, ProgrammingError +from pyathena.formatter import ( + _DEFAULT_FORMATTERS, + DefaultParameterFormatter, + Formatter, + _format_date, +) from dlt.common import logger from dlt.common.data_types import TDataType -from dlt.common.schema import TColumnSchema, Schema -from dlt.common.schema.typing import TTableSchema +from dlt.common.data_writers.escape import escape_bigquery_identifier from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import LoadJob -from dlt.common.destination.reference import TLoadJobState +from dlt.common.destination.reference import LoadJob, TLoadJobState +from dlt.common.schema import Schema, TColumnSchema +from dlt.common.schema.typing import TTableSchema from dlt.common.storages import FileStorage -from dlt.common.data_writers.escape import escape_bigquery_identifier - - -from dlt.destinations.typing import DBApi, DBTransaction -from dlt.destinations.exceptions import DatabaseTerminalException, DatabaseTransientException, DatabaseUndefinedRelation +from dlt.destinations import path_utils from dlt.destinations.athena import capabilities -from dlt.destinations.sql_client import SqlClientBase, DBApiCursorImpl, raise_database_error, raise_open_connection_error -from dlt.destinations.typing import DBApiCursor -from dlt.destinations.job_client_impl import SqlJobClientBase, StorageSchemaInfo from dlt.destinations.athena.configuration import AthenaClientConfiguration -from dlt.destinations import path_utils +from dlt.destinations.exceptions import ( + DatabaseTerminalException, + DatabaseTransientException, + DatabaseUndefinedRelation, +) +from dlt.destinations.job_client_impl import SqlJobClientBase, StorageSchemaInfo +from dlt.destinations.sql_client import ( + DBApiCursorImpl, + SqlClientBase, + raise_database_error, + raise_open_connection_error, +) +from dlt.destinations.typing import DBApi, DBApiCursor, DBTransaction SCT_TO_HIVET: Dict[TDataType, str] = { "complex": "string", @@ -41,7 +63,7 @@ "timestamp": "timestamp", "bigint": "bigint", "binary": "binary", - "decimal": "decimal(%i,%i)" + "decimal": "decimal(%i,%i)", } HIVET_TO_SCT: Dict[str, TDataType] = { @@ -53,7 +75,7 @@ "bigint": "bigint", "binary": "binary", "varbinary": "binary", - "decimal": "decimal" + "decimal": "decimal", } @@ -68,7 +90,6 @@ def _format_pendulum_datetime(formatter: Formatter, escaper: Callable[[str], str class DLTAthenaFormatter(DefaultParameterFormatter): - _INSTANCE: ClassVar["DLTAthenaFormatter"] = None def __new__(cls: Type["DLTAthenaFormatter"]) -> "DLTAthenaFormatter": @@ -76,7 +97,6 @@ def __new__(cls: Type["DLTAthenaFormatter"]) -> "DLTAthenaFormatter": return cls._INSTANCE return super().__new__(cls) - def __init__(self) -> None: if DLTAthenaFormatter._INSTANCE: return @@ -85,9 +105,7 @@ def __init__(self) -> None: formatters[datetime] = _format_pendulum_datetime formatters[Date] = _format_date - super(DefaultParameterFormatter, self).__init__( - mappings=formatters, default=None - ) + super(DefaultParameterFormatter, self).__init__(mappings=formatters, default=None) DLTAthenaFormatter._INSTANCE = self @@ -105,8 +123,8 @@ def exception(self) -> str: # this part of code should be never reached raise NotImplementedError() -class AthenaSQLClient(SqlClientBase[Connection]): +class AthenaSQLClient(SqlClientBase[Connection]): capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() dbapi: ClassVar[DBApi] = pyathena @@ -123,7 +141,8 @@ def open_connection(self) -> Connection: schema_name=self.dataset_name, s3_staging_dir=self.config.query_result_bucket, work_group=self.config.athena_work_group, - **native_credentials) + **native_credentials, + ) return self._conn def close_connection(self) -> None: @@ -157,18 +176,24 @@ def drop_dataset(self) -> None: self.execute_sql(f"DROP DATABASE {self.fully_qualified_ddl_dataset_name()} CASCADE;") def fully_qualified_dataset_name(self, escape: bool = True) -> str: - return self.capabilities.escape_identifier(self.dataset_name) if escape else self.dataset_name + return ( + self.capabilities.escape_identifier(self.dataset_name) if escape else self.dataset_name + ) def drop_tables(self, *tables: str) -> None: if not tables: return - statements = [f"DROP TABLE IF EXISTS {self.make_qualified_ddl_table_name(table)};" for table in tables] + statements = [ + f"DROP TABLE IF EXISTS {self.make_qualified_ddl_table_name(table)};" for table in tables + ] self.execute_fragments(statements) @contextmanager @raise_database_error def begin_transaction(self) -> Iterator[DBTransaction]: - logger.warning("Athena does not support transactions! Each SQL statement is auto-committed separately.") + logger.warning( + "Athena does not support transactions! Each SQL statement is auto-committed separately." + ) yield self @raise_database_error @@ -197,7 +222,9 @@ def _make_database_exception(ex: Exception) -> Exception: return DatabaseTransientException(ex) return ex - def execute_sql(self, sql: AnyStr, *args: Any, **kwargs: Any) -> Optional[Sequence[Sequence[Any]]]: + def execute_sql( + self, sql: AnyStr, *args: Any, **kwargs: Any + ) -> Optional[Sequence[Sequence[Any]]]: with self.execute_query(sql, *args, **kwargs) as curr: if curr.description is None: return None @@ -206,13 +233,17 @@ def execute_sql(self, sql: AnyStr, *args: Any, **kwargs: Any) -> Optional[Sequen return f @staticmethod - def _convert_to_old_pyformat(new_style_string: str, args: Tuple[Any, ...]) -> Tuple[str, Dict[str, Any]]: + def _convert_to_old_pyformat( + new_style_string: str, args: Tuple[Any, ...] + ) -> Tuple[str, Dict[str, Any]]: # create a list of keys - keys = ["arg"+str(i) for i, _ in enumerate(args)] + keys = ["arg" + str(i) for i, _ in enumerate(args)] # create an old style string and replace placeholders - old_style_string, count = re.subn(r"%s", lambda _: "%(" + keys.pop(0) + ")s", new_style_string) + old_style_string, count = re.subn( + r"%s", lambda _: "%(" + keys.pop(0) + ")s", new_style_string + ) # create a dictionary mapping keys to args - mapping = dict(zip(["arg"+str(i) for i, _ in enumerate(args)], args)) + mapping = dict(zip(["arg" + str(i) for i, _ in enumerate(args)], args)) # raise if there is a mismatch between args and string if count != len(args): raise DatabaseTransientException(OperationalError()) @@ -247,19 +278,17 @@ def has_dataset(self) -> bool: class AthenaClient(SqlJobClientBase): - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() def __init__(self, schema: Schema, config: AthenaClientConfiguration) -> None: # verify if staging layout is valid for Athena # this will raise if the table prefix is not properly defined # we actually that {table_name} is first, no {schema_name} is allowed - self.table_prefix_layout = path_utils.get_table_prefix_layout(config.staging_config.layout, []) - - sql_client = AthenaSQLClient( - config.normalize_dataset_name(schema), - config + self.table_prefix_layout = path_utils.get_table_prefix_layout( + config.staging_config.layout, [] ) + + sql_client = AthenaSQLClient(config.normalize_dataset_name(schema), config) super().__init__(schema, config, sql_client) self.sql_client: AthenaSQLClient = sql_client # type: ignore self.config: AthenaClientConfiguration = config @@ -277,17 +306,22 @@ def _to_db_type(cls, sc_t: TDataType) -> str: return SCT_TO_HIVET[sc_t] @classmethod - def _from_db_type(cls, hive_t: str, precision: Optional[int], scale: Optional[int]) -> TDataType: + def _from_db_type( + cls, hive_t: str, precision: Optional[int], scale: Optional[int] + ) -> TDataType: for key, val in HIVET_TO_SCT.items(): if hive_t.startswith(key): return val return None def _get_column_def_sql(self, c: TColumnSchema) -> str: - return f"{self.sql_client.escape_ddl_identifier(c['name'])} {self._to_db_type(c['data_type'])}" - - def _get_table_update_sql(self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool) -> List[str]: + return ( + f"{self.sql_client.escape_ddl_identifier(c['name'])} {self._to_db_type(c['data_type'])}" + ) + def _get_table_update_sql( + self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool + ) -> List[str]: bucket = self.config.staging_config.bucket_url dataset = self.sql_client.dataset_name sql: List[str] = [] @@ -326,4 +360,4 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> @staticmethod def is_dbapi_exception(ex: Exception) -> bool: - return isinstance(ex, Error) \ No newline at end of file + return isinstance(ex, Error) diff --git a/dlt/destinations/athena/configuration.py b/dlt/destinations/athena/configuration.py index d6ba5e3814..d15f01fd0e 100644 --- a/dlt/destinations/athena/configuration.py +++ b/dlt/destinations/athena/configuration.py @@ -1,8 +1,8 @@ from typing import ClassVar, Final, List, Optional from dlt.common.configuration import configspec +from dlt.common.configuration.specs import AwsCredentials from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration -from dlt.common.configuration.specs import AwsCredentials @configspec @@ -13,4 +13,4 @@ class AthenaClientConfiguration(DestinationClientDwhWithStagingConfiguration): athena_work_group: Optional[str] = None aws_data_catalog: Optional[str] = "awsdatacatalog" - __config_gen_annotations__: ClassVar[List[str]] = ["athena_work_group"] \ No newline at end of file + __config_gen_annotations__: ClassVar[List[str]] = ["athena_work_group"] diff --git a/dlt/destinations/bigquery/__init__.py b/dlt/destinations/bigquery/__init__.py index 3d97e9a929..09fab7188b 100644 --- a/dlt/destinations/bigquery/__init__.py +++ b/dlt/destinations/bigquery/__init__.py @@ -1,17 +1,22 @@ from typing import Type -from dlt.common.data_writers.escape import escape_bigquery_identifier -from dlt.common.schema.schema import Schema -from dlt.common.configuration import with_config, known_sections +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE +from dlt.common.configuration import known_sections, with_config from dlt.common.configuration.accessors import config +from dlt.common.data_writers.escape import escape_bigquery_identifier from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import JobClientBase, DestinationClientConfiguration -from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE - +from dlt.common.destination.reference import DestinationClientConfiguration, JobClientBase +from dlt.common.schema.schema import Schema from dlt.destinations.bigquery.configuration import BigQueryClientConfiguration -@with_config(spec=BigQueryClientConfiguration, sections=(known_sections.DESTINATION, "bigquery",)) +@with_config( + spec=BigQueryClientConfiguration, + sections=( + known_sections.DESTINATION, + "bigquery", + ), +) def _configure(config: BigQueryClientConfiguration = config.value) -> BigQueryClientConfiguration: return config @@ -37,7 +42,9 @@ def capabilities() -> DestinationCapabilitiesContext: return caps -def client(schema: Schema, initial_config: DestinationClientConfiguration = config.value) -> JobClientBase: +def client( + schema: Schema, initial_config: DestinationClientConfiguration = config.value +) -> JobClientBase: # import client when creating instance so capabilities and config specs can be accessed without dependencies installed from dlt.destinations.bigquery.bigquery import BigQueryClient @@ -45,4 +52,4 @@ def client(schema: Schema, initial_config: DestinationClientConfiguration = conf def spec() -> Type[DestinationClientConfiguration]: - return BigQueryClientConfiguration \ No newline at end of file + return BigQueryClientConfiguration diff --git a/dlt/destinations/bigquery/bigquery.py b/dlt/destinations/bigquery/bigquery.py index 302788e250..c07e23fcc1 100644 --- a/dlt/destinations/bigquery/bigquery.py +++ b/dlt/destinations/bigquery/bigquery.py @@ -1,29 +1,33 @@ import os from pathlib import Path -from typing import ClassVar, Dict, Optional, Sequence, Tuple, List, cast, Type, Any +from typing import Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type, cast + import google.cloud.bigquery as bigquery # noqa: I250 -from google.cloud import exceptions as gcp_exceptions from google.api_core import exceptions as api_core_exceptions +from google.cloud import exceptions as gcp_exceptions from dlt.common import json, logger -from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import FollowupJob, NewLoadJob, TLoadJobState, LoadJob from dlt.common.data_types import TDataType -from dlt.common.storages.file_storage import FileStorage -from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns +from dlt.common.destination import DestinationCapabilitiesContext +from dlt.common.destination.reference import FollowupJob, LoadJob, NewLoadJob, TLoadJobState +from dlt.common.schema import Schema, TColumnSchema, TTableSchemaColumns from dlt.common.schema.typing import TTableSchema - -from dlt.destinations.job_client_impl import SqlJobClientWithStaging -from dlt.destinations.exceptions import DestinationSchemaWillNotUpdate, DestinationTransientException, LoadJobNotExistsException, LoadJobTerminalException, LoadJobUnknownTableException - +from dlt.common.schema.utils import table_schema_has_type +from dlt.common.storages.file_storage import FileStorage from dlt.destinations.bigquery import capabilities from dlt.destinations.bigquery.configuration import BigQueryClientConfiguration -from dlt.destinations.bigquery.sql_client import BigQuerySqlClient, BQ_TERMINAL_REASONS -from dlt.destinations.sql_jobs import SqlMergeJob, SqlStagingCopyJob +from dlt.destinations.bigquery.sql_client import BQ_TERMINAL_REASONS, BigQuerySqlClient +from dlt.destinations.exceptions import ( + DestinationSchemaWillNotUpdate, + DestinationTransientException, + LoadJobNotExistsException, + LoadJobTerminalException, + LoadJobUnknownTableException, +) +from dlt.destinations.job_client_impl import SqlJobClientWithStaging from dlt.destinations.job_impl import NewReferenceJob from dlt.destinations.sql_client import SqlClientBase - -from dlt.common.schema.utils import table_schema_has_type +from dlt.destinations.sql_jobs import SqlMergeJob, SqlStagingCopyJob SCT_TO_BQT: Dict[TDataType, str] = { "complex": "JSON", @@ -35,7 +39,7 @@ "bigint": "INTEGER", "binary": "BYTES", "decimal": "NUMERIC(%i,%i)", - "wei": "BIGNUMERIC" # non parametrized should hold wei values + "wei": "BIGNUMERIC", # non parametrized should hold wei values } BQT_TO_SCT: Dict[str, TDataType] = { @@ -48,16 +52,17 @@ "BYTES": "binary", "NUMERIC": "decimal", "BIGNUMERIC": "decimal", - "JSON": "complex" + "JSON": "complex", } + class BigQueryLoadJob(LoadJob, FollowupJob): def __init__( self, file_name: str, bq_load_job: bigquery.LoadJob, http_timeout: float, - retry_deadline: float + retry_deadline: float, ) -> None: self.bq_load_job = bq_load_job self.default_retry = bigquery.DEFAULT_RETRY.with_deadline(retry_deadline) @@ -77,7 +82,10 @@ def state(self) -> TLoadJobState: # the job permanently failed for the reason above return "failed" elif reason in ["internalError"]: - logger.warning(f"Got reason {reason} for job {self.file_name}, job considered still running. ({self.bq_load_job.error_result})") + logger.warning( + f"Got reason {reason} for job {self.file_name}, job considered still" + f" running. ({self.bq_load_job.error_result})" + ) # status of the job could not be obtained, job still running return "running" else: @@ -90,13 +98,15 @@ def job_id(self) -> str: return BigQueryLoadJob.get_job_id_from_file_path(super().job_id()) def exception(self) -> str: - exception: str = json.dumps({ - "error_result": self.bq_load_job.error_result, - "errors": self.bq_load_job.errors, - "job_start": self.bq_load_job.started, - "job_end": self.bq_load_job.ended, - "job_id": self.bq_load_job.job_id - }) + exception: str = json.dumps( + { + "error_result": self.bq_load_job.error_result, + "errors": self.bq_load_job.errors, + "job_start": self.bq_load_job.started, + "job_end": self.bq_load_job.ended, + "job_id": self.bq_load_job.job_id, + } + ) return exception @staticmethod @@ -105,19 +115,29 @@ def get_job_id_from_file_path(file_path: str) -> str: class BigQueryMergeJob(SqlMergeJob): - @classmethod - def gen_key_table_clauses(cls, root_table_name: str, staging_root_table_name: str, key_clauses: Sequence[str], for_delete: bool) -> List[str]: + def gen_key_table_clauses( + cls, + root_table_name: str, + staging_root_table_name: str, + key_clauses: Sequence[str], + for_delete: bool, + ) -> List[str]: # generate several clauses: BigQuery does not support OR nor unions sql: List[str] = [] for clause in key_clauses: - sql.append(f"FROM {root_table_name} AS d WHERE EXISTS (SELECT 1 FROM {staging_root_table_name} AS s WHERE {clause.format(d='d', s='s')})") + sql.append( + f"FROM {root_table_name} AS d WHERE EXISTS (SELECT 1 FROM" + f" {staging_root_table_name} AS s WHERE {clause.format(d='d', s='s')})" + ) return sql -class BigqueryStagingCopyJob(SqlStagingCopyJob): +class BigqueryStagingCopyJob(SqlStagingCopyJob): @classmethod - def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any]) -> List[str]: + def generate_sql( + cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any] + ) -> List[str]: sql: List[str] = [] for table in table_chain: with sql_client.with_staging_dataset(staging=True): @@ -129,8 +149,8 @@ def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClient sql.append(f"CREATE TABLE {table_name} CLONE {staging_table_name};") return sql -class BigQueryClient(SqlJobClientWithStaging): +class BigQueryClient(SqlJobClientWithStaging): capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() def __init__(self, schema: Schema, config: BigQueryClientConfiguration) -> None: @@ -139,7 +159,7 @@ def __init__(self, schema: Schema, config: BigQueryClientConfiguration) -> None: config.credentials, config.get_location(), config.http_timeout, - config.retry_deadline + config.retry_deadline, ) super().__init__(schema, config, sql_client) self.config: BigQueryClientConfiguration = config @@ -169,7 +189,7 @@ def restore_file_load(self, file_path: str) -> LoadJob: FileStorage.get_file_name_from_file_path(file_path), self._retrieve_load_job(file_path), self.config.http_timeout, - self.config.retry_deadline + self.config.retry_deadline, ) except api_core_exceptions.GoogleAPICallError as gace: reason = BigQuerySqlClient._get_reason_from_errors(gace) @@ -190,7 +210,7 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> FileStorage.get_file_name_from_file_path(file_path), self._create_load_job(table, file_path), self.config.http_timeout, - self.config.retry_deadline + self.config.retry_deadline, ) except api_core_exceptions.GoogleAPICallError as gace: reason = BigQuerySqlClient._get_reason_from_errors(gace) @@ -207,17 +227,31 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> raise DestinationTransientException(gace) return job - def _get_table_update_sql(self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool, separate_alters: bool = False) -> List[str]: + def _get_table_update_sql( + self, + table_name: str, + new_columns: Sequence[TColumnSchema], + generate_alter: bool, + separate_alters: bool = False, + ) -> List[str]: sql = super()._get_table_update_sql(table_name, new_columns, generate_alter) canonical_name = self.sql_client.make_qualified_table_name(table_name) - cluster_list = [self.capabilities.escape_identifier(c["name"]) for c in new_columns if c.get("cluster")] - partition_list = [self.capabilities.escape_identifier(c["name"]) for c in new_columns if c.get("partition")] + cluster_list = [ + self.capabilities.escape_identifier(c["name"]) for c in new_columns if c.get("cluster") + ] + partition_list = [ + self.capabilities.escape_identifier(c["name"]) + for c in new_columns + if c.get("partition") + ] # partition by must be added first if len(partition_list) > 0: if len(partition_list) > 1: - raise DestinationSchemaWillNotUpdate(canonical_name, partition_list, "Partition requested for more than one column") + raise DestinationSchemaWillNotUpdate( + canonical_name, partition_list, "Partition requested for more than one column" + ) else: sql[0] = sql[0] + f"\nPARTITION BY DATE({partition_list[0]})" if len(cluster_list) > 0: @@ -235,7 +269,7 @@ def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns] table = self.sql_client.native_connection.get_table( self.sql_client.make_qualified_table_name(table_name, escape=False), retry=self.sql_client._default_retry, - timeout=self.config.http_timeout + timeout=self.config.http_timeout, ) partition_field = table.time_partitioning.field if table.time_partitioning else None for c in table.schema: @@ -248,7 +282,7 @@ def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns] "primary_key": False, "foreign_key": False, "cluster": c.name in (table.clustering_fields or []), - "partition": c.name == partition_field + "partition": c.name == partition_field, } schema_table[c.name] = schema_c return True, schema_table @@ -272,7 +306,10 @@ def _create_load_job(self, table: TTableSchema, file_path: str) -> bigquery.Load if ext == "parquet": # if table contains complex types, we cannot load with parquet if table_schema_has_type(table, "complex"): - raise LoadJobTerminalException(file_path, "Bigquery cannot load into JSON data type from parquet. Use jsonl instead.") + raise LoadJobTerminalException( + file_path, + "Bigquery cannot load into JSON data type from parquet. Use jsonl instead.", + ) source_format = bigquery.SourceFormat.PARQUET # parquet needs NUMERIC type autodetection decimal_target_types = ["NUMERIC", "BIGNUMERIC"] @@ -285,25 +322,26 @@ def _create_load_job(self, table: TTableSchema, file_path: str) -> bigquery.Load source_format=source_format, decimal_target_types=decimal_target_types, ignore_unknown_values=False, - max_bad_records=0) + max_bad_records=0, + ) if bucket_path: return self.sql_client.native_connection.load_table_from_uri( - bucket_path, - self.sql_client.make_qualified_table_name(table_name, escape=False), - job_id=job_id, - job_config=job_config, - timeout=self.config.file_upload_timeout - ) + bucket_path, + self.sql_client.make_qualified_table_name(table_name, escape=False), + job_id=job_id, + job_config=job_config, + timeout=self.config.file_upload_timeout, + ) with open(file_path, "rb") as f: return self.sql_client.native_connection.load_table_from_file( - f, - self.sql_client.make_qualified_table_name(table_name, escape=False), - job_id=job_id, - job_config=job_config, - timeout=self.config.file_upload_timeout - ) + f, + self.sql_client.make_qualified_table_name(table_name, escape=False), + job_id=job_id, + job_config=job_config, + timeout=self.config.file_upload_timeout, + ) def _retrieve_load_job(self, file_path: str) -> bigquery.LoadJob: job_id = BigQueryLoadJob.get_job_id_from_file_path(file_path) @@ -321,5 +359,3 @@ def _from_db_type(cls, bq_t: str, precision: Optional[int], scale: Optional[int] if precision is None: # biggest numeric possible return "wei" return BQT_TO_SCT.get(bq_t, "text") - - diff --git a/dlt/destinations/bigquery/configuration.py b/dlt/destinations/bigquery/configuration.py index 146e137475..713f4688b2 100644 --- a/dlt/destinations/bigquery/configuration.py +++ b/dlt/destinations/bigquery/configuration.py @@ -3,9 +3,8 @@ from dlt.common.configuration import configspec from dlt.common.configuration.specs import GcpServiceAccountCredentials -from dlt.common.utils import digest128 - from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration +from dlt.common.utils import digest128 @configspec @@ -16,7 +15,9 @@ class BigQueryClientConfiguration(DestinationClientDwhWithStagingConfiguration): http_timeout: float = 15.0 # connection timeout for http request to BigQuery api file_upload_timeout: float = 30 * 60.0 # a timeout for file upload when loading local files - retry_deadline: float = 60.0 # how long to retry the operation in case of error, the backoff 60s + retry_deadline: float = ( + 60.0 # how long to retry the operation in case of error, the backoff 60s + ) __config_gen_annotations__: ClassVar[List[str]] = ["location"] @@ -25,7 +26,10 @@ def get_location(self) -> str: return self.location # default was changed in credentials, emit deprecation message if self.credentials.location != "US": - warnings.warn("Setting BigQuery location in the credentials is deprecated. Please set the location directly in bigquery section ie. destinations.bigquery.location='EU'") + warnings.warn( + "Setting BigQuery location in the credentials is deprecated. Please set the" + " location directly in bigquery section ie. destinations.bigquery.location='EU'" + ) return self.credentials.location def fingerprint(self) -> str: @@ -35,6 +39,7 @@ def fingerprint(self) -> str: return "" if TYPE_CHECKING: + def __init__( self, destination_name: str = None, @@ -44,7 +49,6 @@ def __init__( location: str = "US", http_timeout: float = 15.0, file_upload_timeout: float = 30 * 60.0, - retry_deadline: float = 60.0 + retry_deadline: float = 60.0, ) -> None: ... - diff --git a/dlt/destinations/bigquery/sql_client.py b/dlt/destinations/bigquery/sql_client.py index 3d6eb19833..c16ce13ce8 100644 --- a/dlt/destinations/bigquery/sql_client.py +++ b/dlt/destinations/bigquery/sql_client.py @@ -1,32 +1,48 @@ - from contextlib import contextmanager from typing import Any, AnyStr, ClassVar, Iterator, List, Optional, Sequence, Type import google.cloud.bigquery as bigquery # noqa: I250 -from google.cloud.bigquery import dbapi as bq_dbapi -from google.cloud.bigquery.dbapi import Connection as DbApiConnection, Cursor as BQDbApiCursor +from google.api_core import exceptions as api_core_exceptions from google.cloud import exceptions as gcp_exceptions +from google.cloud.bigquery import dbapi as bq_dbapi +from google.cloud.bigquery.dbapi import Connection as DbApiConnection +from google.cloud.bigquery.dbapi import Cursor as BQDbApiCursor from google.cloud.bigquery.dbapi import exceptions as dbapi_exceptions -from google.api_core import exceptions as api_core_exceptions from dlt.common.configuration.specs import GcpServiceAccountCredentialsWithoutDefaults from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.typing import StrAny - -from dlt.destinations.typing import DBApi, DBApiCursor, DBTransaction, DataFrame -from dlt.destinations.exceptions import DatabaseTerminalException, DatabaseTransientException, DatabaseUndefinedRelation -from dlt.destinations.sql_client import DBApiCursorImpl, SqlClientBase, raise_database_error, raise_open_connection_error - from dlt.destinations.bigquery import capabilities +from dlt.destinations.exceptions import ( + DatabaseTerminalException, + DatabaseTransientException, + DatabaseUndefinedRelation, +) +from dlt.destinations.sql_client import ( + DBApiCursorImpl, + SqlClientBase, + raise_database_error, + raise_open_connection_error, +) +from dlt.destinations.typing import DataFrame, DBApi, DBApiCursor, DBTransaction # terminal reasons as returned in BQ gRPC error response # https://cloud.google.com/bigquery/docs/error-messages -BQ_TERMINAL_REASONS = ["billingTierLimitExceeded", "duplicate", "invalid", "notFound", "notImplemented", "stopped", "tableUnavailable"] +BQ_TERMINAL_REASONS = [ + "billingTierLimitExceeded", + "duplicate", + "invalid", + "notFound", + "notImplemented", + "stopped", + "tableUnavailable", +] # invalidQuery is an transient error -> must be fixed by programmer class BigQueryDBApiCursorImpl(DBApiCursorImpl): """Use native BigQuery data frame support if available""" + native_cursor: BQDbApiCursor # type: ignore def df(self, chunk_size: int = None, **kwargs: Any) -> DataFrame: @@ -43,7 +59,6 @@ def df(self, chunk_size: int = None, **kwargs: Any) -> DataFrame: class BigQuerySqlClient(SqlClientBase[bigquery.Client], DBTransaction): - dbapi: ClassVar[DBApi] = bq_dbapi capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() @@ -53,7 +68,7 @@ def __init__( credentials: GcpServiceAccountCredentialsWithoutDefaults, location: str = "US", http_timeout: float = 15.0, - retry_deadline: float = 60.0 + retry_deadline: float = 60.0, ) -> None: self._client: bigquery.Client = None self.credentials: GcpServiceAccountCredentialsWithoutDefaults = credentials @@ -62,16 +77,17 @@ def __init__( super().__init__(credentials.project_id, dataset_name) self._default_retry = bigquery.DEFAULT_RETRY.with_deadline(retry_deadline) - self._default_query = bigquery.QueryJobConfig(default_dataset=self.fully_qualified_dataset_name(escape=False)) + self._default_query = bigquery.QueryJobConfig( + default_dataset=self.fully_qualified_dataset_name(escape=False) + ) self._session_query: bigquery.QueryJobConfig = None - @raise_open_connection_error def open_connection(self) -> bigquery.Client: self._client = bigquery.Client( self.credentials.project_id, credentials=self.credentials.to_native_credentials(), - location=self.location + location=self.location, ) # patch the client query so our defaults are used @@ -81,7 +97,7 @@ def query_patch( query: str, retry: Any = self._default_retry, timeout: Any = self.http_timeout, - **kwargs: Any + **kwargs: Any, ) -> Any: return query_orig(query, retry=retry, timeout=timeout, **kwargs) @@ -105,8 +121,8 @@ def begin_transaction(self) -> Iterator[DBTransaction]: "BEGIN TRANSACTION;", job_config=bigquery.QueryJobConfig( create_session=True, - default_dataset=self.fully_qualified_dataset_name(escape=False) - ) + default_dataset=self.fully_qualified_dataset_name(escape=False), + ), ) self._session_query = bigquery.QueryJobConfig( create_session=False, @@ -115,7 +131,7 @@ def begin_transaction(self) -> Iterator[DBTransaction]: bigquery.query.ConnectionProperty( key="session_id", value=job.session_info.session_id ) - ] + ], ) try: job.result() @@ -124,7 +140,9 @@ def begin_transaction(self) -> Iterator[DBTransaction]: self._session_query = None raise else: - raise dbapi_exceptions.ProgrammingError("Nested transactions not supported on BigQuery") + raise dbapi_exceptions.ProgrammingError( + "Nested transactions not supported on BigQuery" + ) yield self self.commit_transaction() except Exception: @@ -150,7 +168,11 @@ def native_connection(self) -> bigquery.Client: def has_dataset(self) -> bool: try: - self._client.get_dataset(self.fully_qualified_dataset_name(escape=False), retry=self._default_retry, timeout=self.http_timeout) + self._client.get_dataset( + self.fully_qualified_dataset_name(escape=False), + retry=self._default_retry, + timeout=self.http_timeout, + ) return True except gcp_exceptions.NotFound: return False @@ -160,7 +182,7 @@ def create_dataset(self) -> None: self.fully_qualified_dataset_name(escape=False), exists_ok=False, retry=self._default_retry, - timeout=self.http_timeout + timeout=self.http_timeout, ) def drop_dataset(self) -> None: @@ -169,10 +191,12 @@ def drop_dataset(self) -> None: not_found_ok=True, delete_contents=True, retry=self._default_retry, - timeout=self.http_timeout + timeout=self.http_timeout, ) - def execute_sql(self, sql: AnyStr, *args: Any, **kwargs: Any) -> Optional[Sequence[Sequence[Any]]]: + def execute_sql( + self, sql: AnyStr, *args: Any, **kwargs: Any + ) -> Optional[Sequence[Sequence[Any]]]: with self.execute_query(sql, *args, **kwargs) as curr: if not curr.description: return None @@ -187,7 +211,7 @@ def execute_sql(self, sql: AnyStr, *args: Any, **kwargs: Any) -> Optional[Sequen @contextmanager @raise_database_error - def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DBApiCursor]: + def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DBApiCursor]: conn: DbApiConnection = None curr: DBApiCursor = None db_args = args if args else kwargs if kwargs else None @@ -226,11 +250,17 @@ def _make_database_exception(cls, ex: Exception) -> Exception: return DatabaseUndefinedRelation(ex) if reason == "invalidQuery" and "was not found" in str(ex) and "Dataset" in str(ex): return DatabaseUndefinedRelation(ex) - if reason == "invalidQuery" and "Not found" in str(ex) and ("Dataset" in str(ex) or "Table" in str(ex)): + if ( + reason == "invalidQuery" + and "Not found" in str(ex) + and ("Dataset" in str(ex) or "Table" in str(ex)) + ): return DatabaseUndefinedRelation(ex) if reason == "accessDenied" and "Dataset" in str(ex) and "not exist" in str(ex): return DatabaseUndefinedRelation(ex) - if reason == "invalidQuery" and ("Unrecognized name" in str(ex) or "cannot be null" in str(ex)): + if reason == "invalidQuery" and ( + "Unrecognized name" in str(ex) or "cannot be null" in str(ex) + ): # unknown column, inserting NULL into required field return DatabaseTerminalException(ex) if reason in BQ_TERMINAL_REASONS: @@ -253,4 +283,7 @@ def is_dbapi_exception(ex: Exception) -> bool: class TransactionsNotImplementedError(NotImplementedError): def __init__(self) -> None: - super().__init__("BigQuery does not support transaction management. Instead you may wrap your SQL script in BEGIN TRANSACTION; ... COMMIT TRANSACTION;") + super().__init__( + "BigQuery does not support transaction management. Instead you may wrap your SQL script" + " in BEGIN TRANSACTION; ... COMMIT TRANSACTION;" + ) diff --git a/dlt/destinations/duckdb/__init__.py b/dlt/destinations/duckdb/__init__.py index c3dfd02db7..4e1fc83abe 100644 --- a/dlt/destinations/duckdb/__init__.py +++ b/dlt/destinations/duckdb/__init__.py @@ -1,17 +1,22 @@ from typing import Type -from dlt.common.schema.schema import Schema -from dlt.common.configuration import with_config, known_sections +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE +from dlt.common.configuration import known_sections, with_config from dlt.common.configuration.accessors import config -from dlt.common.data_writers.escape import escape_postgres_identifier, escape_duckdb_literal +from dlt.common.data_writers.escape import escape_duckdb_literal, escape_postgres_identifier from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import JobClientBase, DestinationClientConfiguration -from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE - +from dlt.common.destination.reference import DestinationClientConfiguration, JobClientBase +from dlt.common.schema.schema import Schema from dlt.destinations.duckdb.configuration import DuckDbClientConfiguration -@with_config(spec=DuckDbClientConfiguration, sections=(known_sections.DESTINATION, "duckdb",)) +@with_config( + spec=DuckDbClientConfiguration, + sections=( + known_sections.DESTINATION, + "duckdb", + ), +) def _configure(config: DuckDbClientConfiguration = config.value) -> DuckDbClientConfiguration: return config @@ -40,7 +45,9 @@ def capabilities() -> DestinationCapabilitiesContext: return caps -def client(schema: Schema, initial_config: DestinationClientConfiguration = config.value) -> JobClientBase: +def client( + schema: Schema, initial_config: DestinationClientConfiguration = config.value +) -> JobClientBase: # import client when creating instance so capabilities and config specs can be accessed without dependencies installed from dlt.destinations.duckdb.duck import DuckDbClient diff --git a/dlt/destinations/duckdb/configuration.py b/dlt/destinations/duckdb/configuration.py index fc1142cabb..503574f8e6 100644 --- a/dlt/destinations/duckdb/configuration.py +++ b/dlt/destinations/duckdb/configuration.py @@ -1,8 +1,9 @@ import os import threading -from pathvalidate import is_valid_filepath from typing import Any, ClassVar, Final, List, Optional, Tuple +from pathvalidate import is_valid_filepath + from dlt.common import logger from dlt.common.configuration import configspec from dlt.common.configuration.specs import ConnectionStringCredentials @@ -58,6 +59,7 @@ def parse_native_representation(self, native_value: Any) -> None: try: # check if database was passed as explicit connection import duckdb + if isinstance(native_value, duckdb.DuckDBPyConnection): self._conn = native_value self._conn_owner = False @@ -128,7 +130,6 @@ def _path_in_pipeline(self, rel_path: str) -> str: return os.path.join(context.pipeline().working_dir, rel_path) return None - def _path_to_pipeline(self, abspath: str) -> None: from dlt.common.configuration.container import Container from dlt.common.pipeline import PipelineContext @@ -164,7 +165,11 @@ def _path_from_pipeline(self, default_path: str) -> Tuple[str, bool]: pipeline_path = pipeline.get_local_state_val(LOCAL_STATE_KEY) # make sure that path exists if not os.path.exists(pipeline_path): - logger.warning(f"Duckdb attached to pipeline {pipeline.pipeline_name} in path {os.path.relpath(pipeline_path)} was deleted. Attaching to duckdb database '{default_path}' in current folder.") + logger.warning( + f"Duckdb attached to pipeline {pipeline.pipeline_name} in path" + f" {os.path.relpath(pipeline_path)} was deleted. Attaching to duckdb" + f" database '{default_path}' in current folder." + ) else: return pipeline_path, False except KeyError: @@ -179,4 +184,6 @@ class DuckDbClientConfiguration(DestinationClientDwhWithStagingConfiguration): destination_name: Final[str] = "duckdb" # type: ignore credentials: DuckDbCredentials - create_indexes: bool = False # should unique indexes be created, this slows loading down massively + create_indexes: bool = ( + False # should unique indexes be created, this slows loading down massively + ) diff --git a/dlt/destinations/duckdb/duck.py b/dlt/destinations/duckdb/duck.py index 8fd9deba52..3dbf72216d 100644 --- a/dlt/destinations/duckdb/duck.py +++ b/dlt/destinations/duckdb/duck.py @@ -1,18 +1,15 @@ from typing import ClassVar, Dict, Optional -from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.data_types import TDataType -from dlt.common.schema import TColumnSchema, TColumnHint, Schema -from dlt.common.destination.reference import LoadJob, FollowupJob, TLoadJobState +from dlt.common.destination import DestinationCapabilitiesContext +from dlt.common.destination.reference import FollowupJob, LoadJob, TLoadJobState +from dlt.common.schema import Schema, TColumnHint, TColumnSchema from dlt.common.schema.typing import TTableSchema from dlt.common.storages.file_storage import FileStorage - -from dlt.destinations.insert_job_client import InsertValuesJobClient - from dlt.destinations.duckdb import capabilities -from dlt.destinations.duckdb.sql_client import DuckDbSqlClient from dlt.destinations.duckdb.configuration import DuckDbClientConfiguration - +from dlt.destinations.duckdb.sql_client import DuckDbSqlClient +from dlt.destinations.insert_job_client import InsertValuesJobClient SCT_TO_PGT: Dict[TDataType, str] = { "complex": "JSON", @@ -23,7 +20,7 @@ "timestamp": "TIMESTAMP WITH TIME ZONE", "bigint": "BIGINT", "binary": "BLOB", - "decimal": "DECIMAL(%i,%i)" + "decimal": "DECIMAL(%i,%i)", } PGT_TO_SCT: Dict[str, TDataType] = { @@ -35,12 +32,10 @@ "TIMESTAMP WITH TIME ZONE": "timestamp", "BIGINT": "bigint", "BLOB": "binary", - "DECIMAL": "decimal" + "DECIMAL": "decimal", } -HINT_TO_POSTGRES_ATTR: Dict[TColumnHint, str] = { - "unique": "UNIQUE" -} +HINT_TO_POSTGRES_ATTR: Dict[TColumnHint, str] = {"unique": "UNIQUE"} class DuckDbCopyJob(LoadJob, FollowupJob): @@ -56,8 +51,9 @@ def __init__(self, table_name: str, file_path: str, sql_client: DuckDbSqlClient) raise ValueError(file_path) qualified_table_name = sql_client.make_qualified_table_name(table_name) with sql_client.begin_transaction(): - sql_client.execute_sql(f"COPY {qualified_table_name} FROM '{file_path}' ( FORMAT {source_format} );") - + sql_client.execute_sql( + f"COPY {qualified_table_name} FROM '{file_path}' ( FORMAT {source_format} );" + ) def state(self) -> TLoadJobState: return "completed" @@ -65,15 +61,12 @@ def state(self) -> TLoadJobState: def exception(self) -> str: raise NotImplementedError() -class DuckDbClient(InsertValuesJobClient): +class DuckDbClient(InsertValuesJobClient): capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() def __init__(self, schema: Schema, config: DuckDbClientConfiguration) -> None: - sql_client = DuckDbSqlClient( - config.normalize_dataset_name(schema), - config.credentials - ) + sql_client = DuckDbSqlClient(config.normalize_dataset_name(schema), config.credentials) super().__init__(schema, config, sql_client) self.config: DuckDbClientConfiguration = config self.sql_client: DuckDbSqlClient = sql_client # type: ignore @@ -86,9 +79,15 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> return job def _get_column_def_sql(self, c: TColumnSchema) -> str: - hints_str = " ".join(self.active_hints.get(h, "") for h in self.active_hints.keys() if c.get(h, False) is True) + hints_str = " ".join( + self.active_hints.get(h, "") + for h in self.active_hints.keys() + if c.get(h, False) is True + ) column_name = self.capabilities.escape_identifier(c["name"]) - return f"{column_name} {self._to_db_type(c['data_type'])} {hints_str} {self._gen_not_null(c['nullable'])}" + return ( + f"{column_name} {self._to_db_type(c['data_type'])} {hints_str} {self._gen_not_null(c['nullable'])}" + ) @classmethod def _to_db_type(cls, sc_t: TDataType) -> str: diff --git a/dlt/destinations/duckdb/sql_client.py b/dlt/destinations/duckdb/sql_client.py index 94f9cb38d2..f6dbadd298 100644 --- a/dlt/destinations/duckdb/sql_client.py +++ b/dlt/destinations/duckdb/sql_client.py @@ -1,19 +1,28 @@ -import duckdb - from contextlib import contextmanager from typing import Any, AnyStr, ClassVar, Iterator, Optional, Sequence -from dlt.common.destination import DestinationCapabilitiesContext -from dlt.destinations.exceptions import DatabaseTerminalException, DatabaseTransientException, DatabaseUndefinedRelation -from dlt.destinations.typing import DBApi, DBApiCursor, DBTransaction, DataFrame -from dlt.destinations.sql_client import SqlClientBase, DBApiCursorImpl, raise_database_error, raise_open_connection_error +import duckdb +from dlt.common.destination import DestinationCapabilitiesContext from dlt.destinations.duckdb import capabilities from dlt.destinations.duckdb.configuration import DuckDbBaseCredentials +from dlt.destinations.exceptions import ( + DatabaseTerminalException, + DatabaseTransientException, + DatabaseUndefinedRelation, +) +from dlt.destinations.sql_client import ( + DBApiCursorImpl, + SqlClientBase, + raise_database_error, + raise_open_connection_error, +) +from dlt.destinations.typing import DataFrame, DBApi, DBApiCursor, DBTransaction class DuckDBDBApiCursorImpl(DBApiCursorImpl): """Use native BigQuery data frame support if available""" + native_cursor: duckdb.DuckDBPyConnection # type: ignore vector_size: ClassVar[int] = 2048 @@ -21,7 +30,9 @@ def df(self, chunk_size: int = None, **kwargs: Any) -> DataFrame: if chunk_size is None: return self.native_cursor.df(**kwargs) else: - multiple = chunk_size // self.vector_size + (0 if self.vector_size % chunk_size == 0 else 1) + multiple = chunk_size // self.vector_size + ( + 0 if self.vector_size % chunk_size == 0 else 1 + ) df = self.native_cursor.fetch_df_chunk(multiple, **kwargs) if df.shape[0] == 0: return None @@ -30,7 +41,6 @@ def df(self, chunk_size: int = None, **kwargs: Any) -> DataFrame: class DuckDbSqlClient(SqlClientBase[duckdb.DuckDBPyConnection], DBTransaction): - dbapi: ClassVar[DBApi] = duckdb capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() @@ -44,11 +54,11 @@ def open_connection(self) -> duckdb.DuckDBPyConnection: self._conn = self.credentials.borrow_conn(read_only=self.credentials.read_only) # TODO: apply config settings from credentials self._conn.execute("PRAGMA enable_checkpoint_on_shutdown;") - config={ + config = { "search_path": self.fully_qualified_dataset_name(), "TimeZone": "UTC", - "checkpoint_threshold": "1gb" - } + "checkpoint_threshold": "1gb", + } if config: for k, v in config.items(): try: @@ -91,7 +101,9 @@ def rollback_transaction(self) -> None: def native_connection(self) -> duckdb.DuckDBPyConnection: return self._conn - def execute_sql(self, sql: AnyStr, *args: Any, **kwargs: Any) -> Optional[Sequence[Sequence[Any]]]: + def execute_sql( + self, sql: AnyStr, *args: Any, **kwargs: Any + ) -> Optional[Sequence[Sequence[Any]]]: with self.execute_query(sql, *args, **kwargs) as curr: if curr.description is None: return None @@ -130,7 +142,9 @@ def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DB # return None def fully_qualified_dataset_name(self, escape: bool = True) -> str: - return self.capabilities.escape_identifier(self.dataset_name) if escape else self.dataset_name + return ( + self.capabilities.escape_identifier(self.dataset_name) if escape else self.dataset_name + ) @classmethod def _make_database_exception(cls, ex: Exception) -> Exception: @@ -141,7 +155,15 @@ def _make_database_exception(cls, ex: Exception) -> Exception: raise DatabaseUndefinedRelation(ex) # duckdb raises TypeError on malformed query parameters return DatabaseTransientException(duckdb.ProgrammingError(ex)) - elif isinstance(ex, (duckdb.OperationalError, duckdb.InternalError, duckdb.SyntaxException, duckdb.ParserException)): + elif isinstance( + ex, + ( + duckdb.OperationalError, + duckdb.InternalError, + duckdb.SyntaxException, + duckdb.ParserException, + ), + ): term = cls._maybe_make_terminal_exception_from_data_error(ex) if term: return term diff --git a/dlt/destinations/dummy/__init__.py b/dlt/destinations/dummy/__init__.py index 7131f0109a..53b77a7ae3 100644 --- a/dlt/destinations/dummy/__init__.py +++ b/dlt/destinations/dummy/__init__.py @@ -1,15 +1,20 @@ from typing import Type -from dlt.common.schema.schema import Schema -from dlt.common.configuration import with_config, known_sections +from dlt.common.configuration import known_sections, with_config from dlt.common.configuration.accessors import config from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import JobClientBase, DestinationClientConfiguration - +from dlt.common.destination.reference import DestinationClientConfiguration, JobClientBase +from dlt.common.schema.schema import Schema from dlt.destinations.dummy.configuration import DummyClientConfiguration -@with_config(spec=DummyClientConfiguration, sections=(known_sections.DESTINATION, "dummy",)) +@with_config( + spec=DummyClientConfiguration, + sections=( + known_sections.DESTINATION, + "dummy", + ), +) def _configure(config: DummyClientConfiguration = config.value) -> DummyClientConfiguration: return config @@ -32,7 +37,9 @@ def capabilities() -> DestinationCapabilitiesContext: return caps -def client(schema: Schema, initial_config: DestinationClientConfiguration = config.value) -> JobClientBase: +def client( + schema: Schema, initial_config: DestinationClientConfiguration = config.value +) -> JobClientBase: # import client when creating instance so capabilities and config specs can be accessed without dependencies installed from dlt.destinations.dummy.dummy import DummyClient diff --git a/dlt/destinations/dummy/configuration.py b/dlt/destinations/dummy/configuration.py index 91a47f0b52..4c80da9659 100644 --- a/dlt/destinations/dummy/configuration.py +++ b/dlt/destinations/dummy/configuration.py @@ -1,11 +1,13 @@ from dlt.common.configuration import configspec from dlt.common.destination import TLoaderFileFormat -from dlt.common.destination.reference import DestinationClientConfiguration, CredentialsConfiguration +from dlt.common.destination.reference import ( + CredentialsConfiguration, + DestinationClientConfiguration, +) @configspec class DummyClientCredentials(CredentialsConfiguration): - def __str__(self) -> str: return "/dev/null" diff --git a/dlt/destinations/dummy/dummy.py b/dlt/destinations/dummy/dummy.py index 9162f5c733..4f60d04036 100644 --- a/dlt/destinations/dummy/dummy.py +++ b/dlt/destinations/dummy/dummy.py @@ -1,20 +1,28 @@ import random from copy import copy from types import TracebackType -from typing import ClassVar, Dict, Optional, Sequence, Type, Iterable, List +from typing import ClassVar, Dict, Iterable, List, Optional, Sequence, Type from dlt.common import pendulum -from dlt.common.schema import Schema, TTableSchema, TSchemaTables +from dlt.common.destination import DestinationCapabilitiesContext +from dlt.common.destination.reference import ( + FollowupJob, + JobClientBase, + LoadJob, + NewLoadJob, + TLoadJobState, +) +from dlt.common.schema import Schema, TSchemaTables, TTableSchema from dlt.common.schema.typing import TWriteDisposition from dlt.common.storages import FileStorage -from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import FollowupJob, NewLoadJob, TLoadJobState, LoadJob, JobClientBase - -from dlt.destinations.exceptions import (LoadJobNotExistsException, LoadJobInvalidStateTransitionException, - DestinationTerminalException, DestinationTransientException) - from dlt.destinations.dummy import capabilities from dlt.destinations.dummy.configuration import DummyClientConfiguration +from dlt.destinations.exceptions import ( + DestinationTerminalException, + DestinationTransientException, + LoadJobInvalidStateTransitionException, + LoadJobNotExistsException, +) class LoadDummyJob(LoadJob, FollowupJob): @@ -31,7 +39,6 @@ def __init__(self, file_name: str, config: DummyClientConfiguration) -> None: if s == "retry": raise DestinationTransientException(self._exception) - def state(self) -> TLoadJobState: # this should poll the server for a job status, here we simulate various outcomes if self._status == "running": @@ -84,10 +91,14 @@ def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: def is_storage_initialized(self) -> bool: return True - def update_storage_schema(self, only_tables: Iterable[str] = None, expected_update: TSchemaTables = None) -> Optional[TSchemaTables]: + def update_storage_schema( + self, only_tables: Iterable[str] = None, expected_update: TSchemaTables = None + ) -> Optional[TSchemaTables]: applied_update = super().update_storage_schema(only_tables, expected_update) if self.config.fail_schema_update: - raise DestinationTransientException("Raise on schema update due to fail_schema_update config flag") + raise DestinationTransientException( + "Raise on schema update due to fail_schema_update config flag" + ) return applied_update def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: @@ -109,7 +120,9 @@ def restore_file_load(self, file_path: str) -> LoadJob: raise LoadJobNotExistsException(job_id) return JOBS[job_id] - def create_table_chain_completed_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + def create_table_chain_completed_followup_jobs( + self, table_chain: Sequence[TTableSchema] + ) -> List[NewLoadJob]: """Creates a list of followup jobs that should be executed after a table chain is completed""" return [] @@ -119,11 +132,10 @@ def complete_load(self, load_id: str) -> None: def __enter__(self) -> "DummyClient": return self - def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType) -> None: + def __exit__( + self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType + ) -> None: pass def _create_job(self, job_id: str) -> LoadDummyJob: - return LoadDummyJob( - job_id, - config=self.config - ) + return LoadDummyJob(job_id, config=self.config) diff --git a/dlt/destinations/exceptions.py b/dlt/destinations/exceptions.py index f0fe32f950..c4bd92d6bb 100644 --- a/dlt/destinations/exceptions.py +++ b/dlt/destinations/exceptions.py @@ -1,6 +1,12 @@ from typing import Sequence -from dlt.common.exceptions import DestinationTerminalException, DestinationTransientException, DestinationUndefinedEntity, DestinationException + from dlt.common.destination.reference import TLoadJobState +from dlt.common.exceptions import ( + DestinationException, + DestinationTerminalException, + DestinationTransientException, + DestinationUndefinedEntity, +) class DatabaseException(DestinationException): @@ -25,32 +31,49 @@ def __init__(self, dbapi_exception: Exception) -> None: class DestinationConnectionError(DestinationTransientException): - def __init__(self, client_type: str, dataset_name: str, reason: str, inner_exc: Exception) -> None: + def __init__( + self, client_type: str, dataset_name: str, reason: str, inner_exc: Exception + ) -> None: self.client_type = client_type self.dataset_name = dataset_name self.inner_exc = inner_exc - super().__init__(f"Connection with {client_type} to dataset name {dataset_name} failed. Please check if you configured the credentials at all and provided the right credentials values. You can be also denied access or your internet connection may be down. The actual reason given is: {reason}") + super().__init__( + f"Connection with {client_type} to dataset name {dataset_name} failed. Please check if" + " you configured the credentials at all and provided the right credentials values. You" + " can be also denied access or your internet connection may be down. The actual reason" + f" given is: {reason}" + ) + class LoadClientNotConnected(DestinationTransientException): def __init__(self, client_type: str, dataset_name: str) -> None: self.client_type = client_type self.dataset_name = dataset_name - super().__init__(f"Connection with {client_type} to dataset {dataset_name} is closed. Open the connection with 'client.open_connection' or with the 'with client:' statement") + super().__init__( + f"Connection with {client_type} to dataset {dataset_name} is closed. Open the" + " connection with 'client.open_connection' or with the 'with client:' statement" + ) class DestinationSchemaWillNotUpdate(DestinationTerminalException): def __init__(self, table_name: str, columns: Sequence[str], msg: str) -> None: self.table_name = table_name self.columns = columns - super().__init__(f"Schema for table {table_name} column(s) {columns} will not update: {msg}") + super().__init__( + f"Schema for table {table_name} column(s) {columns} will not update: {msg}" + ) class DestinationSchemaTampered(DestinationTerminalException): def __init__(self, schema_name: str, version_hash: str, stored_version_hash: str) -> None: self.version_hash = version_hash self.stored_version_hash = stored_version_hash - super().__init__(f"Schema {schema_name} content was changed - by a loader or by destination code - from the moment it was retrieved by load package. " - f"Such schema cannot reliably be updated or saved. Current version hash: {version_hash} != stored version hash {stored_version_hash}") + super().__init__( + f"Schema {schema_name} content was changed - by a loader or by destination code - from" + " the moment it was retrieved by load package. Such schema cannot reliably be updated" + f" or saved. Current version hash: {version_hash} != stored version hash" + f" {stored_version_hash}" + ) class LoadJobNotExistsException(DestinationTerminalException): @@ -60,7 +83,9 @@ def __init__(self, job_id: str) -> None: class LoadJobTerminalException(DestinationTerminalException): def __init__(self, file_path: str, message: str) -> None: - super().__init__(f"Job with id/file name {file_path} encountered unrecoverable problem: {message}") + super().__init__( + f"Job with id/file name {file_path} encountered unrecoverable problem: {message}" + ) class LoadJobUnknownTableException(DestinationTerminalException): @@ -78,17 +103,28 @@ def __init__(self, from_state: TLoadJobState, to_state: TLoadJobState) -> None: class LoadJobFileTooBig(DestinationTerminalException): def __init__(self, file_name: str, max_size: int) -> None: - super().__init__(f"File {file_name} exceeds {max_size} and cannot be loaded. Split the file and try again.") + super().__init__( + f"File {file_name} exceeds {max_size} and cannot be loaded. Split the file and try" + " again." + ) class MergeDispositionException(DestinationTerminalException): - def __init__(self, dataset_name: str, staging_dataset_name: str, tables: Sequence[str], reason: str) -> None: + def __init__( + self, dataset_name: str, staging_dataset_name: str, tables: Sequence[str], reason: str + ) -> None: self.dataset_name = dataset_name self.staging_dataset_name = staging_dataset_name self.tables = tables self.reason = reason - msg = f"Merge sql job for dataset name {dataset_name}, staging dataset name {staging_dataset_name} COULD NOT BE GENERATED. Merge will not be performed. " - msg += f"Data for the following tables ({tables}) is loaded to staging dataset. You may need to write your own materialization. The reason is:\n" + msg = ( + f"Merge sql job for dataset name {dataset_name}, staging dataset name" + f" {staging_dataset_name} COULD NOT BE GENERATED. Merge will not be performed. " + ) + msg += ( + f"Data for the following tables ({tables}) is loaded to staging dataset. You may need" + " to write your own materialization. The reason is:\n" + ) msg += reason super().__init__(msg) diff --git a/dlt/destinations/filesystem/__init__.py b/dlt/destinations/filesystem/__init__.py index 2b0f7bf6a2..c73e7e7484 100644 --- a/dlt/destinations/filesystem/__init__.py +++ b/dlt/destinations/filesystem/__init__.py @@ -1,16 +1,26 @@ from typing import Type -from dlt.common.schema.schema import Schema -from dlt.common.configuration import with_config, known_sections +from dlt.common.configuration import known_sections, with_config from dlt.common.configuration.accessors import config from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import JobClientBase, DestinationClientDwhWithStagingConfiguration - +from dlt.common.destination.reference import ( + DestinationClientDwhWithStagingConfiguration, + JobClientBase, +) +from dlt.common.schema.schema import Schema from dlt.destinations.filesystem.configuration import FilesystemClientConfiguration -@with_config(spec=FilesystemClientConfiguration, sections=(known_sections.DESTINATION, "filesystem",)) -def _configure(config: FilesystemClientConfiguration = config.value) -> FilesystemClientConfiguration: +@with_config( + spec=FilesystemClientConfiguration, + sections=( + known_sections.DESTINATION, + "filesystem", + ), +) +def _configure( + config: FilesystemClientConfiguration = config.value, +) -> FilesystemClientConfiguration: return config @@ -18,7 +28,9 @@ def capabilities() -> DestinationCapabilitiesContext: return DestinationCapabilitiesContext.generic_capabilities("jsonl") -def client(schema: Schema, initial_config: DestinationClientDwhWithStagingConfiguration = config.value) -> JobClientBase: +def client( + schema: Schema, initial_config: DestinationClientDwhWithStagingConfiguration = config.value +) -> JobClientBase: # import client when creating instance so capabilities and config specs can be accessed without dependencies installed from dlt.destinations.filesystem.filesystem import FilesystemClient diff --git a/dlt/destinations/filesystem/configuration.py b/dlt/destinations/filesystem/configuration.py index 0d43f23341..88f13741c0 100644 --- a/dlt/destinations/filesystem/configuration.py +++ b/dlt/destinations/filesystem/configuration.py @@ -1,19 +1,24 @@ +from typing import TYPE_CHECKING, Final, Optional, Type, Union from urllib.parse import urlparse -from typing import Final, Type, Optional, Union, TYPE_CHECKING - from dlt.common.configuration import configspec, resolve_type -from dlt.common.destination.reference import CredentialsConfiguration, DestinationClientStagingConfiguration -from dlt.common.configuration.specs import GcpServiceAccountCredentials, AwsCredentials, GcpOAuthCredentials -from dlt.common.utils import digest128 from dlt.common.configuration.exceptions import ConfigurationValueError - +from dlt.common.configuration.specs import ( + AwsCredentials, + GcpOAuthCredentials, + GcpServiceAccountCredentials, +) +from dlt.common.destination.reference import ( + CredentialsConfiguration, + DestinationClientStagingConfiguration, +) +from dlt.common.utils import digest128 PROTOCOL_CREDENTIALS = { "gs": Union[GcpServiceAccountCredentials, GcpOAuthCredentials], "gcs": Union[GcpServiceAccountCredentials, GcpOAuthCredentials], "gdrive": GcpOAuthCredentials, - "s3": AwsCredentials + "s3": AwsCredentials, } @@ -31,13 +36,16 @@ def protocol(self) -> str: def on_resolved(self) -> None: url = urlparse(self.bucket_url) if not url.path and not url.netloc: - raise ConfigurationValueError("File path or netloc missing. Field bucket_url of FilesystemClientConfiguration must contain valid url with a path or host:password component.") + raise ConfigurationValueError( + "File path or netloc missing. Field bucket_url of FilesystemClientConfiguration" + " must contain valid url with a path or host:password component." + ) # this is just a path in local file system if url.path == self.bucket_url: url = url._replace(scheme="file") self.bucket_url = url.geturl() - @resolve_type('credentials') + @resolve_type("credentials") def resolve_credentials_type(self) -> Type[CredentialsConfiguration]: # use known credentials or empty credentials for unknown protocol return PROTOCOL_CREDENTIALS.get(self.protocol) or Optional[CredentialsConfiguration] # type: ignore[return-value] @@ -60,6 +68,7 @@ def __str__(self) -> str: return self.bucket_url if TYPE_CHECKING: + def __init__( self, destination_name: str = None, diff --git a/dlt/destinations/filesystem/filesystem.py b/dlt/destinations/filesystem/filesystem.py index 45b3504bb6..b755b61404 100644 --- a/dlt/destinations/filesystem/filesystem.py +++ b/dlt/destinations/filesystem/filesystem.py @@ -1,57 +1,72 @@ -import posixpath import os +import posixpath from types import TracebackType -from typing import ClassVar, List, Optional, Type, Iterable, Set +from typing import ClassVar, Iterable, List, Optional, Set, Type + from fsspec import AbstractFileSystem from dlt.common import logger -from dlt.common.schema import Schema, TSchemaTables, TTableSchema -from dlt.common.storages import FileStorage from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import NewLoadJob, TLoadJobState, LoadJob, JobClientBase, FollowupJob -from dlt.destinations.job_impl import EmptyLoadJob +from dlt.common.destination.reference import ( + FollowupJob, + JobClientBase, + LoadJob, + NewLoadJob, + TLoadJobState, +) +from dlt.common.schema import Schema, TSchemaTables, TTableSchema +from dlt.common.storages import FileStorage, LoadStorage +from dlt.destinations import path_utils from dlt.destinations.filesystem import capabilities from dlt.destinations.filesystem.configuration import FilesystemClientConfiguration from dlt.destinations.filesystem.filesystem_client import client_from_config -from dlt.common.storages import LoadStorage -from dlt.destinations.job_impl import NewReferenceJob -from dlt.destinations import path_utils +from dlt.destinations.job_impl import EmptyLoadJob, NewReferenceJob class LoadFilesystemJob(LoadJob): def __init__( - self, - local_path: str, - dataset_path: str, - *, - config: FilesystemClientConfiguration, - schema_name: str, - load_id: str + self, + local_path: str, + dataset_path: str, + *, + config: FilesystemClientConfiguration, + schema_name: str, + load_id: str, ) -> None: file_name = FileStorage.get_file_name_from_file_path(local_path) self.config = config self.dataset_path = dataset_path - self.destination_file_name = LoadFilesystemJob.make_destination_filename(config.layout, file_name, schema_name, load_id) + self.destination_file_name = LoadFilesystemJob.make_destination_filename( + config.layout, file_name, schema_name, load_id + ) super().__init__(file_name) fs_client, _ = client_from_config(config) - self.destination_file_name = LoadFilesystemJob.make_destination_filename(config.layout, file_name, schema_name, load_id) + self.destination_file_name = LoadFilesystemJob.make_destination_filename( + config.layout, file_name, schema_name, load_id + ) item = self.make_remote_path() logger.info("PUT file {item}") fs_client.put_file(local_path, item) @staticmethod - def make_destination_filename(layout: str, file_name: str, schema_name: str, load_id: str) -> str: + def make_destination_filename( + layout: str, file_name: str, schema_name: str, load_id: str + ) -> str: job_info = LoadStorage.parse_job_file_name(file_name) - return path_utils.create_path(layout, - schema_name=schema_name, - table_name=job_info.table_name, - load_id=load_id, - file_id=job_info.file_id, - ext=job_info.file_format) + return path_utils.create_path( + layout, + schema_name=schema_name, + table_name=job_info.table_name, + load_id=load_id, + file_id=job_info.file_id, + ext=job_info.file_format, + ) def make_remote_path(self) -> str: - return f"{self.config.protocol}://{posixpath.join(self.dataset_path, self.destination_file_name)}" + return ( + f"{self.config.protocol}://{posixpath.join(self.dataset_path, self.destination_file_name)}" + ) def state(self) -> TLoadJobState: return "completed" @@ -64,7 +79,9 @@ class FollowupFilesystemJob(FollowupJob, LoadFilesystemJob): def create_followup_jobs(self, next_state: str) -> List[NewLoadJob]: jobs = super().create_followup_jobs(next_state) if next_state == "completed": - ref_job = NewReferenceJob(file_name=self.file_name(), status="running", remote_path=self.make_remote_path()) + ref_job = NewReferenceJob( + file_name=self.file_name(), status="running", remote_path=self.make_remote_path() + ) jobs.append(ref_job) return jobs @@ -98,7 +115,9 @@ def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: # print(f"TRUNCATE {truncated_dirs}") truncate_prefixes: Set[str] = set() for table in truncate_tables: - table_prefix = self.table_prefix_layout.format(schema_name=self.schema.name, table_name=table) + table_prefix = self.table_prefix_layout.format( + schema_name=self.schema.name, table_name=table + ) truncate_prefixes.add(posixpath.join(self.dataset_path, table_prefix)) # print(f"TRUNCATE PREFIXES {truncate_prefixes}") @@ -120,9 +139,14 @@ def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: # print(f"DEL {item}") self.fs_client.rm_file(item) except FileNotFoundError: - logger.info(f"Directory or path to truncate tables {truncate_dir} does not exist but it should be created previously!") - - def update_storage_schema(self, only_tables: Iterable[str] = None, expected_update: TSchemaTables = None) -> TSchemaTables: + logger.info( + f"Directory or path to truncate tables {truncate_dir} does not exist but it" + " should be created previously!" + ) + + def update_storage_schema( + self, only_tables: Iterable[str] = None, expected_update: TSchemaTables = None + ) -> TSchemaTables: # create destination dirs for all tables dirs_to_create = self._get_table_dirs(only_tables or self.schema.tables.keys()) for directory in dirs_to_create: @@ -133,7 +157,9 @@ def _get_table_dirs(self, table_names: Iterable[str]) -> Set[str]: """Gets unique directories where table data is stored.""" table_dirs: Set[str] = set() for table_name in table_names: - table_prefix = self.table_prefix_layout.format(schema_name=self.schema.name, table_name=table_name) + table_prefix = self.table_prefix_layout.format( + schema_name=self.schema.name, table_name=table_name + ) destination_dir = posixpath.join(self.dataset_path, table_prefix) # extract the path component table_dirs.add(os.path.dirname(destination_dir)) @@ -149,7 +175,7 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> self.dataset_path, config=self.config, schema_name=self.schema.name, - load_id=load_id + load_id=load_id, ) def restore_file_load(self, file_path: str) -> LoadJob: @@ -164,5 +190,7 @@ def complete_load(self, load_id: str) -> None: def __enter__(self) -> "FilesystemClient": return self - def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType) -> None: + def __exit__( + self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType + ) -> None: pass diff --git a/dlt/destinations/filesystem/filesystem_client.py b/dlt/destinations/filesystem/filesystem_client.py index e454d5408e..ed6a4a431e 100644 --- a/dlt/destinations/filesystem/filesystem_client.py +++ b/dlt/destinations/filesystem/filesystem_client.py @@ -1,16 +1,14 @@ -from typing import cast, Tuple +from typing import Tuple, cast -from fsspec.core import url_to_fs from fsspec import AbstractFileSystem +from fsspec.core import url_to_fs +from dlt import version +from dlt.common.configuration.specs import AwsCredentials, CredentialsWithDefault, GcpCredentials from dlt.common.exceptions import MissingDependencyException from dlt.common.typing import DictStrAny -from dlt.common.configuration.specs import CredentialsWithDefault, GcpCredentials, AwsCredentials - from dlt.destinations.filesystem.configuration import FilesystemClientConfiguration -from dlt import version - def client_from_config(config: FilesystemClientConfiguration) -> Tuple[AbstractFileSystem, str]: proto = config.protocol @@ -18,15 +16,20 @@ def client_from_config(config: FilesystemClientConfiguration) -> Tuple[AbstractF if proto == "s3": credentials = cast(AwsCredentials, config.credentials) fs_kwargs.update(credentials.to_s3fs_credentials()) - elif proto in ['gcs', 'gs']: + elif proto in ["gcs", "gs"]: assert isinstance(config.credentials, GcpCredentials) # Default credentials are handled by gcsfs - if isinstance(config.credentials, CredentialsWithDefault) and config.credentials.has_default_credentials(): - fs_kwargs['token'] = None + if ( + isinstance(config.credentials, CredentialsWithDefault) + and config.credentials.has_default_credentials() + ): + fs_kwargs["token"] = None else: - fs_kwargs['token'] = dict(config.credentials) - fs_kwargs['project'] = config.credentials.project_id + fs_kwargs["token"] = dict(config.credentials) + fs_kwargs["project"] = config.credentials.project_id try: return url_to_fs(config.bucket_url, **fs_kwargs) # type: ignore[no-any-return] except ModuleNotFoundError as e: - raise MissingDependencyException("filesystem destination", [f"{version.DLT_PKG_NAME}[{proto}]"]) from e + raise MissingDependencyException( + "filesystem destination", [f"{version.DLT_PKG_NAME}[{proto}]"] + ) from e diff --git a/dlt/destinations/insert_job_client.py b/dlt/destinations/insert_job_client.py index 5b736dab94..42dd81f79b 100644 --- a/dlt/destinations/insert_job_client.py +++ b/dlt/destinations/insert_job_client.py @@ -1,14 +1,13 @@ -import os import abc +import os from typing import Any, Iterator, List -from dlt.common.destination.reference import LoadJob, FollowupJob, TLoadJobState +from dlt.common.destination.reference import FollowupJob, LoadJob, TLoadJobState from dlt.common.schema.typing import TTableSchema from dlt.common.storages import FileStorage - -from dlt.destinations.sql_client import SqlClientBase -from dlt.destinations.job_impl import EmptyLoadJob from dlt.destinations.job_client_impl import SqlJobClientWithStaging +from dlt.destinations.job_impl import EmptyLoadJob +from dlt.destinations.sql_client import SqlClientBase class InsertValuesLoadJob(LoadJob, FollowupJob): @@ -17,7 +16,9 @@ def __init__(self, table_name: str, file_path: str, sql_client: SqlClientBase[An self._sql_client = sql_client # insert file content immediately with self._sql_client.begin_transaction(): - for fragments in self._insert(sql_client.make_qualified_table_name(table_name), file_path): + for fragments in self._insert( + sql_client.make_qualified_table_name(table_name), file_path + ): self._sql_client.execute_fragments(fragments) def state(self) -> TLoadJobState: @@ -67,7 +68,6 @@ def _insert(self, qualified_table_name: str, file_path: str) -> Iterator[List[st class InsertValuesJobClient(SqlJobClientWithStaging): - def restore_file_load(self, file_path: str) -> LoadJob: """Returns a completed SqlLoadJob or InsertValuesJob @@ -101,4 +101,3 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> # def _get_out_table_constrains_sql(self, t: TTableSchema) -> str: # # set non unique indexes # pass - diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index 38adeb1bc5..ba62c5914d 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -1,30 +1,62 @@ -import os -from abc import abstractmethod import base64 import binascii import contextlib -from copy import copy import datetime # noqa: 251 -from types import TracebackType -from typing import Any, ClassVar, List, NamedTuple, Optional, Sequence, Tuple, Type, Iterable, Iterator, ContextManager -import zlib +import os import re - -from dlt.common import json, pendulum, logger +import zlib +from abc import abstractmethod +from copy import copy +from types import TracebackType +from typing import ( + Any, + ClassVar, + ContextManager, + Iterable, + Iterator, + List, + NamedTuple, + Optional, + Sequence, + Tuple, + Type, +) + +from dlt.common import json, logger, pendulum from dlt.common.data_types import TDataType -from dlt.common.schema.typing import COLUMN_HINTS, TColumnSchemaBase, TTableSchema, TWriteDisposition +from dlt.common.destination.reference import ( + CredentialsConfiguration, + DestinationClientConfiguration, + DestinationClientDwhConfiguration, + DestinationClientDwhWithStagingConfiguration, + FollowupJob, + JobClientBase, + LoadJob, + NewLoadJob, + TLoadJobState, + WithStagingDataset, +) +from dlt.common.schema import Schema, TColumnSchema, TSchemaTables, TTableSchemaColumns +from dlt.common.schema.typing import ( + COLUMN_HINTS, + LOADS_TABLE_NAME, + VERSION_TABLE_NAME, + TColumnSchemaBase, + TTableSchema, + TWriteDisposition, +) from dlt.common.schema.utils import add_missing_hints from dlt.common.storages import FileStorage -from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns, TSchemaTables -from dlt.common.destination.reference import DestinationClientConfiguration, DestinationClientDwhConfiguration, DestinationClientDwhWithStagingConfiguration, NewLoadJob, WithStagingDataset, TLoadJobState, LoadJob, JobClientBase, FollowupJob, CredentialsConfiguration from dlt.common.utils import concat_strings_with_limit -from dlt.destinations.exceptions import DatabaseUndefinedRelation, DestinationSchemaTampered, DestinationSchemaWillNotUpdate +from dlt.destinations.exceptions import ( + DatabaseUndefinedRelation, + DestinationSchemaTampered, + DestinationSchemaWillNotUpdate, +) from dlt.destinations.job_impl import EmptyLoadJobWithoutFollowup, NewReferenceJob +from dlt.destinations.sql_client import SqlClientBase from dlt.destinations.sql_jobs import SqlMergeJob, SqlStagingCopyJob -from dlt.common.schema.typing import LOADS_TABLE_NAME, VERSION_TABLE_NAME - from dlt.destinations.typing import TNativeConn -from dlt.destinations.sql_client import SqlClientBase class StorageSchemaInfo(NamedTuple): @@ -35,12 +67,10 @@ class StorageSchemaInfo(NamedTuple): inserted_at: datetime.datetime schema: str + # this should suffice for now -DDL_COMMANDS = [ - "ALTER", - "CREATE", - "DROP" -] +DDL_COMMANDS = ["ALTER", "CREATE", "DROP"] + class SqlLoadJob(LoadJob): """A job executing sql statement, without followup trait""" @@ -52,7 +82,10 @@ def __init__(self, file_path: str, sql_client: SqlClientBase[Any]) -> None: sql = f.read() # if we detect ddl transactions, only execute transaction if supported by client - if not self._string_containts_ddl_queries(sql) or sql_client.capabilities.supports_ddl_transactions: + if ( + not self._string_containts_ddl_queries(sql) + or sql_client.capabilities.supports_ddl_transactions + ): # with sql_client.begin_transaction(): sql_client.execute_sql(sql) else: @@ -78,7 +111,13 @@ def is_sql_job(file_path: str) -> bool: class CopyRemoteFileLoadJob(LoadJob, FollowupJob): - def __init__(self, table: TTableSchema, file_path: str, sql_client: SqlClientBase[Any], staging_credentials: Optional[CredentialsConfiguration] = None) -> None: + def __init__( + self, + table: TTableSchema, + file_path: str, + sql_client: SqlClientBase[Any], + staging_credentials: Optional[CredentialsConfiguration] = None, + ) -> None: super().__init__(FileStorage.get_file_name_from_file_path(file_path)) self._sql_client = sql_client self._staging_credentials = staging_credentials @@ -95,10 +134,16 @@ def state(self) -> TLoadJobState: class SqlJobClientBase(JobClientBase): - - VERSION_TABLE_SCHEMA_COLUMNS: ClassVar[str] = "version_hash, schema_name, version, engine_version, inserted_at, schema" - - def __init__(self, schema: Schema, config: DestinationClientConfiguration, sql_client: SqlClientBase[TNativeConn]) -> None: + VERSION_TABLE_SCHEMA_COLUMNS: ClassVar[str] = ( + "version_hash, schema_name, version, engine_version, inserted_at, schema" + ) + + def __init__( + self, + schema: Schema, + config: DestinationClientConfiguration, + sql_client: SqlClientBase[TNativeConn], + ) -> None: super().__init__(schema, config) self.sql_client = sql_client assert isinstance(config, DestinationClientDwhConfiguration) @@ -115,17 +160,25 @@ def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: def is_storage_initialized(self) -> bool: return self.sql_client.has_dataset() - def update_storage_schema(self, only_tables: Iterable[str] = None, expected_update: TSchemaTables = None) -> Optional[TSchemaTables]: + def update_storage_schema( + self, only_tables: Iterable[str] = None, expected_update: TSchemaTables = None + ) -> Optional[TSchemaTables]: super().update_storage_schema(only_tables, expected_update) applied_update: TSchemaTables = {} schema_info = self.get_schema_by_hash(self.schema.stored_version_hash) if schema_info is None: - logger.info(f"Schema with hash {self.schema.stored_version_hash} not found in the storage. upgrading") + logger.info( + f"Schema with hash {self.schema.stored_version_hash} not found in the storage." + " upgrading" + ) with self.maybe_ddl_transaction(): applied_update = self._execute_schema_update_sql(only_tables) else: - logger.info(f"Schema with hash {self.schema.stored_version_hash} inserted at {schema_info.inserted_at} found in storage, no upgrade required") + logger.info( + f"Schema with hash {self.schema.stored_version_hash} inserted at" + f" {schema_info.inserted_at} found in storage, no upgrade required" + ) return applied_update def drop_tables(self, *tables: str, replace_schema: bool = True) -> None: @@ -161,13 +214,17 @@ def _create_staging_copy_job(self, table_chain: Sequence[TTableSchema]) -> NewLo def _create_optimized_replace_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: return self._create_staging_copy_job(table_chain) - def create_table_chain_completed_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + def create_table_chain_completed_followup_jobs( + self, table_chain: Sequence[TTableSchema] + ) -> List[NewLoadJob]: jobs = super().create_table_chain_completed_followup_jobs(table_chain) """Creates a list of followup jobs that should be executed after a table chain is completed""" write_disposition = table_chain[0]["write_disposition"] if write_disposition == "merge": jobs.append(self._create_merge_job(table_chain)) - elif write_disposition == "replace" and self.config.replace_strategy == "insert-from-staging": + elif ( + write_disposition == "replace" and self.config.replace_strategy == "insert-from-staging" + ): jobs.append(self._create_staging_copy_job(table_chain)) elif write_disposition == "replace" and self.config.replace_strategy == "staging-optimized": jobs.append(self._create_optimized_replace_job(table_chain)) @@ -200,19 +257,25 @@ def complete_load(self, load_id: str) -> None: name = self.sql_client.make_qualified_table_name(self.schema.loads_table_name) now_ts = pendulum.now() self.sql_client.execute_sql( - f"INSERT INTO {name}(load_id, schema_name, status, inserted_at, schema_version_hash) VALUES(%s, %s, %s, %s, %s);", - load_id, self.schema.name, 0, now_ts, self.schema.version_hash + f"INSERT INTO {name}(load_id, schema_name, status, inserted_at, schema_version_hash)" + " VALUES(%s, %s, %s, %s, %s);", + load_id, + self.schema.name, + 0, + now_ts, + self.schema.version_hash, ) def __enter__(self) -> "SqlJobClientBase": self.sql_client.open_connection() return self - def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType) -> None: + def __exit__( + self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType + ) -> None: self.sql_client.close_connection() def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns]: - def _null_to_bool(v: str) -> bool: if v == "NO": return False @@ -223,7 +286,9 @@ def _null_to_bool(v: str) -> bool: fields = ["column_name", "data_type", "is_nullable"] if self.capabilities.schema_supports_numeric_precision: fields += ["numeric_precision", "numeric_scale"] - db_params = self.sql_client.make_qualified_table_name(table_name, escape=False).split(".", 3) + db_params = self.sql_client.make_qualified_table_name(table_name, escape=False).split( + ".", 3 + ) query = f""" SELECT {",".join(fields)} FROM INFORMATION_SCHEMA.COLUMNS @@ -240,7 +305,9 @@ def _null_to_bool(v: str) -> bool: return False, schema_table # TODO: pull more data to infer indexes, PK and uniques attributes/constraints for c in rows: - numeric_precision = c[3] if self.capabilities.schema_supports_numeric_precision else None + numeric_precision = ( + c[3] if self.capabilities.schema_supports_numeric_precision else None + ) numeric_scale = c[4] if self.capabilities.schema_supports_numeric_precision else None schema_c: TColumnSchemaBase = { "name": c[0], @@ -257,12 +324,17 @@ def _to_db_type(cls, schema_type: TDataType) -> str: @classmethod @abstractmethod - def _from_db_type(cls, db_type: str, precision: Optional[int], scale: Optional[int]) -> TDataType: + def _from_db_type( + cls, db_type: str, precision: Optional[int], scale: Optional[int] + ) -> TDataType: pass def get_newest_schema_from_storage(self) -> StorageSchemaInfo: name = self.sql_client.make_qualified_table_name(self.schema.version_table_name) - query = f"SELECT {self.VERSION_TABLE_SCHEMA_COLUMNS} FROM {name} WHERE schema_name = %s ORDER BY inserted_at DESC;" + query = ( + f"SELECT {self.VERSION_TABLE_SCHEMA_COLUMNS} FROM {name} WHERE schema_name = %s ORDER" + " BY inserted_at DESC;" + ) return self._row_to_schema_info(query, self.schema.name) def get_schema_by_hash(self, version_hash: str) -> StorageSchemaInfo: @@ -274,12 +346,16 @@ def _execute_schema_update_sql(self, only_tables: Iterable[str]) -> TSchemaTable sql_scripts, schema_update = self._build_schema_update_sql(only_tables) # stay within max query size when doing DDL. some db backends use bytes not characters so decrease limit by half # assuming that most of the characters in DDL encode into single bytes - for sql_fragment in concat_strings_with_limit(sql_scripts, "\n", self.capabilities.max_query_length // 2): + for sql_fragment in concat_strings_with_limit( + sql_scripts, "\n", self.capabilities.max_query_length // 2 + ): self.sql_client.execute_sql(sql_fragment) self._update_schema_in_storage(self.schema) return schema_update - def _build_schema_update_sql(self, only_tables: Iterable[str]) -> Tuple[List[str], TSchemaTables]: + def _build_schema_update_sql( + self, only_tables: Iterable[str] + ) -> Tuple[List[str], TSchemaTables]: """Generates CREATE/ALTER sql for tables that differ between the destination and in client's Schema. This method compares all or `only_tables` defined in self.schema to the respective tables in the destination. It detects only new tables and new columns. @@ -315,7 +391,9 @@ def _make_add_column_sql(self, new_columns: Sequence[TColumnSchema]) -> List[str """Make one or more ADD COLUMN sql clauses to be joined in ALTER TABLE statement(s)""" return [f"ADD COLUMN {self._get_column_def_sql(c)}" for c in new_columns] - def _get_table_update_sql(self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool) -> List[str]: + def _get_table_update_sql( + self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool + ) -> List[str]: # build sql canonical_name = self.sql_client.make_qualified_table_name(table_name) sql_result: List[str] = [] @@ -333,15 +411,23 @@ def _get_table_update_sql(self, table_name: str, new_columns: Sequence[TColumnSc sql_result.append(sql_base + column_sql.join(add_column_statements)) else: # build ALTER as separate statement for each column (redshift limitation) - sql_result.extend([sql_base + col_statement for col_statement in add_column_statements]) + sql_result.extend( + [sql_base + col_statement for col_statement in add_column_statements] + ) # scan columns to get hints if generate_alter: # no hints may be specified on added columns for hint in COLUMN_HINTS: if any(c.get(hint, False) is True for c in new_columns): - hint_columns = [self.capabilities.escape_identifier(c["name"]) for c in new_columns if c.get(hint, False)] - raise DestinationSchemaWillNotUpdate(canonical_name, hint_columns, f"{hint} requested after table was created") + hint_columns = [ + self.capabilities.escape_identifier(c["name"]) + for c in new_columns + if c.get(hint, False) + ] + raise DestinationSchemaWillNotUpdate( + canonical_name, hint_columns, f"{hint} requested after table was created" + ) return sql_result @abstractmethod @@ -352,14 +438,16 @@ def _get_column_def_sql(self, c: TColumnSchema) -> str: def _gen_not_null(v: bool) -> str: return "NOT NULL" if not v else "" - def _create_table_update(self, table_name: str, storage_columns: TTableSchemaColumns) -> Sequence[TColumnSchema]: + def _create_table_update( + self, table_name: str, storage_columns: TTableSchemaColumns + ) -> Sequence[TColumnSchema]: # compare table with stored schema and produce delta updates = self.schema.get_new_table_columns(table_name, storage_columns) logger.info(f"Found {len(updates)} updates for {table_name} in {self.schema.name}") return updates def _row_to_schema_info(self, query: str, *args: Any) -> StorageSchemaInfo: - row: Tuple[Any,...] = None + row: Tuple[Any, ...] = None # if there's no dataset/schema return none info with contextlib.suppress(DatabaseUndefinedRelation): with self.sql_client.execute_query(query, *args) as cur: @@ -387,9 +475,7 @@ def _replace_schema_in_storage(self, schema: Schema) -> None: Save the given schema in storage and remove all previous versions with the same name """ name = self.sql_client.make_qualified_table_name(self.schema.version_table_name) - self.sql_client.execute_sql( - f"DELETE FROM {name} WHERE schema_name = %s;", schema.name - ) + self.sql_client.execute_sql(f"DELETE FROM {name} WHERE schema_name = %s;", schema.name) self._update_schema_in_storage(schema) def _update_schema_in_storage(self, schema: Schema) -> None: @@ -411,13 +497,20 @@ def _commit_schema_update(self, schema: Schema, schema_str: str) -> None: name = self.sql_client.make_qualified_table_name(self.schema.version_table_name) # values = schema.version_hash, schema.name, schema.version, schema.ENGINE_VERSION, str(now_ts), schema_str self.sql_client.execute_sql( - f"INSERT INTO {name}({self.VERSION_TABLE_SCHEMA_COLUMNS}) VALUES (%s, %s, %s, %s, %s, %s);", schema.stored_version_hash, schema.name, schema.version, schema.ENGINE_VERSION, now_ts, schema_str + f"INSERT INTO {name}({self.VERSION_TABLE_SCHEMA_COLUMNS}) VALUES (%s, %s, %s, %s, %s," + " %s);", + schema.stored_version_hash, + schema.name, + schema.version, + schema.ENGINE_VERSION, + now_ts, + schema_str, ) class SqlJobClientWithStaging(SqlJobClientBase, WithStagingDataset): @contextlib.contextmanager - def with_staging_dataset(self)-> Iterator["SqlJobClientBase"]: + def with_staging_dataset(self) -> Iterator["SqlJobClientBase"]: with self.sql_client.with_staging_dataset(True): yield self diff --git a/dlt/destinations/job_impl.py b/dlt/destinations/job_impl.py index fb3ba48b6d..acc7d08903 100644 --- a/dlt/destinations/job_impl.py +++ b/dlt/destinations/job_impl.py @@ -1,11 +1,11 @@ import os import tempfile # noqa: 251 +from dlt.common.destination.reference import FollowupJob, LoadJob, NewLoadJob, TLoadJobState from dlt.common.storages import FileStorage - -from dlt.common.destination.reference import NewLoadJob, FollowupJob, TLoadJobState, LoadJob from dlt.common.storages.load_storage import ParsedLoadJobFileName + class EmptyLoadJobWithoutFollowup(LoadJob): def __init__(self, file_name: str, status: TLoadJobState, exception: str = None) -> None: self._status = status @@ -13,7 +13,9 @@ def __init__(self, file_name: str, status: TLoadJobState, exception: str = None) super().__init__(file_name) @classmethod - def from_file_path(cls, file_path: str, status: TLoadJobState, message: str = None) -> "EmptyLoadJobWithoutFollowup": + def from_file_path( + cls, file_path: str, status: TLoadJobState, message: str = None + ) -> "EmptyLoadJobWithoutFollowup": return cls(FileStorage.get_file_name_from_file_path(file_path), status, exception=message) def state(self) -> TLoadJobState: @@ -38,9 +40,11 @@ def new_file_path(self) -> str: """Path to a newly created temporary job file""" return self._new_file_path -class NewReferenceJob(NewLoadJobImpl): - def __init__(self, file_name: str, status: TLoadJobState, exception: str = None, remote_path: str = None) -> None: +class NewReferenceJob(NewLoadJobImpl): + def __init__( + self, file_name: str, status: TLoadJobState, exception: str = None, remote_path: str = None + ) -> None: file_name = os.path.splitext(file_name)[0] + ".reference" super().__init__(file_name, status, exception) self._remote_path = remote_path diff --git a/dlt/destinations/motherduck/__init__.py b/dlt/destinations/motherduck/__init__.py index 493cd9834b..0200bd55c6 100644 --- a/dlt/destinations/motherduck/__init__.py +++ b/dlt/destinations/motherduck/__init__.py @@ -1,18 +1,25 @@ from typing import Type -from dlt.common.schema.schema import Schema -from dlt.common.configuration import with_config, known_sections +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE +from dlt.common.configuration import known_sections, with_config from dlt.common.configuration.accessors import config -from dlt.common.data_writers.escape import escape_postgres_identifier, escape_duckdb_literal +from dlt.common.data_writers.escape import escape_duckdb_literal, escape_postgres_identifier from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import JobClientBase, DestinationClientConfiguration -from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE - +from dlt.common.destination.reference import DestinationClientConfiguration, JobClientBase +from dlt.common.schema.schema import Schema from dlt.destinations.motherduck.configuration import MotherDuckClientConfiguration -@with_config(spec=MotherDuckClientConfiguration, sections=(known_sections.DESTINATION, "motherduck",)) -def _configure(config: MotherDuckClientConfiguration = config.value) -> MotherDuckClientConfiguration: +@with_config( + spec=MotherDuckClientConfiguration, + sections=( + known_sections.DESTINATION, + "motherduck", + ), +) +def _configure( + config: MotherDuckClientConfiguration = config.value, +) -> MotherDuckClientConfiguration: return config @@ -38,7 +45,9 @@ def capabilities() -> DestinationCapabilitiesContext: return caps -def client(schema: Schema, initial_config: DestinationClientConfiguration = config.value) -> JobClientBase: +def client( + schema: Schema, initial_config: DestinationClientConfiguration = config.value +) -> JobClientBase: # import client when creating instance so capabilities and config specs can be accessed without dependencies installed from dlt.destinations.motherduck.motherduck import MotherDuckClient diff --git a/dlt/destinations/motherduck/configuration.py b/dlt/destinations/motherduck/configuration.py index 6f95bfce50..1243ac43c6 100644 --- a/dlt/destinations/motherduck/configuration.py +++ b/dlt/destinations/motherduck/configuration.py @@ -1,11 +1,10 @@ from typing import Any, ClassVar, Final, List from dlt.common.configuration import configspec +from dlt.common.configuration.exceptions import ConfigurationValueError from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration from dlt.common.typing import TSecretValue from dlt.common.utils import digest128 -from dlt.common.configuration.exceptions import ConfigurationValueError - from dlt.destinations.duckdb.configuration import DuckDbBaseCredentials MOTHERDUCK_DRIVERNAME = "md" @@ -35,7 +34,10 @@ def parse_native_representation(self, native_value: Any) -> None: def on_resolved(self) -> None: self._token_to_password() if self.drivername == MOTHERDUCK_DRIVERNAME and not self.password: - raise ConfigurationValueError("Motherduck schema 'md' was specified without corresponding token or password. The required format of connection string is: md:///?token=") + raise ConfigurationValueError( + "Motherduck schema 'md' was specified without corresponding token or password. The" + " required format of connection string is: md:///?token=" + ) @configspec @@ -43,7 +45,9 @@ class MotherDuckClientConfiguration(DestinationClientDwhWithStagingConfiguration destination_name: Final[str] = "motherduck" # type: ignore credentials: MotherDuckCredentials - create_indexes: bool = False # should unique indexes be created, this slows loading down massively + create_indexes: bool = ( + False # should unique indexes be created, this slows loading down massively + ) def fingerprint(self) -> str: """Returns a fingerprint of user access token""" diff --git a/dlt/destinations/motherduck/motherduck.py b/dlt/destinations/motherduck/motherduck.py index 93c0ed163b..bcb0d84862 100644 --- a/dlt/destinations/motherduck/motherduck.py +++ b/dlt/destinations/motherduck/motherduck.py @@ -2,23 +2,17 @@ from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.schema import Schema - - from dlt.destinations.duckdb.duck import DuckDbClient from dlt.destinations.motherduck import capabilities -from dlt.destinations.motherduck.sql_client import MotherDuckSqlClient from dlt.destinations.motherduck.configuration import MotherDuckClientConfiguration +from dlt.destinations.motherduck.sql_client import MotherDuckSqlClient class MotherDuckClient(DuckDbClient): - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() def __init__(self, schema: Schema, config: MotherDuckClientConfiguration) -> None: super().__init__(schema, config) # type: ignore - sql_client = MotherDuckSqlClient( - config.normalize_dataset_name(schema), - config.credentials - ) + sql_client = MotherDuckSqlClient(config.normalize_dataset_name(schema), config.credentials) self.config: MotherDuckClientConfiguration = config # type: ignore self.sql_client: MotherDuckSqlClient = sql_client diff --git a/dlt/destinations/motherduck/sql_client.py b/dlt/destinations/motherduck/sql_client.py index 2fc664a2e8..db52afa19a 100644 --- a/dlt/destinations/motherduck/sql_client.py +++ b/dlt/destinations/motherduck/sql_client.py @@ -1,20 +1,27 @@ -import duckdb - from contextlib import contextmanager from typing import Any, AnyStr, ClassVar, Iterator, Optional, Sequence -from dlt.common.destination import DestinationCapabilitiesContext -from dlt.destinations.exceptions import DatabaseTerminalException, DatabaseTransientException, DatabaseUndefinedRelation -from dlt.destinations.typing import DBApi, DBApiCursor, DBTransaction, DataFrame -from dlt.destinations.sql_client import SqlClientBase, DBApiCursorImpl, raise_database_error, raise_open_connection_error +import duckdb -from dlt.destinations.duckdb.sql_client import DuckDbSqlClient, DuckDBDBApiCursorImpl +from dlt.common.destination import DestinationCapabilitiesContext +from dlt.destinations.duckdb.sql_client import DuckDBDBApiCursorImpl, DuckDbSqlClient +from dlt.destinations.exceptions import ( + DatabaseTerminalException, + DatabaseTransientException, + DatabaseUndefinedRelation, +) from dlt.destinations.motherduck import capabilities from dlt.destinations.motherduck.configuration import MotherDuckCredentials +from dlt.destinations.sql_client import ( + DBApiCursorImpl, + SqlClientBase, + raise_database_error, + raise_open_connection_error, +) +from dlt.destinations.typing import DataFrame, DBApi, DBApiCursor, DBTransaction class MotherDuckSqlClient(DuckDbSqlClient): - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() def __init__(self, dataset_name: str, credentials: MotherDuckCredentials) -> None: @@ -22,6 +29,12 @@ def __init__(self, dataset_name: str, credentials: MotherDuckCredentials) -> Non self.database_name = credentials.database def fully_qualified_dataset_name(self, escape: bool = True) -> str: - database_name = self.capabilities.escape_identifier(self.database_name) if escape else self.database_name - dataset_name = self.capabilities.escape_identifier(self.dataset_name) if escape else self.dataset_name + database_name = ( + self.capabilities.escape_identifier(self.database_name) + if escape + else self.database_name + ) + dataset_name = ( + self.capabilities.escape_identifier(self.dataset_name) if escape else self.dataset_name + ) return f"{database_name}.{dataset_name}" diff --git a/dlt/destinations/path_utils.py b/dlt/destinations/path_utils.py index a6cf634452..a4b616879a 100644 --- a/dlt/destinations/path_utils.py +++ b/dlt/destinations/path_utils.py @@ -1,24 +1,15 @@ # this can probably go some other place, but it is shared by destinations, so for now it is here +import re from typing import List, Sequence, Tuple import pendulum -import re -from dlt.destinations.exceptions import InvalidFilesystemLayout, CantExtractTablePrefix +from dlt.destinations.exceptions import CantExtractTablePrefix, InvalidFilesystemLayout # TODO: ensure layout only has supported placeholders -SUPPORTED_PLACEHOLDERS = { - "schema_name", - "table_name", - "load_id", - "file_id", - "ext", - "curr_date" -} +SUPPORTED_PLACEHOLDERS = {"schema_name", "table_name", "load_id", "file_id", "ext", "curr_date"} -SUPPORTED_TABLE_NAME_PREFIX_PLACEHOLDERS = ( - "schema_name", -) +SUPPORTED_TABLE_NAME_PREFIX_PLACEHOLDERS = ("schema_name",) def check_layout(layout: str) -> List[str]: @@ -28,11 +19,14 @@ def check_layout(layout: str) -> List[str]: raise InvalidFilesystemLayout(invalid_placeholders) return placeholders + def get_placeholders(layout: str) -> List[str]: - return re.findall(r'\{(.*?)\}', layout) + return re.findall(r"\{(.*?)\}", layout) -def create_path(layout: str, schema_name: str, table_name: str, load_id: str, file_id: str, ext: str) -> str: +def create_path( + layout: str, schema_name: str, table_name: str, load_id: str, file_id: str, ext: str +) -> str: """create a filepath from the layout and our default params""" placeholders = check_layout(layout) path = layout.format( @@ -41,7 +35,7 @@ def create_path(layout: str, schema_name: str, table_name: str, load_id: str, fi load_id=load_id, file_id=file_id, ext=ext, - curr_date=str(pendulum.today()) + curr_date=str(pendulum.today()), ) # if extension is not defined, we append it at the end if "ext" not in placeholders: @@ -51,11 +45,11 @@ def create_path(layout: str, schema_name: str, table_name: str, load_id: str, fi def get_table_prefix_layout( layout: str, - supported_prefix_placeholders: Sequence[str] = SUPPORTED_TABLE_NAME_PREFIX_PLACEHOLDERS + supported_prefix_placeholders: Sequence[str] = SUPPORTED_TABLE_NAME_PREFIX_PLACEHOLDERS, ) -> str: """get layout fragment that defines positions of the table, cutting other placeholders - allowed `supported_prefix_placeholders` that may appear before table. + allowed `supported_prefix_placeholders` that may appear before table. """ placeholders = get_placeholders(layout) @@ -67,14 +61,20 @@ def get_table_prefix_layout( # fail if any other prefix is defined before table_name if [p for p in placeholders[:table_name_index] if p not in supported_prefix_placeholders]: if len(supported_prefix_placeholders) == 0: - details = "No other placeholders are allowed before {table_name} but you have %s present. " % placeholders[:table_name_index] + details = ( + "No other placeholders are allowed before {table_name} but you have %s present. " + % placeholders[:table_name_index] + ) else: - details = "Only %s are allowed before {table_name} but you have %s present. " % (supported_prefix_placeholders, placeholders[:table_name_index]) + details = "Only %s are allowed before {table_name} but you have %s present. " % ( + supported_prefix_placeholders, + placeholders[:table_name_index], + ) raise CantExtractTablePrefix(layout, details) # we include the char after the table_name here, this should be a separator not a new placeholder # this is to prevent selecting tables that have the same starting name - prefix = layout[:layout.index("{table_name}") + 13] + prefix = layout[: layout.index("{table_name}") + 13] if prefix[-1] == "{": raise CantExtractTablePrefix(layout, "A separator is required after a {table_name}. ") diff --git a/dlt/destinations/postgres/__init__.py b/dlt/destinations/postgres/__init__.py index e8904c075f..58e6268254 100644 --- a/dlt/destinations/postgres/__init__.py +++ b/dlt/destinations/postgres/__init__.py @@ -1,18 +1,23 @@ from typing import Type -from dlt.common.schema.schema import Schema -from dlt.common.configuration import with_config, known_sections +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE +from dlt.common.configuration import known_sections, with_config from dlt.common.configuration.accessors import config from dlt.common.data_writers.escape import escape_postgres_identifier, escape_postgres_literal from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import JobClientBase, DestinationClientConfiguration -from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE +from dlt.common.destination.reference import DestinationClientConfiguration, JobClientBase +from dlt.common.schema.schema import Schema from dlt.common.wei import EVM_DECIMAL_PRECISION - from dlt.destinations.postgres.configuration import PostgresClientConfiguration -@with_config(spec=PostgresClientConfiguration, sections=(known_sections.DESTINATION, "postgres",)) +@with_config( + spec=PostgresClientConfiguration, + sections=( + known_sections.DESTINATION, + "postgres", + ), +) def _configure(config: PostgresClientConfiguration = config.value) -> PostgresClientConfiguration: return config @@ -27,7 +32,7 @@ def capabilities() -> DestinationCapabilitiesContext: caps.escape_identifier = escape_postgres_identifier caps.escape_literal = escape_postgres_literal caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) - caps.wei_precision = (2*EVM_DECIMAL_PRECISION, EVM_DECIMAL_PRECISION) + caps.wei_precision = (2 * EVM_DECIMAL_PRECISION, EVM_DECIMAL_PRECISION) caps.max_identifier_length = 63 caps.max_column_identifier_length = 63 caps.max_query_length = 32 * 1024 * 1024 @@ -39,7 +44,9 @@ def capabilities() -> DestinationCapabilitiesContext: return caps -def client(schema: Schema, initial_config: DestinationClientConfiguration = config.value) -> JobClientBase: +def client( + schema: Schema, initial_config: DestinationClientConfiguration = config.value +) -> JobClientBase: # import client when creating instance so capabilities and config specs can be accessed without dependencies installed from dlt.destinations.postgres.postgres import PostgresClient diff --git a/dlt/destinations/postgres/configuration.py b/dlt/destinations/postgres/configuration.py index e3e3af17d6..4bdeb19172 100644 --- a/dlt/destinations/postgres/configuration.py +++ b/dlt/destinations/postgres/configuration.py @@ -1,12 +1,12 @@ -from typing import Final, ClassVar, Any, List +from typing import Any, ClassVar, Final, List + from sqlalchemy.engine import URL from dlt.common.configuration import configspec from dlt.common.configuration.specs import ConnectionStringCredentials -from dlt.common.utils import digest128 -from dlt.common.typing import TSecretValue - from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration +from dlt.common.typing import TSecretValue +from dlt.common.utils import digest128 @configspec diff --git a/dlt/destinations/postgres/postgres.py b/dlt/destinations/postgres/postgres.py index 454b2b80d4..751bbcc5cf 100644 --- a/dlt/destinations/postgres/postgres.py +++ b/dlt/destinations/postgres/postgres.py @@ -1,21 +1,17 @@ -from typing import ClassVar, Dict, Optional, Sequence, List, Any +from typing import Any, ClassVar, Dict, List, Optional, Sequence -from dlt.common.wei import EVM_DECIMAL_PRECISION -from dlt.common.destination.reference import NewLoadJob -from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.data_types import TDataType -from dlt.common.schema import TColumnSchema, TColumnHint, Schema +from dlt.common.destination import DestinationCapabilitiesContext +from dlt.common.destination.reference import NewLoadJob +from dlt.common.schema import Schema, TColumnHint, TColumnSchema from dlt.common.schema.typing import TTableSchema - -from dlt.destinations.sql_jobs import SqlStagingCopyJob - +from dlt.common.wei import EVM_DECIMAL_PRECISION from dlt.destinations.insert_job_client import InsertValuesJobClient - from dlt.destinations.postgres import capabilities -from dlt.destinations.postgres.sql_client import Psycopg2SqlClient from dlt.destinations.postgres.configuration import PostgresClientConfiguration +from dlt.destinations.postgres.sql_client import Psycopg2SqlClient from dlt.destinations.sql_client import SqlClientBase - +from dlt.destinations.sql_jobs import SqlStagingCopyJob SCT_TO_PGT: Dict[TDataType, str] = { "complex": "jsonb", @@ -26,7 +22,7 @@ "date": "date", "bigint": "bigint", "binary": "bytea", - "decimal": "numeric(%i,%i)" + "decimal": "numeric(%i,%i)", } PGT_TO_SCT: Dict[str, TDataType] = { @@ -38,17 +34,17 @@ "date": "date", "bigint": "bigint", "bytea": "binary", - "numeric": "decimal" + "numeric": "decimal", } -HINT_TO_POSTGRES_ATTR: Dict[TColumnHint, str] = { - "unique": "UNIQUE" -} +HINT_TO_POSTGRES_ATTR: Dict[TColumnHint, str] = {"unique": "UNIQUE"} -class PostgresStagingCopyJob(SqlStagingCopyJob): +class PostgresStagingCopyJob(SqlStagingCopyJob): @classmethod - def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any]) -> List[str]: + def generate_sql( + cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any] + ) -> List[str]: sql: List[str] = [] for table in table_chain: with sql_client.with_staging_dataset(staging=True): @@ -57,29 +53,35 @@ def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClient # drop destination table sql.append(f"DROP TABLE IF EXISTS {table_name};") # moving staging table to destination schema - sql.append(f"ALTER TABLE {staging_table_name} SET SCHEMA {sql_client.fully_qualified_dataset_name()};") + sql.append( + f"ALTER TABLE {staging_table_name} SET SCHEMA" + f" {sql_client.fully_qualified_dataset_name()};" + ) # recreate staging table sql.append(f"CREATE TABLE {staging_table_name} (like {table_name} including all);") return sql -class PostgresClient(InsertValuesJobClient): +class PostgresClient(InsertValuesJobClient): capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() def __init__(self, schema: Schema, config: PostgresClientConfiguration) -> None: - sql_client = Psycopg2SqlClient( - config.normalize_dataset_name(schema), - config.credentials - ) + sql_client = Psycopg2SqlClient(config.normalize_dataset_name(schema), config.credentials) super().__init__(schema, config, sql_client) self.config: PostgresClientConfiguration = config self.sql_client = sql_client self.active_hints = HINT_TO_POSTGRES_ATTR if self.config.create_indexes else {} def _get_column_def_sql(self, c: TColumnSchema) -> str: - hints_str = " ".join(self.active_hints.get(h, "") for h in self.active_hints.keys() if c.get(h, False) is True) + hints_str = " ".join( + self.active_hints.get(h, "") + for h in self.active_hints.keys() + if c.get(h, False) is True + ) column_name = self.capabilities.escape_identifier(c["name"]) - return f"{column_name} {self._to_db_type(c['data_type'])} {hints_str} {self._gen_not_null(c['nullable'])}" + return ( + f"{column_name} {self._to_db_type(c['data_type'])} {hints_str} {self._gen_not_null(c['nullable'])}" + ) def _create_optimized_replace_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: return PostgresStagingCopyJob.from_table_chain(table_chain, self.sql_client) @@ -101,4 +103,3 @@ def _from_db_type(cls, pq_t: str, precision: Optional[int], scale: Optional[int] if (precision, scale) == cls.capabilities.wei_precision: return "wei" return PGT_TO_SCT.get(pq_t, "text") - diff --git a/dlt/destinations/postgres/sql_client.py b/dlt/destinations/postgres/sql_client.py index dd0ce0f24a..35c21d0c69 100644 --- a/dlt/destinations/postgres/sql_client.py +++ b/dlt/destinations/postgres/sql_client.py @@ -4,7 +4,7 @@ if platform.python_implementation() == "PyPy": import psycopg2cffi as psycopg2 - from psycopg2cffi.sql import SQL, Composed, Composable + from psycopg2cffi.sql import SQL, Composable, Composed else: import psycopg2 from psycopg2.sql import SQL, Composed, Composable @@ -12,15 +12,23 @@ from contextlib import contextmanager from typing import Any, AnyStr, ClassVar, Iterator, Optional, Sequence -from dlt.destinations.exceptions import DatabaseTerminalException, DatabaseTransientException, DatabaseUndefinedRelation +from dlt.destinations.exceptions import ( + DatabaseTerminalException, + DatabaseTransientException, + DatabaseUndefinedRelation, +) +from dlt.destinations.postgres import capabilities +from dlt.destinations.postgres.configuration import PostgresCredentials +from dlt.destinations.sql_client import ( + DBApiCursorImpl, + SqlClientBase, + raise_database_error, + raise_open_connection_error, +) from dlt.destinations.typing import DBApi, DBApiCursor, DBTransaction -from dlt.destinations.sql_client import DBApiCursorImpl, SqlClientBase, raise_database_error, raise_open_connection_error -from dlt.destinations.postgres.configuration import PostgresCredentials -from dlt.destinations.postgres import capabilities class Psycopg2SqlClient(SqlClientBase["psycopg2.connection"], DBTransaction): - dbapi: ClassVar[DBApi] = psycopg2 capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() @@ -31,9 +39,9 @@ def __init__(self, dataset_name: str, credentials: PostgresCredentials) -> None: def open_connection(self) -> "psycopg2.connection": self._conn = psycopg2.connect( - dsn=self.credentials.to_native_representation(), - options=f"-c search_path={self.fully_qualified_dataset_name()},public" - ) + dsn=self.credentials.to_native_representation(), + options=f"-c search_path={self.fully_qualified_dataset_name()},public", + ) # we'll provide explicit transactions see _reset self._reset_connection() return self._conn @@ -69,7 +77,9 @@ def native_connection(self) -> "psycopg2.connection": return self._conn # @raise_database_error - def execute_sql(self, sql: AnyStr, *args: Any, **kwargs: Any) -> Optional[Sequence[Sequence[Any]]]: + def execute_sql( + self, sql: AnyStr, *args: Any, **kwargs: Any + ) -> Optional[Sequence[Sequence[Any]]]: with self.execute_query(sql, *args, **kwargs) as curr: if curr.description is None: return None @@ -94,13 +104,17 @@ def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DB self.open_connection() raise outer - def execute_fragments(self, fragments: Sequence[AnyStr], *args: Any, **kwargs: Any) -> Optional[Sequence[Sequence[Any]]]: + def execute_fragments( + self, fragments: Sequence[AnyStr], *args: Any, **kwargs: Any + ) -> Optional[Sequence[Sequence[Any]]]: # compose the statements using psycopg2 library - composed = Composed(sql if isinstance(sql, Composable) else SQL(sql) for sql in fragments) + composed = Composed(sql if isinstance(sql, Composable) else SQL(sql) for sql in fragments) return self.execute_sql(composed, *args, **kwargs) def fully_qualified_dataset_name(self, escape: bool = True) -> str: - return self.capabilities.escape_identifier(self.dataset_name) if escape else self.dataset_name + return ( + self.capabilities.escape_identifier(self.dataset_name) if escape else self.dataset_name + ) def _reset_connection(self) -> None: # self._conn.autocommit = True @@ -111,13 +125,23 @@ def _reset_connection(self) -> None: def _make_database_exception(cls, ex: Exception) -> Exception: if isinstance(ex, (psycopg2.errors.UndefinedTable, psycopg2.errors.InvalidSchemaName)): raise DatabaseUndefinedRelation(ex) - elif isinstance(ex, (psycopg2.OperationalError, psycopg2.InternalError, psycopg2.errors.SyntaxError, psycopg2.errors.UndefinedFunction)): + elif isinstance( + ex, + ( + psycopg2.OperationalError, + psycopg2.InternalError, + psycopg2.errors.SyntaxError, + psycopg2.errors.UndefinedFunction, + ), + ): term = cls._maybe_make_terminal_exception_from_data_error(ex) if term: return term else: return DatabaseTransientException(ex) - elif isinstance(ex, (psycopg2.DataError, psycopg2.ProgrammingError, psycopg2.IntegrityError)): + elif isinstance( + ex, (psycopg2.DataError, psycopg2.ProgrammingError, psycopg2.IntegrityError) + ): return DatabaseTerminalException(ex) elif isinstance(ex, TypeError): # psycopg2 raises TypeError on malformed query parameters @@ -128,7 +152,9 @@ def _make_database_exception(cls, ex: Exception) -> Exception: return ex @staticmethod - def _maybe_make_terminal_exception_from_data_error(pg_ex: psycopg2.DataError) -> Optional[Exception]: + def _maybe_make_terminal_exception_from_data_error( + pg_ex: psycopg2.DataError, + ) -> Optional[Exception]: return None @staticmethod diff --git a/dlt/destinations/redshift/__init__.py b/dlt/destinations/redshift/__init__.py index 96741e86cd..14ab76f7b6 100644 --- a/dlt/destinations/redshift/__init__.py +++ b/dlt/destinations/redshift/__init__.py @@ -1,17 +1,22 @@ from typing import Type -from dlt.common.schema.schema import Schema -from dlt.common.configuration import with_config, known_sections +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE +from dlt.common.configuration import known_sections, with_config from dlt.common.configuration.accessors import config from dlt.common.data_writers.escape import escape_redshift_identifier, escape_redshift_literal from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import JobClientBase, DestinationClientConfiguration -from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE - +from dlt.common.destination.reference import DestinationClientConfiguration, JobClientBase +from dlt.common.schema.schema import Schema from dlt.destinations.redshift.configuration import RedshiftClientConfiguration -@with_config(spec=RedshiftClientConfiguration, sections=(known_sections.DESTINATION, "redshift",)) +@with_config( + spec=RedshiftClientConfiguration, + sections=( + known_sections.DESTINATION, + "redshift", + ), +) def _configure(config: RedshiftClientConfiguration = config.value) -> RedshiftClientConfiguration: return config @@ -38,7 +43,9 @@ def capabilities() -> DestinationCapabilitiesContext: return caps -def client(schema: Schema, initial_config: DestinationClientConfiguration = config.value) -> JobClientBase: +def client( + schema: Schema, initial_config: DestinationClientConfiguration = config.value +) -> JobClientBase: # import client when creating instance so capabilities and config specs can be accessed without dependencies installed from dlt.destinations.redshift.redshift import RedshiftClient diff --git a/dlt/destinations/redshift/configuration.py b/dlt/destinations/redshift/configuration.py index 7cb13b996f..310349b813 100644 --- a/dlt/destinations/redshift/configuration.py +++ b/dlt/destinations/redshift/configuration.py @@ -1,10 +1,9 @@ from typing import Final, Optional -from dlt.common.typing import TSecretValue from dlt.common.configuration import configspec +from dlt.common.typing import TSecretValue from dlt.common.utils import digest128 - -from dlt.destinations.postgres.configuration import PostgresCredentials, PostgresClientConfiguration +from dlt.destinations.postgres.configuration import PostgresClientConfiguration, PostgresCredentials @configspec diff --git a/dlt/destinations/redshift/redshift.py b/dlt/destinations/redshift/redshift.py index 3ba4ed9b97..59c554f7bd 100644 --- a/dlt/destinations/redshift/redshift.py +++ b/dlt/destinations/redshift/redshift.py @@ -1,36 +1,34 @@ -import platform import os +import platform +from dlt.common.schema.utils import table_schema_has_type from dlt.destinations.postgres.sql_client import Psycopg2SqlClient -from dlt.common.schema.utils import table_schema_has_type if platform.python_implementation() == "PyPy": import psycopg2cffi as psycopg2 + # from psycopg2cffi.sql import SQL, Composed else: import psycopg2 + # from psycopg2.sql import SQL, Composed -from typing import ClassVar, Dict, List, Optional, Sequence, Any +from typing import Any, ClassVar, Dict, List, Optional, Sequence -from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import NewLoadJob, CredentialsConfiguration +from dlt.common.configuration.specs import AwsCredentialsWithoutDefaults from dlt.common.data_types import TDataType -from dlt.common.schema import TColumnSchema, TColumnHint, Schema +from dlt.common.destination import DestinationCapabilitiesContext +from dlt.common.destination.reference import CredentialsConfiguration, NewLoadJob +from dlt.common.schema import Schema, TColumnHint, TColumnSchema from dlt.common.schema.typing import TTableSchema -from dlt.common.configuration.specs import AwsCredentialsWithoutDefaults - -from dlt.destinations.insert_job_client import InsertValuesJobClient -from dlt.destinations.sql_jobs import SqlMergeJob from dlt.destinations.exceptions import DatabaseTerminalException, LoadJobTerminalException +from dlt.destinations.insert_job_client import InsertValuesJobClient from dlt.destinations.job_client_impl import CopyRemoteFileLoadJob, LoadJob - +from dlt.destinations.job_impl import NewReferenceJob from dlt.destinations.redshift import capabilities from dlt.destinations.redshift.configuration import RedshiftClientConfiguration -from dlt.destinations.job_impl import NewReferenceJob from dlt.destinations.sql_client import SqlClientBase - - +from dlt.destinations.sql_jobs import SqlMergeJob SCT_TO_PGT: Dict[TDataType, str] = { "complex": "super", @@ -41,7 +39,7 @@ "timestamp": "timestamp with time zone", "bigint": "bigint", "binary": "varbinary", - "decimal": "numeric(%i,%i)" + "decimal": "numeric(%i,%i)", } PGT_TO_SCT: Dict[str, TDataType] = { @@ -53,23 +51,24 @@ "timestamp with time zone": "timestamp", "bigint": "bigint", "binary varying": "binary", - "numeric": "decimal" + "numeric": "decimal", } HINT_TO_REDSHIFT_ATTR: Dict[TColumnHint, str] = { "cluster": "DISTKEY", # it is better to not enforce constraints in redshift # "primary_key": "PRIMARY KEY", - "sort": "SORTKEY" + "sort": "SORTKEY", } class RedshiftSqlClient(Psycopg2SqlClient): - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() @staticmethod - def _maybe_make_terminal_exception_from_data_error(pg_ex: psycopg2.DataError) -> Optional[Exception]: + def _maybe_make_terminal_exception_from_data_error( + pg_ex: psycopg2.DataError, + ) -> Optional[Exception]: if "Cannot insert a NULL value into column" in pg_ex.pgerror: # NULL violations is internal error, probably a redshift thing return DatabaseTerminalException(pg_ex) @@ -79,26 +78,33 @@ def _maybe_make_terminal_exception_from_data_error(pg_ex: psycopg2.DataError) -> return DatabaseTerminalException(pg_ex) return None -class RedshiftCopyFileLoadJob(CopyRemoteFileLoadJob): - def __init__(self, table: TTableSchema, - file_path: str, - sql_client: SqlClientBase[Any], - staging_credentials: Optional[CredentialsConfiguration] = None, - staging_iam_role: str = None) -> None: +class RedshiftCopyFileLoadJob(CopyRemoteFileLoadJob): + def __init__( + self, + table: TTableSchema, + file_path: str, + sql_client: SqlClientBase[Any], + staging_credentials: Optional[CredentialsConfiguration] = None, + staging_iam_role: str = None, + ) -> None: self._staging_iam_role = staging_iam_role super().__init__(table, file_path, sql_client, staging_credentials) def execute(self, table: TTableSchema, bucket_path: str) -> None: - # we assume s3 credentials where provided for the staging credentials = "" if self._staging_iam_role: credentials = f"IAM_ROLE '{self._staging_iam_role}'" - elif self._staging_credentials and isinstance(self._staging_credentials, AwsCredentialsWithoutDefaults): + elif self._staging_credentials and isinstance( + self._staging_credentials, AwsCredentialsWithoutDefaults + ): aws_access_key = self._staging_credentials.aws_access_key_id aws_secret_key = self._staging_credentials.aws_secret_access_key - credentials = f"CREDENTIALS 'aws_access_key_id={aws_access_key};aws_secret_access_key={aws_secret_key}'" + credentials = ( + "CREDENTIALS" + f" 'aws_access_key_id={aws_access_key};aws_secret_access_key={aws_secret_key}'" + ) table_name = table["name"] # get format @@ -108,7 +114,11 @@ def execute(self, table: TTableSchema, bucket_path: str) -> None: compression = "" if ext == "jsonl": if table_schema_has_type(table, "binary"): - raise LoadJobTerminalException(self.file_name(), "Redshift cannot load VARBYTE columns from json files. Switch to parquet to load binaries.") + raise LoadJobTerminalException( + self.file_name(), + "Redshift cannot load VARBYTE columns from json files. Switch to parquet to" + " load binaries.", + ) file_type = "FORMAT AS JSON 'auto'" dateformat = "dateformat 'auto' timeformat 'auto'" compression = "GZIP" @@ -136,28 +146,36 @@ def exception(self) -> str: # this part of code should be never reached raise NotImplementedError() -class RedshiftMergeJob(SqlMergeJob): +class RedshiftMergeJob(SqlMergeJob): @classmethod - def gen_key_table_clauses(cls, root_table_name: str, staging_root_table_name: str, key_clauses: Sequence[str], for_delete: bool) -> List[str]: + def gen_key_table_clauses( + cls, + root_table_name: str, + staging_root_table_name: str, + key_clauses: Sequence[str], + for_delete: bool, + ) -> List[str]: """Generate sql clauses that may be used to select or delete rows in root table of destination dataset - A list of clauses may be returned for engines that do not support OR in subqueries. Like BigQuery + A list of clauses may be returned for engines that do not support OR in subqueries. Like BigQuery """ if for_delete: - return [f"FROM {root_table_name} WHERE EXISTS (SELECT 1 FROM {staging_root_table_name} WHERE {' OR '.join([c.format(d=root_table_name,s=staging_root_table_name) for c in key_clauses])})"] - return SqlMergeJob.gen_key_table_clauses(root_table_name, staging_root_table_name, key_clauses, for_delete) + return [ + f"FROM {root_table_name} WHERE EXISTS (SELECT 1 FROM" + f" {staging_root_table_name} WHERE" + f" {' OR '.join([c.format(d=root_table_name,s=staging_root_table_name) for c in key_clauses])})" + ] + return SqlMergeJob.gen_key_table_clauses( + root_table_name, staging_root_table_name, key_clauses, for_delete + ) class RedshiftClient(InsertValuesJobClient): - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() def __init__(self, schema: Schema, config: RedshiftClientConfiguration) -> None: - sql_client = RedshiftSqlClient ( - config.normalize_dataset_name(schema), - config.credentials - ) + sql_client = RedshiftSqlClient(config.normalize_dataset_name(schema), config.credentials) super().__init__(schema, config, sql_client) self.sql_client = sql_client self.config: RedshiftClientConfiguration = config @@ -166,16 +184,30 @@ def _create_merge_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: return RedshiftMergeJob.from_table_chain(table_chain, self.sql_client) def _get_column_def_sql(self, c: TColumnSchema) -> str: - hints_str = " ".join(HINT_TO_REDSHIFT_ATTR.get(h, "") for h in HINT_TO_REDSHIFT_ATTR.keys() if c.get(h, False) is True) + hints_str = " ".join( + HINT_TO_REDSHIFT_ATTR.get(h, "") + for h in HINT_TO_REDSHIFT_ATTR.keys() + if c.get(h, False) is True + ) column_name = self.capabilities.escape_identifier(c["name"]) - return f"{column_name} {self._to_db_type(c['data_type'])} {hints_str} {self._gen_not_null(c['nullable'])}" + return ( + f"{column_name} {self._to_db_type(c['data_type'])} {hints_str} {self._gen_not_null(c['nullable'])}" + ) def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: """Starts SqlLoadJob for files ending with .sql or returns None to let derived classes to handle their specific jobs""" job = super().start_file_load(table, file_path, load_id) if not job: - assert NewReferenceJob.is_reference_job(file_path), "Redshift must use staging to load files" - job = RedshiftCopyFileLoadJob(table, file_path, self.sql_client, staging_credentials=self.config.staging_config.credentials, staging_iam_role=self.config.staging_iam_role) + assert NewReferenceJob.is_reference_job( + file_path + ), "Redshift must use staging to load files" + job = RedshiftCopyFileLoadJob( + table, + file_path, + self.sql_client, + staging_credentials=self.config.staging_config.credentials, + staging_iam_role=self.config.staging_iam_role, + ) return job @classmethod @@ -192,4 +224,3 @@ def _from_db_type(cls, pq_t: str, precision: Optional[int], scale: Optional[int] if (precision, scale) == cls.capabilities.wei_precision: return "wei" return PGT_TO_SCT.get(pq_t, "text") - diff --git a/dlt/destinations/snowflake/__init__.py b/dlt/destinations/snowflake/__init__.py index 5d32bc41fd..029e4c5863 100644 --- a/dlt/destinations/snowflake/__init__.py +++ b/dlt/destinations/snowflake/__init__.py @@ -1,18 +1,22 @@ from typing import Type -from dlt.common.data_writers.escape import escape_bigquery_identifier -from dlt.common.schema.schema import Schema -from dlt.common.configuration import with_config, known_sections +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE +from dlt.common.configuration import known_sections, with_config from dlt.common.configuration.accessors import config +from dlt.common.data_writers.escape import escape_bigquery_identifier, escape_snowflake_identifier from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import JobClientBase, DestinationClientConfiguration -from dlt.common.data_writers.escape import escape_snowflake_identifier -from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE - +from dlt.common.destination.reference import DestinationClientConfiguration, JobClientBase +from dlt.common.schema.schema import Schema from dlt.destinations.snowflake.configuration import SnowflakeClientConfiguration -@with_config(spec=SnowflakeClientConfiguration, sections=(known_sections.DESTINATION, "snowflake",)) +@with_config( + spec=SnowflakeClientConfiguration, + sections=( + known_sections.DESTINATION, + "snowflake", + ), +) def _configure(config: SnowflakeClientConfiguration = config.value) -> SnowflakeClientConfiguration: return config @@ -37,7 +41,9 @@ def capabilities() -> DestinationCapabilitiesContext: return caps -def client(schema: Schema, initial_config: DestinationClientConfiguration = config.value) -> JobClientBase: +def client( + schema: Schema, initial_config: DestinationClientConfiguration = config.value +) -> JobClientBase: # import client when creating instance so capabilities and config specs can be accessed without dependencies installed from dlt.destinations.snowflake.snowflake import SnowflakeClient diff --git a/dlt/destinations/snowflake/configuration.py b/dlt/destinations/snowflake/configuration.py index a27116f3eb..a8f0cc8698 100644 --- a/dlt/destinations/snowflake/configuration.py +++ b/dlt/destinations/snowflake/configuration.py @@ -1,35 +1,38 @@ -from typing import Final, Optional, Any, Dict, ClassVar, List +from typing import Any, ClassVar, Dict, Final, List, Optional from sqlalchemy.engine import URL from dlt import version -from dlt.common.exceptions import MissingDependencyException -from dlt.common.typing import TSecretStrValue -from dlt.common.configuration.specs import ConnectionStringCredentials -from dlt.common.configuration.exceptions import ConfigurationValueError from dlt.common.configuration import configspec +from dlt.common.configuration.exceptions import ConfigurationValueError +from dlt.common.configuration.specs import ConnectionStringCredentials from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration +from dlt.common.exceptions import MissingDependencyException +from dlt.common.typing import TSecretStrValue from dlt.common.utils import digest128 def _read_private_key(private_key: str, password: Optional[str] = None) -> bytes: - """Load an encrypted or unencrypted private key from string. - """ + """Load an encrypted or unencrypted private key from string.""" try: from cryptography.hazmat.backends import default_backend - from cryptography.hazmat.primitives.asymmetric import rsa - from cryptography.hazmat.primitives.asymmetric import dsa from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.primitives.asymmetric import dsa, rsa except ModuleNotFoundError as e: - raise MissingDependencyException("SnowflakeCredentials with private key", dependencies=[f"{version.DLT_PKG_NAME}[snowflake]"]) from e + raise MissingDependencyException( + "SnowflakeCredentials with private key", + dependencies=[f"{version.DLT_PKG_NAME}[snowflake]"], + ) from e pkey = serialization.load_pem_private_key( - private_key.encode(), password.encode() if password is not None else None, backend=default_backend() + private_key.encode(), + password.encode() if password is not None else None, + backend=default_backend(), ) return pkey.private_bytes( encoding=serialization.Encoding.DER, format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption() + encryption_algorithm=serialization.NoEncryption(), ) @@ -48,24 +51,35 @@ class SnowflakeCredentials(ConnectionStringCredentials): def parse_native_representation(self, native_value: Any) -> None: super().parse_native_representation(native_value) - self.warehouse = self.query.get('warehouse') - self.role = self.query.get('role') - self.private_key = self.query.get('private_key') # type: ignore - self.private_key_passphrase = self.query.get('private_key_passphrase') # type: ignore + self.warehouse = self.query.get("warehouse") + self.role = self.query.get("role") + self.private_key = self.query.get("private_key") # type: ignore + self.private_key_passphrase = self.query.get("private_key_passphrase") # type: ignore if not self.is_partial() and (self.password or self.private_key): self.resolve() def on_resolved(self) -> None: if not self.password and not self.private_key: - raise ConfigurationValueError("Please specify password or private_key. SnowflakeCredentials supports password and private key authentication and one of those must be specified.") + raise ConfigurationValueError( + "Please specify password or private_key. SnowflakeCredentials supports password and" + " private key authentication and one of those must be specified." + ) def to_url(self) -> URL: query = dict(self.query or {}) - if self.warehouse and 'warehouse' not in query: - query['warehouse'] = self.warehouse - if self.role and 'role' not in query: - query['role'] = self.role - return URL.create(self.drivername, self.username, self.password, self.host, self.port, self.database, query) + if self.warehouse and "warehouse" not in query: + query["warehouse"] = self.warehouse + if self.role and "role" not in query: + query["role"] = self.role + return URL.create( + self.drivername, + self.username, + self.password, + self.host, + self.port, + self.database, + query, + ) def to_connector_params(self) -> Dict[str, Any]: private_key: Optional[bytes] = None diff --git a/dlt/destinations/snowflake/snowflake.py b/dlt/destinations/snowflake/snowflake.py index 3169bd2c06..a5a61d0ab0 100644 --- a/dlt/destinations/snowflake/snowflake.py +++ b/dlt/destinations/snowflake/snowflake.py @@ -1,26 +1,27 @@ -from typing import ClassVar, Dict, Optional, Sequence, Tuple, List, Any +from typing import Any, ClassVar, Dict, List, Optional, Sequence, Tuple from urllib.parse import urlparse -from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import FollowupJob, NewLoadJob, TLoadJobState, LoadJob, CredentialsConfiguration from dlt.common.configuration.specs import AwsCredentialsWithoutDefaults from dlt.common.data_types import TDataType -from dlt.common.storages.file_storage import FileStorage -from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns +from dlt.common.destination import DestinationCapabilitiesContext +from dlt.common.destination.reference import ( + CredentialsConfiguration, + FollowupJob, + LoadJob, + NewLoadJob, + TLoadJobState, +) +from dlt.common.schema import Schema, TColumnSchema, TTableSchemaColumns from dlt.common.schema.typing import TTableSchema - - -from dlt.destinations.job_client_impl import SqlJobClientWithStaging -from dlt.destinations.job_impl import EmptyLoadJob +from dlt.common.storages.file_storage import FileStorage from dlt.destinations.exceptions import LoadJobTerminalException - +from dlt.destinations.job_client_impl import SqlJobClientWithStaging +from dlt.destinations.job_impl import EmptyLoadJob, NewReferenceJob from dlt.destinations.snowflake import capabilities from dlt.destinations.snowflake.configuration import SnowflakeClientConfiguration from dlt.destinations.snowflake.sql_client import SnowflakeSqlClient -from dlt.destinations.sql_jobs import SqlStagingCopyJob -from dlt.destinations.snowflake.sql_client import SnowflakeSqlClient -from dlt.destinations.job_impl import NewReferenceJob from dlt.destinations.sql_client import SqlClientBase +from dlt.destinations.sql_jobs import SqlStagingCopyJob BIGINT_PRECISION = 19 MAX_NUMERIC_PRECISION = 38 @@ -45,14 +46,20 @@ "DATE": "date", "TIMESTAMP_TZ": "timestamp", "BINARY": "binary", - "VARIANT": "complex" + "VARIANT": "complex", } class SnowflakeLoadJob(LoadJob, FollowupJob): def __init__( - self, file_path: str, table_name: str, load_id: str, client: SnowflakeSqlClient, - stage_name: Optional[str] = None, keep_staged_files: bool = True, staging_credentials: Optional[CredentialsConfiguration] = None + self, + file_path: str, + table_name: str, + load_id: str, + client: SnowflakeSqlClient, + stage_name: Optional[str] = None, + keep_staged_files: bool = True, + staging_credentials: Optional[CredentialsConfiguration] = None, ) -> None: file_name = FileStorage.get_file_name_from_file_path(file_path) super().__init__(file_name) @@ -60,8 +67,14 @@ def __init__( qualified_table_name = client.make_qualified_table_name(table_name) # extract and prepare some vars - bucket_path = NewReferenceJob.resolve_reference(file_path) if NewReferenceJob.is_reference_job(file_path) else "" - file_name = FileStorage.get_file_name_from_file_path(bucket_path) if bucket_path else file_name + bucket_path = ( + NewReferenceJob.resolve_reference(file_path) + if NewReferenceJob.is_reference_job(file_path) + else "" + ) + file_name = ( + FileStorage.get_file_name_from_file_path(bucket_path) if bucket_path else file_name + ) from_clause = "" credentials_clause = "" files_clause = "" @@ -69,7 +82,11 @@ def __init__( if bucket_path: # s3 credentials case - if bucket_path.startswith("s3://") and staging_credentials and isinstance(staging_credentials, AwsCredentialsWithoutDefaults): + if ( + bucket_path.startswith("s3://") + and staging_credentials + and isinstance(staging_credentials, AwsCredentialsWithoutDefaults) + ): credentials_clause = f"""CREDENTIALS=(AWS_KEY_ID='{staging_credentials.aws_access_key_id}' AWS_SECRET_KEY='{staging_credentials.aws_secret_access_key}')""" from_clause = f"FROM '{bucket_path}'" else: @@ -77,14 +94,19 @@ def __init__( bucket_path = bucket_path.replace("gs://", "gcs://") if not stage_name: # when loading from bucket stage must be given - raise LoadJobTerminalException(file_path, f"Cannot load from bucket path {bucket_path} without a stage name. See https://dlthub.com/docs/dlt-ecosystem/destinations/snowflake for instructions on setting up the `stage_name`") + raise LoadJobTerminalException( + file_path, + f"Cannot load from bucket path {bucket_path} without a stage name. See" + " https://dlthub.com/docs/dlt-ecosystem/destinations/snowflake for" + " instructions on setting up the `stage_name`", + ) from_clause = f"FROM @{stage_name}/" files_clause = f"FILES = ('{urlparse(bucket_path).path.lstrip('/')}')" else: # this means we have a local file if not stage_name: # Use implicit table stage by default: "SCHEMA_NAME"."%TABLE_NAME" - stage_name = client.make_qualified_table_name('%'+table_name) + stage_name = client.make_qualified_table_name("%" + table_name) stage_file_path = f'@{stage_name}/"{load_id}"/{file_name}' from_clause = f"FROM {stage_file_path}" @@ -96,19 +118,19 @@ def __init__( with client.begin_transaction(): # PUT and COPY in one tx if local file, otherwise only copy if not bucket_path: - client.execute_sql(f'PUT file://{file_path} @{stage_name}/"{load_id}" OVERWRITE = TRUE, AUTO_COMPRESS = FALSE') - client.execute_sql( - f"""COPY INTO {qualified_table_name} + client.execute_sql( + f'PUT file://{file_path} @{stage_name}/"{load_id}" OVERWRITE = TRUE,' + " AUTO_COMPRESS = FALSE" + ) + client.execute_sql(f"""COPY INTO {qualified_table_name} {from_clause} {files_clause} {credentials_clause} FILE_FORMAT = {source_format} MATCH_BY_COLUMN_NAME='CASE_INSENSITIVE' - """ - ) + """) if stage_file_path and not keep_staged_files: - client.execute_sql(f'REMOVE {stage_file_path}') - + client.execute_sql(f"REMOVE {stage_file_path}") def state(self) -> TLoadJobState: return "completed" @@ -116,10 +138,12 @@ def state(self) -> TLoadJobState: def exception(self) -> str: raise NotImplementedError() -class SnowflakeStagingCopyJob(SqlStagingCopyJob): +class SnowflakeStagingCopyJob(SqlStagingCopyJob): @classmethod - def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any]) -> List[str]: + def generate_sql( + cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any] + ) -> List[str]: sql: List[str] = [] for table in table_chain: with sql_client.with_staging_dataset(staging=True): @@ -133,14 +157,10 @@ def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClient class SnowflakeClient(SqlJobClientWithStaging): - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() def __init__(self, schema: Schema, config: SnowflakeClientConfiguration) -> None: - sql_client = SnowflakeSqlClient( - config.normalize_dataset_name(schema), - config.credentials - ) + sql_client = SnowflakeSqlClient(config.normalize_dataset_name(schema), config.credentials) super().__init__(schema, config, sql_client) self.config: SnowflakeClientConfiguration = config self.sql_client: SnowflakeSqlClient = sql_client # type: ignore @@ -151,12 +171,14 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> if not job: job = SnowflakeLoadJob( file_path, - table['name'], + table["name"], load_id, self.sql_client, stage_name=self.config.stage_name, keep_staged_files=self.config.keep_staged_files, - staging_credentials=self.config.staging_config.credentials if self.config.staging_config else None + staging_credentials=( + self.config.staging_config.credentials if self.config.staging_config else None + ), ) return job @@ -170,10 +192,18 @@ def _make_add_column_sql(self, new_columns: Sequence[TColumnSchema]) -> List[str def _create_optimized_replace_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: return SnowflakeStagingCopyJob.from_table_chain(table_chain, self.sql_client) - def _get_table_update_sql(self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool, separate_alters: bool = False) -> List[str]: + def _get_table_update_sql( + self, + table_name: str, + new_columns: Sequence[TColumnSchema], + generate_alter: bool, + separate_alters: bool = False, + ) -> List[str]: sql = super()._get_table_update_sql(table_name, new_columns, generate_alter) - cluster_list = [self.capabilities.escape_identifier(c['name']) for c in new_columns if c.get('cluster')] + cluster_list = [ + self.capabilities.escape_identifier(c["name"]) for c in new_columns if c.get("cluster") + ] if cluster_list: sql[0] = sql[0] + "\nCLUSTER BY (" + ",".join(cluster_list) + ")" @@ -192,10 +222,10 @@ def _to_db_type(cls, sc_t: TDataType) -> str: def _from_db_type(cls, bq_t: str, precision: Optional[int], scale: Optional[int]) -> TDataType: if bq_t == "NUMBER": if precision == BIGINT_PRECISION and scale == 0: - return 'bigint' + return "bigint" elif (precision, scale) == cls.capabilities.wei_precision: - return 'wei' - return 'decimal' + return "wei" + return "decimal" return SNOW_TO_SCT.get(bq_t, "text") def _get_column_def_sql(self, c: TColumnSchema) -> str: diff --git a/dlt/destinations/snowflake/sql_client.py b/dlt/destinations/snowflake/sql_client.py index 40cdc990a0..e05515d140 100644 --- a/dlt/destinations/snowflake/sql_client.py +++ b/dlt/destinations/snowflake/sql_client.py @@ -1,14 +1,24 @@ from contextlib import contextmanager, suppress -from typing import Any, AnyStr, ClassVar, Iterator, Optional, Sequence, List +from typing import Any, AnyStr, ClassVar, Iterator, List, Optional, Sequence import snowflake.connector as snowflake_lib from dlt.common.destination import DestinationCapabilitiesContext -from dlt.destinations.exceptions import DatabaseTerminalException, DatabaseTransientException, DatabaseUndefinedRelation -from dlt.destinations.sql_client import DBApiCursorImpl, SqlClientBase, raise_database_error, raise_open_connection_error -from dlt.destinations.typing import DBApi, DBApiCursor, DBTransaction, DataFrame -from dlt.destinations.snowflake.configuration import SnowflakeCredentials +from dlt.destinations.exceptions import ( + DatabaseTerminalException, + DatabaseTransientException, + DatabaseUndefinedRelation, +) from dlt.destinations.snowflake import capabilities +from dlt.destinations.snowflake.configuration import SnowflakeCredentials +from dlt.destinations.sql_client import ( + DBApiCursorImpl, + SqlClientBase, + raise_database_error, + raise_open_connection_error, +) +from dlt.destinations.typing import DataFrame, DBApi, DBApiCursor, DBTransaction + class SnowflakeCursorImpl(DBApiCursorImpl): native_cursor: snowflake_lib.cursor.SnowflakeCursor # type: ignore[assignment] @@ -20,7 +30,6 @@ def df(self, chunk_size: int = None, **kwargs: Any) -> Optional[DataFrame]: class SnowflakeSqlClient(SqlClientBase[snowflake_lib.SnowflakeConnection], DBTransaction): - dbapi: ClassVar[DBApi] = snowflake_lib capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() @@ -36,8 +45,7 @@ def open_connection(self) -> snowflake_lib.SnowflakeConnection: if "timezone" not in conn_params: conn_params["timezone"] = "UTC" self._conn = snowflake_lib.connect( - schema=self.fully_qualified_dataset_name(), - **conn_params + schema=self.fully_qualified_dataset_name(), **conn_params ) return self._conn @@ -77,7 +85,9 @@ def drop_tables(self, *tables: str) -> None: with suppress(DatabaseUndefinedRelation): super().drop_tables(*tables) - def execute_sql(self, sql: AnyStr, *args: Any, **kwargs: Any) -> Optional[Sequence[Sequence[Any]]]: + def execute_sql( + self, sql: AnyStr, *args: Any, **kwargs: Any + ) -> Optional[Sequence[Sequence[Any]]]: with self.execute_query(sql, *args, **kwargs) as curr: if curr.description is None: return None @@ -115,7 +125,7 @@ def _reset_connection(self) -> None: @classmethod def _make_database_exception(cls, ex: Exception) -> Exception: if isinstance(ex, snowflake_lib.errors.ProgrammingError): - if ex.sqlstate == 'P0000' and ex.errno == 100132: + if ex.sqlstate == "P0000" and ex.errno == 100132: # Error in a multi statement execution. These don't show the original error codes msg = str(ex) if "NULL result in a non-nullable column" in msg: @@ -124,11 +134,11 @@ def _make_database_exception(cls, ex: Exception) -> Exception: return DatabaseUndefinedRelation(ex) else: return DatabaseTransientException(ex) - if ex.sqlstate in {'42S02', '02000'}: + if ex.sqlstate in {"42S02", "02000"}: return DatabaseUndefinedRelation(ex) - elif ex.sqlstate == '22023': # Adding non-nullable no-default column + elif ex.sqlstate == "22023": # Adding non-nullable no-default column return DatabaseTerminalException(ex) - elif ex.sqlstate == '42000' and ex.errno == 904: # Invalid identifier + elif ex.sqlstate == "42000" and ex.errno == 904: # Invalid identifier return DatabaseTerminalException(ex) elif ex.sqlstate == "22000": return DatabaseTerminalException(ex) @@ -152,7 +162,9 @@ def _make_database_exception(cls, ex: Exception) -> Exception: return ex @staticmethod - def _maybe_make_terminal_exception_from_data_error(snowflake_ex: snowflake_lib.DatabaseError) -> Optional[Exception]: + def _maybe_make_terminal_exception_from_data_error( + snowflake_ex: snowflake_lib.DatabaseError, + ) -> Optional[Exception]: return None @staticmethod diff --git a/dlt/destinations/sql_client.py b/dlt/destinations/sql_client.py index 68fb39af09..daf637b3d4 100644 --- a/dlt/destinations/sql_client.py +++ b/dlt/destinations/sql_client.py @@ -1,19 +1,29 @@ +import inspect from abc import ABC, abstractmethod from contextlib import contextmanager from functools import wraps -import inspect from types import TracebackType -from typing import Any, ClassVar, ContextManager, Generic, Iterator, Optional, Sequence, Tuple, Type, AnyStr, List +from typing import ( + Any, + AnyStr, + ClassVar, + ContextManager, + Generic, + Iterator, + List, + Optional, + Sequence, + Tuple, + Type, +) -from dlt.common.typing import TFun from dlt.common.destination import DestinationCapabilitiesContext - +from dlt.common.typing import TFun from dlt.destinations.exceptions import DestinationConnectionError, LoadClientNotConnected -from dlt.destinations.typing import DBApi, TNativeConn, DBApiCursor, DataFrame, DBTransaction +from dlt.destinations.typing import DataFrame, DBApi, DBApiCursor, DBTransaction, TNativeConn class SqlClientBase(ABC, Generic[TNativeConn]): - dbapi: ClassVar[DBApi] = None capabilities: ClassVar[DestinationCapabilitiesContext] = None @@ -45,7 +55,9 @@ def __enter__(self) -> "SqlClientBase[TNativeConn]": self.open_connection() return self - def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType) -> None: + def __exit__( + self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType + ) -> None: self.close_connection() @property @@ -78,20 +90,27 @@ def truncate_tables(self, *tables: str) -> None: def drop_tables(self, *tables: str) -> None: if not tables: return - statements = [f"DROP TABLE IF EXISTS {self.make_qualified_table_name(table)};" for table in tables] + statements = [ + f"DROP TABLE IF EXISTS {self.make_qualified_table_name(table)};" for table in tables + ] self.execute_fragments(statements) @abstractmethod - def execute_sql(self, sql: AnyStr, *args: Any, **kwargs: Any) -> Optional[Sequence[Sequence[Any]]]: + def execute_sql( + self, sql: AnyStr, *args: Any, **kwargs: Any + ) -> Optional[Sequence[Sequence[Any]]]: pass @abstractmethod - def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> ContextManager[DBApiCursor]: + def execute_query( + self, query: AnyStr, *args: Any, **kwargs: Any + ) -> ContextManager[DBApiCursor]: pass - def execute_fragments(self, fragments: Sequence[AnyStr], *args: Any, **kwargs: Any) -> Optional[Sequence[Sequence[Any]]]: - """Executes several SQL fragments as efficiently as possible to prevent data copying. Default implementation just joins the strings and executes them together. - """ + def execute_fragments( + self, fragments: Sequence[AnyStr], *args: Any, **kwargs: Any + ) -> Optional[Sequence[Sequence[Any]]]: + """Executes several SQL fragments as efficiently as possible to prevent data copying. Default implementation just joins the strings and executes them together.""" return self.execute_sql("".join(fragments), *args, **kwargs) # type: ignore @abstractmethod @@ -109,7 +128,9 @@ def escape_column_name(self, column_name: str, escape: bool = True) -> str: return column_name @contextmanager - def with_alternative_dataset_name(self, dataset_name: str) -> Iterator["SqlClientBase[TNativeConn]"]: + def with_alternative_dataset_name( + self, dataset_name: str + ) -> Iterator["SqlClientBase[TNativeConn]"]: """Sets the `dataset_name` as the default dataset during the lifetime of the context. Does not modify any search paths in the existing connection.""" current_dataset_name = self.dataset_name try: @@ -119,7 +140,9 @@ def with_alternative_dataset_name(self, dataset_name: str) -> Iterator["SqlClien # restore previous dataset name self.dataset_name = current_dataset_name - def with_staging_dataset(self, staging: bool = False)-> ContextManager["SqlClientBase[TNativeConn]"]: + def with_staging_dataset( + self, staging: bool = False + ) -> ContextManager["SqlClientBase[TNativeConn]"]: dataset_name = self.dataset_name if staging: dataset_name = SqlClientBase.make_staging_dataset_name(dataset_name) @@ -127,7 +150,7 @@ def with_staging_dataset(self, staging: bool = False)-> ContextManager["SqlClien def _ensure_native_conn(self) -> None: if not self.native_connection: - raise LoadClientNotConnected(type(self).__name__ , self.dataset_name) + raise LoadClientNotConnected(type(self).__name__, self.dataset_name) @staticmethod @abstractmethod @@ -156,6 +179,7 @@ def _truncate_table_sql(self, qualified_table_name: str) -> str: class DBApiCursorImpl(DBApiCursor): """A DBApi Cursor wrapper with dataframes reading functionality""" + def __init__(self, curr: DBApiCursor) -> None: self.native_cursor = curr @@ -187,7 +211,6 @@ def df(self, chunk_size: int = None, **kwargs: Any) -> Optional[DataFrame]: def raise_database_error(f: TFun) -> TFun: - @wraps(f) def _wrap_gen(self: SqlClientBase[Any], *args: Any, **kwargs: Any) -> Any: try: @@ -211,7 +234,6 @@ def _wrap(self: SqlClientBase[Any], *args: Any, **kwargs: Any) -> Any: def raise_open_connection_error(f: TFun) -> TFun: - @wraps(f) def _wrap(self: SqlClientBase[Any], *args: Any, **kwargs: Any) -> Any: try: diff --git a/dlt/destinations/sql_jobs.py b/dlt/destinations/sql_jobs.py index 769cabc571..e974263c81 100644 --- a/dlt/destinations/sql_jobs.py +++ b/dlt/destinations/sql_jobs.py @@ -1,8 +1,8 @@ from typing import Any, Callable, List, Sequence, Tuple, cast import yaml -from dlt.common.runtime.logger import pretty_format_exception +from dlt.common.runtime.logger import pretty_format_exception from dlt.common.schema.typing import TTableSchema from dlt.common.schema.utils import get_columns_names_with_prop from dlt.common.storages.load_storage import ParsedLoadJobFileName @@ -14,10 +14,13 @@ class SqlBaseJob(NewLoadJobImpl): """Sql base job for jobs that rely on the whole tablechain""" + failed_text: str = "" @classmethod - def from_table_chain(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any]) -> NewLoadJobImpl: + def from_table_chain( + cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any] + ) -> NewLoadJobImpl: """Generates a list of sql statements, that will be executed by the sql client when the job is executed in the loader. The `table_chain` contains a list schemas of a tables with parent-child relationship, ordered by the ancestry (the root of the tree is first on the list). @@ -28,43 +31,63 @@ def from_table_chain(cls, table_chain: Sequence[TTableSchema], sql_client: SqlCl try: # Remove line breaks from multiline statements and write one SQL statement per line in output file # to support clients that need to execute one statement at a time (i.e. snowflake) - sql = [' '.join(stmt.splitlines()) for stmt in cls.generate_sql(table_chain, sql_client)] + sql = [ + " ".join(stmt.splitlines()) for stmt in cls.generate_sql(table_chain, sql_client) + ] job = cls(file_info.job_id(), "running") job._save_text_file("\n".join(sql)) except Exception: # return failed job - tables_str = yaml.dump(table_chain, allow_unicode=True, default_flow_style=False, sort_keys=False) + tables_str = yaml.dump( + table_chain, allow_unicode=True, default_flow_style=False, sort_keys=False + ) job = cls(file_info.job_id(), "failed", pretty_format_exception()) job._save_text_file("\n".join([cls.failed_text, tables_str])) return job @classmethod - def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any]) -> List[str]: + def generate_sql( + cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any] + ) -> List[str]: pass class SqlStagingCopyJob(SqlBaseJob): """Generates a list of sql statements that copy the data from staging dataset into destination dataset.""" + failed_text: str = "Tried to generate a staging copy sql job for the following tables:" @classmethod - def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any]) -> List[str]: + def generate_sql( + cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any] + ) -> List[str]: sql: List[str] = [] for table in table_chain: with sql_client.with_staging_dataset(staging=True): staging_table_name = sql_client.make_qualified_table_name(table["name"]) table_name = sql_client.make_qualified_table_name(table["name"]) - columns = ", ".join(map(sql_client.capabilities.escape_identifier, get_columns_names_with_prop(table, "name"))) + columns = ", ".join( + map( + sql_client.capabilities.escape_identifier, + get_columns_names_with_prop(table, "name"), + ) + ) sql.append(sql_client._truncate_table_sql(table_name)) - sql.append(f"INSERT INTO {table_name}({columns}) SELECT {columns} FROM {staging_table_name};") + sql.append( + f"INSERT INTO {table_name}({columns}) SELECT {columns} FROM {staging_table_name};" + ) return sql + class SqlMergeJob(SqlBaseJob): """Generates a list of sql statements that merge the data from staging dataset into destination dataset.""" + failed_text: str = "Tried to generate a merge sql job for the following tables:" @classmethod - def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any]) -> List[str]: + def generate_sql( + cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any] + ) -> List[str]: """Generates a list of sql statements that merge the data in staging dataset with the data in destination dataset. The `table_chain` contains a list schemas of a tables with parent-child relationship, ordered by the ancestry (the root of the tree is first on the list). @@ -77,39 +100,61 @@ def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClient return cls.gen_merge_sql(table_chain, sql_client) @classmethod - def _gen_key_table_clauses(cls, primary_keys: Sequence[str], merge_keys: Sequence[str])-> List[str]: + def _gen_key_table_clauses( + cls, primary_keys: Sequence[str], merge_keys: Sequence[str] + ) -> List[str]: """Generate sql clauses to select rows to delete via merge and primary key. Return select all clause if no keys defined.""" clauses: List[str] = [] if primary_keys or merge_keys: if primary_keys: - clauses.append(" AND ".join(["%s.%s = %s.%s" % ("{d}", c, "{s}", c) for c in primary_keys])) + clauses.append( + " AND ".join(["%s.%s = %s.%s" % ("{d}", c, "{s}", c) for c in primary_keys]) + ) if merge_keys: - clauses.append(" AND ".join(["%s.%s = %s.%s" % ("{d}", c, "{s}", c) for c in merge_keys])) + clauses.append( + " AND ".join(["%s.%s = %s.%s" % ("{d}", c, "{s}", c) for c in merge_keys]) + ) return clauses or ["1=1"] @classmethod - def gen_key_table_clauses(cls, root_table_name: str, staging_root_table_name: str, key_clauses: Sequence[str], for_delete: bool) -> List[str]: + def gen_key_table_clauses( + cls, + root_table_name: str, + staging_root_table_name: str, + key_clauses: Sequence[str], + for_delete: bool, + ) -> List[str]: """Generate sql clauses that may be used to select or delete rows in root table of destination dataset - A list of clauses may be returned for engines that do not support OR in subqueries. Like BigQuery + A list of clauses may be returned for engines that do not support OR in subqueries. Like BigQuery """ - return [f"FROM {root_table_name} as d WHERE EXISTS (SELECT 1 FROM {staging_root_table_name} as s WHERE {' OR '.join([c.format(d='d',s='s') for c in key_clauses])})"] + return [ + f"FROM {root_table_name} as d WHERE EXISTS (SELECT 1 FROM {staging_root_table_name} as" + f" s WHERE {' OR '.join([c.format(d='d',s='s') for c in key_clauses])})" + ] @classmethod - def gen_delete_temp_table_sql(cls, unique_column: str, key_table_clauses: Sequence[str]) -> Tuple[List[str], str]: + def gen_delete_temp_table_sql( + cls, unique_column: str, key_table_clauses: Sequence[str] + ) -> Tuple[List[str], str]: """Generate sql that creates delete temp table and inserts `unique_column` from root table for all records to delete. May return several statements. - Returns temp table name for cases where special names are required like SQLServer. + Returns temp table name for cases where special names are required like SQLServer. """ sql: List[str] = [] temp_table_name = f"delete_{uniq_id()}" - sql.append(f"CREATE TEMP TABLE {temp_table_name} AS SELECT d.{unique_column} {key_table_clauses[0]};") + sql.append( + f"CREATE TEMP TABLE {temp_table_name} AS SELECT" + f" d.{unique_column} {key_table_clauses[0]};" + ) for clause in key_table_clauses[1:]: sql.append(f"INSERT INTO {temp_table_name} SELECT {unique_column} {clause};") return sql, temp_table_name @classmethod - def gen_insert_temp_table_sql(cls, staging_root_table_name: str, primary_keys: Sequence[str], unique_column: str) -> Tuple[List[str], str]: + def gen_insert_temp_table_sql( + cls, staging_root_table_name: str, primary_keys: Sequence[str], unique_column: str + ) -> Tuple[List[str], str]: sql: List[str] = [] temp_table_name = f"insert_{uniq_id()}" sql.append(f"""CREATE TEMP TABLE {temp_table_name} AS @@ -121,7 +166,9 @@ def gen_insert_temp_table_sql(cls, staging_root_table_name: str, primary_keys: S return sql, temp_table_name @classmethod - def gen_merge_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any]) -> List[str]: + def gen_merge_sql( + cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any] + ) -> List[str]: sql: List[str] = [] root_table = table_chain[0] @@ -130,22 +177,35 @@ def gen_merge_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClien with sql_client.with_staging_dataset(staging=True): staging_root_table_name = sql_client.make_qualified_table_name(root_table["name"]) # get merge and primary keys from top level - primary_keys = list(map(sql_client.capabilities.escape_identifier, get_columns_names_with_prop(root_table, "primary_key"))) - merge_keys = list(map(sql_client.capabilities.escape_identifier, get_columns_names_with_prop(root_table, "merge_key"))) + primary_keys = list( + map( + sql_client.capabilities.escape_identifier, + get_columns_names_with_prop(root_table, "primary_key"), + ) + ) + merge_keys = list( + map( + sql_client.capabilities.escape_identifier, + get_columns_names_with_prop(root_table, "merge_key"), + ) + ) key_clauses = cls._gen_key_table_clauses(primary_keys, merge_keys) unique_column: str = None root_key_column: str = None insert_temp_table_sql: str = None - if len(table_chain) == 1: - key_table_clauses = cls.gen_key_table_clauses(root_table_name, staging_root_table_name, key_clauses, for_delete=True) + key_table_clauses = cls.gen_key_table_clauses( + root_table_name, staging_root_table_name, key_clauses, for_delete=True + ) # if no child tables, just delete data from top table for clause in key_table_clauses: sql.append(f"DELETE {clause};") else: - key_table_clauses = cls.gen_key_table_clauses(root_table_name, staging_root_table_name, key_clauses, for_delete=False) + key_table_clauses = cls.gen_key_table_clauses( + root_table_name, staging_root_table_name, key_clauses, for_delete=False + ) # use unique hint to create temp table with all identifiers to delete unique_columns = get_columns_names_with_prop(root_table, "unique") if not unique_columns: @@ -153,15 +213,21 @@ def gen_merge_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClien sql_client.fully_qualified_dataset_name(), staging_root_table_name, [t["name"] for t in table_chain], - f"There is no unique column (ie _dlt_id) in top table {root_table['name']} so it is not possible to link child tables to it." + f"There is no unique column (ie _dlt_id) in top table {root_table['name']} so" + " it is not possible to link child tables to it.", ) # get first unique column unique_column = sql_client.capabilities.escape_identifier(unique_columns[0]) # create temp table with unique identifier - create_delete_temp_table_sql, delete_temp_table_sql = cls.gen_delete_temp_table_sql(unique_column, key_table_clauses) + create_delete_temp_table_sql, delete_temp_table_sql = cls.gen_delete_temp_table_sql( + unique_column, key_table_clauses + ) sql.extend(create_delete_temp_table_sql) # delete top table - sql.append(f"DELETE FROM {root_table_name} WHERE {unique_column} IN (SELECT * FROM {delete_temp_table_sql});") + sql.append( + f"DELETE FROM {root_table_name} WHERE {unique_column} IN (SELECT * FROM" + f" {delete_temp_table_sql});" + ) # delete other tables for table in table_chain[1:]: table_name = sql_client.make_qualified_table_name(table["name"]) @@ -171,13 +237,20 @@ def gen_merge_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClien sql_client.fully_qualified_dataset_name(), staging_root_table_name, [t["name"] for t in table_chain], - f"There is no root foreign key (ie _dlt_root_id) in child table {table['name']} so it is not possible to refer to top level table {root_table['name']} unique column {unique_column}" + "There is no root foreign key (ie _dlt_root_id) in child table" + f" {table['name']} so it is not possible to refer to top level table" + f" {root_table['name']} unique column {unique_column}", ) root_key_column = sql_client.capabilities.escape_identifier(root_key_columns[0]) - sql.append(f"DELETE FROM {table_name} WHERE {root_key_column} IN (SELECT * FROM {delete_temp_table_sql});") + sql.append( + f"DELETE FROM {table_name} WHERE {root_key_column} IN (SELECT * FROM" + f" {delete_temp_table_sql});" + ) # create temp table used to deduplicate, only when we have primary keys if primary_keys: - create_insert_temp_table_sql, insert_temp_table_sql = cls.gen_insert_temp_table_sql(staging_root_table_name, primary_keys, unique_column) + create_insert_temp_table_sql, insert_temp_table_sql = cls.gen_insert_temp_table_sql( + staging_root_table_name, primary_keys, unique_column + ) sql.extend(create_insert_temp_table_sql) # insert from staging to dataset, truncate staging table @@ -185,8 +258,15 @@ def gen_merge_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClien table_name = sql_client.make_qualified_table_name(table["name"]) with sql_client.with_staging_dataset(staging=True): staging_table_name = sql_client.make_qualified_table_name(table["name"]) - columns = ", ".join(map(sql_client.capabilities.escape_identifier, get_columns_names_with_prop(table, "name"))) - insert_sql = f"INSERT INTO {table_name}({columns}) SELECT {columns} FROM {staging_table_name}" + columns = ", ".join( + map( + sql_client.capabilities.escape_identifier, + get_columns_names_with_prop(table, "name"), + ) + ) + insert_sql = ( + f"INSERT INTO {table_name}({columns}) SELECT {columns} FROM {staging_table_name}" + ) if len(primary_keys) > 0: if len(table_chain) == 1: insert_sql = f"""INSERT INTO {table_name}({columns}) @@ -197,7 +277,9 @@ def gen_merge_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClien SELECT {columns} FROM _dlt_dedup_numbered WHERE _dlt_dedup_rn = 1;""" else: uniq_column = unique_column if table.get("parent") is None else root_key_column - insert_sql += f" WHERE {uniq_column} IN (SELECT * FROM {insert_temp_table_sql});" + insert_sql += ( + f" WHERE {uniq_column} IN (SELECT * FROM {insert_temp_table_sql});" + ) if insert_sql[-1].strip() != ";": insert_sql += ";" diff --git a/dlt/destinations/typing.py b/dlt/destinations/typing.py index 7edf69d2ea..32aaa52e60 100644 --- a/dlt/destinations/typing.py +++ b/dlt/destinations/typing.py @@ -1,4 +1,5 @@ -from typing import Any, AnyStr, List, Type, Optional, Protocol, Tuple, TypeVar +from typing import Any, AnyStr, List, Optional, Protocol, Tuple, Type, TypeVar + try: from pandas import DataFrame except ImportError: @@ -7,6 +8,7 @@ # native connection TNativeConn = TypeVar("TNativeConn", bound=Any) + class DBTransaction(Protocol): def commit_transaction(self) -> None: ... @@ -23,6 +25,7 @@ class DBApi(Protocol): class DBApiCursor(Protocol): """Protocol for DBAPI cursor""" + description: Tuple[Any, ...] native_cursor: "DBApiCursor" @@ -30,12 +33,16 @@ class DBApiCursor(Protocol): def execute(self, query: AnyStr, *args: Any, **kwargs: Any) -> None: ... + def fetchall(self) -> List[Tuple[Any, ...]]: ... + def fetchmany(self, size: int = ...) -> List[Tuple[Any, ...]]: ... + def fetchone(self) -> Optional[Tuple[Any, ...]]: ... + def close(self) -> None: ... @@ -54,4 +61,3 @@ def df(self, chunk_size: int = None, **kwargs: None) -> Optional[DataFrame]: Optional[DataFrame]: A data frame with query results. If chunk_size > 0, None will be returned if there is no more data in results """ ... - diff --git a/dlt/destinations/weaviate/__init__.py b/dlt/destinations/weaviate/__init__.py index ebd87aea0c..f31c292ba6 100644 --- a/dlt/destinations/weaviate/__init__.py +++ b/dlt/destinations/weaviate/__init__.py @@ -1,16 +1,12 @@ from typing import Type -from dlt.common.schema.schema import Schema -from dlt.common.configuration import with_config, known_sections +from dlt.common.configuration import known_sections, with_config from dlt.common.configuration.accessors import config -from dlt.common.destination.reference import ( - JobClientBase, - DestinationClientConfiguration, -) from dlt.common.destination import DestinationCapabilitiesContext - -from dlt.destinations.weaviate.weaviate_adapter import weaviate_adapter +from dlt.common.destination.reference import DestinationClientConfiguration, JobClientBase +from dlt.common.schema.schema import Schema from dlt.destinations.weaviate.configuration import WeaviateClientConfiguration +from dlt.destinations.weaviate.weaviate_adapter import weaviate_adapter @with_config( diff --git a/dlt/destinations/weaviate/configuration.py b/dlt/destinations/weaviate/configuration.py index 3f4c5f0a09..3ea258c8bd 100644 --- a/dlt/destinations/weaviate/configuration.py +++ b/dlt/destinations/weaviate/configuration.py @@ -1,5 +1,5 @@ -from typing import Dict, Literal, Optional, Final from dataclasses import field +from typing import Dict, Final, Literal, Optional from urllib.parse import urlparse from dlt.common.configuration import configspec @@ -33,18 +33,20 @@ class WeaviateClientConfiguration(DestinationClientDwhConfiguration): batch_consistency: TWeaviateBatchConsistency = "ONE" batch_retries: int = 5 conn_timeout: int = 10 - read_timeout: int = 3*60 + read_timeout: int = 3 * 60 dataset_separator: str = "_" credentials: WeaviateCredentials vectorizer: str = "text2vec-openai" - module_config: Dict[str, Dict[str, str]] = field(default_factory=lambda: { - "text2vec-openai": { - "model": "ada", - "modelVersion": "002", - "type": "text", + module_config: Dict[str, Dict[str, str]] = field( + default_factory=lambda: { + "text2vec-openai": { + "model": "ada", + "modelVersion": "002", + "type": "text", + } } - }) + ) def fingerprint(self) -> str: """Returns a fingerprint of host part of a connection string""" diff --git a/dlt/destinations/weaviate/exceptions.py b/dlt/destinations/weaviate/exceptions.py index 0177f64c51..b55152a8fd 100644 --- a/dlt/destinations/weaviate/exceptions.py +++ b/dlt/destinations/weaviate/exceptions.py @@ -2,4 +2,4 @@ class WeaviateBatchError(DestinationException): - pass \ No newline at end of file + pass diff --git a/dlt/destinations/weaviate/naming.py b/dlt/destinations/weaviate/naming.py index 1c43e0b58a..c6c9299d2e 100644 --- a/dlt/destinations/weaviate/naming.py +++ b/dlt/destinations/weaviate/naming.py @@ -7,11 +7,7 @@ class NamingConvention(SnakeCaseNamingConvention): """Normalizes identifiers according to Weaviate documentation: https://weaviate.io/developers/weaviate/config-refs/schema#class""" - RESERVED_PROPERTIES = { - "id": "__id", - "_id": "___id", - "_additional": "__additional" - } + RESERVED_PROPERTIES = {"id": "__id", "_id": "___id", "_additional": "__additional"} _RE_UNDERSCORES = re.compile("([^_])__+") _STARTS_DIGIT = re.compile("^[0-9]") _STARTS_NON_LETTER = re.compile("^[0-9_]") @@ -32,7 +28,10 @@ def normalize_table_identifier(self, identifier: str) -> str: identifier = BaseNamingConvention.normalize_identifier(self, identifier) norm_identifier = self._base_normalize(identifier) # norm_identifier = norm_identifier.strip("_") - norm_identifier = "".join(s[1:2].upper() + s[2:] if s and s[0] == "_" else s for s in self._SPLIT_UNDERSCORE_NON_CAP.split(norm_identifier)) + norm_identifier = "".join( + s[1:2].upper() + s[2:] if s and s[0] == "_" else s + for s in self._SPLIT_UNDERSCORE_NON_CAP.split(norm_identifier) + ) norm_identifier = norm_identifier[0].upper() + norm_identifier[1:] if self._STARTS_NON_LETTER.match(norm_identifier): norm_identifier = "C" + norm_identifier diff --git a/dlt/destinations/weaviate/weaviate_adapter.py b/dlt/destinations/weaviate/weaviate_adapter.py index 048c7e44b7..472bd861b5 100644 --- a/dlt/destinations/weaviate/weaviate_adapter.py +++ b/dlt/destinations/weaviate/weaviate_adapter.py @@ -1,4 +1,4 @@ -from typing import Dict, Any, Literal, Set, get_args +from typing import Any, Dict, Literal, Set, get_args from dlt.common.schema.typing import TColumnNames, TTableSchemaColumns from dlt.extract.decorators import resource as make_resource @@ -70,8 +70,7 @@ def weaviate_adapter( vectorize = [vectorize] if not isinstance(vectorize, list): raise ValueError( - "vectorize must be a list of column names or a single " - "column name as a string" + "vectorize must be a list of column names or a single column name as a string" ) # create weaviate-specific vectorize hints for column_name in vectorize: @@ -84,7 +83,10 @@ def weaviate_adapter( for column_name, method in tokenization.items(): if method not in TOKENIZATION_METHODS: allowed_methods = ", ".join(TOKENIZATION_METHODS) - raise ValueError(f"Tokenization type {method} for column {column_name} is invalid. Allowed methods are: {allowed_methods}") + raise ValueError( + f"Tokenization type {method} for column {column_name} is invalid. Allowed" + f" methods are: {allowed_methods}" + ) if column_name in column_hints: column_hints[column_name][TOKENIZATION_HINT] = method # type: ignore else: diff --git a/dlt/destinations/weaviate/weaviate_client.py b/dlt/destinations/weaviate/weaviate_client.py index 8d418680b1..f40b49175c 100644 --- a/dlt/destinations/weaviate/weaviate_client.py +++ b/dlt/destinations/weaviate/weaviate_client.py @@ -1,51 +1,32 @@ from functools import wraps from types import TracebackType -from typing import ( - ClassVar, - Optional, - Sequence, - List, - Dict, - Type, - Iterable, - Any, - IO, - Tuple, - cast, -) - -from dlt.common.exceptions import ( - DestinationUndefinedEntity, - DestinationTransientException, - DestinationTerminalException, -) +from typing import IO, Any, ClassVar, Dict, Iterable, List, Optional, Sequence, Tuple, Type, cast import weaviate from weaviate.gql.get import GetBuilder from weaviate.util import generate_uuid5 -from dlt.common import json, pendulum, logger -from dlt.common.typing import StrAny, TFun -from dlt.common.time import ensure_pendulum_datetime -from dlt.common.schema import Schema, TTableSchema, TSchemaTables, TTableSchemaColumns -from dlt.common.schema.typing import TColumnSchema -from dlt.common.schema.utils import get_columns_names_with_prop +from dlt.common import json, logger, pendulum +from dlt.common.data_types import TDataType from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import ( - TLoadJobState, - LoadJob, - JobClientBase, +from dlt.common.destination.reference import JobClientBase, LoadJob, TLoadJobState +from dlt.common.exceptions import ( + DestinationTerminalException, + DestinationTransientException, + DestinationUndefinedEntity, ) -from dlt.common.data_types import TDataType +from dlt.common.schema import Schema, TSchemaTables, TTableSchema, TTableSchemaColumns +from dlt.common.schema.typing import TColumnSchema +from dlt.common.schema.utils import get_columns_names_with_prop from dlt.common.storages import FileStorage - -from dlt.destinations.weaviate.weaviate_adapter import VECTORIZE_HINT, TOKENIZATION_HINT - -from dlt.destinations.job_impl import EmptyLoadJob +from dlt.common.time import ensure_pendulum_datetime +from dlt.common.typing import StrAny, TFun from dlt.destinations.job_client_impl import StorageSchemaInfo +from dlt.destinations.job_impl import EmptyLoadJob from dlt.destinations.weaviate import capabilities from dlt.destinations.weaviate.configuration import WeaviateClientConfiguration from dlt.destinations.weaviate.exceptions import WeaviateBatchError +from dlt.destinations.weaviate.weaviate_adapter import TOKENIZATION_HINT, VECTORIZE_HINT SCT_TO_WT: Dict[TDataType, str] = { "text": "text", @@ -108,12 +89,8 @@ def _wrap(*args: Any, **kwargs: Any) -> Any: message = errors["error"][0]["message"] # TODO: actually put the job in failed/retry state and prepare exception message with full info on failing item if "invalid" in message and "property" in message and "on class" in message: - raise DestinationTerminalException( - f"Batch failed {errors} AND WILL BE RETRIED" - ) - raise DestinationTransientException( - f"Batch failed {errors} AND WILL BE RETRIED" - ) + raise DestinationTerminalException(f"Batch failed {errors} AND WILL BE RETRIED") + raise DestinationTransientException(f"Batch failed {errors} AND WILL BE RETRIED") except Exception: raise DestinationTransientException("Batch failed AND WILL BE RETRIED") @@ -173,9 +150,7 @@ def check_batch_result(results: List[StrAny]) -> None: weaviate_error_retries=weaviate.WeaviateErrorRetryConf( self.client_config.batch_retries ), - consistency_level=weaviate.ConsistencyLevel[ - self.client_config.batch_consistency - ], + consistency_level=weaviate.ConsistencyLevel[self.client_config.batch_consistency], num_workers=self.client_config.batch_workers, callback=check_batch_result, ) as batch: @@ -189,9 +164,7 @@ def check_batch_result(results: List[StrAny]) -> None: if key in data: data[key] = str(ensure_pendulum_datetime(data[key])) if self.unique_identifiers: - uuid = self.generate_uuid( - data, self.unique_identifiers, self.class_name - ) + uuid = self.generate_uuid(data, self.unique_identifiers, self.class_name) else: uuid = None @@ -263,9 +236,7 @@ def make_full_name(self, table_name: str) -> str: def get_class_schema(self, table_name: str) -> Dict[str, Any]: """Get the Weaviate class schema for a table.""" - return cast( - Dict[str, Any], self.db_client.schema.get(self.make_full_name(table_name)) - ) + return cast(Dict[str, Any], self.db_client.schema.get(self.make_full_name(table_name))) def create_class( self, class_schema: Dict[str, Any], full_class_name: Optional[str] = None @@ -288,18 +259,14 @@ def create_class( self.db_client.schema.create_class(updated_schema) - def create_class_property( - self, class_name: str, prop_schema: Dict[str, Any] - ) -> None: + def create_class_property(self, class_name: str, prop_schema: Dict[str, Any]) -> None: """Create a Weaviate class property. Args: class_name: The name of the class to create the property on. prop_schema: The property schema to create. """ - self.db_client.schema.property.create( - self.make_full_name(class_name), prop_schema - ) + self.db_client.schema.property.create(self.make_full_name(class_name), prop_schema) def delete_class(self, class_name: str) -> None: """Delete a Weaviate class. @@ -405,14 +372,14 @@ def update_storage_schema( if schema_info is None: logger.info( f"Schema with hash {self.schema.stored_version_hash} " - f"not found in the storage. upgrading" + "not found in the storage. upgrading" ) self._execute_schema_update(only_tables) else: logger.info( f"Schema with hash {self.schema.stored_version_hash} " f"inserted at {schema_info.inserted_at} found " - f"in storage, no upgrade required" + "in storage, no upgrade required" ) return applied_update @@ -420,12 +387,8 @@ def update_storage_schema( def _execute_schema_update(self, only_tables: Iterable[str]) -> None: for table_name in only_tables or self.schema.tables: exists, existing_columns = self.get_storage_table(table_name) - new_columns = self.schema.get_new_table_columns( - table_name, existing_columns - ) - logger.info( - f"Found {len(new_columns)} updates for {table_name} in {self.schema.name}" - ) + new_columns = self.schema.get_new_table_columns(table_name, existing_columns) + logger.info(f"Found {len(new_columns)} updates for {table_name} in {self.schema.name}") if len(new_columns) > 0: if exists: for column in new_columns: @@ -559,9 +522,7 @@ def _make_non_vectorized_class_schema(self, table_name: str) -> Dict[str, Any]: }, } - def start_file_load( - self, table: TTableSchema, file_path: str, load_id: str - ) -> LoadJob: + def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: return LoadWeaviateJob( self.schema, table, diff --git a/dlt/extract/decorators.py b/dlt/extract/decorators.py index 1e652c84d0..475306f2b1 100644 --- a/dlt/extract/decorators.py +++ b/dlt/extract/decorators.py @@ -1,10 +1,24 @@ -import os import inspect -from types import ModuleType +import os from functools import wraps -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterator, List, Optional, Tuple, Type, TypeVar, Union, cast, overload - -from dlt.common.configuration import with_config, get_fun_spec, known_sections, configspec +from types import ModuleType +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Iterator, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, + cast, + overload, +) + +from dlt.common.configuration import configspec, get_fun_spec, known_sections, with_config from dlt.common.configuration.container import Container from dlt.common.configuration.exceptions import ContextDefaultCannotBeCreated from dlt.common.configuration.resolve import inject_section @@ -12,32 +26,42 @@ from dlt.common.configuration.specs.config_section_context import ConfigSectionContext from dlt.common.exceptions import ArgumentsOverloadException from dlt.common.pipeline import PipelineContext -from dlt.common.source import _SOURCES, SourceInfo from dlt.common.schema.schema import Schema from dlt.common.schema.typing import TColumnNames, TTableSchemaColumns, TWriteDisposition +from dlt.common.source import _SOURCES, SourceInfo from dlt.common.storages.exceptions import SchemaNotFoundError from dlt.common.storages.schema_storage import SchemaStorage -from dlt.common.typing import AnyFun, ParamSpec, Concatenate, TDataItem, TDataItems +from dlt.common.typing import AnyFun, Concatenate, ParamSpec, TDataItem, TDataItems from dlt.common.utils import get_callable_name, get_module_name, is_inner_callable -from dlt.extract.exceptions import InvalidTransformerDataTypeGeneratorFunctionRequired, ResourceFunctionExpected, ResourceInnerCallableConfigWrapDisallowed, SourceDataIsNone, SourceIsAClassTypeError, ExplicitSourceNameInvalid, SourceNotAFunction, SourceSchemaNotAvailable +from dlt.extract.exceptions import ( + ExplicitSourceNameInvalid, + InvalidTransformerDataTypeGeneratorFunctionRequired, + ResourceFunctionExpected, + ResourceInnerCallableConfigWrapDisallowed, + SourceDataIsNone, + SourceIsAClassTypeError, + SourceNotAFunction, + SourceSchemaNotAvailable, +) from dlt.extract.incremental import IncrementalResourceWrapper - -from dlt.extract.typing import TTableHintTemplate from dlt.extract.source import DltResource, DltSource, TUnboundDltResource - +from dlt.extract.typing import TTableHintTemplate @configspec class SourceSchemaInjectableContext(ContainerInjectableContext): """A context containing the source schema, present when decorated function is executed""" + schema: Schema can_create_default: ClassVar[bool] = False if TYPE_CHECKING: + def __init__(self, schema: Schema = None) -> None: ... + TSourceFunParams = ParamSpec("TSourceFunParams") TResourceFunParams = ParamSpec("TResourceFunParams") @@ -51,10 +75,11 @@ def source( max_table_nesting: int = None, root_key: bool = False, schema: Schema = None, - spec: Type[BaseConfiguration] = None + spec: Type[BaseConfiguration] = None, ) -> Callable[TSourceFunParams, DltSource]: ... + @overload def source( func: None = ..., @@ -64,10 +89,11 @@ def source( max_table_nesting: int = None, root_key: bool = False, schema: Schema = None, - spec: Type[BaseConfiguration] = None + spec: Type[BaseConfiguration] = None, ) -> Callable[[Callable[TSourceFunParams, Any]], Callable[TSourceFunParams, DltSource]]: ... + def source( func: Optional[AnyFun] = None, /, @@ -76,7 +102,7 @@ def source( max_table_nesting: int = None, root_key: bool = False, schema: Schema = None, - spec: Type[BaseConfiguration] = None + spec: Type[BaseConfiguration] = None, ) -> Any: """A decorator that transforms a function returning one or more `dlt resources` into a `dlt source` in order to load it with `dlt`. @@ -117,7 +143,9 @@ def source( """ if name and schema: - raise ArgumentsOverloadException("'name' has no effect when `schema` argument is present", source.__name__) + raise ArgumentsOverloadException( + "'name' has no effect when `schema` argument is present", source.__name__ + ) def decorator(f: Callable[TSourceFunParams, Any]) -> Callable[TSourceFunParams, DltSource]: nonlocal schema, name @@ -155,7 +183,11 @@ def _wrap(*args: Any, **kwargs: Any) -> DltSource: # configurations will be accessed in this section in the source proxy = Container()[PipelineContext] pipeline_name = None if not proxy.is_active() else proxy.pipeline().pipeline_name - with inject_section(ConfigSectionContext(pipeline_name=pipeline_name, sections=source_sections, source_state_key=name)): + with inject_section( + ConfigSectionContext( + pipeline_name=pipeline_name, sections=source_sections, source_state_key=name + ) + ): rv = conf_f(*args, **kwargs) if rv is None: raise SourceDataIsNone(name) @@ -172,7 +204,6 @@ def _wrap(*args: Any, **kwargs: Any) -> DltSource: s.root_key = root_key return s - # get spec for wrapped function SPEC = get_fun_spec(conf_f) # store the source information @@ -200,10 +231,11 @@ def resource( primary_key: TTableHintTemplate[TColumnNames] = None, merge_key: TTableHintTemplate[TColumnNames] = None, selected: bool = True, - spec: Type[BaseConfiguration] = None + spec: Type[BaseConfiguration] = None, ) -> Callable[TResourceFunParams, DltResource]: ... + @overload def resource( data: None = ..., @@ -215,10 +247,11 @@ def resource( primary_key: TTableHintTemplate[TColumnNames] = None, merge_key: TTableHintTemplate[TColumnNames] = None, selected: bool = True, - spec: Type[BaseConfiguration] = None + spec: Type[BaseConfiguration] = None, ) -> Callable[[Callable[TResourceFunParams, Any]], DltResource]: ... + @overload def resource( data: Union[List[Any], Tuple[Any], Iterator[Any]], @@ -230,7 +263,7 @@ def resource( primary_key: TTableHintTemplate[TColumnNames] = None, merge_key: TTableHintTemplate[TColumnNames] = None, selected: bool = True, - spec: Type[BaseConfiguration] = None + spec: Type[BaseConfiguration] = None, ) -> DltResource: ... @@ -303,22 +336,36 @@ def resource( Returns: DltResource instance which may be loaded, iterated or combined with other resources into a pipeline. """ - def make_resource(_name: str, _section: str, _data: Any, incremental: IncrementalResourceWrapper = None) -> DltResource: + + def make_resource( + _name: str, _section: str, _data: Any, incremental: IncrementalResourceWrapper = None + ) -> DltResource: table_template = DltResource.new_table_template( table_name or _name, write_disposition=write_disposition, columns=columns, primary_key=primary_key, - merge_key=merge_key + merge_key=merge_key, + ) + return DltResource.from_data( + _data, + _name, + _section, + table_template, + selected, + cast(DltResource, depends_on), + incremental=incremental, ) - return DltResource.from_data(_data, _name, _section, table_template, selected, cast(DltResource, depends_on), incremental=incremental) - - def decorator(f: Callable[TResourceFunParams, Any]) -> Callable[TResourceFunParams, DltResource]: + def decorator( + f: Callable[TResourceFunParams, Any] + ) -> Callable[TResourceFunParams, DltResource]: if not callable(f): if depends_on: # raise more descriptive exception if we construct transformer - raise InvalidTransformerDataTypeGeneratorFunctionRequired(name or "", f, type(f)) + raise InvalidTransformerDataTypeGeneratorFunctionRequired( + name or "", f, type(f) + ) raise ResourceFunctionExpected(name or "", f, type(f)) resource_name = name or get_callable_name(f) @@ -340,7 +387,10 @@ def decorator(f: Callable[TResourceFunParams, Any]) -> Callable[TResourceFunPara # this lets the source to override those values and provide common section for all config values for resources present in that source conf_f = with_config( incr_f, - spec=spec, sections=resource_sections, sections_merge_style=ConfigSectionContext.resource_merge_style, include_defaults=False + spec=spec, + sections=resource_sections, + sections_merge_style=ConfigSectionContext.resource_merge_style, + include_defaults=False, ) is_inner_resource = is_inner_callable(f) if conf_f != incr_f and is_inner_resource: @@ -384,10 +434,14 @@ def transformer( primary_key: TTableHintTemplate[TColumnNames] = None, merge_key: TTableHintTemplate[TColumnNames] = None, selected: bool = True, - spec: Type[BaseConfiguration] = None -) -> Callable[[Callable[Concatenate[TDataItem, TResourceFunParams], Any]], Callable[TResourceFunParams, DltResource]]: + spec: Type[BaseConfiguration] = None, +) -> Callable[ + [Callable[Concatenate[TDataItem, TResourceFunParams], Any]], + Callable[TResourceFunParams, DltResource], +]: ... + @overload def transformer( f: Callable[Concatenate[TDataItem, TResourceFunParams], Any], @@ -400,10 +454,11 @@ def transformer( primary_key: TTableHintTemplate[TColumnNames] = None, merge_key: TTableHintTemplate[TColumnNames] = None, selected: bool = True, - spec: Type[BaseConfiguration] = None + spec: Type[BaseConfiguration] = None, ) -> Callable[TResourceFunParams, DltResource]: ... + def transformer( # type: ignore f: Optional[Callable[Concatenate[TDataItem, TResourceFunParams], Any]] = None, /, @@ -415,8 +470,11 @@ def transformer( # type: ignore primary_key: TTableHintTemplate[TColumnNames] = None, merge_key: TTableHintTemplate[TColumnNames] = None, selected: bool = True, - spec: Type[BaseConfiguration] = None -) -> Callable[[Callable[Concatenate[TDataItem, TResourceFunParams], Any]], Callable[TResourceFunParams, DltResource]]: + spec: Type[BaseConfiguration] = None, +) -> Callable[ + [Callable[Concatenate[TDataItem, TResourceFunParams], Any]], + Callable[TResourceFunParams, DltResource], +]: """A form of `dlt resource` that takes input from other resources via `data_from` argument in order to enrich or transform the data. The decorated function `f` must take at least one argument of type TDataItems (a single item or list of items depending on the resource `data_from`). `dlt` will pass @@ -470,7 +528,10 @@ def transformer( # type: ignore spec (Type[BaseConfiguration], optional): A specification of configuration and secret values required by the source. """ if isinstance(f, DltResource): - raise ValueError("Please pass `data_from=` argument as keyword argument. The only positional argument to transformer is the decorated function") + raise ValueError( + "Please pass `data_from=` argument as keyword argument. The only positional argument to" + " transformer is the decorated function" + ) return resource( # type: ignore f, @@ -482,7 +543,7 @@ def transformer( # type: ignore merge_key=merge_key, selected=selected, spec=spec, - depends_on=data_from + depends_on=data_from, ) @@ -521,12 +582,14 @@ def get_source_schema() -> Schema: TDeferredFunParams = ParamSpec("TDeferredFunParams") -def defer(f: Callable[TDeferredFunParams, TBoundItems]) -> Callable[TDeferredFunParams, TDeferred[TBoundItems]]: - +def defer( + f: Callable[TDeferredFunParams, TBoundItems] +) -> Callable[TDeferredFunParams, TDeferred[TBoundItems]]: @wraps(f) def _wrap(*args: Any, **kwargs: Any) -> TDeferred[TBoundItems]: def _curry() -> TBoundItems: return f(*args, **kwargs) + return _curry return _wrap diff --git a/dlt/extract/exceptions.py b/dlt/extract/exceptions.py index 15750fa9c7..1dbaba7d2d 100644 --- a/dlt/extract/exceptions.py +++ b/dlt/extract/exceptions.py @@ -41,7 +41,11 @@ def __init__(self, pipe_name: str, has_parent: bool) -> None: self.pipe_name = pipe_name self.has_parent = has_parent if has_parent: - msg = f"A pipe created from transformer {pipe_name} is unbound or its parent is unbound or empty. Provide a resource in `data_from` argument or bind resources with | operator." + msg = ( + f"A pipe created from transformer {pipe_name} is unbound or its parent is unbound" + " or empty. Provide a resource in `data_from` argument or bind resources with |" + " operator." + ) else: msg = "Pipe is empty and does not have a resource at its head" super().__init__(pipe_name, msg) @@ -51,21 +55,37 @@ class InvalidStepFunctionArguments(PipeException): def __init__(self, pipe_name: str, func_name: str, sig: Signature, call_error: str) -> None: self.func_name = func_name self.sig = sig - super().__init__(pipe_name, f"Unable to call {func_name}: {call_error}. The mapping/filtering function {func_name} requires first argument to take data item and optional second argument named 'meta', but the signature is {sig}") + super().__init__( + pipe_name, + f"Unable to call {func_name}: {call_error}. The mapping/filtering function" + f" {func_name} requires first argument to take data item and optional second argument" + f" named 'meta', but the signature is {sig}", + ) class ResourceExtractionError(PipeException): def __init__(self, pipe_name: str, gen: Any, msg: str, kind: str) -> None: self.msg = msg self.kind = kind - self.func_name = gen.__name__ if isgenerator(gen) else get_callable_name(gen) if callable(gen) else str(gen) - super().__init__(pipe_name, f"extraction of resource {pipe_name} in {kind} {self.func_name} caused an exception: {msg}") + self.func_name = ( + gen.__name__ + if isgenerator(gen) + else get_callable_name(gen) if callable(gen) else str(gen) + ) + super().__init__( + pipe_name, + f"extraction of resource {pipe_name} in {kind} {self.func_name} caused an exception:" + f" {msg}", + ) class ResourceNameMissing(DltResourceException): def __init__(self) -> None: - super().__init__(None, """Resource name is missing. If you create a resource directly from data ie. from a list you must pass the name explicitly in `name` argument. - Please note that for resources created from functions or generators, the name is the function name by default.""") + super().__init__( + None, + """Resource name is missing. If you create a resource directly from data ie. from a list you must pass the name explicitly in `name` argument. + Please note that for resources created from functions or generators, the name is the function name by default.""", + ) # class DependentResourceIsNotCallable(DltResourceException): @@ -74,42 +94,78 @@ def __init__(self) -> None: class ResourceNotFoundError(DltResourceException, KeyError): - def __init__(self, resource_name: str, context: str) -> None: - self.resource_name = resource_name - super().__init__(resource_name, f"Resource with a name {resource_name} could not be found. {context}") + def __init__(self, resource_name: str, context: str) -> None: + self.resource_name = resource_name + super().__init__( + resource_name, f"Resource with a name {resource_name} could not be found. {context}" + ) class InvalidResourceDataType(DltResourceException): def __init__(self, resource_name: str, item: Any, _typ: Type[Any], msg: str) -> None: self.item = item self._typ = _typ - super().__init__(resource_name, f"Cannot create resource {resource_name} from specified data. " + msg) + super().__init__( + resource_name, f"Cannot create resource {resource_name} from specified data. " + msg + ) class InvalidResourceDataTypeAsync(InvalidResourceDataType): - def __init__(self, resource_name: str, item: Any,_typ: Type[Any]) -> None: - super().__init__(resource_name, item, _typ, "Async iterators and generators are not valid resources. Please use standard iterators and generators that yield Awaitables instead (for example by yielding from async function without await") + def __init__(self, resource_name: str, item: Any, _typ: Type[Any]) -> None: + super().__init__( + resource_name, + item, + _typ, + "Async iterators and generators are not valid resources. Please use standard iterators" + " and generators that yield Awaitables instead (for example by yielding from async" + " function without await", + ) class InvalidResourceDataTypeBasic(InvalidResourceDataType): - def __init__(self, resource_name: str, item: Any,_typ: Type[Any]) -> None: - super().__init__(resource_name, item, _typ, f"Resources cannot be strings or dictionaries but {_typ.__name__} was provided. Please pass your data in a list or as a function yielding items. If you want to process just one data item, enclose it in a list.") + def __init__(self, resource_name: str, item: Any, _typ: Type[Any]) -> None: + super().__init__( + resource_name, + item, + _typ, + f"Resources cannot be strings or dictionaries but {_typ.__name__} was provided. Please" + " pass your data in a list or as a function yielding items. If you want to process" + " just one data item, enclose it in a list.", + ) class InvalidResourceDataTypeFunctionNotAGenerator(InvalidResourceDataType): - def __init__(self, resource_name: str, item: Any,_typ: Type[Any]) -> None: - super().__init__(resource_name, item, _typ, "Please make sure that function decorated with @dlt.resource uses 'yield' to return the data.") + def __init__(self, resource_name: str, item: Any, _typ: Type[Any]) -> None: + super().__init__( + resource_name, + item, + _typ, + "Please make sure that function decorated with @dlt.resource uses 'yield' to return the" + " data.", + ) class InvalidResourceDataTypeMultiplePipes(InvalidResourceDataType): - def __init__(self, resource_name: str, item: Any,_typ: Type[Any]) -> None: - super().__init__(resource_name, item, _typ, "Resources with multiple parallel data pipes are not yet supported. This problem most often happens when you are creating a source with @dlt.source decorator that has several resources with the same name.") + def __init__(self, resource_name: str, item: Any, _typ: Type[Any]) -> None: + super().__init__( + resource_name, + item, + _typ, + "Resources with multiple parallel data pipes are not yet supported. This problem most" + " often happens when you are creating a source with @dlt.source decorator that has" + " several resources with the same name.", + ) class InvalidTransformerDataTypeGeneratorFunctionRequired(InvalidResourceDataType): def __init__(self, resource_name: str, item: Any, _typ: Type[Any]) -> None: - super().__init__(resource_name, item, _typ, - "Transformer must be a function decorated with @dlt.transformer that takes data item as its first argument. Only first argument may be 'positional only'.") + super().__init__( + resource_name, + item, + _typ, + "Transformer must be a function decorated with @dlt.transformer that takes data item as" + " its first argument. Only first argument may be 'positional only'.", + ) class InvalidTransformerGeneratorFunction(DltResourceException): @@ -131,29 +187,57 @@ def __init__(self, resource_name: str, func_name: str, sig: Signature, code: int class ResourceInnerCallableConfigWrapDisallowed(DltResourceException): def __init__(self, resource_name: str, section: str) -> None: self.section = section - msg = f"Resource {resource_name} in section {section} is defined over an inner function and requests config/secrets in its arguments. Requesting secret and config values via 'dlt.secrets.values' or 'dlt.config.value' is disallowed for resources that are inner functions. Use the dlt.source to get the required configuration and pass them explicitly to your source." + msg = ( + f"Resource {resource_name} in section {section} is defined over an inner function and" + " requests config/secrets in its arguments. Requesting secret and config values via" + " 'dlt.secrets.values' or 'dlt.config.value' is disallowed for resources that are" + " inner functions. Use the dlt.source to get the required configuration and pass them" + " explicitly to your source." + ) super().__init__(resource_name, msg) class InvalidResourceDataTypeIsNone(InvalidResourceDataType): def __init__(self, resource_name: str, item: Any, _typ: Type[Any]) -> None: - super().__init__(resource_name, item, _typ, "Resource data missing. Did you forget the return statement in @dlt.resource decorated function?") + super().__init__( + resource_name, + item, + _typ, + "Resource data missing. Did you forget the return statement in @dlt.resource decorated" + " function?", + ) class ResourceFunctionExpected(InvalidResourceDataType): def __init__(self, resource_name: str, item: Any, _typ: Type[Any]) -> None: - super().__init__(resource_name, item, _typ, f"Expected function or callable as first parameter to resource {resource_name} but {_typ.__name__} found. Please decorate a function with @dlt.resource") + super().__init__( + resource_name, + item, + _typ, + f"Expected function or callable as first parameter to resource {resource_name} but" + f" {_typ.__name__} found. Please decorate a function with @dlt.resource", + ) class InvalidParentResourceDataType(InvalidResourceDataType): - def __init__(self, resource_name: str, item: Any,_typ: Type[Any]) -> None: - super().__init__(resource_name, item, _typ, f"A parent resource of {resource_name} is of type {_typ.__name__}. Did you forget to use '@dlt.resource` decorator or `resource` function?") + def __init__(self, resource_name: str, item: Any, _typ: Type[Any]) -> None: + super().__init__( + resource_name, + item, + _typ, + f"A parent resource of {resource_name} is of type {_typ.__name__}. Did you forget to" + " use '@dlt.resource` decorator or `resource` function?", + ) class InvalidParentResourceIsAFunction(DltResourceException): def __init__(self, resource_name: str, func_name: str) -> None: self.func_name = func_name - super().__init__(resource_name, f"A data source {func_name} of a transformer {resource_name} is an undecorated function. Please decorate it with '@dlt.resource' or pass to 'resource' function.") + super().__init__( + resource_name, + f"A data source {func_name} of a transformer {resource_name} is an undecorated" + " function. Please decorate it with '@dlt.resource' or pass to 'resource' function.", + ) class DeletingResourcesNotSupported(DltResourceException): @@ -162,10 +246,16 @@ def __init__(self, source_name: str, resource_name: str) -> None: class ParametrizedResourceUnbound(DltResourceException): - def __init__(self, resource_name: str, func_name: str, sig: Signature, kind: str, error: str) -> None: + def __init__( + self, resource_name: str, func_name: str, sig: Signature, kind: str, error: str + ) -> None: self.func_name = func_name self.sig = sig - msg = f"The {kind} {resource_name} is parametrized and expects following arguments: {sig}. Did you forget to bind the {func_name} function? For example from `source.{resource_name}.bind(...)" + msg = ( + f"The {kind} {resource_name} is parametrized and expects following arguments: {sig}." + f" Did you forget to bind the {func_name} function? For example from" + f" `source.{resource_name}.bind(...)" + ) if error: msg += f" .Details: {error}" super().__init__(resource_name, msg) @@ -178,7 +268,9 @@ def __init__(self, resource_name: str, msg: str) -> None: class TableNameMissing(DltSourceException): def __init__(self) -> None: - super().__init__("""Table name is missing in table template. Please provide a string or a function that takes a data item as an argument""") + super().__init__( + """Table name is missing in table template. Please provide a string or a function that takes a data item as an argument""" + ) class InconsistentTableTemplate(DltSourceException): @@ -189,29 +281,43 @@ def __init__(self, reason: str) -> None: class DataItemRequiredForDynamicTableHints(DltResourceException): def __init__(self, resource_name: str) -> None: - super().__init__(resource_name, f"""An instance of resource's data required to generate table schema in resource {resource_name}. - One of table hints for that resource (typically table name) is a function and hint is computed separately for each instance of data extracted from that resource.""") + super().__init__( + resource_name, + f"""An instance of resource's data required to generate table schema in resource {resource_name}. + One of table hints for that resource (typically table name) is a function and hint is computed separately for each instance of data extracted from that resource.""", + ) class SourceDataIsNone(DltSourceException): def __init__(self, source_name: str) -> None: self.source_name = source_name - super().__init__(f"No data returned or yielded from source function {source_name}. Did you forget the return statement?") + super().__init__( + f"No data returned or yielded from source function {source_name}. Did you forget the" + " return statement?" + ) class SourceExhausted(DltSourceException): def __init__(self, source_name: str) -> None: self.source_name = source_name - super().__init__(f"Source {source_name} is exhausted or has active iterator. You can iterate or pass the source to dlt pipeline only once.") + super().__init__( + f"Source {source_name} is exhausted or has active iterator. You can iterate or pass the" + " source to dlt pipeline only once." + ) class ResourcesNotFoundError(DltSourceException): - def __init__(self, source_name: str, available_resources: Set[str], requested_resources: Set[str]) -> None: + def __init__( + self, source_name: str, available_resources: Set[str], requested_resources: Set[str] + ) -> None: self.source_name = source_name self.available_resources = available_resources self.requested_resources = requested_resources self.not_found_resources = requested_resources.difference(available_resources) - msg = f"The following resources could not be found in source {source_name}: {self.not_found_resources}. Available resources are: {available_resources}" + msg = ( + f"The following resources could not be found in source {source_name}:" + f" {self.not_found_resources}. Available resources are: {available_resources}" + ) super().__init__(msg) @@ -220,28 +326,48 @@ def __init__(self, source_name: str, item: Any, _typ: Type[Any]) -> None: self.source_name = source_name self.item = item self.typ = _typ - super().__init__(f"First parameter to the source {source_name} must be a function or callable but is {_typ.__name__}. Please decorate a function with @dlt.source") + super().__init__( + f"First parameter to the source {source_name} must be a function or callable but is" + f" {_typ.__name__}. Please decorate a function with @dlt.source" + ) class SourceIsAClassTypeError(DltSourceException): - def __init__(self, source_name: str, _typ: Type[Any]) -> None: + def __init__(self, source_name: str, _typ: Type[Any]) -> None: self.source_name = source_name self.typ = _typ - super().__init__(f"First parameter to the source {source_name} is a class {_typ.__name__}. Do not decorate classes with @dlt.source. Instead implement __call__ in your class and pass instance of such class to dlt.source() directly") + super().__init__( + f"First parameter to the source {source_name} is a class {_typ.__name__}. Do not" + " decorate classes with @dlt.source. Instead implement __call__ in your class and pass" + " instance of such class to dlt.source() directly" + ) class SourceSchemaNotAvailable(DltSourceException): def __init__(self) -> None: - super().__init__("Current source schema is available only when called from a function decorated with dlt.source or dlt.resource") + super().__init__( + "Current source schema is available only when called from a function decorated with" + " dlt.source or dlt.resource" + ) class ExplicitSourceNameInvalid(DltSourceException): def __init__(self, source_name: str, schema_name: str) -> None: self.source_name = source_name self.schema_name = schema_name - super().__init__(f"Your explicit source name {source_name} is not a valid schema name. Please use a valid schema name ie. '{schema_name}'.") + super().__init__( + f"Your explicit source name {source_name} is not a valid schema name. Please use a" + f" valid schema name ie. '{schema_name}'." + ) class IncrementalUnboundError(DltResourceException): def __init__(self, cursor_path: str) -> None: - super().__init__("", f"The incremental definition with cursor path {cursor_path} is used without being bound to the resource. This most often happens when you create dynamic resource from a generator function that uses incremental. See https://dlthub.com/docs/general-usage/incremental-loading#incremental-loading-with-last-value for an example.") + super().__init__( + "", + f"The incremental definition with cursor path {cursor_path} is used without being bound" + " to the resource. This most often happens when you create dynamic resource from a" + " generator function that uses incremental. See" + " https://dlthub.com/docs/general-usage/incremental-loading#incremental-loading-with-last-value" + " for an example.", + ) diff --git a/dlt/extract/extract.py b/dlt/extract/extract.py index d060c080d1..a10c92b234 100644 --- a/dlt/extract/extract.py +++ b/dlt/extract/extract.py @@ -4,17 +4,15 @@ from dlt.common.configuration.container import Container from dlt.common.configuration.resolve import inject_section +from dlt.common.configuration.specs import known_sections from dlt.common.configuration.specs.config_section_context import ConfigSectionContext from dlt.common.pipeline import _reset_resource_state - from dlt.common.runtime import signals -from dlt.common.runtime.collector import Collector, NULL_COLLECTOR +from dlt.common.runtime.collector import NULL_COLLECTOR, Collector +from dlt.common.schema import Schema, TSchemaUpdate, utils +from dlt.common.storages import DataItemStorage, NormalizeStorage, NormalizeStorageConfiguration +from dlt.common.typing import TDataItem, TDataItems from dlt.common.utils import uniq_id -from dlt.common.typing import TDataItems, TDataItem -from dlt.common.schema import Schema, utils, TSchemaUpdate -from dlt.common.storages import NormalizeStorageConfiguration, NormalizeStorage, DataItemStorage -from dlt.common.configuration.specs import known_sections - from dlt.extract.decorators import SourceSchemaInjectableContext from dlt.extract.exceptions import DataItemRequiredForDynamicTableHints from dlt.extract.pipe import PipeIterator @@ -64,9 +62,8 @@ def extract( *, max_parallel_items: int = None, workers: int = None, - futures_poll_interval: float = None + futures_poll_interval: float = None, ) -> TSchemaUpdate: - dynamic_tables: TSchemaUpdate = {} schema = source.schema resources_with_items: Set[str] = set() @@ -111,11 +108,15 @@ def _write_static_table(resource: DltResource, table_name: str) -> None: dynamic_tables[table_name] = [static_table] # yield from all selected pipes - with PipeIterator.from_pipes(source.resources.selected_pipes, max_parallel_items=max_parallel_items, workers=workers, futures_poll_interval=futures_poll_interval) as pipes: + with PipeIterator.from_pipes( + source.resources.selected_pipes, + max_parallel_items=max_parallel_items, + workers=workers, + futures_poll_interval=futures_poll_interval, + ) as pipes: left_gens = total_gens = len(pipes._sources) collector.update("Resources", 0, total_gens) for pipe_item in pipes: - curr_gens = len(pipes._sources) if left_gens > curr_gens: delta = left_gens - curr_gens @@ -174,24 +175,35 @@ def extract_with_schema( schema: Schema, collector: Collector, max_parallel_items: int, - workers: int + workers: int, ) -> str: # generate extract_id to be able to commit all the sources together later extract_id = storage.create_extract_id() with Container().injectable_context(SourceSchemaInjectableContext(schema)): # inject the config section with the current source name - with inject_section(ConfigSectionContext(sections=(known_sections.SOURCES, source.section, source.name), source_state_key=source.name)): + with inject_section( + ConfigSectionContext( + sections=(known_sections.SOURCES, source.section, source.name), + source_state_key=source.name, + ) + ): # reset resource states for resource in source.resources.extracted.values(): with contextlib.suppress(DataItemRequiredForDynamicTableHints): if resource.write_disposition == "replace": _reset_resource_state(resource._name) - extractor = extract(extract_id, source, storage, collector, max_parallel_items=max_parallel_items, workers=workers) + extractor = extract( + extract_id, + source, + storage, + collector, + max_parallel_items=max_parallel_items, + workers=workers, + ) # iterate over all items in the pipeline and update the schema if dynamic table hints were present for _, partials in extractor.items(): for partial in partials: schema.update_schema(schema.normalize_table_identifiers(partial)) return extract_id - diff --git a/dlt/extract/incremental.py b/dlt/extract/incremental.py index 049e2c947d..9d33ac2342 100644 --- a/dlt/extract/incremental.py +++ b/dlt/extract/incremental.py @@ -1,26 +1,47 @@ -import os -from typing import Generic, TypeVar, Any, Optional, Callable, List, TypedDict, get_args, get_origin, Sequence, Type import inspect -from functools import wraps +import os from datetime import datetime # noqa: I251 +from functools import wraps +from typing import ( + Any, + Callable, + Generic, + List, + Optional, + Sequence, + Type, + TypedDict, + TypeVar, + get_args, + get_origin, +) import dlt -from dlt.common import pendulum, logger -from dlt.common.json import json -from dlt.common.jsonpath import compile_path, find_values, JSONPath -from dlt.common.typing import TDataItem, TDataItems, TFun, extract_inner_type, get_generic_type_argument_from_instance, is_optional_type -from dlt.common.schema.typing import TColumnNames -from dlt.common.configuration import configspec, ConfigurationValueError +from dlt.common import logger, pendulum +from dlt.common.configuration import ConfigurationValueError, configspec from dlt.common.configuration.specs import BaseConfiguration +from dlt.common.data_types.type_helpers import ( + coerce_from_date_types, + coerce_value, + py_type_to_sc_type, +) +from dlt.common.json import json +from dlt.common.jsonpath import JSONPath, compile_path, find_values from dlt.common.pipeline import resource_state +from dlt.common.schema.typing import TColumnNames +from dlt.common.typing import ( + TDataItem, + TDataItems, + TFun, + extract_inner_type, + get_generic_type_argument_from_instance, + is_optional_type, +) from dlt.common.utils import digest128 -from dlt.common.data_types.type_helpers import coerce_from_date_types, coerce_value, py_type_to_sc_type - from dlt.extract.exceptions import IncrementalUnboundError, PipeException from dlt.extract.pipe import Pipe -from dlt.extract.utils import resolve_column_value from dlt.extract.typing import FilterItem, SupportsPipe, TTableHintTemplate - +from dlt.extract.utils import resolve_column_value TCursorValue = TypeVar("TCursorValue", bound=Any) LastValueFunc = Callable[[Sequence[TCursorValue]], Any] @@ -36,7 +57,11 @@ class IncrementalCursorPathMissing(PipeException): def __init__(self, pipe_name: str, json_path: str, item: TDataItem) -> None: self.json_path = json_path self.item = item - msg = f"Cursor element with JSON path {json_path} was not found in extracted data item. All data items must contain this path. Use the same names of fields as in your JSON document - if those are different from the names you see in database." + msg = ( + f"Cursor element with JSON path {json_path} was not found in extracted data item. All" + " data items must contain this path. Use the same names of fields as in your JSON" + " document - if those are different from the names you see in database." + ) super().__init__(pipe_name, msg) @@ -44,7 +69,11 @@ class IncrementalPrimaryKeyMissing(PipeException): def __init__(self, pipe_name: str, primary_key_column: str, item: TDataItem) -> None: self.primary_key_column = primary_key_column self.item = item - msg = f"Primary key column {primary_key_column} was not found in extracted data item. All data items must contain this column. Use the same names of fields as in your JSON document." + msg = ( + f"Primary key column {primary_key_column} was not found in extracted data item. All" + " data items must contain this column. Use the same names of fields as in your JSON" + " document." + ) super().__init__(pipe_name, msg) @@ -86,19 +115,20 @@ class Incremental(FilterItem, BaseConfiguration, Generic[TCursorValue]): The values passed explicitly to Incremental will be ignored. Note that if logical "end date" is present then also "end_value" will be set which means that resource state is not used and exactly this range of date will be loaded """ + cursor_path: str = None # TODO: Support typevar here initial_value: Optional[Any] = None end_value: Optional[Any] = None def __init__( - self, - cursor_path: str = dlt.config.value, - initial_value: Optional[TCursorValue]=None, - last_value_func: Optional[LastValueFunc[TCursorValue]]=max, - primary_key: Optional[TTableHintTemplate[TColumnNames]] = None, - end_value: Optional[TCursorValue] = None, - allow_external_schedulers: bool = False + self, + cursor_path: str = dlt.config.value, + initial_value: Optional[TCursorValue] = None, + last_value_func: Optional[LastValueFunc[TCursorValue]] = max, + primary_key: Optional[TTableHintTemplate[TColumnNames]] = None, + end_value: Optional[TCursorValue] = None, + allow_external_schedulers: bool = False, ) -> None: self.cursor_path = cursor_path if self.cursor_path: @@ -123,7 +153,9 @@ def __init__( """Becomes true on the first item that is out of range of `initial_value`. I.e. when using `max` this is a value that is lower than `initial_value`""" @classmethod - def from_existing_state(cls, resource_name: str, cursor_path: str) -> "Incremental[TCursorValue]": + def from_existing_state( + cls, resource_name: str, cursor_path: str + ) -> "Incremental[TCursorValue]": """Create Incremental instance from existing state.""" state = Incremental._get_state(resource_name, cursor_path) i = cls(cursor_path, state["initial_value"]) @@ -139,7 +171,7 @@ def copy(self) -> "Incremental[TCursorValue]": last_value_func=self.last_value_func, primary_key=self.primary_key, end_value=self.end_value, - allow_external_schedulers=self.allow_external_schedulers + allow_external_schedulers=self.allow_external_schedulers, ) def merge(self, other: "Incremental[TCursorValue]") -> "Incremental[TCursorValue]": @@ -154,32 +186,45 @@ def merge(self, other: "Incremental[TCursorValue]") -> "Incremental[TCursorValue """ kwargs = dict(self, last_value_func=self.last_value_func, primary_key=self.primary_key) for key, value in dict( - other, - last_value_func=other.last_value_func, primary_key=other.primary_key).items(): + other, last_value_func=other.last_value_func, primary_key=other.primary_key + ).items(): if value is not None: kwargs[key] = value # preserve Generic param information if hasattr(self, "__orig_class__"): constructor = self.__orig_class__ else: - constructor = other.__orig_class__ if hasattr(other, "__orig_class__") else other.__class__ + constructor = ( + other.__orig_class__ if hasattr(other, "__orig_class__") else other.__class__ + ) return constructor(**kwargs) # type: ignore def on_resolved(self) -> None: self.cursor_path_p = compile_path(self.cursor_path) if self.end_value is not None and self.initial_value is None: raise ConfigurationValueError( - "Incremental 'end_value' was specified without 'initial_value'. 'initial_value' is required when using 'end_value'." + "Incremental 'end_value' was specified without 'initial_value'. 'initial_value' is" + " required when using 'end_value'." ) # Ensure end value is "higher" than initial value - if self.end_value is not None and self.last_value_func([self.end_value, self.initial_value]) != self.end_value: + if ( + self.end_value is not None + and self.last_value_func([self.end_value, self.initial_value]) != self.end_value + ): if self.last_value_func in (min, max): - adject = 'higher' if self.last_value_func is max else 'lower' - msg = f"Incremental 'initial_value' ({self.initial_value}) is {adject} than 'end_value` ({self.end_value}). 'end_value' must be {adject} than 'initial_value'" + adject = "higher" if self.last_value_func is max else "lower" + msg = ( + f"Incremental 'initial_value' ({self.initial_value}) is {adject} than" + f" 'end_value` ({self.end_value}). 'end_value' must be {adject} than" + " 'initial_value'" + ) else: msg = ( - f"Incremental 'initial_value' ({self.initial_value}) is greater than 'end_value' ({self.end_value}) as determined by the custom 'last_value_func'. " - f"The result of '{self.last_value_func.__name__}([end_value, initial_value])' must equal 'end_value'" + f"Incremental 'initial_value' ({self.initial_value}) is greater than" + f" 'end_value' ({self.end_value}) as determined by the custom" + " 'last_value_func'. The result of" + f" '{self.last_value_func.__name__}([end_value, initial_value])' must equal" + " 'end_value'" ) raise ConfigurationValueError(msg) @@ -205,9 +250,9 @@ def get_state(self) -> IncrementalColumnState: if self.end_value is not None: # End value uses mock state. We don't want to write it. return { - 'initial_value': self.initial_value, - 'last_value': self.initial_value, - 'unique_hashes': [] + "initial_value": self.initial_value, + "last_value": self.initial_value, + "unique_hashes": [], } self._cached_state = Incremental._get_state(self.resource_name, self.cursor_path) @@ -217,26 +262,30 @@ def get_state(self) -> IncrementalColumnState: { "initial_value": self.initial_value, "last_value": self.initial_value, - 'unique_hashes': [] + "unique_hashes": [], } ) return self._cached_state @staticmethod def _get_state(resource_name: str, cursor_path: str) -> IncrementalColumnState: - state: IncrementalColumnState = resource_state(resource_name).setdefault('incremental', {}).setdefault(cursor_path, {}) + state: IncrementalColumnState = ( + resource_state(resource_name).setdefault("incremental", {}).setdefault(cursor_path, {}) + ) # if state params is empty return state @property def last_value(self) -> Optional[TCursorValue]: s = self.get_state() - return s['last_value'] # type: ignore + return s["last_value"] # type: ignore def unique_value(self, row: TDataItem) -> str: try: if self.primary_key: - return digest128(json.dumps(resolve_column_value(self.primary_key, row), sort_keys=True)) + return digest128( + json.dumps(resolve_column_value(self.primary_key, row), sort_keys=True) + ) elif self.primary_key is None: return digest128(json.dumps(row, sort_keys=True)) else: @@ -259,33 +308,36 @@ def transform(self, row: TDataItem) -> bool: row_value = pendulum.instance(row_value) incremental_state = self._cached_state - last_value = incremental_state['last_value'] + last_value = incremental_state["last_value"] last_value_func = self.last_value_func # Check whether end_value has been reached # Filter end value ranges exclusively, so in case of "max" function we remove values >= end_value if self.end_value is not None and ( - last_value_func((row_value, self.end_value)) != self.end_value or last_value_func((row_value, )) == self.end_value + last_value_func((row_value, self.end_value)) != self.end_value + or last_value_func((row_value,)) == self.end_value ): self.end_out_of_range = True return False - check_values = (row_value,) + ((last_value, ) if last_value is not None else ()) + check_values = (row_value,) + ((last_value,) if last_value is not None else ()) new_value = last_value_func(check_values) if last_value == new_value: - processed_row_value = last_value_func((row_value, )) + processed_row_value = last_value_func((row_value,)) # we store row id for all records with the current "last_value" in state and use it to deduplicate if processed_row_value == last_value: unique_value = self.unique_value(row) # if unique value exists then use it to deduplicate if unique_value: - if unique_value in incremental_state['unique_hashes']: + if unique_value in incremental_state["unique_hashes"]: return False # add new hash only if the record row id is same as current last value - incremental_state['unique_hashes'].append(unique_value) + incremental_state["unique_hashes"].append(unique_value) return True # skip the record that is not a last_value or new_value: that record was already processed - check_values = (row_value,) + ((self.start_value,) if self.start_value is not None else ()) + check_values = (row_value,) + ( + (self.start_value,) if self.start_value is not None else () + ) new_value = last_value_func(check_values) # Include rows == start_value but exclude "lower" if new_value == self.start_value and processed_row_value != self.start_value: @@ -307,8 +359,8 @@ def get_incremental_value_type(self) -> Type[Any]: def _join_external_scheduler(self) -> None: """Detects existence of external scheduler from which `start_value` and `end_value` are taken. Detects Airflow and environment variables. - The logical "start date" coming from external scheduler will set the `initial_value` in incremental. if additionally logical "end date" is - present then also "end_value" will be set which means that resource state is not used and exactly this range of date will be loaded + The logical "start date" coming from external scheduler will set the `initial_value` in incremental. if additionally logical "end date" is + present then also "end_value" will be set which means that resource state is not used and exactly this range of date will be loaded """ # fit the pendulum into incremental type param_type = self.get_incremental_value_type() @@ -317,14 +369,22 @@ def _join_external_scheduler(self) -> None: if param_type is not Any: data_type = py_type_to_sc_type(param_type) except Exception as ex: - logger.warning(f"Specified Incremental last value type {param_type} is not supported. Please use DateTime, Date, float, int or str to join external schedulers.({ex})") + logger.warning( + f"Specified Incremental last value type {param_type} is not supported. Please use" + f" DateTime, Date, float, int or str to join external schedulers.({ex})" + ) if param_type is Any: - logger.warning("Could not find the last value type of Incremental class participating in external schedule. " - "Please add typing when declaring incremental argument in your resource or pass initial_value from which the type can be inferred.") + logger.warning( + "Could not find the last value type of Incremental class participating in external" + " schedule. Please add typing when declaring incremental argument in your resource" + " or pass initial_value from which the type can be inferred." + ) return - def _ensure_airflow_end_date(start_date: pendulum.DateTime, end_date: pendulum.DateTime) -> Optional[pendulum.DateTime]: + def _ensure_airflow_end_date( + start_date: pendulum.DateTime, end_date: pendulum.DateTime + ) -> Optional[pendulum.DateTime]: """if end_date is in the future or same as start date (manual run), set it to None so dlt state is used for incremental loading""" now = pendulum.now() if end_date is None or end_date > now or start_date == end_date: @@ -334,6 +394,7 @@ def _ensure_airflow_end_date(start_date: pendulum.DateTime, end_date: pendulum.D try: # we can move it to separate module when we have more of those from airflow.operators.python import get_current_context # noqa + context = get_current_context() start_date = context["data_interval_start"] end_date = _ensure_airflow_end_date(start_date, context["data_interval_end"]) @@ -342,10 +403,17 @@ def _ensure_airflow_end_date(start_date: pendulum.DateTime, end_date: pendulum.D self.end_value = coerce_from_date_types(data_type, end_date) else: self.end_value = None - logger.info(f"Found Airflow scheduler: initial value: {self.initial_value} from data_interval_start {context['data_interval_start']}, end value: {self.end_value} from data_interval_end {context['data_interval_end']}") + logger.info( + f"Found Airflow scheduler: initial value: {self.initial_value} from" + f" data_interval_start {context['data_interval_start']}, end value:" + f" {self.end_value} from data_interval_end {context['data_interval_end']}" + ) return except TypeError as te: - logger.warning(f"Could not coerce Airflow execution dates into the last value type {param_type}. ({te})") + logger.warning( + f"Could not coerce Airflow execution dates into the last value type {param_type}." + f" ({te})" + ) except Exception: pass @@ -368,20 +436,28 @@ def bind(self, pipe: SupportsPipe) -> "Incremental[TCursorValue]": self._join_external_scheduler() # set initial value from last value, in case of a new state those are equal self.start_value = self.last_value - logger.info(f"Bind incremental on {self.resource_name} with initial_value: {self.initial_value}, start_value: {self.start_value}, end_value: {self.end_value}") + logger.info( + f"Bind incremental on {self.resource_name} with initial_value: {self.initial_value}," + f" start_value: {self.start_value}, end_value: {self.end_value}" + ) # cache state self._cached_state = self.get_state() return self def __str__(self) -> str: - return f"Incremental at {id(self)} for resource {self.resource_name} with cursor path: {self.cursor_path} initial {self.initial_value} lv_func {self.last_value_func}" + return ( + f"Incremental at {id(self)} for resource {self.resource_name} with cursor path:" + f" {self.cursor_path} initial {self.initial_value} lv_func {self.last_value_func}" + ) class IncrementalResourceWrapper(FilterItem): _incremental: Optional[Incremental[Any]] = None """Keeps the injectable incremental""" - def __init__(self, resource_name: str, primary_key: Optional[TTableHintTemplate[TColumnNames]] = None) -> None: + def __init__( + self, resource_name: str, primary_key: Optional[TTableHintTemplate[TColumnNames]] = None + ) -> None: """Creates a wrapper over a resource function that accepts Incremental instance in its argument to perform incremental loading. The wrapper delays instantiation of the Incremental to the moment of actual execution and is currently used by `dlt.resource` decorator. @@ -407,14 +483,15 @@ def get_incremental_arg(sig: inspect.Signature) -> Optional[inspect.Parameter]: for p in sig.parameters.values(): annotation = extract_inner_type(p.annotation) annotation = get_origin(annotation) or annotation - if (inspect.isclass(annotation) and issubclass(annotation, Incremental)) or isinstance(p.default, Incremental): + if (inspect.isclass(annotation) and issubclass(annotation, Incremental)) or isinstance( + p.default, Incremental + ): incremental_param = p break return incremental_param def wrap(self, sig: inspect.Signature, func: TFun) -> TFun: - """Wrap the callable to inject an `Incremental` object configured for the resource. - """ + """Wrap the callable to inject an `Incremental` object configured for the resource.""" incremental_param = self.get_incremental_arg(sig) assert incremental_param, "Please use `should_wrap` to decide if to call this function" @@ -447,7 +524,11 @@ def _wrap(*args: Any, **kwargs: Any) -> Any: return func(*bound_args.args, **bound_args.kwargs) raise ValueError(f"{p.name} Incremental has no default") # pass Generic information from annotation to new_incremental - if not hasattr(new_incremental, "__orig_class__") and p.annotation and get_args(p.annotation): + if ( + not hasattr(new_incremental, "__orig_class__") + and p.annotation + and get_args(p.annotation) + ): new_incremental.__orig_class__ = p.annotation # type: ignore # set the incremental only if not yet set or if it was passed explicitly diff --git a/dlt/extract/pipe.py b/dlt/extract/pipe.py index 62bca76e17..94bb0a837c 100644 --- a/dlt/extract/pipe.py +++ b/dlt/extract/pipe.py @@ -1,24 +1,52 @@ +import asyncio import inspect import types -import asyncio -import makefun from asyncio import Future from concurrent.futures import ThreadPoolExecutor from copy import copy from threading import Thread -from typing import Any, ContextManager, Optional, Sequence, Union, Callable, Iterable, Iterator, List, NamedTuple, Awaitable, Tuple, Type, TYPE_CHECKING, Literal +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + ContextManager, + Iterable, + Iterator, + List, + Literal, + NamedTuple, + Optional, + Sequence, + Tuple, + Type, + Union, +) + +import makefun from dlt.common import sleep from dlt.common.configuration import configspec +from dlt.common.configuration.container import Container from dlt.common.configuration.inject import with_config from dlt.common.configuration.specs import BaseConfiguration, ContainerInjectableContext -from dlt.common.configuration.container import Container from dlt.common.exceptions import PipelineException -from dlt.common.source import unset_current_pipe_name, set_current_pipe_name +from dlt.common.source import set_current_pipe_name, unset_current_pipe_name from dlt.common.typing import AnyFun, AnyType, TDataItems from dlt.common.utils import get_callable_name - -from dlt.extract.exceptions import CreatePipeException, DltSourceException, ExtractorException, InvalidResourceDataTypeFunctionNotAGenerator, InvalidStepFunctionArguments, InvalidTransformerGeneratorFunction, ParametrizedResourceUnbound, PipeException, PipeItemProcessingError, PipeNotBoundToData, ResourceExtractionError +from dlt.extract.exceptions import ( + CreatePipeException, + DltSourceException, + ExtractorException, + InvalidResourceDataTypeFunctionNotAGenerator, + InvalidStepFunctionArguments, + InvalidTransformerGeneratorFunction, + ParametrizedResourceUnbound, + PipeException, + PipeItemProcessingError, + PipeNotBoundToData, + ResourceExtractionError, +) from dlt.extract.typing import DataItemWithMeta, ItemTransform, SupportsPipe, TPipedDataItems if TYPE_CHECKING: @@ -62,7 +90,7 @@ class SourcePipeItem(NamedTuple): Iterator[TPipedDataItems], Callable[[TDataItems, Optional[Any]], TPipedDataItems], Callable[[TDataItems, Optional[Any]], Iterator[TPipedDataItems]], - Callable[[TDataItems, Optional[Any]], Iterator[ResolvablePipeItem]] + Callable[[TDataItems, Optional[Any]], Iterator[ResolvablePipeItem]], ] TPipeNextItemMode = Union[Literal["fifo"], Literal["round_robin"]] @@ -107,7 +135,12 @@ def __init__(self, name: str, steps: List[TPipeStep] = None, parent: "Pipe" = No self.append_step(step) @classmethod - def from_data(cls, name: str, gen: Union[Iterable[TPipedDataItems], Iterator[TPipedDataItems], AnyFun], parent: "Pipe" = None) -> "Pipe": + def from_data( + cls, + name: str, + gen: Union[Iterable[TPipedDataItems], Iterator[TPipedDataItems], AnyFun], + parent: "Pipe" = None, + ) -> "Pipe": return cls(name, [gen], parent=parent) @property @@ -143,7 +176,7 @@ def steps(self) -> List[TPipeStep]: def find(self, *step_type: AnyType) -> int: """Finds a step with object of type `step_type`""" - return next((i for i,v in enumerate(self._steps) if type(v) in step_type), -1) + return next((i for i, v in enumerate(self._steps) if type(v) in step_type), -1) def __getitem__(self, i: int) -> TPipeStep: return self._steps[i] @@ -181,7 +214,11 @@ def insert_step(self, step: TPipeStep, index: int) -> "Pipe": return self.append_step(step) if index == 0: if not self.has_parent: - raise CreatePipeException(self.name, "You cannot insert a step before head of the resource that is not a transformer") + raise CreatePipeException( + self.name, + "You cannot insert a step before head of the resource that is not a" + " transformer", + ) step = self._wrap_transform_step_meta(index, step) # actually insert in the list self._steps.insert(index, step) @@ -193,7 +230,10 @@ def insert_step(self, step: TPipeStep, index: int) -> "Pipe": def remove_step(self, index: int) -> None: """Removes steps at a given index. Gen step cannot be removed""" if index == self._gen_idx: - raise CreatePipeException(self.name, f"Step at index {index} holds a data generator for this pipe and cannot be removed") + raise CreatePipeException( + self.name, + f"Step at index {index} holds a data generator for this pipe and cannot be removed", + ) self._steps.pop(index) if index < self._gen_idx: self._gen_idx -= 1 @@ -234,7 +274,13 @@ def ensure_gen_bound(self) -> None: sig.bind() except TypeError as ex: callable_name = get_callable_name(head) - raise ParametrizedResourceUnbound(self.name, callable_name, sig.replace(parameters=list(sig.parameters.values())[1:]), "resource", str(ex)) + raise ParametrizedResourceUnbound( + self.name, + callable_name, + sig.replace(parameters=list(sig.parameters.values())[1:]), + "resource", + str(ex), + ) def evaluate_gen(self) -> None: """Lazily evaluate gen of the pipe when creating PipeIterator. Allows creating multiple use pipes from generator functions and lists""" @@ -248,7 +294,13 @@ def evaluate_gen(self) -> None: # must be parameter-less callable or parameters must have defaults self.replace_gen(gen()) # type: ignore except TypeError as ex: - raise ParametrizedResourceUnbound(self.name, get_callable_name(gen), inspect.signature(gen), "resource", str(ex)) + raise ParametrizedResourceUnbound( + self.name, + get_callable_name(gen), + inspect.signature(gen), + "resource", + str(ex), + ) # otherwise it must be an iterator if isinstance(gen, Iterable): self.replace_gen(iter(gen)) @@ -294,7 +346,9 @@ def _wrap_gen(self, *args: Any, **kwargs: Any) -> Any: sig = inspect.signature(head) # simulate the call to the underlying callable if args or kwargs: - skip_items_arg = 1 if self.has_parent else 0 # skip the data item argument for transformers + skip_items_arg = ( + 1 if self.has_parent else 0 + ) # skip the data item argument for transformers no_item_sig = sig.replace(parameters=list(sig.parameters.values())[skip_items_arg:]) try: no_item_sig.bind(*args, **kwargs) @@ -303,7 +357,6 @@ def _wrap_gen(self, *args: Any, **kwargs: Any) -> Any: # create wrappers with partial if self.has_parent: - if len(sig.parameters) == 2 and "meta" in sig.parameters: return head @@ -337,30 +390,44 @@ def _partial() -> Any: def _verify_head_step(self, step: TPipeStep) -> None: # first element must be Iterable, Iterator or Callable in resource pipe if not isinstance(step, (Iterable, Iterator)) and not callable(step): - raise CreatePipeException(self.name, "A head of a resource pipe must be Iterable, Iterator or a Callable") + raise CreatePipeException( + self.name, "A head of a resource pipe must be Iterable, Iterator or a Callable" + ) def _wrap_transform_step_meta(self, step_no: int, step: TPipeStep) -> TPipeStep: # step must be a callable: a transformer or a transformation if isinstance(step, (Iterable, Iterator)) and not callable(step): if self.has_parent: - raise CreatePipeException(self.name, "Iterable or Iterator cannot be a step in transformer pipe") + raise CreatePipeException( + self.name, "Iterable or Iterator cannot be a step in transformer pipe" + ) else: - raise CreatePipeException(self.name, "Iterable or Iterator can only be a first step in resource pipe") + raise CreatePipeException( + self.name, "Iterable or Iterator can only be a first step in resource pipe" + ) if not callable(step): - raise CreatePipeException(self.name, "Pipe step must be a callable taking one data item as argument and optional second meta argument") + raise CreatePipeException( + self.name, + "Pipe step must be a callable taking one data item as argument and optional second" + " meta argument", + ) else: # check the signature sig = inspect.signature(step) sig_arg_count = len(sig.parameters) callable_name = get_callable_name(step) if sig_arg_count == 0: - raise InvalidStepFunctionArguments(self.name, callable_name, sig, "Function takes no arguments") + raise InvalidStepFunctionArguments( + self.name, callable_name, sig, "Function takes no arguments" + ) # see if meta is present in kwargs meta_arg = next((p for p in sig.parameters.values() if p.name == "meta"), None) if meta_arg is not None: if meta_arg.kind not in (meta_arg.KEYWORD_ONLY, meta_arg.POSITIONAL_OR_KEYWORD): - raise InvalidStepFunctionArguments(self.name, callable_name, sig, "'meta' cannot be pos only argument '") + raise InvalidStepFunctionArguments( + self.name, callable_name, sig, "'meta' cannot be pos only argument '" + ) elif meta_arg is None: # add meta parameter when not present orig_step = step @@ -371,16 +438,17 @@ def _partial(*args: Any, **kwargs: Any) -> Any: return orig_step(*args, **kwargs) step = makefun.wraps( - step, - append_args=inspect.Parameter("meta", inspect._ParameterKind.KEYWORD_ONLY, default=None) - )(_partial) + step, + append_args=inspect.Parameter( + "meta", inspect._ParameterKind.KEYWORD_ONLY, default=None + ), + )(_partial) # verify the step callable, gen may be parametrized and will be evaluated at run time if not self.is_empty: self._ensure_transform_step(step_no, step) return step - def _ensure_transform_step(self, step_no: int, step: TPipeStep) -> None: """Verifies that `step` is a valid callable to be a transform step of the pipeline""" assert callable(step), f"{step} must be callable" @@ -397,7 +465,13 @@ def _ensure_transform_step(self, step_no: int, step: TPipeStep) -> None: raise InvalidTransformerGeneratorFunction(self.name, callable_name, sig, code=1) else: # show the sig without first argument - raise ParametrizedResourceUnbound(self.name, callable_name, sig.replace(parameters=list(sig.parameters.values())[1:]), "transformer", str(ty_ex)) + raise ParametrizedResourceUnbound( + self.name, + callable_name, + sig.replace(parameters=list(sig.parameters.values())[1:]), + "transformer", + str(ty_ex), + ) else: raise InvalidStepFunctionArguments(self.name, callable_name, sig, str(ty_ex)) @@ -415,11 +489,13 @@ def __repr__(self) -> str: bound_str = " data bound to " + repr(self.parent) else: bound_str = "" - return f"Pipe {self.name} ({self._pipe_id})[steps: {len(self._steps)}] at {id(self)}{bound_str}" + return ( + f"Pipe {self.name} ({self._pipe_id})[steps: {len(self._steps)}] at" + f" {id(self)}{bound_str}" + ) class PipeIterator(Iterator[PipeItem]): - @configspec class PipeIteratorConfiguration(BaseConfiguration): max_parallel_items: int = 20 @@ -430,7 +506,13 @@ class PipeIteratorConfiguration(BaseConfiguration): __section__ = "extract" - def __init__(self, max_parallel_items: int, workers: int, futures_poll_interval: float, next_item_mode: TPipeNextItemMode) -> None: + def __init__( + self, + max_parallel_items: int, + workers: int, + futures_poll_interval: float, + next_item_mode: TPipeNextItemMode, + ) -> None: self.max_parallel_items = max_parallel_items self.workers = workers self.futures_poll_interval = futures_poll_interval @@ -446,7 +528,15 @@ def __init__(self, max_parallel_items: int, workers: int, futures_poll_interval: @classmethod @with_config(spec=PipeIteratorConfiguration) - def from_pipe(cls, pipe: Pipe, *, max_parallel_items: int = 20, workers: int = 5, futures_poll_interval: float = 0.01, next_item_mode: TPipeNextItemMode = "fifo") -> "PipeIterator": + def from_pipe( + cls, + pipe: Pipe, + *, + max_parallel_items: int = 20, + workers: int = 5, + futures_poll_interval: float = 0.01, + next_item_mode: TPipeNextItemMode = "fifo", + ) -> "PipeIterator": # join all dependent pipes if pipe.parent: pipe = pipe.full_pipe() @@ -473,15 +563,13 @@ def from_pipes( workers: int = 5, futures_poll_interval: float = 0.01, copy_on_fork: bool = False, - next_item_mode: TPipeNextItemMode = "fifo" + next_item_mode: TPipeNextItemMode = "fifo", ) -> "PipeIterator": - # print(f"max_parallel_items: {max_parallel_items} workers: {workers}") extract = cls(max_parallel_items, workers, futures_poll_interval, next_item_mode) # clone all pipes before iterating (recursively) as we will fork them (this add steps) and evaluate gens pipes = PipeIterator.clone_pipes(pipes) - def _fork_pipeline(pipe: Pipe) -> None: if pipe.parent: # fork the parent pipe @@ -536,7 +624,9 @@ def __next__(self) -> PipeItem: # if item is iterator, then add it as a new source if isinstance(item, Iterator): # print(f"adding iterable {item}") - self._sources.append(SourcePipeItem(item, pipe_item.step, pipe_item.pipe, pipe_item.meta)) + self._sources.append( + SourcePipeItem(item, pipe_item.step, pipe_item.pipe, pipe_item.meta) + ) pipe_item = None continue @@ -563,7 +653,11 @@ def __next__(self) -> PipeItem: # must be resolved if isinstance(item, (Iterator, Awaitable)) or callable(item): raise PipeItemProcessingError( - pipe_item.pipe.name, f"Pipe item at step {pipe_item.step} was not fully evaluated and is of type {type(pipe_item.item).__name__}. This is internal error or you are yielding something weird from resources ie. functions or awaitables.") + pipe_item.pipe.name, + f"Pipe item at step {pipe_item.step} was not fully evaluated and is of type" + f" {type(pipe_item.item).__name__}. This is internal error or you are" + " yielding something weird from resources ie. functions or awaitables.", + ) # mypy not able to figure out that item was resolved return pipe_item # type: ignore @@ -578,14 +672,23 @@ def __next__(self) -> PipeItem: next_item = next_item.data except TypeError as ty_ex: assert callable(step) - raise InvalidStepFunctionArguments(pipe_item.pipe.name, get_callable_name(step), inspect.signature(step), str(ty_ex)) + raise InvalidStepFunctionArguments( + pipe_item.pipe.name, + get_callable_name(step), + inspect.signature(step), + str(ty_ex), + ) except (PipelineException, ExtractorException, DltSourceException, PipeException): raise except Exception as ex: - raise ResourceExtractionError(pipe_item.pipe.name, step, str(ex), "transform") from ex + raise ResourceExtractionError( + pipe_item.pipe.name, step, str(ex), "transform" + ) from ex # create next pipe item if a value was returned. A None means that item was consumed/filtered out and should not be further processed if next_item is not None: - pipe_item = ResolvablePipeItem(next_item, pipe_item.step + 1, pipe_item.pipe, next_meta) + pipe_item = ResolvablePipeItem( + next_item, pipe_item.step + 1, pipe_item.pipe, next_meta + ) else: pipe_item = None @@ -629,7 +732,9 @@ def start_background_loop(loop: asyncio.AbstractEventLoop) -> None: loop.run_forever() self._async_pool = asyncio.new_event_loop() - self._async_pool_thread = Thread(target=start_background_loop, args=(self._async_pool,), daemon=True) + self._async_pool_thread = Thread( + target=start_background_loop, args=(self._async_pool,), daemon=True + ) self._async_pool_thread.start() # start or return async pool @@ -646,7 +751,9 @@ def _ensure_thread_pool(self) -> ThreadPoolExecutor: def __enter__(self) -> "PipeIterator": return self - def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: types.TracebackType) -> None: + def __exit__( + self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: types.TracebackType + ) -> None: self.close() def _next_future(self) -> int: @@ -671,7 +778,9 @@ def _resolve_futures(self) -> ResolvablePipeItem: if future.exception(): ex = future.exception() - if isinstance(ex, (PipelineException, ExtractorException, DltSourceException, PipeException)): + if isinstance( + ex, (PipelineException, ExtractorException, DltSourceException, PipeException) + ): raise ex raise ResourceExtractionError(pipe.name, future, str(ex), "future") from ex @@ -787,6 +896,7 @@ def clone_pipes(pipes: Sequence[Pipe]) -> List[Pipe]: class ManagedPipeIterator(PipeIterator): """A version of the pipe iterator that gets closed automatically on an exception in _next_""" + _ctx: List[ContainerInjectableContext] = None _container: Container = None diff --git a/dlt/extract/schema.py b/dlt/extract/schema.py index 41846b1a7d..736d1e012f 100644 --- a/dlt/extract/schema.py +++ b/dlt/extract/schema.py @@ -1,15 +1,25 @@ -from copy import copy, deepcopy from collections.abc import Mapping as C_Mapping -from typing import List, TypedDict, cast, Any - +from copy import copy, deepcopy +from typing import Any, List, TypedDict, cast + +from dlt.common.schema.typing import ( + TColumnNames, + TColumnProp, + TColumnSchema, + TPartialTableSchema, + TTableSchemaColumns, + TWriteDisposition, +) from dlt.common.schema.utils import DEFAULT_WRITE_DISPOSITION, merge_columns, new_column, new_table -from dlt.common.schema.typing import TColumnNames, TColumnProp, TColumnSchema, TPartialTableSchema, TTableSchemaColumns, TWriteDisposition from dlt.common.typing import TDataItem from dlt.common.validation import validate_dict_ignoring_xkeys - +from dlt.extract.exceptions import ( + DataItemRequiredForDynamicTableHints, + InconsistentTableTemplate, + TableNameMissing, +) from dlt.extract.incremental import Incremental from dlt.extract.typing import TFunHintTemplate, TTableHintTemplate -from dlt.extract.exceptions import DataItemRequiredForDynamicTableHints, InconsistentTableTemplate, TableNameMissing class TTableSchemaTemplate(TypedDict, total=False): @@ -42,7 +52,10 @@ def table_name(self) -> str: @property def write_disposition(self) -> TWriteDisposition: - if self._table_schema_template is None or self._table_schema_template.get("write_disposition") is None: + if ( + self._table_schema_template is None + or self._table_schema_template.get("write_disposition") is None + ): return DEFAULT_WRITE_DISPOSITION w_d = self._table_schema_template.get("write_disposition") if callable(w_d): @@ -50,7 +63,7 @@ def write_disposition(self) -> TWriteDisposition: else: return w_d - def table_schema(self, item: TDataItem = None) -> TPartialTableSchema: + def table_schema(self, item: TDataItem = None) -> TPartialTableSchema: """Computes the table schema based on hints and column definitions passed during resource creation. `item` parameter is used to resolve table hints based on data""" if not self._table_schema_template: return new_table(self._name, resource=self._name) @@ -82,24 +95,26 @@ def apply_hints( columns: TTableHintTemplate[TTableSchemaColumns] = None, primary_key: TTableHintTemplate[TColumnNames] = None, merge_key: TTableHintTemplate[TColumnNames] = None, - incremental: Incremental[Any] = None + incremental: Incremental[Any] = None, ) -> None: """Creates or modifies existing table schema by setting provided hints. Accepts both static and dynamic hints based on data. - This method accepts the same table hints arguments as `dlt.resource` decorator with the following additions. - Skip the argument or pass None to leave the existing hint. - Pass empty value (for particular type ie "" for a string) to remove hint + This method accepts the same table hints arguments as `dlt.resource` decorator with the following additions. + Skip the argument or pass None to leave the existing hint. + Pass empty value (for particular type ie "" for a string) to remove hint - parent_table_name (str, optional): A name of parent table if foreign relation is defined. Please note that if you use merge you must define `root_key` columns explicitly - incremental (Incremental, optional): Enables the incremental loading for a resource. + parent_table_name (str, optional): A name of parent table if foreign relation is defined. Please note that if you use merge you must define `root_key` columns explicitly + incremental (Incremental, optional): Enables the incremental loading for a resource. - Please note that for efficient incremental loading, the resource must be aware of the Incremental by accepting it as one if its arguments and then using is to skip already loaded data. - In non-aware resources, `dlt` will filter out the loaded values, however the resource will yield all the values again. + Please note that for efficient incremental loading, the resource must be aware of the Incremental by accepting it as one if its arguments and then using is to skip already loaded data. + In non-aware resources, `dlt` will filter out the loaded values, however the resource will yield all the values again. """ t = None if not self._table_schema_template: # if there's no template yet, create and set new one - t = self.new_table_template(table_name, parent_table_name, write_disposition, columns, primary_key, merge_key) + t = self.new_table_template( + table_name, parent_table_name, write_disposition, columns, primary_key, merge_key + ) else: # set single hints t = deepcopy(self._table_schema_template) @@ -132,16 +147,18 @@ def set_template(self, table_schema_template: TTableSchemaTemplate) -> None: else: self._table_name_hint_fun = None # check if any other hints in the table template should be inferred from data - self._table_has_other_dynamic_hints = any(callable(v) for k, v in table_schema_template.items() if k != "name") + self._table_has_other_dynamic_hints = any( + callable(v) for k, v in table_schema_template.items() if k != "name" + ) self._table_schema_template = table_schema_template @staticmethod def _resolve_hint(item: TDataItem, hint: TTableHintTemplate[Any]) -> Any: - """Calls each dynamic hint passing a data item""" - if callable(hint): - return hint(item) - else: - return hint + """Calls each dynamic hint passing a data item""" + if callable(hint): + return hint(item) + else: + return hint @staticmethod def _merge_key(hint: TColumnProp, keys: TColumnNames, partial: TPartialTableSchema) -> None: @@ -174,8 +191,8 @@ def new_table_template( write_disposition: TTableHintTemplate[TWriteDisposition] = None, columns: TTableHintTemplate[TTableSchemaColumns] = None, primary_key: TTableHintTemplate[TColumnNames] = None, - merge_key: TTableHintTemplate[TColumnNames] = None - ) -> TTableSchemaTemplate: + merge_key: TTableHintTemplate[TColumnNames] = None, + ) -> TTableSchemaTemplate: if not table_name: raise TableNameMissing() @@ -194,6 +211,10 @@ def new_table_template( if merge_key: new_template["merge_key"] = merge_key # if any of the hints is a function then name must be as well - if any(callable(v) for k, v in new_template.items() if k != "name") and not callable(table_name): - raise InconsistentTableTemplate(f"Table name {table_name} must be a function if any other table hint is a function") + if any(callable(v) for k, v in new_template.items() if k != "name") and not callable( + table_name + ): + raise InconsistentTableTemplate( + f"Table name {table_name} must be a function if any other table hint is a function" + ) return new_template diff --git a/dlt/extract/source.py b/dlt/extract/source.py index 53e54649b0..dc4923823b 100644 --- a/dlt/extract/source.py +++ b/dlt/extract/source.py @@ -1,30 +1,83 @@ -import warnings import contextlib -from copy import copy -import makefun import inspect -from typing import AsyncIterable, AsyncIterator, ClassVar, Callable, ContextManager, Dict, Iterable, Iterator, List, Sequence, Tuple, Union, Any import types +import warnings +from copy import copy +from typing import ( + Any, + AsyncIterable, + AsyncIterator, + Callable, + ClassVar, + ContextManager, + Dict, + Iterable, + Iterator, + List, + Sequence, + Tuple, + Union, +) + +import makefun +from dlt.common.configuration.container import Container from dlt.common.configuration.resolve import inject_section from dlt.common.configuration.specs import known_sections from dlt.common.configuration.specs.config_section_context import ConfigSectionContext -from dlt.common.normalizers.json.relational import DataItemNormalizer as RelationalNormalizer, RelationalNormalizerConfigPropagation +from dlt.common.normalizers.json.relational import DataItemNormalizer as RelationalNormalizer +from dlt.common.normalizers.json.relational import RelationalNormalizerConfigPropagation +from dlt.common.pipeline import ( + PipelineContext, + StateInjectableContext, + SupportsPipelineRun, + pipeline_state, + resource_state, + source_state, +) from dlt.common.schema import Schema from dlt.common.schema.typing import TColumnName -from dlt.common.typing import AnyFun, StrAny, TDataItem, TDataItems, NoneType -from dlt.common.configuration.container import Container -from dlt.common.pipeline import PipelineContext, StateInjectableContext, SupportsPipelineRun, resource_state, source_state, pipeline_state -from dlt.common.utils import graph_find_scc_nodes, flatten_list_or_items, get_callable_name, graph_edges_to_nodes, multi_context_manager, uniq_id - -from dlt.extract.typing import DataItemWithMeta, ItemTransformFunc, ItemTransformFunctionWithMeta, TDecompositionStrategy, TableNameMeta, FilterItem, MapItem, YieldMapItem -from dlt.extract.pipe import Pipe, ManagedPipeIterator, TPipeStep -from dlt.extract.schema import DltResourceSchema, TTableSchemaTemplate -from dlt.extract.incremental import Incremental, IncrementalResourceWrapper +from dlt.common.typing import AnyFun, NoneType, StrAny, TDataItem, TDataItems +from dlt.common.utils import ( + flatten_list_or_items, + get_callable_name, + graph_edges_to_nodes, + graph_find_scc_nodes, + multi_context_manager, + uniq_id, +) from dlt.extract.exceptions import ( - InvalidTransformerDataTypeGeneratorFunctionRequired, InvalidParentResourceDataType, InvalidParentResourceIsAFunction, InvalidResourceDataType, InvalidResourceDataTypeFunctionNotAGenerator, InvalidResourceDataTypeIsNone, InvalidTransformerGeneratorFunction, - DataItemRequiredForDynamicTableHints, InvalidResourceDataTypeAsync, InvalidResourceDataTypeBasic, - InvalidResourceDataTypeMultiplePipes, ParametrizedResourceUnbound, ResourceNameMissing, ResourceNotATransformer, ResourcesNotFoundError, SourceExhausted, DeletingResourcesNotSupported) + DataItemRequiredForDynamicTableHints, + DeletingResourcesNotSupported, + InvalidParentResourceDataType, + InvalidParentResourceIsAFunction, + InvalidResourceDataType, + InvalidResourceDataTypeAsync, + InvalidResourceDataTypeBasic, + InvalidResourceDataTypeFunctionNotAGenerator, + InvalidResourceDataTypeIsNone, + InvalidResourceDataTypeMultiplePipes, + InvalidTransformerDataTypeGeneratorFunctionRequired, + InvalidTransformerGeneratorFunction, + ParametrizedResourceUnbound, + ResourceNameMissing, + ResourceNotATransformer, + ResourcesNotFoundError, + SourceExhausted, +) +from dlt.extract.incremental import Incremental, IncrementalResourceWrapper +from dlt.extract.pipe import ManagedPipeIterator, Pipe, TPipeStep +from dlt.extract.schema import DltResourceSchema, TTableSchemaTemplate +from dlt.extract.typing import ( + DataItemWithMeta, + FilterItem, + ItemTransformFunc, + ItemTransformFunctionWithMeta, + MapItem, + TableNameMeta, + TDecompositionStrategy, + YieldMapItem, +) def with_table_name(item: TDataItems, table_name: str) -> DataItemWithMeta: @@ -33,7 +86,6 @@ def with_table_name(item: TDataItems, table_name: str) -> DataItemWithMeta: class DltResource(Iterable[TDataItem], DltResourceSchema): - Empty: ClassVar["DltResource"] = None source_name: str """Name of the source that contains this instance of the source, set when added to DltResourcesDict""" @@ -44,7 +96,7 @@ def __init__( table_schema_template: TTableSchemaTemplate, selected: bool, incremental: IncrementalResourceWrapper = None, - section: str = None + section: str = None, ) -> None: self._name = pipe.name self.section = section @@ -65,7 +117,7 @@ def from_data( table_schema_template: TTableSchemaTemplate = None, selected: bool = True, depends_on: Union["DltResource", Pipe] = None, - incremental: IncrementalResourceWrapper = None + incremental: IncrementalResourceWrapper = None, ) -> "DltResource": if data is None: raise InvalidResourceDataTypeIsNone(name, data, NoneType) # type: ignore @@ -74,7 +126,9 @@ def from_data( return data if isinstance(data, Pipe): - return cls(data, table_schema_template, selected, incremental=incremental, section=section) + return cls( + data, table_schema_template, selected, incremental=incremental, section=section + ) if callable(data): name = name or get_callable_name(data) @@ -102,10 +156,14 @@ def from_data( # create resource from iterator, iterable or generator function if isinstance(data, (Iterable, Iterator)) or callable(data): pipe = Pipe.from_data(name, data, parent=parent_pipe) - return cls(pipe, table_schema_template, selected, incremental=incremental, section=section) + return cls( + pipe, table_schema_template, selected, incremental=incremental, section=section + ) else: # some other data type that is not supported - raise InvalidResourceDataType(name, data, type(data), f"The data type is {type(data).__name__}") + raise InvalidResourceDataType( + name, data, type(data), f"The data type is {type(data).__name__}" + ) @property def name(self) -> str: @@ -140,7 +198,9 @@ def pipe_data_from(self, data_from: Union["DltResource", Pipe]) -> None: if self.is_transformer: DltResource._ensure_valid_transformer_resource(self._name, self._pipe.gen) else: - raise ResourceNotATransformer(self._name, "Cannot pipe data into resource that is not a transformer.") + raise ResourceNotATransformer( + self._name, "Cannot pipe data into resource that is not a transformer." + ) parent_pipe = self._get_parent_pipe(self._name, data_from) self._pipe.parent = parent_pipe @@ -152,8 +212,9 @@ def add_pipe(self, data: Any) -> None: def select_tables(self, *table_names: Iterable[str]) -> "DltResource": """For resources that dynamically dispatch data to several tables allows to select tables that will receive data, effectively filtering out other data items. - Both `with_table_name` marker and data-based (function) table name hints are supported. + Both `with_table_name` marker and data-based (function) table name hints are supported. """ + def _filter(item: TDataItem, meta: Any = None) -> bool: is_in_meta = isinstance(meta, TableNameMeta) and meta.table_name in table_names is_in_dyn = self._table_name_hint_fun and self._table_name_hint_fun(item) in table_names @@ -163,7 +224,9 @@ def _filter(item: TDataItem, meta: Any = None) -> bool: self.add_filter(_filter) return self - def add_map(self, item_map: ItemTransformFunc[TDataItem], insert_at: int = None) -> "DltResource": # noqa: A003 + def add_map( + self, item_map: ItemTransformFunc[TDataItem], insert_at: int = None + ) -> "DltResource": # noqa: A003 """Adds mapping function defined in `item_map` to the resource pipe at position `inserted_at` `item_map` receives single data items, `dlt` will enumerate any lists of data items automatically @@ -181,7 +244,9 @@ def add_map(self, item_map: ItemTransformFunc[TDataItem], insert_at: int = None) self._pipe.insert_step(MapItem(item_map), insert_at) return self - def add_yield_map(self, item_map: ItemTransformFunc[Iterator[TDataItem]], insert_at: int = None) -> "DltResource": # noqa: A003 + def add_yield_map( + self, item_map: ItemTransformFunc[Iterator[TDataItem]], insert_at: int = None + ) -> "DltResource": # noqa: A003 """Adds generating function defined in `item_map` to the resource pipe at position `inserted_at` `item_map` receives single data items, `dlt` will enumerate any lists of data items automatically. It may yield 0 or more data items and be used to @@ -200,7 +265,9 @@ def add_yield_map(self, item_map: ItemTransformFunc[Iterator[TDataItem]], insert self._pipe.insert_step(YieldMapItem(item_map), insert_at) return self - def add_filter(self, item_filter: ItemTransformFunc[bool], insert_at: int = None) -> "DltResource": # noqa: A003 + def add_filter( + self, item_filter: ItemTransformFunc[bool], insert_at: int = None + ) -> "DltResource": # noqa: A003 """Adds filter defined in `item_filter` to the resource pipe at position `inserted_at` `item_filter` receives single data items, `dlt` will enumerate any lists of data items automatically @@ -228,6 +295,7 @@ def add_limit(self, max_items: int) -> "DltResource": # noqa: A003 Returns: "DltResource": returns self """ + def _gen_wrap(gen: TPipeStep) -> TPipeStep: """Wrap a generator to take the first `max_items` records""" nonlocal max_items @@ -244,12 +312,15 @@ def _gen_wrap(gen: TPipeStep) -> TPipeStep: if inspect.isgenerator(gen): gen.close() return + # transformers should be limited by their input, so we only limit non-transformers if not self.is_transformer: self._pipe.replace_gen(_gen_wrap(self._pipe.gen)) return self - def add_step(self, item_transform: ItemTransformFunctionWithMeta[TDataItems], insert_at: int = None) -> "DltResource": # noqa: A003 + def add_step( + self, item_transform: ItemTransformFunctionWithMeta[TDataItems], insert_at: int = None + ) -> "DltResource": # noqa: A003 if insert_at is None: self._pipe.append_step(item_transform) else: @@ -309,7 +380,9 @@ def clone(self, clone_pipe: bool = True, keep_pipe_id: bool = True) -> "DltResou if self._pipe and not self._pipe.is_empty and clone_pipe: pipe = pipe._clone(keep_pipe_id=keep_pipe_id) # incremental and parent are already in the pipe (if any) - return DltResource(pipe, self._table_schema_template, selected=self.selected, section=self.section) + return DltResource( + pipe, self._table_schema_template, selected=self.selected, section=self.section + ) def __call__(self, *args: Any, **kwargs: Any) -> "DltResource": """Binds the parametrized resources to passed arguments. Creates and returns a bound resource. Generators and iterators are not evaluated.""" @@ -335,7 +408,7 @@ def __or__(self, transform: Union["DltResource", AnyFun]) -> "DltResource": def __iter__(self) -> Iterator[TDataItem]: """Opens iterator that yields the data items from the resources in the same order as in Pipeline class. - A read-only state is provided, initialized from active pipeline state. The state is discarded after the iterator is closed. + A read-only state is provided, initialized from active pipeline state. The state is discarded after the iterator is closed. """ # use the same state dict when opening iterator and when iterator is iterated container = Container() @@ -367,8 +440,12 @@ def _get_config_section_context(self) -> ConfigSectionContext: default_schema_name = pipeline._make_schema_with_default_name().name return ConfigSectionContext( pipeline_name=pipeline_name, - sections=(known_sections.SOURCES, self.section or default_schema_name or uniq_id(), self.source_name or default_schema_name or self._name), - source_state_key=self.source_name or default_schema_name or self.section or uniq_id() + sections=( + known_sections.SOURCES, + self.section or default_schema_name or uniq_id(), + self.source_name or default_schema_name or self._name, + ), + source_state_key=self.source_name or default_schema_name or self.section or uniq_id(), ) def __str__(self) -> str: @@ -381,14 +458,24 @@ def __str__(self) -> str: info += ":" if self.is_transformer: - info += f"\nThis resource is a transformer and takes data items from {self._pipe.parent.name}" + info += ( + "\nThis resource is a transformer and takes data items from" + f" {self._pipe.parent.name}" + ) else: if self._pipe.is_data_bound: if self.requires_binding: head_sig = inspect.signature(self._pipe.gen) # type: ignore - info += f"\nThis resource is parametrized and takes the following arguments {head_sig}. You must call this resource before loading." + info += ( + "\nThis resource is parametrized and takes the following arguments" + f" {head_sig}. You must call this resource before loading." + ) else: - info += "\nIf you want to see the data items in the resource you must iterate it or convert to list ie. list(resource). Note that, like any iterator, you can iterate the resource only once." + info += ( + "\nIf you want to see the data items in the resource you must iterate it or" + " convert to list ie. list(resource). Note that, like any iterator, you can" + " iterate the resource only once." + ) else: info += "\nThis resource is not bound to the data" info += f"\nInstance: info: (data pipe id:{self._pipe._pipe_id}) at {id(self)}" @@ -400,7 +487,9 @@ def _ensure_valid_transformer_resource(name: str, data: Any) -> None: if callable(data): valid_code = DltResource.validate_transformer_generator_function(data) if valid_code != 0: - raise InvalidTransformerGeneratorFunction(name, get_callable_name(data), inspect.signature(data), valid_code) + raise InvalidTransformerGeneratorFunction( + name, get_callable_name(data), inspect.signature(data), valid_code + ) else: raise InvalidTransformerDataTypeGeneratorFunctionRequired(name, data, type(data)) @@ -453,7 +542,7 @@ def __init__(self, source_name: str, source_section: str) -> None: @property def selected(self) -> Dict[str, DltResource]: """Returns a subset of all resources that will be extracted and loaded to the destination.""" - return {k:v for k,v in self.items() if v.selected} + return {k: v for k, v in self.items() if v.selected} @property def extracted(self) -> Dict[str, DltResource]: @@ -469,7 +558,12 @@ def extracted(self) -> Dict[str, DltResource]: resource = self.find_by_pipe(pipe) except KeyError: # resource for pipe not found: return mock resource - mock_template = DltResourceSchema.new_table_template(pipe.name, write_disposition=resource._table_schema_template.get("write_disposition")) + mock_template = DltResourceSchema.new_table_template( + pipe.name, + write_disposition=resource._table_schema_template.get( + "write_disposition" + ), + ) resource = DltResource(pipe, mock_template, False, section=resource.section) resource.source_name = resource.source_name extracted[resource._name] = resource @@ -511,7 +605,9 @@ def select(self, *resource_names: str) -> Dict[str, DltResource]: for name in resource_names: if name not in self: # if any key is missing, display the full info - raise ResourcesNotFoundError(self.source_name, set(self.keys()), set(resource_names)) + raise ResourcesNotFoundError( + self.source_name, set(self.keys()), set(resource_names) + ) # set the selected flags for resource in self.values(): self[resource._name].selected = resource._name in resource_names @@ -523,12 +619,16 @@ def find_by_pipe(self, pipe: Pipe) -> DltResource: if pipe._pipe_id in self._known_pipes: return self._known_pipes[pipe._pipe_id] try: - return self._known_pipes.setdefault(pipe._pipe_id, next(r for r in self.values() if r._pipe._pipe_id == pipe._pipe_id)) + return self._known_pipes.setdefault( + pipe._pipe_id, next(r for r in self.values() if r._pipe._pipe_id == pipe._pipe_id) + ) except StopIteration: raise KeyError(pipe) def clone_new_pipes(self) -> None: - cloned_pipes = ManagedPipeIterator.clone_pipes([r._pipe for r in self.values() if r in self._recently_added]) + cloned_pipes = ManagedPipeIterator.clone_pipes( + [r._pipe for r in self.values() if r in self._recently_added] + ) # replace pipes in resources, the cloned_pipes preserve parent connections for cloned in cloned_pipes: self.find_by_pipe(cloned)._pipe = cloned @@ -560,7 +660,10 @@ class DltSource(Iterable[TDataItem]): * You can use a `run` method to load the data with a default instance of dlt pipeline. * You can get source read only state for the currently active Pipeline instance """ - def __init__(self, name: str, section: str, schema: Schema, resources: Sequence[DltResource] = None) -> None: + + def __init__( + self, name: str, section: str, schema: Schema, resources: Sequence[DltResource] = None + ) -> None: self.name = name self.section = section """Tells if iterator associated with a source is exhausted""" @@ -569,7 +672,10 @@ def __init__(self, name: str, section: str, schema: Schema, resources: Sequence[ if self.name != schema.name: # raise ValueError(f"Schema name {schema.name} differs from source name {name}! The explicit source name argument is deprecated and will be soon removed.") - warnings.warn(f"Schema name {schema.name} differs from source name {name}! The explicit source name argument is deprecated and will be soon removed.") + warnings.warn( + f"Schema name {schema.name} differs from source name {name}! The explicit source" + " name argument is deprecated and will be soon removed." + ) if resources: for resource in resources: @@ -616,21 +722,28 @@ def exhausted(self) -> bool: def root_key(self) -> bool: """Enables merging on all resources by propagating root foreign key to child tables. This option is most useful if you plan to change write disposition of a resource to disable/enable merge""" config = RelationalNormalizer.get_normalizer_config(self._schema).get("propagation") - return config is not None and "root" in config and "_dlt_id" in config["root"] and config["root"]["_dlt_id"] == "_dlt_root_id" + return ( + config is not None + and "root" in config + and "_dlt_id" in config["root"] + and config["root"]["_dlt_id"] == "_dlt_root_id" + ) @root_key.setter def root_key(self, value: bool) -> None: if value is True: propagation_config: RelationalNormalizerConfigPropagation = { - "root": { - "_dlt_id": TColumnName("_dlt_root_id") - }, - "tables": {} + "root": {"_dlt_id": TColumnName("_dlt_root_id")}, + "tables": {}, } - RelationalNormalizer.update_normalizer_config(self._schema, {"propagation": propagation_config}) + RelationalNormalizer.update_normalizer_config( + self._schema, {"propagation": propagation_config} + ) else: if self.root_key: - propagation_config = RelationalNormalizer.get_normalizer_config(self._schema)["propagation"] + propagation_config = RelationalNormalizer.get_normalizer_config(self._schema)[ + "propagation" + ] propagation_config["root"].pop("_dlt_id") # type: ignore @property @@ -671,8 +784,8 @@ def with_resources(self, *resource_names: str) -> "DltSource": def decompose(self, strategy: TDecompositionStrategy) -> List["DltSource"]: """Decomposes source into a list of sources with a given strategy. - "none" will return source as is - "scc" will decompose the dag of selected pipes and their parent into strongly connected components + "none" will return source as is + "scc" will decompose the dag of selected pipes and their parent into strongly connected components """ if strategy == "none": return [self] @@ -703,7 +816,9 @@ def add_limit(self, max_items: int) -> "DltSource": # noqa: A003 @property def run(self) -> SupportsPipelineRun: """A convenience method that will call `run` run on the currently active `dlt` pipeline. If pipeline instance is not found, one with default settings will be created.""" - self_run: SupportsPipelineRun = makefun.partial(Container()[PipelineContext].pipeline().run, *(), data=self) + self_run: SupportsPipelineRun = makefun.partial( + Container()[PipelineContext].pipeline().run, *(), data=self + ) return self_run @property @@ -715,14 +830,16 @@ def state(self) -> StrAny: def clone(self) -> "DltSource": """Creates a deep copy of the source where copies of schema, resources and pipes are created""" # mind that resources and pipes are cloned when added to the DltResourcesDict in the source constructor - return DltSource(self.name, self.section, self.schema.clone(), list(self._resources.values())) + return DltSource( + self.name, self.section, self.schema.clone(), list(self._resources.values()) + ) def __iter__(self) -> Iterator[TDataItem]: """Opens iterator that yields the data items from all the resources within the source in the same order as in Pipeline class. - A read-only state is provided, initialized from active pipeline state. The state is discarded after the iterator is closed. + A read-only state is provided, initialized from active pipeline state. The state is discarded after the iterator is closed. - A source config section is injected to allow secrets/config injection as during regular extraction. + A source config section is injected to allow secrets/config injection as during regular extraction. """ # use the same state dict when opening iterator and when iterator is iterated mock_state, _ = pipeline_state(Container(), {}) @@ -742,7 +859,7 @@ def _get_config_section_context(self) -> ConfigSectionContext: return ConfigSectionContext( pipeline_name=pipeline_name, sections=(known_sections.SOURCES, self.section, self.name), - source_state_key=self.name + source_state_key=self.name, ) def _add_resource(self, name: str, resource: DltResource) -> None: @@ -768,17 +885,29 @@ def __setattr__(self, name: str, value: Any) -> None: super().__setattr__(name, value) def __str__(self) -> str: - info = f"DltSource {self.name} section {self.section} contains {len(self.resources)} resource(s) of which {len(self.selected_resources)} are selected" + info = ( + f"DltSource {self.name} section {self.section} contains" + f" {len(self.resources)} resource(s) of which {len(self.selected_resources)} are" + " selected" + ) for r in self.resources.values(): selected_info = "selected" if r.selected else "not selected" if r.is_transformer: - info += f"\ntransformer {r._name} is {selected_info} and takes data from {r._pipe.parent.name}" + info += ( + f"\ntransformer {r._name} is {selected_info} and takes data from" + f" {r._pipe.parent.name}" + ) else: info += f"\nresource {r._name} is {selected_info}" if self.exhausted: - info += "\nSource is already iterated and cannot be used again ie. to display or load data." + info += ( + "\nSource is already iterated and cannot be used again ie. to display or load data." + ) else: - info += "\nIf you want to see the data items in this source you must iterate it or convert to list ie. list(source)." + info += ( + "\nIf you want to see the data items in this source you must iterate it or convert" + " to list ie. list(source)." + ) info += " Note that, like any iterator, you can iterate the source only once." info += f"\ninstance id: {id(self)}" return info diff --git a/dlt/extract/typing.py b/dlt/extract/typing.py index a8608021ba..0a4dca32a9 100644 --- a/dlt/extract/typing.py +++ b/dlt/extract/typing.py @@ -1,10 +1,20 @@ import inspect from abc import ABC, abstractmethod -from typing import Any, Callable, Generic, Iterator, Literal, Optional, Protocol, TypeVar, Union, Awaitable +from typing import ( + Any, + Awaitable, + Callable, + Generic, + Iterator, + Literal, + Optional, + Protocol, + TypeVar, + Union, +) from dlt.common.typing import TAny, TDataItem, TDataItems - TDecompositionStrategy = Literal["none", "scc"] TDeferredDataItems = Callable[[], TDataItems] TAwaitableDataItems = Awaitable[TDataItems] @@ -37,6 +47,7 @@ def __init__(self, table_name: str) -> None: class SupportsPipe(Protocol): """A protocol with the core Pipe properties and operations""" + name: str """Pipe name which is inherited by a resource""" @@ -45,6 +56,7 @@ class SupportsPipe(Protocol): ItemTransformFunctionNoMeta = Callable[[TDataItem], TAny] ItemTransformFunc = Union[ItemTransformFunctionWithMeta[TAny], ItemTransformFunctionNoMeta[TAny]] + class ItemTransform(ABC, Generic[TAny]): _f_meta: ItemTransformFunctionWithMeta[TAny] = None _f: ItemTransformFunctionNoMeta[TAny] = None @@ -108,7 +120,7 @@ def __call__(self, item: TDataItems, meta: Any = None) -> Optional[TDataItems]: class YieldMapItem(ItemTransform[Iterator[TDataItem]]): - # mypy needs those to type correctly + # mypy needs those to type correctly _f_meta: ItemTransformFunctionWithMeta[TDataItem] _f: ItemTransformFunctionNoMeta[TDataItem] @@ -123,4 +135,4 @@ def __call__(self, item: TDataItems, meta: Any = None) -> Optional[TDataItems]: if self._f_meta: yield from self._f_meta(item, meta) else: - yield from self._f(item) \ No newline at end of file + yield from self._f(item) diff --git a/dlt/extract/utils.py b/dlt/extract/utils.py index 16f7fabdf7..86dae08999 100644 --- a/dlt/extract/utils.py +++ b/dlt/extract/utils.py @@ -1,10 +1,12 @@ -from typing import Union, List, Any +from typing import Any, List, Union -from dlt.extract.typing import TTableHintTemplate, TDataItem from dlt.common.schema.typing import TColumnNames +from dlt.extract.typing import TDataItem, TTableHintTemplate -def resolve_column_value(column_hint: TTableHintTemplate[TColumnNames], item: TDataItem) -> Union[Any, List[Any]]: +def resolve_column_value( + column_hint: TTableHintTemplate[TColumnNames], item: TDataItem +) -> Union[Any, List[Any]]: """Extract values from the data item given a column hint. Returns either a single value or list of values when hint is a composite. """ diff --git a/dlt/helpers/airflow_helper.py b/dlt/helpers/airflow_helper.py index 9bffbe1a37..b76171a1f7 100644 --- a/dlt/helpers/airflow_helper.py +++ b/dlt/helpers/airflow_helper.py @@ -1,7 +1,14 @@ import os from tempfile import gettempdir from typing import Any, Callable, List, Literal, Optional, Sequence, Tuple -from tenacity import retry_if_exception, wait_exponential, stop_after_attempt, Retrying, RetryCallState + +from tenacity import ( + RetryCallState, + Retrying, + retry_if_exception, + stop_after_attempt, + wait_exponential, +) from dlt.common import pendulum from dlt.common.exceptions import MissingDependencyException @@ -9,30 +16,29 @@ try: from airflow.configuration import conf + from airflow.operators.python import PythonOperator, get_current_context from airflow.utils.task_group import TaskGroup - from airflow.operators.python import PythonOperator - from airflow.operators.python import get_current_context except ModuleNotFoundError: raise MissingDependencyException("Airflow", ["airflow>=2.0.0"]) import dlt from dlt.common import logger -from dlt.common.schema.typing import TWriteDisposition -from dlt.common.utils import uniq_id from dlt.common.configuration.container import Container from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext from dlt.common.runtime.collector import NULL_COLLECTOR - +from dlt.common.schema.typing import TWriteDisposition +from dlt.common.utils import uniq_id from dlt.extract.source import DltSource from dlt.pipeline.helpers import retry_load from dlt.pipeline.pipeline import Pipeline from dlt.pipeline.progress import log from dlt.pipeline.typing import TPipelineStep - DEFAULT_RETRY_NO_RETRY = Retrying(stop=stop_after_attempt(1), reraise=True) -DEFAULT_RETRY_BACKOFF = Retrying(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1.5, min=4, max=10), reraise=True) +DEFAULT_RETRY_BACKOFF = Retrying( + stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1.5, min=4, max=10), reraise=True +) class PipelineTasksGroup(TaskGroup): @@ -49,13 +55,13 @@ def __init__( log_progress_period: float = 30.0, buffer_max_items: int = 1000, retry_policy: Retrying = DEFAULT_RETRY_NO_RETRY, - retry_pipeline_steps: Sequence[TPipelineStep] = ("load", ), + retry_pipeline_steps: Sequence[TPipelineStep] = ("load",), fail_task_if_any_job_failed: bool = True, abort_task_if_any_job_failed: bool = False, wipe_local_data: bool = True, save_load_info: bool = False, save_trace_info: bool = False, - **kwargs: Any + **kwargs: Any, ) -> None: """Creates a task group to which you can add pipeline runs @@ -102,7 +108,7 @@ def __init__( self.save_trace_info = save_trace_info # reload providers so config.toml in dags folder is included - dags_folder = conf.get('core', 'dags_folder') + dags_folder = conf.get("core", "dags_folder") # set the dlt project folder to dags os.environ["DLT_PROJECT_DIR"] = dags_folder @@ -128,7 +134,7 @@ def add_run( decompose: Literal["none", "serialize"] = "none", table_name: str = None, write_disposition: TWriteDisposition = None, - **kwargs: Any + **kwargs: Any, ) -> List[PythonOperator]: """Creates a task or a group of tasks to run `data` with `pipeline` @@ -152,7 +158,10 @@ def add_run( # make sure that pipeline was created after dag was initialized if not pipeline.pipelines_dir.startswith(os.environ["DLT_DATA_DIR"]): - raise ValueError("Please create your Pipeline instance after AirflowTasks are created. The dlt pipelines directory is not set correctly") + raise ValueError( + "Please create your Pipeline instance after AirflowTasks are created. The dlt" + " pipelines directory is not set correctly" + ) def task_name(pipeline: Pipeline, data: Any) -> str: task_name = pipeline.pipeline_name @@ -164,11 +173,9 @@ def task_name(pipeline: Pipeline, data: Any) -> str: return task_name with self: - # use factory function to make test, in order to parametrize it. passing arguments to task function (_run) is serializing them and # running template engine on them def make_task(pipeline: Pipeline, data: Any) -> PythonOperator: - def _run() -> None: # activate pipeline pipeline.activate() @@ -181,7 +188,10 @@ def _run() -> None: logger.LOGGER = ti.log # set global number of buffered items - if dlt.config.get("data_writer.buffer_max_items") is None and self.buffer_max_items > 0: + if ( + dlt.config.get("data_writer.buffer_max_items") is None + and self.buffer_max_items > 0 + ): dlt.config["data_writer.buffer_max_items"] = self.buffer_max_items logger.info(f"Set data_writer.buffer_max_items to {self.buffer_max_items}") @@ -191,24 +201,36 @@ def _run() -> None: logger.info("Set load.abort_task_if_any_job_failed to True") if self.log_progress_period > 0 and task_pipeline.collector == NULL_COLLECTOR: - task_pipeline.collector = log(log_period=self.log_progress_period, logger=logger.LOGGER) + task_pipeline.collector = log( + log_period=self.log_progress_period, logger=logger.LOGGER + ) logger.info(f"Enabled log progress with period {self.log_progress_period}") logger.info(f"Pipeline data in {task_pipeline.working_dir}") def log_after_attempt(retry_state: RetryCallState) -> None: if not retry_state.retry_object.stop(retry_state): - logger.error("Retrying pipeline run due to exception: %s", retry_state.outcome.exception()) + logger.error( + "Retrying pipeline run due to exception: %s", + retry_state.outcome.exception(), + ) try: # retry with given policy on selected pipeline steps for attempt in self.retry_policy.copy( - retry=retry_if_exception(retry_load(retry_on_pipeline_steps=self.retry_pipeline_steps)), - after=log_after_attempt + retry=retry_if_exception( + retry_load(retry_on_pipeline_steps=self.retry_pipeline_steps) + ), + after=log_after_attempt, ): with attempt: - logger.info("Running the pipeline, attempt=%s" % attempt.retry_state.attempt_number) - load_info = task_pipeline.run(data, table_name=table_name, write_disposition=write_disposition) + logger.info( + "Running the pipeline, attempt=%s" + % attempt.retry_state.attempt_number + ) + load_info = task_pipeline.run( + data, table_name=table_name, write_disposition=write_disposition + ) logger.info(str(load_info)) # save load and trace if self.save_load_info: @@ -216,7 +238,9 @@ def log_after_attempt(retry_state: RetryCallState) -> None: task_pipeline.run([load_info], table_name="_load_info") if self.save_trace_info: logger.info("Saving the trace in the destination") - task_pipeline.run([task_pipeline.last_trace], table_name="_trace") + task_pipeline.run( + [task_pipeline.last_trace], table_name="_trace" + ) # raise on failed jobs if requested if self.fail_task_if_any_job_failed: load_info.raise_on_failed_jobs() @@ -227,9 +251,7 @@ def log_after_attempt(retry_state: RetryCallState) -> None: task_pipeline._wipe_working_folder() return PythonOperator( - task_id=task_name(pipeline, data), - python_callable=_run, - **kwargs + task_id=task_name(pipeline, data), python_callable=_run, **kwargs ) if decompose == "none": @@ -262,6 +284,7 @@ def airflow_get_execution_dates() -> Tuple[pendulum.DateTime, Optional[pendulum. # prefer logging to task logger try: from airflow.operators.python import get_current_context # noqa + context = get_current_context() return context["data_interval_start"], context["data_interval_end"] except Exception: diff --git a/dlt/helpers/dbt/__init__.py b/dlt/helpers/dbt/__init__.py index d7369672ec..aa0f8830aa 100644 --- a/dlt/helpers/dbt/__init__.py +++ b/dlt/helpers/dbt/__init__.py @@ -1,23 +1,20 @@ import contextlib from typing import List + import pkg_resources import semver -from dlt.common.runners import Venv -from dlt.common.destination.reference import DestinationClientDwhConfiguration from dlt.common.configuration.specs import CredentialsWithDefault +from dlt.common.destination.reference import DestinationClientDwhConfiguration +from dlt.common.runners import Venv from dlt.common.typing import TSecretValue +from dlt.helpers.dbt.runner import DBTPackageRunner, create_runner from dlt.version import get_installed_requirement_string -from dlt.helpers.dbt.runner import create_runner, DBTPackageRunner - DEFAULT_DBT_VERSION = ">=1.1,<1.6" # a map of destination names to dbt package names in case they don't match the pure destination name -DBT_DESTINATION_MAP = { - "athena": "athena-community", - "motherduck": "duckdb" -} +DBT_DESTINATION_MAP = {"athena": "athena-community", "motherduck": "duckdb"} def _default_profile_name(credentials: DestinationClientDwhConfiguration) -> str: @@ -26,14 +23,16 @@ def _default_profile_name(credentials: DestinationClientDwhConfiguration) -> str if isinstance(credentials.credentials, CredentialsWithDefault): if credentials.credentials.has_default_credentials(): profile_name += "_default" - elif profile_name == 'snowflake': - if getattr(credentials.credentials, 'private_key', None): + elif profile_name == "snowflake": + if getattr(credentials.credentials, "private_key", None): # snowflake with private key is a separate profile - profile_name += '_pkey' + profile_name += "_pkey" return profile_name -def _create_dbt_deps(destination_names: List[str], dbt_version: str = DEFAULT_DBT_VERSION) -> List[str]: +def _create_dbt_deps( + destination_names: List[str], dbt_version: str = DEFAULT_DBT_VERSION +) -> List[str]: if dbt_version: # if parses as version use "==" operator with contextlib.suppress(ValueError): @@ -56,13 +55,17 @@ def _create_dbt_deps(destination_names: List[str], dbt_version: str = DEFAULT_DB return all_packages + [dlt_requirement] -def restore_venv(venv_dir: str, destination_names: List[str], dbt_version: str = DEFAULT_DBT_VERSION) -> Venv: +def restore_venv( + venv_dir: str, destination_names: List[str], dbt_version: str = DEFAULT_DBT_VERSION +) -> Venv: venv = Venv.restore(venv_dir) venv.add_dependencies(_create_dbt_deps(destination_names, dbt_version)) return venv -def create_venv(venv_dir: str, destination_names: List[str], dbt_version: str = DEFAULT_DBT_VERSION) -> Venv: +def create_venv( + venv_dir: str, destination_names: List[str], dbt_version: str = DEFAULT_DBT_VERSION +) -> Venv: return Venv.create(venv_dir, _create_dbt_deps(destination_names, dbt_version)) @@ -73,7 +76,7 @@ def package_runner( package_location: str, package_repository_branch: str = None, package_repository_ssh_key: TSecretValue = TSecretValue(""), # noqa - auto_full_refresh_when_out_of_sync: bool = None + auto_full_refresh_when_out_of_sync: bool = None, ) -> DBTPackageRunner: default_profile_name = _default_profile_name(destination_configuration) return create_runner( @@ -84,5 +87,5 @@ def package_runner( package_repository_branch=package_repository_branch, package_repository_ssh_key=package_repository_ssh_key, package_profile_name=default_profile_name, - auto_full_refresh_when_out_of_sync=auto_full_refresh_when_out_of_sync + auto_full_refresh_when_out_of_sync=auto_full_refresh_when_out_of_sync, ) diff --git a/dlt/helpers/dbt/configuration.py b/dlt/helpers/dbt/configuration.py index d21266196e..de70d3795e 100644 --- a/dlt/helpers/dbt/configuration.py +++ b/dlt/helpers/dbt/configuration.py @@ -1,16 +1,18 @@ import os from typing import Optional, Sequence -from dlt.common.typing import StrAny, TSecretValue from dlt.common.configuration import configspec from dlt.common.configuration.specs import BaseConfiguration, RunConfiguration +from dlt.common.typing import StrAny, TSecretValue @configspec class DBTRunnerConfiguration(BaseConfiguration): package_location: str = None package_repository_branch: Optional[str] = None - package_repository_ssh_key: TSecretValue = TSecretValue("") # the default is empty value which will disable custom SSH KEY + package_repository_ssh_key: TSecretValue = TSecretValue( + "" + ) # the default is empty value which will disable custom SSH KEY package_profiles_dir: Optional[str] = None package_profile_name: Optional[str] = None auto_full_refresh_when_out_of_sync: bool = True diff --git a/dlt/helpers/dbt/dbt_utils.py b/dlt/helpers/dbt/dbt_utils.py index 64ffc49a39..505fe3ba9e 100644 --- a/dlt/helpers/dbt/dbt_utils.py +++ b/dlt/helpers/dbt/dbt_utils.py @@ -1,25 +1,29 @@ -import os import logging -from typing import Any, Sequence, Optional, Union +import os import warnings +from typing import Any, Optional, Sequence, Union from dlt.common import json, logger from dlt.common.exceptions import MissingDependencyException from dlt.common.typing import StrAny - -from dlt.helpers.dbt.exceptions import DBTProcessingError, DBTNodeResult, IncrementalSchemaOutOfSyncError +from dlt.helpers.dbt.exceptions import ( + DBTNodeResult, + DBTProcessingError, + IncrementalSchemaOutOfSyncError, +) try: # block disabling root logger import logbook.compat - logbook.compat.redirect_logging = lambda : None + + logbook.compat.redirect_logging = lambda: None # can only import DBT after redirect is disabled # https://stackoverflow.com/questions/48619517/call-a-click-command-from-code import dbt.logger - from dbt.events import functions from dbt.contracts import results as dbt_results + from dbt.events import functions except ModuleNotFoundError: raise MissingDependencyException("DBT Core", ["dbt-core"]) @@ -78,9 +82,12 @@ def set_path_wrapper(self: dbt.logger.LogManager, path: str) -> None: def is_incremental_schema_out_of_sync_error(error: Any) -> bool: - def _check_single_item(error_: dbt_results.RunResult) -> bool: - return error_.status == dbt_results.RunStatus.Error and "The source and target schemas on this incremental model are out of sync" in error_.message + return ( + error_.status == dbt_results.RunStatus.Error + and "The source and target schemas on this incremental model are out of sync" + in error_.message + ) if isinstance(error, dbt_results.RunResult): return _check_single_item(error) @@ -102,18 +109,20 @@ def parse_dbt_execution_results(results: Any) -> Sequence[DBTNodeResult]: return None return [ - DBTNodeResult(res.node.name, res.message, res.execution_time, str(res.status)) for res in results if isinstance(res, dbt_results.NodeResult) - ] + DBTNodeResult(res.node.name, res.message, res.execution_time, str(res.status)) + for res in results + if isinstance(res, dbt_results.NodeResult) + ] def run_dbt_command( - package_path: str, - command: str, - profiles_dir: str, - profile_name: Optional[str] = None, - global_args: Sequence[str] = None, - command_args: Sequence[str] = None, - package_vars: StrAny = None + package_path: str, + command: str, + profiles_dir: str, + profile_name: Optional[str] = None, + global_args: Sequence[str] = None, + command_args: Sequence[str] = None, + package_vars: StrAny = None, ) -> Union[Sequence[DBTNodeResult], dbt_results.ExecutionResult]: args = ["--profiles-dir", profiles_dir] # add profile name if provided @@ -133,7 +142,7 @@ def run_dbt_command( success: bool = None # dbt uses logbook which does not run on python 10. below is a hack that allows that warnings.filterwarnings("ignore", category=DeprecationWarning, module="logbook") - runner_args = (global_args or []) + [command] + args # type: ignore + runner_args = (global_args or []) + [command] + args # type: ignore with dbt.logger.log_manager.applicationbound(): try: @@ -177,8 +186,16 @@ def init_logging_and_run_dbt_command( profiles_dir: str, profile_name: Optional[str] = None, command_args: Sequence[str] = None, - package_vars: StrAny = None + package_vars: StrAny = None, ) -> Union[Sequence[DBTNodeResult], dbt_results.ExecutionResult]: # initialize dbt logging, returns global parameters to dbt command dbt_global_args = initialize_dbt_logging(log_level, is_json_logging) - return run_dbt_command(package_path, command, profiles_dir, profile_name, dbt_global_args, command_args, package_vars) + return run_dbt_command( + package_path, + command, + profiles_dir, + profile_name, + dbt_global_args, + command_args, + package_vars, + ) diff --git a/dlt/helpers/dbt/exceptions.py b/dlt/helpers/dbt/exceptions.py index 3a9d6f9c80..bf6a35fb0d 100644 --- a/dlt/helpers/dbt/exceptions.py +++ b/dlt/helpers/dbt/exceptions.py @@ -1,4 +1,4 @@ -from typing import Any, Sequence, NamedTuple +from typing import Any, NamedTuple, Sequence from dlt.common.exceptions import DltException @@ -23,7 +23,9 @@ class DBTNodeResult(NamedTuple): class DBTProcessingError(DBTRunnerException): - def __init__(self, command: str, run_results: Sequence[DBTNodeResult], dbt_results: Any) -> None: + def __init__( + self, command: str, run_results: Sequence[DBTNodeResult], dbt_results: Any + ) -> None: self.command = command self.run_results = run_results # the results from DBT may be anything diff --git a/dlt/helpers/dbt/runner.py b/dlt/helpers/dbt/runner.py index 2e857b2256..0fc687c1cf 100644 --- a/dlt/helpers/dbt/runner.py +++ b/dlt/helpers/dbt/runner.py @@ -1,25 +1,29 @@ import os from subprocess import CalledProcessError -import giturlparse from typing import Sequence +import giturlparse + import dlt from dlt.common import logger -from dlt.common.configuration import with_config, known_sections +from dlt.common.configuration import known_sections, with_config from dlt.common.configuration.utils import add_config_to_env from dlt.common.destination.reference import DestinationClientDwhConfiguration +from dlt.common.git import ensure_remote_head, force_clone_repo, git_custom_key_command from dlt.common.runners import Venv from dlt.common.runners.stdout import iter_stdout_with_result -from dlt.common.typing import StrAny, TSecretValue from dlt.common.runtime.logger import is_json_logging +from dlt.common.runtime.telemetry import with_telemetry from dlt.common.storages import FileStorage -from dlt.common.git import git_custom_key_command, ensure_remote_head, force_clone_repo +from dlt.common.typing import StrAny, TSecretValue from dlt.common.utils import with_custom_environ - from dlt.helpers.dbt.configuration import DBTRunnerConfiguration -from dlt.helpers.dbt.exceptions import IncrementalSchemaOutOfSyncError, PrerequisitesException, DBTNodeResult, DBTProcessingError - -from dlt.common.runtime.telemetry import with_telemetry +from dlt.helpers.dbt.exceptions import ( + DBTNodeResult, + DBTProcessingError, + IncrementalSchemaOutOfSyncError, + PrerequisitesException, +) class DBTPackageRunner: @@ -31,12 +35,13 @@ class DBTPackageRunner: passed via DBTRunnerConfiguration instance """ - def __init__(self, + def __init__( + self, venv: Venv, credentials: DestinationClientDwhConfiguration, working_dir: str, source_dataset_name: str, - config: DBTRunnerConfiguration + config: DBTRunnerConfiguration, ) -> None: self.venv = venv self.credentials = credentials @@ -62,7 +67,9 @@ def _setup_location(self) -> None: self.cloned_package_name = url.name self.package_path = os.path.join(self.working_dir, self.cloned_package_name) - def _get_package_vars(self, additional_vars: StrAny = None, destination_dataset_name: str = None) -> StrAny: + def _get_package_vars( + self, additional_vars: StrAny = None, destination_dataset_name: str = None + ) -> StrAny: if self.config.package_additional_vars: package_vars = dict(self.config.package_additional_vars) else: @@ -82,7 +89,9 @@ def _log_dbt_run_results(self, results: Sequence[DBTNodeResult]) -> None: if res.status == "error": logger.error(f"Model {res.model_name} error! Error: {res.message}") else: - logger.info(f"Model {res.model_name} {res.status} in {res.time} seconds with {res.message}") + logger.info( + f"Model {res.model_name} {res.status} in {res.time} seconds with {res.message}" + ) def ensure_newest_package(self) -> None: """Clones or brings the dbt package at `package_location` up to date.""" @@ -90,19 +99,37 @@ def ensure_newest_package(self) -> None: with git_custom_key_command(self.config.package_repository_ssh_key) as ssh_command: try: - ensure_remote_head(self.package_path, branch=self.config.package_repository_branch, with_git_command=ssh_command) + ensure_remote_head( + self.package_path, + branch=self.config.package_repository_branch, + with_git_command=ssh_command, + ) except GitError as err: # cleanup package folder logger.info(f"Package will be cloned due to {type(err).__name__}:{str(err)}") - logger.info(f"Will clone {self.config.package_location} head {self.config.package_repository_branch} into {self.package_path}") - force_clone_repo(self.config.package_location, self.repo_storage, self.cloned_package_name, self.config.package_repository_branch, with_git_command=ssh_command) + logger.info( + f"Will clone {self.config.package_location} head" + f" {self.config.package_repository_branch} into {self.package_path}" + ) + force_clone_repo( + self.config.package_location, + self.repo_storage, + self.cloned_package_name, + self.config.package_repository_branch, + with_git_command=ssh_command, + ) @with_custom_environ - def _run_dbt_command(self, command: str, command_args: Sequence[str] = None, package_vars: StrAny = None) -> Sequence[DBTNodeResult]: - logger.info(f"Exec dbt command: {command} {command_args} {package_vars} on profile {self.config.package_profile_name}") + def _run_dbt_command( + self, command: str, command_args: Sequence[str] = None, package_vars: StrAny = None + ) -> Sequence[DBTNodeResult]: + logger.info( + f"Exec dbt command: {command} {command_args} {package_vars} on profile" + f" {self.config.package_profile_name}" + ) # write credentials to environ to pass them to dbt, add DLT__ prefix if self.credentials: - add_config_to_env(self.credentials, ("dlt", )) + add_config_to_env(self.credentials, ("dlt",)) args = [ self.config.runtime.log_level, is_json_logging(self.config.runtime.log_format), @@ -111,7 +138,7 @@ def _run_dbt_command(self, command: str, command_args: Sequence[str] = None, pac self.config.package_profiles_dir, self.config.package_profile_name, command_args, - package_vars + package_vars, ] script = f""" from functools import partial @@ -134,7 +161,12 @@ def _run_dbt_command(self, command: str, command_args: Sequence[str] = None, pac print(cpe.stderr) raise - def run(self, cmd_params: Sequence[str] = ("--fail-fast", ), additional_vars: StrAny = None, destination_dataset_name: str = None) -> Sequence[DBTNodeResult]: + def run( + self, + cmd_params: Sequence[str] = ("--fail-fast",), + additional_vars: StrAny = None, + destination_dataset_name: str = None, + ) -> Sequence[DBTNodeResult]: """Runs `dbt` package Executes `dbt run` on previously cloned package. @@ -151,12 +183,15 @@ def run(self, cmd_params: Sequence[str] = ("--fail-fast", ), additional_vars: St DBTProcessingError: `run` command failed. Contains a list of models with their execution statuses and error messages """ return self._run_dbt_command( - "run", - cmd_params, - self._get_package_vars(additional_vars, destination_dataset_name) + "run", cmd_params, self._get_package_vars(additional_vars, destination_dataset_name) ) - def test(self, cmd_params: Sequence[str] = None, additional_vars: StrAny = None, destination_dataset_name: str = None) -> Sequence[DBTNodeResult]: + def test( + self, + cmd_params: Sequence[str] = None, + additional_vars: StrAny = None, + destination_dataset_name: str = None, + ) -> Sequence[DBTNodeResult]: """Tests `dbt` package Executes `dbt test` on previously cloned package. @@ -173,12 +208,12 @@ def test(self, cmd_params: Sequence[str] = None, additional_vars: StrAny = None, DBTProcessingError: `test` command failed. Contains a list of models with their execution statuses and error messages """ return self._run_dbt_command( - "test", - cmd_params, - self._get_package_vars(additional_vars, destination_dataset_name) + "test", cmd_params, self._get_package_vars(additional_vars, destination_dataset_name) ) - def _run_db_steps(self, run_params: Sequence[str], package_vars: StrAny, source_tests_selector: str) -> Sequence[DBTNodeResult]: + def _run_db_steps( + self, run_params: Sequence[str], package_vars: StrAny, source_tests_selector: str + ) -> Sequence[DBTNodeResult]: if self.repo_storage: # make sure we use package from the remote head self.ensure_newest_package() @@ -209,8 +244,9 @@ def _run_db_steps(self, run_params: Sequence[str], package_vars: StrAny, source_ else: raise - def run_all(self, - run_params: Sequence[str] = ("--fail-fast", ), + def run_all( + self, + run_params: Sequence[str] = ("--fail-fast",), additional_vars: StrAny = None, source_tests_selector: str = None, destination_dataset_name: str = None, @@ -244,7 +280,7 @@ def run_all(self, results = self._run_db_steps( run_params, self._get_package_vars(additional_vars, destination_dataset_name), - source_tests_selector + source_tests_selector, ) self._log_dbt_run_results(results) return results @@ -270,6 +306,6 @@ def create_runner( package_profiles_dir: str = None, package_profile_name: str = None, auto_full_refresh_when_out_of_sync: bool = None, - config: DBTRunnerConfiguration = None - ) -> DBTPackageRunner: + config: DBTRunnerConfiguration = None, +) -> DBTPackageRunner: return DBTPackageRunner(venv, credentials, working_dir, credentials.dataset_name, config) diff --git a/dlt/helpers/pandas_helper.py b/dlt/helpers/pandas_helper.py index 9c077d49a9..3b2099d78d 100644 --- a/dlt/helpers/pandas_helper.py +++ b/dlt/helpers/pandas_helper.py @@ -14,7 +14,12 @@ @deprecated(reason="Use `df` method on cursor returned from client.execute_query") def query_results_to_df( - client: SqlClientBase[Any], query: str, index_col: Any = None, coerce_float: bool = True, parse_dates: Any = None, dtype: Any = None + client: SqlClientBase[Any], + query: str, + index_col: Any = None, + coerce_float: bool = True, + parse_dates: Any = None, + dtype: Any = None, ) -> pd.DataFrame: """ A helper function that executes a query in the destination and returns the result as Pandas `DataFrame` @@ -52,5 +57,7 @@ def query_results_to_df( columns = [c[0] for c in curr.description] # use existing panda function that converts results to data frame # TODO: we may use `_wrap_iterator` to prevent loading the full result to memory first - pf: pd.DataFrame = _wrap_result(curr.fetchall(), columns, index_col, coerce_float, parse_dates, dtype) + pf: pd.DataFrame = _wrap_result( + curr.fetchall(), columns, index_col, coerce_float, parse_dates, dtype + ) return pf diff --git a/dlt/helpers/streamlit_helper.py b/dlt/helpers/streamlit_helper.py index 37bc97b240..dd7fea3d43 100644 --- a/dlt/helpers/streamlit_helper.py +++ b/dlt/helpers/streamlit_helper.py @@ -1,13 +1,12 @@ import sys from typing import Dict, List -import humanize +import humanize from dlt.common import pendulum -from dlt.common.typing import AnyFun from dlt.common.configuration.exceptions import ConfigFieldMissingException from dlt.common.exceptions import MissingDependencyException - +from dlt.common.typing import AnyFun from dlt.helpers.pandas_helper import pd from dlt.pipeline import Pipeline from dlt.pipeline.exceptions import CannotRestorePipelineException, SqlClientNotAvailable @@ -15,9 +14,14 @@ try: import streamlit as st + # from streamlit import SECRETS_FILE_LOC, secrets except ModuleNotFoundError: - raise MissingDependencyException("DLT Streamlit Helpers", ["streamlit"], "DLT Helpers for Streamlit should be run within a streamlit app.") + raise MissingDependencyException( + "DLT Streamlit Helpers", + ["streamlit"], + "DLT Helpers for Streamlit should be run within a streamlit app.", + ) # use right caching function to disable deprecation message @@ -126,11 +130,17 @@ def _query_data_live(query: str, schema_name: str = None) -> pd.DataFrame: st.header("Last load info") col1, col2, col3 = st.columns(3) loads_df = _query_data_live( - f"SELECT load_id, inserted_at FROM {pipeline.default_schema.loads_table_name} WHERE status = 0 ORDER BY inserted_at DESC LIMIT 101 " + f"SELECT load_id, inserted_at FROM {pipeline.default_schema.loads_table_name} WHERE" + " status = 0 ORDER BY inserted_at DESC LIMIT 101 " ) loads_no = loads_df.shape[0] if loads_df.shape[0] > 0: - rel_time = humanize.naturaldelta(pendulum.now() - pendulum.from_timestamp(loads_df.iloc[0, 1].timestamp())) + " ago" + rel_time = ( + humanize.naturaldelta( + pendulum.now() - pendulum.from_timestamp(loads_df.iloc[0, 1].timestamp()) + ) + + " ago" + ) last_load_id = loads_df.iloc[0, 0] if loads_no > 100: loads_no = "> " + str(loads_no) @@ -151,7 +161,10 @@ def _query_data_live(query: str, schema_name: str = None) -> pd.DataFrame: if "parent" in table: continue table_name = table["name"] - query_parts.append(f"SELECT '{table_name}' as table_name, COUNT(1) As rows_count FROM {table_name} WHERE _dlt_load_id = '{selected_load_id}'") + query_parts.append( + f"SELECT '{table_name}' as table_name, COUNT(1) As rows_count FROM" + f" {table_name} WHERE _dlt_load_id = '{selected_load_id}'" + ) query_parts.append("UNION ALL") query_parts.pop() rows_counts_df = _query_data("\n".join(query_parts)) @@ -164,8 +177,9 @@ def _query_data_live(query: str, schema_name: str = None) -> pd.DataFrame: st.header("Schema updates") schemas_df = _query_data_live( - f"SELECT schema_name, inserted_at, version, version_hash FROM {pipeline.default_schema.version_table_name} ORDER BY inserted_at DESC LIMIT 101 " - ) + "SELECT schema_name, inserted_at, version, version_hash FROM" + f" {pipeline.default_schema.version_table_name} ORDER BY inserted_at DESC LIMIT 101 " + ) st.markdown("**100 recent schema updates**") st.dataframe(schemas_df) @@ -184,14 +198,19 @@ def _query_data_live(query: str, schema_name: str = None) -> pd.DataFrame: col2.metric("Remote state version", remote_state_version) if remote_state_version != local_state["_state_version"]: - st.warning("Looks like that local state is not yet synchronized or synchronization is disabled") + st.warning( + "Looks like that local state is not yet synchronized or synchronization is disabled" + ) except CannotRestorePipelineException as restore_ex: st.error("Seems like the pipeline does not exist. Did you run it at least once?") st.exception(restore_ex) except ConfigFieldMissingException as cf_ex: - st.error("Pipeline credentials/configuration is missing. This most often happen when you run the streamlit app from different folder than the `.dlt` with `toml` files resides.") + st.error( + "Pipeline credentials/configuration is missing. This most often happen when you run the" + " streamlit app from different folder than the `.dlt` with `toml` files resides." + ) st.text(str(cf_ex)) except Exception as ex: @@ -199,8 +218,13 @@ def _query_data_live(query: str, schema_name: str = None) -> pd.DataFrame: st.exception(ex) - -def write_data_explorer_page(pipeline: Pipeline, schema_name: str = None, show_dlt_tables: bool = False, example_query: str = "", show_charts: bool = True) -> None: +def write_data_explorer_page( + pipeline: Pipeline, + schema_name: str = None, + show_dlt_tables: bool = False, + example_query: str = "", + show_charts: bool = True, +) -> None: """Writes Streamlit app page with a schema and live data preview. ### Args: @@ -223,7 +247,6 @@ def _query_data(query: str, chunk_size: int = None) -> pd.DataFrame: except SqlClientNotAvailable: st.error("Cannot load data - SqlClient not available") - if schema_name: schema = pipeline.schemas[schema_name] else: @@ -246,7 +269,9 @@ def _query_data(query: str, chunk_size: int = None) -> pd.DataFrame: st.markdown(" | ".join(table_hints)) # table schema contains various hints (like clustering or partition options) that we do not want to show in basic view - essentials_f = lambda c: {k:v for k, v in c.items() if k in ["name", "data_type", "nullable"]} + essentials_f = lambda c: { + k: v for k, v in c.items() if k in ["name", "data_type", "nullable"] + } st.table(map(essentials_f, table["columns"].values())) # add a button that when pressed will show the full content of a table @@ -281,7 +306,6 @@ def _query_data(query: str, chunk_size: int = None) -> pd.DataFrame: # try barchart st.bar_chart(df) if df.dtypes.shape[0] == 2 and show_charts: - # try to import altair charts try: import altair as alt @@ -289,13 +313,17 @@ def _query_data(query: str, chunk_size: int = None) -> pd.DataFrame: raise MissingDependencyException( "DLT Streamlit Helpers", ["altair"], - "DLT Helpers for Streamlit should be run within a streamlit app." + "DLT Helpers for Streamlit should be run within a streamlit" + " app.", ) # try altair - bar_chart = alt.Chart(df).mark_bar().encode( - x=f'{df.columns[1]}:Q', - y=alt.Y(f'{df.columns[0]}:N', sort='-x') + bar_chart = ( + alt.Chart(df) + .mark_bar() + .encode( + x=f"{df.columns[1]}:Q", y=alt.Y(f"{df.columns[0]}:N", sort="-x") + ) ) st.altair_chart(bar_chart, use_container_width=True) except Exception as ex: diff --git a/dlt/load/configuration.py b/dlt/load/configuration.py index b378692c28..2b0bdb8c8d 100644 --- a/dlt/load/configuration.py +++ b/dlt/load/configuration.py @@ -1,8 +1,8 @@ from typing import TYPE_CHECKING from dlt.common.configuration import configspec -from dlt.common.storages import LoadStorageConfiguration from dlt.common.runners.configuration import PoolRunnerConfiguration, TPoolType +from dlt.common.storages import LoadStorageConfiguration @configspec @@ -17,11 +17,12 @@ class LoaderConfiguration(PoolRunnerConfiguration): _load_storage_config: LoadStorageConfiguration = None if TYPE_CHECKING: + def __init__( self, pool_type: TPoolType = "thread", workers: int = None, raise_on_failed_jobs: bool = False, - _load_storage_config: LoadStorageConfiguration = None + _load_storage_config: LoadStorageConfiguration = None, ) -> None: ... diff --git a/dlt/load/exceptions.py b/dlt/load/exceptions.py index 93d4ef76e1..715d771b2f 100644 --- a/dlt/load/exceptions.py +++ b/dlt/load/exceptions.py @@ -1,6 +1,6 @@ from typing import Sequence -from dlt.destinations.exceptions import DestinationTerminalException, DestinationTransientException +from dlt.destinations.exceptions import DestinationTerminalException, DestinationTransientException # class LoadException(DltException): # def __init__(self, msg: str) -> None: @@ -12,7 +12,10 @@ def __init__(self, load_id: str, job_id: str, failed_message: str) -> None: self.load_id = load_id self.job_id = job_id self.failed_message = failed_message - super().__init__(f"Job for {job_id} failed terminally in load {load_id} with message {failed_message}. The package is aborted and cannot be retried.") + super().__init__( + f"Job for {job_id} failed terminally in load {load_id} with message {failed_message}." + " The package is aborted and cannot be retried." + ) class LoadClientJobRetry(DestinationTransientException): @@ -21,15 +24,23 @@ def __init__(self, load_id: str, job_id: str, retry_count: int, max_retry_count: self.job_id = job_id self.retry_count = retry_count self.max_retry_count = max_retry_count - super().__init__(f"Job for {job_id} had {retry_count} retries which a multiple of {max_retry_count}. Exiting retry loop. You can still rerun the load package to retry this job.") + super().__init__( + f"Job for {job_id} had {retry_count} retries which a multiple of {max_retry_count}." + " Exiting retry loop. You can still rerun the load package to retry this job." + ) class LoadClientUnsupportedFileFormats(DestinationTerminalException): - def __init__(self, file_format: str, supported_file_format: Sequence[str], file_path: str) -> None: + def __init__( + self, file_format: str, supported_file_format: Sequence[str], file_path: str + ) -> None: self.file_format = file_format self.supported_types = supported_file_format self.file_path = file_path - super().__init__(f"Loader does not support writer {file_format} in file {file_path}. Supported writers: {supported_file_format}") + super().__init__( + f"Loader does not support writer {file_format} in file {file_path}. Supported writers:" + f" {supported_file_format}" + ) class LoadClientUnsupportedWriteDisposition(DestinationTerminalException): @@ -37,4 +48,7 @@ def __init__(self, table_name: str, write_disposition: str, file_name: str) -> N self.table_name = table_name self.write_disposition = write_disposition self.file_name = file_name - super().__init__(f"Loader does not support {write_disposition} in table {table_name} when loading file {file_name}") + super().__init__( + f"Loader does not support {write_disposition} in table {table_name} when loading file" + f" {file_name}" + ) diff --git a/dlt/load/load.py b/dlt/load/load.py index 3746ea4526..c9aa44eb27 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -1,36 +1,52 @@ import contextlib +import datetime # noqa: 251 +import os from copy import copy from functools import reduce -import datetime # noqa: 251 -from typing import Dict, List, Optional, Tuple, Set, Iterator from multiprocessing.pool import ThreadPool -import os +from typing import Dict, Iterator, List, Optional, Set, Tuple -from dlt.common import sleep, logger -from dlt.common.configuration import with_config, known_sections +from dlt.common import logger, sleep +from dlt.common.configuration import known_sections, with_config from dlt.common.configuration.accessors import config +from dlt.common.destination.reference import ( + DestinationClientConfiguration, + DestinationClientDwhConfiguration, + DestinationReference, + FollowupJob, + JobClientBase, + LoadJob, + NewLoadJob, + TLoadJobState, + WithStagingDataset, +) +from dlt.common.exceptions import ( + DestinationTerminalException, + DestinationTransientException, + TerminalValueError, +) from dlt.common.pipeline import LoadInfo, SupportsPipeline -from dlt.common.schema.utils import get_child_tables, get_top_level_table, get_write_disposition -from dlt.common.storages.load_storage import LoadPackageInfo, ParsedLoadJobFileName, TJobState -from dlt.common.typing import StrAny -from dlt.common.runners import TRunMetrics, Runnable, workermethod -from dlt.common.runtime.collector import Collector, NULL_COLLECTOR +from dlt.common.runners import Runnable, TRunMetrics, workermethod +from dlt.common.runtime.collector import NULL_COLLECTOR, Collector from dlt.common.runtime.logger import pretty_format_exception -from dlt.common.exceptions import TerminalValueError, DestinationTerminalException, DestinationTransientException from dlt.common.schema import Schema from dlt.common.schema.typing import TTableSchema, TWriteDisposition +from dlt.common.schema.utils import get_child_tables, get_top_level_table, get_write_disposition from dlt.common.storages import LoadStorage -from dlt.common.destination.reference import DestinationClientDwhConfiguration, FollowupJob, JobClientBase, WithStagingDataset, DestinationReference, LoadJob, NewLoadJob, TLoadJobState, DestinationClientConfiguration - -from dlt.destinations.job_impl import EmptyLoadJob +from dlt.common.storages.load_storage import LoadPackageInfo, ParsedLoadJobFileName, TJobState +from dlt.common.typing import StrAny from dlt.destinations.exceptions import LoadJobUnknownTableException - +from dlt.destinations.job_impl import EmptyLoadJob from dlt.load.configuration import LoaderConfiguration -from dlt.load.exceptions import LoadClientJobFailed, LoadClientJobRetry, LoadClientUnsupportedWriteDisposition, LoadClientUnsupportedFileFormats +from dlt.load.exceptions import ( + LoadClientJobFailed, + LoadClientJobRetry, + LoadClientUnsupportedFileFormats, + LoadClientUnsupportedWriteDisposition, +) class Load(Runnable[ThreadPool]): - @with_config(spec=LoaderConfiguration, sections=(known_sections.LOAD,)) def __init__( self, @@ -40,7 +56,7 @@ def __init__( is_storage_owner: bool = False, config: LoaderConfiguration = config.value, initial_client_config: DestinationClientConfiguration = config.value, - initial_staging_client_config: DestinationClientConfiguration = config.value + initial_staging_client_config: DestinationClientConfiguration = config.value, ) -> None: self.config = config self.collector = collector @@ -54,18 +70,20 @@ def __init__( self._processed_load_ids: Dict[str, str] = {} """Load ids to dataset name""" - def create_storage(self, is_storage_owner: bool) -> LoadStorage: supported_file_formats = self.capabilities.supported_loader_file_formats if self.staging_destination: - supported_file_formats = self.staging_destination.capabilities().supported_loader_file_formats + ["reference"] + supported_file_formats = ( + self.staging_destination.capabilities().supported_loader_file_formats + + ["reference"] + ) if isinstance(self.get_destination_client(Schema("test")), WithStagingDataset): supported_file_formats += ["sql"] load_storage = LoadStorage( is_storage_owner, self.capabilities.preferred_loader_file_format, supported_file_formats, - config=self.config._load_storage_config + config=self.config._load_storage_config, ) return load_storage @@ -89,12 +107,21 @@ def get_staging_destination_client(self, schema: Schema) -> JobClientBase: return self.staging_destination.client(schema, self.initial_staging_client_config) def is_staging_destination_job(self, file_path: str) -> bool: - return self.staging_destination is not None and os.path.splitext(file_path)[1][1:] in self.staging_destination.capabilities().supported_loader_file_formats + return ( + self.staging_destination is not None + and os.path.splitext(file_path)[1][1:] + in self.staging_destination.capabilities().supported_loader_file_formats + ) @contextlib.contextmanager - def maybe_with_staging_dataset(self, job_client: JobClientBase, table: TTableSchema) -> Iterator[None]: + def maybe_with_staging_dataset( + self, job_client: JobClientBase, table: TTableSchema + ) -> Iterator[None]: """Executes job client methods in context of staging dataset if `table` has `write_disposition` that requires it""" - if isinstance(job_client, WithStagingDataset) and table["write_disposition"] in job_client.get_stage_dispositions(): + if ( + isinstance(job_client, WithStagingDataset) + and table["write_disposition"] in job_client.get_stage_dispositions() + ): with job_client.with_staging_dataset(): yield else: @@ -102,21 +129,35 @@ def maybe_with_staging_dataset(self, job_client: JobClientBase, table: TTableSch @staticmethod @workermethod - def w_spool_job(self: "Load", file_path: str, load_id: str, schema: Schema) -> Optional[LoadJob]: + def w_spool_job( + self: "Load", file_path: str, load_id: str, schema: Schema + ) -> Optional[LoadJob]: job: LoadJob = None try: # if we have a staging destination and the file is not a reference, send to staging - job_client = self.get_staging_destination_client(schema) if self.is_staging_destination_job(file_path) else self.get_destination_client(schema) + job_client = ( + self.get_staging_destination_client(schema) + if self.is_staging_destination_job(file_path) + else self.get_destination_client(schema) + ) with job_client as job_client: job_info = self.load_storage.parse_job_file_name(file_path) if job_info.file_format not in self.load_storage.supported_file_formats: - raise LoadClientUnsupportedFileFormats(job_info.file_format, self.capabilities.supported_loader_file_formats, file_path) + raise LoadClientUnsupportedFileFormats( + job_info.file_format, + self.capabilities.supported_loader_file_formats, + file_path, + ) logger.info(f"Will load file {file_path} with table name {job_info.table_name}") table = self.get_load_table(schema, file_path) if table["write_disposition"] not in ["append", "replace", "merge"]: - raise LoadClientUnsupportedWriteDisposition(job_info.table_name, table["write_disposition"], file_path) + raise LoadClientUnsupportedWriteDisposition( + job_info.table_name, table["write_disposition"], file_path + ) with self.maybe_with_staging_dataset(job_client, table): - job = job_client.start_file_load(table, self.load_storage.storage.make_full_path(file_path), load_id) + job = job_client.start_file_load( + table, self.load_storage.storage.make_full_path(file_path), load_id + ) except (DestinationTerminalException, TerminalValueError): # if job irreversibly cannot be started, mark it as failed logger.exception(f"Terminal problem when adding job {file_path}") @@ -134,7 +175,7 @@ def spool_new_jobs(self, load_id: str, schema: Schema) -> Tuple[int, List[LoadJo # use thread based pool as jobs processing is mostly I/O and we do not want to pickle jobs # TODO: combine files by providing a list of files pertaining to same table into job, so job must be # extended to accept a list - load_files = self.load_storage.list_new_jobs(load_id)[:self.config.workers] + load_files = self.load_storage.list_new_jobs(load_id)[: self.config.workers] file_count = len(load_files) if file_count == 0: logger.info(f"No new jobs found in {load_id}") @@ -147,7 +188,9 @@ def spool_new_jobs(self, load_id: str, schema: Schema) -> Tuple[int, List[LoadJo # remove None jobs and check the rest return file_count, [job for job in jobs if job is not None] - def retrieve_jobs(self, client: JobClientBase, load_id: str, staging_client: JobClientBase = None) -> Tuple[int, List[LoadJob]]: + def retrieve_jobs( + self, client: JobClientBase, load_id: str, staging_client: JobClientBase = None + ) -> Tuple[int, List[LoadJob]]: jobs: List[LoadJob] = [] # list all files that were started but not yet completed @@ -173,18 +216,29 @@ def retrieve_jobs(self, client: JobClientBase, load_id: str, staging_client: Job return len(jobs), jobs - def get_new_jobs_info(self, load_id: str, schema: Schema, dispositions: List[TWriteDisposition] = None) -> List[ParsedLoadJobFileName]: + def get_new_jobs_info( + self, load_id: str, schema: Schema, dispositions: List[TWriteDisposition] = None + ) -> List[ParsedLoadJobFileName]: jobs_info: List[ParsedLoadJobFileName] = [] new_job_files = self.load_storage.list_new_jobs(load_id) for job_file in new_job_files: - if dispositions is None or self.get_load_table(schema, job_file)["write_disposition"] in dispositions: + if ( + dispositions is None + or self.get_load_table(schema, job_file)["write_disposition"] in dispositions + ): jobs_info.append(LoadStorage.parse_job_file_name(job_file)) return jobs_info - def get_completed_table_chain(self, load_id: str, schema: Schema, top_merged_table: TTableSchema, being_completed_job_id: str = None) -> List[TTableSchema]: + def get_completed_table_chain( + self, + load_id: str, + schema: Schema, + top_merged_table: TTableSchema, + being_completed_job_id: str = None, + ) -> List[TTableSchema]: """Gets a table chain starting from the `top_merged_table` containing only tables with completed/failed jobs. None is returned if there's any job that is not completed - Optionally `being_completed_job_id` can be passed that is considered to be completed before job itself moves in storage + Optionally `being_completed_job_id` can be passed that is considered to be completed before job itself moves in storage """ # returns ordered list of tables from parent to child leaf tables table_chain: List[TTableSchema] = [] @@ -192,17 +246,23 @@ def get_completed_table_chain(self, load_id: str, schema: Schema, top_merged_tab for table in get_child_tables(schema.tables, top_merged_table["name"]): table_jobs = self.load_storage.list_jobs_for_table(load_id, table["name"]) # all jobs must be completed in order for merge to be created - if any(job.state not in ("failed_jobs", "completed_jobs") and job.job_file_info.job_id() != being_completed_job_id for job in table_jobs): + if any( + job.state not in ("failed_jobs", "completed_jobs") + and job.job_file_info.job_id() != being_completed_job_id + for job in table_jobs + ): return None # if there are no jobs for the table, skip it, unless the write disposition is replace, as we need to create and clear the child tables if not table_jobs and top_merged_table["write_disposition"] != "replace": - continue + continue table_chain.append(table) # there must be at least table assert len(table_chain) > 0 return table_chain - def create_followup_jobs(self, load_id: str, state: TLoadJobState, starting_job: LoadJob, schema: Schema) -> List[NewLoadJob]: + def create_followup_jobs( + self, load_id: str, state: TLoadJobState, starting_job: LoadJob, schema: Schema + ) -> List[NewLoadJob]: jobs: List[NewLoadJob] = [] if isinstance(starting_job, FollowupJob): # check for merge jobs only for jobs executing on the destination, the staging destination jobs must be excluded @@ -210,10 +270,16 @@ def create_followup_jobs(self, load_id: str, state: TLoadJobState, starting_job: starting_job_file_name = starting_job.file_name() if state == "completed" and not self.is_staging_destination_job(starting_job_file_name): client = self.destination.client(schema, self.initial_client_config) - top_job_table = get_top_level_table(schema.tables, self.get_load_table(schema, starting_job_file_name)["name"]) + top_job_table = get_top_level_table( + schema.tables, self.get_load_table(schema, starting_job_file_name)["name"] + ) # if all tables of chain completed, create follow up jobs - if table_chain := self.get_completed_table_chain(load_id, schema, top_job_table, starting_job.job_file_info().job_id()): - if follow_up_jobs := client.create_table_chain_completed_followup_jobs(table_chain): + if table_chain := self.get_completed_table_chain( + load_id, schema, top_job_table, starting_job.job_file_info().job_id() + ): + if follow_up_jobs := client.create_table_chain_completed_followup_jobs( + table_chain + ): jobs = jobs + follow_up_jobs jobs = jobs + starting_job.create_followup_jobs(state) return jobs @@ -233,22 +299,34 @@ def complete_jobs(self, load_id: str, jobs: List[LoadJob], schema: Schema) -> Li # try to get exception message from job failed_message = job.exception() self.load_storage.fail_job(load_id, job.file_name(), failed_message) - logger.error(f"Job for {job.job_id()} failed terminally in load {load_id} with message {failed_message}") + logger.error( + f"Job for {job.job_id()} failed terminally in load {load_id} with message" + f" {failed_message}" + ) elif state == "retry": # try to get exception message from job retry_message = job.exception() # move back to new folder to try again self.load_storage.retry_job(load_id, job.file_name()) - logger.warning(f"Job for {job.job_id()} retried in load {load_id} with message {retry_message}") + logger.warning( + f"Job for {job.job_id()} retried in load {load_id} with message {retry_message}" + ) elif state == "completed": # create followup jobs followup_jobs = self.create_followup_jobs(load_id, state, job, schema) for followup_job in followup_jobs: # running should be moved into "new jobs", other statuses into started - folder: TJobState = "new_jobs" if followup_job.state() == "running" else "started_jobs" + folder: TJobState = ( + "new_jobs" if followup_job.state() == "running" else "started_jobs" + ) # save all created jobs - self.load_storage.add_new_job(load_id, followup_job.new_file_path(), job_state=folder) - logger.info(f"Job {job.job_id()} CREATED a new FOLLOWUP JOB {followup_job.new_file_path()} placed in {folder}") + self.load_storage.add_new_job( + load_id, followup_job.new_file_path(), job_state=folder + ) + logger.info( + f"Job {job.job_id()} CREATED a new FOLLOWUP JOB" + f" {followup_job.new_file_path()} placed in {folder}" + ) # if followup job is not "running" place it in current queue to be finalized if not followup_job.state() == "running": remaining_jobs.append(followup_job) @@ -260,7 +338,9 @@ def complete_jobs(self, load_id: str, jobs: List[LoadJob], schema: Schema) -> Li if state in ["failed", "completed"]: self.collector.update("Jobs") if state == "failed": - self.collector.update("Jobs", 1, message="WARNING: Some of the jobs failed!", label="Failed") + self.collector.update( + "Jobs", 1, message="WARNING: Some of the jobs failed!", label="Failed" + ) return remaining_jobs @@ -272,18 +352,26 @@ def complete_package(self, load_id: str, schema: Schema, aborted: bool = False) # TODO: Load must provide a clear interface to get last loads and metrics # TODO: get more info ie. was package aborted, schema name etc. if isinstance(job_client.config, DestinationClientDwhConfiguration): - self._processed_load_ids[load_id] = job_client.config.normalize_dataset_name(schema) + self._processed_load_ids[load_id] = job_client.config.normalize_dataset_name( + schema + ) else: self._processed_load_ids[load_id] = None self.load_storage.complete_load_package(load_id, aborted) - logger.info(f"All jobs completed, archiving package {load_id} with aborted set to {aborted}") + logger.info( + f"All jobs completed, archiving package {load_id} with aborted set to {aborted}" + ) - def get_table_chain_tables_for_write_disposition(self, load_id: str, schema: Schema, dispositions: List[TWriteDisposition]) -> Set[str]: + def get_table_chain_tables_for_write_disposition( + self, load_id: str, schema: Schema, dispositions: List[TWriteDisposition] + ) -> Set[str]: """Get all jobs for tables with given write disposition and resolve the table chain""" result: Set[str] = set() table_jobs = self.get_new_jobs_info(load_id, schema, dispositions) for job in table_jobs: - top_job_table = get_top_level_table(schema.tables, self.get_load_table(schema, job.job_id())["name"]) + top_job_table = get_top_level_table( + schema.tables, self.get_load_table(schema, job.job_id())["name"] + ) table_chain = get_child_tables(schema.tables, top_job_table["name"]) for table in table_chain: existing_jobs = self.load_storage.list_jobs_for_table(load_id, table["name"]) @@ -300,30 +388,57 @@ def load_single_package(self, load_id: str, schema: Schema) -> None: expected_update = self.load_storage.begin_schema_update(load_id) if expected_update is not None: # update the default dataset - logger.info(f"Client for {job_client.config.destination_name} will start initialize storage") + logger.info( + f"Client for {job_client.config.destination_name} will start initialize storage" + ) job_client.initialize_storage() - logger.info(f"Client for {job_client.config.destination_name} will update schema to package schema") + logger.info( + f"Client for {job_client.config.destination_name} will update schema to package" + " schema" + ) all_jobs = self.get_new_jobs_info(load_id, schema) all_tables = set(job.table_name for job in all_jobs) dlt_tables = set(t["name"] for t in schema.dlt_tables()) # only update tables that are present in the load package - applied_update = job_client.update_storage_schema(only_tables=all_tables | dlt_tables, expected_update=expected_update) - truncate_tables = self.get_table_chain_tables_for_write_disposition(load_id, schema, job_client.get_truncate_destination_table_dispositions()) + applied_update = job_client.update_storage_schema( + only_tables=all_tables | dlt_tables, expected_update=expected_update + ) + truncate_tables = self.get_table_chain_tables_for_write_disposition( + load_id, schema, job_client.get_truncate_destination_table_dispositions() + ) job_client.initialize_storage(truncate_tables=truncate_tables) # initialize staging storage if needed if self.staging_destination: with self.get_staging_destination_client(schema) as staging_client: - truncate_tables = self.get_table_chain_tables_for_write_disposition(load_id, schema, staging_client.get_truncate_destination_table_dispositions()) + truncate_tables = self.get_table_chain_tables_for_write_disposition( + load_id, + schema, + staging_client.get_truncate_destination_table_dispositions(), + ) staging_client.initialize_storage(truncate_tables) # update the staging dataset if client supports this if isinstance(job_client, WithStagingDataset): - if staging_tables := self.get_table_chain_tables_for_write_disposition(load_id, schema, job_client.get_stage_dispositions()): + if staging_tables := self.get_table_chain_tables_for_write_disposition( + load_id, schema, job_client.get_stage_dispositions() + ): with job_client.with_staging_dataset(): - logger.info(f"Client for {job_client.config.destination_name} will start initialize STAGING storage") + logger.info( + f"Client for {job_client.config.destination_name} will start" + " initialize STAGING storage" + ) job_client.initialize_storage() - logger.info(f"Client for {job_client.config.destination_name} will UPDATE STAGING SCHEMA to package schema") - job_client.update_storage_schema(only_tables=staging_tables | {schema.version_table_name}, expected_update=expected_update) - logger.info(f"Client for {job_client.config.destination_name} will TRUNCATE STAGING TABLES: {staging_tables}") + logger.info( + f"Client for {job_client.config.destination_name} will UPDATE" + " STAGING SCHEMA to package schema" + ) + job_client.update_storage_schema( + only_tables=staging_tables | {schema.version_table_name}, + expected_update=expected_update, + ) + logger.info( + f"Client for {job_client.config.destination_name} will TRUNCATE" + f" STAGING TABLES: {staging_tables}" + ) job_client.initialize_storage(truncate_tables=staging_tables) self.load_storage.commit_schema_update(load_id, applied_update) # initialize staging destination and spool or retrieve unfinished jobs @@ -347,7 +462,9 @@ def load_single_package(self, load_id: str, schema: Schema) -> None: no_completed_jobs = len(package_info.jobs["completed_jobs"]) + no_failed_jobs self.collector.update("Jobs", no_completed_jobs, total_jobs) if no_failed_jobs > 0: - self.collector.update("Jobs", no_failed_jobs, message="WARNING: Some of the jobs failed!", label="Failed") + self.collector.update( + "Jobs", no_failed_jobs, message="WARNING: Some of the jobs failed!", label="Failed" + ) # loop until all jobs are processed while True: try: @@ -359,13 +476,22 @@ def load_single_package(self, load_id: str, schema: Schema) -> None: if self.config.raise_on_failed_jobs: if package_info.jobs["failed_jobs"]: failed_job = package_info.jobs["failed_jobs"][0] - raise LoadClientJobFailed(load_id, failed_job.job_file_info.job_id(), failed_job.failed_message) + raise LoadClientJobFailed( + load_id, + failed_job.job_file_info.job_id(), + failed_job.failed_message, + ) # possibly raise on too many retires if self.config.raise_on_max_retries: for new_job in package_info.jobs["new_jobs"]: r_c = new_job.job_file_info.retry_count if r_c > 0 and r_c % self.config.raise_on_max_retries == 0: - raise LoadClientJobRetry(load_id, new_job.job_file_info.job_id(), r_c, self.config.raise_on_max_retries) + raise LoadClientJobRetry( + load_id, + new_job.job_file_info.job_id(), + r_c, + self.config.raise_on_max_retries, + ) break # process remaining jobs again jobs = remaining_jobs @@ -401,7 +527,9 @@ def run(self, pool: ThreadPool) -> TRunMetrics: return TRunMetrics(False, len(self.load_storage.list_packages())) - def get_load_info(self, pipeline: SupportsPipeline, started_at: datetime.datetime = None) -> LoadInfo: + def get_load_info( + self, pipeline: SupportsPipeline, started_at: datetime.datetime = None + ) -> LoadInfo: # TODO: LoadInfo should hold many datasets load_ids = list(self._processed_load_ids.keys()) load_packages: List[LoadPackageInfo] = [] @@ -414,12 +542,16 @@ def get_load_info(self, pipeline: SupportsPipeline, started_at: datetime.datetim pipeline, self.initial_client_config.destination_name, str(self.initial_client_config), - self.initial_staging_client_config.destination_name if self.initial_staging_client_config else None, + ( + self.initial_staging_client_config.destination_name + if self.initial_staging_client_config + else None + ), str(self.initial_staging_client_config) if self.initial_staging_client_config else None, self.initial_client_config.fingerprint(), _dataset_name, list(load_ids), load_packages, started_at, - pipeline.first_run + pipeline.first_run, ) diff --git a/dlt/normalize/__init__.py b/dlt/normalize/__init__.py index a40a5eaa7e..e76bae8155 100644 --- a/dlt/normalize/__init__.py +++ b/dlt/normalize/__init__.py @@ -1 +1 @@ -from .normalize import Normalize \ No newline at end of file +from .normalize import Normalize diff --git a/dlt/normalize/configuration.py b/dlt/normalize/configuration.py index c4ed7aa89a..10cf8e3be6 100644 --- a/dlt/normalize/configuration.py +++ b/dlt/normalize/configuration.py @@ -3,7 +3,11 @@ from dlt.common.configuration import configspec from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.runners.configuration import PoolRunnerConfiguration, TPoolType -from dlt.common.storages import LoadStorageConfiguration, NormalizeStorageConfiguration, SchemaStorageConfiguration +from dlt.common.storages import ( + LoadStorageConfiguration, + NormalizeStorageConfiguration, + SchemaStorageConfiguration, +) @configspec @@ -15,12 +19,13 @@ class NormalizeConfiguration(PoolRunnerConfiguration): _load_storage_config: LoadStorageConfiguration if TYPE_CHECKING: + def __init__( self, pool_type: TPoolType = "process", workers: int = None, _schema_storage_config: SchemaStorageConfiguration = None, _normalize_storage_config: NormalizeStorageConfiguration = None, - _load_storage_config: LoadStorageConfiguration = None + _load_storage_config: LoadStorageConfiguration = None, ) -> None: ... diff --git a/dlt/normalize/normalize.py b/dlt/normalize/normalize.py index 1ca08c6e47..f9c1de2ba5 100644 --- a/dlt/normalize/normalize.py +++ b/dlt/normalize/normalize.py @@ -1,39 +1,51 @@ import os -from typing import Any, Callable, List, Dict, Sequence, Tuple, Set -from multiprocessing.pool import AsyncResult, Pool as ProcessPool +from multiprocessing.pool import AsyncResult +from multiprocessing.pool import Pool as ProcessPool +from typing import Any, Callable, Dict, List, Sequence, Set, Tuple -from dlt.common import pendulum, json, logger, sleep -from dlt.common.configuration import with_config, known_sections +from dlt.common import json, logger, pendulum, sleep +from dlt.common.configuration import known_sections, with_config from dlt.common.configuration.accessors import config from dlt.common.configuration.container import Container from dlt.common.destination import DestinationCapabilitiesContext, TLoaderFileFormat from dlt.common.json import custom_pua_decode -from dlt.common.runners import TRunMetrics, Runnable +from dlt.common.runners import Runnable, TRunMetrics from dlt.common.runtime import signals -from dlt.common.runtime.collector import Collector, NULL_COLLECTOR +from dlt.common.runtime.collector import NULL_COLLECTOR, Collector +from dlt.common.schema import Schema, TSchemaUpdate +from dlt.common.schema.exceptions import CannotCoerceColumnException from dlt.common.schema.typing import TStoredSchema, TTableSchemaColumns from dlt.common.schema.utils import merge_schema_updates +from dlt.common.storages import ( + LoadStorage, + LoadStorageConfiguration, + NormalizeStorage, + NormalizeStorageConfiguration, + SchemaStorage, +) from dlt.common.storages.exceptions import SchemaNotFoundError -from dlt.common.storages import NormalizeStorage, SchemaStorage, LoadStorage, LoadStorageConfiguration, NormalizeStorageConfiguration from dlt.common.typing import TDataItem -from dlt.common.schema import TSchemaUpdate, Schema -from dlt.common.schema.exceptions import CannotCoerceColumnException from dlt.common.utils import chunks - from dlt.normalize.configuration import NormalizeConfiguration # normalize worker wrapping function (map_parallel, map_single) return type TMapFuncRV = Sequence[TSchemaUpdate] # normalize worker wrapping function signature -TMapFuncType = Callable[[Schema, str, Sequence[str]], TMapFuncRV] # input parameters: (schema name, load_id, list of files to process) +TMapFuncType = Callable[ + [Schema, str, Sequence[str]], TMapFuncRV +] # input parameters: (schema name, load_id, list of files to process) # tuple returned by the worker TWorkerRV = Tuple[List[TSchemaUpdate], int, List[str]] class Normalize(Runnable[ProcessPool]): - @with_config(spec=NormalizeConfiguration, sections=(known_sections.NORMALIZE,)) - def __init__(self, collector: Collector = NULL_COLLECTOR, schema_storage: SchemaStorage = None, config: NormalizeConfiguration = config.value) -> None: + def __init__( + self, + collector: Collector = NULL_COLLECTOR, + schema_storage: SchemaStorage = None, + config: NormalizeConfiguration = config.value, + ) -> None: self.config = config self.collector = collector self.pool: ProcessPool = None @@ -44,20 +56,31 @@ def __init__(self, collector: Collector = NULL_COLLECTOR, schema_storage: Schema # setup storages self.create_storages() # create schema storage with give type - self.schema_storage = schema_storage or SchemaStorage(self.config._schema_storage_config, makedirs=True) + self.schema_storage = schema_storage or SchemaStorage( + self.config._schema_storage_config, makedirs=True + ) def create_storages(self) -> None: # pass initial normalize storage config embedded in normalize config - self.normalize_storage = NormalizeStorage(True, config=self.config._normalize_storage_config) + self.normalize_storage = NormalizeStorage( + True, config=self.config._normalize_storage_config + ) # normalize saves in preferred format but can read all supported formats - self.load_storage = LoadStorage(True, self.config.destination_capabilities.preferred_loader_file_format, LoadStorage.ALL_SUPPORTED_FILE_FORMATS, config=self.config._load_storage_config) + self.load_storage = LoadStorage( + True, + self.config.destination_capabilities.preferred_loader_file_format, + LoadStorage.ALL_SUPPORTED_FILE_FORMATS, + config=self.config._load_storage_config, + ) @staticmethod def load_or_create_schema(schema_storage: SchemaStorage, schema_name: str) -> Schema: try: schema = schema_storage.load_schema(schema_name) schema.update_normalizers() - logger.info(f"Loaded schema with name {schema_name} with version {schema.stored_version}") + logger.info( + f"Loaded schema with name {schema_name} with version {schema.stored_version}" + ) except SchemaNotFoundError: schema = Schema(schema_name) logger.info(f"Created new schema with name {schema_name}") @@ -65,20 +88,24 @@ def load_or_create_schema(schema_storage: SchemaStorage, schema_name: str) -> Sc @staticmethod def w_normalize_files( - normalize_storage_config: NormalizeStorageConfiguration, - loader_storage_config: LoadStorageConfiguration, - destination_caps: DestinationCapabilitiesContext, - stored_schema: TStoredSchema, - load_id: str, - extracted_items_files: Sequence[str], - ) -> TWorkerRV: - + normalize_storage_config: NormalizeStorageConfiguration, + loader_storage_config: LoadStorageConfiguration, + destination_caps: DestinationCapabilitiesContext, + stored_schema: TStoredSchema, + load_id: str, + extracted_items_files: Sequence[str], + ) -> TWorkerRV: schema_updates: List[TSchemaUpdate] = [] total_items = 0 # process all files with data items and write to buffered item storage with Container().injectable_context(destination_caps): schema = Schema.from_stored_schema(stored_schema) - load_storage = LoadStorage(False, destination_caps.preferred_loader_file_format, LoadStorage.ALL_SUPPORTED_FILE_FORMATS, loader_storage_config) + load_storage = LoadStorage( + False, + destination_caps.preferred_loader_file_format, + LoadStorage.ALL_SUPPORTED_FILE_FORMATS, + loader_storage_config, + ) normalize_storage = NormalizeStorage(False, normalize_storage_config) try: @@ -86,22 +113,35 @@ def w_normalize_files( populated_root_tables: Set[str] = set() for extracted_items_file in extracted_items_files: line_no: int = 0 - root_table_name = NormalizeStorage.parse_normalize_file_name(extracted_items_file).table_name + root_table_name = NormalizeStorage.parse_normalize_file_name( + extracted_items_file + ).table_name root_tables.add(root_table_name) - logger.debug(f"Processing extracted items in {extracted_items_file} in load_id {load_id} with table name {root_table_name} and schema {schema.name}") + logger.debug( + f"Processing extracted items in {extracted_items_file} in load_id" + f" {load_id} with table name {root_table_name} and schema {schema.name}" + ) with normalize_storage.storage.open_file(extracted_items_file) as f: # enumerate jsonl file line by line items_count = 0 for line_no, line in enumerate(f): items: List[TDataItem] = json.loads(line) - partial_update, items_count = Normalize._w_normalize_chunk(load_storage, schema, load_id, root_table_name, items) + partial_update, items_count = Normalize._w_normalize_chunk( + load_storage, schema, load_id, root_table_name, items + ) schema_updates.append(partial_update) total_items += items_count - logger.debug(f"Processed {line_no} items from file {extracted_items_file}, items {items_count} of total {total_items}") + logger.debug( + f"Processed {line_no} items from file {extracted_items_file}, items" + f" {items_count} of total {total_items}" + ) # if any item found in the file if items_count > 0: populated_root_tables.add(root_table_name) - logger.debug(f"Processed total {line_no + 1} lines from file {extracted_items_file}, total items {total_items}") + logger.debug( + f"Processed total {line_no + 1} lines from file" + f" {extracted_items_file}, total items {total_items}" + ) # write empty jobs for tables without items if table exists in schema for table_name in root_tables - populated_root_tables: if table_name not in schema.tables: @@ -110,7 +150,9 @@ def w_normalize_files( columns = schema.get_table_columns(table_name) load_storage.write_empty_file(load_id, schema.name, table_name, columns) except Exception: - logger.exception(f"Exception when processing file {extracted_items_file}, line {line_no}") + logger.exception( + f"Exception when processing file {extracted_items_file}, line {line_no}" + ) raise finally: load_storage.close_writers(load_id) @@ -120,14 +162,24 @@ def w_normalize_files( return schema_updates, total_items, load_storage.closed_files() @staticmethod - def _w_normalize_chunk(load_storage: LoadStorage, schema: Schema, load_id: str, root_table_name: str, items: List[TDataItem]) -> Tuple[TSchemaUpdate, int]: - column_schemas: Dict[str, TTableSchemaColumns] = {} # quick access to column schema for writers below + def _w_normalize_chunk( + load_storage: LoadStorage, + schema: Schema, + load_id: str, + root_table_name: str, + items: List[TDataItem], + ) -> Tuple[TSchemaUpdate, int]: + column_schemas: Dict[str, TTableSchemaColumns] = ( + {} + ) # quick access to column schema for writers below schema_update: TSchemaUpdate = {} schema_name = schema.name items_count = 0 for item in items: - for (table_name, parent_table), row in schema.normalize_data_item(item, load_id, root_table_name): + for (table_name, parent_table), row in schema.normalize_data_item( + item, load_id, root_table_name + ): # filter row, may eliminate some or all fields row = schema.filter_row(table_name, row) # do not process empty rows @@ -161,7 +213,9 @@ def _w_normalize_chunk(load_storage: LoadStorage, schema: Schema, load_id: str, def update_schema(self, schema: Schema, schema_updates: List[TSchemaUpdate]) -> None: for schema_update in schema_updates: for table_name, table_updates in schema_update.items(): - logger.info(f"Updating schema for table {table_name} with {len(table_updates)} deltas") + logger.info( + f"Updating schema for table {table_name} with {len(table_updates)} deltas" + ) for partial_table in table_updates: # merge columns schema.update_schema(partial_table) @@ -179,7 +233,7 @@ def group_worker_files(files: Sequence[str], no_groups: int) -> List[Sequence[st while remainder_l > 0: for idx, file in enumerate(reversed(chunk_files.pop())): chunk_files[-l_idx - idx - remainder_l].append(file) # type: ignore - remainder_l -=1 + remainder_l -= 1 l_idx = idx + 1 return chunk_files @@ -187,7 +241,12 @@ def map_parallel(self, schema: Schema, load_id: str, files: Sequence[str]) -> TM workers = self.pool._processes # type: ignore chunk_files = self.group_worker_files(files, workers) schema_dict: TStoredSchema = schema.to_dict() - config_tuple = (self.normalize_storage.config, self.load_storage.config, self.config.destination_capabilities, schema_dict) + config_tuple = ( + self.normalize_storage.config, + self.load_storage.config, + self.config.destination_capabilities, + schema_dict, + ) param_chunk = [[*config_tuple, load_id, files] for files in chunk_files] tasks: List[Tuple[AsyncResult[TWorkerRV], List[Any]]] = [] @@ -196,7 +255,9 @@ def map_parallel(self, schema: Schema, load_id: str, files: Sequence[str]) -> TM # push all tasks to queue for params in param_chunk: - pending: AsyncResult[TWorkerRV] = self.pool.apply_async(Normalize.w_normalize_files, params) + pending: AsyncResult[TWorkerRV] = self.pool.apply_async( + Normalize.w_normalize_files, params + ) tasks.append((pending, params)) while len(tasks) > 0: @@ -216,7 +277,9 @@ def map_parallel(self, schema: Schema, load_id: str, files: Sequence[str]) -> TM self.collector.update("Items", result[1]) except CannotCoerceColumnException as exc: # schema conflicts resulting from parallel executing - logger.warning(f"Parallel schema update conflict, retrying task ({str(exc)}") + logger.warning( + f"Parallel schema update conflict, retrying task ({str(exc)}" + ) # delete all files produced by the task for file in result[2]: os.remove(file) @@ -224,7 +287,9 @@ def map_parallel(self, schema: Schema, load_id: str, files: Sequence[str]) -> TM schema_dict = schema.to_dict() # TODO: it's time for a named tuple params[3] = schema_dict - retry_pending: AsyncResult[TWorkerRV] = self.pool.apply_async(Normalize.w_normalize_files, params) + retry_pending: AsyncResult[TWorkerRV] = self.pool.apply_async( + Normalize.w_normalize_files, params + ) tasks.append((retry_pending, params)) # remove finished tasks tasks.remove(task) @@ -249,14 +314,18 @@ def map_single(self, schema: Schema, load_id: str, files: Sequence[str]) -> TMap self.collector.update("Items", result[1]) return result[0] - def spool_files(self, schema_name: str, load_id: str, map_f: TMapFuncType, files: Sequence[str]) -> None: + def spool_files( + self, schema_name: str, load_id: str, map_f: TMapFuncType, files: Sequence[str] + ) -> None: schema = Normalize.load_or_create_schema(self.schema_storage, schema_name) # process files in parallel or in single thread, depending on map_f schema_updates = map_f(schema, load_id, files) # logger.metrics("Normalize metrics", extra=get_logging_extras([self.schema_version_gauge.labels(schema_name)])) if len(schema_updates) > 0: - logger.info(f"Saving schema {schema_name} with version {schema.version}, writing manifest files") + logger.info( + f"Saving schema {schema_name} with version {schema.version}, writing manifest files" + ) # schema is updated, save it to schema volume self.schema_storage.save_schema(schema) # save schema to temp load folder @@ -270,7 +339,7 @@ def spool_files(self, schema_name: str, load_id: str, map_f: TMapFuncType, files self.load_storage.commit_temp_load_package(load_id) # delete item files to complete commit for file in files: - self.normalize_storage.storage.delete(file) + self.normalize_storage.storage.delete(file) # log and update metrics logger.info(f"Chunk {load_id} processed") @@ -287,7 +356,9 @@ def spool_schema_files(self, load_id: str, schema_name: str, files: Sequence[str self.spool_files(schema_name, load_id, map_parallel_f, files) except CannotCoerceColumnException as exc: # schema conflicts resulting from parallel executing - logger.warning(f"Parallel schema update conflict, switching to single thread ({str(exc)}") + logger.warning( + f"Parallel schema update conflict, switching to single thread ({str(exc)}" + ) # start from scratch self.load_storage.create_temp_load_package(load_id) self.spool_files(schema_name, load_id, self.map_single, files) @@ -307,7 +378,9 @@ def run(self, pool: ProcessPool) -> TRunMetrics: for schema_name, files_iter in self.normalize_storage.group_by_schema(files): schema_files = list(files_iter) load_id = str(pendulum.now().timestamp()) - logger.info(f"Found {len(schema_files)} files in schema {schema_name} load_id {load_id}") + logger.info( + f"Found {len(schema_files)} files in schema {schema_name} load_id {load_id}" + ) with self.collector(f"Normalize {schema_name} in {load_id}"): self.collector.update("Files", 0, len(schema_files)) self.collector.update("Items", 0) diff --git a/dlt/pipeline/__init__.py b/dlt/pipeline/__init__.py index df4314cf0d..26e55d96a4 100644 --- a/dlt/pipeline/__init__.py +++ b/dlt/pipeline/__init__.py @@ -1,19 +1,18 @@ from typing import Sequence, cast, overload -from dlt.common.schema import Schema -from dlt.common.schema.typing import TColumnSchema, TWriteDisposition - -from dlt.common.typing import TSecretValue, Any from dlt.common.configuration import with_config from dlt.common.configuration.container import Container from dlt.common.configuration.inject import get_orig_args, last_config +from dlt.common.data_writers import TLoaderFileFormat from dlt.common.destination.reference import DestinationReference, TDestinationReferenceArg from dlt.common.pipeline import LoadInfo, PipelineContext, get_dlt_pipelines_dir -from dlt.common.data_writers import TLoaderFileFormat - +from dlt.common.schema import Schema +from dlt.common.schema.typing import TColumnSchema, TWriteDisposition +from dlt.common.typing import Any, TSecretValue from dlt.pipeline.configuration import PipelineConfiguration, ensure_correct_pipeline_kwargs from dlt.pipeline.pipeline import Pipeline -from dlt.pipeline.progress import _from_name as collector_from_name, TCollectorArg, _NULL_COLLECTOR +from dlt.pipeline.progress import _NULL_COLLECTOR, TCollectorArg +from dlt.pipeline.progress import _from_name as collector_from_name @overload @@ -118,7 +117,11 @@ def pipeline( pipelines_dir = get_dlt_pipelines_dir() destination = DestinationReference.from_name(destination or kwargs["destination_name"]) - staging = DestinationReference.from_name(staging or kwargs.get("staging_name", None)) if staging is not None else None + staging = ( + DestinationReference.from_name(staging or kwargs.get("staging_name", None)) + if staging is not None + else None + ) progress = collector_from_name(progress) # create new pipeline instance @@ -136,7 +139,8 @@ def pipeline( progress, False, last_config(**kwargs), - kwargs["runtime"]) + kwargs["runtime"], + ) # set it as current pipeline p.activate() return p @@ -159,7 +163,22 @@ def attach( pipelines_dir = get_dlt_pipelines_dir() progress = collector_from_name(progress) # create new pipeline instance - p = Pipeline(pipeline_name, pipelines_dir, pipeline_salt, None, None, None, credentials, None, None, full_refresh, progress, True, last_config(**kwargs), kwargs["runtime"]) + p = Pipeline( + pipeline_name, + pipelines_dir, + pipeline_salt, + None, + None, + None, + credentials, + None, + None, + full_refresh, + progress, + True, + last_config(**kwargs), + kwargs["runtime"], + ) # set it as current pipeline p.activate() return p @@ -235,11 +254,13 @@ def run( table_name=table_name, write_disposition=write_disposition, columns=columns, - schema=schema + schema=schema, ) + # plug default tracking module from dlt.pipeline import trace, track + trace.TRACKING_MODULE = track # setup default pipeline in the container diff --git a/dlt/pipeline/configuration.py b/dlt/pipeline/configuration.py index 3d0c70f4b1..f46e868158 100644 --- a/dlt/pipeline/configuration.py +++ b/dlt/pipeline/configuration.py @@ -1,10 +1,10 @@ from typing import Any, Optional from dlt.common.configuration import configspec -from dlt.common.configuration.specs import RunConfiguration, BaseConfiguration +from dlt.common.configuration.specs import BaseConfiguration, RunConfiguration +from dlt.common.data_writers import TLoaderFileFormat from dlt.common.typing import AnyFun, TSecretValue from dlt.common.utils import digest256 -from dlt.common.data_writers import TLoaderFileFormat @configspec diff --git a/dlt/pipeline/current.py b/dlt/pipeline/current.py index f915a30932..b339c37ae3 100644 --- a/dlt/pipeline/current.py +++ b/dlt/pipeline/current.py @@ -1,8 +1,9 @@ """Easy access to active pipelines, state, sources and schemas""" -from dlt.common.pipeline import source_state as _state, resource_state -from dlt.pipeline import pipeline as _pipeline +from dlt.common.pipeline import resource_state +from dlt.common.pipeline import source_state as _state from dlt.extract.decorators import get_source_schema +from dlt.pipeline import pipeline as _pipeline pipeline = _pipeline """Alias for dlt.pipeline""" diff --git a/dlt/pipeline/dbt.py b/dlt/pipeline/dbt.py index 70bd425f12..ddbe49b9fc 100644 --- a/dlt/pipeline/dbt.py +++ b/dlt/pipeline/dbt.py @@ -1,16 +1,22 @@ -import os import contextlib +import os + from dlt.common.exceptions import VenvNotFound from dlt.common.runners import Venv from dlt.common.schema import Schema -from dlt.common.typing import TSecretValue from dlt.common.schema.utils import normalize_schema_name - -from dlt.helpers.dbt import create_venv as _create_venv, package_runner as _package_runner, DBTPackageRunner, DEFAULT_DBT_VERSION as _DEFAULT_DBT_VERSION, restore_venv as _restore_venv +from dlt.common.typing import TSecretValue +from dlt.helpers.dbt import DEFAULT_DBT_VERSION as _DEFAULT_DBT_VERSION +from dlt.helpers.dbt import DBTPackageRunner +from dlt.helpers.dbt import create_venv as _create_venv +from dlt.helpers.dbt import package_runner as _package_runner +from dlt.helpers.dbt import restore_venv as _restore_venv from dlt.pipeline.pipeline import Pipeline -def get_venv(pipeline: Pipeline, venv_path: str = "dbt", dbt_version: str = _DEFAULT_DBT_VERSION) -> Venv: +def get_venv( + pipeline: Pipeline, venv_path: str = "dbt", dbt_version: str = _DEFAULT_DBT_VERSION +) -> Venv: """Creates or restores a virtual environment in which the `dbt` packages are executed. The recommended way to execute dbt package is to use a separate virtual environment where only the dbt-core @@ -42,12 +48,12 @@ def get_venv(pipeline: Pipeline, venv_path: str = "dbt", dbt_version: str = _DEF def package( - pipeline: Pipeline, - package_location: str, - package_repository_branch: str = None, - package_repository_ssh_key: TSecretValue = TSecretValue(""), # noqa - auto_full_refresh_when_out_of_sync: bool = None, - venv: Venv = None + pipeline: Pipeline, + package_location: str, + package_repository_branch: str = None, + package_repository_ssh_key: TSecretValue = TSecretValue(""), # noqa + auto_full_refresh_when_out_of_sync: bool = None, + venv: Venv = None, ) -> DBTPackageRunner: """Creates a Python wrapper over `dbt` package present at specified location, that allows to control it (ie. run and test) from Python code. @@ -70,7 +76,11 @@ def package( Returns: DBTPackageRunner: A configured and authenticated Python `dbt` wrapper """ - schema = pipeline.default_schema if pipeline.default_schema_name else Schema(normalize_schema_name(pipeline.dataset_name)) + schema = ( + pipeline.default_schema + if pipeline.default_schema_name + else Schema(normalize_schema_name(pipeline.dataset_name)) + ) job_client = pipeline._sql_job_client(schema) if not venv: venv = Venv.restore_current() @@ -81,5 +91,5 @@ def package( package_location, package_repository_branch, package_repository_ssh_key, - auto_full_refresh_when_out_of_sync + auto_full_refresh_when_out_of_sync, ) diff --git a/dlt/pipeline/exceptions.py b/dlt/pipeline/exceptions.py index 4b283a17e7..aa83d81f70 100644 --- a/dlt/pipeline/exceptions.py +++ b/dlt/pipeline/exceptions.py @@ -1,4 +1,5 @@ from typing import Any + from dlt.common.exceptions import PipelineException from dlt.common.pipeline import SupportsPipeline from dlt.pipeline.typing import TPipelineStep @@ -6,14 +7,24 @@ class InvalidPipelineName(PipelineException, ValueError): def __init__(self, pipeline_name: str, details: str) -> None: - super().__init__(pipeline_name, f"The pipeline name {pipeline_name} contains invalid characters. The pipeline name is used to create a pipeline working directory and must be a valid directory name. The actual error is: {details}") + super().__init__( + pipeline_name, + f"The pipeline name {pipeline_name} contains invalid characters. The pipeline name is" + " used to create a pipeline working directory and must be a valid directory name. The" + f" actual error is: {details}", + ) class PipelineConfigMissing(PipelineException): - def __init__(self, pipeline_name: str, config_elem: str, step: TPipelineStep, _help: str = None) -> None: + def __init__( + self, pipeline_name: str, config_elem: str, step: TPipelineStep, _help: str = None + ) -> None: self.config_elem = config_elem self.step = step - msg = f"Configuration element {config_elem} was not provided and {step} step cannot be executed" + msg = ( + f"Configuration element {config_elem} was not provided and {step} step cannot be" + " executed" + ) if _help: msg += f"\n{_help}\n" super().__init__(pipeline_name, msg) @@ -21,41 +32,67 @@ def __init__(self, pipeline_name: str, config_elem: str, step: TPipelineStep, _h class CannotRestorePipelineException(PipelineException): def __init__(self, pipeline_name: str, pipelines_dir: str, reason: str) -> None: - msg = f"Pipeline with name {pipeline_name} in working directory {pipelines_dir} could not be restored: {reason}" + msg = ( + f"Pipeline with name {pipeline_name} in working directory {pipelines_dir} could not be" + f" restored: {reason}" + ) super().__init__(pipeline_name, msg) class SqlClientNotAvailable(PipelineException): - def __init__(self, pipeline_name: str,destination_name: str) -> None: - super().__init__(pipeline_name, f"SQL Client not available for destination {destination_name} in pipeline {pipeline_name}") + def __init__(self, pipeline_name: str, destination_name: str) -> None: + super().__init__( + pipeline_name, + f"SQL Client not available for destination {destination_name} in pipeline" + f" {pipeline_name}", + ) class PipelineStepFailed(PipelineException): - def __init__(self, pipeline: SupportsPipeline, step: TPipelineStep, exception: BaseException, step_info: Any = None) -> None: + def __init__( + self, + pipeline: SupportsPipeline, + step: TPipelineStep, + exception: BaseException, + step_info: Any = None, + ) -> None: self.pipeline = pipeline self.step = step self.exception = exception self.step_info = step_info - super().__init__(pipeline.pipeline_name, f"Pipeline execution failed at stage {step} with exception:\n\n{type(exception)}\n{exception}") + super().__init__( + pipeline.pipeline_name, + f"Pipeline execution failed at stage {step} with" + f" exception:\n\n{type(exception)}\n{exception}", + ) class PipelineStateEngineNoUpgradePathException(PipelineException): - def __init__(self, pipeline_name: str, init_engine: int, from_engine: int, to_engine: int) -> None: + def __init__( + self, pipeline_name: str, init_engine: int, from_engine: int, to_engine: int + ) -> None: self.init_engine = init_engine self.from_engine = from_engine self.to_engine = to_engine - super().__init__(pipeline_name, f"No engine upgrade path for state in pipeline {pipeline_name} from {init_engine} to {to_engine}, stopped at {from_engine}") + super().__init__( + pipeline_name, + f"No engine upgrade path for state in pipeline {pipeline_name} from {init_engine} to" + f" {to_engine}, stopped at {from_engine}", + ) class PipelineHasPendingDataException(PipelineException): def __init__(self, pipeline_name: str, pipelines_dir: str) -> None: msg = ( - f" Operation failed because pipeline with name {pipeline_name} in working directory {pipelines_dir} contains pending extracted files or load packages. " - "Use `dlt pipeline sync` to reset the local state then run this operation again." + f" Operation failed because pipeline with name {pipeline_name} in working directory" + f" {pipelines_dir} contains pending extracted files or load packages. Use `dlt pipeline" + " sync` to reset the local state then run this operation again." ) super().__init__(pipeline_name, msg) class PipelineNotActive(PipelineException): def __init__(self, pipeline_name: str) -> None: - super().__init__(pipeline_name, f"Pipeline {pipeline_name} is not active so it cannot be deactivated") + super().__init__( + pipeline_name, f"Pipeline {pipeline_name} is not active so it cannot be deactivated" + ) diff --git a/dlt/pipeline/helpers.py b/dlt/pipeline/helpers.py index d913464aa6..a9238b97a9 100644 --- a/dlt/pipeline/helpers.py +++ b/dlt/pipeline/helpers.py @@ -1,22 +1,33 @@ import contextlib -from typing import Callable, Sequence, Iterable, Optional, Any, List, Dict, Tuple, Union, TypedDict from itertools import chain +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, TypedDict, Union -from dlt.common.jsonpath import resolve_paths, TAnyJsonPath, compile_paths +from dlt.common.destination.reference import WithStagingDataset from dlt.common.exceptions import TerminalException -from dlt.common.schema.utils import group_tables_by_resource, compile_simple_regexes, compile_simple_regex +from dlt.common.jsonpath import TAnyJsonPath, compile_paths, resolve_paths +from dlt.common.pipeline import ( + TSourceState, + _delete_source_state_keys, + _get_matching_resources, + _reset_resource_state, + _sources_state, +) from dlt.common.schema.typing import TSimpleRegex +from dlt.common.schema.utils import ( + compile_simple_regex, + compile_simple_regexes, + group_tables_by_resource, +) from dlt.common.typing import REPattern -from dlt.common.pipeline import TSourceState, _reset_resource_state, _sources_state, _delete_source_state_keys, _get_matching_resources -from dlt.common.destination.reference import WithStagingDataset - from dlt.destinations.exceptions import DatabaseUndefinedRelation -from dlt.pipeline.exceptions import PipelineStepFailed, PipelineHasPendingDataException -from dlt.pipeline.typing import TPipelineStep from dlt.pipeline import Pipeline +from dlt.pipeline.exceptions import PipelineHasPendingDataException, PipelineStepFailed +from dlt.pipeline.typing import TPipelineStep -def retry_load(retry_on_pipeline_steps: Sequence[TPipelineStep] = ("load",)) -> Callable[[BaseException], bool]: +def retry_load( + retry_on_pipeline_steps: Sequence[TPipelineStep] = ("load",) +) -> Callable[[BaseException], bool]: """A retry strategy for Tenacity that, with default setting, will repeat `load` step for all exceptions that are not terminal Use this condition with tenacity `retry_if_exception`. Terminal exceptions are exceptions that will not go away when operations is repeated. @@ -31,12 +42,15 @@ def retry_load(retry_on_pipeline_steps: Sequence[TPipelineStep] = ("load",)) -> retry_on_pipeline_steps (Tuple[TPipelineStep, ...], optional): which pipeline steps are allowed to be repeated. Default: "load" """ + def _retry_load(ex: BaseException) -> bool: # do not retry in normalize or extract stages if isinstance(ex, PipelineStepFailed) and ex.step not in retry_on_pipeline_steps: return False # do not retry on terminal exceptions - if isinstance(ex, TerminalException) or (ex.__context__ is not None and isinstance(ex.__context__, TerminalException)): + if isinstance(ex, TerminalException) or ( + ex.__context__ is not None and isinstance(ex.__context__, TerminalException) + ): return False return True @@ -80,14 +94,16 @@ def __init__( resources = set(resources) resource_names = [] if drop_all: - self.resource_pattern = compile_simple_regex(TSimpleRegex('re:.*')) # Match everything + self.resource_pattern = compile_simple_regex(TSimpleRegex("re:.*")) # Match everything elif resources: self.resource_pattern = compile_simple_regexes(TSimpleRegex(r) for r in resources) else: self.resource_pattern = None if self.resource_pattern: - data_tables = {t["name"]: t for t in self.schema.data_tables()} # Don't remove _dlt tables + data_tables = { + t["name"]: t for t in self.schema.data_tables() + } # Don't remove _dlt tables resource_tables = group_tables_by_resource(data_tables, pattern=self.resource_pattern) if self.drop_tables: self.tables_to_drop = list(chain.from_iterable(resource_tables.values())) @@ -102,25 +118,34 @@ def __init__( self.drop_all = drop_all self.info: _DropInfo = dict( - tables=[t['name'] for t in self.tables_to_drop], resource_states=[], state_paths=[], + tables=[t["name"] for t in self.tables_to_drop], + resource_states=[], + state_paths=[], resource_names=resource_names, - schema_name=self.schema.name, dataset_name=self.pipeline.dataset_name, + schema_name=self.schema.name, + dataset_name=self.pipeline.dataset_name, drop_all=drop_all, resource_pattern=self.resource_pattern, - warnings=[] + warnings=[], ) if self.resource_pattern and not resource_tables: - self.info['warnings'].append( - f"Specified resource(s) {str(resources)} did not select any table(s) in schema {self.schema.name}. Possible resources are: {list(group_tables_by_resource(data_tables).keys())}" + self.info["warnings"].append( + f"Specified resource(s) {str(resources)} did not select any table(s) in schema" + f" {self.schema.name}. Possible resources are:" + f" {list(group_tables_by_resource(data_tables).keys())}" ) self._new_state = self._create_modified_state() @property def is_empty(self) -> bool: - return len(self.info['tables']) == 0 and len(self.info["state_paths"]) == 0 and len(self.info["resource_states"]) == 0 + return ( + len(self.info["tables"]) == 0 + and len(self.info["state_paths"]) == 0 + and len(self.info["resource_states"]) == 0 + ) def _drop_destination_tables(self) -> None: - table_names = [tbl['name'] for tbl in self.tables_to_drop] + table_names = [tbl["name"] for tbl in self.tables_to_drop] with self.pipeline._sql_job_client(self.schema) as client: client.drop_tables(*table_names) # also delete staging but ignore if staging does not exist @@ -131,7 +156,7 @@ def _drop_destination_tables(self) -> None: def _delete_pipeline_tables(self) -> None: for tbl in self.tables_to_drop: - del self.schema_tables[tbl['name']] + del self.schema_tables[tbl["name"]] self.schema.bump_version() def _list_state_paths(self, source_state: Dict[str, Any]) -> List[str]: @@ -145,13 +170,16 @@ def _create_modified_state(self) -> Dict[str, Any]: for source_name, source_state in source_states: if self.drop_state: for key in _get_matching_resources(self.resource_pattern, source_state): - self.info['resource_states'].append(key) + self.info["resource_states"].append(key) _reset_resource_state(key, source_state) resolved_paths = resolve_paths(self.state_paths_to_drop, source_state) if self.state_paths_to_drop and not resolved_paths: - self.info['warnings'].append(f"State paths {self.state_paths_to_drop} did not select any paths in source {source_name}") + self.info["warnings"].append( + f"State paths {self.state_paths_to_drop} did not select any paths in source" + f" {source_name}" + ) _delete_source_state_keys(resolved_paths, source_state) - self.info['state_paths'].extend(f"{source_name}.{p}" for p in resolved_paths) + self.info["state_paths"].extend(f"{source_name}.{p}" for p in resolved_paths) return state # type: ignore[return-value] def _drop_state_keys(self) -> None: @@ -161,8 +189,12 @@ def _drop_state_keys(self) -> None: state.update(self._new_state) def __call__(self) -> None: - if self.pipeline.has_pending_data: # Raise when there are pending extracted/load files to prevent conflicts - raise PipelineHasPendingDataException(self.pipeline.pipeline_name, self.pipeline.pipelines_dir) + if ( + self.pipeline.has_pending_data + ): # Raise when there are pending extracted/load files to prevent conflicts + raise PipelineHasPendingDataException( + self.pipeline.pipeline_name, self.pipeline.pipelines_dir + ) self.pipeline.sync_destination() if not self.drop_state and not self.drop_tables: @@ -193,6 +225,6 @@ def drop( schema_name: str = None, state_paths: TAnyJsonPath = (), drop_all: bool = False, - state_only: bool = False + state_only: bool = False, ) -> None: return DropCommand(pipeline, resources, schema_name, state_paths, drop_all, state_only)() diff --git a/dlt/pipeline/mark.py b/dlt/pipeline/mark.py index 5f880d8711..e1dbe6e70f 100644 --- a/dlt/pipeline/mark.py +++ b/dlt/pipeline/mark.py @@ -1,2 +1,2 @@ """Module with market functions that make data to be specially processed""" -from dlt.extract.source import with_table_name \ No newline at end of file +from dlt.extract.source import with_table_name diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index b72feb4888..9cf6fe36b3 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -1,60 +1,126 @@ import contextlib -import os import datetime # noqa: 251 +import os +from collections.abc import Sequence as C_Sequence from contextlib import contextmanager from functools import wraps -from collections.abc import Sequence as C_Sequence -from typing import Any, Callable, ClassVar, List, Iterator, Optional, Sequence, Tuple, cast, get_type_hints, ContextManager +from typing import ( + Any, + Callable, + ClassVar, + ContextManager, + Iterator, + List, + Optional, + Sequence, + Tuple, + cast, + get_type_hints, +) from dlt import version from dlt.common import json, logger, pendulum from dlt.common.configuration import inject_section, known_sections -from dlt.common.configuration.specs import RunConfiguration, CredentialsConfiguration from dlt.common.configuration.container import Container -from dlt.common.configuration.exceptions import ConfigFieldMissingException, ContextDefaultCannotBeCreated -from dlt.common.configuration.specs.config_section_context import ConfigSectionContext +from dlt.common.configuration.exceptions import ( + ConfigFieldMissingException, + ContextDefaultCannotBeCreated, +) from dlt.common.configuration.resolve import initialize_credentials -from dlt.common.exceptions import (DestinationLoadingViaStagingNotSupported, DestinationLoadingWithoutStagingNotSupported, DestinationNoStagingMode, - MissingDependencyException, DestinationUndefinedEntity, DestinationIncompatibleLoaderFileFormatException) +from dlt.common.configuration.specs import CredentialsConfiguration, RunConfiguration +from dlt.common.configuration.specs.config_section_context import ConfigSectionContext +from dlt.common.data_writers import TLoaderFileFormat +from dlt.common.destination import DestinationCapabilitiesContext +from dlt.common.destination.capabilities import INTERNAL_LOADER_FILE_FORMATS +from dlt.common.destination.reference import ( + DestinationClientConfiguration, + DestinationClientDwhConfiguration, + DestinationClientDwhWithStagingConfiguration, + DestinationClientStagingConfiguration, + DestinationReference, + JobClientBase, + TDestinationReferenceArg, +) +from dlt.common.exceptions import ( + DestinationIncompatibleLoaderFileFormatException, + DestinationLoadingViaStagingNotSupported, + DestinationLoadingWithoutStagingNotSupported, + DestinationNoStagingMode, + DestinationUndefinedEntity, + MissingDependencyException, +) from dlt.common.normalizers import explicit_normalizers, import_normalizers -from dlt.common.runtime import signals, initialize_runtime +from dlt.common.pipeline import ( + ExtractInfo, + LoadInfo, + NormalizeInfo, + PipelineContext, + StateInjectableContext, + SupportsPipeline, + TPipelineLocalState, + TPipelineState, +) +from dlt.common.runners import pool_runner as runner +from dlt.common.runtime import initialize_runtime, signals +from dlt.common.schema import Schema from dlt.common.schema.typing import TColumnNames, TColumnSchema, TSchemaTables, TWriteDisposition +from dlt.common.schema.utils import normalize_schema_name +from dlt.common.storages import ( + FileStorage, + LiveSchemaStorage, + LoadStorage, + LoadStorageConfiguration, + NormalizeStorage, + NormalizeStorageConfiguration, + SchemaStorage, + SchemaStorageConfiguration, +) from dlt.common.storages.load_storage import LoadJobInfo, LoadPackageInfo from dlt.common.typing import TFun, TSecretValue, is_optional_type -from dlt.common.runners import pool_runner as runner -from dlt.common.storages import LiveSchemaStorage, NormalizeStorage, LoadStorage, SchemaStorage, FileStorage, NormalizeStorageConfiguration, SchemaStorageConfiguration, LoadStorageConfiguration -from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import (DestinationClientDwhConfiguration, DestinationReference, JobClientBase, DestinationClientConfiguration, - TDestinationReferenceArg, DestinationClientStagingConfiguration, DestinationClientStagingConfiguration, - DestinationClientDwhWithStagingConfiguration) -from dlt.common.destination.capabilities import INTERNAL_LOADER_FILE_FORMATS -from dlt.common.pipeline import ExtractInfo, LoadInfo, NormalizeInfo, PipelineContext, SupportsPipeline, TPipelineLocalState, TPipelineState, StateInjectableContext -from dlt.common.schema import Schema from dlt.common.utils import is_interactive -from dlt.common.data_writers import TLoaderFileFormat - +from dlt.destinations.job_client_impl import SqlJobClientBase +from dlt.destinations.sql_client import SqlClientBase from dlt.extract.exceptions import DataItemRequiredForDynamicTableHints, SourceExhausted from dlt.extract.extract import ExtractorStorage, extract_with_schema from dlt.extract.source import DltResource, DltSource +from dlt.load import Load +from dlt.load.configuration import LoaderConfiguration from dlt.normalize import Normalize from dlt.normalize.configuration import NormalizeConfiguration -from dlt.destinations.sql_client import SqlClientBase -from dlt.destinations.job_client_impl import SqlJobClientBase -from dlt.load.configuration import LoaderConfiguration -from dlt.load import Load - from dlt.pipeline.configuration import PipelineConfiguration -from dlt.pipeline.progress import _Collector, _NULL_COLLECTOR -from dlt.pipeline.exceptions import CannotRestorePipelineException, InvalidPipelineName, PipelineConfigMissing, PipelineNotActive, PipelineStepFailed, SqlClientNotAvailable -from dlt.pipeline.trace import PipelineTrace, PipelineStepTrace, load_trace, merge_traces, start_trace, start_trace_step, end_trace_step, end_trace, describe_extract_data +from dlt.pipeline.exceptions import ( + CannotRestorePipelineException, + InvalidPipelineName, + PipelineConfigMissing, + PipelineNotActive, + PipelineStepFailed, + SqlClientNotAvailable, +) +from dlt.pipeline.progress import _NULL_COLLECTOR, _Collector +from dlt.pipeline.state_sync import ( + STATE_ENGINE_VERSION, + json_decode_state, + json_encode_state, + load_state_from_destination, + merge_state_if_changed, + migrate_state, + state_resource, +) +from dlt.pipeline.trace import ( + PipelineStepTrace, + PipelineTrace, + describe_extract_data, + end_trace, + end_trace_step, + load_trace, + merge_traces, + start_trace, + start_trace_step, +) from dlt.pipeline.typing import TPipelineStep -from dlt.pipeline.state_sync import STATE_ENGINE_VERSION, load_state_from_destination, merge_state_if_changed, migrate_state, state_resource, json_encode_state, json_decode_state - -from dlt.common.schema.utils import normalize_schema_name def with_state_sync(may_extract_state: bool = False) -> Callable[[TFun], TFun]: - def decorator(f: TFun) -> TFun: @wraps(f) def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: @@ -73,7 +139,6 @@ def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: def with_schemas_sync(f: TFun) -> TFun: - @wraps(f) def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: for name in self._schema_storage.live_schemas: @@ -88,7 +153,6 @@ def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: def with_runtime_trace(f: TFun) -> TFun: - @wraps(f) def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: trace: PipelineTrace = self._trace @@ -116,12 +180,16 @@ def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: # if there was a step, finish it end_trace_step(self._trace, trace_step, self, step_info) if is_new_trace: - assert trace is self._trace, f"Messed up trace reference {id(self._trace)} vs {id(trace)}" + assert ( + trace is self._trace + ), f"Messed up trace reference {id(self._trace)} vs {id(trace)}" end_trace(trace, self, self._pipeline_storage.storage_path) finally: # always end trace if is_new_trace: - assert self._trace == trace, f"Messed up trace reference {id(self._trace)} vs {id(trace)}" + assert ( + self._trace == trace + ), f"Messed up trace reference {id(self._trace)} vs {id(trace)}" # if we end new trace that had only 1 step, add it to previous trace # this way we combine several separate calls to extract, normalize, load as single trace # the trace of "run" has many steps and will not be merged @@ -132,13 +200,13 @@ def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: def with_config_section(sections: Tuple[str, ...]) -> Callable[[TFun], TFun]: - def decorator(f: TFun) -> TFun: - @wraps(f) def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: # add section context to the container to be used by all configuration without explicit sections resolution - with inject_section(ConfigSectionContext(pipeline_name=self.pipeline_name, sections=sections)): + with inject_section( + ConfigSectionContext(pipeline_name=self.pipeline_name, sections=sections) + ): return f(self, *args, **kwargs) return _wrap # type: ignore @@ -147,7 +215,6 @@ def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: class Pipeline(SupportsPipeline): - STATE_FILE: ClassVar[str] = "state.json" STATE_PROPS: ClassVar[List[str]] = list(get_type_hints(TPipelineState).keys()) LOCAL_STATE_PROPS: ClassVar[List[str]] = list(get_type_hints(TPipelineLocalState).keys()) @@ -178,22 +245,22 @@ class Pipeline(SupportsPipeline): runtime_config: RunConfiguration def __init__( - self, - pipeline_name: str, - pipelines_dir: str, - pipeline_salt: TSecretValue, - destination: DestinationReference, - staging: DestinationReference, - dataset_name: str, - credentials: Any, - import_schema_path: str, - export_schema_path: str, - full_refresh: bool, - progress: _Collector, - must_attach_to_local_pipeline: bool, - config: PipelineConfiguration, - runtime: RunConfiguration, - ) -> None: + self, + pipeline_name: str, + pipelines_dir: str, + pipeline_salt: TSecretValue, + destination: DestinationReference, + staging: DestinationReference, + dataset_name: str, + credentials: Any, + import_schema_path: str, + export_schema_path: str, + full_refresh: bool, + progress: _Collector, + must_attach_to_local_pipeline: bool, + config: PipelineConfiguration, + runtime: RunConfiguration, + ) -> None: """Initializes the Pipeline class which implements `dlt` pipeline. Please use `pipeline` function in `dlt` module to create a new Pipeline instance.""" self.pipeline_salt = pipeline_salt self.config = config @@ -248,7 +315,7 @@ def drop(self) -> "Pipeline": self.collector, False, self.config, - self.runtime_config + self.runtime_config, ) @with_runtime_trace @@ -266,7 +333,7 @@ def extract( primary_key: TColumnNames = None, schema: Schema = None, max_parallel_items: int = None, - workers: int = None + workers: int = None, ) -> ExtractInfo: """Extracts the `data` and prepare it for the normalization. Does not require destination or credentials to be configured. See `run` method for the arguments' description.""" # create extract storage to which all the sources will be extracted @@ -275,7 +342,15 @@ def extract( try: with self._maybe_destination_capabilities(): # extract all sources - for source in self._data_to_sources(data, schema, table_name, parent_table_name, write_disposition, columns, primary_key): + for source in self._data_to_sources( + data, + schema, + table_name, + parent_table_name, + write_disposition, + columns, + primary_key, + ): if source.exhausted: raise SourceExhausted(source.name) # TODO: merge infos for all the sources @@ -289,15 +364,21 @@ def extract( return ExtractInfo(describe_extract_data(data)) except Exception as exc: # TODO: provide metrics from extractor - raise PipelineStepFailed(self, "extract", exc, ExtractInfo(describe_extract_data(data))) from exc + raise PipelineStepFailed( + self, "extract", exc, ExtractInfo(describe_extract_data(data)) + ) from exc @with_runtime_trace @with_schemas_sync @with_config_section((known_sections.NORMALIZE,)) - def normalize(self, workers: int = 1, loader_file_format: TLoaderFileFormat = None) -> NormalizeInfo: + def normalize( + self, workers: int = 1, loader_file_format: TLoaderFileFormat = None + ) -> NormalizeInfo: """Normalizes the data prepared with `extract` method, infers the schema and creates load packages for the `load` method. Requires `destination` to be known.""" if is_interactive() and workers > 1: - raise NotImplementedError("Do not use normalize workers in interactive mode ie. in notebook") + raise NotImplementedError( + "Do not use normalize workers in interactive mode ie. in notebook" + ) if loader_file_format and loader_file_format in INTERNAL_LOADER_FILE_FORMATS: raise ValueError(f"{loader_file_format} is one of internal dlt file formats.") # check if any schema is present, if not then no data was extracted @@ -312,12 +393,16 @@ def normalize(self, workers: int = 1, loader_file_format: TLoaderFileFormat = No pool_type="none" if workers == 1 else "process", _schema_storage_config=self._schema_storage_config, _normalize_storage_config=self._normalize_storage_config, - _load_storage_config=self._load_storage_config + _load_storage_config=self._load_storage_config, ) # run with destination context with self._maybe_destination_capabilities(loader_file_format=loader_file_format): # shares schema storage with the pipeline so we do not need to install - normalize = Normalize(collector=self.collector, config=normalize_config, schema_storage=self._schema_storage) + normalize = Normalize( + collector=self.collector, + config=normalize_config, + schema_storage=self._schema_storage, + ) try: with signals.delayed_signals(): runner.run_pool(normalize.config, normalize) @@ -336,7 +421,7 @@ def load( credentials: Any = None, *, workers: int = 20, - raise_on_failed_jobs: bool = False + raise_on_failed_jobs: bool = False, ) -> LoadInfo: """Loads the packages prepared by `normalize` method into the `dataset_name` at `destination`, using provided `credentials`""" # set destination and default dataset if provided @@ -355,7 +440,7 @@ def load( load_config = LoaderConfiguration( workers=workers, raise_on_failed_jobs=raise_on_failed_jobs, - _load_storage_config=self._load_storage_config + _load_storage_config=self._load_storage_config, ) load = Load( self.destination, @@ -364,7 +449,7 @@ def load( is_storage_owner=False, config=load_config, initial_client_config=client.config, - initial_staging_client_config=staging_client.config if staging_client else None + initial_staging_client_config=staging_client.config if staging_client else None, ) try: with signals.delayed_signals(): @@ -390,7 +475,7 @@ def run( columns: Sequence[TColumnSchema] = None, primary_key: TColumnNames = None, schema: Schema = None, - loader_file_format: TLoaderFileFormat = None + loader_file_format: TLoaderFileFormat = None, ) -> LoadInfo: """Loads the data from `data` argument into the destination specified in `destination` and dataset specified in `dataset_name`. @@ -450,7 +535,12 @@ def run( self._set_dataset_name(dataset_name) # sync state with destination - if self.config.restore_from_destination and not self.full_refresh and not self._state_restored and (self.destination or destination): + if ( + self.config.restore_from_destination + and not self.full_refresh + and not self._state_restored + and (self.destination or destination) + ): self.sync_destination(destination, staging, dataset_name) # sync only once self._state_restored = True @@ -461,19 +551,35 @@ def run( if self.list_normalized_load_packages(): # if there were any pending loads, load them and **exit** if data is not None: - logger.warn("The pipeline `run` method will now load the pending load packages. The data you passed to the run function will not be loaded. In order to do that you must run the pipeline again") + logger.warn( + "The pipeline `run` method will now load the pending load packages. The data" + " you passed to the run function will not be loaded. In order to do that you" + " must run the pipeline again" + ) return self.load(destination, dataset_name, credentials=credentials) # extract from the source if data is not None: - self.extract(data, table_name=table_name, write_disposition=write_disposition, columns=columns, primary_key=primary_key, schema=schema) + self.extract( + data, + table_name=table_name, + write_disposition=write_disposition, + columns=columns, + primary_key=primary_key, + schema=schema, + ) self.normalize(loader_file_format=loader_file_format) return self.load(destination, dataset_name, credentials=credentials) else: return None @with_schemas_sync - def sync_destination(self, destination: TDestinationReferenceArg = None, staging: TDestinationReferenceArg = None, dataset_name: str = None) -> None: + def sync_destination( + self, + destination: TDestinationReferenceArg = None, + staging: TDestinationReferenceArg = None, + dataset_name: str = None, + ) -> None: """Synchronizes pipeline state with the `destination`'s state kept in `dataset_name` ### Summary @@ -500,7 +606,9 @@ def sync_destination(self, destination: TDestinationReferenceArg = None, staging # print(f'REMOTE STATE: {(remote_state or {}).get("_state_version")} >= {state["_state_version"]}') if remote_state and remote_state["_state_version"] >= state["_state_version"]: # compare changes and updates local state - merged_state = merge_state_if_changed(state, remote_state, increase_version=False) + merged_state = merge_state_if_changed( + state, remote_state, increase_version=False + ) # print(f"MERGED STATE: {bool(merged_state)}") if merged_state: # see if state didn't change the pipeline name @@ -508,15 +616,20 @@ def sync_destination(self, destination: TDestinationReferenceArg = None, staging raise CannotRestorePipelineException( state["pipeline_name"], self.pipelines_dir, - f"destination state contains state for pipeline with name {remote_state['pipeline_name']}" + "destination state contains state for pipeline with name" + f" {remote_state['pipeline_name']}", ) # if state was modified force get all schemas - restored_schemas = self._get_schemas_from_destination(merged_state["schema_names"], always_download=True) + restored_schemas = self._get_schemas_from_destination( + merged_state["schema_names"], always_download=True + ) # TODO: we should probably wipe out pipeline here # if we didn't full refresh schemas, get only missing schemas if restored_schemas is None: - restored_schemas = self._get_schemas_from_destination(state["schema_names"], always_download=False) + restored_schemas = self._get_schemas_from_destination( + state["schema_names"], always_download=False + ) # commit all the changes locally if merged_state: # set the pipeline props from merged state @@ -546,8 +659,11 @@ def sync_destination(self, destination: TDestinationReferenceArg = None, staging # reset pipeline self._wipe_working_folder() state = self._get_state() - self._configure(self._schema_storage_config.export_schema_path, self._schema_storage_config.import_schema_path, False) - + self._configure( + self._schema_storage_config.export_schema_path, + self._schema_storage_config.import_schema_path, + False, + ) # write the state back state = merged_state or state @@ -590,7 +706,12 @@ def deactivate(self) -> None: @property def has_data(self) -> bool: """Tells if the pipeline contains any data: schemas, extracted files, load packages or loaded packages in the destination""" - return not self.first_run or bool(self.schema_names) or len(self.list_extracted_resources()) > 0 or len(self.list_normalized_load_packages()) > 0 + return ( + not self.first_run + or bool(self.schema_names) + or len(self.list_extracted_resources()) > 0 + or len(self.list_normalized_load_packages()) > 0 + ) @property def has_pending_data(self) -> bool: @@ -641,7 +762,13 @@ def list_failed_jobs_in_package(self, load_id: str) -> Sequence[LoadJobInfo]: def sync_schema(self, schema_name: str = None, credentials: Any = None) -> TSchemaTables: """Synchronizes the schema `schema_name` with the destination. If no name is provided, the default schema will be synchronized.""" if not schema_name and not self.default_schema_name: - raise PipelineConfigMissing(self.pipeline_name, "default_schema_name", "load", "Pipeline contains no schemas. Please extract any data with `extract` or `run` methods.") + raise PipelineConfigMissing( + self.pipeline_name, + "default_schema_name", + "load", + "Pipeline contains no schemas. Please extract any data with `extract` or `run`" + " methods.", + ) schema = self.schemas[schema_name] if schema_name else self.default_schema client_config = self._get_destination_client_initial_config(credentials) @@ -667,7 +794,7 @@ def get_local_state_val(self, key: str) -> Any: state = self._container[StateInjectableContext].state except ContextDefaultCannotBeCreated: state = self._get_state() - return state["_local"][key] # type: ignore + return state["_local"][key] # type: ignore def sql_client(self, schema_name: str = None, credentials: Any = None) -> SqlClientBase[Any]: """Returns a sql connection configured to query/change the destination and dataset that were used to load the data.""" @@ -681,22 +808,32 @@ def sql_client(self, schema_name: str = None, credentials: Any = None) -> SqlCli if schema_name: schema = self.schemas[schema_name] else: - schema = self.default_schema if self.default_schema_name else Schema(normalize_schema_name(self.pipeline_name)) + schema = ( + self.default_schema + if self.default_schema_name + else Schema(normalize_schema_name(self.pipeline_name)) + ) return self._sql_job_client(schema, credentials).sql_client - def _destination_client(self, schema_name: str = None, credentials: Any = None) -> JobClientBase: + def _destination_client( + self, schema_name: str = None, credentials: Any = None + ) -> JobClientBase: """Get the destination job client for the configured destination""" # TODO: duplicated code from self.sql_client() ... if schema_name: schema = self.schemas[schema_name] else: - schema = self.default_schema if self.default_schema_name else Schema(normalize_schema_name(self.pipeline_name)) + schema = ( + self.default_schema + if self.default_schema_name + else Schema(normalize_schema_name(self.pipeline_name)) + ) client_config = self._get_destination_client_initial_config(credentials) return self._get_destination_clients(schema, client_config)[0] def _sql_job_client(self, schema: Schema, credentials: Any = None) -> SqlJobClientBase: client_config = self._get_destination_client_initial_config(credentials) - client = self._get_destination_clients(schema , client_config)[0] + client = self._get_destination_clients(schema, client_config)[0] if isinstance(client, SqlJobClientBase): return client else: @@ -707,7 +844,12 @@ def _get_normalize_storage(self) -> NormalizeStorage: def _get_load_storage(self) -> LoadStorage: caps = self._get_destination_capabilities() - return LoadStorage(True, caps.preferred_loader_file_format, caps.supported_loader_file_formats, self._load_storage_config) + return LoadStorage( + True, + caps.preferred_loader_file_format, + caps.supported_loader_file_formats, + self._load_storage_config, + ) def _init_working_dir(self, pipeline_name: str, pipelines_dir: str) -> None: self.pipeline_name = pipeline_name @@ -721,21 +863,31 @@ def _init_working_dir(self, pipeline_name: str, pipelines_dir: str) -> None: if self.full_refresh: self._wipe_working_folder() - def _configure(self, import_schema_path: str, export_schema_path: str, must_attach_to_local_pipeline: bool) -> None: + def _configure( + self, import_schema_path: str, export_schema_path: str, must_attach_to_local_pipeline: bool + ) -> None: # create schema storage and folders self._schema_storage_config = SchemaStorageConfiguration( schema_volume_path=os.path.join(self.working_dir, "schemas"), import_schema_path=import_schema_path, - export_schema_path=export_schema_path + export_schema_path=export_schema_path, ) # create default configs - self._normalize_storage_config = NormalizeStorageConfiguration(normalize_volume_path=os.path.join(self.working_dir, "normalize")) - self._load_storage_config = LoadStorageConfiguration(load_volume_path=os.path.join(self.working_dir, "load"),) + self._normalize_storage_config = NormalizeStorageConfiguration( + normalize_volume_path=os.path.join(self.working_dir, "normalize") + ) + self._load_storage_config = LoadStorageConfiguration( + load_volume_path=os.path.join(self.working_dir, "load"), + ) # are we running again? has_state = self._pipeline_storage.has_file(Pipeline.STATE_FILE) if must_attach_to_local_pipeline and not has_state: - raise CannotRestorePipelineException(self.pipeline_name, self.pipelines_dir, f"the pipeline was not found in {self.working_dir}.") + raise CannotRestorePipelineException( + self.pipeline_name, + self.pipelines_dir, + f"the pipeline was not found in {self.working_dir}.", + ) self.must_attach_to_local_pipeline = must_attach_to_local_pipeline # attach to pipeline if folder exists and contains state @@ -763,26 +915,32 @@ def _wipe_working_folder(self) -> None: def _attach_pipeline(self) -> None: pass - def _data_to_sources(self, + def _data_to_sources( + self, data: Any, schema: Schema, table_name: str = None, parent_table_name: str = None, write_disposition: TWriteDisposition = None, columns: Sequence[TColumnSchema] = None, - primary_key: TColumnNames = None + primary_key: TColumnNames = None, ) -> List[DltSource]: - def apply_hint_args(resource: DltResource) -> None: columns_dict = None if columns: - columns_dict = {c["name"]:c for c in columns} + columns_dict = {c["name"]: c for c in columns} # apply hints only if any of the hints is present, table_name must be always present if table_name or parent_table_name or write_disposition or columns or primary_key: resource_table_name: str = None with contextlib.suppress(DataItemRequiredForDynamicTableHints): resource_table_name = resource.table_name - resource.apply_hints(table_name or resource_table_name or resource.name, parent_table_name, write_disposition, columns_dict, primary_key) + resource.apply_hints( + table_name or resource_table_name or resource.name, + parent_table_name, + write_disposition, + columns_dict, + primary_key, + ) def choose_schema() -> Schema: """Except of explicitly passed schema, use a clone that will get discarded if extraction fails""" @@ -812,12 +970,19 @@ def append_data(data_item: Any) -> None: # apply hints apply_hint_args(data_item) sources.append( - DltSource(effective_schema.name, data_item.section or self.pipeline_name, effective_schema, [data_item]) + DltSource( + effective_schema.name, + data_item.section or self.pipeline_name, + effective_schema, + [data_item], ) + ) else: # iterator/iterable/generator # create resource first without table template - resource = DltResource.from_data(data_item, name=table_name, section=self.pipeline_name) + resource = DltResource.from_data( + data_item, name=table_name, section=self.pipeline_name + ) # apply hints apply_hint_args(resource) resources.append(resource) @@ -834,16 +999,22 @@ def append_data(data_item: Any) -> None: if resources: # add all the appended resources in one source - sources.append(DltSource(effective_schema.name, self.pipeline_name, effective_schema, resources)) + sources.append( + DltSource(effective_schema.name, self.pipeline_name, effective_schema, resources) + ) return sources - def _extract_source(self, storage: ExtractorStorage, source: DltSource, max_parallel_items: int, workers: int) -> str: + def _extract_source( + self, storage: ExtractorStorage, source: DltSource, max_parallel_items: int, workers: int + ) -> str: # discover the schema from source source_schema = source.schema source_schema.update_normalizers() - extract_id = extract_with_schema(storage, source, source_schema, self.collector, max_parallel_items, workers) + extract_id = extract_with_schema( + storage, source, source_schema, self.collector, max_parallel_items, workers + ) # if source schema does not exist in the pipeline if source_schema.name not in self._schema_storage: @@ -867,14 +1038,20 @@ def _extract_source(self, storage: ExtractorStorage, source: DltSource, max_para return extract_id - def _get_destination_client_initial_config(self, destination: DestinationReference = None, credentials: Any = None, as_staging: bool = False) -> DestinationClientConfiguration: + def _get_destination_client_initial_config( + self, + destination: DestinationReference = None, + credentials: Any = None, + as_staging: bool = False, + ) -> DestinationClientConfiguration: destination = destination or self.destination if not destination: raise PipelineConfigMissing( self.pipeline_name, "destination", "load", - "Please provide `destination` argument to `pipeline`, `run` or `load` method directly or via .dlt config.toml file or environment variable." + "Please provide `destination` argument to `pipeline`, `run` or `load` method" + " directly or via .dlt config.toml file or environment variable.", ) # create initial destination client config client_spec = destination.spec() @@ -885,27 +1062,41 @@ def _get_destination_client_initial_config(self, destination: DestinationReferen if credentials is not None and not isinstance(credentials, CredentialsConfiguration): # use passed credentials as initial value. initial value may resolve credentials credentials = initialize_credentials( - client_spec.get_resolvable_fields()["credentials"], - credentials + client_spec.get_resolvable_fields()["credentials"], credentials ) # this client support many schemas and datasets if issubclass(client_spec, DestinationClientDwhConfiguration): if not self.dataset_name and self.full_refresh: - logger.warning("Full refresh may not work if dataset name is not set. Please set the dataset_name argument in dlt.pipeline or run method") + logger.warning( + "Full refresh may not work if dataset name is not set. Please set the" + " dataset_name argument in dlt.pipeline or run method" + ) # set default schema name to load all incoming data to a single dataset, no matter what is the current schema name - default_schema_name = None if self.config.use_single_dataset else self.default_schema_name + default_schema_name = ( + None if self.config.use_single_dataset else self.default_schema_name + ) if issubclass(client_spec, DestinationClientStagingConfiguration): - return client_spec(dataset_name=self.dataset_name, default_schema_name=default_schema_name, credentials=credentials, as_staging=as_staging) - return client_spec(dataset_name=self.dataset_name, default_schema_name=default_schema_name, credentials=credentials) + return client_spec( + dataset_name=self.dataset_name, + default_schema_name=default_schema_name, + credentials=credentials, + as_staging=as_staging, + ) + return client_spec( + dataset_name=self.dataset_name, + default_schema_name=default_schema_name, + credentials=credentials, + ) return client_spec(credentials=credentials) - def _get_destination_clients(self, + def _get_destination_clients( + self, schema: Schema, initial_config: DestinationClientConfiguration = None, - initial_staging_config: DestinationClientConfiguration = None + initial_staging_config: DestinationClientConfiguration = None, ) -> Tuple[JobClientBase, JobClientBase]: try: # resolve staging config in order to pass it to destination client config @@ -913,14 +1104,20 @@ def _get_destination_clients(self, if self.staging: if not initial_staging_config: # this is just initial config - without user configuration injected - initial_staging_config = self._get_destination_client_initial_config(self.staging, as_staging=True) + initial_staging_config = self._get_destination_client_initial_config( + self.staging, as_staging=True + ) # create the client - that will also resolve the config staging_client = self.staging.client(schema, initial_staging_config) if not initial_config: # config is not provided then get it with injected credentials initial_config = self._get_destination_client_initial_config(self.destination) # attach the staging client config to destination client config - if its type supports it - if self.staging and isinstance(initial_config, DestinationClientDwhWithStagingConfiguration) and isinstance(staging_client.config ,DestinationClientStagingConfiguration): + if ( + self.staging + and isinstance(initial_config, DestinationClientDwhWithStagingConfiguration) + and isinstance(staging_client.config, DestinationClientStagingConfiguration) + ): initial_config.staging_config = staging_client.config # create instance with initial_config properly set client = self.destination.client(schema, initial_config) @@ -930,17 +1127,18 @@ def _get_destination_clients(self, raise MissingDependencyException( f"{client_spec.destination_name} destination", [f"{version.DLT_PKG_NAME}[{client_spec.destination_name}]"], - "Dependencies for specific destinations are available as extras of dlt" + "Dependencies for specific destinations are available as extras of dlt", ) def _get_destination_capabilities(self) -> DestinationCapabilitiesContext: if not self.destination: - raise PipelineConfigMissing( - self.pipeline_name, - "destination", - "normalize", - "Please provide `destination` argument to `pipeline`, `run` or `load` method directly or via .dlt config.toml file or environment variable." - ) + raise PipelineConfigMissing( + self.pipeline_name, + "destination", + "normalize", + "Please provide `destination` argument to `pipeline`, `run` or `load` method" + " directly or via .dlt config.toml file or environment variable.", + ) return self.destination.capabilities() def _get_staging_capabilities(self) -> DestinationCapabilitiesContext: @@ -966,23 +1164,36 @@ def _set_context(self, is_active: bool) -> None: # set destination context on activation if self.destination: # inject capabilities context - self._container[DestinationCapabilitiesContext] = self._get_destination_capabilities() + self._container[DestinationCapabilitiesContext] = ( + self._get_destination_capabilities() + ) else: # remove destination context on deactivation if DestinationCapabilitiesContext in self._container: del self._container[DestinationCapabilitiesContext] - def _set_destinations(self, destination: TDestinationReferenceArg, staging: TDestinationReferenceArg) -> None: + def _set_destinations( + self, destination: TDestinationReferenceArg, staging: TDestinationReferenceArg + ) -> None: destination_mod = DestinationReference.from_name(destination) self.destination = destination_mod or self.destination - if destination and not self.destination.capabilities().supported_loader_file_formats and not staging: - logger.warning(f"The destination {destination_mod.__name__} requires the filesystem staging destination to be set, but it was not provided. Setting it to 'filesystem'.") + if ( + destination + and not self.destination.capabilities().supported_loader_file_formats + and not staging + ): + logger.warning( + f"The destination {destination_mod.__name__} requires the filesystem staging" + " destination to be set, but it was not provided. Setting it to 'filesystem'." + ) staging = "filesystem" if staging: staging_module = DestinationReference.from_name(staging) - if staging_module and not issubclass(staging_module.spec(), DestinationClientStagingConfiguration): + if staging_module and not issubclass( + staging_module.spec(), DestinationClientStagingConfiguration + ): raise DestinationNoStagingMode(staging_module.__name__) self.staging = staging_module or self.staging @@ -991,7 +1202,9 @@ def _set_destinations(self, destination: TDestinationReferenceArg, staging: TDes self._set_default_normalizers() @contextmanager - def _maybe_destination_capabilities(self, loader_file_format: TLoaderFileFormat = None) -> Iterator[DestinationCapabilitiesContext]: + def _maybe_destination_capabilities( + self, loader_file_format: TLoaderFileFormat = None + ) -> Iterator[DestinationCapabilitiesContext]: try: caps: DestinationCapabilitiesContext = None injected_caps: ContextManager[DestinationCapabilitiesContext] = None @@ -1004,7 +1217,10 @@ def _maybe_destination_capabilities(self, loader_file_format: TLoaderFileFormat caps.preferred_loader_file_format = self._resolve_loader_file_format( DestinationReference.to_name(self.destination), DestinationReference.to_name(self.staging) if self.staging else None, - destination_caps, stage_caps, loader_file_format) + destination_caps, + stage_caps, + loader_file_format, + ) yield caps finally: if injected_caps: @@ -1012,17 +1228,21 @@ def _maybe_destination_capabilities(self, loader_file_format: TLoaderFileFormat @staticmethod def _resolve_loader_file_format( - destination: str, - staging: str, - dest_caps: DestinationCapabilitiesContext, - stage_caps: DestinationCapabilitiesContext, - file_format: TLoaderFileFormat) -> TLoaderFileFormat: - + destination: str, + staging: str, + dest_caps: DestinationCapabilitiesContext, + stage_caps: DestinationCapabilitiesContext, + file_format: TLoaderFileFormat, + ) -> TLoaderFileFormat: possible_file_formats = dest_caps.supported_loader_file_formats if stage_caps: if not dest_caps.supported_staging_file_formats: raise DestinationLoadingViaStagingNotSupported(destination) - possible_file_formats = [f for f in dest_caps.supported_staging_file_formats if f in stage_caps.supported_loader_file_formats] + possible_file_formats = [ + f + for f in dest_caps.supported_staging_file_formats + if f in stage_caps.supported_loader_file_formats + ] if not file_format: if not stage_caps: if not dest_caps.preferred_loader_file_format: @@ -1033,7 +1253,12 @@ def _resolve_loader_file_format( else: file_format = possible_file_formats[0] if len(possible_file_formats) > 0 else None if file_format not in possible_file_formats: - raise DestinationIncompatibleLoaderFileFormatException(destination, staging, file_format, set(possible_file_formats) - INTERNAL_LOADER_FILE_FORMATS) + raise DestinationIncompatibleLoaderFileFormatException( + destination, + staging, + file_format, + set(possible_file_formats) - INTERNAL_LOADER_FILE_FORMATS, + ) return file_format def _set_default_normalizers(self) -> None: @@ -1047,7 +1272,9 @@ def _set_dataset_name(self, new_dataset_name: str) -> None: fields = self.destination.spec().get_resolvable_fields() dataset_name_type = fields.get("dataset_name") # if dataset is required (default!) we create a default dataset name - destination_needs_dataset = dataset_name_type is not None and not is_optional_type(dataset_name_type) + destination_needs_dataset = dataset_name_type is not None and not is_optional_type( + dataset_name_type + ) # if destination is not specified - generate dataset if not self.destination or destination_needs_dataset: new_dataset_name = self.pipeline_name + self.DEFAULT_DATASET_SUFFIX @@ -1090,14 +1317,14 @@ def _get_load_info(self, load: Load) -> LoadInfo: def _get_state(self) -> TPipelineState: try: state = json_decode_state(self._pipeline_storage.load(Pipeline.STATE_FILE)) - return migrate_state(self.pipeline_name, state, state["_state_engine_version"], STATE_ENGINE_VERSION) + return migrate_state( + self.pipeline_name, state, state["_state_engine_version"], STATE_ENGINE_VERSION + ) except FileNotFoundError: return { "_state_version": 0, "_state_engine_version": STATE_ENGINE_VERSION, - "_local": { - "first_run": True - } + "_local": {"first_run": True}, } def _optional_sql_job_client(self, schema_name: str) -> Optional[SqlJobClientBase]: @@ -1114,7 +1341,9 @@ def _optional_sql_job_client(self, schema_name: str) -> Optional[SqlJobClientBas logger.info("Client not available due to missing credentials") return None - def _restore_state_from_destination(self, raise_on_connection_error: bool = True) -> Optional[TPipelineState]: + def _restore_state_from_destination( + self, raise_on_connection_error: bool = True + ) -> Optional[TPipelineState]: # if state is not present locally, take the state from the destination dataset_name = self.dataset_name use_single_dataset = self.config.use_single_dataset @@ -1139,9 +1368,15 @@ def _restore_state_from_destination(self, raise_on_connection_error: bool = True job_client.sql_client.close_connection() if state is None: - logger.info(f"The state was not found in the destination {self.destination.__name__}:{dataset_name}") + logger.info( + "The state was not found in the destination" + f" {self.destination.__name__}:{dataset_name}" + ) else: - logger.info(f"The state was restored from the destination {self.destination.__name__}:{dataset_name}") + logger.info( + "The state was restored from the destination" + f" {self.destination.__name__}:{dataset_name}" + ) return state else: @@ -1150,7 +1385,9 @@ def _restore_state_from_destination(self, raise_on_connection_error: bool = True # restore the use_single_dataset option self.config.use_single_dataset = use_single_dataset - def _get_schemas_from_destination(self, schema_names: Sequence[str], always_download: bool = False) -> Sequence[Schema]: + def _get_schemas_from_destination( + self, schema_names: Sequence[str], always_download: bool = False + ) -> Sequence[Schema]: # check which schemas are present in the pipeline and restore missing schemas restored_schemas: List[Schema] = [] for schema_name in schema_names: @@ -1159,13 +1396,20 @@ def _get_schemas_from_destination(self, schema_names: Sequence[str], always_down with self._optional_sql_job_client(schema_name) as job_client: schema_info = job_client.get_newest_schema_from_storage() if schema_info is None: - logger.info(f"The schema {schema_name} was not found in the destination {self.destination.__name__}:{job_client.sql_client.dataset_name}.") + logger.info( + f"The schema {schema_name} was not found in the destination" + f" {self.destination.__name__}:{job_client.sql_client.dataset_name}." + ) # try to import schema with contextlib.suppress(FileNotFoundError): self._schema_storage.load_schema(schema_name) else: schema = Schema.from_dict(json.loads(schema_info.schema)) - logger.info(f"The schema {schema_name} version {schema.version} hash {schema.stored_version_hash} was restored from the destination {self.destination.__name__}:{job_client.sql_client.dataset_name}") + logger.info( + f"The schema {schema_name} version {schema.version} hash" + f" {schema.stored_version_hash} was restored from the destination" + f" {self.destination.__name__}:{job_client.sql_client.dataset_name}" + ) restored_schemas.append(schema) return restored_schemas @@ -1222,7 +1466,10 @@ def _state_to_props(self, state: TPipelineState) -> None: if prop in state["_local"] and not prop.startswith("_"): setattr(self, prop, state["_local"][prop]) # type: ignore if "destination" in state: - self._set_destinations(DestinationReference.from_name(self.destination), DestinationReference.from_name(self.staging) if "staging" in state else None ) + self._set_destinations( + DestinationReference.from_name(self.destination), + DestinationReference.from_name(self.staging) if "staging" in state else None, + ) def _props_to_state(self, state: TPipelineState) -> None: """Write pipeline props to `state`""" @@ -1244,9 +1491,16 @@ def _save_state(self, state: TPipelineState) -> None: def _extract_state(self, state: TPipelineState) -> TPipelineState: # this will extract the state into current load package and update the schema with the _dlt_pipeline_state table # note: the schema will be persisted because the schema saving decorator is over the state manager decorator for extract - state_source = DltSource(self.default_schema.name, self.pipeline_name, self.default_schema, [state_resource(state)]) + state_source = DltSource( + self.default_schema.name, + self.pipeline_name, + self.default_schema, + [state_resource(state)], + ) storage = ExtractorStorage(self._normalize_storage_config) - extract_id = extract_with_schema(storage, state_source, self.default_schema, _NULL_COLLECTOR, 1, 1) + extract_id = extract_with_schema( + storage, state_source, self.default_schema, _NULL_COLLECTOR, 1, 1 + ) storage.commit_extract_files(extract_id) return state diff --git a/dlt/pipeline/progress.py b/dlt/pipeline/progress.py index 90fc192bb1..50822fcec0 100644 --- a/dlt/pipeline/progress.py +++ b/dlt/pipeline/progress.py @@ -1,12 +1,17 @@ """Measure the extract, normalize and load progress""" -from typing import Union, Literal +from typing import Literal, Union -from dlt.common.runtime.collector import TqdmCollector as tqdm, LogCollector as log, EnlightenCollector as enlighten, AliveCollector as alive_progress -from dlt.common.runtime.collector import Collector as _Collector, NULL_COLLECTOR as _NULL_COLLECTOR +from dlt.common.runtime.collector import NULL_COLLECTOR as _NULL_COLLECTOR +from dlt.common.runtime.collector import AliveCollector as alive_progress +from dlt.common.runtime.collector import Collector as _Collector +from dlt.common.runtime.collector import EnlightenCollector as enlighten +from dlt.common.runtime.collector import LogCollector as log +from dlt.common.runtime.collector import TqdmCollector as tqdm TSupportedCollectors = Literal["tqdm", "enlighten", "log", "alive_progress"] TCollectorArg = Union[_Collector, TSupportedCollectors] + def _from_name(collector: TCollectorArg) -> _Collector: """Create default collector by name""" if collector is None: diff --git a/dlt/pipeline/state_sync.py b/dlt/pipeline/state_sync.py index ade6a83ddc..9677b3602c 100644 --- a/dlt/pipeline/state_sync.py +++ b/dlt/pipeline/state_sync.py @@ -1,22 +1,17 @@ import binascii from typing import Any, Optional, cast -import binascii import pendulum import dlt - from dlt.common import json from dlt.common.pipeline import TPipelineState -from dlt.common.typing import DictStrAny from dlt.common.schema.typing import LOADS_TABLE_NAME, TTableSchemaColumns - +from dlt.common.typing import DictStrAny +from dlt.common.utils import compressed_b64decode, compressed_b64encode from dlt.destinations.sql_client import SqlClientBase from dlt.extract.source import DltResource - from dlt.pipeline.exceptions import PipelineStateEngineNoUpgradePathException -from dlt.common.utils import compressed_b64decode, compressed_b64encode - # allows to upgrade state when restored with a new version of state logic/schema STATE_ENGINE_VERSION = 2 @@ -24,31 +19,11 @@ STATE_TABLE_NAME = "_dlt_pipeline_state" # state table columns STATE_TABLE_COLUMNS: TTableSchemaColumns = { - "version": { - "name": "version", - "data_type": "bigint", - "nullable": False - }, - "engine_version": { - "name": "engine_version", - "data_type": "bigint", - "nullable": False - }, - "pipeline_name": { - "name": "pipeline_name", - "data_type": "text", - "nullable": False - }, - "state": { - "name": "state", - "data_type": "text", - "nullable": False - }, - "created_at": { - "name": "created_at", - "data_type": "timestamp", - "nullable": False - } + "version": {"name": "version", "data_type": "bigint", "nullable": False}, + "engine_version": {"name": "engine_version", "data_type": "bigint", "nullable": False}, + "pipeline_name": {"name": "pipeline_name", "data_type": "text", "nullable": False}, + "state": {"name": "state", "data_type": "text", "nullable": False}, + "created_at": {"name": "created_at", "data_type": "timestamp", "nullable": False}, } @@ -73,7 +48,9 @@ def decompress_state(state_str: str) -> DictStrAny: return json.typed_loadb(state_bytes) # type: ignore[no-any-return] -def merge_state_if_changed(old_state: TPipelineState, new_state: TPipelineState, increase_version: bool = True) -> Optional[TPipelineState]: +def merge_state_if_changed( + old_state: TPipelineState, new_state: TPipelineState, increase_version: bool = True +) -> Optional[TPipelineState]: # we may want to compare hashes like we do with schemas if json.dumps(old_state, sort_keys=True) == json.dumps(new_state, sort_keys=True): return None @@ -90,17 +67,24 @@ def state_resource(state: TPipelineState) -> DltResource: "version": state["_state_version"], "engine_version": state["_state_engine_version"], "pipeline_name": state["pipeline_name"], - "state": state_str, - "created_at": pendulum.now() + "state": state_str, + "created_at": pendulum.now(), } - return dlt.resource([state_doc], name=STATE_TABLE_NAME, write_disposition="append", columns=STATE_TABLE_COLUMNS) + return dlt.resource( + [state_doc], name=STATE_TABLE_NAME, write_disposition="append", columns=STATE_TABLE_COLUMNS + ) -def load_state_from_destination(pipeline_name: str, sql_client: SqlClientBase[Any]) -> TPipelineState: +def load_state_from_destination( + pipeline_name: str, sql_client: SqlClientBase[Any] +) -> TPipelineState: # NOTE: if dataset or table holding state does not exist, the sql_client will rise DestinationUndefinedEntity. caller must handle this # TODO: this must go into job client and STATE_TABLE_NAME + LOADS_TABLE_NAME must get normalized before using in the query - query = f"SELECT state FROM {STATE_TABLE_NAME} AS s JOIN {LOADS_TABLE_NAME} AS l ON l.load_id = s._dlt_load_id WHERE pipeline_name = %s AND l.status = 0 ORDER BY created_at DESC" + query = ( + f"SELECT state FROM {STATE_TABLE_NAME} AS s JOIN {LOADS_TABLE_NAME} AS l ON l.load_id =" + " s._dlt_load_id WHERE pipeline_name = %s AND l.status = 0 ORDER BY created_at DESC" + ) with sql_client.execute_query(query, pipeline_name) as cur: row = cur.fetchone() if not row: @@ -110,7 +94,9 @@ def load_state_from_destination(pipeline_name: str, sql_client: SqlClientBase[An return migrate_state(pipeline_name, s, s["_state_engine_version"], STATE_ENGINE_VERSION) -def migrate_state(pipeline_name: str, state: DictStrAny, from_engine: int, to_engine: int) -> TPipelineState: +def migrate_state( + pipeline_name: str, state: DictStrAny, from_engine: int, to_engine: int +) -> TPipelineState: if from_engine == to_engine: return cast(TPipelineState, state) if from_engine == 1 and to_engine > 1: @@ -120,6 +106,8 @@ def migrate_state(pipeline_name: str, state: DictStrAny, from_engine: int, to_en # check state engine state["_state_engine_version"] = from_engine if from_engine != to_engine: - raise PipelineStateEngineNoUpgradePathException(pipeline_name, state["_state_engine_version"], from_engine, to_engine) + raise PipelineStateEngineNoUpgradePathException( + pipeline_name, state["_state_engine_version"], from_engine, to_engine + ) return cast(TPipelineState, state) diff --git a/dlt/pipeline/trace.py b/dlt/pipeline/trace.py index 53a1c20a45..5a7bca203c 100644 --- a/dlt/pipeline/trace.py +++ b/dlt/pipeline/trace.py @@ -1,30 +1,31 @@ +import dataclasses +import datetime # noqa: 251 import os import pickle -import datetime # noqa: 251 -import dataclasses from collections.abc import Sequence as C_Sequence -from typing import Any, List, Tuple, NamedTuple, Optional, Protocol, Sequence +from typing import Any, List, NamedTuple, Optional, Protocol, Sequence, Tuple + import humanize from dlt.common import pendulum -from dlt.common.runtime.logger import suppress_and_warn from dlt.common.configuration import is_secret_hint from dlt.common.configuration.utils import _RESOLVED_TRACES from dlt.common.pipeline import ExtractDataInfo, SupportsPipeline +from dlt.common.runtime.logger import suppress_and_warn from dlt.common.typing import StrAny from dlt.common.utils import uniq_id - from dlt.extract.source import DltResource, DltSource -from dlt.pipeline.typing import TPipelineStep from dlt.pipeline.exceptions import PipelineStepFailed - +from dlt.pipeline.typing import TPipelineStep TRACE_ENGINE_VERSION = 1 TRACE_FILE_NAME = "trace.pickle" + # @dataclasses.dataclass(init=True) class SerializableResolvedValueTrace(NamedTuple): """Information on resolved secret and config values""" + key: str value: Any default_value: Any @@ -35,7 +36,7 @@ class SerializableResolvedValueTrace(NamedTuple): def asdict(self) -> StrAny: """A dictionary representation that is safe to load.""" - return {k:v for k,v in self._asdict().items() if k not in ("value", "default_value")} + return {k: v for k, v in self._asdict().items() if k not in ("value", "default_value")} def asstr(self, verbosity: int = 0) -> str: return f"{self.key}->{self.value} in {'.'.join(self.sections)} by {self.provider_name}" @@ -47,6 +48,7 @@ def __str__(self) -> str: @dataclasses.dataclass(init=True) class PipelineStepTrace: """Trace of particular pipeline step, contains timing information, the step outcome info or exception in case of failing step""" + span_id: str step: TPipelineStep started_at: datetime.datetime @@ -81,6 +83,7 @@ def __str__(self) -> str: @dataclasses.dataclass(init=True) class PipelineTrace: """Pipeline runtime trace containing data on "extract", "normalize" and "load" steps and resolved config and secret values.""" + transaction_id: str started_at: datetime.datetime steps: List[PipelineStepTrace] @@ -98,7 +101,10 @@ def asstr(self, verbosity: int = 0) -> str: elapsed_str = humanize.precisedelta(elapsed) else: elapsed_str = "---" - msg = f"Run started at {self.started_at} and {completed_str} in {elapsed_str} with {len(self.steps)} steps." + msg = ( + f"Run started at {self.started_at} and {completed_str} in {elapsed_str} with" + f" {len(self.steps)} steps." + ) if verbosity > 0 and len(self.resolved_config_values) > 0: msg += "\nFollowing config and secret values were resolved:\n" msg += "\n".join([s.asstr(verbosity) for s in self.resolved_config_values]) @@ -112,13 +118,23 @@ def __str__(self) -> str: class SupportsTracking(Protocol): - def on_start_trace(self, trace: PipelineTrace, step: TPipelineStep, pipeline: SupportsPipeline) -> None: + def on_start_trace( + self, trace: PipelineTrace, step: TPipelineStep, pipeline: SupportsPipeline + ) -> None: ... - def on_start_trace_step(self, trace: PipelineTrace, step: TPipelineStep, pipeline: SupportsPipeline) -> None: + def on_start_trace_step( + self, trace: PipelineTrace, step: TPipelineStep, pipeline: SupportsPipeline + ) -> None: ... - def on_end_trace_step(self, trace: PipelineTrace, step: PipelineStepTrace, pipeline: SupportsPipeline, step_info: Any) -> None: + def on_end_trace_step( + self, + trace: PipelineTrace, + step: PipelineStepTrace, + pipeline: SupportsPipeline, + step_info: Any, + ) -> None: ... def on_end_trace(self, trace: PipelineTrace, pipeline: SupportsPipeline) -> None: @@ -137,14 +153,18 @@ def start_trace(step: TPipelineStep, pipeline: SupportsPipeline) -> PipelineTrac return trace -def start_trace_step(trace: PipelineTrace, step: TPipelineStep, pipeline: SupportsPipeline) -> PipelineStepTrace: +def start_trace_step( + trace: PipelineTrace, step: TPipelineStep, pipeline: SupportsPipeline +) -> PipelineStepTrace: trace_step = PipelineStepTrace(uniq_id(), step, pendulum.now()) with suppress_and_warn(): TRACKING_MODULE.on_start_trace_step(trace, step, pipeline) return trace_step -def end_trace_step(trace: PipelineTrace, step: PipelineStepTrace, pipeline: SupportsPipeline, step_info: Any) -> None: +def end_trace_step( + trace: PipelineTrace, step: PipelineStepTrace, pipeline: SupportsPipeline, step_info: Any +) -> None: # saves runtime trace of the pipeline if isinstance(step_info, PipelineStepFailed): step_exception = str(step_info) @@ -162,15 +182,18 @@ def end_trace_step(trace: PipelineTrace, step: PipelineStepTrace, pipeline: Supp step.step_exception = step_exception step.step_info = step_info - resolved_values = map(lambda v: SerializableResolvedValueTrace( + resolved_values = map( + lambda v: SerializableResolvedValueTrace( v.key, v.value, v.default_value, is_secret_hint(v.hint), v.sections, v.provider_name, - str(type(v.config).__qualname__) - ) , _RESOLVED_TRACES.values()) + str(type(v.config).__qualname__), + ), + _RESOLVED_TRACES.values(), + ) trace.resolved_config_values = list(resolved_values) trace.steps.append(step) @@ -222,17 +245,16 @@ def describe_extract_data(data: Any) -> List[ExtractDataInfo]: def add_item(item: Any) -> bool: if isinstance(item, (DltResource, DltSource)): # record names of sources/resources - data_info.append({ - "name": item.name, - "data_type": "resource" if isinstance(item, DltResource) else "source" - }) + data_info.append( + { + "name": item.name, + "data_type": "resource" if isinstance(item, DltResource) else "source", + } + ) return False else: # anything else - data_info.append({ - "name": "", - "data_type": type(item).__name__ - }) + data_info.append({"name": "", "data_type": type(item).__name__}) return True item: Any = data diff --git a/dlt/pipeline/track.py b/dlt/pipeline/track.py index 8d3e9bfb98..621a946311 100644 --- a/dlt/pipeline/track.py +++ b/dlt/pipeline/track.py @@ -1,18 +1,17 @@ """Implements SupportsTracking""" import contextlib from typing import Any + import humanize -from dlt.common import pendulum -from dlt.common import logger +from dlt.common import logger, pendulum +from dlt.common.destination import DestinationReference +from dlt.common.pipeline import ExtractInfo, LoadInfo, SupportsPipeline from dlt.common.runtime.exec_info import github_info from dlt.common.runtime.segment import track as dlthub_telemetry_track from dlt.common.runtime.slack import send_slack_message -from dlt.common.pipeline import LoadInfo, ExtractInfo, SupportsPipeline -from dlt.common.destination import DestinationReference - +from dlt.pipeline.trace import PipelineStepTrace, PipelineTrace from dlt.pipeline.typing import TPipelineStep -from dlt.pipeline.trace import PipelineTrace, PipelineStepTrace try: from sentry_sdk import Hub @@ -24,6 +23,7 @@ def _add_sentry_tags(span: Span, pipeline: SupportsPipeline) -> None: span.set_tag("destination", pipeline.destination.__name__) if pipeline.dataset_name: span.set_tag("dataset_name", pipeline.dataset_name) + except ImportError: # sentry is optional dependency and enabled only when RuntimeConfiguration.sentry_dsn is set pass @@ -67,7 +67,9 @@ def on_start_trace(trace: PipelineTrace, step: TPipelineStep, pipeline: Supports transaction.__enter__() -def on_start_trace_step(trace: PipelineTrace, step: TPipelineStep, pipeline: SupportsPipeline) -> None: +def on_start_trace_step( + trace: PipelineTrace, step: TPipelineStep, pipeline: SupportsPipeline +) -> None: if pipeline.runtime_config.sentry_dsn: # print(f"START SENTRY SPAN {trace.transaction_id}:{trace_step.span_id} SCOPE: {Hub.current.scope}") span = Hub.current.scope.span.start_child(description=step, op=step).__enter__() @@ -75,7 +77,9 @@ def on_start_trace_step(trace: PipelineTrace, step: TPipelineStep, pipeline: Sup _add_sentry_tags(span, pipeline) -def on_end_trace_step(trace: PipelineTrace, step: PipelineStepTrace, pipeline: SupportsPipeline, step_info: Any) -> None: +def on_end_trace_step( + trace: PipelineTrace, step: PipelineStepTrace, pipeline: SupportsPipeline, step_info: Any +) -> None: if pipeline.runtime_config.sentry_dsn: # print(f"---END SENTRY SPAN {trace.transaction_id}:{step.span_id}: {step} SCOPE: {Hub.current.scope}") with contextlib.suppress(Exception): @@ -87,8 +91,10 @@ def on_end_trace_step(trace: PipelineTrace, step: PipelineStepTrace, pipeline: S props = { "elapsed": (step.finished_at - trace.started_at).total_seconds(), "success": step.step_exception is None, - "destination_name": DestinationReference.to_name(pipeline.destination) if pipeline.destination else None, - "transaction_id": trace.transaction_id + "destination_name": ( + DestinationReference.to_name(pipeline.destination) if pipeline.destination else None + ), + "transaction_id": trace.transaction_id, } # disable automatic slack messaging until we can configure messages themselves if step.step == "extract" and step_info: @@ -104,4 +110,4 @@ def on_end_trace(trace: PipelineTrace, pipeline: SupportsPipeline) -> None: if pipeline.runtime_config.sentry_dsn: # print(f"---END SENTRY TX: {trace.transaction_id} SCOPE: {Hub.current.scope}") with contextlib.suppress(Exception): - Hub.current.scope.span.__exit__(None, None, None) \ No newline at end of file + Hub.current.scope.span.__exit__(None, None, None) diff --git a/dlt/reflection/names.py b/dlt/reflection/names.py index 1aee6df52b..1778d70e13 100644 --- a/dlt/reflection/names.py +++ b/dlt/reflection/names.py @@ -2,7 +2,7 @@ import dlt import dlt.destinations -from dlt import pipeline, attach, run, source, resource +from dlt import attach, pipeline, resource, run, source DLT = dlt.__name__ DESTINATIONS = dlt.destinations.__name__ @@ -18,5 +18,5 @@ ATTACH: inspect.signature(attach), RUN: inspect.signature(run), SOURCE: inspect.signature(source), - RESOURCE: inspect.signature(resource) -} \ No newline at end of file + RESOURCE: inspect.signature(resource), +} diff --git a/dlt/reflection/script_inspector.py b/dlt/reflection/script_inspector.py index 204135dcd7..733b8d2d95 100644 --- a/dlt/reflection/script_inspector.py +++ b/dlt/reflection/script_inspector.py @@ -1,18 +1,17 @@ +import builtins import os import sys -import builtins +from importlib import import_module from pathlib import Path from types import ModuleType, SimpleNamespace -from typing import Any, Tuple, List, Mapping, Sequence +from typing import Any, List, Mapping, Sequence, Tuple from unittest.mock import patch -from importlib import import_module from dlt.common import logger from dlt.common.exceptions import DltException, MissingDependencyException from dlt.common.typing import DictStrAny - -from dlt.pipeline import Pipeline from dlt.extract.source import DltSource, ManagedPipeIterator +from dlt.pipeline import Pipeline def patch__init__(self: Any, *args: Any, **kwargs: Any) -> None: @@ -21,6 +20,7 @@ def patch__init__(self: Any, *args: Any, **kwargs: Any) -> None: class DummyModule(ModuleType): """A dummy module from which you can import anything""" + def __getattr__(self, key: str) -> Any: if key[0].isupper(): # if imported name is capitalized, import type @@ -28,13 +28,20 @@ def __getattr__(self, key: str) -> Any: else: # otherwise import instance return SimpleNamespace() - __all__: List[Any] = [] # support wildcard imports + + __all__: List[Any] = [] # support wildcard imports def _import_module(name: str, missing_modules: Tuple[str, ...] = ()) -> ModuleType: """Module importer that ignores missing modules by importing a dummy module""" - def _try_import(name: str, _globals: Mapping[str, Any] = None, _locals: Mapping[str, Any] = None, fromlist: Sequence[str] = (), level:int = 0) -> ModuleType: + def _try_import( + name: str, + _globals: Mapping[str, Any] = None, + _locals: Mapping[str, Any] = None, + fromlist: Sequence[str] = (), + level: int = 0, + ) -> ModuleType: """This function works as follows: on ImportError it raises. This import error is then next caught in the main function body and the name is added to exceptions. Next time if the name is on exception list or name is a package on exception list we return DummyModule and do not reraise This excepts only the modules that bubble up ImportError up until our code so any handled import errors are not excepted @@ -62,7 +69,7 @@ def _try_import(name: str, _globals: Mapping[str, Any] = None, _locals: Mapping[ # print(f"ADD {ie.name} {ie.path} vs {name} vs {str(ie)}") if ie.name in missing_modules: raise - missing_modules += (ie.name, ) + missing_modules += (ie.name,) except MissingDependencyException as me: if isinstance(me.__context__, ImportError): if me.__context__.name is None: @@ -71,14 +78,16 @@ def _try_import(name: str, _globals: Mapping[str, Any] = None, _locals: Mapping[ # print(f"{me.__context__.name} IN :/") raise # print(f"ADD {me.__context__.name}") - missing_modules += (me.__context__.name, ) + missing_modules += (me.__context__.name,) else: raise finally: builtins.__import__ = real_import -def load_script_module(module_path:str, script_relative_path: str, ignore_missing_imports: bool = False) -> ModuleType: +def load_script_module( + module_path: str, script_relative_path: str, ignore_missing_imports: bool = False +) -> ModuleType: """Loads a module in `script_relative_path` by splitting it into a script module (file part) and package (folders). `module_path` is added to sys.path Optionally, missing imports will be ignored by importing a dummy module instead. """ @@ -110,12 +119,24 @@ def load_script_module(module_path:str, script_relative_path: str, ignore_missin sys.path.remove(sys_path) -def inspect_pipeline_script(module_path:str, script_relative_path: str, ignore_missing_imports: bool = False) -> ModuleType: +def inspect_pipeline_script( + module_path: str, script_relative_path: str, ignore_missing_imports: bool = False +) -> ModuleType: # patch entry points to pipeline, sources and resources to prevent pipeline from running - with patch.object(Pipeline, '__init__', patch__init__), patch.object(DltSource, '__init__', patch__init__), patch.object(ManagedPipeIterator, '__init__', patch__init__): - return load_script_module(module_path, script_relative_path, ignore_missing_imports=ignore_missing_imports) + with patch.object(Pipeline, "__init__", patch__init__), patch.object( + DltSource, "__init__", patch__init__ + ), patch.object(ManagedPipeIterator, "__init__", patch__init__): + return load_script_module( + module_path, script_relative_path, ignore_missing_imports=ignore_missing_imports + ) class PipelineIsRunning(DltException): def __init__(self, obj: object, args: Tuple[str, ...], kwargs: DictStrAny) -> None: - super().__init__(f"The pipeline script instantiates the pipeline on import. Did you forget to use if __name__ == 'main':? in {obj.__class__.__name__}", obj, args, kwargs) + super().__init__( + "The pipeline script instantiates the pipeline on import. Did you forget to use if" + f" __name__ == 'main':? in {obj.__class__.__name__}", + obj, + args, + kwargs, + ) diff --git a/dlt/reflection/script_visitor.py b/dlt/reflection/script_visitor.py index 7d4e0ea2cd..9d761a71f5 100644 --- a/dlt/reflection/script_visitor.py +++ b/dlt/reflection/script_visitor.py @@ -1,16 +1,15 @@ -import inspect import ast -import astunparse +import inspect from ast import NodeVisitor from typing import Any, Dict, List -from dlt.common.reflection.utils import find_outer_func_def +import astunparse import dlt.reflection.names as n +from dlt.common.reflection.utils import find_outer_func_def class PipelineScriptVisitor(NodeVisitor): - def __init__(self, source: str): self.source = source self.source_lines: List[str] = ast._splitlines_no_ff(source) # type: ignore @@ -73,7 +72,9 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: elif isinstance(deco, ast.Call): alias_name = astunparse.unparse(deco.func).strip() else: - raise ValueError(self.source_segment(deco), type(deco), "Unknown decorator form") + raise ValueError( + self.source_segment(deco), type(deco), "Unknown decorator form" + ) fn = self.func_aliases.get(alias_name) if fn == n.SOURCE: self.known_sources[str(node.name)] = node @@ -96,7 +97,9 @@ def visit_Call(self, node: ast.Call) -> Any: sig = n.SIGNATURES[fn] try: # bind the signature where the argument values are the corresponding ast nodes - bound_args = sig.bind(*node.args, **{str(kwd.arg):kwd.value for kwd in node.keywords}) + bound_args = sig.bind( + *node.args, **{str(kwd.arg): kwd.value for kwd in node.keywords} + ) bound_args.apply_defaults() # print(f"ALIAS: {alias_name} of {self.func_aliases.get(alias_name)} with {bound_args}") fun_calls = self.known_calls.setdefault(fn, []) diff --git a/dlt/sources/__init__.py b/dlt/sources/__init__.py index 2a79dca8bd..eeed405448 100644 --- a/dlt/sources/__init__.py +++ b/dlt/sources/__init__.py @@ -1,2 +1,2 @@ """Module with built in sources and source building blocks""" -from dlt.extract.incremental import Incremental as incremental \ No newline at end of file +from dlt.extract.incremental import Incremental as incremental diff --git a/dlt/sources/credentials.py b/dlt/sources/credentials.py index a4cc38da88..99006e3e5f 100644 --- a/dlt/sources/credentials.py +++ b/dlt/sources/credentials.py @@ -1,4 +1,9 @@ -from dlt.common.configuration.specs import GcpServiceAccountCredentials, GcpOAuthCredentials, GcpCredentials -from dlt.common.configuration.specs import ConnectionStringCredentials -from dlt.common.configuration.specs import OAuth2Credentials -from dlt.common.configuration.specs import CredentialsConfiguration, configspec \ No newline at end of file +from dlt.common.configuration.specs import ( + ConnectionStringCredentials, + CredentialsConfiguration, + GcpCredentials, + GcpOAuthCredentials, + GcpServiceAccountCredentials, + OAuth2Credentials, + configspec, +) diff --git a/dlt/sources/helpers/requests/__init__.py b/dlt/sources/helpers/requests/__init__.py index 1d0a37b6e5..af311d456f 100644 --- a/dlt/sources/helpers/requests/__init__.py +++ b/dlt/sources/helpers/requests/__init__.py @@ -1,25 +1,34 @@ -from tenacity import RetryError from requests import ( - Request, Response, ConnectionError, ConnectTimeout, FileModeWarning, HTTPError, ReadTimeout, + Request, RequestException, + Response, Timeout, TooManyRedirects, URLRequired, ) from requests.exceptions import ChunkedEncodingError +from tenacity import RetryError + +from dlt.common.configuration.specs import RunConfiguration from dlt.sources.helpers.requests.retry import Client from dlt.sources.helpers.requests.session import Session -from dlt.common.configuration.specs import RunConfiguration client = Client() get, post, put, patch, delete, options, head, request = ( - client.get, client.post, client.put, client.patch, client.delete, client.options, client.head, client.request + client.get, + client.post, + client.put, + client.patch, + client.delete, + client.options, + client.head, + client.request, ) diff --git a/dlt/sources/helpers/requests/retry.py b/dlt/sources/helpers/requests/retry.py index d1e7a1a7f3..92c649eb23 100644 --- a/dlt/sources/helpers/requests/retry.py +++ b/dlt/sources/helpers/requests/retry.py @@ -1,21 +1,40 @@ -from email.utils import parsedate_tz, mktime_tz import re import time -from typing import Optional, cast, Callable, Type, Union, Sequence, Tuple, List, TYPE_CHECKING, Any, Dict +from email.utils import mktime_tz, parsedate_tz from threading import local - -from requests import Response, HTTPError, Session as BaseSession -from requests.exceptions import ConnectionError, Timeout, ChunkedEncodingError +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, +) + +from requests import HTTPError, Response +from requests import Session as BaseSession from requests.adapters import HTTPAdapter -from tenacity import Retrying, retry_if_exception_type, stop_after_attempt, RetryCallState, retry_any, wait_exponential +from requests.exceptions import ChunkedEncodingError, ConnectionError, Timeout +from tenacity import ( + RetryCallState, + Retrying, + retry_any, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) from tenacity.retry import retry_base -from dlt.sources.helpers.requests.session import Session, DEFAULT_TIMEOUT -from dlt.sources.helpers.requests.typing import TRequestTimeout -from dlt.common.typing import TimedeltaSeconds -from dlt.common.configuration.specs import RunConfiguration from dlt.common.configuration import with_config - +from dlt.common.configuration.specs import RunConfiguration +from dlt.common.typing import TimedeltaSeconds +from dlt.sources.helpers.requests.session import DEFAULT_TIMEOUT, Session +from dlt.sources.helpers.requests.typing import TRequestTimeout DEFAULT_RETRY_STATUS = (429, *range(500, 600)) DEFAULT_RETRY_EXCEPTIONS = (ConnectionError, Timeout, ChunkedEncodingError) @@ -96,7 +115,7 @@ def _make_retry( backoff_factor: float, respect_retry_after_header: bool, max_delay: TimedeltaSeconds, -)-> Retrying: +) -> Retrying: retry_conds = [retry_if_status(status_codes), retry_if_exception_type(tuple(exceptions))] if condition is not None: if callable(condition): @@ -148,12 +167,15 @@ class Client: respect_retry_after_header: Whether to use the `Retry-After` response header (when available) to determine the retry delay session_attrs: Extra attributes that will be set on the session instance, e.g. `{headers: {'Authorization': 'api-key'}}` (see `requests.sessions.Session` for possible attributes) """ + _session_attrs: Dict[str, Any] @with_config(spec=RunConfiguration) def __init__( self, - request_timeout: Optional[Union[TimedeltaSeconds, Tuple[TimedeltaSeconds, TimedeltaSeconds]]] = DEFAULT_TIMEOUT, + request_timeout: Optional[ + Union[TimedeltaSeconds, Tuple[TimedeltaSeconds, TimedeltaSeconds]] + ] = DEFAULT_TIMEOUT, max_connections: int = 50, raise_for_status: bool = True, status_codes: Sequence[int] = DEFAULT_RETRY_STATUS, @@ -175,7 +197,7 @@ def __init__( condition=retry_condition, backoff_factor=request_backoff_factor, respect_retry_after_header=respect_retry_after_header, - max_delay=request_max_retry_delay + max_delay=request_max_retry_delay, ) self._session_attrs = session_attrs or {} @@ -198,29 +220,31 @@ def __init__( self.options = lambda *a, **kw: self.session.options(*a, **kw) self.request = lambda *a, **kw: self.session.request(*a, **kw) - self._config_version: int = 0 # Incrementing marker to ensure per-thread sessions are recreated on config changes + self._config_version: int = ( + 0 # Incrementing marker to ensure per-thread sessions are recreated on config changes + ) def update_from_config(self, config: RunConfiguration) -> None: """Update session/retry settings from RunConfiguration""" - self._session_kwargs['timeout'] = config.request_timeout - self._retry_kwargs['backoff_factor'] = config.request_backoff_factor - self._retry_kwargs['max_delay'] = config.request_max_retry_delay - self._retry_kwargs['max_attempts'] = config.request_max_attempts + self._session_kwargs["timeout"] = config.request_timeout + self._retry_kwargs["backoff_factor"] = config.request_backoff_factor + self._retry_kwargs["max_delay"] = config.request_max_retry_delay + self._retry_kwargs["max_attempts"] = config.request_max_attempts self._config_version += 1 def _make_session(self) -> Session: session = Session(**self._session_kwargs) # type: ignore[arg-type] for key, value in self._session_attrs.items(): setattr(session, key, value) - session.mount('http://', self._adapter) - session.mount('https://', self._adapter) + session.mount("http://", self._adapter) + session.mount("https://", self._adapter) retry = _make_retry(**self._retry_kwargs) session.request = retry.wraps(session.request) # type: ignore[method-assign] return session @property def session(self) -> Session: - session: Optional[Session] = getattr(self._local, 'session', None) + session: Optional[Session] = getattr(self._local, "session", None) version = self._config_version if session is not None: version = self._local.config_version diff --git a/dlt/sources/helpers/requests/session.py b/dlt/sources/helpers/requests/session.py index 9455f5698e..3ebeffc602 100644 --- a/dlt/sources/helpers/requests/session.py +++ b/dlt/sources/helpers/requests/session.py @@ -1,11 +1,11 @@ +from typing import TYPE_CHECKING, Optional, Sequence, Tuple, Type, TypeVar, Union + from requests import Session as BaseSession from tenacity import Retrying, retry_if_exception_type -from typing import Optional, TYPE_CHECKING, Sequence, Union, Tuple, Type, TypeVar -from dlt.sources.helpers.requests.typing import TRequestTimeout -from dlt.common.typing import TimedeltaSeconds from dlt.common.time import to_seconds - +from dlt.common.typing import TimedeltaSeconds +from dlt.sources.helpers.requests.typing import TRequestTimeout TSession = TypeVar("TSession", bound=BaseSession) @@ -14,7 +14,11 @@ def _timeout_to_seconds(timeout: TRequestTimeout) -> Optional[Union[Tuple[float, float], float]]: - return (to_seconds(timeout[0]), to_seconds(timeout[1])) if isinstance(timeout, tuple) else to_seconds(timeout) + return ( + (to_seconds(timeout[0]), to_seconds(timeout[1])) + if isinstance(timeout, tuple) + else to_seconds(timeout) + ) class Session(BaseSession): @@ -25,9 +29,12 @@ class Session(BaseSession): May be a single value or a tuple for separate (connect, read) timeout. raise_for_status: Whether to raise exception on error status codes (using `response.raise_for_status()`) """ + def __init__( self, - timeout: Optional[Union[TimedeltaSeconds, Tuple[TimedeltaSeconds, TimedeltaSeconds]]] = DEFAULT_TIMEOUT, + timeout: Optional[ + Union[TimedeltaSeconds, Tuple[TimedeltaSeconds, TimedeltaSeconds]] + ] = DEFAULT_TIMEOUT, raise_for_status: bool = True, ) -> None: super().__init__() @@ -38,7 +45,7 @@ def __init__( request = BaseSession.request def request(self, *args, **kwargs): # type: ignore[no-untyped-def,no-redef] - kwargs.setdefault('timeout', self.timeout) + kwargs.setdefault("timeout", self.timeout) resp = super().request(*args, **kwargs) if self.raise_for_status: resp.raise_for_status() diff --git a/dlt/sources/helpers/requests/typing.py b/dlt/sources/helpers/requests/typing.py index 8595e65b95..73d48746f7 100644 --- a/dlt/sources/helpers/requests/typing.py +++ b/dlt/sources/helpers/requests/typing.py @@ -1,4 +1,4 @@ -from typing import Tuple, Union, Optional +from typing import Optional, Tuple, Union from dlt.common.typing import TimedeltaSeconds diff --git a/dlt/sources/helpers/transform.py b/dlt/sources/helpers/transform.py index 0c2f7c5e39..1975c20586 100644 --- a/dlt/sources/helpers/transform.py +++ b/dlt/sources/helpers/transform.py @@ -5,18 +5,22 @@ def take_first(max_items: int) -> ItemTransformFunctionNoMeta[bool]: """A filter that takes only first `max_items` from a resource""" count: int = 0 + def _filter(_: TDataItem) -> bool: nonlocal count count += 1 return count <= max_items + return _filter def skip_first(max_items: int) -> ItemTransformFunctionNoMeta[bool]: """A filter that skips first `max_items` from a resource""" count: int = 0 + def _filter(_: TDataItem) -> bool: nonlocal count count += 1 return count > max_items + return _filter diff --git a/dlt/version.py b/dlt/version.py index f8ca3cb873..d2159cc5f6 100644 --- a/dlt/version.py +++ b/dlt/version.py @@ -1,6 +1,7 @@ -from importlib.metadata import version as pkg_version, distribution as pkg_distribution -from urllib.request import url2pathname +from importlib.metadata import distribution as pkg_distribution +from importlib.metadata import version as pkg_version from urllib.parse import urlparse +from urllib.request import url2pathname DLT_IMPORT_NAME = "dlt" DLT_PKG_NAME = "dlt" diff --git a/docs/examples/_helpers.py b/docs/examples/_helpers.py index 95913d1be1..15789d1f35 100644 --- a/docs/examples/_helpers.py +++ b/docs/examples/_helpers.py @@ -1,7 +1,6 @@ # here we provide the credentials for our public project import base64 - _bigquery_credentials = { "type": "service_account", "project_id": "zinc-mantra-353207", @@ -10,6 +9,12 @@ } # we do not want to have this key verbatim in repo so we decode it here -_bigquery_credentials["private_key"] = bytes([_a ^ _b for _a, _b in zip(base64.b64decode(_bigquery_credentials["private_key"]), b"quickstart-sv"*150)]).decode("utf-8") +_bigquery_credentials["private_key"] = bytes( + [ + _a ^ _b + for _a, _b in zip( + base64.b64decode(_bigquery_credentials["private_key"]), b"quickstart-sv" * 150 + ) + ] +).decode("utf-8") pub_bigquery_credentials = _bigquery_credentials - diff --git a/docs/examples/chess/chess.py b/docs/examples/chess/chess.py index f136e49a0a..9ec3306816 100644 --- a/docs/examples/chess/chess.py +++ b/docs/examples/chess/chess.py @@ -3,15 +3,19 @@ from typing import Any, Iterator import dlt - from dlt.common import sleep from dlt.common.typing import StrAny, TDataItems from dlt.sources.helpers.requests import client @dlt.source -def chess(chess_url: str = dlt.config.value, title: str = "GM", max_players: int = 2, year: int = 2022, month: int = 10) -> Any: - +def chess( + chess_url: str = dlt.config.value, + title: str = "GM", + max_players: int = 2, + year: int = 2022, + month: int = 10, +) -> Any: def _get_data_with_retry(path: str) -> StrAny: r = client.get(f"{chess_url}{path}") return r.json() # type: ignore @@ -29,7 +33,7 @@ def players() -> Iterator[TDataItems]: @dlt.defer def players_profiles(username: Any) -> TDataItems: print(f"getting {username} profile via thread {threading.current_thread().name}") - sleep(1) # add some latency to show parallel runs + sleep(1) # add some latency to show parallel runs return _get_data_with_retry(f"player/{username}") # this resource takes data from players and returns games for the last month if not specified otherwise @@ -41,6 +45,7 @@ def players_games(username: Any) -> Iterator[TDataItems]: return players(), players_profiles, players_games + if __name__ == "__main__": print("You must run this from the docs/examples/chess folder") assert os.getcwd().endswith("chess") @@ -48,12 +53,7 @@ def players_games(username: Any) -> Iterator[TDataItems]: # look for parallel run configuration in `config.toml`! # mind the full_refresh: it makes the pipeline to load to a distinct dataset each time it is run and always is resetting the schema and state info = dlt.pipeline( - pipeline_name="chess_games", - destination="postgres", - dataset_name="chess", - full_refresh=True - ).run( - chess(max_players=5, month=9) - ) + pipeline_name="chess_games", destination="postgres", dataset_name="chess", full_refresh=True + ).run(chess(max_players=5, month=9)) # display where the data went print(info) diff --git a/docs/examples/chess/chess_dbt.py b/docs/examples/chess/chess_dbt.py index 4ee51f6b50..d5fe1576a9 100644 --- a/docs/examples/chess/chess_dbt.py +++ b/docs/examples/chess/chess_dbt.py @@ -1,4 +1,5 @@ import os + import dlt # from chess import chess @@ -21,4 +22,3 @@ # run all the tests tests = transforms.test() print(tests) - diff --git a/docs/examples/credentials/explicit.py b/docs/examples/credentials/explicit.py index 6233140459..a4827f0663 100644 --- a/docs/examples/credentials/explicit.py +++ b/docs/examples/credentials/explicit.py @@ -1,10 +1,13 @@ import os from typing import Iterator + import dlt @dlt.resource -def simple_data(api_url: str = dlt.config.value, api_secret: dlt.TSecretValue = dlt.secrets.value) -> Iterator[str]: +def simple_data( + api_url: str = dlt.config.value, api_secret: dlt.TSecretValue = dlt.secrets.value +) -> Iterator[str]: # just yield api_url and api_secret to show what was configured in the example yield api_url yield api_secret @@ -29,13 +32,17 @@ def simple_data(api_url: str = dlt.config.value, api_secret: dlt.TSecretValue = print(list(data)) # you are free to pass credentials from custom location to destination -pipeline = dlt.pipeline(destination="postgres", credentials=dlt.secrets["custom.destination.credentials"]) +pipeline = dlt.pipeline( + destination="postgres", credentials=dlt.secrets["custom.destination.credentials"] +) # see nice credentials object print(pipeline.credentials) # you can also pass credentials partially, only the password comes from the secrets or environment -pipeline = dlt.pipeline(destination="postgres", credentials="postgres://loader@localhost:5432/dlt_data") +pipeline = dlt.pipeline( + destination="postgres", credentials="postgres://loader@localhost:5432/dlt_data" +) # now lets compare it with default location for config and credentials data = simple_data() -print(list(data)) \ No newline at end of file +print(list(data)) diff --git a/docs/examples/dbt_run_jaffle.py b/docs/examples/dbt_run_jaffle.py index ad059dcd6d..098b35fff8 100644 --- a/docs/examples/dbt_run_jaffle.py +++ b/docs/examples/dbt_run_jaffle.py @@ -2,7 +2,9 @@ pipeline = dlt.pipeline(destination="duckdb", dataset_name="jaffle_jaffle") -print("create or restore virtual environment in which dbt is installed, use the newest version of dbt") +print( + "create or restore virtual environment in which dbt is installed, use the newest version of dbt" +) venv = dlt.dbt.get_venv(pipeline) print("get runner, optionally pass the venv") @@ -11,13 +13,18 @@ print("run the package (clone/pull repo, deps, seed, source tests, run)") models = dbt.run_all() for m in models: - print(f"Model {m.model_name} materialized in {m.time} with status {m.status} and message {m.message}") + print( + f"Model {m.model_name} materialized in {m.time} with status {m.status} and message" + f" {m.message}" + ) print("") print("test the model") models = dbt.test() for m in models: - print(f"Test {m.model_name} executed in {m.time} with status {m.status} and message {m.message}") + print( + f"Test {m.model_name} executed in {m.time} with status {m.status} and message {m.message}" + ) print("") print("get and display data frame with customers") diff --git a/docs/examples/discord_iterator.py b/docs/examples/discord_iterator.py index a3c59ed2c5..44cbe3b5b1 100644 --- a/docs/examples/discord_iterator.py +++ b/docs/examples/discord_iterator.py @@ -1,4 +1,3 @@ - # from dlt.common import json # from dlt.common.schema import Schema # from dlt.common.typing import DictStrAny diff --git a/docs/examples/google_sheets.py b/docs/examples/google_sheets.py index 93c5658233..5e3164e951 100644 --- a/docs/examples/google_sheets.py +++ b/docs/examples/google_sheets.py @@ -1,10 +1,11 @@ import dlt - from sources.google_sheets import google_spreadsheet dlt.pipeline(destination="bigquery", full_refresh=False) # see example.secrets.toml to where to put credentials # "2022-05", "model_metadata" -info = google_spreadsheet("11G95oVZjieRhyGqtQMQqlqpxyvWkRXowKE8CtdLtFaU", ["named range", "Second_Copy!1:2"]) +info = google_spreadsheet( + "11G95oVZjieRhyGqtQMQqlqpxyvWkRXowKE8CtdLtFaU", ["named range", "Second_Copy!1:2"] +) print(list(info)) diff --git a/docs/examples/quickstart.py b/docs/examples/quickstart.py index e55e9f6049..6e49f1af7a 100644 --- a/docs/examples/quickstart.py +++ b/docs/examples/quickstart.py @@ -9,9 +9,9 @@ """ # 1. configuration: name your dataset, table, pass credentials -dataset_name = 'dlt_quickstart' -pipeline_name = 'dlt_quickstart' -table_name = 'my_json_doc' +dataset_name = "dlt_quickstart" +pipeline_name = "dlt_quickstart" +table_name = "my_json_doc" gcp_credentials_json = { "type": "service_account", @@ -24,7 +24,14 @@ destination_name = "duckdb" if destination_name == "bigquery": # we do not want to have this key verbatim in repo so we decode it here - gcp_credentials_json["private_key"] = bytes([_a ^ _b for _a, _b in zip(base64.b64decode(gcp_credentials_json["private_key"]), b"quickstart-sv"*150)]).decode("utf-8") + gcp_credentials_json["private_key"] = bytes( + [ + _a ^ _b + for _a, _b in zip( + base64.b64decode(gcp_credentials_json["private_key"]), b"quickstart-sv" * 150 + ) + ] + ).decode("utf-8") credentials: Any = gcp_credentials_json elif destination_name == "redshift": credentials = db_dsn @@ -41,20 +48,26 @@ dataset_name=dataset_name, credentials=credentials, export_schema_path=export_schema_path, - full_refresh=True + full_refresh=True, ) # 3. Pass the data to the pipeline and give it a table name. Optionally normalize and handle schema. -rows = [{"name": "Ana", "age": 30, "id": 456, "children": [{"name": "Bill", "id": 625}, - {"name": "Elli", "id": 591} - ]}, - - {"name": "Bob", "age": 30, "id": 455, "children": [{"name": "Bill", "id": 625}, - {"name": "Dave", "id": 621} - ]} - ] +rows = [ + { + "name": "Ana", + "age": 30, + "id": 456, + "children": [{"name": "Bill", "id": 625}, {"name": "Elli", "id": 591}], + }, + { + "name": "Bob", + "age": 30, + "id": 455, + "children": [{"name": "Bill", "id": 625}, {"name": "Dave", "id": 621}], + }, +] load_info = pipeline.run(rows, table_name=table_name, write_disposition="replace") diff --git a/docs/examples/rasa_example.py b/docs/examples/rasa_example.py index 3dbd61c692..e3b9ed2a98 100644 --- a/docs/examples/rasa_example.py +++ b/docs/examples/rasa_example.py @@ -1,12 +1,11 @@ import os -import dlt -from dlt.destinations import bigquery, postgres - +from docs.examples._helpers import pub_bigquery_credentials from docs.examples.sources.jsonl import jsonl_files from docs.examples.sources.rasa import rasa -from docs.examples._helpers import pub_bigquery_credentials +import dlt +from dlt.destinations import bigquery, postgres # let's load to bigquery, here we provide the credentials for our public project # credentials = pub_bigquery_credentials @@ -24,9 +23,11 @@ destination=postgres, # export_schema_path=... # uncomment to see the final schema in the folder you want ).run( - rasa(event_files, store_last_timestamp=True), # also store last timestamp so we have no duplicate events - credentials=credentials # if you skip this parameter, the credentials will be injected by the config providers - ) + rasa( + event_files, store_last_timestamp=True + ), # also store last timestamp so we have no duplicate events + credentials=credentials, # if you skip this parameter, the credentials will be injected by the config providers +) print(info) diff --git a/docs/examples/read_table.py b/docs/examples/read_table.py index 291c27bde4..4ccebc232d 100644 --- a/docs/examples/read_table.py +++ b/docs/examples/read_table.py @@ -1,15 +1,17 @@ +from docs.examples.sources.sql_query import query_sql, query_table + import dlt -from dlt.destinations import postgres from dlt.common.data_types.type_helpers import py_type_to_sc_type - -from docs.examples.sources.sql_query import query_table, query_sql +from dlt.destinations import postgres # the connection string to redshift instance holding some ethereum data # the connection string does not contain the password element and you should provide it in environment variable: SOURCES__CREDENTIALS__PASSWORD source_dsn = "redshift+redshift_connector://loader@chat-analytics.czwteevq7bpe.eu-central-1.redshift.amazonaws.com:5439/chat_analytics_rasa" # get data from table, we preserve method signature from pandas -items = query_table("blocks__transactions", source_dsn, table_schema_name="mainnet_2_ethereum", coerce_float=False) +items = query_table( + "blocks__transactions", source_dsn, table_schema_name="mainnet_2_ethereum", coerce_float=False +) # the data is also an iterator for i in items: @@ -25,5 +27,7 @@ # you can find a docker compose file that spins up required instance in tests/load/postgres # note: run the script without required env variables to see info on possible secret configurations that were tried -info = dlt.pipeline().run(items, destination=postgres, dataset_name="ethereum", table_name="transactions") +info = dlt.pipeline().run( + items, destination=postgres, dataset_name="ethereum", table_name="transactions" +) print(info) diff --git a/docs/examples/restore_pipeline.py b/docs/examples/restore_pipeline.py index f3c013e85b..fc1f92a4c0 100644 --- a/docs/examples/restore_pipeline.py +++ b/docs/examples/restore_pipeline.py @@ -18,4 +18,4 @@ # print(pipeline.list_extracted_loads()) # # just finalize -# pipeline.flush() \ No newline at end of file +# pipeline.flush() diff --git a/docs/examples/singer_tap_example.py b/docs/examples/singer_tap_example.py index d03182339c..9767f803b0 100644 --- a/docs/examples/singer_tap_example.py +++ b/docs/examples/singer_tap_example.py @@ -1,17 +1,20 @@ import os from tempfile import mkdtemp +from docs.examples.sources.singer_tap import tap + import dlt from dlt.common.runners import Venv -from docs.examples.sources.singer_tap import tap - # create Venv with desired dependencies, in this case csv tap # venv creation costs time so it should be created only once and reused # here we use context manager to automatically delete venv after example was run # the dependency is meltano version of csv tap -print("Spawning virtual environment to run singer and installing csv tap from git+https://github.com/MeltanoLabs/tap-csv.git") +print( + "Spawning virtual environment to run singer and installing csv tap from" + " git+https://github.com/MeltanoLabs/tap-csv.git" +) # WARNING: on MACOS you need to have working gcc to use tap-csv, otherwise dependency will not be installed with Venv.create(mkdtemp(), ["git+https://github.com/MeltanoLabs/tap-csv.git"]) as venv: # prep singer config for tap-csv @@ -20,13 +23,13 @@ { "entity": "annotations_202205", "path": os.path.abspath("examples/data/singer_taps/model_annotations.csv"), - "keys": [ - "message id" - ] + "keys": ["message id"], } ] } print("running tap-csv") tap_source = tap(venv, "tap-csv", csv_tap_config, "examples/data/singer_taps/csv_catalog.json") - info = dlt.pipeline("meltano_csv", destination="postgres").run(tap_source, credentials="postgres://loader@localhost:5432/dlt_data") + info = dlt.pipeline("meltano_csv", destination="postgres").run( + tap_source, credentials="postgres://loader@localhost:5432/dlt_data" + ) print(info) diff --git a/docs/examples/singer_tap_jsonl_example.py b/docs/examples/singer_tap_jsonl_example.py index fff64bdb1d..f59a831d73 100644 --- a/docs/examples/singer_tap_jsonl_example.py +++ b/docs/examples/singer_tap_jsonl_example.py @@ -1,17 +1,15 @@ -import dlt - -from dlt.common.storages.schema_storage import SchemaStorage - -from docs.examples.sources.singer_tap import singer_raw_stream from docs.examples.sources.jsonl import jsonl_file +from docs.examples.sources.singer_tap import singer_raw_stream +import dlt +from dlt.common.storages.schema_storage import SchemaStorage # load hubspot schema stub - it converts all field names with `timestamp` into timestamp type -schema = SchemaStorage.load_schema_file("docs/examples/schemas/", "hubspot", ("yaml", )) +schema = SchemaStorage.load_schema_file("docs/examples/schemas/", "hubspot", ("yaml",)) p = dlt.pipeline(destination="postgres", full_refresh=True) # now load a pipeline created from jsonl resource that feeds messages into singer tap transformer pipe = jsonl_file("docs/examples/data/singer_taps/tap_hubspot.jsonl") | singer_raw_stream() # provide hubspot schema info = p.run(pipe, schema=schema, credentials="postgres://loader@localhost:5432/dlt_data") -print(info) \ No newline at end of file +print(info) diff --git a/docs/examples/sources/google_sheets.py b/docs/examples/sources/google_sheets.py index 8a3d6b1d1c..d712133977 100644 --- a/docs/examples/sources/google_sheets.py +++ b/docs/examples/sources/google_sheets.py @@ -1,9 +1,9 @@ from typing import Any, Iterator, Sequence, Union, cast import dlt -from dlt.common.configuration.specs import GcpServiceAccountCredentials, GcpOAuthCredentials -from dlt.common.typing import DictStrAny, StrAny +from dlt.common.configuration.specs import GcpOAuthCredentials, GcpServiceAccountCredentials from dlt.common.exceptions import MissingDependencyException +from dlt.common.typing import DictStrAny, StrAny try: from apiclient.discovery import build @@ -16,38 +16,52 @@ # TODO: consider using https://github.com/burnash/gspread for spreadsheet discovery -def _initialize_sheets(credentials: Union[GcpOAuthCredentials, GcpServiceAccountCredentials]) -> Any: +def _initialize_sheets( + credentials: Union[GcpOAuthCredentials, GcpServiceAccountCredentials] +) -> Any: # Build the service object. - service = build('sheets', 'v4', credentials=credentials.to_native_credentials()) + service = build("sheets", "v4", credentials=credentials.to_native_credentials()) return service @dlt.source -def google_spreadsheet(spreadsheet_id: str, sheet_names: Sequence[str], credentials: Union[GcpServiceAccountCredentials, GcpOAuthCredentials, str, StrAny] = dlt.secrets.value) -> Any: - +def google_spreadsheet( + spreadsheet_id: str, + sheet_names: Sequence[str], + credentials: Union[ + GcpServiceAccountCredentials, GcpOAuthCredentials, str, StrAny + ] = dlt.secrets.value, +) -> Any: sheets = _initialize_sheets(cast(GcpServiceAccountCredentials, credentials)) # import pprint # meta = sheets.spreadsheets().get(spreadsheetId=spreadsheet_id, ranges=sheet_names, includeGridData=True).execute() # pprint.pprint(meta) def get_sheet(sheet_name: str) -> Iterator[DictStrAny]: - # get list of list of typed values - result = sheets.spreadsheets().values().get( - spreadsheetId=spreadsheet_id, - range=sheet_name, - # unformatted returns typed values - valueRenderOption="UNFORMATTED_VALUE", - # will return formatted dates - dateTimeRenderOption="FORMATTED_STRING" - ).execute() + result = ( + sheets.spreadsheets() + .values() + .get( + spreadsheetId=spreadsheet_id, + range=sheet_name, + # unformatted returns typed values + valueRenderOption="UNFORMATTED_VALUE", + # will return formatted dates + dateTimeRenderOption="FORMATTED_STRING", + ) + .execute() + ) # pprint.pprint(result) - values = result.get('values') + values = result.get("values") # yield dicts assuming row 0 contains headers and following rows values and all rows have identical length for v in values[1:]: yield {h: v for h, v in zip(values[0], v)} # create resources from supplied sheet names - return [dlt.resource(get_sheet(name), name=name, write_disposition="replace") for name in sheet_names] + return [ + dlt.resource(get_sheet(name), name=name, write_disposition="replace") + for name in sheet_names + ] diff --git a/docs/examples/sources/jsonl.py b/docs/examples/sources/jsonl.py index 282966d00a..5989d2054f 100644 --- a/docs/examples/sources/jsonl.py +++ b/docs/examples/sources/jsonl.py @@ -7,8 +7,9 @@ from dlt.common.typing import StrAny, StrOrBytesPath -def chunk_jsonl(path: StrOrBytesPath, chunk_size: int = 20) -> Union[Iterator[StrAny], Iterator[List[StrAny]]]: - +def chunk_jsonl( + path: StrOrBytesPath, chunk_size: int = 20 +) -> Union[Iterator[StrAny], Iterator[List[StrAny]]]: with open(path, "rb") as f: def _iter() -> Iterator[StrAny]: @@ -24,9 +25,13 @@ def _iter() -> Iterator[StrAny]: else: break + jsonl_file = dlt.resource(chunk_jsonl, name="jsonl", spec=BaseConfiguration) + @dlt.resource(name="jsonl") -def jsonl_files(paths: Sequence[StrOrBytesPath], chunk_size: int = 20) -> Union[Iterator[StrAny], Iterator[List[StrAny]]]: +def jsonl_files( + paths: Sequence[StrOrBytesPath], chunk_size: int = 20 +) -> Union[Iterator[StrAny], Iterator[List[StrAny]]]: for path in paths: yield from chunk_jsonl(path, chunk_size) diff --git a/docs/examples/sources/rasa/__init__.py b/docs/examples/sources/rasa/__init__.py index acd214368a..3a274af671 100644 --- a/docs/examples/sources/rasa/__init__.py +++ b/docs/examples/sources/rasa/__init__.py @@ -1 +1 @@ -from .rasa import rasa \ No newline at end of file +from .rasa import rasa diff --git a/docs/examples/sources/rasa/rasa.py b/docs/examples/sources/rasa/rasa.py index aa31b3c482..bb5699fad5 100644 --- a/docs/examples/sources/rasa/rasa.py +++ b/docs/examples/sources/rasa/rasa.py @@ -1,8 +1,8 @@ from typing import Any, Iterator import dlt -from dlt.common.typing import StrAny, TDataItem, TDataItems from dlt.common.time import timestamp_within +from dlt.common.typing import StrAny, TDataItem, TDataItems from dlt.extract.source import DltResource @@ -13,7 +13,7 @@ def rasa( source_env: str = None, initial_timestamp: float = None, end_timestamp: float = None, - store_last_timestamp: bool = True + store_last_timestamp: bool = True, ) -> Any: """Transforms the base resource provided in `data_from` into a rasa tracker store raw dataset where each event type get it's own table. The resource is a stream resource and it generates tables dynamically from data. The source uses `rasa.schema.yaml` file to initialize the schema @@ -34,7 +34,9 @@ def rasa( def events(source_events: TDataItems) -> Iterator[TDataItem]: # recover start_timestamp from state if given if store_last_timestamp: - start_timestamp = max(initial_timestamp or 0, dlt.current.source_state().get("start_timestamp", 0)) + start_timestamp = max( + initial_timestamp or 0, dlt.current.source_state().get("start_timestamp", 0) + ) # we expect tracker store events here last_timestamp: int = None @@ -51,7 +53,7 @@ def _proc_event(source_event: TDataItem) -> Iterator[TDataItem]: event = { "sender_id": source_event["sender_id"], "timestamp": last_timestamp, - "event": event_type + "event": event_type, } if source_env: event["source"] = source_env diff --git a/docs/examples/sources/singer_tap.py b/docs/examples/sources/singer_tap.py index 65c9b76e0b..f7e4988735 100644 --- a/docs/examples/sources/singer_tap.py +++ b/docs/examples/sources/singer_tap.py @@ -1,6 +1,8 @@ import os import tempfile -from typing import Any, Iterator, TypedDict, cast, Union +from typing import Any, Iterator, TypedDict, Union, cast + +from docs.examples.sources.stdout import json_stdout as singer_process_pipe import dlt from dlt.common import json @@ -8,10 +10,9 @@ from dlt.common.runners.venv import Venv from dlt.common.typing import DictStrAny, StrAny, StrOrBytesPath, TDataItem, TDataItems -from docs.examples.sources.stdout import json_stdout as singer_process_pipe - FilePathOrDict = Union[StrAny, StrOrBytesPath] + class SingerMessage(TypedDict): type: str # noqa: A003 @@ -24,6 +25,7 @@ class SingerRecord(SingerMessage): class SingerState(SingerMessage): value: DictStrAny + # try: # from singer import parse_message_from_obj, Message, RecordMessage, StateMessage # except ImportError: @@ -33,7 +35,9 @@ class SingerState(SingerMessage): # pip install ../singer/singer-python # https://github.com/datamill-co/singer-runner/tree/master/singer_runner # https://techgaun.github.io/active-forks/index.html#singer-io/singer-python -def get_source_from_stream(singer_messages: Iterator[SingerMessage], state: DictStrAny = None) -> Iterator[TDataItem]: +def get_source_from_stream( + singer_messages: Iterator[SingerMessage], state: DictStrAny = None +) -> Iterator[TDataItem]: last_state = {} for msg in singer_messages: if msg["type"] == "RECORD": @@ -57,7 +61,13 @@ def singer_raw_stream(singer_messages: TDataItems, use_state: bool = True) -> It @dlt.source(spec=BaseConfiguration) # use BaseConfiguration spec to prevent injections -def tap(venv: Venv, tap_name: str, config_file: FilePathOrDict, catalog_file: FilePathOrDict, use_state: bool = True) -> Any: +def tap( + venv: Venv, + tap_name: str, + config_file: FilePathOrDict, + catalog_file: FilePathOrDict, + use_state: bool = True, +) -> Any: # TODO: generate append/replace dispositions and some table/column hints from catalog files def as_config_file(config: FilePathOrDict) -> StrOrBytesPath: @@ -87,14 +97,15 @@ def singer_messages() -> Iterator[TDataItem]: else: state_params = () # type: ignore - pipe_iterator = singer_process_pipe(venv, - tap_name, - "--config", - os.path.abspath(config_file_path), - "--catalog", - os.path.abspath(catalog_file_path), - *state_params - ) + pipe_iterator = singer_process_pipe( + venv, + tap_name, + "--config", + os.path.abspath(config_file_path), + "--catalog", + os.path.abspath(catalog_file_path), + *state_params + ) yield from get_source_from_stream(pipe_iterator, state) # type: ignore return singer_messages diff --git a/docs/examples/sources/sql_query.py b/docs/examples/sources/sql_query.py index effa8740d5..b4568d2d2e 100644 --- a/docs/examples/sources/sql_query.py +++ b/docs/examples/sources/sql_query.py @@ -1,34 +1,39 @@ -from typing import Iterator, List, Any, Union from functools import partial +from typing import Any, Iterator, List, Union import dlt - from dlt.common.configuration.specs import ConnectionStringCredentials -from dlt.common.typing import AnyFun, DictStrAny, StrAny, TDataItem from dlt.common.exceptions import MissingDependencyException - +from dlt.common.typing import AnyFun, DictStrAny, StrAny, TDataItem try: # import gracefully and produce nice exception that explains the user what to do import pandas except ImportError: - raise MissingDependencyException("SQL Query Source", ["pandas"], "SQL Query Source temporarily uses pandas as DB interface") + raise MissingDependencyException( + "SQL Query Source", ["pandas"], "SQL Query Source temporarily uses pandas as DB interface" + ) try: from sqlalchemy.exc import NoSuchModuleError except ImportError: - raise MissingDependencyException("SQL Query Source", ["sqlalchemy"], "SQL Query Source temporarily uses pandas as DB interface") - + raise MissingDependencyException( + "SQL Query Source", + ["sqlalchemy"], + "SQL Query Source temporarily uses pandas as DB interface", + ) -def _query_data( - f: AnyFun -) -> Iterator[DictStrAny]: +def _query_data(f: AnyFun) -> Iterator[DictStrAny]: try: items = f() except NoSuchModuleError as m_exc: if "redshift.redshift_connector" in str(m_exc): - raise MissingDependencyException("SQL Query Source", ["sqlalchemy-redshift", "redshift_connector"], "Redshift dialect support for SqlAlchemy") + raise MissingDependencyException( + "SQL Query Source", + ["sqlalchemy-redshift", "redshift_connector"], + "Redshift dialect support for SqlAlchemy", + ) raise for i in items: @@ -46,11 +51,21 @@ def query_table( coerce_float: bool = True, parse_dates: Any = None, columns: List[str] = None, - chunk_size: int = 1000 + chunk_size: int = 1000, ) -> Any: print(credentials) assert isinstance(credentials, ConnectionStringCredentials) - f = partial(pandas.read_sql_table, table_name, credentials.to_native_representation(), table_schema_name, None, coerce_float, parse_dates, columns, chunksize=chunk_size) + f = partial( + pandas.read_sql_table, + table_name, + credentials.to_native_representation(), + table_schema_name, + None, + coerce_float, + parse_dates, + columns, + chunksize=chunk_size, + ) # if resource is returned from decorator function, it will override the hints from decorator return dlt.resource(_query_data(f), name=table_name) @@ -62,8 +77,18 @@ def query_sql( coerce_float: bool = True, parse_dates: Any = None, chunk_size: int = 1000, - dtype: Any = None + dtype: Any = None, ) -> Iterator[TDataItem]: assert isinstance(credentials, ConnectionStringCredentials) - f = partial(pandas.read_sql_query, sql, credentials.to_native_representation(), None, coerce_float, None, parse_dates, chunk_size, dtype) + f = partial( + pandas.read_sql_query, + sql, + credentials.to_native_representation(), + None, + coerce_float, + None, + parse_dates, + chunk_size, + dtype, + ) yield from _query_data(f) diff --git a/docs/examples/sources/stdout.py b/docs/examples/sources/stdout.py index 502d1073d6..fa51c8ce7a 100644 --- a/docs/examples/sources/stdout.py +++ b/docs/examples/sources/stdout.py @@ -1,7 +1,6 @@ from typing import Any, Iterator import dlt - from dlt.common import json from dlt.common.configuration.specs import BaseConfiguration from dlt.common.runners import Venv diff --git a/docs/snippets/conftest.py b/docs/snippets/conftest.py index a76ecb5360..fb0d47cf1c 100644 --- a/docs/snippets/conftest.py +++ b/docs/snippets/conftest.py @@ -1,3 +1 @@ -from tests.utils import patch_home_dir, autouse_test_storage, preserve_environ - - +from tests.utils import autouse_test_storage, patch_home_dir, preserve_environ diff --git a/docs/snippets/intro_snippet.py b/docs/snippets/intro_snippet.py index 8133c60080..8b1cc72455 100644 --- a/docs/snippets/intro_snippet.py +++ b/docs/snippets/intro_snippet.py @@ -1,20 +1,20 @@ # @@@SNIPSTART intro_snippet -import dlt import requests + +import dlt + # Create a dlt pipeline that will load # chess player data to the DuckDB destination pipeline = dlt.pipeline( - pipeline_name='chess_pipeline', - destination='duckdb', - dataset_name='player_data' + pipeline_name="chess_pipeline", destination="duckdb", dataset_name="player_data" ) # Grab some player data from Chess.com API data = [] -for player in ['magnuscarlsen', 'rpragchess']: - response = requests.get(f'https://api.chess.com/pub/player/{player}') +for player in ["magnuscarlsen", "rpragchess"]: + response = requests.get(f"https://api.chess.com/pub/player/{player}") response.raise_for_status() data.append(response.json()) # Extract, normalize, and load the data -info = pipeline.run(data, table_name='player') +info = pipeline.run(data, table_name="player") print(info) # @@@SNIPEND diff --git a/docs/snippets/intro_snippet_test.py b/docs/snippets/intro_snippet_test.py index f5265a36c1..b07cae23be 100644 --- a/docs/snippets/intro_snippet_test.py +++ b/docs/snippets/intro_snippet_test.py @@ -1,6 +1,6 @@ - -from tests.pipeline.utils import assert_load_info from docs.snippets.utils import run_snippet +from tests.pipeline.utils import assert_load_info + def test_intro_snippet() -> None: variables = run_snippet("intro_snippet") diff --git a/docs/snippets/utils.py b/docs/snippets/utils.py index 4958ecc57e..bf0ab83339 100644 --- a/docs/snippets/utils.py +++ b/docs/snippets/utils.py @@ -1,14 +1,15 @@ -from typing import Dict, Any - -from dlt.common.utils import set_working_dir +from typing import Any, Dict from tests.utils import TEST_STORAGE_ROOT, test_storage +from dlt.common.utils import set_working_dir + BASEPATH = "docs/snippets" + def run_snippet(filename: str) -> Dict[str, Any]: with set_working_dir(BASEPATH): code = open(f"{filename}.py", encoding="utf-8").read() variables: Dict[str, Any] = {} exec(code, variables) - return variables \ No newline at end of file + return variables diff --git a/tests/cases.py b/tests/cases.py index f87c329f36..b8d863d1ff 100644 --- a/tests/cases.py +++ b/tests/cases.py @@ -1,26 +1,28 @@ -from typing import Dict, List, Any import base64 +from typing import Any, Dict, List + from hexbytes import HexBytes -from dlt.common import Decimal, pendulum, json +from dlt.common import Decimal, json, pendulum from dlt.common.data_types import TDataType +from dlt.common.schema import TColumnSchema, TTableSchemaColumns +from dlt.common.time import ensure_pendulum_datetime, reduce_pendulum_datetime_precision from dlt.common.typing import StrAny from dlt.common.wei import Wei -from dlt.common.time import ensure_pendulum_datetime, reduce_pendulum_datetime_precision -from dlt.common.schema import TColumnSchema, TTableSchemaColumns - # _UUID = "c8209ee7-ee95-4b90-8c9f-f7a0f8b51014" JSON_TYPED_DICT: StrAny = { "str": "string", "decimal": Decimal("21.37"), - "big_decimal": Decimal("115792089237316195423570985008687907853269984665640564039457584007913129639935.1"), + "big_decimal": Decimal( + "115792089237316195423570985008687907853269984665640564039457584007913129639935.1" + ), "datetime": pendulum.parse("2005-04-02T20:37:37.358236Z"), "date": pendulum.parse("2022-02-02").date(), # "uuid": UUID(_UUID), "hexbytes": HexBytes("0x2137"), - "bytes": b'2137', - "wei": Wei.from_int256(2137, decimals=2) + "bytes": b"2137", + "wei": Wei.from_int256(2137, decimals=2), } JSON_TYPED_DICT_TYPES: Dict[str, TDataType] = { @@ -32,132 +34,56 @@ # "uuid": "text", "hexbytes": "binary", "bytes": "binary", - "wei": "wei" + "wei": "wei", } JSON_TYPED_DICT_NESTED = { "dict": dict(JSON_TYPED_DICT), "list_dicts": [dict(JSON_TYPED_DICT), dict(JSON_TYPED_DICT)], "list": list(JSON_TYPED_DICT.values()), - **JSON_TYPED_DICT + **JSON_TYPED_DICT, } TABLE_UPDATE: List[TColumnSchema] = [ - { - "name": "col1", - "data_type": "bigint", - "nullable": False - }, - { - "name": "col2", - "data_type": "double", - "nullable": False - }, - { - "name": "col3", - "data_type": "bool", - "nullable": False - }, - { - "name": "col4", - "data_type": "timestamp", - "nullable": False - }, - { - "name": "col5", - "data_type": "text", - "nullable": False - }, - { - "name": "col6", - "data_type": "decimal", - "nullable": False - }, - { - "name": "col7", - "data_type": "binary", - "nullable": False - }, - { - "name": "col8", - "data_type": "wei", - "nullable": False - }, - { - "name": "col9", - "data_type": "complex", - "nullable": False, - "variant": True - }, - { - "name": "col10", - "data_type": "date", - "nullable": False - }, - { - "name": "col1_null", - "data_type": "bigint", - "nullable": True - }, - { - "name": "col2_null", - "data_type": "double", - "nullable": True - }, - { - "name": "col3_null", - "data_type": "bool", - "nullable": True - }, - { - "name": "col4_null", - "data_type": "timestamp", - "nullable": True - }, - { - "name": "col5_null", - "data_type": "text", - "nullable": True - }, - { - "name": "col6_null", - "data_type": "decimal", - "nullable": True - }, - { - "name": "col7_null", - "data_type": "binary", - "nullable": True - }, - { - "name": "col8_null", - "data_type": "wei", - "nullable": True - }, - { - "name": "col9_null", - "data_type": "complex", - "nullable": True, - "variant": True - }, - { - "name": "col10_null", - "data_type": "date", - "nullable": True - } + {"name": "col1", "data_type": "bigint", "nullable": False}, + {"name": "col2", "data_type": "double", "nullable": False}, + {"name": "col3", "data_type": "bool", "nullable": False}, + {"name": "col4", "data_type": "timestamp", "nullable": False}, + {"name": "col5", "data_type": "text", "nullable": False}, + {"name": "col6", "data_type": "decimal", "nullable": False}, + {"name": "col7", "data_type": "binary", "nullable": False}, + {"name": "col8", "data_type": "wei", "nullable": False}, + {"name": "col9", "data_type": "complex", "nullable": False, "variant": True}, + {"name": "col10", "data_type": "date", "nullable": False}, + {"name": "col1_null", "data_type": "bigint", "nullable": True}, + {"name": "col2_null", "data_type": "double", "nullable": True}, + {"name": "col3_null", "data_type": "bool", "nullable": True}, + {"name": "col4_null", "data_type": "timestamp", "nullable": True}, + {"name": "col5_null", "data_type": "text", "nullable": True}, + {"name": "col6_null", "data_type": "decimal", "nullable": True}, + {"name": "col7_null", "data_type": "binary", "nullable": True}, + {"name": "col8_null", "data_type": "wei", "nullable": True}, + {"name": "col9_null", "data_type": "complex", "nullable": True, "variant": True}, + {"name": "col10_null", "data_type": "date", "nullable": True}, ] -TABLE_UPDATE_COLUMNS_SCHEMA: TTableSchemaColumns = {t["name"]:t for t in TABLE_UPDATE} +TABLE_UPDATE_COLUMNS_SCHEMA: TTableSchemaColumns = {t["name"]: t for t in TABLE_UPDATE} -TABLE_ROW_ALL_DATA_TYPES = { +TABLE_ROW_ALL_DATA_TYPES = { "col1": 989127831, "col2": 898912.821982, "col3": True, "col4": "2022-05-23T13:26:45.176451+00:00", "col5": "string data \n \r \x8e 🦆", "col6": Decimal("2323.34"), - "col7": b'binary data \n \r \x8e', + "col7": b"binary data \n \r \x8e", "col8": 2**56 + 92093890840, - "col9": {"complex":[1,2,3,"a"], "link": "?commen\ntU\nrn=urn%3Ali%3Acomment%3A%28acti\012 \6 \\vity%3A69'08444473\n\n551163392%2C6n \r \x8e9085"}, + "col9": { + "complex": [1, 2, 3, "a"], + "link": ( + "?commen\ntU\nrn=urn%3Ali%3Acomment%3A%28acti\012 \6" + " \\vity%3A69'08444473\n\n551163392%2C6n \r \x8e9085" + ), + }, "col10": "2023-02-27", "col1_null": None, "col2_null": None, @@ -168,7 +94,7 @@ "col7_null": None, "col8_null": None, "col9_null": None, - "col10_null": None + "col10_null": None, } @@ -176,7 +102,7 @@ def assert_all_data_types_row( db_row: List[Any], parse_complex_strings: bool = False, allow_base64_binary: bool = False, - timestamp_precision:int = 6 + timestamp_precision: int = 6, ) -> None: # content must equal # print(db_row) @@ -184,7 +110,9 @@ def assert_all_data_types_row( expected_rows = list(TABLE_ROW_ALL_DATA_TYPES.values()) parsed_date = pendulum.instance(db_row[3]) db_row[3] = reduce_pendulum_datetime_precision(parsed_date, timestamp_precision) - expected_rows[3] = reduce_pendulum_datetime_precision(ensure_pendulum_datetime(expected_rows[3]), timestamp_precision) + expected_rows[3] = reduce_pendulum_datetime_precision( + ensure_pendulum_datetime(expected_rows[3]), timestamp_precision + ) if isinstance(db_row[6], str): try: diff --git a/tests/cli/cases/deploy_pipeline/debug_pipeline.py b/tests/cli/cases/deploy_pipeline/debug_pipeline.py index 8d87c8ac3d..c49e8b524d 100644 --- a/tests/cli/cases/deploy_pipeline/debug_pipeline.py +++ b/tests/cli/cases/deploy_pipeline/debug_pipeline.py @@ -7,14 +7,17 @@ def example_resource(api_url=dlt.config.value, api_key=dlt.secrets.value, last_i @dlt.source -def example_source(api_url=dlt.config.value, api_key=dlt.secrets.value, last_id = 0): +def example_source(api_url=dlt.config.value, api_key=dlt.secrets.value, last_id=0): # return all the resources to be loaded return example_resource(api_url, api_key, last_id) -if __name__ == '__main__': - p = dlt.pipeline(pipeline_name="debug_pipeline", destination="postgres", dataset_name="debug_pipeline_data", full_refresh=False) - load_info = p.run( - example_source(last_id=819273998) +if __name__ == "__main__": + p = dlt.pipeline( + pipeline_name="debug_pipeline", + destination="postgres", + dataset_name="debug_pipeline_data", + full_refresh=False, ) + load_info = p.run(example_source(last_id=819273998)) print(load_info) diff --git a/tests/cli/common/test_cli_invoke.py b/tests/cli/common/test_cli_invoke.py index 99f34eeaa7..0977ec736d 100644 --- a/tests/cli/common/test_cli_invoke.py +++ b/tests/cli/common/test_cli_invoke.py @@ -1,50 +1,50 @@ import os -from pytest_console_scripts import ScriptRunner from unittest.mock import patch +from pytest_console_scripts import ScriptRunner +from tests.cli.utils import cloned_init_repo, echo_default_choice, repo_dir +from tests.utils import TEST_STORAGE_ROOT, patch_home_dir + from dlt.common.configuration.paths import get_dlt_data_dir -from dlt.common.utils import custom_environ, set_working_dir from dlt.common.pipeline import get_dlt_pipelines_dir - -from tests.cli.utils import echo_default_choice, repo_dir, cloned_init_repo -from tests.utils import TEST_STORAGE_ROOT, patch_home_dir +from dlt.common.utils import custom_environ, set_working_dir BASE_COMMANDS = ["init", "deploy", "pipeline", "telemetry", "schema"] def test_invoke_basic(script_runner: ScriptRunner) -> None: - result = script_runner.run(['dlt', '--version']) + result = script_runner.run(["dlt", "--version"]) assert result.returncode == 0 assert result.stdout.startswith("dlt ") - assert result.stderr == '' + assert result.stderr == "" - result = script_runner.run(['dlt', '--version'], shell=True) + result = script_runner.run(["dlt", "--version"], shell=True) assert result.returncode == 0 assert result.stdout.startswith("dlt ") - assert result.stderr == '' + assert result.stderr == "" for command in BASE_COMMANDS: - result = script_runner.run(['dlt', command, '--help']) + result = script_runner.run(["dlt", command, "--help"]) assert result.returncode == 0 assert result.stdout.startswith(f"usage: dlt {command}") - result = script_runner.run(['dlt', "N/A", '--help']) + result = script_runner.run(["dlt", "N/A", "--help"]) assert result.returncode != 0 def test_invoke_list_pipelines(script_runner: ScriptRunner) -> None: - result = script_runner.run(['dlt', 'pipeline', '--list-pipelines']) + result = script_runner.run(["dlt", "pipeline", "--list-pipelines"]) # directory does not exist (we point to TEST_STORAGE) assert result.returncode == 1 # create empty os.makedirs(get_dlt_pipelines_dir()) - result = script_runner.run(['dlt', 'pipeline', '--list-pipelines']) + result = script_runner.run(["dlt", "pipeline", "--list-pipelines"]) assert result.returncode == 0 assert "No pipelines found in" in result.stdout # info on non existing pipeline - result = script_runner.run(['dlt', 'pipeline', 'debug_pipeline', 'info']) + result = script_runner.run(["dlt", "pipeline", "debug_pipeline", "info"]) assert result.returncode == 1 assert "the pipeline was not found in" in result.stderr @@ -53,17 +53,17 @@ def test_invoke_init_chess_and_template(script_runner: ScriptRunner) -> None: with set_working_dir(TEST_STORAGE_ROOT): # store dlt data in test storage (like patch_home_dir) with custom_environ({"DLT_DATA_DIR": get_dlt_data_dir()}): - result = script_runner.run(['dlt', 'init', 'chess', 'dummy']) + result = script_runner.run(["dlt", "init", "chess", "dummy"]) assert "Verified source chess was added to your project!" in result.stdout assert result.returncode == 0 - result = script_runner.run(['dlt', 'init', 'debug_pipeline', 'dummy']) + result = script_runner.run(["dlt", "init", "debug_pipeline", "dummy"]) assert "Your new pipeline debug_pipeline is ready to be customized!" in result.stdout assert result.returncode == 0 def test_invoke_list_verified_sources(script_runner: ScriptRunner) -> None: known_sources = ["chess", "sql_database", "google_sheets", "pipedrive"] - result = script_runner.run(['dlt', 'init', '--list-verified-sources']) + result = script_runner.run(["dlt", "init", "--list-verified-sources"]) assert result.returncode == 0 for known_source in known_sources: assert known_source in result.stdout @@ -73,25 +73,31 @@ def test_invoke_deploy_project(script_runner: ScriptRunner) -> None: with set_working_dir(TEST_STORAGE_ROOT): # store dlt data in test storage (like patch_home_dir) with custom_environ({"DLT_DATA_DIR": get_dlt_data_dir()}): - result = script_runner.run(['dlt', 'deploy', 'debug_pipeline.py', 'github-action', '--schedule', '@daily']) + result = script_runner.run( + ["dlt", "deploy", "debug_pipeline.py", "github-action", "--schedule", "@daily"] + ) assert result.returncode == -4 assert "The pipeline script does not exist" in result.stderr - result = script_runner.run(['dlt', 'deploy', 'debug_pipeline.py', 'airflow-composer']) + result = script_runner.run(["dlt", "deploy", "debug_pipeline.py", "airflow-composer"]) assert result.returncode == -4 assert "The pipeline script does not exist" in result.stderr # now init - result = script_runner.run(['dlt', 'init', 'chess', 'dummy']) + result = script_runner.run(["dlt", "init", "chess", "dummy"]) assert result.returncode == 0 - result = script_runner.run(['dlt', 'deploy', 'chess_pipeline.py', 'github-action', '--schedule', '@daily']) + result = script_runner.run( + ["dlt", "deploy", "chess_pipeline.py", "github-action", "--schedule", "@daily"] + ) assert "NOTE: You must run the pipeline locally" in result.stdout - result = script_runner.run(['dlt', 'deploy', 'chess_pipeline.py', 'airflow-composer']) + result = script_runner.run(["dlt", "deploy", "chess_pipeline.py", "airflow-composer"]) assert "NOTE: You must run the pipeline locally" in result.stdout def test_invoke_deploy_mock(script_runner: ScriptRunner) -> None: # NOTE: you can mock only once per test with ScriptRunner !! with patch("dlt.cli.deploy_command.deploy_command") as _deploy_command: - script_runner.run(['dlt', 'deploy', 'debug_pipeline.py', 'github-action', '--schedule', '@daily']) + script_runner.run( + ["dlt", "deploy", "debug_pipeline.py", "github-action", "--schedule", "@daily"] + ) assert _deploy_command.called assert _deploy_command.call_args[1] == { "pipeline_script_path": "debug_pipeline.py", @@ -101,11 +107,25 @@ def test_invoke_deploy_mock(script_runner: ScriptRunner) -> None: "command": "deploy", "schedule": "@daily", "run_manually": True, - "run_on_push": False + "run_on_push": False, } _deploy_command.reset_mock() - script_runner.run(['dlt', 'deploy', 'debug_pipeline.py', 'github-action', '--schedule', '@daily', '--location', 'folder', '--branch', 'branch', '--run-on-push']) + script_runner.run( + [ + "dlt", + "deploy", + "debug_pipeline.py", + "github-action", + "--schedule", + "@daily", + "--location", + "folder", + "--branch", + "branch", + "--run-on-push", + ] + ) assert _deploy_command.called assert _deploy_command.call_args[1] == { "pipeline_script_path": "debug_pipeline.py", @@ -115,17 +135,17 @@ def test_invoke_deploy_mock(script_runner: ScriptRunner) -> None: "command": "deploy", "schedule": "@daily", "run_manually": True, - "run_on_push": True + "run_on_push": True, } # no schedule fails _deploy_command.reset_mock() - result = script_runner.run(['dlt', 'deploy', 'debug_pipeline.py', 'github-action']) + result = script_runner.run(["dlt", "deploy", "debug_pipeline.py", "github-action"]) assert not _deploy_command.called assert result.returncode != 0 assert "the following arguments are required: --schedule" in result.stderr # airflow without schedule works _deploy_command.reset_mock() - result = script_runner.run(['dlt', 'deploy', 'debug_pipeline.py', 'airflow-composer']) + result = script_runner.run(["dlt", "deploy", "debug_pipeline.py", "airflow-composer"]) assert _deploy_command.called assert result.returncode == 0 assert _deploy_command.call_args[1] == { @@ -134,11 +154,13 @@ def test_invoke_deploy_mock(script_runner: ScriptRunner) -> None: "repo_location": "https://github.com/dlt-hub/dlt-deploy-template.git", "branch": None, "command": "deploy", - 'secrets_format': 'toml' + "secrets_format": "toml", } # env secrets format _deploy_command.reset_mock() - result = script_runner.run(['dlt', 'deploy', 'debug_pipeline.py', 'airflow-composer', "--secrets-format", "env"]) + result = script_runner.run( + ["dlt", "deploy", "debug_pipeline.py", "airflow-composer", "--secrets-format", "env"] + ) assert _deploy_command.called assert result.returncode == 0 assert _deploy_command.call_args[1] == { @@ -147,5 +169,5 @@ def test_invoke_deploy_mock(script_runner: ScriptRunner) -> None: "repo_location": "https://github.com/dlt-hub/dlt-deploy-template.git", "branch": None, "command": "deploy", - 'secrets_format': 'env' + "secrets_format": "env", } diff --git a/tests/cli/common/test_telemetry_command.py b/tests/cli/common/test_telemetry_command.py index 4a3a0f4be1..cc9bb0f534 100644 --- a/tests/cli/common/test_telemetry_command.py +++ b/tests/cli/common/test_telemetry_command.py @@ -1,23 +1,22 @@ -import pytest +import contextlib import io import os -import contextlib from typing import Any from unittest.mock import patch +import pytest +from tests.utils import patch_random_home_dir, start_test_telemetry, test_storage + +from dlt.cli.telemetry_command import change_telemetry_status_command, telemetry_status_command +from dlt.cli.utils import track_command from dlt.common.configuration.container import Container from dlt.common.configuration.paths import DOT_DLT -from dlt.common.configuration.providers import ConfigTomlProvider, CONFIG_TOML +from dlt.common.configuration.providers import CONFIG_TOML, ConfigTomlProvider from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext from dlt.common.storages import FileStorage from dlt.common.typing import DictStrAny from dlt.common.utils import set_working_dir -from dlt.cli.utils import track_command -from dlt.cli.telemetry_command import telemetry_status_command, change_telemetry_status_command - -from tests.utils import patch_random_home_dir, start_test_telemetry, test_storage - def test_main_telemetry_command(test_storage: FileStorage) -> None: # home dir is patched to TEST_STORAGE, create project dir @@ -30,7 +29,12 @@ def _initial_providers(): glob_ctx = ConfigProvidersContext() glob_ctx.providers = _initial_providers() - with set_working_dir(test_storage.make_full_path("project")), Container().injectable_context(glob_ctx), patch("dlt.common.configuration.specs.config_providers_context.ConfigProvidersContext.initial_providers", _initial_providers): + with set_working_dir(test_storage.make_full_path("project")), Container().injectable_context( + glob_ctx + ), patch( + "dlt.common.configuration.specs.config_providers_context.ConfigProvidersContext.initial_providers", + _initial_providers, + ): # no config files: status is ON with io.StringIO() as buf, contextlib.redirect_stdout(buf): telemetry_status_command() @@ -75,7 +79,6 @@ def _initial_providers(): def test_command_instrumentation() -> None: - @track_command("instrument_ok", False, "in_ok_param", "in_ok_param_2") def instrument_ok(in_ok_param: str, in_ok_param_2: int) -> int: return 0 @@ -126,7 +129,15 @@ def instrument_raises_2(in_raises_2: bool) -> int: def test_instrumentation_wrappers() -> None: - from dlt.cli._dlt import init_command_wrapper, list_verified_sources_command_wrapper, DEFAULT_VERIFIED_SOURCES_REPO, pipeline_command_wrapper, deploy_command_wrapper, COMMAND_DEPLOY_REPO_LOCATION, DeploymentMethods + from dlt.cli._dlt import ( + COMMAND_DEPLOY_REPO_LOCATION, + DEFAULT_VERIFIED_SOURCES_REPO, + DeploymentMethods, + deploy_command_wrapper, + init_command_wrapper, + list_verified_sources_command_wrapper, + pipeline_command_wrapper, + ) from dlt.common.exceptions import UnknownDestinationModule with patch("dlt.common.runtime.segment.before_send", _mock_before_send): @@ -155,16 +166,22 @@ def test_instrumentation_wrappers() -> None: # assert msg["properties"]["operation"] == "list" SENT_ITEMS.clear() - deploy_command_wrapper("list.py", DeploymentMethods.github_actions.value, COMMAND_DEPLOY_REPO_LOCATION, schedule="* * * * *") + deploy_command_wrapper( + "list.py", + DeploymentMethods.github_actions.value, + COMMAND_DEPLOY_REPO_LOCATION, + schedule="* * * * *", + ) msg = SENT_ITEMS[0] assert msg["event"] == "command_deploy" assert msg["properties"]["deployment_method"] == DeploymentMethods.github_actions.value assert msg["properties"]["success"] is False - SENT_ITEMS = [] + + def _mock_before_send(event: DictStrAny, _unused_hint: Any = None) -> DictStrAny: SENT_ITEMS.append(event) # do not send this - return None \ No newline at end of file + return None diff --git a/tests/cli/conftest.py b/tests/cli/conftest.py index e3a47f6202..df6a8846fa 100644 --- a/tests/cli/conftest.py +++ b/tests/cli/conftest.py @@ -1 +1 @@ -from tests.utils import preserve_environ, autouse_test_storage, unload_modules, wipe_pipeline \ No newline at end of file +from tests.utils import autouse_test_storage, preserve_environ, unload_modules, wipe_pipeline diff --git a/tests/cli/test_config_toml_writer.py b/tests/cli/test_config_toml_writer.py index 5d08b23c05..8a769f6307 100644 --- a/tests/cli/test_config_toml_writer.py +++ b/tests/cli/test_config_toml_writer.py @@ -1,8 +1,9 @@ -from typing import Optional, Final +from typing import Final, Optional + import pytest import tomlkit -from dlt.cli.config_toml_writer import write_value, WritableConfigValue, write_values +from dlt.cli.config_toml_writer import WritableConfigValue, write_value, write_values EXAMPLE_COMMENT = "# please set me up!" @@ -15,10 +16,24 @@ def example_toml(): def test_write_value(example_toml): toml_table = example_toml - write_value(toml_table, "species", str, overwrite_existing=True, default_value="Homo sapiens", is_default_of_interest=True) + write_value( + toml_table, + "species", + str, + overwrite_existing=True, + default_value="Homo sapiens", + is_default_of_interest=True, + ) assert toml_table["species"] == "Homo sapiens" - write_value(toml_table, "species", str, overwrite_existing=False, default_value="Mus musculus", is_default_of_interest=True) + write_value( + toml_table, + "species", + str, + overwrite_existing=False, + default_value="Mus musculus", + is_default_of_interest=True, + ) assert toml_table["species"] == "Homo sapiens" # Test with is_default_of_interest=True and non-optional, non-final hint @@ -26,24 +41,42 @@ def test_write_value(example_toml): assert toml_table["species"] == "species" # Test with is_default_of_interest=False and non-optional, non-final hint, and no default - write_value(toml_table, "population", int, overwrite_existing=True, is_default_of_interest=False) + write_value( + toml_table, "population", int, overwrite_existing=True, is_default_of_interest=False + ) # non default get typed example value assert "population" in toml_table # Test with optional hint - write_value(toml_table, "habitat", Optional[str], overwrite_existing=True, is_default_of_interest=False) + write_value( + toml_table, "habitat", Optional[str], overwrite_existing=True, is_default_of_interest=False + ) assert "habitat" not in toml_table # test with optional hint of interest - write_value(toml_table, "habitat", Optional[str], overwrite_existing=True, is_default_of_interest=True) + write_value( + toml_table, "habitat", Optional[str], overwrite_existing=True, is_default_of_interest=True + ) assert "habitat" in toml_table # Test with final hint - write_value(toml_table, "immutable_trait", Final[str], overwrite_existing=True, is_default_of_interest=False) + write_value( + toml_table, + "immutable_trait", + Final[str], + overwrite_existing=True, + is_default_of_interest=False, + ) assert "immutable_trait" not in toml_table # Test with final hint of interest - write_value(toml_table, "immutable_trait", Final[str], overwrite_existing=True, is_default_of_interest=True) + write_value( + toml_table, + "immutable_trait", + Final[str], + overwrite_existing=True, + is_default_of_interest=True, + ) assert "immutable_trait" in toml_table @@ -61,7 +94,9 @@ def test_write_values(example_toml): new_values = [ WritableConfigValue("species", str, "Canis lupus", ("taxonomy", "genus")), - WritableConfigValue("species", str, "Canis lupus familiaris", ("taxonomy", "genus", "subgenus")), + WritableConfigValue( + "species", str, "Canis lupus familiaris", ("taxonomy", "genus", "subgenus") + ), WritableConfigValue("genome_size", float, 2.8, ("genomic_info",)), ] write_values(example_toml, new_values, overwrite_existing=False) @@ -118,7 +153,10 @@ def test_write_values_without_defaults(example_toml): assert example_toml["animal_info"]["is_animal"] is True assert example_toml["genomic_info"]["chromosome_data"]["chromosomes"] == ["a", "b", "c"] - assert example_toml["genomic_info"]["chromosome_data"]["chromosomes"].trivia.comment == EXAMPLE_COMMENT + assert ( + example_toml["genomic_info"]["chromosome_data"]["chromosomes"].trivia.comment + == EXAMPLE_COMMENT + ) assert example_toml["genomic_info"]["gene_data"]["genes"] == {"key": "value"} - assert example_toml["genomic_info"]["gene_data"]["genes"].trivia.comment == EXAMPLE_COMMENT \ No newline at end of file + assert example_toml["genomic_info"]["gene_data"]["genes"].trivia.comment == EXAMPLE_COMMENT diff --git a/tests/cli/test_deploy_command.py b/tests/cli/test_deploy_command.py index 3fedd2daed..e4b1c81327 100644 --- a/tests/cli/test_deploy_command.py +++ b/tests/cli/test_deploy_command.py @@ -1,51 +1,62 @@ -import os -import io import contextlib +import io +import os import shutil import tempfile from subprocess import CalledProcessError -from git import InvalidGitRepositoryError, NoSuchPathError + import pytest +from git import InvalidGitRepositoryError, NoSuchPathError +from tests.utils import TEST_STORAGE_ROOT, test_storage import dlt - +from dlt.cli import _dlt, deploy_command, echo +from dlt.cli.deploy_command_helpers import get_schedule_description +from dlt.cli.exceptions import CliCommandException from dlt.common.runners import Venv from dlt.common.storages.file_storage import FileStorage from dlt.common.typing import StrAny from dlt.common.utils import set_working_dir - -from dlt.cli import deploy_command, _dlt, echo -from dlt.cli.exceptions import CliCommandException from dlt.pipeline.exceptions import CannotRestorePipelineException -from dlt.cli.deploy_command_helpers import get_schedule_description - -from tests.utils import TEST_STORAGE_ROOT, test_storage - DEPLOY_PARAMS = [ ("github-action", {"schedule": "*/30 * * * *", "run_on_push": True, "run_manually": True}), ("airflow-composer", {"secrets_format": "toml"}), ("airflow-composer", {"secrets_format": "env"}), - ] +] @pytest.mark.parametrize("deployment_method,deployment_args", DEPLOY_PARAMS) -def test_deploy_command_no_repo(test_storage: FileStorage, deployment_method: str, deployment_args: StrAny) -> None: +def test_deploy_command_no_repo( + test_storage: FileStorage, deployment_method: str, deployment_args: StrAny +) -> None: pipeline_wf = tempfile.mkdtemp() shutil.copytree("tests/cli/cases/deploy_pipeline", pipeline_wf, dirs_exist_ok=True) with set_working_dir(pipeline_wf): # we do not have repo with pytest.raises(InvalidGitRepositoryError): - deploy_command.deploy_command("debug_pipeline.py", deployment_method, deploy_command.COMMAND_DEPLOY_REPO_LOCATION, **deployment_args) + deploy_command.deploy_command( + "debug_pipeline.py", + deployment_method, + deploy_command.COMMAND_DEPLOY_REPO_LOCATION, + **deployment_args + ) # test wrapper - rc = _dlt.deploy_command_wrapper("debug_pipeline.py", deployment_method, deploy_command.COMMAND_DEPLOY_REPO_LOCATION, **deployment_args) + rc = _dlt.deploy_command_wrapper( + "debug_pipeline.py", + deployment_method, + deploy_command.COMMAND_DEPLOY_REPO_LOCATION, + **deployment_args + ) assert rc == -3 @pytest.mark.parametrize("deployment_method,deployment_args", DEPLOY_PARAMS) -def test_deploy_command(test_storage: FileStorage, deployment_method: str, deployment_args: StrAny) -> None: +def test_deploy_command( + test_storage: FileStorage, deployment_method: str, deployment_args: StrAny +) -> None: # drop pipeline p = dlt.pipeline(pipeline_name="debug_pipeline") p._wipe_working_folder() @@ -53,22 +64,42 @@ def test_deploy_command(test_storage: FileStorage, deployment_method: str, deplo shutil.copytree("tests/cli/cases/deploy_pipeline", TEST_STORAGE_ROOT, dirs_exist_ok=True) with set_working_dir(TEST_STORAGE_ROOT): - from git import Repo, Remote + from git import Remote, Repo # we have a repo without git origin with Repo.init(".") as repo: # test no origin with pytest.raises(CliCommandException) as py_ex: - deploy_command.deploy_command("debug_pipeline.py", deployment_method, deploy_command.COMMAND_DEPLOY_REPO_LOCATION, **deployment_args) + deploy_command.deploy_command( + "debug_pipeline.py", + deployment_method, + deploy_command.COMMAND_DEPLOY_REPO_LOCATION, + **deployment_args + ) assert "Your current repository has no origin set" in py_ex.value.args[0] - rc = _dlt.deploy_command_wrapper("debug_pipeline.py", deployment_method, deploy_command.COMMAND_DEPLOY_REPO_LOCATION, **deployment_args) + rc = _dlt.deploy_command_wrapper( + "debug_pipeline.py", + deployment_method, + deploy_command.COMMAND_DEPLOY_REPO_LOCATION, + **deployment_args + ) assert rc == -5 # we have a repo that was never run Remote.create(repo, "origin", "git@github.com:rudolfix/dlt-cmd-test-2.git") with pytest.raises(CannotRestorePipelineException): - deploy_command.deploy_command("debug_pipeline.py", deployment_method, deploy_command.COMMAND_DEPLOY_REPO_LOCATION, **deployment_args) - rc = _dlt.deploy_command_wrapper("debug_pipeline.py", deployment_method, deploy_command.COMMAND_DEPLOY_REPO_LOCATION, **deployment_args) + deploy_command.deploy_command( + "debug_pipeline.py", + deployment_method, + deploy_command.COMMAND_DEPLOY_REPO_LOCATION, + **deployment_args + ) + rc = _dlt.deploy_command_wrapper( + "debug_pipeline.py", + deployment_method, + deploy_command.COMMAND_DEPLOY_REPO_LOCATION, + **deployment_args + ) assert rc == -2 # run the script with wrong credentials (it is postgres there) @@ -80,9 +111,19 @@ def test_deploy_command(test_storage: FileStorage, deployment_method: str, deplo venv.run_script("debug_pipeline.py") # print(py_ex.value.output) with pytest.raises(deploy_command.PipelineWasNotRun) as py_ex: - deploy_command.deploy_command("debug_pipeline.py", deployment_method, deploy_command.COMMAND_DEPLOY_REPO_LOCATION, **deployment_args) + deploy_command.deploy_command( + "debug_pipeline.py", + deployment_method, + deploy_command.COMMAND_DEPLOY_REPO_LOCATION, + **deployment_args + ) assert "The last pipeline run ended with error" in py_ex.value.args[0] - rc = _dlt.deploy_command_wrapper("debug_pipeline.py", deployment_method, deploy_command.COMMAND_DEPLOY_REPO_LOCATION, **deployment_args) + rc = _dlt.deploy_command_wrapper( + "debug_pipeline.py", + deployment_method, + deploy_command.COMMAND_DEPLOY_REPO_LOCATION, + **deployment_args + ) assert rc == -2 os.environ["DESTINATION__POSTGRES__CREDENTIALS"] = pg_credentials @@ -103,8 +144,8 @@ def test_deploy_command(test_storage: FileStorage, deployment_method: str, deplo _out = buf.getvalue() print(_out) # make sure our secret and config values are all present - assert 'api_key_9x3ehash' in _out - assert 'dlt_data' in _out + assert "api_key_9x3ehash" in _out + assert "dlt_data" in _out if "schedule" in deployment_args: assert get_schedule_description(deployment_args["schedule"]) secrets_format = deployment_args.get("secrets_format", "env") @@ -115,8 +156,17 @@ def test_deploy_command(test_storage: FileStorage, deployment_method: str, deplo # non existing script name with pytest.raises(NoSuchPathError): - deploy_command.deploy_command("no_pipeline.py", deployment_method, deploy_command.COMMAND_DEPLOY_REPO_LOCATION, **deployment_args) + deploy_command.deploy_command( + "no_pipeline.py", + deployment_method, + deploy_command.COMMAND_DEPLOY_REPO_LOCATION, + **deployment_args + ) with echo.always_choose(False, always_choose_value=True): - rc = _dlt.deploy_command_wrapper("no_pipeline.py", deployment_method, deploy_command.COMMAND_DEPLOY_REPO_LOCATION, **deployment_args) + rc = _dlt.deploy_command_wrapper( + "no_pipeline.py", + deployment_method, + deploy_command.COMMAND_DEPLOY_REPO_LOCATION, + **deployment_args + ) assert rc == -4 - diff --git a/tests/cli/test_init_command.py b/tests/cli/test_init_command.py index 36bfdc37a8..c8c7d8d59b 100644 --- a/tests/cli/test_init_command.py +++ b/tests/cli/test_init_command.py @@ -1,38 +1,42 @@ -import io -from copy import deepcopy +import contextlib import hashlib +import io import os -import contextlib +import re +from copy import deepcopy from subprocess import CalledProcessError -from typing import Any, List, Tuple, Optional -from hexbytes import HexBytes -import pytest +from typing import Any, List, Optional, Tuple from unittest import mock -import re -from packaging.requirements import Requirement +import pytest +from hexbytes import HexBytes +from packaging.requirements import Requirement +from tests.cli.utils import ( + cloned_init_repo, + echo_default_choice, + get_project_files, + get_repo_dir, + project_files, + repo_dir, +) +from tests.common.utils import modify_and_commit_file +from tests.utils import IMPLEMENTED_DESTINATIONS, clean_test_storage import dlt - +from dlt.cli import echo, init_command +from dlt.cli.exceptions import CliCommandException +from dlt.cli.init_command import SOURCES_MODULE_NAME, _select_source_files, files_ops +from dlt.cli.init_command import utils as cli_utils +from dlt.cli.requirements import SourceRequirements from dlt.common import git from dlt.common.configuration.paths import make_dlt_settings_path from dlt.common.configuration.providers import CONFIG_TOML, SECRETS_TOML, SecretsTomlProvider from dlt.common.runners import Venv -from dlt.common.storages.file_storage import FileStorage from dlt.common.source import _SOURCES +from dlt.common.storages.file_storage import FileStorage from dlt.common.utils import set_working_dir - - -from dlt.cli import init_command, echo -from dlt.cli.init_command import SOURCES_MODULE_NAME, utils as cli_utils, files_ops, _select_source_files -from dlt.cli.exceptions import CliCommandException -from dlt.cli.requirements import SourceRequirements -from dlt.reflection.script_visitor import PipelineScriptVisitor from dlt.reflection import names as n - -from tests.cli.utils import echo_default_choice, repo_dir, project_files, cloned_init_repo, get_repo_dir, get_project_files -from tests.common.utils import modify_and_commit_file -from tests.utils import IMPLEMENTED_DESTINATIONS, clean_test_storage +from dlt.reflection.script_visitor import PipelineScriptVisitor def get_verified_source_candidates(repo_dir: str) -> List[str]: @@ -83,7 +87,9 @@ def test_init_command_chess_verified_source(repo_dir: str, project_files: FileSt print(e) # now run the pipeline - os.environ.pop("DESTINATION__DUCKDB__CREDENTIALS", None) # settings from local project (secrets.toml etc.) + os.environ.pop( + "DESTINATION__DUCKDB__CREDENTIALS", None + ) # settings from local project (secrets.toml etc.) venv = Venv.restore_current() try: print(venv.run_script("chess_pipeline.py")) @@ -105,7 +111,9 @@ def test_init_list_verified_pipelines(repo_dir: str, project_files: FileStorage) init_command.list_verified_sources_command(repo_dir) -def test_init_list_verified_pipelines_update_warning(repo_dir: str, project_files: FileStorage) -> None: +def test_init_list_verified_pipelines_update_warning( + repo_dir: str, project_files: FileStorage +) -> None: """Sources listed include a warning if a different dlt version is required""" with mock.patch.object(SourceRequirements, "current_dlt_version", return_value="0.0.1"): with io.StringIO() as buf, contextlib.redirect_stdout(buf): @@ -121,7 +129,7 @@ def test_init_list_verified_pipelines_update_warning(repo_dir: str, project_file assert match # Try parsing the printed requiremnt string to verify it's valid parsed_requirement = Requirement(match.group(1)) - assert '0.0.1' not in parsed_requirement.specifier + assert "0.0.1" not in parsed_requirement.specifier def test_init_all_verified_sources_together(repo_dir: str, project_files: FileStorage) -> None: @@ -166,8 +174,10 @@ def test_init_all_verified_sources_isolated(cloned_init_repo: FileStorage) -> No assert_index_version_constraint(files, candidate) -@pytest.mark.parametrize('destination_name', IMPLEMENTED_DESTINATIONS) -def test_init_all_destinations(destination_name: str, project_files: FileStorage, repo_dir: str) -> None: +@pytest.mark.parametrize("destination_name", IMPLEMENTED_DESTINATIONS) +def test_init_all_destinations( + destination_name: str, project_files: FileStorage, repo_dir: str +) -> None: pipeline_name = f"generic_{destination_name}_pipeline" init_command.init_command(pipeline_name, destination_name, True, repo_dir) assert_init_files(project_files, pipeline_name, destination_name) @@ -189,7 +199,9 @@ def test_init_code_update_index_diff(repo_dir: str, project_files: FileStorage) sources_storage.delete(del_file_path) source_files = files_ops.get_verified_source_files(sources_storage, "pipedrive") - remote_index = files_ops.get_remote_source_index(sources_storage.storage_path, source_files.files, ">=0.3.5") + remote_index = files_ops.get_remote_source_index( + sources_storage.storage_path, source_files.files, ">=0.3.5" + ) assert mod_file_path in remote_index["files"] assert remote_index["is_dirty"] is True assert remote_index["files"][mod_file_path]["sha3_256"] == new_content_hash @@ -200,7 +212,7 @@ def test_init_code_update_index_diff(repo_dir: str, project_files: FileStorage) new, modified, deleted = files_ops.gen_index_diff(local_index, remote_index) # remote file entry in new assert new[new_file_path] == remote_index["files"][new_file_path] - #no git sha yet + # no git sha yet assert new[new_file_path]["git_sha"] is None # remote file entry in modified assert modified[mod_file_path] == remote_index["files"][mod_file_path] @@ -210,7 +222,9 @@ def test_init_code_update_index_diff(repo_dir: str, project_files: FileStorage) assert deleted[del_file_path] == local_index["files"][del_file_path] # get conflicts - conflict_modified, conflict_deleted = files_ops.find_conflict_files(local_index, new, modified, deleted, project_files) + conflict_modified, conflict_deleted = files_ops.find_conflict_files( + local_index, new, modified, deleted, project_files + ) assert conflict_modified == [] assert conflict_deleted == [] @@ -231,30 +245,40 @@ def test_init_code_update_index_diff(repo_dir: str, project_files: FileStorage) sources_storage.save(mod_file_path_2, local_content) local_index = files_ops.load_verified_sources_local_index("pipedrive") source_files = files_ops.get_verified_source_files(sources_storage, "pipedrive") - remote_index = files_ops.get_remote_source_index(sources_storage.storage_path, source_files.files, ">=0.3.5") + remote_index = files_ops.get_remote_source_index( + sources_storage.storage_path, source_files.files, ">=0.3.5" + ) new, modified, deleted = files_ops.gen_index_diff(local_index, remote_index) assert mod_file_path_2 in new - conflict_modified, conflict_deleted = files_ops.find_conflict_files(local_index, new, modified, deleted, project_files) + conflict_modified, conflict_deleted = files_ops.find_conflict_files( + local_index, new, modified, deleted, project_files + ) assert set(conflict_modified) == set([mod_file_path, new_file_path]) assert set(conflict_deleted) == set([del_file_path]) modified.update(new) # resolve conflicts in three different ways # skip option (the default) - res, sel_modified, sel_deleted = _select_source_files("pipedrive", deepcopy(modified), deepcopy(deleted), conflict_modified, conflict_deleted) + res, sel_modified, sel_deleted = _select_source_files( + "pipedrive", deepcopy(modified), deepcopy(deleted), conflict_modified, conflict_deleted + ) # noting is written, including non-conflicting file assert res == "s" assert sel_modified == {} assert sel_deleted == {} # Apply option - local changes will be lost with echo.always_choose(False, "a"): - res, sel_modified, sel_deleted = _select_source_files("pipedrive", deepcopy(modified), deepcopy(deleted), conflict_modified, conflict_deleted) + res, sel_modified, sel_deleted = _select_source_files( + "pipedrive", deepcopy(modified), deepcopy(deleted), conflict_modified, conflict_deleted + ) assert res == "a" assert sel_modified == modified assert sel_deleted == deleted # merge only non conflicting changes are applied with echo.always_choose(False, "m"): - res, sel_modified, sel_deleted = _select_source_files("pipedrive", deepcopy(modified), deepcopy(deleted), conflict_modified, conflict_deleted) + res, sel_modified, sel_deleted = _select_source_files( + "pipedrive", deepcopy(modified), deepcopy(deleted), conflict_modified, conflict_deleted + ) assert res == "m" assert len(sel_modified) == 1 and mod_file_path_2 in sel_modified assert sel_deleted == {} @@ -264,18 +288,26 @@ def test_init_code_update_index_diff(repo_dir: str, project_files: FileStorage) sources_storage.save(mod_file_path, local_content) project_files.delete(del_file_path) source_files = files_ops.get_verified_source_files(sources_storage, "pipedrive") - remote_index = files_ops.get_remote_source_index(sources_storage.storage_path, source_files.files, ">=0.3.5") + remote_index = files_ops.get_remote_source_index( + sources_storage.storage_path, source_files.files, ">=0.3.5" + ) new, modified, deleted = files_ops.gen_index_diff(local_index, remote_index) - conflict_modified, conflict_deleted = files_ops.find_conflict_files(local_index, new, modified, deleted, project_files) + conflict_modified, conflict_deleted = files_ops.find_conflict_files( + local_index, new, modified, deleted, project_files + ) assert conflict_modified == [] assert conflict_deleted == [] # generate a conflict by deleting file locally that is modified on remote project_files.delete(mod_file_path) source_files = files_ops.get_verified_source_files(sources_storage, "pipedrive") - remote_index = files_ops.get_remote_source_index(sources_storage.storage_path, source_files.files, ">=0.3.5") + remote_index = files_ops.get_remote_source_index( + sources_storage.storage_path, source_files.files, ">=0.3.5" + ) new, modified, deleted = files_ops.gen_index_diff(local_index, remote_index) - conflict_modified, conflict_deleted = files_ops.find_conflict_files(local_index, new, modified, deleted, project_files) + conflict_modified, conflict_deleted = files_ops.find_conflict_files( + local_index, new, modified, deleted, project_files + ) assert conflict_modified == [mod_file_path] @@ -306,8 +338,14 @@ def test_init_code_update_no_conflict(repo_dir: str, project_files: FileStorage) assert new_local_index["is_dirty"] is False assert new_local_index["last_commit_sha"] == commit.hexsha assert new_local_index["files"][mod_local_path]["commit_sha"] == commit.hexsha - assert new_local_index["files"][mod_local_path]["sha3_256"] == hashlib.sha3_256(bytes(new_content, encoding="ascii")).hexdigest() - assert new_local_index["files"][mod_local_path]["git_sha"] != local_index["files"][mod_local_path]["git_sha"] + assert ( + new_local_index["files"][mod_local_path]["sha3_256"] + == hashlib.sha3_256(bytes(new_content, encoding="ascii")).hexdigest() + ) + assert ( + new_local_index["files"][mod_local_path]["git_sha"] + != local_index["files"][mod_local_path]["git_sha"] + ) # all the other files must keep the old hashes for old_f, new_f in zip(local_index["files"].items(), new_local_index["files"].items()): # assert new_f[1]["commit_sha"] == commit.hexsha @@ -349,7 +387,9 @@ def test_init_code_update_no_conflict(repo_dir: str, project_files: FileStorage) @pytest.mark.parametrize("resolution", ["s", "a", "m"]) -def test_init_code_update_conflict(repo_dir: str, project_files: FileStorage, resolution: str) -> None: +def test_init_code_update_conflict( + repo_dir: str, project_files: FileStorage, resolution: str +) -> None: init_command.init_command("pipedrive", "duckdb", False, repo_dir) repo_storage = FileStorage(repo_dir) mod_local_path = os.path.join("pipedrive", "__init__.py") @@ -406,12 +446,16 @@ def test_init_requirements_text(repo_dir: str, project_files: FileStorage) -> No assert "pip3 install" in _out -def test_pipeline_template_sources_in_single_file(repo_dir: str, project_files: FileStorage) -> None: +def test_pipeline_template_sources_in_single_file( + repo_dir: str, project_files: FileStorage +) -> None: init_command.init_command("debug_pipeline", "bigquery", False, repo_dir) # _SOURCES now contains the sources from pipeline.py which simulates loading from two places with pytest.raises(CliCommandException) as cli_ex: init_command.init_command("generic_pipeline", "redshift", True, repo_dir) - assert "In init scripts you must declare all sources and resources in single file." in str(cli_ex.value) + assert "In init scripts you must declare all sources and resources in single file." in str( + cli_ex.value + ) def test_incompatible_dlt_version_warning(repo_dir: str, project_files: FileStorage) -> None: @@ -420,11 +464,18 @@ def test_incompatible_dlt_version_warning(repo_dir: str, project_files: FileStor init_command.init_command("facebook_ads", "bigquery", False, repo_dir) _out = buf.getvalue() - assert "WARNING: This pipeline requires a newer version of dlt than your installed version (0.1.1)." in _out + assert ( + "WARNING: This pipeline requires a newer version of dlt than your installed version" + " (0.1.1)." + in _out + ) def assert_init_files( - project_files: FileStorage, pipeline_name: str, destination_name: str, dependency_destination: Optional[str] = None + project_files: FileStorage, + pipeline_name: str, + destination_name: str, + dependency_destination: Optional[str] = None, ) -> PipelineScriptVisitor: visitor, _ = assert_common_files(project_files, pipeline_name + ".py", destination_name) assert not project_files.has_folder(pipeline_name) @@ -437,7 +488,9 @@ def assert_requirements_txt(project_files: FileStorage, destination_name: str) - assert project_files.has_file(cli_utils.REQUIREMENTS_TXT) assert "dlt" in project_files.load(cli_utils.REQUIREMENTS_TXT) # dlt dependency specifies destination_name as extra - source_requirements = SourceRequirements.from_string(project_files.load(cli_utils.REQUIREMENTS_TXT)) + source_requirements = SourceRequirements.from_string( + project_files.load(cli_utils.REQUIREMENTS_TXT) + ) assert destination_name in source_requirements.dlt_requirement.extras # Check that atleast some version range is specified assert len(source_requirements.dlt_requirement.specifier) >= 1 @@ -447,11 +500,23 @@ def assert_index_version_constraint(project_files: FileStorage, source_name: str # check dlt version constraint in .sources index for given source matches the one in requirements.txt local_index = files_ops.load_verified_sources_local_index(source_name) index_constraint = local_index["dlt_version_constraint"] - assert index_constraint == SourceRequirements.from_string(project_files.load(cli_utils.REQUIREMENTS_TXT)).dlt_version_constraint() - - -def assert_source_files(project_files: FileStorage, source_name: str, destination_name: str, has_source_section: bool = True) -> Tuple[PipelineScriptVisitor, SecretsTomlProvider]: - visitor, secrets = assert_common_files(project_files, source_name + "_pipeline.py", destination_name) + assert ( + index_constraint + == SourceRequirements.from_string( + project_files.load(cli_utils.REQUIREMENTS_TXT) + ).dlt_version_constraint() + ) + + +def assert_source_files( + project_files: FileStorage, + source_name: str, + destination_name: str, + has_source_section: bool = True, +) -> Tuple[PipelineScriptVisitor, SecretsTomlProvider]: + visitor, secrets = assert_common_files( + project_files, source_name + "_pipeline.py", destination_name + ) assert project_files.has_folder(source_name) source_secrets = secrets.get_value(source_name, Any, None, source_name) if has_source_section: @@ -472,7 +537,9 @@ def assert_source_files(project_files: FileStorage, source_name: str, destinatio return visitor, secrets -def assert_common_files(project_files: FileStorage, pipeline_script: str, destination_name: str) -> Tuple[PipelineScriptVisitor, SecretsTomlProvider]: +def assert_common_files( + project_files: FileStorage, pipeline_script: str, destination_name: str +) -> Tuple[PipelineScriptVisitor, SecretsTomlProvider]: # cwd must be project files - otherwise assert won't work assert os.getcwd() == project_files.storage_path assert project_files.has_file(make_dlt_settings_path(SECRETS_TOML)) @@ -480,7 +547,9 @@ def assert_common_files(project_files: FileStorage, pipeline_script: str, destin assert project_files.has_file(".gitignore") assert project_files.has_file(pipeline_script) # inspect script - visitor = cli_utils.parse_init_script("test", project_files.load(pipeline_script), pipeline_script) + visitor = cli_utils.parse_init_script( + "test", project_files.load(pipeline_script), pipeline_script + ) # check destinations for args in visitor.known_calls[n.PIPELINE]: assert args.arguments["destination"].value == destination_name @@ -490,7 +559,13 @@ def assert_common_files(project_files: FileStorage, pipeline_script: str, destin # destination is there assert secrets.get_value(destination_name, Any, None, "destination") is not None # certain values are never there - for not_there in ["dataset_name", "destination_name", "default_schema_name", "as_staging", "staging_config"]: + for not_there in [ + "dataset_name", + "destination_name", + "default_schema_name", + "as_staging", + "staging_config", + ]: assert secrets.get_value(not_there, Any, None, "destination", destination_name)[0] is None return visitor, secrets diff --git a/tests/cli/test_pipeline_command.py b/tests/cli/test_pipeline_command.py index 1ffc0c66aa..eb0b44ceeb 100644 --- a/tests/cli/test_pipeline_command.py +++ b/tests/cli/test_pipeline_command.py @@ -1,16 +1,22 @@ +import contextlib import io import os -import contextlib from subprocess import CalledProcessError +from tests.cli.utils import ( + cloned_init_repo, + echo_default_choice, + get_project_files, + get_repo_dir, + project_files, + repo_dir, +) + import dlt +from dlt.cli import echo, init_command, pipeline_command from dlt.common.runners.venv import Venv from dlt.common.storages.file_storage import FileStorage -from dlt.cli import echo, init_command, pipeline_command - -from tests.cli.utils import echo_default_choice, repo_dir, project_files, cloned_init_repo, get_repo_dir, get_project_files - def test_pipeline_command_operations(repo_dir: str, project_files: FileStorage) -> None: init_command.init_command("chess", "duckdb", False, repo_dir) @@ -23,7 +29,9 @@ def test_pipeline_command_operations(repo_dir: str, project_files: FileStorage) print(e) # now run the pipeline - os.environ.pop("DESTINATION__DUCKDB__CREDENTIALS", None) # settings from local project (secrets.toml etc.) + os.environ.pop( + "DESTINATION__DUCKDB__CREDENTIALS", None + ) # settings from local project (secrets.toml etc.) venv = Venv.restore_current() try: print(venv.run_script("chess_pipeline.py")) @@ -113,7 +121,9 @@ def test_pipeline_command_operations(repo_dir: str, project_files: FileStorage) with io.StringIO() as buf, contextlib.redirect_stdout(buf): with echo.always_choose(False, True): - pipeline_command.pipeline_command("drop", "chess_pipeline", None, 0, resources=["players_games"]) + pipeline_command.pipeline_command( + "drop", "chess_pipeline", None, 0, resources=["players_games"] + ) _out = buf.getvalue() assert "Selected resource(s): ['players_games']" in _out @@ -124,9 +134,17 @@ def test_pipeline_command_operations(repo_dir: str, project_files: FileStorage) with io.StringIO() as buf, contextlib.redirect_stdout(buf): # Test sync destination and drop when local state is missing - pipeline._pipeline_storage.delete_folder('', recursively=True) + pipeline._pipeline_storage.delete_folder("", recursively=True) with echo.always_choose(False, True): - pipeline_command.pipeline_command("drop", "chess_pipeline", None, 0, destination=pipeline.destination, dataset_name=pipeline.dataset_name, resources=["players_profiles"]) + pipeline_command.pipeline_command( + "drop", + "chess_pipeline", + None, + 0, + destination=pipeline.destination, + dataset_name=pipeline.dataset_name, + resources=["players_profiles"], + ) _out = buf.getvalue() assert "could not be restored: the pipeline was not found in " in _out diff --git a/tests/cli/utils.py b/tests/cli/utils.py index 6490d09407..fd18220296 100644 --- a/tests/cli/utils.py +++ b/tests/cli/utils.py @@ -1,19 +1,17 @@ import os -import pytest import shutil +import pytest +from tests.utils import TEST_STORAGE_ROOT + +from dlt.cli import echo +from dlt.cli.init_command import DEFAULT_VERIFIED_SOURCES_REPO from dlt.common import git from dlt.common.pipeline import get_dlt_repos_dir -from dlt.common.storages.file_storage import FileStorage from dlt.common.source import _SOURCES +from dlt.common.storages.file_storage import FileStorage from dlt.common.utils import set_working_dir, uniq_id -from dlt.cli import echo -from dlt.cli.init_command import DEFAULT_VERIFIED_SOURCES_REPO - -from tests.utils import TEST_STORAGE_ROOT - - INIT_REPO_LOCATION = DEFAULT_VERIFIED_SOURCES_REPO INIT_REPO_BRANCH = "master" PROJECT_DIR = os.path.join(TEST_STORAGE_ROOT, "project") @@ -29,7 +27,9 @@ def echo_default_choice() -> None: @pytest.fixture(scope="module") def cloned_init_repo() -> FileStorage: - return git.get_fresh_repo_files(INIT_REPO_LOCATION, get_dlt_repos_dir(), branch=INIT_REPO_BRANCH) + return git.get_fresh_repo_files( + INIT_REPO_LOCATION, get_dlt_repos_dir(), branch=INIT_REPO_BRANCH + ) @pytest.fixture @@ -45,7 +45,9 @@ def project_files() -> FileStorage: def get_repo_dir(cloned_init_repo: FileStorage) -> str: - repo_dir = os.path.abspath(os.path.join(TEST_STORAGE_ROOT, f"verified_sources_repo_{uniq_id()}")) + repo_dir = os.path.abspath( + os.path.join(TEST_STORAGE_ROOT, f"verified_sources_repo_{uniq_id()}") + ) # copy the whole repo into TEST_STORAGE_ROOT shutil.copytree(cloned_init_repo.storage_path, repo_dir) return repo_dir diff --git a/tests/common/cases/modules/uniq_mod_121.py b/tests/common/cases/modules/uniq_mod_121.py index 893d08d178..c5fc15ef62 100644 --- a/tests/common/cases/modules/uniq_mod_121.py +++ b/tests/common/cases/modules/uniq_mod_121.py @@ -1,8 +1,11 @@ import inspect + from dlt.common.utils import get_module_name + def find_my_module(): pass + if __name__ == "__main__": print(get_module_name(inspect.getmodule(find_my_module))) diff --git a/tests/common/configuration/test_accessors.py b/tests/common/configuration/test_accessors.py index 6c01f66d97..55a7463112 100644 --- a/tests/common/configuration/test_accessors.py +++ b/tests/common/configuration/test_accessors.py @@ -1,23 +1,28 @@ import datetime # noqa: 251 from typing import Any + import pytest +from tests.common.configuration.utils import environment, toml_providers +from tests.utils import preserve_environ import dlt from dlt.common import json from dlt.common.configuration.exceptions import ConfigFieldMissingException - -from dlt.common.configuration.providers import EnvironProvider, ConfigTomlProvider, SecretsTomlProvider +from dlt.common.configuration.providers import ( + ConfigTomlProvider, + EnvironProvider, + SecretsTomlProvider, +) from dlt.common.configuration.resolve import resolve_configuration -from dlt.common.configuration.specs import GcpServiceAccountCredentialsWithoutDefaults, ConnectionStringCredentials +from dlt.common.configuration.specs import ( + ConnectionStringCredentials, + GcpServiceAccountCredentialsWithoutDefaults, +) from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext -from dlt.common.configuration.utils import get_resolved_traces, ResolvedValueTrace +from dlt.common.configuration.utils import ResolvedValueTrace, get_resolved_traces from dlt.common.runners.configuration import PoolRunnerConfiguration from dlt.common.typing import AnyType, TSecretValue - -from tests.utils import preserve_environ -from tests.common.configuration.utils import environment, toml_providers - RESOLVED_TRACES = get_resolved_traces() @@ -39,19 +44,29 @@ def test_getter_accessor(toml_providers: ConfigProvidersContext, environment: An environment["VALUE"] = "{SET" assert dlt.config["value"] == "{SET" - assert RESOLVED_TRACES[".value"] == ResolvedValueTrace("value", "{SET", None, AnyType, [], EnvironProvider().name, None) + assert RESOLVED_TRACES[".value"] == ResolvedValueTrace( + "value", "{SET", None, AnyType, [], EnvironProvider().name, None + ) assert dlt.secrets["value"] == "{SET" - assert RESOLVED_TRACES[".value"] == ResolvedValueTrace("value", "{SET", None, TSecretValue, [], EnvironProvider().name, None) + assert RESOLVED_TRACES[".value"] == ResolvedValueTrace( + "value", "{SET", None, TSecretValue, [], EnvironProvider().name, None + ) # get sectioned values assert dlt.config["typecheck.str_val"] == "test string" - assert RESOLVED_TRACES["typecheck.str_val"] == ResolvedValueTrace("str_val", "test string", None, AnyType, ["typecheck"], ConfigTomlProvider().name, None) + assert RESOLVED_TRACES["typecheck.str_val"] == ResolvedValueTrace( + "str_val", "test string", None, AnyType, ["typecheck"], ConfigTomlProvider().name, None + ) environment["DLT__THIS__VALUE"] = "embedded" assert dlt.config["dlt.this.value"] == "embedded" - assert RESOLVED_TRACES["dlt.this.value"] == ResolvedValueTrace("value", "embedded", None, AnyType, ["dlt", "this"], EnvironProvider().name, None) + assert RESOLVED_TRACES["dlt.this.value"] == ResolvedValueTrace( + "value", "embedded", None, AnyType, ["dlt", "this"], EnvironProvider().name, None + ) assert dlt.secrets["dlt.this.value"] == "embedded" - assert RESOLVED_TRACES["dlt.this.value"] == ResolvedValueTrace("value", "embedded", None, TSecretValue, ["dlt", "this"], EnvironProvider().name, None) + assert RESOLVED_TRACES["dlt.this.value"] == ResolvedValueTrace( + "value", "embedded", None, TSecretValue, ["dlt", "this"], EnvironProvider().name, None + ) def test_getter_auto_cast(toml_providers: ConfigProvidersContext, environment: Any) -> None: @@ -83,7 +98,7 @@ def test_getter_auto_cast(toml_providers: ConfigProvidersContext, environment: A assert dlt.config["value"] == {"a": 1} assert dlt.config["value"]["a"] == 1 # if not dict or list then original string must be returned, null is a JSON -> None - environment["VALUE"] = 'null' + environment["VALUE"] = "null" assert dlt.config["value"] == "null" # typed values are returned as they are @@ -91,11 +106,32 @@ def test_getter_auto_cast(toml_providers: ConfigProvidersContext, environment: A # access dict from toml services_json_dict = dlt.secrets["destination.bigquery"] - assert dlt.secrets["destination.bigquery"]["client_email"] == "loader@a7513.iam.gserviceaccount.com" - assert RESOLVED_TRACES["destination.bigquery"] == ResolvedValueTrace("bigquery", services_json_dict, None, TSecretValue, ["destination"], SecretsTomlProvider().name, None) + assert ( + dlt.secrets["destination.bigquery"]["client_email"] + == "loader@a7513.iam.gserviceaccount.com" + ) + assert RESOLVED_TRACES["destination.bigquery"] == ResolvedValueTrace( + "bigquery", + services_json_dict, + None, + TSecretValue, + ["destination"], + SecretsTomlProvider().name, + None, + ) # equivalent - assert dlt.secrets["destination.bigquery.client_email"] == "loader@a7513.iam.gserviceaccount.com" - assert RESOLVED_TRACES["destination.bigquery.client_email"] == ResolvedValueTrace("client_email", "loader@a7513.iam.gserviceaccount.com", None, TSecretValue, ["destination", "bigquery"], SecretsTomlProvider().name, None) + assert ( + dlt.secrets["destination.bigquery.client_email"] == "loader@a7513.iam.gserviceaccount.com" + ) + assert RESOLVED_TRACES["destination.bigquery.client_email"] == ResolvedValueTrace( + "client_email", + "loader@a7513.iam.gserviceaccount.com", + None, + TSecretValue, + ["destination", "bigquery"], + SecretsTomlProvider().name, + None, + ) def test_getter_accessor_typed(toml_providers: ConfigProvidersContext, environment: Any) -> None: @@ -104,7 +140,9 @@ def test_getter_accessor_typed(toml_providers: ConfigProvidersContext, environme # the typed version coerces the value into desired type, in this case "dict" -> "str" assert dlt.secrets.get("credentials", str) == credentials_str # note that trace keeps original value of "credentials" which was of dictionary type - assert RESOLVED_TRACES[".credentials"] == ResolvedValueTrace("credentials", json.loads(credentials_str), None, str, [], SecretsTomlProvider().name, None) + assert RESOLVED_TRACES[".credentials"] == ResolvedValueTrace( + "credentials", json.loads(credentials_str), None, str, [], SecretsTomlProvider().name, None + ) # unchanged type assert isinstance(dlt.secrets.get("credentials"), dict) # fail on type coercion @@ -117,7 +155,15 @@ def test_getter_accessor_typed(toml_providers: ConfigProvidersContext, environme credentials_str = "databricks+connector://token:@:443/?conn_timeout=15&search_path=a,b,c" c = dlt.secrets.get("databricks.credentials", ConnectionStringCredentials) # as before: the value in trace is the value coming from the provider (as is) - assert RESOLVED_TRACES["databricks.credentials"] == ResolvedValueTrace("credentials", credentials_str, None, ConnectionStringCredentials, ["databricks"], SecretsTomlProvider().name, ConnectionStringCredentials) + assert RESOLVED_TRACES["databricks.credentials"] == ResolvedValueTrace( + "credentials", + credentials_str, + None, + ConnectionStringCredentials, + ["databricks"], + SecretsTomlProvider().name, + ConnectionStringCredentials, + ) assert c.drivername == "databricks+connector" c = dlt.secrets.get("destination.credentials", GcpServiceAccountCredentialsWithoutDefaults) assert c.client_email == "loader@a7513.iam.gserviceaccount.com" @@ -142,12 +188,14 @@ def test_setter(toml_providers: ConfigProvidersContext, environment: Any) -> Non dlt.secrets["pipeline.new.credentials"] = {"api_key": "skjo87a7nnAAaa"} assert dlt.secrets["pipeline.new.credentials"] == {"api_key": "skjo87a7nnAAaa"} # check the toml directly - assert dlt.secrets.writable_provider._toml["pipeline"]["new"]["credentials"] == {"api_key": "skjo87a7nnAAaa"} + assert dlt.secrets.writable_provider._toml["pipeline"]["new"]["credentials"] == { + "api_key": "skjo87a7nnAAaa" + } # mod the config and use it to resolve the configuration dlt.config["pool"] = {"pool_type": "process", "workers": 21} - c = resolve_configuration(PoolRunnerConfiguration(), sections=("pool", )) - assert dict(c) == {"pool_type": "process", "workers": 21, 'run_sleep': 0.1} + c = resolve_configuration(PoolRunnerConfiguration(), sections=("pool",)) + assert dict(c) == {"pool_type": "process", "workers": 21, "run_sleep": 0.1} def test_secrets_separation(toml_providers: ConfigProvidersContext) -> None: @@ -161,13 +209,19 @@ def test_secrets_separation(toml_providers: ConfigProvidersContext) -> None: def test_access_injection(toml_providers: ConfigProvidersContext) -> None: - @dlt.source - def the_source(api_type=dlt.config.value, credentials: GcpServiceAccountCredentialsWithoutDefaults=dlt.secrets.value, databricks_creds: ConnectionStringCredentials=dlt.secrets.value): + def the_source( + api_type=dlt.config.value, + credentials: GcpServiceAccountCredentialsWithoutDefaults = dlt.secrets.value, + databricks_creds: ConnectionStringCredentials = dlt.secrets.value, + ): assert api_type == "REST" assert credentials.client_email == "loader@a7513.iam.gserviceaccount.com" assert databricks_creds.drivername == "databricks+connector" - return dlt.resource([1,2,3], name="data") + return dlt.resource([1, 2, 3], name="data") # inject first argument, the rest pass explicitly - the_source(credentials=dlt.secrets["destination.credentials"], databricks_creds=dlt.secrets["databricks.credentials"]) + the_source( + credentials=dlt.secrets["destination.credentials"], + databricks_creds=dlt.secrets["databricks.credentials"], + ) diff --git a/tests/common/configuration/test_configuration.py b/tests/common/configuration/test_configuration.py index 0e09e93285..11d12cbc34 100644 --- a/tests/common/configuration/test_configuration.py +++ b/tests/common/configuration/test_configuration.py @@ -1,52 +1,88 @@ -import pytest import datetime # noqa: I251 -from unittest.mock import patch from typing import Any, Dict, Final, List, Mapping, MutableMapping, NewType, Optional, Type, Union +from unittest.mock import patch -from dlt.common import json, pendulum, Decimal, Wei -from dlt.common.configuration.providers.provider import ConfigProvider -from dlt.common.configuration.specs.gcp_credentials import GcpServiceAccountCredentialsWithoutDefaults -from dlt.common.utils import custom_environ -from dlt.common.typing import AnyType, DictStrAny, StrAny, TSecretValue, extract_inner_type -from dlt.common.configuration.exceptions import ( - ConfigFieldMissingTypeHintException, ConfigFieldTypeHintNotSupported, - InvalidNativeValue, LookupTrace, ValueNotSecretException, UnmatchedConfigHintResolversException +import pytest +from tests.common.configuration.utils import ( + COERCIONS, + CoercionTestConfiguration, + MockProvider, + SecretConfiguration, + SecretCredentials, + SectionedConfiguration, + WithCredentialsConfiguration, + WrongConfiguration, + env_provider, + environment, + mock_provider, + reset_resolved_traces, ) -from dlt.common.configuration import configspec, ConfigFieldMissingException, ConfigValueCannotBeCoercedException, resolve, is_valid_hint, resolve_type -from dlt.common.configuration.specs import BaseConfiguration, RunConfiguration, ConnectionStringCredentials -from dlt.common.configuration.providers import environ as environ_provider, toml -from dlt.common.configuration.utils import get_resolved_traces, ResolvedValueTrace, serialize_value, deserialize_value, add_config_dict_to_env, add_config_to_env - from tests.utils import preserve_environ -from tests.common.configuration.utils import ( - MockProvider, CoercionTestConfiguration, COERCIONS, SecretCredentials, WithCredentialsConfiguration, WrongConfiguration, SecretConfiguration, - SectionedConfiguration, environment, mock_provider, env_provider, reset_resolved_traces) + +from dlt.common import Decimal, Wei, json, pendulum +from dlt.common.configuration import ( + ConfigFieldMissingException, + ConfigValueCannotBeCoercedException, + configspec, + is_valid_hint, + resolve, + resolve_type, +) +from dlt.common.configuration.exceptions import ( + ConfigFieldMissingTypeHintException, + ConfigFieldTypeHintNotSupported, + InvalidNativeValue, + LookupTrace, + UnmatchedConfigHintResolversException, + ValueNotSecretException, +) +from dlt.common.configuration.providers import environ as environ_provider +from dlt.common.configuration.providers import toml +from dlt.common.configuration.providers.provider import ConfigProvider +from dlt.common.configuration.specs import ( + BaseConfiguration, + ConnectionStringCredentials, + RunConfiguration, +) +from dlt.common.configuration.specs.gcp_credentials import ( + GcpServiceAccountCredentialsWithoutDefaults, +) +from dlt.common.configuration.utils import ( + ResolvedValueTrace, + add_config_dict_to_env, + add_config_to_env, + deserialize_value, + get_resolved_traces, + serialize_value, +) +from dlt.common.typing import AnyType, DictStrAny, StrAny, TSecretValue, extract_inner_type +from dlt.common.utils import custom_environ INVALID_COERCIONS = { # 'STR_VAL': 'test string', # string always OK - 'int_val': "a12345", - 'bool_val': "not_bool", # bool overridden by string - that is the most common problem - 'list_val': {"2": 1, "3": 3.0}, - 'dict_val': "{'a': 1, 'b', '2'}", - 'bytes_val': 'Hello World!', - 'float_val': "invalid", + "int_val": "a12345", + "bool_val": "not_bool", # bool overridden by string - that is the most common problem + "list_val": {"2": 1, "3": 3.0}, + "dict_val": "{'a': 1, 'b', '2'}", + "bytes_val": "Hello World!", + "float_val": "invalid", "tuple_val": "{1:2}", "date_val": "01 May 2022", - "dec_val": True + "dec_val": True, } EXCEPTED_COERCIONS = { # allows to use int for float - 'float_val': 10, + "float_val": 10, # allows to use float for str - 'str_val': 10.0 + "str_val": 10.0, } COERCED_EXCEPTIONS = { # allows to use int for float - 'float_val': 10.0, + "float_val": 10.0, # allows to use float for str - 'str_val': "10.0" + "str_val": "10.0", } @@ -150,26 +186,26 @@ class ConfigWithDynamicType(BaseConfiguration): discriminator: str embedded_config: BaseConfiguration - @resolve_type('embedded_config') + @resolve_type("embedded_config") def resolve_embedded_type(self) -> Type[BaseConfiguration]: - if self.discriminator == 'a': + if self.discriminator == "a": return DynamicConfigA - elif self.discriminator == 'b': + elif self.discriminator == "b": return DynamicConfigB return BaseConfiguration @configspec class ConfigWithInvalidDynamicType(BaseConfiguration): - @resolve_type('a') + @resolve_type("a") def resolve_a_type(self) -> Type[BaseConfiguration]: return DynamicConfigA - @resolve_type('b') + @resolve_type("b") def resolve_b_type(self) -> Type[BaseConfiguration]: return DynamicConfigB - @resolve_type('c') + @resolve_type("c") def resolve_c_type(self) -> Type[BaseConfiguration]: return DynamicConfigC @@ -179,13 +215,13 @@ class SubclassConfigWithDynamicType(ConfigWithDynamicType): is_number: bool dynamic_type_field: Any - @resolve_type('embedded_config') + @resolve_type("embedded_config") def resolve_embedded_type(self) -> Type[BaseConfiguration]: - if self.discriminator == 'c': + if self.discriminator == "c": return DynamicConfigC return super().resolve_embedded_type() - @resolve_type('dynamic_type_field') + @resolve_type("dynamic_type_field") def resolve_dynamic_type_field(self) -> Type[Union[int, str]]: if self.is_number: return int @@ -209,7 +245,9 @@ def test_initial_config_state() -> None: def test_set_default_config_value(environment: Any) -> None: # set from init method - c = resolve.resolve_configuration(InstrumentedConfiguration(head="h", tube=["a", "b"], heels="he")) + c = resolve.resolve_configuration( + InstrumentedConfiguration(head="h", tube=["a", "b"], heels="he") + ) assert c.to_native_representation() == "h>a>b>he" # set from native form c = resolve.resolve_configuration(InstrumentedConfiguration(), explicit_value="h>a>b>he") @@ -217,7 +255,10 @@ def test_set_default_config_value(environment: Any) -> None: assert c.tube == ["a", "b"] assert c.heels == "he" # set from dictionary - c = resolve.resolve_configuration(InstrumentedConfiguration(), explicit_value={"head": "h", "tube": ["tu", "be"], "heels": "xhe"}) + c = resolve.resolve_configuration( + InstrumentedConfiguration(), + explicit_value={"head": "h", "tube": ["tu", "be"], "heels": "xhe"}, + ) assert c.to_native_representation() == "h>tu>be>xhe" @@ -226,9 +267,14 @@ def test_explicit_values(environment: Any) -> None: environment["PIPELINE_NAME"] = "env name" environment["CREATED_VAL"] = "12837" # set explicit values and allow partial config - c = resolve.resolve_configuration(CoercionTestConfiguration(), - explicit_value={"pipeline_name": "initial name", "none_val": type(environment), "bytes_val": b"str"}, - accept_partial=True + c = resolve.resolve_configuration( + CoercionTestConfiguration(), + explicit_value={ + "pipeline_name": "initial name", + "none_val": type(environment), + "bytes_val": b"str", + }, + accept_partial=True, ) # explicit assert c.pipeline_name == "initial name" @@ -237,13 +283,17 @@ def test_explicit_values(environment: Any) -> None: assert c.none_val == type(environment) # unknown field in explicit value dict is ignored - c = resolve.resolve_configuration(CoercionTestConfiguration(), explicit_value={"created_val": "3343"}, accept_partial=True) + c = resolve.resolve_configuration( + CoercionTestConfiguration(), explicit_value={"created_val": "3343"}, accept_partial=True + ) assert "created_val" not in c def test_explicit_values_false_when_bool() -> None: # values like 0, [], "" all coerce to bool False - c = resolve.resolve_configuration(InstrumentedConfiguration(), explicit_value={"head": "", "tube": [], "heels": ""}) + c = resolve.resolve_configuration( + InstrumentedConfiguration(), explicit_value={"head": "", "tube": [], "heels": ""} + ) assert c.head == "" assert c.tube == [] assert c.heels == "" @@ -268,7 +318,6 @@ def test_default_values(environment: Any) -> None: def test_raises_on_final_value_change(environment: Any) -> None: - @configspec class FinalConfiguration(BaseConfiguration): pipeline_name: Final[str] = "comp" @@ -301,7 +350,10 @@ def test_explicit_native_always_skips_resolve(environment: Any) -> None: # explicit representation environment["INS"] = "h>a>b>he" - c = resolve.resolve_configuration(InstrumentedConfiguration(), explicit_value={"head": "h", "tube": ["tu", "be"], "heels": "uhe"}) + c = resolve.resolve_configuration( + InstrumentedConfiguration(), + explicit_value={"head": "h", "tube": ["tu", "be"], "heels": "uhe"}, + ) assert c.heels == "uhe" # also the native explicit value @@ -324,7 +376,10 @@ def test_skip_lookup_native_config_value_if_no_config_section(environment: Any) # the INSTRUMENTED is not looked up because InstrumentedConfiguration has no section with custom_environ({"INSTRUMENTED": "he>tu>u>be>h"}): with pytest.raises(ConfigFieldMissingException) as py_ex: - resolve.resolve_configuration(EmbeddedConfiguration(), explicit_value={"default": "set", "sectioned": {"password": "pwd"}}) + resolve.resolve_configuration( + EmbeddedConfiguration(), + explicit_value={"default": "set", "sectioned": {"password": "pwd"}}, + ) assert py_ex.value.spec_name == "InstrumentedConfiguration" assert py_ex.value.fields == ["head", "tube", "heels"] @@ -348,14 +403,28 @@ def test_on_resolved(environment: Any) -> None: def test_embedded_config(environment: Any) -> None: # resolve all embedded config, using explicit value for instrumented config and explicit dict for sectioned config - C = resolve.resolve_configuration(EmbeddedConfiguration(), explicit_value={"default": "set", "instrumented": "h>tu>be>xhe", "sectioned": {"password": "pwd"}}) + C = resolve.resolve_configuration( + EmbeddedConfiguration(), + explicit_value={ + "default": "set", + "instrumented": "h>tu>be>xhe", + "sectioned": {"password": "pwd"}, + }, + ) assert C.default == "set" assert C.instrumented.to_native_representation() == "h>tu>be>xhe" assert C.sectioned.password == "pwd" # resolve but providing values via env with custom_environ( - {"INSTRUMENTED__HEAD": "h", "INSTRUMENTED__TUBE": '["tu", "u", "be"]', "INSTRUMENTED__HEELS": "xhe", "SECTIONED__PASSWORD": "passwd", "DEFAULT": "DEF"}): + { + "INSTRUMENTED__HEAD": "h", + "INSTRUMENTED__TUBE": '["tu", "u", "be"]', + "INSTRUMENTED__HEELS": "xhe", + "SECTIONED__PASSWORD": "passwd", + "DEFAULT": "DEF", + } + ): C = resolve.resolve_configuration(EmbeddedConfiguration()) assert C.default == "DEF" assert C.instrumented.to_native_representation() == "h>tu>u>be>xhe" @@ -379,11 +448,23 @@ def test_embedded_config(environment: Any) -> None: with patch.object(InstrumentedConfiguration, "__section__", "instrumented"): with custom_environ({"INSTRUMENTED": "he>tu>u>be>h"}): with pytest.raises(RuntimeError): - resolve.resolve_configuration(EmbeddedConfiguration(), explicit_value={"default": "set", "sectioned": {"password": "pwd"}}) + resolve.resolve_configuration( + EmbeddedConfiguration(), + explicit_value={"default": "set", "sectioned": {"password": "pwd"}}, + ) # part via env part via explicit values - with custom_environ({"INSTRUMENTED__HEAD": "h", "INSTRUMENTED__TUBE": '["tu", "u", "be"]', "INSTRUMENTED__HEELS": "xhe"}): - C = resolve.resolve_configuration(EmbeddedConfiguration(), explicit_value={"default": "set", "sectioned": {"password": "pwd"}}) + with custom_environ( + { + "INSTRUMENTED__HEAD": "h", + "INSTRUMENTED__TUBE": '["tu", "u", "be"]', + "INSTRUMENTED__HEELS": "xhe", + } + ): + C = resolve.resolve_configuration( + EmbeddedConfiguration(), + explicit_value={"default": "set", "sectioned": {"password": "pwd"}}, + ) assert C.instrumented.to_native_representation() == "h>tu>u>be>xhe" @@ -392,7 +473,11 @@ def test_embedded_explicit_value_over_provider(environment: Any) -> None: with patch.object(InstrumentedConfiguration, "__section__", "instrumented"): with custom_environ({"INSTRUMENTED": "h>tu>u>be>he"}): # explicit value over the env - c = resolve.resolve_configuration(EmbeddedConfiguration(), explicit_value={"instrumented": "h>tu>be>xhe"}, accept_partial=True) + c = resolve.resolve_configuration( + EmbeddedConfiguration(), + explicit_value={"instrumented": "h>tu>be>xhe"}, + accept_partial=True, + ) assert c.instrumented.to_native_representation() == "h>tu>be>xhe" # parent configuration is not resolved assert not c.is_resolved() @@ -409,7 +494,9 @@ def test_provider_values_over_embedded_default(environment: Any) -> None: with custom_environ({"INSTRUMENTED": "h>tu>u>be>he"}): # read from env - over the default values emb = InstrumentedConfiguration().parse_native_representation("h>tu>be>xhe") - c = resolve.resolve_configuration(EmbeddedConfiguration(instrumented=emb), accept_partial=True) + c = resolve.resolve_configuration( + EmbeddedConfiguration(instrumented=emb), accept_partial=True + ) assert c.instrumented.to_native_representation() == "h>tu>u>be>he" # parent configuration is not resolved assert not c.is_resolved() @@ -426,30 +513,30 @@ def test_run_configuration_gen_name(environment: Any) -> None: def test_configuration_is_mutable_mapping(environment: Any, env_provider: ConfigProvider) -> None: - @configspec class _SecretCredentials(RunConfiguration): pipeline_name: Optional[str] = "secret" secret_value: TSecretValue = None config_files_storage_path: str = "storage" - # configurations provide full MutableMapping support # here order of items in dict matters expected_dict = { - 'pipeline_name': 'secret', - 'sentry_dsn': None, - 'slack_incoming_hook': None, - 'dlthub_telemetry': True, - 'dlthub_telemetry_segment_write_key': 'TLJiyRkGVZGCi2TtjClamXpFcxAA1rSB', - 'log_format': '{asctime}|[{levelname:<21}]|{process}|{name}|{filename}|{funcName}:{lineno}|{message}', - 'log_level': 'WARNING', - 'request_timeout': 60, - 'request_max_attempts': 5, - 'request_backoff_factor': 1, - 'request_max_retry_delay': 300, - 'config_files_storage_path': 'storage', - "secret_value": None + "pipeline_name": "secret", + "sentry_dsn": None, + "slack_incoming_hook": None, + "dlthub_telemetry": True, + "dlthub_telemetry_segment_write_key": "TLJiyRkGVZGCi2TtjClamXpFcxAA1rSB", + "log_format": ( + "{asctime}|[{levelname:<21}]|{process}|{name}|{filename}|{funcName}:{lineno}|{message}" + ), + "log_level": "WARNING", + "request_timeout": 60, + "request_max_attempts": 5, + "request_backoff_factor": 1, + "request_max_retry_delay": 300, + "config_files_storage_path": "storage", + "secret_value": None, } assert dict(_SecretCredentials()) == expected_dict @@ -513,9 +600,10 @@ def test_init_method_gen(environment: Any) -> None: def test_multi_derivation_defaults(environment: Any) -> None: - @configspec - class MultiConfiguration(SectionedConfiguration, MockProdConfiguration, ConfigurationWithOptionalTypes): + class MultiConfiguration( + SectionedConfiguration, MockProdConfiguration, ConfigurationWithOptionalTypes + ): pass # apparently dataclasses set default in reverse mro so MockProdConfiguration overwrites @@ -552,12 +640,19 @@ def test_raises_on_many_unresolved_fields(environment: Any, env_provider: Config resolve.resolve_configuration(CoercionTestConfiguration()) assert cf_missing_exc.value.spec_name == "CoercionTestConfiguration" # get all fields that must be set - val_fields = [f for f in CoercionTestConfiguration().get_resolvable_fields() if f.lower().endswith("_val")] + val_fields = [ + f for f in CoercionTestConfiguration().get_resolvable_fields() if f.lower().endswith("_val") + ] traces = cf_missing_exc.value.traces assert len(traces) == len(val_fields) for tr_field, exp_field in zip(traces, val_fields): assert len(traces[tr_field]) == 1 - assert traces[tr_field][0] == LookupTrace("Environment Variables", [], environ_provider.EnvironProvider.get_key_name(exp_field), None) + assert traces[tr_field][0] == LookupTrace( + "Environment Variables", + [], + environ_provider.EnvironProvider.get_key_name(exp_field), + None, + ) # assert traces[tr_field][1] == LookupTrace("secrets.toml", [], toml.TomlFileProvider.get_key_name(exp_field), None) # assert traces[tr_field][2] == LookupTrace("config.toml", [], toml.TomlFileProvider.get_key_name(exp_field), None) @@ -569,7 +664,9 @@ def test_accepts_optional_missing_fields(environment: Any) -> None: # make optional config resolve.resolve_configuration(ConfigurationWithOptionalTypes()) # make config with optional values - resolve.resolve_configuration(ProdConfigurationWithOptionalTypes(), explicit_value={"int_val": None}) + resolve.resolve_configuration( + ProdConfigurationWithOptionalTypes(), explicit_value={"int_val": None} + ) # make config with optional embedded config C = resolve.resolve_configuration(EmbeddedOptionalConfiguration()) # embedded config was not fully resolved @@ -579,14 +676,18 @@ def test_accepts_optional_missing_fields(environment: Any) -> None: def test_find_all_keys() -> None: keys = VeryWrongConfiguration().get_resolvable_fields() # assert hints and types: LOG_COLOR had it hint overwritten in derived class - assert set({'str_val': str, 'int_val': int, 'NoneConfigVar': str, 'log_color': str}.items()).issubset(keys.items()) + assert set( + {"str_val": str, "int_val": int, "NoneConfigVar": str, "log_color": str}.items() + ).issubset(keys.items()) def test_coercion_to_hint_types(environment: Any) -> None: add_config_dict_to_env(COERCIONS) C = CoercionTestConfiguration() - resolve._resolve_config_fields(C, explicit_values=None, explicit_sections=(), embedded_sections=(), accept_partial=False) + resolve._resolve_config_fields( + C, explicit_values=None, explicit_sections=(), embedded_sections=(), accept_partial=False + ) for key in COERCIONS: assert getattr(C, key) == COERCIONS[key] @@ -647,7 +748,13 @@ def test_invalid_coercions(environment: Any) -> None: add_config_dict_to_env(INVALID_COERCIONS) for key, value in INVALID_COERCIONS.items(): try: - resolve._resolve_config_fields(C, explicit_values=None, explicit_sections=(), embedded_sections=(), accept_partial=False) + resolve._resolve_config_fields( + C, + explicit_values=None, + explicit_sections=(), + embedded_sections=(), + accept_partial=False, + ) except ConfigValueCannotBeCoercedException as coerc_exc: # must fail exactly on expected value if coerc_exc.field_name != key: @@ -662,7 +769,9 @@ def test_excepted_coercions(environment: Any) -> None: C = CoercionTestConfiguration() add_config_dict_to_env(COERCIONS) add_config_dict_to_env(EXCEPTED_COERCIONS, overwrite_keys=True) - resolve._resolve_config_fields(C, explicit_values=None, explicit_sections=(), embedded_sections=(), accept_partial=False) + resolve._resolve_config_fields( + C, explicit_values=None, explicit_sections=(), embedded_sections=(), accept_partial=False + ) for key in EXCEPTED_COERCIONS: assert getattr(C, key) == COERCED_EXCEPTIONS[key] @@ -674,6 +783,7 @@ def test_config_with_unsupported_types_in_hints(environment: Any) -> None: class InvalidHintConfiguration(BaseConfiguration): tuple_val: tuple = None # type: ignore set_val: set = None # type: ignore + InvalidHintConfiguration() @@ -683,6 +793,7 @@ def test_config_with_no_hints(environment: Any) -> None: @configspec class NoHintConfiguration(BaseConfiguration): tuple_val = None + NoHintConfiguration() @@ -691,8 +802,8 @@ def test_config_with_non_templated_complex_hints(environment: Any) -> None: environment["TUPLE_VAL"] = "(1,2,3)" environment["DICT_VAL"] = '{"a": 1}' c = resolve.resolve_configuration(NonTemplatedComplexTypesConfiguration()) - assert c.list_val == [1,2,3] - assert c.tuple_val == (1,2,3) + assert c.list_val == [1, 2, 3] + assert c.tuple_val == (1, 2, 3) assert c.dict_val == {"a": 1} @@ -706,7 +817,7 @@ def test_resolve_configuration(environment: Any) -> None: def test_dataclass_instantiation(environment: Any) -> None: # resolve_configuration works on instances of dataclasses and types are not modified - environment['SECRET_VALUE'] = "1" + environment["SECRET_VALUE"] = "1" C = resolve.resolve_configuration(SecretConfiguration()) # auto derived type holds the value assert C.secret_value == "1" @@ -766,7 +877,6 @@ def test_is_valid_hint() -> None: def test_configspec_auto_base_config_derivation() -> None: - @configspec class AutoBaseDerivationConfiguration: auto: str @@ -857,36 +967,71 @@ def test_last_resolve_exception(environment: Any) -> None: def test_resolved_trace(environment: Any) -> None: with custom_environ( - {"INSTRUMENTED__HEAD": "h", "INSTRUMENTED__TUBE": '["tu", "u", "be"]', "INSTRUMENTED__HEELS": "xhe", "SECTIONED__PASSWORD": "passwd", "DEFAULT": "DEF"}): + { + "INSTRUMENTED__HEAD": "h", + "INSTRUMENTED__TUBE": '["tu", "u", "be"]', + "INSTRUMENTED__HEELS": "xhe", + "SECTIONED__PASSWORD": "passwd", + "DEFAULT": "DEF", + } + ): c = resolve.resolve_configuration(EmbeddedConfiguration(default="_DEFF")) traces = get_resolved_traces() prov_name = environ_provider.EnvironProvider().name - assert traces[".default"] == ResolvedValueTrace("default", "DEF", "_DEFF", str, [], prov_name, c) - assert traces["instrumented.head"] == ResolvedValueTrace("head", "h", None, str, ["instrumented"], prov_name, c.instrumented) + assert traces[".default"] == ResolvedValueTrace( + "default", "DEF", "_DEFF", str, [], prov_name, c + ) + assert traces["instrumented.head"] == ResolvedValueTrace( + "head", "h", None, str, ["instrumented"], prov_name, c.instrumented + ) # value is before casting - assert traces["instrumented.tube"] == ResolvedValueTrace("tube", '["tu", "u", "be"]', None, List[str], ["instrumented"], prov_name, c.instrumented) - assert deserialize_value("tube", traces["instrumented.tube"].value, resolve.extract_inner_hint(List[str])) == ["tu", "u", "be"] - assert traces["instrumented.heels"] == ResolvedValueTrace("heels", "xhe", None, str, ["instrumented"], prov_name, c.instrumented) - assert traces["sectioned.password"] == ResolvedValueTrace("password", "passwd", None, str, ["sectioned"], prov_name, c.sectioned) + assert traces["instrumented.tube"] == ResolvedValueTrace( + "tube", '["tu", "u", "be"]', None, List[str], ["instrumented"], prov_name, c.instrumented + ) + assert deserialize_value( + "tube", traces["instrumented.tube"].value, resolve.extract_inner_hint(List[str]) + ) == ["tu", "u", "be"] + assert traces["instrumented.heels"] == ResolvedValueTrace( + "heels", "xhe", None, str, ["instrumented"], prov_name, c.instrumented + ) + assert traces["sectioned.password"] == ResolvedValueTrace( + "password", "passwd", None, str, ["sectioned"], prov_name, c.sectioned + ) assert len(traces) == 5 # try to get native representation with patch.object(InstrumentedConfiguration, "__section__", "snake"): with custom_environ( - {"INSTRUMENTED": "h>t>t>t>he", "SECTIONED__PASSWORD": "pass", "DEFAULT": "UNDEF", "SNAKE": "h>t>t>t>he"}): + { + "INSTRUMENTED": "h>t>t>t>he", + "SECTIONED__PASSWORD": "pass", + "DEFAULT": "UNDEF", + "SNAKE": "h>t>t>t>he", + } + ): c = resolve.resolve_configuration(EmbeddedConfiguration()) resolve.resolve_configuration(InstrumentedConfiguration()) assert traces[".default"] == ResolvedValueTrace("default", "UNDEF", None, str, [], prov_name, c) - assert traces[".instrumented"] == ResolvedValueTrace("instrumented", "h>t>t>t>he", None, InstrumentedConfiguration, [], prov_name, c) + assert traces[".instrumented"] == ResolvedValueTrace( + "instrumented", "h>t>t>t>he", None, InstrumentedConfiguration, [], prov_name, c + ) - assert traces[".snake"] == ResolvedValueTrace("snake", "h>t>t>t>he", None, InstrumentedConfiguration, [], prov_name, None) + assert traces[".snake"] == ResolvedValueTrace( + "snake", "h>t>t>t>he", None, InstrumentedConfiguration, [], prov_name, None + ) def test_extract_inner_hint() -> None: # extracts base config from an union - assert resolve.extract_inner_hint(Union[GcpServiceAccountCredentialsWithoutDefaults, StrAny, str]) is GcpServiceAccountCredentialsWithoutDefaults - assert resolve.extract_inner_hint(Union[InstrumentedConfiguration, StrAny, str]) is InstrumentedConfiguration + assert ( + resolve.extract_inner_hint(Union[GcpServiceAccountCredentialsWithoutDefaults, StrAny, str]) + is GcpServiceAccountCredentialsWithoutDefaults + ) + assert ( + resolve.extract_inner_hint(Union[InstrumentedConfiguration, StrAny, str]) + is InstrumentedConfiguration + ) # keeps unions assert resolve.extract_inner_hint(Union[StrAny, str]) is Union # ignores specialization in list and dict, leaving origin @@ -908,7 +1053,10 @@ def test_is_secret_hint() -> None: TTestSecretNt = NewType("TTestSecretNt", GcpServiceAccountCredentialsWithoutDefaults) assert resolve.is_secret_hint(TTestSecretNt) is False # recognize unions with credentials - assert resolve.is_secret_hint(Union[GcpServiceAccountCredentialsWithoutDefaults, StrAny, str]) is True + assert ( + resolve.is_secret_hint(Union[GcpServiceAccountCredentialsWithoutDefaults, StrAny, str]) + is True + ) # we do not recognize unions if they do not contain configuration types assert resolve.is_secret_hint(Union[TSecretValue, StrAny, str]) is False assert resolve.is_secret_hint(Optional[str]) is False @@ -928,49 +1076,47 @@ def coerce_single_value(key: str, value: str, hint: Type[Any]) -> Any: def test_dynamic_type_hint(environment: Dict[str, str]) -> None: - """Test dynamic type hint using @resolve_type decorator - """ - environment['DUMMY__DISCRIMINATOR'] = 'b' - environment['DUMMY__EMBEDDED_CONFIG__FIELD_FOR_B'] = 'some_value' + """Test dynamic type hint using @resolve_type decorator""" + environment["DUMMY__DISCRIMINATOR"] = "b" + environment["DUMMY__EMBEDDED_CONFIG__FIELD_FOR_B"] = "some_value" - config = resolve.resolve_configuration(ConfigWithDynamicType(), sections=('dummy', )) + config = resolve.resolve_configuration(ConfigWithDynamicType(), sections=("dummy",)) assert isinstance(config.embedded_config, DynamicConfigB) - assert config.embedded_config.field_for_b == 'some_value' + assert config.embedded_config.field_for_b == "some_value" def test_dynamic_type_hint_subclass(environment: Dict[str, str]) -> None: - """Test overriding @resolve_type method in subclass - """ - environment['DUMMY__IS_NUMBER'] = 'true' - environment['DUMMY__DYNAMIC_TYPE_FIELD'] = '22' + """Test overriding @resolve_type method in subclass""" + environment["DUMMY__IS_NUMBER"] = "true" + environment["DUMMY__DYNAMIC_TYPE_FIELD"] = "22" # Test extended resolver method is applied - environment['DUMMY__DISCRIMINATOR'] = 'c' - environment['DUMMY__EMBEDDED_CONFIG__FIELD_FOR_C'] = 'some_value' + environment["DUMMY__DISCRIMINATOR"] = "c" + environment["DUMMY__EMBEDDED_CONFIG__FIELD_FOR_C"] = "some_value" - config = resolve.resolve_configuration(SubclassConfigWithDynamicType(), sections=('dummy', )) + config = resolve.resolve_configuration(SubclassConfigWithDynamicType(), sections=("dummy",)) assert isinstance(config.embedded_config, DynamicConfigC) - assert config.embedded_config.field_for_c == 'some_value' + assert config.embedded_config.field_for_c == "some_value" # Test super() call is applied correctly - environment['DUMMY__DISCRIMINATOR'] = 'b' - environment['DUMMY__EMBEDDED_CONFIG__FIELD_FOR_B'] = 'some_value' + environment["DUMMY__DISCRIMINATOR"] = "b" + environment["DUMMY__EMBEDDED_CONFIG__FIELD_FOR_B"] = "some_value" - config = resolve.resolve_configuration(SubclassConfigWithDynamicType(), sections=('dummy', )) + config = resolve.resolve_configuration(SubclassConfigWithDynamicType(), sections=("dummy",)) assert isinstance(config.embedded_config, DynamicConfigB) - assert config.embedded_config.field_for_b == 'some_value' + assert config.embedded_config.field_for_b == "some_value" # Test second dynamic field added in subclass - environment['DUMMY__IS_NUMBER'] = 'true' - environment['DUMMY__DYNAMIC_TYPE_FIELD'] = 'some' + environment["DUMMY__IS_NUMBER"] = "true" + environment["DUMMY__DYNAMIC_TYPE_FIELD"] = "some" with pytest.raises(ConfigValueCannotBeCoercedException) as e: - config = resolve.resolve_configuration(SubclassConfigWithDynamicType(), sections=('dummy', )) + config = resolve.resolve_configuration(SubclassConfigWithDynamicType(), sections=("dummy",)) - assert e.value.field_name == 'dynamic_type_field' + assert e.value.field_name == "dynamic_type_field" assert e.value.hint == int @@ -985,30 +1131,53 @@ def test_unmatched_dynamic_hint_resolvers(environment: Dict[str, str]) -> None: def test_add_config_to_env(environment: Dict[str, str]) -> None: - c = resolve.resolve_configuration(EmbeddedConfiguration(instrumented="h>tu>u>be>he", sectioned=SectionedConfiguration(password="PASS"), default="BUBA")) - add_config_to_env(c, ("dlt", )) + c = resolve.resolve_configuration( + EmbeddedConfiguration( + instrumented="h>tu>u>be>he", + sectioned=SectionedConfiguration(password="PASS"), + default="BUBA", + ) + ) + add_config_to_env(c, ("dlt",)) # must contain DLT prefix everywhere, INSTRUMENTED section taken from key and DLT_TEST taken from password - assert environment.items() >= { - 'DLT__DEFAULT': 'BUBA', - 'DLT__INSTRUMENTED__HEAD': 'h', 'DLT__INSTRUMENTED__TUBE': '["tu","u","be"]', 'DLT__INSTRUMENTED__HEELS': 'he', - 'DLT__DLT_TEST__PASSWORD': 'PASS' - }.items() + assert ( + environment.items() + >= { + "DLT__DEFAULT": "BUBA", + "DLT__INSTRUMENTED__HEAD": "h", + "DLT__INSTRUMENTED__TUBE": '["tu","u","be"]', + "DLT__INSTRUMENTED__HEELS": "he", + "DLT__DLT_TEST__PASSWORD": "PASS", + }.items() + ) # no dlt environment.clear() add_config_to_env(c) - assert environment.items() == { - 'DEFAULT': 'BUBA', - 'INSTRUMENTED__HEAD': 'h', 'INSTRUMENTED__TUBE': '["tu","u","be"]', 'INSTRUMENTED__HEELS': 'he', - 'DLT_TEST__PASSWORD': 'PASS' - }.items() + assert ( + environment.items() + == { + "DEFAULT": "BUBA", + "INSTRUMENTED__HEAD": "h", + "INSTRUMENTED__TUBE": '["tu","u","be"]', + "INSTRUMENTED__HEELS": "he", + "DLT_TEST__PASSWORD": "PASS", + }.items() + ) # starts with sectioned environment.clear() add_config_to_env(c.sectioned) - assert environment == {'DLT_TEST__PASSWORD': 'PASS'} + assert environment == {"DLT_TEST__PASSWORD": "PASS"} def test_configuration_copy() -> None: - c = resolve.resolve_configuration(EmbeddedConfiguration(), explicit_value={"default": "set", "instrumented": "h>tu>be>xhe", "sectioned": {"password": "pwd"}}) + c = resolve.resolve_configuration( + EmbeddedConfiguration(), + explicit_value={ + "default": "set", + "instrumented": "h>tu>be>xhe", + "sectioned": {"password": "pwd"}, + }, + ) assert c.is_resolved() copy_c = c.copy() assert copy_c.is_resolved() @@ -1021,7 +1190,9 @@ def test_configuration_copy() -> None: cred.parse_native_representation("postgresql://loader:loader@localhost:5432/dlt_data") copy_cred = cred.copy() assert dict(copy_cred) == dict(cred) - assert copy_cred.to_native_representation() == "postgresql://loader:loader@localhost:5432/dlt_data" + assert ( + copy_cred.to_native_representation() == "postgresql://loader:loader@localhost:5432/dlt_data" + ) # resolve the copy assert not copy_cred.is_resolved() resolved_cred_copy = c = resolve.resolve_configuration(copy_cred) @@ -1029,7 +1200,6 @@ def test_configuration_copy() -> None: def test_configuration_with_configuration_as_default() -> None: - instrumented_default = InstrumentedConfiguration() instrumented_default.parse_native_representation("h>a>b>he") cred = ConnectionStringCredentials() diff --git a/tests/common/configuration/test_container.py b/tests/common/configuration/test_container.py index 559e7b480a..f3db5b0bdf 100644 --- a/tests/common/configuration/test_container.py +++ b/tests/common/configuration/test_container.py @@ -1,15 +1,19 @@ -import pytest from typing import Any, ClassVar, Literal, Optional +import pytest +from tests.common.configuration.utils import environment +from tests.utils import preserve_environ + from dlt.common.configuration import configspec +from dlt.common.configuration.container import Container +from dlt.common.configuration.exceptions import ( + ConfigFieldMissingException, + ContainerInjectableContextMangled, + ContextDefaultCannotBeCreated, +) from dlt.common.configuration.providers.context import ContextProvider from dlt.common.configuration.resolve import resolve_configuration from dlt.common.configuration.specs import BaseConfiguration, ContainerInjectableContext -from dlt.common.configuration.container import Container -from dlt.common.configuration.exceptions import ConfigFieldMissingException, ContainerInjectableContextMangled, ContextDefaultCannotBeCreated - -from tests.utils import preserve_environ -from tests.common.configuration.utils import environment @configspec @@ -27,7 +31,6 @@ class EmbeddedWithInjectableContext(BaseConfiguration): @configspec class NoDefaultInjectableContext(ContainerInjectableContext): - can_create_default: ClassVar[bool] = False diff --git a/tests/common/configuration/test_credentials.py b/tests/common/configuration/test_credentials.py index adf5ac829d..45db501ed6 100644 --- a/tests/common/configuration/test_credentials.py +++ b/tests/common/configuration/test_credentials.py @@ -2,17 +2,29 @@ from typing import Any, Dict import pytest +from tests.common.configuration.utils import environment +from tests.common.utils import json_case_path +from tests.utils import preserve_environ + from dlt.common.configuration import resolve_configuration from dlt.common.configuration.exceptions import ConfigFieldMissingException -from dlt.common.configuration.specs import ConnectionStringCredentials, GcpServiceAccountCredentialsWithoutDefaults, GcpServiceAccountCredentials, GcpOAuthCredentialsWithoutDefaults, GcpOAuthCredentials, AwsCredentials -from dlt.common.configuration.specs.exceptions import InvalidConnectionString, InvalidGoogleNativeCredentialsType, InvalidGoogleOauth2Json, InvalidGoogleServicesJson, OAuth2ScopesRequired +from dlt.common.configuration.specs import ( + AwsCredentials, + ConnectionStringCredentials, + GcpOAuthCredentials, + GcpOAuthCredentialsWithoutDefaults, + GcpServiceAccountCredentials, + GcpServiceAccountCredentialsWithoutDefaults, +) +from dlt.common.configuration.specs.exceptions import ( + InvalidConnectionString, + InvalidGoogleNativeCredentialsType, + InvalidGoogleOauth2Json, + InvalidGoogleServicesJson, + OAuth2ScopesRequired, +) from dlt.common.configuration.specs.run_configuration import RunConfiguration -from tests.utils import preserve_environ -from tests.common.utils import json_case_path -from tests.common.configuration.utils import environment - - SERVICE_JSON = """ { "type": "service_account", @@ -155,7 +167,10 @@ def test_gcp_service_credentials_native_representation(environment) -> None: assert GcpServiceAccountCredentials.__config_gen_annotations__ == [] gcpc = GcpServiceAccountCredentials() - gcpc.parse_native_representation(SERVICE_JSON % '"private_key": "-----BEGIN PRIVATE KEY-----\\n\\n-----END PRIVATE KEY-----\\n",') + gcpc.parse_native_representation( + SERVICE_JSON + % '"private_key": "-----BEGIN PRIVATE KEY-----\\n\\n-----END PRIVATE KEY-----\\n",' + ) assert gcpc.private_key == "-----BEGIN PRIVATE KEY-----\n\n-----END PRIVATE KEY-----\n" assert gcpc.project_id == "chat-analytics" assert gcpc.client_email == "loader@iam.gserviceaccount.com" @@ -191,7 +206,6 @@ def test_gcp_service_credentials_resolved_from_native_representation(environment def test_gcp_oauth_credentials_native_representation(environment) -> None: - with pytest.raises(InvalidGoogleNativeCredentialsType): GcpOAuthCredentials().parse_native_representation(1) @@ -205,13 +219,15 @@ def test_gcp_oauth_credentials_native_representation(environment) -> None: # but is not partial - all required fields are present assert gcoauth.is_partial() is False assert gcoauth.project_id == "level-dragon-333983" - assert gcoauth.client_id == "921382012504-3mtjaj1s7vuvf53j88mgdq4te7akkjm3.apps.googleusercontent.com" + assert ( + gcoauth.client_id + == "921382012504-3mtjaj1s7vuvf53j88mgdq4te7akkjm3.apps.googleusercontent.com" + ) assert gcoauth.client_secret == "gOCSPX-XdY5znbrvjSMEG3pkpA_GHuLPPth" assert gcoauth.refresh_token == "refresh_token" assert gcoauth.token is None assert gcoauth.scopes == ["email", "service"] - # get native representation, it will also location _repr = gcoauth.to_native_representation() assert "localhost" in _repr @@ -289,16 +305,16 @@ def test_run_configuration_slack_credentials(environment: Any) -> None: def test_aws_credentials_resolved(environment: Dict[str, str]) -> None: - environment['CREDENTIALS__AWS_ACCESS_KEY_ID'] = 'fake_access_key' - environment['CREDENTIALS__AWS_SECRET_ACCESS_KEY'] = 'fake_secret_key' - environment['CREDENTIALS__AWS_SESSION_TOKEN'] = 'fake_session_token' - environment['CREDENTIALS__PROFILE_NAME'] = 'fake_profile' - environment['CREDENTIALS__REGION_NAME'] = 'eu-central' + environment["CREDENTIALS__AWS_ACCESS_KEY_ID"] = "fake_access_key" + environment["CREDENTIALS__AWS_SECRET_ACCESS_KEY"] = "fake_secret_key" + environment["CREDENTIALS__AWS_SESSION_TOKEN"] = "fake_session_token" + environment["CREDENTIALS__PROFILE_NAME"] = "fake_profile" + environment["CREDENTIALS__REGION_NAME"] = "eu-central" config = resolve_configuration(AwsCredentials()) - assert config.aws_access_key_id == 'fake_access_key' - assert config.aws_secret_access_key == 'fake_secret_key' - assert config.aws_session_token == 'fake_session_token' - assert config.profile_name == 'fake_profile' + assert config.aws_access_key_id == "fake_access_key" + assert config.aws_secret_access_key == "fake_secret_key" + assert config.aws_session_token == "fake_session_token" + assert config.profile_name == "fake_profile" assert config.region_name == "eu-central" diff --git a/tests/common/configuration/test_environ_provider.py b/tests/common/configuration/test_environ_provider.py index ccac6c54eb..71a1966c0f 100644 --- a/tests/common/configuration/test_environ_provider.py +++ b/tests/common/configuration/test_environ_provider.py @@ -1,13 +1,18 @@ -import pytest from typing import Any -from dlt.common.typing import TSecretValue -from dlt.common.configuration import configspec, ConfigFieldMissingException, ConfigFileNotFoundException, resolve -from dlt.common.configuration.specs import RunConfiguration, BaseConfiguration -from dlt.common.configuration.providers import environ as environ_provider - +import pytest +from tests.common.configuration.utils import SecretConfiguration, WrongConfiguration, environment from tests.utils import preserve_environ -from tests.common.configuration.utils import WrongConfiguration, SecretConfiguration, environment + +from dlt.common.configuration import ( + ConfigFieldMissingException, + ConfigFileNotFoundException, + configspec, + resolve, +) +from dlt.common.configuration.providers import environ as environ_provider +from dlt.common.configuration.specs import BaseConfiguration, RunConfiguration +from dlt.common.typing import TSecretValue @configspec @@ -27,22 +32,25 @@ class MockProdRunConfigurationVar(RunConfiguration): pipeline_name: str = "comp" - def test_resolves_from_environ(environment: Any) -> None: environment["NONECONFIGVAR"] = "Some" C = WrongConfiguration() - resolve._resolve_config_fields(C, explicit_values=None, explicit_sections=(), embedded_sections=(), accept_partial=False) + resolve._resolve_config_fields( + C, explicit_values=None, explicit_sections=(), embedded_sections=(), accept_partial=False + ) assert not C.is_partial() assert C.NoneConfigVar == environment["NONECONFIGVAR"] def test_resolves_from_environ_with_coercion(environment: Any) -> None: - environment["RUNTIME__TEST_BOOL"] = 'yes' + environment["RUNTIME__TEST_BOOL"] = "yes" C = SimpleRunConfiguration() - resolve._resolve_config_fields(C, explicit_values=None, explicit_sections=(), embedded_sections=(), accept_partial=False) + resolve._resolve_config_fields( + C, explicit_values=None, explicit_sections=(), embedded_sections=(), accept_partial=False + ) assert not C.is_partial() # value will be coerced to bool @@ -52,13 +60,13 @@ def test_resolves_from_environ_with_coercion(environment: Any) -> None: def test_secret(environment: Any) -> None: with pytest.raises(ConfigFieldMissingException): resolve.resolve_configuration(SecretConfiguration()) - environment['SECRET_VALUE'] = "1" + environment["SECRET_VALUE"] = "1" C = resolve.resolve_configuration(SecretConfiguration()) assert C.secret_value == "1" # mock the path to point to secret storage # from dlt.common.configuration import config_utils path = environ_provider.SECRET_STORAGE_PATH - del environment['SECRET_VALUE'] + del environment["SECRET_VALUE"] try: # must read a secret file environ_provider.SECRET_STORAGE_PATH = "./tests/common/cases/%s" @@ -66,13 +74,13 @@ def test_secret(environment: Any) -> None: assert C.secret_value == "BANANA" # set some weird path, no secret file at all - del environment['SECRET_VALUE'] + del environment["SECRET_VALUE"] environ_provider.SECRET_STORAGE_PATH = "!C:\\PATH%s" with pytest.raises(ConfigFieldMissingException): resolve.resolve_configuration(SecretConfiguration()) # set env which is a fallback for secret not as file - environment['SECRET_VALUE'] = "1" + environment["SECRET_VALUE"] = "1" C = resolve.resolve_configuration(SecretConfiguration()) assert C.secret_value == "1" finally: @@ -87,7 +95,7 @@ def test_secret_kube_fallback(environment: Any) -> None: # all unix editors will add x10 at the end of file, it will be preserved assert C.secret_kube == "kube\n" # we propagate secrets back to environ and strip the whitespace - assert environment['SECRET_KUBE'] == "kube" + assert environment["SECRET_KUBE"] == "kube" finally: environ_provider.SECRET_STORAGE_PATH = path @@ -99,7 +107,10 @@ def test_configuration_files(environment: Any) -> None: assert C.config_files_storage_path == environment["RUNTIME__CONFIG_FILES_STORAGE_PATH"] assert C.has_configuration_file("hasn't") is False assert C.has_configuration_file("event.schema.json") is True - assert C.get_configuration_file_path("event.schema.json") == "./tests/common/cases/schemas/ev1/event.schema.json" + assert ( + C.get_configuration_file_path("event.schema.json") + == "./tests/common/cases/schemas/ev1/event.schema.json" + ) with C.open_configuration_file("event.schema.json", "r") as f: f.read() with pytest.raises(ConfigFileNotFoundException): diff --git a/tests/common/configuration/test_inject.py b/tests/common/configuration/test_inject.py index 8070a5be9c..414bac20a7 100644 --- a/tests/common/configuration/test_inject.py +++ b/tests/common/configuration/test_inject.py @@ -1,27 +1,29 @@ import os from typing import Any, Dict, Optional, Type, Union + import pytest +from tests.common.configuration.utils import environment, toml_providers +from tests.utils import preserve_environ import dlt - from dlt.common.configuration.exceptions import ConfigFieldMissingException from dlt.common.configuration.inject import get_fun_spec, last_config, with_config from dlt.common.configuration.providers import EnvironProvider from dlt.common.configuration.providers.toml import SECRETS_TOML from dlt.common.configuration.resolve import inject_section -from dlt.common.configuration.specs import BaseConfiguration, GcpServiceAccountCredentialsWithoutDefaults, ConnectionStringCredentials +from dlt.common.configuration.specs import ( + BaseConfiguration, + ConnectionStringCredentials, + GcpServiceAccountCredentialsWithoutDefaults, +) from dlt.common.configuration.specs.base_configuration import is_secret_hint from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext from dlt.common.configuration.specs.config_section_context import ConfigSectionContext from dlt.common.reflection.spec import _get_spec_name_from_f from dlt.common.typing import StrAny, TSecretValue, is_newtype_type -from tests.utils import preserve_environ -from tests.common.configuration.utils import environment, toml_providers - def test_arguments_are_explicit(environment: Any) -> None: - @with_config def f_var(user=dlt.config.value, path=dlt.config.value): # explicit args "survive" the injection: they have precedence over env @@ -43,7 +45,6 @@ def f_var_env(user=dlt.config.value, path=dlt.config.value): def test_default_values_are_resolved(environment: Any) -> None: - @with_config def f_var(user=dlt.config.value, path="a/b/c"): assert user == "env user" @@ -54,7 +55,6 @@ def f_var(user=dlt.config.value, path="a/b/c"): def test_arguments_dlt_literal_defaults_are_required(environment: Any) -> None: - @with_config def f_config(user=dlt.config.value): assert user is not None @@ -84,7 +84,6 @@ def f_secret(password=dlt.secrets.value): def test_inject_from_argument_section(toml_providers: ConfigProvidersContext) -> None: - # `gcp_storage` is a key in `secrets.toml` and the default `credentials` section of GcpServiceAccountCredentialsWithoutDefaults must be replaced with it @with_config @@ -96,11 +95,12 @@ def f_credentials(gcp_storage: GcpServiceAccountCredentialsWithoutDefaults = dlt def test_inject_secret_value_secret_type(environment: Any) -> None: - @with_config - def f_custom_secret_type(_dict: Dict[str, Any] = dlt.secrets.value, _int: int = dlt.secrets.value, **kwargs: Any): + def f_custom_secret_type( + _dict: Dict[str, Any] = dlt.secrets.value, _int: int = dlt.secrets.value, **kwargs: Any + ): # secret values were coerced into types - assert _dict == {"a":1} + assert _dict == {"a": 1} assert _int == 1234 cfg = last_config(**kwargs) spec: Type[BaseConfiguration] = cfg.__class__ @@ -158,23 +158,24 @@ def test_inject_with_sections() -> None: def test_inject_with_sections_and_sections_context() -> None: - @with_config def no_sections(value=dlt.config.value): return value - @with_config(sections=("test", )) + @with_config(sections=("test",)) def test_sections(value=dlt.config.value): return value # a section context that prefers existing context - @with_config(sections=("test", ), sections_merge_style=ConfigSectionContext.prefer_existing) + @with_config(sections=("test",), sections_merge_style=ConfigSectionContext.prefer_existing) def test_sections_pref_existing(value=dlt.config.value): return value - # a section that wants context like dlt resource - @with_config(sections=("test", "module", "name"), sections_merge_style=ConfigSectionContext.resource_merge_style) + @with_config( + sections=("test", "module", "name"), + sections_merge_style=ConfigSectionContext.resource_merge_style, + ) def test_sections_like_resource(value=dlt.config.value): return value @@ -189,7 +190,7 @@ def test_sections_like_resource(value=dlt.config.value): assert test_sections_pref_existing() == "test_section" assert test_sections_like_resource() == "test_section" - with inject_section(ConfigSectionContext(sections=("injected", ))): + with inject_section(ConfigSectionContext(sections=("injected",))): # the "injected" section is applied to "no_section" func that has no sections assert no_sections() == "injected_section" # but not to "test" - it won't be overridden by section context @@ -198,7 +199,9 @@ def test_sections_like_resource(value=dlt.config.value): # this one explicitly prefers existing context assert test_sections_pref_existing() == "injected_section" - with inject_section(ConfigSectionContext(sections=("test", "existing_module", "existing_name"))): + with inject_section( + ConfigSectionContext(sections=("test", "existing_module", "existing_name")) + ): assert test_sections_like_resource() == "resource_style_injected" @@ -256,10 +259,13 @@ def test_initial_spec_from_arg_with_spec_type() -> None: pass -def test_use_most_specific_union_type(environment: Any, toml_providers: ConfigProvidersContext) -> None: - +def test_use_most_specific_union_type( + environment: Any, toml_providers: ConfigProvidersContext +) -> None: @with_config - def postgres_union(local_credentials: Union[ConnectionStringCredentials, str, StrAny] = dlt.secrets.value): + def postgres_union( + local_credentials: Union[ConnectionStringCredentials, str, StrAny] = dlt.secrets.value + ): return local_credentials @with_config @@ -267,7 +273,13 @@ def postgres_direct(local_credentials: ConnectionStringCredentials = dlt.secrets return local_credentials conn_str = "postgres://loader:loader@localhost:5432/dlt_data" - conn_dict = {"host": "localhost", "database": "dlt_test", "username": "loader", "password": "loader", "drivername": "postgresql"} + conn_dict = { + "host": "localhost", + "database": "dlt_test", + "username": "loader", + "password": "loader", + "drivername": "postgresql", + } conn_cred = ConnectionStringCredentials() conn_cred.parse_native_representation(conn_str) @@ -313,7 +325,6 @@ def postgres_direct(local_credentials: ConnectionStringCredentials = dlt.secrets def test_auto_derived_spec_type_name() -> None: - class AutoNameTest: @with_config def __init__(self, pos_par=dlt.secrets.value, /, kw_par=None) -> None: @@ -334,10 +345,13 @@ def stuff_test(pos_par, /, kw_par) -> None: pass # name is composed via __qualname__ of func - assert _get_spec_name_from_f(AutoNameTest.__init__) == "TestAutoDerivedSpecTypeNameAutoNameTestInitConfiguration" + assert ( + _get_spec_name_from_f(AutoNameTest.__init__) + == "TestAutoDerivedSpecTypeNameAutoNameTestInitConfiguration" + ) # synthesized spec present in current module assert "TestAutoDerivedSpecTypeNameAutoNameTestInitConfiguration" in globals() # instantiate C: BaseConfiguration = globals()["TestAutoDerivedSpecTypeNameAutoNameTestInitConfiguration"]() # pos_par converted to secrets, kw_par converted to optional - assert C.get_resolvable_fields() == {"pos_par": TSecretValue, "kw_par": Optional[Any]} \ No newline at end of file + assert C.get_resolvable_fields() == {"pos_par": TSecretValue, "kw_par": Optional[Any]} diff --git a/tests/common/configuration/test_providers.py b/tests/common/configuration/test_providers.py index 2408aae583..f8c7900c24 100644 --- a/tests/common/configuration/test_providers.py +++ b/tests/common/configuration/test_providers.py @@ -1,5 +1,6 @@ import pytest + @pytest.mark.skip("Not implemented") def test_providers_order() -> None: pass diff --git a/tests/common/configuration/test_sections.py b/tests/common/configuration/test_sections.py index 5fca251856..6c7479d476 100644 --- a/tests/common/configuration/test_sections.py +++ b/tests/common/configuration/test_sections.py @@ -1,14 +1,25 @@ -import pytest from typing import Any, Optional -from dlt.common.configuration.container import Container -from dlt.common.configuration import configspec, ConfigFieldMissingException, resolve, inject_section +import pytest +from tests.common.configuration.utils import ( + MockProvider, + SectionedConfiguration, + env_provider, + environment, + mock_provider, +) +from tests.utils import preserve_environ + +from dlt.common.configuration import ( + ConfigFieldMissingException, + configspec, + inject_section, + resolve, +) +from dlt.common.configuration.container import Container +from dlt.common.configuration.exceptions import LookupTrace from dlt.common.configuration.providers.provider import ConfigProvider from dlt.common.configuration.specs import BaseConfiguration, ConfigSectionContext -from dlt.common.configuration.exceptions import LookupTrace - -from tests.utils import preserve_environ -from tests.common.configuration.utils import MockProvider, SectionedConfiguration, environment, mock_provider, env_provider @configspec @@ -52,7 +63,9 @@ def test_sectioned_configuration(environment: Any, env_provider: ConfigProvider) traces = exc_val.value.traces["password"] # only one provider and section was tried assert len(traces) == 1 - assert traces[0] == LookupTrace("Environment Variables", ["DLT_TEST"], "DLT_TEST__PASSWORD", None) + assert traces[0] == LookupTrace( + "Environment Variables", ["DLT_TEST"], "DLT_TEST__PASSWORD", None + ) # assert traces[1] == LookupTrace("secrets.toml", ["DLT_TEST"], "DLT_TEST.password", None) # assert traces[2] == LookupTrace("config.toml", ["DLT_TEST"], "DLT_TEST.password", None) @@ -108,7 +121,14 @@ def test_explicit_sections_with_sectioned_config(mock_provider: MockProvider) -> assert mock_provider.last_sections == [("ns1",), (), ("ns1", "DLT_TEST"), ("DLT_TEST",)] mock_provider.reset_stats() resolve.resolve_configuration(SectionedConfiguration(), sections=("ns1", "ns2")) - assert mock_provider.last_sections == [("ns1", "ns2"), ("ns1",), (), ("ns1", "ns2", "DLT_TEST"), ("ns1", "DLT_TEST"), ("DLT_TEST",)] + assert mock_provider.last_sections == [ + ("ns1", "ns2"), + ("ns1",), + (), + ("ns1", "ns2", "DLT_TEST"), + ("ns1", "DLT_TEST"), + ("DLT_TEST",), + ] def test_overwrite_config_section_from_embedded(mock_provider: MockProvider) -> None: @@ -134,7 +154,13 @@ def test_explicit_sections_from_embedded_config(mock_provider: MockProvider) -> # embedded section inner of explicit mock_provider.reset_stats() resolve.resolve_configuration(EmbeddedConfiguration(), sections=("ns1",)) - assert mock_provider.last_sections == [("ns1", "sv_config",), ("sv_config",)] + assert mock_provider.last_sections == [ + ( + "ns1", + "sv_config", + ), + ("sv_config",), + ] def test_ignore_embedded_section_by_field_name(mock_provider: MockProvider) -> None: @@ -155,7 +181,11 @@ def test_ignore_embedded_section_by_field_name(mock_provider: MockProvider) -> N mock_provider.reset_stats() mock_provider.return_value_on = ("DLT_TEST",) resolve.resolve_configuration(EmbeddedWithIgnoredEmbeddedConfiguration()) - assert mock_provider.last_sections == [('ignored_embedded',), ('ignored_embedded', 'DLT_TEST'), ('DLT_TEST',)] + assert mock_provider.last_sections == [ + ("ignored_embedded",), + ("ignored_embedded", "DLT_TEST"), + ("DLT_TEST",), + ] def test_injected_sections(mock_provider: MockProvider) -> None: @@ -173,7 +203,12 @@ def test_injected_sections(mock_provider: MockProvider) -> None: mock_provider.reset_stats() mock_provider.return_value_on = ("DLT_TEST",) resolve.resolve_configuration(SectionedConfiguration()) - assert mock_provider.last_sections == [("inj-ns1",), (), ("inj-ns1", "DLT_TEST"), ("DLT_TEST",)] + assert mock_provider.last_sections == [ + ("inj-ns1",), + (), + ("inj-ns1", "DLT_TEST"), + ("DLT_TEST",), + ] # injected section inner of ns coming from embedded config mock_provider.reset_stats() mock_provider.return_value_on = () @@ -195,7 +230,7 @@ def test_section_context() -> None: with pytest.raises(ValueError): ConfigSectionContext(sections=()).source_name() with pytest.raises(ValueError): - ConfigSectionContext(sections=("sources", )).source_name() + ConfigSectionContext(sections=("sources",)).source_name() with pytest.raises(ValueError): ConfigSectionContext(sections=("sources", "modules")).source_name() @@ -220,7 +255,7 @@ def test_section_with_pipeline_name(mock_provider: MockProvider) -> None: # PIPE section is exhausted then another lookup without PIPE assert mock_provider.last_sections == [("PIPE", "ns1"), ("PIPE",), ("ns1",), ()] - mock_provider.return_value_on = ("PIPE", ) + mock_provider.return_value_on = ("PIPE",) mock_provider.reset_stats() resolve.resolve_configuration(SingleValConfiguration(), sections=("ns1",)) assert mock_provider.last_sections == [("PIPE", "ns1"), ("PIPE",)] @@ -236,10 +271,12 @@ def test_section_with_pipeline_name(mock_provider: MockProvider) -> None: mock_provider.reset_stats() resolve.resolve_configuration(SectionedConfiguration()) # first the whole SectionedConfiguration is looked under key DLT_TEST (sections: ('PIPE',), ()), then fields of SectionedConfiguration - assert mock_provider.last_sections == [('PIPE',), (), ("PIPE", "DLT_TEST"), ("DLT_TEST",)] + assert mock_provider.last_sections == [("PIPE",), (), ("PIPE", "DLT_TEST"), ("DLT_TEST",)] # with pipeline and injected sections - with container.injectable_context(ConfigSectionContext(pipeline_name="PIPE", sections=("inj-ns1",))): + with container.injectable_context( + ConfigSectionContext(pipeline_name="PIPE", sections=("inj-ns1",)) + ): mock_provider.return_value_on = () mock_provider.reset_stats() resolve.resolve_configuration(SingleValConfiguration()) diff --git a/tests/common/configuration/test_spec_union.py b/tests/common/configuration/test_spec_union.py index 3e6397f0d9..d47dc99a0a 100644 --- a/tests/common/configuration/test_spec_union.py +++ b/tests/common/configuration/test_spec_union.py @@ -1,22 +1,22 @@ import itertools import os +from typing import Any, Optional, Union + import pytest from sqlalchemy.engine import Engine, create_engine -from typing import Optional, Union, Any +from tests.common.configuration.utils import environment +from tests.utils import preserve_environ import dlt -from dlt.common.configuration.exceptions import InvalidNativeValue, ConfigFieldMissingException -from dlt.common.configuration.providers import EnvironProvider -from dlt.common.configuration.specs import CredentialsConfiguration, BaseConfiguration from dlt.common.configuration import configspec, resolve_configuration -from dlt.common.configuration.specs.gcp_credentials import GcpServiceAccountCredentials -from dlt.common.typing import TSecretValue -from dlt.common.configuration.specs.connection_string_credentials import ConnectionStringCredentials +from dlt.common.configuration.exceptions import ConfigFieldMissingException, InvalidNativeValue +from dlt.common.configuration.providers import EnvironProvider from dlt.common.configuration.resolve import initialize_credentials +from dlt.common.configuration.specs import BaseConfiguration, CredentialsConfiguration +from dlt.common.configuration.specs.connection_string_credentials import ConnectionStringCredentials from dlt.common.configuration.specs.exceptions import NativeValueError - -from tests.common.configuration.utils import environment -from tests.utils import preserve_environ +from dlt.common.configuration.specs.gcp_credentials import GcpServiceAccountCredentials +from dlt.common.typing import TSecretValue @configspec @@ -145,8 +145,17 @@ def test_unresolved_union() -> None: resolve_configuration(ZenConfig()) assert cfm_ex.value.fields == ["credentials"] # all the missing fields from all the union elements are present - checked_keys = set(t.key for t in itertools.chain(*cfm_ex.value.traces.values()) if t.provider == EnvironProvider().name) - assert checked_keys == {"CREDENTIALS__EMAIL", "CREDENTIALS__PASSWORD", "CREDENTIALS__API_KEY", "CREDENTIALS__API_SECRET"} + checked_keys = set( + t.key + for t in itertools.chain(*cfm_ex.value.traces.values()) + if t.provider == EnvironProvider().name + ) + assert checked_keys == { + "CREDENTIALS__EMAIL", + "CREDENTIALS__PASSWORD", + "CREDENTIALS__API_KEY", + "CREDENTIALS__API_SECRET", + } def test_union_decorator() -> None: @@ -154,7 +163,10 @@ def test_union_decorator() -> None: # this will generate equivalent of ZenConfig @dlt.source - def zen_source(credentials: Union[ZenApiKeyCredentials, ZenEmailCredentials, str] = dlt.secrets.value, some_option: bool = False): + def zen_source( + credentials: Union[ZenApiKeyCredentials, ZenEmailCredentials, str] = dlt.secrets.value, + some_option: bool = False, + ): # depending on what the user provides in config, ZenApiKeyCredentials or ZenEmailCredentials will be injected in credentials # both classes implement `auth` so you can always call it credentials.auth() @@ -169,16 +181,21 @@ def zen_source(credentials: Union[ZenApiKeyCredentials, ZenEmailCredentials, str # pass explicit dict assert list(zen_source(credentials={"email": "emx", "password": "pass"}))[0].email == "emx" - assert list(zen_source(credentials={"api_key": "🔑", "api_secret": ":secret:"}))[0].api_key == "🔑" + assert ( + list(zen_source(credentials={"api_key": "🔑", "api_secret": ":secret:"}))[0].api_key == "🔑" + ) # mixed credentials will not work with pytest.raises(ConfigFieldMissingException): - assert list(zen_source(credentials={"api_key": "🔑", "password": "pass"}))[0].api_key == "🔑" + assert ( + list(zen_source(credentials={"api_key": "🔑", "password": "pass"}))[0].api_key == "🔑" + ) class GoogleAnalyticsCredentialsBase(CredentialsConfiguration): """ The Base version of all the GoogleAnalyticsCredentials classes. """ + pass @@ -187,6 +204,7 @@ class GoogleAnalyticsCredentialsOAuth(GoogleAnalyticsCredentialsBase): """ This class is used to store credentials Google Analytics """ + client_id: str client_secret: TSecretValue project_id: TSecretValue @@ -195,23 +213,27 @@ class GoogleAnalyticsCredentialsOAuth(GoogleAnalyticsCredentialsBase): @dlt.source(max_table_nesting=2) -def google_analytics(credentials: Union[GoogleAnalyticsCredentialsOAuth, GcpServiceAccountCredentials] = dlt.secrets.value): +def google_analytics( + credentials: Union[ + GoogleAnalyticsCredentialsOAuth, GcpServiceAccountCredentials + ] = dlt.secrets.value +): yield dlt.resource([credentials], name="creds") def test_google_auth_union(environment: Any) -> None: info = { - "type" : "service_account", - "project_id" : "dlthub-analytics", - "private_key_id" : "45cbe97fbd3d756d55d4633a5a72d8530a05b993", - "private_key" : "-----BEGIN PRIVATE KEY-----\n\n-----END PRIVATE KEY-----\n", - "client_email" : "105150287833-compute@developer.gserviceaccount.com", - "client_id" : "106404499083406128146", - "auth_uri" : "https://accounts.google.com/o/oauth2/auth", - "token_uri" : "https://oauth2.googleapis.com/token", - "auth_provider_x509_cert_url" : "https://www.googleapis.com/oauth2/v1/certs", - "client_x509_cert_url" : "https://www.googleapis.com/robot/v1/metadata/x509/105150287833-compute%40developer.gserviceaccount.com" - } + "type": "service_account", + "project_id": "dlthub-analytics", + "private_key_id": "45cbe97fbd3d756d55d4633a5a72d8530a05b993", + "private_key": "-----BEGIN PRIVATE KEY-----\n\n-----END PRIVATE KEY-----\n", + "client_email": "105150287833-compute@developer.gserviceaccount.com", + "client_id": "106404499083406128146", + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": "https://oauth2.googleapis.com/token", + "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", + "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/105150287833-compute%40developer.gserviceaccount.com", + } credentials = list(google_analytics(credentials=info))[0] print(dict(credentials)) @@ -225,23 +247,23 @@ def sql_database(credentials: Union[ConnectionStringCredentials, Engine, str] = def test_union_concrete_type(environment: Any) -> None: # we can pass engine explicitly - engine = create_engine('sqlite:///:memory:', echo=True) + engine = create_engine("sqlite:///:memory:", echo=True) db = sql_database(credentials=engine) creds = list(db)[0] assert isinstance(creds, Engine) # we can pass valid connection string explicitly - db = sql_database(credentials='sqlite://user@/:memory:') + db = sql_database(credentials="sqlite://user@/:memory:") creds = list(db)[0] # but it is used as native value assert isinstance(creds, ConnectionStringCredentials) # pass instance of credentials - cn = ConnectionStringCredentials('sqlite://user@/:memory:') + cn = ConnectionStringCredentials("sqlite://user@/:memory:") db = sql_database(credentials=cn) # exactly that instance is returned assert list(db)[0] is cn # invalid cn with pytest.raises(InvalidNativeValue): - db = sql_database(credentials='?') + db = sql_database(credentials="?") with pytest.raises(InvalidNativeValue): db = sql_database(credentials=123) diff --git a/tests/common/configuration/test_toml_provider.py b/tests/common/configuration/test_toml_provider.py index 33582bb0a5..0eeb6a962e 100644 --- a/tests/common/configuration/test_toml_provider.py +++ b/tests/common/configuration/test_toml_provider.py @@ -1,24 +1,44 @@ +import datetime # noqa: I251 import os +from typing import Any + import pytest import tomlkit -from typing import Any -import datetime # noqa: I251 +from tests.common.configuration.utils import ( + COERCIONS, + CoercionTestConfiguration, + SecretConfiguration, + SecretCredentials, + WithCredentialsConfiguration, + environment, + toml_providers, +) +from tests.utils import preserve_environ import dlt -from dlt.common import pendulum, Decimal -from dlt.common.configuration import configspec, ConfigFieldMissingException, resolve +from dlt.common import Decimal, pendulum +from dlt.common.configuration import ConfigFieldMissingException, configspec, resolve from dlt.common.configuration.container import Container -from dlt.common.configuration.inject import with_config from dlt.common.configuration.exceptions import LookupTrace -from dlt.common.configuration.providers.toml import SECRETS_TOML, CONFIG_TOML, BaseTomlProvider, SecretsTomlProvider, ConfigTomlProvider, StringTomlProvider, TomlProviderReadException +from dlt.common.configuration.inject import with_config +from dlt.common.configuration.providers.toml import ( + CONFIG_TOML, + SECRETS_TOML, + BaseTomlProvider, + ConfigTomlProvider, + SecretsTomlProvider, + StringTomlProvider, + TomlProviderReadException, +) +from dlt.common.configuration.specs import ( + BaseConfiguration, + ConnectionStringCredentials, + GcpServiceAccountCredentialsWithoutDefaults, +) from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext -from dlt.common.configuration.specs import BaseConfiguration, GcpServiceAccountCredentialsWithoutDefaults, ConnectionStringCredentials from dlt.common.runners.configuration import PoolRunnerConfiguration from dlt.common.typing import TSecretValue -from tests.utils import preserve_environ -from tests.common.configuration.utils import SecretCredentials, WithCredentialsConfiguration, CoercionTestConfiguration, COERCIONS, SecretConfiguration, environment, toml_providers - @configspec class EmbeddedWithGcpStorage(BaseConfiguration): @@ -31,7 +51,6 @@ class EmbeddedWithGcpCredentials(BaseConfiguration): def test_secrets_from_toml_secrets(toml_providers: ConfigProvidersContext) -> None: - # remove secret_value to trigger exception del toml_providers["secrets.toml"]._toml["secret_value"] @@ -63,10 +82,8 @@ def test_toml_types(toml_providers: ConfigProvidersContext) -> None: def test_config_provider_order(toml_providers: ConfigProvidersContext, environment: Any) -> None: - # add env provider - @with_config(sections=("api",)) def single_val(port=None): return port @@ -86,7 +103,11 @@ def test_toml_mixed_config_inject(toml_providers: ConfigProvidersContext) -> Non # get data from both providers @with_config - def mixed_val(api_type=dlt.config.value, secret_value: TSecretValue = dlt.secrets.value, typecheck: Any = dlt.config.value): + def mixed_val( + api_type=dlt.config.value, + secret_value: TSecretValue = dlt.secrets.value, + typecheck: Any = dlt.config.value, + ): return api_type, secret_value, typecheck _tup = mixed_val(None, None, None) @@ -109,13 +130,19 @@ def test_toml_sections(toml_providers: ConfigProvidersContext) -> None: def test_secrets_toml_credentials(environment: Any, toml_providers: ConfigProvidersContext) -> None: # there are credentials exactly under destination.bigquery.credentials - c = resolve.resolve_configuration(GcpServiceAccountCredentialsWithoutDefaults(), sections=("destination", "bigquery")) + c = resolve.resolve_configuration( + GcpServiceAccountCredentialsWithoutDefaults(), sections=("destination", "bigquery") + ) assert c.project_id.endswith("destination.bigquery.credentials") # there are no destination.gcp_storage.credentials so it will fallback to "destination"."credentials" - c = resolve.resolve_configuration(GcpServiceAccountCredentialsWithoutDefaults(), sections=("destination", "gcp_storage")) + c = resolve.resolve_configuration( + GcpServiceAccountCredentialsWithoutDefaults(), sections=("destination", "gcp_storage") + ) assert c.project_id.endswith("destination.credentials") # also explicit - c = resolve.resolve_configuration(GcpServiceAccountCredentialsWithoutDefaults(), sections=("destination",)) + c = resolve.resolve_configuration( + GcpServiceAccountCredentialsWithoutDefaults(), sections=("destination",) + ) assert c.project_id.endswith("destination.credentials") # there's "credentials" key but does not contain valid gcp credentials with pytest.raises(ConfigFieldMissingException): @@ -132,12 +159,18 @@ def test_secrets_toml_credentials(environment: Any, toml_providers: ConfigProvid resolve.resolve_configuration(c, sections=("destination", "bigquery")) -def test_secrets_toml_embedded_credentials(environment: Any, toml_providers: ConfigProvidersContext) -> None: +def test_secrets_toml_embedded_credentials( + environment: Any, toml_providers: ConfigProvidersContext +) -> None: # will try destination.bigquery.credentials - c = resolve.resolve_configuration(EmbeddedWithGcpCredentials(), sections=("destination", "bigquery")) + c = resolve.resolve_configuration( + EmbeddedWithGcpCredentials(), sections=("destination", "bigquery") + ) assert c.credentials.project_id.endswith("destination.bigquery.credentials") # will try destination.gcp_storage.credentials and fallback to destination.credentials - c = resolve.resolve_configuration(EmbeddedWithGcpCredentials(), sections=("destination", "gcp_storage")) + c = resolve.resolve_configuration( + EmbeddedWithGcpCredentials(), sections=("destination", "gcp_storage") + ) assert c.credentials.project_id.endswith("destination.credentials") # will try everything until credentials in the root where incomplete credentials are present c = EmbeddedWithGcpCredentials() @@ -150,11 +183,15 @@ def test_secrets_toml_embedded_credentials(environment: Any, toml_providers: Con assert set(py_ex.value.traces.keys()) == {"client_email", "private_key"} # embed "gcp_storage" will bubble up to the very top, never reverts to "credentials" - c = resolve.resolve_configuration(EmbeddedWithGcpStorage(), sections=("destination", "bigquery")) + c = resolve.resolve_configuration( + EmbeddedWithGcpStorage(), sections=("destination", "bigquery") + ) assert c.gcp_storage.project_id.endswith("-gcp-storage") # also explicit - c = resolve.resolve_configuration(GcpServiceAccountCredentialsWithoutDefaults(), sections=("destination",)) + c = resolve.resolve_configuration( + GcpServiceAccountCredentialsWithoutDefaults(), sections=("destination",) + ) assert c.project_id.endswith("destination.credentials") # there's "credentials" key but does not contain valid gcp credentials with pytest.raises(ConfigFieldMissingException): @@ -166,13 +203,22 @@ def test_dicts_are_not_enumerated() -> None: pass -def test_secrets_toml_credentials_from_native_repr(environment: Any, toml_providers: ConfigProvidersContext) -> None: +def test_secrets_toml_credentials_from_native_repr( + environment: Any, toml_providers: ConfigProvidersContext +) -> None: # cfg = toml_providers["secrets.toml"] # print(cfg._toml) # print(cfg._toml["source"]["credentials"]) # resolve gcp_credentials by parsing initial value which is str holding json doc - c = resolve.resolve_configuration(GcpServiceAccountCredentialsWithoutDefaults(), sections=("source",)) - assert c.private_key == "-----BEGIN PRIVATE KEY-----\nMIIEuwIBADANBgkqhkiG9w0BAQEFAASCBKUwggShAgEAAoIBAQCNEN0bL39HmD+S\n...\n-----END PRIVATE KEY-----\n" + c = resolve.resolve_configuration( + GcpServiceAccountCredentialsWithoutDefaults(), sections=("source",) + ) + assert ( + c.private_key + == "-----BEGIN PRIVATE" + " KEY-----\nMIIEuwIBADANBgkqhkiG9w0BAQEFAASCBKUwggShAgEAAoIBAQCNEN0bL39HmD+S\n...\n-----END" + " PRIVATE KEY-----\n" + ) # but project id got overridden from credentials.project_id assert c.project_id.endswith("-credentials") # also try sql alchemy url (native repr) @@ -251,19 +297,33 @@ def test_write_value(toml_providers: ConfigProvidersContext) -> None: # this will create path of tables provider.set_value("deep_int", 2137, "deep_pipeline", "deep", "deep", "deep", "deep") assert provider._toml["deep_pipeline"]["deep"]["deep"]["deep"]["deep"]["deep_int"] == 2137 - assert provider.get_value("deep_int", Any, "deep_pipeline", "deep", "deep", "deep", "deep") == (2137, "deep_pipeline.deep.deep.deep.deep.deep_int") + assert provider.get_value( + "deep_int", Any, "deep_pipeline", "deep", "deep", "deep", "deep" + ) == (2137, "deep_pipeline.deep.deep.deep.deep.deep_int") # same without the pipeline now = pendulum.now() provider.set_value("deep_date", now, None, "deep", "deep", "deep", "deep") - assert provider.get_value("deep_date", Any, None, "deep", "deep", "deep", "deep") == (now, "deep.deep.deep.deep.deep_date") + assert provider.get_value("deep_date", Any, None, "deep", "deep", "deep", "deep") == ( + now, + "deep.deep.deep.deep.deep_date", + ) # in existing path provider.set_value("deep_list", [1, 2, 3], None, "deep", "deep", "deep") - assert provider.get_value("deep_list", Any, None, "deep", "deep", "deep") == ([1, 2, 3], "deep.deep.deep.deep_list") + assert provider.get_value("deep_list", Any, None, "deep", "deep", "deep") == ( + [1, 2, 3], + "deep.deep.deep.deep_list", + ) # still there - assert provider.get_value("deep_date", Any, None, "deep", "deep", "deep", "deep") == (now, "deep.deep.deep.deep.deep_date") + assert provider.get_value("deep_date", Any, None, "deep", "deep", "deep", "deep") == ( + now, + "deep.deep.deep.deep.deep_date", + ) # overwrite value provider.set_value("deep_list", [1, 2, 3, 4], None, "deep", "deep", "deep") - assert provider.get_value("deep_list", Any, None, "deep", "deep", "deep") == ([1, 2, 3, 4], "deep.deep.deep.deep_list") + assert provider.get_value("deep_list", Any, None, "deep", "deep", "deep") == ( + [1, 2, 3, 4], + "deep.deep.deep.deep_list", + ) # invalid type with pytest.raises(ValueError): provider.set_value("deep_decimal", Decimal("1.2"), None, "deep", "deep", "deep", "deep") @@ -271,24 +331,39 @@ def test_write_value(toml_providers: ConfigProvidersContext) -> None: # write new dict to a new key test_d1 = {"key": "top", "embed": {"inner": "bottom", "inner_2": True}} provider.set_value("deep_dict", test_d1, None, "dict_test") - assert provider.get_value("deep_dict", Any, None, "dict_test") == (test_d1, "dict_test.deep_dict") + assert provider.get_value("deep_dict", Any, None, "dict_test") == ( + test_d1, + "dict_test.deep_dict", + ) # write same dict over dict provider.set_value("deep_dict", test_d1, None, "dict_test") - assert provider.get_value("deep_dict", Any, None, "dict_test") == (test_d1, "dict_test.deep_dict") + assert provider.get_value("deep_dict", Any, None, "dict_test") == ( + test_d1, + "dict_test.deep_dict", + ) # get a fragment - assert provider.get_value("inner_2", Any, None, "dict_test", "deep_dict", "embed") == (True, "dict_test.deep_dict.embed.inner_2") + assert provider.get_value("inner_2", Any, None, "dict_test", "deep_dict", "embed") == ( + True, + "dict_test.deep_dict.embed.inner_2", + ) # write a dict over non dict provider.set_value("deep_list", test_d1, None, "deep", "deep", "deep") - assert provider.get_value("deep_list", Any, None, "deep", "deep", "deep") == (test_d1, "deep.deep.deep.deep_list") + assert provider.get_value("deep_list", Any, None, "deep", "deep", "deep") == ( + test_d1, + "deep.deep.deep.deep_list", + ) # merge dicts test_d2 = {"key": "_top", "key2": "new2", "embed": {"inner": "_bottom", "inner_3": 2121}} provider.set_value("deep_dict", test_d2, None, "dict_test") test_m_d1_d2 = { "key": "_top", "embed": {"inner": "_bottom", "inner_2": True, "inner_3": 2121}, - "key2": "new2" + "key2": "new2", } - assert provider.get_value("deep_dict", Any, None, "dict_test") == (test_m_d1_d2, "dict_test.deep_dict") + assert provider.get_value("deep_dict", Any, None, "dict_test") == ( + test_m_d1_d2, + "dict_test.deep_dict", + ) # print(provider.get_value("deep_dict", Any, None, "dict_test")) # write configuration @@ -354,7 +429,6 @@ def test_write_toml_value(toml_providers: ConfigProvidersContext) -> None: def test_toml_string_provider() -> None: - # test basic reading provider = StringTomlProvider(""" [section1.subsection] @@ -364,8 +438,14 @@ def test_toml_string_provider() -> None: key2 = "value2" """) - assert provider.get_value("key1", "", "section1", "subsection") == ("value1", "section1.subsection.key1") - assert provider.get_value("key2", "", "section2", "subsection") == ("value2", "section2.subsection.key2") + assert provider.get_value("key1", "", "section1", "subsection") == ( + "value1", + "section1.subsection.key1", + ) + assert provider.get_value("key2", "", "section2", "subsection") == ( + "value2", + "section2.subsection.key2", + ) # test basic writing provider = StringTomlProvider("") diff --git a/tests/common/configuration/utils.py b/tests/common/configuration/utils.py index 93e8ba638e..e53d464e99 100644 --- a/tests/common/configuration/utils.py +++ b/tests/common/configuration/utils.py @@ -1,16 +1,22 @@ -import pytest -from os import environ import datetime # noqa: I251 -from typing import Any, Iterator, List, Optional, Tuple, Type, Dict, MutableMapping, Optional, Sequence +from os import environ +from typing import Any, Dict, Iterator, List, MutableMapping, Optional, Sequence, Tuple, Type + +import pytest from dlt.common import Decimal, pendulum from dlt.common.configuration import configspec -from dlt.common.configuration.specs import BaseConfiguration, CredentialsConfiguration from dlt.common.configuration.container import Container -from dlt.common.configuration.providers import ConfigProvider, EnvironProvider, ConfigTomlProvider, SecretsTomlProvider -from dlt.common.configuration.utils import get_resolved_traces +from dlt.common.configuration.providers import ( + ConfigProvider, + ConfigTomlProvider, + EnvironProvider, + SecretsTomlProvider, +) +from dlt.common.configuration.specs import BaseConfiguration, CredentialsConfiguration from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext -from dlt.common.typing import TSecretValue, StrAny +from dlt.common.configuration.utils import get_resolved_traces +from dlt.common.typing import StrAny, TSecretValue @configspec @@ -111,7 +117,6 @@ def toml_providers() -> Iterator[ConfigProvidersContext]: class MockProvider(ConfigProvider): - def __init__(self) -> None: self.value: Any = None self.return_value_on: Tuple[str] = () @@ -121,9 +126,11 @@ def reset_stats(self) -> None: self.last_section: Tuple[str] = None self.last_sections: List[Tuple[str]] = [] - def get_value(self, key: str, hint: Type[Any], pipeline_name: str, *sections: str) -> Tuple[Optional[Any], str]: + def get_value( + self, key: str, hint: Type[Any], pipeline_name: str, *sections: str + ) -> Tuple[Optional[Any], str]: if pipeline_name: - sections = (pipeline_name, ) + sections + sections = (pipeline_name,) + sections self.last_section = sections self.last_sections.append(sections) if sections == self.return_value_on: @@ -152,27 +159,21 @@ def supports_secrets(self) -> bool: COERCIONS = { - 'str_val': 'test string', - 'int_val': 12345, - 'bool_val': True, - 'list_val': [1, "2", [3]], - 'dict_val': { - 'a': 1, - "b": "2" - }, - 'bytes_val': b'Hello World!', - 'float_val': 1.18927, + "str_val": "test string", + "int_val": 12345, + "bool_val": True, + "list_val": [1, "2", [3]], + "dict_val": {"a": 1, "b": "2"}, + "bytes_val": b"Hello World!", + "float_val": 1.18927, "tuple_val": (1, 2, {"1": "complicated dicts allowed in literal eval"}), - 'any_val': "function() {}", - 'none_val': "none", - 'COMPLEX_VAL': { - "_": [1440, ["*"], []], - "change-email": [560, ["*"], []] - }, + "any_val": "function() {}", + "none_val": "none", + "COMPLEX_VAL": {"_": [1440, ["*"], []], "change-email": [560, ["*"], []]}, "date_val": pendulum.now(), "dec_val": Decimal("22.38"), "sequence_val": ["A", "B", "KAPPA"], "gen_list_val": ["C", "Z", "N"], "mapping_val": {"FL": 1, "FR": {"1": 2}}, - "mutable_mapping_val": {"str": "str"} + "mutable_mapping_val": {"str": "str"}, } diff --git a/tests/common/normalizers/custom_normalizers.py b/tests/common/normalizers/custom_normalizers.py index 8e24ffab5a..3ae65c8b53 100644 --- a/tests/common/normalizers/custom_normalizers.py +++ b/tests/common/normalizers/custom_normalizers.py @@ -5,7 +5,6 @@ class NamingConvention(SnakeCaseNamingConvention): - def normalize_identifier(self, identifier: str) -> str: if identifier.startswith("column_"): return identifier @@ -13,12 +12,12 @@ def normalize_identifier(self, identifier: str) -> str: class DataItemNormalizer(RelationalNormalizer): - def extend_schema(self) -> None: json_config = self.schema._normalizers_config["json"]["config"] d_h = self.schema._settings.setdefault("default_hints", {}) d_h["not_null"] = json_config["not_null"] - - def normalize_data_item(self, source_event: TDataItem, load_id: str, table_name) -> TNormalizedRowIterator: + def normalize_data_item( + self, source_event: TDataItem, load_id: str, table_name + ) -> TNormalizedRowIterator: yield (table_name, None), source_event diff --git a/tests/common/normalizers/test_import_normalizers.py b/tests/common/normalizers/test_import_normalizers.py index 1b939dfd6e..43e0cdefee 100644 --- a/tests/common/normalizers/test_import_normalizers.py +++ b/tests/common/normalizers/test_import_normalizers.py @@ -1,35 +1,35 @@ import os import pytest +from tests.common.normalizers.custom_normalizers import ( + DataItemNormalizer as CustomRelationalNormalizer, +) +from tests.utils import preserve_environ from dlt.common.configuration.container import Container from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.normalizers import explicit_normalizers, import_normalizers from dlt.common.normalizers.json.relational import DataItemNormalizer as RelationalNormalizer -from dlt.common.normalizers.naming import snake_case -from dlt.common.normalizers.naming import direct +from dlt.common.normalizers.naming import direct, snake_case from dlt.common.normalizers.naming.exceptions import InvalidNamingModule, UnknownNamingModule -from tests.common.normalizers.custom_normalizers import DataItemNormalizer as CustomRelationalNormalizer -from tests.utils import preserve_environ - def test_default_normalizers() -> None: config = explicit_normalizers() - assert config['names'] is None - assert config['json'] is None + assert config["names"] is None + assert config["json"] is None # pass explicit config = explicit_normalizers("direct", {"module": "custom"}) - assert config['names'] == "direct" - assert config['json'] == {"module": "custom"} + assert config["names"] == "direct" + assert config["json"] == {"module": "custom"} # use environ os.environ["SCHEMA__NAMING"] = "direct" os.environ["SCHEMA__JSON_NORMALIZER"] = '{"module": "custom"}' config = explicit_normalizers() - assert config['names'] == "direct" - assert config['json'] == {"module": "custom"} + assert config["names"] == "direct" + assert config["json"] == {"module": "custom"} def test_default_normalizers_with_caps() -> None: @@ -38,8 +38,7 @@ def test_default_normalizers_with_caps() -> None: destination_caps.naming_convention = "direct" with Container().injectable_context(destination_caps): config = explicit_normalizers() - assert config['names'] == "direct" - + assert config["names"] == "direct" def test_import_normalizers() -> None: @@ -52,7 +51,9 @@ def test_import_normalizers() -> None: assert config["json"] == {"module": "dlt.common.normalizers.json.relational"} os.environ["SCHEMA__NAMING"] = "direct" - os.environ["SCHEMA__JSON_NORMALIZER"] = '{"module": "tests.common.normalizers.custom_normalizers"}' + os.environ["SCHEMA__JSON_NORMALIZER"] = ( + '{"module": "tests.common.normalizers.custom_normalizers"}' + ) config, naming, json_normalizer = import_normalizers(explicit_normalizers()) assert config["names"] == "direct" assert config["json"] == {"module": "tests.common.normalizers.custom_normalizers"} diff --git a/tests/common/normalizers/test_json_relational.py b/tests/common/normalizers/test_json_relational.py index e344ca28d2..93338c5f69 100644 --- a/tests/common/normalizers/test_json_relational.py +++ b/tests/common/normalizers/test_json_relational.py @@ -1,15 +1,17 @@ import pytest +from tests.utils import create_schema_with_name +from dlt.common.normalizers.json.relational import DLT_ID_LENGTH_BYTES +from dlt.common.normalizers.json.relational import DataItemNormalizer as RelationalNormalizer +from dlt.common.normalizers.json.relational import RelationalNormalizerConfigPropagation from dlt.common.normalizers.naming import NamingConvention -from dlt.common.schema.typing import TSimpleRegex -from dlt.common.utils import digest128, uniq_id from dlt.common.schema import Schema +from dlt.common.schema.typing import TSimpleRegex from dlt.common.schema.utils import new_table +from dlt.common.utils import digest128, uniq_id -from dlt.common.normalizers.json.relational import RelationalNormalizerConfigPropagation, DataItemNormalizer as RelationalNormalizer, DLT_ID_LENGTH_BYTES # _flatten, _get_child_row_hash, _normalize_row, normalize_data_item, -from tests.utils import create_schema_with_name @pytest.fixture def norm() -> RelationalNormalizer: @@ -20,15 +22,7 @@ def test_flatten_fix_field_name(norm: RelationalNormalizer) -> None: row = { "f-1": "! 30", "f 2": [], - "f!3": { - "f4": "a", - "f-5": "b", - "f*6": { - "c": 7, - "c v": 8, - "c x": [] - } - } + "f!3": {"f4": "a", "f-5": "b", "f*6": {"c": 7, "c v": 8, "c x": []}}, } flattened_row, lists = norm._flatten("mock_table", row, 0) assert "f_1" in flattened_row @@ -40,29 +34,26 @@ def test_flatten_fix_field_name(norm: RelationalNormalizer) -> None: # assert "f_3__f_6__c_x" in flattened_row assert "f_3" not in flattened_row - assert ("f_2", ) in lists - assert ("f_3", "fx6", "c_x", ) in lists + assert ("f_2",) in lists + assert ( + "f_3", + "fx6", + "c_x", + ) in lists def test_preserve_complex_value(norm: RelationalNormalizer) -> None: # add table with complex column norm.schema.update_schema( - new_table("with_complex", - columns = [{ - "name": "value", - "data_type": "complex", - "nullable": "true" - }]) + new_table( + "with_complex", columns=[{"name": "value", "data_type": "complex", "nullable": "true"}] + ) ) - row_1 = { - "value": 1 - } + row_1 = {"value": 1} flattened_row, _ = norm._flatten("with_complex", row_1, 0) assert flattened_row["value"] == 1 - row_2 = { - "value": {"complex": True} - } + row_2 = {"value": {"complex": True}} flattened_row, _ = norm._flatten("with_complex", row_2, 0) assert flattened_row["value"] == row_2["value"] # complex value is not flattened @@ -74,15 +65,11 @@ def test_preserve_complex_value_with_hint(norm: RelationalNormalizer) -> None: norm.schema._settings.setdefault("preferred_types", {})["re:^value$"] = "complex" norm.schema._compile_settings() - row_1 = { - "value": 1 - } + row_1 = {"value": 1} flattened_row, _ = norm._flatten("any_table", row_1, 0) assert flattened_row["value"] == 1 - row_2 = { - "value": {"complex": True} - } + row_2 = {"value": {"complex": True}} flattened_row, _ = norm._flatten("any_table", row_2, 0) assert flattened_row["value"] == row_2["value"] # complex value is not flattened @@ -90,17 +77,11 @@ def test_preserve_complex_value_with_hint(norm: RelationalNormalizer) -> None: def test_child_table_linking(norm: RelationalNormalizer) -> None: - row = { - "f": [{ - "l": ["a", "b", "c"], - "v": 120, - "o": [{"a": 1}, {"a": 2}] - }] - } + row = {"f": [{"l": ["a", "b", "c"], "v": 120, "o": [{"a": 1}, {"a": 2}]}]} # request _dlt_root_id propagation add_dlt_root_id_propagation(norm) - rows = list(norm._normalize_row(row, {}, ("table", ))) + rows = list(norm._normalize_row(row, {}, ("table",))) # should have 7 entries (root + level 1 + 3 * list + 2 * object) assert len(rows) == 7 # root elem will not have a root hash if not explicitly added, "extend" is added only to child @@ -143,17 +124,12 @@ def test_child_table_linking(norm: RelationalNormalizer) -> None: def test_child_table_linking_primary_key(norm: RelationalNormalizer) -> None: row = { "id": "level0", - "f": [{ - "id": "level1", - "l": ["a", "b", "c"], - "v": 120, - "o": [{"a": 1}, {"a": 2}] - }] + "f": [{"id": "level1", "l": ["a", "b", "c"], "v": 120, "o": [{"a": 1}, {"a": 2}]}], } norm.schema.merge_hints({"primary_key": ["id"]}) norm.schema._compile_settings() - rows = list(norm._normalize_row(row, {}, ("table", ))) + rows = list(norm._normalize_row(row, {}, ("table",))) root = next(t for t in rows if t[0][0] == "table")[1] # record hash is random for primary keys, not based on their content # this is a change introduced in dlt 0.2.0a30 @@ -168,7 +144,9 @@ def test_child_table_linking_primary_key(norm: RelationalNormalizer) -> None: assert "_dlt_root_id" not in t_f list_rows = [t for t in rows if t[0][0] == "table__f__l"] - assert all(e[1]["_dlt_parent_id"] != digest128("level1", DLT_ID_LENGTH_BYTES) for e in list_rows) + assert all( + e[1]["_dlt_parent_id"] != digest128("level1", DLT_ID_LENGTH_BYTES) for e in list_rows + ) assert all(r[0][1] == "table__f" for r in list_rows) obj_rows = [t for t in rows if t[0][0] == "table__f__o"] assert all(e[1]["_dlt_parent_id"] != digest128("level1", DLT_ID_LENGTH_BYTES) for e in obj_rows) @@ -178,50 +156,56 @@ def test_child_table_linking_primary_key(norm: RelationalNormalizer) -> None: def test_yields_parents_first(norm: RelationalNormalizer) -> None: row = { "id": "level0", - "f": [{ - "id": "level1", - "l": ["a", "b", "c"], - "v": 120, - "o": [{"a": 1}, {"a": 2}] - }], - "g": [{ - "id": "level2_g", - "l": ["a"] - }] + "f": [{"id": "level1", "l": ["a", "b", "c"], "v": 120, "o": [{"a": 1}, {"a": 2}]}], + "g": [{"id": "level2_g", "l": ["a"]}], } - rows = list(norm._normalize_row(row, {}, ("table", ))) + rows = list(norm._normalize_row(row, {}, ("table",))) tables = list(r[0][0] for r in rows) # child tables are always yielded before parent tables - expected_tables = ['table', 'table__f', 'table__f__l', 'table__f__l', 'table__f__l', 'table__f__o', 'table__f__o', 'table__g', 'table__g__l'] + expected_tables = [ + "table", + "table__f", + "table__f__l", + "table__f__l", + "table__f__l", + "table__f__o", + "table__f__o", + "table__g", + "table__g__l", + ] assert expected_tables == tables def test_yields_parent_relation(norm: RelationalNormalizer) -> None: row = { "id": "level0", - "f": [{ - "id": "level1", - "l": ["a"], - "o": [{"a": 1}], - "b": { - "a": [ {"id": "level5"}], + "f": [ + { + "id": "level1", + "l": ["a"], + "o": [{"a": 1}], + "b": { + "a": [{"id": "level5"}], + }, } - }], + ], "d": { - "a": [ {"id": "level4"}], + "a": [{"id": "level4"}], "b": { - "a": [ {"id": "level5"}], + "a": [{"id": "level5"}], }, - "c": "x" + "c": "x", }, - "e": [{ - "o": [{"a": 1}], - "b": { - "a": [ {"id": "level5"}], + "e": [ + { + "o": [{"a": 1}], + "b": { + "a": [{"id": "level5"}], + }, } - }] + ], } - rows = list(norm._normalize_row(row, {}, ("table", ))) + rows = list(norm._normalize_row(row, {}, ("table",))) # normalizer must return parent table first and move in order of the list elements when yielding child tables # the yielding order if fully defined expected_parents = [ @@ -237,7 +221,7 @@ def test_yields_parent_relation(norm: RelationalNormalizer) -> None: # table__e is yielded it however only contains linking information ("table__e", "table"), ("table__e__o", "table__e"), - ("table__e__b__a", "table__e") + ("table__e__b__a", "table__e"), ] parents = list(r[0] for r in rows) assert parents == expected_parents @@ -279,14 +263,8 @@ def test_yields_parent_relation(norm: RelationalNormalizer) -> None: def test_list_position(norm: RelationalNormalizer) -> None: - row = { - "f": [{ - "l": ["a", "b", "c"], - "v": 120, - "lo": [{"e": "a"}, {"e": "b"}, {"e":"c"}] - }] - } - rows = list(norm._normalize_row(row, {}, ("table", ))) + row = {"f": [{"l": ["a", "b", "c"], "v": 120, "lo": [{"e": "a"}, {"e": "b"}, {"e": "c"}]}]} + rows = list(norm._normalize_row(row, {}, ("table",))) # root has no pos root = [t for t in rows if t[0][0] == "table"][0][1] assert "_dlt_list_idx" not in root @@ -325,26 +303,23 @@ def test_list_in_list() -> None: "_dlt_id": "123456", "created_at": "2023-05-12T12:34:56Z", "ended_at": "2023-05-12T13:14:32Z", - "webpath": [[ - { - "url": "https://www.website.com/", - "timestamp": "2023-05-12T12:35:01Z" - }, - { - "url": "https://www.website.com/products", - "timestamp": "2023-05-12T12:38:45Z" - }, + "webpath": [ + [ + {"url": "https://www.website.com/", "timestamp": "2023-05-12T12:35:01Z"}, + {"url": "https://www.website.com/products", "timestamp": "2023-05-12T12:38:45Z"}, { "url": "https://www.website.com/products/item123", - "timestamp": "2023-05-12T12:42:22Z" + "timestamp": "2023-05-12T12:42:22Z", }, - [{ - "url": "https://www.website.com/products/item1234", - "timestamp": "2023-05-12T12:42:22Z" - }] + [ + { + "url": "https://www.website.com/products/item1234", + "timestamp": "2023-05-12T12:42:22Z", + } + ], ], - [1, 2, 3] - ] + [1, 2, 3], + ], } schema = create_schema_with_name("other") # root @@ -354,12 +329,12 @@ def test_list_in_list() -> None: zen__webpath = [row for row in rows if row[0][0] == "zen__webpath"] # two rows in web__zenpath for two lists assert len(zen__webpath) == 2 - assert zen__webpath[0][0] == ('zen__webpath', 'zen') + assert zen__webpath[0][0] == ("zen__webpath", "zen") # _dlt_id was hardcoded in the original row assert zen__webpath[0][1]["_dlt_parent_id"] == "123456" - assert zen__webpath[0][1]['_dlt_list_idx'] == 0 - assert zen__webpath[1][1]['_dlt_list_idx'] == 1 - assert zen__webpath[1][0] == ('zen__webpath', 'zen') + assert zen__webpath[0][1]["_dlt_list_idx"] == 0 + assert zen__webpath[1][1]["_dlt_list_idx"] == 1 + assert zen__webpath[1][0] == ("zen__webpath", "zen") # inner lists zen__webpath__list = [row for row in rows if row[0][0] == "zen__webpath__list"] # actually both list of objects and list of number will be in the same table @@ -373,7 +348,9 @@ def test_list_in_list() -> None: zen_table = new_table("zen") schema.update_schema(zen_table) - path_table = new_table("zen__webpath", parent_table_name="zen", columns=[{"name": "list", "data_type": "complex"}]) + path_table = new_table( + "zen__webpath", parent_table_name="zen", columns=[{"name": "list", "data_type": "complex"}] + ) schema.update_schema(path_table) rows = list(schema.normalize_data_item(chats, "1762162.1212", "zen")) # both lists are complex types now @@ -387,13 +364,9 @@ def test_child_row_deterministic_hash(norm: RelationalNormalizer) -> None: # directly set record hash so it will be adopted in normalizer as top level hash row = { "_dlt_id": row_id, - "f": [{ - "l": ["a", "b", "c"], - "v": 120, - "lo": [{"e": "a"}, {"e": "b"}, {"e":"c"}] - }] + "f": [{"l": ["a", "b", "c"], "v": 120, "lo": [{"e": "a"}, {"e": "b"}, {"e": "c"}]}], } - rows = list(norm._normalize_row(row, {}, ("table", ))) + rows = list(norm._normalize_row(row, {}, ("table",))) children = [t for t in rows if t[0][0] != "table"] # all hashes must be different distinct_hashes = set([ch[1]["_dlt_id"] for ch in children]) @@ -401,7 +374,9 @@ def test_child_row_deterministic_hash(norm: RelationalNormalizer) -> None: # compute hashes for all children for (table, _), ch in children: - expected_hash = digest128(f"{ch['_dlt_parent_id']}_{table}_{ch['_dlt_list_idx']}", DLT_ID_LENGTH_BYTES) + expected_hash = digest128( + f"{ch['_dlt_parent_id']}_{table}_{ch['_dlt_list_idx']}", DLT_ID_LENGTH_BYTES + ) assert ch["_dlt_id"] == expected_hash # direct compute one of the @@ -410,54 +385,66 @@ def test_child_row_deterministic_hash(norm: RelationalNormalizer) -> None: assert f_lo_p2["_dlt_id"] == digest128(f"{el_f['_dlt_id']}_table__f__lo_2", DLT_ID_LENGTH_BYTES) # same data with same table and row_id - rows_2 = list(norm._normalize_row(row, {}, ("table", ))) + rows_2 = list(norm._normalize_row(row, {}, ("table",))) children_2 = [t for t in rows_2 if t[0][0] != "table"] # corresponding hashes must be identical assert all(ch[0][1]["_dlt_id"] == ch[1][1]["_dlt_id"] for ch in zip(children, children_2)) # change parent table and all child hashes must be different - rows_4 = list(norm._normalize_row(row, {}, ("other_table", ))) + rows_4 = list(norm._normalize_row(row, {}, ("other_table",))) children_4 = [t for t in rows_4 if t[0][0] != "other_table"] assert all(ch[0][1]["_dlt_id"] != ch[1][1]["_dlt_id"] for ch in zip(children, children_4)) # change parent hash and all child hashes must be different row["_dlt_id"] = uniq_id() - rows_3 = list(norm._normalize_row(row, {}, ("table", ))) + rows_3 = list(norm._normalize_row(row, {}, ("table",))) children_3 = [t for t in rows_3 if t[0][0] != "table"] assert all(ch[0][1]["_dlt_id"] != ch[1][1]["_dlt_id"] for ch in zip(children, children_3)) def test_keeps_dlt_id(norm: RelationalNormalizer) -> None: h = uniq_id() - row = { - "a": "b", - "_dlt_id": h - } - rows = list(norm._normalize_row(row, {}, ("table", ))) + row = {"a": "b", "_dlt_id": h} + rows = list(norm._normalize_row(row, {}, ("table",))) root = [t for t in rows if t[0][0] == "table"][0][1] assert root["_dlt_id"] == h def test_propagate_hardcoded_context(norm: RelationalNormalizer) -> None: row = {"level": 1, "list": ["a", "b", "c"], "comp": [{"_timestamp": "a"}]} - rows = list(norm._normalize_row(row, {"_timestamp": 1238.9, "_dist_key": "SENDER_3000"}, ("table", ))) + rows = list( + norm._normalize_row(row, {"_timestamp": 1238.9, "_dist_key": "SENDER_3000"}, ("table",)) + ) # context is not added to root element root = next(t for t in rows if t[0][0] == "table")[1] assert "_timestamp" in root assert "_dist_key" in root # the original _timestamp field will be overwritten in children and added to lists - assert all(e[1]["_timestamp"] == 1238.9 and e[1]["_dist_key"] == "SENDER_3000" for e in rows if e[0][0] != "table") + assert all( + e[1]["_timestamp"] == 1238.9 and e[1]["_dist_key"] == "SENDER_3000" + for e in rows + if e[0][0] != "table" + ) def test_propagates_root_context(norm: RelationalNormalizer) -> None: add_dlt_root_id_propagation(norm) # add timestamp propagation - norm.schema._normalizers_config["json"]["config"]["propagation"]["root"]["timestamp"] = "_partition_ts" + norm.schema._normalizers_config["json"]["config"]["propagation"]["root"][ + "timestamp" + ] = "_partition_ts" # add propagation for non existing element - norm.schema._normalizers_config["json"]["config"]["propagation"]["root"]["__not_found"] = "__not_found" + norm.schema._normalizers_config["json"]["config"]["propagation"]["root"][ + "__not_found" + ] = "__not_found" - row = {"_dlt_id": "###", "timestamp": 12918291.1212, "dependent_list":[1, 2,3], "dependent_objects": [{"vx": "ax"}]} - normalized_rows = list(norm._normalize_row(row, {}, ("table", ))) + row = { + "_dlt_id": "###", + "timestamp": 12918291.1212, + "dependent_list": [1, 2, 3], + "dependent_objects": [{"vx": "ax"}], + } + normalized_rows = list(norm._normalize_row(row, {}, ("table",))) # all non-root rows must have: non_root = [r for r in normalized_rows if r[0][1] is not None] assert all(r[1]["_dlt_root_id"] == "###" for r in non_root) @@ -466,15 +453,19 @@ def test_propagates_root_context(norm: RelationalNormalizer) -> None: @pytest.mark.parametrize("add_pk,add_dlt_id", [(False, False), (True, False), (True, True)]) -def test_propagates_table_context(norm: RelationalNormalizer, add_pk: bool, add_dlt_id: bool) -> None: +def test_propagates_table_context( + norm: RelationalNormalizer, add_pk: bool, add_dlt_id: bool +) -> None: add_dlt_root_id_propagation(norm) - prop_config: RelationalNormalizerConfigPropagation = norm.schema._normalizers_config["json"]["config"]["propagation"] + prop_config: RelationalNormalizerConfigPropagation = norm.schema._normalizers_config["json"][ + "config" + ]["propagation"] prop_config["root"]["timestamp"] = "_partition_ts" # for table "table__lvl1" request to propagate "vx" and "partition_ovr" as "_partition_ts" (should overwrite root) prop_config["tables"]["table__lvl1"] = { "vx": "__vx", "partition_ovr": "_partition_ts", - "__not_found": "__not_found" + "__not_found": "__not_found", } if add_pk: @@ -482,21 +473,17 @@ def test_propagates_table_context(norm: RelationalNormalizer, add_pk: bool, add_ norm.schema.merge_hints({"primary_key": [TSimpleRegex("vx")]}) row = { - "_dlt_id": "###", - "timestamp": 12918291.1212, - "lvl1": [{ - "vx": "ax", - "partition_ovr": 1283.12, - "lvl2": [{ - "_partition_ts": "overwritten" - }] - }] - } + "_dlt_id": "###", + "timestamp": 12918291.1212, + "lvl1": [ + {"vx": "ax", "partition_ovr": 1283.12, "lvl2": [{"_partition_ts": "overwritten"}]} + ], + } if add_dlt_id: # to reproduce a bug where rows with _dlt_id set were not extended row["lvl1"][0]["_dlt_id"] = "row_id_lvl1" - normalized_rows = list(norm._normalize_row(row, {}, ("table", ))) + normalized_rows = list(norm._normalize_row(row, {}, ("table",))) non_root = [r for r in normalized_rows if r[0][1] is not None] # _dlt_root_id in all non root assert all(r[1]["_dlt_root_id"] == "###" for r in non_root) @@ -505,21 +492,30 @@ def test_propagates_table_context(norm: RelationalNormalizer, add_pk: bool, add_ # _partition_ts == timestamp only at lvl1 assert all(r[1]["_partition_ts"] == 12918291.1212 for r in non_root if r[0][0] == "table__lvl1") # _partition_ts == partition_ovr and __vx only at lvl2 - assert all(r[1]["_partition_ts"] == 1283.12 and r[1]["__vx"] == "ax" for r in non_root if r[0][0] == "table__lvl1__lvl2") - assert any(r[1]["_partition_ts"] == 1283.12 and r[1]["__vx"] == "ax" for r in non_root if r[0][0] != "table__lvl1__lvl2") is False + assert all( + r[1]["_partition_ts"] == 1283.12 and r[1]["__vx"] == "ax" + for r in non_root + if r[0][0] == "table__lvl1__lvl2" + ) + assert ( + any( + r[1]["_partition_ts"] == 1283.12 and r[1]["__vx"] == "ax" + for r in non_root + if r[0][0] != "table__lvl1__lvl2" + ) + is False + ) def test_propagates_table_context_to_lists(norm: RelationalNormalizer) -> None: add_dlt_root_id_propagation(norm) - prop_config: RelationalNormalizerConfigPropagation = norm.schema._normalizers_config["json"]["config"]["propagation"] + prop_config: RelationalNormalizerConfigPropagation = norm.schema._normalizers_config["json"][ + "config" + ]["propagation"] prop_config["root"]["timestamp"] = "_partition_ts" - row = { - "_dlt_id": "###", - "timestamp": 12918291.1212, - "lvl1": [1, 2, 3, [4, 5, 6]] - } - normalized_rows = list(norm._normalize_row(row, {}, ("table", ))) + row = {"_dlt_id": "###", "timestamp": 12918291.1212, "lvl1": [1, 2, 3, [4, 5, 6]]} + normalized_rows = list(norm._normalize_row(row, {}, ("table",))) # _partition_ts == timestamp on all child tables non_root = [r for r in normalized_rows if r[0][1] is not None] assert all(r[1]["_partition_ts"] == 12918291.1212 for r in non_root) @@ -532,7 +528,7 @@ def test_removes_normalized_list(norm: RelationalNormalizer) -> None: # after normalizing the list that got normalized into child table must be deleted row = {"comp": [{"_timestamp": "a"}]} # get iterator - normalized_rows_i = norm._normalize_row(row, {}, ("table", )) + normalized_rows_i = norm._normalize_row(row, {}, ("table",)) # yield just one item root_row = next(normalized_rows_i) # root_row = next(r for r in normalized_rows if r[0][1] is None) @@ -543,17 +539,13 @@ def test_preserves_complex_types_list(norm: RelationalNormalizer) -> None: # the exception to test_removes_normalized_list # complex types should be left as they are # add table with complex column - norm.schema.update_schema(new_table("event_slot", - columns = [{ - "name": "value", - "data_type": "complex", - "nullable": "true" - }]) + norm.schema.update_schema( + new_table( + "event_slot", columns=[{"name": "value", "data_type": "complex", "nullable": "true"}] + ) ) - row = { - "value": ["from", {"complex": True}] - } - normalized_rows = list(norm._normalize_row(row, {}, ("event_slot", ))) + row = {"value": ["from", {"complex": True}]} + normalized_rows = list(norm._normalize_row(row, {}, ("event_slot",))) # make sure only 1 row is emitted, the list is not normalized assert len(normalized_rows) == 1 # value is kept in root row -> market as complex @@ -561,10 +553,8 @@ def test_preserves_complex_types_list(norm: RelationalNormalizer) -> None: assert root_row[1]["value"] == row["value"] # same should work for a list - row = { - "value": ["from", ["complex", True]] - } - normalized_rows = list(norm._normalize_row(row, {}, ("event_slot", ))) + row = {"value": ["from", ["complex", True]]} + normalized_rows = list(norm._normalize_row(row, {}, ("event_slot",))) # make sure only 1 row is emitted, the list is not normalized assert len(normalized_rows) == 1 # value is kept in root row -> market as complex @@ -580,7 +570,10 @@ def test_wrap_in_dict(norm: RelationalNormalizer) -> None: # wrap a list rows = list(norm.schema.normalize_data_item([1, 2, 3, 4, "A"], "load_id", "listex")) assert len(rows) == 6 - assert rows[0][0] == ("listex", None,) + assert rows[0][0] == ( + "listex", + None, + ) assert rows[1][0] == ("listex__value", "listex") assert rows[-1][1]["value"] == "A" @@ -590,15 +583,19 @@ def test_complex_types_for_recursion_level(norm: RelationalNormalizer) -> None: # if max recursion depth is set, nested elements will be kept as complex row = { "_dlt_id": "row_id", - "f": [{ - "l": ["a"], # , "b", "c" - "v": 120, - "lo": [{"e": {"v": 1}}] # , {"e": {"v": 2}}, {"e":{"v":3 }} - }] + "f": [ + { + "l": ["a"], # , "b", "c" + "v": 120, + "lo": [{"e": {"v": 1}}], # , {"e": {"v": 2}}, {"e":{"v":3 }} + } + ], } n_rows_nl = list(norm.schema.normalize_data_item(row, "load_id", "default")) # all nested elements were yielded - assert ["default", "default__f", "default__f__l", "default__f__lo"] == [r[0][0] for r in n_rows_nl] + assert ["default", "default__f", "default__f__l", "default__f__lo"] == [ + r[0][0] for r in n_rows_nl + ] # set max nesting to 0 set_max_nesting(norm, 0) @@ -643,12 +640,10 @@ def test_extract_with_table_name_meta() -> None: "flags": 0, "parent_id": None, "guild_id": "815421435900198962", - "permission_overwrites": [] + "permission_overwrites": [], } # force table name - rows = list( - create_schema_with_name("discord").normalize_data_item(row, "load_id", "channel") - ) + rows = list(create_schema_with_name("discord").normalize_data_item(row, "load_id", "channel")) # table is channel assert rows[0][0][0] == "channel" normalized_row = rows[0][1] @@ -675,13 +670,7 @@ def test_parse_with_primary_key() -> None: schema._compile_settings() add_dlt_root_id_propagation(schema.data_item_normalizer) - row = { - "id": "817949077341208606", - "w_id":[{ - "id": 9128918293891111, - "wo_id": [1, 2, 3] - }] - } + row = {"id": "817949077341208606", "w_id": [{"id": 9128918293891111, "wo_id": [1, 2, 3]}]} rows = list(schema.normalize_data_item(row, "load_id", "discord")) # get root root = next(t[1] for t in rows if t[0][0] == "discord") @@ -699,11 +688,15 @@ def test_parse_with_primary_key() -> None: assert "_dlt_root_id" in el_w_id # this must have deterministic child key - f_wo_id = next(t[1] for t in rows if t[0][0] == "discord__w_id__wo_id" and t[1]["_dlt_list_idx"] == 2) + f_wo_id = next( + t[1] for t in rows if t[0][0] == "discord__w_id__wo_id" and t[1]["_dlt_list_idx"] == 2 + ) assert f_wo_id["value"] == 3 assert f_wo_id["_dlt_root_id"] != digest128("817949077341208606", DLT_ID_LENGTH_BYTES) assert f_wo_id["_dlt_parent_id"] != digest128("9128918293891111", DLT_ID_LENGTH_BYTES) - assert f_wo_id["_dlt_id"] == RelationalNormalizer._get_child_row_hash(f_wo_id["_dlt_parent_id"], "discord__w_id__wo_id", 2) + assert f_wo_id["_dlt_id"] == RelationalNormalizer._get_child_row_hash( + f_wo_id["_dlt_parent_id"], "discord__w_id__wo_id", 2 + ) def test_keeps_none_values() -> None: @@ -723,16 +716,10 @@ def test_normalize_and_shorten_deterministically() -> None: data = { "short>ident:1": { - "short>ident:2": { - "short>ident:3": "a" - }, - }, - "LIST+ident:1": { - "LIST+ident:2": { - "LIST+ident:3": [1] - } + "short>ident:2": {"short>ident:3": "a"}, }, - "long+long:SO+LONG:_>16": True + "LIST+ident:1": {"LIST+ident:2": {"LIST+ident:3": [1]}}, + "long+long:SO+LONG:_>16": True, } rows = list(schema.normalize_data_item(data, "1762162.1212", "s")) # all identifiers are 16 chars or shorter @@ -746,14 +733,20 @@ def test_normalize_and_shorten_deterministically() -> None: root_data = rows[0][1] root_data_keys = list(root_data.keys()) # "short:ident:2": "a" will be flattened into root - tag = NamingConvention._compute_tag("short_ident_1__short_ident_2__short_ident_3", NamingConvention._DEFAULT_COLLISION_PROB) + tag = NamingConvention._compute_tag( + "short_ident_1__short_ident_2__short_ident_3", NamingConvention._DEFAULT_COLLISION_PROB + ) assert tag in root_data_keys[0] # long:SO+LONG:_>16 shortened on normalized name - tag = NamingConvention._compute_tag("long+long:SO+LONG:_>16", NamingConvention._DEFAULT_COLLISION_PROB) + tag = NamingConvention._compute_tag( + "long+long:SO+LONG:_>16", NamingConvention._DEFAULT_COLLISION_PROB + ) assert tag in root_data_keys[1] # table name in second row table_name = rows[1][0][0] - tag = NamingConvention._compute_tag("s__lis_txident_1__lis_txident_2__lis_txident_3", NamingConvention._DEFAULT_COLLISION_PROB) + tag = NamingConvention._compute_tag( + "s__lis_txident_1__lis_txident_2__lis_txident_3", NamingConvention._DEFAULT_COLLISION_PROB + ) assert tag in table_name @@ -774,21 +767,12 @@ def test_normalize_empty_keys() -> None: def set_max_nesting(norm: RelationalNormalizer, max_nesting: int) -> None: - RelationalNormalizer.update_normalizer_config(norm.schema, - { - "max_nesting": max_nesting - } - ) + RelationalNormalizer.update_normalizer_config(norm.schema, {"max_nesting": max_nesting}) norm._reset() def add_dlt_root_id_propagation(norm: RelationalNormalizer) -> None: - RelationalNormalizer.update_normalizer_config(norm.schema, { - "propagation": { - "root": { - "_dlt_id": "_dlt_root_id" - }, - "tables": {} - } - }) + RelationalNormalizer.update_normalizer_config( + norm.schema, {"propagation": {"root": {"_dlt_id": "_dlt_root_id"}, "tables": {}}} + ) norm._reset() diff --git a/tests/common/normalizers/test_naming.py b/tests/common/normalizers/test_naming.py index 02ff6e3c38..cdf9ef0888 100644 --- a/tests/common/normalizers/test_naming.py +++ b/tests/common/normalizers/test_naming.py @@ -1,25 +1,31 @@ -import pytest import string from typing import List, Type +import pytest + from dlt.common.normalizers.naming import NamingConvention -from dlt.common.normalizers.naming.snake_case import NamingConvention as SnakeCaseNamingConvention from dlt.common.normalizers.naming.direct import NamingConvention as DirectNamingConvention +from dlt.common.normalizers.naming.snake_case import NamingConvention as SnakeCaseNamingConvention from dlt.common.typing import DictStrStr from dlt.common.utils import uniq_id - LONG_PATH = "prospects_external_data__data365_member__member__feed_activities_created_post__items__comments__items__comments__items__author_details__educations" DENSE_PATH = "__".join(string.ascii_lowercase) LONG_IDENT = 10 * string.printable IDENT_20_CHARS = "she played cello well" RAW_IDENT = ".\n'played CELLO🚧_" RAW_IDENT_W_SPACES = f" {RAW_IDENT} \t\n" -RAW_IDENT_2 = "123.\"\rhello😄!" +RAW_IDENT_2 = '123."\rhello😄!' RAW_IDENT_2_W_SPACES = f"\n {RAW_IDENT_2} \t " RAW_PATH = [RAW_IDENT, RAW_IDENT_2_W_SPACES, RAW_IDENT_2, RAW_IDENT_2_W_SPACES] EMPTY_IDENT = " \t\n " -RAW_PATH_WITH_EMPTY_IDENT = [RAW_IDENT, RAW_IDENT_2_W_SPACES, EMPTY_IDENT, RAW_IDENT_2, RAW_IDENT_2_W_SPACES] +RAW_PATH_WITH_EMPTY_IDENT = [ + RAW_IDENT, + RAW_IDENT_2_W_SPACES, + EMPTY_IDENT, + RAW_IDENT_2, + RAW_IDENT_2_W_SPACES, +] def test_tag_collisions() -> None: @@ -29,52 +35,61 @@ def test_tag_collisions() -> None: generations = 100000 collisions = 0 for _ in range(0, generations): - tag = NamingConvention._compute_tag(uniq_id(32), collision_prob=NamingConvention._DEFAULT_COLLISION_PROB) + tag = NamingConvention._compute_tag( + uniq_id(32), collision_prob=NamingConvention._DEFAULT_COLLISION_PROB + ) if tag in tags: collisions += 1 else: tags[tag] = tag - assert collisions/generations < 0.001 + assert collisions / generations < 0.001 def test_tag_generation() -> None: # is content hash content = 20 * LONG_PATH - content_tag = NamingConvention._compute_tag(content, collision_prob=NamingConvention._DEFAULT_COLLISION_PROB) + content_tag = NamingConvention._compute_tag( + content, collision_prob=NamingConvention._DEFAULT_COLLISION_PROB + ) # no randomness for _ in range(0, 20): - tag = NamingConvention._compute_tag(content, collision_prob=NamingConvention._DEFAULT_COLLISION_PROB) + tag = NamingConvention._compute_tag( + content, collision_prob=NamingConvention._DEFAULT_COLLISION_PROB + ) assert content_tag == tag fixture = [ - ('3f17271231504b8cf65690bcdc379df8a3b8aabe12efe1ea82848ec5f497cb69', 'gds0iw'), - ('58e5c351b53ffe1233e0656a532a721ae1d2ac7af71b6cfec8ceb64c63b10721', 'uyboiq'), - ('e3f34629839cedcabba95354e48a78dc80b0cd35c02ddfbbf20196ba7f968866', '51wdcg'), - ('f0f22b8e8c58389a6c21dbcc1e261ee0354704e24996a0ec541276f58d1f2f52', 'bpm7ca'), - ('0d0de95c7c12ceee919d28d22c970285d80a36dea4fe32dbdd667a888ae6d47f', 'doqcuq'), - ('4973509ea648ddfbaf6c50e1fef33c3b0a3d1c1a82dff543a8255e60b6572567', 'cl7rpq'), - ('877c89f0dcbd24b8c3f787624ddca09deb6a44e4a72f12527209d78e4d9ed247', 'xrnycg'), - ('064df58cd3a51c50dbf30e975e63961a501212ff8e8ca544ab396727f4b8a367', 'kgiizq'), - ('c8f7da1b5c44c1ca10da67c1514c4cf365e4d5912685b25a39206d5c8c1966a1', 'dj9zqq'), - ('222d42333592ea87823fd2e7868d59fb0aded20603f433319691918299513cb6', 'futp4w'), - ('757d64eb242a91b494ec9e2661a7946410d68144d33860d6f4154092d65d5009', 'wetlpg'), - ('3c7348d43478292b4c4e0689d41a536fc8ccabdbd9fb9d0dfbe757a83d34cebe', 'avxagg'), - ('6896fac1546c201d4dc91d2c51bdcd9c820fe92fd0555947e59fdc89ca6f045d', 'wbaj3w'), - ('b4def322a4487dd90fcc4abd2f1efde0cdce81d8e0a580fd1897203ab4ebcebe', 'whojmw'), - ('07d974124b92adafc90473a3968ceb5e8329d815e0e48260473d70a781adb8ae', 'aiqcea'), - ('c67183a762e379290652cc26a786b21eff347643b1cc9012138f460901ce5d53', 'zfztpg'), - ('430976db5adef67d0009aa3cd9a2daca106829b36a7232732c5d694e7197c6d1', 'evr7rq'), - ('c1c8c0ff6933fa4e23fab5605139124b2c6cda0150a412daaea274818ee46e35', 'er0nxq'), - ('0060c538b6ce02b8d8e2c85b4e2810c58b846f4096ed7ab871fc092c45ac09d9', 'zh9xgg'), - ('4d4b99ff5d2a3d5cd076782c9cd088cd85d5c789d7de6bdc19c1d206b687d485', '2vvr5a') + ("3f17271231504b8cf65690bcdc379df8a3b8aabe12efe1ea82848ec5f497cb69", "gds0iw"), + ("58e5c351b53ffe1233e0656a532a721ae1d2ac7af71b6cfec8ceb64c63b10721", "uyboiq"), + ("e3f34629839cedcabba95354e48a78dc80b0cd35c02ddfbbf20196ba7f968866", "51wdcg"), + ("f0f22b8e8c58389a6c21dbcc1e261ee0354704e24996a0ec541276f58d1f2f52", "bpm7ca"), + ("0d0de95c7c12ceee919d28d22c970285d80a36dea4fe32dbdd667a888ae6d47f", "doqcuq"), + ("4973509ea648ddfbaf6c50e1fef33c3b0a3d1c1a82dff543a8255e60b6572567", "cl7rpq"), + ("877c89f0dcbd24b8c3f787624ddca09deb6a44e4a72f12527209d78e4d9ed247", "xrnycg"), + ("064df58cd3a51c50dbf30e975e63961a501212ff8e8ca544ab396727f4b8a367", "kgiizq"), + ("c8f7da1b5c44c1ca10da67c1514c4cf365e4d5912685b25a39206d5c8c1966a1", "dj9zqq"), + ("222d42333592ea87823fd2e7868d59fb0aded20603f433319691918299513cb6", "futp4w"), + ("757d64eb242a91b494ec9e2661a7946410d68144d33860d6f4154092d65d5009", "wetlpg"), + ("3c7348d43478292b4c4e0689d41a536fc8ccabdbd9fb9d0dfbe757a83d34cebe", "avxagg"), + ("6896fac1546c201d4dc91d2c51bdcd9c820fe92fd0555947e59fdc89ca6f045d", "wbaj3w"), + ("b4def322a4487dd90fcc4abd2f1efde0cdce81d8e0a580fd1897203ab4ebcebe", "whojmw"), + ("07d974124b92adafc90473a3968ceb5e8329d815e0e48260473d70a781adb8ae", "aiqcea"), + ("c67183a762e379290652cc26a786b21eff347643b1cc9012138f460901ce5d53", "zfztpg"), + ("430976db5adef67d0009aa3cd9a2daca106829b36a7232732c5d694e7197c6d1", "evr7rq"), + ("c1c8c0ff6933fa4e23fab5605139124b2c6cda0150a412daaea274818ee46e35", "er0nxq"), + ("0060c538b6ce02b8d8e2c85b4e2810c58b846f4096ed7ab871fc092c45ac09d9", "zh9xgg"), + ("4d4b99ff5d2a3d5cd076782c9cd088cd85d5c789d7de6bdc19c1d206b687d485", "2vvr5a"), ] for content, expected_tag in fixture: - tag = NamingConvention._compute_tag(content, collision_prob=NamingConvention._DEFAULT_COLLISION_PROB) + tag = NamingConvention._compute_tag( + content, collision_prob=NamingConvention._DEFAULT_COLLISION_PROB + ) assert len(tag) == 6 assert tag == expected_tag # print(f"('{content}', '{tag}'),") + def test_tag_placement() -> None: # tags are placed in the middle of string and that must happen deterministically tag = "123456" @@ -99,20 +114,26 @@ def test_tag_placement() -> None: def test_shorten_identifier() -> None: # no limit - long_ident = 8*LONG_PATH + long_ident = 8 * LONG_PATH assert NamingConvention.shorten_identifier(long_ident, long_ident, None) == long_ident # within limit assert NamingConvention.shorten_identifier("012345678", "xxx012345678xxx", 10) == "012345678" - assert NamingConvention.shorten_identifier("0123456789", "xxx012345678xx?", 10) == "0123456789" # max_length + assert ( + NamingConvention.shorten_identifier("0123456789", "xxx012345678xx?", 10) == "0123456789" + ) # max_length # tag based on original string placed in the middle - tag = NamingConvention._compute_tag(IDENT_20_CHARS, collision_prob=NamingConvention._DEFAULT_COLLISION_PROB) + tag = NamingConvention._compute_tag( + IDENT_20_CHARS, collision_prob=NamingConvention._DEFAULT_COLLISION_PROB + ) norm_ident = NamingConvention.shorten_identifier(IDENT_20_CHARS, IDENT_20_CHARS, 20) assert tag in norm_ident assert len(norm_ident) == 20 assert norm_ident == "she plauanpualo well" # the tag must be based on raw string, not normalized string, one test case with spaces for raw_content in [uniq_id(), f" {uniq_id()} "]: - tag = NamingConvention._compute_tag(raw_content, collision_prob=NamingConvention._DEFAULT_COLLISION_PROB) + tag = NamingConvention._compute_tag( + raw_content, collision_prob=NamingConvention._DEFAULT_COLLISION_PROB + ) norm_ident = NamingConvention.shorten_identifier(IDENT_20_CHARS, raw_content, 20) assert tag in norm_ident assert len(norm_ident) == 20 @@ -135,7 +156,9 @@ def test_normalize_with_shorten_identifier(convention: Type[NamingConvention]) - # force to shorten naming = convention(len(RAW_IDENT) // 2) # tag expected on stripped RAW_IDENT - tag = NamingConvention._compute_tag(RAW_IDENT, collision_prob=NamingConvention._DEFAULT_COLLISION_PROB) + tag = NamingConvention._compute_tag( + RAW_IDENT, collision_prob=NamingConvention._DEFAULT_COLLISION_PROB + ) # spaces are stripped assert naming.normalize_identifier(RAW_IDENT) == naming.normalize_identifier(RAW_IDENT_W_SPACES) assert tag in naming.normalize_identifier(RAW_IDENT) @@ -192,7 +215,11 @@ def test_normalize_path(convention: Type[NamingConvention]) -> None: norm_path_str = naming.normalize_path(raw_path_str) assert len(naming.break_path(norm_path_str)) == len(RAW_PATH) # double norm path does not change anything - assert naming.normalize_path(raw_path_str) == naming.normalize_path(norm_path_str) == naming.normalize_path(naming.normalize_path(norm_path_str)) + assert ( + naming.normalize_path(raw_path_str) + == naming.normalize_path(norm_path_str) + == naming.normalize_path(naming.normalize_path(norm_path_str)) + ) # empty element in path is ignored assert naming.make_path(*RAW_PATH_WITH_EMPTY_IDENT) == raw_path_str assert naming.normalize_path(raw_path_str) == norm_path_str @@ -200,12 +227,18 @@ def test_normalize_path(convention: Type[NamingConvention]) -> None: # preserve idents but shorten path naming = convention(len(RAW_IDENT) * 2) # give enough max length # tag computed from raw path - tag = NamingConvention._compute_tag(raw_path_str, collision_prob=NamingConvention._DEFAULT_COLLISION_PROB) + tag = NamingConvention._compute_tag( + raw_path_str, collision_prob=NamingConvention._DEFAULT_COLLISION_PROB + ) tagged_raw_path_str = naming.normalize_path(raw_path_str) # contains tag assert tag in tagged_raw_path_str # idempotent - assert tagged_raw_path_str == naming.normalize_path(tagged_raw_path_str) == naming.normalize_path(naming.normalize_path(tagged_raw_path_str)) + assert ( + tagged_raw_path_str + == naming.normalize_path(tagged_raw_path_str) + == naming.normalize_path(naming.normalize_path(tagged_raw_path_str)) + ) assert tagged_raw_path_str == naming.make_path(*naming.break_path(tagged_raw_path_str)) # also cut idents diff --git a/tests/common/normalizers/test_naming_duck_case.py b/tests/common/normalizers/test_naming_duck_case.py index b50ee64581..155b0e27b5 100644 --- a/tests/common/normalizers/test_naming_duck_case.py +++ b/tests/common/normalizers/test_naming_duck_case.py @@ -15,12 +15,20 @@ def test_normalize_identifier(naming_unlimited: NamingConvention) -> None: def test_alphabet_reduction(naming_unlimited: NamingConvention) -> None: - assert naming_unlimited.normalize_identifier(NamingConvention._REDUCE_ALPHABET[0]) == NamingConvention._REDUCE_ALPHABET[1] + assert ( + naming_unlimited.normalize_identifier(NamingConvention._REDUCE_ALPHABET[0]) + == NamingConvention._REDUCE_ALPHABET[1] + ) def test_duck_snake_case_compat(naming_unlimited: NamingConvention) -> None: snake_unlimited = SnakeNamingConvention() # same reduction duck -> snake - assert snake_unlimited.normalize_identifier(NamingConvention._REDUCE_ALPHABET[0]) == NamingConvention._REDUCE_ALPHABET[1] + assert ( + snake_unlimited.normalize_identifier(NamingConvention._REDUCE_ALPHABET[0]) + == NamingConvention._REDUCE_ALPHABET[1] + ) # but there are differences in the reduction - assert naming_unlimited.normalize_identifier(SnakeNamingConvention._REDUCE_ALPHABET[0]) != snake_unlimited.normalize_identifier(SnakeNamingConvention._REDUCE_ALPHABET[0]) + assert naming_unlimited.normalize_identifier( + SnakeNamingConvention._REDUCE_ALPHABET[0] + ) != snake_unlimited.normalize_identifier(SnakeNamingConvention._REDUCE_ALPHABET[0]) diff --git a/tests/common/normalizers/test_naming_snake_case.py b/tests/common/normalizers/test_naming_snake_case.py index 976c242930..d15571f2fb 100644 --- a/tests/common/normalizers/test_naming_snake_case.py +++ b/tests/common/normalizers/test_naming_snake_case.py @@ -38,7 +38,10 @@ def test_normalize_identifier(naming_unlimited: NamingConvention) -> None: def test_alphabet_reduction(naming_unlimited: NamingConvention) -> None: - assert naming_unlimited.normalize_identifier(NamingConvention._REDUCE_ALPHABET[0]) == NamingConvention._REDUCE_ALPHABET[1] + assert ( + naming_unlimited.normalize_identifier(NamingConvention._REDUCE_ALPHABET[0]) + == NamingConvention._REDUCE_ALPHABET[1] + ) def test_normalize_path(naming_unlimited: NamingConvention) -> None: @@ -74,6 +77,7 @@ def test_normalize_make_path(naming_unlimited: NamingConvention) -> None: def test_normalizes_underscores(naming_unlimited: NamingConvention) -> None: - assert naming_unlimited.normalize_identifier("event__value_value2____") == "event_value_value2xxxx" + assert ( + naming_unlimited.normalize_identifier("event__value_value2____") == "event_value_value2xxxx" + ) assert naming_unlimited.normalize_path("e_vent__value_value2___") == "e_vent__value_value2__x" - diff --git a/tests/common/reflection/test_reflect_spec.py b/tests/common/reflection/test_reflect_spec.py index 0d9bc28cc6..7f506a5ce4 100644 --- a/tests/common/reflection/test_reflect_spec.py +++ b/tests/common/reflection/test_reflect_spec.py @@ -3,24 +3,34 @@ import dlt from dlt.common import Decimal -from dlt.common.typing import TSecretValue, is_optional_type from dlt.common.configuration.inject import get_fun_spec, with_config -from dlt.common.configuration.specs import BaseConfiguration, RunConfiguration, ConnectionStringCredentials -from dlt.common.reflection.spec import spec_from_signature, _get_spec_name_from_f +from dlt.common.configuration.specs import ( + BaseConfiguration, + ConnectionStringCredentials, + RunConfiguration, +) +from dlt.common.reflection.spec import _get_spec_name_from_f, spec_from_signature from dlt.common.reflection.utils import get_func_def_node, get_literal_defaults - +from dlt.common.typing import TSecretValue, is_optional_type _DECIMAL_DEFAULT = Decimal("0.01") _SECRET_DEFAULT = TSecretValue("PASS") _CONFIG_DEFAULT = RunConfiguration() -_CREDENTIALS_DEFAULT = ConnectionStringCredentials("postgresql://loader:loader@localhost:5432/dlt_data") +_CREDENTIALS_DEFAULT = ConnectionStringCredentials( + "postgresql://loader:loader@localhost:5432/dlt_data" +) def test_synthesize_spec_from_sig() -> None: - # spec from typed signature without defaults - def f_typed(p1: str = None, p2: Decimal = None, p3: Any = None, p4: Optional[RunConfiguration] = None, p5: TSecretValue = dlt.secrets.value) -> None: + def f_typed( + p1: str = None, + p2: Decimal = None, + p3: Any = None, + p4: Optional[RunConfiguration] = None, + p5: TSecretValue = dlt.secrets.value, + ) -> None: pass SPEC = spec_from_signature(f_typed, inspect.signature(f_typed)) @@ -30,11 +40,23 @@ def f_typed(p1: str = None, p2: Decimal = None, p3: Any = None, p4: Optional[Run assert SPEC.p4 is None assert SPEC.p5 is None fields = SPEC.get_resolvable_fields() - assert fields == {"p1": Optional[str], "p2": Optional[Decimal], "p3": Optional[Any], "p4": Optional[RunConfiguration], "p5": TSecretValue} + assert fields == { + "p1": Optional[str], + "p2": Optional[Decimal], + "p3": Optional[Any], + "p4": Optional[RunConfiguration], + "p5": TSecretValue, + } # spec from typed signatures with defaults - def f_typed_default(t_p1: str = "str", t_p2: Decimal = _DECIMAL_DEFAULT, t_p3: Any = _SECRET_DEFAULT, t_p4: RunConfiguration = _CONFIG_DEFAULT, t_p5: str = None) -> None: + def f_typed_default( + t_p1: str = "str", + t_p2: Decimal = _DECIMAL_DEFAULT, + t_p3: Any = _SECRET_DEFAULT, + t_p4: RunConfiguration = _CONFIG_DEFAULT, + t_p5: str = None, + ) -> None: pass SPEC = spec_from_signature(f_typed_default, inspect.signature(f_typed_default)) @@ -46,11 +68,17 @@ def f_typed_default(t_p1: str = "str", t_p2: Decimal = _DECIMAL_DEFAULT, t_p3: A fields = SPEC().get_resolvable_fields() # Any will not assume TSecretValue type because at runtime it's a str # setting default as None will convert type into optional (t_p5) - assert fields == {"t_p1": str, "t_p2": Decimal, "t_p3": str, "t_p4": RunConfiguration, "t_p5": Optional[str]} + assert fields == { + "t_p1": str, + "t_p2": Decimal, + "t_p3": str, + "t_p4": RunConfiguration, + "t_p5": Optional[str], + } # spec from untyped signature - def f_untyped(untyped_p1 = None, untyped_p2 = dlt.config.value) -> None: + def f_untyped(untyped_p1=None, untyped_p2=dlt.config.value) -> None: pass SPEC = spec_from_signature(f_untyped, inspect.signature(f_untyped)) @@ -61,11 +89,14 @@ def f_untyped(untyped_p1 = None, untyped_p2 = dlt.config.value) -> None: # spec types derived from defaults - - def f_untyped_default(untyped_p1 = "str", untyped_p2 = _DECIMAL_DEFAULT, untyped_p3 = _CREDENTIALS_DEFAULT, untyped_p4 = None) -> None: + def f_untyped_default( + untyped_p1="str", + untyped_p2=_DECIMAL_DEFAULT, + untyped_p3=_CREDENTIALS_DEFAULT, + untyped_p4=None, + ) -> None: pass - SPEC = spec_from_signature(f_untyped_default, inspect.signature(f_untyped_default)) assert SPEC.untyped_p1 == "str" assert SPEC.untyped_p2 == _DECIMAL_DEFAULT @@ -73,11 +104,23 @@ def f_untyped_default(untyped_p1 = "str", untyped_p2 = _DECIMAL_DEFAULT, untyped assert SPEC.untyped_p4 is None fields = SPEC.get_resolvable_fields() # untyped_p4 converted to Optional[Any] - assert fields == {"untyped_p1": str, "untyped_p2": Decimal, "untyped_p3": ConnectionStringCredentials, "untyped_p4": Optional[Any]} + assert fields == { + "untyped_p1": str, + "untyped_p2": Decimal, + "untyped_p3": ConnectionStringCredentials, + "untyped_p4": Optional[Any], + } # spec from signatures containing positional only and keywords only args - def f_pos_kw_only(pos_only_1=dlt.config.value, pos_only_2: str = "default", /, *, kw_only_1=None, kw_only_2: int = 2) -> None: + def f_pos_kw_only( + pos_only_1=dlt.config.value, + pos_only_2: str = "default", + /, + *, + kw_only_1=None, + kw_only_2: int = 2, + ) -> None: pass SPEC = spec_from_signature(f_pos_kw_only, inspect.signature(f_pos_kw_only)) @@ -86,12 +129,19 @@ def f_pos_kw_only(pos_only_1=dlt.config.value, pos_only_2: str = "default", /, * assert SPEC.kw_only_1 is None assert SPEC.kw_only_2 == 2 fields = SPEC.get_resolvable_fields() - assert fields == {"pos_only_1": Any, "pos_only_2": str, "kw_only_1": Optional[Any], "kw_only_2": int} + assert fields == { + "pos_only_1": Any, + "pos_only_2": str, + "kw_only_1": Optional[Any], + "kw_only_2": int, + } # skip arguments with defaults # deregister spec to disable cache del globals()[SPEC.__name__] - SPEC = spec_from_signature(f_pos_kw_only, inspect.signature(f_pos_kw_only), include_defaults=False) + SPEC = spec_from_signature( + f_pos_kw_only, inspect.signature(f_pos_kw_only), include_defaults=False + ) assert not hasattr(SPEC, "kw_only_1") assert not hasattr(SPEC, "kw_only_2") assert not hasattr(SPEC, "pos_only_2") @@ -111,7 +161,6 @@ def f_variadic(var_1: str = "A", *args, kw_var_1: str, **kwargs) -> None: def test_spec_none_when_no_fields() -> None: - def f_default_only(arg1, arg2=None): pass @@ -119,7 +168,9 @@ def f_default_only(arg1, arg2=None): assert SPEC is not None del globals()[SPEC.__name__] - SPEC = spec_from_signature(f_default_only, inspect.signature(f_default_only), include_defaults=False) + SPEC = spec_from_signature( + f_default_only, inspect.signature(f_default_only), include_defaults=False + ) assert SPEC is None def f_no_spec(arg1): @@ -129,20 +180,39 @@ def f_no_spec(arg1): assert SPEC is None -def f_top_kw_defaults_args(arg1, arg2 = "top", arg3 = dlt.config.value, *args, kw1, kw_lit = "12131", kw_secret_val = dlt.secrets.value, **kwargs): +def f_top_kw_defaults_args( + arg1, + arg2="top", + arg3=dlt.config.value, + *args, + kw1, + kw_lit="12131", + kw_secret_val=dlt.secrets.value, + **kwargs, +): pass def test_argument_have_dlt_config_defaults() -> None: - def f_defaults( - req_val, config_val = dlt.config.value, secret_val = dlt.secrets.value, /, - pos_cf = None, pos_cf_val = dlt.config.value, pos_secret_val = dlt.secrets.value, *, - kw_val = None, kw_cf_val = dlt.config.value, kw_secret_val = dlt.secrets.value): + req_val, + config_val=dlt.config.value, + secret_val=dlt.secrets.value, + /, + pos_cf=None, + pos_cf_val=dlt.config.value, + pos_secret_val=dlt.secrets.value, + *, + kw_val=None, + kw_cf_val=dlt.config.value, + kw_secret_val=dlt.secrets.value, + ): pass @with_config - def f_kw_defaults(*, kw1 = dlt.config.value, kw_lit = "12131", kw_secret_val = dlt.secrets.value, **kwargs): + def f_kw_defaults( + *, kw1=dlt.config.value, kw_lit="12131", kw_secret_val=dlt.secrets.value, **kwargs + ): pass # do not delete those spaces @@ -151,18 +221,42 @@ def f_kw_defaults(*, kw1 = dlt.config.value, kw_lit = "12131", kw_secret_val = d @with_config # they are part of the test - def f_kw_defaults_args(arg1, arg2 = 2, arg3 = dlt.config.value, *args, kw1, kw_lit = "12131", kw_secret_val = dlt.secrets.value, **kwargs): + def f_kw_defaults_args( + arg1, + arg2=2, + arg3=dlt.config.value, + *args, + kw1, + kw_lit="12131", + kw_secret_val=dlt.secrets.value, + **kwargs, + ): pass - node = get_func_def_node(f_defaults) assert node.name == "f_defaults" literal_defaults = get_literal_defaults(node) - assert literal_defaults == {'kw_secret_val': 'dlt.secrets.value', 'kw_cf_val': 'dlt.config.value', 'kw_val': 'None', 'pos_secret_val': 'dlt.secrets.value', 'pos_cf_val': 'dlt.config.value', 'pos_cf': 'None', 'secret_val': 'dlt.secrets.value', 'config_val': 'dlt.config.value'} + assert literal_defaults == { + "kw_secret_val": "dlt.secrets.value", + "kw_cf_val": "dlt.config.value", + "kw_val": "None", + "pos_secret_val": "dlt.secrets.value", + "pos_cf_val": "dlt.config.value", + "pos_cf": "None", + "secret_val": "dlt.secrets.value", + "config_val": "dlt.config.value", + } SPEC = spec_from_signature(f_defaults, inspect.signature(f_defaults)) fields = SPEC.get_resolvable_fields() # fields market with dlt config are not optional, same for required fields - for arg in ["config_val", "secret_val", "pos_cf_val", "pos_secret_val", "kw_cf_val", "kw_secret_val"]: + for arg in [ + "config_val", + "secret_val", + "pos_cf_val", + "pos_secret_val", + "kw_cf_val", + "kw_secret_val", + ]: assert not is_optional_type(fields[arg]) for arg in ["pos_cf", "kw_val"]: assert is_optional_type(fields[arg]) @@ -172,7 +266,11 @@ def f_kw_defaults_args(arg1, arg2 = 2, arg3 = dlt.config.value, *args, kw1, kw_l node = get_func_def_node(f_kw_defaults) assert node.name == "f_kw_defaults" literal_defaults = get_literal_defaults(node) - assert literal_defaults == {'kw_secret_val': 'dlt.secrets.value', 'kw_lit': "'12131'", "kw1": "dlt.config.value"} + assert literal_defaults == { + "kw_secret_val": "dlt.secrets.value", + "kw_lit": "'12131'", + "kw1": "dlt.config.value", + } SPEC = spec_from_signature(f_kw_defaults, inspect.signature(f_kw_defaults)) fields = SPEC.get_resolvable_fields() assert not is_optional_type(fields["kw_lit"]) @@ -183,9 +281,19 @@ def f_kw_defaults_args(arg1, arg2 = 2, arg3 = dlt.config.value, *args, kw1, kw_l assert node.name == "f_kw_defaults_args" literal_defaults = get_literal_defaults(node) # print(literal_defaults) - assert literal_defaults == {'kw_secret_val': 'dlt.secrets.value', 'kw_lit': "'12131'", 'arg3': 'dlt.config.value', 'arg2': '2'} + assert literal_defaults == { + "kw_secret_val": "dlt.secrets.value", + "kw_lit": "'12131'", + "arg3": "dlt.config.value", + "arg2": "2", + } node = get_func_def_node(f_top_kw_defaults_args) assert node.name == "f_top_kw_defaults_args" literal_defaults = get_literal_defaults(node) - assert literal_defaults == {'kw_secret_val': 'dlt.secrets.value', 'kw_lit': "'12131'", 'arg3': 'dlt.config.value', 'arg2': "'top'"} + assert literal_defaults == { + "kw_secret_val": "dlt.secrets.value", + "kw_lit": "'12131'", + "arg3": "dlt.config.value", + "arg2": "'top'", + } diff --git a/tests/common/runners/test_pipes.py b/tests/common/runners/test_pipes.py index ec2753f7b9..933c8ee440 100644 --- a/tests/common/runners/test_pipes.py +++ b/tests/common/runners/test_pipes.py @@ -1,13 +1,13 @@ -from subprocess import CalledProcessError import tempfile +from subprocess import CalledProcessError from typing import Any, Iterator, NamedTuple + import pytest -from dlt.common.exceptions import UnsupportedProcessStartMethodException +from dlt.common.exceptions import UnsupportedProcessStartMethodException from dlt.common.runners import TRunMetrics, Venv from dlt.common.runners.stdout import iter_stdout, iter_stdout_with_result -from dlt.common.runners.synth_pickle import encode_obj, decode_obj, decode_last_obj - +from dlt.common.runners.synth_pickle import decode_last_obj, decode_obj, encode_obj from dlt.common.utils import digest128b @@ -27,6 +27,7 @@ class _TestPickler(NamedTuple): # self.s1 = s1 # self.s2 = s2 + class _TestClassUnkField: pass # def __init__(self, s1: _TestPicklex, s2: str) -> None: @@ -55,19 +56,25 @@ def test_pickle_encoder_none() -> None: def test_synth_pickler_unknown_types() -> None: # synth unknown tuple - obj = decode_obj("LfDoYo19lgUOtTn0Ib6JgASVQAAAAAAAAACMH3Rlc3RzLmNvbW1vbi5ydW5uZXJzLnRlc3RfcGlwZXOUjAxfVGVzdFBpY2tsZXiUk5SMA1hZWpRLe4aUgZQu") + obj = decode_obj( + "LfDoYo19lgUOtTn0Ib6JgASVQAAAAAAAAACMH3Rlc3RzLmNvbW1vbi5ydW5uZXJzLnRlc3RfcGlwZXOUjAxfVGVzdFBpY2tsZXiUk5SMA1hZWpRLe4aUgZQu" + ) assert type(obj).__name__.endswith("_TestPicklex") # this is completely different type assert not isinstance(obj, tuple) # synth unknown class containing other unknown types - obj = decode_obj("Koyo502yl4IKMqIxUTJFgASVbQAAAAAAAACMH3Rlc3RzLmNvbW1vbi5ydW5uZXJzLnRlc3RfcGlwZXOUjApfVGVzdENsYXNzlJOUKYGUfZQojAJzMZRoAIwMX1Rlc3RQaWNrbGV4lJOUjAFZlEsXhpSBlIwCczKUjAFVlIwDX3MzlEsDdWIu") + obj = decode_obj( + "Koyo502yl4IKMqIxUTJFgASVbQAAAAAAAACMH3Rlc3RzLmNvbW1vbi5ydW5uZXJzLnRlc3RfcGlwZXOUjApfVGVzdENsYXNzlJOUKYGUfZQojAJzMZRoAIwMX1Rlc3RQaWNrbGV4lJOUjAFZlEsXhpSBlIwCczKUjAFVlIwDX3MzlEsDdWIu" + ) assert type(obj).__name__.endswith("_TestClass") # tuple inside will be synthesized as well assert type(obj.s1).__name__.endswith("_TestPicklex") # known class containing unknown types - obj = decode_obj("PozhjHuf2oS7jPcRxKoagASVbQAAAAAAAACMH3Rlc3RzLmNvbW1vbi5ydW5uZXJzLnRlc3RfcGlwZXOUjBJfVGVzdENsYXNzVW5rRmllbGSUk5QpgZR9lCiMAnMxlGgAjAxfVGVzdFBpY2tsZXiUk5SMAVmUSxeGlIGUjAJzMpSMAVWUdWIu") + obj = decode_obj( + "PozhjHuf2oS7jPcRxKoagASVbQAAAAAAAACMH3Rlc3RzLmNvbW1vbi5ydW5uZXJzLnRlc3RfcGlwZXOUjBJfVGVzdENsYXNzVW5rRmllbGSUk5QpgZR9lCiMAnMxlGgAjAxfVGVzdFBpY2tsZXiUk5SMAVmUSxeGlIGUjAJzMpSMAVWUdWIu" + ) assert isinstance(obj, _TestClassUnkField) assert type(obj.s1).__name__.endswith("_TestPicklex") @@ -88,7 +95,9 @@ def test_iter_stdout() -> None: lines = list(iter_stdout(venv, "python", "tests/common/scripts/empty.py")) assert lines == [] with pytest.raises(CalledProcessError) as cpe: - list(iter_stdout(venv, "python", "tests/common/scripts/no_stdout_no_stderr_with_fail.py")) + list( + iter_stdout(venv, "python", "tests/common/scripts/no_stdout_no_stderr_with_fail.py") + ) # empty stdout assert cpe.value.output == "" assert cpe.value.stderr == "" @@ -102,7 +111,9 @@ def test_iter_stdout_raises() -> None: with Venv.create(tempfile.mkdtemp()) as venv: expected = ["0", "1", "2"] with pytest.raises(CalledProcessError) as cpe: - for i, line in enumerate(iter_stdout(venv, "python", "tests/common/scripts/raising_counter.py")): + for i, line in enumerate( + iter_stdout(venv, "python", "tests/common/scripts/raising_counter.py") + ): assert expected[i] == line assert cpe.value.returncode == 1 # the last output line is available @@ -120,7 +131,9 @@ def test_iter_stdout_raises() -> None: # three lines with 1 MB size + newline _i = -1 with pytest.raises(CalledProcessError) as cpe: - for _i, line in enumerate(iter_stdout(venv, "python", "tests/common/scripts/long_lines_fails.py")): + for _i, line in enumerate( + iter_stdout(venv, "python", "tests/common/scripts/long_lines_fails.py") + ): assert len(line) == 1024 * 1024 assert line == "a" * 1024 * 1024 # there were 3 lines @@ -158,11 +171,15 @@ def test_iter_stdout_with_result() -> None: assert iter_until_returns(i) is None # it just excepts without encoding exception with pytest.raises(CalledProcessError): - i = iter_stdout_with_result(venv, "python", "tests/common/scripts/no_stdout_no_stderr_with_fail.py") + i = iter_stdout_with_result( + venv, "python", "tests/common/scripts/no_stdout_no_stderr_with_fail.py" + ) iter_until_returns(i) # this raises a decoded exception: UnsupportedProcessStartMethodException with pytest.raises(UnsupportedProcessStartMethodException): - i = iter_stdout_with_result(venv, "python", "tests/common/scripts/stdout_encode_exception.py") + i = iter_stdout_with_result( + venv, "python", "tests/common/scripts/stdout_encode_exception.py" + ) iter_until_returns(i) diff --git a/tests/common/runners/test_runnable.py b/tests/common/runners/test_runnable.py index 9ba621d6fe..43ae5b4795 100644 --- a/tests/common/runners/test_runnable.py +++ b/tests/common/runners/test_runnable.py @@ -1,15 +1,20 @@ import gc -import pytest import multiprocessing -from multiprocessing.pool import Pool from multiprocessing.dummy import Pool as ThreadPool +from multiprocessing.pool import Pool -from dlt.normalize.configuration import SchemaStorageConfiguration +import pytest +from tests.common.runners.utils import ( + ALL_METHODS, + _TestRunnableWorker, + _TestRunnableWorkerMethod, + mp_method_auto, +) -from tests.common.runners.utils import _TestRunnableWorkerMethod, _TestRunnableWorker, ALL_METHODS, mp_method_auto +from dlt.normalize.configuration import SchemaStorageConfiguration -@pytest.mark.parametrize('method', ALL_METHODS) +@pytest.mark.parametrize("method", ALL_METHODS) def test_runnable_process_pool(method: str) -> None: multiprocessing.set_start_method(method, force=True) # 4 tasks diff --git a/tests/common/runners/test_runners.py b/tests/common/runners/test_runners.py index c8fbe247d3..aa82298880 100644 --- a/tests/common/runners/test_runners.py +++ b/tests/common/runners/test_runners.py @@ -1,17 +1,21 @@ -import pytest import multiprocessing from typing import Type -from dlt.common.runtime import signals -from dlt.common.configuration import resolve_configuration, configspec +import pytest +from tests.common.runners.utils import ( + ALL_METHODS, + _TestRunnableWorker, + _TestRunnableWorkerMethod, + mp_method_auto, +) +from tests.utils import init_test_logging + +from dlt.common.configuration import configspec, resolve_configuration from dlt.common.configuration.specs.run_configuration import RunConfiguration from dlt.common.exceptions import DltException, SignalReceivedException from dlt.common.runners import pool_runner as runner -from dlt.common.runtime import initialize_runtime from dlt.common.runners.configuration import PoolRunnerConfiguration, TPoolType - -from tests.common.runners.utils import _TestRunnableWorkerMethod, _TestRunnableWorker, ALL_METHODS, mp_method_auto -from tests.utils import init_test_logging +from dlt.common.runtime import initialize_runtime, signals @configspec @@ -43,6 +47,7 @@ def logger_autouse() -> None: _counter = 0 + @pytest.fixture(autouse=True) def default_args() -> None: signals._received_signal = 0 @@ -117,15 +122,12 @@ def test_single_non_idle_run() -> None: def test_runnable_with_runner() -> None: r = _TestRunnableWorkerMethod(4) - runs_count = runner.run_pool( - configure(ThreadPoolConfiguration), - r - ) + runs_count = runner.run_pool(configure(ThreadPoolConfiguration), r) assert runs_count == 1 assert [v[0] for v in r.rv] == list(range(4)) -@pytest.mark.parametrize('method', ALL_METHODS) +@pytest.mark.parametrize("method", ALL_METHODS) def test_pool_runner_process_methods(method) -> None: multiprocessing.set_start_method(method, force=True) r = _TestRunnableWorker(4) @@ -133,9 +135,6 @@ def test_pool_runner_process_methods(method) -> None: C = resolve_configuration(RunConfiguration()) initialize_runtime(C) - runs_count = runner.run_pool( - configure(ProcessPoolConfiguration), - r - ) + runs_count = runner.run_pool(configure(ProcessPoolConfiguration), r) assert runs_count == 1 assert [v[0] for v in r.rv] == list(range(4)) diff --git a/tests/common/runners/test_venv.py b/tests/common/runners/test_venv.py index cd4f2726a4..d3c9878711 100644 --- a/tests/common/runners/test_venv.py +++ b/tests/common/runners/test_venv.py @@ -1,16 +1,16 @@ -import sys import os -from subprocess import CalledProcessError, PIPE +import shutil +import sys import tempfile +from subprocess import PIPE, CalledProcessError + import pytest -import shutil +from tests.utils import preserve_environ from dlt.common.exceptions import CannotInstallDependency from dlt.common.runners import Venv, VenvNotFound from dlt.common.utils import custom_environ -from tests.utils import preserve_environ - def test_create_venv() -> None: with Venv.create(tempfile.mkdtemp()) as venv: @@ -235,7 +235,9 @@ def test_start_command() -> None: # custom environ with custom_environ({"_CUSTOM_ENV_VALUE": "uniq"}): - with venv.start_command("python", "tests/common/scripts/environ.py", stdout=PIPE, text=True) as process: + with venv.start_command( + "python", "tests/common/scripts/environ.py", stdout=PIPE, text=True + ) as process: output, _ = process.communicate() assert process.poll() == 0 assert "_CUSTOM_ENV_VALUE" in output diff --git a/tests/common/runners/utils.py b/tests/common/runners/utils.py index 57c92190f0..124de9a301 100644 --- a/tests/common/runners/utils.py +++ b/tests/common/runners/utils.py @@ -1,16 +1,17 @@ -import os -import pytest import multiprocessing +import os +from multiprocessing.pool import Pool from time import sleep from typing import Iterator, Tuple -from multiprocessing.pool import Pool + +import pytest from dlt.common import logger -from dlt.common.runners import TRunMetrics, Runnable, workermethod +from dlt.common.runners import Runnable, TRunMetrics, workermethod from dlt.common.utils import uniq_id # remove fork-server because it hangs the tests no CI -ALL_METHODS = set(multiprocessing.get_all_start_methods()).intersection(['fork', 'spawn']) +ALL_METHODS = set(multiprocessing.get_all_start_methods()).intersection(["fork", "spawn"]) @pytest.fixture(autouse=True) @@ -21,7 +22,6 @@ def mp_method_auto() -> Iterator[None]: class _TestRunnableWorkerMethod(Runnable): - def __init__(self, tasks: int) -> None: self.uniq = uniq_id() self.tasks = tasks @@ -37,7 +37,9 @@ def worker(self: "_TestRunnableWorkerMethod", v: int) -> Tuple[int, str, int]: def _run(self, pool: Pool) -> Iterator[Tuple[int, str, int]]: rid = id(self) assert rid in _TestRunnableWorkerMethod.RUNNING - self.rv = rv = pool.starmap(_TestRunnableWorkerMethod.worker, [(rid, i) for i in range(self.tasks)]) + self.rv = rv = pool.starmap( + _TestRunnableWorkerMethod.worker, [(rid, i) for i in range(self.tasks)] + ) assert rid in _TestRunnableWorkerMethod.RUNNING return rv @@ -47,7 +49,6 @@ def run(self, pool: Pool) -> TRunMetrics: class _TestRunnableWorker(Runnable): - def __init__(self, tasks: int) -> None: self.tasks = tasks self.rv = None @@ -60,7 +61,7 @@ def worker(v: int) -> Tuple[int, int]: return (v, os.getpid()) def _run(self, pool: Pool) -> Iterator[Tuple[int, str, int]]: - self.rv = rv = pool.starmap(_TestRunnableWorker.worker, [(i, ) for i in range(self.tasks)]) + self.rv = rv = pool.starmap(_TestRunnableWorker.worker, [(i,) for i in range(self.tasks)]) return rv def run(self, pool: Pool) -> TRunMetrics: diff --git a/tests/common/runtime/test_collector.py b/tests/common/runtime/test_collector.py index 600c3b3d4b..c4e79ec2ee 100644 --- a/tests/common/runtime/test_collector.py +++ b/tests/common/runtime/test_collector.py @@ -1,7 +1,8 @@ from collections import defaultdict import pytest -from dlt.common.runtime.collector import NullCollector, DictCollector, Collector + +from dlt.common.runtime.collector import Collector, DictCollector, NullCollector def test_null_collector() -> None: @@ -45,4 +46,4 @@ def test_dict_collector_reset_counters(): assert collector.counters["counter1"] == 5 with DictCollector()("test2") as collector: - assert collector.counters == defaultdict(int) \ No newline at end of file + assert collector.counters == defaultdict(int) diff --git a/tests/common/runtime/test_logging.py b/tests/common/runtime/test_logging.py index 080a7bf3e4..da686a54fe 100644 --- a/tests/common/runtime/test_logging.py +++ b/tests/common/runtime/test_logging.py @@ -1,16 +1,16 @@ -import pytest -import json_logging from importlib.metadata import version as pkg_version +import json_logging +import pytest +from tests.common.configuration.utils import environment +from tests.common.runtime.utils import mock_github_env, mock_image_env, mock_pod_env +from tests.utils import init_test_logging, preserve_environ + from dlt.common import logger -from dlt.common.runtime import exec_info -from dlt.common.typing import StrStr from dlt.common.configuration import configspec from dlt.common.configuration.specs import RunConfiguration - -from tests.common.runtime.utils import mock_image_env, mock_github_env, mock_pod_env -from tests.common.configuration.utils import environment -from tests.utils import preserve_environ, init_test_logging +from dlt.common.runtime import exec_info +from dlt.common.typing import StrStr @configspec @@ -28,11 +28,16 @@ def test_version_extract(environment: StrStr) -> None: version = exec_info.dlt_version_info("logger") # assert version["dlt_version"].startswith(code_version) lib_version = pkg_version("dlt") - assert version == {'dlt_version': lib_version, 'pipeline_name': 'logger'} + assert version == {"dlt_version": lib_version, "pipeline_name": "logger"} # mock image info available in container mock_image_env(environment) version = exec_info.dlt_version_info("logger") - assert version == {'dlt_version': lib_version, 'commit_sha': '192891', 'pipeline_name': 'logger', 'image_version': 'scale/v:112'} + assert version == { + "dlt_version": lib_version, + "commit_sha": "192891", + "pipeline_name": "logger", + "image_version": "scale/v:112", + } def test_pod_info_extract(environment: StrStr) -> None: @@ -40,17 +45,29 @@ def test_pod_info_extract(environment: StrStr) -> None: assert pod_info == {} mock_pod_env(environment) pod_info = exec_info.kube_pod_info() - assert pod_info == {'kube_node_name': 'node_name', 'kube_pod_name': 'pod_name', 'kube_pod_namespace': 'namespace'} + assert pod_info == { + "kube_node_name": "node_name", + "kube_pod_name": "pod_name", + "kube_pod_namespace": "namespace", + } def test_github_info_extract(environment: StrStr) -> None: mock_github_env(environment) github_info = exec_info.github_info() - assert github_info == {"github_user": "rudolfix", "github_repository": "dlt-hub/beginners-workshop-2022", "github_repository_owner": "dlt-hub"} + assert github_info == { + "github_user": "rudolfix", + "github_repository": "dlt-hub/beginners-workshop-2022", + "github_repository_owner": "dlt-hub", + } mock_github_env(environment) del environment["GITHUB_USER"] github_info = exec_info.github_info() - assert github_info == {"github_user": "dlt-hub", "github_repository": "dlt-hub/beginners-workshop-2022", "github_repository_owner": "dlt-hub"} + assert github_info == { + "github_user": "dlt-hub", + "github_repository": "dlt-hub/beginners-workshop-2022", + "github_repository_owner": "dlt-hub", + } @pytest.mark.forked diff --git a/tests/common/runtime/test_signals.py b/tests/common/runtime/test_signals.py index b4e86ad241..a28ca6cc0e 100644 --- a/tests/common/runtime/test_signals.py +++ b/tests/common/runtime/test_signals.py @@ -1,14 +1,14 @@ import os -import pytest import time from multiprocessing.dummy import Process as DummyProcess +import pytest +from tests.utils import skipifwindows + from dlt.common import sleep from dlt.common.exceptions import SignalReceivedException from dlt.common.runtime import signals -from tests.utils import skipifwindows - @pytest.fixture(autouse=True) def clear_signal() -> None: @@ -64,7 +64,6 @@ def test_delayed_signals_context_manager() -> None: def test_sleep_signal() -> None: - thread_signal = 0 def _thread() -> None: diff --git a/tests/common/runtime/test_telemetry.py b/tests/common/runtime/test_telemetry.py index 0308946553..9fba6f5a4d 100644 --- a/tests/common/runtime/test_telemetry.py +++ b/tests/common/runtime/test_telemetry.py @@ -1,25 +1,33 @@ -from typing import Any -import os -import pytest import logging +import os +from typing import Any from unittest.mock import patch +import pytest +from tests.common.configuration.utils import environment +from tests.common.runtime.utils import mock_github_env, mock_image_env, mock_pod_env +from tests.utils import ( + init_test_logging, + preserve_environ, + skipifspawn, + skipifwindows, + start_test_telemetry, +) + from dlt.common import logger -from dlt.common.runtime.segment import get_anonymous_id, track, disable_segment -from dlt.common.typing import DictStrAny, StrStr from dlt.common.configuration import configspec from dlt.common.configuration.specs import RunConfiguration +from dlt.common.runtime.segment import disable_segment, get_anonymous_id, track +from dlt.common.typing import DictStrAny, StrStr from dlt.version import DLT_PKG_NAME, __version__ -from tests.common.runtime.utils import mock_image_env, mock_github_env, mock_pod_env -from tests.common.configuration.utils import environment -from tests.utils import preserve_environ, skipifspawn, skipifwindows, init_test_logging, start_test_telemetry - @configspec class SentryLoggerConfiguration(RunConfiguration): pipeline_name: str = "logger" - sentry_dsn: str = "https://6f6f7b6f8e0f458a89be4187603b55fe@o1061158.ingest.sentry.io/4504819859914752" + sentry_dsn: str = ( + "https://6f6f7b6f8e0f458a89be4187603b55fe@o1061158.ingest.sentry.io/4504819859914752" + ) dlthub_telemetry_segment_write_key: str = "TLJiyRkGVZGCi2TtjClamXpFcxAA1rSB" @@ -30,6 +38,7 @@ class SentryLoggerCriticalConfiguration(SentryLoggerConfiguration): def test_sentry_log_level() -> None: from dlt.common.runtime.sentry import _get_sentry_log_level + sll = _get_sentry_log_level(SentryLoggerCriticalConfiguration(log_level="CRITICAL")) assert sll._handler.level == logging._nameToLevel["CRITICAL"] sll = _get_sentry_log_level(SentryLoggerCriticalConfiguration(log_level="ERROR")) @@ -88,6 +97,8 @@ def test_cleanup(environment: StrStr) -> None: SENT_ITEMS = [] + + def _mock_before_send(event: DictStrAny, _unused_hint: Any = None) -> DictStrAny: # print(event) SENT_ITEMS.append(event) diff --git a/tests/common/runtime/utils.py b/tests/common/runtime/utils.py index 2e3a679d52..9de63767f8 100644 --- a/tests/common/runtime/utils.py +++ b/tests/common/runtime/utils.py @@ -1,5 +1,6 @@ from dlt.common.typing import StrStr + def mock_image_env(environment: StrStr) -> None: environment["COMMIT_SHA"] = "192891" environment["IMAGE_VERSION"] = "scale/v:112" diff --git a/tests/common/schema/test_coercion.py b/tests/common/schema/test_coercion.py index 0fcdbe1393..de362e37fb 100644 --- a/tests/common/schema/test_coercion.py +++ b/tests/common/schema/test_coercion.py @@ -1,17 +1,16 @@ +import datetime # noqa: I251 from collections.abc import Mapping, MutableSequence from copy import copy from typing import Any, Type + import pytest -import datetime # noqa: I251 from hexbytes import HexBytes - from pendulum.tz import UTC +from tests.cases import JSON_TYPED_DICT, JSON_TYPED_DICT_TYPES from dlt.common import Decimal, Wei, json, pendulum +from dlt.common.data_types import TDataType, coerce_value, py_type_to_sc_type from dlt.common.json import _DATETIME, custom_pua_decode_nested -from dlt.common.data_types import coerce_value, py_type_to_sc_type, TDataType - -from tests.cases import JSON_TYPED_DICT, JSON_TYPED_DICT_TYPES def test_coerce_same_type() -> None: @@ -29,9 +28,11 @@ def test_coerce_type_to_text() -> None: # double into text assert coerce_value("text", "double", -1726.1288) == "-1726.1288" # bytes to text (base64) - assert coerce_value("text", "binary", b'binary string') == "YmluYXJ5IHN0cmluZw==" + assert coerce_value("text", "binary", b"binary string") == "YmluYXJ5IHN0cmluZw==" # HexBytes to text (hex with prefix) - assert coerce_value("text", "binary", HexBytes(b'binary string')) == "0x62696e61727920737472696e67" + assert ( + coerce_value("text", "binary", HexBytes(b"binary string")) == "0x62696e61727920737472696e67" + ) def test_coerce_type_to_bool() -> None: @@ -49,7 +50,7 @@ def test_coerce_type_to_bool() -> None: with pytest.raises(ValueError): coerce_value("bool", "complex", {"a": True}) with pytest.raises(ValueError): - coerce_value("bool", "binary", b'True') + coerce_value("bool", "binary", b"True") with pytest.raises(ValueError): coerce_value("bool", "timestamp", pendulum.now()) @@ -60,7 +61,7 @@ def test_coerce_type_to_double() -> None: # text into double if parsable assert coerce_value("double", "text", " -1726.1288 ") == -1726.1288 # hex text into double - assert coerce_value("double", "text", "0xff") == 255.0 + assert coerce_value("double", "text", "0xff") == 255.0 # wei, decimal to double assert coerce_value("double", "wei", Wei.from_int256(2137, decimals=2)) == 21.37 assert coerce_value("double", "decimal", Decimal("-1121.11")) == -1121.11 @@ -94,10 +95,7 @@ def test_coerce_type_to_bigint() -> None: coerce_value("bigint", "text", "912.12") -@pytest.mark.parametrize("dec_cls,data_type", [ - (Decimal, "decimal"), - (Wei, "wei") -]) +@pytest.mark.parametrize("dec_cls,data_type", [(Decimal, "decimal"), (Wei, "wei")]) def test_coerce_to_numeric(dec_cls: Type[Any], data_type: TDataType) -> None: v = coerce_value(data_type, "text", " -1726.839283 ") assert type(v) is dec_cls @@ -133,20 +131,36 @@ def test_coerce_type_from_hex_text() -> None: def test_coerce_type_to_timestamp() -> None: # timestamp cases - assert coerce_value("timestamp", "text", " 1580405246 ") == pendulum.parse("2020-01-30T17:27:26+00:00") + assert coerce_value("timestamp", "text", " 1580405246 ") == pendulum.parse( + "2020-01-30T17:27:26+00:00" + ) # the tenths of microseconds will be ignored - assert coerce_value("timestamp", "double", 1633344898.7415245) == pendulum.parse("2021-10-04T10:54:58.741524+00:00") + assert coerce_value("timestamp", "double", 1633344898.7415245) == pendulum.parse( + "2021-10-04T10:54:58.741524+00:00" + ) # if text is ISO string it will be coerced - assert coerce_value("timestamp", "text", "2022-05-10T03:41:31.466000+00:00") == pendulum.parse("2022-05-10T03:41:31.466000+00:00") - assert coerce_value("timestamp", "text", "2022-05-10T03:41:31.466+02:00") == pendulum.parse("2022-05-10T01:41:31.466Z") - assert coerce_value("timestamp", "text", "2022-05-10T03:41:31.466+0200") == pendulum.parse("2022-05-10T01:41:31.466Z") + assert coerce_value("timestamp", "text", "2022-05-10T03:41:31.466000+00:00") == pendulum.parse( + "2022-05-10T03:41:31.466000+00:00" + ) + assert coerce_value("timestamp", "text", "2022-05-10T03:41:31.466+02:00") == pendulum.parse( + "2022-05-10T01:41:31.466Z" + ) + assert coerce_value("timestamp", "text", "2022-05-10T03:41:31.466+0200") == pendulum.parse( + "2022-05-10T01:41:31.466Z" + ) # parse almost ISO compliant string - assert coerce_value("timestamp", "text", "2022-04-26 10:36+02") == pendulum.parse("2022-04-26T10:36:00+02:00") - assert coerce_value("timestamp", "text", "2022-04-26 10:36") == pendulum.parse("2022-04-26T10:36:00+00:00") + assert coerce_value("timestamp", "text", "2022-04-26 10:36+02") == pendulum.parse( + "2022-04-26T10:36:00+02:00" + ) + assert coerce_value("timestamp", "text", "2022-04-26 10:36") == pendulum.parse( + "2022-04-26T10:36:00+00:00" + ) # parse date string assert coerce_value("timestamp", "text", "2021-04-25") == pendulum.parse("2021-04-25") # from date type - assert coerce_value("timestamp", "date", datetime.date(2023, 2, 27)) == pendulum.parse("2023-02-27") + assert coerce_value("timestamp", "date", datetime.date(2023, 2, 27)) == pendulum.parse( + "2023-02-27" + ) # fails on "now" - yes pendulum by default parses "now" as .now() with pytest.raises(ValueError): @@ -185,26 +199,40 @@ def test_coerce_type_to_timestamp() -> None: def test_coerce_type_to_date() -> None: # from datetime object - assert coerce_value("date", "timestamp", pendulum.datetime(1995, 5, 6, 00, 1, 1, tz=UTC)) == pendulum.parse("1995-05-06", exact=True) + assert coerce_value( + "date", "timestamp", pendulum.datetime(1995, 5, 6, 00, 1, 1, tz=UTC) + ) == pendulum.parse("1995-05-06", exact=True) # from unix timestamp - assert coerce_value("date", "double", 1677546399.494264) == pendulum.parse("2023-02-28", exact=True) + assert coerce_value("date", "double", 1677546399.494264) == pendulum.parse( + "2023-02-28", exact=True + ) assert coerce_value("date", "text", " 1677546399 ") == pendulum.parse("2023-02-28", exact=True) # ISO date string assert coerce_value("date", "text", "2023-02-27") == pendulum.parse("2023-02-27", exact=True) # ISO datetime string - assert coerce_value("date", "text", "2022-05-10T03:41:31.466000+00:00") == pendulum.parse("2022-05-10", exact=True) - assert coerce_value("date", "text", "2022-05-10T03:41:31.466+02:00") == pendulum.parse("2022-05-10", exact=True) - assert coerce_value("date", "text", "2022-05-10T03:41:31.466+0200") == pendulum.parse("2022-05-10", exact=True) + assert coerce_value("date", "text", "2022-05-10T03:41:31.466000+00:00") == pendulum.parse( + "2022-05-10", exact=True + ) + assert coerce_value("date", "text", "2022-05-10T03:41:31.466+02:00") == pendulum.parse( + "2022-05-10", exact=True + ) + assert coerce_value("date", "text", "2022-05-10T03:41:31.466+0200") == pendulum.parse( + "2022-05-10", exact=True + ) # almost ISO compliant string - assert coerce_value("date", "text", "2022-04-26 10:36+02") == pendulum.parse("2022-04-26", exact=True) - assert coerce_value("date", "text", "2022-04-26 10:36") == pendulum.parse("2022-04-26", exact=True) + assert coerce_value("date", "text", "2022-04-26 10:36+02") == pendulum.parse( + "2022-04-26", exact=True + ) + assert coerce_value("date", "text", "2022-04-26 10:36") == pendulum.parse( + "2022-04-26", exact=True + ) def test_coerce_type_to_binary() -> None: # from hex string - assert coerce_value("binary", "text", "0x30") == b'0' + assert coerce_value("binary", "text", "0x30") == b"0" # from base64 - assert coerce_value("binary", "text", "YmluYXJ5IHN0cmluZw==") == b'binary string' + assert coerce_value("binary", "text", "YmluYXJ5IHN0cmluZw==") == b"binary string" # int into bytes assert coerce_value("binary", "bigint", 15) == b"\x0f" # can't into double @@ -260,8 +288,16 @@ def test_coerce_type_complex() -> None: def test_coerce_type_complex_with_pua() -> None: - v_dict = {"list": [1, Wei.from_int256(10**18), f"{_DATETIME}2022-05-10T01:41:31.466Z"], "str": "complex", "pua_date": f"{_DATETIME}2022-05-10T01:41:31.466Z"} - exp_v = {"list":[1, Wei.from_int256(10**18), "2022-05-10T01:41:31.466Z"],"str":"complex","pua_date":"2022-05-10T01:41:31.466Z"} + v_dict = { + "list": [1, Wei.from_int256(10**18), f"{_DATETIME}2022-05-10T01:41:31.466Z"], + "str": "complex", + "pua_date": f"{_DATETIME}2022-05-10T01:41:31.466Z", + } + exp_v = { + "list": [1, Wei.from_int256(10**18), "2022-05-10T01:41:31.466Z"], + "str": "complex", + "pua_date": "2022-05-10T01:41:31.466Z", + } assert coerce_value("complex", "complex", copy(v_dict)) == exp_v assert coerce_value("text", "complex", copy(v_dict)) == json.dumps(exp_v) # also decode recursively diff --git a/tests/common/schema/test_detections.py b/tests/common/schema/test_detections.py index 3a74c6f368..5c5bf4c89e 100644 --- a/tests/common/schema/test_detections.py +++ b/tests/common/schema/test_detections.py @@ -1,8 +1,16 @@ from hexbytes import HexBytes -from dlt.common import pendulum, Decimal, Wei +from dlt.common import Decimal, Wei, pendulum +from dlt.common.schema.detections import ( + _FLOAT_TS_RANGE, + _NOW_TS, + is_hexbytes_to_text, + is_iso_timestamp, + is_large_integer, + is_timestamp, + is_wei_to_double, +) from dlt.common.schema.utils import autodetect_sc_type -from dlt.common.schema.detections import is_hexbytes_to_text, is_timestamp, is_iso_timestamp, is_large_integer, is_wei_to_double, _FLOAT_TS_RANGE, _NOW_TS def test_timestamp_detection() -> None: @@ -39,12 +47,12 @@ def test_detection_large_integer() -> None: assert is_large_integer(int, 2**64 // 2) == "wei" assert is_large_integer(int, 578960446186580977117854925043439539267) == "text" assert is_large_integer(int, 2**64 // 2 - 1) is None - assert is_large_integer(int, -2**64 // 2 - 1) is None + assert is_large_integer(int, -(2**64) // 2 - 1) is None def test_detection_hexbytes_to_text() -> None: - assert is_hexbytes_to_text(bytes, b'hey') is None - assert is_hexbytes_to_text(HexBytes, b'hey') == "text" + assert is_hexbytes_to_text(bytes, b"hey") is None + assert is_hexbytes_to_text(HexBytes, b"hey") == "text" def test_wei_to_double() -> None: @@ -57,7 +65,10 @@ def test_detection_function() -> None: assert autodetect_sc_type(["iso_timestamp"], str, str(pendulum.now())) == "timestamp" assert autodetect_sc_type(["iso_timestamp"], float, str(pendulum.now())) is None assert autodetect_sc_type(["timestamp"], str, str(pendulum.now())) is None - assert autodetect_sc_type(["timestamp", "iso_timestamp"], float, pendulum.now().timestamp()) == "timestamp" + assert ( + autodetect_sc_type(["timestamp", "iso_timestamp"], float, pendulum.now().timestamp()) + == "timestamp" + ) assert autodetect_sc_type(["timestamp", "large_integer"], int, 2**64) == "wei" - assert autodetect_sc_type(["large_integer", "hexbytes_to_text"], HexBytes, b'hey') == "text" + assert autodetect_sc_type(["large_integer", "hexbytes_to_text"], HexBytes, b"hey") == "text" assert autodetect_sc_type(["large_integer", "wei_to_double"], Wei, Wei(10**18)) == "double" diff --git a/tests/common/schema/test_filtering.py b/tests/common/schema/test_filtering.py index 8ab9df877d..51fd8482a3 100644 --- a/tests/common/schema/test_filtering.py +++ b/tests/common/schema/test_filtering.py @@ -1,12 +1,12 @@ -import pytest from copy import deepcopy -from dlt.common.schema.exceptions import ParentTableNotFoundException -from dlt.common.typing import StrAny +import pytest +from tests.common.utils import load_json_case + from dlt.common.schema import Schema +from dlt.common.schema.exceptions import ParentTableNotFoundException from dlt.common.schema.utils import new_table - -from tests.common.utils import load_json_case +from dlt.common.typing import StrAny @pytest.fixture @@ -49,9 +49,14 @@ def test_whole_row_filter_with_exception(schema: Schema) -> None: # mind that path event_bot__custom_data__included_object was also eliminated assert filtered_case == {} # this child of the row has exception (^event_bot__custom_data__included_object__ - the __ at the end select all childern but not the parent) - filtered_case = schema.filter_row("event_bot__custom_data__included_object", deepcopy(bot_case)["custom_data"]["included_object"]) + filtered_case = schema.filter_row( + "event_bot__custom_data__included_object", + deepcopy(bot_case)["custom_data"]["included_object"], + ) assert filtered_case == bot_case["custom_data"]["included_object"] - filtered_case = schema.filter_row("event_bot__custom_data__excluded_path", deepcopy(bot_case)["custom_data"]["excluded_path"]) + filtered_case = schema.filter_row( + "event_bot__custom_data__excluded_path", deepcopy(bot_case)["custom_data"]["excluded_path"] + ) assert filtered_case == {} @@ -59,16 +64,13 @@ def test_filter_parent_table_schema_update(schema: Schema) -> None: # filter out parent table and leave just child one. that should break the child-parent relationship and reject schema update _add_excludes(schema) source_row = { - "metadata": [{ - "elvl1": [{ - "elvl2": [{ - "id": "level3_kept" - }], - "f": "elvl1_removed" - }], - "f": "metadata_removed" - }] - } + "metadata": [ + { + "elvl1": [{"elvl2": [{"id": "level3_kept"}], "f": "elvl1_removed"}], + "f": "metadata_removed", + } + ] + } updates = [] @@ -95,7 +97,9 @@ def test_filter_parent_table_schema_update(schema: Schema) -> None: updates.clear() schema = Schema("event") _add_excludes(schema) - schema.get_table("event_bot")["filters"]["includes"].extend(["re:^metadata___dlt_", "re:^metadata__elvl1___dlt_"]) + schema.get_table("event_bot")["filters"]["includes"].extend( + ["re:^metadata___dlt_", "re:^metadata__elvl1___dlt_"] + ) schema._compile_settings() for (t, p), row in schema.normalize_data_item(source_row, "load_id", "event_bot"): row = schema.filter_row(t, row) @@ -115,7 +119,16 @@ def test_filter_parent_table_schema_update(schema: Schema) -> None: def _add_excludes(schema: Schema) -> None: bot_table = new_table("event_bot") - bot_table.setdefault("filters", {})["excludes"] = ["re:^metadata", "re:^is_flagged$", "re:^data", "re:^custom_data"] - bot_table["filters"]["includes"] = ["re:^data__custom$", "re:^custom_data__included_object__", "re:^metadata__elvl1__elvl2__"] + bot_table.setdefault("filters", {})["excludes"] = [ + "re:^metadata", + "re:^is_flagged$", + "re:^data", + "re:^custom_data", + ] + bot_table["filters"]["includes"] = [ + "re:^data__custom$", + "re:^custom_data__included_object__", + "re:^metadata__elvl1__elvl2__", + ] schema.update_schema(bot_table) schema._compile_settings() diff --git a/tests/common/schema/test_inference.py b/tests/common/schema/test_inference.py index a5436530ce..19b7b63862 100644 --- a/tests/common/schema/test_inference.py +++ b/tests/common/schema/test_inference.py @@ -1,13 +1,19 @@ -import pytest from copy import deepcopy from typing import Any + +import pytest from hexbytes import HexBytes +from tests.common.utils import load_json_case -from dlt.common import Wei, Decimal, pendulum, json +from dlt.common import Decimal, Wei, json, pendulum from dlt.common.json import custom_pua_decode from dlt.common.schema import Schema, utils -from dlt.common.schema.exceptions import CannotCoerceColumnException, CannotCoerceNullException, ParentTableNotFoundException, TablePropertiesConflictException -from tests.common.utils import load_json_case +from dlt.common.schema.exceptions import ( + CannotCoerceColumnException, + CannotCoerceNullException, + ParentTableNotFoundException, + TablePropertiesConflictException, +) @pytest.fixture @@ -54,7 +60,6 @@ def test_map_column_preferred_type(schema: Schema) -> None: assert schema._infer_column_type("AA", "confidence", skip_preferred=True) == "text" - def test_map_column_type(schema: Schema) -> None: # default mappings assert schema._infer_column_type("18271.11", "_column_name") == "text" @@ -80,7 +85,12 @@ def test_coerce_row(schema: Schema) -> None: timestamp_float = 78172.128 timestamp_str = "1970-01-01T21:42:52.128000+00:00" # add new column with preferred - row_1 = {"timestamp": timestamp_float, "confidence": "0.1", "value": "0xFF", "number": Decimal("128.67")} + row_1 = { + "timestamp": timestamp_float, + "confidence": "0.1", + "value": "0xFF", + "number": Decimal("128.67"), + } new_row_1, new_table = schema.coerce_row("event_user", None, row_1) # convert columns to list, they must correspond to the order of fields in row_1 new_columns = list(new_table["columns"].values()) @@ -94,7 +104,12 @@ def test_coerce_row(schema: Schema) -> None: assert new_columns[3]["data_type"] == "decimal" assert "variant" not in new_columns[3] # also rows values should be coerced (confidence) - assert new_row_1 == {"timestamp": pendulum.parse(timestamp_str), "confidence": 0.1, "value": 255, "number": Decimal("128.67")} + assert new_row_1 == { + "timestamp": pendulum.parse(timestamp_str), + "confidence": 0.1, + "value": 255, + "number": Decimal("128.67"), + } # update schema schema.update_schema(new_table) @@ -137,7 +152,9 @@ def test_coerce_row(schema: Schema) -> None: schema.update_schema(new_table) # variant column clashes with existing column - create new_colbool_v_binary column that would be created for binary variant, but give it a type datetime - _, new_table = schema.coerce_row("event_user", None, {"new_colbool": False, "new_colbool__v_timestamp": b"not fit"}) + _, new_table = schema.coerce_row( + "event_user", None, {"new_colbool": False, "new_colbool__v_timestamp": b"not fit"} + ) schema.update_schema(new_table) with pytest.raises(CannotCoerceColumnException) as exc_val: # now pass the binary that would create binary variant - but the column is occupied by text type @@ -179,7 +196,12 @@ def test_shorten_variant_column(schema: Schema) -> None: _add_preferred_types(schema) timestamp_float = 78172.128 # add new column with preferred - row_1 = {"timestamp": timestamp_float, "confidence": "0.1", "value": "0xFF", "number": Decimal("128.67")} + row_1 = { + "timestamp": timestamp_float, + "confidence": "0.1", + "value": "0xFF", + "number": Decimal("128.67"), + } _, new_table = schema.coerce_row("event_user", None, row_1) # schema assumes that identifiers are already normalized so confidence even if it is longer than 9 chars schema.update_schema(new_table) @@ -188,7 +210,9 @@ def test_shorten_variant_column(schema: Schema) -> None: # now variant is created and this will be normalized # TODO: we should move the handling of variants to normalizer new_row_2, new_table = schema.coerce_row("event_user", None, {"confidence": False}) - tag = schema.naming._compute_tag("confidence__v_bool", collision_prob=schema.naming._DEFAULT_COLLISION_PROB) + tag = schema.naming._compute_tag( + "confidence__v_bool", collision_prob=schema.naming._DEFAULT_COLLISION_PROB + ) new_row_2_keys = list(new_row_2.keys()) assert tag in new_row_2_keys[0] assert len(new_row_2_keys[0]) == 9 @@ -252,15 +276,18 @@ def test_supports_variant_pua_decode(schema: Schema) -> None: # pua encoding still present assert normalized_row[0][1]["wad"].startswith("") # decode pua - decoded_row = {k: custom_pua_decode(v) for k,v in normalized_row[0][1].items()} + decoded_row = {k: custom_pua_decode(v) for k, v in normalized_row[0][1].items()} assert isinstance(decoded_row["wad"], Wei) c_row, new_table = schema.coerce_row("eth", None, decoded_row) - assert c_row["wad__v_str"] == str(2**256-1) + assert c_row["wad__v_str"] == str(2**256 - 1) assert new_table["columns"]["wad__v_str"]["data_type"] == "text" def test_supports_variant(schema: Schema) -> None: - rows = [{"evm": Wei.from_int256(2137*10**16, decimals=18)}, {"evm": Wei.from_int256(2**256-1)}] + rows = [ + {"evm": Wei.from_int256(2137 * 10**16, decimals=18)}, + {"evm": Wei.from_int256(2**256 - 1)}, + ] normalized_rows = [] for row in rows: normalized_rows.extend(schema.normalize_data_item(row, "128812.2131", "event")) @@ -270,7 +297,7 @@ def test_supports_variant(schema: Schema) -> None: # row 2 contains Wei assert "evm" in normalized_rows[1][1] assert isinstance(normalized_rows[1][1]["evm"], Wei) - assert normalized_rows[1][1]["evm"] == 2**256-1 + assert normalized_rows[1][1]["evm"] == 2**256 - 1 # coerce row c_row, new_table = schema.coerce_row("eth", None, normalized_rows[0][1]) assert isinstance(c_row["evm"], Wei) @@ -281,13 +308,12 @@ def test_supports_variant(schema: Schema) -> None: # coerce row that should expand to variant c_row, new_table = schema.coerce_row("eth", None, normalized_rows[1][1]) assert isinstance(c_row["evm__v_str"], str) - assert c_row["evm__v_str"] == str(2**256-1) + assert c_row["evm__v_str"] == str(2**256 - 1) assert new_table["columns"]["evm__v_str"]["data_type"] == "text" assert new_table["columns"]["evm__v_str"]["variant"] is True def test_supports_recursive_variant(schema: Schema) -> None: - class RecursiveVariant(int): # provide __call__ for SupportVariant def __call__(self) -> Any: @@ -296,18 +322,16 @@ def __call__(self) -> Any: else: return ("div2", RecursiveVariant(self // 2)) - row = {"rv": RecursiveVariant(8)} c_row, new_table = schema.coerce_row("rec_variant", None, row) # this variant keeps expanding until the value is 1, we start from 8 so there are log2(8) == 3 divisions - col_name = "rv" + "__v_div2"*3 + col_name = "rv" + "__v_div2" * 3 assert c_row[col_name] == 1 assert new_table["columns"][col_name]["data_type"] == "bigint" assert new_table["columns"][col_name]["variant"] is True def test_supports_variant_autovariant_conflict(schema: Schema) -> None: - class PureVariant(int): def __init__(self, v: Any) -> None: self.v = v @@ -319,7 +343,7 @@ def __call__(self) -> Any: if isinstance(self.v, float): return ("text", self.v) - assert issubclass(PureVariant,int) + assert issubclass(PureVariant, int) rows = [{"pv": PureVariant(3377)}, {"pv": PureVariant(21.37)}] normalized_rows = [] for row in rows: @@ -412,9 +436,13 @@ def test_update_schema_table_prop_conflict(schema: Schema) -> None: def test_update_schema_column_conflict(schema: Schema) -> None: - tab1 = utils.new_table("tab1", write_disposition="append", columns=[ - {"name": "col1", "data_type": "text", "nullable": False}, - ]) + tab1 = utils.new_table( + "tab1", + write_disposition="append", + columns=[ + {"name": "col1", "data_type": "text", "nullable": False}, + ], + ) schema.update_schema(tab1) tab1_u1 = deepcopy(tab1) # simulate column that had other datatype inferred @@ -466,12 +494,10 @@ def test_autodetect_convert_type(schema: Schema) -> None: # make sure variants behave the same - class AlwaysWei(Decimal): def __call__(self) -> Any: return ("up", Wei(self)) - # create new column row = {"evm2": AlwaysWei(22)} c_row, new_table = schema.coerce_row("eth", None, row) @@ -497,4 +523,3 @@ def __call__(self) -> Any: c_row, new_table = schema.coerce_row("eth", None, row) assert c_row["evm2"] == 22.2 assert isinstance(c_row["evm2"], float) - diff --git a/tests/common/schema/test_schema.py b/tests/common/schema/test_schema.py index 5d16b3f57f..b322b0be43 100644 --- a/tests/common/schema/test_schema.py +++ b/tests/common/schema/test_schema.py @@ -1,24 +1,33 @@ -from copy import deepcopy import os +from copy import deepcopy from typing import List, Sequence, cast + import pytest +from tests.common.utils import COMMON_TEST_CASES_PATH, load_json_case, load_yml_case +from tests.utils import autouse_test_storage, preserve_environ from dlt.common import pendulum from dlt.common.configuration import resolve_configuration from dlt.common.configuration.container import Container -from dlt.common.storages import SchemaStorageConfiguration from dlt.common.destination.capabilities import DestinationCapabilitiesContext from dlt.common.exceptions import DictValidationException -from dlt.common.normalizers.naming import snake_case, direct +from dlt.common.normalizers.naming import direct, snake_case +from dlt.common.schema import Schema, TColumnSchema, TStoredSchema, utils +from dlt.common.schema.exceptions import ( + InvalidSchemaName, + ParentTableNotFoundException, + SchemaEngineNoUpgradePathException, +) +from dlt.common.schema.typing import ( + COLUMN_HINTS, + LOADS_TABLE_NAME, + VERSION_TABLE_NAME, + TColumnName, + TSimpleRegex, +) +from dlt.common.storages import SchemaStorage, SchemaStorageConfiguration from dlt.common.typing import DictStrAny, StrAny from dlt.common.utils import uniq_id -from dlt.common.schema import TColumnSchema, Schema, TStoredSchema, utils -from dlt.common.schema.exceptions import InvalidSchemaName, ParentTableNotFoundException, SchemaEngineNoUpgradePathException -from dlt.common.schema.typing import LOADS_TABLE_NAME, VERSION_TABLE_NAME, TColumnName, TSimpleRegex, COLUMN_HINTS -from dlt.common.storages import SchemaStorage - -from tests.utils import autouse_test_storage, preserve_environ -from tests.common.utils import load_json_case, load_yml_case, COMMON_TEST_CASES_PATH SCHEMA_NAME = "event" EXPECTED_FILE_NAME = f"{SCHEMA_NAME}.schema.json" @@ -30,17 +39,15 @@ def schema_storage() -> SchemaStorage: SchemaStorageConfiguration(), explicit_value={ "import_schema_path": "tests/common/cases/schemas/rasa", - "external_schema_format": "json" - } + "external_schema_format": "json", + }, ) return SchemaStorage(C, makedirs=True) @pytest.fixture def schema_storage_no_import() -> SchemaStorage: - C = resolve_configuration( - SchemaStorageConfiguration() - ) + C = resolve_configuration(SchemaStorageConfiguration()) return SchemaStorage(C, makedirs=True) @@ -51,15 +58,16 @@ def schema() -> Schema: @pytest.fixture def cn_schema() -> Schema: - return Schema("column_default", { - "names": "tests.common.normalizers.custom_normalizers", - "json": { - "module": "tests.common.normalizers.custom_normalizers", - "config": { - "not_null": ["fake_id"] - } - } - }) + return Schema( + "column_default", + { + "names": "tests.common.normalizers.custom_normalizers", + "json": { + "module": "tests.common.normalizers.custom_normalizers", + "config": {"not_null": ["fake_id"]}, + }, + }, + ) def test_normalize_schema_name(schema: Schema) -> None: @@ -118,7 +126,9 @@ def test_simple_regex_validator() -> None: assert utils.simple_regex_validator(".", "k", "v", TSimpleRegex) is True # validate regex - assert utils.simple_regex_validator(".", "k", TSimpleRegex("re:^_record$"), TSimpleRegex) is True + assert ( + utils.simple_regex_validator(".", "k", TSimpleRegex("re:^_record$"), TSimpleRegex) is True + ) # invalid regex with pytest.raises(DictValidationException) as e: utils.simple_regex_validator(".", "k", "re:[[^_record$", TSimpleRegex) @@ -169,7 +179,7 @@ def test_schema_name() -> None: Schema("1_a") # too long with pytest.raises(InvalidSchemaName) as exc: - Schema("a"*65) + Schema("a" * 65) def test_create_schema_with_normalize_name() -> None: @@ -178,11 +188,19 @@ def test_create_schema_with_normalize_name() -> None: def test_schema_descriptions_and_annotations(schema_storage: SchemaStorage): - schema = SchemaStorage.load_schema_file(os.path.join(COMMON_TEST_CASES_PATH, "schemas/local"), "event", extensions=("yaml", )) + schema = SchemaStorage.load_schema_file( + os.path.join(COMMON_TEST_CASES_PATH, "schemas/local"), "event", extensions=("yaml",) + ) assert schema.tables["blocks"]["description"] == "Ethereum blocks" assert schema.tables["blocks"]["x-annotation"] == "this will be preserved on save" - assert schema.tables["blocks"]["columns"]["_dlt_load_id"]["description"] == "load id coming from the extractor" - assert schema.tables["blocks"]["columns"]["_dlt_load_id"]["x-column-annotation"] == "column annotation preserved on save" + assert ( + schema.tables["blocks"]["columns"]["_dlt_load_id"]["description"] + == "load id coming from the extractor" + ) + assert ( + schema.tables["blocks"]["columns"]["_dlt_load_id"]["x-column-annotation"] + == "column annotation preserved on save" + ) # mod and save schema.tables["blocks"]["description"] += "Saved" @@ -194,8 +212,12 @@ def test_schema_descriptions_and_annotations(schema_storage: SchemaStorage): loaded_schema = schema_storage.load_schema("event") assert loaded_schema.tables["blocks"]["description"].endswith("Saved") assert loaded_schema.tables["blocks"]["x-annotation"].endswith("Saved") - assert loaded_schema.tables["blocks"]["columns"]["_dlt_load_id"]["description"].endswith("Saved") - assert loaded_schema.tables["blocks"]["columns"]["_dlt_load_id"]["x-column-annotation"].endswith("Saved") + assert loaded_schema.tables["blocks"]["columns"]["_dlt_load_id"]["description"].endswith( + "Saved" + ) + assert loaded_schema.tables["blocks"]["columns"]["_dlt_load_id"][ + "x-column-annotation" + ].endswith("Saved") def test_replace_schema_content() -> None: @@ -211,12 +233,21 @@ def test_replace_schema_content() -> None: assert schema_eth._imported_version_hash == schema._imported_version_hash -@pytest.mark.parametrize("columns,hint,value", [ - (["_dlt_id", "_dlt_root_id", "_dlt_load_id", "_dlt_parent_id", "_dlt_list_idx"], "nullable", False), - (["_dlt_id"], "unique", True), - (["_dlt_parent_id"], "foreign_key", True), -]) -def test_relational_normalizer_schema_hints(columns: Sequence[str], hint: str, value: bool, schema_storage: SchemaStorage) -> None: +@pytest.mark.parametrize( + "columns,hint,value", + [ + ( + ["_dlt_id", "_dlt_root_id", "_dlt_load_id", "_dlt_parent_id", "_dlt_list_idx"], + "nullable", + False, + ), + (["_dlt_id"], "unique", True), + (["_dlt_parent_id"], "foreign_key", True), + ], +) +def test_relational_normalizer_schema_hints( + columns: Sequence[str], hint: str, value: bool, schema_storage: SchemaStorage +) -> None: schema = schema_storage.load_schema("event") for name in columns: # infer column hints @@ -241,7 +272,9 @@ def test_save_store_schema(schema: Schema, schema_storage: SchemaStorage) -> Non assert_new_schema_values(schema_copy) -def test_save_store_schema_custom_normalizers(cn_schema: Schema, schema_storage: SchemaStorage) -> None: +def test_save_store_schema_custom_normalizers( + cn_schema: Schema, schema_storage: SchemaStorage +) -> None: schema_storage.save_schema(cn_schema) schema_copy = schema_storage.load_schema(cn_schema.name) assert_new_schema_values_custom_normalizers(schema_copy) @@ -289,7 +322,9 @@ def test_unknown_engine_upgrade() -> None: def test_preserve_column_order(schema: Schema, schema_storage: SchemaStorage) -> None: # python dicts are ordered from v3.6, add 50 column with random names - update: List[TColumnSchema] = [schema._infer_column(uniq_id(), pendulum.now().timestamp()) for _ in range(50)] + update: List[TColumnSchema] = [ + schema._infer_column(uniq_id(), pendulum.now().timestamp()) for _ in range(50) + ] schema.update_schema(utils.new_table("event_test_order", columns=update)) def verify_items(table, update) -> None: @@ -304,7 +339,9 @@ def verify_items(table, update) -> None: table = loaded_schema.get_table_columns("event_test_order") verify_items(table, update) # add more columns - update2: List[TColumnSchema] = [schema._infer_column(uniq_id(), pendulum.now().timestamp()) for _ in range(50)] + update2: List[TColumnSchema] = [ + schema._infer_column(uniq_id(), pendulum.now().timestamp()) for _ in range(50) + ] loaded_schema.update_schema(utils.new_table("event_test_order", columns=update2)) table = loaded_schema.get_table_columns("event_test_order") verify_items(table, update + update2) @@ -312,7 +349,7 @@ def verify_items(table, update) -> None: schema_storage.save_schema(loaded_schema) loaded_schema = schema_storage.load_schema("event") table = loaded_schema.get_table_columns("event_test_order") - verify_items(table, update + update2) + verify_items(table, update + update2) def test_get_schema_new_exist(schema_storage: SchemaStorage) -> None: @@ -320,16 +357,35 @@ def test_get_schema_new_exist(schema_storage: SchemaStorage) -> None: schema_storage.load_schema("wrongschema") -@pytest.mark.parametrize("columns,hint,value", [ - (["timestamp", "_timestamp", "_dist_key", "_dlt_id", "_dlt_root_id", "_dlt_load_id", "_dlt_parent_id", "_dlt_list_idx", "sender_id"], "nullable", False), - (["confidence", "_sender_id"], "nullable", True), - (["timestamp", "_timestamp"], "partition", True), - (["_dist_key", "sender_id"], "cluster", True), - (["_dlt_id"], "unique", True), - (["_dlt_parent_id"], "foreign_key", True), - (["timestamp", "_timestamp"], "sort", True), -]) -def test_rasa_event_hints(columns: Sequence[str], hint: str, value: bool, schema_storage: SchemaStorage) -> None: +@pytest.mark.parametrize( + "columns,hint,value", + [ + ( + [ + "timestamp", + "_timestamp", + "_dist_key", + "_dlt_id", + "_dlt_root_id", + "_dlt_load_id", + "_dlt_parent_id", + "_dlt_list_idx", + "sender_id", + ], + "nullable", + False, + ), + (["confidence", "_sender_id"], "nullable", True), + (["timestamp", "_timestamp"], "partition", True), + (["_dist_key", "sender_id"], "cluster", True), + (["_dlt_id"], "unique", True), + (["_dlt_parent_id"], "foreign_key", True), + (["timestamp", "_timestamp"], "sort", True), + ], +) +def test_rasa_event_hints( + columns: Sequence[str], hint: str, value: bool, schema_storage: SchemaStorage +) -> None: schema = schema_storage.load_schema("event") for name in columns: # infer column hints @@ -397,10 +453,16 @@ def test_merge_hints(schema: Schema) -> None: schema._settings["default_hints"] = {} schema._compiled_hints = {} new_hints = { - "not_null": ["_dlt_id", "_dlt_root_id", "_dlt_parent_id", "_dlt_list_idx", "re:^_dlt_load_id$"], - "foreign_key": ["re:^_dlt_parent_id$"], - "unique": ["re:^_dlt_id$"] - } + "not_null": [ + "_dlt_id", + "_dlt_root_id", + "_dlt_parent_id", + "_dlt_list_idx", + "re:^_dlt_load_id$", + ], + "foreign_key": ["re:^_dlt_parent_id$"], + "unique": ["re:^_dlt_id$"], + } schema.merge_hints(new_hints) assert schema._settings["default_hints"] == new_hints @@ -411,17 +473,21 @@ def test_merge_hints(schema: Schema) -> None: assert set(new_hints[k]) == set(schema._settings["default_hints"][k]) # add new stuff - new_new_hints = { - "not_null": ["timestamp"], - "primary_key": ["id"] - } + new_new_hints = {"not_null": ["timestamp"], "primary_key": ["id"]} schema.merge_hints(new_new_hints) expected_hints = { - "not_null": ["_dlt_id", "_dlt_root_id", "_dlt_parent_id", "_dlt_list_idx", "re:^_dlt_load_id$", "timestamp"], - "foreign_key": ["re:^_dlt_parent_id$"], - "unique": ["re:^_dlt_id$"], - "primary_key": ["id"] - } + "not_null": [ + "_dlt_id", + "_dlt_root_id", + "_dlt_parent_id", + "_dlt_list_idx", + "re:^_dlt_load_id$", + "timestamp", + ], + "foreign_key": ["re:^_dlt_parent_id$"], + "unique": ["re:^_dlt_id$"], + "primary_key": ["id"], + } assert len(expected_hints) == len(schema._settings["default_hints"]) for k in expected_hints: assert set(expected_hints[k]) == set(schema._settings["default_hints"][k]) @@ -432,8 +498,8 @@ def test_default_table_resource() -> None: eth_v5 = load_yml_case("schemas/eth/ethereum_schema_v5") tables = Schema.from_dict(eth_v5).tables - assert tables['blocks']['resource'] == 'blocks' - assert all([t.get('resource') is None for t in tables.values() if t.get('parent')]) + assert tables["blocks"]["resource"] == "blocks" + assert all([t.get("resource") is None for t in tables.values() if t.get("parent")]) def test_data_tables(schema: Schema, schema_storage: SchemaStorage) -> None: @@ -443,8 +509,10 @@ def test_data_tables(schema: Schema, schema_storage: SchemaStorage) -> None: # with tables schema = schema_storage.load_schema("event") # some of them are incomplete - assert set(schema.tables.keys()) == set([LOADS_TABLE_NAME, VERSION_TABLE_NAME, 'event_slot', 'event_user', 'event_bot']) - assert [t["name"] for t in schema.data_tables()] == ['event_slot'] + assert set(schema.tables.keys()) == set( + [LOADS_TABLE_NAME, VERSION_TABLE_NAME, "event_slot", "event_user", "event_bot"] + ) + assert [t["name"] for t in schema.data_tables()] == ["event_slot"] def test_write_disposition(schema_storage: SchemaStorage) -> None: @@ -469,28 +537,39 @@ def test_write_disposition(schema_storage: SchemaStorage) -> None: def test_compare_columns() -> None: - table = utils.new_table("test_table", columns=[ - {"name": "col1", "data_type": "text", "nullable": True}, - {"name": "col2", "data_type": "text", "nullable": False}, - {"name": "col3", "data_type": "timestamp", "nullable": True}, - {"name": "col4", "data_type": "timestamp", "nullable": True} - ]) - table2 = utils.new_table("test_table", columns=[ - {"name": "col1", "data_type": "text", "nullable": False} - ]) + table = utils.new_table( + "test_table", + columns=[ + {"name": "col1", "data_type": "text", "nullable": True}, + {"name": "col2", "data_type": "text", "nullable": False}, + {"name": "col3", "data_type": "timestamp", "nullable": True}, + {"name": "col4", "data_type": "timestamp", "nullable": True}, + ], + ) + table2 = utils.new_table( + "test_table", columns=[{"name": "col1", "data_type": "text", "nullable": False}] + ) # columns identical with self for c in table["columns"].values(): assert utils.compare_complete_columns(c, c) is True - assert utils.compare_complete_columns(table["columns"]["col3"], table["columns"]["col4"]) is False + assert ( + utils.compare_complete_columns(table["columns"]["col3"], table["columns"]["col4"]) is False + ) # data type may not differ - assert utils.compare_complete_columns(table["columns"]["col1"], table["columns"]["col3"]) is False + assert ( + utils.compare_complete_columns(table["columns"]["col1"], table["columns"]["col3"]) is False + ) # nullability may differ - assert utils.compare_complete_columns(table["columns"]["col1"], table2["columns"]["col1"]) is True + assert ( + utils.compare_complete_columns(table["columns"]["col1"], table2["columns"]["col1"]) is True + ) # any of the hints may differ for hint in COLUMN_HINTS: table["columns"]["col3"][hint] = True # name may not differ - assert utils.compare_complete_columns(table["columns"]["col3"], table["columns"]["col4"]) is False + assert ( + utils.compare_complete_columns(table["columns"]["col3"], table["columns"]["col4"]) is False + ) def test_normalize_table_identifiers() -> None: @@ -501,13 +580,18 @@ def test_normalize_table_identifiers() -> None: issues_table = deepcopy(schema.tables["issues"]) # this schema is already normalized so normalization is idempotent assert schema.tables["issues"] == schema.normalize_table_identifiers(issues_table) - assert schema.tables["issues"] == schema.normalize_table_identifiers(schema.normalize_table_identifiers(issues_table)) + assert schema.tables["issues"] == schema.normalize_table_identifiers( + schema.normalize_table_identifiers(issues_table) + ) def assert_new_schema_values_custom_normalizers(schema: Schema) -> None: # check normalizers config assert schema._normalizers_config["names"] == "tests.common.normalizers.custom_normalizers" - assert schema._normalizers_config["json"]["module"] == "tests.common.normalizers.custom_normalizers" + assert ( + schema._normalizers_config["json"]["module"] + == "tests.common.normalizers.custom_normalizers" + ) # check if schema was extended by json normalizer assert ["fake_id"] == schema.settings["default_hints"]["not_null"] # call normalizers @@ -529,13 +613,17 @@ def assert_new_schema_values(schema: Schema) -> None: assert schema.ENGINE_VERSION == 6 assert len(schema.settings["default_hints"]) > 0 # check settings - assert utils.standard_type_detections() == schema.settings["detections"] == schema._type_detections + assert ( + utils.standard_type_detections() == schema.settings["detections"] == schema._type_detections + ) # check normalizers config assert schema._normalizers_config["names"] == "snake_case" assert schema._normalizers_config["json"]["module"] == "dlt.common.normalizers.json.relational" assert isinstance(schema.naming, snake_case.NamingConvention) # check if schema was extended by json normalizer - assert set(["_dlt_id", "_dlt_root_id", "_dlt_parent_id", "_dlt_list_idx", "_dlt_load_id"]).issubset(schema.settings["default_hints"]["not_null"]) + assert set( + ["_dlt_id", "_dlt_root_id", "_dlt_parent_id", "_dlt_list_idx", "_dlt_load_id"] + ).issubset(schema.settings["default_hints"]["not_null"]) # call normalizers assert schema.naming.normalize_identifier("A") == "a" assert schema.naming.normalize_path("A__B") == "a__b" @@ -558,35 +646,62 @@ def test_group_tables_by_resource(schema: Schema) -> None: schema.update_schema(utils.new_table("b_events", columns=[])) schema.update_schema(utils.new_table("c_products", columns=[], resource="products")) schema.update_schema(utils.new_table("a_events__1", columns=[], parent_table_name="a_events")) - schema.update_schema(utils.new_table("a_events__1__2", columns=[], parent_table_name="a_events__1")) + schema.update_schema( + utils.new_table("a_events__1__2", columns=[], parent_table_name="a_events__1") + ) schema.update_schema(utils.new_table("b_events__1", columns=[], parent_table_name="b_events")) # All resources without filter expected_tables = { - "a_events": [schema.tables["a_events"], schema.tables["a_events__1"], schema.tables["a_events__1__2"]], + "a_events": [ + schema.tables["a_events"], + schema.tables["a_events__1"], + schema.tables["a_events__1__2"], + ], "b_events": [schema.tables["b_events"], schema.tables["b_events__1"]], "products": [schema.tables["c_products"]], "_dlt_version": [schema.tables["_dlt_version"]], - "_dlt_loads": [schema.tables["_dlt_loads"]] + "_dlt_loads": [schema.tables["_dlt_loads"]], } result = utils.group_tables_by_resource(schema.tables) assert result == expected_tables # With resource filter - result = utils.group_tables_by_resource(schema.tables, pattern=utils.compile_simple_regex(TSimpleRegex("re:[a-z]_events"))) + result = utils.group_tables_by_resource( + schema.tables, pattern=utils.compile_simple_regex(TSimpleRegex("re:[a-z]_events")) + ) assert result == { - "a_events": [schema.tables["a_events"], schema.tables["a_events__1"], schema.tables["a_events__1__2"]], + "a_events": [ + schema.tables["a_events"], + schema.tables["a_events__1"], + schema.tables["a_events__1__2"], + ], "b_events": [schema.tables["b_events"], schema.tables["b_events__1"]], } # With resources that has many top level tables schema.update_schema(utils.new_table("mc_products", columns=[], resource="products")) - schema.update_schema(utils.new_table("mc_products__sub", columns=[], parent_table_name="mc_products")) - result = utils.group_tables_by_resource(schema.tables, pattern=utils.compile_simple_regex(TSimpleRegex("products"))) + schema.update_schema( + utils.new_table("mc_products__sub", columns=[], parent_table_name="mc_products") + ) + result = utils.group_tables_by_resource( + schema.tables, pattern=utils.compile_simple_regex(TSimpleRegex("products")) + ) # both tables with resource "products" must be here - assert result == {'products': [ - {'columns': {}, 'name': 'c_products', 'resource': 'products', 'write_disposition': 'append'}, - {'columns': {}, 'name': 'mc_products', 'resource': 'products', 'write_disposition': 'append'}, - {'columns': {}, 'name': 'mc_products__sub', 'parent': 'mc_products'} + assert result == { + "products": [ + { + "columns": {}, + "name": "c_products", + "resource": "products", + "write_disposition": "append", + }, + { + "columns": {}, + "name": "mc_products", + "resource": "products", + "write_disposition": "append", + }, + {"columns": {}, "name": "mc_products__sub", "parent": "mc_products"}, ] } diff --git a/tests/common/schema/test_versioning.py b/tests/common/schema/test_versioning.py index b535634ef4..ab33ba7c63 100644 --- a/tests/common/schema/test_versioning.py +++ b/tests/common/schema/test_versioning.py @@ -1,13 +1,12 @@ import pytest import yaml +from tests.common.utils import load_json_case, load_yml_case from dlt.common import json from dlt.common.schema import utils from dlt.common.schema.schema import Schema from dlt.common.schema.typing import TStoredSchema -from tests.common.utils import load_json_case, load_yml_case - def test_content_hash() -> None: eth_v4: TStoredSchema = load_yml_case("schemas/eth/ethereum_schema_v4") diff --git a/tests/common/scripts/args.py b/tests/common/scripts/args.py index 627daeb76b..67c6cc651a 100644 --- a/tests/common/scripts/args.py +++ b/tests/common/scripts/args.py @@ -1,4 +1,4 @@ import sys print(len(sys.argv)) -print(sys.argv) \ No newline at end of file +print(sys.argv) diff --git a/tests/common/scripts/counter.py b/tests/common/scripts/counter.py index 99352cd1f3..afe315cf59 100644 --- a/tests/common/scripts/counter.py +++ b/tests/common/scripts/counter.py @@ -1,9 +1,8 @@ import sys from time import sleep - for i in range(5): print(i) sys.stdout.flush() sleep(0.3) -print("exit") \ No newline at end of file +print("exit") diff --git a/tests/common/scripts/cwd.py b/tests/common/scripts/cwd.py index 404cf43ada..ea065561f3 100644 --- a/tests/common/scripts/cwd.py +++ b/tests/common/scripts/cwd.py @@ -1,3 +1,3 @@ import os -print(os.getcwd()) \ No newline at end of file +print(os.getcwd()) diff --git a/tests/common/scripts/long_lines.py b/tests/common/scripts/long_lines.py index ca5469cd4c..0d22c692ba 100644 --- a/tests/common/scripts/long_lines.py +++ b/tests/common/scripts/long_lines.py @@ -10,4 +10,4 @@ # without new lines print(line_b, file=sys.stderr, end="") -print(line_a, end="") \ No newline at end of file +print(line_a, end="") diff --git a/tests/common/scripts/long_lines_fails.py b/tests/common/scripts/long_lines_fails.py index 0633f078e0..37e2f13e31 100644 --- a/tests/common/scripts/long_lines_fails.py +++ b/tests/common/scripts/long_lines_fails.py @@ -11,4 +11,4 @@ # without new lines print(line_b, file=sys.stderr, end="") print(line_a, end="") -exit(-1) \ No newline at end of file +exit(-1) diff --git a/tests/common/scripts/no_stdout_exception.py b/tests/common/scripts/no_stdout_exception.py index 90c71a4551..75bebd8cc7 100644 --- a/tests/common/scripts/no_stdout_exception.py +++ b/tests/common/scripts/no_stdout_exception.py @@ -1 +1 @@ -raise Exception("no stdout") \ No newline at end of file +raise Exception("no stdout") diff --git a/tests/common/scripts/no_stdout_no_stderr_with_fail.py b/tests/common/scripts/no_stdout_no_stderr_with_fail.py index 8e7ef7e83f..d0d1c88de8 100644 --- a/tests/common/scripts/no_stdout_no_stderr_with_fail.py +++ b/tests/common/scripts/no_stdout_no_stderr_with_fail.py @@ -1 +1 @@ -exit(-1) \ No newline at end of file +exit(-1) diff --git a/tests/common/scripts/raising_counter.py b/tests/common/scripts/raising_counter.py index 74c9a53b20..0212c68daa 100644 --- a/tests/common/scripts/raising_counter.py +++ b/tests/common/scripts/raising_counter.py @@ -1,11 +1,10 @@ import sys from time import sleep - for i in range(5): print(i) # sys.stdout.flush() if i == 2: raise Exception("end") sleep(0.3) -print("exit") \ No newline at end of file +print("exit") diff --git a/tests/common/scripts/stdout_encode_exception.py b/tests/common/scripts/stdout_encode_exception.py index 57658d431b..1de37c8d43 100644 --- a/tests/common/scripts/stdout_encode_exception.py +++ b/tests/common/scripts/stdout_encode_exception.py @@ -1,15 +1,15 @@ from functools import partial -from dlt.common.exceptions import UnsupportedProcessStartMethodException +from dlt.common.exceptions import UnsupportedProcessStartMethodException from dlt.common.runners import TRunMetrics from dlt.common.runners.stdout import exec_to_stdout - def worker(data1, data2): print("in func") raise UnsupportedProcessStartMethodException("this") + f = partial(worker, "this is string", TRunMetrics(True, 300)) with exec_to_stdout(f) as rv: print(rv) diff --git a/tests/common/scripts/stdout_encode_result.py b/tests/common/scripts/stdout_encode_result.py index b399734a4d..51c9b553db 100644 --- a/tests/common/scripts/stdout_encode_result.py +++ b/tests/common/scripts/stdout_encode_result.py @@ -8,6 +8,7 @@ def worker(data1, data2): print("in func") return data1, data2 + f = partial(worker, "this is string", TRunMetrics(True, 300)) with exec_to_stdout(f) as rv: print(rv) diff --git a/tests/common/storages/test_file_storage.py b/tests/common/storages/test_file_storage.py index 48e743d575..c31f7d955a 100644 --- a/tests/common/storages/test_file_storage.py +++ b/tests/common/storages/test_file_storage.py @@ -1,14 +1,14 @@ import gzip import os import stat -import pytest from pathlib import Path +import pytest +from tests.utils import TEST_STORAGE_ROOT, autouse_test_storage, skipifnotwindows, test_storage + from dlt.common.storages.file_storage import FileStorage from dlt.common.utils import encoding_for_mode, set_working_dir, uniq_id -from tests.utils import TEST_STORAGE_ROOT, autouse_test_storage, test_storage, skipifnotwindows - def test_storage_init(test_storage: FileStorage) -> None: # must be absolute path @@ -68,7 +68,10 @@ def test_in_storage(test_storage: FileStorage) -> None: assert test_storage.in_storage(".") is True assert test_storage.in_storage(os.curdir) is True assert test_storage.in_storage(os.path.realpath(os.curdir)) is False - assert test_storage.in_storage(os.path.join(os.path.realpath(os.curdir), TEST_STORAGE_ROOT)) is True + assert ( + test_storage.in_storage(os.path.join(os.path.realpath(os.curdir), TEST_STORAGE_ROOT)) + is True + ) def test_from_wd_to_relative_path(test_storage: FileStorage) -> None: @@ -128,31 +131,31 @@ def test_validate_file_name_component() -> None: @pytest.mark.parametrize("action", ("rename_tree_files", "rename_tree", "atomic_rename")) def test_rename_nested_tree(test_storage: FileStorage, action: str) -> None: - source_dir = os.path.join(test_storage.storage_path, 'source') - nested_dir_1 = os.path.join(source_dir, 'nested1') - nested_dir_2 = os.path.join(nested_dir_1, 'nested2') - empty_dir = os.path.join(source_dir, 'empty') + source_dir = os.path.join(test_storage.storage_path, "source") + nested_dir_1 = os.path.join(source_dir, "nested1") + nested_dir_2 = os.path.join(nested_dir_1, "nested2") + empty_dir = os.path.join(source_dir, "empty") os.makedirs(nested_dir_2) os.makedirs(empty_dir) - with open(os.path.join(source_dir, 'test1.txt'), 'w', encoding="utf-8") as f: - f.write('test') - with open(os.path.join(nested_dir_1, 'test2.txt'), 'w', encoding="utf-8") as f: - f.write('test') - with open(os.path.join(nested_dir_2, 'test3.txt'), 'w', encoding="utf-8") as f: - f.write('test') + with open(os.path.join(source_dir, "test1.txt"), "w", encoding="utf-8") as f: + f.write("test") + with open(os.path.join(nested_dir_1, "test2.txt"), "w", encoding="utf-8") as f: + f.write("test") + with open(os.path.join(nested_dir_2, "test3.txt"), "w", encoding="utf-8") as f: + f.write("test") - dest_dir = os.path.join(test_storage.storage_path, 'dest') + dest_dir = os.path.join(test_storage.storage_path, "dest") getattr(test_storage, action)(source_dir, dest_dir) assert not os.path.exists(source_dir) assert os.path.exists(dest_dir) - assert os.path.exists(os.path.join(dest_dir, 'nested1')) - assert os.path.exists(os.path.join(dest_dir, 'nested1', 'nested2')) - assert os.path.exists(os.path.join(dest_dir, 'empty')) - assert os.path.exists(os.path.join(dest_dir, 'test1.txt')) - assert os.path.exists(os.path.join(dest_dir, 'nested1', 'test2.txt')) - assert os.path.exists(os.path.join(dest_dir, 'nested1', 'nested2', 'test3.txt')) + assert os.path.exists(os.path.join(dest_dir, "nested1")) + assert os.path.exists(os.path.join(dest_dir, "nested1", "nested2")) + assert os.path.exists(os.path.join(dest_dir, "empty")) + assert os.path.exists(os.path.join(dest_dir, "test1.txt")) + assert os.path.exists(os.path.join(dest_dir, "nested1", "test2.txt")) + assert os.path.exists(os.path.join(dest_dir, "nested1", "nested2", "test3.txt")) @skipifnotwindows diff --git a/tests/common/storages/test_loader_storage.py b/tests/common/storages/test_loader_storage.py index 905e6cfcdb..c56cb3a9f9 100644 --- a/tests/common/storages/test_loader_storage.py +++ b/tests/common/storages/test_loader_storage.py @@ -1,19 +1,24 @@ import os -import pytest from pathlib import Path from typing import Sequence, Tuple -from dlt.common import sleep, json, pendulum -from dlt.common.schema import Schema, TSchemaTables -from dlt.common.storages.load_storage import LoadPackageInfo, LoadStorage, ParsedLoadJobFileName, TJobState +import pytest +from tests.utils import TEST_STORAGE_ROOT, autouse_test_storage, write_version + +from dlt.common import json, pendulum, sleep from dlt.common.configuration import resolve_configuration +from dlt.common.schema import Schema, TSchemaTables from dlt.common.storages import LoadStorageConfiguration from dlt.common.storages.exceptions import LoadPackageNotFound, NoMigrationPathException +from dlt.common.storages.load_storage import ( + LoadPackageInfo, + LoadStorage, + ParsedLoadJobFileName, + TJobState, +) from dlt.common.typing import StrAny from dlt.common.utils import uniq_id -from tests.utils import TEST_STORAGE_ROOT, write_version, autouse_test_storage - @pytest.fixture def storage() -> LoadStorage: @@ -34,9 +39,15 @@ def test_complete_successful_package(storage: LoadStorage) -> None: assert not storage.storage.has_folder(storage.get_package_path(load_id)) # has package assert storage.storage.has_folder(storage.get_completed_package_path(load_id)) - assert storage.storage.has_file(os.path.join(storage.get_completed_package_path(load_id), LoadStorage.PACKAGE_COMPLETED_FILE_NAME)) + assert storage.storage.has_file( + os.path.join( + storage.get_completed_package_path(load_id), LoadStorage.PACKAGE_COMPLETED_FILE_NAME + ) + ) # but completed packages are deleted - assert not storage.storage.has_folder(storage._get_job_folder_completed_path(load_id, "completed_jobs")) + assert not storage.storage.has_folder( + storage._get_job_folder_completed_path(load_id, "completed_jobs") + ) assert_package_info(storage, load_id, "loaded", "completed_jobs", jobs_count=0) # delete completed package storage.delete_completed_package(load_id) @@ -50,9 +61,15 @@ def test_complete_successful_package(storage: LoadStorage) -> None: assert not storage.storage.has_folder(storage.get_package_path(load_id)) # has load preserved assert storage.storage.has_folder(storage.get_completed_package_path(load_id)) - assert storage.storage.has_file(os.path.join(storage.get_completed_package_path(load_id), LoadStorage.PACKAGE_COMPLETED_FILE_NAME)) + assert storage.storage.has_file( + os.path.join( + storage.get_completed_package_path(load_id), LoadStorage.PACKAGE_COMPLETED_FILE_NAME + ) + ) # has completed loads - assert storage.storage.has_folder(storage._get_job_folder_completed_path(load_id, "completed_jobs")) + assert storage.storage.has_folder( + storage._get_job_folder_completed_path(load_id, "completed_jobs") + ) storage.delete_completed_package(load_id) assert not storage.storage.has_folder(storage.get_completed_package_path(load_id)) @@ -78,7 +95,9 @@ def test_complete_package_failed_jobs(storage: LoadStorage) -> None: # present in completed loads folder assert storage.storage.has_folder(storage.get_completed_package_path(load_id)) # has completed loads - assert storage.storage.has_folder(storage._get_job_folder_completed_path(load_id, "completed_jobs")) + assert storage.storage.has_folder( + storage._get_job_folder_completed_path(load_id, "completed_jobs") + ) assert_package_info(storage, load_id, "loaded", "failed_jobs") # get failed jobs info @@ -109,7 +128,9 @@ def test_abort_package(storage: LoadStorage) -> None: storage.fail_job(load_id, file_name, "EXCEPTION") assert_package_info(storage, load_id, "normalized", "failed_jobs") storage.complete_load_package(load_id, True) - assert storage.storage.has_folder(storage._get_job_folder_completed_path(load_id, "completed_jobs")) + assert storage.storage.has_folder( + storage._get_job_folder_completed_path(load_id, "completed_jobs") + ) assert_package_info(storage, load_id, "aborted", "failed_jobs") @@ -120,8 +141,10 @@ def test_save_load_schema(storage: LoadStorage) -> None: storage.create_temp_load_package("copy") saved_file_name = storage.save_temp_schema(schema, "copy") - assert saved_file_name.endswith(os.path.join(storage.storage.storage_path, "copy", LoadStorage.SCHEMA_FILE_NAME)) - assert storage.storage.has_file(os.path.join("copy",LoadStorage.SCHEMA_FILE_NAME)) + assert saved_file_name.endswith( + os.path.join(storage.storage.storage_path, "copy", LoadStorage.SCHEMA_FILE_NAME) + ) + assert storage.storage.has_file(os.path.join("copy", LoadStorage.SCHEMA_FILE_NAME)) schema_copy = storage.load_temp_schema("copy") assert schema.stored_version == schema_copy.stored_version @@ -195,7 +218,9 @@ def test_process_schema_update(storage: LoadStorage) -> None: storage.commit_schema_update(load_id, applied_update) assert storage.begin_schema_update(load_id) is None # processed file exists - applied_update_path = os.path.join(storage.get_package_path(load_id), LoadStorage.APPLIED_SCHEMA_UPDATES_FILE_NAME) + applied_update_path = os.path.join( + storage.get_package_path(load_id), LoadStorage.APPLIED_SCHEMA_UPDATES_FILE_NAME + ) assert storage.storage.has_file(applied_update_path) is True assert json.loads(storage.storage.load(applied_update_path)) == applied_update # verify info package @@ -252,7 +277,13 @@ def start_loading_file(s: LoadStorage, content: Sequence[StrAny]) -> Tuple[str, return load_id, file_name -def assert_package_info(storage: LoadStorage, load_id: str, package_state: str, job_state: TJobState, jobs_count: int = 1) -> LoadPackageInfo: +def assert_package_info( + storage: LoadStorage, + load_id: str, + package_state: str, + job_state: TJobState, + jobs_count: int = 1, +) -> LoadPackageInfo: package_info = storage.get_load_package_info(load_id) # make sure it is serializable json.dumps(package_info) diff --git a/tests/common/storages/test_normalize_storage.py b/tests/common/storages/test_normalize_storage.py index 678e1e49fe..c41c44354c 100644 --- a/tests/common/storages/test_normalize_storage.py +++ b/tests/common/storages/test_normalize_storage.py @@ -1,11 +1,10 @@ import pytest +from tests.utils import autouse_test_storage, write_version -from dlt.common.utils import uniq_id from dlt.common.storages import NormalizeStorage, NormalizeStorageConfiguration from dlt.common.storages.exceptions import NoMigrationPathException from dlt.common.storages.normalize_storage import TParsedNormalizeFileName - -from tests.utils import write_version, autouse_test_storage +from dlt.common.utils import uniq_id @pytest.mark.skip() @@ -16,13 +15,20 @@ def test_load_events_and_group_by_sender() -> None: def test_build_extracted_file_name() -> None: load_id = uniq_id() - name = NormalizeStorage.build_extracted_file_stem("event", "table_with_parts__many", load_id) + ".jsonl" + name = ( + NormalizeStorage.build_extracted_file_stem("event", "table_with_parts__many", load_id) + + ".jsonl" + ) assert NormalizeStorage.get_schema_name(name) == "event" - assert NormalizeStorage.parse_normalize_file_name(name) == TParsedNormalizeFileName("event", "table_with_parts__many", load_id) + assert NormalizeStorage.parse_normalize_file_name(name) == TParsedNormalizeFileName( + "event", "table_with_parts__many", load_id + ) # empty schema should be supported name = NormalizeStorage.build_extracted_file_stem("", "table", load_id) + ".jsonl" - assert NormalizeStorage.parse_normalize_file_name(name) == TParsedNormalizeFileName("", "table", load_id) + assert NormalizeStorage.parse_normalize_file_name(name) == TParsedNormalizeFileName( + "", "table", load_id + ) def test_full_migration_path() -> None: diff --git a/tests/common/storages/test_schema_storage.py b/tests/common/storages/test_schema_storage.py index 078af856cb..6dbe21f72b 100644 --- a/tests/common/storages/test_schema_storage.py +++ b/tests/common/storages/test_schema_storage.py @@ -1,17 +1,31 @@ import os import shutil + import pytest import yaml -from dlt.common import json +from tests.common.utils import ( + COMMON_TEST_CASES_PATH, + IMPORTED_VERSION_HASH_ETH_V6, + load_yml_case, + yml_case_path, +) +from tests.utils import TEST_STORAGE_ROOT, autouse_test_storage +from dlt.common import json from dlt.common.schema.schema import Schema from dlt.common.schema.typing import TStoredSchema from dlt.common.schema.utils import explicit_normalizers -from dlt.common.storages.exceptions import InStorageSchemaModified, SchemaNotFoundError, UnexpectedSchemaName -from dlt.common.storages import SchemaStorageConfiguration, SchemaStorage, LiveSchemaStorage, FileStorage - -from tests.utils import autouse_test_storage, TEST_STORAGE_ROOT -from tests.common.utils import load_yml_case, yml_case_path, COMMON_TEST_CASES_PATH, IMPORTED_VERSION_HASH_ETH_V6 +from dlt.common.storages import ( + FileStorage, + LiveSchemaStorage, + SchemaStorage, + SchemaStorageConfiguration, +) +from dlt.common.storages.exceptions import ( + InStorageSchemaModified, + SchemaNotFoundError, + UnexpectedSchemaName, +) @pytest.fixture @@ -22,13 +36,23 @@ def storage() -> SchemaStorage: @pytest.fixture def synced_storage() -> SchemaStorage: # will be created in /schemas - return init_storage(SchemaStorageConfiguration(import_schema_path=TEST_STORAGE_ROOT + "/import", export_schema_path=TEST_STORAGE_ROOT + "/import")) + return init_storage( + SchemaStorageConfiguration( + import_schema_path=TEST_STORAGE_ROOT + "/import", + export_schema_path=TEST_STORAGE_ROOT + "/import", + ) + ) @pytest.fixture def ie_storage() -> SchemaStorage: # will be created in /schemas - return init_storage(SchemaStorageConfiguration(import_schema_path=TEST_STORAGE_ROOT + "/import", export_schema_path=TEST_STORAGE_ROOT + "/export")) + return init_storage( + SchemaStorageConfiguration( + import_schema_path=TEST_STORAGE_ROOT + "/import", + export_schema_path=TEST_STORAGE_ROOT + "/export", + ) + ) def init_storage(C: SchemaStorageConfiguration) -> SchemaStorage: @@ -49,7 +73,9 @@ def test_load_non_existing(storage: SchemaStorage) -> None: def test_load_schema_with_upgrade() -> None: # point the storage root to v4 schema google_spreadsheet_v3.schema - storage = LiveSchemaStorage(SchemaStorageConfiguration(COMMON_TEST_CASES_PATH + "schemas/sheets")) + storage = LiveSchemaStorage( + SchemaStorageConfiguration(COMMON_TEST_CASES_PATH + "schemas/sheets") + ) # the hash when computed on the schema does not match the version_hash in the file so it should raise InStorageSchemaModified # but because the version upgrade is required, the check is skipped and the load succeeds storage.load_schema("google_spreadsheet_v4") @@ -64,7 +90,9 @@ def test_import_initial(synced_storage: SchemaStorage, storage: SchemaStorage) - assert_schema_imported(synced_storage, storage) -def test_import_overwrites_existing_if_modified(synced_storage: SchemaStorage, storage: SchemaStorage) -> None: +def test_import_overwrites_existing_if_modified( + synced_storage: SchemaStorage, storage: SchemaStorage +) -> None: schema = Schema("ethereum") storage.save_schema(schema) # now import schema that wil overwrite schema in storage as it is not linked to external schema @@ -235,28 +263,43 @@ def test_save_store_schema(storage: SchemaStorage) -> None: d_n["names"] = "tests.common.normalizers.custom_normalizers" schema = Schema("column_event", normalizers=d_n) storage.save_schema(schema) - assert storage.storage.has_file(SchemaStorage.NAMED_SCHEMA_FILE_PATTERN % ("column_event", "json")) + assert storage.storage.has_file( + SchemaStorage.NAMED_SCHEMA_FILE_PATTERN % ("column_event", "json") + ) loaded_schema = storage.load_schema("column_event") # also tables gets normalized inside so custom_ is added - assert loaded_schema.to_dict()["tables"]["column__dlt_loads"] == schema.to_dict()["tables"]["column__dlt_loads"] + assert ( + loaded_schema.to_dict()["tables"]["column__dlt_loads"] + == schema.to_dict()["tables"]["column__dlt_loads"] + ) assert loaded_schema.to_dict() == schema.to_dict() def test_schema_from_file() -> None: # json has precedence - schema = SchemaStorage.load_schema_file(os.path.join(COMMON_TEST_CASES_PATH, "schemas/local"), "event") + schema = SchemaStorage.load_schema_file( + os.path.join(COMMON_TEST_CASES_PATH, "schemas/local"), "event" + ) assert schema.name == "event" - schema = SchemaStorage.load_schema_file(os.path.join(COMMON_TEST_CASES_PATH, "schemas/local"), "event", extensions=("yaml",)) + schema = SchemaStorage.load_schema_file( + os.path.join(COMMON_TEST_CASES_PATH, "schemas/local"), "event", extensions=("yaml",) + ) assert schema.name == "event" assert "blocks" in schema.tables with pytest.raises(SchemaNotFoundError): - SchemaStorage.load_schema_file(os.path.join(COMMON_TEST_CASES_PATH, "schemas/local"), "eth", extensions=("yaml",)) + SchemaStorage.load_schema_file( + os.path.join(COMMON_TEST_CASES_PATH, "schemas/local"), "eth", extensions=("yaml",) + ) # file name and schema content mismatch with pytest.raises(UnexpectedSchemaName): - SchemaStorage.load_schema_file(os.path.join(COMMON_TEST_CASES_PATH, "schemas/local"), "name_mismatch", extensions=("yaml",)) + SchemaStorage.load_schema_file( + os.path.join(COMMON_TEST_CASES_PATH, "schemas/local"), + "name_mismatch", + extensions=("yaml",), + ) # def test_save_empty_schema_name(storage: SchemaStorage) -> None: @@ -269,7 +312,10 @@ def test_schema_from_file() -> None: def prepare_import_folder(storage: SchemaStorage) -> None: - shutil.copy(yml_case_path("schemas/eth/ethereum_schema_v6"), os.path.join(storage.storage.storage_path, "../import/ethereum.schema.yaml")) + shutil.copy( + yml_case_path("schemas/eth/ethereum_schema_v6"), + os.path.join(storage.storage.storage_path, "../import/ethereum.schema.yaml"), + ) def assert_schema_imported(synced_storage: SchemaStorage, storage: SchemaStorage) -> Schema: diff --git a/tests/common/storages/test_transactional_file.py b/tests/common/storages/test_transactional_file.py index 9d3d735b9c..97b347a8df 100644 --- a/tests/common/storages/test_transactional_file.py +++ b/tests/common/storages/test_transactional_file.py @@ -5,11 +5,10 @@ import fsspec import pytest +from tests.utils import skipifwindows from dlt.common.storages.transactional_file import TransactionalFile -from tests.utils import skipifwindows - @pytest.fixture(scope="session") def fs() -> fsspec.AbstractFileSystem: @@ -107,7 +106,9 @@ def test_file_transaction_multiple_writers(fs: fsspec.AbstractFileSystem, file_n assert writer_2.read() == b"test 4" -def test_file_transaction_multiple_writers_with_races(fs: fsspec.AbstractFileSystem, file_name: str): +def test_file_transaction_multiple_writers_with_races( + fs: fsspec.AbstractFileSystem, file_name: str +): writer_1 = TransactionalFile(file_name, fs) time.sleep(0.5) writer_2 = TransactionalFile(file_name, fs) @@ -127,8 +128,10 @@ def test_file_transaction_simultaneous(fs: fsspec.AbstractFileSystem): pool = ThreadPoolExecutor(max_workers=40) results = pool.map( - lambda _: TransactionalFile( - "/bucket/test_123", fs).acquire_lock(blocking=False, jitter_mean=0.3), range(200) + lambda _: TransactionalFile("/bucket/test_123", fs).acquire_lock( + blocking=False, jitter_mean=0.3 + ), + range(200), ) assert sum(results) == 1 diff --git a/tests/common/storages/test_versioned_storage.py b/tests/common/storages/test_versioned_storage.py index ff23480a48..06858f6d87 100644 --- a/tests/common/storages/test_versioned_storage.py +++ b/tests/common/storages/test_versioned_storage.py @@ -1,15 +1,16 @@ import pytest import semver +from tests.utils import test_storage, write_version -from dlt.common.storages.file_storage import FileStorage from dlt.common.storages.exceptions import NoMigrationPathException, WrongStorageVersionException +from dlt.common.storages.file_storage import FileStorage from dlt.common.storages.versioned_storage import VersionedStorage -from tests.utils import write_version, test_storage - class MigratedStorage(VersionedStorage): - def migrate_storage(self, from_version: semver.VersionInfo, to_version: semver.VersionInfo) -> None: + def migrate_storage( + self, from_version: semver.VersionInfo, to_version: semver.VersionInfo + ) -> None: # migration example: if from_version == "1.0.0" and from_version < to_version: from_version = semver.VersionInfo.parse("1.1.0") @@ -56,4 +57,4 @@ def test_downgrade_not_possible(test_storage: FileStorage) -> None: write_version(test_storage, "1.2.0") with pytest.raises(NoMigrationPathException) as wmpe: MigratedStorage("1.1.0", True, test_storage) - assert wmpe.value.migrated_version == "1.2.0" \ No newline at end of file + assert wmpe.value.migrated_version == "1.2.0" diff --git a/tests/common/test_arithmetics.py b/tests/common/test_arithmetics.py index 4912d976eb..e3f4f43581 100644 --- a/tests/common/test_arithmetics.py +++ b/tests/common/test_arithmetics.py @@ -1,6 +1,7 @@ import pytest + from dlt.common import Decimal -from dlt.common.arithmetics import numeric_default_context, numeric_default_quantize, Inexact +from dlt.common.arithmetics import Inexact, numeric_default_context, numeric_default_quantize def test_default_numeric_quantize() -> None: @@ -18,7 +19,6 @@ def test_default_numeric_quantize() -> None: scale_18 = Decimal("0.5327010784") assert str(numeric_default_quantize(scale_18)) == "0.532701078" - # less than 0 digits scale_5 = Decimal("0.4") assert str(numeric_default_quantize(scale_5)) == "0.400000000" @@ -27,7 +27,7 @@ def test_default_numeric_quantize() -> None: def test_numeric_context() -> None: # we reach (38,9) numeric with numeric_default_context(): - v = Decimal(10**29-1) + Decimal("0.532701079") + v = Decimal(10**29 - 1) + Decimal("0.532701079") assert str(v) == "99999999999999999999999999999.532701079" assert numeric_default_quantize(v) == v diff --git a/tests/common/test_data_writers/test_buffered_writer.py b/tests/common/test_data_writers/test_buffered_writer.py index e05011b1e8..a2ab5af696 100644 --- a/tests/common/test_data_writers/test_buffered_writer.py +++ b/tests/common/test_data_writers/test_buffered_writer.py @@ -1,24 +1,33 @@ +import datetime # noqa: 251 import os + import pytest -from dlt.common.arithmetics import Decimal +from tests.utils import TEST_STORAGE_ROOT, autouse_test_storage, write_version +from dlt.common.arithmetics import Decimal from dlt.common.data_writers.buffered import BufferedDataWriter from dlt.common.data_writers.exceptions import BufferedDataWriterClosed -from dlt.common.destination import TLoaderFileFormat, DestinationCapabilitiesContext +from dlt.common.destination import DestinationCapabilitiesContext, TLoaderFileFormat from dlt.common.schema.utils import new_column from dlt.common.storages.file_storage import FileStorage - from dlt.common.typing import DictStrAny -from tests.utils import TEST_STORAGE_ROOT, write_version, autouse_test_storage -import datetime # noqa: 251 - -def get_insert_writer(_format: TLoaderFileFormat = "insert_values", buffer_max_items: int = 10, disable_compression: bool = False) -> BufferedDataWriter: +def get_insert_writer( + _format: TLoaderFileFormat = "insert_values", + buffer_max_items: int = 10, + disable_compression: bool = False, +) -> BufferedDataWriter: caps = DestinationCapabilitiesContext.generic_capabilities() caps.preferred_loader_file_format = _format file_template = os.path.join(TEST_STORAGE_ROOT, f"{_format}.%s") - return BufferedDataWriter(_format, file_template, buffer_max_items=buffer_max_items, disable_compression=disable_compression, _caps=caps) + return BufferedDataWriter( + _format, + file_template, + buffer_max_items=buffer_max_items, + disable_compression=disable_compression, + _caps=caps, + ) def test_write_no_item() -> None: @@ -31,9 +40,10 @@ def test_write_no_item() -> None: assert writer.closed_files == [] -@pytest.mark.parametrize("disable_compression", [True, False], ids=["no_compression", "compression"]) +@pytest.mark.parametrize( + "disable_compression", [True, False], ids=["no_compression", "compression"] +) def test_rotation_on_schema_change(disable_compression: bool) -> None: - c1 = new_column("col1", "bigint") c2 = new_column("col2", "bigint") c3 = new_column("col3", "text") @@ -46,7 +56,7 @@ def c1_doc(count: int) -> DictStrAny: return map(lambda x: {"col1": x}, range(0, count)) def c2_doc(count: int) -> DictStrAny: - return map(lambda x: {"col1": x, "col2": x*2+1}, range(0, count)) + return map(lambda x: {"col1": x, "col2": x * 2 + 1}, range(0, count)) def c3_doc(count: int) -> DictStrAny: return map(lambda x: {"col3": "col3_value"}, range(0, count)) @@ -122,7 +132,9 @@ def c3_doc(count: int) -> DictStrAny: assert "(col3_value" in content[-1] -@pytest.mark.parametrize("disable_compression", [True, False], ids=["no_compression", "compression"]) +@pytest.mark.parametrize( + "disable_compression", [True, False], ids=["no_compression", "compression"] +) def test_NO_rotation_on_schema_change(disable_compression: bool) -> None: c1 = new_column("col1", "bigint") c2 = new_column("col2", "bigint") @@ -134,7 +146,7 @@ def c1_doc(count: int) -> DictStrAny: return map(lambda x: {"col1": x}, range(0, count)) def c2_doc(count: int) -> DictStrAny: - return map(lambda x: {"col1": x, "col2": x*2+1}, range(0, count)) + return map(lambda x: {"col1": x, "col2": x * 2 + 1}, range(0, count)) # change schema before file first flush with get_insert_writer(_format="jsonl", disable_compression=disable_compression) as writer: @@ -152,7 +164,9 @@ def c2_doc(count: int) -> DictStrAny: assert content[-1] == '{"col1":1,"col2":3}\n' -@pytest.mark.parametrize("disable_compression", [True, False], ids=["no_compression", "compression"]) +@pytest.mark.parametrize( + "disable_compression", [True, False], ids=["no_compression", "compression"] +) def test_writer_requiring_schema(disable_compression: bool) -> None: # assertion on flushing with pytest.raises(AssertionError): @@ -166,9 +180,10 @@ def test_writer_requiring_schema(disable_compression: bool) -> None: writer.write_data_item([{"col1": 1}], t1) -@pytest.mark.parametrize("disable_compression", [True, False], ids=["no_compression", "compression"]) +@pytest.mark.parametrize( + "disable_compression", [True, False], ids=["no_compression", "compression"] +) def test_writer_optional_schema(disable_compression: bool) -> None: with get_insert_writer(_format="jsonl", disable_compression=disable_compression) as writer: - writer.write_data_item([{"col1": 1}], None) - writer.write_data_item([{"col1": 1}], None) - + writer.write_data_item([{"col1": 1}], None) + writer.write_data_item([{"col1": 1}], None) diff --git a/tests/common/test_data_writers/test_data_writers.py b/tests/common/test_data_writers/test_data_writers.py index 36b6e4b6ec..acbebffd04 100644 --- a/tests/common/test_data_writers/test_data_writers.py +++ b/tests/common/test_data_writers/test_data_writers.py @@ -1,15 +1,27 @@ import io -import pytest from typing import Iterator -from dlt.common import pendulum, json +import pytest +from tests.common.utils import load_json_case, row_to_column_schemas + +from dlt.common import json, pendulum +from dlt.common.data_writers.escape import ( + escape_bigquery_identifier, + escape_duckdb_literal, + escape_postgres_literal, + escape_redshift_identifier, + escape_redshift_literal, +) +from dlt.common.data_writers.writers import ( + DataWriter, + InsertValuesWriter, + JsonlWriter, + ParquetDataWriter, +) from dlt.common.typing import AnyFun + # from dlt.destinations.postgres import capabilities from dlt.destinations.redshift import capabilities as redshift_caps -from dlt.common.data_writers.escape import escape_redshift_identifier, escape_bigquery_identifier, escape_redshift_literal, escape_postgres_literal, escape_duckdb_literal -from dlt.common.data_writers.writers import DataWriter, InsertValuesWriter, JsonlWriter, ParquetDataWriter - -from tests.common.utils import load_json_case, row_to_column_schemas ALL_LITERAL_ESCAPE = [escape_redshift_literal, escape_postgres_literal, escape_duckdb_literal] @@ -41,7 +53,7 @@ def test_simple_jsonl_writer(jsonl_writer: DataWriter) -> None: jsonl_writer.write_all(None, rows) # remove b'' at the end lines = jsonl_writer._f.getvalue().split(b"\n") - assert lines[-1] == b'' + assert lines[-1] == b"" assert len(lines) == 3 @@ -86,13 +98,22 @@ def test_string_literal_escape() -> None: assert escape_redshift_literal(", NULL'); DROP TABLE --") == "', NULL''); DROP TABLE --'" assert escape_redshift_literal(", NULL');\n DROP TABLE --") == "', NULL'');\\n DROP TABLE --'" assert escape_redshift_literal(", NULL);\n DROP TABLE --") == "', NULL);\\n DROP TABLE --'" - assert escape_redshift_literal(", NULL);\\n DROP TABLE --\\") == "', NULL);\\\\n DROP TABLE --\\\\'" + assert ( + escape_redshift_literal(", NULL);\\n DROP TABLE --\\") + == "', NULL);\\\\n DROP TABLE --\\\\'" + ) # assert escape_redshift_literal(b'hello_word') == "\\x68656c6c6f5f776f7264" @pytest.mark.parametrize("escaper", ALL_LITERAL_ESCAPE) def test_string_complex_escape(escaper: AnyFun) -> None: - doc = {"complex":[1,2,3,"a"], "link": "?commen\ntU\nrn=urn%3Ali%3Acomment%3A%28acti\0xA \0x0 \\vity%3A69'08444473\n\n551163392%2C6n \r \x8e9085"} + doc = { + "complex": [1, 2, 3, "a"], + "link": ( + "?commen\ntU\nrn=urn%3Ali%3Acomment%3A%28acti\0xA \0x0" + " \\vity%3A69'08444473\n\n551163392%2C6n \r \x8e9085" + ), + } escaped = escaper(doc) # should be same as string escape if escaper == escape_redshift_literal: @@ -102,16 +123,28 @@ def test_string_complex_escape(escaper: AnyFun) -> None: def test_identifier_escape() -> None: - assert escape_redshift_identifier(", NULL'); DROP TABLE\" -\\-") == '", NULL\'); DROP TABLE"" -\\\\-"' + assert ( + escape_redshift_identifier(", NULL'); DROP TABLE\" -\\-") + == '", NULL\'); DROP TABLE"" -\\\\-"' + ) def test_identifier_escape_bigquery() -> None: - assert escape_bigquery_identifier(", NULL'); DROP TABLE\"` -\\-") == '`, NULL\'); DROP TABLE"\\` -\\\\-`' + assert ( + escape_bigquery_identifier(", NULL'); DROP TABLE\"` -\\-") + == "`, NULL'); DROP TABLE\"\\` -\\\\-`" + ) def test_string_literal_escape_unicode() -> None: # test on some unicode characters assert escape_redshift_literal(", NULL);\n DROP TABLE --") == "', NULL);\\n DROP TABLE --'" - assert escape_redshift_literal("イロハニホヘト チリヌルヲ ワカヨタレソ ツネナラム") == "'イロハニホヘト チリヌルヲ ワカヨタレソ ツネナラム'" - assert escape_redshift_identifier("ąćł\"") == '"ąćł"""' - assert escape_redshift_identifier("イロハニホヘト チリヌルヲ \"ワカヨタレソ ツネナラム") == '"イロハニホヘト チリヌルヲ ""ワカヨタレソ ツネナラム"' + assert ( + escape_redshift_literal("イロハニホヘト チリヌルヲ ワカヨタレソ ツネナラム") + == "'イロハニホヘト チリヌルヲ ワカヨタレソ ツネナラム'" + ) + assert escape_redshift_identifier('ąćł"') == '"ąćł"""' + assert ( + escape_redshift_identifier('イロハニホヘト チリヌルヲ "ワカヨタレソ ツネナラム') + == '"イロハニホヘト チリヌルヲ ""ワカヨタレソ ツネナラム"' + ) diff --git a/tests/common/test_data_writers/test_parquet_writer.py b/tests/common/test_data_writers/test_parquet_writer.py index 674567e95e..14a9a0593a 100644 --- a/tests/common/test_data_writers/test_parquet_writer.py +++ b/tests/common/test_data_writers/test_parquet_writer.py @@ -1,30 +1,38 @@ +import datetime # noqa: 251 import os + import pyarrow as pa import pyarrow.parquet as pq -import datetime # noqa: 251 +from tests.cases import TABLE_ROW_ALL_DATA_TYPES, TABLE_UPDATE_COLUMNS_SCHEMA +from tests.utils import TEST_STORAGE_ROOT, autouse_test_storage, preserve_environ, write_version -from dlt.common import pendulum, Decimal +from dlt.common import Decimal, pendulum from dlt.common.configuration import inject_section +from dlt.common.configuration.specs.config_section_context import ConfigSectionContext from dlt.common.data_writers.buffered import BufferedDataWriter -from dlt.common.destination import TLoaderFileFormat, DestinationCapabilitiesContext +from dlt.common.destination import DestinationCapabilitiesContext, TLoaderFileFormat from dlt.common.schema.utils import new_column -from dlt.common.configuration.specs.config_section_context import ConfigSectionContext from dlt.common.time import ensure_pendulum_date, ensure_pendulum_datetime -from tests.cases import TABLE_UPDATE_COLUMNS_SCHEMA, TABLE_ROW_ALL_DATA_TYPES -from tests.utils import TEST_STORAGE_ROOT, write_version, autouse_test_storage, preserve_environ - def get_writer( _format: TLoaderFileFormat = "insert_values", buffer_max_items: int = 10, file_max_items: int = 10, file_max_bytes: int = None, - _caps: DestinationCapabilitiesContext = None) -> BufferedDataWriter: + _caps: DestinationCapabilitiesContext = None, +) -> BufferedDataWriter: caps = _caps or DestinationCapabilitiesContext.generic_capabilities() caps.preferred_loader_file_format = _format file_template = os.path.join(TEST_STORAGE_ROOT, f"{_format}.%s") - return BufferedDataWriter(_format, file_template, buffer_max_items=buffer_max_items, _caps=caps, file_max_items=file_max_items, file_max_bytes=file_max_bytes) + return BufferedDataWriter( + _format, + file_template, + buffer_max_items=buffer_max_items, + _caps=caps, + file_max_items=file_max_items, + file_max_bytes=file_max_bytes, + ) def test_parquet_writer_schema_evolution_with_big_buffer() -> None: @@ -34,8 +42,13 @@ def test_parquet_writer_schema_evolution_with_big_buffer() -> None: c4 = new_column("col4", "text") with get_writer("parquet") as writer: - writer.write_data_item([{"col1": 1, "col2": 2, "col3": "3"}], {"col1": c1, "col2": c2, "col3": c3}) - writer.write_data_item([{"col1": 1, "col2": 2, "col3": "3", "col4": "4", "col5": {"hello": "marcin"}}], {"col1": c1, "col2": c2, "col3": c3, "col4": c4}) + writer.write_data_item( + [{"col1": 1, "col2": 2, "col3": "3"}], {"col1": c1, "col2": c2, "col3": c3} + ) + writer.write_data_item( + [{"col1": 1, "col2": 2, "col3": "3", "col4": "4", "col5": {"hello": "marcin"}}], + {"col1": c1, "col2": c2, "col3": c3, "col4": c4}, + ) with open(writer.closed_files[0], "rb") as f: table = pq.read_table(f) @@ -53,9 +66,14 @@ def test_parquet_writer_schema_evolution_with_small_buffer() -> None: with get_writer("parquet", buffer_max_items=4, file_max_items=50) as writer: for _ in range(0, 20): - writer.write_data_item([{"col1": 1, "col2": 2, "col3": "3"}], {"col1": c1, "col2": c2, "col3": c3}) + writer.write_data_item( + [{"col1": 1, "col2": 2, "col3": "3"}], {"col1": c1, "col2": c2, "col3": c3} + ) for _ in range(0, 20): - writer.write_data_item([{"col1": 1, "col2": 2, "col3": "3", "col4": "4", "col5": {"hello": "marcin"}}], {"col1": c1, "col2": c2, "col3": c3, "col4": c4}) + writer.write_data_item( + [{"col1": 1, "col2": 2, "col3": "3", "col4": "4", "col5": {"hello": "marcin"}}], + {"col1": c1, "col2": c2, "col3": c3, "col4": c4}, + ) assert len(writer.closed_files) == 2 @@ -74,20 +92,34 @@ def test_parquet_writer_json_serialization() -> None: c3 = new_column("col3", "complex") with get_writer("parquet") as writer: - writer.write_data_item([{"col1": 1, "col2": 2, "col3": {"hello":"dave"}}], {"col1": c1, "col2": c2, "col3": c3}) - writer.write_data_item([{"col1": 1, "col2": 2, "col3": {"hello":"marcin"}}], {"col1": c1, "col2": c2, "col3": c3}) - writer.write_data_item([{"col1": 1, "col2": 2, "col3": {}}], {"col1": c1, "col2": c2, "col3": c3}) - writer.write_data_item([{"col1": 1, "col2": 2, "col3": []}], {"col1": c1, "col2": c2, "col3": c3}) + writer.write_data_item( + [{"col1": 1, "col2": 2, "col3": {"hello": "dave"}}], + {"col1": c1, "col2": c2, "col3": c3}, + ) + writer.write_data_item( + [{"col1": 1, "col2": 2, "col3": {"hello": "marcin"}}], + {"col1": c1, "col2": c2, "col3": c3}, + ) + writer.write_data_item( + [{"col1": 1, "col2": 2, "col3": {}}], {"col1": c1, "col2": c2, "col3": c3} + ) + writer.write_data_item( + [{"col1": 1, "col2": 2, "col3": []}], {"col1": c1, "col2": c2, "col3": c3} + ) with open(writer.closed_files[0], "rb") as f: table = pq.read_table(f) assert table.column("col1").to_pylist() == [1, 1, 1, 1] assert table.column("col2").to_pylist() == [2, 2, 2, 2] - assert table.column("col3").to_pylist() == ["""{"hello":"dave"}""","""{"hello":"marcin"}""","""{}""","""[]"""] + assert table.column("col3").to_pylist() == [ + """{"hello":"dave"}""", + """{"hello":"marcin"}""", + """{}""", + """[]""", + ] def test_parquet_writer_all_data_fields() -> None: - data = dict(TABLE_ROW_ALL_DATA_TYPES) # fix dates to use pendulum data["col4"] = ensure_pendulum_datetime(data["col4"]) @@ -137,15 +169,17 @@ def test_parquet_writer_size_file_rotation() -> None: def test_parquet_writer_config() -> None: - os.environ["NORMALIZE__DATA_WRITER__VERSION"] = "2.0" os.environ["NORMALIZE__DATA_WRITER__DATA_PAGE_SIZE"] = str(1024 * 512) os.environ["NORMALIZE__DATA_WRITER__TIMESTAMP_TIMEZONE"] = "America/New York" - with inject_section(ConfigSectionContext(pipeline_name=None, sections=("normalize", ))): + with inject_section(ConfigSectionContext(pipeline_name=None, sections=("normalize",))): with get_writer("parquet", file_max_bytes=2**8, buffer_max_items=2) as writer: for i in range(0, 5): - writer.write_data_item([{"col1": i, "col2": pendulum.now()}], {"col1": new_column("col1", "bigint"), "col2": new_column("col2", "timestamp")}) + writer.write_data_item( + [{"col1": i, "col2": pendulum.now()}], + {"col1": new_column("col1", "bigint"), "col2": new_column("col2", "timestamp")}, + ) # force the parquet writer to be created writer._flush_items() @@ -169,7 +203,11 @@ def test_parquet_writer_schema_from_caps() -> None: for _ in range(0, 5): writer.write_data_item( [{"col1": Decimal("2617.27"), "col2": pendulum.now(), "col3": Decimal(2**250)}], - {"col1": new_column("col1", "decimal"), "col2": new_column("col2", "timestamp"), "col3": new_column("col3", "wei")} + { + "col1": new_column("col1", "decimal"), + "col2": new_column("col2", "timestamp"), + "col3": new_column("col3", "wei"), + }, ) # force the parquet writer to be created writer._flush_items() diff --git a/tests/common/test_destination.py b/tests/common/test_destination.py index 7afa10ed68..37790655c9 100644 --- a/tests/common/test_destination.py +++ b/tests/common/test_destination.py @@ -1,12 +1,11 @@ import pytest +from tests.utils import ACTIVE_DESTINATIONS from dlt.common.destination.reference import DestinationClientDwhConfiguration, DestinationReference from dlt.common.exceptions import InvalidDestinationReference, UnknownDestinationModule from dlt.common.schema import Schema from dlt.common.schema.exceptions import InvalidDatasetName -from tests.utils import ACTIVE_DESTINATIONS - def test_import_unknown_destination() -> None: # standard destination @@ -34,34 +33,83 @@ def test_import_all_destinations() -> None: def test_normalize_dataset_name() -> None: # with schema name appended - assert DestinationClientDwhConfiguration(dataset_name="ban_ana_dataset", default_schema_name="default").normalize_dataset_name(Schema("banana")) == "ban_ana_dataset_banana" + assert ( + DestinationClientDwhConfiguration( + dataset_name="ban_ana_dataset", default_schema_name="default" + ).normalize_dataset_name(Schema("banana")) + == "ban_ana_dataset_banana" + ) # without schema name appended - assert DestinationClientDwhConfiguration(dataset_name="ban_ana_dataset", default_schema_name="default").normalize_dataset_name(Schema("default")) == "ban_ana_dataset" + assert ( + DestinationClientDwhConfiguration( + dataset_name="ban_ana_dataset", default_schema_name="default" + ).normalize_dataset_name(Schema("default")) + == "ban_ana_dataset" + ) # dataset name will be normalized (now it is up to destination to normalize this) - assert DestinationClientDwhConfiguration(dataset_name="BaNaNa", default_schema_name="default").normalize_dataset_name(Schema("banana")) == "ba_na_na_banana" + assert ( + DestinationClientDwhConfiguration( + dataset_name="BaNaNa", default_schema_name="default" + ).normalize_dataset_name(Schema("banana")) + == "ba_na_na_banana" + ) # empty schemas are invalid with pytest.raises(ValueError): - DestinationClientDwhConfiguration(dataset_name="banana_dataset", default_schema_name=None).normalize_dataset_name(Schema(None)) + DestinationClientDwhConfiguration( + dataset_name="banana_dataset", default_schema_name=None + ).normalize_dataset_name(Schema(None)) with pytest.raises(ValueError): - DestinationClientDwhConfiguration(dataset_name="banana_dataset", default_schema_name="").normalize_dataset_name(Schema("")) + DestinationClientDwhConfiguration( + dataset_name="banana_dataset", default_schema_name="" + ).normalize_dataset_name(Schema("")) # empty dataset name is valid! - assert DestinationClientDwhConfiguration(dataset_name="", default_schema_name="ban_schema").normalize_dataset_name(Schema("schema_ana")) == "_schema_ana" + assert ( + DestinationClientDwhConfiguration( + dataset_name="", default_schema_name="ban_schema" + ).normalize_dataset_name(Schema("schema_ana")) + == "_schema_ana" + ) # empty dataset name is valid! - assert DestinationClientDwhConfiguration(dataset_name="", default_schema_name="schema_ana").normalize_dataset_name(Schema("schema_ana")) == "" + assert ( + DestinationClientDwhConfiguration( + dataset_name="", default_schema_name="schema_ana" + ).normalize_dataset_name(Schema("schema_ana")) + == "" + ) # None dataset name is valid! - assert DestinationClientDwhConfiguration(dataset_name=None, default_schema_name="ban_schema").normalize_dataset_name(Schema("schema_ana")) == "_schema_ana" + assert ( + DestinationClientDwhConfiguration( + dataset_name=None, default_schema_name="ban_schema" + ).normalize_dataset_name(Schema("schema_ana")) + == "_schema_ana" + ) # None dataset name is valid! - assert DestinationClientDwhConfiguration(dataset_name=None, default_schema_name="schema_ana").normalize_dataset_name(Schema("schema_ana")) is None + assert ( + DestinationClientDwhConfiguration( + dataset_name=None, default_schema_name="schema_ana" + ).normalize_dataset_name(Schema("schema_ana")) + is None + ) # now mock the schema name to make sure that it is normalized schema = Schema("barbapapa") schema._schema_name = "BarbaPapa" - assert DestinationClientDwhConfiguration(dataset_name="set", default_schema_name="default").normalize_dataset_name(schema) == "set_barba_papa" + assert ( + DestinationClientDwhConfiguration( + dataset_name="set", default_schema_name="default" + ).normalize_dataset_name(schema) + == "set_barba_papa" + ) def test_normalize_dataset_name_none_default_schema() -> None: # if default schema is None, suffix is not added - assert DestinationClientDwhConfiguration(dataset_name="ban_ana_dataset", default_schema_name=None).normalize_dataset_name(Schema("default")) == "ban_ana_dataset" + assert ( + DestinationClientDwhConfiguration( + dataset_name="ban_ana_dataset", default_schema_name=None + ).normalize_dataset_name(Schema("default")) + == "ban_ana_dataset" + ) diff --git a/tests/common/test_git.py b/tests/common/test_git.py index 96a5f33d94..f925a56d15 100644 --- a/tests/common/test_git.py +++ b/tests/common/test_git.py @@ -1,13 +1,20 @@ import os -from git import GitCommandError, RepositoryDirtyError, GitError -import pytest - -from dlt.common.storages import FileStorage -from dlt.common.git import clone_repo, ensure_remote_head, git_custom_key_command, get_fresh_repo_files, get_repo, is_dirty, is_clean_and_synced -from tests.utils import test_storage, skipifwindows +import pytest +from git import GitCommandError, GitError, RepositoryDirtyError from tests.common.utils import load_secret, modify_and_commit_file, restore_secret_storage_path - +from tests.utils import skipifwindows, test_storage + +from dlt.common.git import ( + clone_repo, + ensure_remote_head, + get_fresh_repo_files, + get_repo, + git_custom_key_command, + is_clean_and_synced, + is_dirty, +) +from dlt.common.storages import FileStorage AWESOME_REPO = "https://github.com/sindresorhus/awesome.git" JAFFLE_SHOP_REPO = "https://github.com/dbt-labs/jaffle_shop.git" @@ -42,7 +49,12 @@ def test_clone(test_storage: FileStorage) -> None: def test_clone_with_commit_id(test_storage: FileStorage) -> None: repo_path = test_storage.make_full_path("awesome_repo") # clone a small public repo - clone_repo(AWESOME_REPO, repo_path, with_git_command=None, branch="7f88000be2d4f265c83465fec4b0b3613af347dd").close() + clone_repo( + AWESOME_REPO, + repo_path, + with_git_command=None, + branch="7f88000be2d4f265c83465fec4b0b3613af347dd", + ).close() assert test_storage.has_folder("awesome_repo") # cannot pull detached head with pytest.raises(GitError): diff --git a/tests/common/test_json.py b/tests/common/test_json.py index a001584198..0a2b0c8ec3 100644 --- a/tests/common/test_json.py +++ b/tests/common/test_json.py @@ -1,16 +1,16 @@ import io import os -from typing import List, NamedTuple from dataclasses import dataclass -import pytest - -from dlt.common import json, Decimal, pendulum -from dlt.common.arithmetics import numeric_default_context -from dlt.common.json import _DECIMAL, _WEI, custom_pua_decode, _orjson, _simplejson, SupportsJson +from typing import List, NamedTuple -from tests.utils import autouse_test_storage, TEST_STORAGE_ROOT +import pytest from tests.cases import JSON_TYPED_DICT, JSON_TYPED_DICT_NESTED from tests.common.utils import json_case_path, load_json_case +from tests.utils import TEST_STORAGE_ROOT, autouse_test_storage + +from dlt.common import Decimal, json, pendulum +from dlt.common.arithmetics import numeric_default_context +from dlt.common.json import _DECIMAL, _WEI, SupportsJson, _orjson, _simplejson, custom_pua_decode class NamedTupleTest(NamedTuple): @@ -158,7 +158,10 @@ def test_json_decimals(json_impl: SupportsJson) -> None: # serialize out of local context s = json_impl.dumps(doc) # full precision. you need to quantize yourself if you need it - assert s == '{"decimal":"99999999999999999999999999999999999999999999999999999999999999999999999999.999"}' + assert ( + s + == '{"decimal":"99999999999999999999999999999999999999999999999999999999999999999999999999.999"}' + ) @pytest.mark.parametrize("json_impl", _JSON_IMPL) @@ -192,18 +195,27 @@ def test_json_pendulum(json_impl: SupportsJson) -> None: @pytest.mark.parametrize("json_impl", _JSON_IMPL) def test_json_named_tuple(json_impl: SupportsJson) -> None: - assert json_impl.dumps(NamedTupleTest("STR", Decimal("1.3333"))) == '{"str_field":"STR","dec_field":"1.3333"}' + assert ( + json_impl.dumps(NamedTupleTest("STR", Decimal("1.3333"))) + == '{"str_field":"STR","dec_field":"1.3333"}' + ) with io.BytesIO() as b: json_impl.typed_dump(NamedTupleTest("STR", Decimal("1.3333")), b) - assert b.getvalue().decode("utf-8") == '{"str_field":"STR","dec_field":"\uF0261.3333"}' + assert b.getvalue().decode("utf-8") == '{"str_field":"STR","dec_field":"\uf0261.3333"}' @pytest.mark.parametrize("json_impl", _JSON_IMPL) def test_data_class(json_impl: SupportsJson) -> None: - assert json_impl.dumps(DataClassTest(str_field="AAA")) == '{"str_field":"AAA","int_field":5,"dec_field":"0.5"}' + assert ( + json_impl.dumps(DataClassTest(str_field="AAA")) + == '{"str_field":"AAA","int_field":5,"dec_field":"0.5"}' + ) with io.BytesIO() as b: json_impl.typed_dump(DataClassTest(str_field="AAA"), b) - assert b.getvalue().decode("utf-8") == '{"str_field":"AAA","int_field":5,"dec_field":"\uF0260.5"}' + assert ( + b.getvalue().decode("utf-8") + == '{"str_field":"AAA","int_field":5,"dec_field":"\uf0260.5"}' + ) @pytest.mark.parametrize("json_impl", _JSON_IMPL) @@ -239,7 +251,7 @@ def test_json_typed_encode(json_impl: SupportsJson) -> None: assert d["decimal"][0] == _DECIMAL assert d["wei"][0] == _WEI # decode all - d_d = {k: custom_pua_decode(v) for k,v in d.items()} + d_d = {k: custom_pua_decode(v) for k, v in d.items()} assert d_d == JSON_TYPED_DICT @@ -253,6 +265,6 @@ def test_load_and_compare_all_impls() -> None: # same docs, same output for idx in range(0, len(docs) - 1): - assert docs[idx] == docs[idx+1] - assert dump_s[idx] == dump_s[idx+1] - assert dump_b[idx] == dump_b[idx+1] + assert docs[idx] == docs[idx + 1] + assert dump_s[idx] == dump_s[idx + 1] + assert dump_b[idx] == dump_b[idx + 1] diff --git a/tests/common/test_pipeline_state.py b/tests/common/test_pipeline_state.py index cce610839f..f9cebffda2 100644 --- a/tests/common/test_pipeline_state.py +++ b/tests/common/test_pipeline_state.py @@ -1,6 +1,6 @@ import re -from typing import Dict, Any from copy import deepcopy +from typing import Any, Dict from unittest import mock from dlt.common import pipeline as ps @@ -11,7 +11,7 @@ def test_delete_source_state_keys() -> None: "a": {"b": {"c": 1}}, "x": {"y": {"c": 2}}, "y": {"x": {"a": 3}}, - "resources": {"some_data": {"incremental": {"last_value": 123}}} + "resources": {"some_data": {"incremental": {"last_value": 123}}}, } state = deepcopy(_fake_source_state) @@ -54,12 +54,12 @@ def test_get_matching_resources() -> None: # with state argument results = ps._get_matching_resources(pattern, _fake_source_state) - assert sorted(results) == ['events_a', 'events_b'] + assert sorted(results) == ["events_a", "events_b"] # with state context with mock.patch.object(ps, "source_state", autospec=True, return_value=_fake_source_state): results = ps._get_matching_resources(pattern, _fake_source_state) - assert sorted(results) == ['events_a', 'events_b'] + assert sorted(results) == ["events_a", "events_b"] # no resources key results = ps._get_matching_resources(pattern, {}) diff --git a/tests/common/test_time.py b/tests/common/test_time.py index 56c6849ab8..af637f5510 100644 --- a/tests/common/test_time.py +++ b/tests/common/test_time.py @@ -1,9 +1,15 @@ +from datetime import date, datetime, timedelta, timezone # noqa: I251 + import pytest -from datetime import datetime, date, timezone, timedelta # noqa: I251 from pendulum.tz import UTC from dlt.common import pendulum -from dlt.common.time import timestamp_before, timestamp_within, ensure_pendulum_datetime, ensure_pendulum_date +from dlt.common.time import ( + ensure_pendulum_date, + ensure_pendulum_datetime, + timestamp_before, + timestamp_within, +) from dlt.common.typing import TAnyDateTime @@ -72,9 +78,7 @@ def test_before() -> None: @pytest.mark.parametrize("date_value, expected", test_params) -def test_ensure_pendulum_datetime( - date_value: TAnyDateTime, expected: pendulum.DateTime -) -> None: +def test_ensure_pendulum_datetime(date_value: TAnyDateTime, expected: pendulum.DateTime) -> None: dt = ensure_pendulum_datetime(date_value) assert dt == expected # always UTC @@ -86,4 +90,6 @@ def test_ensure_pendulum_datetime( def test_ensure_pendulum_date_utc() -> None: # when converting from datetimes make sure to shift to UTC before doing date assert ensure_pendulum_date("2021-01-01T00:00:00+05:00") == pendulum.date(2020, 12, 31) - assert ensure_pendulum_date(datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone(timedelta(hours=8)))) == pendulum.date(2020, 12, 31) \ No newline at end of file + assert ensure_pendulum_date( + datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone(timedelta(hours=8))) + ) == pendulum.date(2020, 12, 31) diff --git a/tests/common/test_typing.py b/tests/common/test_typing.py index 278eb38973..9c887b033c 100644 --- a/tests/common/test_typing.py +++ b/tests/common/test_typing.py @@ -1,10 +1,33 @@ +from typing import ( + List, + Literal, + Mapping, + MutableMapping, + MutableSequence, + NewType, + Optional, + Sequence, + TypedDict, + TypeVar, + Union, +) -from typing import List, Literal, Mapping, MutableMapping, MutableSequence, NewType, Sequence, TypeVar, TypedDict, Optional, Union -from dlt.common.configuration.specs.base_configuration import BaseConfiguration, get_config_if_union_hint from dlt.common.configuration.specs import GcpServiceAccountCredentialsWithoutDefaults - -from dlt.common.typing import StrAny, extract_inner_type, extract_optional_type, is_dict_generic_type, is_list_generic_type, is_literal_type, is_newtype_type, is_optional_type, is_typeddict - +from dlt.common.configuration.specs.base_configuration import ( + BaseConfiguration, + get_config_if_union_hint, +) +from dlt.common.typing import ( + StrAny, + extract_inner_type, + extract_optional_type, + is_dict_generic_type, + is_list_generic_type, + is_literal_type, + is_newtype_type, + is_optional_type, + is_typeddict, +) class TTestTyDi(TypedDict): @@ -76,4 +99,7 @@ def test_get_config_if_union() -> None: assert get_config_if_union_hint(Union[BaseException, str, StrAny]) is None assert get_config_if_union_hint(Union[BaseConfiguration, str, StrAny]) is BaseConfiguration assert get_config_if_union_hint(Union[str, BaseConfiguration, StrAny]) is BaseConfiguration - assert get_config_if_union_hint(Union[GcpServiceAccountCredentialsWithoutDefaults, StrAny, str]) is GcpServiceAccountCredentialsWithoutDefaults + assert ( + get_config_if_union_hint(Union[GcpServiceAccountCredentialsWithoutDefaults, StrAny, str]) + is GcpServiceAccountCredentialsWithoutDefaults + ) diff --git a/tests/common/test_utils.py b/tests/common/test_utils.py index c4541365ff..1b32403267 100644 --- a/tests/common/test_utils.py +++ b/tests/common/test_utils.py @@ -1,11 +1,21 @@ -import itertools -import inspect import binascii +import inspect +import itertools + import pytest from dlt.common.runners import Venv -from dlt.common.utils import (graph_find_scc_nodes, flatten_list_of_str_or_dicts, digest128, graph_edges_to_nodes, map_nested_in_place, - reveal_pseudo_secret, obfuscate_pseudo_secret, get_module_name, concat_strings_with_limit) +from dlt.common.utils import ( + concat_strings_with_limit, + digest128, + flatten_list_of_str_or_dicts, + get_module_name, + graph_edges_to_nodes, + graph_find_scc_nodes, + map_nested_in_place, + obfuscate_pseudo_secret, + reveal_pseudo_secret, +) def test_flatten_list_of_str_or_dicts() -> None: @@ -19,30 +29,27 @@ def test_flatten_list_of_str_or_dicts() -> None: def test_digest128_length() -> None: - assert len(digest128("hash it")) == 120/6 + assert len(digest128("hash it")) == 120 / 6 def test_map_dicts_in_place() -> None: - _d = { - "a": "1", - "b": ["a", "b", ["a", "b"], {"a": "c"}], - "c": { - "d": "e", - "e": ["a", 2] - } + _d = {"a": "1", "b": ["a", "b", ["a", "b"], {"a": "c"}], "c": {"d": "e", "e": ["a", 2]}} + exp_d = { + "a": "11", + "b": ["aa", "bb", ["aa", "bb"], {"a": "cc"}], + "c": {"d": "ee", "e": ["aa", 4]}, } - exp_d = {'a': '11', 'b': ['aa', 'bb', ['aa', 'bb'], {'a': 'cc'}], 'c': {'d': 'ee', 'e': ['aa', 4]}} - assert map_nested_in_place(lambda v: v*2, _d) == exp_d + assert map_nested_in_place(lambda v: v * 2, _d) == exp_d # in place assert _d == exp_d _l = ["a", "b", ["a", "b"], {"a": "c"}] exp_l = ["aa", "bb", ["aa", "bb"], {"a": "cc"}] - assert map_nested_in_place(lambda v: v*2, _l) == exp_l + assert map_nested_in_place(lambda v: v * 2, _l) == exp_l assert _l == exp_l with pytest.raises(ValueError): - map_nested_in_place(lambda v: v*2, "a") + map_nested_in_place(lambda v: v * 2, "a") def test_pseudo_obfuscation() -> None: @@ -77,9 +84,25 @@ def test_concat_strings_with_limit() -> None: assert list(concat_strings_with_limit(philosopher, ";\n", 15)) == ["Bertrand Russell"] # only two strings will be merged (22 chars total) - philosophers = ["Bertrand Russell", "Ludwig Wittgenstein", "G.E. Moore", "J.L. Mackie", "Alfred Tarski"] - moore_merged = ['Bertrand Russell', 'Ludwig Wittgenstein', 'G.E. Moore J.L. Mackie', 'Alfred Tarski'] - moore_merged_2 = ['Bertrand Russell', 'Ludwig Wittgenstein', 'G.E. Moore;\nJ.L. Mackie', 'Alfred Tarski'] + philosophers = [ + "Bertrand Russell", + "Ludwig Wittgenstein", + "G.E. Moore", + "J.L. Mackie", + "Alfred Tarski", + ] + moore_merged = [ + "Bertrand Russell", + "Ludwig Wittgenstein", + "G.E. Moore J.L. Mackie", + "Alfred Tarski", + ] + moore_merged_2 = [ + "Bertrand Russell", + "Ludwig Wittgenstein", + "G.E. Moore;\nJ.L. Mackie", + "Alfred Tarski", + ] assert list(concat_strings_with_limit(philosophers, " ", 22)) == moore_merged # none will be merged assert list(concat_strings_with_limit(philosophers, ";\n", 22)) == philosophers @@ -92,7 +115,7 @@ def test_concat_strings_with_limit() -> None: def test_find_scc_nodes() -> None: - edges = [('A', 'B'), ('B', 'C'), ('D', 'E'), ('F', 'G'), ('G', 'H'), ('I', 'I'), ('J', 'J')] + edges = [("A", "B"), ("B", "C"), ("D", "E"), ("F", "G"), ("G", "H"), ("I", "I"), ("J", "J")] def _comp(s): return sorted([tuple(sorted(c)) for c in s]) @@ -111,8 +134,28 @@ def _comp(s): def test_graph_edges_to_nodes() -> None: - edges = [('A', 'B'), ('A', 'C'), ('B', 'C'), ('D', 'E'), ('F', 'G'), ('G', 'H'), ('I', 'I'), ('J', 'J')] - graph = {"A": {"B", "C"}, "B": {"C"}, "C": set(), "D": {"E"}, "E": set(), "F": {"G"}, "G": {"H"}, "H": set(), "I": set(), "J": set()} + edges = [ + ("A", "B"), + ("A", "C"), + ("B", "C"), + ("D", "E"), + ("F", "G"), + ("G", "H"), + ("I", "I"), + ("J", "J"), + ] + graph = { + "A": {"B", "C"}, + "B": {"C"}, + "C": set(), + "D": {"E"}, + "E": set(), + "F": {"G"}, + "G": {"H"}, + "H": set(), + "I": set(), + "J": set(), + } g1 = graph_edges_to_nodes(edges) for perm_edges in itertools.permutations(edges): @@ -124,4 +167,4 @@ def test_graph_edges_to_nodes() -> None: # test a few edge cases assert graph_edges_to_nodes([]) == {} # ignores double edge - assert graph_edges_to_nodes([('A', 'B'), ('A', 'B')]) == {'A': {'B'}, 'B': set()} + assert graph_edges_to_nodes([("A", "B"), ("A", "B")]) == {"A": {"B"}, "B": set()} diff --git a/tests/common/test_validation.py b/tests/common/test_validation.py index d4885ccd67..08460a82ea 100644 --- a/tests/common/test_validation.py +++ b/tests/common/test_validation.py @@ -1,11 +1,12 @@ from copy import deepcopy +from typing import Dict, List, Literal, Mapping, Optional, Sequence, TypedDict + import pytest import yaml -from typing import Dict, List, Literal, Mapping, Sequence, TypedDict, Optional from dlt.common import json from dlt.common.exceptions import DictValidationException -from dlt.common.schema.typing import TStoredSchema, TColumnSchema +from dlt.common.schema.typing import TColumnSchema, TStoredSchema from dlt.common.schema.utils import simple_regex_validator from dlt.common.typing import DictStrStr, StrStr from dlt.common.validation import validate_dict, validate_dict_ignoring_xkeys @@ -33,28 +34,12 @@ class TTestRecord(TypedDict): f_seq_literal: Sequence[Optional[TLiteral]] -TEST_COL = { - "name": "col1", - "data_type": "bigint", - "nullable": False - } +TEST_COL = {"name": "col1", "data_type": "bigint", "nullable": False} TEST_COL_LIST = [ - { - "name": "col1", - "data_type": "bigint", - "nullable": False - }, - { - "name": "col2", - "data_type": "double", - "nullable": False - }, - { - "name": "col3", - "data_type": "bool", - "nullable": False - } + {"name": "col1", "data_type": "bigint", "nullable": False}, + {"name": "col2", "data_type": "double", "nullable": False}, + {"name": "col3", "data_type": "bool", "nullable": False}, ] TEST_DOC: TTestRecord = { @@ -67,30 +52,30 @@ class TTestRecord(TypedDict): "f_seq_simple": ["x", "y"], "f_seq_optional_str": ["opt1", "opt2"], "f_seq_of_optional_int": [1, 2, 3], - "f_list_of_dict": TEST_COL_LIST, + "f_list_of_dict": TEST_COL_LIST, "f_dict_simple": {"col1": "map_me"}, "f_map_simple": {"col1": "map_me"}, "f_map_of_dict": {"col1": deepcopy(TEST_COL)}, "f_column": deepcopy(TEST_COL), "f_literal": "uno", "f_literal_optional": "dos", - "f_seq_literal": ["uno", "dos", "tres"] + "f_seq_literal": ["uno", "dos", "tres"], } + @pytest.fixture def test_doc() -> TTestRecord: return deepcopy(TEST_DOC) def test_validate_schema_cases() -> None: - with open("tests/common/cases/schemas/eth/ethereum_schema_v4.yml", mode="r", encoding="utf-8") as f: + with open( + "tests/common/cases/schemas/eth/ethereum_schema_v4.yml", mode="r", encoding="utf-8" + ) as f: schema_dict: TStoredSchema = yaml.safe_load(f) validate_dict_ignoring_xkeys( - spec=TStoredSchema, - doc=schema_dict, - path=".", - validator_f=simple_regex_validator + spec=TStoredSchema, doc=schema_dict, path=".", validator_f=simple_regex_validator ) # with open("tests/common/cases/schemas/rasa/event.schema.json") as f: diff --git a/tests/common/test_version.py b/tests/common/test_version.py index 765d690ede..645e51c82e 100644 --- a/tests/common/test_version.py +++ b/tests/common/test_version.py @@ -1,7 +1,8 @@ import os -import pytest from importlib.metadata import PackageNotFoundError +import pytest + from dlt.version import get_installed_requirement_string diff --git a/tests/common/test_wei.py b/tests/common/test_wei.py index 8ee47d11c0..631eb0ff6d 100644 --- a/tests/common/test_wei.py +++ b/tests/common/test_wei.py @@ -1,5 +1,5 @@ from dlt.common.typing import SupportsVariant -from dlt.common.wei import Wei, Decimal +from dlt.common.wei import Decimal, Wei def test_init() -> None: @@ -7,9 +7,12 @@ def test_init() -> None: assert Wei.from_int256(10**18, decimals=18) == 1 # make sure the wei scale is supported assert Wei.from_int256(1, decimals=18) == Decimal("0.000000000000000001") - assert Wei.from_int256(2**256-1) == 2**256-1 - assert str(Wei.from_int256(2**256-1, decimals=18)) == "115792089237316195423570985008687907853269984665640564039457.584007913129639935" - assert str(Wei.from_int256(2**256-1)) == str(2**256-1) + assert Wei.from_int256(2**256 - 1) == 2**256 - 1 + assert ( + str(Wei.from_int256(2**256 - 1, decimals=18)) + == "115792089237316195423570985008687907853269984665640564039457.584007913129639935" + ) + assert str(Wei.from_int256(2**256 - 1)) == str(2**256 - 1) assert type(Wei.from_int256(1)) is Wei @@ -30,6 +33,14 @@ def test_wei_variant() -> None: # we get variant value when we call Wei assert Wei(578960446186580977117854925043439539266)() == 578960446186580977117854925043439539266 - assert Wei(578960446186580977117854925043439539267)() == ("str", "578960446186580977117854925043439539267") - assert Wei(-578960446186580977117854925043439539267)() == -578960446186580977117854925043439539267 - assert Wei(-578960446186580977117854925043439539268)() == ("str", "-578960446186580977117854925043439539268") + assert Wei(578960446186580977117854925043439539267)() == ( + "str", + "578960446186580977117854925043439539267", + ) + assert ( + Wei(-578960446186580977117854925043439539267)() == -578960446186580977117854925043439539267 + ) + assert Wei(-578960446186580977117854925043439539268)() == ( + "str", + "-578960446186580977117854925043439539268", + ) diff --git a/tests/common/utils.py b/tests/common/utils.py index 7a49a80efb..1d8ff45ae9 100644 --- a/tests/common/utils.py +++ b/tests/common/utils.py @@ -1,23 +1,25 @@ -import pytest +import datetime # noqa: 251 import os -import yaml -from git import Repo, Commit from pathlib import Path from typing import Mapping, Tuple, cast -import datetime # noqa: 251 + +import pytest +import yaml +from git import Commit, Repo from dlt.common import json -from dlt.common.typing import StrAny +from dlt.common.configuration.providers import environ as environ_provider from dlt.common.schema import utils from dlt.common.schema.typing import TTableSchemaColumns -from dlt.common.configuration.providers import environ as environ_provider - +from dlt.common.typing import StrAny COMMON_TEST_CASES_PATH = "./tests/common/cases/" # for import schema tests, change when upgrading the schema version IMPORTED_VERSION_HASH_ETH_V6 = "++bJOVuScYYoVUFtjmZMBV+cxsWs8irYHIMV8J1xD5g=" # test sentry DSN -TEST_SENTRY_DSN = "https://797678dd0af64b96937435326c7d30c1@o1061158.ingest.sentry.io/4504306172821504" +TEST_SENTRY_DSN = ( + "https://797678dd0af64b96937435326c7d30c1@o1061158.ingest.sentry.io/4504306172821504" +) # preserve secrets path to be able to restore it SECRET_STORAGE_PATH = environ_provider.SECRET_STORAGE_PATH @@ -41,11 +43,10 @@ def yml_case_path(name: str) -> str: def row_to_column_schemas(row: StrAny) -> TTableSchemaColumns: - return {k: utils.add_missing_hints({ - "name": k, - "data_type": "text", - "nullable": False - }) for k in row.keys()} + return { + k: utils.add_missing_hints({"name": k, "data_type": "text", "nullable": False}) + for k in row.keys() + } @pytest.fixture(autouse=True) @@ -55,13 +56,17 @@ def restore_secret_storage_path() -> None: def load_secret(name: str) -> str: environ_provider.SECRET_STORAGE_PATH = "./tests/common/cases/secrets/%s" - secret, _ = environ_provider.EnvironProvider().get_value(name, environ_provider.TSecretValue, None) + secret, _ = environ_provider.EnvironProvider().get_value( + name, environ_provider.TSecretValue, None + ) if not secret: raise FileNotFoundError(environ_provider.SECRET_STORAGE_PATH % name) return secret -def modify_and_commit_file(repo_path: str, file_name: str, content: str = "NEW README CONTENT") -> Tuple[str, Commit]: +def modify_and_commit_file( + repo_path: str, file_name: str, content: str = "NEW README CONTENT" +) -> Tuple[str, Commit]: file_path = os.path.join(repo_path, file_name) with open(file_path, "w", encoding="utf-8") as f: diff --git a/tests/conftest.py b/tests/conftest.py index d084e3f3af..5f86321bea 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,15 +1,29 @@ -import os import dataclasses import logging +import os from typing import List # patch which providers to enable -from dlt.common.configuration.providers import ConfigProvider, EnvironProvider, SecretsTomlProvider, ConfigTomlProvider -from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext, ConfigProvidersConfiguration +from dlt.common.configuration.providers import ( + ConfigProvider, + ConfigTomlProvider, + EnvironProvider, + SecretsTomlProvider, +) +from dlt.common.configuration.specs.config_providers_context import ( + ConfigProvidersConfiguration, + ConfigProvidersContext, +) + def initial_providers() -> List[ConfigProvider]: # do not read the global config - return [EnvironProvider(), SecretsTomlProvider(project_dir="tests/.dlt", add_global_config=False), ConfigTomlProvider(project_dir="tests/.dlt", add_global_config=False)] + return [ + EnvironProvider(), + SecretsTomlProvider(project_dir="tests/.dlt", add_global_config=False), + ConfigTomlProvider(project_dir="tests/.dlt", add_global_config=False), + ] + ConfigProvidersContext.initial_providers = initial_providers # also disable extras @@ -26,29 +40,47 @@ def pytest_configure(config): from dlt.common.storages import configuration as storage_configuration test_storage_root = "_storage" - run_configuration.RunConfiguration.config_files_storage_path = os.path.join(test_storage_root, "config/") - run_configuration.RunConfiguration.dlthub_telemetry_segment_write_key = "TLJiyRkGVZGCi2TtjClamXpFcxAA1rSB" + run_configuration.RunConfiguration.config_files_storage_path = os.path.join( + test_storage_root, "config/" + ) + run_configuration.RunConfiguration.dlthub_telemetry_segment_write_key = ( + "TLJiyRkGVZGCi2TtjClamXpFcxAA1rSB" + ) delattr(run_configuration.RunConfiguration, "__init__") run_configuration.RunConfiguration = dataclasses.dataclass(run_configuration.RunConfiguration, init=True, repr=False) # type: ignore # push telemetry to CI - storage_configuration.LoadStorageConfiguration.load_volume_path = os.path.join(test_storage_root, "load") + storage_configuration.LoadStorageConfiguration.load_volume_path = os.path.join( + test_storage_root, "load" + ) delattr(storage_configuration.LoadStorageConfiguration, "__init__") - storage_configuration.LoadStorageConfiguration = dataclasses.dataclass(storage_configuration.LoadStorageConfiguration, init=True, repr=False) + storage_configuration.LoadStorageConfiguration = dataclasses.dataclass( + storage_configuration.LoadStorageConfiguration, init=True, repr=False + ) - storage_configuration.NormalizeStorageConfiguration.normalize_volume_path = os.path.join(test_storage_root, "normalize") + storage_configuration.NormalizeStorageConfiguration.normalize_volume_path = os.path.join( + test_storage_root, "normalize" + ) # delete __init__, otherwise it will not be recreated by dataclass delattr(storage_configuration.NormalizeStorageConfiguration, "__init__") - storage_configuration.NormalizeStorageConfiguration = dataclasses.dataclass(storage_configuration.NormalizeStorageConfiguration, init=True, repr=False) + storage_configuration.NormalizeStorageConfiguration = dataclasses.dataclass( + storage_configuration.NormalizeStorageConfiguration, init=True, repr=False + ) - storage_configuration.SchemaStorageConfiguration.schema_volume_path = os.path.join(test_storage_root, "schemas") + storage_configuration.SchemaStorageConfiguration.schema_volume_path = os.path.join( + test_storage_root, "schemas" + ) delattr(storage_configuration.SchemaStorageConfiguration, "__init__") - storage_configuration.SchemaStorageConfiguration = dataclasses.dataclass(storage_configuration.SchemaStorageConfiguration, init=True, repr=False) - - - assert run_configuration.RunConfiguration.config_files_storage_path == os.path.join(test_storage_root, "config/") - assert run_configuration.RunConfiguration().config_files_storage_path == os.path.join(test_storage_root, "config/") + storage_configuration.SchemaStorageConfiguration = dataclasses.dataclass( + storage_configuration.SchemaStorageConfiguration, init=True, repr=False + ) + assert run_configuration.RunConfiguration.config_files_storage_path == os.path.join( + test_storage_root, "config/" + ) + assert run_configuration.RunConfiguration().config_files_storage_path == os.path.join( + test_storage_root, "config/" + ) # path pipeline instance id up to millisecond from dlt.common import pendulum @@ -59,7 +91,9 @@ def _create_pipeline_instance_id(self) -> str: Pipeline._create_pipeline_instance_id = _create_pipeline_instance_id # push sentry to ci - os.environ["RUNTIME__SENTRY_DSN"] = "https://6f6f7b6f8e0f458a89be4187603b55fe@o1061158.ingest.sentry.io/4504819859914752" + os.environ["RUNTIME__SENTRY_DSN"] = ( + "https://6f6f7b6f8e0f458a89be4187603b55fe@o1061158.ingest.sentry.io/4504819859914752" + ) # disable sqlfluff logging for log in ["sqlfluff.parser", "sqlfluff.linter", "sqlfluff.templater", "sqlfluff.lexer"]: diff --git a/tests/destinations/test_path_utils.py b/tests/destinations/test_path_utils.py index 9c01759d1f..ef9e081b90 100644 --- a/tests/destinations/test_path_utils.py +++ b/tests/destinations/test_path_utils.py @@ -1,7 +1,7 @@ import pytest from dlt.destinations import path_utils -from dlt.destinations.exceptions import InvalidFilesystemLayout, CantExtractTablePrefix +from dlt.destinations.exceptions import CantExtractTablePrefix, InvalidFilesystemLayout def test_layout_validity() -> None: @@ -18,9 +18,11 @@ def test_create_path() -> None: "table_name": "table_name", "load_id": "load_id", "file_id": "file_id", - "ext": "ext" + "ext": "ext", } - path = path_utils.create_path("{schema_name}/{table_name}/{load_id}.{file_id}.{ext}", **path_vars) + path = path_utils.create_path( + "{schema_name}/{table_name}/{load_id}.{file_id}.{ext}", **path_vars + ) assert path == "schema_name/table_name/load_id.file_id.ext" # extension gets added automatically @@ -29,14 +31,23 @@ def test_create_path() -> None: def test_get_table_prefix_layout() -> None: - - prefix_layout = path_utils.get_table_prefix_layout("{schema_name}/{table_name}/{load_id}.{file_id}.{ext}") + prefix_layout = path_utils.get_table_prefix_layout( + "{schema_name}/{table_name}/{load_id}.{file_id}.{ext}" + ) assert prefix_layout == "{schema_name}/{table_name}/" - assert prefix_layout.format(schema_name="my_schema", table_name="my_table") == "my_schema/my_table/" + assert ( + prefix_layout.format(schema_name="my_schema", table_name="my_table") + == "my_schema/my_table/" + ) - prefix_layout = path_utils.get_table_prefix_layout("some_random{schema_name}/stuff_in_between/{table_name}/{load_id}") + prefix_layout = path_utils.get_table_prefix_layout( + "some_random{schema_name}/stuff_in_between/{table_name}/{load_id}" + ) assert prefix_layout == "some_random{schema_name}/stuff_in_between/{table_name}/" - assert prefix_layout.format(schema_name="my_schema", table_name="my_table") == "some_randommy_schema/stuff_in_between/my_table/" + assert ( + prefix_layout.format(schema_name="my_schema", table_name="my_table") + == "some_randommy_schema/stuff_in_between/my_table/" + ) # disallow missing table_name with pytest.raises(CantExtractTablePrefix): @@ -48,7 +59,10 @@ def test_get_table_prefix_layout() -> None: # disallow any placeholders before table name (ie. Athena) with pytest.raises(CantExtractTablePrefix): - path_utils.get_table_prefix_layout("{schema_name}some_random{table_name}/stuff_in_between/", supported_prefix_placeholders=[]) + path_utils.get_table_prefix_layout( + "{schema_name}some_random{table_name}/stuff_in_between/", + supported_prefix_placeholders=[], + ) # disallow table_name without following separator with pytest.raises(CantExtractTablePrefix): diff --git a/tests/extract/cases/eth_source/source.py b/tests/extract/cases/eth_source/source.py index 08adb79a22..4ea233526a 100644 --- a/tests/extract/cases/eth_source/source.py +++ b/tests/extract/cases/eth_source/source.py @@ -1,6 +1,8 @@ from typing import Any + import dlt + @dlt.source def ethereum() -> Any: # this just tests if the schema "ethereum" was loaded diff --git a/tests/extract/cases/section_source/external_resources.py b/tests/extract/cases/section_source/external_resources.py index 0a991d7438..cfa8640321 100644 --- a/tests/extract/cases/section_source/external_resources.py +++ b/tests/extract/cases/section_source/external_resources.py @@ -1,24 +1,32 @@ -import dlt - from tests.extract.cases.section_source import init_resource_f_2 from tests.extract.cases.section_source.named_module import resource_f_2 +import dlt + @dlt.source def with_external(source_val: str = dlt.config.value): - @dlt.resource def inner_resource(val): yield val - return dlt.resource([source_val], name="source_val"), inner_resource(source_val), init_resource_f_2, resource_f_2 + return ( + dlt.resource([source_val], name="source_val"), + inner_resource(source_val), + init_resource_f_2, + resource_f_2, + ) @dlt.source def with_bound_external(source_val: str = dlt.config.value): - @dlt.resource def inner_resource(val): yield val - return dlt.resource([source_val], name="source_val"), inner_resource(source_val), init_resource_f_2(), resource_f_2() \ No newline at end of file + return ( + dlt.resource([source_val], name="source_val"), + inner_resource(source_val), + init_resource_f_2(), + resource_f_2(), + ) diff --git a/tests/extract/cases/section_source/named_module.py b/tests/extract/cases/section_source/named_module.py index 4a46ad0e19..c7580982b6 100644 --- a/tests/extract/cases/section_source/named_module.py +++ b/tests/extract/cases/section_source/named_module.py @@ -7,6 +7,7 @@ def source_f_1(val: str = dlt.config.value): return dlt.resource([val], name="f_1") + @dlt.resource def resource_f_2(val: str = dlt.config.value): yield [val] diff --git a/tests/extract/conftest.py b/tests/extract/conftest.py index f5dc47f54b..4729f0baec 100644 --- a/tests/extract/conftest.py +++ b/tests/extract/conftest.py @@ -1 +1,7 @@ -from tests.utils import duckdb_pipeline_location, autouse_test_storage, preserve_environ, patch_home_dir, wipe_pipeline \ No newline at end of file +from tests.utils import ( + autouse_test_storage, + duckdb_pipeline_location, + patch_home_dir, + preserve_environ, + wipe_pipeline, +) diff --git a/tests/extract/test_decorators.py b/tests/extract/test_decorators.py index 21ac047547..126a91f9e1 100644 --- a/tests/extract/test_decorators.py +++ b/tests/extract/test_decorators.py @@ -1,7 +1,10 @@ import os + import pytest +from tests.common.utils import IMPORTED_VERSION_HASH_ETH_V6 import dlt +from dlt.cli.source_detection import detect_source_configs from dlt.common.configuration import known_sections from dlt.common.configuration.container import Container from dlt.common.configuration.inject import get_fun_spec @@ -9,16 +12,24 @@ from dlt.common.configuration.specs.config_section_context import ConfigSectionContext from dlt.common.exceptions import ArgumentsOverloadException, DictValidationException from dlt.common.pipeline import StateInjectableContext, TPipelineState -from dlt.common.source import _SOURCES from dlt.common.schema import Schema +from dlt.common.schema.exceptions import InvalidSchemaName from dlt.common.schema.utils import new_table - -from dlt.cli.source_detection import detect_source_configs -from dlt.extract.exceptions import ExplicitSourceNameInvalid, InvalidResourceDataTypeFunctionNotAGenerator, InvalidResourceDataTypeIsNone, ParametrizedResourceUnbound, PipeNotBoundToData, ResourceFunctionExpected, ResourceInnerCallableConfigWrapDisallowed, SourceDataIsNone, SourceIsAClassTypeError, SourceNotAFunction, SourceSchemaNotAvailable +from dlt.common.source import _SOURCES +from dlt.extract.exceptions import ( + ExplicitSourceNameInvalid, + InvalidResourceDataTypeFunctionNotAGenerator, + InvalidResourceDataTypeIsNone, + ParametrizedResourceUnbound, + PipeNotBoundToData, + ResourceFunctionExpected, + ResourceInnerCallableConfigWrapDisallowed, + SourceDataIsNone, + SourceIsAClassTypeError, + SourceNotAFunction, + SourceSchemaNotAvailable, +) from dlt.extract.source import DltResource, DltSource -from dlt.common.schema.exceptions import InvalidSchemaName - -from tests.common.utils import IMPORTED_VERSION_HASH_ETH_V6 def test_none_returning_source() -> None: @@ -31,12 +42,10 @@ def empty() -> None: with pytest.raises(SourceDataIsNone): dlt.source(empty)() - @dlt.source def deco_empty() -> None: pass - with pytest.raises(SourceDataIsNone): deco_empty() @@ -69,7 +78,6 @@ def test_load_schema_for_callable() -> None: def test_unbound_parametrized_transformer() -> None: - empty_pipe = DltResource.Empty._pipe assert empty_pipe.is_empty assert not empty_pipe.is_data_bound @@ -113,7 +121,7 @@ def test_transformer_no_parens() -> None: bound_r = dlt.resource([1, 2, 3], name="data") @dlt.transformer - def empty_t_1(item, meta = None): + def empty_t_1(item, meta=None): yield "a" * item assert list(bound_r | empty_t_1) == ["a", "aa", "aaa"] @@ -127,7 +135,6 @@ def empty_t_2(item, _meta): def test_source_name_is_invalid_schema_name() -> None: - def camelCase(): return dlt.resource([1, 2, 3], name="resource") @@ -152,10 +159,13 @@ def camelCase(): def test_resource_name_is_invalid_table_name_and_columns() -> None: - @dlt.source def camelCase(): - return dlt.resource([1, 2, 3], name="Resource !", columns={"KA!AX": {"name": "DIF!", "nullable": False, "data_type": "text"}}) + return dlt.resource( + [1, 2, 3], + name="Resource !", + columns={"KA!AX": {"name": "DIF!", "nullable": False, "data_type": "text"}}, + ) s = camelCase() assert s.resources["Resource !"].selected @@ -170,10 +180,11 @@ def camelCase(): def test_columns_argument() -> None: - - @dlt.resource(name="user", columns={"tags": {"data_type": "complex", "x-extra": "x-annotation"}}) + @dlt.resource( + name="user", columns={"tags": {"data_type": "complex", "x-extra": "x-annotation"}} + ) def get_users(): - yield {"u": "u", "tags": [1, 2 ,3]} + yield {"u": "u", "tags": [1, 2, 3]} t = get_users().table_schema() # nullable is added @@ -204,12 +215,12 @@ def some_data(): def test_source_sections() -> None: # source in __init__.py of module - from tests.extract.cases.section_source import init_source_f_1, init_resource_f_2 + from tests.extract.cases.section_source import init_resource_f_2, init_source_f_1 + # source in file module with name override - from tests.extract.cases.section_source.named_module import source_f_1, resource_f_2 + from tests.extract.cases.section_source.named_module import resource_f_2, source_f_1 # we crawl the sections from the most general (no section) to full path - # values without section os.environ["VAL"] = "TOP LEVEL" assert list(init_source_f_1()) == ["TOP LEVEL"] @@ -234,21 +245,27 @@ def test_source_sections() -> None: assert list(resource_f_2()) == ["NAME OVERRIDDEN LEVEL"] # values in function name section - os.environ[f"{known_sections.SOURCES.upper()}__SECTION_SOURCE__INIT_SOURCE_F_1__VAL"] = "SECTION INIT_SOURCE_F_1 LEVEL" + os.environ[f"{known_sections.SOURCES.upper()}__SECTION_SOURCE__INIT_SOURCE_F_1__VAL"] = ( + "SECTION INIT_SOURCE_F_1 LEVEL" + ) assert list(init_source_f_1()) == ["SECTION INIT_SOURCE_F_1 LEVEL"] - os.environ[f"{known_sections.SOURCES.upper()}__SECTION_SOURCE__INIT_RESOURCE_F_2__VAL"] = "SECTION INIT_RESOURCE_F_2 LEVEL" + os.environ[f"{known_sections.SOURCES.upper()}__SECTION_SOURCE__INIT_RESOURCE_F_2__VAL"] = ( + "SECTION INIT_RESOURCE_F_2 LEVEL" + ) assert list(init_resource_f_2()) == ["SECTION INIT_RESOURCE_F_2 LEVEL"] - os.environ[f"{known_sections.SOURCES.upper()}__NAME_OVERRIDDEN__SOURCE_F_1__VAL"] = "NAME SOURCE_F_1 LEVEL" + os.environ[f"{known_sections.SOURCES.upper()}__NAME_OVERRIDDEN__SOURCE_F_1__VAL"] = ( + "NAME SOURCE_F_1 LEVEL" + ) assert list(source_f_1()) == ["NAME SOURCE_F_1 LEVEL"] - os.environ[f"{known_sections.SOURCES.upper()}__NAME_OVERRIDDEN__RESOURCE_F_2__VAL"] = "NAME RESOURCE_F_2 LEVEL" + os.environ[f"{known_sections.SOURCES.upper()}__NAME_OVERRIDDEN__RESOURCE_F_2__VAL"] = ( + "NAME RESOURCE_F_2 LEVEL" + ) assert list(resource_f_2()) == ["NAME RESOURCE_F_2 LEVEL"] def test_source_explicit_section() -> None: - @dlt.source(section="custom_section", schema=Schema("custom_section")) def with_section(secret=dlt.secrets.value): - @dlt.resource def mod_state(): dlt.current.source_state()["val"] = secret @@ -265,7 +282,6 @@ def mod_state(): def test_resource_section() -> None: - r = dlt.resource([1, 2, 3], name="T") assert r.name == "T" assert r.section is None @@ -278,14 +294,23 @@ def _inner_gen(): assert r.section == "test_decorators" from tests.extract.cases.section_source.external_resources import init_resource_f_2 + assert init_resource_f_2.name == "init_resource_f_2" assert init_resource_f_2.section == "section_source" def test_resources_injected_sections() -> None: - from tests.extract.cases.section_source.external_resources import with_external, with_bound_external, init_resource_f_2, resource_f_2 + from tests.extract.cases.section_source.external_resources import ( + init_resource_f_2, + resource_f_2, + with_bound_external, + with_external, + ) + # standalone resources must accept the injected sections for lookups - os.environ["SOURCES__EXTERNAL_RESOURCES__SOURCE_VAL"] = "SOURCES__EXTERNAL_RESOURCES__SOURCE_VAL" + os.environ["SOURCES__EXTERNAL_RESOURCES__SOURCE_VAL"] = ( + "SOURCES__EXTERNAL_RESOURCES__SOURCE_VAL" + ) os.environ["SOURCES__EXTERNAL_RESOURCES__VAL"] = "SOURCES__EXTERNAL_RESOURCES__VAL" os.environ["SOURCES__SECTION_SOURCE__VAL"] = "SOURCES__SECTION_SOURCE__VAL" os.environ["SOURCES__NAME_OVERRIDDEN__VAL"] = "SOURCES__NAME_OVERRIDDEN__VAL" @@ -300,44 +325,59 @@ def test_resources_injected_sections() -> None: "SOURCES__EXTERNAL_RESOURCES__SOURCE_VAL", "SOURCES__EXTERNAL_RESOURCES__SOURCE_VAL", "SOURCES__EXTERNAL_RESOURCES__VAL", - "SOURCES__EXTERNAL_RESOURCES__VAL" + "SOURCES__EXTERNAL_RESOURCES__VAL", ] # this source will bind external resources before returning them (that is: calling them and obtaining generators) # the iterator in the source will force its sections so external resource sections are not used s = with_bound_external() - assert list(s) == list([ - "SOURCES__EXTERNAL_RESOURCES__SOURCE_VAL", - "SOURCES__EXTERNAL_RESOURCES__SOURCE_VAL", - "SOURCES__EXTERNAL_RESOURCES__VAL", - "SOURCES__EXTERNAL_RESOURCES__VAL" - ]) + assert list(s) == list( + [ + "SOURCES__EXTERNAL_RESOURCES__SOURCE_VAL", + "SOURCES__EXTERNAL_RESOURCES__SOURCE_VAL", + "SOURCES__EXTERNAL_RESOURCES__VAL", + "SOURCES__EXTERNAL_RESOURCES__VAL", + ] + ) # inject the source sections like the Pipeline object would s = with_external() assert s.name == "with_external" assert s.section == "external_resources" # from module name hosting the function - with inject_section(ConfigSectionContext(pipeline_name="injected_external", sections=("sources", s.section, s.name))): + with inject_section( + ConfigSectionContext( + pipeline_name="injected_external", sections=("sources", s.section, s.name) + ) + ): # now the external sources must adopt the injected namespace - assert(list(s)) == [ + assert (list(s)) == [ "SOURCES__EXTERNAL_RESOURCES__SOURCE_VAL", "SOURCES__EXTERNAL_RESOURCES__SOURCE_VAL", "SOURCES__EXTERNAL_RESOURCES__VAL", - "SOURCES__EXTERNAL_RESOURCES__VAL" + "SOURCES__EXTERNAL_RESOURCES__VAL", ] # now with environ values that specify source/resource name: the module of the source, the name of the resource - os.environ["SOURCES__EXTERNAL_RESOURCES__INIT_RESOURCE_F_2__VAL"] = "SOURCES__EXTERNAL_RESOURCES__INIT_RESOURCE_F_2__VAL" - os.environ["SOURCES__EXTERNAL_RESOURCES__RESOURCE_F_2__VAL"] = "SOURCES__EXTERNAL_RESOURCES__RESOURCE_F_2__VAL" + os.environ["SOURCES__EXTERNAL_RESOURCES__INIT_RESOURCE_F_2__VAL"] = ( + "SOURCES__EXTERNAL_RESOURCES__INIT_RESOURCE_F_2__VAL" + ) + os.environ["SOURCES__EXTERNAL_RESOURCES__RESOURCE_F_2__VAL"] = ( + "SOURCES__EXTERNAL_RESOURCES__RESOURCE_F_2__VAL" + ) s = with_external() - with inject_section(ConfigSectionContext(pipeline_name="injected_external", sections=("sources", s.section, s.name))): + with inject_section( + ConfigSectionContext( + pipeline_name="injected_external", sections=("sources", s.section, s.name) + ) + ): # now the external sources must adopt the injected namespace - assert(list(s)) == [ + assert (list(s)) == [ "SOURCES__EXTERNAL_RESOURCES__SOURCE_VAL", "SOURCES__EXTERNAL_RESOURCES__SOURCE_VAL", "SOURCES__EXTERNAL_RESOURCES__INIT_RESOURCE_F_2__VAL", - "SOURCES__EXTERNAL_RESOURCES__RESOURCE_F_2__VAL" + "SOURCES__EXTERNAL_RESOURCES__RESOURCE_F_2__VAL", ] + def test_source_schema_context() -> None: import dlt @@ -388,7 +428,6 @@ def created_global(): def test_source_state_context() -> None: - @dlt.resource(selected=False) def main(): state = dlt.current.state() @@ -396,14 +435,14 @@ def main(): # increase the multiplier each time state is obtained state["mark"] *= 2 yield [1, 2, 3] - assert dlt.state()["mark"] == mark*2 + assert dlt.state()["mark"] == mark * 2 @dlt.transformer(data_from=main) def feeding(item): # we must have state assert dlt.current.source_state()["mark"] > 1 mark = dlt.current.source_state()["mark"] - yield from map(lambda i: i*mark, item) + yield from map(lambda i: i * mark, item) @dlt.source def pass_the_state(): @@ -419,7 +458,6 @@ def pass_the_state(): def test_source_schema_modified() -> None: - @dlt.source def schema_test(): return dlt.resource(["A", "B"], name="alpha") @@ -437,13 +475,12 @@ def standalone_resource(secret=dlt.secrets.value, config=dlt.config.value, opt: def test_spec_generation() -> None: - # inner resource cannot take configuration with pytest.raises(ResourceInnerCallableConfigWrapDisallowed) as py_ex: @dlt.resource(write_disposition="merge", primary_key="id") - def inner_resource(initial_id = dlt.config.value): + def inner_resource(initial_id=dlt.config.value): yield [{"id": 1, "name": "row1"}, {"id": 1, "name": "row2"}] assert py_ex.value.resource_name == "inner_resource" @@ -472,7 +509,6 @@ def not_args_r(): def test_sources_no_arguments() -> None: - @dlt.source def no_args(): return dlt.resource([1, 2], name="data") @@ -501,7 +537,6 @@ def not_args_r_i(): def test_resource_sets_invalid_write_disposition() -> None: - @dlt.resource(write_disposition="xxxx") def invalid_disposition(): yield from [1, 2, 3] @@ -513,7 +548,6 @@ def invalid_disposition(): def test_class_source() -> None: - class _Source: def __init__(self, elems: int) -> None: self.elems = elems @@ -527,10 +561,11 @@ def __call__(self, more: int = 1): schema = s.discover_schema() assert schema.name == "_Source" assert "_list" in schema.tables - assert list(s) == ['A', 'V', 'A', 'V', 'A', 'V', 'A', 'V'] + assert list(s) == ["A", "V", "A", "V", "A", "V", "A", "V"] # CAN'T decorate classes themselves with pytest.raises(SourceIsAClassTypeError): + @dlt.source(name="planB") class _SourceB: def __init__(self, elems: int) -> None: diff --git a/tests/extract/test_extract.py b/tests/extract/test_extract.py index 530a089f1c..5c4926cd37 100644 --- a/tests/extract/test_extract.py +++ b/tests/extract/test_extract.py @@ -1,15 +1,14 @@ +from tests.extract.utils import expect_extracted_file +from tests.utils import clean_test_storage + import dlt from dlt.common import json from dlt.common.storages import NormalizeStorageConfiguration from dlt.extract.extract import ExtractorStorage, extract from dlt.extract.source import DltResource, DltSource -from tests.utils import clean_test_storage -from tests.extract.utils import expect_extracted_file - def test_extract_select_tables() -> None: - def expect_tables(resource: DltResource) -> dlt.Schema: # delete files clean_test_storage() @@ -30,9 +29,8 @@ def expect_tables(resource: DltResource) -> dlt.Schema: storage.commit_extract_files(extract_id) # check resulting files assert len(storage.list_files_to_normalize_sorted()) == 2 - expect_extracted_file(storage, "selectables", "odd_table", json.dumps([1,3,5,7,9])) - expect_extracted_file(storage, "selectables", "even_table", json.dumps([0,2,4,6,8])) - + expect_extracted_file(storage, "selectables", "odd_table", json.dumps([1, 3, 5, 7, 9])) + expect_extracted_file(storage, "selectables", "even_table", json.dumps([0, 2, 4, 6, 8])) # delete files clean_test_storage() @@ -49,7 +47,7 @@ def expect_tables(resource: DltResource) -> dlt.Schema: assert len(partials) == 1 storage.commit_extract_files(extract_id) assert len(storage.list_files_to_normalize_sorted()) == 1 - expect_extracted_file(storage, "selectables", "odd_table", json.dumps([1,3,5,7,9])) + expect_extracted_file(storage, "selectables", "odd_table", json.dumps([1, 3, 5, 7, 9])) return schema @@ -68,7 +66,7 @@ def table_with_name_selectable(_range): @dlt.resource(table_name=n_f) def table_name_with_lambda(_range): - yield list(range(_range)) + yield list(range(_range)) schema = expect_tables(table_name_with_lambda) assert "table_name_with_lambda" not in schema.tables diff --git a/tests/extract/test_extract_pipe.py b/tests/extract/test_extract_pipe.py index a1e31730e7..cb6040c0b6 100644 --- a/tests/extract/test_extract_pipe.py +++ b/tests/extract/test_extract_pipe.py @@ -1,8 +1,8 @@ -import os import asyncio import inspect -from typing import List, Sequence +import os import time +from typing import List, Sequence import pytest @@ -10,12 +10,11 @@ from dlt.common import sleep from dlt.common.typing import TDataItems from dlt.extract.exceptions import CreatePipeException, ResourceExtractionError -from dlt.extract.typing import DataItemWithMeta, FilterItem, MapItem, YieldMapItem from dlt.extract.pipe import ManagedPipeIterator, Pipe, PipeItem, PipeIterator +from dlt.extract.typing import DataItemWithMeta, FilterItem, MapItem, YieldMapItem def test_next_item_mode() -> None: - def nested_gen_level_2(): yield from [88, None, 89] @@ -23,25 +22,25 @@ def nested_gen(): yield from [55, 56, None, 77, nested_gen_level_2()] def source_gen1(): - yield from [1, 2, nested_gen(), 3,4] + yield from [1, 2, nested_gen(), 3, 4] def source_gen2(): yield from range(11, 16) def source_gen3(): - yield from range(20,22) + yield from range(20, 22) def get_pipes(): return [ Pipe.from_data("data1", source_gen1()), Pipe.from_data("data2", source_gen2()), Pipe.from_data("data3", source_gen3()), - ] + ] # default mode is "fifo" _l = list(PipeIterator.from_pipes(get_pipes(), next_item_mode="fifo")) # items will be in order of the pipes, nested iterator items appear inline - assert [pi.item for pi in _l] == [1, 2, 55, 56, 77, 88, 89, 3, 4, 11, 12, 13, 14, 15, 20, 21] + assert [pi.item for pi in _l] == [1, 2, 55, 56, 77, 88, 89, 3, 4, 11, 12, 13, 14, 15, 20, 21] # round robin mode _l = list(PipeIterator.from_pipes(get_pipes(), next_item_mode="round_robin")) @@ -50,7 +49,6 @@ def get_pipes(): def test_rotation_on_none() -> None: - global started global gen_1_started global gen_2_started @@ -86,7 +84,7 @@ def get_pipes(): Pipe.from_data("data1", source_gen1()), Pipe.from_data("data2", source_gen2()), Pipe.from_data("data3", source_gen3()), - ] + ] # round robin mode _l = list(PipeIterator.from_pipes(get_pipes(), next_item_mode="round_robin")) @@ -96,10 +94,6 @@ def get_pipes(): assert time.time() - started < 0.8 - - - - def test_add_step() -> None: data = [1, 2, 3] data_iter = iter(data) @@ -141,7 +135,7 @@ def test_insert_remove_step() -> None: pp = Pipe.from_data("data", data) def tx(item): - yield item*2 + yield item * 2 # create pipe with transformer p = Pipe.from_data("tx", tx, parent=pp) @@ -193,7 +187,7 @@ def item_meta_step(item, meta): p.remove_step(0) assert p._gen_idx == 0 _l = list(PipeIterator.from_pipe(p)) - assert [pi.item for pi in _l] == [0.5, 1, 3/2] + assert [pi.item for pi in _l] == [0.5, 1, 3 / 2] # remove all remaining txs p.remove_step(1) pp.remove_step(1) @@ -215,7 +209,7 @@ def item_meta_step(item, meta): def tx_minus(item, meta): assert meta is None - yield item*-4 + yield item * -4 p.replace_gen(tx_minus) _l = list(PipeIterator.from_pipe(p)) @@ -238,8 +232,8 @@ def test_pipe_propagate_meta() -> None: p = Pipe.from_data("data", iter(meta_data)) def item_meta_step(item: int, meta): - assert _meta[item-1] == meta - return item*2 + assert _meta[item - 1] == meta + return item * 2 p.append_step(item_meta_step) _l = list(PipeIterator.from_pipe(p)) @@ -252,19 +246,19 @@ def item_meta_step(item: int, meta): # does not take meta def transformer(item): - yield item*item + yield item * item def item_meta_step_trans(item: int, meta): # reverse all transformations on item - meta_idx = int(item**0.5//2) - assert _meta[meta_idx-1] == meta - return item*2 + meta_idx = int(item**0.5 // 2) + assert _meta[meta_idx - 1] == meta + return item * 2 t = Pipe("tran", [transformer], parent=p) t.append_step(item_meta_step_trans) _l = list(PipeIterator.from_pipe(t)) # item got propagated through transformation -> transformer -> transformation - assert [int((pi.item//2)**0.5//2) for pi in _l] == data + assert [int((pi.item // 2) ** 0.5 // 2) for pi in _l] == data assert [pi.meta for pi in _l] == _meta # same but with the fork step @@ -275,7 +269,7 @@ def item_meta_step_trans(item: int, meta): # do not yield parents _l = list(PipeIterator.from_pipes([p, t], yield_parents=False)) # same result - assert [int((pi.item//2)**0.5//2) for pi in _l] == data + assert [int((pi.item // 2) ** 0.5 // 2) for pi in _l] == data assert [pi.meta for pi in _l] == _meta # same but yield parents @@ -286,11 +280,11 @@ def item_meta_step_trans(item: int, meta): _l = list(PipeIterator.from_pipes([p, t], yield_parents=True)) # same result for transformer tran_l = [pi for pi in _l if pi.pipe._pipe_id == t._pipe_id] - assert [int((pi.item//2)**0.5//2) for pi in tran_l] == data + assert [int((pi.item // 2) ** 0.5 // 2) for pi in tran_l] == data assert [pi.meta for pi in tran_l] == _meta data_l = [pi for pi in _l if pi.pipe._pipe_id == p._pipe_id] # data pipe went only through one transformation - assert [int(pi.item//2) for pi in data_l] == data + assert [int(pi.item // 2) for pi in data_l] == data assert [pi.meta for pi in data_l] == _meta @@ -302,9 +296,9 @@ def test_pipe_transformation_changes_meta() -> None: p = Pipe.from_data("data", iter(meta_data)) def item_meta_step(item: int, meta): - assert _meta[item-1] == meta + assert _meta[item - 1] == meta # return meta, it should overwrite existing one - return DataItemWithMeta("X" + str(item), item*2) + return DataItemWithMeta("X" + str(item), item * 2) p.append_step(item_meta_step) _l = list(PipeIterator.from_pipe(p)) @@ -314,10 +308,10 @@ def item_meta_step(item: int, meta): # also works for deferred transformations @dlt.defer def item_meta_step_defer(item: int, meta): - assert _meta[item-1] == meta + assert _meta[item - 1] == meta sleep(item * 0.2) # return meta, it should overwrite existing one - return DataItemWithMeta("X" + str(item), item*2) + return DataItemWithMeta("X" + str(item), item * 2) p = Pipe.from_data("data", iter(meta_data)) p.append_step(item_meta_step_defer) @@ -327,9 +321,9 @@ def item_meta_step_defer(item: int, meta): # also works for yielding transformations def item_meta_step_flat(item: int, meta): - assert _meta[item-1] == meta + assert _meta[item - 1] == meta # return meta, it should overwrite existing one - yield DataItemWithMeta("X" + str(item), item*2) + yield DataItemWithMeta("X" + str(item), item * 2) p = Pipe.from_data("data", iter(meta_data)) p.append_step(item_meta_step_flat) @@ -339,10 +333,10 @@ def item_meta_step_flat(item: int, meta): # also works for async async def item_meta_step_async(item: int, meta): - assert _meta[item-1] == meta + assert _meta[item - 1] == meta await asyncio.sleep(item * 0.2) # this returns awaitable - return DataItemWithMeta("X" + str(item), item*2) + return DataItemWithMeta("X" + str(item), item * 2) p = Pipe.from_data("data", iter(meta_data)) p.append_step(item_meta_step_async) @@ -353,7 +347,7 @@ async def item_meta_step_async(item: int, meta): # also lets the transformer return meta def transformer(item: int): - yield DataItemWithMeta("X" + str(item), item*2) + yield DataItemWithMeta("X" + str(item), item * 2) p = Pipe.from_data("data", iter(meta_data)) t = Pipe("tran", [transformer], parent=p) @@ -451,14 +445,30 @@ def test_yield_map_step() -> None: p = Pipe.from_data("data", [1, 2, 3]) # this creates number of rows as passed by the data p.append_step(YieldMapItem(lambda item: (yield from [f"item_{x}" for x in range(item)]))) - assert _f_items(list(PipeIterator.from_pipe(p))) == ["item_0", "item_0", "item_1", "item_0", "item_1", "item_2"] + assert _f_items(list(PipeIterator.from_pipe(p))) == [ + "item_0", + "item_0", + "item_1", + "item_0", + "item_1", + "item_2", + ] data = [1, 2, 3] meta = ["A", "B", "C"] # package items into meta wrapper meta_data = [DataItemWithMeta(m, d) for m, d in zip(meta, data)] p = Pipe.from_data("data", meta_data) - p.append_step(YieldMapItem(lambda item, meta: (yield from [f"item_{meta}_{x}" for x in range(item)]))) - assert _f_items(list(PipeIterator.from_pipe(p))) == ["item_A_0", "item_B_0", "item_B_1", "item_C_0", "item_C_1", "item_C_2"] + p.append_step( + YieldMapItem(lambda item, meta: (yield from [f"item_{meta}_{x}" for x in range(item)])) + ) + assert _f_items(list(PipeIterator.from_pipe(p))) == [ + "item_A_0", + "item_B_0", + "item_B_1", + "item_C_0", + "item_C_1", + "item_C_2", + ] def test_pipe_copy_on_fork() -> None: @@ -481,9 +491,8 @@ def test_pipe_copy_on_fork() -> None: def test_clone_pipes() -> None: - def pass_gen(item, meta): - yield item*2 + yield item * 2 data = [1, 2, 3] p1 = Pipe("p1", [data]) @@ -509,7 +518,6 @@ def pass_gen(item, meta): # try circular deps - def assert_cloned_pipes(pipes: List[Pipe], cloned_pipes: List[Pipe]): # clones pipes must be separate instances but must preserve pipe id and names for pipe, cloned_pipe in zip(pipes, cloned_pipes): @@ -527,13 +535,14 @@ def assert_cloned_pipes(pipes: List[Pipe], cloned_pipes: List[Pipe]): # must yield same data for pipe, cloned_pipe in zip(pipes, cloned_pipes): - assert _f_items(list(PipeIterator.from_pipe(pipe))) == _f_items(list(PipeIterator.from_pipe(cloned_pipe))) + assert _f_items(list(PipeIterator.from_pipe(pipe))) == _f_items( + list(PipeIterator.from_pipe(cloned_pipe)) + ) def test_circular_deps() -> None: - def pass_gen(item, meta): - yield item*2 + yield item * 2 c_p1_p3 = Pipe("c_p1_p3", [pass_gen]) c_p1_p4 = Pipe("c_p1_p4", [pass_gen], parent=c_p1_p3) @@ -609,7 +618,6 @@ def raise_gen(item: int): def test_close_on_sync_exception() -> None: - def long_gen(): global close_pipe_got_exit, close_pipe_yielding @@ -636,7 +644,9 @@ def assert_pipes_closed(raise_gen, long_gen) -> None: close_pipe_yielding = False pit: PipeIterator = None - with PipeIterator.from_pipe(Pipe.from_data("failing", raise_gen, parent=Pipe.from_data("endless", long_gen()))) as pit: + with PipeIterator.from_pipe( + Pipe.from_data("failing", raise_gen, parent=Pipe.from_data("endless", long_gen())) + ) as pit: with pytest.raises(ResourceExtractionError) as py_ex: list(pit) assert isinstance(py_ex.value.__cause__, RuntimeError) @@ -648,7 +658,9 @@ def assert_pipes_closed(raise_gen, long_gen) -> None: close_pipe_got_exit = False close_pipe_yielding = False - pit = ManagedPipeIterator.from_pipe(Pipe.from_data("failing", raise_gen, parent=Pipe.from_data("endless", long_gen()))) + pit = ManagedPipeIterator.from_pipe( + Pipe.from_data("failing", raise_gen, parent=Pipe.from_data("endless", long_gen())) + ) with pytest.raises(ResourceExtractionError) as py_ex: list(pit) assert isinstance(py_ex.value.__cause__, RuntimeError) diff --git a/tests/extract/test_incremental.py b/tests/extract/test_incremental.py index 3160a2a1ee..6dfc7cd4e8 100644 --- a/tests/extract/test_incremental.py +++ b/tests/extract/test_incremental.py @@ -1,74 +1,74 @@ import os -from time import sleep -from typing import Optional, Any from datetime import datetime # noqa: I251 from itertools import chain +from time import sleep +from typing import Any, Optional import duckdb import pytest +from tests.extract.utils import AssertItems import dlt -from dlt.common.configuration.container import Container -from dlt.common.configuration.specs.base_configuration import configspec, BaseConfiguration from dlt.common.configuration import ConfigurationValueError +from dlt.common.configuration.container import Container +from dlt.common.configuration.specs.base_configuration import BaseConfiguration, configspec +from dlt.common.json import json from dlt.common.pendulum import pendulum, timedelta from dlt.common.pipeline import StateInjectableContext, resource_state from dlt.common.schema.schema import Schema -from dlt.common.utils import uniq_id, digest128, chunks -from dlt.common.json import json - +from dlt.common.utils import chunks, digest128, uniq_id +from dlt.extract.incremental import IncrementalCursorPathMissing, IncrementalPrimaryKeyMissing from dlt.extract.source import DltSource from dlt.sources.helpers.transform import take_first -from dlt.extract.incremental import IncrementalCursorPathMissing, IncrementalPrimaryKeyMissing - -from tests.extract.utils import AssertItems def test_single_items_last_value_state_is_updated() -> None: @dlt.resource - def some_data(created_at=dlt.sources.incremental('created_at')): - yield {'created_at': 425} - yield {'created_at': 426} + def some_data(created_at=dlt.sources.incremental("created_at")): + yield {"created_at": 425} + yield {"created_at": 426} p = dlt.pipeline(pipeline_name=uniq_id()) p.extract(some_data()) - s = some_data.state['incremental']['created_at'] - assert s['last_value'] == 426 + s = some_data.state["incremental"]["created_at"] + assert s["last_value"] == 426 def test_single_items_last_value_state_is_updated_transformer() -> None: @dlt.transformer - def some_data(item, created_at=dlt.sources.incremental('created_at')): - yield {'created_at': 425} - yield {'created_at': 426} + def some_data(item, created_at=dlt.sources.incremental("created_at")): + yield {"created_at": 425} + yield {"created_at": 426} p = dlt.pipeline(pipeline_name=uniq_id()) - p.extract(dlt.resource([1,2,3], name="table") | some_data()) + p.extract(dlt.resource([1, 2, 3], name="table") | some_data()) - s = some_data().state['incremental']['created_at'] - assert s['last_value'] == 426 + s = some_data().state["incremental"]["created_at"] + assert s["last_value"] == 426 def test_batch_items_last_value_state_is_updated() -> None: @dlt.resource - def some_data(created_at=dlt.sources.incremental('created_at')): - yield [{'created_at': i} for i in range(5)] - yield [{'created_at': i} for i in range(5, 10)] + def some_data(created_at=dlt.sources.incremental("created_at")): + yield [{"created_at": i} for i in range(5)] + yield [{"created_at": i} for i in range(5, 10)] p = dlt.pipeline(pipeline_name=uniq_id()) p.extract(some_data()) - s = p.state["sources"][p.default_schema_name]['resources']['some_data']['incremental']['created_at'] - assert s['last_value'] == 9 + s = p.state["sources"][p.default_schema_name]["resources"]["some_data"]["incremental"][ + "created_at" + ] + assert s["last_value"] == 9 def test_last_value_access_in_resource() -> None: values = [] @dlt.resource - def some_data(created_at=dlt.sources.incremental('created_at')): + def some_data(created_at=dlt.sources.incremental("created_at")): values.append(created_at.last_value) - yield [{'created_at': i} for i in range(6)] + yield [{"created_at": i} for i in range(6)] p = dlt.pipeline(pipeline_name=uniq_id()) p.extract(some_data()) @@ -78,22 +78,24 @@ def some_data(created_at=dlt.sources.incremental('created_at')): def test_unique_keys_are_deduplicated() -> None: - @dlt.resource(primary_key='id') - def some_data(created_at=dlt.sources.incremental('created_at')): + @dlt.resource(primary_key="id") + def some_data(created_at=dlt.sources.incremental("created_at")): if created_at.last_value is None: - yield {'created_at': 1, 'id': 'a'} - yield {'created_at': 2, 'id': 'b'} - yield {'created_at': 3, 'id': 'c'} - yield {'created_at': 3, 'id': 'd'} - yield {'created_at': 3, 'id': 'e'} + yield {"created_at": 1, "id": "a"} + yield {"created_at": 2, "id": "b"} + yield {"created_at": 3, "id": "c"} + yield {"created_at": 3, "id": "d"} + yield {"created_at": 3, "id": "e"} else: - yield {'created_at': 3, 'id': 'c'} - yield {'created_at': 3, 'id': 'd'} - yield {'created_at': 3, 'id': 'e'} - yield {'created_at': 3, 'id': 'f'} - yield {'created_at': 4, 'id': 'g'} - - p = dlt.pipeline(pipeline_name=uniq_id(), destination='duckdb', credentials=duckdb.connect(':memory:')) + yield {"created_at": 3, "id": "c"} + yield {"created_at": 3, "id": "d"} + yield {"created_at": 3, "id": "e"} + yield {"created_at": 3, "id": "f"} + yield {"created_at": 4, "id": "g"} + + p = dlt.pipeline( + pipeline_name=uniq_id(), destination="duckdb", credentials=duckdb.connect(":memory:") + ) p.run(some_data()) p.run(some_data()) @@ -102,26 +104,28 @@ def some_data(created_at=dlt.sources.incremental('created_at')): with c.execute_query("SELECT created_at, id FROM some_data order by created_at, id") as cur: rows = cur.fetchall() - assert rows == [(1, 'a'), (2, 'b'), (3, 'c'), (3, 'd'), (3, 'e'), (3, 'f'), (4, 'g')] + assert rows == [(1, "a"), (2, "b"), (3, "c"), (3, "d"), (3, "e"), (3, "f"), (4, "g")] def test_unique_rows_by_hash_are_deduplicated() -> None: @dlt.resource - def some_data(created_at=dlt.sources.incremental('created_at')): + def some_data(created_at=dlt.sources.incremental("created_at")): if created_at.last_value is None: - yield {'created_at': 1, 'id': 'a'} - yield {'created_at': 2, 'id': 'b'} - yield {'created_at': 3, 'id': 'c'} - yield {'created_at': 3, 'id': 'd'} - yield {'created_at': 3, 'id': 'e'} + yield {"created_at": 1, "id": "a"} + yield {"created_at": 2, "id": "b"} + yield {"created_at": 3, "id": "c"} + yield {"created_at": 3, "id": "d"} + yield {"created_at": 3, "id": "e"} else: - yield {'created_at': 3, 'id': 'c'} - yield {'created_at': 3, 'id': 'd'} - yield {'created_at': 3, 'id': 'e'} - yield {'created_at': 3, 'id': 'f'} - yield {'created_at': 4, 'id': 'g'} - - p = dlt.pipeline(pipeline_name=uniq_id(), destination='duckdb', credentials=duckdb.connect(':memory:')) + yield {"created_at": 3, "id": "c"} + yield {"created_at": 3, "id": "d"} + yield {"created_at": 3, "id": "e"} + yield {"created_at": 3, "id": "f"} + yield {"created_at": 4, "id": "g"} + + p = dlt.pipeline( + pipeline_name=uniq_id(), destination="duckdb", credentials=duckdb.connect(":memory:") + ) p.run(some_data()) p.run(some_data()) @@ -129,61 +133,70 @@ def some_data(created_at=dlt.sources.incremental('created_at')): with c.execute_query("SELECT created_at, id FROM some_data order by created_at, id") as cur: rows = cur.fetchall() - assert rows == [(1, 'a'), (2, 'b'), (3, 'c'), (3, 'd'), (3, 'e'), (3, 'f'), (4, 'g')] + assert rows == [(1, "a"), (2, "b"), (3, "c"), (3, "d"), (3, "e"), (3, "f"), (4, "g")] def test_nested_cursor_path() -> None: @dlt.resource - def some_data(created_at=dlt.sources.incremental('data.items[0].created_at')): - yield {'data': {'items': [{'created_at': 2}]}} + def some_data(created_at=dlt.sources.incremental("data.items[0].created_at")): + yield {"data": {"items": [{"created_at": 2}]}} p = dlt.pipeline(pipeline_name=uniq_id()) p.extract(some_data()) - s = p.state["sources"][p.default_schema_name]['resources']['some_data']['incremental']['data.items[0].created_at'] - assert s['last_value'] == 2 + s = p.state["sources"][p.default_schema_name]["resources"]["some_data"]["incremental"][ + "data.items[0].created_at" + ] + assert s["last_value"] == 2 def test_explicit_initial_value() -> None: @dlt.resource - def some_data(created_at=dlt.sources.incremental('created_at')): - yield {'created_at': created_at.last_value} + def some_data(created_at=dlt.sources.incremental("created_at")): + yield {"created_at": created_at.last_value} p = dlt.pipeline(pipeline_name=uniq_id()) p.extract(some_data(created_at=4242)) - s = p.state["sources"][p.default_schema_name]['resources']['some_data']['incremental']['created_at'] - assert s['last_value'] == 4242 + s = p.state["sources"][p.default_schema_name]["resources"]["some_data"]["incremental"][ + "created_at" + ] + assert s["last_value"] == 4242 def test_explicit_incremental_instance() -> None: - @dlt.resource(primary_key='some_uq') - def some_data(incremental=dlt.sources.incremental('created_at', initial_value=0)): - assert incremental.cursor_path == 'inserted_at' + @dlt.resource(primary_key="some_uq") + def some_data(incremental=dlt.sources.incremental("created_at", initial_value=0)): + assert incremental.cursor_path == "inserted_at" assert incremental.initial_value == 241 - yield {'inserted_at': 242, 'some_uq': 444} + yield {"inserted_at": 242, "some_uq": 444} p = dlt.pipeline(pipeline_name=uniq_id()) - p.extract(some_data(incremental=dlt.sources.incremental('inserted_at', initial_value=241))) + p.extract(some_data(incremental=dlt.sources.incremental("inserted_at", initial_value=241))) @dlt.resource -def some_data_from_config(call_no: int, created_at: Optional[dlt.sources.incremental] = dlt.secrets.value): - assert created_at.cursor_path == 'created_at' +def some_data_from_config( + call_no: int, created_at: Optional[dlt.sources.incremental] = dlt.secrets.value +): + assert created_at.cursor_path == "created_at" # start value will update to the last_value on next call if call_no == 1: - assert created_at.initial_value == '2022-02-03T00:00:00Z' - assert created_at.start_value == '2022-02-03T00:00:00Z' + assert created_at.initial_value == "2022-02-03T00:00:00Z" + assert created_at.start_value == "2022-02-03T00:00:00Z" if call_no == 2: - assert created_at.initial_value == '2022-02-03T00:00:00Z' - assert created_at.start_value == '2022-02-03T00:00:01Z' - yield {'created_at': '2022-02-03T00:00:01Z'} + assert created_at.initial_value == "2022-02-03T00:00:00Z" + assert created_at.start_value == "2022-02-03T00:00:01Z" + yield {"created_at": "2022-02-03T00:00:01Z"} def test_optional_incremental_from_config() -> None: - - os.environ['SOURCES__TEST_INCREMENTAL__SOME_DATA_FROM_CONFIG__CREATED_AT__CURSOR_PATH'] = 'created_at' - os.environ['SOURCES__TEST_INCREMENTAL__SOME_DATA_FROM_CONFIG__CREATED_AT__INITIAL_VALUE'] = '2022-02-03T00:00:00Z' + os.environ["SOURCES__TEST_INCREMENTAL__SOME_DATA_FROM_CONFIG__CREATED_AT__CURSOR_PATH"] = ( + "created_at" + ) + os.environ["SOURCES__TEST_INCREMENTAL__SOME_DATA_FROM_CONFIG__CREATED_AT__INITIAL_VALUE"] = ( + "2022-02-03T00:00:00Z" + ) p = dlt.pipeline(pipeline_name=uniq_id()) p.extract(some_data_from_config(1)) @@ -192,15 +205,17 @@ def test_optional_incremental_from_config() -> None: @configspec class SomeDataOverrideConfiguration: - created_at: dlt.sources.incremental = dlt.sources.incremental('created_at', initial_value='2022-02-03T00:00:00Z') + created_at: dlt.sources.incremental = dlt.sources.incremental( + "created_at", initial_value="2022-02-03T00:00:00Z" + ) # provide what to inject via spec. the spec contain the default @dlt.resource(spec=SomeDataOverrideConfiguration) def some_data_override_config(created_at: dlt.sources.incremental = dlt.config.value): - assert created_at.cursor_path == 'created_at' - assert created_at.initial_value == '2000-02-03T00:00:00Z' - yield {'created_at': '2023-03-03T00:00:00Z'} + assert created_at.cursor_path == "created_at" + assert created_at.initial_value == "2000-02-03T00:00:00Z" + yield {"created_at": "2023-03-03T00:00:00Z"} def test_optional_incremental_not_passed() -> None: @@ -208,7 +223,7 @@ def test_optional_incremental_not_passed() -> None: @dlt.resource def some_data(created_at: Optional[dlt.sources.incremental] = None): - yield [1,2,3] + yield [1, 2, 3] assert list(some_data()) == [1, 2, 3] @@ -219,7 +234,9 @@ class OptionalIncrementalConfig(BaseConfiguration): @dlt.resource(spec=OptionalIncrementalConfig) -def optional_incremental_arg_resource(incremental: Optional[dlt.sources.incremental[Any]] = None) -> Any: +def optional_incremental_arg_resource( + incremental: Optional[dlt.sources.incremental[Any]] = None, +) -> Any: assert incremental is None yield [1, 2, 3] @@ -232,7 +249,7 @@ def test_optional_arg_from_spec_not_passed() -> None: def test_override_initial_value_from_config() -> None: # use the shortest possible config version # os.environ['SOURCES__TEST_INCREMENTAL__SOME_DATA_OVERRIDE_CONFIG__CREATED_AT__INITIAL_VALUE'] = '2000-02-03T00:00:00Z' - os.environ['CREATED_AT__INITIAL_VALUE'] = '2000-02-03T00:00:00Z' + os.environ["CREATED_AT__INITIAL_VALUE"] = "2000-02-03T00:00:00Z" p = dlt.pipeline(pipeline_name=uniq_id()) p.extract(some_data_override_config()) @@ -240,57 +257,70 @@ def test_override_initial_value_from_config() -> None: def test_override_primary_key_in_pipeline() -> None: - """Primary key hint passed to pipeline is propagated through apply_hints - """ - @dlt.resource(primary_key='id') - def some_data(created_at=dlt.sources.incremental('created_at')): + """Primary key hint passed to pipeline is propagated through apply_hints""" + + @dlt.resource(primary_key="id") + def some_data(created_at=dlt.sources.incremental("created_at")): # TODO: this only works because incremental instance is shared across many copies of the resource - assert some_data.incremental.primary_key == ['id', 'other_id'] + assert some_data.incremental.primary_key == ["id", "other_id"] - yield {'created_at': 22, 'id': 2, 'other_id': 5} - yield {'created_at': 22, 'id': 2, 'other_id': 6} + yield {"created_at": 22, "id": 2, "other_id": 5} + yield {"created_at": 22, "id": 2, "other_id": 6} p = dlt.pipeline(pipeline_name=uniq_id()) - p.extract(some_data, primary_key=['id', 'other_id']) + p.extract(some_data, primary_key=["id", "other_id"]) def test_composite_primary_key() -> None: - @dlt.resource(primary_key=['isrc', 'market']) - def some_data(created_at=dlt.sources.incremental('created_at')): - yield {'created_at': 1, 'isrc': 'AAA', 'market': 'DE'} - yield {'created_at': 2, 'isrc': 'BBB', 'market': 'DE'} - yield {'created_at': 2, 'isrc': 'CCC', 'market': 'US'} - yield {'created_at': 2, 'isrc': 'AAA', 'market': 'DE'} - yield {'created_at': 2, 'isrc': 'CCC', 'market': 'DE'} - yield {'created_at': 2, 'isrc': 'DDD', 'market': 'DE'} - yield {'created_at': 2, 'isrc': 'CCC', 'market': 'DE'} - - p = dlt.pipeline(pipeline_name=uniq_id(), destination='duckdb', credentials=duckdb.connect(':memory:')) + @dlt.resource(primary_key=["isrc", "market"]) + def some_data(created_at=dlt.sources.incremental("created_at")): + yield {"created_at": 1, "isrc": "AAA", "market": "DE"} + yield {"created_at": 2, "isrc": "BBB", "market": "DE"} + yield {"created_at": 2, "isrc": "CCC", "market": "US"} + yield {"created_at": 2, "isrc": "AAA", "market": "DE"} + yield {"created_at": 2, "isrc": "CCC", "market": "DE"} + yield {"created_at": 2, "isrc": "DDD", "market": "DE"} + yield {"created_at": 2, "isrc": "CCC", "market": "DE"} + + p = dlt.pipeline( + pipeline_name=uniq_id(), destination="duckdb", credentials=duckdb.connect(":memory:") + ) p.run(some_data()) with p.sql_client() as c: - with c.execute_query("SELECT created_at, isrc, market FROM some_data order by created_at, isrc, market") as cur: + with c.execute_query( + "SELECT created_at, isrc, market FROM some_data order by created_at, isrc, market" + ) as cur: rows = cur.fetchall() - assert rows == [(1, 'AAA', 'DE'), (2, 'AAA', 'DE'), (2, 'BBB', 'DE'), (2, 'CCC', 'DE'), (2, 'CCC', 'US'), (2, 'DDD', 'DE')] + assert rows == [ + (1, "AAA", "DE"), + (2, "AAA", "DE"), + (2, "BBB", "DE"), + (2, "CCC", "DE"), + (2, "CCC", "US"), + (2, "DDD", "DE"), + ] def test_last_value_func_min() -> None: @dlt.resource - def some_data(created_at=dlt.sources.incremental('created_at', last_value_func=min)): - yield {'created_at': 10} - yield {'created_at': 11} - yield {'created_at': 9} - yield {'created_at': 10} - yield {'created_at': 8} - yield {'created_at': 22} + def some_data(created_at=dlt.sources.incremental("created_at", last_value_func=min)): + yield {"created_at": 10} + yield {"created_at": 11} + yield {"created_at": 9} + yield {"created_at": 10} + yield {"created_at": 8} + yield {"created_at": 22} p = dlt.pipeline(pipeline_name=uniq_id()) p.extract(some_data()) - s = p.state["sources"][p.default_schema_name]['resources']['some_data']['incremental']['created_at'] + s = p.state["sources"][p.default_schema_name]["resources"]["some_data"]["incremental"][ + "created_at" + ] - assert s['last_value'] == 8 + assert s["last_value"] == 8 def test_last_value_func_custom() -> None: @@ -298,52 +328,59 @@ def last_value(values): return max(values) + 1 @dlt.resource - def some_data(created_at=dlt.sources.incremental('created_at', last_value_func=last_value)): - yield {'created_at': 9} - yield {'created_at': 10} + def some_data(created_at=dlt.sources.incremental("created_at", last_value_func=last_value)): + yield {"created_at": 9} + yield {"created_at": 10} p = dlt.pipeline(pipeline_name=uniq_id()) p.extract(some_data()) - s = p.state["sources"][p.default_schema_name]['resources']['some_data']['incremental']['created_at'] - assert s['last_value'] == 11 + s = p.state["sources"][p.default_schema_name]["resources"]["some_data"]["incremental"][ + "created_at" + ] + assert s["last_value"] == 11 def test_cursor_datetime_type() -> None: initial_value = pendulum.now() @dlt.resource - def some_data(created_at=dlt.sources.incremental('created_at', initial_value)): - yield {'created_at': initial_value + timedelta(minutes=1)} - yield {'created_at': initial_value + timedelta(minutes=3)} - yield {'created_at': initial_value + timedelta(minutes=2)} - yield {'created_at': initial_value + timedelta(minutes=4)} - yield {'created_at': initial_value + timedelta(minutes=2)} + def some_data(created_at=dlt.sources.incremental("created_at", initial_value)): + yield {"created_at": initial_value + timedelta(minutes=1)} + yield {"created_at": initial_value + timedelta(minutes=3)} + yield {"created_at": initial_value + timedelta(minutes=2)} + yield {"created_at": initial_value + timedelta(minutes=4)} + yield {"created_at": initial_value + timedelta(minutes=2)} p = dlt.pipeline(pipeline_name=uniq_id()) p.extract(some_data()) - s = p.state["sources"][p.default_schema_name]['resources']['some_data']['incremental']['created_at'] - assert s['last_value'] == initial_value + timedelta(minutes=4) + s = p.state["sources"][p.default_schema_name]["resources"]["some_data"]["incremental"][ + "created_at" + ] + assert s["last_value"] == initial_value + timedelta(minutes=4) def test_descending_order_unique_hashes() -> None: """Resource returns items in descending order but using `max` last value function. Only hash matching last_value are stored. """ + @dlt.resource - def some_data(created_at=dlt.sources.incremental('created_at', 20)): + def some_data(created_at=dlt.sources.incremental("created_at", 20)): for i in reversed(range(15, 25)): - yield {'created_at': i} + yield {"created_at": i} p = dlt.pipeline(pipeline_name=uniq_id()) p.extract(some_data()) - s = p.state["sources"][p.default_schema_name]['resources']['some_data']['incremental']['created_at'] + s = p.state["sources"][p.default_schema_name]["resources"]["some_data"]["incremental"][ + "created_at" + ] - last_hash = digest128(json.dumps({'created_at': 24})) + last_hash = digest128(json.dumps({"created_at": 24})) - assert s['unique_hashes'] == [last_hash] + assert s["unique_hashes"] == [last_hash] # make sure nothing is returned on a next run, source will use state from the active pipeline assert list(some_data()) == [] @@ -351,6 +388,7 @@ def some_data(created_at=dlt.sources.incremental('created_at', 20)): def test_unique_keys_json_identifiers() -> None: """Uses primary key name that is matching the name of the JSON element in the original namespace but gets converted into destination namespace""" + @dlt.resource(primary_key="DelTa") def some_data(last_timestamp=dlt.sources.incremental("item.ts")): for i in range(-10, 10): @@ -359,7 +397,7 @@ def some_data(last_timestamp=dlt.sources.incremental("item.ts")): p = dlt.pipeline(pipeline_name=uniq_id()) p.run(some_data, destination="duckdb") # check if default schema contains normalized PK - assert p.default_schema.tables["some_data"]['columns']["del_ta"]['primary_key'] is True + assert p.default_schema.tables["some_data"]["columns"]["del_ta"]["primary_key"] is True with p.sql_client() as c: with c.execute_query("SELECT del_ta FROM some_data") as cur: rows = cur.fetchall() @@ -371,14 +409,15 @@ def some_data(last_timestamp=dlt.sources.incremental("item.ts")): # something got loaded = wee create 20 elements starting from now. so one element will be in the future comparing to previous 20 elements assert len(load_info.loads_ids) == 1 with p.sql_client() as c: - with c.execute_query("SELECT del_ta FROM some_data WHERE _dlt_load_id = %s", load_info.loads_ids[0]) as cur: + with c.execute_query( + "SELECT del_ta FROM some_data WHERE _dlt_load_id = %s", load_info.loads_ids[0] + ) as cur: rows = cur.fetchall() assert len(rows) == 1 assert rows[0][0] == 9 def test_missing_primary_key() -> None: - @dlt.resource(primary_key="DELTA") def some_data(last_timestamp=dlt.sources.incremental("item.ts")): for i in range(-10, 10): @@ -390,7 +429,6 @@ def some_data(last_timestamp=dlt.sources.incremental("item.ts")): def test_missing_cursor_field() -> None: - @dlt.resource def some_data(last_timestamp=dlt.sources.incremental("item.timestamp")): for i in range(-10, 10): @@ -421,7 +459,13 @@ def test_filter_processed_items() -> None: assert all(v["delta"] >= 0 for v in values) # provide the initial value, use min function - values = list(standalone_some_data(last_timestamp=dlt.sources.incremental("item.timestamp", pendulum.now().timestamp(), min))) + values = list( + standalone_some_data( + last_timestamp=dlt.sources.incremental( + "item.timestamp", pendulum.now().timestamp(), min + ) + ) + ) assert len(values) == 10 # the minimum element assert values[0]["delta"] == -10 @@ -440,14 +484,22 @@ def some_data(step, last_timestamp=dlt.sources.incremental("item.ts")): else: # print(last_timestamp.initial_value) # print(now.add(days=step-1).timestamp()) - assert last_timestamp.start_value == last_timestamp.last_value == now.add(days=step-1).timestamp() + assert ( + last_timestamp.start_value + == last_timestamp.last_value + == now.add(days=step - 1).timestamp() + ) for i in range(-10, 10): yield {"delta": i, "item": {"ts": now.add(days=i).timestamp()}} # after all yielded if step == -10: assert last_timestamp.start_value is None else: - assert last_timestamp.start_value == now.add(days=step-1).timestamp() != last_timestamp.last_value + assert ( + last_timestamp.start_value + == now.add(days=step - 1).timestamp() + != last_timestamp.last_value + ) for i in range(-10, 10): r = some_data(i) @@ -469,6 +521,7 @@ def test_replace_resets_state() -> None: assert len(info.loads_ids) == 1 parent_r = standalone_some_data(now) + @dlt.transformer(data_from=parent_r, write_disposition="append") def child(item): state = resource_state("child") @@ -491,14 +544,15 @@ def child(item): assert len(info.loads_ids) == 1 info = p.run(s) # state was reset - assert 'child' not in s.state['resources'] + assert "child" not in s.state["resources"] # there will be a load package to reset the state but also a load package to update the child table - assert len(info.load_packages[0].jobs['completed_jobs']) == 2 - assert {job.job_file_info.table_name for job in info.load_packages[0].jobs['completed_jobs'] } == {"_dlt_pipeline_state", "child"} + assert len(info.load_packages[0].jobs["completed_jobs"]) == 2 + assert { + job.job_file_info.table_name for job in info.load_packages[0].jobs["completed_jobs"] + } == {"_dlt_pipeline_state", "child"} def test_incremental_as_transform() -> None: - now = pendulum.now().timestamp() @dlt.resource @@ -512,7 +566,9 @@ def some_data(): for i in range(-10, 10): yield {"delta": i, "item": {"ts": pendulum.now().add(days=i).timestamp()}} - r = some_data().add_step(dlt.sources.incremental("item.ts", initial_value=now, primary_key="delta")) + r = some_data().add_step( + dlt.sources.incremental("item.ts", initial_value=now, primary_key="delta") + ) p = dlt.pipeline(pipeline_name=uniq_id()) info = p.run(r, destination="duckdb") assert len(info.loads_ids) == 1 @@ -543,12 +599,11 @@ def some_data(last_timestamp=dlt.sources.incremental("item.ts", primary_key=())) def test_apply_hints_incremental() -> None: - p = dlt.pipeline(pipeline_name=uniq_id()) @dlt.resource def some_data(created_at: Optional[dlt.sources.incremental] = None): - yield [1,2,3] + yield [1, 2, 3] # the incremental wrapper is created for a resource and the incremental value is provided via apply hints r = some_data() @@ -569,8 +624,8 @@ def some_data(created_at: Optional[dlt.sources.incremental] = None): assert r.state["incremental"]["$"]["last_value"] == 1 @dlt.resource - def some_data_w_default(created_at = dlt.sources.incremental("$", last_value_func=min)): - yield [1,2,3] + def some_data_w_default(created_at=dlt.sources.incremental("$", last_value_func=min)): + yield [1, 2, 3] # default is overridden by apply hints p = p.drop() @@ -595,12 +650,12 @@ def some_data_no_incremental(): def test_last_value_func_on_dict() -> None: - """Test last value which is a dictionary""" + def by_event_type(event): last_value = None if len(event) == 1: - item, = event + (item,) = event else: item, last_value = event @@ -609,12 +664,18 @@ def by_event_type(event): else: last_value = dict(last_value) item_type = item["type"] - last_value[item_type] = max(item["created_at"], last_value.get(item_type, "1970-01-01T00:00:00Z")) + last_value[item_type] = max( + item["created_at"], last_value.get(item_type, "1970-01-01T00:00:00Z") + ) return last_value - @dlt.resource(primary_key="id", table_name=lambda i: i['type']) - def _get_shuffled_events(last_created_at = dlt.sources.incremental("$", last_value_func=by_event_type)): - with open("tests/normalize/cases/github.events.load_page_1_duck.json", "r", encoding="utf-8") as f: + @dlt.resource(primary_key="id", table_name=lambda i: i["type"]) + def _get_shuffled_events( + last_created_at=dlt.sources.incremental("$", last_value_func=by_event_type) + ): + with open( + "tests/normalize/cases/github.events.load_page_1_duck.json", "r", encoding="utf-8" + ) as f: yield json.load(f) with Container().injectable_context(StateInjectableContext(state={})): @@ -638,8 +699,15 @@ def test_timezone_naive_datetime() -> None: pendulum_start_dt = pendulum.instance(start_dt) # With timezone @dlt.resource - def some_data(updated_at: dlt.sources.incremental[pendulum.DateTime] = dlt.sources.incremental('updated_at', pendulum_start_dt)): - yield [{'updated_at': start_dt + timedelta(hours=1)}, {'updated_at': start_dt + timedelta(hours=2)}] + def some_data( + updated_at: dlt.sources.incremental[pendulum.DateTime] = dlt.sources.incremental( + "updated_at", pendulum_start_dt + ) + ): + yield [ + {"updated_at": start_dt + timedelta(hours=1)}, + {"updated_at": start_dt + timedelta(hours=2)}, + ] pipeline = dlt.pipeline(pipeline_name=uniq_id()) pipeline.extract(some_data()) @@ -647,19 +715,21 @@ def some_data(updated_at: dlt.sources.incremental[pendulum.DateTime] = dlt.sourc @dlt.resource def endless_sequence( - updated_at: dlt.sources.incremental[int] = dlt.sources.incremental('updated_at', initial_value=1) + updated_at: dlt.sources.incremental[int] = dlt.sources.incremental( + "updated_at", initial_value=1 + ) ) -> Any: max_values = 20 start = updated_at.last_value for i in range(start, start + max_values): - yield {'updated_at': i} + yield {"updated_at": i} def test_chunked_ranges() -> None: """Load chunked ranges with end value along with incremental""" - pipeline = dlt.pipeline(pipeline_name='incremental_' + uniq_id(), destination='duckdb') + pipeline = dlt.pipeline(pipeline_name="incremental_" + uniq_id(), destination="duckdb") chunks = [ # Load some start/end ranges in and out of order @@ -678,76 +748,99 @@ def test_chunked_ranges() -> None: for start, end in chunks: pipeline.run( - endless_sequence(updated_at=dlt.sources.incremental(initial_value=start, end_value=end)), - write_disposition='append' + endless_sequence( + updated_at=dlt.sources.incremental(initial_value=start, end_value=end) + ), + write_disposition="append", ) - expected_range = list(chain( - range(10, 20), - range(20, 30), - range(40, 50), - range(50, 60), - range(60, 61), - range(62, 70), - range(70, 89), - range(89, 109), - )) + expected_range = list( + chain( + range(10, 20), + range(20, 30), + range(40, 50), + range(50, 60), + range(60, 61), + range(62, 70), + range(70, 89), + range(89, 109), + ) + ) with pipeline.sql_client() as client: - items = [row[0] for row in client.execute_sql("SELECT updated_at FROM endless_sequence ORDER BY updated_at")] + items = [ + row[0] + for row in client.execute_sql( + "SELECT updated_at FROM endless_sequence ORDER BY updated_at" + ) + ] assert items == expected_range def test_end_value_with_batches() -> None: """Ensure incremental with end_value works correctly when resource yields lists instead of single items""" + @dlt.resource def batched_sequence( - updated_at: dlt.sources.incremental[int] = dlt.sources.incremental('updated_at', initial_value=1) + updated_at: dlt.sources.incremental[int] = dlt.sources.incremental( + "updated_at", initial_value=1 + ) ) -> Any: start = updated_at.last_value - yield [{'updated_at': i} for i in range(start, start + 12)] - yield [{'updated_at': i} for i in range(start+12, start + 20)] + yield [{"updated_at": i} for i in range(start, start + 12)] + yield [{"updated_at": i} for i in range(start + 12, start + 20)] - pipeline = dlt.pipeline(pipeline_name='incremental_' + uniq_id(), destination='duckdb') + pipeline = dlt.pipeline(pipeline_name="incremental_" + uniq_id(), destination="duckdb") pipeline.run( batched_sequence(updated_at=dlt.sources.incremental(initial_value=1, end_value=10)), - write_disposition='append' + write_disposition="append", ) with pipeline.sql_client() as client: - items = [row[0] for row in client.execute_sql("SELECT updated_at FROM batched_sequence ORDER BY updated_at")] + items = [ + row[0] + for row in client.execute_sql( + "SELECT updated_at FROM batched_sequence ORDER BY updated_at" + ) + ] assert items == list(range(1, 10)) pipeline.run( batched_sequence(updated_at=dlt.sources.incremental(initial_value=10, end_value=14)), - write_disposition='append' + write_disposition="append", ) with pipeline.sql_client() as client: - items = [row[0] for row in client.execute_sql("SELECT updated_at FROM batched_sequence ORDER BY updated_at")] + items = [ + row[0] + for row in client.execute_sql( + "SELECT updated_at FROM batched_sequence ORDER BY updated_at" + ) + ] assert items == list(range(1, 14)) def test_load_with_end_value_does_not_write_state() -> None: - """When loading chunk with initial/end value range. The resource state is untouched. - """ - pipeline = dlt.pipeline(pipeline_name='incremental_' + uniq_id(), destination='duckdb') + """When loading chunk with initial/end value range. The resource state is untouched.""" + pipeline = dlt.pipeline(pipeline_name="incremental_" + uniq_id(), destination="duckdb") - pipeline.extract(endless_sequence(updated_at=dlt.sources.incremental(initial_value=20, end_value=30))) + pipeline.extract( + endless_sequence(updated_at=dlt.sources.incremental(initial_value=20, end_value=30)) + ) - assert pipeline.state.get('sources') is None + assert pipeline.state.get("sources") is None def test_end_value_initial_value_errors() -> None: @dlt.resource def some_data( - updated_at: dlt.sources.incremental[int] = dlt.sources.incremental('updated_at') + updated_at: dlt.sources.incremental[int] = dlt.sources.incremental("updated_at"), ) -> Any: - yield {'updated_at': 1} + yield {"updated_at": 1} # end_value without initial_value with pytest.raises(ConfigurationValueError) as ex: @@ -759,32 +852,54 @@ def some_data( with pytest.raises(ConfigurationValueError) as ex: list(some_data(updated_at=dlt.sources.incremental(initial_value=42, end_value=22))) - assert str(ex.value).startswith("Incremental 'initial_value' (42) is higher than 'end_value` (22)") + assert str(ex.value).startswith( + "Incremental 'initial_value' (42) is higher than 'end_value` (22)" + ) # max function and end_value higher than initial_value with pytest.raises(ConfigurationValueError) as ex: - list(some_data(updated_at=dlt.sources.incremental(initial_value=22, end_value=42, last_value_func=min))) + list( + some_data( + updated_at=dlt.sources.incremental( + initial_value=22, end_value=42, last_value_func=min + ) + ) + ) - assert str(ex.value).startswith("Incremental 'initial_value' (22) is lower than 'end_value` (42).") + assert str(ex.value).startswith( + "Incremental 'initial_value' (22) is lower than 'end_value` (42)." + ) def custom_last_value(items): # type: ignore[no-untyped-def] return max(items) # custom function which evaluates end_value lower than initial with pytest.raises(ConfigurationValueError) as ex: - list(some_data(updated_at=dlt.sources.incremental(initial_value=42, end_value=22, last_value_func=custom_last_value))) + list( + some_data( + updated_at=dlt.sources.incremental( + initial_value=42, end_value=22, last_value_func=custom_last_value + ) + ) + ) - assert "The result of 'custom_last_value([end_value, initial_value])' must equal 'end_value'" in str(ex.value) + assert ( + "The result of 'custom_last_value([end_value, initial_value])' must equal 'end_value'" + in str(ex.value) + ) def test_out_of_range_flags() -> None: """Test incremental.start_out_of_range / end_out_of_range flags are set when items are filtered out""" + @dlt.resource def descending( - updated_at: dlt.sources.incremental[int] = dlt.sources.incremental('updated_at', initial_value=10) + updated_at: dlt.sources.incremental[int] = dlt.sources.incremental( + "updated_at", initial_value=10 + ) ) -> Any: for chunk in chunks(list(reversed(range(48))), 10): - yield [{'updated_at': i} for i in chunk] + yield [{"updated_at": i} for i in chunk] # Assert flag is set only on the first item < initial_value if all(item > 9 for item in chunk): assert updated_at.start_out_of_range is False @@ -794,10 +909,12 @@ def descending( @dlt.resource def ascending( - updated_at: dlt.sources.incremental[int] = dlt.sources.incremental('updated_at', initial_value=22, end_value=45) + updated_at: dlt.sources.incremental[int] = dlt.sources.incremental( + "updated_at", initial_value=22, end_value=45 + ) ) -> Any: for chunk in chunks(list(range(22, 500)), 10): - yield [{'updated_at': i} for i in chunk] + yield [{"updated_at": i} for i in chunk] # Flag is set only when end_value is reached if all(item < 45 for item in chunk): assert updated_at.end_out_of_range is False @@ -805,13 +922,14 @@ def ascending( assert updated_at.end_out_of_range is True return - @dlt.resource def descending_single_item( - updated_at: dlt.sources.incremental[int] = dlt.sources.incremental('updated_at', initial_value=10) + updated_at: dlt.sources.incremental[int] = dlt.sources.incremental( + "updated_at", initial_value=10 + ) ) -> Any: for i in reversed(range(14)): - yield {'updated_at': i} + yield {"updated_at": i} if i >= 10: assert updated_at.start_out_of_range is False else: @@ -820,17 +938,19 @@ def descending_single_item( @dlt.resource def ascending_single_item( - updated_at: dlt.sources.incremental[int] = dlt.sources.incremental('updated_at', initial_value=10, end_value=22) + updated_at: dlt.sources.incremental[int] = dlt.sources.incremental( + "updated_at", initial_value=10, end_value=22 + ) ) -> Any: for i in range(10, 500): - yield {'updated_at': i} + yield {"updated_at": i} if i < 22: assert updated_at.end_out_of_range is False else: assert updated_at.end_out_of_range is True return - pipeline = dlt.pipeline(pipeline_name='incremental_' + uniq_id(), destination='duckdb') + pipeline = dlt.pipeline(pipeline_name="incremental_" + uniq_id(), destination="duckdb") pipeline.extract(descending()) @@ -846,13 +966,23 @@ def test_get_incremental_value_type() -> None: assert dlt.sources.incremental("id", initial_value=0).get_incremental_value_type() is int assert dlt.sources.incremental("id", initial_value=None).get_incremental_value_type() is Any assert dlt.sources.incremental[int]("id").get_incremental_value_type() is int - assert dlt.sources.incremental[pendulum.DateTime]("id").get_incremental_value_type() is pendulum.DateTime + assert ( + dlt.sources.incremental[pendulum.DateTime]("id").get_incremental_value_type() + is pendulum.DateTime + ) # typing has precedence - assert dlt.sources.incremental[pendulum.DateTime]("id", initial_value=1).get_incremental_value_type() is pendulum.DateTime + assert ( + dlt.sources.incremental[pendulum.DateTime]( + "id", initial_value=1 + ).get_incremental_value_type() + is pendulum.DateTime + ) # pass default value @dlt.resource - def test_type(updated_at = dlt.sources.incremental[str]("updated_at", allow_external_schedulers=True)): # noqa: B008 + def test_type( + updated_at=dlt.sources.incremental[str]("updated_at", allow_external_schedulers=True) + ): # noqa: B008 yield [{"updated_at": d} for d in [1, 2, 3]] r = test_type() @@ -861,7 +991,11 @@ def test_type(updated_at = dlt.sources.incremental[str]("updated_at", allow_exte # use annotation @dlt.resource - def test_type_2(updated_at: dlt.sources.incremental[int] = dlt.sources.incremental("updated_at", allow_external_schedulers=True)): + def test_type_2( + updated_at: dlt.sources.incremental[int] = dlt.sources.incremental( + "updated_at", allow_external_schedulers=True + ) + ): yield [{"updated_at": d} for d in [1, 2, 3]] r = test_type_2() @@ -879,7 +1013,9 @@ def test_type_3(updated_at: dlt.sources.incremental): # pass explicit value overriding default that is typed @dlt.resource - def test_type_4(updated_at = dlt.sources.incremental("updated_at", allow_external_schedulers=True)): + def test_type_4( + updated_at=dlt.sources.incremental("updated_at", allow_external_schedulers=True) + ): yield [{"updated_at": d} for d in [1, 2, 3]] r = test_type_4(dlt.sources.incremental[str]("updated_at", allow_external_schedulers=True)) @@ -888,7 +1024,9 @@ def test_type_4(updated_at = dlt.sources.incremental("updated_at", allow_externa # no generic type information @dlt.resource - def test_type_5(updated_at = dlt.sources.incremental("updated_at", allow_external_schedulers=True)): + def test_type_5( + updated_at=dlt.sources.incremental("updated_at", allow_external_schedulers=True) + ): yield [{"updated_at": d} for d in [1, 2, 3]] r = test_type_5(dlt.sources.incremental("updated_at")) @@ -898,27 +1036,35 @@ def test_type_5(updated_at = dlt.sources.incremental("updated_at", allow_externa def test_join_env_scheduler() -> None: @dlt.resource - def test_type_2(updated_at: dlt.sources.incremental[int] = dlt.sources.incremental("updated_at", allow_external_schedulers=True)): + def test_type_2( + updated_at: dlt.sources.incremental[int] = dlt.sources.incremental( + "updated_at", allow_external_schedulers=True + ) + ): yield [{"updated_at": d} for d in [1, 2, 3]] - assert list(test_type_2()) == [{'updated_at': 1}, {'updated_at': 2}, {'updated_at': 3}] + assert list(test_type_2()) == [{"updated_at": 1}, {"updated_at": 2}, {"updated_at": 3}] # set start and end values os.environ["DLT_START_VALUE"] = "2" - assert list(test_type_2()) == [{'updated_at': 2}, {'updated_at': 3}] + assert list(test_type_2()) == [{"updated_at": 2}, {"updated_at": 3}] os.environ["DLT_END_VALUE"] = "3" - assert list(test_type_2()) == [{'updated_at': 2}] + assert list(test_type_2()) == [{"updated_at": 2}] def test_join_env_scheduler_pipeline() -> None: @dlt.resource - def test_type_2(updated_at: dlt.sources.incremental[int] = dlt.sources.incremental("updated_at", allow_external_schedulers=True)): + def test_type_2( + updated_at: dlt.sources.incremental[int] = dlt.sources.incremental( + "updated_at", allow_external_schedulers=True + ) + ): yield [{"updated_at": d} for d in [1, 2, 3]] - pip_1_name = 'incremental_' + uniq_id() - pipeline = dlt.pipeline(pipeline_name=pip_1_name, destination='duckdb') + pip_1_name = "incremental_" + uniq_id() + pipeline = dlt.pipeline(pipeline_name=pip_1_name, destination="duckdb") r = test_type_2() - r.add_step(AssertItems([{'updated_at': 2}, {'updated_at': 3}])) + r.add_step(AssertItems([{"updated_at": 2}, {"updated_at": 3}])) os.environ["DLT_START_VALUE"] = "2" pipeline.extract(r) # state is saved next extract has no items @@ -929,17 +1075,19 @@ def test_type_2(updated_at: dlt.sources.incremental[int] = dlt.sources.increment # setting end value will stop using state os.environ["DLT_END_VALUE"] = "3" r = test_type_2() - r.add_step(AssertItems([{'updated_at': 2}])) + r.add_step(AssertItems([{"updated_at": 2}])) pipeline.extract(r) r = test_type_2() os.environ["DLT_START_VALUE"] = "1" - r.add_step(AssertItems([{'updated_at': 1}, {'updated_at': 2}])) + r.add_step(AssertItems([{"updated_at": 1}, {"updated_at": 2}])) pipeline.extract(r) def test_allow_external_schedulers() -> None: @dlt.resource() - def test_type_2(updated_at: dlt.sources.incremental[int] = dlt.sources.incremental("updated_at")): + def test_type_2( + updated_at: dlt.sources.incremental[int] = dlt.sources.incremental("updated_at"), + ): yield [{"updated_at": d} for d in [1, 2, 3]] # does not participate diff --git a/tests/extract/test_sources.py b/tests/extract/test_sources.py index 36e7737415..f8ab7b4e49 100644 --- a/tests/extract/test_sources.py +++ b/tests/extract/test_sources.py @@ -1,4 +1,5 @@ import itertools + import pytest import dlt @@ -7,10 +8,17 @@ from dlt.common.pipeline import StateInjectableContext, source_state from dlt.common.schema import Schema from dlt.common.typing import TDataItems -from dlt.extract.exceptions import InvalidParentResourceDataType, InvalidParentResourceIsAFunction, InvalidTransformerDataTypeGeneratorFunctionRequired, InvalidTransformerGeneratorFunction, ParametrizedResourceUnbound, ResourcesNotFoundError +from dlt.extract.exceptions import ( + InvalidParentResourceDataType, + InvalidParentResourceIsAFunction, + InvalidTransformerDataTypeGeneratorFunctionRequired, + InvalidTransformerGeneratorFunction, + ParametrizedResourceUnbound, + ResourcesNotFoundError, +) from dlt.extract.pipe import Pipe -from dlt.extract.typing import FilterItem, MapItem from dlt.extract.source import DltResource, DltSource +from dlt.extract.typing import FilterItem, MapItem def test_call_data_resource() -> None: @@ -19,8 +27,7 @@ def test_call_data_resource() -> None: def test_parametrized_resource() -> None: - - def parametrized(p1, /, p2, *, p3 = None): + def parametrized(p1, /, p2, *, p3=None): assert p1 == "p1" assert p2 == 1 assert p3 is None @@ -65,8 +72,7 @@ def parametrized(p1, /, p2, *, p3 = None): def test_parametrized_transformer() -> None: - - def good_transformer(item, /, p1, p2, *, p3 = None): + def good_transformer(item, /, p1, p2, *, p3=None): assert p1 == "p1" assert p2 == 2 assert p3 is None @@ -134,9 +140,9 @@ def bad_transformer_3(*, item): def assert_items(_items: TDataItems) -> None: # 2 items yielded * p2=2 - assert len(_items) == 2*2 - assert _items[0] == {'wrap': 'itemX', 'mark': 'p1', 'iter': 0} - assert _items[3] == {'wrap': 'itemY', 'mark': 'p1', 'iter': 1} + assert len(_items) == 2 * 2 + assert _items[0] == {"wrap": "itemX", "mark": "p1", "iter": 0} + assert _items[3] == {"wrap": "itemY", "mark": "p1", "iter": 1} assert_items(items) @@ -148,7 +154,6 @@ def assert_items(_items: TDataItems) -> None: def test_resource_bind_when_in_source() -> None: - @dlt.resource def parametrized(_range: int): yield list(range(_range)) @@ -185,7 +190,6 @@ def test_source(): def test_resource_bind_call_forms() -> None: - @dlt.resource def returns_res(_input): # resource returning resource @@ -228,7 +232,6 @@ def regular(_input): b_returns_pipe = returns_pipe("ABCA") assert len(b_returns_pipe._pipe) == 1 - @dlt.source def test_source(): return returns_res, returns_pipe, regular @@ -241,7 +244,7 @@ def test_source(): assert s.regular._pipe is not regular._pipe # will repeat each string 3 times - s.regular.add_map(lambda i: i*3) + s.regular.add_map(lambda i: i * 3) assert len(regular._pipe) == 2 assert len(s.regular._pipe) == 3 @@ -252,14 +255,14 @@ def test_source(): assert list(s.regular) == ["AAA", "AAA", "AAA"] # binding resource that returns resource will replace the object content, keeping the object id - s.returns_res.add_map(lambda i: i*3) + s.returns_res.add_map(lambda i: i * 3) s.returns_res.bind(["X", "Y", "Z"]) # got rid of all mapping and filter functions assert len(s.returns_res._pipe) == 1 assert list(s.returns_res) == ["X", "Y", "Z"] # same for resource returning pipe - s.returns_pipe.add_map(lambda i: i*3) + s.returns_pipe.add_map(lambda i: i * 3) s.returns_pipe.bind(["X", "Y", "M"]) # got rid of all mapping and filter functions assert len(s.returns_pipe._pipe) == 1 @@ -267,12 +270,11 @@ def test_source(): # s.regular is exhausted so set it again # add lambda that after filtering for A, will multiply it by 4 - s.resources["regular"] = regular.add_map(lambda i: i*4)(["A", "Y"]) - assert list(s) == ['X', 'Y', 'Z', 'X', 'Y', 'M', 'AAAA'] + s.resources["regular"] = regular.add_map(lambda i: i * 4)(["A", "Y"]) + assert list(s) == ["X", "Y", "Z", "X", "Y", "M", "AAAA"] def test_call_clone_separate_pipe() -> None: - all_yields = [] def some_data_gen(param: str): @@ -293,14 +295,13 @@ def some_data(param: str): def test_resource_bind_lazy_eval() -> None: - @dlt.resource def needs_param(param): yield from range(param) @dlt.transformer(data_from=needs_param(3)) def tx_form(item, multi): - yield item*multi + yield item * multi @dlt.transformer(data_from=tx_form(2)) def tx_form_fin(item, div): @@ -308,7 +309,7 @@ def tx_form_fin(item, div): @dlt.transformer(data_from=needs_param) def tx_form_dir(item, multi): - yield item*multi + yield item * multi # tx_form takes data from needs_param(3) which is lazily evaluated assert list(tx_form(2)) == [0, 2, 4] @@ -316,8 +317,8 @@ def tx_form_dir(item, multi): assert list(tx_form(2)) == [0, 2, 4] # same for tx_form_fin - assert list(tx_form_fin(3)) == [0, 2/3, 4/3] - assert list(tx_form_fin(3)) == [0, 2/3, 4/3] + assert list(tx_form_fin(3)) == [0, 2 / 3, 4 / 3] + assert list(tx_form_fin(3)) == [0, 2 / 3, 4 / 3] # binding `needs_param`` in place will not affect the tx_form and tx_form_fin (they operate on copies) needs_param.bind(4) @@ -331,7 +332,6 @@ def tx_form_dir(item, multi): def test_transformer_preliminary_step() -> None: - def yield_twice(item): yield item.upper() yield item.upper() @@ -340,17 +340,23 @@ def yield_twice(item): # filter out small caps and insert this before the head tx_stage.add_filter(FilterItem(lambda letter: letter.isupper()), 0) # be got filtered out before duplication - assert list(dlt.resource(["A", "b", "C"], name="data") | tx_stage) == ['A', 'A', 'C', 'C'] + assert list(dlt.resource(["A", "b", "C"], name="data") | tx_stage) == ["A", "A", "C", "C"] # filter after duplication tx_stage = dlt.transformer()(yield_twice)() tx_stage.add_filter(FilterItem(lambda letter: letter.isupper())) # nothing is filtered out: on duplicate we also capitalize so filter does not trigger - assert list(dlt.resource(["A", "b", "C"], name="data") | tx_stage) == ['A', 'A', 'B', 'B', 'C', 'C'] + assert list(dlt.resource(["A", "b", "C"], name="data") | tx_stage) == [ + "A", + "A", + "B", + "B", + "C", + "C", + ] def test_select_resources() -> None: - @dlt.source def test_source(no_resources): for i in range(no_resources): @@ -376,7 +382,11 @@ def test_source(no_resources): s_sel = s.with_resources("resource_1", "resource_7") # returns a clone assert s is not s_sel - assert list(s_sel.selected_resources) == ["resource_1", "resource_7"] == list(s_sel.resources.selected) + assert ( + list(s_sel.selected_resources) + == ["resource_1", "resource_7"] + == list(s_sel.resources.selected) + ) assert list(s_sel.resources) == all_resource_names info = str(s_sel) assert "resource resource_0 is not selected" in info @@ -394,7 +404,6 @@ def test_source(no_resources): def test_clone_source() -> None: @dlt.source def test_source(no_resources): - def _gen(i): yield "A" * i @@ -413,7 +422,7 @@ def _gen(i): # but we keep pipe ids assert s.resources[name]._pipe._pipe_id == clone_s.resources[name]._pipe._pipe_id - assert list(s) == ['', 'A', 'AA', 'AAA'] + assert list(s) == ["", "A", "AA", "AAA"] # we expired generators assert list(clone_s) == [] @@ -421,7 +430,6 @@ def _gen(i): @dlt.source def test_source(no_resources): - def _gen(i): yield "A" * i @@ -436,15 +444,13 @@ def _gen(i): clone_s.resources[name].bind(idx) # now thanks to late eval both sources evaluate separately - assert list(s) == ['', 'A', 'AA', 'AAA'] - assert list(clone_s) == ['', 'A', 'AA', 'AAA'] + assert list(s) == ["", "A", "AA", "AAA"] + assert list(clone_s) == ["", "A", "AA", "AAA"] def test_multiple_parametrized_transformers() -> None: - @dlt.source def _source(test_set: int = 1): - @dlt.resource(selected=False) def _r1(): yield ["a", "b", "c"] @@ -455,7 +461,7 @@ def _t1(items, suffix): @dlt.transformer(data_from=_t1) def _t2(items, mul): - yield items*mul + yield items * mul if test_set == 1: return _r1, _t1, _t2 @@ -468,8 +474,7 @@ def _t2(items, mul): # true pipelining fun return _r1() | _t1("2") | _t2(2) - - expected_data = ['a_2', 'b_2', 'c_2', 'a_2', 'b_2', 'c_2'] + expected_data = ["a_2", "b_2", "c_2", "a_2", "b_2", "c_2"] # this s contains all resources s = _source(1) @@ -530,7 +535,6 @@ def _t2(items, mul): def test_extracted_resources_selector() -> None: @dlt.source def _source(test_set: int = 1): - @dlt.resource(selected=False, write_disposition="append") def _r1(): yield ["a", "b", "c"] @@ -541,7 +545,7 @@ def _t1(items, suffix): @dlt.transformer(data_from=_r1, write_disposition="merge") def _t2(items, mul): - yield items*mul + yield items * mul if test_set == 1: return _r1, _t1, _t2 @@ -579,10 +583,8 @@ def _t2(items, mul): def test_source_decompose() -> None: - @dlt.source def _source(): - @dlt.resource(selected=True) def _r_init(): yield ["-", "x", "!"] @@ -597,18 +599,18 @@ def _t1(items, suffix): @dlt.transformer(data_from=_r1) def _t2(items, mul): - yield items*mul + yield items * mul @dlt.transformer(data_from=_r1) def _t3(items, mul): for item in items: - yield item.upper()*mul + yield item.upper() * mul # add something to init @dlt.transformer(data_from=_r_init) def _t_init_post(items): for item in items: - yield item*2 + yield item * 2 @dlt.resource def _r_isolee(): @@ -631,7 +633,14 @@ def _r_isolee(): # keeps order of resources inside # here we didn't eliminate (_r_init, _r_init) as this not impacts decomposition, however this edge is not necessary - assert _source().resources.selected_dag == [("_r_init", "_r_init"), ("_r_init", "_t_init_post"), ('_r1', '_t1'), ('_r1', '_t2'), ('_r1', '_t3'), ('_r_isolee', '_r_isolee')] + assert _source().resources.selected_dag == [ + ("_r_init", "_r_init"), + ("_r_init", "_t_init_post"), + ("_r1", "_t1"), + ("_r1", "_t2"), + ("_r1", "_t3"), + ("_r_isolee", "_r_isolee"), + ] components = _source().decompose("scc") # first element contains _r_init assert "_r_init" in components[0].resources.selected.keys() @@ -660,10 +669,8 @@ def _r1(): assert "Bound DltResource" in str(py_ex.value) - @dlt.resource def res_in_res(table_name, w_d): - def _gen(s): yield from s @@ -671,7 +678,6 @@ def _gen(s): def test_resource_returning_resource() -> None: - @dlt.source def source_r_in_r(): yield res_in_res @@ -703,13 +709,19 @@ def test_source(no_resources): def test_add_transform_steps() -> None: # add all step types, using indexes. final steps # gen -> map that converts to str and multiplies character -> filter str of len 2 -> yield all characters in str separately - r = dlt.resource([1, 2, 3, 4], name="all").add_limit(3).add_yield_map(lambda i: (yield from i)).add_map(lambda i: str(i) * i, 1).add_filter(lambda i: len(i) == 2, 2) + r = ( + dlt.resource([1, 2, 3, 4], name="all") + .add_limit(3) + .add_yield_map(lambda i: (yield from i)) + .add_map(lambda i: str(i) * i, 1) + .add_filter(lambda i: len(i) == 2, 2) + ) assert list(r) == ["2", "2"] def test_add_transform_steps_pipe() -> None: r = dlt.resource([1, 2, 3], name="all") | (lambda i: str(i) * i) | (lambda i: (yield from i)) - assert list(r) == ['1', '2', '2', '3', '3', '3'] + assert list(r) == ["1", "2", "2", "3", "3", "3"] def test_limit_infinite_counter() -> None: @@ -718,7 +730,6 @@ def test_limit_infinite_counter() -> None: def test_limit_source() -> None: - def mul_c(item): yield from "A" * (item + 2) @@ -730,11 +741,10 @@ def infinite_source(): yield r | dlt.transformer(name=f"mul_c_{idx}")(mul_c) # transformer is not limited to 2 elements, infinite resource is, we have 3 resources - assert list(infinite_source().add_limit(2)) == ['A', 'A', 0, 'A', 'A', 'A', 1] * 3 + assert list(infinite_source().add_limit(2)) == ["A", "A", 0, "A", "A", "A", 1] * 3 def test_source_state() -> None: - @dlt.source def test_source(expected_state): assert source_state() == expected_state @@ -744,17 +754,16 @@ def test_source(expected_state): test_source({}).state dlt.pipeline(full_refresh=True) - assert test_source({}).state == {} + assert test_source({}).state == {} # inject state to see if what we write in state is there with Container().injectable_context(StateInjectableContext(state={})) as state: test_source({}).state["value"] = 1 test_source({"value": 1}) - assert state.state == {'sources': {'test_source': {'value': 1}}} + assert state.state == {"sources": {"test_source": {"value": 1}}} def test_resource_state() -> None: - @dlt.resource def test_resource(): yield [1, 2, 3] @@ -785,10 +794,14 @@ def test_source(): # resource section is current module print(state.state) # the resource that is a part of the source will create a resource state key in the source state key - assert state.state["sources"]["schema_section"] == {'resources': {'test_resource': {'in-source': True}}} - assert s.state == {'resources': {'test_resource': {'in-source': True}}} + assert state.state["sources"]["schema_section"] == { + "resources": {"test_resource": {"in-source": True}} + } + assert s.state == {"resources": {"test_resource": {"in-source": True}}} # the standalone resource will create key which is default schema name - assert state.state["sources"][p._make_schema_with_default_name().name] == {'resources': {'test_resource': {'direct': True}}} + assert state.state["sources"][p._make_schema_with_default_name().name] == { + "resources": {"test_resource": {"direct": True}} + } # def test_add_resources_to_source_simple() -> None: @@ -805,7 +818,6 @@ def test_resource_dict() -> None: def test_source_multiple_iterations() -> None: - def some_data(): yield [1, 2, 3] yield [1, 2, 3] @@ -820,23 +832,33 @@ def some_data(): def test_exhausted_property() -> None: - # this example will be exhausted after iteration def open_generator_data(): yield from [1, 2, 3, 4] + s = DltSource("source", "module", Schema("source"), [dlt.resource(open_generator_data())]) assert s.exhausted is False assert next(iter(s)) == 1 assert s.exhausted is True # lists will not exhaust - s = DltSource("source", "module", Schema("source"), [dlt.resource([1, 2, 3, 4], table_name="table", name="resource")]) + s = DltSource( + "source", + "module", + Schema("source"), + [dlt.resource([1, 2, 3, 4], table_name="table", name="resource")], + ) assert s.exhausted is False assert next(iter(s)) == 1 assert s.exhausted is False # iterators will not exhaust - s = DltSource("source", "module", Schema("source"), [dlt.resource(iter([1, 2, 3, 4]), table_name="table", name="resource")]) + s = DltSource( + "source", + "module", + Schema("source"), + [dlt.resource(iter([1, 2, 3, 4]), table_name="table", name="resource")], + ) assert s.exhausted is False assert next(iter(s)) == 1 assert s.exhausted is False @@ -844,21 +866,30 @@ def open_generator_data(): # having on exhausted generator resource will make the whole source exhausted def open_generator_data(): yield from [1, 2, 3, 4] - s = DltSource("source", "module", Schema("source"), [ dlt.resource([1, 2, 3, 4], table_name="table", name="resource"), dlt.resource(open_generator_data())]) + + s = DltSource( + "source", + "module", + Schema("source"), + [ + dlt.resource([1, 2, 3, 4], table_name="table", name="resource"), + dlt.resource(open_generator_data()), + ], + ) assert s.exhausted is False # execute the whole source list(s) assert s.exhausted is True - # source with transformers also exhausts @dlt.source def mysource(): r = dlt.resource(itertools.count(start=1), name="infinity").add_limit(5) yield r yield r | dlt.transformer(name="double")(lambda x: x * 2) + s = mysource() assert s.exhausted is False - assert next(iter(s)) == 2 # transformer is returned befor resource + assert next(iter(s)) == 2 # transformer is returned befor resource assert s.exhausted is True diff --git a/tests/extract/utils.py b/tests/extract/utils.py index 3cf7e6373c..cdce03c1f4 100644 --- a/tests/extract/utils.py +++ b/tests/extract/utils.py @@ -1,16 +1,23 @@ -from typing import Any, Optional -import pytest from itertools import zip_longest +from typing import Any, Optional -from dlt.common.typing import TDataItem, TDataItems, TAny +import pytest +from dlt.common.typing import TAny, TDataItem, TDataItems from dlt.extract.extract import ExtractorStorage from dlt.extract.typing import ItemTransform, ItemTransformFunc -def expect_extracted_file(storage: ExtractorStorage, schema_name: str, table_name: str, content: str) -> None: +def expect_extracted_file( + storage: ExtractorStorage, schema_name: str, table_name: str, content: str +) -> None: files = storage.list_files_to_normalize_sorted() - gen = (file for file in files if storage.get_schema_name(file) == schema_name and storage.parse_normalize_file_name(file).table_name == table_name) + gen = ( + file + for file in files + if storage.get_schema_name(file) == schema_name + and storage.parse_normalize_file_name(file).table_name == table_name + ) file = next(gen, None) if file is None: raise FileNotFoundError(storage.build_extracted_file_stem(schema_name, table_name, "***")) @@ -27,9 +34,9 @@ def expect_extracted_file(storage: ExtractorStorage, schema_name: str, table_nam class AssertItems(ItemTransform[TDataItem]): - def __init__(self, expected_items: Any) -> None: - self.expected_items = expected_items + def __init__(self, expected_items: Any) -> None: + self.expected_items = expected_items - def __call__(self, item: TDataItems, meta: Any = None) -> Optional[TDataItems]: + def __call__(self, item: TDataItems, meta: Any = None) -> Optional[TDataItems]: assert item == self.expected_items return item diff --git a/tests/helpers/airflow_tests/conftest.py b/tests/helpers/airflow_tests/conftest.py index 023aab88c2..4883319bd3 100644 --- a/tests/helpers/airflow_tests/conftest.py +++ b/tests/helpers/airflow_tests/conftest.py @@ -1,2 +1,2 @@ from tests.helpers.airflow_tests.utils import initialize_airflow_db -from tests.utils import preserve_environ, autouse_test_storage, TEST_STORAGE_ROOT, patch_home_dir \ No newline at end of file +from tests.utils import TEST_STORAGE_ROOT, autouse_test_storage, patch_home_dir, preserve_environ diff --git a/tests/helpers/airflow_tests/test_airflow_provider.py b/tests/helpers/airflow_tests/test_airflow_provider.py index 447006932b..5d9eed2333 100644 --- a/tests/helpers/airflow_tests/test_airflow_provider.py +++ b/tests/helpers/airflow_tests/test_airflow_provider.py @@ -1,18 +1,18 @@ from airflow import DAG -from airflow.decorators import task, dag -from airflow.operators.python import PythonOperator -from airflow.models.variable import Variable +from airflow.decorators import dag, task from airflow.models.taskinstance import TaskInstance -from airflow.utils.state import State, DagRunState +from airflow.models.variable import Variable +from airflow.operators.python import PythonOperator +from airflow.utils.state import DagRunState, State from airflow.utils.types import DagRunType import dlt from dlt.common import pendulum from dlt.common.configuration.container import Container -from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext from dlt.common.configuration.providers.toml import SECRETS_TOML_KEY +from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext -DEFAULT_DATE = pendulum.datetime(2023, 4, 18, tz='Europe/Berlin') +DEFAULT_DATE = pendulum.datetime(2023, 4, 18, tz="Europe/Berlin") # Test data SECRETS_TOML_CONTENT = """ [sources] @@ -21,7 +21,6 @@ def test_airflow_secrets_toml_provider() -> None: - @dag(start_date=DEFAULT_DATE) def test_dag(): from dlt.common.configuration.providers.airflow import AirflowSecretsTomlProvider @@ -33,18 +32,17 @@ def test_dag(): @task() def test_task(): - provider = AirflowSecretsTomlProvider() - api_key, _ = provider.get_value('api_key', str, None, 'sources') + api_key, _ = provider.get_value("api_key", str, None, "sources") # There's no pytest context here in the task, so we need to return # the results as a dict and assert them in the test function. # See ti.xcom_pull() below. return { - 'name': provider.name, - 'supports_secrets': provider.supports_secrets, - 'api_key_from_provider': api_key, + "name": provider.name, + "supports_secrets": provider.supports_secrets, + "api_key_from_provider": api_key, } test_task() @@ -61,12 +59,12 @@ def test_task(): ti.run() # print(task_def.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)) - result = ti.xcom_pull(task_ids='test_task') + result = ti.xcom_pull(task_ids="test_task") assert ti.state == State.SUCCESS - assert result['name'] == 'Airflow Secrets TOML Provider' - assert result['supports_secrets'] - assert result['api_key_from_provider'] == 'test_value' + assert result["name"] == "Airflow Secrets TOML Provider" + assert result["supports_secrets"] + assert result["api_key_from_provider"] == "test_value" def test_airflow_secrets_toml_provider_import_dlt_dag() -> None: @@ -86,7 +84,7 @@ def test_dag(): @task() def test_task(): return { - 'api_key_from_provider': api_key, + "api_key_from_provider": api_key, } test_task() @@ -103,10 +101,10 @@ def test_task(): ti.run() # print(task_def.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)) - result = ti.xcom_pull(task_ids='test_task') + result = ti.xcom_pull(task_ids="test_task") assert ti.state == State.SUCCESS - assert result['api_key_from_provider'] == 'test_value' + assert result["api_key_from_provider"] == "test_value" def test_airflow_secrets_toml_provider_import_dlt_task() -> None: @@ -114,7 +112,6 @@ def test_airflow_secrets_toml_provider_import_dlt_task() -> None: @dag(start_date=DEFAULT_DATE) def test_dag(): - @task() def test_task(): Variable.set(SECRETS_TOML_KEY, SECRETS_TOML_CONTENT) @@ -125,7 +122,7 @@ def test_task(): api_key = secrets["sources.api_key"] return { - 'api_key_from_provider': api_key, + "api_key_from_provider": api_key, } test_task() @@ -142,14 +139,14 @@ def test_task(): ti.run() # print(task_def.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)) - result = ti.xcom_pull(task_ids='test_task') + result = ti.xcom_pull(task_ids="test_task") assert ti.state == State.SUCCESS - assert result['api_key_from_provider'] == 'test_value' + assert result["api_key_from_provider"] == "test_value" def test_airflow_secrets_toml_provider_is_loaded(): - dag = DAG(dag_id='test_dag', start_date=DEFAULT_DATE) + dag = DAG(dag_id="test_dag", start_date=DEFAULT_DATE) def test_task(): from dlt.common.configuration.providers.airflow import AirflowSecretsTomlProvider @@ -177,13 +174,11 @@ def test_task(): # the results as a dict and assert them in the test function. # See ti.xcom_pull() below. return { - 'airflow_secrets_toml_provider_is_loaded': astp_is_loaded, - 'api_key_from_provider': api_key, + "airflow_secrets_toml_provider_is_loaded": astp_is_loaded, + "api_key_from_provider": api_key, } - task = PythonOperator( - task_id='test_task', python_callable=test_task, dag=dag - ) + task = PythonOperator(task_id="test_task", python_callable=test_task, dag=dag) dag.create_dagrun( state=DagRunState.RUNNING, @@ -196,31 +191,31 @@ def test_task(): ti.run() - result = ti.xcom_pull(task_ids='test_task') + result = ti.xcom_pull(task_ids="test_task") assert ti.state == State.SUCCESS - assert result['airflow_secrets_toml_provider_is_loaded'] - assert result['api_key_from_provider'] == 'test_value' + assert result["airflow_secrets_toml_provider_is_loaded"] + assert result["api_key_from_provider"] == "test_value" def test_airflow_secrets_toml_provider_missing_variable(): - dag = DAG(dag_id='test_dag', start_date=DEFAULT_DATE) + dag = DAG(dag_id="test_dag", start_date=DEFAULT_DATE) def test_task(): - from dlt.common.configuration.specs import config_providers_context from dlt.common.configuration.providers.airflow import AirflowSecretsTomlProvider + from dlt.common.configuration.specs import config_providers_context # Make sure the variable is not set Variable.delete(SECRETS_TOML_KEY) providers = config_providers_context._extra_providers() - provider = next(provider for provider in providers if isinstance(provider, AirflowSecretsTomlProvider)) + provider = next( + provider for provider in providers if isinstance(provider, AirflowSecretsTomlProvider) + ) return { - 'airflow_secrets_toml': provider._toml.as_string(), + "airflow_secrets_toml": provider._toml.as_string(), } - task = PythonOperator( - task_id='test_task', python_callable=test_task, dag=dag - ) + task = PythonOperator(task_id="test_task", python_callable=test_task, dag=dag) dag.create_dagrun( state=DagRunState.RUNNING, @@ -233,20 +228,21 @@ def test_task(): ti.run() - result = ti.xcom_pull(task_ids='test_task') + result = ti.xcom_pull(task_ids="test_task") assert ti.state == State.SUCCESS - assert result['airflow_secrets_toml'] == "" + assert result["airflow_secrets_toml"] == "" def test_airflow_secrets_toml_provider_invalid_content(): - dag = DAG(dag_id='test_dag', start_date=DEFAULT_DATE) + dag = DAG(dag_id="test_dag", start_date=DEFAULT_DATE) def test_task(): import tomlkit + from dlt.common.configuration.providers.airflow import AirflowSecretsTomlProvider - Variable.set(SECRETS_TOML_KEY, 'invalid_content') + Variable.set(SECRETS_TOML_KEY, "invalid_content") # There's no pytest context here in the task, so we need # to catch the exception manually and return the result @@ -258,12 +254,10 @@ def test_task(): exception_raised = True return { - 'exception_raised': exception_raised, + "exception_raised": exception_raised, } - task = PythonOperator( - task_id='test_task', python_callable=test_task, dag=dag - ) + task = PythonOperator(task_id="test_task", python_callable=test_task, dag=dag) dag.create_dagrun( state=DagRunState.RUNNING, @@ -276,7 +270,7 @@ def test_task(): ti.run() - result = ti.xcom_pull(task_ids='test_task') + result = ti.xcom_pull(task_ids="test_task") assert ti.state == State.SUCCESS - assert result['exception_raised'] + assert result["exception_raised"] diff --git a/tests/helpers/airflow_tests/test_airflow_wrapper.py b/tests/helpers/airflow_tests/test_airflow_wrapper.py index da801b2c7b..e5db92c668 100644 --- a/tests/helpers/airflow_tests/test_airflow_wrapper.py +++ b/tests/helpers/airflow_tests/test_airflow_wrapper.py @@ -1,38 +1,36 @@ import os -import pytest from typing import List + +import pytest from airflow import DAG from airflow.decorators import dag -from airflow.operators.python import PythonOperator from airflow.models import TaskInstance +from airflow.operators.python import PythonOperator from airflow.utils.state import DagRunState from airflow.utils.types import DagRunType +from tests.load.pipeline.utils import load_table_counts +from tests.utils import TEST_STORAGE_ROOT import dlt from dlt.common import pendulum from dlt.common.utils import uniq_id -from dlt.helpers.airflow_helper import PipelineTasksGroup, DEFAULT_RETRY_BACKOFF +from dlt.helpers.airflow_helper import DEFAULT_RETRY_BACKOFF, PipelineTasksGroup from dlt.pipeline.exceptions import CannotRestorePipelineException, PipelineStepFailed -from tests.load.pipeline.utils import load_table_counts -from tests.utils import TEST_STORAGE_ROOT - - -DEFAULT_DATE = pendulum.datetime(2023, 4, 18, tz='Europe/Berlin') +DEFAULT_DATE = pendulum.datetime(2023, 4, 18, tz="Europe/Berlin") default_args = { - 'owner': 'airflow', - 'depends_on_past': False, - 'email_on_failure': False, - 'email_on_retry': False, - 'retries': 0, - 'max_active_runs': 1 + "owner": "airflow", + "depends_on_past": False, + "email_on_failure": False, + "email_on_retry": False, + "retries": 0, + "max_active_runs": 1, } @dlt.source def mock_data_source(): - @dlt.resource(selected=True) def _r_init(): yield ["-", "x", "!"] @@ -47,18 +45,18 @@ def _t1(items, suffix): @dlt.transformer(data_from=_r1) def _t2(items, mul): - yield items*mul + yield items * mul @dlt.transformer(data_from=_r1) def _t3(items, mul): for item in items: - yield item.upper()*mul + yield item.upper() * mul # add something to init @dlt.transformer(data_from=_r_init) def _t_init_post(items): for item in items: - yield item*2 + yield item * 2 @dlt.resource def _r_isolee(): @@ -69,7 +67,6 @@ def _r_isolee(): @dlt.source(section="mock_data_source_state") def mock_data_source_state(): - @dlt.resource(selected=True) def _r_init(): dlt.current.source_state()["counter"] = 1 @@ -94,7 +91,7 @@ def _t2(items, mul): dlt.current.source_state()["counter"] += 1 dlt.current.resource_state("_r1")["counter"] += 1 dlt.current.resource_state()["counter"] = 1 - yield items*mul + yield items * mul @dlt.transformer(data_from=_r1) def _t3(items, mul): @@ -102,13 +99,13 @@ def _t3(items, mul): dlt.current.resource_state("_r1")["counter"] += 1 dlt.current.resource_state()["counter"] = 1 for item in items: - yield item.upper()*mul + yield item.upper() * mul # add something to init @dlt.transformer(data_from=_r_init) def _t_init_post(items): for item in items: - yield item*2 + yield item * 2 @dlt.resource def _r_isolee(): @@ -121,53 +118,83 @@ def _r_isolee(): def test_regular_run() -> None: # run the pipeline normally pipeline_standalone = dlt.pipeline( - pipeline_name="pipeline_standalone", dataset_name="mock_data_" + uniq_id(), destination="duckdb", credentials=":pipeline:") + pipeline_name="pipeline_standalone", + dataset_name="mock_data_" + uniq_id(), + destination="duckdb", + credentials=":pipeline:", + ) pipeline_standalone.run(mock_data_source()) - pipeline_standalone_counts = load_table_counts(pipeline_standalone, *[t["name"] for t in pipeline_standalone.default_schema.data_tables()]) + pipeline_standalone_counts = load_table_counts( + pipeline_standalone, *[t["name"] for t in pipeline_standalone.default_schema.data_tables()] + ) tasks_list: List[PythonOperator] = None - @dag( - schedule=None, - start_date=DEFAULT_DATE, - catchup=False, - default_args=default_args - ) + + @dag(schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args) def dag_regular(): nonlocal tasks_list - tasks = PipelineTasksGroup("pipeline_dag_regular", local_data_folder=TEST_STORAGE_ROOT, wipe_local_data=False) + tasks = PipelineTasksGroup( + "pipeline_dag_regular", local_data_folder=TEST_STORAGE_ROOT, wipe_local_data=False + ) pipeline_dag_regular = dlt.pipeline( - pipeline_name="pipeline_dag_regular", dataset_name="mock_data_" + uniq_id(), destination="duckdb", credentials=":pipeline:") - tasks_list = tasks.add_run(pipeline_dag_regular, mock_data_source(), decompose="none", trigger_rule="all_done", retries=0, provide_context=True) + pipeline_name="pipeline_dag_regular", + dataset_name="mock_data_" + uniq_id(), + destination="duckdb", + credentials=":pipeline:", + ) + tasks_list = tasks.add_run( + pipeline_dag_regular, + mock_data_source(), + decompose="none", + trigger_rule="all_done", + retries=0, + provide_context=True, + ) dag_def: DAG = dag_regular() assert len(tasks_list) == 1 # composite task name - assert tasks_list[0].task_id == "pipeline_dag_regular.mock_data_source__r_init-_t_init_post-_t1-_t2-2-more" + assert ( + tasks_list[0].task_id + == "pipeline_dag_regular.mock_data_source__r_init-_t_init_post-_t1-_t2-2-more" + ) dag_def.test() # we should be able to attach to pipeline state created within Airflow pipeline_dag_regular = dlt.attach(pipeline_name="pipeline_dag_regular") - pipeline_dag_regular_counts = load_table_counts(pipeline_dag_regular, *[t["name"] for t in pipeline_dag_regular.default_schema.data_tables()]) + pipeline_dag_regular_counts = load_table_counts( + pipeline_dag_regular, + *[t["name"] for t in pipeline_dag_regular.default_schema.data_tables()], + ) # same data should be loaded assert pipeline_dag_regular_counts == pipeline_standalone_counts quackdb_path = os.path.join(TEST_STORAGE_ROOT, "pipeline_dag_decomposed.duckdb") - @dag( - schedule=None, - start_date=DEFAULT_DATE, - catchup=False, - default_args=default_args - ) + + @dag(schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args) def dag_decomposed(): nonlocal tasks_list - tasks = PipelineTasksGroup("pipeline_dag_decomposed", local_data_folder=TEST_STORAGE_ROOT, wipe_local_data=False) + tasks = PipelineTasksGroup( + "pipeline_dag_decomposed", local_data_folder=TEST_STORAGE_ROOT, wipe_local_data=False + ) # set duckdb to be outside of pipeline folder which is dropped on each task pipeline_dag_decomposed = dlt.pipeline( - pipeline_name="pipeline_dag_decomposed", dataset_name="mock_data_" + uniq_id(), destination="duckdb", credentials=quackdb_path) - tasks_list = tasks.add_run(pipeline_dag_decomposed, mock_data_source(), decompose="serialize", trigger_rule="all_done", retries=0, provide_context=True) + pipeline_name="pipeline_dag_decomposed", + dataset_name="mock_data_" + uniq_id(), + destination="duckdb", + credentials=quackdb_path, + ) + tasks_list = tasks.add_run( + pipeline_dag_decomposed, + mock_data_source(), + decompose="serialize", + trigger_rule="all_done", + retries=0, + provide_context=True, + ) dag_def: DAG = dag_decomposed() assert len(tasks_list) == 3 @@ -177,7 +204,10 @@ def dag_decomposed(): assert tasks_list[2].task_id == "pipeline_dag_decomposed.mock_data_source__r_isolee" dag_def.test() pipeline_dag_decomposed = dlt.attach(pipeline_name="pipeline_dag_decomposed") - pipeline_dag_decomposed_counts = load_table_counts(pipeline_dag_decomposed, *[t["name"] for t in pipeline_dag_decomposed.default_schema.data_tables()]) + pipeline_dag_decomposed_counts = load_table_counts( + pipeline_dag_decomposed, + *[t["name"] for t in pipeline_dag_decomposed.default_schema.data_tables()], + ) assert pipeline_dag_decomposed_counts == pipeline_standalone_counts @@ -200,7 +230,6 @@ def dag_decomposed(): def test_run_with_retry() -> None: - retries = 2 now = pendulum.now() @@ -212,19 +241,22 @@ def _fail_3(): raise Exception(f"Failed on retry #{retries}") yield from "ABC" - @dag( - schedule=None, - start_date=DEFAULT_DATE, - catchup=False, - default_args=default_args - ) + @dag(schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args) def dag_fail_3(): # by default we do not retry so this will fail - tasks = PipelineTasksGroup("pipeline_fail_3", local_data_folder=TEST_STORAGE_ROOT, wipe_local_data=False) + tasks = PipelineTasksGroup( + "pipeline_fail_3", local_data_folder=TEST_STORAGE_ROOT, wipe_local_data=False + ) pipeline_fail_3 = dlt.pipeline( - pipeline_name="pipeline_fail_3", dataset_name="mock_data_" + uniq_id(), destination="duckdb", credentials=":pipeline:") - tasks.add_run(pipeline_fail_3, _fail_3, trigger_rule="all_done", retries=0, provide_context=True) + pipeline_name="pipeline_fail_3", + dataset_name="mock_data_" + uniq_id(), + destination="duckdb", + credentials=":pipeline:", + ) + tasks.add_run( + pipeline_fail_3, _fail_3, trigger_rule="all_done", retries=0, provide_context=True + ) dag_def: DAG = dag_fail_3() ti = get_task_run(dag_def, "pipeline_fail_3.pipeline_fail_3", now) @@ -233,19 +265,25 @@ def dag_fail_3(): ti._run_raw_task() assert pip_ex.value.step == "extract" - @dag( - schedule=None, - start_date=DEFAULT_DATE, - catchup=False, - default_args=default_args - ) + @dag(schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args) def dag_fail_4(): # by default we do not retry extract so we fail - tasks = PipelineTasksGroup("pipeline_fail_3", retry_policy=DEFAULT_RETRY_BACKOFF, local_data_folder=TEST_STORAGE_ROOT, wipe_local_data=False) + tasks = PipelineTasksGroup( + "pipeline_fail_3", + retry_policy=DEFAULT_RETRY_BACKOFF, + local_data_folder=TEST_STORAGE_ROOT, + wipe_local_data=False, + ) pipeline_fail_3 = dlt.pipeline( - pipeline_name="pipeline_fail_3", dataset_name="mock_data_" + uniq_id(), destination="duckdb", credentials=":pipeline:") - tasks.add_run(pipeline_fail_3, _fail_3, trigger_rule="all_done", retries=0, provide_context=True) + pipeline_name="pipeline_fail_3", + dataset_name="mock_data_" + uniq_id(), + destination="duckdb", + credentials=":pipeline:", + ) + tasks.add_run( + pipeline_fail_3, _fail_3, trigger_rule="all_done", retries=0, provide_context=True + ) dag_def: DAG = dag_fail_4() ti = get_task_run(dag_def, "pipeline_fail_3.pipeline_fail_3", now) @@ -255,19 +293,26 @@ def dag_fail_4(): ti._run_raw_task() assert pip_ex.value.step == "extract" - @dag( - schedule=None, - start_date=DEFAULT_DATE, - catchup=False, - default_args=default_args - ) + @dag(schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args) def dag_fail_5(): # this will retry - tasks = PipelineTasksGroup("pipeline_fail_3", retry_policy=DEFAULT_RETRY_BACKOFF, retry_pipeline_steps=("load", "extract"), local_data_folder=TEST_STORAGE_ROOT, wipe_local_data=False) + tasks = PipelineTasksGroup( + "pipeline_fail_3", + retry_policy=DEFAULT_RETRY_BACKOFF, + retry_pipeline_steps=("load", "extract"), + local_data_folder=TEST_STORAGE_ROOT, + wipe_local_data=False, + ) pipeline_fail_3 = dlt.pipeline( - pipeline_name="pipeline_fail_3", dataset_name="mock_data_" + uniq_id(), destination="duckdb", credentials=":pipeline:") - tasks.add_run(pipeline_fail_3, _fail_3, trigger_rule="all_done", retries=0, provide_context=True) + pipeline_name="pipeline_fail_3", + dataset_name="mock_data_" + uniq_id(), + destination="duckdb", + credentials=":pipeline:", + ) + tasks.add_run( + pipeline_fail_3, _fail_3, trigger_rule="all_done", retries=0, provide_context=True + ) dag_def: DAG = dag_fail_5() ti = get_task_run(dag_def, "pipeline_fail_3.pipeline_fail_3", now) @@ -277,22 +322,30 @@ def dag_fail_5(): def test_run_decomposed_with_state_wipe() -> None: - dataset_name = "mock_data_" + uniq_id() pipeline_name = "pipeline_dag_regular_" + uniq_id() - @dag( - schedule=None, - start_date=DEFAULT_DATE, - catchup=False, - default_args=default_args - ) + @dag(schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args) def dag_regular(): - tasks = PipelineTasksGroup(pipeline_name, local_data_folder=TEST_STORAGE_ROOT, wipe_local_data=True, save_load_info=True, save_trace_info=True) + tasks = PipelineTasksGroup( + pipeline_name, + local_data_folder=TEST_STORAGE_ROOT, + wipe_local_data=True, + save_load_info=True, + save_trace_info=True, + ) pipeline_dag_regular = dlt.pipeline( - pipeline_name=pipeline_name, dataset_name=dataset_name, destination="duckdb") - tasks.add_run(pipeline_dag_regular, mock_data_source_state(), decompose="serialize", trigger_rule="all_done", retries=0, provide_context=True) + pipeline_name=pipeline_name, dataset_name=dataset_name, destination="duckdb" + ) + tasks.add_run( + pipeline_dag_regular, + mock_data_source_state(), + decompose="serialize", + trigger_rule="all_done", + retries=0, + provide_context=True, + ) dag_def: DAG = dag_regular() dag_def.test() @@ -302,7 +355,8 @@ def dag_regular(): dlt.attach(pipeline_name=pipeline_name) pipeline_dag_regular = dlt.pipeline( - pipeline_name=pipeline_name, dataset_name=dataset_name, destination="duckdb") + pipeline_name=pipeline_name, dataset_name=dataset_name, destination="duckdb" + ) pipeline_dag_regular.sync_destination() # print(pipeline_dag_regular.state) # now source can attach to state in the pipeline @@ -311,9 +365,9 @@ def dag_regular(): # end state was increased twice (in init and in isolee at the end) assert post_source.state["end_counter"] == 2 # the source counter was increased in init, _r1 and in 3 transformers * 3 items - assert post_source.state["counter"] == 1 + 1 + 3*3 + assert post_source.state["counter"] == 1 + 1 + 3 * 3 # resource counter _r1 - assert post_source._r1.state["counter"] == 1 + 3*3 + assert post_source._r1.state["counter"] == 1 + 3 * 3 # each transformer has a counter assert post_source._t1.state["counter"] == 1 assert post_source._t2.state["counter"] == 1 @@ -324,68 +378,114 @@ def test_run_multiple_sources() -> None: dataset_name = "mock_data_" + uniq_id() pipeline_name = "pipeline_dag_regular_" + uniq_id() - @dag( - schedule=None, - start_date=DEFAULT_DATE, - catchup=False, - default_args=default_args - ) + @dag(schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args) def dag_serialize(): - tasks = PipelineTasksGroup(pipeline_name, local_data_folder=TEST_STORAGE_ROOT, wipe_local_data=True) + tasks = PipelineTasksGroup( + pipeline_name, local_data_folder=TEST_STORAGE_ROOT, wipe_local_data=True + ) pipeline_dag_regular = dlt.pipeline( - pipeline_name=pipeline_name, dataset_name=dataset_name, destination="duckdb") - st_tasks = tasks.add_run(pipeline_dag_regular, mock_data_source_state(), decompose="serialize", trigger_rule="all_done", retries=0, provide_context=True) - nst_tasks = tasks.add_run(pipeline_dag_regular, mock_data_source(), decompose="serialize", trigger_rule="all_done", retries=0, provide_context=True) + pipeline_name=pipeline_name, dataset_name=dataset_name, destination="duckdb" + ) + st_tasks = tasks.add_run( + pipeline_dag_regular, + mock_data_source_state(), + decompose="serialize", + trigger_rule="all_done", + retries=0, + provide_context=True, + ) + nst_tasks = tasks.add_run( + pipeline_dag_regular, + mock_data_source(), + decompose="serialize", + trigger_rule="all_done", + retries=0, + provide_context=True, + ) # connect end of first run to a head of a second st_tasks[-1] >> nst_tasks[0] - dag_def: DAG = dag_serialize() dag_def.test() pipeline_dag_serial = dlt.pipeline( - pipeline_name=pipeline_name, dataset_name=dataset_name, destination="duckdb") + pipeline_name=pipeline_name, dataset_name=dataset_name, destination="duckdb" + ) pipeline_dag_serial.sync_destination() # we should have two schemas - assert set(pipeline_dag_serial.schema_names) == {'mock_data_source_state', 'mock_data_source'} - counters_st_tasks = load_table_counts(pipeline_dag_serial, *[t["name"] for t in pipeline_dag_serial.schemas['mock_data_source_state'].data_tables()]) - counters_nst_tasks = load_table_counts(pipeline_dag_serial, *[t["name"] for t in pipeline_dag_serial.schemas['mock_data_source'].data_tables()]) + assert set(pipeline_dag_serial.schema_names) == {"mock_data_source_state", "mock_data_source"} + counters_st_tasks = load_table_counts( + pipeline_dag_serial, + *[t["name"] for t in pipeline_dag_serial.schemas["mock_data_source_state"].data_tables()], + ) + counters_nst_tasks = load_table_counts( + pipeline_dag_serial, + *[t["name"] for t in pipeline_dag_serial.schemas["mock_data_source"].data_tables()], + ) # print(counters_st_tasks) # print(counters_nst_tasks) # this state is confirmed in other test - assert pipeline_dag_serial.state["sources"]["mock_data_source_state"] == {'counter': 11, 'end_counter': 2, 'resources': {'_r1': {'counter': 10}, '_t3': {'counter': 1}, '_t2': {'counter': 1}, '_t1': {'counter': 1}}} + assert pipeline_dag_serial.state["sources"]["mock_data_source_state"] == { + "counter": 11, + "end_counter": 2, + "resources": { + "_r1": {"counter": 10}, + "_t3": {"counter": 1}, + "_t2": {"counter": 1}, + "_t1": {"counter": 1}, + }, + } # next DAG does not connect subgraphs dataset_name = "mock_data_" + uniq_id() pipeline_name = "pipeline_dag_regular_" + uniq_id() - @dag( - schedule=None, - start_date=DEFAULT_DATE, - catchup=False, - default_args=default_args - ) + @dag(schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args) def dag_parallel(): - tasks = PipelineTasksGroup(pipeline_name, local_data_folder=TEST_STORAGE_ROOT, wipe_local_data=True) + tasks = PipelineTasksGroup( + pipeline_name, local_data_folder=TEST_STORAGE_ROOT, wipe_local_data=True + ) pipeline_dag_regular = dlt.pipeline( - pipeline_name=pipeline_name, dataset_name=dataset_name, destination="duckdb") - tasks.add_run(pipeline_dag_regular, mock_data_source_state(), decompose="serialize", trigger_rule="all_done", retries=0, provide_context=True) - tasks.add_run(pipeline_dag_regular, mock_data_source(), decompose="serialize", trigger_rule="all_done", retries=0, provide_context=True) + pipeline_name=pipeline_name, dataset_name=dataset_name, destination="duckdb" + ) + tasks.add_run( + pipeline_dag_regular, + mock_data_source_state(), + decompose="serialize", + trigger_rule="all_done", + retries=0, + provide_context=True, + ) + tasks.add_run( + pipeline_dag_regular, + mock_data_source(), + decompose="serialize", + trigger_rule="all_done", + retries=0, + provide_context=True, + ) # do not connect graph dag_def: DAG = dag_parallel() dag_def.test() pipeline_dag_parallel = dlt.pipeline( - pipeline_name=pipeline_name, dataset_name=dataset_name, destination="duckdb") + pipeline_name=pipeline_name, dataset_name=dataset_name, destination="duckdb" + ) pipeline_dag_parallel.sync_destination() # we should have two schemas - assert set(pipeline_dag_parallel.schema_names) == {'mock_data_source_state', 'mock_data_source'} - counters_st_tasks_par = load_table_counts(pipeline_dag_parallel, *[t["name"] for t in pipeline_dag_parallel.schemas['mock_data_source_state'].data_tables()]) - counters_nst_tasks_par = load_table_counts(pipeline_dag_parallel, *[t["name"] for t in pipeline_dag_parallel.schemas['mock_data_source'].data_tables()]) + assert set(pipeline_dag_parallel.schema_names) == {"mock_data_source_state", "mock_data_source"} + counters_st_tasks_par = load_table_counts( + pipeline_dag_parallel, + *[t["name"] for t in pipeline_dag_parallel.schemas["mock_data_source_state"].data_tables()], + ) + counters_nst_tasks_par = load_table_counts( + pipeline_dag_parallel, + *[t["name"] for t in pipeline_dag_parallel.schemas["mock_data_source"].data_tables()], + ) assert counters_st_tasks == counters_st_tasks_par assert counters_nst_tasks == counters_nst_tasks_par assert pipeline_dag_serial.state["sources"] == pipeline_dag_parallel.state["sources"] @@ -395,19 +495,31 @@ def dag_parallel(): dataset_name = "mock_data_" + uniq_id() pipeline_name = "pipeline_dag_regular_" + uniq_id() - @dag( - schedule=None, - start_date=DEFAULT_DATE, - catchup=False, - default_args=default_args - ) + @dag(schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args) def dag_mixed(): - tasks = PipelineTasksGroup(pipeline_name, local_data_folder=TEST_STORAGE_ROOT, wipe_local_data=True) + tasks = PipelineTasksGroup( + pipeline_name, local_data_folder=TEST_STORAGE_ROOT, wipe_local_data=True + ) pipeline_dag_regular = dlt.pipeline( - pipeline_name=pipeline_name, dataset_name=dataset_name, destination="duckdb") - pd_tasks = tasks.add_run(pipeline_dag_regular, mock_data_source_state(), decompose="serialize", trigger_rule="all_done", retries=0, provide_context=True) - hb_tasks = tasks.add_run(pipeline_dag_regular, mock_data_source(), decompose="serialize", trigger_rule="all_done", retries=0, provide_context=True) + pipeline_name=pipeline_name, dataset_name=dataset_name, destination="duckdb" + ) + pd_tasks = tasks.add_run( + pipeline_dag_regular, + mock_data_source_state(), + decompose="serialize", + trigger_rule="all_done", + retries=0, + provide_context=True, + ) + hb_tasks = tasks.add_run( + pipeline_dag_regular, + mock_data_source(), + decompose="serialize", + trigger_rule="all_done", + retries=0, + provide_context=True, + ) # create almost randomly connected tasks across two runs for pd_t, hb_t in zip(pd_tasks, hb_tasks): pd_t >> hb_t @@ -416,12 +528,19 @@ def dag_mixed(): dag_def.test() pipeline_dag_mixed = dlt.pipeline( - pipeline_name=pipeline_name, dataset_name=dataset_name, destination="duckdb") + pipeline_name=pipeline_name, dataset_name=dataset_name, destination="duckdb" + ) pipeline_dag_mixed.sync_destination() # we should have two schemas - assert set(pipeline_dag_mixed.schema_names) == {'mock_data_source_state', 'mock_data_source'} - counters_st_tasks_par = load_table_counts(pipeline_dag_mixed, *[t["name"] for t in pipeline_dag_mixed.schemas['mock_data_source_state'].data_tables()]) - counters_nst_tasks_par = load_table_counts(pipeline_dag_mixed, *[t["name"] for t in pipeline_dag_mixed.schemas['mock_data_source'].data_tables()]) + assert set(pipeline_dag_mixed.schema_names) == {"mock_data_source_state", "mock_data_source"} + counters_st_tasks_par = load_table_counts( + pipeline_dag_mixed, + *[t["name"] for t in pipeline_dag_mixed.schemas["mock_data_source_state"].data_tables()], + ) + counters_nst_tasks_par = load_table_counts( + pipeline_dag_mixed, + *[t["name"] for t in pipeline_dag_mixed.schemas["mock_data_source"].data_tables()], + ) assert counters_st_tasks == counters_st_tasks_par assert counters_nst_tasks == counters_nst_tasks_par assert pipeline_dag_serial.state["sources"] == pipeline_dag_mixed.state["sources"] @@ -434,8 +553,8 @@ def get_task_run(dag_def: DAG, task_name: str, now: pendulum.DateTime) -> TaskIn state=DagRunState.RUNNING, execution_date=now, run_type=DagRunType.MANUAL, - data_interval=(now, now) + data_interval=(now, now), ) dag_def.run(start_date=now, run_at_least_once=True) task_def = dag_def.task_dict[task_name] - return TaskInstance(task=task_def, execution_date=now) \ No newline at end of file + return TaskInstance(task=task_def, execution_date=now) diff --git a/tests/helpers/airflow_tests/test_join_airflow_scheduler.py b/tests/helpers/airflow_tests/test_join_airflow_scheduler.py index 0dc31a89ce..316f848fe1 100644 --- a/tests/helpers/airflow_tests/test_join_airflow_scheduler.py +++ b/tests/helpers/airflow_tests/test_join_airflow_scheduler.py @@ -1,44 +1,50 @@ import datetime -from pendulum.tz import UTC + from airflow import DAG from airflow.decorators import dag, task from airflow.models import DagRun from airflow.models.taskinstance import TaskInstance from airflow.operators.python import get_current_context # noqa -from airflow.utils.state import State, DagRunState +from airflow.utils.state import DagRunState, State from airflow.utils.types import DagRunType +from pendulum.tz import UTC import dlt from dlt.common import pendulum -from dlt.common.utils import uniq_id from dlt.common.time import ensure_pendulum_date +from dlt.common.utils import uniq_id # flake8: noqa: B008 CATCHUP_BEGIN = pendulum.datetime(2023, 1, 1, tz="Europe/Berlin") default_args = { - 'owner': 'airflow', - 'depends_on_past': False, - 'email_on_failure': False, - 'email_on_retry': False, - 'retries': 0, + "owner": "airflow", + "depends_on_past": False, + "email_on_failure": False, + "email_on_retry": False, + "retries": 0, } + @dlt.resource() -def existing_incremental(updated_at: dlt.sources.incremental[pendulum.DateTime] = dlt.sources.incremental("updated_at", allow_external_schedulers=True)): +def existing_incremental( + updated_at: dlt.sources.incremental[pendulum.DateTime] = dlt.sources.incremental( + "updated_at", allow_external_schedulers=True + ) +): yield {"updated_at": CATCHUP_BEGIN, "state": updated_at.get_state()} def test_date_coercion() -> None: - @dag(schedule_interval='@daily', + @dag( + schedule_interval="@daily", start_date=CATCHUP_BEGIN, catchup=False, max_active_runs=1, - default_args=default_args + default_args=default_args, ) def dag_regular(): - @task def scheduled() -> None: context = get_current_context() @@ -50,49 +56,78 @@ def scheduled() -> None: assert state["updated_at"] == CATCHUP_BEGIN assert "Europe/Berlin" in str(state["updated_at"].tz) # must have UTC timezone - assert state["state"]["initial_value"] == CATCHUP_BEGIN == context["data_interval_start"] + assert ( + state["state"]["initial_value"] == CATCHUP_BEGIN == context["data_interval_start"] + ) assert state["state"]["initial_value"].tz == UTC assert state["state"]["last_value"] == CATCHUP_BEGIN == context["data_interval_start"] assert state["state"]["last_value"].tz == UTC # end date assert r.incremental._incremental.end_value == context["data_interval_end"] assert r.incremental._incremental.end_value.tz == UTC - assert (r.incremental._incremental.end_value - state["state"]["initial_value"]) == datetime.timedelta(hours=24) + assert ( + r.incremental._incremental.end_value - state["state"]["initial_value"] + ) == datetime.timedelta(hours=24) # datetime.datetime coercion must be pendulum anyway @dlt.resource() - def incremental_datetime(updated_at = dlt.sources.incremental[datetime.datetime]("updated_at", allow_external_schedulers=True)): + def incremental_datetime( + updated_at=dlt.sources.incremental[datetime.datetime]( + "updated_at", allow_external_schedulers=True + ) + ): yield {"updated_at": CATCHUP_BEGIN, "state": updated_at.get_state()} r = incremental_datetime() state = list(r)[0] # must have UTC timezone - assert state["state"]["initial_value"] == CATCHUP_BEGIN == context["data_interval_start"] + assert ( + state["state"]["initial_value"] == CATCHUP_BEGIN == context["data_interval_start"] + ) assert state["state"]["initial_value"].tz == UTC # datetime.date coercion also works @dlt.resource() - def incremental_datetime(updated_at = dlt.sources.incremental[datetime.date]("updated_at", allow_external_schedulers=True)): - yield {"updated_at": ensure_pendulum_date(CATCHUP_BEGIN), "state": updated_at.get_state()} + def incremental_datetime( + updated_at=dlt.sources.incremental[datetime.date]( + "updated_at", allow_external_schedulers=True + ) + ): + yield { + "updated_at": ensure_pendulum_date(CATCHUP_BEGIN), + "state": updated_at.get_state(), + } r = incremental_datetime() state = list(r)[0] - assert state["state"]["initial_value"] == ensure_pendulum_date(context["data_interval_start"]) + assert state["state"]["initial_value"] == ensure_pendulum_date( + context["data_interval_start"] + ) assert isinstance(state["state"]["initial_value"], datetime.date) # coerce to int @dlt.resource() - def incremental_datetime(updated_at = dlt.sources.incremental[int]("updated_at", allow_external_schedulers=True)): + def incremental_datetime( + updated_at=dlt.sources.incremental[int]( + "updated_at", allow_external_schedulers=True + ) + ): yield {"updated_at": CATCHUP_BEGIN.int_timestamp, "state": updated_at.get_state()} r = incremental_datetime() state = list(r)[0] assert state["state"]["initial_value"] == context["data_interval_start"].int_timestamp - assert r.incremental._incremental.end_value == context["data_interval_end"].int_timestamp + assert ( + r.incremental._incremental.end_value == context["data_interval_end"].int_timestamp + ) # coerce to float @dlt.resource() - def incremental_datetime(updated_at = dlt.sources.incremental[float]("updated_at", allow_external_schedulers=True)): + def incremental_datetime( + updated_at=dlt.sources.incremental[float]( + "updated_at", allow_external_schedulers=True + ) + ): yield {"updated_at": CATCHUP_BEGIN.timestamp(), "state": updated_at.get_state()} r = incremental_datetime() @@ -102,14 +137,27 @@ def incremental_datetime(updated_at = dlt.sources.incremental[float]("updated_at # coerce to str @dlt.resource() - def incremental_datetime(updated_at = dlt.sources.incremental[str]("updated_at", allow_external_schedulers=True)): - yield {"updated_at": CATCHUP_BEGIN.in_tz("UTC").isoformat(), "state": updated_at.get_state()} + def incremental_datetime( + updated_at=dlt.sources.incremental[str]( + "updated_at", allow_external_schedulers=True + ) + ): + yield { + "updated_at": CATCHUP_BEGIN.in_tz("UTC").isoformat(), + "state": updated_at.get_state(), + } r = incremental_datetime() state = list(r)[0] # must have UTC timezone - assert state["state"]["initial_value"] == context["data_interval_start"].in_tz("UTC").isoformat() - assert r.incremental._incremental.end_value == context["data_interval_end"].in_tz("UTC").isoformat() + assert ( + state["state"]["initial_value"] + == context["data_interval_start"].in_tz("UTC").isoformat() + ) + assert ( + r.incremental._incremental.end_value + == context["data_interval_end"].in_tz("UTC").isoformat() + ) scheduled() @@ -122,11 +170,12 @@ def incremental_datetime(updated_at = dlt.sources.incremental[str]("updated_at", def test_no_next_execution_date() -> None: now = pendulum.now() - @dag(schedule=None, + @dag( + schedule=None, catchup=False, start_date=CATCHUP_BEGIN, default_args=default_args, - max_active_runs=1 + max_active_runs=1, ) def dag_no_schedule(): @task @@ -134,8 +183,15 @@ def unscheduled(): context = get_current_context() @dlt.resource() - def incremental_datetime(updated_at = dlt.sources.incremental[datetime.datetime]("updated_at", allow_external_schedulers=True)): - yield {"updated_at": context["data_interval_start"], "state": updated_at.get_state()} + def incremental_datetime( + updated_at=dlt.sources.incremental[datetime.datetime]( + "updated_at", allow_external_schedulers=True + ) + ): + yield { + "updated_at": context["data_interval_start"], + "state": updated_at.get_state(), + } r = incremental_datetime() state = list(r)[0] @@ -151,8 +207,15 @@ def incremental_datetime(updated_at = dlt.sources.incremental[datetime.datetime] # will be filtered out (now earlier than data_interval_start) @dlt.resource() - def incremental_datetime(updated_at = dlt.sources.incremental[datetime.datetime]("updated_at", allow_external_schedulers=True)): - yield {"updated_at": now.subtract(hours=1, seconds=1), "state": updated_at.get_state()} + def incremental_datetime( + updated_at=dlt.sources.incremental[datetime.datetime]( + "updated_at", allow_external_schedulers=True + ) + ): + yield { + "updated_at": now.subtract(hours=1, seconds=1), + "state": updated_at.get_state(), + } r = incremental_datetime() assert len(list(r)) == 0 @@ -172,18 +235,27 @@ def incremental_datetime(updated_at = dlt.sources.incremental[datetime.datetime] ti.run() assert ti.state == State.SUCCESS - @dag(schedule_interval='@daily', + @dag( + schedule_interval="@daily", start_date=CATCHUP_BEGIN, catchup=True, - default_args=default_args + default_args=default_args, ) def dag_daily_schedule(): @task def scheduled(): context = get_current_context() + @dlt.resource() - def incremental_datetime(updated_at = dlt.sources.incremental[datetime.datetime]("updated_at", allow_external_schedulers=True)): - yield {"updated_at": context["data_interval_start"], "state": updated_at.get_state()} + def incremental_datetime( + updated_at=dlt.sources.incremental[datetime.datetime]( + "updated_at", allow_external_schedulers=True + ) + ): + yield { + "updated_at": context["data_interval_start"], + "state": updated_at.get_state(), + } r = incremental_datetime() state = list(r)[0] @@ -208,7 +280,7 @@ def incremental_datetime(updated_at = dlt.sources.incremental[datetime.datetime] state=DagRunState.RUNNING, execution_date=now, run_type=DagRunType.MANUAL, - data_interval=(now, now) + data_interval=(now, now), ) dag_def.run(start_date=now, run_at_least_once=True) task_def = dag_def.task_dict["scheduled"] @@ -219,16 +291,20 @@ def incremental_datetime(updated_at = dlt.sources.incremental[datetime.datetime] def test_scheduler_pipeline_state() -> None: pipeline = dlt.pipeline( - pipeline_name="pipeline_dag_regular", dataset_name="mock_data_" + uniq_id(), destination="duckdb", credentials=":pipeline:") + pipeline_name="pipeline_dag_regular", + dataset_name="mock_data_" + uniq_id(), + destination="duckdb", + credentials=":pipeline:", + ) now = pendulum.now() - @dag(schedule_interval='@daily', + @dag( + schedule_interval="@daily", start_date=CATCHUP_BEGIN, catchup=False, - default_args=default_args + default_args=default_args, ) def dag_regular(): - @task def scheduled() -> None: r = existing_incremental() @@ -252,7 +328,7 @@ def scheduled() -> None: state=DagRunState.RUNNING, execution_date=now, run_type=DagRunType.MANUAL, - data_interval=(now, now) + data_interval=(now, now), ) dag_def.run(start_date=now, run_at_least_once=True) task_def = dag_def.task_dict["scheduled"] @@ -261,20 +337,13 @@ def scheduled() -> None: assert ti.state == State.SUCCESS assert "sources" not in pipeline.state - pipeline = pipeline.drop() dag_def.test(execution_date=CATCHUP_BEGIN) assert "sources" not in pipeline.state - @dag( - schedule=None, - start_date=CATCHUP_BEGIN, - catchup=False, - default_args=default_args - ) + @dag(schedule=None, start_date=CATCHUP_BEGIN, catchup=False, default_args=default_args) def dag_no_schedule(): - @task def unscheduled() -> None: r = existing_incremental() diff --git a/tests/helpers/airflow_tests/utils.py b/tests/helpers/airflow_tests/utils.py index c4b1fbb6e3..761c306eee 100644 --- a/tests/helpers/airflow_tests/utils.py +++ b/tests/helpers/airflow_tests/utils.py @@ -1,16 +1,17 @@ -import os import argparse +import os + import pytest from airflow.cli.commands.db_command import resetdb from airflow.configuration import conf from airflow.models.variable import Variable from dlt.common.configuration.container import Container -from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext from dlt.common.configuration.providers.toml import SECRETS_TOML_KEY +from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext -@pytest.fixture(scope='function', autouse=True) +@pytest.fixture(scope="function", autouse=True) def initialize_airflow_db(): setup_airflow() # backup context providers @@ -28,10 +29,10 @@ def initialize_airflow_db(): def setup_airflow() -> None: # Disable loading examples - conf.set('core', 'load_examples', 'False') + conf.set("core", "load_examples", "False") # Prepare the arguments for the initdb function args = argparse.Namespace() - args.backend = conf.get(section='core', key='sql_alchemy_conn') + args.backend = conf.get(section="core", key="sql_alchemy_conn") # Run Airflow resetdb before running any tests args.yes = True diff --git a/tests/helpers/dbt_tests/local/test_dbt_utils.py b/tests/helpers/dbt_tests/local/test_dbt_utils.py index 71e570bd69..aa7d1400dc 100644 --- a/tests/helpers/dbt_tests/local/test_dbt_utils.py +++ b/tests/helpers/dbt_tests/local/test_dbt_utils.py @@ -1,23 +1,38 @@ import os import shutil + import pytest +from tests.helpers.dbt_tests.utils import clone_jaffle_repo, load_test_case +from tests.utils import preserve_environ, test_storage + from dlt.common.configuration.resolve import resolve_configuration from dlt.common.configuration.utils import add_config_to_env from dlt.common.runners.synth_pickle import decode_obj from dlt.common.storages import FileStorage from dlt.common.utils import uniq_id - from dlt.destinations.postgres.configuration import PostgresCredentials -from dlt.helpers.dbt.dbt_utils import DBTProcessingError, initialize_dbt_logging, run_dbt_command, is_incremental_schema_out_of_sync_error - -from tests.utils import test_storage, preserve_environ -from tests.helpers.dbt_tests.utils import clone_jaffle_repo, load_test_case +from dlt.helpers.dbt.dbt_utils import ( + DBTProcessingError, + initialize_dbt_logging, + is_incremental_schema_out_of_sync_error, + run_dbt_command, +) def test_is_incremental_schema_out_of_sync_error() -> None: # in case of --fail-fast detect on a single run result - assert is_incremental_schema_out_of_sync_error(decode_obj(load_test_case("run_result_incremental_fail.pickle.hex"))) is True - assert is_incremental_schema_out_of_sync_error(decode_obj(load_test_case("run_execution_incremental_fail.pickle.hex"))) is True + assert ( + is_incremental_schema_out_of_sync_error( + decode_obj(load_test_case("run_result_incremental_fail.pickle.hex")) + ) + is True + ) + assert ( + is_incremental_schema_out_of_sync_error( + decode_obj(load_test_case("run_execution_incremental_fail.pickle.hex")) + ) + is True + ) assert is_incremental_schema_out_of_sync_error("AAA") is False @@ -27,24 +42,36 @@ def test_dbt_commands(test_storage: FileStorage) -> None: dbt_vars = {"dbt_schema": schema_name} # extract postgres creds from env, parse and emit - credentials = resolve_configuration(PostgresCredentials(), sections=("destination", "postgres")) + credentials = resolve_configuration(PostgresCredentials(), sections=("destination", "postgres")) add_config_to_env(credentials, ("dlt",)) repo_path = clone_jaffle_repo(test_storage) # copy profile - shutil.copy("./tests/helpers/dbt_tests/cases/profiles_invalid_credentials.yml", os.path.join(repo_path, "profiles.yml")) + shutil.copy( + "./tests/helpers/dbt_tests/cases/profiles_invalid_credentials.yml", + os.path.join(repo_path, "profiles.yml"), + ) # initialize logging global_args = initialize_dbt_logging("ERROR", False) # run deps, results are None assert run_dbt_command(repo_path, "deps", ".", global_args=global_args) is None # run list, results are list of strings - results = run_dbt_command(repo_path, "list", ".", global_args=global_args, package_vars=dbt_vars) + results = run_dbt_command( + repo_path, "list", ".", global_args=global_args, package_vars=dbt_vars + ) assert isinstance(results, list) assert len(results) == 28 assert "jaffle_shop.not_null_orders_amount" in results # run list for specific selector - results = run_dbt_command(repo_path, "list", ".", global_args=global_args, command_args=["-s", "jaffle_shop.not_null_orders_amount"], package_vars=dbt_vars) + results = run_dbt_command( + repo_path, + "list", + ".", + global_args=global_args, + command_args=["-s", "jaffle_shop.not_null_orders_amount"], + package_vars=dbt_vars, + ) assert len(results) == 1 assert results[0] == "jaffle_shop.not_null_orders_amount" # run debug, that will fail @@ -61,26 +88,46 @@ def test_dbt_commands(test_storage: FileStorage) -> None: # same for run with pytest.raises(DBTProcessingError) as dbt_err: - run_dbt_command(repo_path, "run", ".", global_args=global_args, package_vars=dbt_vars, command_args=["--fail-fast", "--full-refresh"]) + run_dbt_command( + repo_path, + "run", + ".", + global_args=global_args, + package_vars=dbt_vars, + command_args=["--fail-fast", "--full-refresh"], + ) # in that case test results are bool, not list of tests runs assert dbt_err.value.command == "run" # copy a correct profile - shutil.copy("./tests/helpers/dbt_tests/cases/profiles.yml", os.path.join(repo_path, "profiles.yml")) + shutil.copy( + "./tests/helpers/dbt_tests/cases/profiles.yml", os.path.join(repo_path, "profiles.yml") + ) - results = run_dbt_command(repo_path, "seed", ".", global_args=global_args, package_vars=dbt_vars) + results = run_dbt_command( + repo_path, "seed", ".", global_args=global_args, package_vars=dbt_vars + ) assert isinstance(results, list) assert len(results) == 3 assert results[0].model_name == "raw_customers" assert results[0].status == "success" - results = run_dbt_command(repo_path, "run", ".", global_args=global_args, package_vars=dbt_vars, command_args=["--fail-fast", "--full-refresh"]) + results = run_dbt_command( + repo_path, + "run", + ".", + global_args=global_args, + package_vars=dbt_vars, + command_args=["--fail-fast", "--full-refresh"], + ) assert isinstance(results, list) assert len(results) == 5 assert results[-1].model_name == "orders" assert results[-1].status == "success" - results = run_dbt_command(repo_path, "test", ".", global_args=global_args, package_vars=dbt_vars) + results = run_dbt_command( + repo_path, "test", ".", global_args=global_args, package_vars=dbt_vars + ) assert isinstance(results, list) assert len(results) == 20 assert results[-1].status == "pass" diff --git a/tests/helpers/dbt_tests/local/test_runner_destinations.py b/tests/helpers/dbt_tests/local/test_runner_destinations.py index 5e8d8d754a..4e258c90fd 100644 --- a/tests/helpers/dbt_tests/local/test_runner_destinations.py +++ b/tests/helpers/dbt_tests/local/test_runner_destinations.py @@ -1,20 +1,25 @@ import os from typing import Any -from git import GitCommandError + import pytest +from git import GitCommandError +from tests.common.utils import load_secret, modify_and_commit_file +from tests.helpers.dbt_tests.local.utils import ( + DBTDestinationInfo, + setup_rasa_runner, + setup_rasa_runner_client, +) +from tests.helpers.dbt_tests.utils import find_run_result +from tests.utils import TEST_STORAGE_ROOT, clean_test_storage, preserve_environ from dlt.common.utils import uniq_id - from dlt.helpers.dbt.dbt_utils import DBTProcessingError from dlt.helpers.dbt.exceptions import PrerequisitesException -from tests.helpers.dbt_tests.utils import find_run_result - -from tests.utils import TEST_STORAGE_ROOT, clean_test_storage, preserve_environ -from tests.common.utils import modify_and_commit_file, load_secret -from tests.helpers.dbt_tests.local.utils import setup_rasa_runner_client, setup_rasa_runner, DBTDestinationInfo DESTINATION_DATASET_NAME = "test_" + uniq_id() -ALL_DBT_DESTINATIONS = [DBTDestinationInfo("bigquery", "CREATE TABLE", "MERGE")] # DBTDestinationInfo("redshift", "SELECT", "INSERT") +ALL_DBT_DESTINATIONS = [ + DBTDestinationInfo("bigquery", "CREATE TABLE", "MERGE") +] # DBTDestinationInfo("redshift", "SELECT", "INSERT") ALL_DBT_DESTINATIONS_NAMES = ["bigquery"] # "redshift", @@ -27,29 +32,36 @@ def destination_info(request: Any) -> DBTDestinationInfo: def test_setup_dbt_runner() -> None: - runner = setup_rasa_runner("redshift", "carbon_bot_3", override_values={ - "package_additional_vars": {"add_var_name": "add_var_value"}, - "runtime": { - "log_format": "JSON", - "log_level": "INFO" - } - }) + runner = setup_rasa_runner( + "redshift", + "carbon_bot_3", + override_values={ + "package_additional_vars": {"add_var_name": "add_var_value"}, + "runtime": {"log_format": "JSON", "log_level": "INFO"}, + }, + ) assert runner.package_path.endswith("rasa_semantic_schema") assert runner.config.package_profile_name == "redshift" assert runner.config.package_additional_vars == {"add_var_name": "add_var_value"} - assert runner._get_package_vars() == {"source_dataset_name": "carbon_bot_3", "add_var_name": "add_var_value"} + assert runner._get_package_vars() == { + "source_dataset_name": "carbon_bot_3", + "add_var_name": "add_var_value", + } assert runner.source_dataset_name == "carbon_bot_3" assert runner.cloned_package_name == "rasa_semantic_schema" assert runner.working_dir == TEST_STORAGE_ROOT def test_initialize_package_wrong_key() -> None: - runner = setup_rasa_runner("redshift", override_values={ - # private repo - "package_location": "git@github.com:dlt-hub/rasa_bot_experiments.git", - "package_repository_branch": None, - "package_repository_ssh_key": load_secret("DEPLOY_KEY") - }) + runner = setup_rasa_runner( + "redshift", + override_values={ + # private repo + "package_location": "git@github.com:dlt-hub/rasa_bot_experiments.git", + "package_repository_branch": None, + "package_repository_ssh_key": load_secret("DEPLOY_KEY"), + }, + ) with pytest.raises(GitCommandError) as gce: runner.run_all() @@ -60,12 +72,17 @@ def test_reinitialize_package() -> None: runner = setup_rasa_runner("redshift") runner.ensure_newest_package() # mod the package - readme_path, _ = modify_and_commit_file(runner.package_path, "README.md", content=runner.config.package_profiles_dir) + readme_path, _ = modify_and_commit_file( + runner.package_path, "README.md", content=runner.config.package_profiles_dir + ) assert os.path.isfile(readme_path) # this will wipe out old package and clone again runner.ensure_newest_package() # we have old file back - assert runner.repo_storage.load(f"{runner.cloned_package_name}/README.md") != runner.config.package_profiles_dir + assert ( + runner.repo_storage.load(f"{runner.cloned_package_name}/README.md") + != runner.config.package_profiles_dir + ) def test_dbt_test_no_raw_schema(destination_info: DBTDestinationInfo) -> None: @@ -76,7 +93,7 @@ def test_dbt_test_no_raw_schema(destination_info: DBTDestinationInfo) -> None: runner.run_all( destination_dataset_name=DESTINATION_DATASET_NAME, run_params=["--fail-fast", "--full-refresh"], - source_tests_selector="tag:prerequisites" + source_tests_selector="tag:prerequisites", ) assert isinstance(prq_ex.value.args[0], DBTProcessingError) @@ -89,16 +106,21 @@ def test_dbt_run_full_refresh(destination_info: DBTDestinationInfo) -> None: destination_dataset_name=DESTINATION_DATASET_NAME, run_params=["--fail-fast", "--full-refresh"], additional_vars={"user_id": "metadata__user_id"}, - source_tests_selector="tag:prerequisites" + source_tests_selector="tag:prerequisites", ) assert all(r.message.startswith(destination_info.replace_strategy) for r in run_results) is True assert find_run_result(run_results, "_loads") is not None # all models must be SELECT as we do full refresh - assert find_run_result(run_results, "_loads").message.startswith(destination_info.replace_strategy) + assert find_run_result(run_results, "_loads").message.startswith( + destination_info.replace_strategy + ) assert all(m.message.startswith(destination_info.replace_strategy) for m in run_results) is True # all tests should pass - runner.test(destination_dataset_name=DESTINATION_DATASET_NAME, additional_vars={"user_id": "metadata__user_id"}) + runner.test( + destination_dataset_name=DESTINATION_DATASET_NAME, + additional_vars={"user_id": "metadata__user_id"}, + ) def test_dbt_run_error_via_additional_vars(destination_info: DBTDestinationInfo) -> None: @@ -110,8 +132,11 @@ def test_dbt_run_error_via_additional_vars(destination_info: DBTDestinationInfo) runner.run_all( destination_dataset_name=DESTINATION_DATASET_NAME, run_params=["--fail-fast", "--full-refresh"], - additional_vars={"user_id": "metadata__user_id", "external_session_id": "metadata__sess_id"}, - source_tests_selector="tag:prerequisites" + additional_vars={ + "user_id": "metadata__user_id", + "external_session_id": "metadata__sess_id", + }, + source_tests_selector="tag:prerequisites", ) stg_interactions = find_run_result(dbt_err.value.run_results, "stg_interactions") assert "metadata__sess_id" in stg_interactions.message @@ -127,7 +152,7 @@ def test_dbt_incremental_schema_out_of_sync_error(destination_info: DBTDestinati run_params=["--fail-fast", "--model", "+interactions"], # remove all counter metrics additional_vars={"count_metrics": []}, - source_tests_selector="tag:prerequisites" + source_tests_selector="tag:prerequisites", ) # generate schema error on incremental load @@ -140,7 +165,9 @@ def test_dbt_incremental_schema_out_of_sync_error(destination_info: DBTDestinati ) # metrics: StrStr = get_metrics_from_prometheus([runner.model_exec_info])["dbtrunner_model_status_info"] # full refresh on interactions - assert find_run_result(results, "interactions").message.startswith(destination_info.replace_strategy) + assert find_run_result(results, "interactions").message.startswith( + destination_info.replace_strategy + ) # now incremental load should happen results = runner.run( diff --git a/tests/helpers/dbt_tests/local/utils.py b/tests/helpers/dbt_tests/local/utils.py index 2993753a0c..62c9e8c4ff 100644 --- a/tests/helpers/dbt_tests/local/utils.py +++ b/tests/helpers/dbt_tests/local/utils.py @@ -1,19 +1,16 @@ - import contextlib from typing import Iterator, NamedTuple +from tests.load.utils import cm_yield_client, delete_dataset +from tests.utils import TEST_STORAGE_ROOT, init_test_logging + from dlt.common.configuration.utils import add_config_to_env from dlt.common.destination.reference import DestinationClientDwhConfiguration from dlt.common.runners import Venv from dlt.common.typing import StrAny - from dlt.helpers.dbt.configuration import DBTRunnerConfiguration from dlt.helpers.dbt.runner import DBTPackageRunner, create_runner -from tests.load.utils import cm_yield_client, delete_dataset -from tests.utils import TEST_STORAGE_ROOT, init_test_logging - - FIXTURES_DATASET_NAME = "test_fixture_carbon_bot_session_cases" @@ -23,10 +20,13 @@ class DBTDestinationInfo(NamedTuple): incremental_strategy: str -def setup_rasa_runner(profile_name: str, dataset_name: str = None, override_values: StrAny = None) -> DBTPackageRunner: - +def setup_rasa_runner( + profile_name: str, dataset_name: str = None, override_values: StrAny = None +) -> DBTPackageRunner: C = DBTRunnerConfiguration() - C.package_location = "https://github.com/scale-vector/rasa_semantic_schema.git" # "/home/rudolfix/src/dbt/rasa_semantic_schema" + C.package_location = ( # "/home/rudolfix/src/dbt/rasa_semantic_schema" + "https://github.com/scale-vector/rasa_semantic_schema.git" + ) C.package_repository_branch = "dlt-dbt-runner-ci-do-not-delete" # override values including the defaults above @@ -41,7 +41,7 @@ def setup_rasa_runner(profile_name: str, dataset_name: str = None, override_valu DestinationClientDwhConfiguration(dataset_name=dataset_name or FIXTURES_DATASET_NAME), TEST_STORAGE_ROOT, package_profile_name=profile_name, - config=C + config=C, ) # now C is resolved init_test_logging(C.runtime) @@ -49,7 +49,9 @@ def setup_rasa_runner(profile_name: str, dataset_name: str = None, override_valu @contextlib.contextmanager -def setup_rasa_runner_client(destination_name: str, destination_dataset_name: str) -> Iterator[None]: +def setup_rasa_runner_client( + destination_name: str, destination_dataset_name: str +) -> Iterator[None]: with cm_yield_client(destination_name, FIXTURES_DATASET_NAME) as client: # emit environ so credentials are passed to dbt profile add_config_to_env(client.config, ("DLT",)) diff --git a/tests/helpers/dbt_tests/test_runner_dbt_versions.py b/tests/helpers/dbt_tests/test_runner_dbt_versions.py index 9a43489688..81d5b5bea4 100644 --- a/tests/helpers/dbt_tests/test_runner_dbt_versions.py +++ b/tests/helpers/dbt_tests/test_runner_dbt_versions.py @@ -1,29 +1,37 @@ import os import shutil import tempfile -from typing import Any, Iterator, List from functools import partial -from typing import Tuple +from typing import Any, Iterator, List, Tuple + import pytest -from dlt.common import json +from tests.helpers.dbt_tests.utils import ( + JAFFLE_SHOP_REPO, + assert_jaffle_completed, + clone_jaffle_repo, + find_run_result, +) +from tests.load.utils import cm_yield_client_with_storage, yield_client_with_storage +from tests.utils import preserve_environ, test_storage +from dlt.common import json from dlt.common.configuration import resolve_configuration -from dlt.common.configuration.specs import GcpServiceAccountCredentials, CredentialsWithDefault -from dlt.common.storages.file_storage import FileStorage +from dlt.common.configuration.specs import CredentialsWithDefault, GcpServiceAccountCredentials from dlt.common.runners import Venv from dlt.common.runners.synth_pickle import decode_obj, encode_obj +from dlt.common.storages.file_storage import FileStorage from dlt.common.typing import AnyFun - -from dlt.destinations.postgres.postgres import PostgresClient from dlt.destinations.bigquery import BigQueryClientConfiguration +from dlt.destinations.postgres.postgres import PostgresClient +from dlt.helpers.dbt import ( + DEFAULT_DBT_VERSION, + _create_dbt_deps, + _default_profile_name, + create_venv, + package_runner, +) from dlt.helpers.dbt.configuration import DBTRunnerConfiguration -from dlt.helpers.dbt.exceptions import PrerequisitesException, DBTProcessingError -from dlt.helpers.dbt import package_runner, create_venv, _create_dbt_deps, _default_profile_name, DEFAULT_DBT_VERSION - -from tests.helpers.dbt_tests.utils import JAFFLE_SHOP_REPO, assert_jaffle_completed, clone_jaffle_repo, find_run_result - -from tests.utils import test_storage, preserve_environ -from tests.load.utils import yield_client_with_storage, cm_yield_client_with_storage +from dlt.helpers.dbt.exceptions import DBTProcessingError, PrerequisitesException @pytest.fixture(scope="function") @@ -40,14 +48,14 @@ def client() -> Iterator[PostgresClient]: ("postgres", None), ("snowflake", "1.4.0"), ("snowflake", "1.5.2"), - ("snowflake", None) + ("snowflake", None), ] PACKAGE_IDS = [ - f"{destination}-venv-{version}" - if version else f"{destination}-local" + f"{destination}-venv-{version}" if version else f"{destination}-local" for destination, version in PACKAGE_PARAMS ] + @pytest.fixture(scope="module", params=PACKAGE_PARAMS, ids=PACKAGE_IDS) def dbt_package_f(request: Any) -> Iterator[Tuple[str, AnyFun]]: destination_name, version = request.param @@ -60,7 +68,11 @@ def dbt_package_f(request: Any) -> Iterator[Tuple[str, AnyFun]]: def test_infer_venv_deps() -> None: requirements = _create_dbt_deps(["postgres", "bigquery"]) - assert requirements[:3] == [f"dbt-postgres{DEFAULT_DBT_VERSION}", f"dbt-bigquery{DEFAULT_DBT_VERSION}", f"dbt-core{DEFAULT_DBT_VERSION}"] + assert requirements[:3] == [ + f"dbt-postgres{DEFAULT_DBT_VERSION}", + f"dbt-bigquery{DEFAULT_DBT_VERSION}", + f"dbt-core{DEFAULT_DBT_VERSION}", + ] # should lead to here assert os.path.isdir(requirements[-1]) # provide exact version @@ -88,7 +100,10 @@ def test_dbt_configuration() -> None: # check names normalized C: DBTRunnerConfiguration = resolve_configuration( DBTRunnerConfiguration(), - explicit_value={"package_repository_ssh_key": "---NO NEWLINE---", "package_location": "/var/local"} + explicit_value={ + "package_repository_ssh_key": "---NO NEWLINE---", + "package_location": "/var/local", + }, ) assert C.package_repository_ssh_key == "---NO NEWLINE---\n" assert C.package_additional_vars is None @@ -97,7 +112,11 @@ def test_dbt_configuration() -> None: C = resolve_configuration( DBTRunnerConfiguration(), - explicit_value={"package_repository_ssh_key": "---WITH NEWLINE---\n", "package_location": "/var/local", "package_additional_vars": {"a": 1}} + explicit_value={ + "package_repository_ssh_key": "---WITH NEWLINE---\n", + "package_location": "/var/local", + "package_additional_vars": {"a": 1}, + }, ) assert C.package_repository_ssh_key == "---WITH NEWLINE---\n" assert C.package_additional_vars == {"a": 1} @@ -107,9 +126,9 @@ def test_dbt_run_exception_pickle() -> None: obj = decode_obj( encode_obj( DBTProcessingError("test", "A", "B"), # type: ignore[arg-type] - ignore_pickle_errors=False + ignore_pickle_errors=False, ), - ignore_pickle_errors=False + ignore_pickle_errors=False, ) assert obj.command == "test" assert obj.run_results == "A" @@ -118,12 +137,21 @@ def test_dbt_run_exception_pickle() -> None: def test_runner_setup(client: PostgresClient, test_storage: FileStorage) -> None: - add_vars = {"source_dataset_name": "overwritten", "destination_dataset_name": "destination", "schema_name": "this_Schema"} + add_vars = { + "source_dataset_name": "overwritten", + "destination_dataset_name": "destination", + "schema_name": "this_Schema", + } os.environ["DBT_PACKAGE_RUNNER__PACKAGE_ADDITIONAL_VARS"] = json.dumps(add_vars) os.environ["AUTO_FULL_REFRESH_WHEN_OUT_OF_SYNC"] = "False" os.environ["DBT_PACKAGE_RUNNER__RUNTIME__LOG_LEVEL"] = "CRITICAL" test_storage.create_folder("jaffle") - r = package_runner(Venv.restore_current(), client.config, test_storage.make_full_path("jaffle"), JAFFLE_SHOP_REPO) + r = package_runner( + Venv.restore_current(), + client.config, + test_storage.make_full_path("jaffle"), + JAFFLE_SHOP_REPO, + ) # runner settings assert r.credentials is client.config assert r.working_dir == test_storage.make_full_path("jaffle") @@ -139,55 +167,76 @@ def test_runner_setup(client: PostgresClient, test_storage: FileStorage) -> None assert r.config.runtime.log_level == "CRITICAL" assert r.config.auto_full_refresh_when_out_of_sync is False - assert r._get_package_vars() == {"source_dataset_name": client.config.dataset_name, "destination_dataset_name": "destination", "schema_name": "this_Schema"} - assert r._get_package_vars(destination_dataset_name="dest_test_123") == {"source_dataset_name": client.config.dataset_name, "destination_dataset_name": "dest_test_123", "schema_name": "this_Schema"} + assert r._get_package_vars() == { + "source_dataset_name": client.config.dataset_name, + "destination_dataset_name": "destination", + "schema_name": "this_Schema", + } + assert r._get_package_vars(destination_dataset_name="dest_test_123") == { + "source_dataset_name": client.config.dataset_name, + "destination_dataset_name": "dest_test_123", + "schema_name": "this_Schema", + } assert r._get_package_vars(additional_vars={"add": 1, "schema_name": "ovr"}) == { - "source_dataset_name": client.config.dataset_name, - "destination_dataset_name": "destination", "schema_name": "ovr", - "add": 1 - } + "source_dataset_name": client.config.dataset_name, + "destination_dataset_name": "destination", + "schema_name": "ovr", + "add": 1, + } -def test_runner_dbt_destinations(test_storage: FileStorage, dbt_package_f: Tuple[str, AnyFun]) -> None: +def test_runner_dbt_destinations( + test_storage: FileStorage, dbt_package_f: Tuple[str, AnyFun] +) -> None: destination_name, dbt_func = dbt_package_f with cm_yield_client_with_storage(destination_name) as client: - jaffle_base_dir = 'jaffle_' + destination_name + jaffle_base_dir = "jaffle_" + destination_name test_storage.create_folder(jaffle_base_dir) results = dbt_func( client.config, test_storage.make_full_path(jaffle_base_dir), JAFFLE_SHOP_REPO ).run_all(["--fail-fast", "--full-refresh"]) - assert_jaffle_completed(test_storage, results, destination_name, jaffle_dir=jaffle_base_dir + '/jaffle_shop') + assert_jaffle_completed( + test_storage, results, destination_name, jaffle_dir=jaffle_base_dir + "/jaffle_shop" + ) -def test_run_jaffle_from_folder_incremental(test_storage: FileStorage, dbt_package_f: Tuple[str, AnyFun]) -> None: +def test_run_jaffle_from_folder_incremental( + test_storage: FileStorage, dbt_package_f: Tuple[str, AnyFun] +) -> None: destination_name, dbt_func = dbt_package_f with cm_yield_client_with_storage(destination_name) as client: repo_path = clone_jaffle_repo(test_storage) # copy model with error into package to force run error in model - shutil.copy("./tests/helpers/dbt_tests/cases/jaffle_customers_incremental.sql", os.path.join(repo_path, "models", "customers.sql")) + shutil.copy( + "./tests/helpers/dbt_tests/cases/jaffle_customers_incremental.sql", + os.path.join(repo_path, "models", "customers.sql"), + ) results = dbt_func(client.config, None, repo_path).run_all(run_params=None) assert_jaffle_completed(test_storage, results, destination_name, jaffle_dir="jaffle_shop") results = dbt_func(client.config, None, repo_path).run_all() # out of 100 records 0 was inserted customers = find_run_result(results, "customers") - assert customers.message in JAFFLE_MESSAGES_INCREMENTAL[destination_name]['customers'] + assert customers.message in JAFFLE_MESSAGES_INCREMENTAL[destination_name]["customers"] # change the column name. that will force dbt to fail (on_schema_change='fail'). the runner should do a full refresh - shutil.copy("./tests/helpers/dbt_tests/cases/jaffle_customers_incremental_new_column.sql", os.path.join(repo_path, "models", "customers.sql")) + shutil.copy( + "./tests/helpers/dbt_tests/cases/jaffle_customers_incremental_new_column.sql", + os.path.join(repo_path, "models", "customers.sql"), + ) results = dbt_func(client.config, None, repo_path).run_all(run_params=None) assert_jaffle_completed(test_storage, results, destination_name, jaffle_dir="jaffle_shop") -def test_run_jaffle_fail_prerequisites(test_storage: FileStorage, dbt_package_f: Tuple[str, AnyFun]) -> None: +def test_run_jaffle_fail_prerequisites( + test_storage: FileStorage, dbt_package_f: Tuple[str, AnyFun] +) -> None: destination_name, dbt_func = dbt_package_f with cm_yield_client_with_storage(destination_name) as client: test_storage.create_folder("jaffle") # we run all the tests before tables are materialized with pytest.raises(PrerequisitesException) as pr_exc: dbt_func( - client.config, - test_storage.make_full_path("jaffle"), - JAFFLE_SHOP_REPO - ).run_all(["--fail-fast", "--full-refresh"], source_tests_selector="*") + client.config, test_storage.make_full_path("jaffle"), JAFFLE_SHOP_REPO + ).run_all(["--fail-fast", "--full-refresh"], source_tests_selector="*") proc_err = pr_exc.value.args[0] assert isinstance(proc_err, DBTProcessingError) customers = find_run_result(proc_err.run_results, "unique_customers_customer_id") @@ -196,23 +245,32 @@ def test_run_jaffle_fail_prerequisites(test_storage: FileStorage, dbt_package_f: assert all(r.status == "error" for r in proc_err.run_results) -def test_run_jaffle_invalid_run_args(test_storage: FileStorage, dbt_package_f: Tuple[str, AnyFun]) -> None: +def test_run_jaffle_invalid_run_args( + test_storage: FileStorage, dbt_package_f: Tuple[str, AnyFun] +) -> None: destination_name, dbt_func = dbt_package_f with cm_yield_client_with_storage(destination_name) as client: test_storage.create_folder("jaffle") # we run all the tests before tables are materialized with pytest.raises(DBTProcessingError) as pr_exc: - dbt_func(client.config, test_storage.make_full_path("jaffle"), JAFFLE_SHOP_REPO).run_all(["--wrong_flag"]) + dbt_func( + client.config, test_storage.make_full_path("jaffle"), JAFFLE_SHOP_REPO + ).run_all(["--wrong_flag"]) # dbt < 1.5 raises systemexit, dbt >= 1.5 just returns success False assert isinstance(pr_exc.value.dbt_results, SystemExit) or pr_exc.value.dbt_results is None -def test_run_jaffle_failed_run(test_storage: FileStorage, dbt_package_f: Tuple[str, AnyFun]) -> None: +def test_run_jaffle_failed_run( + test_storage: FileStorage, dbt_package_f: Tuple[str, AnyFun] +) -> None: destination_name, dbt_func = dbt_package_f with cm_yield_client_with_storage(destination_name) as client: repo_path = clone_jaffle_repo(test_storage) # copy model with error into package to force run error in model - shutil.copy("./tests/helpers/dbt_tests/cases/jaffle_customers_with_error.sql", os.path.join(repo_path, "models", "customers.sql")) + shutil.copy( + "./tests/helpers/dbt_tests/cases/jaffle_customers_with_error.sql", + os.path.join(repo_path, "models", "customers.sql"), + ) with pytest.raises(DBTProcessingError) as pr_exc: dbt_func(client.config, None, repo_path).run_all(run_params=None) assert len(pr_exc.value.run_results) == 5 @@ -221,11 +279,9 @@ def test_run_jaffle_failed_run(test_storage: FileStorage, dbt_package_f: Tuple[s JAFFLE_MESSAGES_INCREMENTAL = { - 'snowflake': { + "snowflake": { # Different message per version - 'customers': ('SUCCESS 1', 'SUCCESS 100'), + "customers": ("SUCCESS 1", "SUCCESS 100"), }, - 'postgres': { - 'customers': ("INSERT 0 100", ) - } + "postgres": {"customers": ("INSERT 0 100",)}, } diff --git a/tests/helpers/dbt_tests/utils.py b/tests/helpers/dbt_tests/utils.py index 65e0eae2cb..3045a263c3 100644 --- a/tests/helpers/dbt_tests/utils.py +++ b/tests/helpers/dbt_tests/utils.py @@ -1,23 +1,23 @@ import os from typing import List, Sequence -from dlt.common.storages import FileStorage from dlt.common.git import clone_repo +from dlt.common.storages import FileStorage from dlt.helpers.dbt.exceptions import DBTNodeResult JAFFLE_SHOP_REPO = "https://github.com/dbt-labs/jaffle_shop.git" TEST_CASES_PATH = "./tests/helpers/dbt_tests/cases/" JAFFLE_RESULT_MESSAGES = { - 'postgres': { - 'stg_orders': 'CREATE VIEW', - 'customers': 'SELECT 100', + "postgres": { + "stg_orders": "CREATE VIEW", + "customers": "SELECT 100", }, # Snowflake only returns generic success messages - 'snowflake': { - 'stg_orders': 'SUCCESS 1', - 'customers': 'SUCCESS 1', - } + "snowflake": { + "stg_orders": "SUCCESS 1", + "customers": "SUCCESS 1", + }, } @@ -33,17 +33,24 @@ def find_run_result(results: Sequence[DBTNodeResult], model_name: str) -> DBTNod def clone_jaffle_repo(test_storage: FileStorage) -> str: repo_path = test_storage.make_full_path("jaffle_shop") # clone jaffle shop for dbt 1.0.0 - clone_repo(JAFFLE_SHOP_REPO, repo_path, with_git_command=None, branch="main").close() # core-v1.0.0 + clone_repo( + JAFFLE_SHOP_REPO, repo_path, with_git_command=None, branch="main" + ).close() # core-v1.0.0 return repo_path -def assert_jaffle_completed(test_storage: FileStorage, results: List[DBTNodeResult], destination_name: str, jaffle_dir: str = "jaffle/jaffle_shop") -> None: +def assert_jaffle_completed( + test_storage: FileStorage, + results: List[DBTNodeResult], + destination_name: str, + jaffle_dir: str = "jaffle/jaffle_shop", +) -> None: assert len(results) == 5 assert all(r.status == "success" for r in results) - stg_orders = find_run_result(results, 'stg_orders') - assert stg_orders.message == JAFFLE_RESULT_MESSAGES[destination_name]['stg_orders'] + stg_orders = find_run_result(results, "stg_orders") + assert stg_orders.message == JAFFLE_RESULT_MESSAGES[destination_name]["stg_orders"] customers = find_run_result(results, "customers") - assert customers.message == JAFFLE_RESULT_MESSAGES[destination_name]['customers'] + assert customers.message == JAFFLE_RESULT_MESSAGES[destination_name]["customers"] # `run_dbt` has injected credentials into environ. make sure that credentials were removed assert "CREDENTIALS__PASSWORD" not in os.environ # make sure jaffle_shop was cloned into right dir diff --git a/tests/helpers/providers/test_google_secrets_provider.py b/tests/helpers/providers/test_google_secrets_provider.py index 814b995f5e..26bc593029 100644 --- a/tests/helpers/providers/test_google_secrets_provider.py +++ b/tests/helpers/providers/test_google_secrets_provider.py @@ -1,18 +1,16 @@ import dlt from dlt import TSecretValue from dlt.common import logger -from dlt.common.configuration.specs import GcpServiceAccountCredentials -from dlt.common.configuration.providers import GoogleSecretsProvider from dlt.common.configuration.accessors import secrets +from dlt.common.configuration.providers import GoogleSecretsProvider +from dlt.common.configuration.resolve import resolve_configuration +from dlt.common.configuration.specs import GcpServiceAccountCredentials, known_sections from dlt.common.configuration.specs.config_providers_context import _google_secrets_provider from dlt.common.configuration.specs.run_configuration import RunConfiguration -from dlt.common.configuration.specs import GcpServiceAccountCredentials, known_sections from dlt.common.typing import AnyType from dlt.common.utils import custom_environ -from dlt.common.configuration.resolve import resolve_configuration - -DLT_SECRETS_TOML_CONTENT=""" +DLT_SECRETS_TOML_CONTENT = """ secret_value=2137 api.secret_key="ABCD" @@ -26,7 +24,9 @@ def test_regular_keys() -> None: logger.init_logging(RunConfiguration()) # copy bigquery credentials into providers credentials - c = resolve_configuration(GcpServiceAccountCredentials(), sections=(known_sections.DESTINATION, "bigquery")) + c = resolve_configuration( + GcpServiceAccountCredentials(), sections=(known_sections.DESTINATION, "bigquery") + ) secrets[f"{known_sections.PROVIDERS}.google_secrets.credentials"] = dict(c) # c = secrets.get("destination.credentials", GcpServiceAccountCredentials) # print(c) @@ -37,22 +37,46 @@ def test_regular_keys() -> None: # load secrets toml per pipeline provider.get_value("secret_key", AnyType, "pipeline", "api") - assert provider.get_value("secret_key", AnyType, "pipeline", "api") == ("ABCDE", "pipeline-api-secret_key") - assert provider.get_value("credentials", AnyType, "pipeline") == ({"project_id": "mock-credentials-pipeline"}, "pipeline-credentials") + assert provider.get_value("secret_key", AnyType, "pipeline", "api") == ( + "ABCDE", + "pipeline-api-secret_key", + ) + assert provider.get_value("credentials", AnyType, "pipeline") == ( + {"project_id": "mock-credentials-pipeline"}, + "pipeline-credentials", + ) # load source test_source which should also load "sources", "pipeline-sources", "sources-test_source" and "pipeline-sources-test_source" - assert provider.get_value("only_pipeline", AnyType, "pipeline", "sources", "test_source") == ("ONLY", "pipeline-sources-test_source-only_pipeline") + assert provider.get_value("only_pipeline", AnyType, "pipeline", "sources", "test_source") == ( + "ONLY", + "pipeline-sources-test_source-only_pipeline", + ) # we set sources.test_source.secret_prop_1="OVR_A" in pipeline-sources to override value in sources - assert provider.get_value("secret_prop_1", AnyType, None, "sources", "test_source") == ("OVR_A", "sources-test_source-secret_prop_1") + assert provider.get_value("secret_prop_1", AnyType, None, "sources", "test_source") == ( + "OVR_A", + "sources-test_source-secret_prop_1", + ) # get element unique to pipeline-sources - assert provider.get_value("only_pipeline_top", AnyType, "pipeline", "sources") == ("TOP", "pipeline-sources-only_pipeline_top") + assert provider.get_value("only_pipeline_top", AnyType, "pipeline", "sources") == ( + "TOP", + "pipeline-sources-only_pipeline_top", + ) # get element unique to sources - assert provider.get_value("all_sources_present", AnyType, None, "sources") == (True, "sources-all_sources_present") + assert provider.get_value("all_sources_present", AnyType, None, "sources") == ( + True, + "sources-all_sources_present", + ) # get element unique to sources-test_source - assert provider.get_value("secret_prop_2", AnyType, None, "sources", "test_source") == ("B", "sources-test_source-secret_prop_2") + assert provider.get_value("secret_prop_2", AnyType, None, "sources", "test_source") == ( + "B", + "sources-test_source-secret_prop_2", + ) # this destination will not be found - assert provider.get_value("url", AnyType, "pipeline", "destination", "filesystem") == (None, "pipeline-destination-filesystem-url") + assert provider.get_value("url", AnyType, "pipeline", "destination", "filesystem") == ( + None, + "pipeline-destination-filesystem-url", + ) # try a single secret value assert provider.get_value("secret", TSecretValue, "pipeline") == (None, "pipeline-secret") @@ -63,7 +87,10 @@ def test_regular_keys() -> None: assert provider.get_value("secret", str, "pipeline") == (None, "pipeline-secret") provider.only_secrets = False # non secrets allowed - assert provider.get_value("secret", str, "pipeline") == ("THIS IS SECRET VALUE", "pipeline-secret") + assert provider.get_value("secret", str, "pipeline") == ( + "THIS IS SECRET VALUE", + "pipeline-secret", + ) # request json # print(provider._toml.as_string()) @@ -73,12 +100,12 @@ def test_regular_keys() -> None: # def test_special_sections() -> None: # pass - # with custom_environ({"GOOGLE_APPLICATION_CREDENTIALS": "_secrets/pipelines-ci-secrets-65c0517a9b30.json"}): - # provider = _google_secrets_provider() - # print(provider.get_value("credentials", GcpServiceAccountCredentials, None, "destination", "bigquery")) - # print(provider._toml.as_string()) - # print(provider.get_value("subdomain", AnyType, None, "sources", "zendesk", "credentials")) - # print(provider._toml.as_string()) +# with custom_environ({"GOOGLE_APPLICATION_CREDENTIALS": "_secrets/pipelines-ci-secrets-65c0517a9b30.json"}): +# provider = _google_secrets_provider() +# print(provider.get_value("credentials", GcpServiceAccountCredentials, None, "destination", "bigquery")) +# print(provider._toml.as_string()) +# print(provider.get_value("subdomain", AnyType, None, "sources", "zendesk", "credentials")) +# print(provider._toml.as_string()) # def test_provider_insertion() -> None: @@ -88,4 +115,3 @@ def test_regular_keys() -> None: # }): # # - diff --git a/tests/load/bigquery/test_bigquery_client.py b/tests/load/bigquery/test_bigquery_client.py index 3af9234c89..85a35e2dbf 100644 --- a/tests/load/bigquery/test_bigquery_client.py +++ b/tests/load/bigquery/test_bigquery_client.py @@ -1,26 +1,36 @@ -import os import base64 +import os from copy import copy from typing import Any, Iterator, Tuple + import pytest +from tests.common.configuration.utils import environment +from tests.common.utils import json_case_path as common_json_case_path +from tests.load.utils import ( + cm_yield_client_with_storage, + expect_load_file, + prepare_table, + yield_client_with_storage, +) +from tests.utils import TEST_STORAGE_ROOT, delete_test_storage, preserve_environ -from dlt.common import json, pendulum, Decimal +from dlt.common import Decimal, json, pendulum from dlt.common.arithmetics import numeric_default_context from dlt.common.configuration.exceptions import ConfigFieldMissingException from dlt.common.configuration.resolve import resolve_configuration -from dlt.common.configuration.specs import GcpServiceAccountCredentials, GcpServiceAccountCredentialsWithoutDefaults, GcpOAuthCredentials, GcpOAuthCredentialsWithoutDefaults -from dlt.common.configuration.specs import gcp_credentials +from dlt.common.configuration.specs import ( + GcpOAuthCredentials, + GcpOAuthCredentialsWithoutDefaults, + GcpServiceAccountCredentials, + GcpServiceAccountCredentialsWithoutDefaults, + gcp_credentials, +) from dlt.common.configuration.specs.exceptions import InvalidGoogleNativeCredentialsType from dlt.common.storages import FileStorage -from dlt.common.utils import digest128, uniq_id, custom_environ - +from dlt.common.utils import custom_environ, digest128, uniq_id from dlt.destinations.bigquery.bigquery import BigQueryClient, BigQueryClientConfiguration from dlt.destinations.exceptions import LoadJobNotExistsException, LoadJobTerminalException -from tests.utils import TEST_STORAGE_ROOT, delete_test_storage, preserve_environ -from tests.common.utils import json_case_path as common_json_case_path -from tests.common.configuration.utils import environment -from tests.load.utils import expect_load_file, prepare_table, yield_client_with_storage, cm_yield_client_with_storage @pytest.fixture(scope="module") def client() -> Iterator[BigQueryClient]: @@ -42,7 +52,7 @@ def test_service_credentials_with_default(environment: Any) -> None: # resolve will miss values and try to find default credentials on the machine with pytest.raises(ConfigFieldMissingException) as py_ex: resolve_configuration(gcpc) - assert py_ex.value.fields == ['project_id', 'private_key', 'client_email'] + assert py_ex.value.fields == ["project_id", "private_key", "client_email"] # prepare real service.json services_str, dest_path = prepare_service_json() @@ -106,7 +116,7 @@ def test_oauth_credentials_with_default(environment: Any) -> None: # resolve will miss values and try to find default credentials on the machine with pytest.raises(ConfigFieldMissingException) as py_ex: resolve_configuration(gcoauth) - assert py_ex.value.fields == ['client_id', 'client_secret', 'refresh_token', 'project_id'] + assert py_ex.value.fields == ["client_id", "client_secret", "refresh_token", "project_id"] # prepare real service.json oauth_str, _ = prepare_oauth_json() @@ -180,7 +190,9 @@ def test_get_oauth_access_token() -> None: def test_bigquery_configuration() -> None: - config = resolve_configuration(BigQueryClientConfiguration(dataset_name="dataset"), sections=("destination", "bigquery")) + config = resolve_configuration( + BigQueryClientConfiguration(dataset_name="dataset"), sections=("destination", "bigquery") + ) assert config.location == "US" assert config.get_location() == "US" assert config.http_timeout == 15.0 @@ -190,16 +202,22 @@ def test_bigquery_configuration() -> None: # credentials location is deprecated os.environ["CREDENTIALS__LOCATION"] = "EU" - config = resolve_configuration(BigQueryClientConfiguration(dataset_name="dataset"), sections=("destination", "bigquery")) + config = resolve_configuration( + BigQueryClientConfiguration(dataset_name="dataset"), sections=("destination", "bigquery") + ) assert config.location == "US" assert config.credentials.location == "EU" # but if it is set, we propagate it to the config assert config.get_location() == "EU" os.environ["LOCATION"] = "ATLANTIS" - config = resolve_configuration(BigQueryClientConfiguration(dataset_name="dataset"), sections=("destination", "bigquery")) + config = resolve_configuration( + BigQueryClientConfiguration(dataset_name="dataset"), sections=("destination", "bigquery") + ) assert config.get_location() == "ATLANTIS" os.environ["DESTINATION__FILE_UPLOAD_TIMEOUT"] = "20000" - config = resolve_configuration(BigQueryClientConfiguration(dataset_name="dataset"), sections=("destination", "bigquery")) + config = resolve_configuration( + BigQueryClientConfiguration(dataset_name="dataset"), sections=("destination", "bigquery") + ) assert config.file_upload_timeout == 20000.0 # default fingerprint is empty @@ -230,30 +248,40 @@ def test_bigquery_job_errors(client: BigQueryClient, file_storage: FileStorage) load_json = { "_dlt_id": uniq_id(), "_dlt_root_id": uniq_id(), - "sender_id":'90238094809sajlkjxoiewjhduuiuehd', - "timestamp": str(pendulum.now()) + "sender_id": "90238094809sajlkjxoiewjhduuiuehd", + "timestamp": str(pendulum.now()), } job = expect_load_file(client, file_storage, json.dumps(load_json), user_table_name) # start a job from the same file. it should fallback to retrieve job silently - r_job = client.start_file_load(client.schema.get_table(user_table_name), file_storage.make_full_path(job.file_name()), uniq_id()) + r_job = client.start_file_load( + client.schema.get_table(user_table_name), + file_storage.make_full_path(job.file_name()), + uniq_id(), + ) assert r_job.state() == "completed" -@pytest.mark.parametrize('location', ["US", "EU"]) +@pytest.mark.parametrize("location", ["US", "EU"]) def test_bigquery_location(location: str, file_storage: FileStorage) -> None: - with cm_yield_client_with_storage("bigquery", default_config_values={"location": location}) as client: + with cm_yield_client_with_storage( + "bigquery", default_config_values={"location": location} + ) as client: user_table_name = prepare_table(client) load_json = { "_dlt_id": uniq_id(), "_dlt_root_id": uniq_id(), - "sender_id": '90238094809sajlkjxoiewjhduuiuehd', - "timestamp": str(pendulum.now()) + "sender_id": "90238094809sajlkjxoiewjhduuiuehd", + "timestamp": str(pendulum.now()), } job = expect_load_file(client, file_storage, json.dumps(load_json), user_table_name) # start a job from the same file. it should fallback to retrieve job silently - client.start_file_load(client.schema.get_table(user_table_name), file_storage.make_full_path(job.file_name()), uniq_id()) + client.start_file_load( + client.schema.get_table(user_table_name), + file_storage.make_full_path(job.file_name()), + uniq_id(), + ) canonical_name = client.sql_client.make_qualified_table_name(user_table_name, escape=False) t = client.sql_client.native_connection.get_table(canonical_name) assert t.location == location @@ -265,58 +293,84 @@ def test_loading_errors(client: BigQueryClient, file_storage: FileStorage) -> No load_json = { "_dlt_id": uniq_id(), "_dlt_root_id": uniq_id(), - "sender_id":'90238094809sajlkjxoiewjhduuiuehd', - "timestamp": str(pendulum.now()) + "sender_id": "90238094809sajlkjxoiewjhduuiuehd", + "timestamp": str(pendulum.now()), } insert_json = copy(load_json) insert_json["_unk_"] = None - job = expect_load_file(client, file_storage, json.dumps(insert_json), user_table_name, status="failed") + job = expect_load_file( + client, file_storage, json.dumps(insert_json), user_table_name, status="failed" + ) assert "No such field: _unk_" in job.exception() # insert null value insert_json = copy(load_json) insert_json["timestamp"] = None - job = expect_load_file(client, file_storage, json.dumps(insert_json), user_table_name, status="failed") + job = expect_load_file( + client, file_storage, json.dumps(insert_json), user_table_name, status="failed" + ) assert "Only optional fields can be set to NULL. Field: timestamp;" in job.exception() # insert wrong type insert_json = copy(load_json) insert_json["timestamp"] = "AA" - job = expect_load_file(client, file_storage, json.dumps(insert_json), user_table_name, status="failed") + job = expect_load_file( + client, file_storage, json.dumps(insert_json), user_table_name, status="failed" + ) assert "Couldn't convert value to timestamp:" in job.exception() # numeric overflow on bigint insert_json = copy(load_json) # 2**64//2 - 1 is a maximum bigint value - insert_json["metadata__rasa_x_id"] = 2**64//2 - job = expect_load_file(client, file_storage, json.dumps(insert_json), user_table_name, status="failed") + insert_json["metadata__rasa_x_id"] = 2**64 // 2 + job = expect_load_file( + client, file_storage, json.dumps(insert_json), user_table_name, status="failed" + ) assert "Could not convert value" in job.exception() # numeric overflow on NUMERIC insert_json = copy(load_json) # default decimal is (38, 9) (128 bit), use local context to generate decimals with 38 precision with numeric_default_context(): - below_limit = Decimal(10**29) - Decimal('0.001') + below_limit = Decimal(10**29) - Decimal("0.001") above_limit = Decimal(10**29) # this will pass insert_json["parse_data__intent__id"] = below_limit - job = expect_load_file(client, file_storage, json.dumps(insert_json), user_table_name, status="completed") + job = expect_load_file( + client, file_storage, json.dumps(insert_json), user_table_name, status="completed" + ) # this will fail insert_json["parse_data__intent__id"] = above_limit - job = expect_load_file(client, file_storage, json.dumps(insert_json), user_table_name, status="failed") - assert "Invalid NUMERIC value: 100000000000000000000000000000 Field: parse_data__intent__id;" in job.exception() + job = expect_load_file( + client, file_storage, json.dumps(insert_json), user_table_name, status="failed" + ) + assert ( + "Invalid NUMERIC value: 100000000000000000000000000000 Field: parse_data__intent__id;" + in job.exception() + ) # max bigquery decimal is (76, 76) (256 bit) = 5.7896044618658097711785492504343953926634992332820282019728792003956564819967E+38 insert_json = copy(load_json) - insert_json["parse_data__metadata__rasa_x_id"] = Decimal("5.7896044618658097711785492504343953926634992332820282019728792003956564819968E+38") - job = expect_load_file(client, file_storage, json.dumps(insert_json), user_table_name, status="failed") - assert "Invalid BIGNUMERIC value: 578960446186580977117854925043439539266.34992332820282019728792003956564819968 Field: parse_data__metadata__rasa_x_id;" in job.exception() + insert_json["parse_data__metadata__rasa_x_id"] = Decimal( + "5.7896044618658097711785492504343953926634992332820282019728792003956564819968E+38" + ) + job = expect_load_file( + client, file_storage, json.dumps(insert_json), user_table_name, status="failed" + ) + assert ( + "Invalid BIGNUMERIC value:" + " 578960446186580977117854925043439539266.34992332820282019728792003956564819968 Field:" + " parse_data__metadata__rasa_x_id;" + in job.exception() + ) def prepare_oauth_json() -> Tuple[str, str]: # prepare real service.json storage = FileStorage("_secrets", makedirs=True) - with open(common_json_case_path("oauth_client_secret_929384042504"), mode="r", encoding="utf-8") as f: + with open( + common_json_case_path("oauth_client_secret_929384042504"), mode="r", encoding="utf-8" + ) as f: oauth_str = f.read() dest_path = storage.save("oauth_client_secret_929384042504.json", oauth_str) return oauth_str, dest_path diff --git a/tests/load/bigquery/test_bigquery_table_builder.py b/tests/load/bigquery/test_bigquery_table_builder.py index 0ab691aa2f..e3f6cc25ba 100644 --- a/tests/load/bigquery/test_bigquery_table_builder.py +++ b/tests/load/bigquery/test_bigquery_table_builder.py @@ -1,18 +1,18 @@ +from copy import deepcopy + import pytest import sqlfluff -from copy import deepcopy +from tests.load.utils import TABLE_UPDATE -from dlt.common.utils import custom_environ, uniq_id -from dlt.common.schema import Schema -from dlt.common.schema.utils import new_table from dlt.common.configuration import resolve_configuration from dlt.common.configuration.specs import GcpServiceAccountCredentialsWithoutDefaults - +from dlt.common.schema import Schema +from dlt.common.schema.utils import new_table +from dlt.common.utils import custom_environ, uniq_id from dlt.destinations.bigquery.bigquery import BigQueryClient from dlt.destinations.bigquery.configuration import BigQueryClientConfiguration from dlt.destinations.exceptions import DestinationSchemaWillNotUpdate -from tests.load.utils import TABLE_UPDATE @pytest.fixture def schema() -> Schema: @@ -35,7 +35,9 @@ def gcp_client(schema: Schema) -> BigQueryClient: # return client without opening connection creds = GcpServiceAccountCredentialsWithoutDefaults() creds.project_id = "test_project_id" - return BigQueryClient(schema, BigQueryClientConfiguration(dataset_name="test_" + uniq_id(), credentials=creds)) + return BigQueryClient( + schema, BigQueryClientConfiguration(dataset_name="test_" + uniq_id(), credentials=creds) + ) def test_create_table(gcp_client: BigQueryClient) -> None: diff --git a/tests/load/cases/fake_destination.py b/tests/load/cases/fake_destination.py index 152b2db918..298ba81e41 100644 --- a/tests/load/cases/fake_destination.py +++ b/tests/load/cases/fake_destination.py @@ -1 +1 @@ -# module that is used to test wrong destination references \ No newline at end of file +# module that is used to test wrong destination references diff --git a/tests/load/conftest.py b/tests/load/conftest.py index e44d154b04..c5e97903a1 100644 --- a/tests/load/conftest.py +++ b/tests/load/conftest.py @@ -1,15 +1,13 @@ import os -import pytest from typing import Iterator +import pytest from tests.load.utils import ALL_BUCKETS from tests.utils import preserve_environ -@pytest.fixture(scope='function', params=ALL_BUCKETS) +@pytest.fixture(scope="function", params=ALL_BUCKETS) def all_buckets_env(request) -> Iterator[str]: # type: ignore[no-untyped-def] - """Parametrized fixture to configure filesystem destination bucket in env for each test bucket - """ - os.environ['DESTINATION__FILESYSTEM__BUCKET_URL'] = request.param + """Parametrized fixture to configure filesystem destination bucket in env for each test bucket""" + os.environ["DESTINATION__FILESYSTEM__BUCKET_URL"] = request.param yield request.param - diff --git a/tests/load/duckdb/test_duckdb_client.py b/tests/load/duckdb/test_duckdb_client.py index 1eac8c1fe4..19f0e83667 100644 --- a/tests/load/duckdb/test_duckdb_client.py +++ b/tests/load/duckdb/test_duckdb_client.py @@ -1,14 +1,19 @@ import os + import pytest +from tests.load.pipeline.utils import assert_table, drop_pipeline +from tests.utils import TEST_STORAGE_ROOT, autouse_test_storage, patch_home_dir, preserve_environ import dlt from dlt.common.configuration.resolve import resolve_configuration from dlt.common.configuration.utils import get_resolved_traces +from dlt.destinations.duckdb.configuration import ( + DEFAULT_DUCK_DB_NAME, + DUCK_DB_NAME, + DuckDbClientConfiguration, + DuckDbCredentials, +) -from dlt.destinations.duckdb.configuration import DUCK_DB_NAME, DuckDbClientConfiguration, DuckDbCredentials, DEFAULT_DUCK_DB_NAME - -from tests.load.pipeline.utils import drop_pipeline, assert_table -from tests.utils import patch_home_dir, autouse_test_storage, preserve_environ, TEST_STORAGE_ROOT @pytest.fixture(autouse=True) def delete_default_duckdb_credentials() -> None: @@ -66,7 +71,9 @@ def test_duckdb_database_path() -> None: os.unlink(db_path) # test special :pipeline: path to create in pipeline folder - c = resolve_configuration(DuckDbClientConfiguration(dataset_name="test_dataset", credentials=":pipeline:")) + c = resolve_configuration( + DuckDbClientConfiguration(dataset_name="test_dataset", credentials=":pipeline:") + ) db_path = os.path.abspath(os.path.join(p.working_dir, DEFAULT_DUCK_DB_NAME)) assert c.credentials.database.lower() == db_path.lower() # connect @@ -78,7 +85,11 @@ def test_duckdb_database_path() -> None: # provide relative path db_path = "_storage/test_quack.duckdb" - c = resolve_configuration(DuckDbClientConfiguration(dataset_name="test_dataset", credentials="duckdb:///_storage/test_quack.duckdb")) + c = resolve_configuration( + DuckDbClientConfiguration( + dataset_name="test_dataset", credentials="duckdb:///_storage/test_quack.duckdb" + ) + ) assert c.credentials.database.lower() == os.path.abspath(db_path).lower() conn = c.credentials.borrow_conn(read_only=False) c.credentials.return_conn(conn) @@ -87,7 +98,9 @@ def test_duckdb_database_path() -> None: # provide absolute path db_path = os.path.abspath("_storage/abs_test_quack.duckdb") - c = resolve_configuration(DuckDbClientConfiguration(dataset_name="test_dataset", credentials=f"duckdb:///{db_path}")) + c = resolve_configuration( + DuckDbClientConfiguration(dataset_name="test_dataset", credentials=f"duckdb:///{db_path}") + ) assert os.path.isabs(c.credentials.database) assert c.credentials.database.lower() == db_path.lower() conn = c.credentials.borrow_conn(read_only=False) @@ -97,7 +110,9 @@ def test_duckdb_database_path() -> None: # set just path as credentials db_path = "_storage/path_test_quack.duckdb" - c = resolve_configuration(DuckDbClientConfiguration(dataset_name="test_dataset", credentials=db_path)) + c = resolve_configuration( + DuckDbClientConfiguration(dataset_name="test_dataset", credentials=db_path) + ) assert c.credentials.database.lower() == os.path.abspath(db_path).lower() conn = c.credentials.borrow_conn(read_only=False) c.credentials.return_conn(conn) @@ -105,7 +120,9 @@ def test_duckdb_database_path() -> None: p = p.drop() db_path = os.path.abspath("_storage/abs_path_test_quack.duckdb") - c = resolve_configuration(DuckDbClientConfiguration(dataset_name="test_dataset", credentials=db_path)) + c = resolve_configuration( + DuckDbClientConfiguration(dataset_name="test_dataset", credentials=db_path) + ) assert os.path.isabs(c.credentials.database) assert c.credentials.database.lower() == db_path.lower() conn = c.credentials.borrow_conn(read_only=False) @@ -117,7 +134,9 @@ def test_duckdb_database_path() -> None: import duckdb with pytest.raises(duckdb.IOException): - c = resolve_configuration(DuckDbClientConfiguration(dataset_name="test_dataset", credentials=TEST_STORAGE_ROOT)) + c = resolve_configuration( + DuckDbClientConfiguration(dataset_name="test_dataset", credentials=TEST_STORAGE_ROOT) + ) conn = c.credentials.borrow_conn(read_only=False) @@ -202,7 +221,9 @@ def test_external_duckdb_database() -> None: # pass explicit in memory database conn = duckdb.connect(":memory:") - c = resolve_configuration(DuckDbClientConfiguration(dataset_name="test_dataset", credentials=conn)) + c = resolve_configuration( + DuckDbClientConfiguration(dataset_name="test_dataset", credentials=conn) + ) assert c.credentials._conn_borrows == 0 assert c.credentials._conn is conn int_conn = c.credentials.borrow_conn(read_only=False) @@ -214,6 +235,7 @@ def test_external_duckdb_database() -> None: assert hasattr(c.credentials, "_conn") conn.close() + def test_default_duckdb_dataset_name() -> None: # Check if dataset_name does not collide with pipeline_name data = ["a", "b", "c"] diff --git a/tests/load/duckdb/test_duckdb_table_builder.py b/tests/load/duckdb/test_duckdb_table_builder.py index 921e27c4b2..0feaa0de1d 100644 --- a/tests/load/duckdb/test_duckdb_table_builder.py +++ b/tests/load/duckdb/test_duckdb_table_builder.py @@ -1,14 +1,13 @@ -import pytest from copy import deepcopy + +import pytest import sqlfluff +from tests.load.utils import TABLE_UPDATE -from dlt.common.utils import uniq_id from dlt.common.schema import Schema - -from dlt.destinations.duckdb.duck import DuckDbClient +from dlt.common.utils import uniq_id from dlt.destinations.duckdb.configuration import DuckDbClientConfiguration - -from tests.load.utils import TABLE_UPDATE +from dlt.destinations.duckdb.duck import DuckDbClient @pytest.fixture @@ -29,7 +28,7 @@ def test_create_table_with_hints(client: DuckDbClient) -> None: mod_update[0]["sort"] = True mod_update[1]["unique"] = True mod_update[4]["foreign_key"] = True - sql = ';'.join(client._get_table_update_sql("event_test_table", mod_update, False)) + sql = ";".join(client._get_table_update_sql("event_test_table", mod_update, False)) assert '"col1" BIGINT NOT NULL' in sql assert '"col2" DOUBLE NOT NULL' in sql assert '"col5" VARCHAR ' in sql @@ -39,7 +38,10 @@ def test_create_table_with_hints(client: DuckDbClient) -> None: assert '"col4" TIMESTAMP WITH TIME ZONE NOT NULL' in sql # same thing with indexes - client = DuckDbClient(client.schema, DuckDbClientConfiguration(dataset_name="test_" + uniq_id(), create_indexes=True)) + client = DuckDbClient( + client.schema, + DuckDbClientConfiguration(dataset_name="test_" + uniq_id(), create_indexes=True), + ) sql = client._get_table_update_sql("event_test_table", mod_update, False)[0] sqlfluff.parse(sql) assert '"col2" DOUBLE UNIQUE NOT NULL' in sql @@ -47,7 +49,7 @@ def test_create_table_with_hints(client: DuckDbClient) -> None: def test_alter_table(client: DuckDbClient) -> None: # existing table has no columns - sql = ';'.join(client._get_table_update_sql("event_test_table", TABLE_UPDATE, True)) + sql = ";".join(client._get_table_update_sql("event_test_table", TABLE_UPDATE, True)) sqlfluff.parse(sql) assert sql.startswith("ALTER TABLE") assert sql.count("ALTER TABLE") == len(TABLE_UPDATE) diff --git a/tests/load/duckdb/test_motherduck_client.py b/tests/load/duckdb/test_motherduck_client.py index 4a167fa016..80291fdac7 100644 --- a/tests/load/duckdb/test_motherduck_client.py +++ b/tests/load/duckdb/test_motherduck_client.py @@ -1,14 +1,17 @@ import os + import pytest +from tests.utils import patch_home_dir, preserve_environ, skip_if_not_active from dlt.common.configuration.resolve import resolve_configuration - -from dlt.destinations.motherduck.configuration import MotherDuckCredentials, MotherDuckClientConfiguration - -from tests.utils import patch_home_dir, preserve_environ, skip_if_not_active +from dlt.destinations.motherduck.configuration import ( + MotherDuckClientConfiguration, + MotherDuckCredentials, +) skip_if_not_active("motherduck") + def test_motherduck_database() -> None: # set HOME env otherwise some internal components in ducdkb (HTTPS) do not initialize os.environ["HOME"] = "/tmp" @@ -20,7 +23,9 @@ def test_motherduck_database() -> None: cred.parse_native_representation("md:///?token=TOKEN") assert cred.password == "TOKEN" - config = resolve_configuration(MotherDuckClientConfiguration(dataset_name="test"), sections=("destination", "motherduck")) + config = resolve_configuration( + MotherDuckClientConfiguration(dataset_name="test"), sections=("destination", "motherduck") + ) # connect con = config.credentials.borrow_conn(read_only=False) con.sql("SHOW DATABASES") diff --git a/tests/load/filesystem/test_aws_credentials.py b/tests/load/filesystem/test_aws_credentials.py index 67f6759639..a1ef8a802f 100644 --- a/tests/load/filesystem/test_aws_credentials.py +++ b/tests/load/filesystem/test_aws_credentials.py @@ -1,26 +1,28 @@ -import pytest from typing import Dict +import pytest +from tests.common.configuration.utils import environment +from tests.load.utils import ALL_FILESYSTEM_DRIVERS +from tests.utils import autouse_test_storage, preserve_environ from dlt.common.configuration import resolve_configuration from dlt.common.configuration.specs.aws_credentials import AwsCredentials from dlt.common.configuration.specs.exceptions import InvalidBoto3Session -from tests.common.configuration.utils import environment -from tests.load.utils import ALL_FILESYSTEM_DRIVERS -from tests.utils import preserve_environ, autouse_test_storage -@pytest.mark.skipif('s3' not in ALL_FILESYSTEM_DRIVERS, reason='s3 filesystem driver not configured') +@pytest.mark.skipif( + "s3" not in ALL_FILESYSTEM_DRIVERS, reason="s3 filesystem driver not configured" +) def test_aws_credentials_resolved_from_default(environment: Dict[str, str]) -> None: - environment['AWS_ACCESS_KEY_ID'] = 'fake_access_key' - environment['AWS_SECRET_ACCESS_KEY'] = 'fake_secret_key' - environment['AWS_SESSION_TOKEN'] = 'fake_session_token' + environment["AWS_ACCESS_KEY_ID"] = "fake_access_key" + environment["AWS_SECRET_ACCESS_KEY"] = "fake_secret_key" + environment["AWS_SESSION_TOKEN"] = "fake_session_token" config = resolve_configuration(AwsCredentials()) - assert config.aws_access_key_id == 'fake_access_key' - assert config.aws_secret_access_key == 'fake_secret_key' - assert config.aws_session_token == 'fake_session_token' + assert config.aws_access_key_id == "fake_access_key" + assert config.aws_secret_access_key == "fake_secret_key" + assert config.aws_session_token == "fake_session_token" # we do not set the profile assert config.profile_name is None @@ -36,11 +38,13 @@ def test_aws_credentials_resolved_from_default(environment: Dict[str, str]) -> N # assert config.profile_name == "default" -@pytest.mark.skipif('s3' not in ALL_FILESYSTEM_DRIVERS, reason='s3 filesystem driver not configured') +@pytest.mark.skipif( + "s3" not in ALL_FILESYSTEM_DRIVERS, reason="s3 filesystem driver not configured" +) def test_aws_credentials_from_boto3(environment: Dict[str, str]) -> None: - environment['AWS_ACCESS_KEY_ID'] = 'fake_access_key' - environment['AWS_SECRET_ACCESS_KEY'] = 'fake_secret_key' - environment['AWS_SESSION_TOKEN'] = 'fake_session_token' + environment["AWS_ACCESS_KEY_ID"] = "fake_access_key" + environment["AWS_SECRET_ACCESS_KEY"] = "fake_secret_key" + environment["AWS_SESSION_TOKEN"] = "fake_session_token" import boto3 diff --git a/tests/load/filesystem/test_filesystem_client.py b/tests/load/filesystem/test_filesystem_client.py index bbcd011338..6193df2b3c 100644 --- a/tests/load/filesystem/test_filesystem_client.py +++ b/tests/load/filesystem/test_filesystem_client.py @@ -1,16 +1,19 @@ -import posixpath import os +import posixpath import pytest +from tests.load.filesystem.utils import perform_load +from tests.utils import ( + autouse_test_storage, + clean_test_storage, + init_test_logging, + preserve_environ, +) +from dlt.common.storages import FileStorage, LoadStorage from dlt.common.utils import digest128, uniq_id -from dlt.common.storages import LoadStorage, FileStorage - -from dlt.destinations.filesystem.filesystem import LoadFilesystemJob, FilesystemClientConfiguration +from dlt.destinations.filesystem.filesystem import FilesystemClientConfiguration, LoadFilesystemJob -from tests.load.filesystem.utils import perform_load -from tests.utils import clean_test_storage, init_test_logging -from tests.utils import preserve_environ, autouse_test_storage @pytest.fixture(autouse=True) def storage() -> FileStorage: @@ -24,34 +27,38 @@ def logger_autouse() -> None: NORMALIZED_FILES = [ "event_user.839c6e6b514e427687586ccc65bf133f.0.jsonl", - "event_loop_interrupted.839c6e6b514e427687586ccc65bf133f.0.jsonl" + "event_loop_interrupted.839c6e6b514e427687586ccc65bf133f.0.jsonl", ] ALL_LAYOUTS = ( None, - "{schema_name}/{table_name}/{load_id}.{file_id}.{ext}", # new default layout with schema - "{schema_name}.{table_name}.{load_id}.{file_id}.{ext}", # classic layout - "{table_name}88{load_id}-u-{file_id}.{ext}" # default layout with strange separators + "{schema_name}/{table_name}/{load_id}.{file_id}.{ext}", # new default layout with schema + "{schema_name}.{table_name}.{load_id}.{file_id}.{ext}", # classic layout + "{table_name}88{load_id}-u-{file_id}.{ext}", # default layout with strange separators ) def test_filesystem_configuration() -> None: assert FilesystemClientConfiguration().fingerprint() == "" - assert FilesystemClientConfiguration(bucket_url="s3://cool").fingerprint() == digest128("s3://cool") + assert FilesystemClientConfiguration(bucket_url="s3://cool").fingerprint() == digest128( + "s3://cool" + ) -@pytest.mark.parametrize('write_disposition', ('replace', 'append', 'merge')) -@pytest.mark.parametrize('layout', ALL_LAYOUTS) +@pytest.mark.parametrize("write_disposition", ("replace", "append", "merge")) +@pytest.mark.parametrize("layout", ALL_LAYOUTS) def test_successful_load(write_disposition: str, layout: str, all_buckets_env: str) -> None: """Test load is successful with an empty destination dataset""" if layout: - os.environ['DESTINATION__FILESYSTEM__LAYOUT'] = layout + os.environ["DESTINATION__FILESYSTEM__LAYOUT"] = layout else: os.environ.pop("DESTINATION__FILESYSTEM__LAYOUT", None) - dataset_name = 'test_' + uniq_id() + dataset_name = "test_" + uniq_id() - with perform_load(dataset_name, NORMALIZED_FILES, write_disposition=write_disposition) as load_info: + with perform_load( + dataset_name, NORMALIZED_FILES, write_disposition=write_disposition + ) as load_info: client, jobs, _, load_id = load_info layout = client.config.layout dataset_path = posixpath.join(client.fs_path, client.config.dataset_name) @@ -62,77 +69,99 @@ def test_successful_load(write_disposition: str, layout: str, all_buckets_env: s # Sanity check, there are jobs assert jobs for job in jobs: - assert job.state() == 'completed' + assert job.state() == "completed" job_info = LoadStorage.parse_job_file_name(job.file_name()) destination_path = posixpath.join( dataset_path, - layout.format(schema_name=client.schema.name, table_name=job_info.table_name, load_id=load_id, file_id=job_info.file_id, ext=job_info.file_format) + layout.format( + schema_name=client.schema.name, + table_name=job_info.table_name, + load_id=load_id, + file_id=job_info.file_id, + ext=job_info.file_format, + ), ) # File is created with correct filename and path assert client.fs_client.isfile(destination_path) -@pytest.mark.parametrize('layout', ALL_LAYOUTS) +@pytest.mark.parametrize("layout", ALL_LAYOUTS) def test_replace_write_disposition(layout: str, all_buckets_env: str) -> None: if layout: - os.environ['DESTINATION__FILESYSTEM__LAYOUT'] = layout + os.environ["DESTINATION__FILESYSTEM__LAYOUT"] = layout else: os.environ.pop("DESTINATION__FILESYSTEM__LAYOUT", None) - dataset_name = 'test_' + uniq_id() + dataset_name = "test_" + uniq_id() # NOTE: context manager will delete the dataset at the end so keep it open until the end - with perform_load(dataset_name, NORMALIZED_FILES, write_disposition='replace') as load_info: + with perform_load(dataset_name, NORMALIZED_FILES, write_disposition="replace") as load_info: client, _, root_path, load_id1 = load_info layout = client.config.layout # this path will be kept after replace job_2_load_1_path = posixpath.join( root_path, - LoadFilesystemJob.make_destination_filename(layout, NORMALIZED_FILES[1], client.schema.name, load_id1) + LoadFilesystemJob.make_destination_filename( + layout, NORMALIZED_FILES[1], client.schema.name, load_id1 + ), ) - with perform_load(dataset_name, [NORMALIZED_FILES[0]], write_disposition='replace') as load_info: + with perform_load( + dataset_name, [NORMALIZED_FILES[0]], write_disposition="replace" + ) as load_info: client, _, root_path, load_id2 = load_info # this one we expect to be replaced with job_1_load_2_path = posixpath.join( root_path, - LoadFilesystemJob.make_destination_filename(layout, NORMALIZED_FILES[0], client.schema.name, load_id2) + LoadFilesystemJob.make_destination_filename( + layout, NORMALIZED_FILES[0], client.schema.name, load_id2 + ), ) # First file from load1 remains, second file is replaced by load2 # assert that only these two files are in the destination folder paths = [] - for basedir, _dirs, files in client.fs_client.walk(client.dataset_path, detail=False, refresh=True): + for basedir, _dirs, files in client.fs_client.walk( + client.dataset_path, detail=False, refresh=True + ): for f in files: paths.append(posixpath.join(basedir, f)) ls = set(paths) assert ls == {job_2_load_1_path, job_1_load_2_path} -@pytest.mark.parametrize('layout', ALL_LAYOUTS) +@pytest.mark.parametrize("layout", ALL_LAYOUTS) def test_append_write_disposition(layout: str, all_buckets_env: str) -> None: """Run load twice with append write_disposition and assert that there are two copies of each file in destination""" if layout: - os.environ['DESTINATION__FILESYSTEM__LAYOUT'] = layout + os.environ["DESTINATION__FILESYSTEM__LAYOUT"] = layout else: os.environ.pop("DESTINATION__FILESYSTEM__LAYOUT", None) - dataset_name = 'test_' + uniq_id() + dataset_name = "test_" + uniq_id() # NOTE: context manager will delete the dataset at the end so keep it open until the end - with perform_load(dataset_name, NORMALIZED_FILES, write_disposition='append') as load_info: + with perform_load(dataset_name, NORMALIZED_FILES, write_disposition="append") as load_info: client, jobs1, root_path, load_id1 = load_info - with perform_load(dataset_name, NORMALIZED_FILES, write_disposition='append') as load_info: + with perform_load(dataset_name, NORMALIZED_FILES, write_disposition="append") as load_info: client, jobs2, root_path, load_id2 = load_info layout = client.config.layout expected_files = [ - LoadFilesystemJob.make_destination_filename(layout, job.file_name(), client.schema.name, load_id1) for job in jobs1 + LoadFilesystemJob.make_destination_filename( + layout, job.file_name(), client.schema.name, load_id1 + ) + for job in jobs1 ] + [ - LoadFilesystemJob.make_destination_filename(layout, job.file_name(), client.schema.name, load_id2) for job in jobs2 + LoadFilesystemJob.make_destination_filename( + layout, job.file_name(), client.schema.name, load_id2 + ) + for job in jobs2 ] expected_files = sorted([posixpath.join(root_path, fn) for fn in expected_files]) paths = [] - for basedir, _dirs, files in client.fs_client.walk(client.dataset_path, detail=False, refresh=True): + for basedir, _dirs, files in client.fs_client.walk( + client.dataset_path, detail=False, refresh=True + ): for f in files: paths.append(posixpath.join(basedir, f)) assert list(sorted(paths)) == expected_files diff --git a/tests/load/filesystem/utils.py b/tests/load/filesystem/utils.py index 489a9995a4..281d84d064 100644 --- a/tests/load/filesystem/utils.py +++ b/tests/load/filesystem/utils.py @@ -1,31 +1,29 @@ import posixpath -from typing import Iterator, List, Sequence, Tuple from contextlib import contextmanager +from typing import Iterator, List, Sequence, Tuple + +from tests.load.utils import prepare_load_package -from dlt.load import Load from dlt.common.configuration.container import Container from dlt.common.configuration.specs.config_section_context import ConfigSectionContext from dlt.common.destination.reference import DestinationReference, LoadJob from dlt.destinations import filesystem from dlt.destinations.filesystem.filesystem import FilesystemClient from dlt.destinations.job_impl import EmptyLoadJob -from tests.load.utils import prepare_load_package +from dlt.load import Load def setup_loader(dataset_name: str) -> Load: destination: DestinationReference = filesystem # type: ignore[assignment] config = filesystem.spec()(dataset_name=dataset_name) # setup loader - with Container().injectable_context(ConfigSectionContext(sections=('filesystem',))): - return Load( - destination, - initial_client_config=config - ) + with Container().injectable_context(ConfigSectionContext(sections=("filesystem",))): + return Load(destination, initial_client_config=config) @contextmanager def perform_load( - dataset_name: str, cases: Sequence[str], write_disposition: str='append' + dataset_name: str, cases: Sequence[str], write_disposition: str = "append" ) -> Iterator[Tuple[FilesystemClient, List[LoadJob], str, str]]: load = setup_loader(dataset_name) load_id, schema = prepare_load_package(load.load_storage, cases, write_disposition) @@ -33,9 +31,9 @@ def perform_load( # for the replace disposition in the loader we truncate the tables, so do this here truncate_tables = [] - if write_disposition == 'replace': + if write_disposition == "replace": for item in cases: - parts = item.split('.') + parts = item.split(".") truncate_tables.append(parts[0]) client.initialize_storage(truncate_tables=truncate_tables) diff --git a/tests/load/pipeline/conftest.py b/tests/load/pipeline/conftest.py index 76dc74a555..639c0460b5 100644 --- a/tests/load/pipeline/conftest.py +++ b/tests/load/pipeline/conftest.py @@ -1,3 +1,8 @@ -from tests.utils import patch_home_dir, preserve_environ, autouse_test_storage, duckdb_pipeline_location -from tests.pipeline.utils import drop_dataset_from_env from tests.load.pipeline.utils import drop_pipeline +from tests.pipeline.utils import drop_dataset_from_env +from tests.utils import ( + autouse_test_storage, + duckdb_pipeline_location, + patch_home_dir, + preserve_environ, +) diff --git a/tests/load/pipeline/test_athena.py b/tests/load/pipeline/test_athena.py index 56b63524b0..2ea4d53fca 100644 --- a/tests/load/pipeline/test_athena.py +++ b/tests/load/pipeline/test_athena.py @@ -1,19 +1,30 @@ -import pytest -from copy import deepcopy import datetime # noqa: I251 +from copy import deepcopy + +import pytest +from tests.load.pipeline.utils import ( + DestinationTestConfiguration, + destinations_configs, + load_table_counts, +) +from tests.load.utils import ( + TABLE_ROW_ALL_DATA_TYPES, + TABLE_UPDATE_COLUMNS_SCHEMA, + assert_all_data_types_row, +) +from tests.pipeline.utils import assert_load_info import dlt from dlt.common import pendulum from dlt.common.utils import uniq_id -from tests.load.pipeline.utils import load_table_counts -from tests.load.utils import TABLE_UPDATE_COLUMNS_SCHEMA, TABLE_ROW_ALL_DATA_TYPES, assert_all_data_types_row -from tests.pipeline.utils import assert_load_info -from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True, subset=["athena"]), ids=lambda x: x.name) +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["athena"]), + ids=lambda x: x.name, +) def test_athena_destinations(destination_config: DestinationTestConfiguration) -> None: - pipeline = destination_config.setup_pipeline("athena_" + uniq_id(), full_refresh=True) @dlt.resource(name="items", write_disposition="append") @@ -21,19 +32,15 @@ def items(): yield { "id": 1, "name": "item", - "sub_items": [{ - "id": 101, - "name": "sub item 101" - },{ - "id": 101, - "name": "sub item 102" - }] + "sub_items": [{"id": 101, "name": "sub item 101"}, {"id": 101, "name": "sub item 102"}], } pipeline.run(items) # see if we have athena tables with items - table_counts = load_table_counts(pipeline, *[t["name"] for t in pipeline.default_schema._schema_tables.values() ]) + table_counts = load_table_counts( + pipeline, *[t["name"] for t in pipeline.default_schema._schema_tables.values()] + ) assert table_counts["items"] == 1 assert table_counts["items__sub_items"] == 2 assert table_counts["_dlt_loads"] == 1 @@ -45,25 +52,37 @@ def items2(): "id": 1, "name": "item", "new_field": "hello", - "sub_items": [{ - "id": 101, - "name": "sub item 101", - "other_new_field": "hello 101", - },{ - "id": 101, - "name": "sub item 102", - "other_new_field": "hello 102", - }] + "sub_items": [ + { + "id": 101, + "name": "sub item 101", + "other_new_field": "hello 101", + }, + { + "id": 101, + "name": "sub item 102", + "other_new_field": "hello 102", + }, + ], } + pipeline.run(items2) - table_counts = load_table_counts(pipeline, *[t["name"] for t in pipeline.default_schema._schema_tables.values()]) + table_counts = load_table_counts( + pipeline, *[t["name"] for t in pipeline.default_schema._schema_tables.values()] + ) assert table_counts["items"] == 2 assert table_counts["items__sub_items"] == 4 assert table_counts["_dlt_loads"] == 2 -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True, subset=["athena"]), ids=lambda x: x.name) -def test_athena_all_datatypes_and_timestamps(destination_config: DestinationTestConfiguration) -> None: +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["athena"]), + ids=lambda x: x.name, +) +def test_athena_all_datatypes_and_timestamps( + destination_config: DestinationTestConfiguration, +) -> None: pipeline = destination_config.setup_pipeline("athena_" + uniq_id(), full_refresh=True) data_types = deepcopy(TABLE_ROW_ALL_DATA_TYPES) column_schemas = deepcopy(TABLE_UPDATE_COLUMNS_SCHEMA) @@ -72,7 +91,7 @@ def test_athena_all_datatypes_and_timestamps(destination_config: DestinationTest @dlt.resource(table_name="data_types", write_disposition="append", columns=column_schemas) def my_resource(): nonlocal data_types - yield [data_types]*10 + yield [data_types] * 10 @dlt.source(max_table_nesting=0) def my_source(): @@ -86,35 +105,59 @@ def my_source(): assert len(db_rows) == 10 db_row = list(db_rows[0]) # content must equal - assert_all_data_types_row(db_row[:-2], parse_complex_strings=True, timestamp_precision=sql_client.capabilities.timestamp_precision) + assert_all_data_types_row( + db_row[:-2], + parse_complex_strings=True, + timestamp_precision=sql_client.capabilities.timestamp_precision, + ) # now let's query the data with timestamps and dates. # https://docs.aws.amazon.com/athena/latest/ug/engine-versions-reference-0003.html#engine-versions-reference-0003-timestamp-changes # use string representation TIMESTAMP(2) - db_rows = sql_client.execute_sql("SELECT * FROM data_types WHERE col4 = TIMESTAMP '2022-05-23 13:26:45.176'") + db_rows = sql_client.execute_sql( + "SELECT * FROM data_types WHERE col4 = TIMESTAMP '2022-05-23 13:26:45.176'" + ) assert len(db_rows) == 10 # no rows - TIMESTAMP(6) not supported - db_rows = sql_client.execute_sql("SELECT * FROM data_types WHERE col4 = TIMESTAMP '2022-05-23 13:26:45.176145'") + db_rows = sql_client.execute_sql( + "SELECT * FROM data_types WHERE col4 = TIMESTAMP '2022-05-23 13:26:45.176145'" + ) assert len(db_rows) == 0 # use pendulum # that will pass - db_rows = sql_client.execute_sql("SELECT * FROM data_types WHERE col4 = %s", pendulum.datetime(2022, 5, 23, 13, 26, 45, 176000)) + db_rows = sql_client.execute_sql( + "SELECT * FROM data_types WHERE col4 = %s", + pendulum.datetime(2022, 5, 23, 13, 26, 45, 176000), + ) assert len(db_rows) == 10 # that will return empty list - db_rows = sql_client.execute_sql("SELECT * FROM data_types WHERE col4 = %s", pendulum.datetime(2022, 5, 23, 13, 26, 45, 176145)) + db_rows = sql_client.execute_sql( + "SELECT * FROM data_types WHERE col4 = %s", + pendulum.datetime(2022, 5, 23, 13, 26, 45, 176145), + ) assert len(db_rows) == 0 # use datetime - db_rows = sql_client.execute_sql("SELECT * FROM data_types WHERE col4 = %s", datetime.datetime(2022, 5, 23, 13, 26, 45, 176000)) + db_rows = sql_client.execute_sql( + "SELECT * FROM data_types WHERE col4 = %s", + datetime.datetime(2022, 5, 23, 13, 26, 45, 176000), + ) assert len(db_rows) == 10 - db_rows = sql_client.execute_sql("SELECT * FROM data_types WHERE col4 = %s", datetime.datetime(2022, 5, 23, 13, 26, 45, 176145)) + db_rows = sql_client.execute_sql( + "SELECT * FROM data_types WHERE col4 = %s", + datetime.datetime(2022, 5, 23, 13, 26, 45, 176145), + ) assert len(db_rows) == 0 # check date db_rows = sql_client.execute_sql("SELECT * FROM data_types WHERE col10 = DATE '2023-02-27'") assert len(db_rows) == 10 - db_rows = sql_client.execute_sql("SELECT * FROM data_types WHERE col10 = %s", pendulum.date(2023, 2, 27)) + db_rows = sql_client.execute_sql( + "SELECT * FROM data_types WHERE col10 = %s", pendulum.date(2023, 2, 27) + ) assert len(db_rows) == 10 - db_rows = sql_client.execute_sql("SELECT * FROM data_types WHERE col10 = %s", datetime.date(2023, 2, 27)) + db_rows = sql_client.execute_sql( + "SELECT * FROM data_types WHERE col10 = %s", datetime.date(2023, 2, 27) + ) assert len(db_rows) == 10 diff --git a/tests/load/pipeline/test_dbt_helper.py b/tests/load/pipeline/test_dbt_helper.py index e55f5b2964..125da2d48b 100644 --- a/tests/load/pipeline/test_dbt_helper.py +++ b/tests/load/pipeline/test_dbt_helper.py @@ -1,7 +1,14 @@ import os +import tempfile from typing import Iterator + import pytest -import tempfile +from tests.load.pipeline.utils import ( + DestinationTestConfiguration, + destinations_configs, + select_data, +) +from tests.utils import ACTIVE_SQL_DESTINATIONS import dlt from dlt.common.runners import Venv @@ -10,10 +17,6 @@ from dlt.helpers.dbt import create_venv from dlt.helpers.dbt.exceptions import DBTProcessingError, PrerequisitesException -from tests.load.pipeline.utils import select_data -from tests.utils import ACTIVE_SQL_DESTINATIONS -from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration - # uncomment add motherduck tests # NOTE: the tests are passing but we disable them due to frequent ATTACH DATABASE timeouts # ACTIVE_DESTINATIONS += ["motherduck"] @@ -27,10 +30,16 @@ def dbt_venv() -> Iterator[Venv]: yield venv -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name) -def test_run_jaffle_package(destination_config: DestinationTestConfiguration, dbt_venv: Venv) -> None: +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) +def test_run_jaffle_package( + destination_config: DestinationTestConfiguration, dbt_venv: Venv +) -> None: if destination_config.destination == "athena": - pytest.skip("dbt-athena requires database to be created and we don't do it in case of Jaffle") + pytest.skip( + "dbt-athena requires database to be created and we don't do it in case of Jaffle" + ) pipeline = destination_config.setup_pipeline("jaffle_jaffle", full_refresh=True) # get runner, pass the env from fixture dbt = dlt.dbt.package(pipeline, "https://github.com/dbt-labs/jaffle_shop.git", venv=dbt_venv) @@ -55,14 +64,18 @@ def test_run_jaffle_package(destination_config: DestinationTestConfiguration, db assert len(orders) == 99 -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name) +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) def test_run_chess_dbt(destination_config: DestinationTestConfiguration, dbt_venv: Venv) -> None: from docs.examples.chess.chess import chess # provide chess url via environ os.environ["CHESS_URL"] = "https://api.chess.com/pub/" - pipeline = destination_config.setup_pipeline("chess_games", dataset_name="chess_dbt_test", full_refresh=True) + pipeline = destination_config.setup_pipeline( + "chess_games", dataset_name="chess_dbt_test", full_refresh=True + ) assert pipeline.default_schema_name is None # get the runner for the "dbt_transform" package transforms = dlt.dbt.package(pipeline, "docs/examples/chess/dbt_transform", venv=dbt_venv) @@ -79,27 +92,41 @@ def test_run_chess_dbt(destination_config: DestinationTestConfiguration, dbt_ven transforms.run_all(source_tests_selector="source:*") # run all the tests transforms.test() - load_ids = select_data(pipeline, "SELECT load_id, schema_name, status FROM _dlt_loads ORDER BY status") + load_ids = select_data( + pipeline, "SELECT load_id, schema_name, status FROM _dlt_loads ORDER BY status" + ) assert len(load_ids) == 2 - view_player_games = select_data(pipeline, "SELECT * FROM view_player_games ORDER BY username, uuid") + view_player_games = select_data( + pipeline, "SELECT * FROM view_player_games ORDER BY username, uuid" + ) assert len(view_player_games) > 0 # run again transforms.run() # no new load ids - no new data in view table - new_load_ids = select_data(pipeline, "SELECT load_id, schema_name, status FROM _dlt_loads ORDER BY status") - new_view_player_games = select_data(pipeline, "SELECT * FROM view_player_games ORDER BY username, uuid") + new_load_ids = select_data( + pipeline, "SELECT load_id, schema_name, status FROM _dlt_loads ORDER BY status" + ) + new_view_player_games = select_data( + pipeline, "SELECT * FROM view_player_games ORDER BY username, uuid" + ) assert load_ids == new_load_ids assert view_player_games == new_view_player_games -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name) -def test_run_chess_dbt_to_other_dataset(destination_config: DestinationTestConfiguration, dbt_venv: Venv) -> None: +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) +def test_run_chess_dbt_to_other_dataset( + destination_config: DestinationTestConfiguration, dbt_venv: Venv +) -> None: from docs.examples.chess.chess import chess # provide chess url via environ os.environ["CHESS_URL"] = "https://api.chess.com/pub/" - pipeline = destination_config.setup_pipeline("chess_games", dataset_name="chess_dbt_test", full_refresh=True) + pipeline = destination_config.setup_pipeline( + "chess_games", dataset_name="chess_dbt_test", full_refresh=True + ) # load each schema in separate dataset pipeline.config.use_single_dataset = False # assert pipeline.default_schema_name is None @@ -122,12 +149,18 @@ def test_run_chess_dbt_to_other_dataset(destination_config: DestinationTestConfi # run tests on destination dataset where transformations actually are transforms.test(destination_dataset_name=info.dataset_name + "_" + test_suffix) # get load ids from the source dataset - load_ids = select_data(pipeline, "SELECT load_id, schema_name, status FROM _dlt_loads ORDER BY status") + load_ids = select_data( + pipeline, "SELECT load_id, schema_name, status FROM _dlt_loads ORDER BY status" + ) assert len(load_ids) == 1 # status is 0, no more entries assert load_ids[0][2] == 0 # get from destination dataset - load_ids = select_data(pipeline, "SELECT load_id, schema_name, status FROM _dlt_loads ORDER BY status", schema_name=test_suffix) + load_ids = select_data( + pipeline, + "SELECT load_id, schema_name, status FROM _dlt_loads ORDER BY status", + schema_name=test_suffix, + ) # TODO: the package is not finished, both results should be here assert len(load_ids) == 1 # status is 1, no more entries diff --git a/tests/load/pipeline/test_drop.py b/tests/load/pipeline/test_drop.py index 43030aa5d3..5e6b6f6689 100644 --- a/tests/load/pipeline/test_drop.py +++ b/tests/load/pipeline/test_drop.py @@ -1,58 +1,63 @@ -from typing import Any, Iterator, Dict, Any, List -from unittest import mock from itertools import chain +from typing import Any, Dict, Iterator, List +from unittest import mock import pytest +from tests.load.pipeline.utils import DestinationTestConfiguration, destinations_configs import dlt -from dlt.extract.source import DltResource from dlt.common.utils import uniq_id -from dlt.pipeline import helpers, state_sync, Pipeline +from dlt.destinations.job_client_impl import SqlJobClientBase +from dlt.extract.source import DltResource from dlt.load import Load +from dlt.pipeline import Pipeline, helpers, state_sync from dlt.pipeline.exceptions import PipelineStepFailed -from dlt.destinations.job_client_impl import SqlJobClientBase - -from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration def _attach(pipeline: Pipeline) -> Pipeline: return dlt.attach(pipeline.pipeline_name, pipeline.pipelines_dir) -@dlt.source(section='droppable', name='droppable') +@dlt.source(section="droppable", name="droppable") def droppable_source() -> List[DltResource]: @dlt.resource - def droppable_a(a: dlt.sources.incremental[int]=dlt.sources.incremental('a', 0)) -> Iterator[Dict[str, Any]]: + def droppable_a( + a: dlt.sources.incremental[int] = dlt.sources.incremental("a", 0) + ) -> Iterator[Dict[str, Any]]: yield dict(a=1, b=2, c=3) yield dict(a=4, b=23, c=24) - @dlt.resource - def droppable_b(asd: dlt.sources.incremental[int]=dlt.sources.incremental('asd', 0)) -> Iterator[Dict[str, Any]]: + def droppable_b( + asd: dlt.sources.incremental[int] = dlt.sources.incremental("asd", 0) + ) -> Iterator[Dict[str, Any]]: # Child table yield dict(asd=2323, qe=555, items=[dict(m=1, n=2), dict(m=3, n=4)]) - @dlt.resource - def droppable_c(qe: dlt.sources.incremental[int] = dlt.sources.incremental('qe')) -> Iterator[Dict[str, Any]]: + def droppable_c( + qe: dlt.sources.incremental[int] = dlt.sources.incremental("qe"), + ) -> Iterator[Dict[str, Any]]: # Grandchild table - yield dict(asdasd=2424, qe=111, items=[ - dict(k=2, r=2, labels=[dict(name='abc'), dict(name='www')]) - ]) + yield dict( + asdasd=2424, qe=111, items=[dict(k=2, r=2, labels=[dict(name="abc"), dict(name="www")])] + ) @dlt.resource - def droppable_d(o: dlt.sources.incremental[int] = dlt.sources.incremental('o')) -> Iterator[List[Dict[str, Any]]]: - dlt.state()['data_from_d'] = {'foo1': {'bar': 1}, 'foo2': {'bar': 2}} + def droppable_d( + o: dlt.sources.incremental[int] = dlt.sources.incremental("o"), + ) -> Iterator[List[Dict[str, Any]]]: + dlt.state()["data_from_d"] = {"foo1": {"bar": 1}, "foo2": {"bar": 2}} yield [dict(o=55), dict(o=22)] return [droppable_a(), droppable_b(), droppable_c(), droppable_d()] RESOURCE_TABLES = dict( - droppable_a=['droppable_a'], - droppable_b=['droppable_b', 'droppable_b__items'], - droppable_c=['droppable_c', 'droppable_c__items', 'droppable_c__items__labels'], - droppable_d=['droppable_d'] + droppable_a=["droppable_a"], + droppable_b=["droppable_b", "droppable_b__items"], + droppable_c=["droppable_c", "droppable_c__items", "droppable_c__items__labels"], + droppable_d=["droppable_d"], ) @@ -60,12 +65,13 @@ def assert_dropped_resources(pipeline: Pipeline, resources: List[str]) -> None: assert_dropped_resource_tables(pipeline, resources) assert_dropped_resource_states(pipeline, resources) + def assert_dropped_resource_tables(pipeline: Pipeline, resources: List[str]) -> None: # Verify only requested resource tables are removed from pipeline schema all_tables = set(chain.from_iterable(RESOURCE_TABLES.values())) dropped_tables = set(chain.from_iterable(RESOURCE_TABLES[r] for r in resources)) expected_tables = all_tables - dropped_tables - result_tables = set(t['name'] for t in pipeline.default_schema.data_tables()) + result_tables = set(t["name"] for t in pipeline.default_schema.data_tables()) assert result_tables == expected_tables # Verify requested tables are dropped from destination @@ -85,124 +91,146 @@ def assert_dropped_resource_states(pipeline: Pipeline, resources: List[str]) -> # Verify only requested resource keys are removed from state all_resources = set(RESOURCE_TABLES.keys()) expected_keys = all_resources - set(resources) - sources_state = pipeline.state['sources'] # type: ignore[typeddict-item] - result_keys = set(sources_state['droppable']['resources'].keys()) + sources_state = pipeline.state["sources"] # type: ignore[typeddict-item] + result_keys = set(sources_state["droppable"]["resources"].keys()) assert result_keys == expected_keys def assert_destination_state_loaded(pipeline: Pipeline) -> None: """Verify stored destination state matches the local pipeline state""" with pipeline.sql_client() as sql_client: - destination_state = state_sync.load_state_from_destination(pipeline.pipeline_name, sql_client) + destination_state = state_sync.load_state_from_destination( + pipeline.pipeline_name, sql_client + ) pipeline_state = dict(pipeline.state) - del pipeline_state['_local'] + del pipeline_state["_local"] assert pipeline_state == destination_state -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name) +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) def test_drop_command_resources_and_state(destination_config: DestinationTestConfiguration) -> None: """Test the drop command with resource and state path options and verify correct data is deleted from destination and locally""" source = droppable_source() - pipeline = destination_config.setup_pipeline('drop_test_' + uniq_id(), full_refresh=True) + pipeline = destination_config.setup_pipeline("drop_test_" + uniq_id(), full_refresh=True) pipeline.run(source) attached = _attach(pipeline) - helpers.drop(attached, resources=['droppable_c', 'droppable_d'], state_paths='data_from_d.*.bar') + helpers.drop( + attached, resources=["droppable_c", "droppable_d"], state_paths="data_from_d.*.bar" + ) attached = _attach(pipeline) - assert_dropped_resources(attached, ['droppable_c', 'droppable_d']) + assert_dropped_resources(attached, ["droppable_c", "droppable_d"]) # Verify extra json paths are removed from state - sources_state = pipeline.state['sources'] # type: ignore[typeddict-item] - assert sources_state['droppable']['data_from_d'] == {'foo1': {}, 'foo2': {}} + sources_state = pipeline.state["sources"] # type: ignore[typeddict-item] + assert sources_state["droppable"]["data_from_d"] == {"foo1": {}, "foo2": {}} assert_destination_state_loaded(pipeline) -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name) +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) def test_drop_destination_tables_fails(destination_config: DestinationTestConfiguration) -> None: """Fail on drop tables. Command runs again.""" source = droppable_source() - pipeline = destination_config.setup_pipeline('drop_test_' + uniq_id(), full_refresh=True) + pipeline = destination_config.setup_pipeline("drop_test_" + uniq_id(), full_refresh=True) pipeline.run(source) attached = _attach(pipeline) - with mock.patch.object(helpers.DropCommand, '_drop_destination_tables', side_effect=RuntimeError("Something went wrong")): + with mock.patch.object( + helpers.DropCommand, + "_drop_destination_tables", + side_effect=RuntimeError("Something went wrong"), + ): with pytest.raises(RuntimeError): - helpers.drop(attached, resources=('droppable_a', 'droppable_b')) + helpers.drop(attached, resources=("droppable_a", "droppable_b")) attached = _attach(pipeline) - helpers.drop(attached, resources=('droppable_a', 'droppable_b')) + helpers.drop(attached, resources=("droppable_a", "droppable_b")) - assert_dropped_resources(attached, ['droppable_a', 'droppable_b']) + assert_dropped_resources(attached, ["droppable_a", "droppable_b"]) assert_destination_state_loaded(attached) -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name) +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) def test_fail_after_drop_tables(destination_config: DestinationTestConfiguration) -> None: """Fail directly after drop tables. Command runs again ignoring destination tables missing.""" source = droppable_source() - pipeline = destination_config.setup_pipeline('drop_test_' + uniq_id(), full_refresh=True) + pipeline = destination_config.setup_pipeline("drop_test_" + uniq_id(), full_refresh=True) pipeline.run(source) attached = _attach(pipeline) - with mock.patch.object(helpers.DropCommand, '_drop_state_keys', side_effect=RuntimeError("Something went wrong")): + with mock.patch.object( + helpers.DropCommand, "_drop_state_keys", side_effect=RuntimeError("Something went wrong") + ): with pytest.raises(RuntimeError): - helpers.drop(attached, resources=('droppable_a', 'droppable_b')) + helpers.drop(attached, resources=("droppable_a", "droppable_b")) attached = _attach(pipeline) - helpers.drop(attached, resources=('droppable_a', 'droppable_b')) + helpers.drop(attached, resources=("droppable_a", "droppable_b")) - assert_dropped_resources(attached, ['droppable_a', 'droppable_b']) + assert_dropped_resources(attached, ["droppable_a", "droppable_b"]) assert_destination_state_loaded(attached) -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name) +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) def test_load_step_fails(destination_config: DestinationTestConfiguration) -> None: """Test idempotency. pipeline.load() fails. Command can be run again successfully""" source = droppable_source() - pipeline = destination_config.setup_pipeline('drop_test_' + uniq_id(), full_refresh=True) + pipeline = destination_config.setup_pipeline("drop_test_" + uniq_id(), full_refresh=True) pipeline.run(source) attached = _attach(pipeline) - with mock.patch.object(Load, 'run', side_effect=RuntimeError("Something went wrong")): + with mock.patch.object(Load, "run", side_effect=RuntimeError("Something went wrong")): with pytest.raises(PipelineStepFailed) as e: - helpers.drop(attached, resources=('droppable_a', 'droppable_b')) + helpers.drop(attached, resources=("droppable_a", "droppable_b")) assert isinstance(e.value.exception, RuntimeError) attached = _attach(pipeline) - helpers.drop(attached, resources=('droppable_a', 'droppable_b')) + helpers.drop(attached, resources=("droppable_a", "droppable_b")) - assert_dropped_resources(attached, ['droppable_a', 'droppable_b']) + assert_dropped_resources(attached, ["droppable_a", "droppable_b"]) assert_destination_state_loaded(attached) -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name) +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) def test_resource_regex(destination_config: DestinationTestConfiguration) -> None: source = droppable_source() - pipeline = destination_config.setup_pipeline('drop_test_' + uniq_id(), full_refresh=True) + pipeline = destination_config.setup_pipeline("drop_test_" + uniq_id(), full_refresh=True) pipeline.run(source) attached = _attach(pipeline) - helpers.drop(attached, resources=['re:.+_b', 're:.+_a']) + helpers.drop(attached, resources=["re:.+_b", "re:.+_a"]) attached = _attach(pipeline) - assert_dropped_resources(attached, ['droppable_a', 'droppable_b']) + assert_dropped_resources(attached, ["droppable_a", "droppable_b"]) assert_destination_state_loaded(attached) -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name) +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) def test_drop_nothing(destination_config: DestinationTestConfiguration) -> None: """No resources, no state keys. Nothing is changed.""" source = droppable_source() - pipeline = destination_config.setup_pipeline('drop_test_' + uniq_id(), full_refresh=True) + pipeline = destination_config.setup_pipeline("drop_test_" + uniq_id(), full_refresh=True) pipeline.run(source) attached = _attach(pipeline) @@ -214,13 +242,17 @@ def test_drop_nothing(destination_config: DestinationTestConfiguration) -> None: assert previous_state == attached.state -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name) +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) def test_drop_all_flag(destination_config: DestinationTestConfiguration) -> None: """Using drop_all flag. Destination dataset and all local state is deleted""" source = droppable_source() - pipeline = destination_config.setup_pipeline('drop_test_' + uniq_id(), full_refresh=True) + pipeline = destination_config.setup_pipeline("drop_test_" + uniq_id(), full_refresh=True) pipeline.run(source) - dlt_tables = [t['name'] for t in pipeline.default_schema.dlt_tables()] # Original _dlt tables to check for + dlt_tables = [ + t["name"] for t in pipeline.default_schema.dlt_tables() + ] # Original _dlt tables to check for attached = _attach(pipeline) @@ -237,15 +269,17 @@ def test_drop_all_flag(destination_config: DestinationTestConfiguration) -> None assert exists -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name) +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) def test_run_pipeline_after_partial_drop(destination_config: DestinationTestConfiguration) -> None: """Pipeline can be run again after dropping some resources""" - pipeline = destination_config.setup_pipeline('drop_test_' + uniq_id(), full_refresh=True) + pipeline = destination_config.setup_pipeline("drop_test_" + uniq_id(), full_refresh=True) pipeline.run(droppable_source()) attached = _attach(pipeline) - helpers.drop(attached, resources='droppable_a') + helpers.drop(attached, resources="droppable_a") attached = _attach(pipeline) @@ -254,23 +288,26 @@ def test_run_pipeline_after_partial_drop(destination_config: DestinationTestConf attached.load(raise_on_failed_jobs=True) -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name) +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) def test_drop_state_only(destination_config: DestinationTestConfiguration) -> None: """Pipeline can be run again after dropping some resources""" - pipeline = destination_config.setup_pipeline('drop_test_' + uniq_id(), full_refresh=True) + pipeline = destination_config.setup_pipeline("drop_test_" + uniq_id(), full_refresh=True) pipeline.run(droppable_source()) attached = _attach(pipeline) - helpers.drop(attached, resources=('droppable_a', 'droppable_b'), state_only=True) + helpers.drop(attached, resources=("droppable_a", "droppable_b"), state_only=True) attached = _attach(pipeline) assert_dropped_resource_tables(attached, []) # No tables dropped - assert_dropped_resource_states(attached, ['droppable_a', 'droppable_b']) + assert_dropped_resource_states(attached, ["droppable_a", "droppable_b"]) assert_destination_state_loaded(attached) -if __name__ == '__main__': +if __name__ == "__main__": import pytest - pytest.main(['-k', 'drop_all', 'tests/load/pipeline/test_drop.py', '--pdb', '-s']) + + pytest.main(["-k", "drop_all", "tests/load/pipeline/test_drop.py", "--pdb", "-s"]) diff --git a/tests/load/pipeline/test_filesystem_pipeline.py b/tests/load/pipeline/test_filesystem_pipeline.py index 82be81f337..6df83ba198 100644 --- a/tests/load/pipeline/test_filesystem_pipeline.py +++ b/tests/load/pipeline/test_filesystem_pipeline.py @@ -1,21 +1,26 @@ +import os import posixpath from pathlib import Path -import dlt, os -from dlt.common.utils import uniq_id +import pyarrow.parquet as pq + +import dlt +from dlt.common.schema.typing import LOADS_TABLE_NAME from dlt.common.storages.load_storage import LoadJobInfo +from dlt.common.utils import uniq_id from dlt.destinations.filesystem.filesystem import FilesystemClient, LoadFilesystemJob -from dlt.common.schema.typing import LOADS_TABLE_NAME - -import pyarrow.parquet as pq -def assert_file_matches(layout: str, job: LoadJobInfo, load_id: str, client: FilesystemClient) -> None: +def assert_file_matches( + layout: str, job: LoadJobInfo, load_id: str, client: FilesystemClient +) -> None: """Verify file contents of load job are identical to the corresponding file in destination""" local_path = Path(job.file_path) filename = local_path.name - destination_fn = LoadFilesystemJob.make_destination_filename(layout, filename, client.schema.name, load_id) + destination_fn = LoadFilesystemJob.make_destination_filename( + layout, filename, client.schema.name, load_id + ) destination_path = posixpath.join(client.dataset_path, destination_fn) assert local_path.read_bytes() == client.fs_client.read_bytes(destination_path) @@ -25,11 +30,15 @@ def test_pipeline_merge_write_disposition(all_buckets_env: str) -> None: """Run pipeline twice with merge write disposition Resource with primary key falls back to append. Resource without keys falls back to replace. """ - pipeline = dlt.pipeline(pipeline_name='test_' + uniq_id(), destination="filesystem", dataset_name='test_' + uniq_id()) + pipeline = dlt.pipeline( + pipeline_name="test_" + uniq_id(), + destination="filesystem", + dataset_name="test_" + uniq_id(), + ) - @dlt.resource(primary_key='id') + @dlt.resource(primary_key="id") def some_data(): # type: ignore[no-untyped-def] - yield [{'id': 1}, {'id': 2}, {'id': 3}] + yield [{"id": 1}, {"id": 2}, {"id": 3}] @dlt.resource def other_data(): # type: ignore[no-untyped-def] @@ -39,8 +48,8 @@ def other_data(): # type: ignore[no-untyped-def] def some_source(): # type: ignore[no-untyped-def] return [some_data(), other_data()] - info1 = pipeline.run(some_source(), write_disposition='merge') - info2 = pipeline.run(some_source(), write_disposition='merge') + info1 = pipeline.run(some_source(), write_disposition="merge") + info2 = pipeline.run(some_source(), write_disposition="merge") client: FilesystemClient = pipeline._destination_client() # type: ignore[assignment] layout = client.config.layout @@ -67,10 +76,9 @@ def some_source(): # type: ignore[no-untyped-def] # Verify file contents assert info2.load_packages for pkg in info2.load_packages: - assert pkg.jobs['completed_jobs'] - for job in pkg.jobs['completed_jobs']: - assert_file_matches(layout, job, pkg.load_id, client) - + assert pkg.jobs["completed_jobs"] + for job in pkg.jobs["completed_jobs"]: + assert_file_matches(layout, job, pkg.load_id, client) complete_fn = f"{client.schema.name}.{LOADS_TABLE_NAME}.%s" @@ -79,7 +87,7 @@ def some_source(): # type: ignore[no-untyped-def] assert client.fs_client.isfile(posixpath.join(client.dataset_path, complete_fn % load_id2)) # Force replace - pipeline.run(some_source(), write_disposition='replace') + pipeline.run(some_source(), write_disposition="replace") append_files = client.fs_client.ls(append_glob, detail=False, refresh=True) replace_files = client.fs_client.ls(replace_glob, detail=False, refresh=True) assert len(append_files) == 1 @@ -87,14 +95,17 @@ def some_source(): # type: ignore[no-untyped-def] def test_pipeline_parquet_filesystem_destination() -> None: - # store locally - os.environ['DESTINATION__FILESYSTEM__BUCKET_URL'] = "file://_storage" - pipeline = dlt.pipeline(pipeline_name='parquet_test_' + uniq_id(), destination="filesystem", dataset_name='parquet_test_' + uniq_id()) - - @dlt.resource(primary_key='id') + os.environ["DESTINATION__FILESYSTEM__BUCKET_URL"] = "file://_storage" + pipeline = dlt.pipeline( + pipeline_name="parquet_test_" + uniq_id(), + destination="filesystem", + dataset_name="parquet_test_" + uniq_id(), + ) + + @dlt.resource(primary_key="id") def some_data(): # type: ignore[no-untyped-def] - yield [{'id': 1}, {'id': 2}, {'id': 3}] + yield [{"id": 1}, {"id": 2}, {"id": 3}] @dlt.resource def other_data(): # type: ignore[no-untyped-def] @@ -113,8 +124,8 @@ def some_source(): # type: ignore[no-untyped-def] assert len(package_info.jobs["completed_jobs"]) == 3 client: FilesystemClient = pipeline._destination_client() # type: ignore[assignment] - some_data_glob = posixpath.join(client.dataset_path, 'some_data/*') - other_data_glob = posixpath.join(client.dataset_path, 'other_data/*') + some_data_glob = posixpath.join(client.dataset_path, "some_data/*") + other_data_glob = posixpath.join(client.dataset_path, "other_data/*") some_data_files = client.fs_client.glob(some_data_glob) other_data_files = client.fs_client.glob(other_data_glob) diff --git a/tests/load/pipeline/test_merge_disposition.py b/tests/load/pipeline/test_merge_disposition.py index b2bcc7942b..bcbb0cdb41 100644 --- a/tests/load/pipeline/test_merge_disposition.py +++ b/tests/load/pipeline/test_merge_disposition.py @@ -1,13 +1,19 @@ -from copy import copy -import pytest import itertools import random +from copy import copy from typing import List + import pytest import yaml +from tests.load.pipeline.utils import ( + DestinationTestConfiguration, + destinations_configs, + load_table_counts, + select_data, +) +from tests.pipeline.utils import assert_load_info import dlt - from dlt.common import json, pendulum from dlt.common.configuration.container import Container from dlt.common.pipeline import StateInjectableContext @@ -16,23 +22,25 @@ from dlt.extract.source import DltResource from dlt.sources.helpers.transform import skip_first, take_first -from tests.pipeline.utils import assert_load_info -from tests.load.pipeline.utils import load_table_counts, select_data -from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration - # uncomment add motherduck tests # NOTE: the tests are passing but we disable them due to frequent ATTACH DATABASE timeouts # ACTIVE_DESTINATIONS += ["motherduck"] -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name) +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) def test_merge_on_keys_in_schema(destination_config: DestinationTestConfiguration) -> None: p = destination_config.setup_pipeline("eth_2", full_refresh=True) with open("tests/common/cases/schemas/eth/ethereum_schema_v5.yml", "r", encoding="utf-8") as f: schema = dlt.Schema.from_dict(yaml.safe_load(f)) - with open("tests/normalize/cases/ethereum.blocks.9c1d9b504ea240a482b007788d5cd61c_2.json", "r", encoding="utf-8") as f: + with open( + "tests/normalize/cases/ethereum.blocks.9c1d9b504ea240a482b007788d5cd61c_2.json", + "r", + encoding="utf-8", + ) as f: data = json.load(f) # take only the first block. the first block does not have uncles so this table should not be created and merged @@ -42,7 +50,10 @@ def test_merge_on_keys_in_schema(destination_config: DestinationTestConfiguratio # we load a single block assert eth_1_counts["blocks"] == 1 # check root key propagation - assert p.default_schema.tables["blocks__transactions"]["columns"]["_dlt_root_id"]["root_key"] is True + assert ( + p.default_schema.tables["blocks__transactions"]["columns"]["_dlt_root_id"]["root_key"] + is True + ) # now we load the whole dataset. blocks should be created which adds columns to blocks # if the table would be created before the whole load would fail because new columns have hints info = p.run(data, table_name="blocks", write_disposition="merge", schema=schema) @@ -59,11 +70,15 @@ def test_merge_on_keys_in_schema(destination_config: DestinationTestConfiguratio assert eth_2_counts == eth_3_counts -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name) +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) def test_merge_on_ad_hoc_primary_key(destination_config: DestinationTestConfiguration) -> None: p = destination_config.setup_pipeline("github_1", full_refresh=True) - with open("tests/normalize/cases/github.issues.load_page_5_duck.json", "r", encoding="utf-8") as f: + with open( + "tests/normalize/cases/github.issues.load_page_5_duck.json", "r", encoding="utf-8" + ) as f: data = json.load(f) # note: NodeId will be normalized to "node_id" which exists in the schema info = p.run(data[:17], table_name="issues", write_disposition="merge", primary_key="NodeId") @@ -89,17 +104,27 @@ def test_merge_on_ad_hoc_primary_key(destination_config: DestinationTestConfigur @dlt.source(root_key=True) def github(): - - @dlt.resource(table_name="issues", write_disposition="merge", primary_key="id", merge_key=("node_id", "url")) + @dlt.resource( + table_name="issues", + write_disposition="merge", + primary_key="id", + merge_key=("node_id", "url"), + ) def load_issues(): - with open("tests/normalize/cases/github.issues.load_page_5_duck.json", "r", encoding="utf-8") as f: + with open( + "tests/normalize/cases/github.issues.load_page_5_duck.json", "r", encoding="utf-8" + ) as f: yield from json.load(f) return load_issues -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name) -def test_merge_source_compound_keys_and_changes(destination_config: DestinationTestConfiguration) -> None: +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) +def test_merge_source_compound_keys_and_changes( + destination_config: DestinationTestConfiguration, +) -> None: p = destination_config.setup_pipeline("github_3", full_refresh=True) info = p.run(github()) @@ -108,9 +133,18 @@ def test_merge_source_compound_keys_and_changes(destination_config: DestinationT # 100 issues total assert github_1_counts["issues"] == 100 # check keys created - assert p.default_schema.tables["issues"]["columns"]["node_id"].items() > {"merge_key": True, "data_type": "text", "nullable": False}.items() - assert p.default_schema.tables["issues"]["columns"]["url"].items() > {"merge_key": True, "data_type": "text", "nullable": False}.items() - assert p.default_schema.tables["issues"]["columns"]["id"].items() > {"primary_key": True, "data_type": "bigint", "nullable": False}.items() + assert ( + p.default_schema.tables["issues"]["columns"]["node_id"].items() + > {"merge_key": True, "data_type": "text", "nullable": False}.items() + ) + assert ( + p.default_schema.tables["issues"]["columns"]["url"].items() + > {"merge_key": True, "data_type": "text", "nullable": False}.items() + ) + assert ( + p.default_schema.tables["issues"]["columns"]["id"].items() + > {"primary_key": True, "data_type": "bigint", "nullable": False}.items() + ) # append load_issues resource info = p.run(github().load_issues, write_disposition="append") @@ -118,10 +152,10 @@ def test_merge_source_compound_keys_and_changes(destination_config: DestinationT assert p.default_schema.tables["issues"]["write_disposition"] == "append" # the counts of all tables must be double github_2_counts = load_table_counts(p, *[t["name"] for t in p.default_schema.data_tables()]) - assert {k:v*2 for k, v in github_1_counts.items()} == github_2_counts + assert {k: v * 2 for k, v in github_1_counts.items()} == github_2_counts # now replace all resources - info = p.run(github(), write_disposition="replace" ) + info = p.run(github(), write_disposition="replace") assert_load_info(info) assert p.default_schema.tables["issues"]["write_disposition"] == "replace" # assert p.default_schema.tables["issues__labels"]["write_disposition"] == "replace" @@ -130,7 +164,9 @@ def test_merge_source_compound_keys_and_changes(destination_config: DestinationT assert github_1_counts == github_3_counts -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name) +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) def test_merge_no_child_tables(destination_config: DestinationTestConfiguration) -> None: p = destination_config.setup_pipeline("github_3", full_refresh=True) github_data = github() @@ -161,7 +197,9 @@ def test_merge_no_child_tables(destination_config: DestinationTestConfiguration) assert github_2_counts["issues"] == 100 if destination_config.supports_merge else 115 -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name) +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) def test_merge_no_merge_keys(destination_config: DestinationTestConfiguration) -> None: p = destination_config.setup_pipeline("github_3", full_refresh=True) github_data = github() @@ -187,19 +225,24 @@ def test_merge_no_merge_keys(destination_config: DestinationTestConfiguration) - assert github_1_counts["issues"] == 10 if destination_config.supports_merge else 100 - 45 -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name) +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) def test_merge_keys_non_existing_columns(destination_config: DestinationTestConfiguration) -> None: p = destination_config.setup_pipeline("github_3", full_refresh=True) github_data = github() # set keys names that do not exist in the data - github_data.load_issues.apply_hints(merge_key=("mA1", "Ma2"), primary_key=("123-x", )) + github_data.load_issues.apply_hints(merge_key=("mA1", "Ma2"), primary_key=("123-x",)) # skip first 45 rows github_data.load_issues.add_filter(skip_first(45)) info = p.run(github_data) assert_load_info(info) github_1_counts = load_table_counts(p, *[t["name"] for t in p.default_schema.data_tables()]) assert github_1_counts["issues"] == 100 - 45 - assert p.default_schema.tables["issues"]["columns"]["m_a1"].items() > {"merge_key": True, "nullable": False}.items() + assert ( + p.default_schema.tables["issues"]["columns"]["m_a1"].items() + > {"merge_key": True, "nullable": False}.items() + ) # for non merge destinations we just check that the run passes if not destination_config.supports_merge: @@ -207,7 +250,7 @@ def test_merge_keys_non_existing_columns(destination_config: DestinationTestConf # all the keys are invalid so the merge falls back to replace github_data = github() - github_data.load_issues.apply_hints(merge_key=("mA1", "Ma2"), primary_key=("123-x", )) + github_data.load_issues.apply_hints(merge_key=("mA1", "Ma2"), primary_key=("123-x",)) github_data.load_issues.add_filter(take_first(1)) info = p.run(github_data) assert_load_info(info) @@ -219,7 +262,11 @@ def test_merge_keys_non_existing_columns(destination_config: DestinationTestConf assert "m_a1" not in table_schema # unbound columns were not created -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True, subset=["duckdb", "snowflake", "bigquery"]), ids=lambda x: x.name) +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["duckdb", "snowflake", "bigquery"]), + ids=lambda x: x.name, +) def test_pipeline_load_parquet(destination_config: DestinationTestConfiguration) -> None: p = destination_config.setup_pipeline("github_3", full_refresh=True) github_data = github() @@ -227,7 +274,9 @@ def test_pipeline_load_parquet(destination_config: DestinationTestConfiguration) github_data.max_table_nesting = 2 github_data_copy = github() github_data_copy.max_table_nesting = 2 - info = p.run([github_data, github_data_copy], loader_file_format="parquet", write_disposition="merge") + info = p.run( + [github_data, github_data_copy], loader_file_format="parquet", write_disposition="merge" + ) assert_load_info(info) # make sure it was parquet or sql transforms files = p.get_load_package_info(p.list_completed_load_packages()[0]).jobs["completed_jobs"] @@ -250,22 +299,34 @@ def test_pipeline_load_parquet(destination_config: DestinationTestConfiguration) assert github_1_counts["issues"] == 100 - -@dlt.transformer(name="github_repo_events", primary_key="id", write_disposition="merge", table_name=lambda i: i['type']) -def github_repo_events(page: List[StrAny], last_created_at = dlt.sources.incremental("created_at", "1970-01-01T00:00:00Z")): - """A transformer taking a stream of github events and dispatching them to tables named by event type. Deduplicates be 'id'. Loads incrementally by 'created_at' """ +@dlt.transformer( + name="github_repo_events", + primary_key="id", + write_disposition="merge", + table_name=lambda i: i["type"], +) +def github_repo_events( + page: List[StrAny], + last_created_at=dlt.sources.incremental("created_at", "1970-01-01T00:00:00Z"), +): + """A transformer taking a stream of github events and dispatching them to tables named by event type. Deduplicates be 'id'. Loads incrementally by 'created_at'""" yield page @dlt.transformer(name="github_repo_events", primary_key="id", write_disposition="merge") -def github_repo_events_table_meta(page: List[StrAny], last_created_at = dlt.sources.incremental("created_at", "1970-01-01T00:00:00Z")): - """A transformer taking a stream of github events and dispatching them to tables using table meta. Deduplicates be 'id'. Loads incrementally by 'created_at' """ - yield from [dlt.mark.with_table_name(p, p['type']) for p in page] +def github_repo_events_table_meta( + page: List[StrAny], + last_created_at=dlt.sources.incremental("created_at", "1970-01-01T00:00:00Z"), +): + """A transformer taking a stream of github events and dispatching them to tables using table meta. Deduplicates be 'id'. Loads incrementally by 'created_at'""" + yield from [dlt.mark.with_table_name(p, p["type"]) for p in page] @dlt.resource def _get_shuffled_events(shuffle: bool = dlt.secrets.value): - with open("tests/normalize/cases/github.events.load_page_1_duck.json", "r", encoding="utf-8") as f: + with open( + "tests/normalize/cases/github.events.load_page_1_duck.json", "r", encoding="utf-8" + ) as f: issues = json.load(f) # random order if shuffle: @@ -273,17 +334,22 @@ def _get_shuffled_events(shuffle: bool = dlt.secrets.value): yield issues - -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name) -@pytest.mark.parametrize("github_resource",[github_repo_events, github_repo_events_table_meta]) -def test_merge_with_dispatch_and_incremental(destination_config: DestinationTestConfiguration, github_resource: DltResource) -> None: - newest_issues = list(sorted(_get_shuffled_events(True), key = lambda x: x["created_at"], reverse=True)) +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) +@pytest.mark.parametrize("github_resource", [github_repo_events, github_repo_events_table_meta]) +def test_merge_with_dispatch_and_incremental( + destination_config: DestinationTestConfiguration, github_resource: DltResource +) -> None: + newest_issues = list( + sorted(_get_shuffled_events(True), key=lambda x: x["created_at"], reverse=True) + ) newest_issue = newest_issues[0] @dlt.resource def _new_event(node_id): new_i = copy(newest_issue) - new_i["id"] = str(random.randint(0, 2^32)) + new_i["id"] = str(random.randint(0, 2 ^ 32)) new_i["created_at"] = pendulum.now().isoformat() new_i["node_id"] = node_id # yield pages @@ -301,21 +367,33 @@ def _updated_event(node_id): with Container().injectable_context(StateInjectableContext(state={})): assert len(list(_get_shuffled_events(True) | github_resource)) == 100 incremental_state = github_resource.state - assert incremental_state["incremental"]["created_at"]["last_value"] == newest_issue["created_at"] - assert incremental_state["incremental"]["created_at"]["unique_hashes"] == [digest128(f'"{newest_issue["id"]}"')] + assert ( + incremental_state["incremental"]["created_at"]["last_value"] + == newest_issue["created_at"] + ) + assert incremental_state["incremental"]["created_at"]["unique_hashes"] == [ + digest128(f'"{newest_issue["id"]}"') + ] # subsequent load will skip all elements assert len(list(_get_shuffled_events(True) | github_resource)) == 0 # add one more issue assert len(list(_new_event("new_node") | github_resource)) == 1 - assert incremental_state["incremental"]["created_at"]["last_value"] > newest_issue["created_at"] - assert incremental_state["incremental"]["created_at"]["unique_hashes"] != [digest128(str(newest_issue["id"]))] + assert ( + incremental_state["incremental"]["created_at"]["last_value"] + > newest_issue["created_at"] + ) + assert incremental_state["incremental"]["created_at"]["unique_hashes"] != [ + digest128(str(newest_issue["id"])) + ] # load to destination p = destination_config.setup_pipeline("github_3", full_refresh=True) info = p.run(_get_shuffled_events(True) | github_resource) assert_load_info(info) # get top tables - counts = load_table_counts(p, *[t["name"] for t in p.default_schema.data_tables() if t.get("parent") is None]) + counts = load_table_counts( + p, *[t["name"] for t in p.default_schema.data_tables() if t.get("parent") is None] + ) # total number of events in all top tables == 100 assert sum(counts.values()) == 100 # this should skip all events due to incremental load @@ -326,10 +404,12 @@ def _updated_event(node_id): # load one more event with a new id info = p.run(_new_event("new_node") | github_resource) assert_load_info(info) - counts = load_table_counts(p, *[t["name"] for t in p.default_schema.data_tables() if t.get("parent") is None]) + counts = load_table_counts( + p, *[t["name"] for t in p.default_schema.data_tables() if t.get("parent") is None] + ) assert sum(counts.values()) == 101 # all the columns have primary keys and merge disposition derived from resource - for table in p.default_schema.data_tables(): + for table in p.default_schema.data_tables(): if table.get("parent") is None: assert table["write_disposition"] == "merge" assert table["columns"]["id"]["primary_key"] is True @@ -338,7 +418,9 @@ def _updated_event(node_id): info = p.run(_updated_event("new_node_X") | github_resource) assert_load_info(info) # still 101 - counts = load_table_counts(p, *[t["name"] for t in p.default_schema.data_tables() if t.get("parent") is None]) + counts = load_table_counts( + p, *[t["name"] for t in p.default_schema.data_tables() if t.get("parent") is None] + ) assert sum(counts.values()) == 101 if destination_config.supports_merge else 102 # for non merge destinations we just check that the run passes if not destination_config.supports_merge: @@ -349,13 +431,18 @@ def _updated_event(node_id): assert len(list(q.fetchall())) == 1 -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name) +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) def test_deduplicate_single_load(destination_config: DestinationTestConfiguration) -> None: p = destination_config.setup_pipeline("abstract", full_refresh=True) @dlt.resource(write_disposition="merge", primary_key="id") def duplicates(): - yield [{"id": 1, "name": "row1", "child": [1, 2, 3]}, {"id": 1, "name": "row2", "child": [4, 5, 6]}] + yield [ + {"id": 1, "name": "row1", "child": [1, 2, 3]}, + {"id": 1, "name": "row2", "child": [4, 5, 6]}, + ] info = p.run(duplicates()) assert_load_info(info) @@ -364,7 +451,6 @@ def duplicates(): assert counts["duplicates__child"] == 3 if destination_config.supports_merge else 6 select_data(p, "SELECT * FROM duplicates")[0] - @dlt.resource(write_disposition="merge", primary_key=("id", "subkey")) def duplicates_no_child(): yield [{"id": 1, "subkey": "AX", "name": "row1"}, {"id": 1, "subkey": "AX", "name": "row2"}] @@ -375,13 +461,18 @@ def duplicates_no_child(): assert counts["duplicates_no_child"] == 1 if destination_config.supports_merge else 2 -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name) +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) def test_no_deduplicate_only_merge_key(destination_config: DestinationTestConfiguration) -> None: p = destination_config.setup_pipeline("abstract", full_refresh=True) @dlt.resource(write_disposition="merge", merge_key="id") def duplicates(): - yield [{"id": 1, "name": "row1", "child": [1, 2, 3]}, {"id": 1, "name": "row2", "child": [4, 5, 6]}] + yield [ + {"id": 1, "name": "row1", "child": [1, 2, 3]}, + {"id": 1, "name": "row2", "child": [4, 5, 6]}, + ] info = p.run(duplicates()) assert_load_info(info) @@ -389,7 +480,6 @@ def duplicates(): assert counts["duplicates"] == 2 assert counts["duplicates__child"] == 6 - @dlt.resource(write_disposition="merge", merge_key=("id", "subkey")) def duplicates_no_child(): yield [{"id": 1, "subkey": "AX", "name": "row1"}, {"id": 1, "subkey": "AX", "name": "row2"}] diff --git a/tests/load/pipeline/test_pipelines.py b/tests/load/pipeline/test_pipelines.py index 4e01ca0f82..b56c8bddf3 100644 --- a/tests/load/pipeline/test_pipelines.py +++ b/tests/load/pipeline/test_pipelines.py @@ -1,33 +1,54 @@ -from copy import deepcopy import gzip import os +from copy import deepcopy from typing import Any, Callable, Iterator, Tuple + import pytest +from tests.load.pipeline.utils import ( + DestinationTestConfiguration, + assert_query_data, + assert_table, + destinations_configs, + drop_active_pipeline_data, + load_table_counts, + select_data, +) +from tests.load.utils import ( + TABLE_ROW_ALL_DATA_TYPES, + TABLE_UPDATE_COLUMNS_SCHEMA, + assert_all_data_types_row, + delete_dataset, +) +from tests.pipeline.utils import assert_load_info +from tests.utils import TEST_STORAGE_ROOT, preserve_environ import dlt - from dlt.common import json, sleep from dlt.common.destination.reference import DestinationReference +from dlt.common.exceptions import DestinationHasFailedJobs +from dlt.common.schema.exceptions import CannotCoerceColumnException from dlt.common.schema.schema import Schema from dlt.common.schema.typing import VERSION_TABLE_NAME from dlt.common.typing import TDataItem from dlt.common.utils import uniq_id from dlt.extract.exceptions import ResourceNameMissing from dlt.extract.source import DltSource -from dlt.pipeline.exceptions import CannotRestorePipelineException, PipelineConfigMissing, PipelineStepFailed -from dlt.common.schema.exceptions import CannotCoerceColumnException -from dlt.common.exceptions import DestinationHasFailedJobs - -from tests.utils import TEST_STORAGE_ROOT, preserve_environ -from tests.pipeline.utils import assert_load_info -from tests.load.utils import TABLE_ROW_ALL_DATA_TYPES, TABLE_UPDATE_COLUMNS_SCHEMA, assert_all_data_types_row, delete_dataset -from tests.load.pipeline.utils import drop_active_pipeline_data, assert_query_data, assert_table, load_table_counts, select_data -from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration - - -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True, all_buckets_filesystem_configs=True), ids=lambda x: x.name) -@pytest.mark.parametrize('use_single_dataset', [True, False]) -def test_default_pipeline_names(use_single_dataset: bool, destination_config: DestinationTestConfiguration) -> None: +from dlt.pipeline.exceptions import ( + CannotRestorePipelineException, + PipelineConfigMissing, + PipelineStepFailed, +) + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, all_buckets_filesystem_configs=True), + ids=lambda x: x.name, +) +@pytest.mark.parametrize("use_single_dataset", [True, False]) +def test_default_pipeline_names( + use_single_dataset: bool, destination_config: DestinationTestConfiguration +) -> None: destination_config.setup() p = dlt.pipeline() p.config.use_single_dataset = use_single_dataset @@ -66,8 +87,12 @@ def data_fun() -> Iterator[Any]: with p.managed_state(): p._set_destinations( DestinationReference.from_name(destination_config.destination), - DestinationReference.from_name(destination_config.staging) if destination_config.staging else None - ) + ( + DestinationReference.from_name(destination_config.staging) + if destination_config.staging + else None + ), + ) # does not reset the dataset name assert p.dataset_name in possible_dataset_names # never do that in production code @@ -91,13 +116,23 @@ def data_fun() -> Iterator[Any]: assert_table(p, "data_fun", data, schema_name="names", info=info) -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True, all_buckets_filesystem_configs=True), ids=lambda x: x.name) +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, all_buckets_filesystem_configs=True), + ids=lambda x: x.name, +) def test_default_schema_name(destination_config: DestinationTestConfiguration) -> None: destination_config.setup() dataset_name = "dataset_" + uniq_id() data = ["a", "b", "c"] - p = dlt.pipeline("test_default_schema_name", TEST_STORAGE_ROOT, destination=destination_config.destination, staging=destination_config.staging, dataset_name=dataset_name) + p = dlt.pipeline( + "test_default_schema_name", + TEST_STORAGE_ROOT, + destination=destination_config.destination, + staging=destination_config.staging, + dataset_name=dataset_name, + ) p.extract(data, table_name="test", schema=Schema("default")) p.normalize() info = p.load() @@ -110,9 +145,12 @@ def test_default_schema_name(destination_config: DestinationTestConfiguration) - assert_table(p, "test", data, info=info) -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True, all_buckets_filesystem_configs=True), ids=lambda x: x.name) +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, all_buckets_filesystem_configs=True), + ids=lambda x: x.name, +) def test_attach_pipeline(destination_config: DestinationTestConfiguration) -> None: - # load data and then restore the pipeline and see if data is still there data = ["a", "b", "c"] @@ -122,7 +160,12 @@ def _data(): yield d destination_config.setup() - info = dlt.run(_data(), destination=destination_config.destination, staging=destination_config.staging, dataset_name="specific" + uniq_id()) + info = dlt.run( + _data(), + destination=destination_config.destination, + staging=destination_config.staging, + dataset_name="specific" + uniq_id(), + ) with pytest.raises(CannotRestorePipelineException): dlt.attach("unknown") @@ -143,9 +186,12 @@ def _data(): assert_table(p, "data_table", data, info=info) -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name) -def test_skip_sync_schema_for_tables_without_columns(destination_config: DestinationTestConfiguration) -> None: - +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) +def test_skip_sync_schema_for_tables_without_columns( + destination_config: DestinationTestConfiguration, +) -> None: # load data and then restore the pipeline and see if data is still there data = ["a", "b", "c"] @@ -172,7 +218,11 @@ def _data(): assert not exists -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True, all_buckets_filesystem_configs=True), ids=lambda x: x.name) +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, all_buckets_filesystem_configs=True), + ids=lambda x: x.name, +) def test_run_full_refresh(destination_config: DestinationTestConfiguration) -> None: data = ["a", ["a", "b", "c"], ["a", "b", "c"]] destination_config.setup() @@ -185,7 +235,12 @@ def _data(): return dlt.resource(d(), name="lists", write_disposition="replace") p = dlt.pipeline(full_refresh=True) - info = p.run(_data(), destination=destination_config.destination, staging=destination_config.staging, dataset_name="iteration" + uniq_id()) + info = p.run( + _data(), + destination=destination_config.destination, + staging=destination_config.staging, + dataset_name="iteration" + uniq_id(), + ) assert info.dataset_name == p.dataset_name assert info.dataset_name.endswith(p._pipeline_instance_id) # print(p.default_schema.to_pretty_yaml()) @@ -201,23 +256,18 @@ def _data(): assert_table(p, "lists__value", sorted(data[1] + data[2])) - -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name) +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) def test_evolve_schema(destination_config: DestinationTestConfiguration) -> None: dataset_name = "d" + uniq_id() row = { "id": "level0", - "f": [{ - "id": "level1", - "l": ["a", "b", "c"], - "v": 120, - "o": [{"a": 1}, {"a": 2}] - }] + "f": [{"id": "level1", "l": ["a", "b", "c"], "v": 120, "o": [{"a": 1}, {"a": 2}]}], } @dlt.source(name="parallel") def source(top_elements: int): - @dlt.defer def get_item(no: int) -> TDataItem: # the test will not last 10 seconds but 2 (there are 5 working threads by default) @@ -226,23 +276,38 @@ def get_item(no: int) -> TDataItem: data["id"] = "level" + str(no) return data - @dlt.resource(columns={"id": {"name": "id", "nullable": False, "data_type": "text", "unique": True, "sort": True}}) + @dlt.resource( + columns={ + "id": { + "name": "id", + "nullable": False, + "data_type": "text", + "unique": True, + "sort": True, + } + } + ) def simple_rows(): for no in range(top_elements): # yield deferred items resolved in threads yield get_item(no) - @dlt.resource(table_name="simple_rows", columns={"new_column": {"nullable": True, "data_type": "decimal"}}) + @dlt.resource( + table_name="simple_rows", + columns={"new_column": {"nullable": True, "data_type": "decimal"}}, + ) def extended_rows(): for no in range(top_elements): # yield deferred items resolved in threads - yield get_item(no+100) + yield get_item(no + 100) return simple_rows(), extended_rows(), dlt.resource(["a", "b", "c"], name="simple") import_schema_path = os.path.join(TEST_STORAGE_ROOT, "schemas", "import") export_schema_path = os.path.join(TEST_STORAGE_ROOT, "schemas", "export") - p = destination_config.setup_pipeline("my_pipeline", import_schema_path=import_schema_path, export_schema_path=export_schema_path) + p = destination_config.setup_pipeline( + "my_pipeline", import_schema_path=import_schema_path, export_schema_path=export_schema_path + ) p.extract(source(10).with_resources("simple_rows")) # print(p.default_schema.to_pretty_yaml()) @@ -283,18 +348,30 @@ def extended_rows(): # TODO: test export and import schema # test data - id_data = sorted(["level" + str(n) for n in range(10)] + ["level" + str(n) for n in range(100, 110)]) + id_data = sorted( + ["level" + str(n) for n in range(10)] + ["level" + str(n) for n in range(100, 110)] + ) assert_query_data(p, "SELECT * FROM simple_rows ORDER BY id", id_data) - assert_query_data(p, "SELECT schema_version_hash FROM _dlt_loads ORDER BY inserted_at", version_history) - - -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True, all_buckets_filesystem_configs=True), ids=lambda x: x.name) -@pytest.mark.parametrize('disable_compression', [True, False]) -def test_pipeline_data_writer_compression(disable_compression: bool, destination_config: DestinationTestConfiguration) -> None: + assert_query_data( + p, "SELECT schema_version_hash FROM _dlt_loads ORDER BY inserted_at", version_history + ) + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, all_buckets_filesystem_configs=True), + ids=lambda x: x.name, +) +@pytest.mark.parametrize("disable_compression", [True, False]) +def test_pipeline_data_writer_compression( + disable_compression: bool, destination_config: DestinationTestConfiguration +) -> None: # Ensure pipeline works without compression data = ["a", "b", "c"] - dataset_name = "compression_data_"+ uniq_id() - dlt.config["data_writer"] = {"disable_compression": disable_compression} # not sure how else to set this + dataset_name = "compression_data_" + uniq_id() + dlt.config["data_writer"] = { + "disable_compression": disable_compression + } # not sure how else to set this p = destination_config.setup_pipeline("compression_test", dataset_name=dataset_name) p.extract(dlt.resource(data, name="data")) s = p._get_normalize_storage() @@ -308,27 +385,24 @@ def test_pipeline_data_writer_compression(disable_compression: bool, destination assert_table(info.pipeline, "data", data, info=info) -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name) +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) def test_source_max_nesting(destination_config: DestinationTestConfiguration) -> None: destination_config.setup() - complex_part = { - "l": [1, 2, 3], - "c": { - "a": 1, - "b": 12.3 - } - } + complex_part = {"l": [1, 2, 3], "c": {"a": 1, "b": 12.3}} @dlt.source(name="complex", max_table_nesting=0) def complex_data(): - return dlt.resource([ - { - "idx": 1, - "cn": complex_part - } - ], name="complex_cn") - info = dlt.run(complex_data(), destination=destination_config.destination, staging=destination_config.staging, dataset_name="ds_" + uniq_id()) + return dlt.resource([{"idx": 1, "cn": complex_part}], name="complex_cn") + + info = dlt.run( + complex_data(), + destination=destination_config.destination, + staging=destination_config.staging, + dataset_name="ds_" + uniq_id(), + ) print(info) rows = select_data(dlt.pipeline(), "SELECT cn FROM complex_cn") assert len(rows) == 1 @@ -338,7 +412,9 @@ def complex_data(): assert cn_val == complex_part -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name) +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) def test_dataset_name_change(destination_config: DestinationTestConfiguration) -> None: destination_config.setup() # standard name @@ -378,11 +454,18 @@ def test_dataset_name_change(destination_config: DestinationTestConfiguration) - # do not remove - it allows us to filter tests by destination -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True, subset=["postgres"]), ids=lambda x: x.name) -def test_pipeline_explicit_destination_credentials(destination_config: DestinationTestConfiguration) -> None: - +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["postgres"]), + ids=lambda x: x.name, +) +def test_pipeline_explicit_destination_credentials( + destination_config: DestinationTestConfiguration, +) -> None: # explicit credentials resolved - p = dlt.pipeline(destination="postgres", credentials="postgresql://loader:loader@localhost:5432/dlt_data") + p = dlt.pipeline( + destination="postgres", credentials="postgresql://loader:loader@localhost:5432/dlt_data" + ) c = p._get_destination_clients(Schema("s"), p._get_destination_client_initial_config())[0] assert c.config.credentials.host == "localhost" @@ -391,7 +474,9 @@ def test_pipeline_explicit_destination_credentials(destination_config: Destinati os.environ.pop("DESTINATION__POSTGRES__CREDENTIALS", None) # explicit credentials resolved ignoring the config providers os.environ["DESTINATION__POSTGRES__CREDENTIALS__HOST"] = "HOST" - p = dlt.pipeline(destination="postgres", credentials="postgresql://loader:loader@localhost:5432/dlt_data") + p = dlt.pipeline( + destination="postgres", credentials="postgresql://loader:loader@localhost:5432/dlt_data" + ) c = p._get_destination_clients(Schema("s"), p._get_destination_client_initial_config())[0] assert c.config.credentials.host == "localhost" @@ -413,14 +498,18 @@ def test_pipeline_explicit_destination_credentials(destination_config: Destinati # do not remove - it allows us to filter tests by destination -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True, subset=["postgres"]), ids=lambda x: x.name) -def test_pipeline_with_sources_sharing_schema(destination_config: DestinationTestConfiguration) -> None: - +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["postgres"]), + ids=lambda x: x.name, +) +def test_pipeline_with_sources_sharing_schema( + destination_config: DestinationTestConfiguration, +) -> None: schema = Schema("shared") @dlt.source(schema=schema, max_table_nesting=1) def source_1(): - @dlt.resource(primary_key="user_id") def gen1(): dlt.current.source_state()["source_1"] = True @@ -435,7 +524,6 @@ def conflict(): @dlt.source(schema=schema, max_table_nesting=2) def source_2(): - @dlt.resource(primary_key="id") def gen1(): dlt.current.source_state()["source_2"] = True @@ -478,15 +566,15 @@ def conflict(): p.load() table_names = [t["name"] for t in default_schema.data_tables()] counts = load_table_counts(p, *table_names) - assert counts == {'gen1': 2, 'gen2': 3, 'conflict': 1} + assert counts == {"gen1": 2, "gen2": 3, "conflict": 1} # both sources share the same state assert p.state["sources"] == { - 'shared': { - 'source_1': True, - 'resources': {'gen1': {'source_1': True, 'source_2': True}}, - 'source_2': True - } + "shared": { + "source_1": True, + "resources": {"gen1": {"source_1": True, "source_2": True}}, + "source_2": True, } + } drop_active_pipeline_data() # same pipeline but enable conflict @@ -497,13 +585,16 @@ def conflict(): # do not remove - it allows us to filter tests by destination -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True, subset=["postgres"]), ids=lambda x: x.name) +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["postgres"]), + ids=lambda x: x.name, +) def test_many_pipelines_single_dataset(destination_config: DestinationTestConfiguration) -> None: schema = Schema("shared") @dlt.source(schema=schema, max_table_nesting=1) def source_1(): - @dlt.resource(primary_key="user_id") def gen1(): dlt.current.source_state()["source_1"] = True @@ -514,7 +605,6 @@ def gen1(): @dlt.source(schema=schema, max_table_nesting=2) def source_2(): - @dlt.resource(primary_key="id") def gen1(): dlt.current.source_state()["source_2"] = True @@ -527,44 +617,68 @@ def gen2(): return gen2, gen1 # load source_1 to common dataset - p = dlt.pipeline(pipeline_name="source_1_pipeline", destination="duckdb", dataset_name="shared_dataset") + p = dlt.pipeline( + pipeline_name="source_1_pipeline", destination="duckdb", dataset_name="shared_dataset" + ) p.run(source_1(), credentials="duckdb:///_storage/test_quack.duckdb") counts = load_table_counts(p, *p.default_schema.tables.keys()) - assert counts.items() >= {'gen1': 1, '_dlt_pipeline_state': 1, "_dlt_loads": 1}.items() + assert counts.items() >= {"gen1": 1, "_dlt_pipeline_state": 1, "_dlt_loads": 1}.items() p._wipe_working_folder() p.deactivate() - p = dlt.pipeline(pipeline_name="source_2_pipeline", destination="duckdb", dataset_name="shared_dataset") + p = dlt.pipeline( + pipeline_name="source_2_pipeline", destination="duckdb", dataset_name="shared_dataset" + ) p.run(source_2(), credentials="duckdb:///_storage/test_quack.duckdb") # table_names = [t["name"] for t in p.default_schema.data_tables()] counts = load_table_counts(p, *p.default_schema.tables.keys()) # gen1: one record comes from source_1, 1 record from source_2 - assert counts.items() >= {'gen1': 2, '_dlt_pipeline_state': 2, "_dlt_loads": 2}.items() + assert counts.items() >= {"gen1": 2, "_dlt_pipeline_state": 2, "_dlt_loads": 2}.items() # assert counts == {'gen1': 2, 'gen2': 3} p._wipe_working_folder() p.deactivate() # restore from destination, check state - p = dlt.pipeline(pipeline_name="source_1_pipeline", destination="duckdb", dataset_name="shared_dataset", credentials="duckdb:///_storage/test_quack.duckdb") + p = dlt.pipeline( + pipeline_name="source_1_pipeline", + destination="duckdb", + dataset_name="shared_dataset", + credentials="duckdb:///_storage/test_quack.duckdb", + ) p.sync_destination() # we have our separate state - assert p.state["sources"]["shared"] == {'source_1': True, 'resources': {'gen1': {'source_1': True}}} + assert p.state["sources"]["shared"] == { + "source_1": True, + "resources": {"gen1": {"source_1": True}}, + } # but the schema was common so we have the earliest one assert "gen2" in p.default_schema.tables p._wipe_working_folder() p.deactivate() - p = dlt.pipeline(pipeline_name="source_2_pipeline", destination="duckdb", dataset_name="shared_dataset", credentials="duckdb:///_storage/test_quack.duckdb") + p = dlt.pipeline( + pipeline_name="source_2_pipeline", + destination="duckdb", + dataset_name="shared_dataset", + credentials="duckdb:///_storage/test_quack.duckdb", + ) p.sync_destination() # we have our separate state - assert p.state["sources"]["shared"] == {'source_2': True, 'resources': {'gen1': {'source_2': True}}} + assert p.state["sources"]["shared"] == { + "source_2": True, + "resources": {"gen1": {"source_2": True}}, + } # do not remove - it allows us to filter tests by destination -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True, subset=["snowflake"]), ids=lambda x: x.name) +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["snowflake"]), + ids=lambda x: x.name, +) def test_snowflake_custom_stage(destination_config: DestinationTestConfiguration) -> None: """Using custom stage name instead of the table stage""" - os.environ['DESTINATION__SNOWFLAKE__STAGE_NAME'] = 'my_non_existing_stage' + os.environ["DESTINATION__SNOWFLAKE__STAGE_NAME"] = "my_non_existing_stage" pipeline, data = simple_nested_pipeline(destination_config, f"custom_stage_{uniq_id()}", False) info = pipeline.run(data()) with pytest.raises(DestinationHasFailedJobs) as f_jobs: @@ -576,8 +690,8 @@ def test_snowflake_custom_stage(destination_config: DestinationTestConfiguration # NOTE: this stage must be created in DLT_DATA database for this test to pass! # CREATE STAGE MY_CUSTOM_LOCAL_STAGE; # GRANT READ, WRITE ON STAGE DLT_DATA.PUBLIC.MY_CUSTOM_LOCAL_STAGE TO ROLE DLT_LOADER_ROLE; - stage_name = 'PUBLIC.MY_CUSTOM_LOCAL_STAGE' - os.environ['DESTINATION__SNOWFLAKE__STAGE_NAME'] = stage_name + stage_name = "PUBLIC.MY_CUSTOM_LOCAL_STAGE" + os.environ["DESTINATION__SNOWFLAKE__STAGE_NAME"] = stage_name pipeline, data = simple_nested_pipeline(destination_config, f"custom_stage_{uniq_id()}", False) info = pipeline.run(data()) assert_load_info(info) @@ -590,16 +704,22 @@ def test_snowflake_custom_stage(destination_config: DestinationTestConfiguration assert len(staged_files) == 3 # check data of one table to ensure copy was done successfully tbl_name = client.make_qualified_table_name("lists") - assert_query_data(pipeline, f"SELECT value FROM {tbl_name}", ['a', None, None]) + assert_query_data(pipeline, f"SELECT value FROM {tbl_name}", ["a", None, None]) # do not remove - it allows us to filter tests by destination -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True, subset=["snowflake"]), ids=lambda x: x.name) +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["snowflake"]), + ids=lambda x: x.name, +) def test_snowflake_delete_file_after_copy(destination_config: DestinationTestConfiguration) -> None: """Using keep_staged_files = false option to remove staged files after copy""" - os.environ['DESTINATION__SNOWFLAKE__KEEP_STAGED_FILES'] = 'FALSE' + os.environ["DESTINATION__SNOWFLAKE__KEEP_STAGED_FILES"] = "FALSE" - pipeline, data = simple_nested_pipeline(destination_config, f"delete_staged_files_{uniq_id()}", False) + pipeline, data = simple_nested_pipeline( + destination_config, f"delete_staged_files_{uniq_id()}", False + ) info = pipeline.run(data()) assert_load_info(info) @@ -608,26 +728,32 @@ def test_snowflake_delete_file_after_copy(destination_config: DestinationTestCon with pipeline.sql_client() as client: # no files are left in table stage - stage_name = client.make_qualified_table_name('%lists') + stage_name = client.make_qualified_table_name("%lists") staged_files = client.execute_sql(f'LIST @{stage_name}/"{load_id}"') assert len(staged_files) == 0 # ensure copy was done tbl_name = client.make_qualified_table_name("lists") - assert_query_data(pipeline, f"SELECT value FROM {tbl_name}", ['a', None, None]) + assert_query_data(pipeline, f"SELECT value FROM {tbl_name}", ["a", None, None]) # do not remove - it allows us to filter tests by destination -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True, subset=["bigquery", "snowflake", "duckdb"]), ids=lambda x: x.name) +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["bigquery", "snowflake", "duckdb"]), + ids=lambda x: x.name, +) def test_parquet_loading(destination_config: DestinationTestConfiguration) -> None: """Run pipeline twice with merge write disposition Resource with primary key falls back to append. Resource without keys falls back to replace. """ - pipeline = destination_config.setup_pipeline('parquet_test_' + uniq_id(), dataset_name='parquet_test_' + uniq_id()) + pipeline = destination_config.setup_pipeline( + "parquet_test_" + uniq_id(), dataset_name="parquet_test_" + uniq_id() + ) - @dlt.resource(primary_key='id') + @dlt.resource(primary_key="id") def some_data(): # type: ignore[no-untyped-def] - yield [{'id': 1}, {'id': 2}, {'id': 3}] + yield [{"id": 1}, {"id": 2}, {"id": 3}] @dlt.resource(write_disposition="replace") def other_data(): # type: ignore[no-untyped-def] @@ -644,7 +770,7 @@ def other_data(): # type: ignore[no-untyped-def] @dlt.resource(table_name="data_types", write_disposition="merge", columns=column_schemas) def my_resource(): nonlocal data_types - yield [data_types]*10 + yield [data_types] * 10 @dlt.source(max_table_nesting=0) def some_source(): # type: ignore[no-untyped-def] @@ -659,16 +785,27 @@ def some_source(): # type: ignore[no-untyped-def] client = pipeline._destination_client() # type: ignore[assignment] with client.sql_client as sql_client: - assert [row[0] for row in sql_client.execute_sql("SELECT * FROM other_data")] == [1, 2, 3, 4, 5] + assert [row[0] for row in sql_client.execute_sql("SELECT * FROM other_data")] == [ + 1, + 2, + 3, + 4, + 5, + ] assert [row[0] for row in sql_client.execute_sql("SELECT * FROM some_data")] == [1, 2, 3] db_rows = sql_client.execute_sql("SELECT * FROM data_types") assert len(db_rows) == 10 db_row = list(db_rows[0]) # "snowflake" and "bigquery" do not parse JSON form parquet string so double parse - assert_all_data_types_row(db_row[:-2], parse_complex_strings=destination_config.destination in ["snowflake", "bigquery"]) + assert_all_data_types_row( + db_row[:-2], + parse_complex_strings=destination_config.destination in ["snowflake", "bigquery"], + ) -def simple_nested_pipeline(destination_config: DestinationTestConfiguration, dataset_name: str, full_refresh: bool) -> Tuple[dlt.Pipeline, Callable[[], DltSource]]: +def simple_nested_pipeline( + destination_config: DestinationTestConfiguration, dataset_name: str, full_refresh: bool +) -> Tuple[dlt.Pipeline, Callable[[], DltSource]]: data = ["a", ["a", "b", "c"], ["a", "b", "c"]] def d(): @@ -678,6 +815,11 @@ def d(): def _data(): return dlt.resource(d(), name="lists", write_disposition="append") - p = dlt.pipeline(pipeline_name=f"pipeline_{dataset_name}", full_refresh=full_refresh, destination=destination_config.destination, staging=destination_config.staging, dataset_name=dataset_name) + p = dlt.pipeline( + pipeline_name=f"pipeline_{dataset_name}", + full_refresh=full_refresh, + destination=destination_config.destination, + staging=destination_config.staging, + dataset_name=dataset_name, + ) return p, _data - diff --git a/tests/load/pipeline/test_replace_disposition.py b/tests/load/pipeline/test_replace_disposition.py index 3176e1ce95..6eaf3352de 100644 --- a/tests/load/pipeline/test_replace_disposition.py +++ b/tests/load/pipeline/test_replace_disposition.py @@ -1,40 +1,68 @@ +import os from typing import Dict + +import pytest import yaml -import dlt, os, pytest +from tests.load.pipeline.utils import ( + DestinationTestConfiguration, + destinations_configs, + drop_active_pipeline_data, + load_table_counts, + load_tables_to_dicts, +) +from tests.pipeline.utils import assert_load_info + +import dlt from dlt.common.utils import uniq_id -from tests.pipeline.utils import assert_load_info -from tests.load.pipeline.utils import drop_active_pipeline_data, load_table_counts, load_tables_to_dicts -from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration - REPLACE_STRATEGIES = ["truncate-and-insert", "insert-from-staging", "staging-optimized"] -@pytest.mark.parametrize("destination_config", destinations_configs(local_filesystem_configs=True, default_staging_configs=True, default_sql_configs=True), ids=lambda x: x.name) -@pytest.mark.parametrize("replace_strategy", REPLACE_STRATEGIES) -def test_replace_disposition(destination_config: DestinationTestConfiguration, replace_strategy: str) -> None: +@pytest.mark.parametrize( + "destination_config", + destinations_configs( + local_filesystem_configs=True, default_staging_configs=True, default_sql_configs=True + ), + ids=lambda x: x.name, +) +@pytest.mark.parametrize("replace_strategy", REPLACE_STRATEGIES) +def test_replace_disposition( + destination_config: DestinationTestConfiguration, replace_strategy: str +) -> None: if not destination_config.supports_merge and replace_strategy != "truncate-and-insert": - pytest.skip(f"Destination {destination_config.name} does not support merge and thus {replace_strategy}") + pytest.skip( + f"Destination {destination_config.name} does not support merge and thus" + f" {replace_strategy}" + ) # only allow 40 items per file - os.environ['DATA_WRITER__FILE_MAX_ITEMS'] = "40" + os.environ["DATA_WRITER__FILE_MAX_ITEMS"] = "40" # use staging tables for replace - os.environ['DESTINATION__REPLACE_STRATEGY'] = replace_strategy + os.environ["DESTINATION__REPLACE_STRATEGY"] = replace_strategy # make duckdb to reuse database in working folder os.environ["DESTINATION__DUCKDB__CREDENTIALS"] = "duckdb:///test_replace_disposition.duckdb" # TODO: start storing _dlt_loads with right json content increase_loads = lambda x: x if destination_config.destination == "filesystem" else x + 1 - increase_state_loads = lambda info: len([job for job in info.load_packages[0].jobs["completed_jobs"] if job.job_file_info.table_name == "_dlt_pipeline_state" and job.job_file_info.file_format != "reference"]) + increase_state_loads = lambda info: len( + [ + job + for job in info.load_packages[0].jobs["completed_jobs"] + if job.job_file_info.table_name == "_dlt_pipeline_state" + and job.job_file_info.file_format != "reference" + ] + ) # filesystem does not have versions and child tables def norm_table_counts(counts: Dict[str, int], *child_tables: str) -> Dict[str, int]: if destination_config.destination != "filesystem": return counts - return {**{"_dlt_version": 0}, **{t:0 for t in child_tables}, **counts} + return {**{"_dlt_version": 0}, **{t: 0 for t in child_tables}, **counts} dataset_name = "test_replace_strategies_ds" + uniq_id() - pipeline = destination_config.setup_pipeline("test_replace_strategies", dataset_name=dataset_name) + pipeline = destination_config.setup_pipeline( + "test_replace_strategies", dataset_name=dataset_name + ) global offset offset = 1000 @@ -46,35 +74,39 @@ def load_items(): # 6 jobs for the sub_items # 3 jobs for the sub_sub_items global offset - for _, index in enumerate(range(offset, offset+120), 1): + for _, index in enumerate(range(offset, offset + 120), 1): yield { "id": index, "name": f"item {index}", - "sub_items": [{ - "id": index + 1000, - "name": f"sub item {index + 1000}" - },{ - "id": index + 2000, - "name": f"sub item {index + 2000}", - "sub_sub_items": [{ - "id": index + 3000, - "name": f"sub item {index + 3000}", - }] - }] - } + "sub_items": [ + {"id": index + 1000, "name": f"sub item {index + 1000}"}, + { + "id": index + 2000, + "name": f"sub item {index + 2000}", + "sub_sub_items": [ + { + "id": index + 3000, + "name": f"sub item {index + 3000}", + } + ], + }, + ], + } # append resource to see if we do not drop any tables @dlt.resource(write_disposition="append") def append_items(): global offset - for _, index in enumerate(range(offset, offset+12), 1): + for _, index in enumerate(range(offset, offset + 12), 1): yield { "id": index, "name": f"item {index}", } # first run with offset 0 - info = pipeline.run([load_items, append_items], loader_file_format=destination_config.file_format) + info = pipeline.run( + [load_items, append_items], loader_file_format=destination_config.file_format + ) assert_load_info(info) # count state records that got extracted state_records = increase_state_loads(info) @@ -83,7 +115,9 @@ def append_items(): # second run with higher offset so we can check the results offset = 1000 - info = pipeline.run([load_items, append_items], loader_file_format=destination_config.file_format) + info = pipeline.run( + [load_items, append_items], loader_file_format=destination_config.file_format + ) assert_load_info(info) state_records += increase_state_loads(info) dlt_loads = increase_loads(dlt_loads) @@ -97,43 +131,59 @@ def append_items(): "items__sub_items__sub_sub_items": 120, "_dlt_pipeline_state": state_records, "_dlt_loads": dlt_loads, - "_dlt_version": dlt_versions + "_dlt_version": dlt_versions, } # check we really have the replaced data in our destination - table_dicts = load_tables_to_dicts(pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()]) - assert {x for i,x in enumerate(range(1000, 1120), 1)} == {int(x["id"]) for x in table_dicts["items"]} - assert {x for i,x in enumerate(range(2000, 2000+120), 1)}.union({x for i,x in enumerate(range(3000, 3000+120), 1)}) == {int(x["id"]) for x in table_dicts["items__sub_items"]} - assert {x for i,x in enumerate(range(4000, 4120), 1)} == {int(x["id"]) for x in table_dicts["items__sub_items__sub_sub_items"]} + table_dicts = load_tables_to_dicts( + pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] + ) + assert {x for i, x in enumerate(range(1000, 1120), 1)} == { + int(x["id"]) for x in table_dicts["items"] + } + assert {x for i, x in enumerate(range(2000, 2000 + 120), 1)}.union( + {x for i, x in enumerate(range(3000, 3000 + 120), 1)} + ) == {int(x["id"]) for x in table_dicts["items__sub_items"]} + assert {x for i, x in enumerate(range(4000, 4120), 1)} == { + int(x["id"]) for x in table_dicts["items__sub_items__sub_sub_items"] + } # we need to test that destination tables including child tables are cleared when we yield none from the resource @dlt.resource(name="items", write_disposition="replace", primary_key="id") def load_items_none(): yield - info = pipeline.run([load_items_none, append_items], loader_file_format=destination_config.file_format) + info = pipeline.run( + [load_items_none, append_items], loader_file_format=destination_config.file_format + ) assert_load_info(info) state_records += increase_state_loads(info) dlt_loads = increase_loads(dlt_loads) # table and child tables should be cleared table_counts = load_table_counts(pipeline, *pipeline.default_schema.tables.keys()) - assert norm_table_counts(table_counts, "items__sub_items", "items__sub_items__sub_sub_items") == { + assert norm_table_counts( + table_counts, "items__sub_items", "items__sub_items__sub_sub_items" + ) == { "append_items": 36, "items": 0, "items__sub_items": 0, "items__sub_items__sub_sub_items": 0, "_dlt_pipeline_state": state_records, "_dlt_loads": dlt_loads, - "_dlt_version": dlt_versions + "_dlt_version": dlt_versions, } # drop and deactivate existing pipeline # drop_active_pipeline_data() # create a pipeline with different name but loading to the same dataset as above - this is to provoke truncating non existing tables - pipeline_2 = destination_config.setup_pipeline("test_replace_strategies_2", dataset_name=dataset_name) - info = pipeline_2.run(load_items, table_name="items_copy", loader_file_format=destination_config.file_format) + pipeline_2 = destination_config.setup_pipeline( + "test_replace_strategies_2", dataset_name=dataset_name + ) + info = pipeline_2.run( + load_items, table_name="items_copy", loader_file_format=destination_config.file_format + ) assert_load_info(info) new_state_records = increase_state_loads(info) assert new_state_records == 1 @@ -155,55 +205,61 @@ def load_items_none(): "items_copy__sub_items__sub_sub_items": 120, "_dlt_pipeline_state": state_records + 1, "_dlt_loads": dlt_loads, - "_dlt_version": increase_loads(dlt_versions) + "_dlt_version": increase_loads(dlt_versions), } # old pipeline -> shares completed loads and versions table table_counts = load_table_counts(pipeline, *pipeline.default_schema.tables.keys()) - assert norm_table_counts(table_counts, "items__sub_items", "items__sub_items__sub_sub_items") == { + assert norm_table_counts( + table_counts, "items__sub_items", "items__sub_items__sub_sub_items" + ) == { "append_items": 48, "items": 0, "items__sub_items": 0, "items__sub_items__sub_sub_items": 0, "_dlt_pipeline_state": state_records + 1, "_dlt_loads": dlt_loads, # next load - "_dlt_version": increase_loads(dlt_versions) # new table name -> new schema + "_dlt_version": increase_loads(dlt_versions), # new table name -> new schema } -@pytest.mark.parametrize("destination_config", destinations_configs(local_filesystem_configs=True, default_staging_configs=True, default_sql_configs=True), ids=lambda x: x.name) + +@pytest.mark.parametrize( + "destination_config", + destinations_configs( + local_filesystem_configs=True, default_staging_configs=True, default_sql_configs=True + ), + ids=lambda x: x.name, +) @pytest.mark.parametrize("replace_strategy", REPLACE_STRATEGIES) -def test_replace_table_clearing(destination_config: DestinationTestConfiguration,replace_strategy: str) -> None: +def test_replace_table_clearing( + destination_config: DestinationTestConfiguration, replace_strategy: str +) -> None: if not destination_config.supports_merge and replace_strategy != "truncate-and-insert": - pytest.skip(f"Destination {destination_config.name} does not support merge and thus {replace_strategy}") + pytest.skip( + f"Destination {destination_config.name} does not support merge and thus" + f" {replace_strategy}" + ) # use staging tables for replace - os.environ['DESTINATION__REPLACE_STRATEGY'] = replace_strategy + os.environ["DESTINATION__REPLACE_STRATEGY"] = replace_strategy - pipeline = destination_config.setup_pipeline("test_replace_table_clearing", dataset_name="test_replace_table_clearing", full_refresh=True) + pipeline = destination_config.setup_pipeline( + "test_replace_table_clearing", dataset_name="test_replace_table_clearing", full_refresh=True + ) @dlt.resource(name="main_resource", write_disposition="replace", primary_key="id") def items_with_subitems(): data = { "id": 1, "name": "item", - "sub_items": [{ - "id": 101, - "name": "sub item 101" - },{ - "id": 101, - "name": "sub item 102" - }] + "sub_items": [{"id": 101, "name": "sub item 101"}, {"id": 101, "name": "sub item 102"}], } yield dlt.mark.with_table_name(data, "items") yield dlt.mark.with_table_name(data, "other_items") @dlt.resource(name="main_resource", write_disposition="replace", primary_key="id") def items_without_subitems(): - data = [{ - "id": 1, - "name": "item", - "sub_items": [] - }] + data = [{"id": 1, "name": "item", "sub_items": []}] yield dlt.mark.with_table_name(data, "items") yield dlt.mark.with_table_name(data, "other_items") @@ -211,17 +267,16 @@ def items_without_subitems(): def items_with_subitems_yield_none(): yield None yield None - data = [{ - "id": 1, - "name": "item", - "sub_items": [{ - "id": 101, - "name": "sub item 101" - },{ - "id": 101, - "name": "sub item 102" - }] - }] + data = [ + { + "id": 1, + "name": "item", + "sub_items": [ + {"id": 101, "name": "sub item 101"}, + {"id": 101, "name": "sub item 102"}, + ], + } + ] yield dlt.mark.with_table_name(data, "items") yield dlt.mark.with_table_name(data, "other_items") yield None @@ -232,13 +287,7 @@ def static_items(): yield { "id": 1, "name": "item", - "sub_items": [{ - "id": 101, - "name": "sub item 101" - },{ - "id": 101, - "name": "sub item 102" - }] + "sub_items": [{"id": 101, "name": "sub item 101"}, {"id": 101, "name": "sub item 102"}], } @dlt.resource(name="main_resource", write_disposition="replace", primary_key="id") @@ -246,8 +295,12 @@ def yield_none(): yield # regular call - pipeline.run([items_with_subitems, static_items], loader_file_format=destination_config.file_format) - table_counts = load_table_counts(pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()]) + pipeline.run( + [items_with_subitems, static_items], loader_file_format=destination_config.file_format + ) + table_counts = load_table_counts( + pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] + ) assert table_counts["items"] == 1 assert table_counts["items__sub_items"] == 2 assert table_counts["other_items"] == 1 @@ -257,7 +310,9 @@ def yield_none(): # see if child table gets cleared pipeline.run(items_without_subitems, loader_file_format=destination_config.file_format) - table_counts = load_table_counts(pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()]) + table_counts = load_table_counts( + pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] + ) assert table_counts["items"] == 1 assert table_counts.get("items__sub_items", 0) == 0 assert table_counts["other_items"] == 1 @@ -268,7 +323,9 @@ def yield_none(): # see if yield none clears everything pipeline.run(items_with_subitems, loader_file_format=destination_config.file_format) pipeline.run(yield_none, loader_file_format=destination_config.file_format) - table_counts = load_table_counts(pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()]) + table_counts = load_table_counts( + pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] + ) assert table_counts.get("items", 0) == 0 assert table_counts.get("items__sub_items", 0) == 0 assert table_counts.get("other_items", 0) == 0 @@ -278,7 +335,9 @@ def yield_none(): # see if yielding something next to other none entries still goes into db pipeline.run(items_with_subitems_yield_none, loader_file_format=destination_config.file_format) - table_counts = load_table_counts(pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()]) + table_counts = load_table_counts( + pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] + ) assert table_counts["items"] == 1 assert table_counts["items__sub_items"] == 2 assert table_counts["other_items"] == 1 diff --git a/tests/load/pipeline/test_restore_state.py b/tests/load/pipeline/test_restore_state.py index a3c11f4048..72631bab10 100644 --- a/tests/load/pipeline/test_restore_state.py +++ b/tests/load/pipeline/test_restore_state.py @@ -2,24 +2,33 @@ import os import shutil from typing import Any, Dict + import pytest +from tests.cases import JSON_TYPED_DICT +from tests.common.configuration.utils import environment +from tests.common.utils import IMPORTED_VERSION_HASH_ETH_V6 +from tests.common.utils import yml_case_path as common_yml_case_path +from tests.load.pipeline.utils import ( + DestinationTestConfiguration, + assert_query_data, + destinations_configs, + drop_active_pipeline_data, +) +from tests.utils import TEST_STORAGE_ROOT import dlt from dlt.common import pendulum +from dlt.common.exceptions import DestinationUndefinedEntity from dlt.common.schema.schema import Schema, utils from dlt.common.schema.typing import LOADS_TABLE_NAME, VERSION_TABLE_NAME from dlt.common.utils import custom_environ, uniq_id -from dlt.common.exceptions import DestinationUndefinedEntity - from dlt.pipeline.pipeline import Pipeline -from dlt.pipeline.state_sync import STATE_TABLE_COLUMNS, STATE_TABLE_NAME, load_state_from_destination, state_resource - -from tests.utils import TEST_STORAGE_ROOT -from tests.cases import JSON_TYPED_DICT -from tests.common.utils import IMPORTED_VERSION_HASH_ETH_V6, yml_case_path as common_yml_case_path -from tests.common.configuration.utils import environment -from tests.load.pipeline.utils import assert_query_data, drop_active_pipeline_data -from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration +from dlt.pipeline.state_sync import ( + STATE_TABLE_COLUMNS, + STATE_TABLE_NAME, + load_state_from_destination, + state_resource, +) @pytest.fixture(autouse=True) @@ -29,10 +38,15 @@ def duckdb_pipeline_location() -> None: del os.environ["DESTINATION__DUCKDB__CREDENTIALS"] -@pytest.mark.parametrize("destination_config", destinations_configs(default_staging_configs=True, default_sql_configs=True), ids=lambda x: x.name) +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_staging_configs=True, default_sql_configs=True), + ids=lambda x: x.name, +) def test_restore_state_utils(destination_config: DestinationTestConfiguration) -> None: - - p = destination_config.setup_pipeline(pipeline_name="pipe_" + uniq_id(), dataset_name="state_test_" + uniq_id()) + p = destination_config.setup_pipeline( + pipeline_name="pipe_" + uniq_id(), dataset_name="state_test_" + uniq_id() + ) schema = Schema("state") # inject schema into pipeline, don't do it in production @@ -53,11 +67,17 @@ def test_restore_state_utils(destination_config: DestinationTestConfiguration) - initial_state["_local"]["_last_extracted_at"] = pendulum.now() # add _dlt_id and _dlt_load_id resource = state_resource(initial_state) - resource.apply_hints(columns={ - "_dlt_id": utils.add_missing_hints({"name": "_dlt_id", "data_type": "text", "nullable": False}), - "_dlt_load_id": utils.add_missing_hints({"name": "_dlt_load_id", "data_type": "text", "nullable": False}), - **STATE_TABLE_COLUMNS - }) + resource.apply_hints( + columns={ + "_dlt_id": utils.add_missing_hints( + {"name": "_dlt_id", "data_type": "text", "nullable": False} + ), + "_dlt_load_id": utils.add_missing_hints( + {"name": "_dlt_load_id", "data_type": "text", "nullable": False} + ), + **STATE_TABLE_COLUMNS, + } + ) schema.update_schema(resource.table_schema()) # do not bump version here or in sync_schema, dlt won't recognize that schema changed and it won't update it in storage # so dlt in normalize stage infers _state_version table again but with different column order and the column order in schema is different @@ -134,21 +154,33 @@ def test_restore_state_utils(destination_config: DestinationTestConfiguration) - assert new_stored_state["_state_version"] + 1 == new_stored_state_2["_state_version"] -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name) -def test_silently_skip_on_invalid_credentials(destination_config: DestinationTestConfiguration, environment: Any) -> None: +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) +def test_silently_skip_on_invalid_credentials( + destination_config: DestinationTestConfiguration, environment: Any +) -> None: environment["CREDENTIALS"] = "postgres://loader:password@localhost:5432/dlt_data" - environment["DESTINATION__BIGQUERY__CREDENTIALS"] = '{"project_id": "chat-analytics-","client_email": "loader@chat-analytics-317513","private_key": "-----BEGIN PRIVATE KEY-----\\nMIIEuwIBADANBgkqhkiG9w0BAQEFAASCBKUwggShAgEAAoIBAQCNEN0bL39HmD"}' + environment["DESTINATION__BIGQUERY__CREDENTIALS"] = ( + '{"project_id": "chat-analytics-","client_email":' + ' "loader@chat-analytics-317513","private_key": "-----BEGIN PRIVATE' + ' KEY-----\\nMIIEuwIBADANBgkqhkiG9w0BAQEFAASCBKUwggShAgEAAoIBAQCNEN0bL39HmD"}' + ) pipeline_name = "pipe_" + uniq_id() - dataset_name="state_test_" + uniq_id() + dataset_name = "state_test_" + uniq_id() # NOTE: we are not restoring the state in __init__ anymore but the test should stay: init should not fail on lack of credentials destination_config.setup_pipeline(pipeline_name=pipeline_name, dataset_name=dataset_name) -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name) -@pytest.mark.parametrize('use_single_dataset', [True, False]) -def test_get_schemas_from_destination(destination_config: DestinationTestConfiguration, use_single_dataset: bool) -> None: +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) +@pytest.mark.parametrize("use_single_dataset", [True, False]) +def test_get_schemas_from_destination( + destination_config: DestinationTestConfiguration, use_single_dataset: bool +) -> None: pipeline_name = "pipe_" + uniq_id() - dataset_name="state_test_" + uniq_id() + dataset_name = "state_test_" + uniq_id() def _make_dn_name(schema_name: str) -> str: if use_single_dataset: @@ -214,11 +246,13 @@ def _make_dn_name(schema_name: str) -> str: assert len(restored_schemas) == 3 -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name) +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) def test_restore_state_pipeline(destination_config: DestinationTestConfiguration) -> None: os.environ["RESTORE_FROM_DESTINATION"] = "True" pipeline_name = "pipe_" + uniq_id() - dataset_name="state_test_" + uniq_id() + dataset_name = "state_test_" + uniq_id() p = destination_config.setup_pipeline(pipeline_name=pipeline_name, dataset_name=dataset_name) def some_data_gen(param: str) -> Any: @@ -280,7 +314,10 @@ def some_data(): assert p.default_schema_name == "default" assert set(p.schema_names) == set(["default", "two", "three", "four"]) assert p.state["sources"] == { - "default": {'state1': 'state1', 'state2': 'state2'}, "two": {'state3': 'state3'}, "three": {'state4': 'state4'}, "four": {"state5": JSON_TYPED_DICT} + "default": {"state1": "state1", "state2": "state2"}, + "two": {"state3": "state3"}, + "three": {"state4": "state4"}, + "four": {"state5": JSON_TYPED_DICT}, } for schema in p.schemas.values(): assert "some_data" in schema.tables @@ -290,7 +327,9 @@ def some_data(): # full refresh will not restore pipeline even if requested p._wipe_working_folder() - p = destination_config.setup_pipeline(pipeline_name=pipeline_name, dataset_name=dataset_name, full_refresh=True) + p = destination_config.setup_pipeline( + pipeline_name=pipeline_name, dataset_name=dataset_name, full_refresh=True + ) p.run() assert p.default_schema_name is None drop_active_pipeline_data() @@ -310,11 +349,15 @@ def some_data(): assert restored_state["_state_version"] == orig_state["_state_version"] # second run will not restore - p._inject_schema(Schema("second")) # this will modify state, run does not sync if states are identical + p._inject_schema( + Schema("second") + ) # this will modify state, run does not sync if states are identical assert p.state["_state_version"] > orig_state["_state_version"] # print(p.state) p.run() - assert set(p.schema_names) == set(["default", "two", "three", "second", "four"]) # we keep our local copy + assert set(p.schema_names) == set( + ["default", "two", "three", "second", "four"] + ) # we keep our local copy # clear internal flag and decrease state version so restore triggers state = p.state state["_state_version"] -= 1 @@ -324,10 +367,12 @@ def some_data(): assert set(p.schema_names) == set(["default", "two", "three", "four"]) -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name) +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) def test_ignore_state_unfinished_load(destination_config: DestinationTestConfiguration) -> None: pipeline_name = "pipe_" + uniq_id() - dataset_name="state_test_" + uniq_id() + dataset_name = "state_test_" + uniq_id() p = destination_config.setup_pipeline(pipeline_name=pipeline_name, dataset_name=dataset_name) @dlt.resource @@ -340,24 +385,30 @@ def some_data(param: str) -> Any: state = load_state_from_destination(pipeline_name, job_client.sql_client) assert state is not None # delete load id - job_client.sql_client.execute_sql(f"DELETE FROM {LOADS_TABLE_NAME} WHERE load_id = %s", next(iter(info.loads_ids))) + job_client.sql_client.execute_sql( + f"DELETE FROM {LOADS_TABLE_NAME} WHERE load_id = %s", next(iter(info.loads_ids)) + ) # state without completed load id is not visible state = load_state_from_destination(pipeline_name, job_client.sql_client) assert state is None -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name) -def test_restore_schemas_while_import_schemas_exist(destination_config: DestinationTestConfiguration) -> None: +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) +def test_restore_schemas_while_import_schemas_exist( + destination_config: DestinationTestConfiguration, +) -> None: # restored schema should attach itself to imported schema and it should not get overwritten import_schema_path = os.path.join(TEST_STORAGE_ROOT, "schemas", "import") export_schema_path = os.path.join(TEST_STORAGE_ROOT, "schemas", "export") pipeline_name = "pipe_" + uniq_id() - dataset_name="state_test_" + uniq_id() + dataset_name = "state_test_" + uniq_id() p = destination_config.setup_pipeline( pipeline_name=pipeline_name, dataset_name=dataset_name, import_schema_path=import_schema_path, - export_schema_path=export_schema_path + export_schema_path=export_schema_path, ) prepare_import_folder(p) # make sure schema got imported @@ -382,10 +433,14 @@ def test_restore_schemas_while_import_schemas_exist(destination_config: Destinat p = dlt.pipeline( pipeline_name=pipeline_name, import_schema_path=import_schema_path, - export_schema_path=export_schema_path + export_schema_path=export_schema_path, ) # use run to get changes - p.run(destination=destination_config.destination, staging=destination_config.staging, dataset_name=dataset_name) + p.run( + destination=destination_config.destination, + staging=destination_config.staging, + dataset_name=dataset_name, + ) schema = p.schemas["ethereum"] assert "labels" in schema.tables assert "annotations" in schema.tables @@ -406,10 +461,12 @@ def test_restore_change_dataset_and_destination(destination_name: str) -> None: pass -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name) +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) def test_restore_state_parallel_changes(destination_config: DestinationTestConfiguration) -> None: pipeline_name = "pipe_" + uniq_id() - dataset_name="state_test_" + uniq_id() + dataset_name = "state_test_" + uniq_id() destination_config.setup() p = dlt.pipeline(pipeline_name=pipeline_name) @@ -421,15 +478,25 @@ def some_data(param: str) -> Any: # extract two resources that modify the state data1 = some_data("state1") data1._name = "state1_data" - p.run([data1, some_data("state2")], schema=Schema("default"), destination=destination_config.destination, staging=destination_config.staging, dataset_name=dataset_name) + p.run( + [data1, some_data("state2")], + schema=Schema("default"), + destination=destination_config.destination, + staging=destination_config.staging, + dataset_name=dataset_name, + ) orig_state = p.state # create a production pipeline in separate pipelines_dir production_p = dlt.pipeline(pipeline_name=pipeline_name, pipelines_dir=TEST_STORAGE_ROOT) - production_p.sync_destination(destination=destination_config.destination, staging=destination_config.staging, dataset_name=dataset_name) + production_p.sync_destination( + destination=destination_config.destination, + staging=destination_config.staging, + dataset_name=dataset_name, + ) assert production_p.default_schema_name == "default" prod_state = production_p.state - assert prod_state["sources"] == {"default": {'state1': 'state1', 'state2': 'state2'}} + assert prod_state["sources"] == {"default": {"state1": "state1", "state2": "state2"}} assert prod_state["_state_version"] == orig_state["_state_version"] # generate data on production that modifies the schema but not state data2 = some_data("state1") @@ -482,13 +549,21 @@ def some_data(param: str) -> Any: assert ra_production_p.state == prod_state # get all the states, notice version 4 twice (one from production, the other from local) - assert_query_data(p, f"SELECT version, _dlt_load_id FROM {STATE_TABLE_NAME} ORDER BY created_at", [2, 3, 4, 4, 5]) + assert_query_data( + p, + f"SELECT version, _dlt_load_id FROM {STATE_TABLE_NAME} ORDER BY created_at", + [2, 3, 4, 4, 5], + ) -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name) -def test_reset_pipeline_on_deleted_dataset(destination_config: DestinationTestConfiguration) -> None: +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) +def test_reset_pipeline_on_deleted_dataset( + destination_config: DestinationTestConfiguration, +) -> None: pipeline_name = "pipe_" + uniq_id() - dataset_name="state_test_" + uniq_id() + dataset_name = "state_test_" + uniq_id() destination_config.setup() p = dlt.pipeline(pipeline_name=pipeline_name) @@ -499,7 +574,13 @@ def some_data(param: str) -> Any: data4 = some_data("state4") data4.apply_hints(table_name="state1_data4") - p.run(data4, schema=Schema("sch1"), destination=destination_config.destination, staging=destination_config.staging, dataset_name=dataset_name) + p.run( + data4, + schema=Schema("sch1"), + destination=destination_config.destination, + staging=destination_config.staging, + dataset_name=dataset_name, + ) data5 = some_data("state4") data5.apply_hints(table_name="state1_data5") p.run(data5, schema=Schema("sch2")) @@ -522,7 +603,13 @@ def some_data(param: str) -> Any: p.config.restore_from_destination = False data4 = some_data("state4") data4.apply_hints(table_name="state1_data4") - p.run(data4, schema=Schema("sch1"), destination=destination_config.destination, staging=destination_config.staging, dataset_name=dataset_name) + p.run( + data4, + schema=Schema("sch1"), + destination=destination_config.destination, + staging=destination_config.staging, + dataset_name=dataset_name, + ) assert p.first_run is False assert p.state["_local"]["first_run"] is False # attach again to make the `run` method check the destination @@ -538,4 +625,7 @@ def some_data(param: str) -> Any: def prepare_import_folder(p: Pipeline) -> None: os.makedirs(p._schema_storage.config.import_schema_path, exist_ok=True) - shutil.copy(common_yml_case_path("schemas/eth/ethereum_schema_v5"), os.path.join(p._schema_storage.config.import_schema_path, "ethereum.schema.yaml")) + shutil.copy( + common_yml_case_path("schemas/eth/ethereum_schema_v5"), + os.path.join(p._schema_storage.config.import_schema_path, "ethereum.schema.yaml"), + ) diff --git a/tests/load/pipeline/test_stage_loading.py b/tests/load/pipeline/test_stage_loading.py index f06c6be44a..5a41cf43cc 100644 --- a/tests/load/pipeline/test_stage_loading.py +++ b/tests/load/pipeline/test_stage_loading.py @@ -1,21 +1,33 @@ -import pytest -from typing import Dict, Any - -import dlt, os -from dlt.common import json, sleep +import os from copy import deepcopy -from dlt.common.utils import uniq_id +from typing import Any, Dict +import pytest from tests.load.pipeline.test_merge_disposition import github -from tests.load.pipeline.utils import load_table_counts -from tests.pipeline.utils import assert_load_info -from tests.load.utils import TABLE_ROW_ALL_DATA_TYPES, TABLE_UPDATE_COLUMNS_SCHEMA, assert_all_data_types_row -from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration +from tests.load.pipeline.utils import ( + DestinationTestConfiguration, + destinations_configs, + load_table_counts, +) +from tests.load.utils import ( + TABLE_ROW_ALL_DATA_TYPES, + TABLE_UPDATE_COLUMNS_SCHEMA, + assert_all_data_types_row, +) +from tests.pipeline.utils import assert_load_info + +import dlt +from dlt.common import json, sleep +from dlt.common.utils import uniq_id -@dlt.resource(table_name="issues", write_disposition="merge", primary_key="id", merge_key=("node_id", "url")) +@dlt.resource( + table_name="issues", write_disposition="merge", primary_key="id", merge_key=("node_id", "url") +) def load_modified_issues(): - with open("tests/normalize/cases/github.issues.load_page_5_duck.json", "r", encoding="utf-8") as f: + with open( + "tests/normalize/cases/github.issues.load_page_5_duck.json", "r", encoding="utf-8" + ) as f: issues = json.load(f) # change 2 issues @@ -28,10 +40,13 @@ def load_modified_issues(): yield from issues -@pytest.mark.parametrize("destination_config", destinations_configs(all_staging_configs=True), ids=lambda x: x.name) +@pytest.mark.parametrize( + "destination_config", destinations_configs(all_staging_configs=True), ids=lambda x: x.name +) def test_staging_load(destination_config: DestinationTestConfiguration) -> None: - - pipeline = destination_config.setup_pipeline(pipeline_name='test_stage_loading_5', dataset_name="test_staging_load" + uniq_id()) + pipeline = destination_config.setup_pipeline( + pipeline_name="test_stage_loading_5", dataset_name="test_staging_load" + uniq_id() + ) info = pipeline.run(github(), loader_file_format=destination_config.file_format) assert_load_info(info) @@ -42,12 +57,41 @@ def test_staging_load(destination_config: DestinationTestConfiguration) -> None: # we have 4 parquet and 4 reference jobs plus one merge job num_jobs = 4 + 4 + 1 if destination_config.supports_merge else 4 + 4 assert len(package_info.jobs["completed_jobs"]) == num_jobs - assert len([x for x in package_info.jobs["completed_jobs"] if x.job_file_info.file_format == "reference"]) == 4 - assert len([x for x in package_info.jobs["completed_jobs"] if x.job_file_info.file_format == destination_config.file_format]) == 4 + assert ( + len( + [ + x + for x in package_info.jobs["completed_jobs"] + if x.job_file_info.file_format == "reference" + ] + ) + == 4 + ) + assert ( + len( + [ + x + for x in package_info.jobs["completed_jobs"] + if x.job_file_info.file_format == destination_config.file_format + ] + ) + == 4 + ) if destination_config.supports_merge: - assert len([x for x in package_info.jobs["completed_jobs"] if x.job_file_info.file_format == "sql"]) == 1 + assert ( + len( + [ + x + for x in package_info.jobs["completed_jobs"] + if x.job_file_info.file_format == "sql" + ] + ) + == 1 + ) - initial_counts = load_table_counts(pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()]) + initial_counts = load_table_counts( + pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] + ) assert initial_counts["issues"] == 100 # check item of first row in db @@ -60,37 +104,58 @@ def test_staging_load(destination_config: DestinationTestConfiguration) -> None: info = pipeline.run(load_modified_issues, loader_file_format=destination_config.file_format) assert_load_info(info) assert pipeline.default_schema.tables["issues"]["write_disposition"] == "merge" - merge_counts = load_table_counts(pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()]) + merge_counts = load_table_counts( + pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] + ) assert merge_counts == initial_counts # check changes where merged in with pipeline._get_destination_clients(pipeline.default_schema)[0] as client: - rows = client.sql_client.execute_sql("SELECT number FROM issues WHERE id = 1232152492 LIMIT 1") + rows = client.sql_client.execute_sql( + "SELECT number FROM issues WHERE id = 1232152492 LIMIT 1" + ) assert rows[0][0] == 105 - rows = client.sql_client.execute_sql("SELECT number FROM issues WHERE id = 1142699354 LIMIT 1") + rows = client.sql_client.execute_sql( + "SELECT number FROM issues WHERE id = 1142699354 LIMIT 1" + ) assert rows[0][0] == 300 # test append - info = pipeline.run(github().load_issues, write_disposition="append", loader_file_format=destination_config.file_format) + info = pipeline.run( + github().load_issues, + write_disposition="append", + loader_file_format=destination_config.file_format, + ) assert_load_info(info) assert pipeline.default_schema.tables["issues"]["write_disposition"] == "append" # the counts of all tables must be double - append_counts = load_table_counts(pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()]) - assert {k:v*2 for k, v in initial_counts.items()} == append_counts + append_counts = load_table_counts( + pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] + ) + assert {k: v * 2 for k, v in initial_counts.items()} == append_counts # test replace - info = pipeline.run(github().load_issues, write_disposition="replace", loader_file_format=destination_config.file_format) + info = pipeline.run( + github().load_issues, + write_disposition="replace", + loader_file_format=destination_config.file_format, + ) assert_load_info(info) assert pipeline.default_schema.tables["issues"]["write_disposition"] == "replace" # the counts of all tables must be double - replace_counts = load_table_counts(pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()]) + replace_counts = load_table_counts( + pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] + ) assert replace_counts == initial_counts -@pytest.mark.parametrize("destination_config", destinations_configs(all_staging_configs=True), ids=lambda x: x.name) +@pytest.mark.parametrize( + "destination_config", destinations_configs(all_staging_configs=True), ids=lambda x: x.name +) def test_all_data_types(destination_config: DestinationTestConfiguration) -> None: - - pipeline = destination_config.setup_pipeline('test_stage_loading', dataset_name="test_all_data_types" + uniq_id()) + pipeline = destination_config.setup_pipeline( + "test_stage_loading", dataset_name="test_all_data_types" + uniq_id() + ) data_types = deepcopy(TABLE_ROW_ALL_DATA_TYPES) column_schemas = deepcopy(TABLE_UPDATE_COLUMNS_SCHEMA) @@ -110,7 +175,7 @@ def test_all_data_types(destination_config: DestinationTestConfiguration) -> Non @dlt.resource(table_name="data_types", write_disposition="merge", columns=column_schemas) def my_resource(): nonlocal data_types - yield [data_types]*10 + yield [data_types] * 10 @dlt.source(max_table_nesting=0) def my_source(): @@ -124,12 +189,18 @@ def my_source(): assert len(db_rows) == 10 db_row = list(db_rows[0]) # parquet is not really good at inserting json, best we get are strings in JSON columns - parse_complex_strings = destination_config.file_format == "parquet" and destination_config.destination in ["redshift", "bigquery", "snowflake"] - allow_base64_binary = destination_config.file_format == "jsonl" and destination_config.destination in ["redshift"] + parse_complex_strings = ( + destination_config.file_format == "parquet" + and destination_config.destination in ["redshift", "bigquery", "snowflake"] + ) + allow_base64_binary = ( + destination_config.file_format == "jsonl" + and destination_config.destination in ["redshift"] + ) # content must equal assert_all_data_types_row( db_row[:-2], parse_complex_strings=parse_complex_strings, allow_base64_binary=allow_base64_binary, - timestamp_precision=sql_client.capabilities.timestamp_precision + timestamp_precision=sql_client.capabilities.timestamp_precision, ) diff --git a/tests/load/pipeline/utils.py b/tests/load/pipeline/utils.py index 9b89fc943f..80332d244f 100644 --- a/tests/load/pipeline/utils.py +++ b/tests/load/pipeline/utils.py @@ -1,23 +1,24 @@ -import posixpath, os -from typing import Any, Iterator, List, Sequence, TYPE_CHECKING, Optional, Tuple, Dict +import os +import posixpath +from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Sequence, Tuple + import pytest +from tests.load.utils import DestinationTestConfiguration, destinations_configs import dlt -from dlt.pipeline.pipeline import Pipeline - from dlt.common import json from dlt.common.configuration.container import Container +from dlt.common.destination.reference import WithStagingDataset from dlt.common.pipeline import LoadInfo, PipelineContext +from dlt.common.schema.typing import LOADS_TABLE_NAME from dlt.common.typing import DictStrAny from dlt.pipeline.exceptions import SqlClientNotAvailable -from dlt.common.schema.typing import LOADS_TABLE_NAME -from dlt.common.destination.reference import WithStagingDataset - -from tests.load.utils import DestinationTestConfiguration, destinations_configs +from dlt.pipeline.pipeline import Pipeline if TYPE_CHECKING: from dlt.destinations.filesystem.filesystem import FilesystemClient + @pytest.fixture(autouse=True) def drop_pipeline() -> Iterator[None]: yield @@ -75,19 +76,39 @@ def _drop_dataset_sql(schema_name: str) -> None: def _is_filesystem(p: dlt.Pipeline) -> bool: if not p.destination: return False - return p.destination.__name__.rsplit('.', 1)[-1] == 'filesystem' + return p.destination.__name__.rsplit(".", 1)[-1] == "filesystem" -def assert_table(p: dlt.Pipeline, table_name: str, table_data: List[Any], schema_name: str = None, info: LoadInfo = None) -> None: +def assert_table( + p: dlt.Pipeline, + table_name: str, + table_data: List[Any], + schema_name: str = None, + info: LoadInfo = None, +) -> None: func = _assert_table_fs if _is_filesystem(p) else _assert_table_sql func(p, table_name, table_data, schema_name, info) -def _assert_table_sql(p: dlt.Pipeline, table_name: str, table_data: List[Any], schema_name: str = None, info: LoadInfo = None) -> None: - assert_query_data(p, f"SELECT * FROM {table_name} ORDER BY 1 NULLS FIRST", table_data, schema_name, info) - - -def _assert_table_fs(p: dlt.Pipeline, table_name: str, table_data: List[Any], schema_name: str = None, info: LoadInfo = None) -> None: +def _assert_table_sql( + p: dlt.Pipeline, + table_name: str, + table_data: List[Any], + schema_name: str = None, + info: LoadInfo = None, +) -> None: + assert_query_data( + p, f"SELECT * FROM {table_name} ORDER BY 1 NULLS FIRST", table_data, schema_name, info + ) + + +def _assert_table_fs( + p: dlt.Pipeline, + table_name: str, + table_data: List[Any], + schema_name: str = None, + info: LoadInfo = None, +) -> None: """Assert table is loaded to filesystem destination""" client: FilesystemClient = p._destination_client(schema_name) # type: ignore[assignment] # get table directory @@ -107,7 +128,9 @@ def select_data(p: dlt.Pipeline, sql: str, schema_name: str = None) -> List[Sequ return list(cur.fetchall()) -def assert_query_data(p: dlt.Pipeline, sql: str, table_data: List[Any], schema_name: str = None, info: LoadInfo = None) -> None: +def assert_query_data( + p: dlt.Pipeline, sql: str, table_data: List[Any], schema_name: str = None, info: LoadInfo = None +) -> None: """Asserts that query selecting single column of values matches `table_data`. If `info` is provided, second column must contain one of load_ids in `info`""" rows = select_data(p, sql, schema_name) assert len(rows) == len(table_data) @@ -162,6 +185,7 @@ def load_file(path: str, file: str) -> Tuple[str, List[Dict[str, Any]]]: # load parquet elif ext == "parquet": import pyarrow.parquet as pq + with open(full_path, "rb") as f: table = pq.read_table(f) cols = table.column_names @@ -184,7 +208,9 @@ def load_files(p: dlt.Pipeline, *table_names: str) -> Dict[str, List[Dict[str, A """For now this will expect the standard layout in the filesystem destination, if changed the results will not be correct""" client: FilesystemClient = p._destination_client() # type: ignore[assignment] result = {} - for basedir, _dirs, files in client.fs_client.walk(client.dataset_path, detail=False, refresh=True): + for basedir, _dirs, files in client.fs_client.walk( + client.dataset_path, detail=False, refresh=True + ): for file in files: table_name, items = load_file(basedir, file) if table_name not in table_names: @@ -206,7 +232,9 @@ def load_table_counts(p: dlt.Pipeline, *table_names: str) -> DictStrAny: # try sql, could be other destination though try: - query = "\nUNION ALL\n".join([f"SELECT '{name}' as name, COUNT(1) as c FROM {name}" for name in table_names]) + query = "\nUNION ALL\n".join( + [f"SELECT '{name}' as name, COUNT(1) as c FROM {name}" for name in table_names] + ) with p.sql_client() as c: with c.execute_query(query) as cur: rows = list(cur.fetchall()) @@ -223,7 +251,6 @@ def load_table_counts(p: dlt.Pipeline, *table_names: str) -> DictStrAny: def load_tables_to_dicts(p: dlt.Pipeline, *table_names: str) -> Dict[str, List[Dict[str, Any]]]: - # try sql, could be other destination though try: result = {} @@ -246,9 +273,17 @@ def load_tables_to_dicts(p: dlt.Pipeline, *table_names: str) -> Dict[str, List[D # try files return load_files(p, *table_names) -def load_table_distinct_counts(p: dlt.Pipeline, distinct_column: str, *table_names: str) -> DictStrAny: + +def load_table_distinct_counts( + p: dlt.Pipeline, distinct_column: str, *table_names: str +) -> DictStrAny: """Returns counts of distinct values for column `distinct_column` for `table_names` as dict""" - query = "\nUNION ALL\n".join([f"SELECT '{name}' as name, COUNT(DISTINCT {distinct_column}) as c FROM {name}" for name in table_names]) + query = "\nUNION ALL\n".join( + [ + f"SELECT '{name}' as name, COUNT(DISTINCT {distinct_column}) as c FROM {name}" + for name in table_names + ] + ) with p.sql_client() as c: with c.execute_query(query) as cur: rows = list(cur.fetchall()) diff --git a/tests/load/postgres/test_postgres_client.py b/tests/load/postgres/test_postgres_client.py index 269d2df301..95710ff6f3 100644 --- a/tests/load/postgres/test_postgres_client.py +++ b/tests/load/postgres/test_postgres_client.py @@ -1,20 +1,19 @@ import os from typing import Iterator + import pytest +from tests.common.configuration.utils import environment +from tests.load.utils import expect_load_file, prepare_table, yield_client_with_storage +from tests.utils import TEST_STORAGE_ROOT, delete_test_storage, preserve_environ, skipifpypy -from dlt.common import pendulum, Wei -from dlt.common.configuration.resolve import resolve_configuration, ConfigFieldMissingException +from dlt.common import Wei, pendulum +from dlt.common.configuration.resolve import ConfigFieldMissingException, resolve_configuration from dlt.common.storages import FileStorage from dlt.common.utils import uniq_id - from dlt.destinations.postgres.configuration import PostgresCredentials from dlt.destinations.postgres.postgres import PostgresClient from dlt.destinations.postgres.sql_client import psycopg2 -from tests.utils import TEST_STORAGE_ROOT, delete_test_storage, skipifpypy, preserve_environ -from tests.load.utils import expect_load_file, prepare_table, yield_client_with_storage -from tests.common.configuration.utils import environment - @pytest.fixture def file_storage() -> FileStorage: @@ -43,14 +42,20 @@ def test_postgres_credentials_defaults() -> None: def test_postgres_credentials_native_value(environment) -> None: with pytest.raises(ConfigFieldMissingException): - resolve_configuration(PostgresCredentials(), explicit_value="postgres://loader@localhost/dlt_data") + resolve_configuration( + PostgresCredentials(), explicit_value="postgres://loader@localhost/dlt_data" + ) # set password via env os.environ["CREDENTIALS__PASSWORD"] = "pass" - c = resolve_configuration(PostgresCredentials(), explicit_value="postgres://loader@localhost/dlt_data") + c = resolve_configuration( + PostgresCredentials(), explicit_value="postgres://loader@localhost/dlt_data" + ) assert c.is_resolved() assert c.password == "pass" # but if password is specified - it is final - c = resolve_configuration(PostgresCredentials(), explicit_value="postgres://loader:loader@localhost/dlt_data") + c = resolve_configuration( + PostgresCredentials(), explicit_value="postgres://loader:loader@localhost/dlt_data" + ) assert c.is_resolved() assert c.password == "loader" @@ -68,14 +73,32 @@ def test_wei_value(client: PostgresClient, file_storage: FileStorage) -> None: user_table_name = prepare_table(client) # postgres supports EVM precisions - insert_sql = "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp, parse_data__metadata__rasa_x_id)\nVALUES\n" - insert_values = f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd', '{str(pendulum.now())}', {Wei.from_int256(2*256-1)});" - expect_load_file(client, file_storage, insert_sql+insert_values, user_table_name) - - insert_sql = "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp, parse_data__metadata__rasa_x_id)\nVALUES\n" - insert_values = f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd', '{str(pendulum.now())}', {Wei.from_int256(2*256-1, 18)});" - expect_load_file(client, file_storage, insert_sql+insert_values, user_table_name) - - insert_sql = "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp, parse_data__metadata__rasa_x_id)\nVALUES\n" - insert_values = f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd', '{str(pendulum.now())}', {Wei.from_int256(2*256-1, 78)});" - expect_load_file(client, file_storage, insert_sql+insert_values, user_table_name) + insert_sql = ( + "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp," + " parse_data__metadata__rasa_x_id)\nVALUES\n" + ) + insert_values = ( + f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd'," + f" '{str(pendulum.now())}', {Wei.from_int256(2*256-1)});" + ) + expect_load_file(client, file_storage, insert_sql + insert_values, user_table_name) + + insert_sql = ( + "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp," + " parse_data__metadata__rasa_x_id)\nVALUES\n" + ) + insert_values = ( + f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd'," + f" '{str(pendulum.now())}', {Wei.from_int256(2*256-1, 18)});" + ) + expect_load_file(client, file_storage, insert_sql + insert_values, user_table_name) + + insert_sql = ( + "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp," + " parse_data__metadata__rasa_x_id)\nVALUES\n" + ) + insert_values = ( + f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd'," + f" '{str(pendulum.now())}', {Wei.from_int256(2*256-1, 78)});" + ) + expect_load_file(client, file_storage, insert_sql + insert_values, user_table_name) diff --git a/tests/load/postgres/test_postgres_table_builder.py b/tests/load/postgres/test_postgres_table_builder.py index 50d7d3f245..aeb2231229 100644 --- a/tests/load/postgres/test_postgres_table_builder.py +++ b/tests/load/postgres/test_postgres_table_builder.py @@ -1,14 +1,14 @@ -import pytest from copy import deepcopy + +import pytest import sqlfluff +from tests.load.utils import TABLE_UPDATE -from dlt.common.utils import uniq_id from dlt.common.schema import Schema - -from dlt.destinations.postgres.postgres import PostgresClient +from dlt.common.utils import uniq_id from dlt.destinations.postgres.configuration import PostgresClientConfiguration, PostgresCredentials +from dlt.destinations.postgres.postgres import PostgresClient -from tests.load.utils import TABLE_UPDATE @pytest.fixture def schema() -> Schema: @@ -18,7 +18,12 @@ def schema() -> Schema: @pytest.fixture def client(schema: Schema) -> PostgresClient: # return client without opening connection - return PostgresClient(schema, PostgresClientConfiguration(dataset_name="test_" + uniq_id(), credentials=PostgresCredentials())) + return PostgresClient( + schema, + PostgresClientConfiguration( + dataset_name="test_" + uniq_id(), credentials=PostgresCredentials() + ), + ) def test_create_table(client: PostgresClient) -> None: @@ -75,7 +80,14 @@ def test_create_table_with_hints(client: PostgresClient) -> None: assert '"col4" timestamp with time zone NOT NULL' in sql # same thing without indexes - client = PostgresClient(client.schema, PostgresClientConfiguration(dataset_name="test_" + uniq_id(), create_indexes=False, credentials=PostgresCredentials())) + client = PostgresClient( + client.schema, + PostgresClientConfiguration( + dataset_name="test_" + uniq_id(), + create_indexes=False, + credentials=PostgresCredentials(), + ), + ) sql = client._get_table_update_sql("event_test_table", mod_update, False)[0] sqlfluff.parse(sql, dialect="postgres") assert '"col2" double precision NOT NULL' in sql diff --git a/tests/load/redshift/test_redshift_client.py b/tests/load/redshift/test_redshift_client.py index b9a2d9d6e7..f79c8e1ba3 100644 --- a/tests/load/redshift/test_redshift_client.py +++ b/tests/load/redshift/test_redshift_client.py @@ -1,25 +1,23 @@ import base64 import os from typing import Iterator -import pytest from unittest.mock import patch +import pytest +from tests.common.utils import COMMON_TEST_CASES_PATH +from tests.load.utils import expect_load_file, prepare_table, yield_client_with_storage +from tests.utils import TEST_STORAGE_ROOT, autouse_test_storage, skipifpypy + from dlt.common import json, pendulum from dlt.common.configuration.resolve import resolve_configuration from dlt.common.schema.typing import VERSION_TABLE_NAME from dlt.common.storages import FileStorage from dlt.common.storages.schema_storage import SchemaStorage from dlt.common.utils import uniq_id - from dlt.destinations.exceptions import DatabaseTerminalException from dlt.destinations.redshift.configuration import RedshiftCredentials from dlt.destinations.redshift.redshift import RedshiftClient, psycopg2 -from tests.common.utils import COMMON_TEST_CASES_PATH -from tests.utils import TEST_STORAGE_ROOT, autouse_test_storage, skipifpypy -from tests.load.utils import expect_load_file, prepare_table, yield_client_with_storage - - @pytest.fixture def file_storage() -> FileStorage: @@ -50,13 +48,13 @@ def test_text_too_long(client: RedshiftClient, file_storage: FileStorage) -> Non # try some unicode value - redshift checks the max length based on utf-8 representation, not the number of characters # max_len_str = 'उ' * (65535 // 3) + 1 -> does not fit # max_len_str = 'a' * 65535 + 1 -> does not fit - max_len_str = 'उ' * ((caps["max_text_data_type_length"] // 3) + 1) + max_len_str = "उ" * ((caps["max_text_data_type_length"] // 3) + 1) # max_len_str_b = max_len_str.encode("utf-8") # print(len(max_len_str_b)) row_id = uniq_id() insert_values = f"('{row_id}', '{uniq_id()}', '{max_len_str}' , '{str(pendulum.now())}');" with pytest.raises(DatabaseTerminalException) as exv: - expect_load_file(client, file_storage, insert_sql+insert_values, user_table_name) + expect_load_file(client, file_storage, insert_sql + insert_values, user_table_name) assert type(exv.value.dbapi_exception) is psycopg2.errors.StringDataRightTruncation @@ -64,25 +62,36 @@ def test_wei_value(client: RedshiftClient, file_storage: FileStorage) -> None: user_table_name = prepare_table(client) # max redshift decimal is (38, 0) (128 bit) = 10**38 - 1 - insert_sql = "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp, parse_data__metadata__rasa_x_id)\nVALUES\n" - insert_values = f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd', '{str(pendulum.now())}', {10**38});" + insert_sql = ( + "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp," + " parse_data__metadata__rasa_x_id)\nVALUES\n" + ) + insert_values = ( + f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd'," + f" '{str(pendulum.now())}', {10**38});" + ) with pytest.raises(DatabaseTerminalException) as exv: - expect_load_file(client, file_storage, insert_sql+insert_values, user_table_name) + expect_load_file(client, file_storage, insert_sql + insert_values, user_table_name) assert type(exv.value.dbapi_exception) is psycopg2.errors.InternalError_ def test_schema_string_exceeds_max_text_length(client: RedshiftClient) -> None: client.update_storage_schema() # schema should be compressed and stored as base64 - schema = SchemaStorage.load_schema_file(os.path.join(COMMON_TEST_CASES_PATH, "schemas/ev1"), "event", ("json",)) + schema = SchemaStorage.load_schema_file( + os.path.join(COMMON_TEST_CASES_PATH, "schemas/ev1"), "event", ("json",) + ) schema_str = json.dumps(schema.to_dict()) assert len(schema_str.encode("utf-8")) > client.capabilities.max_text_data_type_length client._update_schema_in_storage(schema) schema_info = client.get_newest_schema_from_storage() assert schema_info.schema == schema_str # take base64 from db - with client.sql_client.execute_query(f"SELECT schema FROM {VERSION_TABLE_NAME} WHERE version_hash = '{schema.stored_version_hash}'") as cur: - row = cur.fetchone() + with client.sql_client.execute_query( + f"SELECT schema FROM {VERSION_TABLE_NAME} WHERE version_hash =" + f" '{schema.stored_version_hash}'" + ) as cur: + row = cur.fetchone() # decode base base64.b64decode(row[0], validate=True) @@ -99,7 +108,10 @@ def test_maximum_query_size(client: RedshiftClient, file_storage: FileStorage) - insert_sql = "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp)\nVALUES\n" insert_values = "('{}', '{}', '90238094809sajlkjxoiewjhduuiuehd', '{}'){}" - insert_sql = insert_sql + insert_values.format(uniq_id(), uniq_id(), str(pendulum.now()), ",\n") * 150000 + insert_sql = ( + insert_sql + + insert_values.format(uniq_id(), uniq_id(), str(pendulum.now()), ",\n") * 150000 + ) insert_sql += insert_values.format(uniq_id(), uniq_id(), str(pendulum.now()), ";") user_table_name = prepare_table(client) diff --git a/tests/load/redshift/test_redshift_table_builder.py b/tests/load/redshift/test_redshift_table_builder.py index 16ef6f8a76..8a6f00fafa 100644 --- a/tests/load/redshift/test_redshift_table_builder.py +++ b/tests/load/redshift/test_redshift_table_builder.py @@ -1,16 +1,16 @@ +from copy import deepcopy + import pytest import sqlfluff -from copy import deepcopy +from tests.load.utils import TABLE_UPDATE -from dlt.common.utils import uniq_id, custom_environ, digest128 -from dlt.common.schema import Schema from dlt.common.configuration import resolve_configuration - +from dlt.common.schema import Schema +from dlt.common.utils import custom_environ, digest128, uniq_id from dlt.destinations.exceptions import DestinationSchemaWillNotUpdate -from dlt.destinations.redshift.redshift import RedshiftClient from dlt.destinations.redshift.configuration import RedshiftClientConfiguration, RedshiftCredentials +from dlt.destinations.redshift.redshift import RedshiftClient -from tests.load.utils import TABLE_UPDATE @pytest.fixture def schema() -> Schema: @@ -20,12 +20,22 @@ def schema() -> Schema: @pytest.fixture def client(schema: Schema) -> RedshiftClient: # return client without opening connection - return RedshiftClient(schema, RedshiftClientConfiguration(dataset_name="test_" + uniq_id(), credentials=RedshiftCredentials())) + return RedshiftClient( + schema, + RedshiftClientConfiguration( + dataset_name="test_" + uniq_id(), credentials=RedshiftCredentials() + ), + ) def test_redshift_configuration() -> None: # check names normalized - with custom_environ({"DESTINATION__REDSHIFT__CREDENTIALS__DATABASE": "UPPER_CASE_DATABASE", "DESTINATION__REDSHIFT__CREDENTIALS__PASSWORD": " pass\n"}): + with custom_environ( + { + "DESTINATION__REDSHIFT__CREDENTIALS__DATABASE": "UPPER_CASE_DATABASE", + "DESTINATION__REDSHIFT__CREDENTIALS__PASSWORD": " pass\n", + } + ): C = resolve_configuration(RedshiftCredentials(), sections=("destination", "redshift")) assert C.database == "upper_case_database" assert C.password == "pass" @@ -33,13 +43,16 @@ def test_redshift_configuration() -> None: # check fingerprint assert RedshiftClientConfiguration().fingerprint() == "" # based on host - c = resolve_configuration(RedshiftCredentials(), explicit_value="postgres://user1:pass@host1/db1?warehouse=warehouse1&role=role1") + c = resolve_configuration( + RedshiftCredentials(), + explicit_value="postgres://user1:pass@host1/db1?warehouse=warehouse1&role=role1", + ) assert RedshiftClientConfiguration(credentials=c).fingerprint() == digest128("host1") def test_create_table(client: RedshiftClient) -> None: # non existing table - sql = ';'.join(client._get_table_update_sql("event_test_table", TABLE_UPDATE, False)) + sql = ";".join(client._get_table_update_sql("event_test_table", TABLE_UPDATE, False)) sqlfluff.parse(sql, dialect="redshift") assert "event_test_table" in sql assert '"col1" bigint NOT NULL' in sql @@ -56,7 +69,7 @@ def test_create_table(client: RedshiftClient) -> None: def test_alter_table(client: RedshiftClient) -> None: # existing table has no columns - sql = ';'.join(client._get_table_update_sql("event_test_table", TABLE_UPDATE, True)) + sql = ";".join(client._get_table_update_sql("event_test_table", TABLE_UPDATE, True)) sqlfluff.parse(sql, dialect="redshift") canonical_name = client.sql_client.make_qualified_table_name("event_test_table") # must have several ALTER TABLE statements @@ -81,7 +94,7 @@ def test_create_table_with_hints(client: RedshiftClient) -> None: mod_update[0]["sort"] = True mod_update[1]["cluster"] = True mod_update[4]["cluster"] = True - sql = ';'.join(client._get_table_update_sql("event_test_table", mod_update, False)) + sql = ";".join(client._get_table_update_sql("event_test_table", mod_update, False)) sqlfluff.parse(sql, dialect="redshift") # PRIMARY KEY will not be present https://heap.io/blog/redshift-pitfalls-avoid assert '"col1" bigint SORTKEY NOT NULL' in sql diff --git a/tests/load/snowflake/test_snowflake_configuration.py b/tests/load/snowflake/test_snowflake_configuration.py index 7214574f2d..a88725658e 100644 --- a/tests/load/snowflake/test_snowflake_configuration.py +++ b/tests/load/snowflake/test_snowflake_configuration.py @@ -1,17 +1,20 @@ import os -import pytest from pathlib import Path + +import pytest from sqlalchemy.engine import make_url pytest.importorskip("snowflake") -from dlt.common.configuration.resolve import resolve_configuration +from tests.common.configuration.utils import environment + from dlt.common.configuration.exceptions import ConfigurationValueError +from dlt.common.configuration.resolve import resolve_configuration from dlt.common.utils import digest128 - -from dlt.destinations.snowflake.configuration import SnowflakeClientConfiguration, SnowflakeCredentials - -from tests.common.configuration.utils import environment +from dlt.destinations.snowflake.configuration import ( + SnowflakeClientConfiguration, + SnowflakeCredentials, +) def test_connection_string_with_all_params() -> None: @@ -36,49 +39,61 @@ def test_connection_string_with_all_params() -> None: def test_to_connector_params() -> None: - pkey_str = Path('./tests/common/cases/secrets/encrypted-private-key').read_text('utf8') + pkey_str = Path("./tests/common/cases/secrets/encrypted-private-key").read_text("utf8") creds = SnowflakeCredentials() creds.private_key = pkey_str # type: ignore[assignment] - creds.private_key_passphrase = '12345' # type: ignore[assignment] - creds.username = 'user1' - creds.database = 'db1' - creds.host = 'host1' - creds.warehouse = 'warehouse1' - creds.role = 'role1' + creds.private_key_passphrase = "12345" # type: ignore[assignment] + creds.username = "user1" + creds.database = "db1" + creds.host = "host1" + creds.warehouse = "warehouse1" + creds.role = "role1" params = creds.to_connector_params() - assert isinstance(params['private_key'], bytes) - params.pop('private_key') + assert isinstance(params["private_key"], bytes) + params.pop("private_key") assert params == dict( - user='user1', - database='db1', - account='host1', + user="user1", + database="db1", + account="host1", password=None, - warehouse='warehouse1', - role='role1', + warehouse="warehouse1", + role="role1", ) def test_snowflake_credentials_native_value(environment) -> None: with pytest.raises(ConfigurationValueError): - resolve_configuration(SnowflakeCredentials(), explicit_value="snowflake://user1@host1/db1?warehouse=warehouse1&role=role1") + resolve_configuration( + SnowflakeCredentials(), + explicit_value="snowflake://user1@host1/db1?warehouse=warehouse1&role=role1", + ) # set password via env os.environ["CREDENTIALS__PASSWORD"] = "pass" - c = resolve_configuration(SnowflakeCredentials(), explicit_value="snowflake://user1@host1/db1?warehouse=warehouse1&role=role1") + c = resolve_configuration( + SnowflakeCredentials(), + explicit_value="snowflake://user1@host1/db1?warehouse=warehouse1&role=role1", + ) assert c.is_resolved() assert c.password == "pass" # # but if password is specified - it is final - c = resolve_configuration(SnowflakeCredentials(), explicit_value="snowflake://user1:pass1@host1/db1?warehouse=warehouse1&role=role1") + c = resolve_configuration( + SnowflakeCredentials(), + explicit_value="snowflake://user1:pass1@host1/db1?warehouse=warehouse1&role=role1", + ) assert c.is_resolved() assert c.password == "pass1" # set PK via env del os.environ["CREDENTIALS__PASSWORD"] os.environ["CREDENTIALS__PRIVATE_KEY"] = "pk" - c = resolve_configuration(SnowflakeCredentials(), explicit_value="snowflake://user1@host1/db1?warehouse=warehouse1&role=role1") + c = resolve_configuration( + SnowflakeCredentials(), + explicit_value="snowflake://user1@host1/db1?warehouse=warehouse1&role=role1", + ) assert c.is_resolved() assert c.private_key == "pk" @@ -87,5 +102,8 @@ def test_snowflake_configuration() -> None: # def empty fingerprint assert SnowflakeClientConfiguration().fingerprint() == "" # based on host - c = resolve_configuration(SnowflakeCredentials(), explicit_value="snowflake://user1:pass@host1/db1?warehouse=warehouse1&role=role1") + c = resolve_configuration( + SnowflakeCredentials(), + explicit_value="snowflake://user1:pass@host1/db1?warehouse=warehouse1&role=role1", + ) assert SnowflakeClientConfiguration(credentials=c).fingerprint() == digest128("host1") diff --git a/tests/load/snowflake/test_snowflake_table_builder.py b/tests/load/snowflake/test_snowflake_table_builder.py index efbd478089..138d4de1e6 100644 --- a/tests/load/snowflake/test_snowflake_table_builder.py +++ b/tests/load/snowflake/test_snowflake_table_builder.py @@ -2,14 +2,16 @@ import pytest import sqlfluff +from tests.load.utils import TABLE_UPDATE -from dlt.common.utils import uniq_id from dlt.common.schema import Schema -from dlt.destinations.snowflake.snowflake import SnowflakeClient -from dlt.destinations.snowflake.configuration import SnowflakeClientConfiguration, SnowflakeCredentials +from dlt.common.utils import uniq_id from dlt.destinations.exceptions import DestinationSchemaWillNotUpdate - -from tests.load.utils import TABLE_UPDATE +from dlt.destinations.snowflake.configuration import ( + SnowflakeClientConfiguration, + SnowflakeCredentials, +) +from dlt.destinations.snowflake.snowflake import SnowflakeClient @pytest.fixture @@ -21,14 +23,16 @@ def schema() -> Schema: def snowflake_client(schema: Schema) -> SnowflakeClient: # return client without opening connection creds = SnowflakeCredentials() - return SnowflakeClient(schema, SnowflakeClientConfiguration(dataset_name="test_" + uniq_id(), credentials=creds)) + return SnowflakeClient( + schema, SnowflakeClientConfiguration(dataset_name="test_" + uniq_id(), credentials=creds) + ) def test_create_table(snowflake_client: SnowflakeClient) -> None: statements = snowflake_client._get_table_update_sql("event_test_table", TABLE_UPDATE, False) assert len(statements) == 1 sql = statements[0] - sqlfluff.parse(sql, dialect='snowflake') + sqlfluff.parse(sql, dialect="snowflake") assert sql.strip().startswith("CREATE TABLE") assert "EVENT_TEST_TABLE" in sql diff --git a/tests/load/test_dummy_client.py b/tests/load/test_dummy_client.py index e5d0186509..84cc0e4810 100644 --- a/tests/load/test_dummy_client.py +++ b/tests/load/test_dummy_client.py @@ -1,39 +1,43 @@ -import shutil import os +import shutil from multiprocessing.pool import ThreadPool from time import sleep from typing import List, Sequence, Tuple -import pytest from unittest.mock import patch +import pytest +from tests.load.utils import prepare_load_package +from tests.utils import ( + TEST_DICT_CONFIG_PROVIDER, + clean_test_storage, + init_test_logging, + preserve_environ, + skip_if_not_active, +) + +from dlt.common.destination.reference import DestinationReference, LoadJob from dlt.common.exceptions import TerminalException, TerminalValueError from dlt.common.schema import Schema +from dlt.common.schema.utils import get_top_level_table from dlt.common.storages import FileStorage, LoadStorage from dlt.common.storages.load_storage import JobWithUnsupportedWriterException from dlt.common.utils import uniq_id -from dlt.common.destination.reference import DestinationReference, LoadJob - -from dlt.load import Load -from dlt.destinations.job_impl import EmptyLoadJob - from dlt.destinations import dummy from dlt.destinations.dummy import dummy as dummy_impl from dlt.destinations.dummy.configuration import DummyClientConfiguration +from dlt.destinations.job_impl import EmptyLoadJob +from dlt.load import Load from dlt.load.exceptions import LoadClientJobFailed, LoadClientJobRetry -from dlt.common.schema.utils import get_top_level_table - -from tests.utils import clean_test_storage, init_test_logging, TEST_DICT_CONFIG_PROVIDER, preserve_environ -from tests.load.utils import prepare_load_package -from tests.utils import skip_if_not_active skip_if_not_active("dummy") NORMALIZED_FILES = [ "event_user.839c6e6b514e427687586ccc65bf133f.0.jsonl", - "event_loop_interrupted.839c6e6b514e427687586ccc65bf133f.0.jsonl" + "event_loop_interrupted.839c6e6b514e427687586ccc65bf133f.0.jsonl", ] + @pytest.fixture(autouse=True) def storage() -> FileStorage: return clean_test_storage(init_normalize=True, init_loader=True) @@ -47,10 +51,7 @@ def logger_autouse() -> None: def test_spool_job_started() -> None: # default config keeps the job always running load = setup_loader() - load_id, schema = prepare_load_package( - load.load_storage, - NORMALIZED_FILES - ) + load_id, schema = prepare_load_package(load.load_storage, NORMALIZED_FILES) files = load.load_storage.list_new_jobs(load_id) assert len(files) == 2 jobs: List[LoadJob] = [] @@ -58,7 +59,11 @@ def test_spool_job_started() -> None: job = Load.w_spool_job(load, f, load_id, schema) assert type(job) is dummy_impl.LoadDummyJob assert job.state() == "running" - assert load.load_storage.storage.has_file(load.load_storage._get_job_file_path(load_id, LoadStorage.STARTED_JOBS_FOLDER, job.file_name())) + assert load.load_storage.storage.has_file( + load.load_storage._get_job_file_path( + load_id, LoadStorage.STARTED_JOBS_FOLDER, job.file_name() + ) + ) jobs.append(job) # still running remaining_jobs = load.complete_jobs(load_id, jobs, schema) @@ -68,8 +73,7 @@ def test_spool_job_started() -> None: def test_unsupported_writer_type() -> None: load = setup_loader() load_id, _ = prepare_load_package( - load.load_storage, - ["event_bot.181291798a78198.0.unsupported_format"] + load.load_storage, ["event_bot.181291798a78198.0.unsupported_format"] ) with pytest.raises(TerminalValueError): load.load_storage.list_new_jobs(load_id) @@ -77,10 +81,7 @@ def test_unsupported_writer_type() -> None: def test_unsupported_write_disposition() -> None: load = setup_loader() - load_id, schema = prepare_load_package( - load.load_storage, - [NORMALIZED_FILES[0]] - ) + load_id, schema = prepare_load_package(load.load_storage, [NORMALIZED_FILES[0]]) # mock unsupported disposition schema.get_table("event_user")["write_disposition"] = "skip" # write back schema @@ -88,16 +89,15 @@ def test_unsupported_write_disposition() -> None: with ThreadPool() as pool: load.run(pool) # job with unsupported write disp. is failed - exception = [f for f in load.load_storage.list_failed_jobs(load_id) if f.endswith(".exception")][0] + exception = [ + f for f in load.load_storage.list_failed_jobs(load_id) if f.endswith(".exception") + ][0] assert "LoadClientUnsupportedWriteDisposition" in load.load_storage.storage.load(exception) def test_get_new_jobs_info() -> None: load = setup_loader() - load_id, schema = prepare_load_package( - load.load_storage, - NORMALIZED_FILES - ) + load_id, schema = prepare_load_package(load.load_storage, NORMALIZED_FILES) # no write disposition specified - get all new jobs assert len(load.get_new_jobs_info(load_id, schema)) == 2 @@ -108,60 +108,83 @@ def test_get_new_jobs_info() -> None: assert len(load.get_new_jobs_info(load_id, schema, ["replace"])) == 0 assert len(load.get_new_jobs_info(load_id, schema, ["replace", "append"])) == 2 - load.load_storage.start_job(load_id, "event_loop_interrupted.839c6e6b514e427687586ccc65bf133f.0.jsonl") + load.load_storage.start_job( + load_id, "event_loop_interrupted.839c6e6b514e427687586ccc65bf133f.0.jsonl" + ) assert len(load.get_new_jobs_info(load_id, schema, ["replace", "append"])) == 1 def test_get_completed_table_chain_single_job_per_table() -> None: load = setup_loader() - load_id, schema = prepare_load_package( - load.load_storage, - NORMALIZED_FILES - ) + load_id, schema = prepare_load_package(load.load_storage, NORMALIZED_FILES) top_job_table = get_top_level_table(schema.tables, "event_user") assert load.get_completed_table_chain(load_id, schema, top_job_table) is None # fake being completed - assert len(load.get_completed_table_chain(load_id, schema, top_job_table, "event_user.839c6e6b514e427687586ccc65bf133f.0.jsonl")) == 1 + assert ( + len( + load.get_completed_table_chain( + load_id, + schema, + top_job_table, + "event_user.839c6e6b514e427687586ccc65bf133f.0.jsonl", + ) + ) + == 1 + ) # actually complete loop_top_job_table = get_top_level_table(schema.tables, "event_loop_interrupted") - load.load_storage.start_job(load_id, "event_loop_interrupted.839c6e6b514e427687586ccc65bf133f.0.jsonl") + load.load_storage.start_job( + load_id, "event_loop_interrupted.839c6e6b514e427687586ccc65bf133f.0.jsonl" + ) assert load.get_completed_table_chain(load_id, schema, loop_top_job_table) is None - load.load_storage.complete_job(load_id, "event_loop_interrupted.839c6e6b514e427687586ccc65bf133f.0.jsonl") - assert load.get_completed_table_chain(load_id, schema, loop_top_job_table) == [schema.get_table("event_loop_interrupted")] - assert load.get_completed_table_chain(load_id, schema, loop_top_job_table, "event_user.839c6e6b514e427687586ccc65bf133f.0.jsonl") == [schema.get_table("event_loop_interrupted")] + load.load_storage.complete_job( + load_id, "event_loop_interrupted.839c6e6b514e427687586ccc65bf133f.0.jsonl" + ) + assert load.get_completed_table_chain(load_id, schema, loop_top_job_table) == [ + schema.get_table("event_loop_interrupted") + ] + assert load.get_completed_table_chain( + load_id, schema, loop_top_job_table, "event_user.839c6e6b514e427687586ccc65bf133f.0.jsonl" + ) == [schema.get_table("event_loop_interrupted")] def test_spool_job_failed() -> None: # this config fails job on start load = setup_loader(client_config=DummyClientConfiguration(fail_prob=1.0)) - load_id, schema = prepare_load_package( - load.load_storage, - NORMALIZED_FILES - ) + load_id, schema = prepare_load_package(load.load_storage, NORMALIZED_FILES) files = load.load_storage.list_new_jobs(load_id) jobs: List[LoadJob] = [] for f in files: job = Load.w_spool_job(load, f, load_id, schema) assert type(job) is EmptyLoadJob assert job.state() == "failed" - assert load.load_storage.storage.has_file(load.load_storage._get_job_file_path(load_id, LoadStorage.STARTED_JOBS_FOLDER, job.file_name())) + assert load.load_storage.storage.has_file( + load.load_storage._get_job_file_path( + load_id, LoadStorage.STARTED_JOBS_FOLDER, job.file_name() + ) + ) jobs.append(job) # complete files remaining_jobs = load.complete_jobs(load_id, jobs, schema) assert len(remaining_jobs) == 0 for job in jobs: - assert load.load_storage.storage.has_file(load.load_storage._get_job_file_path(load_id, LoadStorage.FAILED_JOBS_FOLDER, job.file_name())) - assert load.load_storage.storage.has_file(load.load_storage._get_job_file_path(load_id, LoadStorage.FAILED_JOBS_FOLDER, job.file_name() + ".exception")) + assert load.load_storage.storage.has_file( + load.load_storage._get_job_file_path( + load_id, LoadStorage.FAILED_JOBS_FOLDER, job.file_name() + ) + ) + assert load.load_storage.storage.has_file( + load.load_storage._get_job_file_path( + load_id, LoadStorage.FAILED_JOBS_FOLDER, job.file_name() + ".exception" + ) + ) started_files = load.load_storage.list_started_jobs(load_id) assert len(started_files) == 0 # test the whole flow load = setup_loader(client_config=DummyClientConfiguration(fail_prob=1.0)) - load_id, schema = prepare_load_package( - load.load_storage, - NORMALIZED_FILES - ) + load_id, schema = prepare_load_package(load.load_storage, NORMALIZED_FILES) run_all(load) package_info = load.load_storage.get_load_package_info(load_id) assert package_info.state == "loaded" @@ -174,10 +197,7 @@ def test_spool_job_failed_exception_init() -> None: os.environ["LOAD__RAISE_ON_FAILED_JOBS"] = "true" os.environ["FAIL_IN_INIT"] = "true" load = setup_loader(client_config=DummyClientConfiguration(fail_prob=1.0)) - load_id, _ = prepare_load_package( - load.load_storage, - NORMALIZED_FILES - ) + load_id, _ = prepare_load_package(load.load_storage, NORMALIZED_FILES) with patch.object(dummy_impl.DummyClient, "complete_load") as complete_load: with pytest.raises(LoadClientJobFailed) as py_ex: run_all(load) @@ -196,10 +216,7 @@ def test_spool_job_failed_exception_complete() -> None: os.environ["RAISE_ON_FAILED_JOBS"] = "true" os.environ["FAIL_IN_INIT"] = "false" load = setup_loader(client_config=DummyClientConfiguration(fail_prob=1.0)) - load_id, _ = prepare_load_package( - load.load_storage, - NORMALIZED_FILES - ) + load_id, _ = prepare_load_package(load.load_storage, NORMALIZED_FILES) with pytest.raises(LoadClientJobFailed) as py_ex: run_all(load) assert py_ex.value.load_id == load_id @@ -213,22 +230,17 @@ def test_spool_job_failed_exception_complete() -> None: def test_spool_job_retry_new() -> None: # this config retries job on start (transient fail) load = setup_loader(client_config=DummyClientConfiguration(retry_prob=1.0)) - load_id, schema = prepare_load_package( - load.load_storage, - NORMALIZED_FILES - ) + load_id, schema = prepare_load_package(load.load_storage, NORMALIZED_FILES) files = load.load_storage.list_new_jobs(load_id) for f in files: job = Load.w_spool_job(load, f, load_id, schema) assert job.state() == "retry" + def test_spool_job_retry_spool_new() -> None: # this config retries job on start (transient fail) load = setup_loader(client_config=DummyClientConfiguration(retry_prob=1.0)) - load_id, schema = prepare_load_package( - load.load_storage, - NORMALIZED_FILES - ) + load_id, schema = prepare_load_package(load.load_storage, NORMALIZED_FILES) # call higher level function that returns jobs and counts with ThreadPool() as pool: load.pool = pool @@ -241,17 +253,18 @@ def test_spool_job_retry_started() -> None: # this config keeps the job always running load = setup_loader() # dummy_impl.CLIENT_CONFIG = DummyClientConfiguration - load_id, schema = prepare_load_package( - load.load_storage, - NORMALIZED_FILES - ) + load_id, schema = prepare_load_package(load.load_storage, NORMALIZED_FILES) files = load.load_storage.list_new_jobs(load_id) jobs: List[LoadJob] = [] for f in files: job = Load.w_spool_job(load, f, load_id, schema) assert type(job) is dummy_impl.LoadDummyJob assert job.state() == "running" - assert load.load_storage.storage.has_file(load.load_storage._get_job_file_path(load_id, LoadStorage.STARTED_JOBS_FOLDER, job.file_name())) + assert load.load_storage.storage.has_file( + load.load_storage._get_job_file_path( + load_id, LoadStorage.STARTED_JOBS_FOLDER, job.file_name() + ) + ) # mock job config to make it retry job.config.retry_prob = 1.0 jobs.append(job) @@ -275,10 +288,7 @@ def test_spool_job_retry_started() -> None: def test_try_retrieve_job() -> None: load = setup_loader() - load_id, schema = prepare_load_package( - load.load_storage, - NORMALIZED_FILES - ) + load_id, schema = prepare_load_package(load.load_storage, NORMALIZED_FILES) # manually move jobs to started files = load.load_storage.list_new_jobs(load_id) for f in files: @@ -291,10 +301,7 @@ def test_try_retrieve_job() -> None: for j in jobs: assert j.state() == "failed" # new load package - load_id, schema = prepare_load_package( - load.load_storage, - NORMALIZED_FILES - ) + load_id, schema = prepare_load_package(load.load_storage, NORMALIZED_FILES) load.pool = ThreadPool() jobs_count, jobs = load.spool_new_jobs(load_id, schema) assert jobs_count == 2 @@ -313,7 +320,9 @@ def test_completed_loop() -> None: def test_failed_loop() -> None: # ask to delete completed - load = setup_loader(delete_completed_jobs=True, client_config=DummyClientConfiguration(fail_prob=1.0)) + load = setup_loader( + delete_completed_jobs=True, client_config=DummyClientConfiguration(fail_prob=1.0) + ) # actually not deleted because one of the jobs failed assert_complete_job(load, load.load_storage.storage, should_delete_completed=False) @@ -328,10 +337,7 @@ def test_completed_loop_with_delete_completed() -> None: def test_retry_on_new_loop() -> None: # test job that retries sitting in new jobs load = setup_loader(client_config=DummyClientConfiguration(retry_prob=1.0)) - load_id, schema = prepare_load_package( - load.load_storage, - NORMALIZED_FILES - ) + load_id, schema = prepare_load_package(load.load_storage, NORMALIZED_FILES) with ThreadPool() as pool: # 1st retry load.run(pool) @@ -352,17 +358,16 @@ def test_retry_on_new_loop() -> None: assert not load.load_storage.storage.has_folder(load.load_storage.get_package_path(load_id)) # parse the completed job names completed_path = load.load_storage.get_completed_package_path(load_id) - for fn in load.load_storage.storage.list_folder_files(os.path.join(completed_path, LoadStorage.COMPLETED_JOBS_FOLDER)): + for fn in load.load_storage.storage.list_folder_files( + os.path.join(completed_path, LoadStorage.COMPLETED_JOBS_FOLDER) + ): # we update a retry count in each case assert LoadStorage.parse_job_file_name(fn).retry_count == 2 def test_retry_exceptions() -> None: load = setup_loader(client_config=DummyClientConfiguration(retry_prob=1.0)) - prepare_load_package( - load.load_storage, - NORMALIZED_FILES - ) + prepare_load_package(load.load_storage, NORMALIZED_FILES) with ThreadPool() as pool: # 1st retry with pytest.raises(LoadClientJobRetry) as py_ex: @@ -383,8 +388,10 @@ def test_wrong_writer_type() -> None: load = setup_loader() load_id, _ = prepare_load_package( load.load_storage, - ["event_bot.b1d32c6660b242aaabbf3fc27245b7e6.0.insert_values", - "event_user.b1d32c6660b242aaabbf3fc27245b7e6.0.insert_values"] + [ + "event_bot.b1d32c6660b242aaabbf3fc27245b7e6.0.insert_values", + "event_user.b1d32c6660b242aaabbf3fc27245b7e6.0.insert_values", + ], ) with ThreadPool() as pool: with pytest.raises(JobWithUnsupportedWriterException) as exv: @@ -401,22 +408,28 @@ def test_terminal_exceptions() -> None: raise AssertionError() -def assert_complete_job(load: Load, storage: FileStorage, should_delete_completed: bool = False) -> None: - load_id, _ = prepare_load_package( - load.load_storage, - NORMALIZED_FILES - ) +def assert_complete_job( + load: Load, storage: FileStorage, should_delete_completed: bool = False +) -> None: + load_id, _ = prepare_load_package(load.load_storage, NORMALIZED_FILES) # will complete all jobs with patch.object(dummy_impl.DummyClient, "complete_load") as complete_load: with ThreadPool() as pool: load.run(pool) # did process schema update - assert storage.has_file(os.path.join(load.load_storage.get_package_path(load_id), LoadStorage.APPLIED_SCHEMA_UPDATES_FILE_NAME)) + assert storage.has_file( + os.path.join( + load.load_storage.get_package_path(load_id), + LoadStorage.APPLIED_SCHEMA_UPDATES_FILE_NAME, + ) + ) # will finalize the whole package load.run(pool) # moved to loaded assert not storage.has_folder(load.load_storage.get_package_path(load_id)) - completed_path = load.load_storage._get_job_folder_completed_path(load_id, "completed_jobs") + completed_path = load.load_storage._get_job_folder_completed_path( + load_id, "completed_jobs" + ) if should_delete_completed: # package was deleted assert not storage.has_folder(completed_path) @@ -436,7 +449,9 @@ def run_all(load: Load) -> None: sleep(0.1) -def setup_loader(delete_completed_jobs: bool = False, client_config: DummyClientConfiguration = None) -> Load: +def setup_loader( + delete_completed_jobs: bool = False, client_config: DummyClientConfiguration = None +) -> Load: # reset jobs for a test dummy_impl.JOBS = {} destination: DestinationReference = dummy @@ -446,7 +461,4 @@ def setup_loader(delete_completed_jobs: bool = False, client_config: DummyClient # setup loader with TEST_DICT_CONFIG_PROVIDER().values({"delete_completed_jobs": delete_completed_jobs}): - return Load( - destination, - initial_client_config=client_config - ) + return Load(destination, initial_client_config=client_config) diff --git a/tests/load/test_insert_job_client.py b/tests/load/test_insert_job_client.py index 874cd91d4a..56b1794aa0 100644 --- a/tests/load/test_insert_job_client.py +++ b/tests/load/test_insert_job_client.py @@ -1,37 +1,51 @@ from typing import Iterator, List -import pytest from unittest.mock import patch -from dlt.common import pendulum, Decimal +import pytest +from tests.load.pipeline.utils import DestinationTestConfiguration, destinations_configs +from tests.load.utils import expect_load_file, prepare_table, yield_client_with_storage +from tests.utils import TEST_STORAGE_ROOT, autouse_test_storage, skipifpypy + +from dlt.common import Decimal, pendulum from dlt.common.arithmetics import numeric_default_context from dlt.common.storages import FileStorage from dlt.common.utils import uniq_id - -from dlt.destinations.exceptions import DatabaseTerminalException, DatabaseTransientException, DatabaseUndefinedRelation +from dlt.destinations.exceptions import ( + DatabaseTerminalException, + DatabaseTransientException, + DatabaseUndefinedRelation, +) from dlt.destinations.insert_job_client import InsertValuesJobClient -from tests.utils import TEST_STORAGE_ROOT, autouse_test_storage, skipifpypy -from tests.load.utils import expect_load_file, prepare_table, yield_client_with_storage -from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration - DEFAULT_SUBSET = ["duckdb", "redshift", "postgres"] + @pytest.fixture def file_storage() -> FileStorage: return FileStorage(TEST_STORAGE_ROOT, file_type="b", makedirs=True) + @pytest.fixture(scope="function") def client(request) -> InsertValuesJobClient: yield from yield_client_with_storage(request.param.destination) -@pytest.mark.parametrize("client", destinations_configs(default_sql_configs=True, subset=DEFAULT_SUBSET), indirect=True, ids=lambda x: x.name) + +@pytest.mark.parametrize( + "client", + destinations_configs(default_sql_configs=True, subset=DEFAULT_SUBSET), + indirect=True, + ids=lambda x: x.name, +) def test_simple_load(client: InsertValuesJobClient, file_storage: FileStorage) -> None: user_table_name = prepare_table(client) canonical_name = client.sql_client.make_qualified_table_name(user_table_name) # create insert insert_sql = "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp)\nVALUES\n" - insert_values = f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd', '{str(pendulum.now())}')" - expect_load_file(client, file_storage, insert_sql+insert_values+";", user_table_name) + insert_values = ( + f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd'," + f" '{str(pendulum.now())}')" + ) + expect_load_file(client, file_storage, insert_sql + insert_values + ";", user_table_name) rows_count = client.sql_client.execute_sql(f"SELECT COUNT(1) FROM {canonical_name}")[0][0] assert rows_count == 1 # insert 100 more rows @@ -41,17 +55,26 @@ def test_simple_load(client: InsertValuesJobClient, file_storage: FileStorage) - assert rows_count == 101 # insert null value insert_sql_nc = "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp, text)\nVALUES\n" - insert_values_nc = f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd', '{str(pendulum.now())}', NULL);" - expect_load_file(client, file_storage, insert_sql_nc+insert_values_nc, user_table_name) + insert_values_nc = ( + f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd'," + f" '{str(pendulum.now())}', NULL);" + ) + expect_load_file(client, file_storage, insert_sql_nc + insert_values_nc, user_table_name) rows_count = client.sql_client.execute_sql(f"SELECT COUNT(1) FROM {canonical_name}")[0][0] assert rows_count == 102 @skipifpypy -@pytest.mark.parametrize("client", destinations_configs(default_sql_configs=True, subset=DEFAULT_SUBSET), indirect=True, ids=lambda x: x.name) +@pytest.mark.parametrize( + "client", + destinations_configs(default_sql_configs=True, subset=DEFAULT_SUBSET), + indirect=True, + ids=lambda x: x.name, +) def test_loading_errors(client: InsertValuesJobClient, file_storage: FileStorage) -> None: # test expected dbiapi exceptions for supported destinations import duckdb + from dlt.destinations.postgres.sql_client import psycopg2 TNotNullViolation = psycopg2.errors.NotNullViolation @@ -66,75 +89,102 @@ def test_loading_errors(client: InsertValuesJobClient, file_storage: FileStorage TNotNullViolation = duckdb.ConstraintException TNumericValueOutOfRange = TDatatypeMismatch = duckdb.ConversionException - user_table_name = prepare_table(client) # insert into unknown column insert_sql = "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp, _unk_)\nVALUES\n" - insert_values = f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd', '{str(pendulum.now())}', NULL);" + insert_values = ( + f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd'," + f" '{str(pendulum.now())}', NULL);" + ) with pytest.raises(DatabaseTerminalException) as exv: - expect_load_file(client, file_storage, insert_sql+insert_values, user_table_name) + expect_load_file(client, file_storage, insert_sql + insert_values, user_table_name) assert type(exv.value.dbapi_exception) is TUndefinedColumn # insert null value insert_sql = "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp)\nVALUES\n" insert_values = f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd', NULL);" with pytest.raises(DatabaseTerminalException) as exv: - expect_load_file(client, file_storage, insert_sql+insert_values, user_table_name) + expect_load_file(client, file_storage, insert_sql + insert_values, user_table_name) assert type(exv.value.dbapi_exception) is TNotNullViolation # insert wrong type insert_sql = "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp)\nVALUES\n" insert_values = f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd', TRUE);" with pytest.raises(DatabaseTerminalException) as exv: - expect_load_file(client, file_storage, insert_sql+insert_values, user_table_name) + expect_load_file(client, file_storage, insert_sql + insert_values, user_table_name) assert type(exv.value.dbapi_exception) is TDatatypeMismatch # numeric overflow on bigint - insert_sql = "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp, metadata__rasa_x_id)\nVALUES\n" + insert_sql = ( + "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp, metadata__rasa_x_id)\nVALUES\n" + ) # 2**64//2 - 1 is a maximum bigint value - insert_values = f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd', '{str(pendulum.now())}', {2**64//2});" + insert_values = ( + f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd'," + f" '{str(pendulum.now())}', {2**64//2});" + ) with pytest.raises(DatabaseTerminalException) as exv: - expect_load_file(client, file_storage, insert_sql+insert_values, user_table_name) - assert type(exv.value.dbapi_exception) in (TNumericValueOutOfRange, ) + expect_load_file(client, file_storage, insert_sql + insert_values, user_table_name) + assert type(exv.value.dbapi_exception) in (TNumericValueOutOfRange,) # numeric overflow on NUMERIC - insert_sql = "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp, parse_data__intent__id)\nVALUES\n" + insert_sql = ( + "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp," + " parse_data__intent__id)\nVALUES\n" + ) # default decimal is (38, 9) (128 bit), use local context to generate decimals with 38 precision with numeric_default_context(): - below_limit = Decimal(10**29) - Decimal('0.001') + below_limit = Decimal(10**29) - Decimal("0.001") above_limit = Decimal(10**29) # this will pass - insert_values = f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd', '{str(pendulum.now())}', {below_limit});" - expect_load_file(client, file_storage, insert_sql+insert_values, user_table_name) + insert_values = ( + f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd'," + f" '{str(pendulum.now())}', {below_limit});" + ) + expect_load_file(client, file_storage, insert_sql + insert_values, user_table_name) # this will raise - insert_values = f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd', '{str(pendulum.now())}', {above_limit});" + insert_values = ( + f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd'," + f" '{str(pendulum.now())}', {above_limit});" + ) with pytest.raises(DatabaseTerminalException) as exv: - expect_load_file(client, file_storage, insert_sql+insert_values, user_table_name) - assert type(exv.value.dbapi_exception) in (TNumericValueOutOfRange, psycopg2.errors.InternalError_) - - - -@pytest.mark.parametrize("client", destinations_configs(default_sql_configs=True, subset=DEFAULT_SUBSET), indirect=True, ids=lambda x: x.name) + expect_load_file(client, file_storage, insert_sql + insert_values, user_table_name) + assert type(exv.value.dbapi_exception) in ( + TNumericValueOutOfRange, + psycopg2.errors.InternalError_, + ) + + +@pytest.mark.parametrize( + "client", + destinations_configs(default_sql_configs=True, subset=DEFAULT_SUBSET), + indirect=True, + ids=lambda x: x.name, +) def test_query_split(client: InsertValuesJobClient, file_storage: FileStorage) -> None: mocked_caps = client.sql_client.__class__.capabilities insert_sql = prepare_insert_statement(10) # this guarantees that we execute inserts line by line - with patch.object(mocked_caps, "max_query_length", 2), patch.object(client.sql_client, "execute_fragments") as mocked_fragments: + with patch.object(mocked_caps, "max_query_length", 2), patch.object( + client.sql_client, "execute_fragments" + ) as mocked_fragments: user_table_name = prepare_table(client) expect_load_file(client, file_storage, insert_sql, user_table_name) # print(mocked_fragments.mock_calls) # split in 10 lines assert mocked_fragments.call_count == 10 for idx, call in enumerate(mocked_fragments.call_args_list): - fragment:List[str] = call.args[0] + fragment: List[str] = call.args[0] # last elem of fragment is a data list, first element is id, and must end with ;\n assert fragment[-1].startswith(f"'{idx}'") assert fragment[-1].endswith(");") assert_load_with_max_query(client, file_storage, 10, 2) start_idx = insert_sql.find("S\n(") - idx = insert_sql.find("),\n", len(insert_sql)//2) + idx = insert_sql.find("),\n", len(insert_sql) // 2) # set query length so it reads data until "," (followed by \n) query_length = (idx - start_idx - 1) * 2 - with patch.object(mocked_caps, "max_query_length", query_length), patch.object(client.sql_client, "execute_fragments") as mocked_fragments: + with patch.object(mocked_caps, "max_query_length", query_length), patch.object( + client.sql_client, "execute_fragments" + ) as mocked_fragments: user_table_name = prepare_table(client) expect_load_file(client, file_storage, insert_sql, user_table_name) # split in 2 on ',' @@ -142,7 +192,9 @@ def test_query_split(client: InsertValuesJobClient, file_storage: FileStorage) - # so it reads until "\n" query_length = (idx - start_idx) * 2 - with patch.object(mocked_caps, "max_query_length", query_length), patch.object(client.sql_client, "execute_fragments") as mocked_fragments: + with patch.object(mocked_caps, "max_query_length", query_length), patch.object( + client.sql_client, "execute_fragments" + ) as mocked_fragments: user_table_name = prepare_table(client) expect_load_file(client, file_storage, insert_sql, user_table_name) # split in 2 on ',' @@ -150,14 +202,21 @@ def test_query_split(client: InsertValuesJobClient, file_storage: FileStorage) - # so it reads till the last ; query_length = (len(insert_sql) - start_idx - 3) * 2 - with patch.object(mocked_caps, "max_query_length", query_length), patch.object(client.sql_client, "execute_fragments") as mocked_fragments: + with patch.object(mocked_caps, "max_query_length", query_length), patch.object( + client.sql_client, "execute_fragments" + ) as mocked_fragments: user_table_name = prepare_table(client) expect_load_file(client, file_storage, insert_sql, user_table_name) # split in 2 on ',' assert mocked_fragments.call_count == 1 -def assert_load_with_max_query(client: InsertValuesJobClient, file_storage: FileStorage, insert_lines: int, max_query_length: int) -> None: +def assert_load_with_max_query( + client: InsertValuesJobClient, + file_storage: FileStorage, + insert_lines: int, + max_query_length: int, +) -> None: # load and check for real mocked_caps = client.sql_client.__class__.capabilities with patch.object(mocked_caps, "max_query_length", max_query_length): @@ -167,7 +226,9 @@ def assert_load_with_max_query(client: InsertValuesJobClient, file_storage: File rows_count = client.sql_client.execute_sql(f"SELECT COUNT(1) FROM {user_table_name}")[0][0] assert rows_count == insert_lines # get all uniq ids in order - with client.sql_client.execute_query(f"SELECT _dlt_id FROM {user_table_name} ORDER BY timestamp ASC;") as c: + with client.sql_client.execute_query( + f"SELECT _dlt_id FROM {user_table_name} ORDER BY timestamp ASC;" + ) as c: rows = list(c.fetchall()) v_ids = list(map(lambda i: i[0], rows)) assert list(map(str, range(0, insert_lines))) == v_ids @@ -177,7 +238,7 @@ def assert_load_with_max_query(client: InsertValuesJobClient, file_storage: File def prepare_insert_statement(lines: int) -> str: insert_sql = "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp)\nVALUES\n" insert_values = "('{}', '{}', '90238094809sajlkjxoiewjhduuiuehd', '{}')" - #ids = [] + # ids = [] for i in range(lines): # id_ = uniq_id() # ids.append(id_) @@ -187,4 +248,4 @@ def prepare_insert_statement(lines: int) -> str: else: insert_sql += ";" # print(insert_sql) - return insert_sql \ No newline at end of file + return insert_sql diff --git a/tests/load/test_job_client.py b/tests/load/test_job_client.py index 5b72fcde74..47d806dd7c 100644 --- a/tests/load/test_job_client.py +++ b/tests/load/test_job_client.py @@ -1,46 +1,66 @@ import contextlib +import datetime # noqa: I251 +import io +import os from copy import deepcopy -import io, os from time import sleep +from typing import Iterator from unittest.mock import patch + import pytest -import datetime # noqa: I251 -from typing import Iterator +from tests.common.utils import load_json_case +from tests.load.pipeline.utils import DestinationTestConfiguration, destinations_configs +from tests.load.utils import ( + TABLE_ROW_ALL_DATA_TYPES, + TABLE_UPDATE, + TABLE_UPDATE_COLUMNS_SCHEMA, + assert_all_data_types_row, + cm_yield_client_with_storage, + expect_load_file, + load_table, + prepare_table, + write_dataset, + yield_client_with_storage, +) +from tests.utils import TEST_STORAGE_ROOT, autouse_test_storage from dlt.common import json, pendulum -from dlt.common.schema import Schema +from dlt.common.destination.reference import WithStagingDataset +from dlt.common.schema import Schema, TTableSchemaColumns from dlt.common.schema.typing import LOADS_TABLE_NAME, VERSION_TABLE_NAME -from dlt.common.schema.utils import new_table, new_column +from dlt.common.schema.utils import new_column, new_table from dlt.common.storages import FileStorage -from dlt.common.schema import TTableSchemaColumns from dlt.common.utils import uniq_id -from dlt.destinations.exceptions import DatabaseException, DatabaseTerminalException, DatabaseUndefinedRelation - +from dlt.destinations.exceptions import ( + DatabaseException, + DatabaseTerminalException, + DatabaseUndefinedRelation, +) from dlt.destinations.job_client_impl import SqlJobClientBase -from dlt.common.destination.reference import WithStagingDataset - -from tests.utils import TEST_STORAGE_ROOT, autouse_test_storage -from tests.common.utils import load_json_case -from tests.load.utils import (TABLE_UPDATE, TABLE_UPDATE_COLUMNS_SCHEMA, TABLE_ROW_ALL_DATA_TYPES, assert_all_data_types_row , expect_load_file, load_table, yield_client_with_storage, - cm_yield_client_with_storage, write_dataset, prepare_table) -from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration @pytest.fixture def file_storage() -> FileStorage: return FileStorage(TEST_STORAGE_ROOT, file_type="b", makedirs=True) + @pytest.fixture(scope="function") def client(request) -> SqlJobClientBase: yield from yield_client_with_storage(request.param.destination) + @pytest.mark.order(1) -@pytest.mark.parametrize("client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name) +@pytest.mark.parametrize( + "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name +) def test_initialize_storage(client: SqlJobClientBase) -> None: pass + @pytest.mark.order(2) -@pytest.mark.parametrize("client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name) +@pytest.mark.parametrize( + "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name +) def test_get_schema_on_empty_storage(client: SqlJobClientBase) -> None: # test getting schema on empty dataset without any tables exists, _ = client.get_storage_table(VERSION_TABLE_NAME) @@ -50,8 +70,11 @@ def test_get_schema_on_empty_storage(client: SqlJobClientBase) -> None: schema_info = client.get_schema_by_hash("8a0298298823928939") assert schema_info is None + @pytest.mark.order(3) -@pytest.mark.parametrize("client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name) +@pytest.mark.parametrize( + "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name +) def test_get_update_basic_schema(client: SqlJobClientBase) -> None: schema = client.schema schema_update = client.update_storage_schema() @@ -104,7 +127,7 @@ def test_get_update_basic_schema(client: SqlJobClientBase) -> None: client._update_schema_in_storage(first_schema) this_schema = client.get_schema_by_hash(first_schema.version_hash) newest_schema = client.get_newest_schema_from_storage() - assert this_schema == newest_schema # error + assert this_schema == newest_schema # error assert this_schema.version == first_schema.version == 2 assert this_schema.version_hash == first_schema.stored_version_hash @@ -127,7 +150,9 @@ def test_get_update_basic_schema(client: SqlJobClientBase) -> None: assert this_schema == newest_schema -@pytest.mark.parametrize("client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name) +@pytest.mark.parametrize( + "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name +) def test_complete_load(client: SqlJobClientBase) -> None: client.update_storage_schema() load_id = "182879721.182912" @@ -139,18 +164,28 @@ def test_complete_load(client: SqlJobClientBase) -> None: assert load_rows[0][1] == client.schema.name assert load_rows[0][2] == 0 import datetime # noqa: I251 + assert type(load_rows[0][3]) is datetime.datetime assert load_rows[0][4] == client.schema.version_hash # make sure that hash in loads exists in schema versions table versions_table = client.sql_client.make_qualified_table_name(VERSION_TABLE_NAME) - version_rows = list(client.sql_client.execute_sql(f"SELECT * FROM {versions_table} WHERE version_hash = %s", load_rows[0][4])) + version_rows = list( + client.sql_client.execute_sql( + f"SELECT * FROM {versions_table} WHERE version_hash = %s", load_rows[0][4] + ) + ) assert len(version_rows) == 1 client.complete_load("load2") load_rows = list(client.sql_client.execute_sql(f"SELECT * FROM {load_table}")) assert len(load_rows) == 2 -@pytest.mark.parametrize("client", destinations_configs(default_sql_configs=True, subset=["redshift", "postgres", "duckdb"]), indirect=True, ids=lambda x: x.name) +@pytest.mark.parametrize( + "client", + destinations_configs(default_sql_configs=True, subset=["redshift", "postgres", "duckdb"]), + indirect=True, + ids=lambda x: x.name, +) def test_schema_update_create_table_redshift(client: SqlJobClientBase) -> None: # infer typical rasa event schema schema = client.schema @@ -160,7 +195,7 @@ def test_schema_update_create_table_redshift(client: SqlJobClientBase) -> None: assert timestamp["sort"] is True # this will be destkey sender_id = schema._infer_column("sender_id", "982398490809324") - assert sender_id["cluster"] is True + assert sender_id["cluster"] is True # this will be not null record_hash = schema._infer_column("_dlt_id", "m,i0392903jdlkasjdlk") assert record_hash["unique"] is True @@ -176,7 +211,12 @@ def test_schema_update_create_table_redshift(client: SqlJobClientBase) -> None: assert exists is True -@pytest.mark.parametrize("client", destinations_configs(default_sql_configs=True, subset=["bigquery"]), indirect=True, ids=lambda x: x.name) +@pytest.mark.parametrize( + "client", + destinations_configs(default_sql_configs=True, subset=["bigquery"]), + indirect=True, + ids=lambda x: x.name, +) def test_schema_update_create_table_bigquery(client: SqlJobClientBase) -> None: # infer typical rasa event schema schema = client.schema @@ -203,7 +243,9 @@ def test_schema_update_create_table_bigquery(client: SqlJobClientBase) -> None: assert storage_table["version"]["cluster"] is False -@pytest.mark.parametrize("client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name) +@pytest.mark.parametrize( + "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name +) def test_schema_update_alter_table(client: SqlJobClientBase) -> None: # force to update schema in chunks by setting the max query size to 10 bytes/chars with patch.object(client.capabilities, "max_query_length", new=10): @@ -241,34 +283,36 @@ def test_schema_update_alter_table(client: SqlJobClientBase) -> None: assert storage_table["col4"]["data_type"] == "timestamp" -@pytest.mark.parametrize("client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name) +@pytest.mark.parametrize( + "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name +) def test_drop_tables(client: SqlJobClientBase) -> None: schema = client.schema # Add columns in all tables - schema.tables['event_user']['columns'] = dict(schema.tables['event_slot']['columns']) - schema.tables['event_bot']['columns'] = dict(schema.tables['event_slot']['columns']) + schema.tables["event_user"]["columns"] = dict(schema.tables["event_slot"]["columns"]) + schema.tables["event_bot"]["columns"] = dict(schema.tables["event_slot"]["columns"]) schema.bump_version() client.update_storage_schema() # Create a second schema with 2 hashes sd = schema.to_dict() - sd['name'] = 'event_2' + sd["name"] = "event_2" schema_2 = Schema.from_dict(sd).clone() # type: ignore[arg-type] for tbl_name in list(schema_2.tables): - if tbl_name.startswith('_dlt'): + if tbl_name.startswith("_dlt"): continue - schema_2.tables[tbl_name + '_2'] = schema_2.tables.pop(tbl_name) + schema_2.tables[tbl_name + "_2"] = schema_2.tables.pop(tbl_name) client.schema = schema_2 client.schema.bump_version() client.update_storage_schema() - client.schema.tables['event_slot_2']['columns']['value']['nullable'] = False + client.schema.tables["event_slot_2"]["columns"]["value"]["nullable"] = False client.schema.bump_version() client.update_storage_schema() # Drop tables from the first schema client.schema = schema - tables_to_drop = ['event_slot', 'event_user'] + tables_to_drop = ["event_slot", "event_user"] for tbl in tables_to_drop: del schema.tables[tbl] schema.bump_version() @@ -291,16 +335,22 @@ def test_drop_tables(client: SqlJobClientBase) -> None: # Verify _dlt_version schema is updated and old versions deleted table_name = client.sql_client.make_qualified_table_name(VERSION_TABLE_NAME) - rows = client.sql_client.execute_sql(f"SELECT version_hash FROM {table_name} WHERE schema_name = %s", schema.name) + rows = client.sql_client.execute_sql( + f"SELECT version_hash FROM {table_name} WHERE schema_name = %s", schema.name + ) assert len(rows) == 1 assert rows[0][0] == schema.version_hash # Other schema is not replaced - rows = client.sql_client.execute_sql(f"SELECT version_hash FROM {table_name} WHERE schema_name = %s", schema_2.name) + rows = client.sql_client.execute_sql( + f"SELECT version_hash FROM {table_name} WHERE schema_name = %s", schema_2.name + ) assert len(rows) == 2 -@pytest.mark.parametrize("client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name) +@pytest.mark.parametrize( + "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name +) def test_get_storage_table_with_all_types(client: SqlJobClientBase) -> None: schema = client.schema table_name = "event_test_table" + uniq_id() @@ -328,11 +378,14 @@ def test_get_storage_table_with_all_types(client: SqlJobClientBase) -> None: assert c["data_type"] == expected_c["data_type"] -@pytest.mark.parametrize("client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name) +@pytest.mark.parametrize( + "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name +) def test_preserve_column_order(client: SqlJobClientBase) -> None: schema = client.schema table_name = "event_test_table" + uniq_id() import random + columns = deepcopy(TABLE_UPDATE) random.shuffle(columns) print(columns) @@ -351,13 +404,15 @@ def _assert_columns_order(sql_: str) -> None: idx = sql_.find(col_name, idx) assert idx > 0, f"column {col_name} not found in script" - sql = ';'.join(client._get_table_update_sql(table_name, columns, generate_alter=False)) + sql = ";".join(client._get_table_update_sql(table_name, columns, generate_alter=False)) _assert_columns_order(sql) - sql = ';'.join(client._get_table_update_sql(table_name, columns, generate_alter=True)) + sql = ";".join(client._get_table_update_sql(table_name, columns, generate_alter=True)) _assert_columns_order(sql) -@pytest.mark.parametrize("client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name) +@pytest.mark.parametrize( + "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name +) def test_data_writer_load(client: SqlJobClientBase, file_storage: FileStorage) -> None: if not client.capabilities.preferred_loader_file_format: pytest.skip("preferred loader file format not set, destination will only work with staging") @@ -376,12 +431,16 @@ def test_data_writer_load(client: SqlJobClientBase, file_storage: FileStorage) - write_dataset(client, f, [rows[1]], client.schema.get_table(table_name)["columns"]) query = f.getvalue().decode() expect_load_file(client, file_storage, query, table_name) - db_row = client.sql_client.execute_sql(f"SELECT * FROM {canonical_name} WHERE f_int = {rows[1]['f_int']}")[0] + db_row = client.sql_client.execute_sql( + f"SELECT * FROM {canonical_name} WHERE f_int = {rows[1]['f_int']}" + )[0] assert db_row[3] is None assert db_row[5] is None -@pytest.mark.parametrize("client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name) +@pytest.mark.parametrize( + "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name +) def test_data_writer_string_escape(client: SqlJobClientBase, file_storage: FileStorage) -> None: if not client.capabilities.preferred_loader_file_format: pytest.skip("preferred loader file format not set, destination will only work with staging") @@ -399,8 +458,12 @@ def test_data_writer_string_escape(client: SqlJobClientBase, file_storage: FileS assert list(db_row) == list(row.values()) -@pytest.mark.parametrize("client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name) -def test_data_writer_string_escape_edge(client: SqlJobClientBase, file_storage: FileStorage) -> None: +@pytest.mark.parametrize( + "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name +) +def test_data_writer_string_escape_edge( + client: SqlJobClientBase, file_storage: FileStorage +) -> None: if not client.capabilities.preferred_loader_file_format: pytest.skip("preferred loader file format not set, destination will only work with staging") rows, table_name = prepare_schema(client, "weird_rows") @@ -409,19 +472,25 @@ def test_data_writer_string_escape_edge(client: SqlJobClientBase, file_storage: write_dataset(client, f, rows, client.schema.get_table(table_name)["columns"]) query = f.getvalue().decode() expect_load_file(client, file_storage, query, table_name) - for i in range(1,len(rows) + 1): + for i in range(1, len(rows) + 1): db_row = client.sql_client.execute_sql(f"SELECT str FROM {canonical_name} WHERE idx = {i}") - assert db_row[0][0] == rows[i-1]["str"] + assert db_row[0][0] == rows[i - 1]["str"] -@pytest.mark.parametrize('write_disposition', ["append", "replace"]) -@pytest.mark.parametrize("client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name) -def test_load_with_all_types(client: SqlJobClientBase, write_disposition: str, file_storage: FileStorage) -> None: +@pytest.mark.parametrize("write_disposition", ["append", "replace"]) +@pytest.mark.parametrize( + "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name +) +def test_load_with_all_types( + client: SqlJobClientBase, write_disposition: str, file_storage: FileStorage +) -> None: if not client.capabilities.preferred_loader_file_format: pytest.skip("preferred loader file format not set, destination will only work with staging") table_name = "event_test_table" + uniq_id() # we should have identical content with all disposition types - client.schema.update_schema(new_table(table_name, write_disposition=write_disposition, columns=TABLE_UPDATE)) + client.schema.update_schema( + new_table(table_name, write_disposition=write_disposition, columns=TABLE_UPDATE) + ) client.schema.bump_version() client.update_storage_schema() @@ -431,7 +500,9 @@ def test_load_with_all_types(client: SqlJobClientBase, write_disposition: str, f client.initialize_storage() client.update_storage_schema() - with client.sql_client.with_staging_dataset(write_disposition in client.get_stage_dispositions()): + with client.sql_client.with_staging_dataset( + write_disposition in client.get_stage_dispositions() + ): canonical_name = client.sql_client.make_qualified_table_name(table_name) # write row with io.BytesIO() as f: @@ -442,28 +513,39 @@ def test_load_with_all_types(client: SqlJobClientBase, write_disposition: str, f # content must equal assert_all_data_types_row(db_row) -@pytest.mark.parametrize('write_disposition,replace_strategy', [ - ("append", ""), - ("merge", ""), - ("replace", "truncate-and-insert"), - ("replace", "insert-from-staging"), - ("replace", "staging-optimized") - ]) -@pytest.mark.parametrize("client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name) -def test_write_dispositions(client: SqlJobClientBase, write_disposition: str, replace_strategy: str, file_storage: FileStorage) -> None: + +@pytest.mark.parametrize( + "write_disposition,replace_strategy", + [ + ("append", ""), + ("merge", ""), + ("replace", "truncate-and-insert"), + ("replace", "insert-from-staging"), + ("replace", "staging-optimized"), + ], +) +@pytest.mark.parametrize( + "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name +) +def test_write_dispositions( + client: SqlJobClientBase, + write_disposition: str, + replace_strategy: str, + file_storage: FileStorage, +) -> None: if not client.capabilities.preferred_loader_file_format: pytest.skip("preferred loader file format not set, destination will only work with staging") - os.environ['DESTINATION__REPLACE_STRATEGY'] = replace_strategy + os.environ["DESTINATION__REPLACE_STRATEGY"] = replace_strategy table_name = "event_test_table" + uniq_id() client.schema.update_schema( new_table(table_name, write_disposition=write_disposition, columns=TABLE_UPDATE) - ) + ) child_table = client.schema.naming.make_path(table_name, "child") # add child table without write disposition so it will be inferred from the parent client.schema.update_schema( new_table(child_table, columns=TABLE_UPDATE, parent_table_name=table_name) - ) + ) client.schema.bump_version() client.update_storage_schema() @@ -495,7 +577,12 @@ def test_write_dispositions(client: SqlJobClientBase, write_disposition: str, re else: # load directly on other expect_load_file(client, file_storage, query, t) - db_rows = list(client.sql_client.execute_sql(f"SELECT * FROM {client.sql_client.make_qualified_table_name(t)} ORDER BY col1 ASC")) + db_rows = list( + client.sql_client.execute_sql( + f"SELECT * FROM {client.sql_client.make_qualified_table_name(t)} ORDER BY" + " col1 ASC" + ) + ) # in case of merge if write_disposition == "append": # we append 1 row to tables in each iteration @@ -508,13 +595,20 @@ def test_write_dispositions(client: SqlJobClientBase, write_disposition: str, re assert len(db_rows) == 0 # check staging with client.sql_client.with_staging_dataset(staging=True): - db_rows = list(client.sql_client.execute_sql(f"SELECT * FROM {client.sql_client.make_qualified_table_name(t)} ORDER BY col1 ASC")) + db_rows = list( + client.sql_client.execute_sql( + f"SELECT * FROM {client.sql_client.make_qualified_table_name(t)} ORDER" + " BY col1 ASC" + ) + ) assert len(db_rows) == idx + 1 # last row must have our last idx - make sure we append and overwrite assert db_rows[-1][0] == idx -@pytest.mark.parametrize("client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name) +@pytest.mark.parametrize( + "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name +) def test_retrieve_job(client: SqlJobClientBase, file_storage: FileStorage) -> None: if not client.capabilities.preferred_loader_file_format: pytest.skip("preferred loader file format not set, destination will only work with staging") @@ -522,8 +616,8 @@ def test_retrieve_job(client: SqlJobClientBase, file_storage: FileStorage) -> No load_json = { "_dlt_id": uniq_id(), "_dlt_root_id": uniq_id(), - "sender_id":'90238094809sajlkjxoiewjhduuiuehd', - "timestamp": str(pendulum.now()) + "sender_id": "90238094809sajlkjxoiewjhduuiuehd", + "timestamp": str(pendulum.now()), } with io.BytesIO() as f: write_dataset(client, f, [load_json], client.schema.get_table(user_table_name)["columns"]) @@ -538,30 +632,50 @@ def test_retrieve_job(client: SqlJobClientBase, file_storage: FileStorage) -> No assert r_job.state() == "completed" -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name) +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) def test_default_schema_name_init_storage(destination_config: DestinationTestConfiguration) -> None: - with cm_yield_client_with_storage(destination_config.destination, default_config_values={ - "default_schema_name": "event" # pass the schema that is a default schema. that should create dataset with the name `dataset_name` - }) as client: + with cm_yield_client_with_storage( + destination_config.destination, + default_config_values={ + "default_schema_name": ( # pass the schema that is a default schema. that should create dataset with the name `dataset_name` + "event" + ) + }, + ) as client: assert client.sql_client.dataset_name == client.config.dataset_name assert client.sql_client.has_dataset() - with cm_yield_client_with_storage(destination_config.destination, default_config_values={ - "default_schema_name": None # no default_schema. that should create dataset with the name `dataset_name` - }) as client: + with cm_yield_client_with_storage( + destination_config.destination, + default_config_values={ + "default_schema_name": ( + None # no default_schema. that should create dataset with the name `dataset_name` + ) + }, + ) as client: assert client.sql_client.dataset_name == client.config.dataset_name assert client.sql_client.has_dataset() - with cm_yield_client_with_storage(destination_config.destination, default_config_values={ - "default_schema_name": "event_2" # the default schema is not event schema . that should create dataset with the name `dataset_name` with schema suffix - }) as client: + with cm_yield_client_with_storage( + destination_config.destination, + default_config_values={ + "default_schema_name": ( # the default schema is not event schema . that should create dataset with the name `dataset_name` with schema suffix + "event_2" + ) + }, + ) as client: assert client.sql_client.dataset_name == client.config.dataset_name + "_event" assert client.sql_client.has_dataset() -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name) -def test_many_schemas_single_dataset(destination_config: DestinationTestConfiguration, file_storage: FileStorage) -> None: - +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) +def test_many_schemas_single_dataset( + destination_config: DestinationTestConfiguration, file_storage: FileStorage +) -> None: def _load_something(_client: SqlJobClientBase, expected_rows: int) -> None: # load something to event:user_table user_row = { @@ -570,7 +684,7 @@ def _load_something(_client: SqlJobClientBase, expected_rows: int) -> None: # "_dlt_load_id": "load_id", "event": "user", "sender_id": "sender_id", - "timestamp": str(pendulum.now()) + "timestamp": str(pendulum.now()), } with io.BytesIO() as f: write_dataset(_client, f, [user_row], _client.schema.tables["event_user"]["columns"]) @@ -579,11 +693,14 @@ def _load_something(_client: SqlJobClientBase, expected_rows: int) -> None: db_rows = list(_client.sql_client.execute_sql("SELECT * FROM event_user")) assert len(db_rows) == expected_rows - with cm_yield_client_with_storage(destination_config.destination, default_config_values={"default_schema_name": None}) as client: - + with cm_yield_client_with_storage( + destination_config.destination, default_config_values={"default_schema_name": None} + ) as client: # event schema with event table if not client.capabilities.preferred_loader_file_format: - pytest.skip("preferred loader file format not set, destination will only work with staging") + pytest.skip( + "preferred loader file format not set, destination will only work with staging" + ) user_table = load_table("event_user")["event_user"] client.schema.update_schema(new_table("event_user", columns=user_table.values())) @@ -627,11 +744,17 @@ def _load_something(_client: SqlJobClientBase, expected_rows: int) -> None: _load_something(client, 3) # adding new non null column will generate sync error - event_3_schema.tables["event_user"]["columns"]["mandatory_column"] = new_column("mandatory_column", "text", nullable=False) + event_3_schema.tables["event_user"]["columns"]["mandatory_column"] = new_column( + "mandatory_column", "text", nullable=False + ) client.schema.bump_version() with pytest.raises(DatabaseException) as py_ex: client.update_storage_schema() - assert "mandatory_column" in str(py_ex.value).lower() or "NOT NULL" in str(py_ex.value) or "Adding columns with constraints not yet supported" in str(py_ex.value) + assert ( + "mandatory_column" in str(py_ex.value).lower() + or "NOT NULL" in str(py_ex.value) + or "Adding columns with constraints not yet supported" in str(py_ex.value) + ) def prepare_schema(client: SqlJobClientBase, case: str) -> None: diff --git a/tests/load/test_sql_client.py b/tests/load/test_sql_client.py index d498845e62..a3549174d6 100644 --- a/tests/load/test_sql_client.py +++ b/tests/load/test_sql_client.py @@ -1,76 +1,98 @@ -import pytest import datetime # noqa: I251 -from typing import Iterator -from threading import Thread, Event +from threading import Event, Thread from time import sleep +from typing import Iterator -from dlt.common import pendulum, Decimal +import pytest +from tests.load.pipeline.utils import destinations_configs +from tests.load.utils import AWS_BUCKET, prepare_table, yield_client_with_storage +from tests.utils import TEST_STORAGE_ROOT, autouse_test_storage + +from dlt.common import Decimal, pendulum from dlt.common.exceptions import IdentifierTooLongException from dlt.common.schema.typing import LOADS_TABLE_NAME, VERSION_TABLE_NAME from dlt.common.storages import FileStorage +from dlt.common.time import ensure_pendulum_datetime from dlt.common.utils import derives_from_class_of_name, uniq_id -from dlt.destinations.exceptions import DatabaseException, DatabaseTerminalException, DatabaseTransientException, DatabaseUndefinedRelation - -from dlt.destinations.sql_client import DBApiCursor, SqlClientBase +from dlt.destinations.exceptions import ( + DatabaseException, + DatabaseTerminalException, + DatabaseTransientException, + DatabaseUndefinedRelation, +) from dlt.destinations.job_client_impl import SqlJobClientBase -from dlt.common.time import ensure_pendulum_datetime - -from tests.utils import TEST_STORAGE_ROOT, autouse_test_storage -from tests.load.utils import yield_client_with_storage, prepare_table, AWS_BUCKET -from tests.load.pipeline.utils import destinations_configs +from dlt.destinations.sql_client import DBApiCursor, SqlClientBase @pytest.fixture def file_storage() -> FileStorage: return FileStorage(TEST_STORAGE_ROOT, file_type="b", makedirs=True) + @pytest.fixture(scope="function") def client(request) -> SqlJobClientBase: yield from yield_client_with_storage(request.param.destination) -@pytest.mark.parametrize("client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name) + +@pytest.mark.parametrize( + "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name +) def test_sql_client_default_dataset_unqualified(client: SqlJobClientBase) -> None: client.update_storage_schema() load_id = "182879721.182912" client.complete_load(load_id) curr: DBApiCursor # get data from unqualified name - with client.sql_client.execute_query(f"SELECT * FROM {LOADS_TABLE_NAME} ORDER BY inserted_at") as curr: + with client.sql_client.execute_query( + f"SELECT * FROM {LOADS_TABLE_NAME} ORDER BY inserted_at" + ) as curr: columns = [c[0] for c in curr.description] data = curr.fetchall() assert len(data) > 0 # get data from qualified name load_table = client.sql_client.make_qualified_table_name(LOADS_TABLE_NAME) - with client.sql_client.execute_query(f"SELECT * FROM {load_table} ORDER BY inserted_at") as curr: + with client.sql_client.execute_query( + f"SELECT * FROM {load_table} ORDER BY inserted_at" + ) as curr: assert [c[0] for c in curr.description] == columns assert curr.fetchall() == data -@pytest.mark.parametrize("client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name) +@pytest.mark.parametrize( + "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name +) def test_malformed_query_parameters(client: SqlJobClientBase) -> None: client.update_storage_schema() # parameters for placeholder will not be provided. the placeholder remains in query with pytest.raises(DatabaseTransientException) as term_ex: - with client.sql_client.execute_query(f"SELECT * FROM {LOADS_TABLE_NAME} WHERE inserted_at = %s"): + with client.sql_client.execute_query( + f"SELECT * FROM {LOADS_TABLE_NAME} WHERE inserted_at = %s" + ): pass assert client.sql_client.is_dbapi_exception(term_ex.value.dbapi_exception) # too many parameters with pytest.raises(DatabaseTransientException) as term_ex: - with client.sql_client.execute_query(f"SELECT * FROM {LOADS_TABLE_NAME} WHERE inserted_at = %s", pendulum.now(), 10): + with client.sql_client.execute_query( + f"SELECT * FROM {LOADS_TABLE_NAME} WHERE inserted_at = %s", pendulum.now(), 10 + ): pass assert client.sql_client.is_dbapi_exception(term_ex.value.dbapi_exception) # unknown named parameter if client.sql_client.dbapi.paramstyle == "pyformat": with pytest.raises(DatabaseTransientException) as term_ex: - with client.sql_client.execute_query(f"SELECT * FROM {LOADS_TABLE_NAME} WHERE inserted_at = %(date)s"): + with client.sql_client.execute_query( + f"SELECT * FROM {LOADS_TABLE_NAME} WHERE inserted_at = %(date)s" + ): pass assert client.sql_client.is_dbapi_exception(term_ex.value.dbapi_exception) -@pytest.mark.parametrize("client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name) +@pytest.mark.parametrize( + "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name +) def test_malformed_execute_parameters(client: SqlJobClientBase) -> None: client.update_storage_schema() # parameters for placeholder will not be provided. the placeholder remains in query @@ -80,26 +102,36 @@ def test_malformed_execute_parameters(client: SqlJobClientBase) -> None: # too many parameters with pytest.raises(DatabaseTransientException) as term_ex: - client.sql_client.execute_sql(f"SELECT * FROM {LOADS_TABLE_NAME} WHERE inserted_at = %s", pendulum.now(), 10) + client.sql_client.execute_sql( + f"SELECT * FROM {LOADS_TABLE_NAME} WHERE inserted_at = %s", pendulum.now(), 10 + ) assert client.sql_client.is_dbapi_exception(term_ex.value.dbapi_exception) # unknown named parameter if client.sql_client.dbapi.paramstyle == "pyformat": with pytest.raises(DatabaseTransientException) as term_ex: - client.sql_client.execute_sql(f"SELECT * FROM {LOADS_TABLE_NAME} WHERE inserted_at = %(date)s") + client.sql_client.execute_sql( + f"SELECT * FROM {LOADS_TABLE_NAME} WHERE inserted_at = %(date)s" + ) assert client.sql_client.is_dbapi_exception(term_ex.value.dbapi_exception) -@pytest.mark.parametrize("client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name) +@pytest.mark.parametrize( + "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name +) def test_execute_sql(client: SqlJobClientBase) -> None: client.update_storage_schema() # ask with datetime # no_rows = client.sql_client.execute_sql(f"SELECT schema_name, inserted_at FROM {VERSION_TABLE_NAME} WHERE inserted_at = %s", pendulum.now().add(seconds=1)) # assert len(no_rows) == 0 - rows = client.sql_client.execute_sql(f"SELECT schema_name, inserted_at FROM {VERSION_TABLE_NAME}") + rows = client.sql_client.execute_sql( + f"SELECT schema_name, inserted_at FROM {VERSION_TABLE_NAME}" + ) assert len(rows) == 1 assert rows[0][0] == "event" - rows = client.sql_client.execute_sql(f"SELECT schema_name, inserted_at FROM {VERSION_TABLE_NAME} WHERE schema_name = %s", "event") + rows = client.sql_client.execute_sql( + f"SELECT schema_name, inserted_at FROM {VERSION_TABLE_NAME} WHERE schema_name = %s", "event" + ) assert len(rows) == 1 # print(rows) assert rows[0][0] == "event" @@ -108,18 +140,31 @@ def test_execute_sql(client: SqlJobClientBase) -> None: # print(rows[0][1]) # print(type(rows[0][1])) # convert to pendulum to make sure it is supported by dbapi - rows = client.sql_client.execute_sql(f"SELECT schema_name, inserted_at FROM {VERSION_TABLE_NAME} WHERE inserted_at = %s", ensure_pendulum_datetime(rows[0][1])) + rows = client.sql_client.execute_sql( + f"SELECT schema_name, inserted_at FROM {VERSION_TABLE_NAME} WHERE inserted_at = %s", + ensure_pendulum_datetime(rows[0][1]), + ) assert len(rows) == 1 # use rows in subsequent test if client.sql_client.dbapi.paramstyle == "pyformat": - rows = client.sql_client.execute_sql(f"SELECT schema_name, inserted_at FROM {VERSION_TABLE_NAME} WHERE inserted_at = %(date)s", date=rows[0][1]) + rows = client.sql_client.execute_sql( + f"SELECT schema_name, inserted_at FROM {VERSION_TABLE_NAME} WHERE inserted_at =" + " %(date)s", + date=rows[0][1], + ) assert len(rows) == 1 assert rows[0][0] == "event" - rows = client.sql_client.execute_sql(f"SELECT schema_name, inserted_at FROM {VERSION_TABLE_NAME} WHERE inserted_at = %(date)s", date=pendulum.now().add(seconds=1)) + rows = client.sql_client.execute_sql( + f"SELECT schema_name, inserted_at FROM {VERSION_TABLE_NAME} WHERE inserted_at =" + " %(date)s", + date=pendulum.now().add(seconds=1), + ) assert len(rows) == 0 -@pytest.mark.parametrize("client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name) +@pytest.mark.parametrize( + "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name +) def test_execute_ddl(client: SqlJobClientBase) -> None: uniq_suffix = uniq_id() client.update_storage_schema() @@ -129,37 +174,57 @@ def test_execute_ddl(client: SqlJobClientBase) -> None: assert rows[0][0] == Decimal("1.0") # create view, note that bigquery will not let you execute a view that does not have fully qualified table names. f_q_table_name = client.sql_client.make_qualified_table_name(table_name) - client.sql_client.execute_sql(f"CREATE OR REPLACE VIEW view_tmp_{uniq_suffix} AS (SELECT * FROM {f_q_table_name});") + client.sql_client.execute_sql( + f"CREATE OR REPLACE VIEW view_tmp_{uniq_suffix} AS (SELECT * FROM {f_q_table_name});" + ) rows = client.sql_client.execute_sql(f"SELECT * FROM view_tmp_{uniq_suffix}") assert rows[0][0] == Decimal("1.0") -@pytest.mark.parametrize("client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name) +@pytest.mark.parametrize( + "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name +) def test_execute_query(client: SqlJobClientBase) -> None: client.update_storage_schema() - with client.sql_client.execute_query(f"SELECT schema_name, inserted_at FROM {VERSION_TABLE_NAME}") as curr: + with client.sql_client.execute_query( + f"SELECT schema_name, inserted_at FROM {VERSION_TABLE_NAME}" + ) as curr: rows = curr.fetchall() assert len(rows) == 1 assert rows[0][0] == "event" - with client.sql_client.execute_query(f"SELECT schema_name, inserted_at FROM {VERSION_TABLE_NAME} WHERE schema_name = %s", "event") as curr: + with client.sql_client.execute_query( + f"SELECT schema_name, inserted_at FROM {VERSION_TABLE_NAME} WHERE schema_name = %s", "event" + ) as curr: rows = curr.fetchall() assert len(rows) == 1 assert rows[0][0] == "event" assert isinstance(rows[0][1], datetime.datetime) - with client.sql_client.execute_query(f"SELECT schema_name, inserted_at FROM {VERSION_TABLE_NAME} WHERE inserted_at = %s", rows[0][1]) as curr: + with client.sql_client.execute_query( + f"SELECT schema_name, inserted_at FROM {VERSION_TABLE_NAME} WHERE inserted_at = %s", + rows[0][1], + ) as curr: rows = curr.fetchall() assert len(rows) == 1 assert rows[0][0] == "event" - with client.sql_client.execute_query(f"SELECT schema_name, inserted_at FROM {VERSION_TABLE_NAME} WHERE inserted_at = %s", pendulum.now().add(seconds=1)) as curr: + with client.sql_client.execute_query( + f"SELECT schema_name, inserted_at FROM {VERSION_TABLE_NAME} WHERE inserted_at = %s", + pendulum.now().add(seconds=1), + ) as curr: rows = curr.fetchall() assert len(rows) == 0 if client.sql_client.dbapi.paramstyle == "pyformat": - with client.sql_client.execute_query(f"SELECT schema_name, inserted_at FROM {VERSION_TABLE_NAME} WHERE inserted_at = %(date)s", date=pendulum.now().add(seconds=1)) as curr: + with client.sql_client.execute_query( + f"SELECT schema_name, inserted_at FROM {VERSION_TABLE_NAME} WHERE inserted_at =" + " %(date)s", + date=pendulum.now().add(seconds=1), + ) as curr: rows = curr.fetchall() assert len(rows) == 0 -@pytest.mark.parametrize("client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name) +@pytest.mark.parametrize( + "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name +) def test_execute_df(client: SqlJobClientBase) -> None: if client.config.destination_name == "bigquery": chunk_size = 50 @@ -194,7 +259,9 @@ def test_execute_df(client: SqlJobClientBase) -> None: assert df_3 is None -@pytest.mark.parametrize("client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name) +@pytest.mark.parametrize( + "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name +) def test_database_exceptions(client: SqlJobClientBase) -> None: client.update_storage_schema() # invalid table @@ -211,11 +278,15 @@ def test_database_exceptions(client: SqlJobClientBase) -> None: pass assert client.sql_client.is_dbapi_exception(term_ex.value.dbapi_exception) with pytest.raises(DatabaseUndefinedRelation) as term_ex: - with client.sql_client.execute_query("DELETE FROM TABLE_XXX WHERE 1=1;DELETE FROM ticket_forms__ticket_field_ids WHERE 1=1;"): + with client.sql_client.execute_query( + "DELETE FROM TABLE_XXX WHERE 1=1;DELETE FROM ticket_forms__ticket_field_ids WHERE 1=1;" + ): pass assert client.sql_client.is_dbapi_exception(term_ex.value.dbapi_exception) with pytest.raises(DatabaseUndefinedRelation) as term_ex: - with client.sql_client.execute_query("DROP TABLE TABLE_XXX;DROP TABLE ticket_forms__ticket_field_ids;"): + with client.sql_client.execute_query( + "DROP TABLE TABLE_XXX;DROP TABLE ticket_forms__ticket_field_ids;" + ): pass # invalid syntax @@ -225,7 +296,9 @@ def test_database_exceptions(client: SqlJobClientBase) -> None: assert client.sql_client.is_dbapi_exception(term_ex.value.dbapi_exception) # invalid column with pytest.raises(DatabaseTerminalException) as term_ex: - with client.sql_client.execute_query(f"SELECT * FROM {LOADS_TABLE_NAME} ORDER BY column_XXX"): + with client.sql_client.execute_query( + f"SELECT * FROM {LOADS_TABLE_NAME} ORDER BY column_XXX" + ): pass assert client.sql_client.is_dbapi_exception(term_ex.value.dbapi_exception) # invalid parameters to dbapi @@ -237,7 +310,9 @@ def test_database_exceptions(client: SqlJobClientBase) -> None: with client.sql_client.with_alternative_dataset_name("UNKNOWN"): qualified_name = client.sql_client.make_qualified_table_name(LOADS_TABLE_NAME) with pytest.raises(DatabaseUndefinedRelation) as term_ex: - with client.sql_client.execute_query(f"SELECT * FROM {qualified_name} ORDER BY inserted_at"): + with client.sql_client.execute_query( + f"SELECT * FROM {qualified_name} ORDER BY inserted_at" + ): pass assert client.sql_client.is_dbapi_exception(term_ex.value.dbapi_exception) with pytest.raises(DatabaseUndefinedRelation) as term_ex: @@ -250,27 +325,37 @@ def test_database_exceptions(client: SqlJobClientBase) -> None: assert client.sql_client.is_dbapi_exception(term_ex.value.dbapi_exception) -@pytest.mark.parametrize("client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name) +@pytest.mark.parametrize( + "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name +) def test_commit_transaction(client: SqlJobClientBase) -> None: table_name = prepare_temp_table(client) with client.sql_client.begin_transaction(): client.sql_client.execute_sql(f"INSERT INTO {table_name} VALUES (%s)", Decimal("1.0")) # check row still in transaction - rows = client.sql_client.execute_sql(f"SELECT col FROM {table_name} WHERE col = %s", Decimal("1.0")) + rows = client.sql_client.execute_sql( + f"SELECT col FROM {table_name} WHERE col = %s", Decimal("1.0") + ) assert len(rows) == 1 # check row after commit - rows = client.sql_client.execute_sql(f"SELECT col FROM {table_name} WHERE col = %s", Decimal("1.0")) + rows = client.sql_client.execute_sql( + f"SELECT col FROM {table_name} WHERE col = %s", Decimal("1.0") + ) assert len(rows) == 1 assert rows[0][0] == 1.0 with client.sql_client.begin_transaction() as tx: client.sql_client.execute_sql(f"DELETE FROM {table_name} WHERE col = %s", Decimal("1.0")) # explicit commit tx.commit_transaction() - rows = client.sql_client.execute_sql(f"SELECT col FROM {table_name} WHERE col = %s", Decimal("1.0")) + rows = client.sql_client.execute_sql( + f"SELECT col FROM {table_name} WHERE col = %s", Decimal("1.0") + ) assert len(rows) == 0 -@pytest.mark.parametrize("client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name) +@pytest.mark.parametrize( + "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name +) def test_rollback_transaction(client: SqlJobClientBase) -> None: if client.capabilities.supports_transactions is False: pytest.skip("Destination does not support tx") @@ -279,11 +364,15 @@ def test_rollback_transaction(client: SqlJobClientBase) -> None: with pytest.raises(RuntimeError): with client.sql_client.begin_transaction(): client.sql_client.execute_sql(f"INSERT INTO {table_name} VALUES (%s)", Decimal("1.0")) - rows = client.sql_client.execute_sql(f"SELECT col FROM {table_name} WHERE col = %s", Decimal("1.0")) + rows = client.sql_client.execute_sql( + f"SELECT col FROM {table_name} WHERE col = %s", Decimal("1.0") + ) assert len(rows) == 1 # python exception triggers rollback raise RuntimeError("ROLLBACK") - rows = client.sql_client.execute_sql(f"SELECT col FROM {table_name} WHERE col = %s", Decimal("1.0")) + rows = client.sql_client.execute_sql( + f"SELECT col FROM {table_name} WHERE col = %s", Decimal("1.0") + ) assert len(rows) == 0 # test rollback on invalid query @@ -291,15 +380,21 @@ def test_rollback_transaction(client: SqlJobClientBase) -> None: with client.sql_client.begin_transaction(): client.sql_client.execute_sql(f"INSERT INTO {table_name} VALUES (%s)", Decimal("1.0")) # table does not exist - client.sql_client.execute_sql(f"SELECT col FROM {table_name}_X WHERE col = %s", Decimal("1.0")) - rows = client.sql_client.execute_sql(f"SELECT col FROM {table_name} WHERE col = %s", Decimal("1.0")) + client.sql_client.execute_sql( + f"SELECT col FROM {table_name}_X WHERE col = %s", Decimal("1.0") + ) + rows = client.sql_client.execute_sql( + f"SELECT col FROM {table_name} WHERE col = %s", Decimal("1.0") + ) assert len(rows) == 0 # test explicit rollback with client.sql_client.begin_transaction() as tx: client.sql_client.execute_sql(f"INSERT INTO {table_name} VALUES (%s)", Decimal("1.0")) tx.rollback_transaction() - rows = client.sql_client.execute_sql(f"SELECT col FROM {table_name} WHERE col = %s", Decimal("1.0")) + rows = client.sql_client.execute_sql( + f"SELECT col FROM {table_name} WHERE col = %s", Decimal("1.0") + ) assert len(rows) == 0 # test double rollback - behavior inconsistent across databases (some raise some not) @@ -310,7 +405,9 @@ def test_rollback_transaction(client: SqlJobClientBase) -> None: # tx.rollback_transaction() -@pytest.mark.parametrize("client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name) +@pytest.mark.parametrize( + "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name +) def test_transaction_isolation(client: SqlJobClientBase) -> None: if client.capabilities.supports_transactions is False: pytest.skip("Destination does not support tx") @@ -320,7 +417,9 @@ def test_transaction_isolation(client: SqlJobClientBase) -> None: def test_thread(thread_id: Decimal) -> None: # make a copy of the sql_client - thread_client = client.sql_client.__class__(client.sql_client.dataset_name, client.sql_client.credentials) + thread_client = client.sql_client.__class__( + client.sql_client.dataset_name, client.sql_client.credentials + ) with thread_client: with thread_client.begin_transaction(): thread_client.execute_sql(f"INSERT INTO {table_name} VALUES (%s)", thread_id) @@ -348,11 +447,18 @@ def test_thread(thread_id: Decimal) -> None: assert rows[0][0] == Decimal("2.0") -@pytest.mark.parametrize("client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name) +@pytest.mark.parametrize( + "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name +) def test_max_table_identifier_length(client: SqlJobClientBase) -> None: if client.capabilities.max_identifier_length >= 65536: - pytest.skip(f"destination {client.config.destination_name} has no table name length restriction") - table_name = 8 * "prospects_external_data__data365_member__member__feed_activities_created_post__items__comments__items__comments__items__author_details__educations" + pytest.skip( + f"destination {client.config.destination_name} has no table name length restriction" + ) + table_name = ( + 8 + * "prospects_external_data__data365_member__member__feed_activities_created_post__items__comments__items__comments__items__author_details__educations" + ) with pytest.raises(IdentifierTooLongException) as py_ex: prepare_table(client, "long_table_name", table_name, make_uniq_table=False) assert py_ex.value.identifier_type == "table" @@ -371,12 +477,19 @@ def test_max_table_identifier_length(client: SqlJobClientBase) -> None: # assert exists is True -@pytest.mark.parametrize("client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name) +@pytest.mark.parametrize( + "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name +) def test_max_column_identifier_length(client: SqlJobClientBase) -> None: if client.capabilities.max_column_identifier_length >= 65536: - pytest.skip(f"destination {client.config.destination_name} has no column name length restriction") + pytest.skip( + f"destination {client.config.destination_name} has no column name length restriction" + ) table_name = "prospects_external_data__data365_member__member" - column_name = 7 * "prospects_external_data__data365_member__member__feed_activities_created_post__items__comments__items__comments__items__author_details__educations__school_name" + column_name = ( + 7 + * "prospects_external_data__data365_member__member__feed_activities_created_post__items__comments__items__comments__items__author_details__educations__school_name" + ) with pytest.raises(IdentifierTooLongException) as py_ex: prepare_table(client, "long_column_name", table_name, make_uniq_table=False) assert py_ex.value.identifier_type == "column" @@ -388,7 +501,9 @@ def test_max_column_identifier_length(client: SqlJobClientBase) -> None: # assert long_column_name[:client.capabilities.max_column_identifier_length] in table_def -@pytest.mark.parametrize("client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name) +@pytest.mark.parametrize( + "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name +) def test_recover_on_explicit_tx(client: SqlJobClientBase) -> None: if client.capabilities.supports_transactions is False: pytest.skip("Destination does not support tx") @@ -415,7 +530,11 @@ def test_recover_on_explicit_tx(client: SqlJobClientBase) -> None: assert_load_id(client.sql_client, "EFG") # wrong value inserted - statements = ["BEGIN TRANSACTION;", f"INSERT INTO {version_table}(version) VALUES(1);", "COMMIT;"] + statements = [ + "BEGIN TRANSACTION;", + f"INSERT INTO {version_table}(version) VALUES(1);", + "COMMIT;", + ] # cannot insert NULL value with pytest.raises(DatabaseTerminalException): client.sql_client.execute_fragments(statements) @@ -426,11 +545,13 @@ def test_recover_on_explicit_tx(client: SqlJobClientBase) -> None: assert_load_id(client.sql_client, "HJK") -def assert_load_id(sql_client:SqlClientBase, load_id: str) -> None: +def assert_load_id(sql_client: SqlClientBase, load_id: str) -> None: # and data is actually committed when connection reopened sql_client.close_connection() sql_client.open_connection() - rows = sql_client.execute_sql(f"SELECT load_id FROM {LOADS_TABLE_NAME} WHERE load_id = %s", load_id) + rows = sql_client.execute_sql( + f"SELECT load_id FROM {LOADS_TABLE_NAME} WHERE load_id = %s", load_id + ) assert len(rows) == 1 @@ -440,7 +561,12 @@ def prepare_temp_table(client: SqlJobClientBase) -> str: iceberg_table_suffix = "" coltype = "numeric" if client.config.destination_name == "athena": - iceberg_table_suffix = f"LOCATION '{AWS_BUCKET}/ci/{table_name}' TBLPROPERTIES ('table_type'='ICEBERG', 'format'='parquet');" + iceberg_table_suffix = ( + f"LOCATION '{AWS_BUCKET}/ci/{table_name}' TBLPROPERTIES ('table_type'='ICEBERG'," + " 'format'='parquet');" + ) coltype = "bigint" - client.sql_client.execute_sql(f"CREATE TABLE {table_name} (col {coltype}) {iceberg_table_suffix};") + client.sql_client.execute_sql( + f"CREATE TABLE {table_name} (col {coltype}) {iceberg_table_suffix};" + ) return table_name diff --git a/tests/load/utils.py b/tests/load/utils.py index ef8035f67f..9d3f3dc7b4 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -1,32 +1,43 @@ -import contextlib -from importlib import import_module import codecs +import contextlib import os -from typing import Any, Iterator, List, Sequence, cast, IO, Tuple, Optional import shutil -from pathlib import Path from dataclasses import dataclass +from importlib import import_module +from pathlib import Path +from typing import IO, Any, Iterator, List, Optional, Sequence, Tuple, cast + +from tests.cases import ( + TABLE_ROW_ALL_DATA_TYPES, + TABLE_UPDATE, + TABLE_UPDATE_COLUMNS_SCHEMA, + assert_all_data_types_row, +) +from tests.utils import ACTIVE_DESTINATIONS, IMPLEMENTED_DESTINATIONS, SQL_DESTINATIONS import dlt from dlt.common import json, sleep from dlt.common.configuration import resolve_configuration from dlt.common.configuration.container import Container from dlt.common.configuration.specs.config_section_context import ConfigSectionContext -from dlt.common.destination.reference import DestinationClientDwhConfiguration, DestinationReference, JobClientBase, LoadJob, DestinationClientStagingConfiguration, WithStagingDataset from dlt.common.data_writers import DataWriter -from dlt.common.schema import TColumnSchema, TTableSchemaColumns, Schema -from dlt.common.storages import SchemaStorage, FileStorage, SchemaStorageConfiguration +from dlt.common.destination.reference import ( + DestinationClientDwhConfiguration, + DestinationClientStagingConfiguration, + DestinationReference, + JobClientBase, + LoadJob, + WithStagingDataset, +) +from dlt.common.schema import Schema, TColumnSchema, TTableSchemaColumns from dlt.common.schema.utils import new_table -from dlt.common.storages.load_storage import ParsedLoadJobFileName, LoadStorage +from dlt.common.storages import FileStorage, SchemaStorage, SchemaStorageConfiguration +from dlt.common.storages.load_storage import LoadStorage, ParsedLoadJobFileName from dlt.common.typing import StrAny from dlt.common.utils import uniq_id - -from dlt.load import Load -from dlt.destinations.sql_client import SqlClientBase from dlt.destinations.job_client_impl import SqlJobClientBase - -from tests.utils import ACTIVE_DESTINATIONS, IMPLEMENTED_DESTINATIONS, SQL_DESTINATIONS -from tests.cases import TABLE_UPDATE_COLUMNS_SCHEMA, TABLE_UPDATE, TABLE_ROW_ALL_DATA_TYPES, assert_all_data_types_row +from dlt.destinations.sql_client import SqlClientBase +from dlt.load import Load # bucket urls AWS_BUCKET = dlt.config.get("tests.bucket_url_s3", str) @@ -34,15 +45,22 @@ FILE_BUCKET = dlt.config.get("tests.bucket_url_file", str) MEMORY_BUCKET = dlt.config.get("tests.memory", str) -ALL_FILESYSTEM_DRIVERS = dlt.config.get("ALL_FILESYSTEM_DRIVERS", list) or ["s3", "gs", "file", "memory"] +ALL_FILESYSTEM_DRIVERS = dlt.config.get("ALL_FILESYSTEM_DRIVERS", list) or [ + "s3", + "gs", + "file", + "memory", +] # Filter out buckets not in all filesystem drivers ALL_BUCKETS = [GCS_BUCKET, AWS_BUCKET, FILE_BUCKET, MEMORY_BUCKET] -ALL_BUCKETS = [bucket for bucket in ALL_BUCKETS if bucket.split(':')[0] in ALL_FILESYSTEM_DRIVERS] +ALL_BUCKETS = [bucket for bucket in ALL_BUCKETS if bucket.split(":")[0] in ALL_FILESYSTEM_DRIVERS] + @dataclass class DestinationTestConfiguration: """Class for defining test setup for one destination.""" + destination: str staging: Optional[str] = None file_format: Optional[str] = None @@ -54,7 +72,7 @@ class DestinationTestConfiguration: @property def name(self) -> str: - name: str = self.destination + name: str = self.destination if self.file_format: name += f"-{self.file_format}" if not self.staging: @@ -67,30 +85,38 @@ def name(self) -> str: def setup(self) -> None: """Sets up environment variables for this destination configuration""" - os.environ['DESTINATION__FILESYSTEM__BUCKET_URL'] = self.bucket_url or "" - os.environ['DESTINATION__STAGE_NAME'] = self.stage_name or "" - os.environ['DESTINATION__STAGING_IAM_ROLE'] = self.staging_iam_role or "" + os.environ["DESTINATION__FILESYSTEM__BUCKET_URL"] = self.bucket_url or "" + os.environ["DESTINATION__STAGE_NAME"] = self.stage_name or "" + os.environ["DESTINATION__STAGING_IAM_ROLE"] = self.staging_iam_role or "" """For the filesystem destinations we disable compression to make analyzing the result easier""" if self.destination == "filesystem": - os.environ['DATA_WRITER__DISABLE_COMPRESSION'] = "True" - + os.environ["DATA_WRITER__DISABLE_COMPRESSION"] = "True" - def setup_pipeline(self, pipeline_name: str, dataset_name: str = None, full_refresh: bool = False, **kwargs) -> dlt.Pipeline: + def setup_pipeline( + self, pipeline_name: str, dataset_name: str = None, full_refresh: bool = False, **kwargs + ) -> dlt.Pipeline: """Convenience method to setup pipeline with this configuration""" self.setup() - pipeline = dlt.pipeline(pipeline_name=pipeline_name, destination=self.destination, staging=self.staging, dataset_name=dataset_name or pipeline_name, full_refresh=full_refresh, **kwargs) + pipeline = dlt.pipeline( + pipeline_name=pipeline_name, + destination=self.destination, + staging=self.staging, + dataset_name=dataset_name or pipeline_name, + full_refresh=full_refresh, + **kwargs, + ) return pipeline def destinations_configs( - default_sql_configs: bool = False, - default_staging_configs: bool = False, - all_staging_configs: bool = False, - local_filesystem_configs: bool = False, - all_buckets_filesystem_configs: bool = False, - subset: List[str] = "") -> Iterator[DestinationTestConfiguration]: - + default_sql_configs: bool = False, + default_staging_configs: bool = False, + all_staging_configs: bool = False, + local_filesystem_configs: bool = False, + all_buckets_filesystem_configs: bool = False, + subset: List[str] = "", +) -> Iterator[DestinationTestConfiguration]: # sanity check for item in subset: assert item in IMPLEMENTED_DESTINATIONS, f"Destination {item} is not implemented" @@ -100,39 +126,122 @@ def destinations_configs( # default non staging sql based configs, one per destination if default_sql_configs: - destination_configs += [DestinationTestConfiguration(destination=destination) for destination in SQL_DESTINATIONS if destination != "athena"] + destination_configs += [ + DestinationTestConfiguration(destination=destination) + for destination in SQL_DESTINATIONS + if destination != "athena" + ] # athena needs filesystem staging, which will be automatically set, we have to supply a bucket url though - destination_configs += [DestinationTestConfiguration(destination="athena", supports_merge=False, bucket_url=AWS_BUCKET)] + destination_configs += [ + DestinationTestConfiguration( + destination="athena", supports_merge=False, bucket_url=AWS_BUCKET + ) + ] if default_staging_configs or all_staging_configs: destination_configs += [ - DestinationTestConfiguration(destination="athena", staging="filesystem", file_format="parquet", bucket_url=AWS_BUCKET, supports_merge=False), - DestinationTestConfiguration(destination="redshift", staging="filesystem", file_format="parquet", bucket_url=AWS_BUCKET, staging_iam_role="arn:aws:iam::267388281016:role/redshift_s3_read", extra_info="s3-role"), - DestinationTestConfiguration(destination="bigquery", staging="filesystem", file_format="parquet", bucket_url=GCS_BUCKET, extra_info="gcs-authorization"), - DestinationTestConfiguration(destination="snowflake", staging="filesystem", file_format="jsonl", bucket_url=GCS_BUCKET, stage_name="PUBLIC.dlt_gcs_stage", extra_info="gcs-integration"), - DestinationTestConfiguration(destination="snowflake", staging="filesystem", file_format="jsonl", bucket_url=AWS_BUCKET, stage_name="PUBLIC.dlt_s3_stage", extra_info="s3-integration") + DestinationTestConfiguration( + destination="athena", + staging="filesystem", + file_format="parquet", + bucket_url=AWS_BUCKET, + supports_merge=False, + ), + DestinationTestConfiguration( + destination="redshift", + staging="filesystem", + file_format="parquet", + bucket_url=AWS_BUCKET, + staging_iam_role="arn:aws:iam::267388281016:role/redshift_s3_read", + extra_info="s3-role", + ), + DestinationTestConfiguration( + destination="bigquery", + staging="filesystem", + file_format="parquet", + bucket_url=GCS_BUCKET, + extra_info="gcs-authorization", + ), + DestinationTestConfiguration( + destination="snowflake", + staging="filesystem", + file_format="jsonl", + bucket_url=GCS_BUCKET, + stage_name="PUBLIC.dlt_gcs_stage", + extra_info="gcs-integration", + ), + DestinationTestConfiguration( + destination="snowflake", + staging="filesystem", + file_format="jsonl", + bucket_url=AWS_BUCKET, + stage_name="PUBLIC.dlt_s3_stage", + extra_info="s3-integration", + ), ] if all_staging_configs: destination_configs += [ - DestinationTestConfiguration(destination="redshift", staging="filesystem", file_format="parquet", bucket_url=AWS_BUCKET, extra_info="credential-forwarding"), - DestinationTestConfiguration(destination="snowflake", staging="filesystem", file_format="parquet", bucket_url=AWS_BUCKET, extra_info="credential-forwarding"), - DestinationTestConfiguration(destination="redshift", staging="filesystem", file_format="jsonl", bucket_url=AWS_BUCKET, extra_info="credential-forwarding"), - DestinationTestConfiguration(destination="bigquery", staging="filesystem", file_format="jsonl", bucket_url=GCS_BUCKET, extra_info="gcs-authorization"), + DestinationTestConfiguration( + destination="redshift", + staging="filesystem", + file_format="parquet", + bucket_url=AWS_BUCKET, + extra_info="credential-forwarding", + ), + DestinationTestConfiguration( + destination="snowflake", + staging="filesystem", + file_format="parquet", + bucket_url=AWS_BUCKET, + extra_info="credential-forwarding", + ), + DestinationTestConfiguration( + destination="redshift", + staging="filesystem", + file_format="jsonl", + bucket_url=AWS_BUCKET, + extra_info="credential-forwarding", + ), + DestinationTestConfiguration( + destination="bigquery", + staging="filesystem", + file_format="jsonl", + bucket_url=GCS_BUCKET, + extra_info="gcs-authorization", + ), ] # add local filesystem destinations if requested if local_filesystem_configs: - destination_configs += [DestinationTestConfiguration(destination="filesystem", bucket_url=FILE_BUCKET, file_format="insert_values")] - destination_configs += [DestinationTestConfiguration(destination="filesystem", bucket_url=FILE_BUCKET, file_format="parquet")] - destination_configs += [DestinationTestConfiguration(destination="filesystem", bucket_url=FILE_BUCKET, file_format="jsonl")] + destination_configs += [ + DestinationTestConfiguration( + destination="filesystem", bucket_url=FILE_BUCKET, file_format="insert_values" + ) + ] + destination_configs += [ + DestinationTestConfiguration( + destination="filesystem", bucket_url=FILE_BUCKET, file_format="parquet" + ) + ] + destination_configs += [ + DestinationTestConfiguration( + destination="filesystem", bucket_url=FILE_BUCKET, file_format="jsonl" + ) + ] if all_buckets_filesystem_configs: for bucket in ALL_BUCKETS: - destination_configs += [DestinationTestConfiguration(destination="filesystem", bucket_url=bucket, extra_info=bucket)] + destination_configs += [ + DestinationTestConfiguration( + destination="filesystem", bucket_url=bucket, extra_info=bucket + ) + ] # filter out non active destinations - destination_configs = [conf for conf in destination_configs if conf.destination in ACTIVE_DESTINATIONS] + destination_configs = [ + conf for conf in destination_configs if conf.destination in ACTIVE_DESTINATIONS + ] # filter out destinations not in subset if subset: @@ -145,19 +254,33 @@ def load_table(name: str) -> TTableSchemaColumns: with open(f"./tests/load/cases/{name}.json", "rb") as f: return cast(TTableSchemaColumns, json.load(f)) -def expect_load_file(client: JobClientBase, file_storage: FileStorage, query: str, table_name: str, status = "completed") -> LoadJob: - file_name = ParsedLoadJobFileName(table_name, uniq_id(), 0, client.capabilities.preferred_loader_file_format).job_id() + +def expect_load_file( + client: JobClientBase, + file_storage: FileStorage, + query: str, + table_name: str, + status="completed", +) -> LoadJob: + file_name = ParsedLoadJobFileName( + table_name, uniq_id(), 0, client.capabilities.preferred_loader_file_format + ).job_id() file_storage.save(file_name, query.encode("utf-8")) table = Load.get_load_table(client.schema, file_name) job = client.start_file_load(table, file_storage.make_full_path(file_name), uniq_id()) while job.state() == "running": sleep(0.5) assert job.file_name() == file_name - assert job.state() == status + assert job.state() == status return job -def prepare_table(client: JobClientBase, case_name: str = "event_user", table_name: str = "event_user", make_uniq_table: bool = True) -> None: +def prepare_table( + client: JobClientBase, + case_name: str = "event_user", + table_name: str = "event_user", + make_uniq_table: bool = True, +) -> None: client.schema.bump_version() client.update_storage_schema() user_table = load_table(case_name)[table_name] @@ -170,11 +293,12 @@ def prepare_table(client: JobClientBase, case_name: str = "event_user", table_na client.update_storage_schema() return user_table_name + def yield_client( destination_name: str, dataset_name: str = None, default_config_values: StrAny = None, - schema_name: str = "event" + schema_name: str = "event", ) -> Iterator[SqlJobClientBase]: os.environ.pop("DATASET_NAME", None) # import destination reference by name @@ -190,9 +314,10 @@ def yield_client( # also apply to config dest_config.update(default_config_values) # get event default schema - storage_config = resolve_configuration(SchemaStorageConfiguration(), explicit_value={ - "schema_volume_path": "tests/common/cases/schemas/rasa" - }) + storage_config = resolve_configuration( + SchemaStorageConfiguration(), + explicit_value={"schema_volume_path": "tests/common/cases/schemas/rasa"}, + ) schema_storage = SchemaStorage(storage_config) schema = schema_storage.load_schema(schema_name) # create client and dataset @@ -204,35 +329,42 @@ def yield_client( destination_name="fake-stage", dataset_name=dest_config.dataset_name, default_schema_name=dest_config.default_schema_name, - bucket_url=AWS_BUCKET + bucket_url=AWS_BUCKET, ) dest_config.staging_config = staging_config # lookup for credentials in the section that is destination name - with Container().injectable_context(ConfigSectionContext(sections=("destination", destination_name,))): + with Container().injectable_context( + ConfigSectionContext( + sections=( + "destination", + destination_name, + ) + ) + ): with destination.client(schema, dest_config) as client: yield client + @contextlib.contextmanager def cm_yield_client( destination_name: str, dataset_name: str, default_config_values: StrAny = None, - schema_name: str = "event" + schema_name: str = "event", ) -> Iterator[SqlJobClientBase]: return yield_client(destination_name, dataset_name, default_config_values, schema_name) def yield_client_with_storage( - destination_name: str, - default_config_values: StrAny = None, - schema_name: str = "event" + destination_name: str, default_config_values: StrAny = None, schema_name: str = "event" ) -> Iterator[SqlJobClientBase]: - # create dataset with random name dataset_name = "test_" + uniq_id() - with cm_yield_client(destination_name, dataset_name, default_config_values, schema_name) as client: + with cm_yield_client( + destination_name, dataset_name, default_config_values, schema_name + ) as client: client.initialize_storage() yield client # print(dataset_name) @@ -253,40 +385,46 @@ def delete_dataset(client: SqlClientBase[Any], normalized_dataset_name: str) -> @contextlib.contextmanager def cm_yield_client_with_storage( - destination_name: str, - default_config_values: StrAny = None, - schema_name: str = "event" + destination_name: str, default_config_values: StrAny = None, schema_name: str = "event" ) -> Iterator[SqlJobClientBase]: return yield_client_with_storage(destination_name, default_config_values, schema_name) -def write_dataset(client: JobClientBase, f: IO[bytes], rows: List[StrAny], columns_schema: TTableSchemaColumns) -> None: - data_format = DataWriter.data_format_from_file_format(client.capabilities.preferred_loader_file_format) +def write_dataset( + client: JobClientBase, f: IO[bytes], rows: List[StrAny], columns_schema: TTableSchemaColumns +) -> None: + data_format = DataWriter.data_format_from_file_format( + client.capabilities.preferred_loader_file_format + ) # adapt bytes stream to text file format if not data_format.is_binary_format and isinstance(f.read(0), bytes): f = codecs.getwriter("utf-8")(f) writer = DataWriter.from_destination_capabilities(client.capabilities, f) # remove None values for idx, row in enumerate(rows): - rows[idx] = {k:v for k, v in row.items() if v is not None} + rows[idx] = {k: v for k, v in row.items() if v is not None} writer.write_all(columns_schema, rows) -def prepare_load_package(load_storage: LoadStorage, cases: Sequence[str], write_disposition: str='append') -> Tuple[str, Schema]: +def prepare_load_package( + load_storage: LoadStorage, cases: Sequence[str], write_disposition: str = "append" +) -> Tuple[str, Schema]: load_id = uniq_id() load_storage.create_temp_load_package(load_id) for case in cases: path = f"./tests/load/cases/loading/{case}" - shutil.copy(path, load_storage.storage.make_full_path(f"{load_id}/{LoadStorage.NEW_JOBS_FOLDER}")) + shutil.copy( + path, load_storage.storage.make_full_path(f"{load_id}/{LoadStorage.NEW_JOBS_FOLDER}") + ) schema_path = Path("./tests/load/cases/loading/schema.json") - data = json.loads(schema_path.read_text(encoding='utf8')) - for name, table in data['tables'].items(): - if name.startswith('_dlt'): + data = json.loads(schema_path.read_text(encoding="utf8")) + for name, table in data["tables"].items(): + if name.startswith("_dlt"): continue - table['write_disposition'] = write_disposition - Path( - load_storage.storage.make_full_path(load_id) - ).joinpath(schema_path.name).write_text(json.dumps(data), encoding='utf8') + table["write_disposition"] = write_disposition + Path(load_storage.storage.make_full_path(load_id)).joinpath(schema_path.name).write_text( + json.dumps(data), encoding="utf8" + ) schema_update_path = "./tests/load/cases/loading/schema_updates.json" shutil.copy(schema_update_path, load_storage.storage.make_full_path(load_id)) diff --git a/tests/load/weaviate/test_naming.py b/tests/load/weaviate/test_naming.py index e09620e91b..32d70ff7ef 100644 --- a/tests/load/weaviate/test_naming.py +++ b/tests/load/weaviate/test_naming.py @@ -1,12 +1,13 @@ -import dlt, pytest +import pytest +from tests.common.utils import load_yml_case +import dlt from dlt.destinations.weaviate.naming import NamingConvention -from tests.common.utils import load_yml_case @dlt.source def small(): - return dlt.resource([1,2,3], name="table") + return dlt.resource([1, 2, 3], name="table") def test_table_name_normalization() -> None: diff --git a/tests/load/weaviate/test_pipeline.py b/tests/load/weaviate/test_pipeline.py index 26b28366c2..3088e49f6b 100644 --- a/tests/load/weaviate/test_pipeline.py +++ b/tests/load/weaviate/test_pipeline.py @@ -1,16 +1,16 @@ import pytest +from tests.pipeline.utils import assert_load_info import dlt from dlt.common import json from dlt.common.utils import uniq_id - from dlt.destinations.weaviate import weaviate_adapter -from dlt.destinations.weaviate.weaviate_adapter import VECTORIZE_HINT, TOKENIZATION_HINT +from dlt.destinations.weaviate.weaviate_adapter import TOKENIZATION_HINT, VECTORIZE_HINT from dlt.destinations.weaviate.weaviate_client import WeaviateClient -from tests.pipeline.utils import assert_load_info from .utils import assert_class, delete_classes, drop_active_pipeline_data + @pytest.fixture(autouse=True) def drop_weaviate_schema() -> None: yield @@ -144,16 +144,14 @@ def test_pipeline_merge() -> None: "doc_id": 1, "title": "The Shawshank Redemption", "description": ( - "Two imprisoned men find redemption through acts " - "of decency over the years." + "Two imprisoned men find redemption through acts of decency over the years." ), }, { "doc_id": 2, "title": "The Godfather", "description": ( - "A crime dynasty's aging patriarch transfers " - "control to his reluctant son." + "A crime dynasty's aging patriarch transfers control to his reluctant son." ), }, { @@ -258,20 +256,39 @@ def test_merge_github_nested() -> None: p = dlt.pipeline(destination="weaviate", dataset_name="github1", full_refresh=True) assert p.dataset_name.startswith("github1_202") - with open("tests/normalize/cases/github.issues.load_page_5_duck.json", "r", encoding="utf-8") as f: + with open( + "tests/normalize/cases/github.issues.load_page_5_duck.json", "r", encoding="utf-8" + ) as f: data = json.load(f) info = p.run( - weaviate_adapter(data[:17], vectorize=["title", "body"], tokenization={"user__login": "lowercase"}), + weaviate_adapter( + data[:17], vectorize=["title", "body"], tokenization={"user__login": "lowercase"} + ), table_name="issues", write_disposition="merge", - primary_key="id" + primary_key="id", ) assert_load_info(info) # assert if schema contains tables with right names - assert set(p.default_schema.tables.keys()) == {'DltVersion', 'DltLoads', 'Issues', 'DltPipelineState', 'Issues__Labels', 'Issues__Assignees'} - assert set([t["name"] for t in p.default_schema.data_tables()]) == {'Issues', 'Issues__Labels', 'Issues__Assignees'} - assert set([t["name"] for t in p.default_schema.dlt_tables()]) == {'DltVersion', 'DltLoads', 'DltPipelineState'} + assert set(p.default_schema.tables.keys()) == { + "DltVersion", + "DltLoads", + "Issues", + "DltPipelineState", + "Issues__Labels", + "Issues__Assignees", + } + assert set([t["name"] for t in p.default_schema.data_tables()]) == { + "Issues", + "Issues__Labels", + "Issues__Assignees", + } + assert set([t["name"] for t in p.default_schema.dlt_tables()]) == { + "DltVersion", + "DltLoads", + "DltPipelineState", + } issues = p.default_schema.tables["Issues"] # make sure that both "id" column and "primary_key" were changed to __id assert issues["columns"]["__id"]["primary_key"] is True diff --git a/tests/load/weaviate/test_weaviate_client.py b/tests/load/weaviate/test_weaviate_client.py index 266a00f914..0c6365ad12 100644 --- a/tests/load/weaviate/test_weaviate_client.py +++ b/tests/load/weaviate/test_weaviate_client.py @@ -1,23 +1,28 @@ import io -import pytest from typing import Iterator -from dlt.common.schema import Schema +import pytest +from tests.load.utils import ( + TABLE_ROW_ALL_DATA_TYPES, + TABLE_UPDATE, + TABLE_UPDATE_COLUMNS_SCHEMA, + expect_load_file, + write_dataset, +) +from tests.utils import TEST_STORAGE_ROOT + from dlt.common.configuration.container import Container from dlt.common.configuration.specs.config_section_context import ConfigSectionContext +from dlt.common.schema import Schema +from dlt.common.schema.utils import new_table +from dlt.common.storages.file_storage import FileStorage from dlt.common.utils import uniq_id - from dlt.destinations import weaviate from dlt.destinations.weaviate.weaviate_client import WeaviateClient -from dlt.common.storages.file_storage import FileStorage -from dlt.common.schema.utils import new_table -from tests.load.utils import TABLE_ROW_ALL_DATA_TYPES, TABLE_UPDATE, TABLE_UPDATE_COLUMNS_SCHEMA, expect_load_file, write_dataset - -from tests.utils import TEST_STORAGE_ROOT - from .utils import drop_active_pipeline_data + @pytest.fixture(autouse=True) def drop_weaviate_schema() -> None: yield @@ -27,16 +32,13 @@ def drop_weaviate_schema() -> None: def get_client_instance(schema: Schema) -> WeaviateClient: config = weaviate.spec()() config.dataset_name = "ClientTest" + uniq_id() - with Container().injectable_context(ConfigSectionContext(sections=('destination', 'weaviate'))): + with Container().injectable_context(ConfigSectionContext(sections=("destination", "weaviate"))): return weaviate.client(schema, config) # type: ignore[return-value] -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def client() -> Iterator[WeaviateClient]: - schema = Schema('test_schema', { - 'names': "dlt.destinations.weaviate.naming", - 'json': None - }) + schema = Schema("test_schema", {"names": "dlt.destinations.weaviate.naming", "json": None}) _client = get_client_instance(schema) try: yield _client @@ -49,11 +51,15 @@ def file_storage() -> FileStorage: return FileStorage(TEST_STORAGE_ROOT, file_type="b", makedirs=True) -@pytest.mark.parametrize('write_disposition', ["append", "replace", "merge"]) -def test_all_data_types(client: WeaviateClient, write_disposition: str, file_storage: FileStorage) -> None: +@pytest.mark.parametrize("write_disposition", ["append", "replace", "merge"]) +def test_all_data_types( + client: WeaviateClient, write_disposition: str, file_storage: FileStorage +) -> None: class_name = "AllTypes" # we should have identical content with all disposition types - client.schema.update_schema(new_table(class_name, write_disposition=write_disposition, columns=TABLE_UPDATE)) + client.schema.update_schema( + new_table(class_name, write_disposition=write_disposition, columns=TABLE_UPDATE) + ) client.schema.bump_version() client.update_storage_schema() diff --git a/tests/load/weaviate/utils.py b/tests/load/weaviate/utils.py index 20c4d93edb..86c61a234a 100644 --- a/tests/load/weaviate/utils.py +++ b/tests/load/weaviate/utils.py @@ -1,12 +1,10 @@ -import dlt from typing import Any, List import dlt -from dlt.common.pipeline import PipelineContext from dlt.common.configuration.container import Container - +from dlt.common.pipeline import PipelineContext +from dlt.destinations.weaviate.weaviate_adapter import TOKENIZATION_HINT, VECTORIZE_HINT from dlt.destinations.weaviate.weaviate_client import WeaviateClient -from dlt.destinations.weaviate.weaviate_adapter import VECTORIZE_HINT, TOKENIZATION_HINT def assert_unordered_list_equal(list1: List[Any], list2: List[Any]) -> None: @@ -69,6 +67,7 @@ def delete_classes(p, class_list): for class_name in class_list: db_client.schema.delete_class(class_name) + def drop_active_pipeline_data() -> None: def schema_has_classes(client): schema = client.db_client.schema.get() diff --git a/tests/normalize/mock_rasa_json_normalizer.py b/tests/normalize/mock_rasa_json_normalizer.py index fd8f0b1731..c0c1fa0386 100644 --- a/tests/normalize/mock_rasa_json_normalizer.py +++ b/tests/normalize/mock_rasa_json_normalizer.py @@ -1,18 +1,25 @@ -from dlt.common.normalizers.json import TNormalizedRowIterator, DataItemNormalizer +from dlt.common.normalizers.json import DataItemNormalizer, TNormalizedRowIterator from dlt.common.normalizers.json.relational import DataItemNormalizer as RelationalNormalizer from dlt.common.schema import Schema from dlt.common.typing import TDataItem class DataItemNormalizer(RelationalNormalizer): - - def normalize_data_item(self, source_event: TDataItem, load_id: str, table_name: str) -> TNormalizedRowIterator: + def normalize_data_item( + self, source_event: TDataItem, load_id: str, table_name: str + ) -> TNormalizedRowIterator: if self.schema.name == "event": # this emulates rasa parser on standard parser - event = {"sender_id": source_event["sender_id"], "timestamp": source_event["timestamp"], "type": source_event["event"]} + event = { + "sender_id": source_event["sender_id"], + "timestamp": source_event["timestamp"], + "type": source_event["event"], + } yield from super().normalize_data_item(event, load_id, table_name) # add table name which is "event" field in RASA OSS - yield from super().normalize_data_item(source_event, load_id, table_name + "_" + source_event["event"]) + yield from super().normalize_data_item( + source_event, load_id, table_name + "_" + source_event["event"] + ) else: # will generate tables properly yield from super().normalize_data_item(source_event, load_id, table_name) diff --git a/tests/normalize/test_normalize.py b/tests/normalize/test_normalize.py index 6aecdea59d..9f277b0f5c 100644 --- a/tests/normalize/test_normalize.py +++ b/tests/normalize/test_normalize.py @@ -1,25 +1,35 @@ -import pytest from fnmatch import fnmatch -from typing import Dict, Iterator, List, Sequence, Tuple -from multiprocessing import get_start_method, Pool +from multiprocessing import Pool, get_start_method from multiprocessing.dummy import Pool as ThreadPool +from typing import Dict, Iterator, List, Sequence, Tuple + +import pytest +from tests.cases import JSON_TYPED_DICT, JSON_TYPED_DICT_TYPES +from tests.normalize.utils import ( + ALL_CAPABILITIES, + DEFAULT_CAPS, + INSERT_CAPS, + JSONL_CAPS, + json_case_path, +) +from tests.utils import ( + TEST_DICT_CONFIG_PROVIDER, + assert_no_dict_key_starts_with, + clean_test_storage, + init_test_logging, +) from dlt.common import json -from dlt.common.schema.schema import Schema -from dlt.common.utils import uniq_id -from dlt.common.typing import StrAny +from dlt.common.configuration.container import Container from dlt.common.data_types import TDataType -from dlt.common.storages import NormalizeStorage, LoadStorage from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.configuration.container import Container - +from dlt.common.schema.schema import Schema +from dlt.common.storages import LoadStorage, NormalizeStorage +from dlt.common.typing import StrAny +from dlt.common.utils import uniq_id from dlt.extract.extract import ExtractorStorage from dlt.normalize import Normalize -from tests.cases import JSON_TYPED_DICT, JSON_TYPED_DICT_TYPES -from tests.utils import TEST_DICT_CONFIG_PROVIDER, assert_no_dict_key_starts_with, clean_test_storage, init_test_logging -from tests.normalize.utils import json_case_path, INSERT_CAPS, JSONL_CAPS, DEFAULT_CAPS, ALL_CAPABILITIES - @pytest.fixture(scope="module", autouse=True) def default_caps() -> Iterator[DestinationCapabilitiesContext]: @@ -56,7 +66,9 @@ def rasa_normalize() -> Normalize: def init_normalize(default_schemas_path: str = None) -> Iterator[Normalize]: clean_test_storage() # pass schema config fields to schema storage via dict config provider - with TEST_DICT_CONFIG_PROVIDER().values({"import_schema_path": default_schemas_path, "external_schema_format": "json"}): + with TEST_DICT_CONFIG_PROVIDER().values( + {"import_schema_path": default_schemas_path, "external_schema_format": "json"} + ): # inject the destination capabilities n = Normalize() yield n @@ -73,8 +85,12 @@ def test_initialize(rasa_normalize: Normalize) -> None: @pytest.mark.parametrize("caps", JSONL_CAPS, indirect=True) -def test_normalize_single_user_event_jsonl(caps: DestinationCapabilitiesContext, raw_normalize: Normalize) -> None: - expected_tables, load_files = normalize_event_user(raw_normalize, "event.event.user_load_1", EXPECTED_USER_TABLES) +def test_normalize_single_user_event_jsonl( + caps: DestinationCapabilitiesContext, raw_normalize: Normalize +) -> None: + expected_tables, load_files = normalize_event_user( + raw_normalize, "event.event.user_load_1", EXPECTED_USER_TABLES + ) # load, parse and verify jsonl for expected_table in expected_tables: get_line_from_file(raw_normalize.load_storage, load_files[expected_table]) @@ -85,7 +101,11 @@ def test_normalize_single_user_event_jsonl(caps: DestinationCapabilitiesContext, assert event_json["event"] == "user" assert event_json["parse_data__intent__name"] == "greet" assert event_json["text"] == "hello" - event_text, lines = get_line_from_file(raw_normalize.load_storage, load_files["event__parse_data__response_selector__default__ranking"], 9) + event_text, lines = get_line_from_file( + raw_normalize.load_storage, + load_files["event__parse_data__response_selector__default__ranking"], + 9, + ) assert lines == 10 event_json = json.loads(event_text) assert "id" in event_json @@ -94,31 +114,47 @@ def test_normalize_single_user_event_jsonl(caps: DestinationCapabilitiesContext, @pytest.mark.parametrize("caps", INSERT_CAPS, indirect=True) -def test_normalize_single_user_event_insert(caps: DestinationCapabilitiesContext, raw_normalize: Normalize) -> None: +def test_normalize_single_user_event_insert( + caps: DestinationCapabilitiesContext, raw_normalize: Normalize +) -> None: # mock_destination_caps(raw_normalize, caps) - expected_tables, load_files = normalize_event_user(raw_normalize, "event.event.user_load_1", EXPECTED_USER_TABLES) + expected_tables, load_files = normalize_event_user( + raw_normalize, "event.event.user_load_1", EXPECTED_USER_TABLES + ) # verify values line for expected_table in expected_tables: get_line_from_file(raw_normalize.load_storage, load_files[expected_table]) # return first values line from event_user file event_text, lines = get_line_from_file(raw_normalize.load_storage, load_files["event"], 2) assert lines == 3 - assert "'user'" in event_text + assert "'user'" in event_text assert "'greet'" in event_text assert "'hello'" in event_text - event_text, lines = get_line_from_file(raw_normalize.load_storage, load_files["event__parse_data__response_selector__default__ranking"], 11) + event_text, lines = get_line_from_file( + raw_normalize.load_storage, + load_files["event__parse_data__response_selector__default__ranking"], + 11, + ) assert lines == 12 assert "(7005479104644416710," in event_text @pytest.mark.parametrize("caps", JSONL_CAPS, indirect=True) -def test_normalize_filter_user_event(caps: DestinationCapabilitiesContext, rasa_normalize: Normalize) -> None: +def test_normalize_filter_user_event( + caps: DestinationCapabilitiesContext, rasa_normalize: Normalize +) -> None: load_id = extract_and_normalize_cases(rasa_normalize, ["event.event.user_load_v228_1"]) _, load_files = expect_load_package( rasa_normalize.load_storage, load_id, - ["event", "event_user", "event_user__metadata__user_nicknames", - "event_user__parse_data__entities", "event_user__parse_data__entities__processors", "event_user__parse_data__intent_ranking"] + [ + "event", + "event_user", + "event_user__metadata__user_nicknames", + "event_user__parse_data__entities", + "event_user__parse_data__entities__processors", + "event_user__parse_data__intent_ranking", + ], ) event_text, lines = get_line_from_file(rasa_normalize.load_storage, load_files["event_user"], 0) assert lines == 1 @@ -129,9 +165,15 @@ def test_normalize_filter_user_event(caps: DestinationCapabilitiesContext, rasa_ @pytest.mark.parametrize("caps", JSONL_CAPS, indirect=True) -def test_normalize_filter_bot_event(caps: DestinationCapabilitiesContext, rasa_normalize: Normalize) -> None: - load_id = extract_and_normalize_cases(rasa_normalize, ["event.event.bot_load_metadata_2987398237498798"]) - _, load_files = expect_load_package(rasa_normalize.load_storage, load_id, ["event", "event_bot"]) +def test_normalize_filter_bot_event( + caps: DestinationCapabilitiesContext, rasa_normalize: Normalize +) -> None: + load_id = extract_and_normalize_cases( + rasa_normalize, ["event.event.bot_load_metadata_2987398237498798"] + ) + _, load_files = expect_load_package( + rasa_normalize.load_storage, load_id, ["event", "event_bot"] + ) event_text, lines = get_line_from_file(rasa_normalize.load_storage, load_files["event_bot"], 0) assert lines == 1 filtered_row = json.loads(event_text) @@ -140,35 +182,41 @@ def test_normalize_filter_bot_event(caps: DestinationCapabilitiesContext, rasa_n @pytest.mark.parametrize("caps", JSONL_CAPS, indirect=True) -def test_preserve_slot_complex_value_json_l(caps: DestinationCapabilitiesContext, rasa_normalize: Normalize) -> None: +def test_preserve_slot_complex_value_json_l( + caps: DestinationCapabilitiesContext, rasa_normalize: Normalize +) -> None: load_id = extract_and_normalize_cases(rasa_normalize, ["event.event.slot_session_metadata_1"]) - _, load_files = expect_load_package(rasa_normalize.load_storage, load_id, ["event", "event_slot"]) + _, load_files = expect_load_package( + rasa_normalize.load_storage, load_id, ["event", "event_slot"] + ) event_text, lines = get_line_from_file(rasa_normalize.load_storage, load_files["event_slot"], 0) assert lines == 1 filtered_row = json.loads(event_text) assert type(filtered_row["value"]) is dict - assert filtered_row["value"] == { - "user_id": "world", - "mitter_id": "hello" - } + assert filtered_row["value"] == {"user_id": "world", "mitter_id": "hello"} @pytest.mark.parametrize("caps", INSERT_CAPS, indirect=True) -def test_preserve_slot_complex_value_insert(caps: DestinationCapabilitiesContext, rasa_normalize: Normalize) -> None: +def test_preserve_slot_complex_value_insert( + caps: DestinationCapabilitiesContext, rasa_normalize: Normalize +) -> None: load_id = extract_and_normalize_cases(rasa_normalize, ["event.event.slot_session_metadata_1"]) - _, load_files = expect_load_package(rasa_normalize.load_storage, load_id, ["event", "event_slot"]) + _, load_files = expect_load_package( + rasa_normalize.load_storage, load_id, ["event", "event_slot"] + ) event_text, lines = get_line_from_file(rasa_normalize.load_storage, load_files["event_slot"], 2) assert lines == 3 - c_val = json.dumps({ - "user_id": "world", - "mitter_id": "hello" - }) + c_val = json.dumps({"user_id": "world", "mitter_id": "hello"}) assert c_val in event_text @pytest.mark.parametrize("caps", INSERT_CAPS, indirect=True) -def test_normalize_many_events_insert(caps: DestinationCapabilitiesContext, rasa_normalize: Normalize) -> None: - load_id = extract_and_normalize_cases(rasa_normalize, ["event.event.many_load_2", "event.event.user_load_1"]) +def test_normalize_many_events_insert( + caps: DestinationCapabilitiesContext, rasa_normalize: Normalize +) -> None: + load_id = extract_and_normalize_cases( + rasa_normalize, ["event.event.many_load_2", "event.event.user_load_1"] + ) expected_tables = EXPECTED_USER_TABLES_RASA_NORMALIZER + ["event_bot", "event_action"] _, load_files = expect_load_package(rasa_normalize.load_storage, load_id, expected_tables) # return first values line from event_user file @@ -179,8 +227,12 @@ def test_normalize_many_events_insert(caps: DestinationCapabilitiesContext, rasa @pytest.mark.parametrize("caps", JSONL_CAPS, indirect=True) -def test_normalize_many_events(caps: DestinationCapabilitiesContext, rasa_normalize: Normalize) -> None: - load_id = extract_and_normalize_cases(rasa_normalize, ["event.event.many_load_2", "event.event.user_load_1"]) +def test_normalize_many_events( + caps: DestinationCapabilitiesContext, rasa_normalize: Normalize +) -> None: + load_id = extract_and_normalize_cases( + rasa_normalize, ["event.event.many_load_2", "event.event.user_load_1"] + ) expected_tables = EXPECTED_USER_TABLES_RASA_NORMALIZER + ["event_bot", "event_action"] _, load_files = expect_load_package(rasa_normalize.load_storage, load_id, expected_tables) # return first values line from event_user file @@ -191,22 +243,32 @@ def test_normalize_many_events(caps: DestinationCapabilitiesContext, rasa_normal @pytest.mark.parametrize("caps", ALL_CAPABILITIES, indirect=True) -def test_normalize_raw_no_type_hints(caps: DestinationCapabilitiesContext, raw_normalize: Normalize) -> None: +def test_normalize_raw_no_type_hints( + caps: DestinationCapabilitiesContext, raw_normalize: Normalize +) -> None: normalize_event_user(raw_normalize, "event.event.user_load_1", EXPECTED_USER_TABLES) assert_timestamp_data_type(raw_normalize.load_storage, "double") @pytest.mark.parametrize("caps", ALL_CAPABILITIES, indirect=True) -def test_normalize_raw_type_hints(caps: DestinationCapabilitiesContext, rasa_normalize: Normalize) -> None: +def test_normalize_raw_type_hints( + caps: DestinationCapabilitiesContext, rasa_normalize: Normalize +) -> None: extract_and_normalize_cases(rasa_normalize, ["event.event.user_load_1"]) assert_timestamp_data_type(rasa_normalize.load_storage, "timestamp") @pytest.mark.parametrize("caps", ALL_CAPABILITIES, indirect=True) -def test_normalize_many_schemas(caps: DestinationCapabilitiesContext, rasa_normalize: Normalize) -> None: +def test_normalize_many_schemas( + caps: DestinationCapabilitiesContext, rasa_normalize: Normalize +) -> None: extract_cases( rasa_normalize.normalize_storage, - ["event.event.many_load_2", "event.event.user_load_1", "ethereum.blocks.9c1d9b504ea240a482b007788d5cd61c_2"] + [ + "event.event.many_load_2", + "event.event.user_load_1", + "ethereum.blocks.9c1d9b504ea240a482b007788d5cd61c_2", + ], ) # use real process pool in tests with Pool(processes=4) as p: @@ -224,12 +286,16 @@ def test_normalize_many_schemas(caps: DestinationCapabilitiesContext, rasa_norma expected_tables = EXPECTED_USER_TABLES_RASA_NORMALIZER + ["event_bot", "event_action"] expect_load_package(rasa_normalize.load_storage, load_id, expected_tables) if schema.name == "ethereum": - expect_load_package(rasa_normalize.load_storage, load_id, EXPECTED_ETH_TABLES, full_schema_update=False) + expect_load_package( + rasa_normalize.load_storage, load_id, EXPECTED_ETH_TABLES, full_schema_update=False + ) assert set(schemas) == set(["ethereum", "event"]) @pytest.mark.parametrize("caps", ALL_CAPABILITIES, indirect=True) -def test_normalize_typed_json(caps: DestinationCapabilitiesContext, raw_normalize: Normalize) -> None: +def test_normalize_typed_json( + caps: DestinationCapabilitiesContext, raw_normalize: Normalize +) -> None: extract_items(raw_normalize.normalize_storage, [JSON_TYPED_DICT], "special", "special") with ThreadPool(processes=1) as pool: raw_normalize.run(pool) @@ -282,16 +348,30 @@ def test_schema_changes(caps: DestinationCapabilitiesContext, raw_normalize: Nor assert len(table_files["doc__comp"]) == 1 s: Schema = raw_normalize.load_or_create_schema(raw_normalize.schema_storage, "evolution") doc_table = s.get_table("doc") - assert {"_dlt_load_id", "_dlt_id", "str", "int", "bool", "int__v_text"} == set(doc_table["columns"].keys()) + assert {"_dlt_load_id", "_dlt_id", "str", "int", "bool", "int__v_text"} == set( + doc_table["columns"].keys() + ) doc__comp_table = s.get_table("doc__comp") assert doc__comp_table["parent"] == "doc" - assert {"_dlt_id", "_dlt_list_idx", "_dlt_parent_id", "str", "int", "bool", "int__v_text"} == set(doc__comp_table["columns"].keys()) + assert { + "_dlt_id", + "_dlt_list_idx", + "_dlt_parent_id", + "str", + "int", + "bool", + "int__v_text", + } == set(doc__comp_table["columns"].keys()) @pytest.mark.parametrize("caps", ALL_CAPABILITIES, indirect=True) -def test_normalize_twice_with_flatten(caps: DestinationCapabilitiesContext, raw_normalize: Normalize) -> None: +def test_normalize_twice_with_flatten( + caps: DestinationCapabilitiesContext, raw_normalize: Normalize +) -> None: load_id = extract_and_normalize_cases(raw_normalize, ["github.issues.load_page_5_duck"]) - _, table_files = expect_load_package(raw_normalize.load_storage, load_id, ["issues", "issues__labels", "issues__assignees"]) + _, table_files = expect_load_package( + raw_normalize.load_storage, load_id, ["issues", "issues__labels", "issues__assignees"] + ) assert len(table_files["issues"]) == 1 _, lines = get_line_from_file(raw_normalize.load_storage, table_files["issues"], 0) # insert writer adds 2 lines @@ -314,7 +394,12 @@ def assert_schema(_schema: Schema): assert_schema(schema) load_id = extract_and_normalize_cases(raw_normalize, ["github.issues.load_page_5_duck"]) - _, table_files = expect_load_package(raw_normalize.load_storage, load_id, ["issues", "issues__labels", "issues__assignees"], full_schema_update=False) + _, table_files = expect_load_package( + raw_normalize.load_storage, + load_id, + ["issues", "issues__labels", "issues__assignees"], + full_schema_update=False, + ) assert len(table_files["issues"]) == 1 _, lines = get_line_from_file(raw_normalize.load_storage, table_files["issues"], 0) # insert writer adds 2 lines @@ -324,35 +409,74 @@ def assert_schema(_schema: Schema): def test_group_worker_files() -> None: - files = ["f%03d" % idx for idx in range(0, 100)] assert Normalize.group_worker_files([], 4) == [] assert Normalize.group_worker_files(["f001"], 1) == [["f001"]] assert Normalize.group_worker_files(["f001"], 100) == [["f001"]] assert Normalize.group_worker_files(files[:4], 4) == [["f000"], ["f001"], ["f002"], ["f003"]] - assert Normalize.group_worker_files(files[:5], 4) == [["f000"], ["f001"], ["f002"], ["f003", "f004"]] - assert Normalize.group_worker_files(files[:8], 4) == [["f000", "f001"], ["f002", "f003"], ["f004", "f005"], ["f006", "f007"]] - assert Normalize.group_worker_files(files[:8], 3) == [["f000", "f001"], ["f002", "f003", "f006"], ["f004", "f005", "f007"]] - assert Normalize.group_worker_files(files[:5], 3) == [["f000"], ["f001", "f003"], ["f002", "f004"]] + assert Normalize.group_worker_files(files[:5], 4) == [ + ["f000"], + ["f001"], + ["f002"], + ["f003", "f004"], + ] + assert Normalize.group_worker_files(files[:8], 4) == [ + ["f000", "f001"], + ["f002", "f003"], + ["f004", "f005"], + ["f006", "f007"], + ] + assert Normalize.group_worker_files(files[:8], 3) == [ + ["f000", "f001"], + ["f002", "f003", "f006"], + ["f004", "f005", "f007"], + ] + assert Normalize.group_worker_files(files[:5], 3) == [ + ["f000"], + ["f001", "f003"], + ["f002", "f004"], + ] # check if sorted files = ["tab1.1", "chd.3", "tab1.2", "chd.4", "tab1.3"] - assert Normalize.group_worker_files(files, 3) == [["chd.3"], ["chd.4", "tab1.2"], ["tab1.1", "tab1.3"]] - - -EXPECTED_ETH_TABLES = ["blocks", "blocks__transactions", "blocks__transactions__logs", "blocks__transactions__logs__topics", - "blocks__uncles", "blocks__transactions__access_list", "blocks__transactions__access_list__storage_keys"] - -EXPECTED_USER_TABLES_RASA_NORMALIZER = ["event", "event_user", "event_user__parse_data__intent_ranking"] - - -EXPECTED_USER_TABLES = ["event", "event__parse_data__intent_ranking", "event__parse_data__response_selector__all_retrieval_intents", - "event__parse_data__response_selector__default__ranking", "event__parse_data__response_selector__default__response__response_templates", - "event__parse_data__response_selector__default__response__responses"] - - -def extract_items(normalize_storage: NormalizeStorage, items: Sequence[StrAny], schema_name: str, table_name: str) -> None: + assert Normalize.group_worker_files(files, 3) == [ + ["chd.3"], + ["chd.4", "tab1.2"], + ["tab1.1", "tab1.3"], + ] + + +EXPECTED_ETH_TABLES = [ + "blocks", + "blocks__transactions", + "blocks__transactions__logs", + "blocks__transactions__logs__topics", + "blocks__uncles", + "blocks__transactions__access_list", + "blocks__transactions__access_list__storage_keys", +] + +EXPECTED_USER_TABLES_RASA_NORMALIZER = [ + "event", + "event_user", + "event_user__parse_data__intent_ranking", +] + + +EXPECTED_USER_TABLES = [ + "event", + "event__parse_data__intent_ranking", + "event__parse_data__response_selector__all_retrieval_intents", + "event__parse_data__response_selector__default__ranking", + "event__parse_data__response_selector__default__response__response_templates", + "event__parse_data__response_selector__default__response__responses", +] + + +def extract_items( + normalize_storage: NormalizeStorage, items: Sequence[StrAny], schema_name: str, table_name: str +) -> None: extractor = ExtractorStorage(normalize_storage.config) extract_id = extractor.create_extract_id() extractor.write_data_item(extract_id, schema_name, table_name, items, None) @@ -360,7 +484,9 @@ def extract_items(normalize_storage: NormalizeStorage, items: Sequence[StrAny], extractor.commit_extract_files(extract_id) -def normalize_event_user(normalize: Normalize, case: str, expected_user_tables: List[str] = None) -> None: +def normalize_event_user( + normalize: Normalize, case: str, expected_user_tables: List[str] = None +) -> None: expected_user_tables = expected_user_tables or EXPECTED_USER_TABLES_RASA_NORMALIZER load_id = extract_and_normalize_cases(normalize, [case]) return expect_load_package(normalize.load_storage, load_id, expected_user_tables) @@ -390,12 +516,20 @@ def extract_cases(normalize_storage: NormalizeStorage, cases: Sequence[str]) -> extract_items(normalize_storage, items, schema_name, table_name) -def expect_load_package(load_storage: LoadStorage, load_id: str, expected_tables: Sequence[str], full_schema_update: bool = True) -> Dict[str, str]: +def expect_load_package( + load_storage: LoadStorage, + load_id: str, + expected_tables: Sequence[str], + full_schema_update: bool = True, +) -> Dict[str, str]: # normalize tables as paths (original json is snake case so we may do it without real lineage info) schema = load_storage.load_package_schema(load_id) # we are still in destination caps context so schema contains length assert schema.naming.max_length > 0 - expected_tables = [schema.naming.shorten_fragments(*schema.naming.break_path(table)) for table in expected_tables] + expected_tables = [ + schema.naming.shorten_fragments(*schema.naming.break_path(table)) + for table in expected_tables + ] # find jobs and processed files files = load_storage.list_new_jobs(load_id) @@ -419,7 +553,9 @@ def expect_load_package(load_storage: LoadStorage, load_id: str, expected_tables return expected_tables, ofl -def get_line_from_file(load_storage: LoadStorage, loaded_files: List[str], return_line: int = 0) -> Tuple[str, int]: +def get_line_from_file( + load_storage: LoadStorage, loaded_files: List[str], return_line: int = 0 +) -> Tuple[str, int]: lines = [] for file in loaded_files: with load_storage.storage.open_file(file) as f: diff --git a/tests/normalize/utils.py b/tests/normalize/utils.py index 505440cbba..3bbaeb2624 100644 --- a/tests/normalize/utils.py +++ b/tests/normalize/utils.py @@ -1,12 +1,11 @@ from typing import Mapping, cast from dlt.common import json - -from dlt.destinations.duckdb import capabilities as duck_insert_caps -from dlt.destinations.redshift import capabilities as rd_insert_caps -from dlt.destinations.postgres import capabilities as pg_insert_caps from dlt.destinations.bigquery import capabilities as jsonl_caps +from dlt.destinations.duckdb import capabilities as duck_insert_caps from dlt.destinations.filesystem import capabilities as filesystem_caps +from dlt.destinations.postgres import capabilities as pg_insert_caps +from dlt.destinations.redshift import capabilities as rd_insert_caps def filesystem_caps_jsonl_adapter(): @@ -27,4 +26,4 @@ def load_json_case(name: str) -> Mapping: def json_case_path(name: str) -> str: - return f"./tests/normalize/cases/{name}.json" \ No newline at end of file + return f"./tests/normalize/cases/{name}.json" diff --git a/tests/pipeline/cases/github_pipeline/github_extract.py b/tests/pipeline/cases/github_pipeline/github_extract.py index 5aa4fec156..30a171abf2 100644 --- a/tests/pipeline/cases/github_pipeline/github_extract.py +++ b/tests/pipeline/cases/github_pipeline/github_extract.py @@ -1,11 +1,13 @@ import sys -import dlt - from github_pipeline import github +import dlt + if __name__ == "__main__": - p = dlt.pipeline("dlt_github_pipeline", destination="duckdb", dataset_name="github_3", full_refresh=False) + p = dlt.pipeline( + "dlt_github_pipeline", destination="duckdb", dataset_name="github_3", full_refresh=False + ) github_source = github() if len(sys.argv) > 1: # load only N issues diff --git a/tests/pipeline/cases/github_pipeline/github_pipeline.py b/tests/pipeline/cases/github_pipeline/github_pipeline.py index 6d19709947..b8bca410d9 100644 --- a/tests/pipeline/cases/github_pipeline/github_pipeline.py +++ b/tests/pipeline/cases/github_pipeline/github_pipeline.py @@ -1,23 +1,31 @@ import sys import dlt - from dlt.common import json + @dlt.source(root_key=True) def github(): - - @dlt.resource(table_name="issues", write_disposition="merge", primary_key="id", merge_key=("node_id", "url")) + @dlt.resource( + table_name="issues", + write_disposition="merge", + primary_key="id", + merge_key=("node_id", "url"), + ) def load_issues(): # we should be in TEST_STORAGE folder - with open("../tests/normalize/cases/github.issues.load_page_5_duck.json", "r", encoding="utf-8") as f: + with open( + "../tests/normalize/cases/github.issues.load_page_5_duck.json", "r", encoding="utf-8" + ) as f: yield from json.load(f) return load_issues if __name__ == "__main__": - p = dlt.pipeline("dlt_github_pipeline", destination="duckdb", dataset_name="github_3", full_refresh=False) + p = dlt.pipeline( + "dlt_github_pipeline", destination="duckdb", dataset_name="github_3", full_refresh=False + ) github_source = github() if len(sys.argv) > 1: # load only N issues diff --git a/tests/pipeline/conftest.py b/tests/pipeline/conftest.py index a9a94230a2..e0faaa9030 100644 --- a/tests/pipeline/conftest.py +++ b/tests/pipeline/conftest.py @@ -1,2 +1,8 @@ -from tests.utils import preserve_environ, autouse_test_storage, patch_home_dir, wipe_pipeline, duckdb_pipeline_location -from tests.pipeline.utils import drop_dataset_from_env \ No newline at end of file +from tests.pipeline.utils import drop_dataset_from_env +from tests.utils import ( + autouse_test_storage, + duckdb_pipeline_location, + patch_home_dir, + preserve_environ, + wipe_pipeline, +) diff --git a/tests/pipeline/test_dlt_versions.py b/tests/pipeline/test_dlt_versions.py index 8ae9c01026..8b7a597cae 100644 --- a/tests/pipeline/test_dlt_versions.py +++ b/tests/pipeline/test_dlt_versions.py @@ -1,22 +1,21 @@ import os +import shutil import tempfile + import pytest -import shutil +from tests.utils import TEST_STORAGE_ROOT, test_storage import dlt from dlt.common import json -from dlt.common.runners import Venv -from dlt.common.utils import custom_environ, set_working_dir from dlt.common.configuration.paths import get_dlt_data_dir -from dlt.common.storages import FileStorage -from dlt.common.schema.typing import LOADS_TABLE_NAME, VERSION_TABLE_NAME, TStoredSchema from dlt.common.configuration.resolve import resolve_configuration +from dlt.common.runners import Venv +from dlt.common.schema.typing import LOADS_TABLE_NAME, VERSION_TABLE_NAME, TStoredSchema +from dlt.common.storages import FileStorage +from dlt.common.utils import custom_environ, set_working_dir from dlt.destinations.duckdb.configuration import DuckDbClientConfiguration from dlt.destinations.duckdb.sql_client import DuckDbSqlClient -from tests.utils import TEST_STORAGE_ROOT, test_storage - - GITHUB_PIPELINE_NAME = "dlt_github_pipeline" GITHUB_DATASET = "github_3" @@ -29,18 +28,34 @@ def test_pipeline_with_dlt_update(test_storage: FileStorage) -> None: # store dlt data in test storage (like patch_home_dir) with custom_environ({"DLT_DATA_DIR": get_dlt_data_dir()}): # save database outside of pipeline dir - with custom_environ({"DESTINATION__DUCKDB__CREDENTIALS": "duckdb:///test_github_3.duckdb"}): + with custom_environ( + {"DESTINATION__DUCKDB__CREDENTIALS": "duckdb:///test_github_3.duckdb"} + ): # create virtual env with (0.3.0) before the current schema upgrade with Venv.create(tempfile.mkdtemp(), ["dlt[duckdb]==0.3.0"]) as venv: # load 20 issues - print(venv.run_script("../tests/pipeline/cases/github_pipeline/github_pipeline.py", "20")) + print( + venv.run_script( + "../tests/pipeline/cases/github_pipeline/github_pipeline.py", "20" + ) + ) # load schema and check _dlt_loads definition - github_schema: TStoredSchema = json.loads(test_storage.load(f".dlt/pipelines/{GITHUB_PIPELINE_NAME}/schemas/github.schema.json")) + github_schema: TStoredSchema = json.loads( + test_storage.load( + f".dlt/pipelines/{GITHUB_PIPELINE_NAME}/schemas/github.schema.json" + ) + ) # print(github_schema["tables"][LOADS_TABLE_NAME]) assert github_schema["engine_version"] == 5 - assert "schema_version_hash" not in github_schema["tables"][LOADS_TABLE_NAME]["columns"] + assert ( + "schema_version_hash" + not in github_schema["tables"][LOADS_TABLE_NAME]["columns"] + ) # check loads table without attaching to pipeline - duckdb_cfg = resolve_configuration(DuckDbClientConfiguration(dataset_name=GITHUB_DATASET), sections=("destination", "duckdb")) + duckdb_cfg = resolve_configuration( + DuckDbClientConfiguration(dataset_name=GITHUB_DATASET), + sections=("destination", "duckdb"), + ) with DuckDbSqlClient(GITHUB_DATASET, duckdb_cfg.credentials) as client: rows = client.execute_sql(f"SELECT * FROM {LOADS_TABLE_NAME}") # make sure we have just 4 columns @@ -53,11 +68,17 @@ def test_pipeline_with_dlt_update(test_storage: FileStorage) -> None: # load all issues print(venv.run_script("../tests/pipeline/cases/github_pipeline/github_pipeline.py")) # hash hash in schema - github_schema = json.loads(test_storage.load(f".dlt/pipelines/{GITHUB_PIPELINE_NAME}/schemas/github.schema.json")) + github_schema = json.loads( + test_storage.load( + f".dlt/pipelines/{GITHUB_PIPELINE_NAME}/schemas/github.schema.json" + ) + ) assert github_schema["engine_version"] == 6 assert "schema_version_hash" in github_schema["tables"][LOADS_TABLE_NAME]["columns"] with DuckDbSqlClient(GITHUB_DATASET, duckdb_cfg.credentials) as client: - rows = client.execute_sql(f"SELECT * FROM {LOADS_TABLE_NAME} ORDER BY inserted_at") + rows = client.execute_sql( + f"SELECT * FROM {LOADS_TABLE_NAME} ORDER BY inserted_at" + ) # we have two loads assert len(rows) == 2 assert len(rows[0]) == 5 @@ -93,23 +114,40 @@ def test_load_package_with_dlt_update(test_storage: FileStorage) -> None: # store dlt data in test storage (like patch_home_dir) with custom_environ({"DLT_DATA_DIR": get_dlt_data_dir()}): # save database outside of pipeline dir - with custom_environ({"DESTINATION__DUCKDB__CREDENTIALS": "duckdb:///test_github_3.duckdb"}): + with custom_environ( + {"DESTINATION__DUCKDB__CREDENTIALS": "duckdb:///test_github_3.duckdb"} + ): # create virtual env with (0.3.0) before the current schema upgrade with Venv.create(tempfile.mkdtemp(), ["dlt[duckdb]==0.3.0"]) as venv: # extract and normalize on old version but DO NOT LOAD - print(venv.run_script("../tests/pipeline/cases/github_pipeline/github_extract.py", "70")) + print( + venv.run_script( + "../tests/pipeline/cases/github_pipeline/github_extract.py", "70" + ) + ) # switch to current version and make sure the load package loads and schema migrates venv = Venv.restore_current() print(venv.run_script("../tests/pipeline/cases/github_pipeline/github_load.py")) - duckdb_cfg = resolve_configuration(DuckDbClientConfiguration(dataset_name=GITHUB_DATASET), sections=("destination", "duckdb")) + duckdb_cfg = resolve_configuration( + DuckDbClientConfiguration(dataset_name=GITHUB_DATASET), + sections=("destination", "duckdb"), + ) with DuckDbSqlClient(GITHUB_DATASET, duckdb_cfg.credentials) as client: rows = client.execute_sql("SELECT * FROM issues") assert len(rows) == 70 - github_schema = json.loads(test_storage.load(f".dlt/pipelines/{GITHUB_PIPELINE_NAME}/schemas/github.schema.json")) + github_schema = json.loads( + test_storage.load( + f".dlt/pipelines/{GITHUB_PIPELINE_NAME}/schemas/github.schema.json" + ) + ) # attach to existing pipeline pipeline = dlt.attach(GITHUB_PIPELINE_NAME, credentials=duckdb_cfg.credentials) # get the schema from schema storage before we sync - github_schema = json.loads(test_storage.load(f".dlt/pipelines/{GITHUB_PIPELINE_NAME}/schemas/github.schema.json")) + github_schema = json.loads( + test_storage.load( + f".dlt/pipelines/{GITHUB_PIPELINE_NAME}/schemas/github.schema.json" + ) + ) pipeline = pipeline.drop() pipeline.sync_destination() assert pipeline.default_schema.ENGINE_VERSION == 6 diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index afaffafa2f..bdba03e5ac 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -2,17 +2,36 @@ import os import random from typing import Any -from tenacity import retry_if_exception, Retrying, stop_after_attempt import pytest +from tenacity import Retrying, retry_if_exception, stop_after_attempt +from tests.common.configuration.utils import environment +from tests.common.utils import TEST_SENTRY_DSN +from tests.extract.utils import expect_extracted_file +from tests.load.pipeline.utils import DestinationTestConfiguration, destinations_configs +from tests.pipeline.utils import airtable_emojis, assert_load_info +from tests.utils import TEST_STORAGE_ROOT import dlt from dlt.common import json, sleep from dlt.common.configuration.container import Container +from dlt.common.configuration.specs.aws_credentials import AwsCredentials +from dlt.common.configuration.specs.exceptions import NativeValueError +from dlt.common.configuration.specs.gcp_credentials import GcpOAuthCredentials from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.exceptions import DestinationHasFailedJobs, DestinationTerminalException, PipelineStateNotAvailable, UnknownDestinationModule +from dlt.common.exceptions import ( + DestinationHasFailedJobs, + DestinationTerminalException, + PipelineStateNotAvailable, + UnknownDestinationModule, +) from dlt.common.pipeline import PipelineContext -from dlt.common.runtime.collector import AliveCollector, EnlightenCollector, LogCollector, TqdmCollector +from dlt.common.runtime.collector import ( + AliveCollector, + EnlightenCollector, + LogCollector, + TqdmCollector, +) from dlt.common.schema.exceptions import InvalidDatasetName from dlt.common.utils import uniq_id from dlt.extract.exceptions import SourceExhausted @@ -22,17 +41,6 @@ from dlt.pipeline.exceptions import InvalidPipelineName, PipelineNotActive, PipelineStepFailed from dlt.pipeline.helpers import retry_load from dlt.pipeline.state_sync import STATE_TABLE_NAME -from dlt.common.configuration.specs.exceptions import NativeValueError -from dlt.common.configuration.specs.aws_credentials import AwsCredentials -from dlt.common.configuration.specs.gcp_credentials import GcpOAuthCredentials -from tests.common.utils import TEST_SENTRY_DSN - -from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration - -from tests.utils import TEST_STORAGE_ROOT -from tests.common.configuration.utils import environment -from tests.extract.utils import expect_extracted_file -from tests.pipeline.utils import assert_load_info, airtable_emojis def test_default_pipeline() -> None: @@ -187,51 +195,86 @@ def test_deterministic_salt(environment) -> None: assert p.pipeline_salt != p3.pipeline_salt -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name) +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) def test_create_pipeline_all_destinations(destination_config: DestinationTestConfiguration) -> None: # create pipelines, extract and normalize. that should be possible without installing any dependencies - p = dlt.pipeline(pipeline_name=destination_config.destination + "_pipeline", destination=destination_config.destination, staging=destination_config.staging) + p = dlt.pipeline( + pipeline_name=destination_config.destination + "_pipeline", + destination=destination_config.destination, + staging=destination_config.staging, + ) # are capabilities injected caps = p._container[DestinationCapabilitiesContext] # are right naming conventions created - assert p._default_naming.max_length == min(caps.max_column_identifier_length, caps.max_identifier_length) + assert p._default_naming.max_length == min( + caps.max_column_identifier_length, caps.max_identifier_length + ) p.extract([1, "2", 3], table_name="data") # is default schema with right naming convention - assert p.default_schema.naming.max_length == min(caps.max_column_identifier_length, caps.max_identifier_length) + assert p.default_schema.naming.max_length == min( + caps.max_column_identifier_length, caps.max_identifier_length + ) p.normalize() - assert p.default_schema.naming.max_length == min(caps.max_column_identifier_length, caps.max_identifier_length) + assert p.default_schema.naming.max_length == min( + caps.max_column_identifier_length, caps.max_identifier_length + ) def test_destination_explicit_credentials(environment: Any) -> None: # test redshift - p = dlt.pipeline(pipeline_name="postgres_pipeline", destination="redshift", credentials="redshift://loader:loader@localhost:5432/dlt_data") + p = dlt.pipeline( + pipeline_name="postgres_pipeline", + destination="redshift", + credentials="redshift://loader:loader@localhost:5432/dlt_data", + ) config = p._get_destination_client_initial_config() assert config.credentials.is_resolved() # with staging - p = dlt.pipeline(pipeline_name="postgres_pipeline", staging="filesystem", destination="redshift", credentials="redshift://loader:loader@localhost:5432/dlt_data") + p = dlt.pipeline( + pipeline_name="postgres_pipeline", + staging="filesystem", + destination="redshift", + credentials="redshift://loader:loader@localhost:5432/dlt_data", + ) config = p._get_destination_client_initial_config(p.destination) assert config.credentials.is_resolved() config = p._get_destination_client_initial_config(p.staging, as_staging=True) assert config.credentials is None p._wipe_working_folder() # try filesystem which uses union of credentials that requires bucket_url to resolve - p = dlt.pipeline(pipeline_name="postgres_pipeline", destination="filesystem", credentials={"aws_access_key_id": "key_id", "aws_secret_access_key": "key"}) + p = dlt.pipeline( + pipeline_name="postgres_pipeline", + destination="filesystem", + credentials={"aws_access_key_id": "key_id", "aws_secret_access_key": "key"}, + ) config = p._get_destination_client_initial_config(p.destination) assert isinstance(config.credentials, AwsCredentials) assert config.credentials.is_resolved() # resolve gcp oauth - p = dlt.pipeline(pipeline_name="postgres_pipeline", destination="filesystem", credentials={"project_id": "pxid", "refresh_token": "123token", "client_id": "cid", "client_secret": "s"}) + p = dlt.pipeline( + pipeline_name="postgres_pipeline", + destination="filesystem", + credentials={ + "project_id": "pxid", + "refresh_token": "123token", + "client_id": "cid", + "client_secret": "s", + }, + ) config = p._get_destination_client_initial_config(p.destination) assert isinstance(config.credentials, GcpOAuthCredentials) assert config.credentials.is_resolved() # if string cannot be parsed - p = dlt.pipeline(pipeline_name="postgres_pipeline", destination="filesystem", credentials="PR8BLEM") + p = dlt.pipeline( + pipeline_name="postgres_pipeline", destination="filesystem", credentials="PR8BLEM" + ) # with pytest.raises(NativeValueError) as ne_x: p._get_destination_client_initial_config(p.destination) def test_extract_source_twice() -> None: - def some_data(): yield [1, 2, 3] yield [1, 2, 3] @@ -267,8 +310,18 @@ def some_data(): def test_extract_multiple_sources() -> None: - s1 = DltSource("default", "module", dlt.Schema("default"), [dlt.resource([1, 2, 3], name="resource_1"), dlt.resource([3, 4, 5], name="resource_2")]) - s2 = DltSource("default_2", "module", dlt.Schema("default_2"), [dlt.resource([6, 7, 8], name="resource_3"), dlt.resource([9, 10, 0], name="resource_4")]) + s1 = DltSource( + "default", + "module", + dlt.Schema("default"), + [dlt.resource([1, 2, 3], name="resource_1"), dlt.resource([3, 4, 5], name="resource_2")], + ) + s2 = DltSource( + "default_2", + "module", + dlt.Schema("default_2"), + [dlt.resource([6, 7, 8], name="resource_3"), dlt.resource([9, 10, 0], name="resource_4")], + ) p = dlt.pipeline(destination="dummy") p.config.restore_from_destination = False @@ -287,11 +340,21 @@ def test_extract_multiple_sources() -> None: def i_fail(): raise NotImplementedError() - s3 = DltSource("default_3", "module", dlt.Schema("default_3"), [dlt.resource([1, 2, 3], name="resource_1"), dlt.resource([3, 4, 5], name="resource_2")]) - s4 = DltSource("default_4", "module", dlt.Schema("default_4"), [dlt.resource([6, 7, 8], name="resource_3"), i_fail]) + s3 = DltSource( + "default_3", + "module", + dlt.Schema("default_3"), + [dlt.resource([1, 2, 3], name="resource_1"), dlt.resource([3, 4, 5], name="resource_2")], + ) + s4 = DltSource( + "default_4", + "module", + dlt.Schema("default_4"), + [dlt.resource([6, 7, 8], name="resource_3"), i_fail], + ) with pytest.raises(PipelineStepFailed): - p.extract([s3, s4]) + p.extract([s3, s4]) # nothing to normalize assert len(storage.list_files_to_normalize_sorted()) == 0 @@ -391,7 +454,7 @@ def test_sentry_tracing() -> None: def r_check_sentry(): assert sentry_sdk.Hub.current.scope.span.op == "extract" assert sentry_sdk.Hub.current.scope.span.containing_transaction.name == "run" - yield [1,2,3] + yield [1, 2, 3] p.run(r_check_sentry) assert sentry_sdk.Hub.current.scope.span is None @@ -416,12 +479,10 @@ def r_fail(): assert sentry_sdk.Hub.current.scope.span is None - def test_pipeline_state_on_extract_exception() -> None: pipeline_name = "pipe_" + uniq_id() p = dlt.pipeline(pipeline_name=pipeline_name, destination="dummy") - @dlt.resource def data_piece_1(): yield [1, 2, 3] @@ -543,7 +604,6 @@ def test_run_load_pending() -> None: pipeline_name = "pipe_" + uniq_id() p = dlt.pipeline(pipeline_name=pipeline_name, destination="dummy") - def some_data(): yield from [1, 2, 3] @@ -588,7 +648,11 @@ def fail_extract(): attempt = None - for attempt in Retrying(stop=stop_after_attempt(3), retry=retry_if_exception(retry_load(("load", "extract"))), reraise=True): + for attempt in Retrying( + stop=stop_after_attempt(3), + retry=retry_if_exception(retry_load(("load", "extract"))), + reraise=True, + ): with attempt: p.run(fail_extract()) # it retried @@ -597,7 +661,9 @@ def fail_extract(): # now it fails (extract is terminal exception) retry_count = 2 with pytest.raises(PipelineStepFailed) as py_ex: - for attempt in Retrying(stop=stop_after_attempt(3), retry=retry_if_exception(retry_load(())), reraise=True): + for attempt in Retrying( + stop=stop_after_attempt(3), retry=retry_if_exception(retry_load(())), reraise=True + ): with attempt: p.run(fail_extract()) assert isinstance(py_ex.value, PipelineStepFailed) @@ -607,7 +673,11 @@ def fail_extract(): os.environ["RAISE_ON_FAILED_JOBS"] = "true" os.environ["FAIL_PROB"] = "1.0" with pytest.raises(PipelineStepFailed) as py_ex: - for attempt in Retrying(stop=stop_after_attempt(3), retry=retry_if_exception(retry_load(("load", "extract"))), reraise=True): + for attempt in Retrying( + stop=stop_after_attempt(3), + retry=retry_if_exception(retry_load(("load", "extract"))), + reraise=True, + ): with attempt: p.run(fail_extract()) assert isinstance(py_ex.value, PipelineStepFailed) @@ -637,6 +707,7 @@ def test_set_get_local_value() -> None: assert p.state["_local"][value] == value new_val = uniq_id() + # check in context manager @dlt.resource def _w_local_state(): @@ -667,26 +738,32 @@ def resource_1(): assert p.default_schema.get_table("resource_1")["write_disposition"] == "replace" -@dlt.transformer(name="github_repo_events", primary_key="id", write_disposition="merge", table_name=lambda i: i['type']) +@dlt.transformer( + name="github_repo_events", + primary_key="id", + write_disposition="merge", + table_name=lambda i: i["type"], +) def github_repo_events(page): yield page @dlt.transformer(name="github_repo_events", primary_key="id", write_disposition="merge") def github_repo_events_table_meta(page): - yield from [dlt.mark.with_table_name(p, p['type']) for p in page] + yield from [dlt.mark.with_table_name(p, p["type"]) for p in page] @dlt.resource def _get_shuffled_events(): - with open("tests/normalize/cases/github.events.load_page_1_duck.json", "r", encoding="utf-8") as f: + with open( + "tests/normalize/cases/github.events.load_page_1_duck.json", "r", encoding="utf-8" + ) as f: issues = json.load(f) yield issues -@pytest.mark.parametrize('github_resource', (github_repo_events_table_meta, github_repo_events)) +@pytest.mark.parametrize("github_resource", (github_repo_events_table_meta, github_repo_events)) def test_dispatch_rows_to_tables(github_resource: DltResource): - os.environ["COMPLETED_PROB"] = "1.0" pipeline_name = "pipe_" + uniq_id() p = dlt.pipeline(pipeline_name=pipeline_name, destination="dummy") @@ -696,49 +773,53 @@ def test_dispatch_rows_to_tables(github_resource: DltResource): # get all expected tables events = list(_get_shuffled_events) - expected_tables = set(map(lambda e: p.default_schema.naming.normalize_identifier(e["type"]), events)) + expected_tables = set( + map(lambda e: p.default_schema.naming.normalize_identifier(e["type"]), events) + ) # all the tables present - assert expected_tables.intersection([t["name"] for t in p.default_schema.data_tables()]) == expected_tables + assert ( + expected_tables.intersection([t["name"] for t in p.default_schema.data_tables()]) + == expected_tables + ) # all the columns have primary keys and merge disposition derived from resource - for table in p.default_schema.data_tables(): + for table in p.default_schema.data_tables(): if table.get("parent") is None: assert table["write_disposition"] == "merge" assert table["columns"]["id"]["primary_key"] is True def test_resource_name_in_schema() -> None: - @dlt.resource(table_name='some_table') + @dlt.resource(table_name="some_table") def static_data(): - yield {'a': 1, 'b': 2} + yield {"a": 1, "b": 2} - @dlt.resource(table_name=lambda x: 'dynamic_func_table') + @dlt.resource(table_name=lambda x: "dynamic_func_table") def dynamic_func_data(): - yield {'a': 1, 'b': 2} + yield {"a": 1, "b": 2} @dlt.resource def dynamic_mark_data(): - yield dlt.mark.with_table_name({'a': 1, 'b': 2}, 'dynamic_mark_table') + yield dlt.mark.with_table_name({"a": 1, "b": 2}, "dynamic_mark_table") - @dlt.resource(table_name='parent_table') + @dlt.resource(table_name="parent_table") def nested_data(): - yield {'a': 1, 'items': [{'c': 2}, {'c': 3}, {'c': 4}]} + yield {"a": 1, "items": [{"c": 2}, {"c": 3}, {"c": 4}]} @dlt.source def some_source(): return [static_data(), dynamic_func_data(), dynamic_mark_data(), nested_data()] - source = some_source() - p = dlt.pipeline(pipeline_name=uniq_id(), destination='dummy') + p = dlt.pipeline(pipeline_name=uniq_id(), destination="dummy") p.run(source) - assert source.schema.tables['some_table']['resource'] == 'static_data' - assert source.schema.tables['dynamic_func_table']['resource'] == 'dynamic_func_data' - assert source.schema.tables['dynamic_mark_table']['resource'] == 'dynamic_mark_data' - assert source.schema.tables['parent_table']['resource'] == 'nested_data' - assert 'resource' not in source.schema.tables['parent_table__items'] + assert source.schema.tables["some_table"]["resource"] == "static_data" + assert source.schema.tables["dynamic_func_table"]["resource"] == "dynamic_func_data" + assert source.schema.tables["dynamic_mark_table"]["resource"] == "dynamic_mark_data" + assert source.schema.tables["parent_table"]["resource"] == "nested_data" + assert "resource" not in source.schema.tables["parent_table__items"] def test_preserve_fields_order() -> None: @@ -762,12 +843,23 @@ def reverse_order(item): p.extract(ordered_dict().add_map(reverse_order)) p.normalize() - assert list(p.default_schema.tables["order_1"]["columns"].keys()) == ["col_1", "col_2", "col_3", '_dlt_load_id', '_dlt_id'] - assert list(p.default_schema.tables["order_2"]["columns"].keys()) == ["col_3", "col_2", "col_1", '_dlt_load_id', '_dlt_id'] + assert list(p.default_schema.tables["order_1"]["columns"].keys()) == [ + "col_1", + "col_2", + "col_3", + "_dlt_load_id", + "_dlt_id", + ] + assert list(p.default_schema.tables["order_2"]["columns"].keys()) == [ + "col_3", + "col_2", + "col_1", + "_dlt_load_id", + "_dlt_id", + ] def run_deferred(iters): - @dlt.defer def item(n): sleep(random.random() / 2) @@ -785,7 +877,6 @@ def many_delayed(many, iters): @pytest.mark.parametrize("progress", ["tqdm", "enlighten", "log", "alive_progress"]) def test_pipeline_progress(progress: str) -> None: - os.environ["TIMEOUT"] = "3.0" p = dlt.pipeline(destination="dummy", progress=progress) @@ -813,11 +904,12 @@ def test_pipeline_progress(progress: str) -> None: def test_pipeline_log_progress() -> None: - os.environ["TIMEOUT"] = "3.0" # will attach dlt logger - p = dlt.pipeline(destination="dummy", progress=dlt.progress.log(0.5, logger=None, log_level=logging.WARNING)) + p = dlt.pipeline( + destination="dummy", progress=dlt.progress.log(0.5, logger=None, log_level=logging.WARNING) + ) # collector was created before pipeline so logger is not attached assert p.collector.logger is None p.extract(many_delayed(2, 10)) @@ -831,7 +923,6 @@ def test_pipeline_log_progress() -> None: def test_pipeline_source_state_activation() -> None: - appendix_yielded = None @dlt.source @@ -850,7 +941,7 @@ def appendix(): def writes_state(): dlt.current.source_state()["appendix"] = source_st dlt.current.resource_state()["RX"] = resource_st - yield from [1,2,3] + yield from [1, 2, 3] yield writes_state @@ -861,8 +952,11 @@ def writes_state(): assert s_appendix.state == {} # create state by running extract p_appendix.extract(s_appendix) - assert s_appendix.state == {'appendix': 'appendix', 'resources': {'writes_state': {'RX': 'r_appendix'}}} - assert s_appendix.writes_state.state == {'RX': 'r_appendix'} + assert s_appendix.state == { + "appendix": "appendix", + "resources": {"writes_state": {"RX": "r_appendix"}}, + } + assert s_appendix.writes_state.state == {"RX": "r_appendix"} # change the active pipeline p_postfix = dlt.pipeline(pipeline_name="postfix_p") @@ -870,7 +964,7 @@ def writes_state(): assert s_appendix.state == {} # and back p_appendix.activate() - assert s_appendix.writes_state.state == {'RX': 'r_appendix'} + assert s_appendix.writes_state.state == {"RX": "r_appendix"} # create another source s_w_appendix = reads_state("appendix", "r_appendix") @@ -947,7 +1041,12 @@ def test_emojis_resource_names() -> None: table = info.load_packages[0].schema_update["_schedule"] assert table["resource"] == "📆 Schedule" # only schedule is added - assert set(info.load_packages[0].schema_update.keys()) == {"_dlt_version", "_dlt_loads", "_schedule", "_dlt_pipeline_state"} + assert set(info.load_packages[0].schema_update.keys()) == { + "_dlt_version", + "_dlt_loads", + "_schedule", + "_dlt_pipeline_state", + } info = pipeline.run(airtable_emojis()) assert_load_info(info) # here we add _peacock with has primary_key (so at least single column) diff --git a/tests/pipeline/test_pipeline_file_format_resolver.py b/tests/pipeline/test_pipeline_file_format_resolver.py index 0a9ecacd2a..cf60fd18dd 100644 --- a/tests/pipeline/test_pipeline_file_format_resolver.py +++ b/tests/pipeline/test_pipeline_file_format_resolver.py @@ -1,13 +1,19 @@ +import pytest import dlt -import pytest +from dlt.common.exceptions import ( + DestinationIncompatibleLoaderFileFormatException, + DestinationLoadingViaStagingNotSupported, + DestinationNoStagingMode, +) -from dlt.common.exceptions import DestinationIncompatibleLoaderFileFormatException, DestinationLoadingViaStagingNotSupported, DestinationNoStagingMode -def test_file_format_resolution() -> None: +def test_file_format_resolution() -> None: # raise on destinations that does not support staging with pytest.raises(DestinationLoadingViaStagingNotSupported): - p = dlt.pipeline(pipeline_name="managed_state_pipeline", destination="postgres", staging="filesystem") + p = dlt.pipeline( + pipeline_name="managed_state_pipeline", destination="postgres", staging="filesystem" + ) # raise on staging that does not support staging interface with pytest.raises(DestinationNoStagingMode): @@ -15,7 +21,7 @@ def test_file_format_resolution() -> None: p = dlt.pipeline(pipeline_name="managed_state_pipeline") - class cp(): + class cp: def __init__(self) -> None: self.preferred_loader_file_format = None self.supported_loader_file_formats = [] @@ -59,4 +65,4 @@ def __init__(self) -> None: destcp.preferred_staging_file_format = "csv" stagecp.supported_loader_file_formats = ["jsonl", "parquet"] with pytest.raises(DestinationIncompatibleLoaderFileFormatException): - p._resolve_loader_file_format("some", "some", destcp, stagecp, None) \ No newline at end of file + p._resolve_loader_file_format("some", "some", destcp, stagecp, None) diff --git a/tests/pipeline/test_pipeline_state.py b/tests/pipeline/test_pipeline_state.py index c9abbdad59..102c374b8b 100644 --- a/tests/pipeline/test_pipeline_state.py +++ b/tests/pipeline/test_pipeline_state.py @@ -1,40 +1,43 @@ import os import shutil + import pytest +from tests.pipeline.utils import airtable_emojis, json_case_path, load_json_case +from tests.utils import test_storage import dlt - +from dlt.common import pipeline as state_module from dlt.common.exceptions import PipelineStateNotAvailable, ResourceNameNotAvailable from dlt.common.schema import Schema from dlt.common.source import get_current_pipe_name from dlt.common.storages import FileStorage -from dlt.common import pipeline as state_module from dlt.common.utils import uniq_id - from dlt.pipeline.exceptions import PipelineStateEngineNoUpgradePathException, PipelineStepFailed from dlt.pipeline.pipeline import Pipeline -from dlt.pipeline.state_sync import migrate_state, STATE_ENGINE_VERSION - -from tests.utils import test_storage -from tests.pipeline.utils import json_case_path, load_json_case, airtable_emojis +from dlt.pipeline.state_sync import STATE_ENGINE_VERSION, migrate_state @dlt.resource() def some_data(): last_value = dlt.current.source_state().get("last_value", 0) - yield [1,2,3] + yield [1, 2, 3] dlt.current.source_state()["last_value"] = last_value + 1 @dlt.resource() def some_data_resource_state(): last_value = dlt.current.resource_state().get("last_value", 0) - yield [1,2,3] + yield [1, 2, 3] dlt.current.resource_state()["last_value"] = last_value + 1 def test_restore_state_props() -> None: - p = dlt.pipeline(pipeline_name="restore_state_props", destination="redshift", staging="filesystem", dataset_name="the_dataset") + p = dlt.pipeline( + pipeline_name="restore_state_props", + destination="redshift", + staging="filesystem", + dataset_name="the_dataset", + ) p.extract(some_data()) state = p.state assert state["dataset_name"] == "the_dataset" @@ -77,7 +80,9 @@ def some_source(): sources_state = p.state["sources"] # the source name is the source state key assert sources_state[s.name]["last_value"] == 1 - assert sources_state["managed_state"]["last_value"] == 2 # the state for standalone resource not affected + assert ( + sources_state["managed_state"]["last_value"] == 2 + ) # the state for standalone resource not affected @dlt.source def source_same_section(): @@ -114,7 +119,6 @@ def test_no_active_pipeline_required_for_resource() -> None: def test_active_pipeline_required_for_source() -> None: - @dlt.source def some_source(): dlt.current.source_state().get("last_value", 0) @@ -134,6 +138,7 @@ def some_source(): p.deactivate() list(s) + def test_source_state_iterator(): os.environ["COMPLETED_PROB"] = "1.0" pipeline_name = "pipe_" + uniq_id() @@ -146,7 +151,7 @@ def main(): # increase the multiplier each time state is obtained state["mark"] *= 2 yield [1, 2, 3] - assert dlt.current.source_state()["mark"] == mark*2 + assert dlt.current.source_state()["mark"] == mark * 2 @dlt.transformer(data_from=main) def feeding(item): @@ -154,7 +159,7 @@ def feeding(item): assert dlt.current.source_state()["mark"] > 1 print(f"feeding state {dlt.current.source_state()}") mark = dlt.current.source_state()["mark"] - yield from map(lambda i: i*mark, item) + yield from map(lambda i: i * mark, item) @dlt.source def pass_the_state(): @@ -188,6 +193,7 @@ def test_unmanaged_state() -> None: def _gen_inner(): dlt.state()["gen"] = True yield 1 + list(dlt.resource(_gen_inner)) list(dlt.resource(_gen_inner())) assert state_module._last_full_state["sources"]["unmanaged"]["gen"] is True @@ -236,7 +242,12 @@ def _gen_inner(): def test_resource_state_write() -> None: r = some_data_resource_state() assert list(r) == [1, 2, 3] - assert state_module._last_full_state["sources"]["test_pipeline_state"]["resources"]["some_data_resource_state"]["last_value"] == 1 + assert ( + state_module._last_full_state["sources"]["test_pipeline_state"]["resources"][ + "some_data_resource_state" + ]["last_value"] + == 1 + ) with pytest.raises(ResourceNameNotAvailable): get_current_pipe_name() @@ -247,7 +258,12 @@ def _gen_inner(): p = dlt.pipeline() r = dlt.resource(_gen_inner(), name="name_ovrd") assert list(r) == [1] - assert state_module._last_full_state["sources"][p._make_schema_with_default_name().name]["resources"]["name_ovrd"]["gen"] is True + assert ( + state_module._last_full_state["sources"][p._make_schema_with_default_name().name][ + "resources" + ]["name_ovrd"]["gen"] + is True + ) with pytest.raises(ResourceNameNotAvailable): get_current_pipe_name() @@ -267,20 +283,29 @@ def _gen_inner(tv="df"): r = dlt.resource(_gen_inner("gen_tf"), name="name_ovrd") p.extract(r) assert r.state["gen"] == "gen_tf" - assert state_module._last_full_state["sources"][p.default_schema_name]["resources"]["name_ovrd"]["gen"] == "gen_tf" + assert ( + state_module._last_full_state["sources"][p.default_schema_name]["resources"]["name_ovrd"][ + "gen" + ] + == "gen_tf" + ) with pytest.raises(ResourceNameNotAvailable): get_current_pipe_name() r = dlt.resource(_gen_inner, name="pure_function") p.extract(r) assert r.state["gen"] == "df" - assert state_module._last_full_state["sources"][p.default_schema_name]["resources"]["pure_function"]["gen"] == "df" + assert ( + state_module._last_full_state["sources"][p.default_schema_name]["resources"][ + "pure_function" + ]["gen"] + == "df" + ) with pytest.raises(ResourceNameNotAvailable): get_current_pipe_name() # get resource state in defer function def _gen_inner_defer(tv="df"): - @dlt.defer def _run(): dlt.current.resource_state()["gen"] = tv @@ -296,7 +321,6 @@ def _run(): # get resource state in defer explicitly def _gen_inner_defer_explicit_name(resource_name, tv="df"): - @dlt.defer def _run(): dlt.current.resource_state(resource_name)["gen"] = tv @@ -307,11 +331,15 @@ def _run(): r = dlt.resource(_gen_inner_defer_explicit_name, name="defer_function_explicit") p.extract(r("defer_function_explicit", "expl")) assert r.state["gen"] == "expl" - assert state_module._last_full_state["sources"][p.default_schema_name]["resources"]["defer_function_explicit"]["gen"] == "expl" + assert ( + state_module._last_full_state["sources"][p.default_schema_name]["resources"][ + "defer_function_explicit" + ]["gen"] + == "expl" + ) # get resource state in yielding defer (which btw is invalid and will be resolved in main thread) def _gen_inner_defer_yielding(tv="yielding"): - @dlt.defer def _run(): dlt.current.resource_state()["gen"] = tv @@ -322,11 +350,15 @@ def _run(): r = dlt.resource(_gen_inner_defer_yielding, name="defer_function_yielding") p.extract(r) assert r.state["gen"] == "yielding" - assert state_module._last_full_state["sources"][p.default_schema_name]["resources"]["defer_function_yielding"]["gen"] == "yielding" + assert ( + state_module._last_full_state["sources"][p.default_schema_name]["resources"][ + "defer_function_yielding" + ]["gen"] + == "yielding" + ) # get resource state in async function def _gen_inner_async(tv="async"): - async def _run(): dlt.current.resource_state()["gen"] = tv return 1 @@ -351,8 +383,18 @@ def _gen_inner(item): # p = dlt.pipeline() # p.extract(dlt.transformer(_gen_inner, data_from=r, name="tx_other_name")) assert list(dlt.transformer(_gen_inner, data_from=r, name="tx_other_name")) == [2, 4, 6] - assert state_module._last_full_state["sources"]["test_pipeline_state"]["resources"]["some_data_resource_state"]["last_value"] == 1 - assert state_module._last_full_state["sources"]["test_pipeline_state"]["resources"]["tx_other_name"]["gen"] is True + assert ( + state_module._last_full_state["sources"]["test_pipeline_state"]["resources"][ + "some_data_resource_state" + ]["last_value"] + == 1 + ) + assert ( + state_module._last_full_state["sources"]["test_pipeline_state"]["resources"][ + "tx_other_name" + ]["gen"] + is True + ) # returning transformer def _gen_inner_rv(item): @@ -360,8 +402,20 @@ def _gen_inner_rv(item): return item * 2 r = some_data_resource_state() - assert list(dlt.transformer(_gen_inner_rv, data_from=r, name="tx_other_name_rv")) == [1, 2, 3, 1, 2, 3] - assert state_module._last_full_state["sources"]["test_pipeline_state"]["resources"]["tx_other_name_rv"]["gen"] is True + assert list(dlt.transformer(_gen_inner_rv, data_from=r, name="tx_other_name_rv")) == [ + 1, + 2, + 3, + 1, + 2, + 3, + ] + assert ( + state_module._last_full_state["sources"]["test_pipeline_state"]["resources"][ + "tx_other_name_rv" + ]["gen"] + is True + ) # deferred transformer @dlt.defer @@ -390,8 +444,17 @@ async def _gen_inner_rv_async_name(item, r_name): return item r = some_data_resource_state() - assert list(dlt.transformer(_gen_inner_rv_async_name, data_from=r, name="tx_other_name_async")("tx_other_name_async")) == [1, 2, 3] - assert state_module._last_full_state["sources"]["test_pipeline_state"]["resources"]["tx_other_name_async"]["gen"] is True + assert list( + dlt.transformer(_gen_inner_rv_async_name, data_from=r, name="tx_other_name_async")( + "tx_other_name_async" + ) + ) == [1, 2, 3] + assert ( + state_module._last_full_state["sources"]["test_pipeline_state"]["resources"][ + "tx_other_name_async" + ]["gen"] + is True + ) def test_transform_function_state_write() -> None: @@ -400,29 +463,41 @@ def test_transform_function_state_write() -> None: # transform executed within the same thread def transform(item): dlt.current.resource_state()["form"] = item - return item*2 + return item * 2 r.add_map(transform) assert list(r) == [2, 4, 6] - assert state_module._last_full_state["sources"]["test_pipeline_state"]["resources"]["some_data_resource_state"]["form"] == 3 + assert ( + state_module._last_full_state["sources"]["test_pipeline_state"]["resources"][ + "some_data_resource_state" + ]["form"] + == 3 + ) def test_migrate_state(test_storage: FileStorage) -> None: state_v1 = load_json_case("state/state.v1") - state = migrate_state("test_pipeline", state_v1, state_v1["_state_engine_version"], STATE_ENGINE_VERSION) + state = migrate_state( + "test_pipeline", state_v1, state_v1["_state_engine_version"], STATE_ENGINE_VERSION + ) assert state["_state_engine_version"] == STATE_ENGINE_VERSION assert "_local" in state with pytest.raises(PipelineStateEngineNoUpgradePathException) as py_ex: state_v1 = load_json_case("state/state.v1") - migrate_state("test_pipeline", state_v1, state_v1["_state_engine_version"], STATE_ENGINE_VERSION + 1) + migrate_state( + "test_pipeline", state_v1, state_v1["_state_engine_version"], STATE_ENGINE_VERSION + 1 + ) assert py_ex.value.init_engine == state_v1["_state_engine_version"] assert py_ex.value.from_engine == STATE_ENGINE_VERSION assert py_ex.value.to_engine == STATE_ENGINE_VERSION + 1 # also test pipeline init where state is old test_storage.create_folder("debug_pipeline") - shutil.copy(json_case_path("state/state.v1"), test_storage.make_full_path(f"debug_pipeline/{Pipeline.STATE_FILE}")) + shutil.copy( + json_case_path("state/state.v1"), + test_storage.make_full_path(f"debug_pipeline/{Pipeline.STATE_FILE}"), + ) p = dlt.attach(pipeline_name="debug_pipeline", pipelines_dir=test_storage.storage_path) assert p.dataset_name == "debug_pipeline_data" assert p.default_schema_name == "example_source" @@ -438,6 +513,7 @@ def test_resource_state_name_not_normalized() -> None: # get state from destination from dlt.pipeline.state_sync import load_state_from_destination + with pipeline.sql_client() as client: state = load_state_from_destination(pipeline.pipeline_name, client) assert "airtable_emojis" in state["sources"] diff --git a/tests/pipeline/test_pipeline_trace.py b/tests/pipeline/test_pipeline_trace.py index da450720e9..d0277905e2 100644 --- a/tests/pipeline/test_pipeline_trace.py +++ b/tests/pipeline/test_pipeline_trace.py @@ -1,37 +1,43 @@ -import io -import os import asyncio import datetime # noqa: 251 +import io +import os from typing import Any, List from unittest.mock import patch + import pytest import requests_mock +from tests.common.configuration.utils import environment, toml_providers +from tests.utils import start_test_telemetry import dlt - from dlt.common import json from dlt.common.configuration.specs import CredentialsConfiguration from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext from dlt.common.pipeline import ExtractInfo -from dlt.common.schema import Schema from dlt.common.runtime.telemetry import stop_telemetry +from dlt.common.schema import Schema from dlt.common.typing import DictStrAny, StrStr, TSecretValue - +from dlt.extract.pipe import Pipe +from dlt.extract.source import DltResource, DltSource from dlt.pipeline.exceptions import PipelineStepFailed from dlt.pipeline.pipeline import Pipeline -from dlt.pipeline.trace import PipelineTrace, SerializableResolvedValueTrace, describe_extract_data, load_trace +from dlt.pipeline.trace import ( + PipelineTrace, + SerializableResolvedValueTrace, + describe_extract_data, + load_trace, +) from dlt.pipeline.track import slack_notify_load_success -from dlt.extract.source import DltResource, DltSource -from dlt.extract.pipe import Pipe -from tests.utils import start_test_telemetry -from tests.common.configuration.utils import toml_providers, environment def test_create_trace(toml_providers: ConfigProvidersContext) -> None: - @dlt.source - def inject_tomls(api_type = dlt.config.value, credentials: CredentialsConfiguration = dlt.secrets.value, secret_value: TSecretValue = "123"): - + def inject_tomls( + api_type=dlt.config.value, + credentials: CredentialsConfiguration = dlt.secrets.value, + secret_value: TSecretValue = "123", + ): @dlt.resource def data(): yield [1, 2, 3] @@ -78,7 +84,6 @@ def data(): # extract with exception @dlt.source def async_exception(max_range=1): - async def get_val(v): await asyncio.sleep(0.1) if v % 3 == 0: @@ -87,7 +92,7 @@ async def get_val(v): @dlt.resource def data(): - yield from [get_val(v) for v in range(1,max_range)] + yield from [get_val(v) for v in range(1, max_range)] return data() @@ -148,7 +153,7 @@ def data(): def test_save_load_trace() -> None: os.environ["COMPLETED_PROB"] = "1.0" - info = dlt.pipeline().run([1,2,3], table_name="data", destination="dummy") + info = dlt.pipeline().run([1, 2, 3], table_name="data", destination="dummy") pipeline = dlt.pipeline() # will get trace from working dir trace = pipeline.last_trace @@ -190,20 +195,26 @@ def data(): def test_disable_trace(environment: StrStr) -> None: environment["ENABLE_RUNTIME_TRACE"] = "false" environment["COMPLETED_PROB"] = "1.0" - dlt.pipeline().run([1,2,3], table_name="data", destination="dummy") + dlt.pipeline().run([1, 2, 3], table_name="data", destination="dummy") assert dlt.pipeline().last_trace is None def test_trace_on_restore_state(environment: StrStr) -> None: environment["COMPLETED_PROB"] = "1.0" - def _sync_destination_patch(self: Pipeline, destination: str = None, staging: str = None, dataset_name: str = None): + def _sync_destination_patch( + self: Pipeline, destination: str = None, staging: str = None, dataset_name: str = None + ): # just wipe the pipeline simulating deleted dataset self._wipe_working_folder() - self._configure(self._schema_storage_config.export_schema_path, self._schema_storage_config.import_schema_path, False) - - with patch.object(Pipeline, 'sync_destination', _sync_destination_patch): - dlt.pipeline().run([1,2,3], table_name="data", destination="dummy") + self._configure( + self._schema_storage_config.export_schema_path, + self._schema_storage_config.import_schema_path, + False, + ) + + with patch.object(Pipeline, "sync_destination", _sync_destination_patch): + dlt.pipeline().run([1, 2, 3], table_name="data", destination="dummy") assert len(dlt.pipeline().last_trace.steps) == 4 @@ -213,13 +224,15 @@ def test_load_none_trace() -> None: def test_trace_telemetry() -> None: - with patch("dlt.common.runtime.sentry.before_send", _mock_sentry_before_send), patch("dlt.common.runtime.segment.before_send", _mock_segment_before_send): + with patch("dlt.common.runtime.sentry.before_send", _mock_sentry_before_send), patch( + "dlt.common.runtime.segment.before_send", _mock_segment_before_send + ): start_test_telemetry() SEGMENT_SENT_ITEMS.clear() SENTRY_SENT_ITEMS.clear() # default dummy fails all files - dlt.pipeline().run([1,2,3], table_name="data", destination="dummy") + dlt.pipeline().run([1, 2, 3], table_name="data", destination="dummy") # we should have 4 segment items assert len(SEGMENT_SENT_ITEMS) == 4 expected_steps = ["extract", "normalize", "load", "run"] @@ -256,32 +269,48 @@ def data(): assert isinstance(event["properties"]["elapsed"], float) # check extract info if step == "extract": - assert event["properties"]["extract_data"] == [{"name": "data", "data_type": "resource"}] + assert event["properties"]["extract_data"] == [ + {"name": "data", "data_type": "resource"} + ] # we didn't log any errors assert len(SENTRY_SENT_ITEMS) == 0 def test_extract_data_describe() -> None: schema = Schema("test") - assert describe_extract_data(DltSource("sss_extract", "sect", schema)) == [{"name": "sss_extract", "data_type": "source"}] - assert describe_extract_data(DltResource(Pipe("rrr_extract"), None, False)) == [{"name": "rrr_extract", "data_type": "resource"}] - assert describe_extract_data([DltSource("sss_extract", "sect", schema)]) == [{"name": "sss_extract", "data_type": "source"}] - assert describe_extract_data([DltResource(Pipe("rrr_extract"), None, False)]) == [{"name": "rrr_extract", "data_type": "resource"}] + assert describe_extract_data(DltSource("sss_extract", "sect", schema)) == [ + {"name": "sss_extract", "data_type": "source"} + ] + assert describe_extract_data(DltResource(Pipe("rrr_extract"), None, False)) == [ + {"name": "rrr_extract", "data_type": "resource"} + ] + assert describe_extract_data([DltSource("sss_extract", "sect", schema)]) == [ + {"name": "sss_extract", "data_type": "source"} + ] + assert describe_extract_data([DltResource(Pipe("rrr_extract"), None, False)]) == [ + {"name": "rrr_extract", "data_type": "resource"} + ] assert describe_extract_data( [DltResource(Pipe("rrr_extract"), None, False), DltSource("sss_extract", "sect", schema)] - ) == [ - {"name": "rrr_extract", "data_type": "resource"}, {"name": "sss_extract", "data_type": "source"} - ] + ) == [ + {"name": "rrr_extract", "data_type": "resource"}, + {"name": "sss_extract", "data_type": "source"}, + ] assert describe_extract_data([{"a": "b"}]) == [{"name": "", "data_type": "dict"}] from pandas import DataFrame + # we assume that List content has same type - assert describe_extract_data([DataFrame(), {"a": "b"}]) == [{"name": "", "data_type": "DataFrame"}] + assert describe_extract_data([DataFrame(), {"a": "b"}]) == [ + {"name": "", "data_type": "DataFrame"} + ] # first unnamed element in the list breaks checking info assert describe_extract_data( - [DltResource(Pipe("rrr_extract"), None, False), DataFrame(), DltSource("sss_extract", "sect", schema)] - ) == [ - {"name": "rrr_extract", "data_type": "resource"}, {"name": "", "data_type": "DataFrame"} + [ + DltResource(Pipe("rrr_extract"), None, False), + DataFrame(), + DltSource("sss_extract", "sect", schema), ] + ) == [{"name": "rrr_extract", "data_type": "resource"}, {"name": "", "data_type": "DataFrame"}] def test_slack_hook(environment: StrStr) -> None: @@ -293,8 +322,15 @@ def test_slack_hook(environment: StrStr) -> None: environment["RUNTIME__SLACK_INCOMING_HOOK"] = hook_url with requests_mock.mock() as m: m.post(hook_url, json={}) - load_info = dlt.pipeline().run([1,2,3], table_name="data", destination="dummy") - assert slack_notify_load_success(load_info.pipeline.runtime_config.slack_incoming_hook, load_info, load_info.pipeline.last_trace) == 200 + load_info = dlt.pipeline().run([1, 2, 3], table_name="data", destination="dummy") + assert ( + slack_notify_load_success( + load_info.pipeline.runtime_config.slack_incoming_hook, + load_info, + load_info.pipeline.last_trace, + ) + == 200 + ) assert m.called message = m.last_request.json() assert "rudolfix" in message["text"] @@ -304,9 +340,16 @@ def test_slack_hook(environment: StrStr) -> None: def test_broken_slack_hook(environment: StrStr) -> None: environment["COMPLETED_PROB"] = "1.0" environment["RUNTIME__SLACK_INCOMING_HOOK"] = "http://localhost:22" - load_info = dlt.pipeline().run([1,2,3], table_name="data", destination="dummy") + load_info = dlt.pipeline().run([1, 2, 3], table_name="data", destination="dummy") # connection error - assert slack_notify_load_success(load_info.pipeline.runtime_config.slack_incoming_hook, load_info, load_info.pipeline.last_trace) == -1 + assert ( + slack_notify_load_success( + load_info.pipeline.runtime_config.slack_incoming_hook, + load_info, + load_info.pipeline.last_trace, + ) + == -1 + ) # pipeline = dlt.pipeline() # assert pipeline.last_trace is not None # assert pipeline._trace is None @@ -317,21 +360,28 @@ def test_broken_slack_hook(environment: StrStr) -> None: # assert run_step.step_exception is None -def _find_resolved_value(resolved: List[SerializableResolvedValueTrace], key: str, sections: List[str]) -> SerializableResolvedValueTrace: +def _find_resolved_value( + resolved: List[SerializableResolvedValueTrace], key: str, sections: List[str] +) -> SerializableResolvedValueTrace: return next((v for v in resolved if v.key == key and v.sections == sections), None) SEGMENT_SENT_ITEMS = [] + + def _mock_segment_before_send(event: DictStrAny) -> DictStrAny: SEGMENT_SENT_ITEMS.append(event) return event SENTRY_SENT_ITEMS = [] + + def _mock_sentry_before_send(event: DictStrAny, _unused_hint: Any = None) -> DictStrAny: SENTRY_SENT_ITEMS.append(event) return event + def assert_trace_printable(trace: PipelineTrace) -> None: str(trace) trace.asstr(0) diff --git a/tests/pipeline/utils.py b/tests/pipeline/utils.py index ee60099dfb..d2ac25d6a6 100644 --- a/tests/pipeline/utils.py +++ b/tests/pipeline/utils.py @@ -1,12 +1,12 @@ -import pytest from os import environ +import pytest +from tests.utils import TEST_STORAGE_ROOT + import dlt from dlt.common import json from dlt.common.pipeline import LoadInfo, PipelineContext -from tests.utils import TEST_STORAGE_ROOT - PIPELINE_TEST_CASES_PATH = "./tests/pipeline/cases/" @@ -36,10 +36,8 @@ def load_json_case(name: str) -> dict: @dlt.source def airtable_emojis(): - @dlt.resource(name="📆 Schedule") def schedule(): - yield [1, 2, 3] @dlt.resource(name="💰Budget", primary_key=("🔑book_id", "asset_id")) @@ -56,5 +54,4 @@ def peacock(): def wide_peacock(): yield [{"peacock": [1, 2, 3]}] - return budget, schedule, peacock, wide_peacock diff --git a/tests/reflection/module_cases/__init__.py b/tests/reflection/module_cases/__init__.py index 851514132d..1eae3becf1 100644 --- a/tests/reflection/module_cases/__init__.py +++ b/tests/reflection/module_cases/__init__.py @@ -1,4 +1,4 @@ import xxx.absolutely - from xxx.absolutely import a1, a3 -from dlt.common.utils import uniq_id \ No newline at end of file + +from dlt.common.utils import uniq_id diff --git a/tests/reflection/module_cases/all_imports.py b/tests/reflection/module_cases/all_imports.py index 0cfde3a9a1..32ca48ec6f 100644 --- a/tests/reflection/module_cases/all_imports.py +++ b/tests/reflection/module_cases/all_imports.py @@ -1 +1 @@ -from dlt.common.utils import uniq_id \ No newline at end of file +from dlt.common.utils import uniq_id diff --git a/tests/reflection/module_cases/dlt_import_exception.py b/tests/reflection/module_cases/dlt_import_exception.py index a2ae17f418..81d320a29c 100644 --- a/tests/reflection/module_cases/dlt_import_exception.py +++ b/tests/reflection/module_cases/dlt_import_exception.py @@ -1,6 +1,5 @@ from dlt.common.exceptions import MissingDependencyException - try: from xxx.no import e except ImportError: diff --git a/tests/reflection/module_cases/executes_resource.py b/tests/reflection/module_cases/executes_resource.py index a2024398fc..3049eb51f9 100644 --- a/tests/reflection/module_cases/executes_resource.py +++ b/tests/reflection/module_cases/executes_resource.py @@ -1,9 +1,10 @@ import dlt + @dlt.resource def aleph(n: int): for i in range(0, n): yield i -print(list(aleph(10))) \ No newline at end of file +print(list(aleph(10))) diff --git a/tests/reflection/module_cases/import_as_type.py b/tests/reflection/module_cases/import_as_type.py index 500a1bf8a0..38604304ba 100644 --- a/tests/reflection/module_cases/import_as_type.py +++ b/tests/reflection/module_cases/import_as_type.py @@ -1,6 +1,8 @@ from xxx.aa import Tx + def create_tx() -> Tx: return Tx() + tx = Tx() diff --git a/tests/reflection/module_cases/no_pkg.py b/tests/reflection/module_cases/no_pkg.py index 62e3377048..497740970c 100644 --- a/tests/reflection/module_cases/no_pkg.py +++ b/tests/reflection/module_cases/no_pkg.py @@ -1 +1 @@ -from . import uniq_id \ No newline at end of file +from . import uniq_id diff --git a/tests/reflection/module_cases/raises.py b/tests/reflection/module_cases/raises.py index 2c4cc4daa1..06fa4e7831 100644 --- a/tests/reflection/module_cases/raises.py +++ b/tests/reflection/module_cases/raises.py @@ -1,4 +1,5 @@ from xxx.absolutely import a1, a3 + from dlt.common.utils import uniq_id -raise NotImplementedError("empty module") \ No newline at end of file +raise NotImplementedError("empty module") diff --git a/tests/reflection/module_cases/stripe_analytics/__init__.py b/tests/reflection/module_cases/stripe_analytics/__init__.py index 6877ef5475..05bc244fe0 100644 --- a/tests/reflection/module_cases/stripe_analytics/__init__.py +++ b/tests/reflection/module_cases/stripe_analytics/__init__.py @@ -1,2 +1,2 @@ +from .helpers import HELPERS_VALUE from .stripe_analytics import VALUE -from .helpers import HELPERS_VALUE \ No newline at end of file diff --git a/tests/reflection/module_cases/stripe_analytics/stripe_analytics.py b/tests/reflection/module_cases/stripe_analytics/stripe_analytics.py index d41cb0c51a..6ee95e6bf8 100644 --- a/tests/reflection/module_cases/stripe_analytics/stripe_analytics.py +++ b/tests/reflection/module_cases/stripe_analytics/stripe_analytics.py @@ -1,3 +1,3 @@ import stripe -VALUE = 1 \ No newline at end of file +VALUE = 1 diff --git a/tests/reflection/module_cases/stripe_analytics_pipeline.py b/tests/reflection/module_cases/stripe_analytics_pipeline.py index 7cb84c9e6e..edb0467df6 100644 --- a/tests/reflection/module_cases/stripe_analytics_pipeline.py +++ b/tests/reflection/module_cases/stripe_analytics_pipeline.py @@ -1,4 +1,4 @@ -from stripe_analytics import VALUE, HELPERS_VALUE +from stripe_analytics import HELPERS_VALUE, VALUE print(VALUE) -print(HELPERS_VALUE) \ No newline at end of file +print(HELPERS_VALUE) diff --git a/tests/reflection/test_script_inspector.py b/tests/reflection/test_script_inspector.py index 291c823357..d18272e022 100644 --- a/tests/reflection/test_script_inspector.py +++ b/tests/reflection/test_script_inspector.py @@ -1,12 +1,18 @@ from types import SimpleNamespace -import pytest - -from dlt.reflection.script_inspector import load_script_module, inspect_pipeline_script, DummyModule, PipelineIsRunning +import pytest from tests.utils import unload_modules +from dlt.reflection.script_inspector import ( + DummyModule, + PipelineIsRunning, + inspect_pipeline_script, + load_script_module, +) + MODULE_CASES = "./tests/reflection/module_cases" + def test_import_init_module() -> None: with pytest.raises(ModuleNotFoundError): load_script_module("./tests/reflection/", "module_cases", ignore_missing_imports=False) @@ -27,7 +33,9 @@ def test_import_module() -> None: with pytest.raises(ImportError): load_script_module(MODULE_CASES, "no_pkg", ignore_missing_imports=True) # but with package name in module name it will work - m = load_script_module("./tests/reflection/", "module_cases.no_pkg", ignore_missing_imports=True) + m = load_script_module( + "./tests/reflection/", "module_cases.no_pkg", ignore_missing_imports=True + ) # uniq_id got imported assert isinstance(m.uniq_id(), str) @@ -58,4 +66,4 @@ def test_package_dummy_clash() -> None: m = load_script_module(MODULE_CASES, "stripe_analytics_pipeline", ignore_missing_imports=True) # and those would fails assert m.VALUE == 1 - assert m.HELPERS_VALUE == 3 \ No newline at end of file + assert m.HELPERS_VALUE == 3 diff --git a/tests/sources/helpers/test_requests.py b/tests/sources/helpers/test_requests.py index ab86aad240..21ce48f014 100644 --- a/tests/sources/helpers/test_requests.py +++ b/tests/sources/helpers/test_requests.py @@ -1,32 +1,38 @@ -from contextlib import contextmanager -from typing import Iterator, Any, cast, Type -from unittest import mock -from email.utils import format_datetime import os import random +from contextlib import contextmanager +from email.utils import format_datetime +from typing import Any, Iterator, Type, cast +from unittest import mock import pytest import requests import requests_mock -from tenacity import wait_exponential, RetryCallState, RetryError - +from tenacity import RetryCallState, RetryError, wait_exponential from tests.utils import preserve_environ + import dlt from dlt.common.configuration.specs import RunConfiguration -from dlt.sources.helpers.requests import Session, Client, client as default_client +from dlt.sources.helpers.requests import Client, Session +from dlt.sources.helpers.requests import client as default_client from dlt.sources.helpers.requests.retry import ( - DEFAULT_RETRY_EXCEPTIONS, DEFAULT_RETRY_STATUS, retry_if_status, retry_any, Retrying, wait_exponential_retry_after + DEFAULT_RETRY_EXCEPTIONS, + DEFAULT_RETRY_STATUS, + Retrying, + retry_any, + retry_if_status, + wait_exponential_retry_after, ) -@pytest.fixture(scope='function', autouse=True) +@pytest.fixture(scope="function", autouse=True) def mock_sleep() -> Iterator[mock.MagicMock]: - with mock.patch('time.sleep') as m: + with mock.patch("time.sleep") as m: yield m def test_default_session_retry_settings() -> None: - retry: Retrying = Client().session.request.retry # type: ignore + retry: Retrying = Client().session.request.retry # type: ignore assert retry.stop.max_attempt_number == 5 # type: ignore assert isinstance(retry.retry, retry_any) retries = retry.retry.retries @@ -36,7 +42,7 @@ def test_default_session_retry_settings() -> None: assert retry.wait.multiplier == 1 -@pytest.mark.parametrize('respect_retry_after_header', (True, False)) +@pytest.mark.parametrize("respect_retry_after_header", (True, False)) def test_custom_session_retry_settings(respect_retry_after_header: bool) -> None: def custom_retry_cond(response, exception): # type: ignore return True @@ -52,14 +58,14 @@ def custom_retry_cond(response, exception): # type: ignore assert retry.stop.max_attempt_number == 14 # type: ignore assert isinstance(retry.retry, retry_any) retries = retry.retry.retries - assert retries[2].predicate == custom_retry_cond # type: ignore + assert retries[2].predicate == custom_retry_cond # type: ignore assert isinstance(retry.wait, wait_exponential) assert retry.wait.multiplier == 2 def test_retry_on_status_all_fails(mock_sleep: mock.MagicMock) -> None: session = Client().session - url = 'https://example.com/data' + url = "https://example.com/data" with requests_mock.mock(session=session) as m: m.get(url, status_code=503) @@ -68,16 +74,16 @@ def test_retry_on_status_all_fails(mock_sleep: mock.MagicMock) -> None: assert m.call_count == RunConfiguration.request_max_attempts + def test_retry_on_status_success_after_2(mock_sleep: mock.MagicMock) -> None: - """Test successful request after 2 retries - """ + """Test successful request after 2 retries""" session = Client().session - url = 'https://example.com/data' + url = "https://example.com/data" responses = [ - dict(text='error', status_code=503), - dict(text='error', status_code=503), - dict(text='error', status_code=200) + dict(text="error", status_code=503), + dict(text="error", status_code=503), + dict(text="error", status_code=200), ] with requests_mock.mock(session=session) as m: @@ -87,8 +93,9 @@ def test_retry_on_status_success_after_2(mock_sleep: mock.MagicMock) -> None: assert resp.status_code == 200 assert m.call_count == 3 + def test_retry_on_status_without_raise_for_status(mock_sleep: mock.MagicMock) -> None: - url = 'https://example.com/data' + url = "https://example.com/data" session = Client(raise_for_status=False).session with requests_mock.mock(session=session) as m: @@ -98,10 +105,16 @@ def test_retry_on_status_without_raise_for_status(mock_sleep: mock.MagicMock) -> assert m.call_count == RunConfiguration.request_max_attempts -@pytest.mark.parametrize('exception_class', [requests.ConnectionError, requests.ConnectTimeout, requests.exceptions.ChunkedEncodingError]) -def test_retry_on_exception_all_fails(exception_class: Type[Exception], mock_sleep: mock.MagicMock) -> None: + +@pytest.mark.parametrize( + "exception_class", + [requests.ConnectionError, requests.ConnectTimeout, requests.exceptions.ChunkedEncodingError], +) +def test_retry_on_exception_all_fails( + exception_class: Type[Exception], mock_sleep: mock.MagicMock +) -> None: session = Client().session - url = 'https://example.com/data' + url = "https://example.com/data" with requests_mock.mock(session=session) as m: m.get(url, exc=exception_class) @@ -110,41 +123,44 @@ def test_retry_on_exception_all_fails(exception_class: Type[Exception], mock_sle assert m.call_count == RunConfiguration.request_max_attempts + def test_retry_on_custom_condition(mock_sleep: mock.MagicMock) -> None: def retry_on(response: requests.Response, exception: BaseException) -> bool: - return response.text == 'error' + return response.text == "error" session = Client(retry_condition=retry_on).session - url = 'https://example.com/data' + url = "https://example.com/data" with requests_mock.mock(session=session) as m: - m.get(url, text='error') + m.get(url, text="error") response = session.get(url) assert response.content == b"error" assert m.call_count == RunConfiguration.request_max_attempts + def test_retry_on_custom_condition_success_after_2(mock_sleep: mock.MagicMock) -> None: def retry_on(response: requests.Response, exception: BaseException) -> bool: - return response.text == 'error' + return response.text == "error" session = Client(retry_condition=retry_on).session - url = 'https://example.com/data' - responses = [dict(text='error'), dict(text='error'), dict(text='success')] + url = "https://example.com/data" + responses = [dict(text="error"), dict(text="error"), dict(text="success")] with requests_mock.mock(session=session) as m: m.get(url, responses) resp = session.get(url) - assert resp.text == 'success' + assert resp.text == "success" assert m.call_count == 3 + def test_wait_retry_after_int(mock_sleep: mock.MagicMock) -> None: session = Client(request_backoff_factor=0).session - url = 'https://example.com/data' + url = "https://example.com/data" responses = [ - dict(text='error', headers={'retry-after': '4'}, status_code=429), - dict(text='success') + dict(text="error", headers={"retry-after": "4"}, status_code=429), + dict(text="success"), ] with requests_mock.mock(session=session) as m: @@ -155,46 +171,46 @@ def test_wait_retry_after_int(mock_sleep: mock.MagicMock) -> None: assert 4 <= mock_sleep.call_args[0][0] <= 5 # Adds jitter up to 1s -@pytest.mark.parametrize('existing_session', (False, True)) +@pytest.mark.parametrize("existing_session", (False, True)) def test_init_default_client(existing_session: bool) -> None: """Test that the default client config is updated from runtime configuration. Run twice. 1. Clean start with no existing session attached. 2. With session in thread local (session is updated) """ cfg = { - 'RUNTIME__REQUEST_TIMEOUT': random.randrange(1, 100), - 'RUNTIME__REQUEST_MAX_ATTEMPTS': random.randrange(1, 100), - 'RUNTIME__REQUEST_BACKOFF_FACTOR': random.randrange(1, 100), - 'RUNTIME__REQUEST_MAX_RETRY_DELAY': random.randrange(1, 100), + "RUNTIME__REQUEST_TIMEOUT": random.randrange(1, 100), + "RUNTIME__REQUEST_MAX_ATTEMPTS": random.randrange(1, 100), + "RUNTIME__REQUEST_BACKOFF_FACTOR": random.randrange(1, 100), + "RUNTIME__REQUEST_MAX_RETRY_DELAY": random.randrange(1, 100), } os.environ.update({key: str(value) for key, value in cfg.items()}) - dlt.pipeline(pipeline_name='dummy_pipeline') + dlt.pipeline(pipeline_name="dummy_pipeline") session = default_client.session - assert session.timeout == cfg['RUNTIME__REQUEST_TIMEOUT'] + assert session.timeout == cfg["RUNTIME__REQUEST_TIMEOUT"] retry = session.request.retry # type: ignore[attr-defined] - assert retry.wait.multiplier == cfg['RUNTIME__REQUEST_BACKOFF_FACTOR'] - assert retry.stop.max_attempt_number == cfg['RUNTIME__REQUEST_MAX_ATTEMPTS'] - assert retry.wait.max == cfg['RUNTIME__REQUEST_MAX_RETRY_DELAY'] + assert retry.wait.multiplier == cfg["RUNTIME__REQUEST_BACKOFF_FACTOR"] + assert retry.stop.max_attempt_number == cfg["RUNTIME__REQUEST_MAX_ATTEMPTS"] + assert retry.wait.max == cfg["RUNTIME__REQUEST_MAX_RETRY_DELAY"] -@pytest.mark.parametrize('existing_session', (False, True)) +@pytest.mark.parametrize("existing_session", (False, True)) def test_client_instance_with_config(existing_session: bool) -> None: cfg = { - 'RUNTIME__REQUEST_TIMEOUT': random.randrange(1, 100), - 'RUNTIME__REQUEST_MAX_ATTEMPTS': random.randrange(1, 100), - 'RUNTIME__REQUEST_BACKOFF_FACTOR': random.randrange(1, 100), - 'RUNTIME__REQUEST_MAX_RETRY_DELAY': random.randrange(1, 100), + "RUNTIME__REQUEST_TIMEOUT": random.randrange(1, 100), + "RUNTIME__REQUEST_MAX_ATTEMPTS": random.randrange(1, 100), + "RUNTIME__REQUEST_BACKOFF_FACTOR": random.randrange(1, 100), + "RUNTIME__REQUEST_MAX_RETRY_DELAY": random.randrange(1, 100), } os.environ.update({key: str(value) for key, value in cfg.items()}) client = Client() session = client.session - assert session.timeout == cfg['RUNTIME__REQUEST_TIMEOUT'] + assert session.timeout == cfg["RUNTIME__REQUEST_TIMEOUT"] retry = session.request.retry # type: ignore[attr-defined] - assert retry.wait.multiplier == cfg['RUNTIME__REQUEST_BACKOFF_FACTOR'] - assert retry.stop.max_attempt_number == cfg['RUNTIME__REQUEST_MAX_ATTEMPTS'] - assert retry.wait.max == cfg['RUNTIME__REQUEST_MAX_RETRY_DELAY'] + assert retry.wait.multiplier == cfg["RUNTIME__REQUEST_BACKOFF_FACTOR"] + assert retry.stop.max_attempt_number == cfg["RUNTIME__REQUEST_MAX_ATTEMPTS"] + assert retry.wait.max == cfg["RUNTIME__REQUEST_MAX_RETRY_DELAY"] diff --git a/tests/tools/clean_redshift.py b/tests/tools/clean_redshift.py index 7dea0ba3e1..92a2d400df 100644 --- a/tests/tools/clean_redshift.py +++ b/tests/tools/clean_redshift.py @@ -1,9 +1,10 @@ -from dlt.destinations.postgres.postgres import PostgresClient, psycopg2 from psycopg2.errors import InsufficientPrivilege, InternalError_, SyntaxError +from dlt.destinations.postgres.postgres import PostgresClient, psycopg2 + CONNECTION_STRING = "" -if __name__ == '__main__': +if __name__ == "__main__": # connect connection = psycopg2.connect(CONNECTION_STRING) connection.set_isolation_level(0) diff --git a/tests/tools/create_storages.py b/tests/tools/create_storages.py index 4f0abe3512..43ee846313 100644 --- a/tests/tools/create_storages.py +++ b/tests/tools/create_storages.py @@ -1,5 +1,11 @@ -from dlt.common.storages import NormalizeStorage, LoadStorage, SchemaStorage, NormalizeStorageConfiguration, LoadStorageConfiguration, SchemaStorageConfiguration - +from dlt.common.storages import ( + LoadStorage, + LoadStorageConfiguration, + NormalizeStorage, + NormalizeStorageConfiguration, + SchemaStorage, + SchemaStorageConfiguration, +) # NormalizeStorage(True, NormalizeVolumeConfiguration) # LoadStorage(True, LoadVolumeConfiguration, "jsonl", LoadStorage.ALL_SUPPORTED_FILE_FORMATS) diff --git a/tests/utils.py b/tests/utils.py index 01cecae04e..99f924d79c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,32 +1,44 @@ -import os -import sys import multiprocessing +import os import platform -import requests -import pytest +import sys from os import environ from typing import Iterator, List from unittest.mock import patch +import pytest +import requests + import dlt from dlt.common.configuration.container import Container from dlt.common.configuration.providers import DictionaryProvider from dlt.common.configuration.resolve import resolve_configuration from dlt.common.configuration.specs import RunConfiguration from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext +from dlt.common.pipeline import PipelineContext from dlt.common.runtime.logger import init_logging from dlt.common.runtime.telemetry import start_telemetry, stop_telemetry -from dlt.common.storages import FileStorage from dlt.common.schema import Schema +from dlt.common.storages import FileStorage from dlt.common.storages.versioned_storage import VersionedStorage from dlt.common.typing import StrAny from dlt.common.utils import custom_environ, uniq_id -from dlt.common.pipeline import PipelineContext TEST_STORAGE_ROOT = "_storage" # destination constants -IMPLEMENTED_DESTINATIONS = {"athena", "duckdb", "bigquery", "redshift", "postgres", "snowflake", "filesystem", "weaviate", "dummy", "motherduck"} +IMPLEMENTED_DESTINATIONS = { + "athena", + "duckdb", + "bigquery", + "redshift", + "postgres", + "snowflake", + "filesystem", + "weaviate", + "dummy", + "motherduck", +} NON_SQL_DESTINATIONS = {"filesystem", "weaviate", "dummy", "motherduck"} SQL_DESTINATIONS = IMPLEMENTED_DESTINATIONS - NON_SQL_DESTINATIONS @@ -47,6 +59,7 @@ for destination in ACTIVE_DESTINATIONS: assert destination in IMPLEMENTED_DESTINATIONS, f"Unknown active destination {destination}" + def TEST_DICT_CONFIG_PROVIDER(): # add test dictionary provider providers_context = Container()[ConfigProvidersContext] @@ -57,7 +70,8 @@ def TEST_DICT_CONFIG_PROVIDER(): providers_context.add_provider(provider) return provider -class MockHttpResponse(): + +class MockHttpResponse: def __init__(self, status_code: int) -> None: self.status_code = status_code @@ -153,15 +167,19 @@ def start_test_telemetry(c: RunConfiguration = None): start_telemetry(c) -def clean_test_storage(init_normalize: bool = False, init_loader: bool = False, mode: str = "t") -> FileStorage: +def clean_test_storage( + init_normalize: bool = False, init_loader: bool = False, mode: str = "t" +) -> FileStorage: storage = FileStorage(TEST_STORAGE_ROOT, mode, makedirs=True) storage.delete_folder("", recursively=True, delete_ro=True) storage.create_folder(".") if init_normalize: from dlt.common.storages import NormalizeStorage + NormalizeStorage(True) if init_loader: from dlt.common.storages import LoadStorage + LoadStorage(True, "jsonl", LoadStorage.ALL_SUPPORTED_FILE_FORMATS) return storage @@ -174,11 +192,13 @@ def create_schema_with_name(schema_name) -> Schema: def assert_no_dict_key_starts_with(d: StrAny, key_prefix: str) -> None: assert all(not key.startswith(key_prefix) for key in d.keys()) + def skip_if_not_active(destination: str) -> None: assert destination in IMPLEMENTED_DESTINATIONS, f"Unknown skipped destination {destination}" if destination not in ACTIVE_DESTINATIONS: pytest.skip(f"{destination} not in ACTIVE_DESTINATIONS", allow_module_level=True) + skipifspawn = pytest.mark.skipif( multiprocessing.get_start_method() != "fork", reason="process fork not supported" ) @@ -187,9 +207,7 @@ def skip_if_not_active(destination: str) -> None: platform.python_implementation() == "PyPy", reason="won't run in PyPy interpreter" ) -skipifnotwindows = pytest.mark.skipif( - platform.system() != "Windows", reason="runs only on windows" -) +skipifnotwindows = pytest.mark.skipif(platform.system() != "Windows", reason="runs only on windows") skipifwindows = pytest.mark.skipif( platform.system() == "Windows", reason="does not runs on windows"