Skip to content

Commit

Permalink
AirbyteLib: Require stream selection (airbytehq#34979)
Browse files Browse the repository at this point in the history
  • Loading branch information
aaronsteers authored and jatinyadav-cc committed Feb 26, 2024
1 parent 3732d0f commit 0b78ca0
Show file tree
Hide file tree
Showing 8 changed files with 145 additions and 25 deletions.
27 changes: 27 additions & 0 deletions airbyte-lib/airbyte_lib/_factories/connector_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import shutil
import warnings
from pathlib import Path
from typing import Any

Expand All @@ -11,6 +12,32 @@
from airbyte_lib.source import Source


def get_connector(
name: str,
config: dict[str, Any] | None = None,
*,
version: str | None = None,
pip_url: str | None = None,
local_executable: Path | str | None = None,
install_if_missing: bool = True,
) -> Source:
"""Deprecated. Use get_source instead."""
warnings.warn(
"The `get_connector()` function is deprecated and will be removed in a future version."
"Please use `get_source()` instead.",
DeprecationWarning,
stacklevel=2,
)
return get_source(
name=name,
config=config,
version=version,
pip_url=pip_url,
local_executable=local_executable,
install_if_missing=install_if_missing,
)


def get_source(
name: str,
config: dict[str, Any] | None = None,
Expand Down
12 changes: 12 additions & 0 deletions airbyte-lib/airbyte_lib/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,18 @@ class AirbyteLibInputError(AirbyteError, ValueError):
input_value: str | None = None


@dataclass
class AirbyteLibNoStreamsSelectedError(AirbyteLibInputError):
"""No streams were selected for the source."""

guidance = (
"Please call `select_streams()` to select at least one stream from the list provided. "
"You can also call `select_all_streams()` to select all available streams for this source."
)
connector_name: str | None = None
available_streams: list[str] | None = None


# AirbyteLib Cache Errors


Expand Down
57 changes: 40 additions & 17 deletions airbyte-lib/airbyte_lib/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import json
import tempfile
import warnings
from contextlib import contextmanager, suppress
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -90,20 +91,34 @@ def __init__(
self._last_log_messages: list[str] = []
self._discovered_catalog: AirbyteCatalog | None = None
self._spec: ConnectorSpecification | None = None
self._selected_stream_names: list[str] | None = None
self._selected_stream_names: list[str] = []
if config is not None:
self.set_config(config, validate=validate)
if streams is not None:
self.set_streams(streams)

def set_streams(self, streams: list[str]) -> None:
"""Optionally, select the stream names that should be read from the connector.
"""Deprecated. See select_streams()."""
warnings.warn(
"The 'set_streams' method is deprecated and will be removed in a future version. "
"Please use the 'select_streams' method instead.",
DeprecationWarning,
stacklevel=2,
)
self.select_streams(streams)

Currently, if this is not set, all streams will be read.
def select_all_streams(self) -> None:
"""Select all streams.
TODO: In the future if not set, the default behavior may exclude streams which the connector
would default to disabled. (For instance, streams that require a premium license
are sometimes disabled by default within the connector.)
This is a more streamlined equivalent to:
> source.select_streams(source.get_available_streams()).
"""
self._selected_stream_names = self.get_available_streams()

def select_streams(self, streams: list[str]) -> None:
"""Select the stream names that should be read from the connector.
Currently, if this is not set, all streams will be read.
"""
available_streams = self.get_available_streams()
for stream in streams:
Expand All @@ -118,12 +133,9 @@ def set_streams(self, streams: list[str]) -> None:
def get_selected_streams(self) -> list[str]:
"""Get the selected streams.
If no streams are selected, return all available streams.
If no streams are selected, return an empty list.
"""
if self._selected_stream_names:
return self._selected_stream_names

return self.get_available_streams()
return self._selected_stream_names

def set_config(
self,
Expand Down Expand Up @@ -252,19 +264,24 @@ def configured_catalog(self) -> ConfiguredAirbyteCatalog:
disable by default. (For instance, streams that require a premium license are sometimes
disabled by default within the connector.)
"""
_ = self.discovered_catalog # Ensure discovered catalog is cached before we start
streams_filter: list[str] | None = self._selected_stream_names
# Ensure discovered catalog is cached before we start
_ = self.discovered_catalog

# Filter for selected streams if set, otherwise use all available streams:
streams_filter: list[str] = self._selected_stream_names or self.get_available_streams()

return ConfiguredAirbyteCatalog(
streams=[
# TODO: Set sync modes and primary key to a sensible adaptive default
ConfiguredAirbyteStream(
stream=stream,
sync_mode=SyncMode.incremental,
destination_sync_mode=DestinationSyncMode.overwrite,
primary_key=stream.source_defined_primary_key,
# TODO: The below assumes all sources can coalesce from incremental sync to
# full_table as needed. CDK supports this, so it might be safe:
sync_mode=SyncMode.incremental,
)
for stream in self.discovered_catalog.streams
if streams_filter is None or stream.name in streams_filter
if stream.name in streams_filter
],
)

Expand Down Expand Up @@ -530,10 +547,16 @@ def read(
},
) from None

if not self._selected_stream_names:
raise exc.AirbyteLibNoStreamsSelectedError(
connector_name=self.name,
available_streams=self.get_available_streams(),
)

cache.register_source(
source_name=self.name,
incoming_source_catalog=self.configured_catalog,
stream_names=set(self.get_selected_streams()),
stream_names=set(self._selected_stream_names),
)
state = cache.get_state() if not force_full_refresh else None
print(f"Started `{self.name}` read operation at {pendulum.now().format('HH:mm:ss')}...")
Expand Down
44 changes: 38 additions & 6 deletions airbyte-lib/docs/generated/airbyte_lib.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion airbyte-lib/examples/run_faker.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
)
print("Faker source installed.")
source.check()
source.set_streams(["products", "users", "purchases"])
source.select_streams(["products", "users", "purchases"])

result = source.read()

Expand Down
2 changes: 1 addition & 1 deletion airbyte-lib/examples/run_github.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
}
)
source.check()
source.set_streams(["issues", "pull_requests", "commits"])
source.select_streams(["issues", "pull_requests", "commits"])

result = source.read()

Expand Down
25 changes: 25 additions & 0 deletions airbyte-lib/tests/integration_tests/test_source_test_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def test_file_write_and_cleanup() -> None:
cache_wo_cleanup = ab.new_local_cache(cache_dir=temp_dir_2, cleanup=False)

source = ab.get_source("source-test", config={"apiKey": "test"})
source.select_all_streams()

_ = source.read(cache_w_cleanup)
_ = source.read(cache_wo_cleanup)
Expand All @@ -207,6 +208,8 @@ def assert_cache_data(expected_test_stream_data: dict[str, list[dict[str, str |

def test_sync_to_duckdb(expected_test_stream_data: dict[str, list[dict[str, str | int]]]):
source = ab.get_source("source-test", config={"apiKey": "test"})
source.select_all_streams()

cache = ab.new_local_cache()

result: ReadResult = source.read(cache)
Expand All @@ -217,6 +220,7 @@ def test_sync_to_duckdb(expected_test_stream_data: dict[str, list[dict[str, str

def test_read_result_mapping():
source = ab.get_source("source-test", config={"apiKey": "test"})
source.select_all_streams()
result: ReadResult = source.read(ab.new_local_cache())
assert len(result) == 2
assert isinstance(result, Mapping)
Expand All @@ -228,6 +232,8 @@ def test_read_result_mapping():

def test_dataset_list_and_len(expected_test_stream_data):
source = ab.get_source("source-test", config={"apiKey": "test"})
source.select_all_streams()

result: ReadResult = source.read(ab.new_local_cache())
stream_1 = result["stream1"]
assert len(stream_1) == 2
Expand All @@ -250,6 +256,8 @@ def test_read_from_cache(expected_test_stream_data: dict[str, list[dict[str, str
"""
cache_name = str(ulid.ULID())
source = ab.get_source("source-test", config={"apiKey": "test"})
source.select_all_streams()

cache = ab.new_local_cache(cache_name)

source.read(cache)
Expand All @@ -268,6 +276,7 @@ def test_read_isolated_by_prefix(expected_test_stream_data: dict[str, list[dict[
cache_name = str(ulid.ULID())
db_path = Path(f"./.cache/{cache_name}.duckdb")
source = ab.get_source("source-test", config={"apiKey": "test"})
source.select_all_streams()
cache = ab.DuckDBCache(config=ab.DuckDBCacheConfig(db_path=db_path, table_prefix="prefix_"))

source.read(cache)
Expand Down Expand Up @@ -325,6 +334,8 @@ def test_merge_streams_in_cache(expected_test_stream_data: dict[str, list[dict[s

def test_read_result_as_list(expected_test_stream_data: dict[str, list[dict[str, str | int]]]):
source = ab.get_source("source-test", config={"apiKey": "test"})
source.select_all_streams()

cache = ab.new_local_cache()

result: ReadResult = source.read(cache)
Expand Down Expand Up @@ -354,6 +365,8 @@ def test_sync_with_merge_to_duckdb(expected_test_stream_data: dict[str, list[dic
# TODO: Add a check with a primary key to ensure that the merge strategy works as expected.
"""
source = ab.get_source("source-test", config={"apiKey": "test"})
source.select_all_streams()

cache = ab.new_local_cache()

# Read twice to test merge strategy
Expand All @@ -373,6 +386,8 @@ def test_cached_dataset(
expected_test_stream_data: dict[str, list[dict[str, str | int]]],
) -> None:
source = ab.get_source("source-test", config={"apiKey": "test"})
source.select_all_streams()

result: ReadResult = source.read(ab.new_local_cache())

stream_name = "stream1"
Expand Down Expand Up @@ -435,6 +450,8 @@ def test_cached_dataset(

def test_cached_dataset_filter():
source = ab.get_source("source-test", config={"apiKey": "test"})
source.select_all_streams()

result: ReadResult = source.read(ab.new_local_cache())

stream_name = "stream1"
Expand Down Expand Up @@ -532,6 +549,8 @@ def test_sync_with_merge_to_postgres(new_pg_cache_config: PostgresCacheConfig, e
# TODO: Add a check with a primary key to ensure that the merge strategy works as expected.
"""
source = ab.get_source("source-test", config={"apiKey": "test"})
source.select_all_streams()

cache = PostgresCache(config=new_pg_cache_config)

# Read twice to test merge strategy
Expand Down Expand Up @@ -582,6 +601,8 @@ def test_tracking(mock_datetime: Mock, mock_requests: Mock, raises: bool, api_ke
mock_requests.post = mock_post

source = ab.get_source("source-test", config={"apiKey": api_key})
source.select_all_streams()

cache = ab.new_local_cache()

if request_call_fails:
Expand Down Expand Up @@ -635,6 +656,8 @@ def test_tracking(mock_datetime: Mock, mock_requests: Mock, raises: bool, api_ke

def test_sync_to_postgres(new_pg_cache_config: PostgresCacheConfig, expected_test_stream_data: dict[str, list[dict[str, str | int]]]):
source = ab.get_source("source-test", config={"apiKey": "test"})
source.select_all_streams()

cache = PostgresCache(config=new_pg_cache_config)

result: ReadResult = source.read(cache)
Expand All @@ -649,6 +672,8 @@ def test_sync_to_postgres(new_pg_cache_config: PostgresCacheConfig, expected_tes

def test_sync_to_snowflake(snowflake_config: SnowflakeCacheConfig, expected_test_stream_data: dict[str, list[dict[str, str | int]]]):
source = ab.get_source("source-test", config={"apiKey": "test"})
source.select_all_streams()

cache = SnowflakeSQLCache(config=snowflake_config)

with cache.get_sql_connection() as con:
Expand Down
1 change: 1 addition & 0 deletions docs/using-airbyte/airbyte-lib/getting-started.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ source = ab.get_source(
install_if_missing=True,
)
source.check()
source.select_all_streams()
result = source.read()

for name, records in result.streams.items():
Expand Down

0 comments on commit 0b78ca0

Please sign in to comment.