Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 18 additions & 15 deletions datajoint/autopopulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class AutoPopulate(metaclass=abc.ABCMeta):
must define the property populated_from, and must define the callback method _make_tuples.
"""
_jobs = None
_populated_from = None

@property
def populated_from(self):
Expand All @@ -28,13 +29,16 @@ def populated_from(self):
join of the parent relations. Users may override to change the granularity
or the scope of populate() calls.
"""
parents = [FreeRelation(self.target.connection, rel) for rel in self.target.parents]
if not parents:
raise DataJointError('A relation must have parent relations to be able to be populated')
ret = parents.pop(0)
while parents:
ret *= parents.pop(0)
return ret
if self._populated_from is None:
self.connection.dependencies.load()
parents = [FreeRelation(self.target.connection, rel) for rel in self.target.parents]
if not parents:
raise DataJointError('A relation must have parent relations to be able to be populated')
ret = parents.pop(0)
while parents:
ret *= parents.pop(0)
self._populated_from = ret
return self._populated_from

@abc.abstractmethod
def _make_tuples(self, key):
Expand Down Expand Up @@ -70,17 +74,16 @@ def populate(self, restriction=None, suppress_errors=False,
if order not in valid_order:
raise DataJointError('The order argument must be one of %s' % str(valid_order))

self.connection.dependencies.load()

if not isinstance(self.populated_from, RelationalOperand):
todo = self.populated_from
if not isinstance(todo, RelationalOperand):
raise DataJointError('Invalid populated_from value')
todo.restrict(restriction)

error_list = [] if suppress_errors else None

jobs = self.connection.jobs[self.target.database]
table_name = self.target.table_name
unpopulated = (self.populated_from & restriction) - self.target.project()
keys = unpopulated.fetch.keys()
keys = (todo - self.target.project()).fetch.keys()
if order == "reverse":
keys = list(keys).reverse()
elif order == "random":
Expand Down Expand Up @@ -113,14 +116,14 @@ def populate(self, restriction=None, suppress_errors=False,
jobs.complete(table_name, key)
return error_list


def progress(self, restriction=None, display=True):
"""
report progress of populating this table
:return: remaining, total -- tuples to be populated
"""
total = len(self.populated_from & restriction)
remaining = len((self.populated_from & restriction) - self.target.project())
todo = self.populated_from & restriction
total = len(todo)
remaining = len(todo - self.target.project())
if display:
print('%-20s' % self.__class__.__name__, flush=True, end=': ')
print('Completed %d of %d (%2.1f%%) %s' %
Expand Down
17 changes: 0 additions & 17 deletions datajoint/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,6 @@
from . import DataJointError
from functools import wraps

def load_dependencies(func):
"""
Decorator that ensures that dependencies are loaded
"""

@wraps(func)
def f(*args, **kwargs):
args[0].load()
return func(*args, **kwargs)
return f



class Dependencies:
"""
Expand All @@ -32,23 +20,18 @@ def __init__(self, conn):


@property
@load_dependencies
def parents(self):
return self._parents


@property
@load_dependencies
def children(self):
return self._children

@property
@load_dependencies
def references(self):
return self._references

@property
@load_dependencies
def referenced(self):
return self._referenced

Expand Down
10 changes: 3 additions & 7 deletions datajoint/erd.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,11 @@
pygraphviz_layout = None

import matplotlib.pyplot as plt
from . import DataJointError
from functools import wraps
from .utils import to_camel_case
from inspect import isabstract
from .base_relation import BaseRelation

logger = logging.getLogger(__name__)

from inspect import isabstract


def get_concrete_descendants(cls):
desc = []
Expand Down Expand Up @@ -153,7 +149,7 @@ def plot(self):
# draw non-primary key relations
nx.draw_networkx_edges(self, pos, self.non_pk_edges, style='dashed', arrows=False)
apos = np.array(list(pos.values()))
xmax = apos[:, 0].max() + 200 #TODO: use something more sensible then hard fixed number
xmax = apos[:, 0].max() + 200 # TODO: use something more sensible than hard fixed number
xmin = apos[:, 0].min() - 100
ax.set_xlim(xmin, xmax)
ax.axis('off') # hide axis
Expand Down Expand Up @@ -378,7 +374,7 @@ def remove_edges_in_path(self, path):
Removes all shared edges between this graph and the path
:param path: a list of nodes defining a path. All edges in this path will be removed from the graph if found
"""
if len(path) <= 1: # no path exists!
if len(path) <= 1: # no path exists!
return
for a, b in zip(path[:-1], path[1:]):
self.remove_edge(a, b)
Expand Down
1 change: 1 addition & 0 deletions datajoint/relational_operand.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ def restrict(self, *restrictions):
has_restriction = any(isinstance(r, RelationalOperand) or r for r in restrictions)
do_subquery = has_restriction and self.heading.computed
if do_subquery:
# TODO fix this for the case (r & key).aggregate(m='compute')
raise DataJointError('In-place restriction on renamed attributes is not allowed')
super().restrict(*restrictions)

Expand Down