diff --git a/datajoint/declare.py b/datajoint/declare.py index 8b357b8c0..4f2cd2463 100644 --- a/datajoint/declare.py +++ b/datajoint/declare.py @@ -114,6 +114,5 @@ def compile_attribute(line, in_key=False): else: match['default'] = 'NOT NULL' match['comment'] = match['comment'].replace('"', '\\"') # escape double quotes in comment - sql = ('`{name}` {type} {default}' + (' COMMENT "{comment}"' if match['comment'] else '') - ).format(**match) + sql = ('`{name}` {type} {default}' + (' COMMENT "{comment}"' if match['comment'] else '')).format(**match) return match['name'], sql diff --git a/datajoint/fetch.py b/datajoint/fetch.py index 955116892..395885528 100644 --- a/datajoint/fetch.py +++ b/datajoint/fetch.py @@ -1,4 +1,5 @@ from collections import OrderedDict +from collections.abc import Callable, Iterable from functools import wraps import warnings from .blob import unpack @@ -47,7 +48,7 @@ def ret(*args, **kwargs): return ret -class Fetch: +class Fetch(Iterable, Callable): """ A fetch object that handles retrieving elements from the database table. @@ -59,9 +60,7 @@ def __init__(self, relation): self.behavior = dict(relation.behavior) self._relation = relation._relation else: - self.behavior = dict( - offset=None, limit=None, order_by=None, as_dict=False - ) + self.behavior = dict(offset=None, limit=None, order_by=None, as_dict=False) self._relation = relation @copy_first @@ -240,7 +239,7 @@ def __len__(self): return len(self._relation) -class Fetch1: +class Fetch1(Callable): """ Fetch object for fetching exactly one row. diff --git a/datajoint/heading.py b/datajoint/heading.py index e3986c032..7a742416b 100644 --- a/datajoint/heading.py +++ b/datajoint/heading.py @@ -164,9 +164,6 @@ def init_from_database(self, conn, database, table_name): attr['string'] = bool(re.match(r'(var)?char|enum|date|time|timestamp', attr['type'])) attr['is_blob'] = bool(re.match(r'(tiny|medium|long)?blob', attr['type'])) - # strip field lengths off integer types - attr['type'] = re.sub(r'((tiny|small|medium|big)?int)\(\d+\)', r'\1', attr['type']) - attr['computation'] = None if not (attr['numeric'] or attr['string'] or attr['is_blob']): raise DataJointError('Unsupported field type {field} in `{database}`.`{table_name}`'.format( diff --git a/datajoint/relation.py b/datajoint/relation.py index 9f3ba6f49..ce02351dd 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -1,4 +1,5 @@ -from collections import Mapping, OrderedDict +from collections.abc import Mapping +from collections import OrderedDict import numpy as np import logging import abc diff --git a/datajoint/relational_operand.py b/datajoint/relational_operand.py index 2374f5d1b..d1733d67b 100644 --- a/datajoint/relational_operand.py +++ b/datajoint/relational_operand.py @@ -1,3 +1,4 @@ +from collections.abc import Iterable, Mapping import numpy as np import abc import re @@ -103,7 +104,7 @@ def aggregate(self, group, *attributes, **renamed_attributes): if not isinstance(group, RelationalOperand): raise DataJointError('The second argument must be a relation') return Aggregation( - Join(self, Subquery(group), left=True), + Join(self, group, left=True), *attributes, **renamed_attributes) def __and__(self, restriction): @@ -112,8 +113,7 @@ def __and__(self, restriction): :return: a restricted copy of the argument """ ret = copy(self) - ret._restrictions = list(ret.restrictions) # copy restriction list - ret.restrict(restriction) + ret.restrict(restriction, *ret.restrictions) return ret def restrict(self, *restrictions): @@ -124,7 +124,8 @@ def restrict(self, *restrictions): However, each member of restrictions can be a list of conditions, which are combined with OR. :param restrictions: list of restrictions. """ - restrictions = [r for r in restrictions if r is not None] # remove Nones + # remove Nones and duplicates + restrictions = [r for r in restrictions if r is not None and r not in self.restrictions] if restrictions: if any(is_empty_set(r) for r in restrictions): # if any condition is an empty list, return empty @@ -135,6 +136,14 @@ def restrict(self, *restrictions): else: self._restrictions.extend(restrictions) + def attributes_in_restrictions(self): + """ + :return: list of attributes that are probably used in the restrictions. + This is used internally for optimizing SQL statements + """ + where_clause = self.where_clause + return set(name for name in self.heading.names if name in where_clause) + def __sub__(self, restriction): """ inverted restriction aka antijoin @@ -212,7 +221,7 @@ def where_clause(self): return '' def make_condition(arg): - if isinstance(arg, dict): + if isinstance(arg, Mapping): condition = ['`%s`=%s' % (k, repr(v)) for k, v in arg.items() if k in self.heading] elif isinstance(arg, np.void): condition = ['`%s`=%s' % (k, arg[k]) for k in arg.dtype.fields if k in self.heading] @@ -225,20 +234,20 @@ def make_condition(arg): negate = isinstance(r, Not) if negate: r = r.restriction - if isinstance(r, dict) or isinstance(r, np.void): + if isinstance(r, Mapping) or isinstance(r, np.void): r = make_condition(r) elif isinstance(r, np.ndarray) or isinstance(r, list): r = '(' + ') OR ('.join([make_condition(q) for q in r]) + ')' elif isinstance(r, RelationalOperand): - common_attributes = ','.join([q for q in self.heading.names if q in r.heading.names]) + common_attributes = [q for q in self.heading.names if q in r.heading.names] if not common_attributes: r = 'FALSE' if negate else 'TRUE' else: - r = '({fields}) {not_}in (SELECT {fields} FROM {from_}{where})'.format( + common_attributes = '`'+'`,`'.join(common_attributes)+'`' + r = '({fields}) {not_}in ({subquery})'.format( fields=common_attributes, not_="not " if negate else "", - from_=r.from_clause, - where=r.where_clause) + subquery=r.make_select(common_attributes)) negate = False if not isinstance(r, str): raise DataJointError('Invalid restriction object') @@ -267,7 +276,8 @@ def __init__(self, arg1, arg2, left=False): raise DataJointError('Cannot join relations with different database connections') self._arg1 = Subquery(arg1) if arg1.heading.computed else arg1 self._arg2 = Subquery(arg2) if arg2.heading.computed else arg2 - self._restrictions = self._arg1.restrictions + self._arg2.restrictions + self.restrict(*self._arg1.restrictions) + self.restrict(*self._arg2.restrictions) self._left = left self._heading = self._arg1.heading.join(self._arg2.heading, left=left) @@ -311,12 +321,15 @@ def __init__(self, arg, *attributes, **renamed_attributes): self._renamed_attributes.update({d['alias']: d['sql_expression']}) else: self._attributes.append(attribute) + self._arg = arg - if arg.heading.computed or arg.restrictions: + restricting_on_removed_attributes = bool( + arg.attributes_in_restrictions() - set(self.heading.names)) + use_subquery = restricting_on_removed_attributes or arg.heading.computed + if use_subquery: self._arg = Subquery(arg) else: - self._arg = arg - self._restrictions = arg.restrictions + self.restrict(*arg.restrictions) def _repr_helper(self): return "(%r).project(%r)" % (self._arg, self._attributes) @@ -329,13 +342,17 @@ def connection(self): def heading(self): return self._arg.heading.project(*self._attributes, **self._renamed_attributes) + @property + def _grouped(self): + return self._arg._grouped + @property def from_clause(self): return self._arg.from_clause def __and__(self, restriction): """ - When projection has renamed attributes, it must be enclosed in a subquery before restriction + When restricting on renamed attributes, enclose in subquery """ has_restriction = isinstance(restriction, RelationalOperand) or restriction do_subquery = has_restriction and self.heading.computed diff --git a/setup.py b/setup.py index 112b22ce7..1e95bf956 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ long_description=long_description, author='Dimitri Yatsenko', author_email='Dimitri.Yatsenko@gmail.com', - license = "GNU LGPL", + license="GNU LGPL", url='https://github.com/datajoint/datajoint-python', keywords='database organization', packages=find_packages(exclude=['contrib', 'docs', 'tests*']), diff --git a/tests/test_fetch.py b/tests/test_fetch.py index bac9eda77..e4757998c 100644 --- a/tests/test_fetch.py +++ b/tests/test_fetch.py @@ -38,10 +38,11 @@ def test_getitem(self): def test_getitem_for_fetch1(self): """Testing Fetch1.__getitem__""" - assert_true( (self.subject & "subject_id=10").fetch1['subject_id'] == 10) - assert_true( (self.subject & "subject_id=10").fetch1['subject_id','species'] == (10, 'monkey')) - assert_true( (self.subject & "subject_id=10").fetch1['subject_id':'species'] == (10, 'Curious George')) - + assert_true((self.subject & "subject_id=10").fetch1['subject_id'] == 10) + assert_equal((self.subject & "subject_id=10").fetch1['subject_id', 'species'], + (10, 'monkey')) + assert_equal((self.subject & "subject_id=10").fetch1['subject_id':'species'], + (10, 'Curious George')) def test_order_by(self): """Tests order_by sorting order""" @@ -113,14 +114,14 @@ def test_keys(self): langs.sort(key=itemgetter(0), reverse=True) langs.sort(key=itemgetter(1), reverse=False) - cur = self.lang.fetch.order_by('language', 'name DESC')['name','language'] + cur = self.lang.fetch.order_by('language', 'name DESC')['name', 'language'] cur2 = list(self.lang.fetch.order_by('language', 'name DESC').keys()) for c, c2 in zip(zip(*cur), cur2): assert_true(c == tuple(c2.values()), 'Values are not the same') def test_fetch1(self): - key = {'name': 'Edgar', 'language':'Japanese'} + key = {'name': 'Edgar', 'language': 'Japanese'} true = schema.Language.contents[-1] dat = (self.lang & key).fetch1() @@ -170,7 +171,6 @@ def test_offset(self): for c, l in list(zip(cur, langs[1:]))[:4]: assert_true(np.all([cc == ll for cc, ll in zip(c, l)]), 'Sorting order is different') - def test_limit_warning(self): """Tests whether warning is raised if offset is used without limit.""" with warnings.catch_warnings(record=True) as w: