From bdef6c35d287a3e2a8a9f89d05522fe674e968ca Mon Sep 17 00:00:00 2001 From: Fabian Sinz Date: Fri, 24 Jul 2015 15:44:49 -0500 Subject: [PATCH 1/4] bugfixes and more tests --- datajoint/fetch.py | 57 ++++++--------- datajoint/relational_operand.py | 31 +++++--- tests/schema.py | 20 +++++ tests/test_fetch.py | 121 +++++++++++++++++++++++++++++++ tests/test_relational_operand.py | 25 ------- 5 files changed, 182 insertions(+), 72 deletions(-) create mode 100644 tests/test_fetch.py diff --git a/datajoint/fetch.py b/datajoint/fetch.py index d5b99c38f..4369f82ad 100644 --- a/datajoint/fetch.py +++ b/datajoint/fetch.py @@ -1,8 +1,11 @@ from collections import OrderedDict +import itertools from .blob import unpack import numpy as np from datajoint import DataJointError from . import key as PRIMARY_KEY +from collections import abc + def prepare_attributes(relation, item): if isinstance(item, str) or item is PRIMARY_KEY: @@ -20,43 +23,35 @@ def prepare_attributes(relation, item): raise DataJointError("Index must be a slice, a tuple, a list, a string.") return item, attributes -class FetchQuery: +class Fetch: def __init__(self, relation): """ """ + self.behavior = dict( - offset=0, limit=None, order_by=None, descending=False, as_dict=False, map=None + offset=0, limit=None, order_by=None, as_dict=False ) self._relation = relation - def from_to(self, fro, to): self.behavior['offset'] = fro self.behavior['limit'] = to - fro return self - def order_by(self, order_by): - self.behavior['order_by'] = order_by + def order_by(self, *args): + + if len(args) > 0: + self.behavior['order_by'] = self.behavior['order_by'] if self.behavior['order_by'] is not None else [] + self.behavior['order_by'].extend(args) return self def as_dict(self): self.behavior['as_dict'] = True - def ascending(self): - self.behavior['descending'] = False - return self - - def descending(self): - self.behavior['descending'] = True - return self - - def apply(self, f): - self.behavior['map'] = f - return self - def limit_by(self, limit): + def limit_to(self, limit): self.behavior['limit'] = limit return self @@ -78,9 +73,7 @@ def __call__(self, **kwargs): """ behavior = dict(self.behavior, **kwargs) - cur = self._relation.cursor(offset=behavior['offset'], limit=behavior['limit'], - order_by=behavior['order_by'], descending=behavior['descending'], - as_dict=behavior['as_dict']) + cur = self._relation.cursor(**behavior) heading = self._relation.heading if behavior['as_dict']: @@ -92,22 +85,15 @@ def __call__(self, **kwargs): for blob_name in heading.blobs: ret[blob_name] = list(map(unpack, ret[blob_name])) - if behavior['map'] is not None: - f = behavior['map'] - for i in range(len(ret)): - ret[i] = f(ret[i]) - return ret def __iter__(self): """ Iterator that returns the contents of the database. """ - behavior = self.behavior + behavior = dict(self.behavior) - cur = self._relation.cursor(offset=behavior['offset'], limit=behavior['limit'], - order_by=behavior['order_by'], descending=behavior['descending'], - as_dict=behavior['as_dict']) + cur = self._relation.cursor(**behavior) heading = self._relation.heading do_unpack = tuple(h in heading.blobs for h in heading.names) @@ -126,10 +112,10 @@ def keys(self, **kwargs): """ Iterator that returns primary keys. """ + b = dict(self.behavior, **kwargs) if 'as_dict' not in kwargs: - kwargs['as_dict'] = True - yield from self._relation.project().fetch.set_behavior(**kwargs) - + b['as_dict'] = True + yield from self._relation.project().fetch.set_behavior(**b) def __getitem__(self, item): """ @@ -146,7 +132,7 @@ def __getitem__(self, item): single_output = isinstance(item, str) or item is PRIMARY_KEY or isinstance(item, int) item, attributes = prepare_attributes(self._relation, item) - result = self._relation.project(*attributes).fetch() + result = self._relation.project(*attributes).fetch(**self.behavior) return_values = [ np.ndarray(result.shape, np.dtype({name: result.dtype.fields[name] for name in self._relation.primary_key}), @@ -158,8 +144,7 @@ def __getitem__(self, item): return return_values[0] if single_output else return_values -class Fetch1Query: - +class Fetch1: def __init__(self, relation): self._relation = relation @@ -202,4 +187,4 @@ def __getitem__(self, item): else result[attribute][0] for attribute in item ) - return return_values[0] if single_output else return_values \ No newline at end of file + return return_values[0] if single_output else return_values diff --git a/datajoint/relational_operand.py b/datajoint/relational_operand.py index a89cd2c76..249dba843 100644 --- a/datajoint/relational_operand.py +++ b/datajoint/relational_operand.py @@ -10,7 +10,7 @@ from . import DataJointError import logging -from .fetch import FetchQuery, Fetch1Query +from .fetch import Fetch, Fetch1 logger = logging.getLogger(__name__) @@ -171,7 +171,7 @@ def __call__(self, *args, **kwargs): """ return self.fetch(*args, **kwargs) - def cursor(self, offset=0, limit=None, order_by=None, descending=False, as_dict=False): + def cursor(self, offset=0, limit=None, order_by=None, as_dict=False): """ Return query cursor. See Relation.fetch() for input description. @@ -181,9 +181,19 @@ def cursor(self, offset=0, limit=None, order_by=None, descending=False, as_dict= raise DataJointError('limit is required when offset is set') sql = self.make_select() if order_by is not None: - sql += ' ORDER BY ' + ', '.join(order_by) - if descending: - sql += ' DESC' + order_attr, order = [], [] + for a in order_by: + a = [e.strip() for e in a.split('=')] + if len(a) == 1: + order_attr.append(a[0]) + order.append('ASC') + else: + order_attr.append(a[0]) + order.append(a[1]) + + sql += ' ORDER BY ' + ', '.join(['`%s` %s' % (attr, val) for attr, val in + zip(order_attr, order)]) + if limit is not None: sql += ' LIMIT %d' % limit if offset: @@ -206,12 +216,13 @@ def __repr__(self): repr_string += ' (%d tuples)\n' % len(self) return repr_string + @property def fetch1(self): - return Fetch1Query(self) + return Fetch1(self) @property def fetch(self): - return FetchQuery(self) + return Fetch(self) @property def where_clause(self): @@ -253,8 +264,6 @@ def make_condition(arg): return ' WHERE ' + ' AND '.join(condition_string) - - class Not: """ inverse restriction @@ -319,9 +328,9 @@ def __init__(self, arg, group=None, *attributes, **renamed_attributes): self._arg = Subquery(arg) else: self._group = None - if arg.heading.computed or\ + if arg.heading.computed or \ (isinstance(arg.restrictions, RelationalOperand) and \ - all(attr in self._attributes for attr in arg.restrictions.heading.names)) : + all(attr in self._attributes for attr in arg.restrictions.heading.names)): # can simply the expression because all restrictions attrs are projected out anyway! self._arg = arg self._restrictions = self._arg.restrictions diff --git a/tests/schema.py b/tests/schema.py index 55e2e094f..f496e7b92 100644 --- a/tests/schema.py +++ b/tests/schema.py @@ -39,6 +39,26 @@ class Subject(dj.Manual): def prepare(self): self.insert(self.contents, ignore_errors=True) +@schema +class Language(dj.Lookup): + + definition = """ + # languages spoken by some of the developers + + entry_id : int + --- + name : varchar(40) # name of the developer + language : varchar(40) # language + """ + + contents = [ + (0, 'Fabian', 'English'), + (1, 'Edgar', 'English'), + (2, 'Dimitri', 'English'), + (3, 'Dimitri', 'Ukrainian'), + (4, 'Fabian', 'German'), + (5, 'Edgar', 'Japanese'), + ] @schema class Experiment(dj.Imported): diff --git a/tests/test_fetch.py b/tests/test_fetch.py new file mode 100644 index 000000000..5dd24b93e --- /dev/null +++ b/tests/test_fetch.py @@ -0,0 +1,121 @@ +from operator import itemgetter, attrgetter +import itertools +from nose.tools import assert_true +from numpy.testing import assert_array_equal, assert_equal +import numpy as np + +from . import schema +import datajoint as dj + + +class TestFetch: + def __init__(self): + self.subject = schema.Subject() + self.lang = schema.Language() + + def test_getitem(self): + """Testing Fetch.__getitem__""" + + np.testing.assert_array_equal(sorted(self.subject.project().fetch(), key=itemgetter(0)), + sorted(self.subject.fetch[dj.key], key=itemgetter(0)), + 'Primary key is not returned correctly') + + tmp = self.subject.fetch(order_by=['subject_id']) + + for column, field in zip(self.subject.fetch[:], [e[0] for e in tmp.dtype.descr]): + np.testing.assert_array_equal(sorted(tmp[field]), sorted(column), 'slice : does not work correctly') + + subject_notes, key, real_id = self.subject.fetch['subject_notes', dj.key, 'real_id'] + # + np.testing.assert_array_equal(sorted(subject_notes), sorted(tmp['subject_notes'])) + np.testing.assert_array_equal(sorted(real_id), sorted(tmp['real_id'])) + np.testing.assert_array_equal(sorted(key, key=itemgetter(0)), + sorted(self.subject.project().fetch(), key=itemgetter(0))) + + for column, field in zip(self.subject.fetch['subject_id'::2], [e[0] for e in tmp.dtype.descr][::2]): + np.testing.assert_array_equal(sorted(tmp[field]), sorted(column), 'slice : does not work correctly') + + def test_order_by(self): + """Tests order_by sorting order""" + langs = schema.Language.contents + + for ord_name, ord_lang in itertools.product(*2 * [['ASC', 'DESC']]): + cur = self.lang.fetch.order_by('name=' + ord_name, 'language=' + ord_lang)() + langs.sort(key=itemgetter(2), reverse=ord_lang == 'DESC') + langs.sort(key=itemgetter(1), reverse=ord_name == 'DESC') + for c, l in zip(cur, langs): + assert_true(np.all(cc == ll for cc, ll in zip(c, l)), 'Sorting order is different') + + def test_order_by_default(self): + """Tests order_by sorting order with defaults""" + langs = schema.Language.contents + + cur = self.lang.fetch.order_by('language', 'name=DESC')() + langs.sort(key=itemgetter(1), reverse=True) + langs.sort(key=itemgetter(2), reverse=False) + + for c, l in zip(cur, langs): + assert_true(np.all([cc == ll for cc, ll in zip(c, l)]), 'Sorting order is different') + + def test_order_by_direct(self): + """Tests order_by sorting order passing it to __call__""" + langs = schema.Language.contents + + cur = self.lang.fetch(order_by=['language', 'name=DESC']) + langs.sort(key=itemgetter(1), reverse=True) + langs.sort(key=itemgetter(2), reverse=False) + for c, l in zip(cur, langs): + assert_true(np.all([cc == ll for cc, ll in zip(c, l)]), 'Sorting order is different') + + def test_limit_to(self): + """Test the limit_to function """ + langs = schema.Language.contents + + cur = self.lang.fetch.limit_to(4)(order_by=['language', 'name=DESC']) + langs.sort(key=itemgetter(1), reverse=True) + langs.sort(key=itemgetter(2), reverse=False) + assert_equal(len(cur), 4, 'Length is not correct') + for c, l in list(zip(cur, langs))[:4]: + assert_true(np.all([cc == ll for cc, ll in zip(c, l)]), 'Sorting order is different') + + def test_from_to(self): + """Test the from_to function """ + langs = schema.Language.contents + + cur = self.lang.fetch.from_to(2, 6)(order_by=['language', 'name=DESC']) + langs.sort(key=itemgetter(1), reverse=True) + langs.sort(key=itemgetter(2), reverse=False) + assert_equal(len(cur), 4, 'Length is not correct') + for c, l in list(zip(cur, langs[2:6])): + assert_true(np.all([cc == ll for cc, ll in zip(c, l)]), 'Sorting order is different') + + def test_iter(self): + """Test iterator""" + langs = schema.Language.contents + + cur = self.lang.fetch.order_by('language', 'name=DESC') + langs.sort(key=itemgetter(1), reverse=True) + langs.sort(key=itemgetter(2), reverse=False) + for (_, name, lang), (_, tname, tlang) in list(zip(cur, langs)): + assert_true(name == tname and lang == tlang, 'Values are not the same') + + def test_keys(self): + """test key iterator""" + langs = schema.Language.contents + langs.sort(key=itemgetter(1), reverse=True) + langs.sort(key=itemgetter(2), reverse=False) + + cur = self.lang.fetch.order_by('language', 'name=DESC')['entry_id'] + cur2 = [e['entry_id'] for e in self.lang.fetch.order_by('language', 'name=DESC').keys()] + + keys, _, _ = list(zip(*langs)) + for k, c, c2 in zip(keys, cur, cur2): + assert_true(k == c == c2, 'Values are not the same') + + def test_fetch1(self): + key = {'entry_id': 0} + true = schema.Language.contents[0] + + dat = (self.lang & key).fetch1() + for k, (ke, c) in zip(true, dat.items()): + assert_true(k == c == (self.lang & key).fetch1[ke], 'Values are not the same') diff --git a/tests/test_relational_operand.py b/tests/test_relational_operand.py index e62389ddc..0a8fabc22 100644 --- a/tests/test_relational_operand.py +++ b/tests/test_relational_operand.py @@ -53,28 +53,3 @@ # def test_not(self): # pass -class TestRelationalOperand: - def __init__(self): - self.subject = schema.Subject() - - def test_getitem(self): - """Testing RelationalOperand.__getitem__""" - - np.testing.assert_array_equal(sorted(self.subject.project().fetch(), key=itemgetter(0)), - sorted(self.subject.fetch[dj.key], key=itemgetter(0)), - 'Primary key is not returned correctly') - - tmp = self.subject.fetch(order_by=['subject_id']) - - for column, field in zip(self.subject.fetch[:], [e[0] for e in tmp.dtype.descr]): - np.testing.assert_array_equal(sorted(tmp[field]), sorted(column), 'slice : does not work correctly') - - subject_notes, key, real_id = self.subject.fetch['subject_notes', dj.key, 'real_id'] - # - np.testing.assert_array_equal(sorted(subject_notes), sorted(tmp['subject_notes'])) - np.testing.assert_array_equal(sorted(real_id), sorted(tmp['real_id'])) - np.testing.assert_array_equal(sorted(key, key=itemgetter(0)), - sorted(self.subject.project().fetch(), key=itemgetter(0))) - - for column, field in zip(self.subject.fetch['subject_id'::2], [e[0] for e in tmp.dtype.descr][::2]): - np.testing.assert_array_equal(sorted(tmp[field]), sorted(column), 'slice : does not work correctly') From 9b28802e1078bf9bbac7fc5f142c0ffaf59115b9 Mon Sep 17 00:00:00 2001 From: Fabian Sinz Date: Fri, 24 Jul 2015 17:01:14 -0500 Subject: [PATCH 2/4] fetch object is copied and order_by attributes are overwritten --- doc/source/_static/.dummy | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 doc/source/_static/.dummy diff --git a/doc/source/_static/.dummy b/doc/source/_static/.dummy new file mode 100644 index 000000000..e69de29bb From 13f5b0c4536994f1f78a74b7301e591d264fd506 Mon Sep 17 00:00:00 2001 From: Fabian Sinz Date: Fri, 24 Jul 2015 17:01:17 -0500 Subject: [PATCH 3/4] fetch object is copied and order_by attributes are overwritten --- datajoint/fetch.py | 36 ++++++++++++++++++++++++++++-------- tests/test_fetch.py | 12 ++++++++++++ 2 files changed, 40 insertions(+), 8 deletions(-) diff --git a/datajoint/fetch.py b/datajoint/fetch.py index 4369f82ad..910ff5106 100644 --- a/datajoint/fetch.py +++ b/datajoint/fetch.py @@ -1,5 +1,7 @@ from collections import OrderedDict +from functools import wraps import itertools +import re from .blob import unpack import numpy as np from datajoint import DataJointError @@ -23,38 +25,56 @@ def prepare_attributes(relation, item): raise DataJointError("Index must be a slice, a tuple, a list, a string.") return item, attributes +def copy_first(f): + @wraps(f) + def ret(*args, **kwargs): + args = list(args) + args[0] = args[0].__class__(args[0]) # call copy constructor + return f(*args, **kwargs) + + return ret class Fetch: def __init__(self, relation): - """ + if isinstance(relation, Fetch): # copy constructor + self.behavior = dict(relation.behavior) + self._relation = relation._relation + else: + self.behavior = dict( + offset=0, limit=None, order_by=None, as_dict=False + ) + self._relation = relation - """ - - self.behavior = dict( - offset=0, limit=None, order_by=None, as_dict=False - ) - self._relation = relation + @copy_first def from_to(self, fro, to): self.behavior['offset'] = fro self.behavior['limit'] = to - fro return self + @copy_first def order_by(self, *args): - if len(args) > 0: self.behavior['order_by'] = self.behavior['order_by'] if self.behavior['order_by'] is not None else [] + for a in args: # remove duplicates + name = a.split('=')[0].strip() + pat = re.compile(r"%s\s*(=\s*(DESC|ASC)\s*|)?$" % (name,), re.I) + self.behavior['order_by'] = [e for e in self.behavior['order_by'] if not pat.match(e)] self.behavior['order_by'].extend(args) return self + @copy_first def as_dict(self): self.behavior['as_dict'] = True + return self + @copy_first def limit_to(self, limit): self.behavior['limit'] = limit return self + @copy_first def set_behavior(self, **kwargs): self.behavior.update(kwargs) return self diff --git a/tests/test_fetch.py b/tests/test_fetch.py index 5dd24b93e..001f266bb 100644 --- a/tests/test_fetch.py +++ b/tests/test_fetch.py @@ -119,3 +119,15 @@ def test_fetch1(self): dat = (self.lang & key).fetch1() for k, (ke, c) in zip(true, dat.items()): assert_true(k == c == (self.lang & key).fetch1[ke], 'Values are not the same') + + def test_copy(self): + """Test whether modifications copy the object""" + f = self.lang.fetch + f2 = f.order_by('name') + assert_true(f.behavior['order_by'] is None and len(f2.behavior['order_by']) == 1, 'Object was not copied') + + def test_overwrite(self): + """Test whether order_by overwrites duplicates""" + f = self.lang.fetch.order_by('name = DeSc ') + f2 = f.order_by('name') + assert_true(f2.behavior['order_by'] == ['name'], 'order_by attribute was not overwritten') \ No newline at end of file From dad10f85c630d7bc2f3b19ab4161171f44cd7a88 Mon Sep 17 00:00:00 2001 From: Fabian Sinz Date: Fri, 24 Jul 2015 17:23:58 -0500 Subject: [PATCH 4/4] syntax change --- datajoint/fetch.py | 5 +++-- datajoint/relational_operand.py | 13 +------------ tests/test_fetch.py | 18 +++++++++--------- 3 files changed, 13 insertions(+), 23 deletions(-) diff --git a/datajoint/fetch.py b/datajoint/fetch.py index 910ff5106..13758d42f 100644 --- a/datajoint/fetch.py +++ b/datajoint/fetch.py @@ -56,9 +56,10 @@ def from_to(self, fro, to): def order_by(self, *args): if len(args) > 0: self.behavior['order_by'] = self.behavior['order_by'] if self.behavior['order_by'] is not None else [] + namepat = re.compile(r"\s*(?P\w+).*") for a in args: # remove duplicates - name = a.split('=')[0].strip() - pat = re.compile(r"%s\s*(=\s*(DESC|ASC)\s*|)?$" % (name,), re.I) + name = namepat.match(a).group('name') + pat = re.compile(r"%s(\s*$|\s+(\S*\s*)*$)" % (name,)) self.behavior['order_by'] = [e for e in self.behavior['order_by'] if not pat.match(e)] self.behavior['order_by'].extend(args) return self diff --git a/datajoint/relational_operand.py b/datajoint/relational_operand.py index 249dba843..769dd7986 100644 --- a/datajoint/relational_operand.py +++ b/datajoint/relational_operand.py @@ -181,18 +181,7 @@ def cursor(self, offset=0, limit=None, order_by=None, as_dict=False): raise DataJointError('limit is required when offset is set') sql = self.make_select() if order_by is not None: - order_attr, order = [], [] - for a in order_by: - a = [e.strip() for e in a.split('=')] - if len(a) == 1: - order_attr.append(a[0]) - order.append('ASC') - else: - order_attr.append(a[0]) - order.append(a[1]) - - sql += ' ORDER BY ' + ', '.join(['`%s` %s' % (attr, val) for attr, val in - zip(order_attr, order)]) + sql += ' ORDER BY ' + ', '.join(order_by) if limit is not None: sql += ' LIMIT %d' % limit diff --git a/tests/test_fetch.py b/tests/test_fetch.py index 001f266bb..37746da07 100644 --- a/tests/test_fetch.py +++ b/tests/test_fetch.py @@ -40,7 +40,7 @@ def test_order_by(self): langs = schema.Language.contents for ord_name, ord_lang in itertools.product(*2 * [['ASC', 'DESC']]): - cur = self.lang.fetch.order_by('name=' + ord_name, 'language=' + ord_lang)() + cur = self.lang.fetch.order_by('name ' + ord_name, 'language ' + ord_lang)() langs.sort(key=itemgetter(2), reverse=ord_lang == 'DESC') langs.sort(key=itemgetter(1), reverse=ord_name == 'DESC') for c, l in zip(cur, langs): @@ -50,7 +50,7 @@ def test_order_by_default(self): """Tests order_by sorting order with defaults""" langs = schema.Language.contents - cur = self.lang.fetch.order_by('language', 'name=DESC')() + cur = self.lang.fetch.order_by('language', 'name DESC')() langs.sort(key=itemgetter(1), reverse=True) langs.sort(key=itemgetter(2), reverse=False) @@ -61,7 +61,7 @@ def test_order_by_direct(self): """Tests order_by sorting order passing it to __call__""" langs = schema.Language.contents - cur = self.lang.fetch(order_by=['language', 'name=DESC']) + cur = self.lang.fetch(order_by=['language', 'name DESC']) langs.sort(key=itemgetter(1), reverse=True) langs.sort(key=itemgetter(2), reverse=False) for c, l in zip(cur, langs): @@ -71,7 +71,7 @@ def test_limit_to(self): """Test the limit_to function """ langs = schema.Language.contents - cur = self.lang.fetch.limit_to(4)(order_by=['language', 'name=DESC']) + cur = self.lang.fetch.limit_to(4)(order_by=['language', 'name DESC']) langs.sort(key=itemgetter(1), reverse=True) langs.sort(key=itemgetter(2), reverse=False) assert_equal(len(cur), 4, 'Length is not correct') @@ -82,7 +82,7 @@ def test_from_to(self): """Test the from_to function """ langs = schema.Language.contents - cur = self.lang.fetch.from_to(2, 6)(order_by=['language', 'name=DESC']) + cur = self.lang.fetch.from_to(2, 6)(order_by=['language', 'name DESC']) langs.sort(key=itemgetter(1), reverse=True) langs.sort(key=itemgetter(2), reverse=False) assert_equal(len(cur), 4, 'Length is not correct') @@ -93,7 +93,7 @@ def test_iter(self): """Test iterator""" langs = schema.Language.contents - cur = self.lang.fetch.order_by('language', 'name=DESC') + cur = self.lang.fetch.order_by('language', 'name DESC') langs.sort(key=itemgetter(1), reverse=True) langs.sort(key=itemgetter(2), reverse=False) for (_, name, lang), (_, tname, tlang) in list(zip(cur, langs)): @@ -105,8 +105,8 @@ def test_keys(self): langs.sort(key=itemgetter(1), reverse=True) langs.sort(key=itemgetter(2), reverse=False) - cur = self.lang.fetch.order_by('language', 'name=DESC')['entry_id'] - cur2 = [e['entry_id'] for e in self.lang.fetch.order_by('language', 'name=DESC').keys()] + cur = self.lang.fetch.order_by('language', 'name DESC')['entry_id'] + cur2 = [e['entry_id'] for e in self.lang.fetch.order_by('language', 'name DESC').keys()] keys, _, _ = list(zip(*langs)) for k, c, c2 in zip(keys, cur, cur2): @@ -128,6 +128,6 @@ def test_copy(self): def test_overwrite(self): """Test whether order_by overwrites duplicates""" - f = self.lang.fetch.order_by('name = DeSc ') + f = self.lang.fetch.order_by('name DeSc ') f2 = f.order_by('name') assert_true(f2.behavior['order_by'] == ['name'], 'order_by attribute was not overwritten') \ No newline at end of file