Skip to content

Commit

Permalink
File-based CDK: make incremental syncs concurrent (airbytehq#34540)
Browse files Browse the repository at this point in the history
  • Loading branch information
clnoll authored and jatinyadav-cc committed Feb 26, 2024
1 parent 5dcec71 commit c5dd63d
Show file tree
Hide file tree
Showing 21 changed files with 3,937 additions and 124 deletions.
2 changes: 1 addition & 1 deletion airbyte-cdk/python/airbyte_cdk/sources/abstract_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def read(
# TODO assert all streams exist in the connector
# get the streams once in case the connector needs to make any queries to generate them
stream_instances = {s.name: s for s in self.streams(config)}
state_manager = ConnectorStateManager(stream_instance_map=stream_instances, state=state)
state_manager = ConnectorStateManager(stream_instance_map={s.stream.name: s.stream for s in catalog.streams}, state=state)
self._stream_to_instance_map = stream_instances

stream_name_to_exception: MutableMapping[str, AirbyteTracedException] = {}
Expand Down
16 changes: 13 additions & 3 deletions airbyte-cdk/python/airbyte_cdk/sources/connector_state_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,15 @@
import copy
from typing import Any, List, Mapping, MutableMapping, Optional, Tuple, Union

from airbyte_cdk.models import AirbyteMessage, AirbyteStateBlob, AirbyteStateMessage, AirbyteStateType, AirbyteStreamState, StreamDescriptor
from airbyte_cdk.models import (
AirbyteMessage,
AirbyteStateBlob,
AirbyteStateMessage,
AirbyteStateType,
AirbyteStream,
AirbyteStreamState,
StreamDescriptor,
)
from airbyte_cdk.models import Type as MessageType
from airbyte_cdk.sources.streams import Stream
from pydantic import Extra
Expand All @@ -29,7 +37,9 @@ class ConnectorStateManager:
"""

def __init__(
self, stream_instance_map: Mapping[str, Stream], state: Optional[Union[List[AirbyteStateMessage], MutableMapping[str, Any]]] = None
self,
stream_instance_map: Mapping[str, AirbyteStream],
state: Optional[Union[List[AirbyteStateMessage], MutableMapping[str, Any]]] = None,
):
shared_state, per_stream_states = self._extract_from_state_message(state, stream_instance_map)

Expand Down Expand Up @@ -97,7 +107,7 @@ def create_state_message(self, stream_name: str, namespace: Optional[str], send_

@classmethod
def _extract_from_state_message(
cls, state: Optional[Union[List[AirbyteStateMessage], MutableMapping[str, Any]]], stream_instance_map: Mapping[str, Stream]
cls, state: Optional[Union[List[AirbyteStateMessage], MutableMapping[str, Any]]], stream_instance_map: Mapping[str, AirbyteStream]
) -> Tuple[Optional[AirbyteStateBlob], MutableMapping[HashableStreamDescriptor, Optional[AirbyteStateBlob]]]:
"""
Takes an incoming list of state messages or the legacy state format and extracts state attributes according to type
Expand Down
122 changes: 87 additions & 35 deletions airbyte-cdk/python/airbyte_cdk/sources/file_based/file_based_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from airbyte_cdk.models import (
AirbyteMessage,
AirbyteStateMessage,
AirbyteStream,
ConfiguredAirbyteCatalog,
ConnectorSpecification,
FailureType,
Expand All @@ -20,6 +21,7 @@
)
from airbyte_cdk.sources.concurrent_source.concurrent_source import ConcurrentSource
from airbyte_cdk.sources.concurrent_source.concurrent_source_adapter import ConcurrentSourceAdapter
from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager
from airbyte_cdk.sources.file_based.availability_strategy import AbstractFileBasedAvailabilityStrategy, DefaultFileBasedAvailabilityStrategy
from airbyte_cdk.sources.file_based.config.abstract_file_based_spec import AbstractFileBasedSpec
from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig, ValidationPolicy
Expand All @@ -31,12 +33,15 @@
from airbyte_cdk.sources.file_based.schema_validation_policies import DEFAULT_SCHEMA_VALIDATION_POLICIES, AbstractSchemaValidationPolicy
from airbyte_cdk.sources.file_based.stream import AbstractFileBasedStream, DefaultFileBasedStream
from airbyte_cdk.sources.file_based.stream.concurrent.adapters import FileBasedStreamFacade
from airbyte_cdk.sources.file_based.stream.concurrent.cursor import FileBasedNoopCursor
from airbyte_cdk.sources.file_based.stream.concurrent.cursor import (
AbstractConcurrentFileBasedCursor,
FileBasedConcurrentCursor,
FileBasedNoopCursor,
)
from airbyte_cdk.sources.file_based.stream.cursor import AbstractFileBasedCursor
from airbyte_cdk.sources.file_based.stream.cursor.default_file_based_cursor import DefaultFileBasedCursor
from airbyte_cdk.sources.message.repository import InMemoryMessageRepository, MessageRepository
from airbyte_cdk.sources.source import TState
from airbyte_cdk.sources.streams import Stream
from airbyte_cdk.sources.streams.concurrent.cursor import CursorField
from airbyte_cdk.utils.analytics_message import create_analytics_message
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
from pydantic.error_wrappers import ValidationError
Expand All @@ -56,12 +61,12 @@ def __init__(
spec_class: Type[AbstractFileBasedSpec],
catalog: Optional[ConfiguredAirbyteCatalog],
config: Optional[Mapping[str, Any]],
state: Optional[TState],
state: Optional[MutableMapping[str, Any]],
availability_strategy: Optional[AbstractFileBasedAvailabilityStrategy] = None,
discovery_policy: AbstractDiscoveryPolicy = DefaultDiscoveryPolicy(),
parsers: Mapping[Type[Any], FileTypeParser] = default_parsers,
validation_policies: Mapping[ValidationPolicy, AbstractSchemaValidationPolicy] = DEFAULT_SCHEMA_VALIDATION_POLICIES,
cursor_cls: Type[AbstractFileBasedCursor] = DefaultFileBasedCursor,
cursor_cls: Type[Union[AbstractConcurrentFileBasedCursor, AbstractFileBasedCursor]] = FileBasedConcurrentCursor,
):
self.stream_reader = stream_reader
self.spec_class = spec_class
Expand Down Expand Up @@ -137,52 +142,99 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]:
"""
Return a list of this source's streams.
"""
file_based_streams = self._get_file_based_streams(config)

configured_streams: List[Stream] = []

for stream in file_based_streams:
sync_mode = self._get_sync_mode_from_catalog(stream)
if sync_mode == SyncMode.full_refresh and hasattr(self, "_concurrency_level") and self._concurrency_level is not None:
configured_streams.append(
FileBasedStreamFacade.create_from_stream(stream, self, self.logger, None, FileBasedNoopCursor(stream.config))
)
else:
configured_streams.append(stream)

return configured_streams
if self.catalog:
state_manager = ConnectorStateManager(
stream_instance_map={s.stream.name: s.stream for s in self.catalog.streams},
state=self.state,
)
else:
# During `check` operations we don't have a catalog so cannot create a state manager.
# Since the state manager is only required for incremental syncs, this is fine.
state_manager = None

def _get_file_based_streams(self, config: Mapping[str, Any]) -> List[AbstractFileBasedStream]:
try:
parsed_config = self._get_parsed_config(config)
self.stream_reader.config = parsed_config
streams: List[AbstractFileBasedStream] = []
streams: List[Stream] = []
for stream_config in parsed_config.streams:
# Like state_manager, `catalog_stream` may be None during `check`
catalog_stream = self._get_stream_from_catalog(stream_config)
stream_state = (
state_manager.get_stream_state(catalog_stream.name, catalog_stream.namespace)
if (state_manager and catalog_stream)
else None
)
self._validate_input_schema(stream_config)
streams.append(
DefaultFileBasedStream(
config=stream_config,
catalog_schema=self.stream_schemas.get(stream_config.name),
stream_reader=self.stream_reader,
availability_strategy=self.availability_strategy,
discovery_policy=self.discovery_policy,
parsers=self.parsers,
validation_policy=self._validate_and_get_validation_policy(stream_config),
cursor=self.cursor_cls(stream_config),
errors_collector=self.errors_collector,

sync_mode = self._get_sync_mode_from_catalog(stream_config.name)

if sync_mode == SyncMode.full_refresh and hasattr(self, "_concurrency_level") and self._concurrency_level is not None:
cursor = FileBasedNoopCursor(stream_config)
stream = FileBasedStreamFacade.create_from_stream(
self._make_default_stream(stream_config, cursor), self, self.logger, stream_state, cursor
)
)

elif (
sync_mode == SyncMode.incremental
and issubclass(self.cursor_cls, AbstractConcurrentFileBasedCursor)
and hasattr(self, "_concurrency_level")
and self._concurrency_level is not None
):
assert (
state_manager is not None
), "No ConnectorStateManager was created, but it is required for incremental syncs. This is unexpected. Please contact Support."

cursor = self.cursor_cls(
stream_config,
stream_config.name,
None,
stream_state,
self.message_repository,
state_manager,
CursorField(DefaultFileBasedStream.ab_last_mod_col),
)
stream = FileBasedStreamFacade.create_from_stream(
self._make_default_stream(stream_config, cursor), self, self.logger, stream_state, cursor
)
else:
cursor = self.cursor_cls(stream_config)
stream = self._make_default_stream(stream_config, cursor)

streams.append(stream)
return streams

except ValidationError as exc:
raise ConfigValidationError(FileBasedSourceError.CONFIG_VALIDATION_ERROR) from exc

def _get_sync_mode_from_catalog(self, stream: Stream) -> Optional[SyncMode]:
def _make_default_stream(
self, stream_config: FileBasedStreamConfig, cursor: Optional[AbstractFileBasedCursor]
) -> AbstractFileBasedStream:
return DefaultFileBasedStream(
config=stream_config,
catalog_schema=self.stream_schemas.get(stream_config.name),
stream_reader=self.stream_reader,
availability_strategy=self.availability_strategy,
discovery_policy=self.discovery_policy,
parsers=self.parsers,
validation_policy=self._validate_and_get_validation_policy(stream_config),
errors_collector=self.errors_collector,
cursor=cursor,
)

def _get_stream_from_catalog(self, stream_config: FileBasedStreamConfig) -> Optional[AirbyteStream]:
if self.catalog:
for stream in self.catalog.streams or []:
if stream.stream.name == stream_config.name:
return stream.stream
return None

def _get_sync_mode_from_catalog(self, stream_name: str) -> Optional[SyncMode]:
if self.catalog:
for catalog_stream in self.catalog.streams:
if stream.name == catalog_stream.stream.name:
if stream_name == catalog_stream.stream.name:
return catalog_stream.sync_mode
self.logger.warning(f"No sync mode was found for {stream.name}.")
self.logger.warning(f"No sync mode was found for {stream_name}.")
return None

def read(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeParser
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
from airbyte_cdk.sources.file_based.schema_validation_policies import AbstractSchemaValidationPolicy
from airbyte_cdk.sources.file_based.stream.cursor import AbstractFileBasedCursor
from airbyte_cdk.sources.file_based.types import StreamSlice
from airbyte_cdk.sources.streams import Stream

Expand Down Expand Up @@ -45,6 +46,7 @@ def __init__(
parsers: Dict[Type[Any], FileTypeParser],
validation_policy: AbstractSchemaValidationPolicy,
errors_collector: FileBasedErrorsCollector,
cursor: AbstractFileBasedCursor,
):
super().__init__()
self.config = config
Expand All @@ -55,6 +57,7 @@ def __init__(
self._availability_strategy = availability_strategy
self._parsers = parsers
self.errors_collector = errors_collector
self._cursor = cursor

@property
@abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import copy
import logging
from functools import lru_cache
from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Union
from typing import TYPE_CHECKING, Any, Iterable, List, Mapping, MutableMapping, Optional, Union

from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level, SyncMode, Type
from airbyte_cdk.sources import AbstractSource
Expand All @@ -19,6 +19,7 @@
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
from airbyte_cdk.sources.file_based.stream import AbstractFileBasedStream
from airbyte_cdk.sources.file_based.stream.concurrent.cursor import FileBasedNoopCursor
from airbyte_cdk.sources.file_based.stream.cursor import AbstractFileBasedCursor
from airbyte_cdk.sources.file_based.types import StreamSlice
from airbyte_cdk.sources.message import MessageRepository
from airbyte_cdk.sources.streams.concurrent.abstract_stream_facade import AbstractStreamFacade
Expand All @@ -33,6 +34,9 @@
from airbyte_cdk.sources.utils.slice_logger import SliceLogger
from deprecated.classic import deprecated

if TYPE_CHECKING:
from airbyte_cdk.sources.file_based.stream.concurrent.cursor import AbstractConcurrentFileBasedCursor

"""
This module contains adapters to help enabling concurrency on File-based Stream objects without needing to migrate to AbstractStream
"""
Expand All @@ -47,13 +51,14 @@ def create_from_stream(
source: AbstractSource,
logger: logging.Logger,
state: Optional[MutableMapping[str, Any]],
cursor: FileBasedNoopCursor,
cursor: "AbstractConcurrentFileBasedCursor",
) -> "FileBasedStreamFacade":
"""
Create a ConcurrentStream from a FileBasedStream object.
"""
pk = get_primary_key_from_stream(stream.primary_key)
cursor_field = get_cursor_field_from_stream(stream)
stream._cursor = cursor

if not source.message_repository:
raise ValueError(
Expand All @@ -62,7 +67,7 @@ def create_from_stream(

message_repository = source.message_repository
return FileBasedStreamFacade(
DefaultStream( # type: ignore
DefaultStream(
partition_generator=FileBasedStreamPartitionGenerator(
stream,
message_repository,
Expand Down Expand Up @@ -90,14 +95,13 @@ def __init__(
self,
stream: DefaultStream,
legacy_stream: AbstractFileBasedStream,
cursor: FileBasedNoopCursor,
cursor: AbstractFileBasedCursor,
slice_logger: SliceLogger,
logger: logging.Logger,
):
"""
:param stream: The underlying AbstractStream
"""
# super().__init__(stream, legacy_stream, cursor, slice_logger, logger)
self._abstract_stream = stream
self._legacy_stream = legacy_stream
self._cursor = cursor
Expand Down Expand Up @@ -216,7 +220,7 @@ def __init__(
sync_mode: SyncMode,
cursor_field: Optional[List[str]],
state: Optional[MutableMapping[str, Any]],
cursor: FileBasedNoopCursor,
cursor: "AbstractConcurrentFileBasedCursor",
):
self._stream = stream
self._slice = _slice
Expand Down Expand Up @@ -292,7 +296,7 @@ def __init__(
sync_mode: SyncMode,
cursor_field: Optional[List[str]],
state: Optional[MutableMapping[str, Any]],
cursor: FileBasedNoopCursor,
cursor: "AbstractConcurrentFileBasedCursor",
):
self._stream = stream
self._message_repository = message_repository
Expand All @@ -305,19 +309,17 @@ def generate(self) -> Iterable[FileBasedStreamPartition]:
pending_partitions = []
for _slice in self._stream.stream_slices(sync_mode=self._sync_mode, cursor_field=self._cursor_field, stream_state=self._state):
if _slice is not None:
pending_partitions.extend(
[
for file in _slice.get("files", []):
pending_partitions.append(
FileBasedStreamPartition(
self._stream,
{"files": [copy.deepcopy(f)]},
{"files": [copy.deepcopy(file)]},
self._message_repository,
self._sync_mode,
self._cursor_field,
self._state,
self._cursor,
)
for f in _slice.get("files", [])
]
)
)
self._cursor.set_pending_partitions(pending_partitions)
yield from pending_partitions
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .abstract_concurrent_file_based_cursor import AbstractConcurrentFileBasedCursor
from .file_based_noop_cursor import FileBasedNoopCursor
from .file_based_concurrent_cursor import FileBasedConcurrentCursor

__all__ = ["AbstractConcurrentFileBasedCursor", "FileBasedConcurrentCursor", "FileBasedNoopCursor"]
Loading

0 comments on commit c5dd63d

Please sign in to comment.