Skip to content

Commit

Permalink
Allow foreign keys to tables with a composite primary key
Browse files Browse the repository at this point in the history
  • Loading branch information
gvx committed Jan 28, 2021
1 parent 5b9d57e commit 6008098
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 53 deletions.
8 changes: 2 additions & 6 deletions docs/source/wurm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,6 @@ type for each column has to be one of the following:
* A type registered with :func:`wurm.register_type`.
* A previously defined :class:`wurm.Table` or
:class:`wurm.WithoutRowid` subclass.

.. note:: Only tables with a non-composite primary key may currently
be referenced this way. This means that zero fields can be
marked with :data:`wurm.Primary` for rowid tables, and
exactly one field has to be marked with :data:`wurm.Primary`
for WITHOUT ROWID tables.
* :samp:`wurm.Primary[{T}]` or :samp:`wurm.Unique[{T}]`, where
:samp:`{T}` is one of the types mentioned above.

Expand Down Expand Up @@ -111,6 +105,8 @@ type for each column has to be one of the following:

.. autofunction:: wurm.register_type

.. autofunction:: wurm.register_dataclass

-------------
Queries
-------------
Expand Down
33 changes: 26 additions & 7 deletions tests/test_wurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,23 @@ class OneToOneForeignKeyTest(wurm.Table):
class PrimaryForeignKeyTest(wurm.WithoutRowid):
point: wurm.Primary[Point]

@wurm.register_dataclass
@dataclass
class Color:
r: float
g: float
b: float
a: float

@dataclass
class MultiColumnFields(wurm.Table):
color: Color

@dataclass
class ForeignKeyTest2(wurm.Table):
point: CompositeKey


@pytest.fixture
def connection():
wurm.setup_connection(sqlite3.connect(":memory:"))
Expand Down Expand Up @@ -324,13 +341,10 @@ def test_foreign_keys_1(connection):
assert isinstance(ForeignKeyTest.query(rowid=1).one().point, Point)

def test_foreign_keys_2():
@dataclass
class ForeignKeyTest2(wurm.Table):
point: CompositeKey
with pytest.raises(NotImplementedError):
wurm.setup_connection(sqlite3.connect(":memory:"))
# prevent the table from being created again
ForeignKeyTest2.__abstract__ = True
c = CompositeKey(1, 2)
c.insert()
ForeignKeyTest2(c).insert()
assert ForeignKeyTest2.query().one().point.part_one == 1

def test_foreign_keys_3(connection):
p = Point(1, 1)
Expand All @@ -353,3 +367,8 @@ def test_foreign_keys_4(connection):
dup = PrimaryForeignKeyTest(p)
with pytest.raises(wurm.WurmError):
dup.insert()

def test_multicolumnfields(connection):
MultiColumnFields(Color(0., 2., 8., 1.)).insert()
o, = MultiColumnFields
assert o.color == Color(0., 2., 8., 1.)
6 changes: 3 additions & 3 deletions wurm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
__version__ = '0.1.0'

from .typemaps import register_type, Unique, Primary
from .typemaps import register_type, register_dataclass, Unique, Primary
from .tables import WithoutRowid, Table
from .connection import WurmError, setup_connection
from .queries import lt, gt, le, ge, ne, eq, Query

__all__ = ['register_type', 'Unique', 'Primary', 'WurmError',
'WithoutRowid', 'Table', 'setup_connection', 'lt', 'gt', 'le',
__all__ = ['register_type', 'register_dataclass', 'Unique', 'Primary',
'WurmError', 'WithoutRowid', 'Table', 'setup_connection', 'lt', 'gt', 'le',
'ge', 'ne', 'eq', 'Query']
25 changes: 14 additions & 11 deletions wurm/queries.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from dataclasses import dataclass, fields
from dataclasses import dataclass
from typing import Any

from . import sql
from .connection import execute, WurmError
from .typemaps import from_stored, to_stored
from .typemaps import from_stored, to_stored, columns_for

@dataclass(frozen=True)
class Comparison:
Expand Down Expand Up @@ -82,16 +82,19 @@ def ensure_comparison(value):
return eq(value)

def encode_query_value(table, fieldname, value):
for field in fields(table):
if field.name == fieldname:
return to_stored(fieldname, field.type, value)
for field, ty in table.__fields_info__.items():
if field == fieldname:
return to_stored(fieldname, ty, value)
raise WurmError(f'invalid query: {table.__name__}.{fieldname} does not exist')

def decode_row(table, row):
values = {name: from_stored(ty, stored_value)
for stored_value, (name, ty)
in zip(row, table.__fields_info__.items())
}
values = {}
for name, ty in table.__fields_info__.items():
columns = len(list(columns_for(name, ty)))
values[name] = from_stored(row[:columns], ty)
row = row[columns:]
assert not row

if 'rowid' in values:
rowid = values.pop('rowid')
else:
Expand All @@ -104,7 +107,7 @@ def decode_row(table, row):
class Query:
"""Represents one or more queries on a specified table.
:samp:`Query({table}, filters)` is equivalent to :samp:`{table}.query(**filters)``"""
:samp:`Query({table}, filters)` is equivalent to :samp:`{table}.query(**filters)`"""
def __init__(self, table: type, filters: dict):
self.table = table
self.filters = {key: ensure_comparison(value) for key, value
Expand All @@ -128,7 +131,7 @@ def __len__(self):
def select_with_limit(self, limit=None):
"""Create an iterator over the results of this query.
This accesses the database.
.. note:: This method accesses the connected database.
:param limit: The number of results to limit this query to.
:type limit: int or None
Expand Down
12 changes: 11 additions & 1 deletion wurm/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,16 @@ def create_primary_key(table):
return ", ".join(column for field in table.__primary_key__
for column in columns_for(field, field_info[field]))

def get_foreign_keys(table):
from .tables import BaseTable
for name, ty in table.__fields_info__.items():
if issubclass(ty, BaseTable):
yield ', '.join(columns_for(name, ty)), ty.__table_name__

def create_foreign_keys(table):
return ''.join(f', foreign key ({cols}) references {tblname}'
for cols, tblname in get_foreign_keys(table))

def create_indexes(table):
table_name = table.__table_name__
field_info = table.__fields_info__
Expand All @@ -19,7 +29,7 @@ def create_indexes(table):

def create(table):
return (f'create table if not exists {table.__table_name__}'
f'({create_fields(table)}, '
f'({create_fields(table)}{create_foreign_keys(table)}, '
f'PRIMARY KEY ({create_primary_key(table)}))')

def count(table, where=None):
Expand Down
73 changes: 48 additions & 25 deletions wurm/typemaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,46 @@ def from_string(cls, string):
sql_equiv = SQL_EQUIVALENTS[sql_type]
TYPE_MAPPING[python_type] = StoredValueTypeMap(sql_equiv, encode, decode)

def register_dataclass(dclass):
'''Registers a dataclass for use in model fields.
This is a convenience function that can optionally be used as a
decorator. Given::
@dataclasses.dataclass
class Color:
r: float
g: float
b: float
then the following::
register_dataclass(Color)
is equivalent to::
register_type(Color, dict(r=float, g=float, b=float),
encode=dataclasses.astuple, decode=Color)
In either case, the model::
class MyTable(Table):
color: Color
will have the fields ``color_r``, ``color_g`` and ``color_b``, which
will transparently be converted to and from ``Color`` objects.
:param dclass: The dataclass to register
:returns: The registered dataclass
'''
from dataclasses import astuple, is_dataclass, fields
assert is_dataclass(dclass)
register_type(dclass,
{field.name: field.type for field in fields(dclass)},
encode=astuple, decode=dclass)
return dclass

def columns_for(field_name, python_type):
python_type = unwrap_type(python_type)
from .tables import BaseTable
if issubclass(python_type, BaseTable):
return (f'{field_name}_{pk}'
Expand All @@ -89,7 +127,6 @@ def columns_for(field_name, python_type):
return field_name,

def to_stored(field_name, python_type, value):
python_type = unwrap_type(python_type)
if value is None:
return dict.fromkeys(columns_for(field_name, python_type))

Expand All @@ -112,44 +149,30 @@ def unwrap_type(ty):
return get_args(ty)[0]
return ty

def from_stored(python_type, value):
if value is None:
def from_stored(stored_tuple, python_type):
if all(v is None for v in stored_tuple):
return None
if get_origin(python_type) is Annotated:
python_type = get_args(python_type)[0]
from .tables import BaseTable
from .queries import Query
if issubclass(python_type, BaseTable):
# FIXME: this is problematic for recursive foreign keys
# not to mention the n + 1 problem
pk, = python_type.__primary_key__
return Query(python_type, {pk: value}).one()
return TYPE_MAPPING[python_type].decode(value)
return Query(python_type, dict(zip(python_type.__primary_key__, stored_tuple))).one()
return TYPE_MAPPING[python_type].decode(*stored_tuple)

def sql_type_for(fieldname, python_type):
postfix = ['']
if get_origin(python_type) is Annotated:
python_type, *rest = get_args(python_type)
if any(_UniqueMarker is arg for arg in rest):
postfix.append('UNIQUE') # FIXME: UNIQUE needs to be
# defined like primary keys are, for when the field is
# mapped to multiple columns
from .tables import BaseTable
if issubclass(python_type, BaseTable):
if len(python_type.__primary_key__) > 1:
raise NotImplementedError('composite foreign keys are not '
'yet supported')
postfix.append('REFERENCES')
postfix.append(python_type.__table_name__)
pk, = python_type.__primary_key__
python_type = python_type.__fields_info__[pk]
return sql_type_for(f'{fieldname}_{pk}', python_type) + ' '.join(postfix)
pk = python_type.__primary_key__
info = python_type.__fields_info__
return ', '.join(sql_type_for(f'{fieldname}_{key}', info[key])
for key in pk)

mapped_type = TYPE_MAPPING[python_type].sql_type
if isinstance(mapped_type, MappingProxyType):
return ', '.join(f'{fieldname}_{key} {ty}' for key, ty
in mapped_type.items())
return f'{fieldname} {mapped_type}' + ' '.join(postfix)
return f'{fieldname} {mapped_type}'

register_type(str, str, encode=passthrough, decode=passthrough)
register_type(bytes, bytes, encode=passthrough, decode=passthrough)
Expand Down

0 comments on commit 6008098

Please sign in to comment.