diff --git a/datajoint/fetch.py b/datajoint/fetch.py index d5b99c38f..13758d42f 100644 --- a/datajoint/fetch.py +++ b/datajoint/fetch.py @@ -1,8 +1,13 @@ from collections import OrderedDict +from functools import wraps +import itertools +import re from .blob import unpack import numpy as np from datajoint import DataJointError from . import key as PRIMARY_KEY +from collections import abc + def prepare_attributes(relation, item): if isinstance(item, str) or item is PRIMARY_KEY: @@ -20,46 +25,57 @@ def prepare_attributes(relation, item): raise DataJointError("Index must be a slice, a tuple, a list, a string.") return item, attributes -class FetchQuery: +def copy_first(f): + @wraps(f) + def ret(*args, **kwargs): + args = list(args) + args[0] = args[0].__class__(args[0]) # call copy constructor + return f(*args, **kwargs) - def __init__(self, relation): - """ + return ret - """ - self.behavior = dict( - offset=0, limit=None, order_by=None, descending=False, as_dict=False, map=None - ) - self._relation = relation +class Fetch: + def __init__(self, relation): + if isinstance(relation, Fetch): # copy constructor + self.behavior = dict(relation.behavior) + self._relation = relation._relation + else: + self.behavior = dict( + offset=0, limit=None, order_by=None, as_dict=False + ) + self._relation = relation + @copy_first def from_to(self, fro, to): self.behavior['offset'] = fro self.behavior['limit'] = to - fro return self - def order_by(self, order_by): - self.behavior['order_by'] = order_by + @copy_first + def order_by(self, *args): + if len(args) > 0: + self.behavior['order_by'] = self.behavior['order_by'] if self.behavior['order_by'] is not None else [] + namepat = re.compile(r"\s*(?P\w+).*") + for a in args: # remove duplicates + name = namepat.match(a).group('name') + pat = re.compile(r"%s(\s*$|\s+(\S*\s*)*$)" % (name,)) + self.behavior['order_by'] = [e for e in self.behavior['order_by'] if not pat.match(e)] + self.behavior['order_by'].extend(args) return self + @copy_first def as_dict(self): self.behavior['as_dict'] = True - - def ascending(self): - self.behavior['descending'] = False return self - def descending(self): - self.behavior['descending'] = True - return self - - def apply(self, f): - self.behavior['map'] = f - return self - def limit_by(self, limit): + @copy_first + def limit_to(self, limit): self.behavior['limit'] = limit return self + @copy_first def set_behavior(self, **kwargs): self.behavior.update(kwargs) return self @@ -78,9 +94,7 @@ def __call__(self, **kwargs): """ behavior = dict(self.behavior, **kwargs) - cur = self._relation.cursor(offset=behavior['offset'], limit=behavior['limit'], - order_by=behavior['order_by'], descending=behavior['descending'], - as_dict=behavior['as_dict']) + cur = self._relation.cursor(**behavior) heading = self._relation.heading if behavior['as_dict']: @@ -92,22 +106,15 @@ def __call__(self, **kwargs): for blob_name in heading.blobs: ret[blob_name] = list(map(unpack, ret[blob_name])) - if behavior['map'] is not None: - f = behavior['map'] - for i in range(len(ret)): - ret[i] = f(ret[i]) - return ret def __iter__(self): """ Iterator that returns the contents of the database. """ - behavior = self.behavior + behavior = dict(self.behavior) - cur = self._relation.cursor(offset=behavior['offset'], limit=behavior['limit'], - order_by=behavior['order_by'], descending=behavior['descending'], - as_dict=behavior['as_dict']) + cur = self._relation.cursor(**behavior) heading = self._relation.heading do_unpack = tuple(h in heading.blobs for h in heading.names) @@ -126,10 +133,10 @@ def keys(self, **kwargs): """ Iterator that returns primary keys. """ + b = dict(self.behavior, **kwargs) if 'as_dict' not in kwargs: - kwargs['as_dict'] = True - yield from self._relation.project().fetch.set_behavior(**kwargs) - + b['as_dict'] = True + yield from self._relation.project().fetch.set_behavior(**b) def __getitem__(self, item): """ @@ -146,7 +153,7 @@ def __getitem__(self, item): single_output = isinstance(item, str) or item is PRIMARY_KEY or isinstance(item, int) item, attributes = prepare_attributes(self._relation, item) - result = self._relation.project(*attributes).fetch() + result = self._relation.project(*attributes).fetch(**self.behavior) return_values = [ np.ndarray(result.shape, np.dtype({name: result.dtype.fields[name] for name in self._relation.primary_key}), @@ -158,8 +165,7 @@ def __getitem__(self, item): return return_values[0] if single_output else return_values -class Fetch1Query: - +class Fetch1: def __init__(self, relation): self._relation = relation @@ -202,4 +208,4 @@ def __getitem__(self, item): else result[attribute][0] for attribute in item ) - return return_values[0] if single_output else return_values \ No newline at end of file + return return_values[0] if single_output else return_values diff --git a/datajoint/relational_operand.py b/datajoint/relational_operand.py index a89cd2c76..769dd7986 100644 --- a/datajoint/relational_operand.py +++ b/datajoint/relational_operand.py @@ -10,7 +10,7 @@ from . import DataJointError import logging -from .fetch import FetchQuery, Fetch1Query +from .fetch import Fetch, Fetch1 logger = logging.getLogger(__name__) @@ -171,7 +171,7 @@ def __call__(self, *args, **kwargs): """ return self.fetch(*args, **kwargs) - def cursor(self, offset=0, limit=None, order_by=None, descending=False, as_dict=False): + def cursor(self, offset=0, limit=None, order_by=None, as_dict=False): """ Return query cursor. See Relation.fetch() for input description. @@ -182,8 +182,7 @@ def cursor(self, offset=0, limit=None, order_by=None, descending=False, as_dict= sql = self.make_select() if order_by is not None: sql += ' ORDER BY ' + ', '.join(order_by) - if descending: - sql += ' DESC' + if limit is not None: sql += ' LIMIT %d' % limit if offset: @@ -206,12 +205,13 @@ def __repr__(self): repr_string += ' (%d tuples)\n' % len(self) return repr_string + @property def fetch1(self): - return Fetch1Query(self) + return Fetch1(self) @property def fetch(self): - return FetchQuery(self) + return Fetch(self) @property def where_clause(self): @@ -253,8 +253,6 @@ def make_condition(arg): return ' WHERE ' + ' AND '.join(condition_string) - - class Not: """ inverse restriction @@ -319,9 +317,9 @@ def __init__(self, arg, group=None, *attributes, **renamed_attributes): self._arg = Subquery(arg) else: self._group = None - if arg.heading.computed or\ + if arg.heading.computed or \ (isinstance(arg.restrictions, RelationalOperand) and \ - all(attr in self._attributes for attr in arg.restrictions.heading.names)) : + all(attr in self._attributes for attr in arg.restrictions.heading.names)): # can simply the expression because all restrictions attrs are projected out anyway! self._arg = arg self._restrictions = self._arg.restrictions diff --git a/doc/source/_static/.dummy b/doc/source/_static/.dummy new file mode 100644 index 000000000..e69de29bb diff --git a/tests/schema.py b/tests/schema.py index 55e2e094f..f496e7b92 100644 --- a/tests/schema.py +++ b/tests/schema.py @@ -39,6 +39,26 @@ class Subject(dj.Manual): def prepare(self): self.insert(self.contents, ignore_errors=True) +@schema +class Language(dj.Lookup): + + definition = """ + # languages spoken by some of the developers + + entry_id : int + --- + name : varchar(40) # name of the developer + language : varchar(40) # language + """ + + contents = [ + (0, 'Fabian', 'English'), + (1, 'Edgar', 'English'), + (2, 'Dimitri', 'English'), + (3, 'Dimitri', 'Ukrainian'), + (4, 'Fabian', 'German'), + (5, 'Edgar', 'Japanese'), + ] @schema class Experiment(dj.Imported): diff --git a/tests/test_fetch.py b/tests/test_fetch.py new file mode 100644 index 000000000..37746da07 --- /dev/null +++ b/tests/test_fetch.py @@ -0,0 +1,133 @@ +from operator import itemgetter, attrgetter +import itertools +from nose.tools import assert_true +from numpy.testing import assert_array_equal, assert_equal +import numpy as np + +from . import schema +import datajoint as dj + + +class TestFetch: + def __init__(self): + self.subject = schema.Subject() + self.lang = schema.Language() + + def test_getitem(self): + """Testing Fetch.__getitem__""" + + np.testing.assert_array_equal(sorted(self.subject.project().fetch(), key=itemgetter(0)), + sorted(self.subject.fetch[dj.key], key=itemgetter(0)), + 'Primary key is not returned correctly') + + tmp = self.subject.fetch(order_by=['subject_id']) + + for column, field in zip(self.subject.fetch[:], [e[0] for e in tmp.dtype.descr]): + np.testing.assert_array_equal(sorted(tmp[field]), sorted(column), 'slice : does not work correctly') + + subject_notes, key, real_id = self.subject.fetch['subject_notes', dj.key, 'real_id'] + # + np.testing.assert_array_equal(sorted(subject_notes), sorted(tmp['subject_notes'])) + np.testing.assert_array_equal(sorted(real_id), sorted(tmp['real_id'])) + np.testing.assert_array_equal(sorted(key, key=itemgetter(0)), + sorted(self.subject.project().fetch(), key=itemgetter(0))) + + for column, field in zip(self.subject.fetch['subject_id'::2], [e[0] for e in tmp.dtype.descr][::2]): + np.testing.assert_array_equal(sorted(tmp[field]), sorted(column), 'slice : does not work correctly') + + def test_order_by(self): + """Tests order_by sorting order""" + langs = schema.Language.contents + + for ord_name, ord_lang in itertools.product(*2 * [['ASC', 'DESC']]): + cur = self.lang.fetch.order_by('name ' + ord_name, 'language ' + ord_lang)() + langs.sort(key=itemgetter(2), reverse=ord_lang == 'DESC') + langs.sort(key=itemgetter(1), reverse=ord_name == 'DESC') + for c, l in zip(cur, langs): + assert_true(np.all(cc == ll for cc, ll in zip(c, l)), 'Sorting order is different') + + def test_order_by_default(self): + """Tests order_by sorting order with defaults""" + langs = schema.Language.contents + + cur = self.lang.fetch.order_by('language', 'name DESC')() + langs.sort(key=itemgetter(1), reverse=True) + langs.sort(key=itemgetter(2), reverse=False) + + for c, l in zip(cur, langs): + assert_true(np.all([cc == ll for cc, ll in zip(c, l)]), 'Sorting order is different') + + def test_order_by_direct(self): + """Tests order_by sorting order passing it to __call__""" + langs = schema.Language.contents + + cur = self.lang.fetch(order_by=['language', 'name DESC']) + langs.sort(key=itemgetter(1), reverse=True) + langs.sort(key=itemgetter(2), reverse=False) + for c, l in zip(cur, langs): + assert_true(np.all([cc == ll for cc, ll in zip(c, l)]), 'Sorting order is different') + + def test_limit_to(self): + """Test the limit_to function """ + langs = schema.Language.contents + + cur = self.lang.fetch.limit_to(4)(order_by=['language', 'name DESC']) + langs.sort(key=itemgetter(1), reverse=True) + langs.sort(key=itemgetter(2), reverse=False) + assert_equal(len(cur), 4, 'Length is not correct') + for c, l in list(zip(cur, langs))[:4]: + assert_true(np.all([cc == ll for cc, ll in zip(c, l)]), 'Sorting order is different') + + def test_from_to(self): + """Test the from_to function """ + langs = schema.Language.contents + + cur = self.lang.fetch.from_to(2, 6)(order_by=['language', 'name DESC']) + langs.sort(key=itemgetter(1), reverse=True) + langs.sort(key=itemgetter(2), reverse=False) + assert_equal(len(cur), 4, 'Length is not correct') + for c, l in list(zip(cur, langs[2:6])): + assert_true(np.all([cc == ll for cc, ll in zip(c, l)]), 'Sorting order is different') + + def test_iter(self): + """Test iterator""" + langs = schema.Language.contents + + cur = self.lang.fetch.order_by('language', 'name DESC') + langs.sort(key=itemgetter(1), reverse=True) + langs.sort(key=itemgetter(2), reverse=False) + for (_, name, lang), (_, tname, tlang) in list(zip(cur, langs)): + assert_true(name == tname and lang == tlang, 'Values are not the same') + + def test_keys(self): + """test key iterator""" + langs = schema.Language.contents + langs.sort(key=itemgetter(1), reverse=True) + langs.sort(key=itemgetter(2), reverse=False) + + cur = self.lang.fetch.order_by('language', 'name DESC')['entry_id'] + cur2 = [e['entry_id'] for e in self.lang.fetch.order_by('language', 'name DESC').keys()] + + keys, _, _ = list(zip(*langs)) + for k, c, c2 in zip(keys, cur, cur2): + assert_true(k == c == c2, 'Values are not the same') + + def test_fetch1(self): + key = {'entry_id': 0} + true = schema.Language.contents[0] + + dat = (self.lang & key).fetch1() + for k, (ke, c) in zip(true, dat.items()): + assert_true(k == c == (self.lang & key).fetch1[ke], 'Values are not the same') + + def test_copy(self): + """Test whether modifications copy the object""" + f = self.lang.fetch + f2 = f.order_by('name') + assert_true(f.behavior['order_by'] is None and len(f2.behavior['order_by']) == 1, 'Object was not copied') + + def test_overwrite(self): + """Test whether order_by overwrites duplicates""" + f = self.lang.fetch.order_by('name DeSc ') + f2 = f.order_by('name') + assert_true(f2.behavior['order_by'] == ['name'], 'order_by attribute was not overwritten') \ No newline at end of file diff --git a/tests/test_relational_operand.py b/tests/test_relational_operand.py index e62389ddc..0a8fabc22 100644 --- a/tests/test_relational_operand.py +++ b/tests/test_relational_operand.py @@ -53,28 +53,3 @@ # def test_not(self): # pass -class TestRelationalOperand: - def __init__(self): - self.subject = schema.Subject() - - def test_getitem(self): - """Testing RelationalOperand.__getitem__""" - - np.testing.assert_array_equal(sorted(self.subject.project().fetch(), key=itemgetter(0)), - sorted(self.subject.fetch[dj.key], key=itemgetter(0)), - 'Primary key is not returned correctly') - - tmp = self.subject.fetch(order_by=['subject_id']) - - for column, field in zip(self.subject.fetch[:], [e[0] for e in tmp.dtype.descr]): - np.testing.assert_array_equal(sorted(tmp[field]), sorted(column), 'slice : does not work correctly') - - subject_notes, key, real_id = self.subject.fetch['subject_notes', dj.key, 'real_id'] - # - np.testing.assert_array_equal(sorted(subject_notes), sorted(tmp['subject_notes'])) - np.testing.assert_array_equal(sorted(real_id), sorted(tmp['real_id'])) - np.testing.assert_array_equal(sorted(key, key=itemgetter(0)), - sorted(self.subject.project().fetch(), key=itemgetter(0))) - - for column, field in zip(self.subject.fetch['subject_id'::2], [e[0] for e in tmp.dtype.descr][::2]): - np.testing.assert_array_equal(sorted(tmp[field]), sorted(column), 'slice : does not work correctly')