Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 121 additions & 3 deletions examples/dbapi.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
"outputs": [],
"source": [
"from firebolt.db import connect\n",
"from firebolt.client import DEFAULT_API_URL"
"from firebolt.client import DEFAULT_API_URL\n",
"from datetime import datetime"
]
},
{
Expand All @@ -36,7 +37,6 @@
"source": [
"# Only one of these two parameters should be specified\n",
"engine_url = \"\"\n",
"engine_name = \"\"\n",
"assert bool(engine_url) != bool(\n",
" engine_name\n",
"), \"Specify only one of engine_name and engine_url\"\n",
Expand Down Expand Up @@ -98,8 +98,32 @@
" \"insert into test_table values (1, 'hello', '2021-01-01 01:01:01'),\"\n",
" \"(2, 'world', '2022-02-02 02:02:02'),\"\n",
" \"(3, '!', '2023-03-03 03:03:03')\"\n",
")"
]
},
{
"cell_type": "markdown",
"id": "b356295a",
"metadata": {},
"source": [
"### Parameterized query"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "929f5221",
"metadata": {},
"outputs": [],
"source": [
"cursor.execute(\n",
" \"insert into test_table values (?, ?, ?)\",\n",
" (3, \"single parameter set\", datetime.now()),\n",
")\n",
"cursor.execute(\"select * from test_table\")"
"cursor.executemany(\n",
" \"insert into test_table values (?, ?, ?)\",\n",
" ((4, \"multiple\", datetime.now()), (5, \"parameter sets\", datetime.fromtimestamp(0))),\n",
")"
]
},
{
Expand All @@ -117,6 +141,7 @@
"metadata": {},
"outputs": [],
"source": [
"cursor.execute(\"select * from test_table\")\n",
"print(\"Description: \", cursor.description)\n",
"print(\"Rowcount: \", cursor.rowcount)"
]
Expand All @@ -141,6 +166,67 @@
"print(cursor.fetchall())"
]
},
{
"cell_type": "markdown",
"id": "efc4ff0a",
"metadata": {},
"source": [
"## Multi-statement queries"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "744817b1",
"metadata": {},
"outputs": [],
"source": [
"cursor.execute(\n",
" \"\"\"\n",
" select * from test_table where id < 4;\n",
" select * from test_table where id > 2;\n",
"\"\"\"\n",
")\n",
"print(cursor._row_sets[0][2])\n",
"print(cursor._row_sets[1][2])\n",
"print(cursor._rows)\n",
"# print(\"First query: \", cursor.fetchall())\n",
"assert cursor.nextset()\n",
"print(cursor._rows)\n",
"# print(\"Secont query: \", cursor.fetchall())\n",
"assert cursor.nextset() is None"
]
},
{
"cell_type": "markdown",
"id": "02e5db2f",
"metadata": {},
"source": [
"### Error handling\n",
"If one query fails during the execution, all remaining queries are canceled.\n",
"However, you still can fetch results for successful queries"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "888500a9",
"metadata": {},
"outputs": [],
"source": [
"try:\n",
" cursor.execute(\n",
" \"\"\"\n",
" select * from test_table where id < 4;\n",
" select * from test_table where wrong_field > 2;\n",
" select * from test_table\n",
" \"\"\"\n",
" )\n",
"except:\n",
" pass\n",
"cursor.fetchall()"
]
},
{
"cell_type": "markdown",
"id": "b1cd4ff2",
Expand Down Expand Up @@ -286,6 +372,38 @@
"source": [
"await print_results(async_cursor)"
]
},
{
"cell_type": "markdown",
"id": "da36dd3f",
"metadata": {},
"source": [
"### Closing connection"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "83fc1686",
"metadata": {},
"outputs": [],
"source": [
"# manually\n",
"connection.close()\n",
"\n",
"# using context manager\n",
"with connect(\n",
" engine_url=engine_url,\n",
" engine_name=engine_name,\n",
" database=database_name,\n",
" username=username,\n",
" password=password,\n",
" api_endpoint=api_endpoint,\n",
") as conn:\n",
" # create cursors, perform database queries\n",
" pass\n",
"conn.closed"
]
}
],
"metadata": {
Expand Down
8 changes: 4 additions & 4 deletions src/firebolt/async_db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import socket
from json import JSONDecodeError
from types import TracebackType
from typing import Callable, List, Optional, Type
from typing import Any, Callable, List, Optional, Type

from httpcore.backends.auto import AutoBackend
from httpcore.backends.base import AsyncNetworkStream
Expand Down Expand Up @@ -207,15 +207,15 @@ def __init__(
self._cursors: List[BaseCursor] = []
self._is_closed = False

def cursor(self) -> BaseCursor:
def _cursor(self, **kwargs: Any) -> BaseCursor:
"""
Create new cursor object.
"""

if self.closed:
raise ConnectionClosedError("Unable to create cursor: connection closed")

c = self.cursor_class(self._client, self)
c = self.cursor_class(self._client, self, **kwargs)
self._cursors.append(c)
return c

Expand Down Expand Up @@ -279,7 +279,7 @@ class Connection(BaseConnection):
aclose = BaseConnection._aclose

def cursor(self) -> Cursor:
c = super().cursor()
c = super()._cursor()
assert isinstance(c, Cursor) # typecheck
return c

Expand Down
68 changes: 63 additions & 5 deletions src/firebolt/common/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
from asyncio import get_event_loop, new_event_loop, set_event_loop
from asyncio import (
AbstractEventLoop,
get_event_loop,
new_event_loop,
set_event_loop,
)
from functools import lru_cache, wraps
from typing import TYPE_CHECKING, Any, Callable, Type, TypeVar
from threading import Thread
from typing import (
TYPE_CHECKING,
Any,
Callable,
Coroutine,
Optional,
Type,
TypeVar,
)

T = TypeVar("T")

Expand Down Expand Up @@ -37,15 +51,59 @@ def fix_url_schema(url: str) -> str:
return url if url.startswith("http") else f"https://{url}"


def async_to_sync(f: Callable) -> Callable:
class AsyncJobThread:
"""
Thread runner that allows running async tasks syncronously in a separate thread.
Caches loop to be reused in all threads
It allows running async functions syncronously inside a running event loop.
Since nesting loops is not allowed, we create a separate thread for a new event loop
"""

def __init__(self) -> None:
self.loop: Optional[AbstractEventLoop] = None
self.result: Optional[Any] = None
self.exception: Optional[BaseException] = None

def _initialize_loop(self) -> None:
if not self.loop:
try:
# despite the docs, this function fails if no loop is set
self.loop = get_event_loop()
except RuntimeError:
self.loop = new_event_loop()
set_event_loop(self.loop)

def run(self, coro: Coroutine) -> None:
try:
self._initialize_loop()
assert self.loop is not None
self.result = self.loop.run_until_complete(coro)
except BaseException as e:
self.exception = e

def execute(self, coro: Coroutine) -> Any:
thread = Thread(target=self.run, args=[coro])
thread.start()
thread.join()
if self.exception:
raise self.exception
return self.result


def async_to_sync(f: Callable, async_job_thread: AsyncJobThread = None) -> Callable:
@wraps(f)
def sync(*args: Any, **kwargs: Any) -> Any:
try:
loop = get_event_loop()
except RuntimeError:
loop = new_event_loop()
set_event_loop(loop)
res = loop.run_until_complete(f(*args, **kwargs))
return res
# We are inside a running loop
if loop.is_running():
nonlocal async_job_thread
if not async_job_thread:
async_job_thread = AsyncJobThread()
return async_job_thread.execute(f(*args, **kwargs))
return loop.run_until_complete(f(*args, **kwargs))

return sync
10 changes: 5 additions & 5 deletions src/firebolt/db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from firebolt.async_db.connection import BaseConnection as AsyncBaseConnection
from firebolt.async_db.connection import async_connect_factory
from firebolt.common.exception import ConnectionClosedError
from firebolt.common.util import async_to_sync
from firebolt.common.util import AsyncJobThread, async_to_sync
from firebolt.db.cursor import Cursor

DEFAULT_TIMEOUT_SECONDS: int = 5
Expand All @@ -33,7 +33,7 @@ class Connection(AsyncBaseConnection):
are not implemented.
"""

__slots__ = AsyncBaseConnection.__slots__ + ("_closing_lock",)
__slots__ = AsyncBaseConnection.__slots__ + ("_closing_lock", "_async_job_thread")

cursor_class = Cursor

Expand All @@ -42,18 +42,18 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
# Holding this lock for write means that connection is closing itself.
# cursor() should hold this lock for read to read/write state
self._closing_lock = RWLockWrite()
self._async_job_thread = AsyncJobThread()

@wraps(AsyncBaseConnection.cursor)
def cursor(self) -> Cursor:
with self._closing_lock.gen_rlock():
c = super().cursor()
c = super()._cursor(async_job_thread=self._async_job_thread)
assert isinstance(c, Cursor) # typecheck
return c

@wraps(AsyncBaseConnection._aclose)
def close(self) -> None:
with self._closing_lock.gen_wlock():
async_to_sync(self._aclose)()
async_to_sync(self._aclose, self._async_job_thread)()

# Context manager support
def __enter__(self) -> Connection:
Expand Down
17 changes: 13 additions & 4 deletions src/firebolt/db/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
check_not_closed,
check_query_executed,
)
from firebolt.common.util import async_to_sync
from firebolt.common.util import AsyncJobThread, async_to_sync


class Cursor(AsyncBaseCursor):
Expand All @@ -31,11 +31,16 @@ class Cursor(AsyncBaseCursor):
with :py:func:`fetchmany` method
"""

__slots__ = AsyncBaseCursor.__slots__ + ("_query_lock", "_idx_lock")
__slots__ = AsyncBaseCursor.__slots__ + (
"_query_lock",
"_idx_lock",
"_async_job_thread",
)

def __init__(self, *args: Any, **kwargs: Any) -> None:
self._query_lock = RWLockWrite()
self._idx_lock = Lock()
self._async_job_thread: AsyncJobThread = kwargs.pop("async_job_thread")
super().__init__(*args, **kwargs)

@wraps(AsyncBaseCursor.execute)
Expand All @@ -46,14 +51,18 @@ def execute(
set_parameters: Optional[Dict] = None,
) -> int:
with self._query_lock.gen_wlock():
return async_to_sync(super().execute)(query, parameters, set_parameters)
return async_to_sync(super().execute, self._async_job_thread)(
query, parameters, set_parameters
)

@wraps(AsyncBaseCursor.executemany)
def executemany(
self, query: str, parameters_seq: Sequence[Sequence[ParameterType]]
) -> int:
with self._query_lock.gen_wlock():
return async_to_sync(super().executemany)(query, parameters_seq)
return async_to_sync(super().executemany, self._async_job_thread)(
query, parameters_seq
)

@wraps(AsyncBaseCursor._get_next_range)
def _get_next_range(self, size: int) -> Tuple[int, int]:
Expand Down
Loading