Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions datajoint/declare.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,5 @@ def compile_attribute(line, in_key=False):
else:
match['default'] = 'NOT NULL'
match['comment'] = match['comment'].replace('"', '\\"') # escape double quotes in comment
sql = ('`{name}` {type} {default}' + (' COMMENT "{comment}"' if match['comment'] else '')
).format(**match)
sql = ('`{name}` {type} {default}' + (' COMMENT "{comment}"' if match['comment'] else '')).format(**match)
return match['name'], sql
9 changes: 4 additions & 5 deletions datajoint/fetch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import OrderedDict
from collections.abc import Callable, Iterable
from functools import wraps
import warnings
from .blob import unpack
Expand Down Expand Up @@ -47,7 +48,7 @@ def ret(*args, **kwargs):
return ret


class Fetch:
class Fetch(Iterable, Callable):
"""
A fetch object that handles retrieving elements from the database table.

Expand All @@ -59,9 +60,7 @@ def __init__(self, relation):
self.behavior = dict(relation.behavior)
self._relation = relation._relation
else:
self.behavior = dict(
offset=None, limit=None, order_by=None, as_dict=False
)
self.behavior = dict(offset=None, limit=None, order_by=None, as_dict=False)
self._relation = relation

@copy_first
Expand Down Expand Up @@ -240,7 +239,7 @@ def __len__(self):
return len(self._relation)


class Fetch1:
class Fetch1(Callable):
"""
Fetch object for fetching exactly one row.

Expand Down
3 changes: 0 additions & 3 deletions datajoint/heading.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,6 @@ def init_from_database(self, conn, database, table_name):
attr['string'] = bool(re.match(r'(var)?char|enum|date|time|timestamp', attr['type']))
attr['is_blob'] = bool(re.match(r'(tiny|medium|long)?blob', attr['type']))

# strip field lengths off integer types
attr['type'] = re.sub(r'((tiny|small|medium|big)?int)\(\d+\)', r'\1', attr['type'])

attr['computation'] = None
if not (attr['numeric'] or attr['string'] or attr['is_blob']):
raise DataJointError('Unsupported field type {field} in `{database}`.`{table_name}`'.format(
Expand Down
3 changes: 2 additions & 1 deletion datajoint/relation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import Mapping, OrderedDict
from collections.abc import Mapping
from collections import OrderedDict
import numpy as np
import logging
import abc
Expand Down
47 changes: 32 additions & 15 deletions datajoint/relational_operand.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Iterable, Mapping
import numpy as np
import abc
import re
Expand Down Expand Up @@ -103,7 +104,7 @@ def aggregate(self, group, *attributes, **renamed_attributes):
if not isinstance(group, RelationalOperand):
raise DataJointError('The second argument must be a relation')
return Aggregation(
Join(self, Subquery(group), left=True),
Join(self, group, left=True),
*attributes, **renamed_attributes)

def __and__(self, restriction):
Expand All @@ -112,8 +113,7 @@ def __and__(self, restriction):
:return: a restricted copy of the argument
"""
ret = copy(self)
ret._restrictions = list(ret.restrictions) # copy restriction list
ret.restrict(restriction)
ret.restrict(restriction, *ret.restrictions)
return ret

def restrict(self, *restrictions):
Expand All @@ -124,7 +124,8 @@ def restrict(self, *restrictions):
However, each member of restrictions can be a list of conditions, which are combined with OR.
:param restrictions: list of restrictions.
"""
restrictions = [r for r in restrictions if r is not None] # remove Nones
# remove Nones and duplicates
restrictions = [r for r in restrictions if r is not None and r not in self.restrictions]
if restrictions:
if any(is_empty_set(r) for r in restrictions):
# if any condition is an empty list, return empty
Expand All @@ -135,6 +136,14 @@ def restrict(self, *restrictions):
else:
self._restrictions.extend(restrictions)

def attributes_in_restrictions(self):
"""
:return: list of attributes that are probably used in the restrictions.
This is used internally for optimizing SQL statements
"""
where_clause = self.where_clause
return set(name for name in self.heading.names if name in where_clause)

def __sub__(self, restriction):
"""
inverted restriction aka antijoin
Expand Down Expand Up @@ -212,7 +221,7 @@ def where_clause(self):
return ''

def make_condition(arg):
if isinstance(arg, dict):
if isinstance(arg, Mapping):
condition = ['`%s`=%s' % (k, repr(v)) for k, v in arg.items() if k in self.heading]
elif isinstance(arg, np.void):
condition = ['`%s`=%s' % (k, arg[k]) for k in arg.dtype.fields if k in self.heading]
Expand All @@ -225,20 +234,20 @@ def make_condition(arg):
negate = isinstance(r, Not)
if negate:
r = r.restriction
if isinstance(r, dict) or isinstance(r, np.void):
if isinstance(r, Mapping) or isinstance(r, np.void):
r = make_condition(r)
elif isinstance(r, np.ndarray) or isinstance(r, list):
r = '(' + ') OR ('.join([make_condition(q) for q in r]) + ')'
elif isinstance(r, RelationalOperand):
common_attributes = ','.join([q for q in self.heading.names if q in r.heading.names])
common_attributes = [q for q in self.heading.names if q in r.heading.names]
if not common_attributes:
r = 'FALSE' if negate else 'TRUE'
else:
r = '({fields}) {not_}in (SELECT {fields} FROM {from_}{where})'.format(
common_attributes = '`'+'`,`'.join(common_attributes)+'`'
r = '({fields}) {not_}in ({subquery})'.format(
fields=common_attributes,
not_="not " if negate else "",
from_=r.from_clause,
where=r.where_clause)
subquery=r.make_select(common_attributes))
negate = False
if not isinstance(r, str):
raise DataJointError('Invalid restriction object')
Expand Down Expand Up @@ -267,7 +276,8 @@ def __init__(self, arg1, arg2, left=False):
raise DataJointError('Cannot join relations with different database connections')
self._arg1 = Subquery(arg1) if arg1.heading.computed else arg1
self._arg2 = Subquery(arg2) if arg2.heading.computed else arg2
self._restrictions = self._arg1.restrictions + self._arg2.restrictions
self.restrict(*self._arg1.restrictions)
self.restrict(*self._arg2.restrictions)
self._left = left
self._heading = self._arg1.heading.join(self._arg2.heading, left=left)

Expand Down Expand Up @@ -311,12 +321,15 @@ def __init__(self, arg, *attributes, **renamed_attributes):
self._renamed_attributes.update({d['alias']: d['sql_expression']})
else:
self._attributes.append(attribute)
self._arg = arg

if arg.heading.computed or arg.restrictions:
restricting_on_removed_attributes = bool(
arg.attributes_in_restrictions() - set(self.heading.names))
use_subquery = restricting_on_removed_attributes or arg.heading.computed
if use_subquery:
self._arg = Subquery(arg)
else:
self._arg = arg
self._restrictions = arg.restrictions
self.restrict(*arg.restrictions)

def _repr_helper(self):
return "(%r).project(%r)" % (self._arg, self._attributes)
Expand All @@ -329,13 +342,17 @@ def connection(self):
def heading(self):
return self._arg.heading.project(*self._attributes, **self._renamed_attributes)

@property
def _grouped(self):
return self._arg._grouped

@property
def from_clause(self):
return self._arg.from_clause

def __and__(self, restriction):
"""
When projection has renamed attributes, it must be enclosed in a subquery before restriction
When restricting on renamed attributes, enclose in subquery
"""
has_restriction = isinstance(restriction, RelationalOperand) or restriction
do_subquery = has_restriction and self.heading.computed
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
long_description=long_description,
author='Dimitri Yatsenko',
author_email='Dimitri.Yatsenko@gmail.com',
license = "GNU LGPL",
license="GNU LGPL",
url='https://github.com/datajoint/datajoint-python',
keywords='database organization',
packages=find_packages(exclude=['contrib', 'docs', 'tests*']),
Expand Down
14 changes: 7 additions & 7 deletions tests/test_fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,11 @@ def test_getitem(self):

def test_getitem_for_fetch1(self):
"""Testing Fetch1.__getitem__"""
assert_true( (self.subject & "subject_id=10").fetch1['subject_id'] == 10)
assert_true( (self.subject & "subject_id=10").fetch1['subject_id','species'] == (10, 'monkey'))
assert_true( (self.subject & "subject_id=10").fetch1['subject_id':'species'] == (10, 'Curious George'))

assert_true((self.subject & "subject_id=10").fetch1['subject_id'] == 10)
assert_equal((self.subject & "subject_id=10").fetch1['subject_id', 'species'],
(10, 'monkey'))
assert_equal((self.subject & "subject_id=10").fetch1['subject_id':'species'],
(10, 'Curious George'))

def test_order_by(self):
"""Tests order_by sorting order"""
Expand Down Expand Up @@ -113,14 +114,14 @@ def test_keys(self):
langs.sort(key=itemgetter(0), reverse=True)
langs.sort(key=itemgetter(1), reverse=False)

cur = self.lang.fetch.order_by('language', 'name DESC')['name','language']
cur = self.lang.fetch.order_by('language', 'name DESC')['name', 'language']
cur2 = list(self.lang.fetch.order_by('language', 'name DESC').keys())

for c, c2 in zip(zip(*cur), cur2):
assert_true(c == tuple(c2.values()), 'Values are not the same')

def test_fetch1(self):
key = {'name': 'Edgar', 'language':'Japanese'}
key = {'name': 'Edgar', 'language': 'Japanese'}
true = schema.Language.contents[-1]

dat = (self.lang & key).fetch1()
Expand Down Expand Up @@ -170,7 +171,6 @@ def test_offset(self):
for c, l in list(zip(cur, langs[1:]))[:4]:
assert_true(np.all([cc == ll for cc, ll in zip(c, l)]), 'Sorting order is different')


def test_limit_warning(self):
"""Tests whether warning is raised if offset is used without limit."""
with warnings.catch_warnings(record=True) as w:
Expand Down