-
Notifications
You must be signed in to change notification settings - Fork 69
/
sql.py
232 lines (181 loc) · 7.53 KB
/
sql.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
"""Base class for SQL-type streams."""
from __future__ import annotations
import abc
import typing as t
from functools import cached_property
import sqlalchemy as sa
import singer_sdk.helpers._catalog as catalog
from singer_sdk._singerlib import CatalogEntry, MetadataMapping
from singer_sdk.connectors import SQLConnector
from singer_sdk.streams.core import REPLICATION_INCREMENTAL, Stream
if t.TYPE_CHECKING:
from singer_sdk.connectors.sql import FullyQualifiedName
from singer_sdk.helpers.types import Context
from singer_sdk.tap_base import Tap
class SQLStream(Stream, metaclass=abc.ABCMeta):
"""Base class for SQLAlchemy-based streams."""
connector_class = SQLConnector
_cached_schema: dict | None = None
supports_nulls_first: bool = False
"""Whether the database supports the NULLS FIRST/LAST syntax."""
def __init__(
self,
tap: Tap,
catalog_entry: dict,
connector: SQLConnector | None = None,
) -> None:
"""Initialize the database stream.
If `connector` is omitted, a new connector will be created.
Args:
tap: The parent tap object.
catalog_entry: Catalog entry dict.
connector: Optional connector to reuse.
"""
self._connector: SQLConnector
self._connector = connector or self.connector_class(dict(tap.config))
self.catalog_entry = catalog_entry
super().__init__(
tap=tap,
schema=self.schema,
name=self.tap_stream_id,
)
@property
def _singer_catalog_entry(self) -> CatalogEntry:
"""Return catalog entry as specified by the Singer catalog spec.
Returns:
A CatalogEntry object.
"""
return CatalogEntry.from_dict(self.catalog_entry)
@property
def connector(self) -> SQLConnector:
"""Return a connector object.
Returns:
The connector object.
"""
return self._connector
@property
def metadata(self) -> MetadataMapping:
"""Return the Singer metadata.
Metadata from an input catalog will override standard metadata.
Returns:
Metadata object as specified in the Singer spec.
"""
return self._singer_catalog_entry.metadata
@cached_property
def schema(self) -> dict:
"""Return metadata object (dict) as specified in the Singer spec.
Metadata from an input catalog will override standard metadata.
Returns:
The schema object.
"""
return self._singer_catalog_entry.schema.to_dict()
@property
def tap_stream_id(self) -> str:
"""Return the unique ID used by the tap to identify this stream.
Generally, this is the same value as in `Stream.name`.
In rare cases, such as for database types with multi-part names,
this may be slightly different from `Stream.name`.
Returns:
The unique tap stream ID as a string.
"""
return self._singer_catalog_entry.tap_stream_id
@property
def primary_keys(self) -> t.Sequence[str] | None:
"""Get primary keys from the catalog entry definition.
Returns:
A list of primary key(s) for the stream.
"""
return self._singer_catalog_entry.metadata.root.table_key_properties or []
@primary_keys.setter
def primary_keys(self, new_value: t.Sequence[str]) -> None:
"""Set or reset the primary key(s) in the stream's catalog entry.
Args:
new_value: a list of one or more column names
"""
self._singer_catalog_entry.metadata.root.table_key_properties = new_value
@property
def fully_qualified_name(self) -> FullyQualifiedName:
"""Generate the fully qualified version of the table name.
Raises:
ValueError: If table_name is not able to be detected.
Returns:
The fully qualified name.
"""
catalog_entry = self._singer_catalog_entry
if not catalog_entry.table:
msg = f"Missing table name in catalog entry: {catalog_entry.to_dict()}"
raise ValueError(msg)
return self.connector.get_fully_qualified_name(
table_name=catalog_entry.table,
schema_name=catalog_entry.metadata.root.schema_name,
db_name=catalog_entry.database,
)
def get_selected_schema(self) -> dict:
"""Return a copy of the Stream JSON schema, dropping any fields not selected.
Returns:
A dictionary containing a copy of the Stream JSON schema, filtered
to any selection criteria.
"""
return catalog.get_selected_schema(
stream_name=self.name,
schema=self.schema,
mask=self.mask,
)
# Get records from stream
def get_records(self, context: Context | None) -> t.Iterable[dict[str, t.Any]]:
"""Return a generator of record-type dictionary objects.
If the stream has a replication_key value defined, records will be sorted by the
incremental key. If the stream also has an available starting bookmark, the
records will be filtered for values greater than or equal to the bookmark value.
Args:
context: If partition context is provided, will read specifically from this
data slice.
Yields:
One dict per record.
Raises:
NotImplementedError: If partition is passed in context and the stream does
not support partitioning.
"""
if context:
msg = f"Stream '{self.name}' does not support partitioning."
raise NotImplementedError(msg)
selected_column_names = self.get_selected_schema()["properties"].keys()
table = self.connector.get_table(
full_table_name=self.fully_qualified_name,
column_names=selected_column_names,
)
query = table.select()
if self.replication_key:
replication_key_col = table.columns[self.replication_key]
order_by = (
sa.nulls_first(replication_key_col.asc())
if self.supports_nulls_first
else replication_key_col.asc()
)
query = query.order_by(order_by)
start_val = self.get_starting_replication_key_value(context)
if start_val:
query = query.where(replication_key_col >= start_val)
if self.ABORT_AT_RECORD_COUNT is not None:
# Limit record count to one greater than the abort threshold. This ensures
# `MaxRecordsLimitException` exception is properly raised by caller
# `Stream._sync_records()` if more records are available than can be
# processed.
query = query.limit(self.ABORT_AT_RECORD_COUNT + 1)
with self.connector._connect() as conn: # noqa: SLF001
for record in conn.execute(query).mappings():
transformed_record = self.post_process(dict(record))
if transformed_record is None:
# Record filtered out during post_process()
continue
yield transformed_record
@property
def is_sorted(self) -> bool:
"""Expect stream to be sorted.
When `True`, incremental streams will attempt to resume if unexpectedly
interrupted.
Returns:
`True` if stream is sorted. Defaults to `False`.
"""
return self.replication_method == REPLICATION_INCREMENTAL
__all__ = ["SQLConnector", "SQLStream"]