Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Refactor SQLConnector connection handling #1394

Merged
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

84 changes: 59 additions & 25 deletions singer_sdk/connectors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from __future__ import annotations

import logging
import warnings
from contextlib import contextmanager
qbatten marked this conversation as resolved.
Show resolved Hide resolved
from datetime import datetime
from functools import lru_cache
from typing import Any, Iterable, cast
from typing import Any, Iterable, Iterator, cast

import sqlalchemy
from sqlalchemy.engine import Engine
Expand Down Expand Up @@ -34,6 +36,7 @@ class SQLConnector:
allow_column_alter: bool = False # Whether altering column types is supported.
allow_merge_upsert: bool = False # Whether MERGE UPSERT is supported.
allow_temp_tables: bool = True # Whether temp tables are supported.
_cached_engine: Engine | None = None

def __init__(
self, config: dict | None = None, sqlalchemy_url: str | None = None
Expand All @@ -46,7 +49,6 @@ def __init__(
"""
self._config: dict[str, Any] = config or {}
self._sqlalchemy_url: str | None = sqlalchemy_url or None
self._connection: sqlalchemy.engine.Connection | None = None

@property
def config(self) -> dict:
Expand All @@ -66,8 +68,17 @@ def logger(self) -> logging.Logger:
"""
return logging.getLogger("sqlconnector")

@contextmanager
def _connect(self) -> Iterator[sqlalchemy.engine.Connection]:
with self._engine.connect().execution_options(stream_results=True) as conn:
yield conn

def create_sqlalchemy_connection(self) -> sqlalchemy.engine.Connection:
"""Return a new SQLAlchemy connection using the provided config.
"""(DEPRECATED) Return a new SQLAlchemy connection using the provided config.

Do not use the SQLConnector's connection directly. Instead, if you need
to execute something that isn't available on the connector currently,
make a child class and add a method on that connector.

By default this will create using the sqlalchemy `stream_results=True` option
described here:
Expand All @@ -81,34 +92,49 @@ def create_sqlalchemy_connection(self) -> sqlalchemy.engine.Connection:
Returns:
A newly created SQLAlchemy engine object.
"""
return (
self.create_sqlalchemy_engine()
.connect()
.execution_options(stream_results=True)
warnings.warn(
"`create_sqlalchemy_connection` is deprecated. "
qbatten marked this conversation as resolved.
Show resolved Hide resolved
"If you need to execute something that isn't available "
"on the connector currently, make a child class and "
"add your required method on that connector.",
DeprecationWarning,
)
return self._engine.connect().execution_options(stream_results=True)
qbatten marked this conversation as resolved.
Show resolved Hide resolved

def create_sqlalchemy_engine(self) -> sqlalchemy.engine.Engine:
"""Return a new SQLAlchemy engine using the provided config.
def create_sqlalchemy_engine(self) -> Engine:
"""(DEPRECATED) Return a new SQLAlchemy engine using the provided config.

Developers can generally override just one of the following:
`sqlalchemy_engine`, sqlalchemy_url`.

Returns:
A newly created SQLAlchemy engine object.
"""
return sqlalchemy.create_engine(self.sqlalchemy_url, echo=False)
warnings.warn(
"`create_sqlalchemy_engine` is deprecated. Override"
qbatten marked this conversation as resolved.
Show resolved Hide resolved
"`_engine` or sqlalchemy_url` instead.",
DeprecationWarning,
)
return self._engine
qbatten marked this conversation as resolved.
Show resolved Hide resolved

@property
def connection(self) -> sqlalchemy.engine.Connection:
"""Return or set the SQLAlchemy connection object.
"""(DEPRECATED) Return or set the SQLAlchemy connection object.

Do not use the SQLConnector's connection directly. Instead, if you need
to execute something that isn't available on the connector currently,
make a child class and add a method on that connector.

Returns:
The active SQLAlchemy connection object.
"""
if not self._connection:
self._connection = self.create_sqlalchemy_connection()

return self._connection
warnings.warn(
"`connection` is deprecated. If you need to execute something "
qbatten marked this conversation as resolved.
Show resolved Hide resolved
"that isn't available on the connector currently, make a child "
"class and add your required method on that connector.",
DeprecationWarning,
)
return self.create_sqlalchemy_connection()
qbatten marked this conversation as resolved.
Show resolved Hide resolved

@property
def sqlalchemy_url(self) -> str:
Expand Down Expand Up @@ -249,16 +275,20 @@ def _dialect(self) -> sqlalchemy.engine.Dialect:
Returns:
The dialect object.
"""
return cast(sqlalchemy.engine.Dialect, self.connection.engine.dialect)
return cast(sqlalchemy.engine.Dialect, self._engine.dialect)

@property
def _engine(self) -> sqlalchemy.engine.Engine:
"""Return the dialect object.
def _engine(self) -> Engine:
"""Return the engine object.

Returns:
The dialect object.
The SQLAlchemy Engine that's attached to this SQLConnector instance.
"""
return cast(sqlalchemy.engine.Engine, self.connection.engine)
if not self._cached_engine:
self._cached_engine = sqlalchemy.create_engine(
self.sqlalchemy_url, echo=False
)
return cast(Engine, self._cached_engine)

def quote(self, name: str) -> str:
"""Quote a name if it needs quoting, using '.' as a name-part delimiter.
Expand Down Expand Up @@ -421,7 +451,7 @@ def discover_catalog_entries(self) -> list[dict]:
The discovered catalog entries as a list.
"""
result: list[dict] = []
engine = self.create_sqlalchemy_engine()
engine = self._engine
inspected = sqlalchemy.inspect(engine)
for schema_name in self.get_schema_names(engine, inspected):
# Iterate through each table and view
Expand Down Expand Up @@ -562,7 +592,8 @@ def create_schema(self, schema_name: str) -> None:
Args:
schema_name: The target schema to create.
"""
self._engine.execute(sqlalchemy.schema.CreateSchema(schema_name))
with self._connect() as conn:
conn.execute(sqlalchemy.schema.CreateSchema(schema_name))

def create_empty_table(
self,
Expand Down Expand Up @@ -635,7 +666,8 @@ def _create_empty_column(
column_add_ddl = self.get_column_add_ddl(
table_name=full_table_name, column_name=column_name, column_type=sql_type
)
self.connection.execute(column_add_ddl)
with self._connect() as conn:
conn.execute(column_add_ddl)

def prepare_schema(self, schema_name: str) -> None:
"""Create the target database schema.
Expand Down Expand Up @@ -723,7 +755,8 @@ def rename_column(self, full_table_name: str, old_name: str, new_name: str) -> N
column_rename_ddl = self.get_column_rename_ddl(
table_name=full_table_name, column_name=old_name, new_column_name=new_name
)
self.connection.execute(column_rename_ddl)
with self._connect() as conn:
conn.execute(column_rename_ddl)

def merge_sql_types(
self, sql_types: list[sqlalchemy.types.TypeEngine]
Expand Down Expand Up @@ -1027,4 +1060,5 @@ def _adapt_column_type(
column_name=column_name,
column_type=compatible_sql_type,
)
self.connection.execute(alter_column_ddl)
with self._connect() as conn:
conn.execute(alter_column_ddl)