Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Refactor SQLConnector connection handling #1394

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 76 additions & 25 deletions singer_sdk/connectors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from __future__ import annotations

import logging
import warnings
from contextlib import contextmanager
qbatten marked this conversation as resolved.
Show resolved Hide resolved
from datetime import datetime
from functools import lru_cache
from typing import Any, Iterable, cast
from typing import Any, Iterable, Iterator, cast

import sqlalchemy
from sqlalchemy.engine import Engine
Expand Down Expand Up @@ -34,6 +36,7 @@ class SQLConnector:
allow_column_alter: bool = False # Whether altering column types is supported.
allow_merge_upsert: bool = False # Whether MERGE UPSERT is supported.
allow_temp_tables: bool = True # Whether temp tables are supported.
_cached_engine: Engine | None = None

def __init__(
self, config: dict | None = None, sqlalchemy_url: str | None = None
Expand All @@ -46,7 +49,6 @@ def __init__(
"""
self._config: dict[str, Any] = config or {}
self._sqlalchemy_url: str | None = sqlalchemy_url or None
self._connection: sqlalchemy.engine.Connection | None = None

@property
def config(self) -> dict:
Expand All @@ -66,8 +68,17 @@ def logger(self) -> logging.Logger:
"""
return logging.getLogger("sqlconnector")

@contextmanager
def _connect(self) -> Iterator[sqlalchemy.engine.Connection]:
with self._engine.connect().execution_options(stream_results=True) as conn:
yield conn

def create_sqlalchemy_connection(self) -> sqlalchemy.engine.Connection:
"""Return a new SQLAlchemy connection using the provided config.
"""(DEPRECATED) Return a new SQLAlchemy connection using the provided config.

Do not use the SQLConnector's connection directly. Instead, if you need
to execute something that isn't available on the connector currently,
make a child class and add a method on that connector.

By default this will create using the sqlalchemy `stream_results=True` option
described here:
Expand All @@ -81,34 +92,49 @@ def create_sqlalchemy_connection(self) -> sqlalchemy.engine.Connection:
Returns:
A newly created SQLAlchemy engine object.
"""
return (
self.create_sqlalchemy_engine()
.connect()
.execution_options(stream_results=True)
warnings.warn(
"`SQLConnector.create_sqlalchemy_connection` is deprecated. "
"If you need to execute something that isn't available "
"on the connector currently, make a child class and "
"add your required method on that connector.",
DeprecationWarning,
)
return self._engine.connect().execution_options(stream_results=True)
qbatten marked this conversation as resolved.
Show resolved Hide resolved

def create_sqlalchemy_engine(self) -> sqlalchemy.engine.Engine:
"""Return a new SQLAlchemy engine using the provided config.
def create_sqlalchemy_engine(self) -> Engine:
"""(DEPRECATED) Return a new SQLAlchemy engine using the provided config.

Developers can generally override just one of the following:
`sqlalchemy_engine`, sqlalchemy_url`.

Returns:
A newly created SQLAlchemy engine object.
"""
return sqlalchemy.create_engine(self.sqlalchemy_url, echo=False)
warnings.warn(
"`SQLConnector.create_sqlalchemy_engine` is deprecated. Override"
"`_engine` or sqlalchemy_url` instead.",
DeprecationWarning,
)
return self._engine
qbatten marked this conversation as resolved.
Show resolved Hide resolved

@property
def connection(self) -> sqlalchemy.engine.Connection:
"""Return or set the SQLAlchemy connection object.
"""(DEPRECATED) Return or set the SQLAlchemy connection object.

Do not use the SQLConnector's connection directly. Instead, if you need
to execute something that isn't available on the connector currently,
make a child class and add a method on that connector.

Returns:
The active SQLAlchemy connection object.
"""
if not self._connection:
self._connection = self.create_sqlalchemy_connection()

return self._connection
warnings.warn(
"`SQLConnector.connection` is deprecated. If you need to execute something "
"that isn't available on the connector currently, make a child "
"class and add your required method on that connector.",
DeprecationWarning,
)
return self.create_sqlalchemy_connection()
qbatten marked this conversation as resolved.
Show resolved Hide resolved

@property
def sqlalchemy_url(self) -> str:
Expand Down Expand Up @@ -249,16 +275,37 @@ def _dialect(self) -> sqlalchemy.engine.Dialect:
Returns:
The dialect object.
"""
return cast(sqlalchemy.engine.Dialect, self.connection.engine.dialect)
return cast(sqlalchemy.engine.Dialect, self._engine.dialect)

@property
def _engine(self) -> sqlalchemy.engine.Engine:
"""Return the dialect object.
def _engine(self) -> Engine:
"""Return the engine object.

This is the correct way to access the Connector's engine, if needed
(e.g. to inspect tables).

Returns:
The dialect object.
The SQLAlchemy Engine that's attached to this SQLConnector instance.
"""
if not self._cached_engine:
self._cached_engine = self.create_engine()
return cast(Engine, self._cached_engine)

def create_engine(self) -> Engine:
"""Creates and returns a new engine. Do not call outside of _engine.

NOTE: Do not call this method. The only place that this method should
be called is inside the self._engine method. If you'd like to access
the engine on a connector, use self._engine.

This method exists solely so that tap/target developers can override it
on their subclass of SQLConnector to perform custom engine creation
logic.

Returns:
A new SQLAlchemy Engine.
"""
return cast(sqlalchemy.engine.Engine, self.connection.engine)
return sqlalchemy.create_engine(self.sqlalchemy_url, echo=False)

def quote(self, name: str) -> str:
"""Quote a name if it needs quoting, using '.' as a name-part delimiter.
Expand Down Expand Up @@ -421,7 +468,7 @@ def discover_catalog_entries(self) -> list[dict]:
The discovered catalog entries as a list.
"""
result: list[dict] = []
engine = self.create_sqlalchemy_engine()
engine = self._engine
inspected = sqlalchemy.inspect(engine)
for schema_name in self.get_schema_names(engine, inspected):
# Iterate through each table and view
Expand Down Expand Up @@ -562,7 +609,8 @@ def create_schema(self, schema_name: str) -> None:
Args:
schema_name: The target schema to create.
"""
self._engine.execute(sqlalchemy.schema.CreateSchema(schema_name))
with self._connect() as conn:
conn.execute(sqlalchemy.schema.CreateSchema(schema_name))

def create_empty_table(
self,
Expand Down Expand Up @@ -635,7 +683,8 @@ def _create_empty_column(
column_add_ddl = self.get_column_add_ddl(
table_name=full_table_name, column_name=column_name, column_type=sql_type
)
self.connection.execute(column_add_ddl)
with self._connect() as conn:
conn.execute(column_add_ddl)

def prepare_schema(self, schema_name: str) -> None:
"""Create the target database schema.
Expand Down Expand Up @@ -723,7 +772,8 @@ def rename_column(self, full_table_name: str, old_name: str, new_name: str) -> N
column_rename_ddl = self.get_column_rename_ddl(
table_name=full_table_name, column_name=old_name, new_column_name=new_name
)
self.connection.execute(column_rename_ddl)
with self._connect() as conn:
conn.execute(column_rename_ddl)

def merge_sql_types(
self, sql_types: list[sqlalchemy.types.TypeEngine]
Expand Down Expand Up @@ -1027,4 +1077,5 @@ def _adapt_column_type(
column_name=column_name,
column_type=compatible_sql_type,
)
self.connection.execute(alter_column_ddl)
with self._connect() as conn:
conn.execute(alter_column_ddl)
61 changes: 60 additions & 1 deletion tests/core/test_connector_sql.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from unittest import mock

import pytest
import sqlalchemy
from sqlalchemy.dialects import sqlite

from singer_sdk.connectors import SQLConnector
from singer_sdk.exceptions import ConfigValidationError


def stringify(in_dict):
Expand All @@ -14,7 +17,7 @@ class TestConnectorSQL:

@pytest.fixture()
def connector(self):
return SQLConnector()
return SQLConnector(config={"sqlalchemy_url": "sqlite:///"})

@pytest.mark.parametrize(
"method_name,kwargs,context,unrendered_statement,rendered_statement",
Expand Down Expand Up @@ -130,3 +133,59 @@ def test_update_collation_non_text_type(self):
assert not hasattr(compatible_type, "collation")
# Check that we get the same type we put in
assert str(compatible_type) == "INTEGER"

def test_create_engine_returns_new_engine(self, connector):
engine1 = connector.create_engine()
engine2 = connector.create_engine()
assert engine1 is not engine2

def test_engine_creates_and_returns_cached_engine(self, connector):
assert not connector._cached_engine
engine1 = connector._engine
engine2 = connector._cached_engine
assert engine1 is engine2
qbatten marked this conversation as resolved.
Show resolved Hide resolved

def test_deprecated_functions_warn(self, connector):
with pytest.deprecated_call():
connector.create_sqlalchemy_engine()
with pytest.deprecated_call():
connector.create_sqlalchemy_connection()
with pytest.deprecated_call():
connector.connection

def test_connect_calls_engine(self, connector):
with mock.patch.object(SQLConnector, "_engine") as mock_engine:
with connector._connect() as conn:
mock_engine.connect.assert_called_once()

def test_connect_calls_engine(self, connector):
attached_engine = connector._engine
with mock.patch.object(attached_engine, "connect") as mock_conn:
with connector._connect() as conn:
mock_conn.assert_called_once()

def test_connect_raises_on_operational_failure(self, connector):
with pytest.raises(sqlalchemy.exc.OperationalError) as e:
with connector._connect() as conn:
conn.execute("SELECT * FROM fake_table")

def test_rename_column_uses_connect_correctly(self, connector):
attached_engine = connector._engine
# Ends up using the attached engine
with mock.patch.object(attached_engine, "connect") as mock_conn:
connector.rename_column("fake_table", "old_name", "new_name")
mock_conn.assert_called_once()
# Uses the _connect method
with mock.patch.object(connector, "_connect") as mock_connect_method:
connector.rename_column("fake_table", "old_name", "new_name")
mock_connect_method.assert_called_once()

def test_get_slalchemy_url_raises_if_not_in_config(self, connector):
with pytest.raises(ConfigValidationError):
connector.get_sqlalchemy_url({})

def test_dialect_uses_engine(self, connector):
attached_engine = connector._engine
with mock.patch.object(attached_engine, "dialect") as mock_dialect:
res = connector._dialect
assert res == attached_engine.dialect