Skip to content

Commit

Permalink
Add identity map and strict relations
Browse files Browse the repository at this point in the history
  • Loading branch information
gvx committed Feb 6, 2021
1 parent 7fcc6bd commit b858970
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 76 deletions.
35 changes: 35 additions & 0 deletions tests/test_relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ class Parent6(wurm.Table):
class Child6(wurm.Table):
parent: Parent6

@dataclass
class Parent7(wurm.Table):
children = wurm.relation('Child7.parent', lazy='strict')

@dataclass
class Child7(wurm.Table):
parent: Parent7

def test_relation_1(connection):
p = Parent()
Expand Down Expand Up @@ -103,3 +110,31 @@ def test_access_relation_twice(connection):

def test_relation_3(connection):
assert isinstance(Parent6().children, wurm.Query)

def test_relation_4(connection):
p = Parent7()
p.insert()
Child7(parent=p).insert()
Child7(parent=p).insert()
Child7(parent=p).insert()
assert len(p.children) == 0

def test_relation_5(connection):
p = Parent7()
p.insert()
p, = Parent7
Child7(parent=p).insert()
Child7(parent=p).insert()
Child7(parent=p).insert()
assert len(p.children) == 0


def test_relation_6(connection):
p = Parent7()
p.insert()
Parent7.del_object(p) # remove object from identity mapping
p, = Parent7
Child7(parent=p).insert()
Child7(parent=p).insert()
Child7(parent=p).insert()
assert len(p.children) == 0
5 changes: 3 additions & 2 deletions tests/test_wurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,15 @@ def test_model_get(connection):
def test_model_get_doesnt_cache(connection):
p1 = Point(10, 20)
p1.insert()
assert Point.query(rowid=p1.rowid).one() is not p1
assert Point.query(rowid=p1.rowid).one() is p1

def test_model_update(connection):
p1 = Point(10, 20)
p1.insert()
p1.y = 1000
p1.commit()
assert Point.query(rowid=p1.rowid).one().y == 1000
del p1
assert Point.query().one().y == 1000

def test_model_delete1(connection):
p1 = Point(10, 20)
Expand Down
16 changes: 6 additions & 10 deletions wurm/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,20 +89,16 @@ def encode_query_value(table, fieldname, value):

def decode_row(table, row):
values = {}
pk = ()
for name, ty in table.__fields_info__.items():
columns = len(list(columns_for(name, ty)))
values[name] = from_stored(row[:columns], ty)
segment = row[:columns]
if name in table.__primary_key__:
pk += segment
values[name] = from_stored(segment, ty)
row = row[columns:]
assert not row

if 'rowid' in values:
rowid = values.pop('rowid')
else:
rowid = ...
item = table(**values)
if rowid is not ...:
item.rowid = rowid
return item
return table.get_object(pk, values)

T = TypeVar('T')

Expand Down
167 changes: 103 additions & 64 deletions wurm/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys
from typing import ClassVar, Tuple, Dict, get_type_hints
from types import MappingProxyType
from weakref import WeakValueDictionary

from .typemaps import (to_stored, Primary, is_primary, unwrap_type,
is_unique)
Expand All @@ -26,6 +27,81 @@ def indexes(fields):
return tuple((key, True) for key, value in fields.items()
if is_unique(value))

def primary_key_columns(item):
for fieldname, ty in item.__fields_info__.items():
if fieldname in item.__primary_key__:
yield from to_stored(fieldname, ty,
getattr(item, fieldname)).values()

class relation:
"""Describe a relationship between two tables.
:param str target: The name of either the target table, or the
specific field being referenced.
:param str lazy: How the relationship is loaded. Possible options:
``'select'`` (the default): the relationship is loaded lazily,
as a list of target objects; ``'query'``: the relationship is
loaded lazily, as a query on the target model; ``'strict'``: the
relationship is loaded as a list when objects of the current
model are loaded."""
def __init__(self, target: str, lazy: str = 'select'):
self.target = target
assert lazy in {'select', 'query', 'strict'}
self.lazy = lazy
def _find_target(self, owner):
search_path = self.target.split('.')
ns = self.namespace
for name in search_path[:-1]:
ns = getattr(ns, name)
if isinstance(ns, TableMeta):
target_attr = search_path[-1]
else:
ns = getattr(ns, search_path[-1])
target_attr = None
if not isinstance(ns, TableMeta):
raise TypeError(f'invalid target {self.target} for '
f'{owner.__name__}.{self.name}')
target_table = ns
if target_attr is None:
possible_attrs = [fieldname
for fieldname, ty
in target_table.__fields_info__.items()
if ty is owner]
if not possible_attrs:
raise TypeError(f'Model {target_table.__name__} does '
f'not have a {owner.__name__} field, so the '
f'relation {owner.__name__}.{self.name} is invalid.')
if len(possible_attrs) > 1:
raise TypeError(f'Model {target_table.__name__} has '
f'multiple {owner.__name__} fields: '
f'{", ".join(possible_attrs[:-1])} and '
f'{possible_attrs[-1]}. Specify the right field for '
f'the relation {owner.__name__}.{self.name}.')
target_attr, = possible_attrs
elif target_table.__fields_info__[target_attr] is not owner:
raise TypeError(f'{self.target} is not {owner.__name__}, '
f'so the relation {owner.__name__}.{self.name} is invalid.')
self.target_table = target_table
self.target_attr = target_attr
def __set_name__(self, owner, name):
self.name = name
self.namespace = sys.modules[owner.__module__]
def __get__(self, instance, owner=None):
if instance is None:
return self # FIXME: relation on class?
if not hasattr(self, 'target_table'):
self._find_target(owner)
q = Query(self.target_table, {self.target_attr: instance})
if self.lazy == 'query':
return q
else:
return list(q)

def strict_relations(classdict):
for key, value in classdict.items():
if isinstance(value, relation) and value.lazy == 'strict':
yield key

class TableMeta(type):
def __new__(cls, clsname, bases, classdict, name=None, abstract=False):
if not all(getattr(base, '__abstract__', True) for base in bases):
Expand All @@ -41,6 +117,8 @@ def __new__(cls, clsname, bases, classdict, name=None, abstract=False):
t.__primary_key__ = primary_key_fields(fields)
t.__data_fields__ = data_fields(fields)
t.__indexes__ = indexes(fields)
t.__strict_relations__ = tuple(strict_relations(classdict))
t.__id_map__ = WeakValueDictionary()
return t
def __iter__(self):
"""Iterate over all the objects in the table.
Expand Down Expand Up @@ -76,6 +154,22 @@ def query(self, **kwargs):
:returns: A query for this table.
:rtype: Query"""
return Query(self, kwargs)
def get_object(self, pk, values):
if pk not in self.__id_map__:
if 'rowid' in values:
rowid = values.pop('rowid')
else:
rowid = ...
self.__id_map__[pk] = item = self(**values)
if rowid is not ...:
item.rowid = rowid
for rel in self.__strict_relations__:
item.__dict__[rel] = getattr(item, rel)
return self.__id_map__[pk]
def add_object(self, item):
self.__id_map__[tuple(primary_key_columns(item))] = item
def del_object(self, item):
del self.__id_map__[tuple(primary_key_columns(item))]

@dataclass
class BaseTable(metaclass=TableMeta, abstract=True):
Expand Down Expand Up @@ -108,24 +202,31 @@ def display_owner(self):
``HasOwner`` will have a field called ``owner`` and a method
called ``display_owner``.
"""
__id_map__: ClassVar[dict]
__fields_info__: ClassVar[Dict[str, type]]
__data_fields__: ClassVar[Tuple[str, ...]]
__primary_key__: ClassVar[Tuple[str, ...]]
__indexes__: ClassVar[Tuple[Tuple[str, bool], ...]]
__strict_relations__: ClassVar[Tuple[str, ...]]
__abstract__: ClassVar[bool]
__table_name__: ClassVar[str]
def __new__(cls, *args, **kwargs):
if cls.__abstract__:
raise TypeError('cannot instantiate abstract table')
return super().__new__(cls)
def __post_init__(self):
for rel in self.__strict_relations__:
self.__dict__[rel] = []
def insert(self):
"""Insert a new object into the database.
.. note:: This method accesses the connected database.
"""
cursor = execute(sql.insert(type(self)), self._encode_row())
self.rowid = cursor.lastrowid
if 'rowid' in self.__primary_key__:
self.rowid = cursor.lastrowid
type(self).add_object(self)
def commit(self):
"""Commits any changes to the object to the database.
Expand All @@ -143,6 +244,7 @@ def delete(self):
"""
Query(type(self), self._primary_key()).delete()
type(self).del_object(self)
def _primary_key(self):
return {key: getattr(self, key) for key in self.__primary_key__}
def _encode_row(self):
Expand Down Expand Up @@ -181,66 +283,3 @@ def create_tables(tbl, conn):
execute(create_index_query, conn=conn)
create_tables(table, conn)

class relation:
"""Describe a relationship between two tables.
:param str target: The name of either the target table, or the
specific field being referenced.
:param str lazy: How the relationship is loaded. Possible options:
``'select'`` (the default): the relationship is loaded lazily,
as a list of target objects; ``'query'``: the relationship is
loaded lazily, as a query on the target model; ``'strict'``: the
relationship is loaded as a list when objects of the current
model are loaded."""
def __init__(self, target: str, lazy: str = 'select'):
self.target = target
assert lazy in {'select', 'query', 'strict'}
self.lazy = lazy
def _find_target(self, owner):
search_path = self.target.split('.')
ns = self.namespace
for name in search_path[:-1]:
ns = getattr(ns, name)
if isinstance(ns, TableMeta):
target_attr = search_path[-1]
else:
ns = getattr(ns, search_path[-1])
target_attr = None
if not isinstance(ns, TableMeta):
raise TypeError(f'invalid target {self.target} for '
f'{owner.__name__}.{self.name}')
target_table = ns
if target_attr is None:
possible_attrs = [fieldname
for fieldname, ty
in target_table.__fields_info__.items()
if ty is owner]
if not possible_attrs:
raise TypeError(f'Model {target_table.__name__} does '
f'not have a {owner.__name__} field, so the '
f'relation {owner.__name__}.{self.name} is invalid.')
if len(possible_attrs) > 1:
raise TypeError(f'Model {target_table.__name__} has '
f'multiple {owner.__name__} fields: '
f'{", ".join(possible_attrs[:-1])} and '
f'{possible_attrs[-1]}. Specify the right field for '
f'the relation {owner.__name__}.{self.name}.')
target_attr, = possible_attrs
elif target_table.__fields_info__[target_attr] is not owner:
raise TypeError(f'{self.target} is not {owner.__name__}, '
f'so the relation {owner.__name__}.{self.name} is invalid.')
self.target_table = target_table
self.target_attr = target_attr
def __set_name__(self, owner, name):
self.name = name
self.namespace = sys.modules[owner.__module__]
def __get__(self, instance, owner=None):
if instance is None:
return self # FIXME: relation on class?
if not hasattr(self, 'target_table'):
self._find_target(owner)
q = Query(self.target_table, {self.target_attr: instance})
if self.lazy == 'select':
return list(q)
else:
return q

0 comments on commit b858970

Please sign in to comment.