Skip to content

Commit

Permalink
Fix some column types being parsed twice (#582)
Browse files Browse the repository at this point in the history
* Fix JSON and enum type columns
* Add time to reparsing check, add date and time tests
* Make processed types inclusive rather than exclusive, limit to just DIALECT_EXCLUDE
  • Loading branch information
pmdevita authored Feb 22, 2024
1 parent 1e40ad1 commit 0fc16b2
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 8 deletions.
13 changes: 6 additions & 7 deletions databases/backends/common/records.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import json
import enum
import typing
from datetime import date, datetime
from datetime import date, datetime, time

from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.engine.row import Row as SQLRow
from sqlalchemy.sql.compiler import _CompileLabel
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.sqltypes import JSON
from sqlalchemy.types import TypeEngine

from databases.interfaces import Record as RecordInterface
Expand Down Expand Up @@ -62,12 +63,10 @@ def __getitem__(self, key: typing.Any) -> typing.Any:
raw = self._row[idx]
processor = datatype._cached_result_processor(self._dialect, None)

if self._dialect.name not in DIALECT_EXCLUDE:
if isinstance(raw, dict):
raw = json.dumps(raw)
if self._dialect.name in DIALECT_EXCLUDE:
if processor is not None and isinstance(raw, (int, str, float)):
return processor(raw)

if processor is not None and (not isinstance(raw, (datetime, date))):
return processor(raw)
return raw

def __iter__(self) -> typing.Iterator:
Expand Down
138 changes: 137 additions & 1 deletion tests/test_databases.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import datetime
import decimal
import enum
import functools
import gc
import itertools
Expand Down Expand Up @@ -55,6 +56,47 @@ def process_result_value(self, value, dialect):
sqlalchemy.Column("published", sqlalchemy.DateTime),
)

# Used to test Date
events = sqlalchemy.Table(
"events",
metadata,
sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True),
sqlalchemy.Column("date", sqlalchemy.Date),
)


# Used to test Time
daily_schedule = sqlalchemy.Table(
"daily_schedule",
metadata,
sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True),
sqlalchemy.Column("time", sqlalchemy.Time),
)


class TshirtSize(enum.Enum):
SMALL = "SMALL"
MEDIUM = "MEDIUM"
LARGE = "LARGE"
XL = "XL"


class TshirtColor(enum.Enum):
BLUE = 0
GREEN = 1
YELLOW = 2
RED = 3


# Used to test Enum
tshirt_size = sqlalchemy.Table(
"tshirt_size",
metadata,
sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True),
sqlalchemy.Column("size", sqlalchemy.Enum(TshirtSize)),
sqlalchemy.Column("color", sqlalchemy.Enum(TshirtColor)),
)

# Used to test JSON
session = sqlalchemy.Table(
"session",
Expand Down Expand Up @@ -928,6 +970,52 @@ async def test_datetime_field(database_url):
assert results[0]["published"] == now


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_date_field(database_url):
"""
Test Date columns, to ensure records are coerced to/from proper Python types.
"""

async with Database(database_url) as database:
async with database.transaction(force_rollback=True):
now = datetime.date.today()

# execute()
query = events.insert()
values = {"date": now}
await database.execute(query, values)

# fetch_all()
query = events.select()
results = await database.fetch_all(query=query)
assert len(results) == 1
assert results[0]["date"] == now


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_time_field(database_url):
"""
Test Time columns, to ensure records are coerced to/from proper Python types.
"""

async with Database(database_url) as database:
async with database.transaction(force_rollback=True):
now = datetime.datetime.now().time().replace(microsecond=0)

# execute()
query = daily_schedule.insert()
values = {"time": now}
await database.execute(query, values)

# fetch_all()
query = daily_schedule.select()
results = await database.fetch_all(query=query)
assert len(results) == 1
assert results[0]["time"] == now


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_decimal_field(database_url):
Expand Down Expand Up @@ -957,7 +1045,32 @@ async def test_decimal_field(database_url):

@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_json_field(database_url):
async def test_enum_field(database_url):
"""
Test enum columns, to ensure correct cross-database support.
"""

async with Database(database_url) as database:
async with database.transaction(force_rollback=True):
# execute()
size = TshirtSize.SMALL
color = TshirtColor.GREEN
values = {"size": size, "color": color}
query = tshirt_size.insert()
await database.execute(query, values)

# fetch_all()
query = tshirt_size.select()
results = await database.fetch_all(query=query)

assert len(results) == 1
assert results[0]["size"] == size
assert results[0]["color"] == color


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_json_dict_field(database_url):
"""
Test JSON columns, to ensure correct cross-database support.
"""
Expand All @@ -978,6 +1091,29 @@ async def test_json_field(database_url):
assert results[0]["data"] == {"text": "hello", "boolean": True, "int": 1}


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_json_list_field(database_url):
"""
Test JSON columns, to ensure correct cross-database support.
"""

async with Database(database_url) as database:
async with database.transaction(force_rollback=True):
# execute()
data = ["lemon", "raspberry", "lime", "pumice"]
values = {"data": data}
query = session.insert()
await database.execute(query, values)

# fetch_all()
query = session.select()
results = await database.fetch_all(query=query)

assert len(results) == 1
assert results[0]["data"] == ["lemon", "raspberry", "lime", "pumice"]


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_custom_field(database_url):
Expand Down

0 comments on commit 0fc16b2

Please sign in to comment.