diff --git a/datajoint/erd.py b/datajoint/erd.py index fefbc8031..14858d530 100644 --- a/datajoint/erd.py +++ b/datajoint/erd.py @@ -4,14 +4,11 @@ import logging from collections import defaultdict -import pyparsing as pp import networkx as nx from networkx import DiGraph from functools import cmp_to_key import operator -from collections import OrderedDict - # use pygraphviz if available try: from networkx import pygraphviz_layout @@ -20,37 +17,21 @@ import matplotlib.pyplot as plt from inspect import isabstract -from .base_relation import BaseRelation +from .user_relations import UserRelation, Part logger = logging.getLogger(__name__) -def get_concrete_descendants(cls): +def get_concrete_subclasses(cls): desc = [] child= cls.__subclasses__() for c in child: if not isabstract(c): desc.append(c) - desc.extend(get_concrete_descendants(c)) + desc.extend(get_concrete_subclasses(c)) return desc -def parse_base_relations(rels): - name_map = {} - for r in rels: - try: - name_map[r().full_table_name] = '{module}.{cls}'.format(module=r.__module__, cls=r.__name__) - except TypeError: - # skip if failed to instantiate BaseRelation derivative - pass - return name_map - - -def get_table_relation_name_map(): - rels = get_concrete_descendants(BaseRelation) - return parse_base_relations(rels) - - class ERD(DiGraph): """ A directed graph representing dependencies between Relations within and across @@ -65,15 +46,24 @@ def node_labels(self): """ :return: dictionary of key : label pairs for plotting """ - name_map = get_table_relation_name_map() + def full_class_name(user_class): + if issubclass(user_class, Part): + return '{module}.{master}.{cls}'.format( + module=user_class.__module__, + master=user_class.master.__name__, + cls=user_class.__name__) + else: + return '{module}.{cls}'.format( + module=user_class.__module__, + cls=user_class.__name__) + + name_map = {rel.full_table_name: full_class_name(rel) for rel in get_concrete_subclasses(UserRelation)} return {k: self.get_label(k, name_map) for k in self.nodes()} def get_label(self, node, name_map=None): label = self.node[node].get('label', '') if label.strip(): return label - - # it's not efficient to recreate name-map on every call! if name_map is not None and node in name_map: return name_map[node] # no other name exists, so just use full table now diff --git a/datajoint/schema.py b/datajoint/schema.py index 985e038b4..44be4c1a5 100644 --- a/datajoint/schema.py +++ b/datajoint/schema.py @@ -87,15 +87,18 @@ def process_relation_class(relation_class, context): process_relation_class(cls, context=self.context) - # Process subordinate relations - parts = list() - is_part = lambda x: inspect.isclass(x) and issubclass(x, Part) + # Process part relations + def is_part(x): + return inspect.isclass(x) and issubclass(x, Part) - for var, part in inspect.getmembers(cls, is_part): - parts.append(part) - part._master = cls - # TODO: look into local namespace for the subclasses - process_relation_class(part, context=dict(self.context, **{cls.__name__: cls})) + parts = list() + for part in dir(cls): + if part[0].isupper(): + part = getattr(cls, part) + if is_part(part): + parts.append(part) + part._master = cls + process_relation_class(part, context=dict(self.context, **{cls.__name__: cls})) # invoke Relation._prepare() on class and its part relations. cls()._prepare() diff --git a/datajoint/user_relations.py b/datajoint/user_relations.py index 4f5875bac..9ffe75045 100644 --- a/datajoint/user_relations.py +++ b/datajoint/user_relations.py @@ -2,59 +2,83 @@ Hosts the table tiers, user relations should be derived from. """ -import abc from .base_relation import BaseRelation from .autopopulate import AutoPopulate from .utils import from_camel_case from . import DataJointError -class Part(BaseRelation, metaclass=abc.ABCMeta): +class classproperty: + + def __init__(self, f): + self.f = f + + def __get__(self, obj, owner): + return self.f(owner) + + +class UserRelation(BaseRelation): + """ + A subclass of UserRelation defines is a dedicated class interfacing a base relation. + UserRelation is initialized by the decorator generated by schema(). + """ + _connection = None + _context = None + _heading = None + + @classproperty + def connection(cls): + return cls._connection + + @classproperty + def full_table_name(cls): + return r"`{0:s}`.`{1:s}`".format(cls.database, cls.table_name) + + +class Part(UserRelation): """ Inherit from this class if the table's values are details of an entry in another relation and if this table is populated by this relation. For example, the entries inheriting from dj.Part could be single entries of a matrix, while the parent table refers to the entire matrix. Part relations are implemented as classes inside classes. """ + _master = None - @property - def master(self): - if not hasattr(self, '_master'): - raise DataJointError( - 'Part relations must be declared inside a base relation class') - return self._master + @classproperty + def master(cls): + return cls._master - @property - def table_name(self): - return self.master().table_name + '__' + from_camel_case(self.__class__.__name__) + @classproperty + def table_name(cls): + return cls.master.table_name + '__' + from_camel_case(cls.__name__) -class Manual(BaseRelation, metaclass=abc.ABCMeta): +class Manual(UserRelation): """ Inherit from this class if the table's values are entered manually. """ - @property - def table_name(self): + @classproperty + def table_name(cls): """ :returns: the table name of the table formatted for mysql. """ - return from_camel_case(self.__class__.__name__) + return from_camel_case(cls.__name__) -class Lookup(BaseRelation, metaclass=abc.ABCMeta): +class Lookup(UserRelation): """ Inherit from this class if the table's values are for lookup. This is currently equivalent to defining the table as Manual and serves semantic purposes only. """ - @property - def table_name(self): + @classproperty + def table_name(cls): """ :returns: the table name of the table formatted for mysql. """ - return '#' + from_camel_case(self.__class__.__name__) + return '#' + from_camel_case(cls.__name__) def _prepare(self): """ @@ -64,29 +88,29 @@ def _prepare(self): self.insert(self.contents, skip_duplicates=True) -class Imported(BaseRelation, AutoPopulate, metaclass=abc.ABCMeta): +class Imported(UserRelation, AutoPopulate): """ Inherit from this class if the table's values are imported from external data sources. The inherited class must at least provide the function `_make_tuples`. """ - @property - def table_name(self): + @classproperty + def table_name(cls): """ :returns: the table name of the table formatted for mysql. """ - return "_" + from_camel_case(self.__class__.__name__) + return "_" + from_camel_case(cls.__name__) -class Computed(BaseRelation, AutoPopulate, metaclass=abc.ABCMeta): +class Computed(UserRelation, AutoPopulate): """ Inherit from this class if the table's values are computed from other relations in the schema. The inherited class must at least provide the function `_make_tuples`. """ - @property - def table_name(self): + @classproperty + def table_name(cls): """ :returns: the table name of the table formatted for mysql. """ - return "__" + from_camel_case(self.__class__.__name__) + return "__" + from_camel_case(cls.__name__)