Skip to content

Commit

Permalink
Add preliminary implementation of relations
Browse files Browse the repository at this point in the history
  • Loading branch information
gvx committed Jan 31, 2021
1 parent 6008098 commit 7fcc6bd
Show file tree
Hide file tree
Showing 6 changed files with 218 additions and 27 deletions.
105 changes: 105 additions & 0 deletions tests/test_relations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from dataclasses import dataclass
import sqlite3

import wurm
import pytest


@pytest.fixture
def connection():
token = wurm.setup_connection(sqlite3.connect(":memory:"))
yield
wurm.close_connection(token)

@dataclass
class Parent(wurm.Table):
children = wurm.relation('Child.parent')

@dataclass
class Child(wurm.Table):
parent: Parent

@dataclass
class Parent2(wurm.Table):
children = wurm.relation('Child2')

@dataclass
class Child2(wurm.Table):
parent1: Parent2
parent2: Parent2

@dataclass
class WrongRelation(wurm.Table):
oh_no = wurm.relation('test_relation_1')

@dataclass
class Parent3(wurm.Table):
children = wurm.relation('Child')

@dataclass
class Parent4(wurm.Table):
children = wurm.relation('Child4')

@dataclass
class Child4(wurm.Table):
parent: Parent4

@dataclass
class Parent5(wurm.Table):
children = wurm.relation('Child.parent')

@dataclass
class Parent6(wurm.Table):
children = wurm.relation('Child6.parent', lazy='query')

@dataclass
class Child6(wurm.Table):
parent: Parent6


def test_relation_1(connection):
p = Parent()
p.insert()
Child(parent=p).insert()
Child(parent=p).insert()
Child(parent=p).insert()
assert len(p.children) == 3

def test_ambiguous_relation():
with pytest.raises(TypeError, match='multiple Parent2 fields'):
Parent2().children

def test_wrong_relation_type():
with pytest.raises(TypeError, match='invalid target'):
WrongRelation().oh_no

def test_relation_no_relevant_field():
with pytest.raises(TypeError, match='does not have a Parent3 field'):
Parent3().children

def test_relation_2(connection):
p = Parent4()
p.insert()
Child4(parent=p).insert()
Child4(parent=p).insert()
Child4(parent=p).insert()
assert len(p.children) == 3

def test_relation_field_wrong_type():
with pytest.raises(TypeError, match=r'Child\.parent is not Parent5'):
Parent5().children

def test_relation_on_class():
# FIXME: what should this do?
assert Parent.children.target == 'Child.parent'

def test_access_relation_twice(connection):
p = Parent()
p.insert()
Child(parent=p).insert()
Child(parent=p).insert()
Child(parent=p).insert()
assert len(p.children) == len(p.children)

def test_relation_3(connection):
assert isinstance(Parent6().children, wurm.Query)
28 changes: 19 additions & 9 deletions tests/test_wurm.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
from dataclasses import dataclass
from datetime import date, time, datetime
from pathlib import Path
import sqlite3

import pytest

import wurm

import sqlite3

@pytest.fixture
def connection():
token = wurm.setup_connection(sqlite3.connect(":memory:"))
yield
wurm.close_connection(token)


@dataclass
class Point(wurm.Table):
Expand Down Expand Up @@ -78,14 +85,9 @@ class MultiColumnFields(wurm.Table):
class ForeignKeyTest2(wurm.Table):
point: CompositeKey


@pytest.fixture
def connection():
wurm.setup_connection(sqlite3.connect(":memory:"))

def test_no_connection():
with pytest.raises(wurm.WurmError):
list(Point)
with pytest.raises(wurm.WurmError, match=r'setup_connection\(\) not called'):
Point(0,0).insert()

def test_model(connection):
assert Point.__table_name__ == 'Point'
Expand Down Expand Up @@ -340,7 +342,7 @@ def test_foreign_keys_1(connection):
fk.insert()
assert isinstance(ForeignKeyTest.query(rowid=1).one().point, Point)

def test_foreign_keys_2():
def test_foreign_keys_2(connection):
c = CompositeKey(1, 2)
c.insert()
ForeignKeyTest2(c).insert()
Expand Down Expand Up @@ -372,3 +374,11 @@ def test_multicolumnfields(connection):
MultiColumnFields(Color(0., 2., 8., 1.)).insert()
o, = MultiColumnFields
assert o.color == Color(0., 2., 8., 1.)

def test_register_tuple_type():
class Test:
pass
wurm.register_type(Test, (str, int), encode=..., decode=...)
from wurm import typemaps
assert list(typemaps.columns_for('field', Test)) == ['field_0',
'field_1']
6 changes: 3 additions & 3 deletions wurm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
__version__ = '0.1.0'

from .typemaps import register_type, register_dataclass, Unique, Primary
from .tables import WithoutRowid, Table
from .connection import WurmError, setup_connection
from .tables import WithoutRowid, Table, relation
from .connection import WurmError, setup_connection, close_connection
from .queries import lt, gt, le, ge, ne, eq, Query

__all__ = ['register_type', 'register_dataclass', 'Unique', 'Primary',
'WurmError', 'WithoutRowid', 'Table', 'setup_connection', 'lt', 'gt', 'le',
'ge', 'ne', 'eq', 'Query']
'ge', 'ne', 'eq', 'Query', 'relation', 'close_connection']
6 changes: 5 additions & 1 deletion wurm/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ def setup_connection(conn):
:class:`sqlite3.Connection`, before accessing the database via wurm.
This records the connection and ensures all tables are created."""
connection.set(conn)
token = connection.set(conn)
execute('PRAGMA foreign_keys = ON', conn=conn)
from .tables import BaseTable, create_tables
create_tables(BaseTable, conn)
return token

def close_connection(token):
connection.reset(token)
35 changes: 21 additions & 14 deletions wurm/queries.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Any
from typing import Any, Generic, TypeVar, Type, Dict, Optional, Iterator

from . import sql
from .connection import execute, WurmError
Expand Down Expand Up @@ -104,21 +104,28 @@ def decode_row(table, row):
item.rowid = rowid
return item

class Query:
T = TypeVar('T')

class Query(Generic[T]):
"""Represents one or more queries on a specified table.
:samp:`Query({table}, filters)` is equivalent to :samp:`{table}.query(**filters)`"""
def __init__(self, table: type, filters: dict):
table: Type[T]
filters: Dict[str, Comparison]
comparisons: str
values: Dict[str, Any]
def __init__(self, table: Type[T], filters: Dict[str, Any]) -> None:
self.table = table
self.filters = {key: ensure_comparison(value) for key, value
in filters.items()}
self.comparisons = ' and '.join(f'{key}{value.op}:{key}'
for key, value in self.filters.items())
self.values = {column: cooked
values = [(column, value.op, cooked)
for key, value in self.filters.items()
for column, cooked
in encode_query_value(table, key, value.value).items()}
def __len__(self):
in encode_query_value(table, key, value.value).items()]
self.values = {column: cooked for column, _, cooked in values}
self.comparisons = ' and '.join(f'{column}{op}:{column}'
for column, op, _ in values)
def __len__(self) -> int:
"""Returns the number of rows matching this query.
.. note:: This method accesses the connected database.
Expand All @@ -128,7 +135,7 @@ def __len__(self):
"""
c, = execute(sql.count(self.table, self.comparisons), self.values).fetchone()
return c
def select_with_limit(self, limit=None):
def select_with_limit(self, limit: Optional[int] = None) -> Iterator[T]:
"""Create an iterator over the results of this query.
.. note:: This method accesses the connected database.
Expand All @@ -142,34 +149,34 @@ def select_with_limit(self, limit=None):
values = self.values
for row in execute(sql.select(self.table, self.comparisons, limit is not None), values):
yield decode_row(self.table, row)
def __iter__(self):
def __iter__(self) -> Iterator[T]:
"""Iterate over the results of this query.
.. note:: This method accesses the connected database.
Equivalent to :meth:`select_with_limit` without specifying *limit*."""
return self.select_with_limit()
def _only_first(self, *, of):
def _only_first(self, *, of: int) -> T:
try:
i, = self.select_with_limit(of)
except ValueError as e:
raise WurmError(e.args[0].replace('values to unpack', 'rows returned')) from None
return i
def first(self):
def first(self) -> T:
"""Return the first result of this query.
.. note:: This method accesses the connected database.
:raises WurmError: if this query returns zero results"""
return self._only_first(of=1)
def one(self):
def one(self) -> T:
"""Return the only result of this query.
.. note:: This method accesses the connected database.
:raises WurmError: if this query returns zero results or more than one"""
return self._only_first(of=2)
def delete(self):
def delete(self) -> None:
"""Delete the objects matching this query.
.. warning:: Calling this on an empty query deletes all rows
Expand Down
65 changes: 65 additions & 0 deletions wurm/tables.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass, field
import sys
from typing import ClassVar, Tuple, Dict, get_type_hints
from types import MappingProxyType

Expand Down Expand Up @@ -179,3 +180,67 @@ def create_tables(tbl, conn):
for create_index_query in sql.create_indexes(table):
execute(create_index_query, conn=conn)
create_tables(table, conn)

class relation:
"""Describe a relationship between two tables.
:param str target: The name of either the target table, or the
specific field being referenced.
:param str lazy: How the relationship is loaded. Possible options:
``'select'`` (the default): the relationship is loaded lazily,
as a list of target objects; ``'query'``: the relationship is
loaded lazily, as a query on the target model; ``'strict'``: the
relationship is loaded as a list when objects of the current
model are loaded."""
def __init__(self, target: str, lazy: str = 'select'):
self.target = target
assert lazy in {'select', 'query', 'strict'}
self.lazy = lazy
def _find_target(self, owner):
search_path = self.target.split('.')
ns = self.namespace
for name in search_path[:-1]:
ns = getattr(ns, name)
if isinstance(ns, TableMeta):
target_attr = search_path[-1]
else:
ns = getattr(ns, search_path[-1])
target_attr = None
if not isinstance(ns, TableMeta):
raise TypeError(f'invalid target {self.target} for '
f'{owner.__name__}.{self.name}')
target_table = ns
if target_attr is None:
possible_attrs = [fieldname
for fieldname, ty
in target_table.__fields_info__.items()
if ty is owner]
if not possible_attrs:
raise TypeError(f'Model {target_table.__name__} does '
f'not have a {owner.__name__} field, so the '
f'relation {owner.__name__}.{self.name} is invalid.')
if len(possible_attrs) > 1:
raise TypeError(f'Model {target_table.__name__} has '
f'multiple {owner.__name__} fields: '
f'{", ".join(possible_attrs[:-1])} and '
f'{possible_attrs[-1]}. Specify the right field for '
f'the relation {owner.__name__}.{self.name}.')
target_attr, = possible_attrs
elif target_table.__fields_info__[target_attr] is not owner:
raise TypeError(f'{self.target} is not {owner.__name__}, '
f'so the relation {owner.__name__}.{self.name} is invalid.')
self.target_table = target_table
self.target_attr = target_attr
def __set_name__(self, owner, name):
self.name = name
self.namespace = sys.modules[owner.__module__]
def __get__(self, instance, owner=None):
if instance is None:
return self # FIXME: relation on class?
if not hasattr(self, 'target_table'):
self._find_target(owner)
q = Query(self.target_table, {self.target_attr: instance})
if self.lazy == 'select':
return list(q)
else:
return q

0 comments on commit 7fcc6bd

Please sign in to comment.