Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,36 @@ for item in result.fetchall():
print(item)
```

### [AsyncIO](https://docs.sqlalchemy.org/en/14/orm/extensions/asyncio.html) extension

```python
from sqlalchemy import text
from sqlalchemy.ext.asyncio import create_async_engine
from firebolt_db.firebolt_async_dialect import AsyncFireboltDialect
from sqlalchemy.dialects import registry

registry.register("firebolt", "src.firebolt_db.firebolt_async_dialect", "AsyncFireboltDialect")
engine = create_async_engine("firebolt://email@domain:password@sample_database/sample_engine")

async with engine.connect() as conn:

await conn.execute(
text(f"INSERT INTO example(dummy) VALUES (11)")
)

result = await conn.execute(
text(f"SELECT * FROM example")
)
print(result.fetchall())

await engine.dispose()
```


## Limitations

1. Transactions are not supported since Firebolt database does not support them at this time.
1. [AsyncIO](https://docs.sqlalchemy.org/en/14/orm/extensions/asyncio.html) is not yet implemented.
1. Parametrised calls to execute and executemany are not implemented.

## Contributing

Expand Down
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@ where = src
[options.extras_require]
dev =
devtools==0.7.0
mock==4.0.3
mypy==0.910
pre-commit==2.15.0
pytest==6.2.5
sqlalchemy-stubs
sqlalchemy-stubs==0.4

[mypy]
disallow_untyped_defs = True
Expand Down
14 changes: 0 additions & 14 deletions src/firebolt_db/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,3 @@
from firebolt.common.exception import (
DatabaseError,
DataError,
Error,
IntegrityError,
InterfaceError,
InternalError,
NotSupportedError,
OperationalError,
ProgrammingError,
Warning,
)
from firebolt.db import connect

__all__ = [
"connect",
"apilevel",
Expand Down
157 changes: 157 additions & 0 deletions src/firebolt_db/firebolt_async_dialect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
from __future__ import annotations

from asyncio import Lock
from types import ModuleType
from typing import Any, Iterator, List, Optional, Tuple

import firebolt.async_db as async_dbapi
from firebolt.async_db import Connection

# Ignoring type since sqlalchemy-stubs doesn't cover AdaptedConnection
from sqlalchemy.engine import AdaptedConnection # type: ignore[attr-defined]
from sqlalchemy.util.concurrency import await_only

from firebolt_db.firebolt_dialect import FireboltDialect


class AsyncCursorWrapper:
__slots__ = (
"_adapt_connection",
"_connection",
"await_",
"_cursor",
"_rows",
)

server_side = False

def __init__(self, adapt_connection: AsyncConnectionWrapper):
self._adapt_connection = adapt_connection
self._connection = adapt_connection._connection
self.await_ = adapt_connection.await_
self._rows: List[List] = []
self._cursor = self._connection.cursor()

def close(self) -> None:
self._rows[:] = []
self._cursor.close()

@property
def description(self) -> str:
return self._cursor.description

@property
def arraysize(self) -> int:
return self._cursor.arraysize

@arraysize.setter
def arraysize(self, value: int) -> None:
self._cursor.arraysize = value

@property
def rowcount(self) -> int:
return self._cursor.rowcount

def execute(self, operation: str, parameters: Optional[Tuple] = None) -> None:
self.await_(self._execute(operation, parameters))

async def _execute(
self, operation: str, parameters: Optional[Tuple] = None
) -> None:
async with self._adapt_connection._execute_mutex:
await self._cursor.execute(operation, parameters)
if self._cursor.description:
self._rows = await self._cursor.fetchall()
else:
self._rows = []

def executemany(self, operation: str, seq_of_parameters: List[Tuple]) -> None:
raise NotImplementedError("executemany is not supported yet")

def __iter__(self) -> Iterator[List]:
while self._rows:
yield self._rows.pop(0)

def fetchone(self) -> Optional[List]:
if self._rows:
return self._rows.pop(0)
else:
return None

def fetchmany(self, size: int = None) -> List[List]:
if size is None:
size = self._cursor.arraysize

retval = self._rows[0:size]
self._rows[:] = self._rows[size:]
return retval

def fetchall(self) -> List[List]:
retval = self._rows[:]
self._rows[:] = []
return retval


class AsyncConnectionWrapper(AdaptedConnection):
await_ = staticmethod(await_only)
__slots__ = ("dbapi", "_connection", "_execute_mutex")

def __init__(self, dbapi: AsyncAPIWrapper, connection: Connection):
self.dbapi = dbapi
self._connection = connection
self._execute_mutex = Lock()

def cursor(self) -> AsyncCursorWrapper:
return AsyncCursorWrapper(self)

def rollback(self) -> None:
pass

def commit(self) -> None:
self._connection.commit()

def close(self) -> None:
self.await_(self._connection._aclose())


class AsyncAPIWrapper(ModuleType):
"""Wrapper around Firebolt async dbapi that returns a similar wrapper for
Cursor on connect()"""

def __init__(self, dbapi: ModuleType):
self.dbapi = dbapi
self.paramstyle = dbapi.paramstyle # type: ignore[attr-defined] # noqa: F821
self._init_dbapi_attributes()

def _init_dbapi_attributes(self) -> None:
for name in (
"DatabaseError",
"Error",
"IntegrityError",
"NotSupportedError",
"OperationalError",
"ProgrammingError",
):
setattr(self, name, getattr(self.dbapi, name))

def connect(self, *arg: Any, **kw: Any) -> AsyncConnectionWrapper:

connection = await_only(self.dbapi.connect(*arg, **kw)) # type: ignore[attr-defined] # noqa: F821,E501
return AsyncConnectionWrapper(
self,
connection,
)


class AsyncFireboltDialect(FireboltDialect):
driver = "firebolt_aio"
supports_statement_cache: bool = False
supports_server_side_cursors: bool = False
is_async: bool = True

@classmethod
def dbapi(cls) -> AsyncAPIWrapper:
return AsyncAPIWrapper(async_dbapi)


dialect = AsyncFireboltDialect
16 changes: 7 additions & 9 deletions src/firebolt_db/firebolt_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
from types import ModuleType
from typing import Any, Dict, List, Optional, Tuple, Union

import firebolt.db as dbapi
import sqlalchemy.types as sqltypes
from sqlalchemy.engine import Connection as AlchemyConnection
from sqlalchemy.engine import ExecutionContext, default
from sqlalchemy.engine.url import URL
from sqlalchemy.sql import compiler
from sqlalchemy.sql import compiler, text
from sqlalchemy.types import (
BIGINT,
BOOLEAN,
Expand All @@ -19,8 +20,6 @@
VARCHAR,
)

import firebolt_db


class ARRAY(sqltypes.TypeEngine):
__visit_name__ = "ARRAY"
Expand Down Expand Up @@ -97,7 +96,7 @@ def __init__(

@classmethod
def dbapi(cls) -> ModuleType:
return firebolt_db
return dbapi

# Build firebolt-sdk compatible connection arguments.
# URL format : firebolt://username:password@host:port/db_name
Expand All @@ -117,7 +116,7 @@ def get_schema_names(
self, connection: AlchemyConnection, **kwargs: Any
) -> List[str]:
query = "select schema_name from information_schema.databases"
result = connection.execute(query)
result = connection.execute(text(query))
return [row.schema_name for row in result]

def has_table(
Expand All @@ -133,8 +132,7 @@ def has_table(
""".format(
table_name=table_name
)

result = connection.execute(query)
result = connection.execute(text(query))
return result.fetchone().exists_

def get_table_names(
Expand All @@ -146,7 +144,7 @@ def get_table_names(
query=query, schema=schema
)

result = connection.execute(query)
result = connection.execute(text(query))
return [row.table_name for row in result]

def get_view_names(
Expand Down Expand Up @@ -184,7 +182,7 @@ def get_columns(
query=query, schema=schema
)

result = connection.execute(query)
result = connection.execute(text(query))

return [
{
Expand Down
Loading