Skip to content
This repository has been archived by the owner on May 17, 2024. It is now read-only.

Commit

Permalink
Merge pull request #323 from datafold/shared_conn
Browse files Browse the repository at this point in the history
connect(): Added support for shared connection; Database.is_closed property
  • Loading branch information
erezsh committed Dec 1, 2022
2 parents a0c7efe + 9b42e9b commit 2e36969
Show file tree
Hide file tree
Showing 9 changed files with 70 additions and 15 deletions.
5 changes: 5 additions & 0 deletions data_diff/sqeleton/abcs/database_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,11 @@ def close(self):
"Close connection(s) to the database instance. Querying will stop functioning."
...

@property
@abstractmethod
def is_closed(self) -> bool:
"Return whether or not the connection has been closed"

@abstractmethod
def _normalize_table_path(self, path: DbPath) -> DbPath:
...
Expand Down
6 changes: 6 additions & 0 deletions data_diff/sqeleton/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ class Database(AbstractDatabase):
CONNECT_URI_KWPARAMS = []

_interactive = False
is_closed = False

@property
def name(self):
Expand Down Expand Up @@ -440,6 +441,10 @@ def _query_conn(self, conn, sql_code: Union[str, ThreadLocalInterpreter]) -> lis
callback = partial(self._query_cursor, c)
return apply_query(callback, sql_code)

def close(self):
self.is_closed = True
return super().close()


class ThreadedDatabase(Database):
"""Access the database through singleton threads.
Expand Down Expand Up @@ -476,6 +481,7 @@ def create_connection(self):
...

def close(self):
super().close()
self._queue.shutdown()

@property
Expand Down
1 change: 1 addition & 0 deletions data_diff/sqeleton/databases/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def _query(self, sql_code: Union[str, ThreadLocalInterpreter]):
return apply_query(self._query_atom, sql_code)

def close(self):
super().close()
self._client.close()

def select_table_schema(self, path: DbPath) -> str:
Expand Down
26 changes: 21 additions & 5 deletions data_diff/sqeleton/databases/connect.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Type, List, Optional, Union, Dict
from itertools import zip_longest
import dsnparse
from contextlib import suppress

from runtype import dataclass

from ..utils import WeakCache
from .base import Database, ThreadedDatabase
from .postgresql import PostgreSQL
from .mysql import MySQL
Expand All @@ -19,12 +21,13 @@
from .duckdb import DuckDB



@dataclass
class MatchUriPath:
database_cls: Type[Database]
params: List[str]
kwparams: List[str] = []
help_str: str
help_str: str = "<unspecified>"

def __post_init__(self):
assert self.params == self.database_cls.CONNECT_URI_PARAMS, self.params
Expand Down Expand Up @@ -101,6 +104,7 @@ def __init__(self, database_by_scheme: Dict[str, Database]):
name: MatchUriPath(cls, cls.CONNECT_URI_PARAMS, cls.CONNECT_URI_KWPARAMS, help_str=cls.CONNECT_URI_HELP)
for name, cls in database_by_scheme.items()
}
self.conn_cache = WeakCache()

def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1) -> Database:
"""Connect to the given database uri
Expand Down Expand Up @@ -200,7 +204,7 @@ def _connection_created(self, db):
"Nop function to be overridden by subclasses."
return db

def __call__(self, db_conf: Union[str, dict], thread_count: Optional[int] = 1) -> Database:
def __call__(self, db_conf: Union[str, dict], thread_count: Optional[int] = 1, shared: bool = True) -> Database:
"""Connect to a database using the given database configuration.
Configuration can be given either as a URI string, or as a dict of {option: value}.
Expand All @@ -213,6 +217,7 @@ def __call__(self, db_conf: Union[str, dict], thread_count: Optional[int] = 1) -
Parameters:
db_conf (str | dict): The configuration for the database to connect. URI or dict.
thread_count (int, optional): Size of the threadpool. Ignored by cloud databases. (default: 1)
shared (bool): Whether to cache and return the same connection for the same db_conf. (default: True)
Note: For non-cloud databases, a low thread-pool size may be a performance bottleneck.
Expand All @@ -235,8 +240,19 @@ def __call__(self, db_conf: Union[str, dict], thread_count: Optional[int] = 1) -
>>> connect({"driver": "mysql", "host": "localhost", "database": "db"})
<data_diff.databases.mysql.MySQL object at 0x0000025DB3F94820>
"""
if shared:
with suppress(KeyError):
conn = self.conn_cache.get(db_conf)
if not conn.is_closed:
return conn

if isinstance(db_conf, str):
return self.connect_to_uri(db_conf, thread_count)
conn = self.connect_to_uri(db_conf, thread_count)
elif isinstance(db_conf, dict):
return self.connect_with_dict(db_conf, thread_count)
raise TypeError(f"db configuration must be a URI string or a dictionary. Instead got '{db_conf}'.")
conn = self.connect_with_dict(db_conf, thread_count)
else:
raise TypeError(f"db configuration must be a URI string or a dictionary. Instead got '{db_conf}'.")

if shared:
self.conn_cache.add(db_conf, conn)
return conn
1 change: 1 addition & 0 deletions data_diff/sqeleton/databases/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def _query(self, sql_code: Union[str, ThreadLocalInterpreter]):
return self._query_conn(self._conn, sql_code)

def close(self):
super().close()
self._conn.close()

def create_connection(self):
Expand Down
1 change: 1 addition & 0 deletions data_diff/sqeleton/databases/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def _query(self, sql_code: str) -> list:
return query_cursor(c, sql_code)

def close(self):
super().close()
self._conn.close()

def select_table_schema(self, path: DbPath) -> str:
Expand Down
1 change: 1 addition & 0 deletions data_diff/sqeleton/databases/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def __init__(self, *, schema: str, **kw):
self.default_schema = schema

def close(self):
super().close()
self._conn.close()

def _query(self, sql_code: Union[str, ThreadLocalInterpreter]):
Expand Down
26 changes: 26 additions & 0 deletions data_diff/sqeleton/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Union, Dict, Any, Hashable
from weakref import ref
from typing import TypeVar
from typing import Iterable, Iterator, MutableMapping, Union, Any, Sequence, Dict
from abc import abstractmethod
Expand All @@ -9,6 +11,30 @@
# -- Common --


class WeakCache:
def __init__(self):
self._cache = {}

def _hashable_key(self, k: Union[dict, Hashable]) -> Hashable:
if isinstance(k, dict):
return tuple(k.items())
return k

def add(self, key: Union[dict, Hashable], value: Any):
key = self._hashable_key(key)
self._cache[key] = ref(value)

def get(self, key: Union[dict, Hashable]) -> Any:
key = self._hashable_key(key)

value = self._cache[key]()
if value is None:
del self._cache[key]
raise KeyError(f"Key {key} not found, or no longer a valid reference")

return value


def join_iter(joiner: Any, iterable: Iterable) -> Iterable:
it = iter(iterable)
try:
Expand Down
18 changes: 8 additions & 10 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import arrow
from datetime import datetime

from data_diff import diff_tables, connect_to_table
from data_diff import diff_tables, connect_to_table, Algorithm
from data_diff.databases import MySQL
from data_diff.sqeleton.queries import table, commit

Expand Down Expand Up @@ -36,13 +36,17 @@ def setUp(self) -> None:
)

def test_api(self):
# test basic
t1 = connect_to_table(TEST_MYSQL_CONN_STRING, self.table_src_name)
t2 = connect_to_table(TEST_MYSQL_CONN_STRING, (self.table_dst_name,))
diff = list(diff_tables(t1, t2))
diff = list(diff_tables(t1, t2, algorithm=Algorithm.JOINDIFF))
assert len(diff) == 1

t1.database.close()
t2.database.close()
# test algorithm
# (also tests shared connection on connect_to_table)
for algo in (Algorithm.HASHDIFF, Algorithm.JOINDIFF):
diff = list(diff_tables(t1, t2, algorithm=algo))
assert len(diff) == 1

# test where
diff_id = diff[0][1][0]
Expand All @@ -53,9 +57,6 @@ def test_api(self):
diff = list(diff_tables(t1, t2))
assert len(diff) == 0

t1.database.close()
t2.database.close()

def test_api_get_stats_dict(self):
# XXX Likely to change in the future
expected_dict = {
Expand All @@ -76,6 +77,3 @@ def test_api_get_stats_dict(self):
self.assertEqual(expected_dict, output)
self.assertIsNotNone(diff)
assert len(list(diff)) == 1

t1.database.close()
t2.database.close()

0 comments on commit 2e36969

Please sign in to comment.