diff --git a/Makefile b/Makefile index 77a432a28..137068e47 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,7 @@ all: @echo 'MakeFile for DataJoint packaging ' @echo ' ' @echo 'make sdist Creates source distribution ' - @echo 'make wheel Creates Wheel dstribution ' + @echo 'make wheel Creates Wheel distribution ' @echo 'make pypi Package and upload to PyPI ' @echo 'make pypitest Package and upload to PyPI test server' @echo 'make purge Remove all build related directories ' diff --git a/datajoint/__init__.py b/datajoint/__init__.py index 47ce89b30..f9ad15b43 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -15,7 +15,7 @@ __version__ = "0.2" __all__ = ['__author__', '__version__', 'config', 'conn', 'kill', - 'Connection', 'Heading', 'Relation', 'FreeRelation', 'Not', 'schema', + 'Connection', 'Heading', 'BaseRelation', 'FreeRelation', 'Not', 'schema', 'Manual', 'Lookup', 'Imported', 'Computed', 'Part'] @@ -52,9 +52,9 @@ class DataJointError(Exception): # ------------- flatten import hierarchy ------------------------- from .connection import conn, Connection -from .relation import Relation +from .base_relation import BaseRelation from .user_relations import Manual, Lookup, Imported, Computed, Part from .relational_operand import Not from .heading import Heading -from .schema import schema +from .schema import Schema as schema from .kill import kill diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index 1cef4445e..753e42e5b 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -5,7 +5,7 @@ import random from .relational_operand import RelationalOperand from . import DataJointError -from .relation import FreeRelation +from .base_relation import FreeRelation # noinspection PyExceptionInherit,PyCallingNonCallable @@ -60,13 +60,9 @@ def populate(self, restriction=None, suppress_errors=False, :param restriction: restriction on rel.populated_from - target :param suppress_errors: suppresses error if true - :param reserve_jobs: currently not implemented - :param batch: batch size of a single job + :param reserve_jobs: if true, reserves job to populate in asynchronous fashion :param order: "original"|"reverse"|"random" - the order of execution """ - if not isinstance(self.populated_from, RelationalOperand): - raise DataJointError('Invalid populated_from value') - if self.connection.in_transaction: raise DataJointError('Populate cannot be called during a transaction.') @@ -74,6 +70,11 @@ def populate(self, restriction=None, suppress_errors=False, if order not in valid_order: raise DataJointError('The order argument must be one of %s' % str(valid_order)) + self.connection.dependencies.load() + + if not isinstance(self.populated_from, RelationalOperand): + raise DataJointError('Invalid populated_from value') + error_list = [] if suppress_errors else None jobs = self.connection.jobs[self.target.database] @@ -112,6 +113,7 @@ def populate(self, restriction=None, suppress_errors=False, jobs.complete(table_name, key) return error_list + def progress(self, restriction=None, display=True): """ report progress of populating this table diff --git a/datajoint/relation.py b/datajoint/base_relation.py similarity index 81% rename from datajoint/relation.py rename to datajoint/base_relation.py index 9a7bf4f7a..d9100a096 100644 --- a/datajoint/relation.py +++ b/datajoint/base_relation.py @@ -16,15 +16,16 @@ logger = logging.getLogger(__name__) -class Relation(RelationalOperand, metaclass=abc.ABCMeta): +class BaseRelation(RelationalOperand, metaclass=abc.ABCMeta): """ - Relation is an abstract class that represents a base relation, i.e. a table in the database. + BaseRelation is an abstract class that represents a base relation, i.e. a table in the database. To make it a concrete class, override the abstract properties specifying the connection, table name, database, context, and definition. A Relation implements insert and delete methods in addition to inherited relational operators. """ _heading = None _context = None + database = None # ---------- abstract properties ------------ # @property @@ -58,8 +59,10 @@ def heading(self): """ if self._heading is None: self._heading = Heading() # instance-level heading - if not self._heading: + if not self._heading: # heading is not initialized self.declare() + self._heading.init_from_database(self.connection, self.database, self.table_name) + return self._heading def declare(self): @@ -69,9 +72,6 @@ def declare(self): if not self.is_declared: self.connection.query( declare(self.full_table_name, self.definition, self._context)) - if self.is_declared: - self.connection.erm.load_dependencies(self.full_table_name) - self._heading.init_from_database(self.connection, self.database, self.table_name) @property def from_clause(self): @@ -87,13 +87,25 @@ def select_fields(self): """ return '*' - def erd(self, *args, **kwargs): + def erd(self, *args, fill=True, mode='updown', **kwargs): """ + :param mode: diffent methods of creating a graph pertaining to this relation. + Currently options includes the following: + * 'updown': Contains this relation and all other nodes that can be reached within specific + number of ups and downs in the graph. ups(=2) and downs(=2) are optional keyword arguments + * 'ancestors': Returs :return: the entity relationship diagram object of this relation """ erd = self.connection.erd() - nodes = erd.up_down_neighbors(self.full_table_name) - return erd.restrict_by_tables(nodes) + if mode == 'updown': + nodes = erd.up_down_neighbors(self.full_table_name, *args, **kwargs) + elif mode == 'ancestors': + nodes = erd.ancestors(self.full_table_name, *args, **kwargs) + elif mode == 'descendants': + nodes = erd.descendants(self.full_table_name, *args, **kwargs) + else: + raise DataJointError('Unsupported erd mode', 'Mode "%s" is currently not supported' % mode) + return erd.restrict_by_tables(nodes, fill=fill) # ------------- dependencies ---------- # @property @@ -101,28 +113,28 @@ def parents(self): """ :return: the parent relation of this relation """ - return self.connection.erm.parents[self.full_table_name] + return self.connection.dependencies.parents[self.full_table_name] @property def children(self): """ :return: the child relations of this relation """ - return self.connection.erm.children[self.full_table_name] + return self.connection.dependencies.children[self.full_table_name] @property def references(self): """ :return: list of tables that this tables refers to """ - return self.connection.erm.references[self.full_table_name] + return self.connection.dependencies.references[self.full_table_name] @property def referenced(self): """ :return: list of tables for which this table is referenced by """ - return self.connection.erm.referenced[self.full_table_name] + return self.connection.dependencies.referenced[self.full_table_name] @property def descendants(self): @@ -134,10 +146,13 @@ def descendants(self): :return: list of descendants """ relations = (FreeRelation(self.connection, table) - for table in self.connection.erm.get_descendants(self.full_table_name)) + for table in self.connection.dependencies.get_descendants(self.full_table_name)) return [relation for relation in relations if relation.is_declared] def _repr_helper(self): + """ + :return: String representation of this object + """ return "%s.%s()" % (self.__module__, self.__class__.__name__) # --------- SQL functionality --------- # @@ -162,7 +177,7 @@ def insert(self, rows, **kwargs): """ Insert a collection of rows. Additional keyword arguments are passed to insert1. - :param iter: Must be an iterator that generates a sequence of valid arguments for insert. + :param rows: An iterable where an element is a valid arguments for insert1. """ for row in rows: self.insert1(row, **kwargs) @@ -172,9 +187,9 @@ def insert1(self, tup, replace=False, ignore_errors=False, skip_duplicates=False Insert one data record or one Mapping (like a dict). :param tup: Data record, a Mapping (like a dict), or a list or tuple with ordered values. - :param replace=False: Replaces data tuple if True. + :param replace=False: If True, replaces the matching data tuple in the table if it exists. :param ignore_errors=False: If True, ignore errors: e.g. constraint violations. - :param skip_dublicates=False: If True, ignore duplicate inserts. + :param skip_dublicates=False: If True, silently skip duplicate inserts. Example:: @@ -185,14 +200,18 @@ def insert1(self, tup, replace=False, ignore_errors=False, skip_duplicates=False heading = self.heading def check_fields(fields): + """ + Validates that all items in `fields` are valid attributes in the heading + """ for field in fields: if field not in heading: raise KeyError(u'{0:s} is not in the attribute list'.format(field)) def make_attribute(name, value): """ - For a given attribute, return its value or value placeholder as a string to be included - in the query and the value, if any to be submitted for processing by mysql API. + For a given attribute `name` with `value, return its processed value or value placeholder + as a string to be included in the query and the value, if any to be submitted for + processing by mysql API. """ if heading[name].is_blob: value = pack(value) @@ -249,8 +268,10 @@ def make_attribute(name, value): def delete_quick(self): """ - Deletes the table without cascading and without user prompt. + Deletes the table without cascading and without user prompt. If this table has any dependent + table(s), this will fail. """ + #TODO: give a better exception message self.connection.query('DELETE FROM ' + self.from_clause + self.where_clause) def delete(self): @@ -258,6 +279,7 @@ def delete(self): Deletes the contents of the table and its dependent tables, recursively. User is prompted for confirmation if config['safemode'] is set to True. """ + self.connection.dependencies.load() # construct a list (OrderedDict) of relations to delete relations = OrderedDict((r.full_table_name, r) for r in self.descendants) @@ -281,7 +303,7 @@ def delete(self): for name, r in relations.items(): if restrictions[name]: # do not restrict by an empty list r.restrict([r.project() if isinstance(r, RelationalOperand) else r - for r in restrictions[name]]) # project + for r in restrictions[name]]) # project # execute do_delete = False # indicate if there is anything to delete @@ -308,19 +330,21 @@ def delete(self): def drop_quick(self): """ Drops the table associated with this relation without cascading and without user prompt. + If the table has any dependent table(s), this call will fail with an error. """ + #TODO: give a better exception message if self.is_declared: self.connection.query('DROP TABLE %s' % self.full_table_name) - self.connection.erm.clear_dependencies(self.full_table_name) - if self._heading: - self._heading.reset() logger.info("Dropped table %s" % self.full_table_name) + else: + logger.info("Nothing to drop: table %s is not declared" % self.full_table_name) def drop(self): """ Drop the table and all tables that reference it, recursively. User is prompted for confirmation if config['safemode'] is set to True. """ + self.connection.dependencies.load() do_drop = True relations = self.descendants if config['safemode']: @@ -351,9 +375,10 @@ def _prepare(self): pass -class FreeRelation(Relation): +class FreeRelation(BaseRelation): """ - A base relation without a dedicated class. The table name is explicitly set. + A base relation without a dedicated class. Each instance is associated with a table + specified by full_table_name. """ def __init__(self, connection, full_table_name, definition=None, context=None): self.database, self._table_name = (s.strip('`') for s in full_table_name.split('.')) diff --git a/datajoint/connection.py b/datajoint/connection.py index 58eb65dd6..02a6d9795 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -8,7 +8,8 @@ import logging from . import config from . import DataJointError -from .erd import ERM +from datajoint.erd import ERD +from .dependencies import Dependencies from .jobs import JobManager logger = logging.getLogger(__name__) @@ -67,7 +68,8 @@ def __init__(self, host, user, passwd, init_fun=None): self._conn.autocommit(True) self._in_transaction = False self.jobs = JobManager(self) - self.erm = ERM(self) + self.schemas = dict() + self.dependencies = Dependencies(self) def __del__(self): logger.info('Disconnecting {user}@{host}:{port}'.format(**self.conn_info)) @@ -81,8 +83,8 @@ def __repr__(self): return "DataJoint connection ({connected}) {user}@{host}:{port}".format( connected=connected, **self.conn_info) - def erd(self, *args, **kwargs): - return self.erm.copy_graph(*args, **kwargs) + def register(self, schema): + self.schemas[schema.database] = schema @property def is_connected(self): @@ -91,6 +93,10 @@ def is_connected(self): """ return self._conn.ping() + def erd(self): + self.dependencies.load() + return ERD.create_from_dependencies(self.dependencies) + def query(self, query, args=(), as_dict=False): """ Execute the specified query and return the tuple generator (cursor). diff --git a/datajoint/declare.py b/datajoint/declare.py index 4f2cd2463..e853cae40 100644 --- a/datajoint/declare.py +++ b/datajoint/declare.py @@ -38,8 +38,21 @@ def declare(full_table_name, definition, context): elif line.startswith('---') or line.startswith('___'): in_key = False # start parsing dependent attributes elif line.startswith('->'): - # foreign key - ref = eval(line[2:], context)() # TODO: surround this with try...except... to give a better error message + # foreign + # TODO: clean up import order + from .base_relation import BaseRelation + # TODO: break this step into finer steps, checking the type of reference before calling it + try: + ref = eval(line[2:], context)() + except NameError: + raise DataJointError('Foreign key reference %s could not be resolved' % line[2:]) + except TypeError: + raise DataJointError('Foreign key reference %s could not be instantiated.' + 'Make sure %s is a valid BaseRelation subclass' % line[2:]) + # TODO: consider the case where line[2:] is a function that returns an instance of BaseRelation + if not isinstance(ref, BaseRelation): + raise DataJointError('Foreign key reference %s must be a subclass of BaseRelation' % line[2:]) + foreign_key_sql.append( 'FOREIGN KEY ({primary_key})' ' REFERENCES {ref} ({primary_key})' diff --git a/datajoint/dependencies.py b/datajoint/dependencies.py new file mode 100644 index 000000000..3ca943a35 --- /dev/null +++ b/datajoint/dependencies.py @@ -0,0 +1,127 @@ +from collections import defaultdict +import pyparsing as pp +from . import DataJointError + + +class Dependencies: + """ + Lookup for dependencies between tables + """ + __primary_key_parser = (pp.CaselessLiteral('PRIMARY KEY') + + pp.QuotedString('(', endQuoteChar=')').setResultsName('primary_key')) + + def __init__(self, conn): + self._conn = conn + self._parents = dict() + self._referenced = dict() + self._children = defaultdict(list) + self._references = defaultdict(list) + + @property + def parents(self): + return self._parents + + + @property + def children(self): + return self._children + + @property + def references(self): + return self._references + + @property + def referenced(self): + return self._referenced + + def get_descendants(self, full_table_name): + """ + :param full_table_name: a table name in the format `database`.`table_name` + :return: list of all children and references, in order of dependence. + This is helpful for cascading delete or drop operations. + """ + ret = defaultdict(lambda: 0) + + def recurse(name, level): + ret[name] = max(ret[name], level) + for child in self.children[name] + self.references[name]: + recurse(child, level+1) + + recurse(full_table_name, 0) + return sorted(ret.keys(), key=ret.__getitem__) + + @staticmethod + def __foreign_key_parser(database): + def add_database(string, loc, toc): + return ['`{database}`.`{table}`'.format(database=database, table=toc[0])] + + return (pp.CaselessLiteral('CONSTRAINT').suppress() + + pp.QuotedString('`').suppress() + + pp.CaselessLiteral('FOREIGN KEY').suppress() + + pp.QuotedString('(', endQuoteChar=')').setResultsName('attributes') + + pp.CaselessLiteral('REFERENCES') + + pp.Or([ + pp.QuotedString('`').setParseAction(add_database), + pp.Combine(pp.QuotedString('`', unquoteResults=False) + '.' + + pp.QuotedString('`', unquoteResults=False))]).setResultsName('referenced_table') + + pp.QuotedString('(', endQuoteChar=')').setResultsName('referenced_attributes')) + + def load(self): + """ + Load dependencies for all tables that have not yet been loaded in all registered schemas + """ + for database in self._conn.schemas: + cur = self._conn.query('SHOW TABLES FROM `{database}`'.format(database=database)) + fk_parser = self.__foreign_key_parser(database) + # list tables that have not yet been loaded, exclude service tables that starts with ~ + tables = ('`{database}`.`{table_name}`'.format(database=database, table_name=info[0]) + for info in cur if not info[0].startswith('~')) + tables = (name for name in tables if name not in self._parents) + for name in tables: + self._parents[name] = list() + self._referenced[name] = list() + # fetch the CREATE TABLE statement + create_statement = self._conn.query('SHOW CREATE TABLE %s' % name).fetchone() + create_statement = create_statement[1].split('\n') + primary_key = None + for line in create_statement: + if primary_key is None: + try: + result = self.__primary_key_parser.parseString(line) + except pp.ParseException: + pass + else: + primary_key = [s.strip(' `') for s in result.primary_key.split(',')] + try: + result = fk_parser.parseString(line) + except pp.ParseException: + pass + else: + if not primary_key: + raise DataJointError('No primary key found %s' % name) + if result.referenced_attributes != result.attributes: + raise DataJointError( + "%s's foreign key refers to differently named attributes in %s" + % (self.__class__.__name__, result.referenced_table)) + if all(q in primary_key for q in [s.strip('` ') for s in result.attributes.split(',')]): + self._parents[name].append(result.referenced_table) + self._children[result.referenced_table].append(name) + else: + self._referenced[name].append(result.referenced_table) + self._references[result.referenced_table].append(name) + + def get_descendants(self, full_table_name): + """ + :param full_table_name: a table name in the format `database`.`table_name` + :return: list of all children and references, in order of dependence. + This is helpful for cascading delete or drop operations. + """ + ret = defaultdict(lambda: 0) + + def recurse(name, level): + ret[name] = max(ret[name], level) + for child in self.children[name] + self.references[name]: + recurse(child, level+1) + + recurse(full_table_name, 0) + return sorted(ret.keys(), key=ret.__getitem__) diff --git a/datajoint/erd.py b/datajoint/erd.py index aae2f4055..5655f08e3 100644 --- a/datajoint/erd.py +++ b/datajoint/erd.py @@ -7,26 +7,88 @@ import pyparsing as pp import networkx as nx from networkx import DiGraph -from networkx import pygraphviz_layout +from functools import cmp_to_key +import operator + +from collections import OrderedDict + +# use pygraphviz if available +try: + from networkx import pygraphviz_layout +except: + pygraphviz_layout = None + import matplotlib.pyplot as plt from . import DataJointError +from functools import wraps from .utils import to_camel_case +from .base_relation import BaseRelation logger = logging.getLogger(__name__) +from inspect import isabstract + + +def get_concrete_descendants(cls): + desc = [] + child= cls.__subclasses__() + for c in child: + if not isabstract(c): + desc.append(c) + desc.extend(get_concrete_descendants(c)) + return desc + + +def parse_base_relations(rels): + name_map = {} + for r in rels: + try: + name_map[r().full_table_name] = '{module}.{cls}'.format(module=r.__module__, cls=r.__name__) + except TypeError: + # skip if failed to instantiate BaseRelation derivative + pass + return name_map -class RelGraph(DiGraph): + +def get_table_relation_name_map(): + rels = get_concrete_descendants(BaseRelation) + return parse_base_relations(rels) + + +class ERD(DiGraph): """ - A directed graph representing relations between tables within and across - multiple databases found. + A directed graph representing dependencies between Relations within and across + multiple databases. """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + @property def node_labels(self): """ :return: dictionary of key : label pairs for plotting """ - return {k: attr['label'] for k, attr in self.node.items()} + name_map = get_table_relation_name_map() + return {k: self.get_label(k, name_map) for k in self.nodes()} + + def get_label(self, node, name_map=None): + label = self.node[node].get('label', '') + if label.strip(): + return label + + # it's not efficient to recreate name-map on every call! + if name_map is not None and node in name_map: + return name_map[node] + # no other name exists, so just use full table now + return node + + @property + def lone_nodes(self): + """ + :return: list of nodes that are not connected to any other node + """ + return list(x for x in self.root_nodes if len(self.out_edges(x)) == 0) @property def pk_edges(self): @@ -47,7 +109,7 @@ def non_pk_edges(self): def highlight(self, nodes): """ Highlights specified nodes when plotting - :param nodes: list of nodes to be highlighted + :param nodes: list of nodes, specified by full table names, to be highlighted """ for node in nodes: self.node[node]['highlight'] = True @@ -56,7 +118,7 @@ def remove_highlight(self, nodes=None): """ Remove highlights from specified nodes when plotting. If specified node is not highlighted to begin with, nothing happens. - :param nodes: list of nodes to remove highlights from + :param nodes: list of nodes, specified by full table names, to remove highlights from """ if not nodes: nodes = self.nodes_iter() @@ -72,6 +134,10 @@ def plot(self): if not self.nodes(): # There is nothing to plot logger.warning('Nothing to plot') return + if pygraphviz_layout is None: + logger.warning('Failed to load Pygraphviz - plotting not supported at this time') + return + pos = pygraphviz_layout(self, prog='dot') fig = plt.figure(figsize=[10, 7]) ax = fig.add_subplot(111) @@ -93,17 +159,16 @@ def plot(self): ax.axis('off') # hide axis def __repr__(self): - #TODO: provide string version of ERD - return "RelGraph: to be implemented" + return self.repr_path() - def restrict_by_modules(self, modules, fill=False): + def restrict_by_databases(self, databases, fill=False): """ - Creates a subgraph containing only tables in the specified modules. - :param modules: list of module names - :param fill: set True to automatically include nodes connecting two nodes in the specified modules + Creates a subgraph containing only tables in the specified database. + :param databases: list of database names + :param fill: if True, automatically include nodes connecting two nodes in the specified modules :return: a subgraph with specified nodes """ - nodes = [n for n in self.nodes() if self.node[n].get('mod') in modules] + nodes = [n for n in self.nodes() if n.split('.')[0].strip('`') in databases] if fill: nodes = self.fill_connection_nodes(nodes) return self.subgraph(nodes) @@ -111,12 +176,12 @@ def restrict_by_modules(self, modules, fill=False): def restrict_by_tables(self, tables, fill=False): """ Creates a subgraph containing only specified tables. - :param tables: list of tables to keep in the subgraph + :param tables: list of tables to keep in the subgraph. Tables are specified using full table names :param fill: set True to automatically include nodes connecting two nodes in the specified list of tables :return: a subgraph with specified nodes """ - nodes = [n for n in self.nodes() if self.node[n].get('label') in tables] + nodes = [n for n in self.nodes() if n in tables] if fill: nodes = self.fill_connection_nodes(nodes) return self.subgraph(nodes) @@ -125,7 +190,7 @@ def restrict_by_tables_in_module(self, module, tables, fill=False): nodes = [n for n in self.nodes() if self.node[n].get('mod') in module and self.node[n].get('cls') in tables] if fill: - nodes = self.fill_connection_nodes(nodes) + nodes = self.fill_connection_nodes(nodes) return self.subgraph(nodes) def fill_connection_nodes(self, nodes): @@ -137,54 +202,69 @@ def fill_connection_nodes(self, nodes): graph = self.subgraph(self.ancestors_of_all(nodes)) return graph.descendants_of_all(nodes) - def ancestors_of_all(self, nodes): + def ancestors_of_all(self, nodes, n=-1): """ Find and return a set of all ancestors of the given nodes. The set will also contain the specified nodes. :param nodes: list of nodes for which ancestors are to be found + :param n: maximum number of generations to go up for each node. + If set to a negative number, will return all ancestors. :return: a set containing passed in nodes and all of their ancestors """ s = set() - for n in nodes: - s.update(self.ancestors(n)) + for node in nodes: + s.update(self.ancestors(node, n)) return s - def descendants_of_all(self, nodes): + def descendants_of_all(self, nodes, n=-1): """ Find and return a set including all descendants of the given nodes. The set will also contain the given nodes as well. :param nodes: list of nodes for which descendants are to be found + :param n: maximum number of generations to go down for each node. + If set to a negative number, will return all descendants. :return: a set containing passed in nodes and all of their descendants """ s = set() - for n in nodes: - s.update(self.descendants(n)) + for node in nodes: + s.update(self.descendants(node, n)) return s - def ancestors(self, node): + def copy_graph(self, *args, **kwargs): + return self.__class__(self, *args, **kwargs) + + def ancestors(self, node, n=-1): """ Find and return a set containing all ancestors of the specified node. For convenience in plotting, this set will also include the specified node as well (may change in future). :param node: node for which all ancestors are to be discovered + :param n: maximum number of generations to go up. If set to a negative number, + will return all ancestors. :return: a set containing the node and all of its ancestors """ s = {node} + if n == 0: + return s for p in self.predecessors_iter(node): - s.update(self.ancestors(p)) + s.update(self.ancestors(p, n-1)) return s - def descendants(self, node): + def descendants(self, node, n=-1): """ Find and return a set containing all descendants of the specified node. For convenience in plotting, this set will also include the specified node as well (may change in future). :param node: node for which all descendants are to be discovered + :param n: maximum number of generations to go down. If set to a negative number, + will return all descendants :return: a set containing the node and all of its descendants """ s = {node} + if n == 0: + return s for c in self.successors_iter(node): - s.update(self.descendants(c)) + s.update(self.descendants(c, n-1)) return s def up_down_neighbors(self, node, ups=2, downs=2, _prev=None): @@ -245,157 +325,163 @@ def n_neighbors(self, node, n, directed=False, prev=None): s.update(self.n_neighbors(x, n-1, prev)) return s + @property + def root_nodes(self): + return {node for node in self.nodes() if len(self.predecessors(node)) == 0} -class ERM(RelGraph): - """ - Entity Relation Map - - Represents known relation between tables - """ - # _checked_dependencies = set() - - def __init__(self, conn, *args, **kwargs): - super().__init__(*args, **kwargs) - self._conn = conn - self._parents = dict() - self._referenced = dict() - self._children = defaultdict(list) - self._references = defaultdict(list) - if conn.is_connected: - self._conn = conn - else: - raise DataJointError('The connection is broken') #TODO: make better exception message - self.update_graph() - - def update_graph(self, reload=False): - self.clear() - - # create primary key foreign connections - for table, parents in self._parents.items(): - mod, cls = (x.strip('`') for x in table.split('.')) - self.add_node(table, label=table, mod=mod, cls=cls) - for parent in parents: - self.add_edge(parent, table, rel='parent') - - # create non primary key foreign connections - for table, referenced in self._referenced.items(): - for ref in referenced: - self.add_edge(ref, table, rel='referenced') + @property + def leaf_nodes(self): + return {node for node in self.nodes() if len(self.successors(node)) == 0} - def copy_graph(self, *args, **kwargs): + def nodes_by_depth(self): """ - Return copy of the graph represented by this object at the - time of call. Note that the returned graph is no longer - bound to a connection. + Return all nodes, ordered by their depth in the hierarchy + :returns: list of nodes, ordered by depth from shallowest to deepest """ - return RelGraph(self, *args, **kwargs) + ret = defaultdict(lambda: 0) + roots = self.root_nodes - def subgraph(self, *args, **kwargs): - return RelGraph(self).subgraph(*args, **kwargs) + def recurse(node, depth): + if depth > ret[node]: + ret[node] = depth + for child in self.successors_iter(node): + recurse(child, depth+1) - def load_dependencies(self, full_table_name): - # check if already loaded. Use clear_dependencies before reloading - if full_table_name in self._parents: - return - self._parents[full_table_name] = list() - self._referenced[full_table_name] = list() - - # fetch the CREATE TABLE statement - cur = self._conn.query('SHOW CREATE TABLE %s' % full_table_name) - create_statement = cur.fetchone() - if not create_statement: - raise DataJointError('Could not load the definition for %s' % full_table_name) - create_statement = create_statement[1].split('\n') - - # build foreign key fk_parser - database = full_table_name.split('.')[0].strip('`') - add_database = lambda string, loc, toc: ['`{database}`.`{table}`'.format(database=database, table=toc[0])] - - # primary key parser - pk_parser = pp.CaselessLiteral('PRIMARY KEY') - pk_parser += pp.QuotedString('(', endQuoteChar=')').setResultsName('primary_key') - - # foreign key parser - fk_parser = pp.CaselessLiteral('CONSTRAINT').suppress() - fk_parser += pp.QuotedString('`').suppress() - fk_parser += pp.CaselessLiteral('FOREIGN KEY').suppress() - fk_parser += pp.QuotedString('(', endQuoteChar=')').setResultsName('attributes') - fk_parser += pp.CaselessLiteral('REFERENCES') - fk_parser += pp.Or([ - pp.QuotedString('`').setParseAction(add_database), - pp.Combine(pp.QuotedString('`', unquoteResults=False) + - '.' + pp.QuotedString('`', unquoteResults=False)) - ]).setResultsName('referenced_table') - fk_parser += pp.QuotedString('(', endQuoteChar=')').setResultsName('referenced_attributes') - - # parse foreign keys - primary_key = None - for line in create_statement: - if primary_key is None: - try: - result = pk_parser.parseString(line) - except pp.ParseException: - pass - else: - primary_key = [s.strip(' `') for s in result.primary_key.split(',')] - try: - result = fk_parser.parseString(line) - except pp.ParseException: - pass - else: - if not primary_key: - raise DataJointError('No primary key found %s' % full_table_name) - if result.referenced_attributes != result.attributes: - raise DataJointError( - "%s's foreign key refers to differently named attributes in %s" - % (self.__class__.__name__, result.referenced_table)) - if all(q in primary_key for q in [s.strip('` ') for s in result.attributes.split(',')]): - self._parents[full_table_name].append(result.referenced_table) - self._children[result.referenced_table].append(full_table_name) - else: - self._referenced[full_table_name].append(result.referenced_table) - self._references[result.referenced_table].append(full_table_name) - - self.update_graph() - - def clear_dependencies(self, full_table_name): - for ref in self._parents.pop(full_table_name, []): - if full_table_name in self._children[ref]: - self._children[ref].remove(full_table_name) - for ref in self._referenced.pop(full_table_name, []): - if full_table_name in self._references[ref]: - self._references[ref].remove(full_table_name) + for root in roots: + recurse(root, 0) - @property - def parents(self): - return self._parents + return sorted(ret.items(), key=operator.itemgetter(1)) - @property - def children(self): - return self._children + def get_longest_path(self): + """ + :returns: a list of graph nodes defining th longest path in the graph + """ + # no path exists if there is not an edge! + if not self.edges(): + return [] - @property - def references(self): - return self._references + node_depth_list = self.nodes_by_depth() + node_depth_lookup = dict(node_depth_list) + path = [] - @property - def referenced(self): - return self._referenced + leaf = node_depth_list[-1][0] + predecessors = [leaf] + while predecessors: + leaf = sorted(predecessors, key=node_depth_lookup.get)[-1] + path.insert(0, leaf) + predecessors = self.predecessors(leaf) - def get_descendants(self, full_table_name): + return path + + def remove_edges_in_path(self, path): """ - :param full_table_name: a table name in the format `database`.`table_name` - :return: list of all children and references, in order of dependence. - This is helpful for cascading delete or drop operations. + Removes all shared edges between this graph and the path + :param path: a list of nodes defining a path. All edges in this path will be removed from the graph if found """ - ret = defaultdict(lambda: 0) + if len(path) <= 1: # no path exists! + return + for a, b in zip(path[:-1], path[1:]): + self.remove_edge(a, b) + + def longest_paths(self): + """ + :return: list of paths from longest to shortest. A path is a list of nodes. + """ + g = self.copy_graph() + paths = [] + path = g.get_longest_path() + while path: + paths.append(path) + g.remove_edges_in_path(path) + path = g.get_longest_path() + return paths + + def repr_path(self): + """ + Construct string representation of the erm, summarizing dependencies between + tables + :return: string representation of the erm + """ + if len(self) == 0: + return "No relations to show" + + paths = self.longest_paths() + + # turn comparator into Key object for use in sort + k = cmp_to_key(self.compare_path) + sorted_paths = sorted(paths, key=k) + + # table name will be padded to match the longest table name + node_labels = self.node_labels + n = max([len(x) for x in node_labels.values()]) + 1 + rep = '' + for path in sorted_paths: + rep += self.repr_path_with_depth(path, n) + + for node in self.lone_nodes: + rep += node_labels[node] + '\n' + + return rep + + def compare_path(self, path1, path2): + """ + Comparator between two paths: path1 and path2 based on a combination of rules. + Path 1 is greater than path2 if: + 1) i^th node in path1 is at greater depth than the i^th node in path2 OR + 2) if i^th nodes are at the same depth, i^th node in path 1 is alphabetically less than i^th node + in path 2 + 3) if neither of the above statement is true even if path1 and path2 are switched, proceed to i+1^th node + If path2 is a subpath start at node 1, then path1 is greater than path2 + :param path1: path 1 of 2 to be compared + :param path2: path 2 of 2 to be compared + :return: return 1 if path1 is greater than path2, -1 if path1 is less than path2, and 0 if they are identical + """ + node_depth_lookup = dict(self.nodes_by_depth()) + for node1, node2 in zip(path1, path2): + if node_depth_lookup[node1] != node_depth_lookup[node2]: + return -1 if node_depth_lookup[node1] < node_depth_lookup[node2] else 1 + if node1 != node2: + return -1 if node1 < node2 else 1 + if len(node1) != len(node2): + return -1 if len(node1) < len(node2) else 1 + return 0 + + def repr_path_with_depth(self, path, n=20, m=2): + node_depth_lookup = dict(self.nodes_by_depth()) + node_labels = self.node_labels + space = '-' * n + rep = '' + prev_depth = 0 + first = True + for (i, node) in enumerate(path): + depth = node_depth_lookup[node] + label = node_labels[node] + if first: + rep += (' '*(n+m))*(depth-prev_depth) + else: + rep += space.join(['-'*m]*(depth-prev_depth))[:-1] + '>' + first = False + prev_depth = depth + if i == len(path)-1: + rep += label + else: + rep += label.ljust(n, '-') + rep += '\n' + return rep + + @classmethod + def create_from_dependencies(cls, dependencies, *args, **kwargs): + obj = cls(*args, **kwargs) - def recurse(full_table_name, level): - if level > ret[full_table_name]: - ret[full_table_name] = level - for child in self.children[full_table_name] + self.references[full_table_name]: - recurse(child, level+1) + for full_table, parents in dependencies.parents.items(): + database, table = (x.strip('`') for x in full_table.split('.')) + obj.add_node(full_table, database=database, table=table) + for parent in parents: + obj.add_edge(parent, full_table, rel='parent') - recurse(full_table_name, 0) - return sorted(ret.keys(), key=ret.__getitem__) + # create non primary key foreign connections + for full_table, referenced in dependencies.referenced.items(): + for ref in referenced: + obj.add_edge(ref, full_table, rel='referenced') + return obj diff --git a/datajoint/fetch.py b/datajoint/fetch.py index 395885528..61a911ba5 100644 --- a/datajoint/fetch.py +++ b/datajoint/fetch.py @@ -188,10 +188,7 @@ def keys(self, **kwargs): """ Iterator that returns primary keys. """ - b = dict(self.behavior, **kwargs) - if 'as_dict' not in kwargs: - b['as_dict'] = True - yield from self._relation.project().fetch.set_behavior(**b) + yield from self._relation.project().fetch.set_behavior(**dict(self.behavior, as_dict=True, **kwargs)) def __getitem__(self, item): """ @@ -216,8 +213,7 @@ def __getitem__(self, item): result, 0, result.strides) if attribute is PRIMARY_KEY else result[attribute] - for attribute in item - ] + for attribute in item] return return_values[0] if single_output else return_values def __repr__(self): diff --git a/datajoint/jobs.py b/datajoint/jobs.py index db41e0eac..eaa631d77 100644 --- a/datajoint/jobs.py +++ b/datajoint/jobs.py @@ -1,7 +1,7 @@ import hashlib import os import pymysql -from .relation import Relation +from .base_relation import BaseRelation def key_hash(key): @@ -14,7 +14,7 @@ def key_hash(key): return hashed.hexdigest() -class JobRelation(Relation): +class JobRelation(BaseRelation): """ A base relation with no definition. Allows reserving jobs """ @@ -56,14 +56,12 @@ def reserve(self, table_name, key): :param key: the dict of the job's primary key :return: True if reserved job successfully. False = the jobs is already taken """ - job_key = dict(table_name=table_name, key_hash=key_hash(key)) + try: - self.insert1( - dict(job_key, - status="reserved", - host=os.uname().nodename, - pid=os.getpid())) - except pymysql.err.IntegrityError: + job_key = dict(table_name=table_name, key_hash=key_hash(key), + status='reserved', host=os.uname().nodename, pid=os.getpid()) + self.insert1(job_key) + except pymysql.err.IntegrityError: #TODO check for other exceptions! return False else: return True diff --git a/datajoint/kill.py b/datajoint/kill.py index 48b9cbd54..8496f99ab 100644 --- a/datajoint/kill.py +++ b/datajoint/kill.py @@ -26,7 +26,7 @@ def kill(restriction=None, connection=None): while True: print(' ID USER STATE TIME INFO') print('+--+ +----------+ +-----------+ +--+') - for process in connection.query(query, as_dict=True).fetchall(): + for process in connection.query(query, as_dict=True).fetchall(): try: print('{ID:>4d} {USER:<12s} {STATE:<12s} {TIME:>5d} {INFO}'.format(**process)) except TypeError as err: @@ -40,6 +40,7 @@ def kill(restriction=None, connection=None): pid = int(response) except ValueError: pass # ignore non-numeric input + #TODO: check behavior when invalid input given else: try: connection.query('kill %d' % pid) diff --git a/datajoint/relational_operand.py b/datajoint/relational_operand.py index c7bc94ec0..3ee4e3bf4 100644 --- a/datajoint/relational_operand.py +++ b/datajoint/relational_operand.py @@ -1,4 +1,4 @@ -from collections.abc import Mapping +from collections.abc import Mapping, Sequence import numpy as np import abc import re @@ -11,6 +11,82 @@ logger = logging.getLogger(__name__) +class AndList(Sequence): + """ + A list of restrictions to by applied to a relation. The restrictions are ANDed. + Each restriction can be a list or set or a relation whose elements are ORed. + But the elements that are lists can contain + """ + def __init__(self, heading): + self.heading = heading + self._list = [] + + def __len__(self): + return len(self._list) + + def __getitem__(self, i): + return self._list[i] + + def add(self, *args): + # remove Nones and duplicates + args = [r for r in args if r is not None and r not in self] + if args: + if any(is_empty_set(r) for r in args): + # if any condition is an empty list, return FALSE + self._list = ['FALSE'] + else: + self._list.extend(args) + + def where_clause(self): + """ + convert to a WHERE clause string + """ + def make_condition(arg, _negate=False): + if isinstance(arg, (str, AndList)): + return str(arg), _negate + + # semijoin or antijoin + if isinstance(arg, RelationalOperand): + common_attributes = [q for q in self.heading.names if q in arg.heading.names] + if not common_attributes: + condition = 'FALSE' if negate else 'TRUE' + else: + common_attributes = '`'+'`,`'.join(common_attributes)+'`' + condition = '({fields}) {not_}in ({subquery})'.format( + fields=common_attributes, + not_="not " if negate else "", + subquery=arg.make_select(common_attributes)) + return condition, False # negate is cleared + + # mappings are turned into ANDed equality conditions + if isinstance(arg, Mapping): + condition = ['`%s`=%s' % (k, repr(v)) for k, v in arg.items() if k in self.heading] + elif isinstance(arg, np.void): + # element of a record array + condition = ['`%s`=%s' % (k, arg[k]) for k in arg.dtype.fields if k in self.heading] + else: + raise DataJointError('invalid restriction type') + return ' AND '.join(condition) if condition else 'TRUE', _negate + + if not self: + return '' + + conditions = [] + for item in self: + negate = isinstance(item, Not) + if negate: + item = item.restriction + if isinstance(item, (list, tuple, set, np.ndarray)): + # sets of conditions are ORed + item = '(' + ') OR ('.join([make_condition(q)[0] for q in item]) + ')' + else: + item, negate = make_condition(item, negate) + if not item: + raise DataJointError('Empty condition') + conditions.append(('NOT (%s)' if negate else '(%s)') % item) + return ' WHERE ' + ' AND '.join(conditions) + + class RelationalOperand(metaclass=abc.ABCMeta): """ RelationalOperand implements relational algebra and fetch methods. @@ -25,12 +101,21 @@ class RelationalOperand(metaclass=abc.ABCMeta): @property def restrictions(self): - return [] if self._restrictions is None else self._restrictions + if self._restrictions is None: + self._restrictions = AndList(self.heading) + return self._restrictions + + def clear_restrictions(self): + self._restrictions = None @property def primary_key(self): return self.heading.primary_key + @property + def where_clause(self): + return self.restrictions.where_clause() + # --------- abstract properties ----------- @property @@ -107,13 +192,20 @@ def aggregate(self, group, *attributes, **renamed_attributes): Join(self, group, left=True), *attributes, **renamed_attributes) + def __iand__(self, restriction): + """ + in-place restriction by a single condition + """ + self.restrict(restriction) + def __and__(self, restriction): """ relational restriction or semijoin :return: a restricted copy of the argument """ ret = copy(self) - ret.restrict(restriction, *ret.restrictions) + ret.clear_restrictions() + ret.restrict(restriction, *list(self.restrictions)) return ret def restrict(self, *restrictions): @@ -124,25 +216,15 @@ def restrict(self, *restrictions): However, each member of restrictions can be a list of conditions, which are combined with OR. :param restrictions: list of restrictions. """ - # remove Nones and duplicates - restrictions = [r for r in restrictions if r is not None and r not in self.restrictions] - if restrictions: - if any(is_empty_set(r) for r in restrictions): - # if any condition is an empty list, return empty - self._restrictions = ['FALSE'] - else: - if self._restrictions is None: - self._restrictions = restrictions - else: - self._restrictions.extend(restrictions) + self.restrictions.add(*restrictions) def attributes_in_restrictions(self): """ :return: list of attributes that are probably used in the restrictions. This is used internally for optimizing SQL statements """ - where_clause = self.where_clause - return set(name for name in self.heading.names if name in where_clause) + s = self.restrictions.where_clause() # avoid calling multiple times + return set(name for name in self.heading.names if name in s) def __sub__(self, restriction): """ @@ -166,9 +248,8 @@ def make_select(self, select_fields=None): return 'SELECT {fields} FROM {from_}{where}{group}'.format( fields=select_fields if select_fields else self.select_fields, from_=self.from_clause, - where=self.where_clause, - group=' GROUP BY `%s`' % '`,`'.join(self.primary_key) if self._grouped else '' - ) + where=self.restrictions.where_clause(), + group=' GROUP BY `%s`' % '`,`'.join(self.primary_key) if self._grouped else '') def __len__(self): """ @@ -212,54 +293,10 @@ def fetch1(self): def fetch(self): return Fetch(self) - @property - def where_clause(self): - """ - convert the restriction into an SQL WHERE - """ - if not self.restrictions: - return '' - - def make_condition(arg, _negate=False): - if isinstance(arg, str): - condition = [arg] - elif isinstance(arg, RelationalOperand): - common_attributes = [q for q in self.heading.names if q in arg.heading.names] - if not common_attributes: - condition = ['FALSE' if negate else 'TRUE'] - else: - common_attributes = '`'+'`,`'.join(common_attributes)+'`' - condition = ['({fields}) {not_}in ({subquery})'.format( - fields=common_attributes, - not_="not " if negate else "", - subquery=arg.make_select(common_attributes))] - _negate = False - elif isinstance(arg, Mapping): - condition = ['`%s`=%s' % (k, repr(v)) for k, v in arg.items() if k in self.heading] - elif isinstance(arg, np.void): - # element of a record array - condition = ['`%s`=%s' % (k, arg[k]) for k in arg.dtype.fields if k in self.heading] - else: - raise DataJointError('invalid restriction type') - return ' AND '.join(condition) if condition else 'TRUE', _negate - - conditions = [] - for r in self.restrictions: - negate = isinstance(r, Not) - if negate: - r = r.restriction - if isinstance(r, (list, tuple, set, np.ndarray)): - r = '(' + ') OR ('.join([make_condition(q)[0] for q in r]) + ')' - else: - r, negate = make_condition(r, negate) - if r: - conditions.append('%s(%s)' % ('not ' if negate else '', r)) - return ' WHERE ' + ' AND '.join(conditions) - class Not: """ - inverse restriction + invert restriction """ def __init__(self, restriction): self.restriction = restriction @@ -277,10 +314,10 @@ def __init__(self, arg1, arg2, left=False): raise DataJointError('Cannot join relations with different database connections') self._arg1 = Subquery(arg1) if arg1.heading.computed else arg1 self._arg2 = Subquery(arg2) if arg2.heading.computed else arg2 - self.restrict(*self._arg1.restrictions) - self.restrict(*self._arg2.restrictions) - self._left = left self._heading = self._arg1.heading.join(self._arg2.heading, left=left) + self.restrict(*list(self._arg1.restrictions)) + self.restrict(*list(self._arg2.restrictions)) + self._left = left def _repr_helper(self): return "(%r) * (%r)" % (self._arg1, self._arg2) @@ -330,7 +367,7 @@ def __init__(self, arg, *attributes, **renamed_attributes): if use_subquery: self._arg = Subquery(arg) else: - self.restrict(*arg.restrictions) + self.restrict(*list(arg.restrictions)) def _repr_helper(self): return "(%r).project(%r)" % (self._arg, self._attributes) @@ -352,12 +389,21 @@ def from_clause(self): return self._arg.from_clause def __and__(self, restriction): + has_restriction = isinstance(restriction, RelationalOperand) or bool(restriction) + do_subquery = has_restriction and self.heading.computed + ret = Subquery(self) if do_subquery else self + ret.restrict(restriction) + return ret + + def restrict(self, *restrictions): """ - When restricting on renamed attributes, enclose in subquery + Override restrict: when restricting on renamed attributes, enclose in subquery """ - has_restriction = isinstance(restriction, RelationalOperand) or restriction + has_restriction = any(isinstance(r, RelationalOperand) or r for r in restrictions) do_subquery = has_restriction and self.heading.computed - return Subquery(self) & restriction if do_subquery else super().__and__(restriction) + if do_subquery: + raise DataJointError('In-place restriction on renamed attributes is not allowed') + super().restrict(*restrictions) class Aggregation(Projection): @@ -370,6 +416,7 @@ class Subquery(RelationalOperand): """ A Subquery encapsulates its argument in a SELECT statement, enabling its use as a subquery. The attribute list and the WHERE clause are resolved. + As such, a subquery does not have any renamed attributes. """ __counter = 0 diff --git a/datajoint/schema.py b/datajoint/schema.py index 002c90f21..985e038b4 100644 --- a/datajoint/schema.py +++ b/datajoint/schema.py @@ -3,19 +3,23 @@ from . import conn, DataJointError from .heading import Heading -from .relation import Relation +from .base_relation import BaseRelation from .user_relations import Part +import inspect logger = logging.getLogger(__name__) -class schema: +class Schema: """ A schema object can be used as a decorator that associates a Relation class to a database as - well as a namespace for looking up foreign key references. + well as a namespace for looking up foreign key references in table declaration. """ def __init__(self, database, context, connection=None): """ + Associates the specified database with this schema object. If the target database does not exist + already, will attempt on creating the database. + :param database: name of the database to associate the decorated class with :param context: dictionary for looking up foreign keys references, usually set to locals() :param connection: Connection object. Defaults to datajoint.conn() @@ -25,10 +29,8 @@ def __init__(self, database, context, connection=None): self.database = database self.connection = connection self.context = context - - # if the database does not exist, create it - cur = connection.query("SHOW DATABASES LIKE '{database}'".format(database=database)) - if cur.rowcount == 0: + if not self.exists: + # create schema logger.info("Database `{database}` could not be found. " "Attempting to create the database.".format(database=database)) try: @@ -38,10 +40,34 @@ def __init__(self, database, context, connection=None): raise DataJointError("Database named `{database}` was not defined, and" " an attempt to create has failed. Check" " permissions.".format(database=database)) + connection.register(self) + + def drop(self): + """ + Drop the associated database if it exists + """ + if self.exists: + logger.info("Dropping `{database}`.".format(database=self.database)) + try: + self.connection.query("DROP DATABASE `{database}`".format(database=self.database)) + logger.info("Database `{database}` was dropped successfully.".format(database=self.database)) + except pymysql.OperationalError: + raise DataJointError("An attempt to drop database named `{database}` " + "has failed. Check permissions.".format(database=self.database)) + else: + logger.info("Database named `{database}` does not exist. Doing nothing.".format(database=self.database)) + + @property + def exists(self): + """ + :return: true if the associated database exists on the server + """ + cur = self.connection.query("SHOW DATABASES LIKE '{database}'".format(database=self.database)) + return cur.rowcount > 0 def __call__(self, cls): """ - The decorator binds its argument class object to a database + Binds the passed in class object to a database. This is intended to be used as a decorator. :param cls: class to be decorated """ @@ -53,28 +79,23 @@ def process_relation_class(relation_class, context): relation_class._connection = self.connection relation_class._heading = Heading() relation_class._context = context + # instantiate the class and declare the table in database if not already present relation_class().declare() if issubclass(cls, Part): - raise DataJointError('The schema decorator should not apply to part relations') + raise DataJointError('The schema decorator should not be applied to Part relations') process_relation_class(cls, context=self.context) - # Process subordinate relations + # Process subordinate relations parts = list() - for name in (name for name in dir(cls) if not name.startswith('_')): - part = getattr(cls, name) - try: - is_sub = issubclass(part, Part) - except TypeError: - pass - else: - if is_sub: - parts.append(part) - part._master = cls - process_relation_class(part, context=dict(self.context, **{cls.__name__: cls})) - elif issubclass(part, Relation): - raise DataJointError('Part relations must subclass from datajoint.Part') + is_part = lambda x: inspect.isclass(x) and issubclass(x, Part) + + for var, part in inspect.getmembers(cls, is_part): + parts.append(part) + part._master = cls + # TODO: look into local namespace for the subclasses + process_relation_class(part, context=dict(self.context, **{cls.__name__: cls})) # invoke Relation._prepare() on class and its part relations. cls()._prepare() diff --git a/datajoint/user_relations.py b/datajoint/user_relations.py index 8ffcd2e2f..4f5875bac 100644 --- a/datajoint/user_relations.py +++ b/datajoint/user_relations.py @@ -3,13 +3,13 @@ """ import abc -from .relation import Relation +from .base_relation import BaseRelation from .autopopulate import AutoPopulate from .utils import from_camel_case from . import DataJointError -class Part(Relation, metaclass=abc.ABCMeta): +class Part(BaseRelation, metaclass=abc.ABCMeta): """ Inherit from this class if the table's values are details of an entry in another relation and if this table is populated by this relation. For example, the entries inheriting from @@ -29,7 +29,7 @@ def table_name(self): return self.master().table_name + '__' + from_camel_case(self.__class__.__name__) -class Manual(Relation, metaclass=abc.ABCMeta): +class Manual(BaseRelation, metaclass=abc.ABCMeta): """ Inherit from this class if the table's values are entered manually. """ @@ -42,7 +42,7 @@ def table_name(self): return from_camel_case(self.__class__.__name__) -class Lookup(Relation, metaclass=abc.ABCMeta): +class Lookup(BaseRelation, metaclass=abc.ABCMeta): """ Inherit from this class if the table's values are for lookup. This is currently equivalent to defining the table as Manual and serves semantic @@ -64,7 +64,7 @@ def _prepare(self): self.insert(self.contents, skip_duplicates=True) -class Imported(Relation, AutoPopulate, metaclass=abc.ABCMeta): +class Imported(BaseRelation, AutoPopulate, metaclass=abc.ABCMeta): """ Inherit from this class if the table's values are imported from external data sources. The inherited class must at least provide the function `_make_tuples`. @@ -78,7 +78,7 @@ def table_name(self): return "_" + from_camel_case(self.__class__.__name__) -class Computed(Relation, AutoPopulate, metaclass=abc.ABCMeta): +class Computed(BaseRelation, AutoPopulate, metaclass=abc.ABCMeta): """ Inherit from this class if the table's values are computed from other relations in the schema. The inherited class must at least provide the function `_make_tuples`. diff --git a/requirements.txt b/requirements.txt index e8a5dcbf0..af4d48f13 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,3 @@ networkx matplotlib sphinx_rtd_theme mock -pygraphviz==1.3rc1 diff --git a/tests/__init__.py b/tests/__init__.py index 149175c82..e1123a920 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -9,11 +9,15 @@ import logging from os import environ + +# turn on verbose logging +logging.basicConfig(level=logging.DEBUG) + import datajoint as dj __all__ = ['__author__', 'PREFIX', 'CONN_INFO'] -logging.basicConfig(level=logging.DEBUG) + # Connection for testing CONN_INFO = dict( @@ -27,11 +31,10 @@ def setup_package(): """ Package-level unit test setup - :return: + Turns off safemode """ dj.config['safemode'] = False - def teardown_package(): """ Package-level unit test teardown. diff --git a/tests/schema.py b/tests/schema.py index 3b6b0a479..e1444a189 100644 --- a/tests/schema.py +++ b/tests/schema.py @@ -1,5 +1,5 @@ """ -Test schema definition +Sample scheme with realistic tables for testing """ import random diff --git a/tests/test_cascading_delete.py b/tests/test_cascading_delete.py index 8182ad6c1..fa499ba1e 100644 --- a/tests/test_cascading_delete.py +++ b/tests/test_cascading_delete.py @@ -63,7 +63,7 @@ def test_delete_lookup(): 'schema is not populated') L().delete() assert_false(bool(L() or D() or E() or E.F()), 'incomplete delete') - A().delete() # delete all is necessary because delete L deletes from subtables. TODO: submit this as an issue + A().delete() # delete all is necessary because delete L deletes from subtables. # @staticmethod # def test_delete_lookup_restricted():