diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index 753e42e5b..8efdcb37b 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -19,6 +19,7 @@ class AutoPopulate(metaclass=abc.ABCMeta): must define the property populated_from, and must define the callback method _make_tuples. """ _jobs = None + _populated_from = None @property def populated_from(self): @@ -28,13 +29,16 @@ def populated_from(self): join of the parent relations. Users may override to change the granularity or the scope of populate() calls. """ - parents = [FreeRelation(self.target.connection, rel) for rel in self.target.parents] - if not parents: - raise DataJointError('A relation must have parent relations to be able to be populated') - ret = parents.pop(0) - while parents: - ret *= parents.pop(0) - return ret + if self._populated_from is None: + self.connection.dependencies.load() + parents = [FreeRelation(self.target.connection, rel) for rel in self.target.parents] + if not parents: + raise DataJointError('A relation must have parent relations to be able to be populated') + ret = parents.pop(0) + while parents: + ret *= parents.pop(0) + self._populated_from = ret + return self._populated_from @abc.abstractmethod def _make_tuples(self, key): @@ -70,17 +74,16 @@ def populate(self, restriction=None, suppress_errors=False, if order not in valid_order: raise DataJointError('The order argument must be one of %s' % str(valid_order)) - self.connection.dependencies.load() - - if not isinstance(self.populated_from, RelationalOperand): + todo = self.populated_from + if not isinstance(todo, RelationalOperand): raise DataJointError('Invalid populated_from value') + todo.restrict(restriction) error_list = [] if suppress_errors else None jobs = self.connection.jobs[self.target.database] table_name = self.target.table_name - unpopulated = (self.populated_from & restriction) - self.target.project() - keys = unpopulated.fetch.keys() + keys = (todo - self.target.project()).fetch.keys() if order == "reverse": keys = list(keys).reverse() elif order == "random": @@ -113,14 +116,14 @@ def populate(self, restriction=None, suppress_errors=False, jobs.complete(table_name, key) return error_list - def progress(self, restriction=None, display=True): """ report progress of populating this table :return: remaining, total -- tuples to be populated """ - total = len(self.populated_from & restriction) - remaining = len((self.populated_from & restriction) - self.target.project()) + todo = self.populated_from & restriction + total = len(todo) + remaining = len(todo - self.target.project()) if display: print('%-20s' % self.__class__.__name__, flush=True, end=': ') print('Completed %d of %d (%2.1f%%) %s' % diff --git a/datajoint/dependencies.py b/datajoint/dependencies.py index abc42c84d..722fd5244 100644 --- a/datajoint/dependencies.py +++ b/datajoint/dependencies.py @@ -3,18 +3,6 @@ from . import DataJointError from functools import wraps -def load_dependencies(func): - """ - Decorator that ensures that dependencies are loaded - """ - - @wraps(func) - def f(*args, **kwargs): - args[0].load() - return func(*args, **kwargs) - return f - - class Dependencies: """ @@ -32,23 +20,18 @@ def __init__(self, conn): @property - @load_dependencies def parents(self): return self._parents - @property - @load_dependencies def children(self): return self._children @property - @load_dependencies def references(self): return self._references @property - @load_dependencies def referenced(self): return self._referenced diff --git a/datajoint/erd.py b/datajoint/erd.py index 5655f08e3..fefbc8031 100644 --- a/datajoint/erd.py +++ b/datajoint/erd.py @@ -19,15 +19,11 @@ pygraphviz_layout = None import matplotlib.pyplot as plt -from . import DataJointError -from functools import wraps -from .utils import to_camel_case +from inspect import isabstract from .base_relation import BaseRelation logger = logging.getLogger(__name__) -from inspect import isabstract - def get_concrete_descendants(cls): desc = [] @@ -153,7 +149,7 @@ def plot(self): # draw non-primary key relations nx.draw_networkx_edges(self, pos, self.non_pk_edges, style='dashed', arrows=False) apos = np.array(list(pos.values())) - xmax = apos[:, 0].max() + 200 #TODO: use something more sensible then hard fixed number + xmax = apos[:, 0].max() + 200 # TODO: use something more sensible than hard fixed number xmin = apos[:, 0].min() - 100 ax.set_xlim(xmin, xmax) ax.axis('off') # hide axis @@ -378,7 +374,7 @@ def remove_edges_in_path(self, path): Removes all shared edges between this graph and the path :param path: a list of nodes defining a path. All edges in this path will be removed from the graph if found """ - if len(path) <= 1: # no path exists! + if len(path) <= 1: # no path exists! return for a, b in zip(path[:-1], path[1:]): self.remove_edge(a, b) diff --git a/datajoint/relational_operand.py b/datajoint/relational_operand.py index d2faade5a..698201015 100644 --- a/datajoint/relational_operand.py +++ b/datajoint/relational_operand.py @@ -436,6 +436,7 @@ def restrict(self, *restrictions): has_restriction = any(isinstance(r, RelationalOperand) or r for r in restrictions) do_subquery = has_restriction and self.heading.computed if do_subquery: + # TODO fix this for the case (r & key).aggregate(m='compute') raise DataJointError('In-place restriction on renamed attributes is not allowed') super().restrict(*restrictions)