Skip to content

Commit

Permalink
Extend abnormal state test to include cursor format check (airbytehq#…
Browse files Browse the repository at this point in the history
  • Loading branch information
roman-yermilov-gl authored and kabeer27 committed Jun 11, 2024
1 parent 7471240 commit 6792d86
Show file tree
Hide file tree
Showing 6 changed files with 360 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,29 @@ class FullRefreshConfig(BaseConfig):
)


class FutureStateCursorFormatStreamConfiguration(BaseConfig):
name: str
format: Optional[str] = Field(default=None, description="Expected format of the cursor value")


class FutureStateCursorFormatConfiguration(BaseConfig):
format: Optional[str] = Field(
default=None,
description="The default format of the cursor value will be used for all streams except those defined in the streams section",
)
streams: List[FutureStateCursorFormatStreamConfiguration] = Field(
default_factory=list, description="Expected cursor value format for a particular stream"
)


class FutureStateConfig(BaseConfig):
future_state_path: Optional[str] = Field(description="Path to a state file with values in far future")
missing_streams: List[EmptyStreamConfiguration] = Field(default=[], description="List of missing streams with valid bypass reasons.")
bypass_reason: Optional[str]
cursor_format: Optional[FutureStateCursorFormatConfiguration] = Field(
default_factory=FutureStateCursorFormatConfiguration,
description=("Expected cursor format"),
)


class IncrementalConfig(BaseConfig):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
#

import json
import re
from pathlib import Path
from typing import Any, Dict, Iterator, List, Mapping, MutableMapping, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Mapping, MutableMapping, Optional, Tuple, Union

import pytest
from airbyte_protocol.models import (
Expand All @@ -13,6 +14,7 @@
AirbyteStateStats,
AirbyteStateType,
ConfiguredAirbyteCatalog,
ConfiguredAirbyteStream,
SyncMode,
Type,
)
Expand All @@ -24,6 +26,18 @@

MIN_BATCHES_TO_TEST: int = 5

SCHEMA_TYPES_MAPPING = {
"str": str,
"string": str,
"int": int,
"integer": int,
"int32": int,
"int64": int,
"float": float,
"double": float,
"number": float,
}


@pytest.fixture(name="future_state_configuration")
def future_state_configuration_fixture(inputs, base_path, test_strictness_level) -> Tuple[Path, List[EmptyStreamConfiguration]]:
Expand Down Expand Up @@ -236,7 +250,7 @@ async def test_read_sequential_slices(
# ), f"Records for subsequent reads with new state should be different.\n\n records_1: {records_1} \n\n state: {state_input} \n\n records_{idx + 1}: {records_N} \n\n diff: {diff}"

async def test_state_with_abnormally_large_values(
self, connector_config, configured_catalog, future_state, docker_runner: ConnectorRunner
self, inputs: IncrementalConfig, connector_config, configured_catalog, future_state, docker_runner: ConnectorRunner
):
configured_catalog = incremental_only_catalog(configured_catalog)
output = await docker_runner.call_read_with_state(config=connector_config, catalog=configured_catalog, state=future_state)
Expand All @@ -248,6 +262,83 @@ async def test_state_with_abnormally_large_values(
), f"The sync should produce no records when run with the state with abnormally large values {records[0].record.stream}"
assert states, "The sync should produce at least one STATE message"

cursor_fields_per_stream = {
stream.stream.name: self._get_cursor_field(stream)
for stream in configured_catalog.streams
if stream.sync_mode == SyncMode.incremental
}
actual_state_cursor_values_per_stream = {
state.state.stream.stream_descriptor.name: self._get_cursor_values_from_states_by_cursor(
state.state.stream.stream_state.dict(), cursor_fields_per_stream[state.state.stream.stream_descriptor.name]
)
for state in states
}
future_state_cursor_values_per_stream = {
state["stream"]["stream_descriptor"]["name"]: self._get_cursor_values_from_states_by_cursor(
state["stream"]["stream_state"], cursor_fields_per_stream[state["stream"]["stream_descriptor"]["name"]]
)
for state in future_state
if state["stream"]["stream_descriptor"]["name"] in cursor_fields_per_stream
}

assert all(future_state_cursor_values_per_stream.values()), "Future state must be set up for all given streams"

expected_cursor_value_schema_per_stream = {
# TODO: Check if cursor value may be a nested property. If so, then should I use ._get_cursor_values_from_states ?
stream.stream.name: stream.stream.json_schema["properties"][cursor_fields_per_stream[stream.stream.name]]
for stream in configured_catalog.streams
}

future_state_formatrs_per_stream = {stream.name: stream for stream in inputs.future_state.cursor_format.streams}
for stream in configured_catalog.streams:
pattern = future_state_formatrs_per_stream.get(stream.stream.name, inputs.future_state.cursor_format).format

# All streams must be defined in the abnormal_state.json file due to the high test strictness level rule.
# However, a state may not be present in the output if a stream was unavailable during sync.
# Ideally, this should not be the case, but in reality, it often happens.
# It is not the purpose of this test to check for this, so we just skip it here.
if stream.stream.name not in actual_state_cursor_values_per_stream:
continue

actual_cursor_values = actual_state_cursor_values_per_stream[stream.stream.name]
future_state_cursor_values = future_state_cursor_values_per_stream[stream.stream.name]

expected_types = self._get_cursor_value_types(expected_cursor_value_schema_per_stream[stream.stream.name]["type"])

for actual_cursor_value, future_state_cursor_value in zip(actual_cursor_values, future_state_cursor_values):

for _type in expected_types:

if actual_cursor_value:
assert isinstance(
actual_cursor_value, _type
), f"Cursor value {actual_cursor_value} is not of type {_type}. Expected {_type}, got {type(actual_cursor_value)}"

if future_state_cursor_value:
assert isinstance(
future_state_cursor_value, _type
), f"Cursor value {future_state_cursor_value} is not of type {_type}. Expected {_type}, got {type(future_state_cursor_value)}"

if not (actual_cursor_value and future_state_cursor_value):
continue

# If the cursor value is numeric and the type check has passed, it means the format is correct
if isinstance(actual_cursor_value, (int, float)):
continue

# When the data is of string type, we need to ensure the format is correct for both cursor values
if pattern:
assert self._check_cursor_by_regex_match(
actual_cursor_value, pattern
), f"Actual cursor value {actual_cursor_value} does not match pattern: {pattern}"
assert self._check_cursor_by_regex_match(
future_state_cursor_value, pattern
), f"Future cursor value {future_state_cursor_value} does not match pattern: {pattern}"
else:
assert self._check_cursor_by_char_types(
actual_cursor_value, future_state_cursor_value
), "Actual and future state formats do not match. Actual cursor value: {actual_cursor_value}, future cursor value: {future_state_cursor_value}"

def get_next_state_input(
self, state_message: AirbyteStateMessage, stream_name_to_per_stream_state: MutableMapping
) -> Tuple[Union[List[MutableMapping], MutableMapping], MutableMapping]:
Expand All @@ -266,6 +357,66 @@ def get_next_state_input(
]
return state_input, stream_name_to_per_stream_state

@staticmethod
def _get_cursor_values_from_states_by_cursor(states: Union[list, dict], cursor_field: str) -> List[Union[str, int]]:
values = []
nodes_to_visit = [states]

while nodes_to_visit:
current_node = nodes_to_visit.pop()

if isinstance(current_node, dict):
for key, value in current_node.items():
if key == cursor_field:
values.append(value)
nodes_to_visit.append(value)
elif isinstance(current_node, list):
nodes_to_visit.extend(current_node)

return values

@staticmethod
def _check_cursor_by_char_types(actual_cursor: str, expected_cursor: str) -> bool:
if len(actual_cursor) != len(expected_cursor):
return False

for char1, char2 in zip(actual_cursor, expected_cursor):
if char1.isalpha() and char2.isalpha():
continue
elif char1.isdigit() and char2.isdigit():
continue
elif not char1.isalnum() and not char2.isalnum() and char1 == char2:
continue
else:
return False

return True

@staticmethod
def _check_cursor_by_regex_match(cursor: str, pattern: str) -> bool:
return bool(re.match(pattern, cursor))

@staticmethod
def _get_cursor_field(stream: ConfiguredAirbyteStream) -> Optional[str]:
cursor_field = stream.cursor_field or stream.stream.default_cursor_field
if cursor_field:
return next(iter(cursor_field))

@staticmethod
def _get_cursor_value_types(schema_type: Union[list, str]) -> List[Callable[..., Any]]:
if isinstance(schema_type, str):
schema_type = [schema_type]
types = []
for _type in schema_type:
if _type == "null":
continue

if _type not in SCHEMA_TYPES_MAPPING:
pytest.fail(f"Unsupported type: {_type}. Update SCHEMA_TYPES_MAPPING with the {_type} and its corresponding function")

types.append(SCHEMA_TYPES_MAPPING[_type])
return types

@staticmethod
def _get_state(airbyte_message: AirbyteMessage) -> AirbyteStateMessage:
if not airbyte_message.state.stream:
Expand Down
Loading

0 comments on commit 6792d86

Please sign in to comment.