diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index 3223b1b47..2a0379c21 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -1,6 +1,5 @@ from .relational_operand import RelationalOperand from . import DataJointError -import pprint import abc import logging @@ -24,7 +23,7 @@ def pop_rel(self): pass @abc.abstractmethod - def make_tuples(self, key): + def _make_tuples(self, key): """ Derived classes must implement method make_tuples that fetches data from parent tables, restricting by the given key, computes dependent attributes, and inserts the new tuples into self. @@ -35,38 +34,32 @@ def make_tuples(self, key): def target(self): return self - def populate(self, catch_errors=False, reserve_jobs=False, restrict=None): + def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False): """ - rel.populate() will call rel.make_tuples(key) for every primary key in self.pop_rel + rel.populate() calls rel._make_tuples(key) for every primary key in self.pop_rel for which there is not already a tuple in rel. """ + error_list = [] if suppress_errors else None if not isinstance(self.pop_rel, RelationalOperand): - raise DataJointError('') - self.conn._cancel_transaction() - - unpopulated = self.pop_rel - self.target - if not unpopulated.count: - logger.info('Nothing to populate', flush=True) - if catch_errors: - error_keys, errors = [], [] - for key in unpopulated.fetch(): - self.conn._start_transaction() - n = self(key).count - if n: # already populated + raise DataJointError('Invalid pop_rel value') + self.conn._cancel_transaction() # rollback previous transaction, if any + unpopulated = (self.pop_rel - self.target) & restriction + for key in unpopulated.project(): + self.conn._start_transaction() + if key in self.target: # already populated + self.conn._cancel_transaction() + else: + logger.info('Populating: ' + str(key)) + try: + self._make_tuples(key) + except Exception as error: self.conn._cancel_transaction() - else: - print('Populating:') - pprint.pprint(key) - try: - self.make_tuples(key) - except Exception as e: - self.conn._cancel_transaction() - if not catch_errors: - raise - print(e) - errors += [e] - error_keys += [key] + if not suppress_errors: + raise else: - self.conn._commit_transaction() - if catch_errors: - return errors, error_keys \ No newline at end of file + print(error) + error_list.append((key, error)) + else: + self.conn._commit_transaction() + logger.info('Done populating.', flush=True) + return error_list \ No newline at end of file diff --git a/datajoint/free_relation.py b/datajoint/free_relation.py index 922f30fdd..ff27236e8 100644 --- a/datajoint/free_relation.py +++ b/datajoint/free_relation.py @@ -86,13 +86,9 @@ def _field_to_sql(field): #TODO move this into Attribute Tuple default = 'NOT NULL' # if some default specified if field.default: - # enclose value in quotes (even numeric), except special SQL values - # or values already enclosed by the user - if field.default.upper() in mysql_constants or field.default[:1] in ["'", '"']: - default = '%s DEFAULT %s' % (default, field.default) - else: - default = '%s DEFAULT "%s"' % (default, field.default) - + # enclose value in quotes except special SQL values or already enclosed + quote = field.default.upper() not in mysql_constants and field.default[0] not in '"\'' + default += ' DEFAULT ' + ('"%s"' if quote else "%s") % field.default if any((c in r'\"' for c in field.comment)): raise DataJointError('Illegal characters in attribute comment "%s"' % field.comment) @@ -336,7 +332,7 @@ def _declare(self): if not defined_name == self.ref_name: raise DataJointError('Table name {} does not match the declared' - 'name {}'.format(expected_name, defined_name)) + 'name {}'.format(self.ref_name, defined_name)) # compile the CREATE TABLE statement # TODO: support prefix diff --git a/datajoint/relational_operand.py b/datajoint/relational_operand.py index 916f7cca7..396048d4d 100644 --- a/datajoint/relational_operand.py +++ b/datajoint/relational_operand.py @@ -88,9 +88,10 @@ def __iand__(self, restriction): """ in-place relational restriction or semijoin """ - if self._restrictions is None: - self._restrictions = [] - self._restrictions.append(restriction) + if restriction is not None: + if self._restrictions is None: + self._restrictions = [] + self._restrictions.append(restriction) return self def __and__(self, restriction): @@ -122,12 +123,23 @@ def make_select(self, attribute_spec=None): attribute_spec = self.heading.as_sql return 'SELECT ' + attribute_spec + ' FROM ' + self.from_clause + self.where_clause - @property - def count(self): + def __len__(self): + """ + number of tuples in the relation. This also takes care of the truth value + """ cur = self.conn.query(self.make_select('count(*)')) return cur.fetchone()[0] + def __contains__(self, item): + """ + "item in relation" is equivalient to "len(relation & item)>0" + """ + return len(self & item)>0 + def __call__(self, *args, **kwargs): + """ + calling a relation is equivalent to fetching from it + """ return self.fetch(*args, **kwargs) def fetch(self, offset=0, limit=None, order_by=None, descending=False, as_dict=False): @@ -181,9 +193,9 @@ def __repr__(self): repr_string += ' '.join(['+' + '-'*(width-2) + '+' for _ in columns]) + '\n' for tup in rel.fetch(limit=limit): repr_string += ' '.join([template % column for column in tup]) + '\n' - if self.count > limit: + if len(self) > limit: repr_string += '...\n' - repr_string += ' (%d tuples)\n' % self.count + repr_string += ' (%d tuples)\n' % len(self) return repr_string def __iter__(self): diff --git a/tests/test_relation.py b/tests/test_relation.py index 29c7cb8f9..830dea26e 100644 --- a/tests/test_relation.py +++ b/tests/test_relation.py @@ -73,7 +73,7 @@ def test_compound_restriction(self): tM = t & (s & "real_id = 'M'") t1 = t & "subject_id = 1" - assert_equal(tM.count, t1.count, "Results of compound request does not have same length") + assert_equal(len(tM), len(t1), "Results of compound request does not have same length") for t1_item, tM_item in zip(sorted(t1, key=lambda item: item['trial_id']), sorted(tM, key=lambda item: item['trial_id'])): diff --git a/tests/test_relational_operand.py b/tests/test_relational_operand.py index ae7dd31ca..3a25a87d0 100644 --- a/tests/test_relational_operand.py +++ b/tests/test_relational_operand.py @@ -31,7 +31,7 @@ def test_isub(self): def test_sub(self): pass - def test_count(self): + def test_len(self): pass def test_fetch(self):