Skip to content

Commit

Permalink
chore: Standardized record and context types (#2415)
Browse files Browse the repository at this point in the history
chore: Standardize `record` and `context` types`
  • Loading branch information
edgarrmondragon committed May 7, 2024
1 parent 4eb93e4 commit 86e249d
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 42 deletions.
74 changes: 44 additions & 30 deletions singer_sdk/streams/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import copy
import datetime
import json
import sys
import typing as t
from os import PathLike
from pathlib import Path
Expand Down Expand Up @@ -50,6 +51,11 @@
from singer_sdk.helpers._util import utc_now
from singer_sdk.mapper import RemoveRecordTransform, SameRecordTransform, StreamMap

if sys.version_info < (3, 10):
from typing_extensions import TypeAlias
else:
from typing import TypeAlias # noqa: ICN003

if t.TYPE_CHECKING:
import logging

Expand All @@ -62,6 +68,8 @@
REPLICATION_LOG_BASED = "LOG_BASED"

FactoryType = t.TypeVar("FactoryType", bound="Stream")
Record: TypeAlias = t.Dict[str, t.Any]
Context: TypeAlias = t.Dict


class Stream(metaclass=abc.ABCMeta): # noqa: PLR0904
Expand Down Expand Up @@ -227,7 +235,7 @@ def is_timestamp_replication_key(self) -> bool:

def get_starting_replication_key_value(
self,
context: dict | None,
context: Context | None,
) -> t.Any | None: # noqa: ANN401
"""Get starting replication key.
Expand All @@ -252,7 +260,9 @@ def get_starting_replication_key_value(
else None
)

def get_starting_timestamp(self, context: dict | None) -> datetime.datetime | None:
def get_starting_timestamp(
self, context: Context | None
) -> datetime.datetime | None:
"""Get starting replication timestamp.
Will return the value of the stream's replication key when `--state` is passed.
Expand Down Expand Up @@ -330,7 +340,7 @@ def descendent_streams(self) -> list[Stream]:

def _write_replication_key_signpost(
self,
context: dict | None,
context: Context | None,
value: datetime.datetime | str | int | float,
) -> None:
"""Write the signpost value, if available.
Expand Down Expand Up @@ -371,7 +381,7 @@ def compare_start_date(self, value: str, start_date_value: str) -> str:

return value

def _write_starting_replication_value(self, context: dict | None) -> None:
def _write_starting_replication_value(self, context: Context | None) -> None:
"""Write the starting replication value, if available.
Args:
Expand Down Expand Up @@ -399,7 +409,7 @@ def _write_starting_replication_value(self, context: dict | None) -> None:

def get_replication_key_signpost(
self,
context: dict | None, # noqa: ARG002
context: Context | None, # noqa: ARG002
) -> datetime.datetime | t.Any | None: # noqa: ANN401
"""Get the replication signpost.
Expand Down Expand Up @@ -646,7 +656,7 @@ def tap_state(self) -> dict:
"""
return self._tap_state

def get_context_state(self, context: dict | None) -> dict:
def get_context_state(self, context: Context | None) -> dict:
"""Return a writable state dict for the given context.
Gives a partitioned context state if applicable; else returns stream state.
Expand Down Expand Up @@ -701,7 +711,7 @@ def stream_state(self) -> dict:
# Partitions

@property
def partitions(self) -> list[dict] | None:
def partitions(self) -> list[Context] | None:
"""Get stream partitions.
Developers may override this property to provide a default partitions list.
Expand All @@ -724,9 +734,9 @@ def partitions(self) -> list[dict] | None:

def _increment_stream_state(
self,
latest_record: dict[str, t.Any],
latest_record: Record,
*,
context: dict | None = None,
context: Context | None = None,
) -> None:
"""Update state of stream or partition with data from the provided record.
Expand Down Expand Up @@ -817,7 +827,7 @@ def mask(self) -> singer.SelectionMask:

def _generate_record_messages(
self,
record: dict,
record: Record,
) -> t.Generator[singer.RecordMessage, None, None]:
"""Write out a RECORD message.
Expand Down Expand Up @@ -846,7 +856,7 @@ def _generate_record_messages(
time_extracted=utc_now(),
)

def _write_record_message(self, record: dict) -> None:
def _write_record_message(self, record: Record) -> None:
"""Write out a RECORD message.
Args:
Expand Down Expand Up @@ -963,7 +973,7 @@ def reset_state_progress_markers(self, state: dict | None = None) -> None:
state: State object to promote progress markers with.
"""
if state is None or state == {}:
context: dict | None
context: Context | None
for context in self.partitions or [{}]:
state = self.get_context_state(context or None)
reset_state_progress_markers(state)
Expand Down Expand Up @@ -992,7 +1002,7 @@ def finalize_state_progress_markers(self, state: dict | None = None) -> None:
for child_stream in self.child_streams or []:
child_stream.finalize_state_progress_markers()

context: dict | None
context: Context | None
for context in self.partitions or [{}]:
state = self.get_context_state(context or None)
self._finalize_state(state)
Expand All @@ -1005,9 +1015,9 @@ def finalize_state_progress_markers(self, state: dict | None = None) -> None:

def _process_record(
self,
record: dict,
child_context: dict | None = None,
partition_context: dict | None = None,
record: Record,
child_context: Context | None = None,
partition_context: Context | None = None,
) -> None:
"""Process a record.
Expand All @@ -1032,7 +1042,7 @@ def _process_record(

def _sync_records( # noqa: C901
self,
context: dict | None = None,
context: Context | None = None,
*,
write_messages: bool = True,
) -> t.Generator[dict, t.Any, t.Any]:
Expand All @@ -1054,7 +1064,7 @@ def _sync_records( # noqa: C901
timer = metrics.sync_timer(self.name)

record_index = 0
context_element: dict | None
context_element: Context | None
context_list: list[dict] | None
context_list = [context] if context is not None else self.partitions
selected = self.selected
Expand All @@ -1070,7 +1080,7 @@ def _sync_records( # noqa: C901
current_context,
)
self._write_starting_replication_value(current_context)
child_context: dict | None = (
child_context: Context | None = (
None if current_context is None else copy.copy(current_context)
)

Expand Down Expand Up @@ -1131,7 +1141,7 @@ def _sync_records( # noqa: C901
def _sync_batches(
self,
batch_config: BatchConfig,
context: dict | None = None,
context: Context | None = None,
) -> None:
"""Sync batches, emitting BATCH messages.
Expand All @@ -1148,7 +1158,7 @@ def _sync_batches(
# Public methods ("final", not recommended to be overridden)

@t.final
def sync(self, context: dict | None = None) -> None:
def sync(self, context: Context | None = None) -> None:
"""Sync this stream.
This method is internal to the SDK and should not need to be overridden.
Expand Down Expand Up @@ -1188,7 +1198,7 @@ def sync(self, context: dict | None = None) -> None:
)
raise

def _sync_children(self, child_context: dict | None) -> None:
def _sync_children(self, child_context: Context | None) -> None:
if child_context is None:
self.logger.warning(
"Context for child streams of '%s' is null, "
Expand Down Expand Up @@ -1223,7 +1233,7 @@ def apply_catalog(self, catalog: singer.Catalog) -> None:
if catalog_entry.replication_method:
self.forced_replication_method = catalog_entry.replication_method

def _get_state_partition_context(self, context: dict | None) -> dict | None:
def _get_state_partition_context(self, context: Context | None) -> dict | None:
"""Override state handling if Stream.state_partitioning_keys is specified.
Args:
Expand All @@ -1240,7 +1250,11 @@ def _get_state_partition_context(self, context: dict | None) -> dict | None:

return {k: v for k, v in context.items() if k in self.state_partitioning_keys}

def get_child_context(self, record: dict, context: dict | None) -> dict | None:
def get_child_context(
self,
record: Record,
context: Context | None,
) -> dict | None:
"""Return a child context object from the record and optional provided context.
By default, will return context if provided and otherwise the record dict.
Expand Down Expand Up @@ -1281,8 +1295,8 @@ def get_child_context(self, record: dict, context: dict | None) -> dict | None:

def generate_child_contexts(
self,
record: dict,
context: dict | None,
record: Record,
context: Context | None,
) -> t.Iterable[dict | None]:
"""Generate child contexts.
Expand All @@ -1300,7 +1314,7 @@ def generate_child_contexts(
@abc.abstractmethod
def get_records(
self,
context: dict | None,
context: Context | None,
) -> t.Iterable[dict | tuple[dict, dict | None]]:
"""Abstract record generator function. Must be overridden by the child class.
Expand Down Expand Up @@ -1346,7 +1360,7 @@ def get_batch_config(self, config: t.Mapping) -> BatchConfig | None: # noqa: PL
def get_batches(
self,
batch_config: BatchConfig,
context: dict | None = None,
context: Context | None = None,
) -> t.Iterable[tuple[BaseBatchFileEncoding, list[str]]]:
"""Batch generator function.
Expand All @@ -1371,8 +1385,8 @@ def get_batches(

def post_process( # noqa: PLR6301
self,
row: dict,
context: dict | None = None, # noqa: ARG002
row: Record,
context: Context | None = None, # noqa: ARG002
) -> dict | None:
"""As needed, append or transform raw data to match expected structure.
Expand Down
5 changes: 4 additions & 1 deletion singer_sdk/streams/graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from singer_sdk.helpers._classproperty import classproperty
from singer_sdk.streams.rest import RESTStream

if t.TYPE_CHECKING:
from singer_sdk.streams.core import Context

_TToken = t.TypeVar("_TToken")


Expand Down Expand Up @@ -44,7 +47,7 @@ def query(self) -> str:

def prepare_request_payload(
self,
context: dict | None,
context: Context | None,
next_page_token: _TToken | None,
) -> dict | None:
"""Prepare the data payload for the GraphQL API request.
Expand Down
21 changes: 11 additions & 10 deletions singer_sdk/streams/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from backoff.types import Details

from singer_sdk._singerlib import Schema
from singer_sdk.streams.core import Context
from singer_sdk.tap_base import Tap

if sys.version_info >= (3, 10):
Expand Down Expand Up @@ -110,7 +111,7 @@ def _url_encode(val: str | datetime | bool | int | list[str]) -> str: # noqa: F
"""
return val.replace("/", "%2F") if isinstance(val, str) else str(val)

def get_url(self, context: dict | None) -> str:
def get_url(self, context: Context | None) -> str:
"""Get stream entity URL.
Developers override this method to perform dynamic URL generation.
Expand Down Expand Up @@ -245,7 +246,7 @@ def request_decorator(self, func: t.Callable) -> t.Callable:
def _request(
self,
prepared_request: requests.PreparedRequest,
context: dict | None,
context: Context | None,
) -> requests.Response:
"""TODO.
Expand All @@ -271,7 +272,7 @@ def _request(

def get_url_params( # noqa: PLR6301
self,
context: dict | None, # noqa: ARG002
context: Context | None, # noqa: ARG002
next_page_token: _TToken | None, # noqa: ARG002
) -> dict[str, t.Any] | str:
"""Return a dictionary or string of URL query parameters.
Expand Down Expand Up @@ -325,7 +326,7 @@ def build_prepared_request(

def prepare_request(
self,
context: dict | None,
context: Context | None,
next_page_token: _TToken | None,
) -> requests.PreparedRequest:
"""Prepare a request object for this stream.
Expand Down Expand Up @@ -357,7 +358,7 @@ def prepare_request(
json=request_data,
)

def request_records(self, context: dict | None) -> t.Iterable[dict]:
def request_records(self, context: Context | None) -> t.Iterable[dict]:
"""Request records from REST endpoint(s), returning response records.
If pagination is detected, pages will be recursed automatically.
Expand Down Expand Up @@ -403,7 +404,7 @@ def _write_request_duration_log(
self,
endpoint: str,
response: requests.Response,
context: dict | None,
context: Context | None,
extra_tags: dict | None,
) -> None:
"""TODO.
Expand Down Expand Up @@ -440,7 +441,7 @@ def update_sync_costs(
self,
request: requests.PreparedRequest,
response: requests.Response,
context: dict | None,
context: Context | None,
) -> dict[str, int]:
"""Update internal calculation of Sync costs.
Expand All @@ -465,7 +466,7 @@ def calculate_sync_cost( # noqa: PLR6301
self,
request: requests.PreparedRequest, # noqa: ARG002
response: requests.Response, # noqa: ARG002
context: dict | None, # noqa: ARG002
context: Context | None, # noqa: ARG002
) -> dict[str, int]:
"""Calculate the cost of the last API call made.
Expand Down Expand Up @@ -494,7 +495,7 @@ def calculate_sync_cost( # noqa: PLR6301

def prepare_request_payload(
self,
context: dict | None,
context: Context | None,
next_page_token: _TToken | None,
) -> dict | None:
"""Prepare the data payload for the REST API request.
Expand Down Expand Up @@ -560,7 +561,7 @@ def timeout(self) -> int:

# Records iterator

def get_records(self, context: dict | None) -> t.Iterable[dict[str, t.Any]]:
def get_records(self, context: Context | None) -> t.Iterable[dict[str, t.Any]]:
"""Return a generator of record-type dictionary objects.
Each record emitted should be a dictionary of property names to their values.
Expand Down
Loading

0 comments on commit 86e249d

Please sign in to comment.