Skip to content

Commit

Permalink
low-code: Do not apply transforms on AirbyteLogMessages and AirbyteTr…
Browse files Browse the repository at this point in the history
…aceMessages (airbytehq#25290)

* Check the input type before applying transformations

* format

* remove debug prints
  • Loading branch information
girarda authored and marcosmarxm committed Jun 8, 2023
1 parent 4331ef7 commit 315e68f
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import InitVar, dataclass, field
from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Union

from airbyte_cdk.models import SyncMode
from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, AirbyteTraceMessage, SyncMode
from airbyte_cdk.sources.declarative.interpolation import InterpolatedString
from airbyte_cdk.sources.declarative.retrievers.retriever import Retriever
from airbyte_cdk.sources.declarative.schema import DefaultSchemaLoader
Expand Down Expand Up @@ -100,12 +100,21 @@ def read_records(
for record in self.retriever.read_records(sync_mode, cursor_field, stream_slice, stream_state):
yield self._apply_transformations(record, self.config, stream_slice)

def _apply_transformations(self, record: Mapping[str, Any], config: Config, stream_slice: StreamSlice):
output_record = record
def _apply_transformations(
self,
message_or_record_data: Union[AirbyteMessage, AirbyteLogMessage, AirbyteTraceMessage, Mapping[str, Any]],
config: Config,
stream_slice: StreamSlice,
):
# If the input is an AirbyteRecord, transform the record's data
# If the input is another type of Airbyte Message, return it as is
# If the input is a dict, transform it
if isinstance(message_or_record_data, AirbyteLogMessage) or isinstance(message_or_record_data, AirbyteTraceMessage):
return message_or_record_data
for transformation in self.transformations:
output_record = transformation.transform(record, config=config, stream_state=self.state, stream_slice=stream_slice)
transformation.transform(message_or_record_data, config=config, stream_state=self.state, stream_slice=stream_slice)

return output_record
return message_or_record_data

def get_json_schema(self) -> Mapping[str, Any]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

from airbyte_cdk.models import AirbyteLogMessage, AirbyteTraceMessage, Level, SyncMode, TraceType
from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream
from airbyte_cdk.sources.declarative.transformations import RecordTransformation
from airbyte_cdk.sources.declarative.transformations import AddFields, RecordTransformation
from airbyte_cdk.sources.declarative.transformations.add_fields import AddedFieldDefinition


def test_declarative_stream():
Expand Down Expand Up @@ -63,8 +64,66 @@ def test_declarative_stream():
assert stream.cursor_field == cursor_field
assert stream.stream_slices(sync_mode=SyncMode.incremental, cursor_field=cursor_field, stream_state=None) == stream_slices
for transformation in transformations:
assert len(transformation.transform.call_args_list) == len(records)
expected_calls = [
call(record, config=config, stream_slice=input_slice, stream_state=state) for record in records if isinstance(record, dict)
]
assert len(transformation.transform.call_args_list) == len(expected_calls)
transformation.transform.assert_has_calls(expected_calls, any_order=False)


def test_declarative_stream_with_add_fields_transform():
name = "stream"
primary_key = "pk"
cursor_field = "created_at"

schema_loader = MagicMock()
json_schema = {"name": {"type": "string"}}
schema_loader.get_json_schema.return_value = json_schema

state = MagicMock()
retriever_records = [
{"pk": 1234, "field": "value"},
{"pk": 4567, "field": "different_value"},
AirbyteLogMessage(level=Level.INFO, message="This is a log message"),
AirbyteTraceMessage(type=TraceType.ERROR, emitted_at=12345),
]

expected_records = [
{"pk": 1234, "field": "value", "added_key": "added_value"},
{"pk": 4567, "field": "different_value", "added_key": "added_value"},
AirbyteLogMessage(level=Level.INFO, message="This is a log message"),
AirbyteTraceMessage(type=TraceType.ERROR, emitted_at=12345),
]
stream_slices = [
{"date": "2021-01-01"},
{"date": "2021-01-02"},
{"date": "2021-01-03"},
]

retriever = MagicMock()
retriever.state = state
retriever.read_records.return_value = retriever_records
retriever.stream_slices.return_value = stream_slices

inputs = [AddedFieldDefinition(path=["added_key"], value="added_value", parameters={})]
add_fields_transform = AddFields(fields=inputs, parameters={})
transformations = [add_fields_transform]

config = {"api_key": "open_sesame"}

stream = DeclarativeStream(
name=name,
primary_key=primary_key,
stream_cursor_field="{{ parameters['cursor_field'] }}",
schema_loader=schema_loader,
retriever=retriever,
config=config,
transformations=transformations,
parameters={"cursor_field": "created_at"},
)

assert stream.name == name
assert stream.get_json_schema() == json_schema
assert stream.state == state
input_slice = stream_slices[0]
assert list(stream.read_records(SyncMode.full_refresh, cursor_field, input_slice, state)) == expected_records
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,102 @@ def _create_page(response_body):
(_create_page({"rates": [{"ABC": 0}, {"AED": 1}],"_metadata": {"next": "next"}}), _create_page({"rates": [{"USD": 2}],"_metadata": {"next": "next"}})) * 10,
[{"ABC": 0}, {"AED": 1}],
[call({}, {}, None)]),
("test_read_manifest_with_added_fields",
{
"version": "0.34.2",
"type": "DeclarativeSource",
"check": {
"type": "CheckStream",
"stream_names": [
"Rates"
]
},
"streams": [
{
"type": "DeclarativeStream",
"name": "Rates",
"primary_key": [],
"schema_loader": {
"type": "InlineSchemaLoader",
"schema": {
"$schema": "http://json-schema.org/schema#",
"properties": {
"ABC": {
"type": "number"
},
"AED": {
"type": "number"
},
},
"type": "object"
}
},
"transformations": [
{
"type": "AddFields",
"fields": [
{
"type": "AddedFieldDefinition",
"path": ["added_field_key"],
"value": "added_field_value"
}
]
}
],
"retriever": {
"type": "SimpleRetriever",
"requester": {
"type": "HttpRequester",
"url_base": "https://api.apilayer.com",
"path": "/exchangerates_data/latest",
"http_method": "GET",
"request_parameters": {},
"request_headers": {},
"request_body_json": {},
"authenticator": {
"type": "ApiKeyAuthenticator",
"header": "apikey",
"api_token": "{{ config['api_key'] }}"
}
},
"record_selector": {
"type": "RecordSelector",
"extractor": {
"type": "DpathExtractor",
"field_path": [
"rates"
]
}
},
"paginator": {
"type": "NoPagination"
}
}
}
],
"spec": {
"connection_specification": {
"$schema": "http://json-schema.org/draft-07/schema#",
"type": "object",
"required": [
"api_key"
],
"properties": {
"api_key": {
"type": "string",
"title": "API Key",
"airbyte_secret": True
}
},
"additionalProperties": True
},
"documentation_url": "https://example.org",
"type": "Spec"
}
},
(_create_page({"rates": [{"ABC": 0}, {"AED": 1}],"_metadata": {"next": "next"}}), _create_page({"rates": [{"USD": 2}],"_metadata": {"next": "next"}})) * 10,
[{"ABC": 0, "added_field_key": "added_field_value"}, {"AED": 1, "added_field_key": "added_field_value"}],
[call({}, {}, None)]),
("test_read_with_pagination_no_partitions",
{
"version": "0.34.2",
Expand Down

0 comments on commit 315e68f

Please sign in to comment.