Skip to content

Commit

Permalink
➕ Add support for SQL Server
Browse files Browse the repository at this point in the history
  • Loading branch information
tarsil committed Mar 27, 2023
1 parent deedd13 commit 28ea805
Show file tree
Hide file tree
Showing 10 changed files with 433 additions and 10 deletions.
22 changes: 22 additions & 0 deletions .github/workflows/test-suite.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,30 @@ jobs:
- 5432:5432
options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5

mssql:
image: mcr.microsoft.com/mssql/server:2019-GA-ubuntu-16.04
env:
MSSQL_SA_PASSWORD: "mssql123mssql"
ACCEPT_EULA: "Y"
MSSQL_PID: "Developer"
ports:
- "1433:1433"

steps:
- uses: "actions/checkout@v3"
- uses: "actions/setup-python@v4"
with:
python-version: "${{ matrix.python-version }}"
- name: "Install drivers"
run: |
curl https://packages.microsoft.com/keys/microsoft.asc | apt-key add -
curl https://packages.microsoft.com/config/ubuntu/$(lsb_release -rs)/prod.list > /etc/apt/sources.list.d/mssql-release.list
sudo apt-get update -y
sudo ACCEPT_EULA=Y apt-get install -y msodbcsql17
sudo ACCEPT_EULA=Y apt-get install -y mssql-tools
echo 'export PATH="$PATH:/opt/mssql-tools/bin"' >> ~/.bashrc
source ~/.bashrc
sudo apt-get install -y unixodbc-dev
- name: "Install dependencies"
run: "scripts/install"
- name: "Run linting checks"
Expand All @@ -60,4 +79,7 @@ jobs:
postgresql://username:password@localhost:5432/testsuite,
postgresql+aiopg://username:password@127.0.0.1:5432/testsuite,
postgresql+asyncpg://username:password@localhost:5432/testsuite
mssql://sa:mssql123mssql@localhost:1433/master?driver=ODBC+Driver+17+for+SQL+Server,
mssql+pyodbc://sa:mssql123mssql@localhost:1433/master?driver=ODBC+Driver+17+for+SQL+Server,
mssql+aioodbc://sa:mssql123mssql@localhost:1433/master?driver=ODBC+Driver+17+for+SQL+Server
run: "scripts/test"
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Database drivers supported are:
* [aiomysql][aiomysql]
* [asyncmy][asyncmy]
* [aiosqlite][aiosqlite]
* [aioodbc][aioodbc]

You can install the required database drivers with:

Expand All @@ -45,9 +46,10 @@ $ pip install databases[aiopg]
$ pip install databases[aiomysql]
$ pip install databases[asyncmy]
$ pip install databases[aiosqlite]
$ pip install databases[aioodbc]
```

Note that if you are using any synchronous SQLAlchemy functions such as `engine.create_all()` or [alembic][alembic] migrations then you still have to install a synchronous DB driver: [psycopg2][psycopg2] for PostgreSQL and [pymysql][pymysql] for MySQL.
Note that if you are using any synchronous SQLAlchemy functions such as `engine.create_all()` or [alembic][alembic] migrations then you still have to install a synchronous DB driver: [psycopg2][psycopg2] for PostgreSQL, [pymysql][pymysql] for MySQL and [pyodbc][pyodbc] for SQL Server.

---

Expand Down Expand Up @@ -103,11 +105,13 @@ for examples of how to start using databases together with SQLAlchemy core expre
[alembic]: https://alembic.sqlalchemy.org/en/latest/
[psycopg2]: https://www.psycopg.org/
[pymysql]: https://github.com/PyMySQL/PyMySQL
[pyodbc]: https://github.com/mkleehammer/pyodbc
[asyncpg]: https://github.com/MagicStack/asyncpg
[aiopg]: https://github.com/aio-libs/aiopg
[aiomysql]: https://github.com/aio-libs/aiomysql
[asyncmy]: https://github.com/long2ice/asyncmy
[aiosqlite]: https://github.com/omnilib/aiosqlite
[aioodbc]: https://aioodbc.readthedocs.io/en/latest/

[starlette]: https://github.com/encode/starlette
[sanic]: https://github.com/huge-success/sanic
Expand Down
307 changes: 307 additions & 0 deletions databases/backends/mssql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,307 @@
import getpass
import logging
import typing
import uuid

import aioodbc
from sqlalchemy.dialects.mssql import pyodbc
from sqlalchemy.engine.cursor import CursorResultMetaData
from sqlalchemy.engine.interfaces import Dialect, ExecutionContext
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql.ddl import DDLElement

from databases.backends.common.records import Record, Row, create_column_maps
from databases.core import LOG_EXTRA, DatabaseURL
from databases.interfaces import (
ConnectionBackend,
DatabaseBackend,
Record as RecordInterface,
TransactionBackend,
)

logger = logging.getLogger("databases")


class MSSQLBackend(DatabaseBackend):
def __init__(
self, database_url: typing.Union[DatabaseURL, str], **options: typing.Any
) -> None:
self._database_url = DatabaseURL(database_url)
self._options = options
self._dialect = pyodbc.dialect(paramstyle="pyformat")
self._dialect.supports_native_decimal = True
self._pool: aioodbc.Pool = None

def _get_connection_kwargs(self) -> dict:
url_options = self._database_url.options

kwargs = {}
min_size = url_options.get("min_size")
max_size = url_options.get("max_size")
pool_recycle = url_options.get("pool_recycle")
ssl = url_options.get("ssl")
driver = url_options.get("driver")
trusted_connection = url_options.get("trusted_connection", "no")

assert driver is not None, "The driver must be specified"

if min_size is not None:
kwargs["minsize"] = int(min_size)
if max_size is not None:
kwargs["maxsize"] = int(max_size)
if pool_recycle is not None:
kwargs["pool_recycle"] = int(pool_recycle)
if ssl is not None:
kwargs["ssl"] = {"true": True, "false": False}[ssl.lower()]

kwargs["trusted_connection"] = trusted_connection.lower()
kwargs["driver"] = driver

for key, value in self._options.items():
# Coerce 'min_size' and 'max_size' for consistency.
if key == "min_size":
key = "minsize"
elif key == "max_size":
key = "maxsize"
kwargs[key] = value

return kwargs

async def connect(self) -> None:
assert self._pool is None, "DatabaseBackend is already running"
kwargs = self._get_connection_kwargs()

driver = kwargs["driver"]
database = self._database_url.database
hostname = self._database_url.hostname
port = self._database_url.port or 1433
user = self._database_url.username or getpass.getuser()
password = self._database_url.password

dsn = f"Driver={driver};Database={database};Server={hostname};UID={user};PWD={password};Port={port}"

self._pool = await aioodbc.create_pool(
dsn=dsn,
autocommit=True,
**kwargs,
)

async def disconnect(self) -> None:
assert self._pool is not None, "DatabaseBackend is not running"
self._pool.close()
await self._pool.wait_closed()
self._pool = None

def connection(self) -> "MSSQLConnection":
return MSSQLConnection(self, self._dialect)


class CompilationContext:
def __init__(self, context: ExecutionContext):
self.context = context


class MSSQLConnection(ConnectionBackend):
def __init__(self, database: MSSQLBackend, dialect: Dialect) -> None:
self._database = database
self._dialect = dialect
self._connection: typing.Optional[aioodbc.Connection] = None

async def acquire(self) -> None:
assert self._connection is None, "Connection is already acquired"
assert self._database._pool is not None, "DatabaseBackend is not running"
self._connection = await self._database._pool.acquire()

async def release(self) -> None:
assert self._connection is not None, "Connection is not acquired"
assert self._database._pool is not None, "DatabaseBackend is not running"
await self._database._pool.release(self._connection)
self._connection = None

async def fetch_all(self, query: ClauseElement) -> typing.List["RecordInterface"]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, result_columns, context = self._compile(query)
column_maps = create_column_maps(result_columns)
dialect = self._dialect
cursor = await self._connection.cursor()
try:
await cursor.execute(query_str, args)
rows = await cursor.fetchall()
metadata = CursorResultMetaData(context, cursor.description)
rows = [
Row(
metadata,
metadata._processors,
metadata._keymap,
Row._default_key_style,
row,
)
for row in rows
]
return [Record(row, result_columns, dialect, column_maps) for row in rows]
finally:
await cursor.close()

async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, result_columns, context = self._compile(query)
column_maps = create_column_maps(result_columns)
dialect = self._dialect
cursor = await self._connection.cursor()
try:
await cursor.execute(query_str, args)
row = await cursor.fetchone()
if row is None:
return None
metadata = CursorResultMetaData(context, cursor.description)
row = Row(
metadata,
metadata._processors,
metadata._keymap,
Row._default_key_style,
row,
)
return Record(row, result_columns, dialect, column_maps)
finally:
await cursor.close()

async def execute(self, query: ClauseElement) -> typing.Any:
assert self._connection is not None, "Connection is not acquired"
query_str, args, _, _ = self._compile(query)
cursor = await self._connection.cursor()
try:
values = await cursor.execute(query_str, args)
try:
values = await values.fetchone()
return values[0]
except Exception:
...
finally:
await cursor.close()

async def execute_many(self, queries: typing.List[ClauseElement]) -> None:
assert self._connection is not None, "Connection is not acquired"
cursor = await self._connection.cursor()
try:
for single_query in queries:
single_query, args, _, _ = self._compile(single_query)
await cursor.execute(single_query, args)
finally:
await cursor.close()

async def iterate(
self, query: ClauseElement
) -> typing.AsyncGenerator[typing.Any, None]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, result_columns, context = self._compile(query)
column_maps = create_column_maps(result_columns)
dialect = self._dialect
cursor = await self._connection.cursor()
try:
await cursor.execute(query_str, args)
metadata = CursorResultMetaData(context, cursor.description)
async for row in cursor:
record = Row(
metadata,
metadata._processors,
metadata._keymap,
Row._default_key_style,
row,
)
yield Record(record, result_columns, dialect, column_maps)
finally:
await cursor.close()

def transaction(self) -> TransactionBackend:
return MSSQLTransaction(self)

def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]:
compiled = query.compile(
dialect=self._dialect, compile_kwargs={"render_postcompile": True}
)

execution_context = self._dialect.execution_ctx_cls()
execution_context.dialect = self._dialect

if not isinstance(query, DDLElement):
compiled_params = compiled.params.items()

mapping = {key: "?" for _, (key, _) in enumerate(compiled_params, start=1)}
compiled_query = compiled.string % mapping

processors = compiled._bind_processors
args = [
processors[key](val) if key in processors else val
for key, val in compiled_params
]

execution_context.result_column_struct = (
compiled._result_columns,
compiled._ordered_columns,
compiled._textual_ordered_columns,
compiled._ad_hoc_textual,
compiled._loose_column_name_matching,
)

result_map = compiled._result_columns
else:
compiled_query = compiled.string
args = []
result_map = None

query_message = compiled_query.replace(" \n", " ").replace("\n", " ")
logger.debug(
"Query: %s Args: %s", query_message, repr(tuple(args)), extra=LOG_EXTRA
)
return compiled_query, args, result_map, CompilationContext(execution_context)

@property
def raw_connection(self) -> aioodbc.connection.Connection:
assert self._connection is not None, "Connection is not acquired"
return self._connection


class MSSQLTransaction(TransactionBackend):
def __init__(self, connection: MSSQLConnection):
self._connection = connection
self._is_root = False
self._savepoint_name = ""

async def start(
self, is_root: bool, extra_options: typing.Dict[typing.Any, typing.Any]
) -> None:
assert self._connection._connection is not None, "Connection is not acquired"
self._is_root = is_root
cursor = await self._connection._connection.cursor()
if self._is_root:
await cursor.execute("BEGIN TRANSACTION")
else:
id = str(uuid.uuid4()).replace("-", "_")[:12]
self._savepoint_name = f"STARLETTE_SAVEPOINT_{id}"
try:
await cursor.execute(f"SAVE TRANSACTION {self._savepoint_name}")
finally:
cursor.close()

async def commit(self) -> None:
assert self._connection._connection is not None, "Connection is not acquired"
cursor = await self._connection._connection.cursor()
if self._is_root:
await cursor.execute("COMMIT TRANSACTION")
else:
try:
await cursor.execute(f"COMMIT TRANSACTION {self._savepoint_name}")
finally:
cursor.close()

async def rollback(self) -> None:
assert self._connection._connection is not None, "Connection is not acquired"
cursor = await self._connection._connection.cursor()
if self._is_root:
await cursor.execute("BEGIN TRANSACTION")
await cursor.execute("ROLLBACK TRANSACTION")
else:
try:
await cursor.execute(f"ROLLBACK TRANSACTION {self._savepoint_name}")
finally:
cursor.close()
3 changes: 3 additions & 0 deletions databases/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ class Database:
"postgres": "databases.backends.postgres:PostgresBackend",
"mysql": "databases.backends.mysql:MySQLBackend",
"mysql+asyncmy": "databases.backends.asyncmy:AsyncMyBackend",
"mssql": "databases.backends.mssql:MSSQLBackend",
"mssql+pyodbc": "databases.backends.mssql:MSSQLBackend",
"mssql+aioodbc": "databases.backends.mssql:MSSQLBackend",
"sqlite": "databases.backends.sqlite:SQLiteBackend",
}

Expand Down
Loading

0 comments on commit 28ea805

Please sign in to comment.