Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/firebolt/async_db/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ class _InternalType(Enum):
"""Enum of all internal firebolt types except for array."""

# INT, INTEGER
Int8 = "Int8"
UInt8 = "UInt8"
Int16 = "Int16"
UInt16 = "UInt16"
Int32 = "Int32"
UInt32 = "UInt32"
Expand Down Expand Up @@ -125,7 +127,9 @@ class _InternalType(Enum):
def python_type(self) -> type:
"""Convert internal type to python type."""
types = {
_InternalType.Int8: int,
_InternalType.UInt8: int,
_InternalType.Int16: int,
_InternalType.UInt16: int,
_InternalType.Int32: int,
_InternalType.UInt32: int,
Expand Down
12 changes: 5 additions & 7 deletions src/firebolt/db/connection.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations

from asyncio import new_event_loop
from functools import wraps
from inspect import cleandoc
from types import TracebackType
from typing import Any
from warnings import warn

from readerwriterlock.rwlock import RWLockWrite

Expand Down Expand Up @@ -38,7 +38,7 @@ class Connection(AsyncBaseConnection):
"""
)

__slots__ = AsyncBaseConnection.__slots__ + ("_closing_lock", "_loop")
__slots__ = AsyncBaseConnection.__slots__ + ("_closing_lock",)

cursor_class = Cursor

Expand All @@ -47,7 +47,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
# Holding this lock for write means that connection is closing itself.
# cursor() should hold this lock for read to read/write state
self._closing_lock = RWLockWrite()
self._loop = new_event_loop()

@wraps(AsyncBaseConnection.cursor)
def cursor(self) -> Cursor:
Expand All @@ -59,9 +58,7 @@ def cursor(self) -> Cursor:
@wraps(AsyncBaseConnection._aclose)
def close(self) -> None:
with self._closing_lock.gen_wlock():
if not self.closed:
self._loop.run_until_complete(self._aclose())
self._loop.close()
async_to_sync(self._aclose)()

# Context manager support
def __enter__(self) -> Connection:
Expand All @@ -75,7 +72,8 @@ def __exit__(
self.close()

def __del__(self) -> None:
self.close()
if not self.closed:
warn(f"Unclosed {self!r}", UserWarning)


connect = async_to_sync(async_connect_factory(Connection))
20 changes: 12 additions & 8 deletions tests/integration/dbapi/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,20 +70,22 @@ def api_endpoint() -> str:
@fixture
def all_types_query() -> str:
return (
"select 1 as uint8, 258 as uint16, 80000 as uint32, -30000 as int32,"
"30000000000 as uint64, -30000000000 as int64, cast(1.23 AS FLOAT) as float32,"
" 1.2345678901234 as float64, 'text' as \"string\", "
"CAST('2021-03-28' AS DATE) as \"date\", "
'CAST(\'2019-07-31 01:01:01\' AS DATETIME) as "datetime", true as "bool",'
'[1,2,3,4] as "array", cast(null as int) as nullable'
"select 1 as uint8, -1 as int8, 257 as uint16, -257 as int16, 80000 as uint32,"
" -80000 as int32, 30000000000 as uint64, -30000000000 as int64, cast(1.23 AS"
" FLOAT) as float32, 1.2345678901234 as float64, 'text' as \"string\","
" CAST('2021-03-28' AS DATE) as \"date\", CAST('2019-07-31 01:01:01' AS"
' DATETIME) as "datetime", true as "bool",[1,2,3,4] as "array", cast(null as'
" int) as nullable"
)


@fixture
def all_types_query_description() -> List[Column]:
return [
Column("uint8", int, None, None, None, None, None),
Column("int8", int, None, None, None, None, None),
Column("uint16", int, None, None, None, None, None),
Column("int16", int, None, None, None, None, None),
Column("uint32", int, None, None, None, None, None),
Column("int32", int, None, None, None, None, None),
Column("uint64", int, None, None, None, None, None),
Expand All @@ -104,9 +106,11 @@ def all_types_query_response() -> List[ColType]:
return [
[
1,
258,
-1,
257,
-257,
80000,
-30000,
-80000,
30000000000,
-30000000000,
1.23,
Expand Down
8 changes: 6 additions & 2 deletions tests/integration/dbapi/sync/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
def connection(
engine_url: str, database_name: str, username: str, password: str, api_endpoint: str
) -> Connection:
return connect(
connection = connect(
engine_url=engine_url,
database=database_name,
username=username,
password=password,
api_endpoint=api_endpoint,
)
yield connection
connection.close()


@fixture
Expand All @@ -24,10 +26,12 @@ def connection_engine_name(
password: str,
api_endpoint: str,
) -> Connection:
return connect(
connection = connect(
engine_name=engine_name,
database=database_name,
username=username,
password=password,
api_endpoint=api_endpoint,
)
yield connection
connection.close()
48 changes: 24 additions & 24 deletions tests/integration/dbapi/sync/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,34 @@ def test_invalid_credentials(
engine_url: str, database_name: str, username: str, password: str, api_endpoint: str
) -> None:
"""Connection properly reacts to invalid credentials error"""
connection = connect(
with connect(
engine_url=engine_url,
database=database_name,
username=username + "_",
password=password + "_",
api_endpoint=api_endpoint,
)
with raises(AuthenticationError) as exc_info:
connection.cursor().execute("show tables")
) as connection:
with raises(AuthenticationError) as exc_info:
connection.cursor().execute("show tables")

assert str(exc_info.value).startswith(
"Failed to authenticate"
), "Invalid authentication error message"
assert str(exc_info.value).startswith(
"Failed to authenticate"
), "Invalid authentication error message"


def test_engine_url_not_exists(
engine_url: str, database_name: str, username: str, password: str, api_endpoint: str
) -> None:
"""Connection properly reacts to invalid engine url error"""
connection = connect(
with connect(
engine_url=engine_url + "_",
database=database_name,
username=username,
password=password,
api_endpoint=api_endpoint,
)
with raises(ConnectError):
connection.cursor().execute("show tables")
) as connection:
with raises(ConnectError):
connection.cursor().execute("show tables")


def test_engine_name_not_exists(
Expand All @@ -54,14 +54,14 @@ def test_engine_name_not_exists(
) -> None:
"""Connection properly reacts to invalid engine name error"""
with raises(FireboltEngineError):
connection = connect(
with connect(
engine_name=engine_name + "_________",
database=database_name,
username=username,
password=password,
api_endpoint=api_endpoint,
)
connection.cursor().execute("show tables")
) as connection:
connection.cursor().execute("show tables")


def test_engine_stopped(
Expand All @@ -73,34 +73,34 @@ def test_engine_stopped(
) -> None:
"""Connection properly reacts to engine not running error"""
with raises(EngineNotRunningError):
connection = connect(
with connect(
engine_url=stopped_engine_url,
database=database_name,
username=username,
password=password,
api_endpoint=api_endpoint,
)
connection.cursor().execute("show tables")
) as connection:
connection.cursor().execute("show tables")


def test_database_not_exists(
engine_url: str, database_name: str, username: str, password: str, api_endpoint: str
) -> None:
"""Connection properly reacts to invalid database error"""
new_db_name = database_name + "_"
connection = connect(
with connect(
engine_url=engine_url,
database=new_db_name,
username=username,
password=password,
api_endpoint=api_endpoint,
)
with raises(FireboltDatabaseError) as exc_info:
connection.cursor().execute("show tables")
) as connection:
with raises(FireboltDatabaseError) as exc_info:
connection.cursor().execute("show tables")

assert (
str(exc_info.value) == f"Database {new_db_name} does not exist"
), "Invalid database name error message"
assert (
str(exc_info.value) == f"Database {new_db_name} does not exist"
), "Invalid database name error message"


def test_sql_error(connection: Connection) -> None:
Expand Down
36 changes: 25 additions & 11 deletions tests/unit/db/test_connection.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Callable, List

from httpx import codes
from pytest import raises
from pytest import raises, warns
from pytest_httpx import HTTPXMock

from firebolt.async_db._types import ColType
Expand Down Expand Up @@ -54,24 +54,28 @@ def test_cursor_initialized(
httpx_mock.add_callback(query_callback, url=query_url)

for url in (settings.server, f"https://{settings.server}"):
connection = connect(
with connect(
engine_url=url,
database=db_name,
username="u",
password="p",
api_endpoint=settings.server,
)
) as connection:

cursor = connection.cursor()
assert cursor.connection == connection, "Invalid cursor connection attribute"
assert cursor._client == connection._client, "Invalid cursor _client attribute"
cursor = connection.cursor()
assert (
cursor.connection == connection
), "Invalid cursor connection attribute"
assert (
cursor._client == connection._client
), "Invalid cursor _client attribute"

assert cursor.execute("select*") == len(python_query_data)
assert cursor.execute("select*") == len(python_query_data)

cursor.close()
assert (
cursor not in connection._cursors
), "Cursor wasn't removed from connection after close"
cursor.close()
assert (
cursor not in connection._cursors
), "Cursor wasn't removed from connection after close"


def test_connect_empty_parameters():
Expand Down Expand Up @@ -154,3 +158,13 @@ def test_connect_engine_name(
api_endpoint=settings.server,
) as connection:
assert connection.cursor().execute("select*") == len(python_query_data)


def test_connection_unclosed_warnings():
c = Connection("", "", "", "", "")
with warns(UserWarning) as winfo:
del c

assert "Unclosed" in str(
winfo.list[0].message
), "Invalid unclosed connection warning"
4 changes: 2 additions & 2 deletions tests/unit/service/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ def test_get_connection(
manager = ResourceManager(settings=settings)
engine = manager.engines.create(name=engine_name)

connection = engine.get_connection()
assert connection
with engine.get_connection() as connection:
assert connection


def test_attach_to_database(
Expand Down