diff --git a/.travis.yml b/.travis.yml index 7c3f7bf8b..87387a166 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,7 @@ sudo: required language: python env: - - DJ_TEST_HOST="127.0.0.1" DJ_TEST_USER="root" DJ_TEST_PASSWORD="" DJ_HOST="127.0.0.1" DJ_USER="root" DJ_PASSWORD="" + - DJ_TEST_HOST="127.0.0.1" DJ_TEST_USER="datajoint" DJ_TEST_PASSWORD="datajoint" DJ_HOST="127.0.0.1" DJ_USER="datajoint" DJ_PASSWORD="datajoint" python: - "3.4" - "3.5" @@ -9,6 +9,7 @@ services: mysql before_install: - sudo apt-get -qq update - sudo apt-get install -y libblas-dev liblapack-dev libatlas-dev gfortran + - mysql -e "create user 'datajoint'@'%' identified by 'datajoint'; GRANT ALL PRIVILEGES ON \`djtest\_%\`.* TO 'datajoint'@'%';" -uroot install: - travis_wait 30 pip install -r requirements.txt - pip install nose nose-cov python-coveralls diff --git a/README.md b/README.md index f835814af..f65eba9b5 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ DataJoint for Python is a high-level programming interface for relational databa DataJoint was initially developed in 2009 by Dimitri Yatsenko in Andreas Tolias' Lab for the distributed processing and management of large volumes of data streaming from regular experiments. Starting in 2011, DataJoint has been available as an open-source project adopted by other labs and improved through contributions from several developers. + ## Quick start guide The current pip version is ancient. We will update it as soon as datajoint release 1.0 is out. ~~To install datajoint using `pip` just run:~~ @@ -27,3 +28,26 @@ pip install datajoint ~~However, please be aware that DataJoint for Python is still undergoing major changes, and thus what's available on PyPI via `pip` is in **pre-release state**!~~ +## Tutorial +1. [Setup](tutorial-notebooks/Primer00.ipynb) +1. [Connect](tutorial-notebooks/Primer01.ipynb) +1. [Create a schema, define a table](tutorial-notebooks/Primer02.ipynb) +1. [Dependencies](tutorial-notebooks/Primer03.ipynb) +1. [Schemas as Python modules](tutorial-notebooks/Primer04.ipynb) +1. [Lookup tables](tutorial-notebooks/Primer05.ipynb) +1. [Queries 1: restrictions and joins](tutorial-notebooks/Primer06.ipynb) +1. [Dependencies 2: non-primary](tutorial-notebooks/Primer07.ipynb) +1. [Queries 2: projections](tutorial-notebooks/Primer08.ipynb) +1. [Dependencies 3: aliased foreign keys](tutorial-notebooks/Primer09.ipynb) +1. [Computations](tutorial-notebooks/Primer10.ipynb) +1. [Parameterized Computations](tutorial-notebooks/Primer11.ipynb) +1. [Master-part relationships](tutorial-notebooks/Primer12.ipynb) +1. [Understanding transactions](tutorial-notebooks/Primer13.ipynb) +1. [Job management for distributed computation](tutorial-notebooks/Primer14.ipynb) +1. [Projection and aggregation](tutorial-notebooks/Primer15.ipynb) +1. [Relation U](tutorial-notebooks/Primer16.ipynb) +1. [Dependencies 4: mapped dependencies](tutorial-notebooks/Primer17.ipynb) +1. [Representing graphs](tutorial-notebooks/Primer18.ipynb) +1. [Customizing computations](tutorial-notebooks/Primer19.ipynb) +1. [BOSS interface](tutorial-notebooks/Primer20.ipynb) +1. [Web interfaces](tutorial-notebooks/Primer21.ipynb) diff --git a/datajoint/__init__.py b/datajoint/__init__.py index 45df9c9c8..c353e9ed7 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -6,18 +6,25 @@ DataJoint is free software under the LGPL License. In addition, we request that any use of DataJoint leading to a publication be acknowledged in the publication. + +Please cite: + http://biorxiv.org/content/early/2015/11/14/031658 + http://dx.doi.org/10.1101/031658 """ import logging import os __author__ = "Dimitri Yatsenko, Edgar Walker, and Fabian Sinz at Baylor College of Medicine" -__version__ = "0.2" +__version__ = "0.2.1" +__date__ = "June 1, 2016" __all__ = ['__author__', '__version__', 'config', 'conn', 'kill', 'Connection', 'Heading', 'BaseRelation', 'FreeRelation', 'Not', 'schema', 'Manual', 'Lookup', 'Imported', 'Computed', 'Part', - 'AndList', 'OrList'] + 'AndList', 'OrList', 'ERD', 'U'] + +print('DataJoint', __version__, '('+__date__+')') class key: @@ -65,12 +72,12 @@ class DataJointError(Exception): logger.setLevel(log_levels[config['loglevel']]) - # ------------- flatten import hierarchy ------------------------- from .connection import conn, Connection from .base_relation import BaseRelation from .user_relations import Manual, Lookup, Imported, Computed, Part -from .relational_operand import Not, AndList, OrList +from .relational_operand import Not, AndList, OrList, U from .heading import Heading from .schema import Schema as schema from .kill import kill +from .erd import ERD diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index 8b4c652ec..f4dde71fe 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -3,6 +3,7 @@ import logging import datetime import random +from pymysql import OperationalError from .relational_operand import RelationalOperand, AndList from . import DataJointError from .base_relation import FreeRelation @@ -12,17 +13,17 @@ logger = logging.getLogger(__name__) -class AutoPopulate(metaclass=abc.ABCMeta): +class AutoPopulate: """ AutoPopulate is a mixin class that adds the method populate() to a Relation class. Auto-populated relations must inherit from both Relation and AutoPopulate, - must define the property populated_from, and must define the callback method _make_tuples. + must define the property `key_source`, and must define the callback method _make_tuples. """ _jobs = None _populated_from = None @property - def populated_from(self): + def key_source(self): """ :return: the relation whose primary key values are passed, sequentially, to the `_make_tuples` method when populate() is called.The default value is the @@ -30,23 +31,22 @@ def populated_from(self): or the scope of populate() calls. """ if self._populated_from is None: - self.connection.dependencies.load() - parents = [FreeRelation(self.target.connection, rel) for rel in self.target.parents] + self.connection.dependencies.load(self.full_table_name) + parents = self.target.parents(primary=True) if not parents: raise DataJointError('A relation must have parent relations to be able to be populated') - ret = parents.pop(0) + self._populated_from = FreeRelation(self.connection, parents.pop(0)).proj() while parents: - ret *= parents.pop(0) - self._populated_from = ret + self._populated_from *= FreeRelation(self.connection, parents.pop(0)).proj() return self._populated_from - @abc.abstractmethod def _make_tuples(self, key): """ Derived classes must implement method _make_tuples that fetches data from tables that are above them in the dependency hierarchy, restricting by the given key, computes dependent attributes, and inserts the new tuples into self. """ + raise NotImplementedError('Subclasses of AutoPopulate must implement the method "_make_tuples"') @property def target(self): @@ -58,10 +58,10 @@ def target(self): def populate(self, *restrictions, suppress_errors=False, reserve_jobs=False, order="original"): """ - rel.populate() calls rel._make_tuples(key) for every primary key in self.populated_from + rel.populate() calls rel._make_tuples(key) for every primary key in self.key_source for which there is not already a tuple in rel. - :param restrictions: a list of restrictions each restrict (rel.populated_from - target.proj()) + :param restrictions: a list of restrictions each restrict (rel.key_source - target.proj()) :param suppress_errors: suppresses error if true :param reserve_jobs: if true, reserves job to populate in asynchronous fashion :param order: "original"|"reverse"|"random" - the order of execution @@ -73,10 +73,10 @@ def populate(self, *restrictions, suppress_errors=False, reserve_jobs=False, ord if order not in valid_order: raise DataJointError('The order argument must be one of %s' % str(valid_order)) - todo = self.populated_from + todo = self.key_source if not isinstance(todo, RelationalOperand): - raise DataJointError('Invalid populated_from value') - todo.restrict(AndList(restrictions)) + raise DataJointError('Invalid key_source value') + todo = todo & AndList(restrictions) error_list = [] if suppress_errors else None @@ -104,7 +104,11 @@ def populate(self, *restrictions, suppress_errors=False, reserve_jobs=False, ord try: self._make_tuples(dict(key)) except Exception as error: - self.connection.cancel_transaction() + try: + self.connection.cancel_transaction() + except OperationalError: + pass + if reserve_jobs: jobs.error(self.target.table_name, key, error_message=str(error)) if not suppress_errors: @@ -118,14 +122,14 @@ def populate(self, *restrictions, suppress_errors=False, reserve_jobs=False, ord jobs.complete(self.target.table_name, key) return error_list - def progress(self, restriction=None, display=True): + def progress(self, *restrictions, display=True): """ report progress of populating this table :return: remaining, total -- tuples to be populated """ - todo = self.populated_from & restriction + todo = self.key_source & AndList(restrictions) total = len(todo) - remaining = len(todo - self.target.project()) + remaining = len(todo - self.target.proj()) if display: print('%-20s' % self.__class__.__name__, flush=True, end=': ') print('Completed %d of %d (%2.1f%%) %s' % diff --git a/datajoint/base_relation.py b/datajoint/base_relation.py index a2834e7ff..a0b0b2894 100644 --- a/datajoint/base_relation.py +++ b/datajoint/base_relation.py @@ -1,11 +1,9 @@ -from collections.abc import Mapping -from collections import OrderedDict, defaultdict +import collections +import itertools import numpy as np import logging -import abc import binascii -from . import config -from . import DataJointError +from . import config, DataJointError from .declare import declare from .relational_operand import RelationalOperand from .blob import pack @@ -15,7 +13,7 @@ logger = logging.getLogger(__name__) -class BaseRelation(RelationalOperand, metaclass=abc.ABCMeta): +class BaseRelation(RelationalOperand): """ BaseRelation is an abstract class that represents a base relation, i.e. a table in the database. To make it a concrete class, override the abstract properties specifying the connection, @@ -28,49 +26,38 @@ class BaseRelation(RelationalOperand, metaclass=abc.ABCMeta): # ---------- abstract properties ------------ # @property - @abc.abstractmethod def table_name(self): """ :return: the name of the table in the database """ + raise NotImplementedError('Subclasses of BaseRelation must implement the property "table_name"') @property - @abc.abstractmethod def definition(self): """ :return: a string containing the table definition using the DataJoint DDL """ + raise NotImplementedError('Subclasses of BaseRelation must implement the property "definition"') # -------------- required by RelationalOperand ----------------- # - @property - def connection(self): - """ - :return: the connection object of the relation - """ - return self._connection - @property def heading(self): """ Returns the table heading. If the table is not declared, attempts to declare it and return heading. - :return: table heading """ if self._heading is None: self._heading = Heading() # instance-level heading - if not self._heading: # heading is not initialized - self.declare() + if not self._heading: # lazy loading of heading self._heading.init_from_database(self.connection, self.database, self.table_name) - return self._heading def declare(self): """ Loads the table heading. If the table is not declared, use self.definition to declare """ - if not self.is_declared: - self.connection.query( - declare(self.full_table_name, self.definition, self._context)) + self.connection.query( + declare(self.full_table_name, self.definition, self._context)) @property def from_clause(self): @@ -86,67 +73,25 @@ def select_fields(self): """ return '*' - def erd(self, *args, fill=True, mode='updown', **kwargs): - """ - :param mode: diffent methods of creating a graph pertaining to this relation. - Currently options includes the following: - * 'updown': Contains this relation and all other nodes that can be reached within specific - number of ups and downs in the graph. ups(=2) and downs(=2) are optional keyword arguments - * 'ancestors': Returs - :return: the entity relationship diagram object of this relation - """ - erd = self.connection.erd() - if mode == 'updown': - nodes = erd.up_down_neighbors(self.full_table_name, *args, **kwargs) - elif mode == 'ancestors': - nodes = erd.ancestors(self.full_table_name, *args, **kwargs) - elif mode == 'descendants': - nodes = erd.descendants(self.full_table_name, *args, **kwargs) - else: - raise DataJointError('Unsupported erd mode', 'Mode "%s" is currently not supported' % mode) - return erd.restrict_by_tables(nodes, fill=fill) - - # ------------- dependencies ---------- # - @property - def parents(self): + def parents(self, primary=None): """ - :return: the parent relation of this relation + :param primary: if None, then all parents are returned. If True, then only foreign keys composed of + primary key attributes are considered. If False, the only foreign keys including at least one non-primary + attribute are considered. + :return: list of tables referenced with self's foreign keys """ - return self.connection.dependencies.parents[self.full_table_name] + return [p[0] for p in self.connection.dependencies.in_edges(self.full_table_name, data=True) + if primary is None or p[2]['primary'] == primary] - @property - def children(self): - """ - :return: the child relations of this relation - """ - return self.connection.dependencies.children[self.full_table_name] - - @property - def references(self): + def children(self, primary=None): """ - :return: list of tables that this tables refers to + :param primary: if None, then all parents are returned. If True, then only foreign keys composed of + primary key attributes are considered. If False, the only foreign keys including at least one non-primary + attribute are considered. + :return: list of tables with foreign keys referencing self """ - return self.connection.dependencies.references[self.full_table_name] - - @property - def referenced(self): - """ - :return: list of tables for which this table is referenced by - """ - return self.connection.dependencies.referenced[self.full_table_name] - - @property - def descendants(self): - """ - Returns a list of relation objects for all children and references, recursively, - in order of dependence. The returned values do not include self. - This is helpful for cascading delete or drop operations. - - :return: list of descendants - """ - relations = (FreeRelation(self.connection, table) - for table in self.connection.dependencies.get_descendants(self.full_table_name)) - return [relation for relation in relations if relation.is_declared] + return [p[1] for p in self.connection.dependencies.out_edges(self.full_table_name, data=True) + if primary is None or p[2]['primary'] == primary] def _repr_helper(self): """ @@ -157,7 +102,7 @@ def _repr_helper(self): @property def is_declared(self): """ - :return: True is the table is declared + :return: True is the table is declared in the database """ cur = self.connection.query( 'SHOW TABLES in `{database}`LIKE "{table_name}"'.format( @@ -183,16 +128,14 @@ def insert(self, rows, replace=False, ignore_errors=False, skip_duplicates=False Insert a collection of rows. Additional keyword arguments are passed to insert1. :param rows: An iterable where an element is a valid arguments for insert1. - :param replace=False: If True, replaces the matching data tuple in the table if it exists. - :param ignore_errors=False: If True, ignore errors: e.g. constraint violations. - :param skip_dublicates=False: If True, silently skip duplicate inserts. + :param replace: If True, replaces the matching data tuple in the table if it exists. + :param ignore_errors: If True, ignore errors: e.g. constraint violations. + :param skip_duplicates: If True, silently skip duplicate inserts. Example:: - >>> relation.insert([ >>> dict(subject_id=7, species="mouse", date_of_birth="2014-09-01"), >>> dict(subject_id=8, species="mouse", date_of_birth="2014-09-02")]) - """ heading = self.heading field_list = None # ensures that all rows have the same attributes in the same order as the first row. @@ -244,14 +187,14 @@ def check_fields(fields): check_fields(row.dtype.fields) attributes = [make_placeholder(name, row[name]) for name in heading if name in row.dtype.fields] - elif isinstance(row, Mapping): # dict-based + elif isinstance(row, collections.abc.Mapping): # dict-based check_fields(row.keys()) attributes = [make_placeholder(name, row[name]) for name in heading if name in row] else: # positional try: if len(row) != len(heading): raise DataJointError( - 'Incorrect number of attributes: ' + 'Invalid insert argument. Incorrect number of attributes: ' '{given} given; {expected} expected'.format( given=len(row), expected=len(heading))) except TypeError: @@ -282,38 +225,22 @@ def row_exists(row): """ primary_key_value = dict((name, value) for (name, value) in zip(row['names'], row['values']) if heading[name].in_key) - return self & primary_key_value + return primary_key_value in self rows = list(make_row_to_insert(row) for row in rows) - - if not rows: - return - - # skip duplicates only if the entire primary key is specified. - skip_duplicates = skip_duplicates and set(heading.primary_key).issubset(set(field_list)) - - if skip_duplicates: - rows = list(row for row in rows if not row_exists(row)) - - if not rows: - return - - if replace: - sql = 'REPLACE' - elif ignore_errors: - sql = 'INSERT IGNORE' - else: - sql = 'INSERT' - - sql += " INTO %s (`%s`) VALUES " % (self.from_clause, '`,`'.join(field_list)) - - # add placeholders to sql - sql += ','.join('(' + ','.join(row['placeholders']) + ')' for row in rows) - # compile all values into one list - args = [] - for r in rows: - args.extend(v for v in r['values'] if v is not None) - self.connection.query(sql, args=args) + if rows: + # skip duplicates only if the entire primary key is specified. + skip_duplicates = skip_duplicates and set(heading.primary_key).issubset(set(field_list)) + if skip_duplicates: + rows = list(row for row in rows if not row_exists(row)) + if rows: + self.connection.query( + "{command} INTO {destination}(`{fields}`) VALUES {placeholders}".format( + command='REPLACE' if replace else 'INSERT IGNORE' if ignore_errors else 'INSERT', + destination=self.from_clause, + fields='`,`'.join(field_list), + placeholders=','.join('(' + ','.join(row['placeholders']) + ')' for row in rows)), + args=list(itertools.chain.from_iterable((v for v in r['values'] if v is not None) for r in rows))) def delete_quick(self): """ @@ -329,49 +256,49 @@ def delete(self): """ self.connection.dependencies.load() - # construct a list (OrderedDict) of relations to delete - relations = OrderedDict((r.full_table_name, r) for r in self.descendants) + relations_to_delete = collections.OrderedDict( + (r, FreeRelation(self.connection, r)) + for r in self.connection.dependencies.descendants(self.full_table_name)) # construct restrictions for each relation restrict_by_me = set() - restrictions = defaultdict(list) + restrictions = collections.defaultdict(list) if self.restrictions: restrict_by_me.add(self.full_table_name) restrictions[self.full_table_name].append(self.restrictions) # copy own restrictions - for r in relations.values(): - restrict_by_me.update(r.references) - for name, r in relations.items(): - for dep in (r.children + r.references): + for r in relations_to_delete.values(): + restrict_by_me.update(r.children(primary=False)) + for name, r in relations_to_delete.items(): + for dep in r.children(): if name in restrict_by_me: restrictions[dep].append(r) else: restrictions[dep].extend(restrictions[name]) # apply restrictions - for name, r in relations.items(): + for name, r in relations_to_delete.items(): if restrictions[name]: # do not restrict by an empty list - r.restrict([r.project() if isinstance(r, RelationalOperand) else r + r.restrict([r.proj() if isinstance(r, RelationalOperand) else r for r in restrictions[name]]) # project - # execute do_delete = False # indicate if there is anything to delete if config['safemode']: print('The contents of the following tables are about to be deleted:') - for relation in list(relations.values()): + for relation in list(relations_to_delete.values()): count = len(relation) if count: do_delete = True if config['safemode']: print(relation.full_table_name, '(%d tuples)' % count) else: - relations.pop(relation.full_table_name) + relations_to_delete.pop(relation.full_table_name) if not do_delete: if config['safemode']: print('Nothing to delete') else: if not config['safemode'] or user_choice("Proceed?", default='no') == 'yes': with self.connection.transaction: - for r in reversed(list(relations.values())): + for r in reversed(list(relations_to_delete.values())): r.delete_quick() print('Done') @@ -380,7 +307,6 @@ def drop_quick(self): Drops the table associated with this relation without cascading and without user prompt. If the table has any dependent table(s), this call will fail with an error. """ - # TODO: give a better exception message if self.is_declared: self.connection.query('DROP TABLE %s' % self.full_table_name) logger.info("Dropped table %s" % self.full_table_name) @@ -394,15 +320,15 @@ def drop(self): """ self.connection.dependencies.load() do_drop = True - relations = self.descendants + tables = self.connection.dependencies.descendants(self.full_table_name) if config['safemode']: - for relation in relations: - print(relation.full_table_name, '(%d tuples)' % len(relation)) + for table in tables: + print(table, '(%d tuples)' % len(FreeRelation(self.connection, table))) do_drop = user_choice("Proceed?", default='no') == 'yes' if do_drop: - while relations: - relations.pop().drop_quick() - print('Tables dropped.') + for table in reversed(tables): + FreeRelation(self.connection, table).drop_quick() + print('Tables dropped. Restart kernel.') @property def size_on_disk(self): @@ -415,7 +341,7 @@ def size_on_disk(self): return ret['Data_length'] + ret['Index_length'] # --------- functionality used by the decorator --------- - def _prepare(self): + def prepare(self): """ This method is overridden by the user_relations subclasses. It is called on an instance once when the class is declared. @@ -427,25 +353,27 @@ class FreeRelation(BaseRelation): """ A base relation without a dedicated class. Each instance is associated with a table specified by full_table_name. + :param arg: a dj.Connection or a dj.FreeRelation """ - def __init__(self, connection, full_table_name, definition=None, context=None): - self.database, self._table_name = (s.strip('`') for s in full_table_name.split('.')) - self._connection = connection - self._definition = definition - self._context = context + def __init__(self, arg, full_table_name=None): + super().__init__() + if isinstance(arg, FreeRelation): + # copy constructor + self.database = arg.database + self._table_name = arg._table_name + self._connection = arg._connection + else: + self.database, self._table_name = (s.strip('`') for s in full_table_name.split('.')) + self._connection = arg + def __repr__(self): return "FreeRelation(`%s`.`%s`)" % (self.database, self._table_name) @property def definition(self): - """ - Definition of the table. - - :return: the definition - """ - return self._definition + return None @property def connection(self): diff --git a/datajoint/blob.py b/datajoint/blob.py index debae6a35..be930bba0 100644 --- a/datajoint/blob.py +++ b/datajoint/blob.py @@ -34,7 +34,8 @@ b'ZL123\0': zlib.decompress } -class BlobReader(object): + +class BlobReader: def __init__(self, blob, simplify=False): self._simplify = simplify self._blob = blob @@ -200,7 +201,6 @@ def __str__(self): return str(self._blob[self.pos:]) - def pack(obj): """ Packs an object into a blob to be compatible with mym.mex diff --git a/datajoint/connection.py b/datajoint/connection.py index d63ad93a4..849de6e70 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -7,7 +7,6 @@ import logging from . import config from . import DataJointError -from datajoint.erd import ERD from .dependencies import Dependencies from .jobs import JobManager @@ -92,10 +91,6 @@ def is_connected(self): """ return self._conn.ping() - def erd(self): - self.dependencies.load() - return ERD.create_from_dependencies(self.dependencies) - def query(self, query, args=(), as_dict=False): """ Execute the specified query and return the tuple generator (cursor). diff --git a/datajoint/declare.py b/datajoint/declare.py index e853cae40..d01bb87a2 100644 --- a/datajoint/declare.py +++ b/datajoint/declare.py @@ -11,6 +11,96 @@ logger = logging.getLogger(__name__) +def build_foreign_key_parser(): + attribute_name = pp.Word(pp.srange('[a-z]'), pp.srange('[a-z0-9_]')) + new_attributes = pp.Optional(pp.delimitedList(attribute_name)).setResultsName('new_attributes') + arrow = pp.Literal('->').suppress() + ref_table = pp.Word(pp.alphas, pp.alphanums + '._').setResultsName('ref_table') + left = pp.Literal('(').suppress() + right = pp.Literal(')').suppress() + ref_attrs = pp.Optional(left + pp.delimitedList(attribute_name) + right).setResultsName('ref_attrs') + return new_attributes + arrow + ref_table + ref_attrs + + +def build_attribute_parser(): + quoted = pp.Or(pp.QuotedString('"'), pp.QuotedString("'")) + colon = pp.Literal(':').suppress() + attribute_name = pp.Word(pp.srange('[a-z]'), pp.srange('[a-z0-9_]')).setResultsName('name') + data_type = pp.Combine(pp.Word(pp.alphas) + pp.SkipTo("#", ignore=quoted)).setResultsName('type') + default = pp.Literal('=').suppress() + pp.SkipTo(colon, ignore=quoted).setResultsName('default') + comment = pp.Literal('#').suppress() + pp.restOfLine.setResultsName('comment') + return attribute_name + pp.Optional(default) + colon + data_type + comment + + +foreign_key_parser = build_foreign_key_parser() +attribute_parser = build_attribute_parser() + + +def is_foreign_key(line): + """ + :param line: a line from the table definition + :return: true if the line appears to be a foreign key definition + """ + arrow_position = line.find('->') + return arrow_position >= 0 and not any(c in line[0:arrow_position] for c in '"#\'') + + +def compile_foreign_key(line, context, attributes, primary_key, attr_sql, foreign_key_sql): + """ + :param line: a line from a table definition + :param context: namespace containing referenced objects + :param attributes: list of attribute names already in the declaration -- to be updated by this function + :param primary_key: None if the current foreign key is made from the dependent section. Otherwise it is the list + of primary key attributes thus far -- to be updated by the function + :param attr_sql: a list of sql statements defining attributes -- to be updated by this function. + :param foreign_key_sql: a list of sql statements specifying foreign key constraints -- to be updated by this function. + """ + from .base_relation import BaseRelation + try: + result = foreign_key_parser.parseString(line) + except pp.ParseException as err: + raise DataJointError('Parsing error in line "%s". %s.' % line, err) + try: + referenced_class = eval(result.ref_table, context) + except NameError: + raise DataJointError('Foreign key reference %s could not be resolved' % result.ref_table) + if not issubclass(referenced_class, BaseRelation): + raise DataJointError('Foreign key reference %s must be a subclass of UserRelation' % result.ref_table) + if result.ref_attrs and len(result.new_attributes) != len(result.ref_attrs): + raise DataJointError('The number of new attributes and referenced attributes does not match in "%s"' % line) + ref = referenced_class() + if not result.new_attributes: + # a simple foreign key + for attr in ref.primary_key: + if attr not in attributes: + attributes.append(attr) + attr_sql.append(ref.heading[attr].sql) + if primary_key is not None: + primary_key.append(attr) + fk = ref.primary_key + elif len(result.new_attributes) == 1 and not result.ref_attrs: + # a one-alias foreign key + ref_attr = (ref.primary_key if len(ref.primary_key) == 1 else + [attr for attr in ref.primary_key if attr not in attributes]) + if len(ref_attr) != 1: + raise DataJointError('Mismatched attributes in foreign key "%s"' % line) + ref_attr = ref_attr[0] + attr = result.new_attributes[0] + attributes.append(attr) + assert ref.heading[ref_attr].sql.startswith('`%s`' % ref_attr) + attr_sql.append(ref.heading[ref_attr].sql.replace(ref_attr, attr, 1)) + if primary_key is not None: + primary_key.append(attr) + fk = [attr if k == ref_attr else k for k in ref.primary_key] + else: + # a mapped foreign key + raise NotImplementedError('TBD mapped foreign keys ') + + foreign_key_sql.append( + 'FOREIGN KEY (`{fk}`) REFERENCES {ref} (`{pk}`) ON UPDATE CASCADE ON DELETE RESTRICT'.format( + fk='`,`'.join(fk), pk='`,`'.join(ref.primary_key), ref=ref.full_table_name)) + + def declare(full_table_name, definition, context): """ Parse declaration and create new SQL table accordingly. @@ -21,10 +111,8 @@ def declare(full_table_name, definition, context): """ # split definition into lines definition = re.split(r'\s*\n\s*', definition.strip()) - # check for optional table comment table_comment = definition.pop(0)[1:].strip() if definition[0].startswith('#') else '' - in_key = True # parse primary keys primary_key = [] attributes = [] @@ -37,34 +125,10 @@ def declare(full_table_name, definition, context): pass elif line.startswith('---') or line.startswith('___'): in_key = False # start parsing dependent attributes - elif line.startswith('->'): - # foreign - # TODO: clean up import order - from .base_relation import BaseRelation - # TODO: break this step into finer steps, checking the type of reference before calling it - try: - ref = eval(line[2:], context)() - except NameError: - raise DataJointError('Foreign key reference %s could not be resolved' % line[2:]) - except TypeError: - raise DataJointError('Foreign key reference %s could not be instantiated.' - 'Make sure %s is a valid BaseRelation subclass' % line[2:]) - # TODO: consider the case where line[2:] is a function that returns an instance of BaseRelation - if not isinstance(ref, BaseRelation): - raise DataJointError('Foreign key reference %s must be a subclass of BaseRelation' % line[2:]) - - foreign_key_sql.append( - 'FOREIGN KEY ({primary_key})' - ' REFERENCES {ref} ({primary_key})' - ' ON UPDATE CASCADE ON DELETE RESTRICT'.format( - primary_key='`' + '`,`'.join(ref.primary_key) + '`', ref=ref.full_table_name) - ) - for name in ref.primary_key: - if in_key and name not in primary_key: - primary_key.append(name) - if name not in attributes: - attributes.append(name) - attribute_sql.append(ref.heading[name].sql) + elif is_foreign_key(line): + compile_foreign_key(line, context, attributes, + primary_key if in_key else None, + attribute_sql, foreign_key_sql) elif re.match(r'^(unique\s+)?index[^:]*$', line, re.I): # index index_sql.append(line) # the SQL syntax is identical to DataJoint's else: @@ -74,7 +138,6 @@ def declare(full_table_name, definition, context): if name not in attributes: attributes.append(name) attribute_sql.append(sql) - # compile SQL if not primary_key: raise DataJointError('Table must have a primary key') @@ -97,15 +160,6 @@ def compile_attribute(line, in_key=False): :param in_key: set to True if attribute is in primary key set :returns: (name, sql) -- attribute name and sql code for its declaration """ - quoted = pp.Or(pp.QuotedString('"'), pp.QuotedString("'")) - colon = pp.Literal(':').suppress() - attribute_name = pp.Word(pp.srange('[a-z]'), pp.srange('[a-z0-9_]')).setResultsName('name') - - data_type = pp.Combine(pp.Word(pp.alphas)+pp.SkipTo("#", ignore=quoted)).setResultsName('type') - default = pp.Literal('=').suppress() + pp.SkipTo(colon, ignore=quoted).setResultsName('default') - comment = pp.Literal('#').suppress() + pp.restOfLine.setResultsName('comment') - - attribute_parser = attribute_name + pp.Optional(default) + colon + data_type + comment match = attribute_parser.parseString(line+'#', parseAll=True) match['comment'] = match['comment'].rstrip('#') diff --git a/datajoint/dependencies.py b/datajoint/dependencies.py index 722fd5244..d4fe79adc 100644 --- a/datajoint/dependencies.py +++ b/datajoint/dependencies.py @@ -1,59 +1,23 @@ -from collections import defaultdict import pyparsing as pp +import networkx as nx from . import DataJointError -from functools import wraps -class Dependencies: +class Dependencies(nx.DiGraph): """ Lookup for dependencies between tables """ __primary_key_parser = (pp.CaselessLiteral('PRIMARY KEY') + pp.QuotedString('(', endQuoteChar=')').setResultsName('primary_key')) - def __init__(self, conn): - self._conn = conn - self._parents = dict() - self._referenced = dict() - self._children = defaultdict(list) - self._references = defaultdict(list) - - - @property - def parents(self): - return self._parents - - @property - def children(self): - return self._children - - @property - def references(self): - return self._references - - @property - def referenced(self): - return self._referenced - - def get_descendants(self, full_table_name): - """ - :param full_table_name: a table name in the format `database`.`table_name` - :return: list of all children and references, in order of dependence. - This is helpful for cascading delete or drop operations. - """ - ret = defaultdict(lambda: 0) - - def recurse(name, level): - ret[name] = max(ret[name], level) - for child in self.children[name] + self.references[name]: - recurse(child, level+1) - - recurse(full_table_name, 0) - return sorted(ret.keys(), key=ret.__getitem__) + def __init__(self, connection): + self._conn = connection + self.loaded_tables = set() + super().__init__(self) @staticmethod def __foreign_key_parser(database): - def add_database(string, loc, toc): + def paste_database(unused1, unused2, toc): return ['`{database}`.`{table}`'.format(database=database, table=toc[0])] return (pp.CaselessLiteral('CONSTRAINT').suppress() + @@ -62,67 +26,60 @@ def add_database(string, loc, toc): pp.QuotedString('(', endQuoteChar=')').setResultsName('attributes') + pp.CaselessLiteral('REFERENCES') + pp.Or([ - pp.QuotedString('`').setParseAction(add_database), + pp.QuotedString('`').setParseAction(paste_database), pp.Combine(pp.QuotedString('`', unquoteResults=False) + '.' + pp.QuotedString('`', unquoteResults=False))]).setResultsName('referenced_table') + pp.QuotedString('(', endQuoteChar=')').setResultsName('referenced_attributes')) - def load(self): + def add_table(self, table_name): """ - Load dependencies for all tables that have not yet been loaded in all registered schemas + Adds table to the dependency graph + :param table_name: in format `schema`.`table` """ - for database in self._conn.schemas: - cur = self._conn.query('SHOW TABLES FROM `{database}`'.format(database=database)) - fk_parser = self.__foreign_key_parser(database) - # list tables that have not yet been loaded, exclude service tables that starts with ~ - tables = ('`{database}`.`{table_name}`'.format(database=database, table_name=info[0]) - for info in cur if not info[0].startswith('~')) - tables = (name for name in tables if name not in self._parents) - for name in tables: - self._parents[name] = list() - self._referenced[name] = list() - # fetch the CREATE TABLE statement - create_statement = self._conn.query('SHOW CREATE TABLE %s' % name).fetchone() - create_statement = create_statement[1].split('\n') - primary_key = None - for line in create_statement: - if primary_key is None: - try: - result = self.__primary_key_parser.parseString(line) - except pp.ParseException: - pass - else: - primary_key = [s.strip(' `') for s in result.primary_key.split(',')] - try: - result = fk_parser.parseString(line) - except pp.ParseException: - pass - else: - if not primary_key: - raise DataJointError('No primary key found %s' % name) - if result.referenced_attributes != result.attributes: - raise DataJointError( - "%s's foreign key refers to differently named attributes in %s" - % (self.__class__.__name__, result.referenced_table)) - if all(q in primary_key for q in [s.strip('` ') for s in result.attributes.split(',')]): - self._parents[name].append(result.referenced_table) - self._children[result.referenced_table].append(name) - else: - self._referenced[name].append(result.referenced_table) - self._references[result.referenced_table].append(name) - - def get_descendants(self, full_table_name): + if table_name in self.loaded_tables: + return + fk_parser = self.__foreign_key_parser(table_name.split('.')[0].strip('`')) + self.loaded_tables.add(table_name) + self.add_node(table_name) + create_statement = self._conn.query('SHOW CREATE TABLE %s' % table_name).fetchone()[1].split('\n') + primary_key = None + for line in create_statement: + if primary_key is None: + try: + result = self.__primary_key_parser.parseString(line) + except pp.ParseException: + pass + else: + primary_key = [s.strip(' `') for s in result.primary_key.split(',')] + try: + result = fk_parser.parseString(line) + except pp.ParseException: + pass + else: + self.add_edge(result.referenced_table, table_name, + primary=all(r.strip('` ') in primary_key for r in result.attributes.split(','))) + + def load(self, target=None): """ - :param full_table_name: a table name in the format `database`.`table_name` - :return: list of all children and references, in order of dependence. - This is helpful for cascading delete or drop operations. + Load dependencies for all loaded schemas. + This method gets called before any operation that requires dependencies: delete, drop, populate, progress. """ - ret = defaultdict(lambda: 0) - - def recurse(name, level): - ret[name] = max(ret[name], level) - for child in self.children[name] + self.references[name]: - recurse(child, level+1) - - recurse(full_table_name, 0) - return sorted(ret.keys(), key=ret.__getitem__) + if target is not None and '.' in target: # `database`.`table` + self.add_table(target) + else: + databases = self._conn.schemas if target is None else [target] + for database in databases: + for row in self._conn.query('SHOW TABLES FROM `{database}`'.format(database=database)): + table = row[0] + if not table.startswith('~'): # exclude service tables + self.add_table('`{db}`.`{tab}`'.format(db=database, tab=table)) + if not nx.is_directed_acyclic_graph(self): + raise DataJointError('DataJoint can only work with acyclic dependencies') + + def descendants(self, full_table_name): + """ + :param full_table_name: In form `schema`.`table_name` + :return: all dependent tables sorted in topological order. Self is included. + """ + nodes = nx.algorithms.dag.descendants(self, full_table_name) + return [full_table_name] + nx.algorithms.dag.topological_sort(self, nodes) diff --git a/datajoint/erd.py b/datajoint/erd.py index 14858d530..8984e98fb 100644 --- a/datajoint/erd.py +++ b/datajoint/erd.py @@ -1,473 +1,272 @@ -from matplotlib import transforms - -import numpy as np - -import logging -from collections import defaultdict import networkx as nx -from networkx import DiGraph -from functools import cmp_to_key -import operator +import numpy as np +import re +from scipy.optimize import basinhopping +import itertools +import inspect +from . import Manual, Imported, Computed, Lookup, Part, DataJointError -# use pygraphviz if available -try: - from networkx import pygraphviz_layout -except: - pygraphviz_layout = None +user_relation_classes = (Manual, Lookup, Computed, Imported, Part) -import matplotlib.pyplot as plt -from inspect import isabstract -from .user_relations import UserRelation, Part -logger = logging.getLogger(__name__) +def _get_concrete_subclasses(class_list): + for cls in class_list: + for subclass in cls.__subclasses__(): + if not inspect.isabstract(subclass): + yield subclass + yield from _get_concrete_subclasses([subclass]) -def get_concrete_subclasses(cls): - desc = [] - child= cls.__subclasses__() - for c in child: - if not isabstract(c): - desc.append(c) - desc.extend(get_concrete_subclasses(c)) - return desc +def _get_tier(table_name): + try: + return next(tier for tier in user_relation_classes + if re.fullmatch(tier.tier_regexp, table_name)) + except StopIteration: + return None -class ERD(DiGraph): - """ - A directed graph representing dependencies between Relations within and across - multiple databases. +class ERD(nx.DiGraph): """ + Entity relationship diagram. - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + Usage: + >>> erd = Erd(source) + source can be a base relation object, a base relation class, a schema, or a module that has a schema + or source can be a sequence of such objects. - @property - def node_labels(self): - """ - :return: dictionary of key : label pairs for plotting - """ - def full_class_name(user_class): - if issubclass(user_class, Part): - return '{module}.{master}.{cls}'.format( - module=user_class.__module__, - master=user_class.master.__name__, - cls=user_class.__name__) - else: - return '{module}.{cls}'.format( - module=user_class.__module__, - cls=user_class.__name__) + >>> erd.draw() + draws the diagram using pyplot - name_map = {rel.full_table_name: full_class_name(rel) for rel in get_concrete_subclasses(UserRelation)} - return {k: self.get_label(k, name_map) for k in self.nodes()} + erd1 + erd2 - combines the two ERDs. + erd + n - adds n levels of successors + erd - n - adds n levens of predecessors + Thus dj.ERD(schema.Table)+1-1 defines the diagram of immediate ancestors and descendants of schema.Table - def get_label(self, node, name_map=None): - label = self.node[node].get('label', '') - if label.strip(): - return label - if name_map is not None and node in name_map: - return name_map[node] - # no other name exists, so just use full table now - return node - - @property - def lone_nodes(self): - """ - :return: list of nodes that are not connected to any other node - """ - return list(x for x in self.root_nodes if len(self.out_edges(x)) == 0) - - @property - def pk_edges(self): - """ - :return: list of edges representing primary key foreign relations - """ - return [edge for edge in self.edges() - if self[edge[0]][edge[1]].get('rel') == 'parent'] - - @property - def non_pk_edges(self): - """ - :return: list of edges representing non primary key foreign relations - """ - return [edge for edge in self.edges() - if self[edge[0]][edge[1]].get('rel') == 'referenced'] - - def highlight(self, nodes): - """ - Highlights specified nodes when plotting - :param nodes: list of nodes, specified by full table names, to be highlighted - """ - for node in nodes: - self.node[node]['highlight'] = True - - def remove_highlight(self, nodes=None): - """ - Remove highlights from specified nodes when plotting. If specified node is not - highlighted to begin with, nothing happens. - :param nodes: list of nodes, specified by full table names, to remove highlights from - """ - if not nodes: - nodes = self.nodes_iter() - for node in nodes: - self.node[node]['highlight'] = False + Note that erd + 1 - 1 may differ from erd - 1 + 1 and so forth. + Only those tables that are loaded in the connection object are displayed + """ + def __init__(self, source): - # TODO: make this take in various config parameters for plotting - def plot(self): - """ - Plots an entity relation diagram (ERD) among all nodes that is part - of the current graph. - """ - if not self.nodes(): # There is nothing to plot - logger.warning('Nothing to plot') + if isinstance(source, ERD): + # copy constructor + self.nodes_to_show = set(source.nodes_to_show) + super().__init__(source) return - if pygraphviz_layout is None: - logger.warning('Failed to load Pygraphviz - plotting not supported at this time') - return - - pos = pygraphviz_layout(self, prog='dot') - fig = plt.figure(figsize=[10, 7]) - ax = fig.add_subplot(111) - nx.draw_networkx_nodes(self, pos, node_size=200, node_color='g') - text_dict = nx.draw_networkx_labels(self, pos, self.node_labels) - trans = ax.transData + \ - transforms.ScaledTranslation(12/72, 0, fig.dpi_scale_trans) - for text in text_dict.values(): - text.set_horizontalalignment('left') - text.set_transform(trans) - # draw primary key relations - nx.draw_networkx_edges(self, pos, self.pk_edges, arrows=False) - # draw non-primary key relations - nx.draw_networkx_edges(self, pos, self.non_pk_edges, style='dashed', arrows=False) - apos = np.array(list(pos.values())) - xmax = apos[:, 0].max() + 200 # TODO: use something more sensible than hard fixed number - xmin = apos[:, 0].min() - 100 - ax.set_xlim(xmin, xmax) - ax.axis('off') # hide axis - - def __repr__(self): - return self.repr_path() - - def restrict_by_databases(self, databases, fill=False): - """ - Creates a subgraph containing only tables in the specified database. - :param databases: list of database names - :param fill: if True, automatically include nodes connecting two nodes in the specified modules - :return: a subgraph with specified nodes - """ - nodes = [n for n in self.nodes() if n.split('.')[0].strip('`') in databases] - if fill: - nodes = self.fill_connection_nodes(nodes) - return self.subgraph(nodes) - - def restrict_by_tables(self, tables, fill=False): - """ - Creates a subgraph containing only specified tables. - :param tables: list of tables to keep in the subgraph. Tables are specified using full table names - :param fill: set True to automatically include nodes connecting two nodes in the specified list - of tables - :return: a subgraph with specified nodes - """ - nodes = [n for n in self.nodes() if n in tables] - if fill: - nodes = self.fill_connection_nodes(nodes) - return self.subgraph(nodes) - - def restrict_by_tables_in_module(self, module, tables, fill=False): - nodes = [n for n in self.nodes() if self.node[n].get('mod') in module and - self.node[n].get('cls') in tables] - if fill: - nodes = self.fill_connection_nodes(nodes) - return self.subgraph(nodes) - - def fill_connection_nodes(self, nodes): - """ - For given set of nodes, find and add nodes that serves as - connection points for two nodes in the set. - :param nodes: list of nodes for which connection nodes are to be filled in - """ - graph = self.subgraph(self.ancestors_of_all(nodes)) - return graph.descendants_of_all(nodes) - - def ancestors_of_all(self, nodes, n=-1): - """ - Find and return a set of all ancestors of the given - nodes. The set will also contain the specified nodes. - :param nodes: list of nodes for which ancestors are to be found - :param n: maximum number of generations to go up for each node. - If set to a negative number, will return all ancestors. - :return: a set containing passed in nodes and all of their ancestors - """ - s = set() - for node in nodes: - s.update(self.ancestors(node, n)) - return s - - def descendants_of_all(self, nodes, n=-1): - """ - Find and return a set including all descendants of the given - nodes. The set will also contain the given nodes as well. - :param nodes: list of nodes for which descendants are to be found - :param n: maximum number of generations to go down for each node. - If set to a negative number, will return all descendants. - :return: a set containing passed in nodes and all of their descendants - """ - s = set() - for node in nodes: - s.update(self.descendants(node, n)) - return s - - def copy_graph(self, *args, **kwargs): - return self.__class__(self, *args, **kwargs) - - def ancestors(self, node, n=-1): - """ - Find and return a set containing all ancestors of the specified - node. For convenience in plotting, this set will also include - the specified node as well (may change in future). - :param node: node for which all ancestors are to be discovered - :param n: maximum number of generations to go up. If set to a negative number, - will return all ancestors. - :return: a set containing the node and all of its ancestors - """ - s = {node} - if n == 0: - return s - for p in self.predecessors_iter(node): - s.update(self.ancestors(p, n-1)) - return s - - def descendants(self, node, n=-1): - """ - Find and return a set containing all descendants of the specified - node. For convenience in plotting, this set will also include - the specified node as well (may change in future). - :param node: node for which all descendants are to be discovered - :param n: maximum number of generations to go down. If set to a negative number, - will return all descendants - :return: a set containing the node and all of its descendants - """ - s = {node} - if n == 0: - return s - for c in self.successors_iter(node): - s.update(self.descendants(c, n-1)) - return s - - def up_down_neighbors(self, node, ups=2, downs=2, _prev=None): - """ - Returns a set of all nodes that can be reached from the specified node by - moving up and down the ancestry tree with specific number of ups and downs. - - Example: - up_down_neighbors(node, ups=2, downs=1) will return all nodes that can be reached by - any combinations of two up tracing and 1 down tracing of the ancestry tree. This includes - all children of a grand-parent (two ups and one down), all grand parents of all children (one down - and then two ups), and all siblings parents (one up, one down, and one up). - - It must be noted that except for some special cases, there is no generalized interpretations for - the relationship among nodes captured by this method. However, it does tend to produce a fairy - good concise view of the relationships surrounding the specified node. - - - :param node: node to base all discovery on - :param ups: number of times to go up the ancestry tree (go up to parent) - :param downs: number of times to go down the ancestry tree (go down to children) - :param _prev: previously visited node. This will be excluded from up down search in this recursion - :return: a set of all nodes that can be reached within specified numbers of ups and downs from the source node - """ - s = {node} - if ups > 0: - for x in self.predecessors_iter(node): - if x != _prev: - s.update(self.up_down_neighbors(x, ups-1, downs, node)) - if downs > 0: - for x in self.successors_iter(node): - if x != _prev: - s.update(self.up_down_neighbors(x, ups, downs-1, node)) - return s - - def n_neighbors(self, node, n, directed=False, prev=None): - """ - Returns a set of n degree neighbors for the - specified node. The set will contain the node itself. - - n degree neighbors are defined as node that can be reached - within n edges from the root node. - - By default all edges (incoming and outgoing) will be followed. - Set directed=True to follow only outgoing edges. - """ - s = {node} - if n == 1: - s.update(self.predecessors(node)) - s.update(self.successors(node)) - elif n > 1: - if not directed: - for x in self.predecesors_iter(): - if x != prev: # skip prev point - s.update(self.n_neighbors(x, n-1, prev)) - for x in self.succesors_iter(): - if x != prev: - s.update(self.n_neighbors(x, n-1, prev)) - return s - - @property - def root_nodes(self): - return {node for node in self.nodes() if len(self.predecessors(node)) == 0} - - @property - def leaf_nodes(self): - return {node for node in self.nodes() if len(self.successors(node)) == 0} - - def nodes_by_depth(self): - """ - Return all nodes, ordered by their depth in the hierarchy - :returns: list of nodes, ordered by depth from shallowest to deepest - """ - ret = defaultdict(lambda: 0) - roots = self.root_nodes - - def recurse(node, depth): - if depth > ret[node]: - ret[node] = depth - for child in self.successors_iter(node): - recurse(child, depth+1) - - for root in roots: - recurse(root, 0) - - return sorted(ret.items(), key=operator.itemgetter(1)) - - def get_longest_path(self): - """ - :returns: a list of graph nodes defining th longest path in the graph - """ - # no path exists if there is not an edge! - if not self.edges(): - return [] - node_depth_list = self.nodes_by_depth() - node_depth_lookup = dict(node_depth_list) - path = [] - - leaf = node_depth_list[-1][0] - predecessors = [leaf] - while predecessors: - leaf = sorted(predecessors, key=node_depth_lookup.get)[-1] - path.insert(0, leaf) - predecessors = self.predecessors(leaf) - - return path - - def remove_edges_in_path(self, path): - """ - Removes all shared edges between this graph and the path - :param path: a list of nodes defining a path. All edges in this path will be removed from the graph if found - """ - if len(path) <= 1: # no path exists! + # if source is not a list, make it a list + try: + source[0] + except (TypeError, KeyError): + source = [source] + + # find connection in the first item in the list + try: + connection = source[0].connection + except AttributeError: + try: + connection = source[0].schema.connection + except AttributeError: + raise DataJointError('Could find database connection in %s' % repr(source[0])) + + # initialize graph from dependencies + connection.dependencies.load() + super().__init__(connection.dependencies) + + # Enumerate nodes from all the items in the list + self.nodes_to_show = set() + for source in source: + try: + self.nodes_to_show.add(source.full_table_name) + except AttributeError: + try: + database = source.database + except AttributeError: + try: + database = source.schema.database + except AttributeError: + raise DataJointError('Cannot plot ERD for %s' % repr(source)) + for node in self: + if node.startswith('`%s`' % database): + self.nodes_to_show.add(node) + + def __add__(self, arg): + """ + :param arg: either another ERD or a positive integer. + :return: Union of the ERDs when arg is another ERD or an expansion downstream when arg is a positive integer. + """ + self = ERD(self) # copy + try: + self.nodes_to_show.update(arg.nodes_to_show) + except AttributeError: + for i in range(arg): + new = nx.algorithms.boundary.node_boundary(self, self.nodes_to_show) + if not new: + break + self.nodes_to_show.update(new) + return self + + def __sub__(self, arg): + """ + :param arg: either another ERD or a positive integer. + :return: Difference of the ERDs when arg is another ERD or an expansion upstream when arg is a positive integer. + """ + self = ERD(self) # copy + try: + self.nodes_to_show.difference_update(arg.nodes_to_show) + except AttributeError: + for i in range(arg): + new = nx.algorithms.boundary.node_boundary(nx.DiGraph(self).reverse(), self.nodes_to_show) + if not new: + break + self.nodes_to_show.update(new) + return self + + def __mul__(self, arg): + """ + Intersection of two ERDs + :param arg: another ERD + :return: a new ERD comprising nodes that are present in both operands. + """ + self = ERD(self) # copy + self.nodes_to_show.intersection_update(arg.nodes_to_show) + return self + + def _make_graph(self, prefix_module): + """ + Make the self.graph - a graph object ready for drawing + """ + graph = nx.DiGraph(self).subgraph(self.nodes_to_show) + nx.set_node_attributes(graph, 'node_type', {n: _get_tier(n.split('`')[-2]) for n in graph}) + # relabel nodes to class names + class_list = list(cls for cls in _get_concrete_subclasses(user_relation_classes)) + mapping = { + cls.full_table_name: + (cls._context['__name__'] + '.' + if (prefix_module and cls._context['__name__'] != '__main__') else '') + + (cls._master.__name__+'.' if issubclass(cls, Part) else '') + cls.__name__ + for cls in class_list if cls.full_table_name in graph} + new_names = [mapping.values()] + if len(new_names) > len(set(new_names)): + raise DataJointError('Some classes have identical names. The ERD cannot be plotted.') + nx.relabel_nodes(graph, mapping, copy=False) + return graph + + def draw(self, pos=None, layout=None, prefix_module=True, font_scale=1.2, **layout_options): + if not self.nodes_to_show: + print('There is nothing to plot') return - for a, b in zip(path[:-1], path[1:]): - self.remove_edge(a, b) - - def longest_paths(self): - """ - :return: list of paths from longest to shortest. A path is a list of nodes. - """ - g = self.copy_graph() - paths = [] - path = g.get_longest_path() - while path: - paths.append(path) - g.remove_edges_in_path(path) - path = g.get_longest_path() - return paths - - def repr_path(self): - """ - Construct string representation of the erm, summarizing dependencies between - tables - :return: string representation of the erm - """ - if len(self) == 0: - return "No relations to show" - - paths = self.longest_paths() - - # turn comparator into Key object for use in sort - k = cmp_to_key(self.compare_path) - sorted_paths = sorted(paths, key=k) - - # table name will be padded to match the longest table name - node_labels = self.node_labels - n = max([len(x) for x in node_labels.values()]) + 1 - rep = '' - for path in sorted_paths: - rep += self.repr_path_with_depth(path, n) - - for node in self.lone_nodes: - rep += node_labels[node] + '\n' - - return rep - - def compare_path(self, path1, path2): - """ - Comparator between two paths: path1 and path2 based on a combination of rules. - Path 1 is greater than path2 if: - 1) i^th node in path1 is at greater depth than the i^th node in path2 OR - 2) if i^th nodes are at the same depth, i^th node in path 1 is alphabetically less than i^th node - in path 2 - 3) if neither of the above statement is true even if path1 and path2 are switched, proceed to i+1^th node - If path2 is a subpath start at node 1, then path1 is greater than path2 - :param path1: path 1 of 2 to be compared - :param path2: path 2 of 2 to be compared - :return: return 1 if path1 is greater than path2, -1 if path1 is less than path2, and 0 if they are identical - """ - node_depth_lookup = dict(self.nodes_by_depth()) - for node1, node2 in zip(path1, path2): - if node_depth_lookup[node1] != node_depth_lookup[node2]: - return -1 if node_depth_lookup[node1] < node_depth_lookup[node2] else 1 - if node1 != node2: - return -1 if node1 < node2 else 1 - if len(node1) != len(node2): - return -1 if len(node1) < len(node2) else 1 - return 0 - - def repr_path_with_depth(self, path, n=20, m=2): - node_depth_lookup = dict(self.nodes_by_depth()) - node_labels = self.node_labels - space = '-' * n - rep = '' - prev_depth = 0 - first = True - for (i, node) in enumerate(path): - depth = node_depth_lookup[node] - label = node_labels[node] - if first: - rep += (' '*(n+m))*(depth-prev_depth) - else: - rep += space.join(['-'*m]*(depth-prev_depth))[:-1] + '>' - first = False - prev_depth = depth - if i == len(path)-1: - rep += label - else: - rep += label.ljust(n, '-') - rep += '\n' - return rep - - @classmethod - def create_from_dependencies(cls, dependencies, *args, **kwargs): - obj = cls(*args, **kwargs) - - for full_table, parents in dependencies.parents.items(): - database, table = (x.strip('`') for x in full_table.split('.')) - obj.add_node(full_table, database=database, table=table) - for parent in parents: - obj.add_edge(parent, full_table, rel='parent') - - # create non primary key foreign connections - for full_table, referenced in dependencies.referenced.items(): - for ref in referenced: - obj.add_edge(ref, full_table, rel='referenced') - - return obj + graph = self._make_graph(prefix_module) + if pos is None: + pos = (layout if layout else self._layout)(graph, **layout_options) + import matplotlib.pyplot as plt + + edge_list = graph.edges(data=True) + edge_styles = ['solid' if e[2]['primary'] else 'dashed' for e in edge_list] + nx.draw_networkx_edges(graph, pos=pos, edgelist=edge_list, style=edge_styles, alpha=0.2) + + label_props = { # http://matplotlib.org/examples/color/named_colors.html + None: dict(bbox=dict(boxstyle='round,pad=0.1', facecolor='yellow', alpha=0.3), size=round(font_scale*8)), + Manual: dict(bbox=dict(boxstyle='round,pad=0.1', edgecolor='white', facecolor='darkgreen', alpha=0.3), size=round(font_scale*10)), + Lookup: dict(bbox=dict(boxstyle='round,pad=0.1', edgecolor='white', facecolor='gray', alpha=0.2), size=round(font_scale*8)), + Computed: dict(bbox=dict(boxstyle='round,pad=0.1', edgecolor='white', facecolor='red', alpha=0.2), size=round(font_scale*10)), + Imported: dict(bbox=dict(boxstyle='round,pad=0.1', edgecolor='white', facecolor='darkblue', alpha=0.2), size=round(font_scale*10)), + Part: dict(size=round(font_scale*7))} + ax = plt.gca() + for node in graph.nodes(data=True): + ax.text(pos[node[0]][0], pos[node[0]][1], node[0], + horizontalalignment=('right' if pos[node[0]][0] < 0.5 else 'left'), + **label_props[node[1]['node_type']]) + ax = plt.gca() + ax.axis('off') + ax.set_xlim([-0.4, 1.4]) # allow a margin for labels + plt.show() + + @staticmethod + def _layout(graph, quality=2): + """ + :param graph: a networkx.DiGraph object + :param quality: 0=dirty, 1=draft, 2=good, 3=great, 4=publish + :return: position dict keyed by node names + """ + if not nx.is_directed_acyclic_graph(graph): + DataJointError('This layout only works for acyclic graphs') + + # assign depths + nodes = set(node for node in graph.nodes() if not graph.in_edges(node)) # root + depth = 0 + depths = {} + while nodes: + depths = dict(depths, **dict.fromkeys(nodes, depth)) + nodes = set(edge[1] for edge in graph.out_edges(nodes)) + depth += 1 + # push depth down as far as possible + updated = True + while updated: + updated = False + for node in graph.nodes(): + if graph.successors(node): + m = min(depths[n] for n in graph.successors(node)) - 1 + updated = updated or m > depths[node] + depths[node] = m + longest_path = nx.dag_longest_path(graph) # place at x=0 + + # assign initial x positions + x = dict.fromkeys(graph, 0) + unplaced = set(node for node in graph if node not in longest_path) + for node in sorted(unplaced, key=graph.degree, reverse=True): + neighbors = set(nx.all_neighbors(graph, node)) + placed_neighbors = neighbors.difference(unplaced) + placed_other = set(graph.nodes()).difference(unplaced).difference(neighbors) + x[node] = (sum(x[n] for n in placed_neighbors) - + sum(x[n] for n in placed_other) + + 0.05*(np.random.ranf()-0.5))/(len(placed_neighbors) + len(placed_other) + 0.01) + x[node] += 2*(x[node] > 0)-1 + unplaced.remove(node) + + nodes = nx.topological_sort(graph) + x = np.array([x[n] for n in nodes]) + + intersecting_edge_pairs = list( + [[nodes.index(n) for n in edge1], + [nodes.index(n) for n in edge2]] + for edge1, edge2 in itertools.combinations(graph.edges(), 2) + if len(set(edge1 + edge2)) == 4 and ( + depths[edge1[1]] > depths[edge2[0]] and + depths[edge2[1]] > depths[edge1[0]])) + depths = depth - np.array([depths[n] for n in nodes]) + + # minimize layout cost function (for x-coordinate only) + A = np.asarray(nx.to_numpy_matrix(graph, dtype=bool)) # adjacency matrix + A = np.logical_or(A, A.transpose()) + D = np.zeros_like(A,dtype=bool) # neighbor matrix + for d in set(depths): + ix = depths == d + D[np.outer(ix,ix)]=True + D = np.logical_xor(D, np.identity(len(nodes), bool)) + + def cost(xx): + xx = np.expand_dims(xx, 1) + g = xx.transpose()-xx + h = g**2 + 1e-8 + crossings = sum((xx[edge1[0]][0] > xx[edge2[0]][0]) != (xx[edge1[1]][0] > xx[edge2[1]][0]) + for edge1, edge2 in intersecting_edge_pairs) + return crossings*1000 + h[A].sum() + 0.1*h[D].sum() + (1/h[D]).sum() + + def grad(xx): + xx = np.expand_dims(xx, 1) + g = xx.transpose()-xx + h = g**2 + 1e-8 + return -2*((A*g).sum(axis=1) + 0.1*(D*g).sum(axis=1) - (D*g/h**2).sum(axis=1)) + niter = [100, 200, 500, 1000, 3000][quality] + maxiter = [1, 2, 3, 4, 4][quality] + x = basinhopping(cost, x, niter=niter, interval=40, T=30, stepsize=1.0, disp=False, + minimizer_kwargs=dict(jac=grad, options=dict(maxiter=maxiter))).x + # normalize coordinates to unit square + phi = np.pi*20/180 # rotate coordinate slightly + cs, sn = np.cos(phi), np.sin(phi) + x, depths = cs*x - sn*depths, sn*x + cs*depths + x -= x.min() + x /= x.max()+0.01 + depths -= depths.min() + depths = depths/(depths.max()+0.01) + return {node: (x, y) for node, x, y in zip(nodes, x, depths)} diff --git a/datajoint/fetch.py b/datajoint/fetch.py index 96c7b9754..3f2e7394f 100644 --- a/datajoint/fetch.py +++ b/datajoint/fetch.py @@ -1,105 +1,72 @@ from collections import OrderedDict from collections.abc import Callable, Iterable -from functools import wraps +import numpy as np import warnings - from .blob import unpack -import numpy as np -from datajoint import DataJointError +from . import DataJointError from . import key as PRIMARY_KEY -from . import config -def prepare_attributes(relation, item): - """ - Used by fetch.__getitem__ to deal with slices - - :param relation: the relation that created the fetch object - :param item: the item passed to __getitem__. Can be a string, a tuple, a list, or a slice. +class FetchBase: - :return: a tuple of items to fetch, a list of the corresponding attributes - :raise DataJointError: if item does not match one of the datatypes above - """ - if isinstance(item, str) or item is PRIMARY_KEY: - item = (item,) - elif isinstance(item, int): - item = (relation.heading.names[item],) - elif isinstance(item, slice): - attributes = relation.heading.names - start = attributes.index(item.start) if isinstance(item.start, str) else item.start - stop = attributes.index(item.stop) if isinstance(item.stop, str) else item.stop - item = attributes[slice(start, stop, item.step)] - try: - attributes = tuple(i for i in item if i is not PRIMARY_KEY) - except TypeError: - raise DataJointError("Index must be a slice, a tuple, a list, a string.") - return item, attributes - - -def copy_first(f): - """ - Decorates methods that return an altered copy of self - """ - - @wraps(f) - def ret(*args, **kwargs): - args = list(args) - args[0] = args[0].__class__(args[0]) # call copy constructor - return f(*args, **kwargs) - - return ret + @staticmethod + def _prepare_attributes(item): + """ + Used by fetch.__getitem__ to deal with slices + :param item: the item passed to __getitem__. Can be a string, a tuple, a list, or a slice. + :return: a tuple of items to fetch, a list of the corresponding attributes + :raise DataJointError: if item does not match one of the datatypes above + """ + if isinstance(item, str) or item is PRIMARY_KEY: + item = (item,) + try: + attributes = tuple(i for i in item if i is not PRIMARY_KEY) + except TypeError: + raise DataJointError("Index must be a sequence or a string.") + return item, attributes -class Fetch(Iterable, Callable): +class Fetch(FetchBase, Callable, Iterable): """ A fetch object that handles retrieving elements from the database table. :param relation: relation the fetch object retrieves data from. """ - def __init__(self, relation): - if isinstance(relation, Fetch): # copy constructor - self.behavior = dict(relation.behavior) - self._relation = relation._relation + def __init__(self, arg): + if isinstance(arg, Fetch): + self.behavior = dict(arg.behavior) + self._relation = arg._relation else: self.behavior = dict(offset=None, limit=None, order_by=None, as_dict=False) - self._relation = relation + self._relation = arg - @copy_first def order_by(self, *args): """ Changes the state of the fetch object to order the results by a particular attribute. The commands are handed down to mysql. - :param args: the attributes to sort by. If DESC is passed after the name, then the order is descending. :return: a copy of the fetch object - Example: - >>> my_relation.fetch.order_by('language', 'name DESC') - """ + self = Fetch(self) if len(args) > 0: self.behavior['order_by'] = args return self @property - @copy_first def as_dict(self): """ Changes the state of the fetch object to return dictionaries. - :return: a copy of the fetch object - Example: - >>> my_relation.fetch.as_dict() - """ - self.behavior['as_dict'] = True - return self + ret = Fetch(self) + ret.behavior['as_dict'] = True + return ret - @copy_first def limit(self, limit): """ Limits the number of items fetched. @@ -107,10 +74,10 @@ def limit(self, limit): :param limit: limit on the number of items :return: a copy of the fetch object """ - self.behavior['limit'] = limit - return self + ret = Fetch(self) + ret.behavior['limit'] = limit + return ret - @copy_first def offset(self, offset): """ Offsets the number of itms fetched. Needs to be applied with limit. @@ -118,21 +85,11 @@ def offset(self, offset): :param offset: offset :return: a copy of the fetch object """ - if self.behavior['limit'] is None: + ret = Fetch(self) + if ret.behavior['limit'] is None: warnings.warn('You should supply a limit together with an offset,') - self.behavior['offset'] = offset - return self - - @copy_first - def set_behavior(self, **kwargs): - """ - Sets the behavior like offset, limit, or order_by via keywords arguments. - - :param kwargs: keyword arguments - :return: a copy of the fetch object - """ - self.behavior.update(kwargs) - return self + ret.behavior['offset'] = offset + return ret def __call__(self, **kwargs): """ @@ -141,10 +98,8 @@ def __call__(self, **kwargs): :param offset: the number of tuples to skip in the returned result :param limit: the maximum number of tuples to return :param order_by: the list of attributes to order the results. No ordering should be assumed if order_by=None. - :param descending: the list of attributes to order the results :param as_dict: returns a list of dictionaries instead of a record array :return: the contents of the relation in the form of a structured numpy.array - """ behavior = dict(self.behavior, **kwargs) if behavior['limit'] is None and behavior['offset'] is not None: @@ -188,9 +143,9 @@ def __iter__(self): def keys(self, **kwargs): """ - Iterator that returns primary keys. + Iterator that returns primary keys as a sequence of dicts. """ - yield from self._relation.proj().fetch.set_behavior(**dict(self.behavior, as_dict=True, **kwargs)) + yield from self._relation.proj().fetch(**dict(self.behavior, as_dict=True, **kwargs)) def __getitem__(self, item): """ @@ -199,16 +154,12 @@ def __getitem__(self, item): :return: tuple with an entry for each element of item Examples: - >>> a, b = relation['a', 'b'] >>> a, b, key = relation['a', 'b', datajoint.key] - >>> results = relation['a':'z'] # return attributes a-z as a tuple - >>> results = relation[:-1] # return all but the last attribute """ single_output = isinstance(item, str) or item is PRIMARY_KEY or isinstance(item, int) - item, attributes = prepare_attributes(self._relation, item) - - result = self._relation.project(*attributes).fetch(**self.behavior) + item, attributes = self._prepare_attributes(item) + result = self._relation.proj(*attributes).fetch(**self.behavior) return_values = [ np.ndarray(result.shape, np.dtype({name: result.dtype.fields[name] for name in self._relation.primary_key}), @@ -229,7 +180,7 @@ def __len__(self): return len(self._relation) -class Fetch1(Callable): +class Fetch1(FetchBase, Callable): """ Fetch object for fetching exactly one row. @@ -242,7 +193,6 @@ def __init__(self, relation): def __call__(self): """ This version of fetch is called when self is expected to contain exactly one tuple. - :return: the one tuple in the relation in the form of a dict """ heading = self._relation.heading @@ -251,7 +201,6 @@ def __call__(self): ret = cur.fetchone() if not ret or cur.fetchone(): raise DataJointError('fetch1 should only be used for relations with exactly one tuple') - return OrderedDict((name, unpack(ret[name]) if heading[name].is_blob else ret[name]) for name in heading.names) @@ -259,30 +208,22 @@ def __getitem__(self, item): """ Fetch attributes as separate outputs. datajoint.key is a special value that requests the entire primary key - :return: tuple with an entry for each element of item Examples: - >>> a, b = relation['a', 'b'] >>> a, b, key = relation['a', 'b', datajoint.key] - >>> results = relation['a':'z'] # return attributes a-z as a tuple - >>> results = relation[:-1] # return all but the last attribute - """ single_output = isinstance(item, str) or item is PRIMARY_KEY or isinstance(item, int) - item, attributes = prepare_attributes(self._relation, item) - - result = self._relation.project(*attributes).fetch() + item, attributes = self._prepare_attributes(item) + result = self._relation.proj(*attributes).fetch() if len(result) != 1: - raise DataJointError('Fetch1 should only return one tuple') - + raise DataJointError('fetch1 should only return one tuple. %d tuples were found' % len(result)) return_values = tuple( np.ndarray(result.shape, np.dtype({name: result.dtype.fields[name] for name in self._relation.primary_key}), result, 0, result.strides) if attribute is PRIMARY_KEY else result[attribute][0] - for attribute in item - ) + for attribute in item) return return_values[0] if single_output else return_values diff --git a/datajoint/heading.py b/datajoint/heading.py index 0e81e0ce4..bcb05aa6d 100644 --- a/datajoint/heading.py +++ b/datajoint/heading.py @@ -42,13 +42,14 @@ class Heading: the attribute names and the values are Attributes. """ - def __init__(self, attributes=None): + def __init__(self, arg=None): """ - :param attributes: a list of dicts with the same keys as Attribute + :param arg: a list of dicts with the same keys as Attribute """ - if attributes: - attributes = OrderedDict([(q['name'], Attribute(**q)) for q in attributes]) - self.attributes = attributes + assert not isinstance(arg, Heading), 'Headings cannot be copied' + self.table_info = None + self.attributes = None if arg is None else OrderedDict( + (q['name'], Attribute(**q)) for q in arg) def __len__(self): return 0 if self.attributes is None else len(self.attributes) @@ -116,6 +117,12 @@ def init_from_database(self, conn, database, table_name): """ initialize heading from a database table. The table must exist already. """ + info = conn.query('SHOW TABLE STATUS FROM `{database}` WHERE name="{table_name}"'.format( + table_name=table_name, database=database), as_dict=True).fetchone() + if info is None: + raise DataJointError('The table is not defined.') + self.table_info = {k.lower(): v for k, v in info.items()} + cur = conn.query( 'SHOW FULL COLUMNS FROM `{table_name}` IN `{database}`'.format( table_name=table_name, database=database), as_dict=True) @@ -184,19 +191,19 @@ def init_from_database(self, conn, database, table_name): attr['dtype'] = numeric_types[(t, is_unsigned)] self.attributes = OrderedDict([(q['name'], Attribute(**q)) for q in attributes]) - def proj(self, *attribute_list, **renamed_attributes): + def proj(self, attribute_list, renamed_attributes, include_primary_key): """ derive a new heading by selecting, renaming, or computing attributes. In relational algebra these operators are known as project, rename, and expand. The primary key is always included. """ - # check missing attributes - missing = [a for a in attribute_list if a not in self.names] - if missing: - raise DataJointError('Attributes `%s` are not found' % '`, `'.join(missing)) + try: # check for missing attributes + raise DataJointError('Attribute `%s` is not found' % next(a for a in attribute_list if a not in self.names)) + except StopIteration: + pass - # always add primary key attributes - attribute_list = self.primary_key + [a for a in attribute_list if a not in self.primary_key] + if include_primary_key: + attribute_list = self.primary_key + [a for a in attribute_list if a not in self.primary_key] # convert attribute_list into a list of dicts but exclude renamed attributes attribute_list = [v._asdict() for k, v in self.attributes.items() @@ -228,22 +235,34 @@ def proj(self, *attribute_list, **renamed_attributes): return Heading(attribute_list) - def join(self, other, left): + def join(self, other, aggregated): """ - Joins two headings. + Join two headings into a new one. """ assert isinstance(other, Heading) attribute_list = [v._asdict() for v in self.attributes.values()] for name in other.names: if name not in self.names: - attribute = other.attributes[name]._asdict(); - if left: + attribute = other.attributes[name]._asdict() + if aggregated: attribute['in_key'] = False attribute_list.append(attribute) return Heading(attribute_list) def resolve(self): """ - Remove attribute computations after they have been resolved in a subquery + Create a new heading with removed attribute computations. + Used by subqueries, which resolve the computations. + """ + return Heading(dict(v._asdict(), computation=None) for v in self.attributes.values()) + + def extend_primary_key(self, new_attributes): + """ + Create a new heading in which the primary key also includes new_attributes. + :param new_attributes: new attributes to be added to the primary key. """ - return Heading([dict(v._asdict(), computation=None) for v in self.attributes.values()]) + try: # check for missing attributes + raise DataJointError('Attribute `%s` is not found' % next(a for a in new_attributes if a not in self.names)) + except StopIteration: + return Heading(dict(v._asdict(), in_key=v.in_key or v.name in new_attributes) + for v in self.attributes.values()) diff --git a/datajoint/jobs.py b/datajoint/jobs.py index eaa631d77..c9e22a818 100644 --- a/datajoint/jobs.py +++ b/datajoint/jobs.py @@ -18,12 +18,20 @@ class JobRelation(BaseRelation): """ A base relation with no definition. Allows reserving jobs """ + _table_name = '~jobs' + + def __init__(self, arg, database=None): + super().__init__() + if isinstance(arg, JobRelation): + # copy constructor + self.database = arg.database + self._connection = arg._connection + self._definition = arg._definition + return - def __init__(self, connection, database): self.database = database - self._table_name = '~jobs' - self._connection = connection - self._definition = """ # job reservation table + self._connection = arg + self._definition = """ # job reservation table for `{database}` table_name :varchar(255) # className of the table key_hash :char(32) # key hash --- @@ -34,7 +42,9 @@ def __init__(self, connection, database): host="" :varchar(255) # system hostname pid=0 :int unsigned # system process id timestamp=CURRENT_TIMESTAMP :timestamp # automatic timestamp - """ + """.format(database=database) + if not self.is_declared: + self.declare() @property def definition(self): @@ -48,11 +58,19 @@ def connection(self): def table_name(self): return self._table_name + def delete(self): + """ bypass interactive prompts""" + self.delete_quick() + + def drop(self): + """ bypass interactive prompts""" + self.drop_quick() + def reserve(self, table_name, key): """ Reserve a job for computation. When a job is reserved, the job table contains an entry for the job key, identified by its hash. When jobs are completed, the entry is removed. - :param full_table_name: `database`.`table_name` + :param table_name: `database`.`table_name` :param key: the dict of the job's primary key :return: True if reserved job successfully. False = the jobs is already taken """ diff --git a/datajoint/relational_operand.py b/datajoint/relational_operand.py index 362a8fb21..6181ae786 100644 --- a/datajoint/relational_operand.py +++ b/datajoint/relational_operand.py @@ -1,20 +1,27 @@ -from collections.abc import Mapping, Sequence +import collections import numpy as np -import abc import re -from copy import copy import logging -from . import DataJointError, config import datetime +from . import DataJointError, config from .fetch import Fetch, Fetch1 logger = logging.getLogger(__name__) +def restricts_to_empty(arg): + """ + returns true if restriction to arg will produce the empty relation. + """ + return not isinstance(arg, AndList) and ( + arg is None or arg is False or isinstance(arg, str) and arg.upper() == "FALSE" or + isinstance(arg, (list, set, tuple, np.ndarray, RelationalOperand)) and len(arg) == 0) + + class AndList(list): """ - A list of restrictions to by applied to a relation. The restrictions are ANDed. - Each restriction can be a list or set or a relation whose elements are ORed. + A list of restrictions to by applied to a relation. The restrictions are AND-ed. + Each restriction can be a list or set or a relation whose elements are OR-ed. But the elements that are lists can contain other AndLists. Example: @@ -27,7 +34,7 @@ class AndList(list): class OrList(list): """ - A list of restrictions to by applied to a relation. The restrictions are ORed. + A list of restrictions to by applied to a relation. The restrictions are OR-ed. If any restriction is . But the elements that are lists can contain other AndLists. @@ -42,27 +49,50 @@ class OrList(list): pass -class RelationalOperand(metaclass=abc.ABCMeta): +class RelationalOperand: """ - RelationalOperand implements relational algebra and fetch methods. - RelationalOperand objects reference other relation objects linked by operators. + RelationalOperand implements the relational algebra. + RelationalOperand objects link other relational operands with relational operators. The leaves of this tree of objects are base relations. When fetching data from the database, this tree of objects is compiled into an SQL expression. - It is a mixin class that provides relational operators, iteration, and fetch capability. - RelationalOperand operators are: restrict, pro, and join. + RelationalOperand operators are restrict, join, proj, and aggregate. """ - _restrictions = None + def __init__(self, arg=None): + assert arg is None or isinstance(arg, RelationalOperand), \ + 'Cannot construct RelationalOperand from %s' % arg.__class__.__name__ + self._restrictions = AndList(() if arg is None else arg._restrictions) + + # --------- abstract properties ----------- + + @property + def connection(self): + """ + :return: a datajoint.Connection object + """ + raise NotImplementedError('Subclasses of RelationOperand must implement the property "connection"') + + @property + def from_clause(self): + """ + :return: a string containing the FROM clause of the SQL SELECT statement + """ + raise NotImplementedError('Subclasses of RelationOperand must implement the property "from_clause"') + + @property + def heading(self): + """ + :return: all RelationalOperands must supply a valid datajoint.Heading object + """ + raise NotImplementedError('Subclasses of RelationOperand must implement the property "from_clause"') + + # ---------- derived properties -------- @property def restrictions(self): - if self._restrictions is None: - self._restrictions = AndList() + assert isinstance(self._restrictions, AndList) return self._restrictions - def clear_restrictions(self): - self._restrictions = None - @property def primary_key(self): return self.heading.primary_key @@ -91,11 +121,11 @@ def make_condition(arg, _negate=False): condition = '({fields}) {not_}in ({subquery})'.format( fields=common_attributes, not_="not " if _negate else "", - subquery=arg.make_select(common_attributes)) + subquery=arg.make_sql(common_attributes)) return condition, False # _negate is cleared # mappings are turned into ANDed equality conditions - elif isinstance(arg, Mapping): + elif isinstance(arg, collections.abc.Mapping): condition = ['`%s`=%r' % (k, v if not isinstance(v, (datetime.date, datetime.datetime, datetime.time)) else str(v)) for k, v in arg.items() if k in self.heading] @@ -111,7 +141,7 @@ def make_condition(arg, _negate=False): # An empty or-list in the restrictions immediately causes an empty result assert isinstance(self.restrictions, AndList) - if any(is_empty_or_list(r) for r in self.restrictions): + if any(restricts_to_empty(r) for r in self.restrictions): return ' WHERE FALSE' conditions = [] @@ -121,7 +151,7 @@ def make_condition(arg, _negate=False): item = item.restriction # NOT is added below if isinstance(item, (list, tuple, set, np.ndarray)): # process an OR list - temp = [make_condition(q)[0] for q in item if q is not is_empty_or_list(q)] + temp = [make_condition(q)[0] for q in item if q is not restricts_to_empty(q)] item = '(' + ') OR ('.join(temp) + ')' if temp else 'FALSE' else: item, negate = make_condition(item, negate) @@ -130,29 +160,6 @@ def make_condition(arg, _negate=False): conditions.append(('NOT (%s)' if negate else '(%s)') % item) return ' WHERE ' + ' AND '.join(conditions) - # --------- abstract properties ----------- - - @property - @abc.abstractmethod - def connection(self): - """ - :return: a datajoint.Connection object - """ - - @property - @abc.abstractmethod - def from_clause(self): - """ - :return: a string containing the FROM clause of the SQL SELECT statement - """ - - @property - @abc.abstractmethod - def heading(self): - """ - :return: a valid datajoint.Heading object - """ - @property def select_fields(self): """ @@ -160,8 +167,11 @@ def select_fields(self): """ return self.heading.as_sql - @property def _grouped(self): + """ + If grouped, then GROUP BY the primary key. Used for aggregation. + :return: True for aggregation, False otherwise + """ return False # --------- relational operators ----------- @@ -170,19 +180,7 @@ def __mul__(self, other): """ relational join """ - return Join(self, other) - - def __mod__(self, attributes=None): - """ - relational projection operator. See RelationalOperand.project - """ - return self.proj(*attributes) - - def project(self, *args, **kwargs): - """ - alias for self.proj() for backward compatibility - """ - return self.proj(*args, **kwargs) + return other*self if isinstance(other, U) else Join(self, other) def proj(self, *attributes, **renamed_attributes): """ @@ -195,22 +193,22 @@ def proj(self, *attributes, **renamed_attributes): Each attribute can only be used once in attributes or renamed_attributes. Therefore, the projected relation cannot have more attributes than the original relation. """ - return Projection(self, *attributes, **renamed_attributes) + return Projection(self, attributes, renamed_attributes) - def aggregate(self, group, *attributes, **renamed_attributes): + def aggregate(self, group, *attributes, keep_all_rows=False, **renamed_attributes): """ Relational aggregation operator :param group: relation whose tuples can be used in aggregation operators :param attributes: - :param renamed_attributes: + :param keep_all_rows: True = preserve the number of tuples in the result (equivalent of LEFT JOIN in SQL) + :param renamed_attributes: a dict of renamings and computations :return: a relation representing the aggregation/projection operator result """ if not isinstance(group, RelationalOperand): raise DataJointError('The second argument must be a relation') - return Aggregation( - Join(self, group, left=True), - *attributes, **renamed_attributes) + return Aggregation(self, group, keep_all_rows=keep_all_rows, + attributes=attributes, renamed_attributes=renamed_attributes) def __iand__(self, restriction): """ @@ -225,27 +223,21 @@ def __and__(self, restriction): """ relational restriction or semijoin :return: a restricted copy of the argument - See relational_operand.restrict for more detail. """ - ret = copy(self) - ret.clear_restrictions() - ret.restrict(self.restrictions) - ret.restrict(restriction) - return ret + return self.__class__(self).restrict(restriction) def __isub__(self, restriction): """ - in-place inverted restriction + in-place inverted restriction aka antijoin See relational_operand.restrict for more detail. """ - self.restrict(Not(restriction)) - return self + return self.restrict(Not(restriction)) def __sub__(self, restriction): """ - inverse restriction aka antijoin + inverted restriction aka antijoin :return: a restricted copy of the argument See relational_operand.restrict for more detail. @@ -265,20 +257,22 @@ def restrict(self, restriction): Inverse restriction is accomplished by either using the subtraction operator or the Not class. The expressions in each row equivalent: + rel & True rel + rel & False the empty relation rel & 'TRUE' rel rel & 'FALSE' the empty relation rel - cond rel & Not(cond) - rel - 'TRUE' rel & 'FALSE' + rel - 'TRUE' rel & False rel - 'FALSE' rel rel & AndList((cond1,cond2)) rel & cond1 & cond2 rel & AndList() rel rel & [cond1, cond2] rel & OrList((cond1, cond2)) - rel & [] rel & 'FALSE' - rel & None rel & 'FALSE' - rel & any_empty_relation rel & 'FALSE' + rel & [] rel & False + rel & None rel & False + rel & any_empty_relation rel & False rel - AndList((cond1,cond2)) rel & [Not(cond1), Not(cond2)] rel - [cond1, cond2] rel & Not(cond1) & Not(cond2) - rel - AndList() rel & 'FALSE' + rel - AndList() rel & False rel - [] rel rel - None rel rel - any_empty_relation rel @@ -296,13 +290,17 @@ def restrict(self, restriction): :param restriction: a sequence or an array (treated as OR list), another relation, an SQL condition string, or an AndList. """ + # ineffective restrictions + if isinstance(restriction, U) or restriction is True or \ + isinstance(restriction, str) and restriction.upper() == "TRUE": + return self if isinstance(restriction, AndList): self.restrictions.extend(restriction) - elif is_empty_or_list(restriction): - self.clear_restrictions() - self.restrictions.append('FALSE') + elif restricts_to_empty(restriction): + self._restrictions = AndList(['FALSE']) else: self.restrictions.append(restriction) + return self @property def fetch1(self): @@ -320,79 +318,67 @@ def attributes_in_restrictions(self): s = self.where_clause return set(name for name in self.heading.names if name in s) - @abc.abstractmethod def _repr_helper(self): """ :return: (string) basic representation of the relation """ + raise NotImplementedError('Subclasses of RelationOperand must implement the method "_repr_helper"') def __repr__(self): if config['loglevel'].lower() == 'debug': ret = self._repr_helper() - if self._restrictions: - ret += ' & %r' % self._restrictions - else: - rel = self.proj(*self.heading.non_blobs) # project out blobs - limit = config['display.limit'] - width = config['display.width'] - - tups = rel.fetch(limit=limit) - columns = rel.heading.names - - widths = {f: min(max([len(f)] + [len(str(e)) for e in tups[f]])+4,width) for f in columns} - - templates = {f: '%%-%d.%ds' % (widths[f], widths[f]) for f in columns} - repr_string = ' '.join([templates[column] % column for column in columns]) + '\n' - repr_string += ' '.join(['+' + '-' * (widths[column] - 2) + '+' for column in columns]) + '\n' - for tup in tups: - repr_string += ' '.join([templates[column] % tup[column] for column in columns]) + '\n' - if len(rel) > limit: - repr_string += '...\n' - repr_string += ' (%d tuples)\n' % len(rel) - return repr_string - - return ret + if self.restrictions: + ret += ' & %r' % self.restrictions + return ret + rel = self.proj(*self.heading.non_blobs) # project out blobs + limit = config['display.limit'] + width = config['display.width'] + tups = rel.fetch(limit=limit) + columns = rel.heading.names + widths = {f: min(max([len(f)] + [len(str(e)) for e in tups[f]])+4, width) for f in columns} + templates = {f: '%%-%d.%ds' % (widths[f], widths[f]) for f in columns} + return ( + ' '.join([templates[f] % ('*'+f if f in rel.primary_key else f) for f in columns]) + '\n' + + ' '.join(['+' + '-' * (widths[column] - 2) + '+' for column in columns]) + '\n' + + '\n'.join(' '.join(templates[f] % tup[f] for f in columns) for tup in tups) + + ('\n...\n' if len(rel) > limit else '\n') + + ' (%d tuples)\n' % len(rel)) def _repr_html_(self): limit = config['display.limit'] rel = self.proj(*self.heading.non_blobs) # project out blobs columns = rel.heading.names + info = self.heading.table_info content = dict( - head=''.join(columns), + title="" if info is None else "

%s

" % info['comment'], + head=''.join("" + c + "" if c in self.primary_key else c for c in columns), body=''.join( ['\n'.join(['%s' % column for column in tup]) for tup in rel.fetch(limit=limit)]), - tuples=len(rel) - ) - return """
\n - \n - \n - \n - - \n - - - %(body)s - - \n
- %(head)s -
\n

%(tuples)i tuples

\n
- """ % content - - def make_select(self, select_fields=None): + tuples=len(rel)) + return """ %(title)s +
+ + + %(body)s +
%(head)s
+

%(tuples)i tuples

+ """ % content + + def make_sql(self, select_fields=None): return 'SELECT {fields} FROM {from_}{where}{group}'.format( fields=select_fields if select_fields else self.select_fields, from_=self.from_clause, where=self.where_clause, - group=' GROUP BY `%s`' % '`,`'.join(self.primary_key) if self._grouped else '') + group=' GROUP BY `%s`' % '`,`'.join(self.primary_key) if self._grouped() else '') def __len__(self): """ number of tuples in the relation. This also takes care of the truth value """ - if self._grouped: + if self._grouped(): return len(Subquery(self)) - cur = self.connection.query(self.make_select('count(*)')) + cur = self.connection.query(self.make_sql('count(*)')) return cur.fetchone()[0] def __contains__(self, item): @@ -409,7 +395,7 @@ def cursor(self, offset=0, limit=None, order_by=None, as_dict=False): """ if offset and limit is None: raise DataJointError('limit is required when offset is set') - sql = self.make_select() + sql = self.make_sql() if order_by is not None: sql += ' ORDER BY ' + ', '.join(order_by) @@ -427,7 +413,7 @@ class Not: """ def __init__(self, restriction): - self.restriction = restriction + self.restriction = True if isinstance(restriction, U) else restriction class Join(RelationalOperand): @@ -435,17 +421,28 @@ class Join(RelationalOperand): Relational join """ - def __init__(self, arg1, arg2, left=False): - if not isinstance(arg2, RelationalOperand): - raise DataJointError('a relation can only be joined with another relation') - if arg1.connection != arg2.connection: - raise DataJointError('Cannot join relations with different database connections') - self._arg1 = Subquery(arg1) if isinstance(arg1, Projection) else arg1 - self._arg2 = Subquery(arg2) if isinstance(arg2, Projection) else arg2 - self._heading = self._arg1.heading.join(self._arg2.heading, left=left) - self.restrict(self._arg1.restrictions) - self.restrict(self._arg2.restrictions) - self._left = left + def __init__(self, arg1, arg2=None, aggregated=False, keep_all_rows=None): + if arg2 is None and isinstance(arg1, Join): + # copy constructor + super().__init__(arg1) + self._arg1 = arg1._arg1 + self._arg2 = arg1._arg2 + self._heading = arg1._heading + self._left = arg1._left + else: + super().__init__() + assert aggregated or keep_all_rows is None # keep_all_rows should be set only for aggregation + assert not any(isinstance(arg, U) for arg in (arg1, arg2)), 'Cannot join with Relation U' + if not isinstance(arg2, RelationalOperand): + raise DataJointError('a relation can only be joined with another relation') + if arg1.connection != arg2.connection: + raise DataJointError('Cannot join relations with different database connections') + self._arg1 = Subquery(arg1) if isinstance(arg1, Projection) else arg1 + self._arg2 = Subquery(arg2) if isinstance(arg2, Projection) else arg2 + self._heading = self._arg1.heading.join(self._arg2.heading, aggregated=aggregated) + self.restrict(self._arg1.restrictions) + self.restrict(self._arg2.restrictions) + self._left = keep_all_rows def _repr_helper(self): return "(%r) * (%r)" % (self._arg1, self._arg2) @@ -470,25 +467,38 @@ def select_fields(self): return self.heading.as_sql +attribute_alias_parser = re.compile('^\s*(?P\S(.*\S)?)\s*->\s*(?P[a-z][a-z_0-9]*)\s*$') + + class Projection(RelationalOperand): - def __init__(self, arg, *attributes, **renamed_attributes): + + def __init__(self, arg, attributes=None, renamed_attributes=None, include_primary_key=True): """ See RelationalOperand.proj() """ + if attributes is None: + # copy constructor + assert isinstance(arg, Projection), 'Projection can only be copied from another projection.' + super().__init__(arg) # copy restrictions + self._arg = arg._arg + self._renamed_attributes = arg._renamed_attributes # ok not to copy + self._attributes = arg._attributes # ok not to copy + self._include_primary_key = arg._include_primary_key + return + + super().__init__() # parse attributes in the form 'sql_expression -> new_attribute' - alias_parser = re.compile( - '^\s*(?P\S(.*\S)?)\s*->\s*(?P[a-z][a-z_0-9]*)\s*$') self._attributes = [] self._renamed_attributes = renamed_attributes + self._include_primary_key = include_primary_key for attribute in attributes: - alias_match = alias_parser.match(attribute) + alias_match = attribute_alias_parser.match(attribute) if alias_match: d = alias_match.groupdict() self._renamed_attributes.update({d['alias']: d['sql_expression']}) else: self._attributes.append(attribute) self._arg = arg - restricting_on_removed_attributes = bool( arg.attributes_in_restrictions() - set(self.heading.names)) use_subquery = restricting_on_removed_attributes or arg.heading.computed @@ -506,11 +516,11 @@ def connection(self): @property def heading(self): - return self._arg.heading.proj(*self._attributes, **self._renamed_attributes) + return self._arg.heading.proj( + self._attributes, self._renamed_attributes, include_primary_key=self._include_primary_key) - @property def _grouped(self): - return self._arg._grouped + return self._arg._grouped() @property def from_clause(self): @@ -519,27 +529,50 @@ def from_clause(self): def __and__(self, restriction): has_restriction = isinstance(restriction, RelationalOperand) or bool(restriction) do_subquery = has_restriction and self.heading.computed - ret = Subquery(self) if do_subquery else self - ret.restrict(restriction) - return ret + return (Subquery(self) if do_subquery else self).restrict(restriction) class Aggregation(Projection): - @property + + def __init__(self, arg, group=None, attributes=None, renamed_attributes=None, keep_all_rows=None): + """ + See: RelationalOperand.aggregate + """ + if group is None and isinstance(arg, Aggregation): + # copy constructor + super().__init__(arg) + self._left_arg = arg._left_arg + self._group = arg._group + else: + super().__init__( + Join(arg, group, aggregated=True, keep_all_rows=keep_all_rows), + attributes=attributes, renamed_attributes=renamed_attributes) + self._left_arg = arg + self._group = group + def _grouped(self): return True + def _repr_helper(self): + return "(%r).aggregate(%r, %r, **%s)" % (self._arg, self._group, self._attributes) + class Subquery(RelationalOperand): """ A Subquery encapsulates its argument in a SELECT statement, enabling its use as a subquery. - The attribute list and the WHERE clause are resolved. - As such, a subquery does not have any renamed attributes. + The attribute list and the WHERE clause are resolved. Thus, a subquery no longer has any renamed attributes. + A subquery of a subquery is a just a copy of the subquery with no change in SQL. """ __counter = 0 def __init__(self, arg): - self._arg = arg + if isinstance(arg, Subquery): + # copy constructor + super().__init__(arg) + self._arg = arg._arg + else: + super().__init__() + self._arg = arg @property def connection(self): @@ -552,7 +585,7 @@ def counter(self): @property def from_clause(self): - return '(' + self._arg.make_select() + ') as `_s%x`' % self.counter + return '(' + self._arg.make_sql() + ') as `_s%x`' % self.counter @property def select_fields(self): @@ -566,10 +599,94 @@ def _repr_helper(self): return "%r" % self._arg -def is_empty_or_list(arg): +class U(RelationalOperand): """ - returns true if the argument is equivalent to an empty OR list. + dj.U objects are special relations representing all possible values their attributes. + dj.U objects cannot be queried on their own but are useful for forming some relational queries. + dj.U('attr1', ..., 'attrn') represents a relation with the primary key attributes attr1 ... attrn. + The body of the relation is filled with all possible combinations of values of the attributes. + Without any attributes, dj.U() represents the relation with one tuple and no attributes. + The Third Manifesto refers to dj.U() as TABLE_DEE. + + Relational restriction: + dj.U can be used to enumerate unique combinations of values of attributes from other relations. + + The following expression produces a relation containing all unique combinations of contrast and brightness + found in relation stimulus: + dj.U('contrast', 'brightness') & stimulus + + The following expression produces a relation containing all unique combinations of contrast and brightness that is + contained in relation1 but not contained in relation 2. + (dj.U('contrast', 'brightness') & relation1) - relation2 + + Relational aggregation: + In aggregation, dj.U is used to compute aggregate expressions on the entire relation. + + The following expression produces a relation with one tuple and one attribute s containing the total number + of tuples in relation: + dj.U().aggregate(relation, n='count(*)') + + The following expression produces a relation with one tuple containing the number n of distinct values of attr + in relation. + dj.U().aggregate(relation, n='count(distinct attr)') + + The following expression produces a relation with one tuple and one attribute s containing the total sum of attr + from relation: + dj.U().aggregate(relation, s='sum(attr)') # sum of attr from the entire relation + + The following expression produces a relation with the count n of tuples in relation containing each unique + combination of values in attr1 and attr2. + dj.U(attr1,attr2).aggregate(relation, n='count(*)') + + Joins: + If relation rel has attributes 'attr1' and 'attr2', then rel*dj.U('attr1','attr2') or produces a relation that is + identical to rel except attr1 and attr2 are included in the primary key. This is useful for producing a join on + non-primary key attributes. + For example, if attr is in both rel1 and rel2 but not in their primary keys, then rel1*rel2 will throw an error + because in most cases, it does not make sense to join on non-primary key attributes and users must first rename + attr in one of the operands. The expression dj.U('attr')*rel1*rel2 overrides this constraint. + Join is commutative. """ - return not isinstance(arg, AndList) and ( - arg is None or - isinstance(arg, (list, set, tuple, np.ndarray, RelationalOperand)) and len(arg) == 0) + + def __init__(self, *primary_key): + super().__init__() + if len(primary_key) == 1 and isinstance(primary_key[0], U): + # copy constructor + self._primary_key = primary_key[0]._primary_key # ok not to copy + else: + # regular constructor + self._primary_key = primary_key + + # ----------- prohibited operations ------------- # + @property + def connection(self): + raise DataJointError('Relation U does not support this operation') + + @property + def from_clause(self): + raise DataJointError('Relation U does not support this operation') + + # ------------- overriden operations ---------------- # + + def _repr_helper(self): + return 'U(%s)' % (','.join(self.primary_key)) + + @property + def heading(self): + raise DataJointError('Relation U does not support this operation') + + @property + def primary_key(self): + return self._primary_key + + def restrict(self, relation): + if not isinstance(relation, RelationalOperand): + raise DataJointError('Relation U can only be restricted with another relation.') + return Projection(relation, attributes=self.primary_key, renamed_attributes=dict(), include_primary_key=False) + + def __mul__(self, relation): + if not isinstance(relation, RelationalOperand): + raise DataJointError('Relation U can only be joined with another relation.') + copy = relation.__class__(relation) + copy._heading = copy.heading.extend_primary_key(self.primary_key) + return copy diff --git a/datajoint/schema.py b/datajoint/schema.py index bbad08846..4ad7e1acb 100644 --- a/datajoint/schema.py +++ b/datajoint/schema.py @@ -1,10 +1,10 @@ -from operator import itemgetter import pymysql import logging import re -from . import conn, DataJointError +from . import conn, DataJointError, config from datajoint.utils import to_camel_case from .heading import Heading +from .utils import user_choice from .user_relations import Part, Computed, Imported, Manual, Lookup import inspect @@ -13,12 +13,10 @@ class Schema: """ - A schema object can be used as a decorator that associates a Relation class to a database as - well as a namespace for looking up foreign key references in table declaration. + A schema object is a decorator for UserRelation classes that binds them to their database. + It also specifies the namespace `context` in which other UserRelation classes are defined. """ - table2class = {} - def __init__(self, database, context, connection=None): """ Associates the specified database with this schema object. If the target database does not exist @@ -34,7 +32,7 @@ def __init__(self, database, context, connection=None): self.connection = connection self.context = context if not self.exists: - # create schema + # create database logger.info("Database `{database}` could not be found. " "Attempting to create the database.".format(database=database)) try: @@ -44,65 +42,64 @@ def __init__(self, database, context, connection=None): raise DataJointError("Database named `{database}` was not defined, and" " an attempt to create has failed. Check" " permissions.".format(database=database)) - else: - for table_name in map(itemgetter(0), - connection.query( - 'SHOW TABLES in {database}'.format(database=self.database)).fetchall()): - class_name, class_obj = self.create_userrelation_from_table(table_name) - - # decorate class with @schema if it is not None and not a dj.Part - context[class_name] = self(class_obj) \ - if class_obj is not None and not issubclass(class_obj, Part) else class_obj - - connection.register(self) - def create_userrelation_from_table(self, table_name): + def spawn_missing_classes(self): """ - Creates the appropriate python user relation classes from tables in a database. The tier of the class - is inferred from the table name. - - Schema stores the class objects in a class dictionary and returns those - when prompted for the same table from the same database again. This way, the id - of both returned class objects is the same and comparison with python's "is" works - correctly. + Creates the appropriate python user relation classes from tables in the database and places them + in the context. """ - class_name = to_camel_case(table_name) - - def _make_tuples(other, key): - raise NotImplementedError("This is an automatically created class. _make_tuples is not implemented.") - - if (self.database, table_name) in Schema.table2class: - class_name, class_obj = Schema.table2class[self.database, table_name] - else: - if re.fullmatch(Part._regexp, table_name): - groups = re.fullmatch(Part._regexp, table_name).groupdict() - master_table_name = groups['master'] - master_name, master_class = self.create_userrelation_from_table(master_table_name) - class_name = to_camel_case(groups['part']) - class_obj = type(class_name, (Part,), dict(definition=...)) - setattr(master_class, class_name, class_obj) - class_name, class_obj = master_name, master_class - - elif re.fullmatch(Computed._regexp, table_name): - class_obj = type(class_name, (Computed,), dict(definition=..., _make_tuples=_make_tuples)) - elif re.fullmatch(Imported._regexp, table_name): - class_obj = type(class_name, (Imported,), dict(definition=..., _make_tuples=_make_tuples)) - elif re.fullmatch(Lookup._regexp, table_name): - class_obj = type(class_name, (Lookup,), dict(definition=...)) - elif re.fullmatch(Manual._regexp, table_name): - class_obj = type(class_name, (Manual,), dict(definition=...)) + + def _make_tuples_stub(unused_self, unused_key): + raise NotImplementedError( + "This is an automatically created user relation class. _make_tuples is not implemented.") + + tables = [row[0] for row in self.connection.query('SHOW TABLES in `%s`' % self.database)] + + # declare master relation classes + master_classes = {} + part_tables = [] + for table_name in tables: + class_name = to_camel_case(table_name) + if class_name not in self.context: + try: + cls = next(cls for cls in (Lookup, Manual, Imported, Computed) + if re.fullmatch(cls.tier_regexp, table_name)) + except StopIteration: + if re.fullmatch(Part.tier_regexp, table_name): + part_tables.append(table_name) + else: + master_classes[table_name] = type(class_name, (cls,), + dict(definition=..., _make_tuples=_make_tuples_stub)) + # attach parts to masters + for part_table in part_tables: + groups = re.fullmatch(Part.tier_regexp, part_table).groupdict() + class_name = to_camel_case(groups['part']) + try: + master_class = master_classes[groups['master']] + except KeyError: + # if master not found among the spawned classes, check in the context + master_class = self.context[to_camel_case(groups['master'])] + if not hasattr(master_class, class_name): + part_class = type(class_name, (Part,), dict(definition=...)) + part_class._master = master_class + self.process_relation_class(part_class, context=self.context, assert_declared=True) + setattr(master_class, class_name, part_class) else: - class_obj = None + setattr(master_class, class_name, type(class_name, (Part,), dict(definition=...))) - Schema.table2class[self.database, table_name] = class_name, class_obj - return class_name, class_obj + # place classes in context upon decorating them with the schema + for cls in master_classes.values(): + self.context[cls.__name__] = self(cls) def drop(self): """ Drop the associated database if it exists """ - if self.exists: + if not self.exists: + logger.info("Database named `{database}` does not exist. Doing nothing.".format(database=self.database)) + elif (not config['safemode'] or + user_choice("Proceed to delete entire schema `%s`?" % self.database, default='no') == 'yes'): logger.info("Dropping `{database}`.".format(database=self.database)) try: self.connection.query("DROP DATABASE `{database}`".format(database=self.database)) @@ -110,8 +107,6 @@ def drop(self): except pymysql.OperationalError: raise DataJointError("An attempt to drop database named `{database}` " "has failed. Check permissions.".format(database=self.database)) - else: - logger.info("Database named `{database}` does not exist. Doing nothing.".format(database=self.database)) @property def exists(self): @@ -121,46 +116,49 @@ def exists(self): cur = self.connection.query("SHOW DATABASES LIKE '{database}'".format(database=self.database)) return cur.rowcount > 0 + def process_relation_class(self, relation_class, context, assert_declared=False): + """ + assign schema properties to the relation class and declare the table + """ + relation_class.database = self.database + relation_class._connection = self.connection + relation_class._heading = Heading() + relation_class._context = context + # instantiate the class, declare the table if not already, and fill it with initial values. + instance = relation_class() + if not instance.is_declared: + assert not assert_declared, 'incorrect table name generation' + instance.declare() + if hasattr(instance, 'contents'): + total = len(instance) + contents_keys = [dict(zip(instance.primary_key, c)) for c in instance.contents] + if total > len(instance & contents_keys) and 'yes' == user_choice( + '%s contains data that are no longer in its contents. ' + 'Would you like to delete it?' % relation_class.__name__): + (instance - contents_keys).delete() + total = len(instance) + if len(instance.contents) > total: + instance.insert(instance.contents, skip_duplicates=True) + def __call__(self, cls): """ Binds the passed in class object to a database. This is intended to be used as a decorator. :param cls: class to be decorated """ - def process_relation_class(relation_class, context): - """ - assign schema properties to the relation class and declare the table - """ - relation_class.database = self.database - relation_class._connection = self.connection - relation_class._heading = Heading() - relation_class._context = context - # instantiate the class and declare the table in database if not already present - relation_class().declare() - if issubclass(cls, Part): raise DataJointError('The schema decorator should not be applied to Part relations') - process_relation_class(cls, context=self.context) + self.process_relation_class(cls, context=self.context) # Process part relations - def is_part(x): - return inspect.isclass(x) and issubclass(x, Part) - - parts = list() - for part in dir(cls): + for part in cls._ordered_class_members: if part[0].isupper(): part = getattr(cls, part) - if is_part(part): - parts.append(part) + if inspect.isclass(part) and issubclass(part, Part): part._master = cls - process_relation_class(part, context=dict(self.context, **{cls.__name__: cls})) - - # invoke Relation._prepare() on class and its part relations. - cls()._prepare() - for part in parts: - part()._prepare() - + # allow addressing master + self.process_relation_class(part, context=dict(self.context, **{cls.__name__: cls})) return cls @property diff --git a/datajoint/user_relations.py b/datajoint/user_relations.py index 96f7dccc3..42af6a69c 100644 --- a/datajoint/user_relations.py +++ b/datajoint/user_relations.py @@ -2,22 +2,31 @@ Hosts the table tiers, user relations should be derived from. """ +import collections from .base_relation import BaseRelation from .autopopulate import AutoPopulate -from .utils import from_camel_case +from .utils import from_camel_case, ClassProperty +from . import DataJointError _base_regexp = r'[a-z]+[a-z0-9]*(_[a-z]+[a-z0-9]*)*' -class classproperty: - def __init__(self, f): - self.f = f +class OrderedClass(type): + """ + Class whose members are ordered + See https://docs.python.org/3/reference/datamodel.html#metaclass-example + """ + @classmethod + def __prepare__(metacls, name, bases, **kwds): + return collections.OrderedDict() - def __get__(self, obj, owner): - return self.f(owner) + def __new__(cls, name, bases, namespace, **kwds): + result = type.__new__(cls, name, bases, dict(namespace)) + result._ordered_class_members = tuple(namespace) + return result -class UserRelation(BaseRelation): +class UserRelation(BaseRelation, metaclass=OrderedClass): """ A subclass of UserRelation is a dedicated class interfacing a base relation. UserRelation is initialized by the decorator generated by schema(). @@ -25,15 +34,24 @@ class UserRelation(BaseRelation): _connection = None _context = None _heading = None - _regexp = None + tier_regexp = None _prefix = None - @classproperty + @ClassProperty def connection(cls): return cls._connection - @classproperty + @ClassProperty + def table_name(cls): + """ + :returns: the table name of the table formatted for mysql. + """ + return cls._prefix + from_camel_case(cls.__name__) + + @ClassProperty def full_table_name(cls): + if cls.database is None: + raise DataJointError('Class %s is not properly declared (schema decorator not applied?)' % cls.__name__) return r"`{0:s}`.`{1:s}`".format(cls.database, cls.table_name) @@ -43,14 +61,7 @@ class Manual(UserRelation): """ _prefix = r'' - _regexp = r'(?P' + _prefix + _base_regexp + ')' - - @classproperty - def table_name(cls): - """ - :returns: the table name of the table formatted for SQL. - """ - return from_camel_case(cls.__name__) + tier_regexp = r'(?P' + _prefix + _base_regexp + ')' class Lookup(UserRelation): @@ -61,21 +72,7 @@ class Lookup(UserRelation): """ _prefix = '#' - _regexp = r'(?P' + _prefix + _base_regexp.replace('TIER', 'lookup') + ')' - - @classproperty - def table_name(cls): - """ - :returns: the table name of the table formatted for mysql. - """ - return cls._prefix + from_camel_case(cls.__name__) - - def _prepare(self): - """ - Checks whether the instance has a property called `contents` and inserts its elements. - """ - if hasattr(self, 'contents'): - self.insert(self.contents, skip_duplicates=True) + tier_regexp = r'(?P' + _prefix + _base_regexp.replace('TIER', 'lookup') + ')' class Imported(UserRelation, AutoPopulate): @@ -85,14 +82,7 @@ class Imported(UserRelation, AutoPopulate): """ _prefix = '_' - _regexp = r'(?P' + _prefix + _base_regexp + ')' - - @classproperty - def table_name(cls): - """ - :returns: the table name of the table formatted for mysql. - """ - return cls._prefix + from_camel_case(cls.__name__) + tier_regexp = r'(?P' + _prefix + _base_regexp + ')' class Computed(UserRelation, AutoPopulate): @@ -102,17 +92,10 @@ class Computed(UserRelation, AutoPopulate): """ _prefix = '__' - _regexp = r'(?P' + _prefix + _base_regexp + ')' - - @classproperty - def table_name(cls): - """ - :returns: the table name of the table formatted for SQL. - """ - return cls._prefix + from_camel_case(cls.__name__) + tier_regexp = r'(?P' + _prefix + _base_regexp + ')' -class Part(BaseRelation): +class Part(UserRelation): """ Inherit from this class if the table's values are details of an entry in another relation and if this table is populated by this relation. For example, the entries inheriting from @@ -120,16 +103,28 @@ class Part(BaseRelation): Part relations are implemented as classes inside classes. """ - _regexp = r'(?P' + '|'.join( - [c._regexp for c in [Manual, Imported, Computed, Lookup]] + _connection = None + _context = None + _heading = None + _master = None + + tier_regexp = r'(?P' + '|'.join( + [c.tier_regexp for c in (Manual, Lookup, Imported, Computed)] ) + r'){1,1}' + '__' + r'(?P' + _base_regexp + ')' - _master = None + @ClassProperty + def connection(cls): + return cls._connection + + @ClassProperty + def full_table_name(cls): + return None if cls.database is None or cls.table_name is None else r"`{0:s}`.`{1:s}`".format( + cls.database, cls.table_name) - @classproperty + @ClassProperty def master(cls): return cls._master - @classproperty + @ClassProperty def table_name(cls): - return cls.master.table_name + '__' + from_camel_case(cls.__name__) + return None if cls.master is None else cls.master.table_name + '__' + from_camel_case(cls.__name__) \ No newline at end of file diff --git a/datajoint/utils.py b/datajoint/utils.py index dcf0cb5a7..e5beccec9 100644 --- a/datajoint/utils.py +++ b/datajoint/utils.py @@ -1,8 +1,15 @@ -import numpy as np import re from datajoint import DataJointError +class ClassProperty: + def __init__(self, f): + self.f = f + + def __get__(self, obj, owner): + return self.f(owner) + + def user_choice(prompt, choices=("yes", "no"), default=None): """ Prompts the user for confirmation. The default value, if any, is capitalized. @@ -21,20 +28,6 @@ def user_choice(prompt, choices=("yes", "no"), default=None): return response -# TODO: This will be removed after dj.All is implemented. See issue #112 -def group_by(rel, *attributes, sortby=None): - r = rel.project(*attributes).fetch() - dtype2 = np.dtype({name:r.dtype.fields[name] for name in attributes}) - r2 = np.unique(np.ndarray(r.shape, dtype2, r, 0, r.strides)) - r2.sort(order=sortby if sortby is not None else attributes) - for nk in r2: - restr = ' and '.join(["%s='%s'" % (fn, str(v)) for fn, v in zip(r2.dtype.names, nk)]) - if len(nk) == 1: - yield nk[0], rel & restr - else: - yield nk, rel & restr - - def to_camel_case(s): """ Convert names with under score (_) separation into camel case names. diff --git a/demos/demo1.py b/demos/demo1.py deleted file mode 100644 index d30894617..000000000 --- a/demos/demo1.py +++ /dev/null @@ -1,67 +0,0 @@ -# -*- coding: utf-8 -*- - -import datajoint as dj - -print("Welcome to the database 'demo1'") - -schema = dj.schema('dj_test', locals()) - -@schema -class Subject(dj.Manual): - definition = """ - # Basic subject info - subject_id : int # internal subject id - --- - real_id : varchar(40) # real-world name - species = "mouse" : enum('mouse', 'monkey', 'human') # species - date_of_birth=null : date # animal's date of birth - sex="unknown" : enum('M','F','unknown') # - caretaker="Unknown" : varchar(20) # person responsible for working with this subject - animal_notes="" : varchar(4096) # strain, genetic manipulations, etc - """ - -s = Subject() -p = s.primary_key - - -@schema -class Experiment(dj.Manual): - definition = """ - # Basic subject info - - -> Subject - experiment : smallint # experiment number for this subject - --- - experiment_folder : varchar(255) # folder path - experiment_date : date # experiment start date - experiment_notes="" : varchar(4096) - experiment_ts=CURRENT_TIMESTAMP : timestamp # automatic timestamp - """ - - -@schema -class Session(dj.Manual): - definition = """ - # a two-photon imaging session - - -> Experiment - session_id : tinyint # two-photon session within this experiment - ----------- - setup : tinyint # experimental setup - lens : tinyint # lens e.g.: 10x, 20x, 25x, 60x - """ - - -@schema -class Scan(dj.Manual): - definition = """ - # a two-photon imaging session - - -> Session - scan_id : tinyint # two-photon session within this experiment - ---- - depth : float # depth from surface - wavelength : smallint # (nm) laser wavelength - mwatts: numeric(4,1) # (mW) laser power to brain - """ - diff --git a/demos/rundemo1.py b/demos/rundemo1.py deleted file mode 100644 index 48a9db6e8..000000000 --- a/demos/rundemo1.py +++ /dev/null @@ -1,110 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Created on Thu Aug 28 00:46:11 2014 - -@author: dimitri -""" -import logging -import demo1 -from collections import namedtuple - -logging.basicConfig(level=logging.DEBUG) - -subject = demo1.Subject() -experiment = demo1.Experiment() -session = demo1.Session() -scan = demo1.Scan() - -scan.drop() -session.drop() -experiment.drop() -subject.drop() - -subject.insert(dict(subject_id=1, - real_id="George", - species="monkey", - date_of_birth="2011-01-01", - sex="M", - caretaker="Arthur", - animal_notes="this is a test")) - -subject.insert(dict(subject_id=2, - real_id='1373', - date_of_birth="2014-08-01", - caretaker="Joe")) - - -def tup(*arg): - return dict(zip(subject.heading.names, arg)) - -subject.insert(tup(3, 'Alice', 'monkey', '2012-09-01', 'M', 'Joe', '')) -subject.insert(tup(4, 'Dennis', 'monkey', '2012-09-01', 'M', 'Joe', '')) -subject.insert(tup(5, 'Warren', 'monkey', '2012-09-01', 'M', 'Joe', '')) -subject.insert(tup(6, 'Franky', 'monkey', '2012-09-01', 'M', 'Joe', '')) -subject.insert(tup(7, 'Simon', 'monkey', '2012-09-01', 'F', 'Joe', '')) -subject.insert(tup(8, 'Ferocious', 'monkey', '2012-09-01', 'M', 'Joe', '')) -subject.insert(tup(9, 'Simon', 'monkey', '2012-09-01', 'm', 'Joe', '')) -subject.insert(tup(10, 'Ferocious', 'monkey', '2012-09-01', 'F', 'Joe', '')) -subject.insert(tup(11, 'Simon', 'monkey', '2012-09-01', 'm', 'Joe', '')) -subject.insert(tup(12, 'Ferocious', 'monkey', '2012-09-01', 'M', 'Joe', '')) -subject.insert(tup(13, 'Dauntless', 'monkey', '2012-09-01', 'F', 'Joe', '')) -subject.insert(tup(14, 'Dawn', 'monkey', '2012-09-01', 'F', 'Joe', '')) -subject.insert(tup(12430, 'C0430', 'mouse', '2012-09-01', 'M', 'Joe', '')) -subject.insert(tup(12431, 'C0431', 'mouse', '2012-09-01', 'F', 'Joe', '')) - -(subject & 'subject_id=1').fetch1() -print(subject) -print(subject.project()) -print(subject.project(name='real_id', dob='date_of_birth', sex='sex') & 'sex="M"') - -(subject & dict(subject_id=12431)).delete() -print(subject) - -experiment.insert(dict( - subject_id=1, - experiment=1, - experiment_date="2014-08-28", - experiment_notes="my first experiment")) - -experiment.insert(dict( - subject_id=1, - experiment=2, - experiment_date="2014-08-28", - experiment_notes="my second experiment")) - -experiment.insert(dict( - subject_id=2, - experiment=1, - experiment_date="2015-05-01" -)) - -print(experiment) -print(experiment * subject) -print(subject & experiment) -print(subject - experiment) - -session.insert(dict( - subject_id=1, - experiment=2, - session_id=1, - setup=0, - lens="20x" -)) - -scan.insert(dict( - subject_id=1, - experiment=2, - session_id=1, - scan_id=1, - depth=250, - wavelength=980, - mwatts=30.5 -)) - -print((scan * experiment) % ('wavelength->lambda', 'experiment_date')) - -# cleanup -scan.drop() -session.drop() -experiment.drop() -subject.drop() \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py index b875cbfb6..c3f68581a 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -18,7 +18,6 @@ __all__ = ['__author__', 'PREFIX', 'CONN_INFO'] - # Connection for testing CONN_INFO = dict( host=environ.get('DJ_TEST_HOST', 'localhost'), diff --git a/tests/schema.py b/tests/schema.py index 494e4066a..79b497f52 100644 --- a/tests/schema.py +++ b/tests/schema.py @@ -3,7 +3,6 @@ """ import random - import numpy as np import datajoint as dj from . import PREFIX, CONN_INFO @@ -50,20 +49,15 @@ class Subject(dj.Manual): [1552, '1552', 'mouse', '2015-06-15', ''], [1553, '1553', 'mouse', '2016-07-01', '']] - def _prepare(self): - self.insert(self.contents, ignore_errors=True) - @schema class Language(dj.Lookup): definition = """ # languages spoken by some of the developers - name : varchar(40) # name of the developer language : varchar(40) # language --- """ - contents = [ ('Fabian', 'English'), ('Edgar', 'English'), @@ -95,12 +89,12 @@ def _make_tuples(self, key): from datetime import date, timedelta users = User().fetch()['username'] random.seed('Amazing Seed') - for experiment_id in range(self.fake_experiments_per_subject): - self.insert1( - dict(key, - experiment_id=experiment_id, - experiment_date=(date.today() - timedelta(random.expovariate(1 / 30))).isoformat(), - username=random.choice(users))) + self.insert( + dict(key, + experiment_id=experiment_id, + experiment_date=(date.today() - timedelta(random.expovariate(1 / 30))).isoformat(), + username=random.choice(users)) + for experiment_id in range(self.fake_experiments_per_subject)) @schema @@ -117,12 +111,11 @@ def _make_tuples(self, key): populate with random data (pretend reading from raw files) """ random.seed('Amazing Seed') - for trial_id in range(10): - self.insert1( - dict(key, - trial_id=trial_id, - start_time=random.random() * 1e9 - )) + self.insert( + dict(key, + trial_id=trial_id, + start_time=random.random() * 1e9) + for trial_id in range(10)) @schema @@ -131,7 +124,7 @@ class Ephys(dj.Imported): -> Trial ---- sampling_frequency :double # (Hz) - duration :double # (s) + duration :decimal(7,3) # (s) """ class Channel(dj.Part): @@ -139,7 +132,8 @@ class Channel(dj.Part): -> Ephys channel :tinyint unsigned # channel number within Ephys ---- - voltage :longblob + voltage : longblob + current = null : longblob # optional current to test null handling """ def _make_tuples(self, key): @@ -151,13 +145,13 @@ def _make_tuples(self, key): sampling_frequency=6000, duration=np.minimum(2, random.expovariate(1))) self.insert1(row) - number_samples = round(row['duration'] * row['sampling_frequency']) + number_samples = int(row['duration'] * row['sampling_frequency']+0.5) sub = self.Channel() - for channel in range(2): - sub.insert1( - dict(key, - channel=channel, - voltage=np.float32(np.random.randn(number_samples)))) + sub.insert( + dict(key, + channel=channel, + voltage=np.float32(np.random.randn(number_samples))) + for channel in range(2)) @schema @@ -176,6 +170,7 @@ class UberTrash(dj.Manual): id : int --- """ + contents = [(1,)] @schema @@ -185,8 +180,4 @@ class UnterTrash(dj.Manual): my_id : int --- """ - - def _prepare(self): - UberTrash().insert1((1,), skip_duplicates=True) - self.insert1((1, 1), skip_duplicates=True) - self.insert1((1, 2), skip_duplicates=True) + contents = [(1, 1), (1, 2)] \ No newline at end of file diff --git a/tests/schema_advanced.py b/tests/schema_advanced.py new file mode 100644 index 000000000..7a1e18207 --- /dev/null +++ b/tests/schema_advanced.py @@ -0,0 +1,62 @@ +import datajoint as dj +from . import PREFIX, CONN_INFO + +schema = dj.schema(PREFIX + '_advanced', locals(), connection=dj.conn(**CONN_INFO)) + + +@schema +class Person(dj.Manual): + definition = """ + person_id : int + ---- + full_name : varchar(60) + sex : enum('M','F') + """ + + def fill(self): + """ + fill fake names from www.fakenamegenerator.com + """ + self.insert(( + (0, "May K. Hall", "F"), + (1, "Jeffrey E. Gillen", "M"), + (2, "Hanna R. Walters", "F"), + (3, "Russel S. James", "M"), + (4, "Robbin J. Fletcher", "F"), + (5, "Wade J. Sullivan", "M"), + (6, "Dorothy J. Chen", "F"), + (7, "Michael L. Kowalewski", "M"), + (8, "Kimberly J. Stringer", "F"), + (9, "Mark G. Hair", "M"), + (10, "Mary R. Thompson", "F"), + (11, "Graham C. Gilpin", "M"), + (12, "Nelda T. Ruggeri", "F"), + (13, "Bryan M. Cummings", "M"), + (14, "Sara C. Le", "F"), + (15, "Myron S. Jaramillo", "M") + )) + + +@schema +class Parent(dj.Manual): + definition = """ + -> Person + parent_sex : enum('M','F') + --- + parent -> Person + """ + + def fill(self): + + def make_parent(pid, parent): + return dict(person_id=pid, + parent=parent, + parent_sex=(Person() & dict(person_id=parent)).fetch['sex'][0]) + + self.insert(make_parent(*r) for r in ( + (0, 2), (0, 3), (1, 4), (1, 5), (2, 4), (2, 5), (3, 4), (3, 7), (4, 7), + (4, 8), (5, 9), (5, 10), (6, 9), (6, 10), (7, 11), (7, 12), (8, 11), (8, 14), + (9, 11), (9, 12), (10, 13), (10, 14), (11, 14), (11, 15), (12, 14), (12, 15))) + + + diff --git a/tests/schema_empty.py b/tests/schema_empty.py index ebe641160..31b923692 100644 --- a/tests/schema_empty.py +++ b/tests/schema_empty.py @@ -8,3 +8,10 @@ schema = dj.schema(PREFIX + '_test1', locals(), connection=dj.conn(**CONN_INFO)) + +@schema +class Ephys(dj.Imported): + definition = """ # This is already declare in ./schema.py + """ + +schema.spawn_missing_classes() # load the rest of the classes diff --git a/tests/schema_simple.py b/tests/schema_simple.py index e5478cfe9..de01badb1 100644 --- a/tests/schema_simple.py +++ b/tests/schema_simple.py @@ -47,7 +47,7 @@ def _make_tuples(self, key): sigma = random.lognormvariate(0, 4) n = random.randint(0, 10) self.insert1(dict(key, mu=mu, sigma=sigma, n=n)) - sub.insert((dict(key, id_c=j, value=random.normalvariate(mu, sigma)) for j in range(n))) + sub.insert(dict(key, id_c=j, value=random.normalvariate(mu, sigma)) for j in range(n)) @schema @@ -73,8 +73,7 @@ def _make_tuples(self, key): # make reference to a random tuple from L random.seed(str(key)) lookup = list(L().fetch.keys()) - for i in range(4): - self.insert1(dict(key, id_d=i, **random.choice(lookup))) + self.insert(dict(key, id_d=i, **random.choice(lookup)) for i in range(4)) @schema @@ -100,9 +99,7 @@ def _make_tuples(self, key): sub = E.F() references = list((B.C() & key).fetch.keys()) random.shuffle(references) - for i, ref in enumerate(references): - if random.getrandbits(1): - sub.insert1(dict(key, id_f=i, **ref)) + sub.insert(dict(key, id_f=i, **ref) for i, ref in enumerate(references) if random.getrandbits(1)) @schema @@ -112,10 +109,7 @@ class DataA(dj.Lookup): --- a : int """ - - @property - def contents(self): - yield from zip(range(5), range(5)) + contents = list(zip(range(5), range(5))) @schema @@ -125,7 +119,4 @@ class DataB(dj.Lookup): --- a : int """ - - @property - def contents(self): - yield from zip(range(5), range(5, 10)) + contents = list(zip(range(5), range(5, 10))) diff --git a/tests/test_autopopulate.py b/tests/test_autopopulate.py index 0c5a30dc3..83da52e24 100644 --- a/tests/test_autopopulate.py +++ b/tests/test_autopopulate.py @@ -34,12 +34,14 @@ def test_populate(self): # test restricted populate assert_false(self.trial, 'table already filled?') - restriction = dict(subject_id=self.subject.project().fetch()['subject_id'][0]) + restriction = dict(subject_id=self.subject.proj().fetch()['subject_id'][0]) + d = self.trial.connection.dependencies + d.load() self.trial.populate(restriction) assert_true(self.trial, 'table was not populated') - poprel = self.trial.populated_from - assert_equal(len(poprel & self.trial), len(poprel & restriction)) - assert_equal(len(poprel - self.trial), len(poprel - restriction)) + key_source = self.trial.key_source + assert_equal(len(key_source & self.trial), len(key_source & restriction)) + assert_equal(len(key_source - self.trial), len(key_source - restriction)) # test subtable populate assert_false(self.ephys) diff --git a/tests/test_cascading_delete.py b/tests/test_cascading_delete.py index 917748cec..2ff31f512 100644 --- a/tests/test_cascading_delete.py +++ b/tests/test_cascading_delete.py @@ -10,8 +10,8 @@ def setup(): """ class-level test setup. Executes before each test method. """ - A()._prepare() - L()._prepare() + A().insert(A.contents, skip_duplicates=True) + L().insert(L.contents, skip_duplicates=True) B().populate() D().populate() E().populate() @@ -70,7 +70,11 @@ def test_delete_lookup(): assert_false(dj.config['safemode'], 'safemode must be off for testing') assert_true(bool(L() and A() and B() and B.C() and D() and E() and E.F()), 'schema is not populated') - L().delete() + try: + L().delete() + except Exception as e: + raise + assert_false(bool(L() or D() or E() or E.F()), 'incomplete delete') A().delete() # delete all is necessary because delete L deletes from subtables. diff --git a/tests/test_declare.py b/tests/test_declare.py index 57d475efa..df6519e75 100644 --- a/tests/test_declare.py +++ b/tests/test_declare.py @@ -1,4 +1,3 @@ -import warnings from nose.tools import assert_true, assert_false, assert_equal, assert_list_equal, raises from . import schema import datajoint as dj @@ -13,6 +12,13 @@ class TestDeclare: + + @staticmethod + def test_schema_decorator(): + assert_true(issubclass(schema.Subject, dj.BaseRelation)) + assert_true(issubclass(schema.Subject, dj.Manual)) + assert_true(not issubclass(schema.Subject, dj.Part)) + @staticmethod def test_attributes(): # test autoincrement declaration @@ -45,25 +51,24 @@ def test_attributes(): ['subject_id', 'experiment_id', 'trial_id']) assert_list_equal(channel.heading.names, - ['subject_id', 'experiment_id', 'trial_id', 'channel', 'voltage']) + ['subject_id', 'experiment_id', 'trial_id', 'channel', 'voltage', 'current']) assert_list_equal(channel.primary_key, ['subject_id', 'experiment_id', 'trial_id', 'channel']) assert_true(channel.heading.attributes['voltage'].is_blob) def test_dependencies(self): - assert_equal(user.references, [experiment.full_table_name]) - assert_equal(experiment.referenced, [user.full_table_name]) - - assert_equal(subject.children, [experiment.full_table_name]) - assert_equal(experiment.parents, [subject.full_table_name]) + assert_equal(user.children(primary=False), [experiment.full_table_name]) + assert_equal(experiment.parents(primary=False), [user.full_table_name]) - assert_equal(experiment.children, [trial.full_table_name]) - assert_equal(trial.parents, [experiment.full_table_name]) + assert_equal(subject.children(primary=True), [experiment.full_table_name]) + assert_equal(experiment.parents(primary=True), [subject.full_table_name]) - assert_equal(trial.children, [ephys.full_table_name]) - assert_equal(ephys.parents, [trial.full_table_name]) + assert_equal(experiment.children(primary=True), [trial.full_table_name]) + assert_equal(trial.parents(primary=True), [experiment.full_table_name]) - assert_equal(ephys.children, [channel.full_table_name]) - assert_equal(channel.parents, [ephys.full_table_name]) + assert_equal(trial.children(primary=True), [ephys.full_table_name]) + assert_equal(ephys.parents(primary=True), [trial.full_table_name]) + assert_equal(ephys.children(primary=True), [channel.full_table_name]) + assert_equal(channel.parents(primary=True), [ephys.full_table_name]) diff --git a/tests/test_erd.py b/tests/test_erd.py new file mode 100644 index 000000000..0cd21892d --- /dev/null +++ b/tests/test_erd.py @@ -0,0 +1,49 @@ +from nose.tools import assert_false, assert_true +import datajoint as dj +from .schema_simple import A, B, D, E, L, schema + + +class TestERD: + + @staticmethod + def setup(): + """ + class-level test setup. Executes before each test method. + """ + + @staticmethod + def test_decorator(): + assert_true(issubclass(A, dj.BaseRelation)) + assert_false(issubclass(A, dj.Part)) + assert_true(B.database == schema.database) + assert_true(issubclass(B.C, dj.Part)) + assert_true(B.C.database == schema.database) + assert_true(B.C.master is B and E.F.master is E) + + @staticmethod + def test_dependencies(): + deps = schema.connection.dependencies + assert_true(all(cls.full_table_name in deps for cls in (A, B, B.C, D, E, E.F, L))) + assert_true(set(A().children()) == set([B.full_table_name, D.full_table_name])) + assert_true(set(D().parents(primary=True)) == set([A.full_table_name])) + assert_true(set(D().parents(primary=False)) == set([L.full_table_name])) + assert_true(set(deps.descendants(L.full_table_name)).issubset(cls.full_table_name for cls in (L, D, E, E.F))) + + @staticmethod + def test_erd(): + erd = dj.ERD(schema) + graph = erd._make_graph(prefix_module=False) + assert_true(set(cls.__name__ for cls in (A, B, D, E, L)).issubset(graph.nodes())) + pos = erd._layout(graph) + assert_true(set(cls.__name__ for cls in (A, B, D, E, L)).issubset(pos.keys())) + + @staticmethod + def test_erd_algebra(): + erd0 = dj.ERD(B) + erd1 = erd0 + 3 + erd2 = dj.ERD(E) - 3 + erd3 = erd1 * erd2 + assert_true(erd0.nodes_to_show == set(cls.full_table_name for cls in [B])) + assert_true(erd1.nodes_to_show == set(cls.full_table_name for cls in (B, B.C, E, E.F))) + assert_true(erd2.nodes_to_show == set(cls.full_table_name for cls in (A, B, D, E, L))) + assert_true(erd3.nodes_to_show == set(cls.full_table_name for cls in (B, E))) diff --git a/tests/test_fetch.py b/tests/test_fetch.py index e4757998c..7b3ced61a 100644 --- a/tests/test_fetch.py +++ b/tests/test_fetch.py @@ -11,108 +11,99 @@ class TestFetch: def __init__(self): self.subject = schema.Subject() - self.lang = schema.Language() def test_getitem(self): """Testing Fetch.__getitem__""" - - np.testing.assert_array_equal(sorted(self.subject.project().fetch(), key=itemgetter(0)), + np.testing.assert_array_equal(sorted(self.subject.proj().fetch(), key=itemgetter(0)), sorted(self.subject.fetch[dj.key], key=itemgetter(0)), 'Primary key is not returned correctly') tmp = self.subject.fetch(order_by=['subject_id']) - for column, field in zip(self.subject.fetch[:], [e[0] for e in tmp.dtype.descr]): - np.testing.assert_array_equal(sorted(tmp[field]), sorted(column), 'slice : does not work correctly') - subject_notes, key, real_id = self.subject.fetch['subject_notes', dj.key, 'real_id'] - # + np.testing.assert_array_equal(sorted(subject_notes), sorted(tmp['subject_notes'])) np.testing.assert_array_equal(sorted(real_id), sorted(tmp['real_id'])) np.testing.assert_array_equal(sorted(key, key=itemgetter(0)), - sorted(self.subject.project().fetch(), key=itemgetter(0))) - - for column, field in zip(self.subject.fetch['subject_id'::2], [e[0] for e in tmp.dtype.descr][::2]): - np.testing.assert_array_equal(sorted(tmp[field]), sorted(column), 'slice : does not work correctly') + sorted(self.subject.proj().fetch(), key=itemgetter(0))) def test_getitem_for_fetch1(self): """Testing Fetch1.__getitem__""" assert_true((self.subject & "subject_id=10").fetch1['subject_id'] == 10) assert_equal((self.subject & "subject_id=10").fetch1['subject_id', 'species'], (10, 'monkey')) - assert_equal((self.subject & "subject_id=10").fetch1['subject_id':'species'], - (10, 'Curious George')) def test_order_by(self): """Tests order_by sorting order""" - langs = schema.Language.contents + languages = schema.Language.contents for ord_name, ord_lang in itertools.product(*2 * [['ASC', 'DESC']]): cur = self.lang.fetch.order_by('name ' + ord_name, 'language ' + ord_lang)() - langs.sort(key=itemgetter(1), reverse=ord_lang == 'DESC') - langs.sort(key=itemgetter(0), reverse=ord_name == 'DESC') - for c, l in zip(cur, langs): + languages.sort(key=itemgetter(1), reverse=ord_lang == 'DESC') + languages.sort(key=itemgetter(0), reverse=ord_name == 'DESC') + for c, l in zip(cur, languages): assert_true(np.all(cc == ll for cc, ll in zip(c, l)), 'Sorting order is different') def test_order_by_default(self): """Tests order_by sorting order with defaults""" - langs = schema.Language.contents - + languages = schema.Language.contents cur = self.lang.fetch.order_by('language', 'name DESC')() - langs.sort(key=itemgetter(0), reverse=True) - langs.sort(key=itemgetter(1), reverse=False) + languages.sort(key=itemgetter(0), reverse=True) + languages.sort(key=itemgetter(1), reverse=False) - for c, l in zip(cur, langs): + for c, l in zip(cur, languages): assert_true(np.all([cc == ll for cc, ll in zip(c, l)]), 'Sorting order is different') def test_order_by_direct(self): """Tests order_by sorting order passing it to __call__""" - langs = schema.Language.contents - + languages = schema.Language.contents cur = self.lang.fetch(order_by=['language', 'name DESC']) - langs.sort(key=itemgetter(0), reverse=True) - langs.sort(key=itemgetter(1), reverse=False) - for c, l in zip(cur, langs): + languages.sort(key=itemgetter(0), reverse=True) + languages.sort(key=itemgetter(1), reverse=False) + for c, l in zip(cur, languages): assert_true(np.all([cc == ll for cc, ll in zip(c, l)]), 'Sorting order is different') def test_limit(self): """Test the limit function """ - langs = schema.Language.contents + languages = schema.Language.contents cur = self.lang.fetch.limit(4)(order_by=['language', 'name DESC']) - langs.sort(key=itemgetter(0), reverse=True) - langs.sort(key=itemgetter(1), reverse=False) + languages.sort(key=itemgetter(0), reverse=True) + languages.sort(key=itemgetter(1), reverse=False) assert_equal(len(cur), 4, 'Length is not correct') - for c, l in list(zip(cur, langs))[:4]: + for c, l in list(zip(cur, languages))[:4]: assert_true(np.all([cc == ll for cc, ll in zip(c, l)]), 'Sorting order is different') def test_limit_offset(self): """Test the limit and offset functions together""" - langs = schema.Language.contents + languages = schema.Language.contents cur = self.lang.fetch(offset=2, limit=4, order_by=['language', 'name DESC']) - langs.sort(key=itemgetter(0), reverse=True) - langs.sort(key=itemgetter(1), reverse=False) + languages.sort(key=itemgetter(0), reverse=True) + languages.sort(key=itemgetter(1), reverse=False) assert_equal(len(cur), 4, 'Length is not correct') - for c, l in list(zip(cur, langs[2:6])): + for c, l in list(zip(cur, languages[2:6])): assert_true(np.all([cc == ll for cc, ll in zip(c, l)]), 'Sorting order is different') def test_iter(self): """Test iterator""" - langs = schema.Language.contents - + languages = schema.Language.contents cur = self.lang.fetch.order_by('language', 'name DESC') - langs.sort(key=itemgetter(0), reverse=True) - langs.sort(key=itemgetter(1), reverse=False) - for (name, lang), (tname, tlang) in list(zip(cur, langs)): + languages.sort(key=itemgetter(0), reverse=True) + languages.sort(key=itemgetter(1), reverse=False) + for (name, lang), (tname, tlang) in list(zip(cur, languages)): assert_true(name == tname and lang == tlang, 'Values are not the same') + # now as dict + cur = self.lang.fetch.as_dict.order_by('language', 'name DESC') + for row, (tname, tlang) in list(zip(cur, languages)): + assert_true(row['name'] == tname and row['language'] == tlang, 'Values are not the same') def test_keys(self): """test key iterator""" - langs = schema.Language.contents - langs.sort(key=itemgetter(0), reverse=True) - langs.sort(key=itemgetter(1), reverse=False) + languages = schema.Language.contents + languages.sort(key=itemgetter(0), reverse=True) + languages.sort(key=itemgetter(1), reverse=False) cur = self.lang.fetch.order_by('language', 'name DESC')['name', 'language'] cur2 = list(self.lang.fetch.order_by('language', 'name DESC').keys()) @@ -120,22 +111,14 @@ def test_keys(self): for c, c2 in zip(zip(*cur), cur2): assert_true(c == tuple(c2.values()), 'Values are not the same') - def test_fetch1(self): + def test_fetch1_step1(self): key = {'name': 'Edgar', 'language': 'Japanese'} true = schema.Language.contents[-1] - dat = (self.lang & key).fetch1() for k, (ke, c) in zip(true, dat.items()): assert_true(k == c == (self.lang & key).fetch1[ke], 'Values are not the same') - def test_copy(self): - """Test whether modifications copy the object""" - f = self.lang.fetch - f2 = f.order_by('name') - assert_true(f.behavior['order_by'] is None and len(f2.behavior['order_by']) == 1, - 'Object was not copied') - def test_repr(self): """Test string representation of fetch, returning table preview""" repr = self.subject.fetch.__repr__() @@ -164,11 +147,11 @@ def test_asdict_with_call(self): def test_offset(self): """Tests offset""" cur = self.lang.fetch.limit(4).offset(1)(order_by=['language', 'name DESC']) - langs = self.lang.contents - langs.sort(key=itemgetter(0), reverse=True) - langs.sort(key=itemgetter(1), reverse=False) + languages = self.lang.contents + languages.sort(key=itemgetter(0), reverse=True) + languages.sort(key=itemgetter(1), reverse=False) assert_equal(len(cur), 4, 'Length is not correct') - for c, l in list(zip(cur, langs[1:]))[:4]: + for c, l in list(zip(cur, languages[1:]))[:4]: assert_true(np.all([cc == ll for cc, ll in zip(c, l)]), 'Sorting order is different') def test_limit_warning(self): @@ -179,9 +162,14 @@ def test_limit_warning(self): def test_len(self): """Tests __len__""" - assert_true(len(self.lang.fetch) == len(self.lang),'__len__ is not behaving properly') + assert_true(len(self.lang.fetch) == len(self.lang), '__len__ is not behaving properly') + + @raises(dj.DataJointError) + def test_fetch1_step2(self): + """Tests whether fetch1 raises error""" + self.lang.fetch1() @raises(dj.DataJointError) - def test_fetch1(self): + def test_fetch1_step3(self): """Tests whether fetch1 raises error""" - self.lang.fetch1() \ No newline at end of file + self.lang.fetch1['name'] diff --git a/tests/test_foreign_keys.py b/tests/test_foreign_keys.py new file mode 100644 index 000000000..9a9459449 --- /dev/null +++ b/tests/test_foreign_keys.py @@ -0,0 +1,13 @@ +from nose.tools import assert_equal + +from . import schema_advanced + + +def test_aliased_fk(): + person = schema_advanced.Person() + parent = schema_advanced.Parent() + person.fill() + parent.fill() + parents = person*parent*person.proj(parent_name='full_name', parent='person_id') + parents &= dict(full_name="May K. Hall") + assert_equal(set(parents.fetch['parent_name']), {'Hanna R. Walters', 'Russel S. James'}) diff --git a/tests/test_relation.py b/tests/test_relation.py index b5aa7391e..d2db8258b 100644 --- a/tests/test_relation.py +++ b/tests/test_relation.py @@ -1,12 +1,8 @@ from inspect import getmembers import re -import itertools -from numpy.testing import assert_array_equal import numpy as np -from nose.tools import assert_raises, assert_equal, assert_not_equal, \ - assert_false, assert_true, assert_list_equal, \ - assert_tuple_equal, assert_dict_equal, raises +from nose.tools import assert_equal, assert_not_equal, assert_true, assert_list_equal, raises from . import schema from pymysql import IntegrityError, ProgrammingError import datajoint as dj @@ -142,9 +138,15 @@ def test_table_regexp(self): """Test whether table names are matched by regular expressions""" tiers = [dj.Imported, dj.Manual, dj.Lookup, dj.Computed] for name, rel in getmembers(schema, relation_selector): - assert_true(re.match(rel._regexp, rel().table_name), + assert_true(re.match(rel.tier_regexp, rel.table_name), 'Regular expression does not match for {name}'.format(name=name)) + for tier in tiers: + assert_true(issubclass(rel, tier) or not re.match(tier.tier_regexp, rel.table_name), + 'Regular expression matches for {name} but should not'.format(name=name)) - for tier in itertools.filterfalse(lambda t: issubclass(rel, t), tiers): - assert_false(re.match(tier._regexp, rel().table_name), - 'Regular expression matches for {name} but should not'.format(name=name)) + def test_table_size(self): + """test getting the size of the table and its indices in bytes""" + assert_true(self.experiment.size_on_disk > 100) + + def test_repr_html(self): + assert_true(self.ephys._repr_html_().strip().startswith("

")) diff --git a/tests/test_relation_u.py b/tests/test_relation_u.py new file mode 100644 index 000000000..25a569862 --- /dev/null +++ b/tests/test_relation_u.py @@ -0,0 +1,45 @@ +from nose.tools import assert_equal, assert_true, raises, assert_list_equal +from . import schema +import datajoint as dj + + +class TestU: + """ + Test base relations: insert, delete + """ + + def __init__(self): + self.user = schema.User() + self.language = schema.Language() + self.subject = schema.Subject() + self.experiment = schema.Experiment() + self.trial = schema.Trial() + self.ephys = schema.Ephys() + self.channel = schema.Ephys.Channel() + self.img = schema.Image() + self.trash = schema.UberTrash() + + def test_restriction(self): + rel = dj.U('language') & self.language + assert_list_equal(rel.heading.names, ['language']) + languages = rel.fetch() + + def test_ineffective_restriction(self): + rel = self.language & dj.U('language') + assert_true(rel.make_sql() == self.language.make_sql()) + + def test_join(self): + rel = self.experiment*dj.U('experiment_date') + assert_equal(self.experiment.primary_key, ['subject_id', 'experiment_id']) + assert_equal(rel.primary_key, self.experiment.primary_key + ['experiment_date']) + + rel = dj.U('experiment_date')*self.experiment + assert_equal(self.experiment.primary_key, ['subject_id', 'experiment_id']) + assert_equal(rel.primary_key, self.experiment.primary_key + ['experiment_date']) + + @raises(dj.DataJointError) + def test_invalid_join(self): + rel = dj.U('language') * dict(language="English") + + # def test_aggregations(self): + # rel = dj.U('language').aggregate(n='count(*)') \ No newline at end of file diff --git a/tests/test_relational_operand.py b/tests/test_relational_operand.py index 928f05ba1..96c8bb39e 100644 --- a/tests/test_relational_operand.py +++ b/tests/test_relational_operand.py @@ -4,15 +4,15 @@ assert_tuple_equal, assert_dict_equal, raises import datajoint as dj from .schema_simple import A, B, D, E, L, DataA, DataB -import datetime from .schema import Experiment + def setup(): """ module-level test setup """ - A()._prepare() - L()._prepare() + A().insert(A.contents, skip_duplicates=True) + L().insert(L.contents, skip_duplicates=True) B().populate() D().populate() E().populate() @@ -69,7 +69,7 @@ def test_join(): 'incorrect join primary_key') # test renamed join - x = B().project(i='id_a') # rename the common attribute to achieve full cartesian product + x = B().proj(i='id_a') # rename the common attribute to achieve full cartesian product y = D() rel = x*y assert_equal(len(rel), len(x)*len(y), @@ -79,8 +79,8 @@ def test_join(): assert_equal(set(x.primary_key).union(y.primary_key), set(rel.primary_key), 'incorrect join primary_key') - # test the % notation - x = B() % ['id_a->a'] + # test the -> notation + x = B().proj('id_a->a') y = D() rel = x*y assert_equal(len(rel), len(x)*len(y), @@ -106,10 +106,10 @@ def test_join(): @staticmethod def test_project(): - x = A().project(a='id_a') # rename + x = A().proj(a='id_a') # rename assert_equal(x.heading.names, ['a'], 'renaming does not work') - x = A().project(a='(id_a)') # extend + x = A().proj(a='(id_a)') # extend assert_equal(set(x.heading.names), set(('id_a', 'a')), 'extend does not work') @@ -117,12 +117,12 @@ def test_project(): cond = L() & 'cond_in_l' assert_equal(len(D() & cond) + len(D() - cond), len(D()), 'failed semijoin or antijoin') - assert_equal(len((D() & cond).project()), len((D() & cond)), + assert_equal(len((D() & cond).proj()), len((D() & cond)), 'projection failed: altered its argument''s cardinality') @staticmethod def test_aggregate(): - x = B().aggregate(B.C(), 'n', count='count(id_c)', mean='avg(value)', max='max(value)') + x = B().aggregate(B.C(), 'n', count='count(id_c)', mean='avg(value)', max='max(value)', keep_all_rows=True) assert_equal(len(x), len(B())) for n, count, mean, max_, key in zip(*x.fetch['n', 'count', 'mean', 'max', dj.key]): assert_equal(n, count, 'aggregation failed (count)') @@ -186,7 +186,7 @@ def test_datetime(): @staticmethod def test_join_project_optimization(): """Test optimization for join of projected relations with matching non-primary key""" - print(DataA().project() * DataB().proj()) + print(DataA().proj() * DataB().proj()) print(DataA()) - assert_true(len(DataA().project() * DataB().proj()) == len(DataA()) == len(DataB()), + assert_true(len(DataA().proj() * DataB().proj()) == len(DataA()) == len(DataB()), "Join of projected relations does not work") diff --git a/tests/test_schema.py b/tests/test_schema.py index 9e25cf559..786b75b81 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -5,6 +5,7 @@ from inspect import getmembers from . import schema from . import schema_empty +from . import PREFIX, CONN_INFO def relation_selector(attr): @@ -30,3 +31,47 @@ def test_namespace_population(): assert_true(hasattr(rel, name_part), '{name_part} not found in {name}'.format(name_part=name_part, name=name)) assert_true(getattr(rel, name_part).__base__ is dj.Part, 'Wrong tier for {name}'.format(name=name_part)) + + +@raises(dj.DataJointError) +def test_undecorated_table(): + """ + Undecorated user relation classes should raise an informative exception upon first use + """ + + class UndecoratedClass(dj.Manual): + definition = "" + + a = UndecoratedClass() + a.full_table_name + + +@raises(dj.DataJointError) +def test_reject_decorated_part(): + """ + Decorating a dj.Part table should raise an informative exception. + """ + + @schema.schema + class A(dj.Manual): + definition = ... + + @schema.schema + class B(dj.Part): + definition = ... + + +@raises(dj.DataJointError) +def test_unauthorized_database(): + """ + an attempt to create a database to which user has no privileges should raise an informative exception. + """ + dj.schema('unauthorized_schema', locals(), connection=dj.conn(**CONN_INFO)) + + +def test_drop_database(): + schema = dj.schema(PREFIX + '_drop_test', locals(), connection=dj.conn(**CONN_INFO)) + assert_true(schema.exists) + schema.drop() + assert_false(schema.exists) + schema.drop() # should do nothing \ No newline at end of file diff --git a/tutorial-notebooks/Primer00.ipynb b/tutorial-notebooks/Primer00.ipynb new file mode 100644 index 000000000..622362a93 --- /dev/null +++ b/tutorial-notebooks/Primer00.ipynb @@ -0,0 +1,86 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "DataJoint Primer. Section 0. \n", + "# Setup" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Database access\n", + "This section pertains to system administrators rather than users of DataJoint.\n", + "\n", + "If you are collaborating with a team who already use DataJoint, simply request your database credentials and skip this section.\n", + "\n", + "### Hosting the database server\n", + "DataJoint relies on a MySQL-compatible database server (e.g. MySQL, MariaDB, Amazon Aurora) for all its data operations. \n", + "We advise that you turn to an IT professional for configuring your server, especially for a large distributed team.\n", + "However, numerous online resources exist to help you accomplish this task.\n", + "\n", + "The first decision you need to make is where this server will be hosted and how it will be administered. \n", + "The server may be hosted on your personal computer or on a dedicated machine in your lab. \n", + "Increasingly, many teams make use of cloud-hosted database services such as Amazon's [RDS](https://aws.amazon.com/rds/), which allow great flexibility and easy admininstration.\n", + "\n", + "For more info see http://datajoint.github.io/installation/.\n", + "\n", + "### Database server configuration\n", + "Typical default configurations of MySQL servers is not adequate and needs to be adjusted to allow for stricter data checks and larger data packet sizes. \n", + "\n", + "### User account and privileges\n", + "Create a user account on the MySQL server. \n", + "\n", + "For example, if your username is `alice`, the SQL code for this step is \n", + "```SQL\n", + "CREATE USER 'alice'@'%' IDENTIFIED BY 'alices-secret-password';\n", + "```\n", + "\n", + "Teams that use DataJoint typically divide their data into _schemas_ grouped together by common prefixes. For example, a lab may have a collection of schemas that begin with `common_`. \n", + "Some common processing may be organized into several schemas that begin with `pipeline_`.\n", + "Typically each user has all privileges to schemas that begin with her username.\n", + "\n", + "For example, `alice` may have privileges to select and insert data from the common schemas (but not create new tables), and have all privileges to the pipeline schemas.\n", + "\n", + "Then the SQL code to grant her priviges might look like\n", + "```SQL\n", + "GRANT SELECT, INSERT ON `common\\_%`.* TO 'alice'@'%';\n", + "GRANT ALL PRIVILEGES ON `pipeline\\_%`.* TO 'alice'@'%';\n", + "GRANT ALL PRIVILEGES ON `alice\\_%`.* TO 'alice'@'%';\n", + "```\n", + " " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[Next](Primer01.ipynb)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.5.0" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tutorial-notebooks/Primer01.ipynb b/tutorial-notebooks/Primer01.ipynb new file mode 100644 index 000000000..246a8660d --- /dev/null +++ b/tutorial-notebooks/Primer01.ipynb @@ -0,0 +1,154 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[Prev](Primer00.ipynb)\n", + "\n", + "DataJoint Primer. Section 1.\n", + "# Connect" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Install DataJoint\n", + "Install the `datajoint` libraries and its dependencies.\n", + "The installation instructions are found at http://datajoint.github.io/installation/\n", + "\n", + "If you run into any issues, submit them https://github.com/datajoint/datajoint-python/issues" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": true + }, + "source": [ + "## Configure connection\n", + "Upon importing `datajoint` in Python for the first time, it will prompt you to specify your database credentials:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "DataJoint 0.2.1 (June 1, 2016)\n", + "Cannot find configuration settings. Using default configuration. To change that, either\n", + " * modify the local copy of dj_local_conf.json that datajoint just saved for you\n", + " * put a file named .datajoint_config.json with the same configuration format in your home\n", + " * specify the environment variables DJ_USER, DJ_HOST, DJ_PASS\n", + " \n" + ] + } + ], + "source": [ + "import datajoint" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Follow these instructions to specify your database credentials (host address, username, and password), which you should have received from your database administrator.\n", + "\n", + "If you specify your credentials in a config file, be careful not to share your it with others when sharing your code, for example. If you use a version control system such as `git` or `svn`, be sure to exclude the config file." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "When you restart the Python kernel and import `datajoint` again, it will indicate which configuration was used." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "DataJoint 0.2.1 (June 1, 2016)\n", + "Loading local settings from /Users/dimitri/.datajoint_config.json\n" + ] + } + ], + "source": [ + "import datajoint as dj" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You may test the database connection using the `conn` function. If you supplied correct credentials you should get the status of your connection, e.g.:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "data": { + "text/plain": [ + "DataJoint connection (connected) dimitri@ninai.cluster-chjk7zcxhsgn.us-east-1.rds.amazonaws.com:3306" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dj.conn()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": true + }, + "source": [ + "[Next](Primer02.ipynb)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.5.0" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tutorial-notebooks/Primer02.ipynb b/tutorial-notebooks/Primer02.ipynb new file mode 100644 index 000000000..b3a5b3179 --- /dev/null +++ b/tutorial-notebooks/Primer02.ipynb @@ -0,0 +1,455 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[Prev](Primer01.ipynb)\n", + "\n", + "DataJoint Primer. Section 2.\n", + "# Defining a table" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Some terminology\n", + "DataJoint relies the *Relational Data Model* for organizing its data.\n", + "\n", + "The crux of the Relational Data Model is that all data are stored in simple **tables** with **rows** and **columns**.\n", + "\n", + "In more theoretical and academic settings, he tables are called **relations**, columns are **attributes**, and rows are **tuples**. We may use boths sets of terms interchangeably.\n", + "\n", + "The tables are simple and may not be nested and can always be accessed directly without following paths of links.\n", + "\n", + "Groups of related tables are called **schemas**. The word **database** is synonymous to **schema** in this context and we may use them interchangeably.\n", + "\n", + "To summarize, when speaking about databases\n", + "> \"schema\" == \"database\"\n", + "\n", + "> \"relation\" == \"table\"\n", + "\n", + "> \"attribute\" == \"column\"\n", + "\n", + "> \"tuple\" == \"row\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create a schema\n", + "As described in Primer 0, I have all privileges to any schema that starts with `dimitri_`.\n", + "\n", + "Therefore, let me create a new schema for data about things in our lab. " + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "DataJoint 0.2.1 (June 1, 2016)\n", + "Loading local settings from /Users/dimitri/.datajoint_config.json\n" + ] + } + ], + "source": [ + "import datajoint as dj\n", + "schema = dj.schema('dimitri_lab', locals())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create a table\n", + "Now I will create a table in the database `dimitri_lab` by declaring a special Python class. For example, let's create the table `Person` to store information about people who work in our lab." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "@schema\n", + "class Person(dj.Manual):\n", + " definition = \"\"\" # members of the lab\n", + " username : char(16) # short unique name\n", + " ----\n", + " full_name : varchar(25)\n", + " \"\"\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This defines a table with two columns `username` and `full_name`. \n", + "## Inserting data manually\n", + "We can now enter single rows of data into `Person` one at a time using the `insert1` method:" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "Person().insert1(('dimitri', 'Dimitri Yatsenko'))\n", + "Person().insert1(dict(username='shan', full_name=\"Shan Shen\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Or multiple rows at once using the `insert` method:" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "Person().insert((\n", + " ('andreas', 'Andreas S. Tolias'),\n", + " ('jake', 'Jacob Reimer'),\n", + " ('fabee', 'Fabian Sinz'),\n", + " {'full_name': 'Edgar Y. Walker', 'username': 'edgar'}\n", + " ))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Previewing data\n", + "You may get a preview of the contents of the table from the instance of its class:" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "data": { + "text/html": [ + "

members of the lab

\n", + "
\n", + " \n", + " \n", + " \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
usernamefull_name
andreasAndreas S. Tolias
dimitriDimitri Yatsenko
edgarEdgar Y. Walker
fabeeFabian Sinz
jakeJacob Reimer
shanShan Shen
\n", + "

6 tuples

\n", + " " + ], + "text/plain": [ + "username full_name \n", + "+----------+ +------------+\n", + "andreas Andreas S. Tol\n", + "dimitri Dimitri Yatsen\n", + "edgar Edgar Y. Walke\n", + "fabee Fabian Sinz \n", + "jake Jacob Reimer \n", + "shan Shan Shen \n", + " (6 tuples)" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Person()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "DataJoint will not allow entering two rows with the same `username` because it is the *primary key* of `Person`.\n", + "\n", + "The primary key attributes are listed above the separator `---` in the `definition` string. Every row must have unique values in the primary key attributes." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "ename": "IntegrityError", + "evalue": "(1062, \"Duplicate entry 'jake' for key 'PRIMARY'\")", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mIntegrityError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mPerson\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minsert1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'jake'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'Jacob W'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m/Users/dimitri/dev/datajoint-python/datajoint/base_relation.py\u001b[0m in \u001b[0;36minsert1\u001b[0;34m(self, row, **kwargs)\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0mparam\u001b[0m \u001b[0mrow\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mData\u001b[0m \u001b[0mrecord\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ma\u001b[0m \u001b[0mMapping\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mlike\u001b[0m \u001b[0ma\u001b[0m \u001b[0mdict\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0ma\u001b[0m \u001b[0mlist\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mtuple\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mordered\u001b[0m \u001b[0mvalues\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 124\u001b[0m \"\"\"\n\u001b[0;32m--> 125\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minsert\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrow\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 126\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 127\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0minsert\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrows\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreplace\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mignore_errors\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mskip_duplicates\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/Users/dimitri/dev/datajoint-python/datajoint/base_relation.py\u001b[0m in \u001b[0;36minsert\u001b[0;34m(self, rows, replace, ignore_errors, skip_duplicates)\u001b[0m\n\u001b[1;32m 242\u001b[0m \u001b[0mfields\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'`,`'\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfield_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 243\u001b[0m placeholders=','.join('(' + ','.join(row['placeholders']) + ')' for row in rows)),\n\u001b[0;32m--> 244\u001b[0;31m args=list(itertools.chain.from_iterable((v for v in r['values'] if v is not None) for r in rows)))\n\u001b[0m\u001b[1;32m 245\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 246\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdelete_quick\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/Users/dimitri/dev/datajoint-python/datajoint/connection.py\u001b[0m in \u001b[0;36mquery\u001b[0;34m(self, query, args, as_dict)\u001b[0m\n\u001b[1;32m 106\u001b[0m \u001b[0;31m# Log the query\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 107\u001b[0m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdebug\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Executing SQL:\"\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mquery\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m300\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 108\u001b[0;31m \u001b[0mcur\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexecute\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mquery\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 109\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 110\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mcur\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.5/site-packages/pymysql/cursors.py\u001b[0m in \u001b[0;36mexecute\u001b[0;34m(self, query, args)\u001b[0m\n\u001b[1;32m 156\u001b[0m \u001b[0mquery\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmogrify\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mquery\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 157\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 158\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_query\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mquery\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 159\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_executed\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mquery\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 160\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.5/site-packages/pymysql/cursors.py\u001b[0m in \u001b[0;36m_query\u001b[0;34m(self, q)\u001b[0m\n\u001b[1;32m 306\u001b[0m \u001b[0mconn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_db\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 307\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_last_executed\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mq\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 308\u001b[0;31m \u001b[0mconn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mquery\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mq\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 309\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_do_get_result\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 310\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrowcount\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.5/site-packages/pymysql/connections.py\u001b[0m in \u001b[0;36mquery\u001b[0;34m(self, sql, unbuffered)\u001b[0m\n\u001b[1;32m 818\u001b[0m \u001b[0msql\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msql\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mencode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mencoding\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'surrogateescape'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 819\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_execute_command\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mCOMMAND\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mCOM_QUERY\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msql\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 820\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_affected_rows\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_read_query_result\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0munbuffered\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0munbuffered\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 821\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_affected_rows\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 822\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.5/site-packages/pymysql/connections.py\u001b[0m in \u001b[0;36m_read_query_result\u001b[0;34m(self, unbuffered)\u001b[0m\n\u001b[1;32m 1000\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1001\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mMySQLResult\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1002\u001b[0;31m \u001b[0mresult\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1003\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1004\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mserver_status\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.5/site-packages/pymysql/connections.py\u001b[0m in \u001b[0;36mread\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1283\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mread\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1284\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1285\u001b[0;31m \u001b[0mfirst_packet\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconnection\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_read_packet\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1286\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1287\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mfirst_packet\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_ok_packet\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.5/site-packages/pymysql/connections.py\u001b[0m in \u001b[0;36m_read_packet\u001b[0;34m(self, packet_type)\u001b[0m\n\u001b[1;32m 964\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 965\u001b[0m \u001b[0mpacket\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpacket_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbuff\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mencoding\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 966\u001b[0;31m \u001b[0mpacket\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcheck_error\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 967\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mpacket\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 968\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.5/site-packages/pymysql/connections.py\u001b[0m in \u001b[0;36mcheck_error\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 392\u001b[0m \u001b[0merrno\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread_uint16\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 393\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mDEBUG\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"errno =\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0merrno\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 394\u001b[0;31m \u001b[0merr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mraise_mysql_exception\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_data\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 395\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 396\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdump\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.5/site-packages/pymysql/err.py\u001b[0m in \u001b[0;36mraise_mysql_exception\u001b[0;34m(data)\u001b[0m\n\u001b[1;32m 118\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mraise_mysql_exception\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 119\u001b[0m \u001b[0merrinfo\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_get_error_info\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 120\u001b[0;31m \u001b[0m_check_mysql_exception\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0merrinfo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m/usr/local/lib/python3.5/site-packages/pymysql/err.py\u001b[0m in \u001b[0;36m_check_mysql_exception\u001b[0;34m(errinfo)\u001b[0m\n\u001b[1;32m 110\u001b[0m \u001b[0merrorclass\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0merror_map\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0merrno\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 111\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0merrorclass\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 112\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0merrorclass\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0merrno\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0merrorvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 113\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 114\u001b[0m \u001b[0;31m# couldn't find the right error number\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mIntegrityError\u001b[0m: (1062, \"Duplicate entry 'jake' for key 'PRIMARY'\")" + ] + } + ], + "source": [ + "Person().insert1(('jake', 'Jacob W'))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Deleting all rows from a table" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `delete` method deletes all rows from the table:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The contents of the following tables are about to be deleted:\n", + "`dimitri_lab`.`person` (6 tuples)\n", + "Proceed? [yes, No]: yes\n", + "Done\n" + ] + } + ], + "source": [ + "Person().delete()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "data": { + "text/html": [ + "

members of the lab

\n", + "
\n", + " \n", + " \n", + " \n", + "
usernamefull_name
\n", + "

0 tuples

\n", + " " + ], + "text/plain": [ + "username full_name \n", + "+----------+ +-----------+\n", + " (0 tuples)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Person()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Dropping the table\n", + "Deleting all rows from the table still leaves the empy table in the database.\n", + "The `drop` method removes the table from the database:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "`dimitri_lab`.`person` (0 tuples)\n", + "Proceed? [yes, No]: yes\n", + "Tables dropped. Restart kernel.\n" + ] + } + ], + "source": [ + "Person().drop()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "It suggests to restart the Python kernel. This is because the class `Person` remains but its table is no longer defined in the database. Using the `Person` class after dropping its table will cause errors (see below). When the class is defined again, its table will be created again from its `definition` string. However, if the table is not dropped, the new class will use the existing table and its `definition` string will have no effect." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "ename": "DataJointError", + "evalue": "The table is not defined.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mDataJointError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/usr/local/lib/python3.5/site-packages/IPython/core/formatters.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, obj)\u001b[0m\n\u001b[1;32m 341\u001b[0m \u001b[0mmethod\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_safe_get_formatter_method\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprint_method\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 342\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmethod\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 343\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mmethod\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 344\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 345\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/Users/dimitri/dev/datajoint-python/datajoint/relational_operand.py\u001b[0m in \u001b[0;36m_repr_html_\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 360\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_repr_html_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 361\u001b[0m \u001b[0mlimit\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'display.limit'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 362\u001b[0;31m \u001b[0mrel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproj\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mheading\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnon_blobs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# project out blobs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 363\u001b[0m \u001b[0mcolumns\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mheading\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnames\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 364\u001b[0m \u001b[0minfo\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mheading\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtable_info\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/Users/dimitri/dev/datajoint-python/datajoint/base_relation.py\u001b[0m in \u001b[0;36mheading\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_heading\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mHeading\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# instance-level heading\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 52\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_heading\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# lazy loading of heading\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 53\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_heading\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minit_from_database\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconnection\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdatabase\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtable_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 54\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_heading\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/Users/dimitri/dev/datajoint-python/datajoint/heading.py\u001b[0m in \u001b[0;36minit_from_database\u001b[0;34m(self, conn, database, table_name)\u001b[0m\n\u001b[1;32m 121\u001b[0m table_name=table_name, database=database), as_dict=True).fetchone()\n\u001b[1;32m 122\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0minfo\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 123\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mDataJointError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'The table is not defined.'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 124\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtable_info\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlower\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mv\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m \u001b[0;32min\u001b[0m \u001b[0minfo\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 125\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mDataJointError\u001b[0m: The table is not defined." + ] + }, + { + "ename": "DataJointError", + "evalue": "The table is not defined.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mDataJointError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/usr/local/lib/python3.5/site-packages/IPython/core/formatters.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, obj)\u001b[0m\n\u001b[1;32m 697\u001b[0m \u001b[0mtype_pprinters\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtype_printers\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 698\u001b[0m deferred_pprinters=self.deferred_printers)\n\u001b[0;32m--> 699\u001b[0;31m \u001b[0mprinter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpretty\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 700\u001b[0m \u001b[0mprinter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mflush\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 701\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mstream\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgetvalue\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.5/site-packages/IPython/lib/pretty.py\u001b[0m in \u001b[0;36mpretty\u001b[0;34m(self, obj)\u001b[0m\n\u001b[1;32m 381\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcallable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmeth\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 382\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmeth\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcycle\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 383\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_default_pprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcycle\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 384\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 385\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mend_group\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.5/site-packages/IPython/lib/pretty.py\u001b[0m in \u001b[0;36m_default_pprint\u001b[0;34m(obj, p, cycle)\u001b[0m\n\u001b[1;32m 501\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0m_safe_getattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mklass\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'__repr__'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0m_baseclass_reprs\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 502\u001b[0m \u001b[0;31m# A user-provided repr. Find newlines and replace them with p.break_()\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 503\u001b[0;31m \u001b[0m_repr_pprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcycle\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 504\u001b[0m \u001b[0;32mreturn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 505\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbegin_group\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'<'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.5/site-packages/IPython/lib/pretty.py\u001b[0m in \u001b[0;36m_repr_pprint\u001b[0;34m(obj, p, cycle)\u001b[0m\n\u001b[1;32m 692\u001b[0m \u001b[0;34m\"\"\"A pprint that just redirects to the normal repr function.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 693\u001b[0m \u001b[0;31m# Find newlines and replace them with p.break_()\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 694\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrepr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 695\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0midx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0moutput_line\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplitlines\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 696\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0midx\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/Users/dimitri/dev/datajoint-python/datajoint/relational_operand.py\u001b[0m in \u001b[0;36m__repr__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 337\u001b[0m \u001b[0mret\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m' & %r'\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_restrictions\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 338\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 339\u001b[0;31m \u001b[0mrel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproj\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mheading\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnon_blobs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# project out blobs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 340\u001b[0m \u001b[0mlimit\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'display.limit'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 341\u001b[0m \u001b[0mwidth\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'display.width'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/Users/dimitri/dev/datajoint-python/datajoint/base_relation.py\u001b[0m in \u001b[0;36mheading\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_heading\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mHeading\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# instance-level heading\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 52\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_heading\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# lazy loading of heading\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 53\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_heading\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minit_from_database\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconnection\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdatabase\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtable_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 54\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_heading\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/Users/dimitri/dev/datajoint-python/datajoint/heading.py\u001b[0m in \u001b[0;36minit_from_database\u001b[0;34m(self, conn, database, table_name)\u001b[0m\n\u001b[1;32m 121\u001b[0m table_name=table_name, database=database), as_dict=True).fetchone()\n\u001b[1;32m 122\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0minfo\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 123\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mDataJointError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'The table is not defined.'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 124\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtable_info\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlower\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mv\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m \u001b[0;32min\u001b[0m \u001b[0minfo\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 125\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mDataJointError\u001b[0m: The table is not defined." + ] + } + ], + "source": [ + "Person()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# A closer look\n", + "\n", + "## What makes a relation class\n", + "- Class `Person` is decorated with the `schema` object created earlier. It links the table to the database. \n", + "- Class `Person` inherits from `dj.Manual`, which indicates that data will be entered into `Person` manually. We will discuss automatically populated tables in future sections.\n", + "- Class `Person` defines the multiline string property `definition`, which defines the attributes (columns) of the table.\n", + "- The class name **must** be in CamelCase and can only contain alphanumerical characters (underscores are not allowed). This is important because datajoint converts the class name into the corresponding table name in the database.\n", + "\n", + "## Format of the table definition string\n", + "The first line beginning with a `#` describes the contents of the table. \n", + "Each of the subsequent lines defines an attribute (column) of the table in the format\n", + "> `attribute_name : type # comment`\n", + "\n", + "All attribute names must be in lowercase, must start with a letter and can only contain alphanumerical characters and underscores.\n", + "\n", + "The separator \n", + "> `----`\n", + "\n", + "comprises three or more dashes and separates the **primary key** attributes above from the **dependent attributes** below. \n", + "\n", + "The primary key attributes uniquely identify each row in the table; the table cannot contain two rows with the same values of the primary key attributes.\n", + "\n", + "## Attribute types\n", + "The attribute types are MySQL data types and are summarized here http://datajoint.github.io/datatypes/. \n", + "\n", + "Most commonly used in datajoint are \n", + "- signed integers: ** `tinyint`, `smallint`, `int`, `bigint` **\n", + "- unsigned integers: ** `tinyint unsigned`, `smallint unsigned`, `int unsigned`, `bigint unsigned` ** \n", + "- floating-point and fixed-point fractional numbers: **`float`, `double`, `decimal`**\n", + "- enumeration: **`enum`**\n", + "- true/false **`boolean`**\n", + "- strings **`char`, `varchar`**\n", + "- dates and times **`date`, `timestamp`**\n", + "- arbitrary things such as images, traces, etc: **`longblob`** " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[Next](Primer03.ipynb)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.5.0" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tutorial-notebooks/Primer03.ipynb b/tutorial-notebooks/Primer03.ipynb new file mode 100644 index 000000000..128f9bb73 --- /dev/null +++ b/tutorial-notebooks/Primer03.ipynb @@ -0,0 +1,918 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[Prev](Primer02.ipynb)\n", + "\n", + "DataJoint Primer. Section 3.\n", + "# Dependencies" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "DataJoint 0.2.1 (June 1, 2016)\n", + "Loading local settings from /Users/dimitri/.datajoint_config.json\n" + ] + } + ], + "source": [ + "%matplotlib notebook\n", + "import datajoint as dj" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "schema = dj.schema('dimitri_experiment', locals())\n", + "\n", + "@schema\n", + "class Subject(dj.Manual):\n", + " definition = \"\"\"\n", + " # Basic subject info\n", + " subject_id : int # internal subject id\n", + " ---\n", + " real_id : varchar(40) # real-world name\n", + " species = \"mouse\" : enum('mouse', 'monkey', 'human') # species\n", + " date_of_birth=null : date # animal's date of birth\n", + " sex=\"unknown\" : enum('M','F','unknown') #\n", + " caretaker=\"Unknown\" : varchar(20) # person responsible for working with this subject\n", + " animal_notes=\"\" : varchar(4096) # strain, genetic manipulations, etc\n", + " \"\"\"\n", + "\n", + "\n", + "@schema\n", + "class Experiment(dj.Manual):\n", + " definition = \"\"\"\n", + " # Basic subject info\n", + "\n", + " -> Subject\n", + " experiment : smallint # experiment number for this subject\n", + " ---\n", + " experiment_folder : varchar(255) # folder path\n", + " experiment_date : date # experiment start date\n", + " experiment_notes=\"\" : varchar(4096)\n", + " experiment_ts=CURRENT_TIMESTAMP : timestamp # automatic timestamp\n", + " \"\"\"\n", + "\n", + "\n", + "@schema\n", + "class Session(dj.Manual):\n", + " definition = \"\"\"\n", + " # a two-photon imaging session\n", + "\n", + " -> Experiment\n", + " session_id : tinyint # two-photon session within this experiment\n", + " -----------\n", + " setup : tinyint # experimental setup\n", + " lens : tinyint # lens e.g.: 10x, 20x, 25x, 60x\n", + " \"\"\"\n", + "\n", + "\n", + "@schema\n", + "class Scan(dj.Manual):\n", + " definition = \"\"\"\n", + " # a two-photon imaging session\n", + "\n", + " -> Session\n", + " scan_id : tinyint # two-photon session within this experiment\n", + " ----\n", + " depth : float # depth from surface\n", + " wavelength : smallint # (nm) laser wavelength\n", + " mwatts: numeric(4,1) # (mW) laser power to brain\n", + " \"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "data": { + "application/javascript": [ + "/* Put everything inside the global mpl namespace */\n", + "window.mpl = {};\n", + "\n", + "mpl.get_websocket_type = function() {\n", + " if (typeof(WebSocket) !== 'undefined') {\n", + " return WebSocket;\n", + " } else if (typeof(MozWebSocket) !== 'undefined') {\n", + " return MozWebSocket;\n", + " } else {\n", + " alert('Your browser does not have WebSocket support.' +\n", + " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", + " 'Firefox 4 and 5 are also supported but you ' +\n", + " 'have to enable WebSockets in about:config.');\n", + " };\n", + "}\n", + "\n", + "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", + " this.id = figure_id;\n", + "\n", + " this.ws = websocket;\n", + "\n", + " this.supports_binary = (this.ws.binaryType != undefined);\n", + "\n", + " if (!this.supports_binary) {\n", + " var warnings = document.getElementById(\"mpl-warnings\");\n", + " if (warnings) {\n", + " warnings.style.display = 'block';\n", + " warnings.textContent = (\n", + " \"This browser does not support binary websocket messages. \" +\n", + " \"Performance may be slow.\");\n", + " }\n", + " }\n", + "\n", + " this.imageObj = new Image();\n", + "\n", + " this.context = undefined;\n", + " this.message = undefined;\n", + " this.canvas = undefined;\n", + " this.rubberband_canvas = undefined;\n", + " this.rubberband_context = undefined;\n", + " this.format_dropdown = undefined;\n", + "\n", + " this.image_mode = 'full';\n", + "\n", + " this.root = $('
');\n", + " this._root_extra_style(this.root)\n", + " this.root.attr('style', 'display: inline-block');\n", + "\n", + " $(parent_element).append(this.root);\n", + "\n", + " this._init_header(this);\n", + " this._init_canvas(this);\n", + " this._init_toolbar(this);\n", + "\n", + " var fig = this;\n", + "\n", + " this.waiting = false;\n", + "\n", + " this.ws.onopen = function () {\n", + " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", + " fig.send_message(\"send_image_mode\", {});\n", + " fig.send_message(\"refresh\", {});\n", + " }\n", + "\n", + " this.imageObj.onload = function() {\n", + " if (fig.image_mode == 'full') {\n", + " // Full images could contain transparency (where diff images\n", + " // almost always do), so we need to clear the canvas so that\n", + " // there is no ghosting.\n", + " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", + " }\n", + " fig.context.drawImage(fig.imageObj, 0, 0);\n", + " };\n", + "\n", + " this.imageObj.onunload = function() {\n", + " this.ws.close();\n", + " }\n", + "\n", + " this.ws.onmessage = this._make_on_message_function(this);\n", + "\n", + " this.ondownload = ondownload;\n", + "}\n", + "\n", + "mpl.figure.prototype._init_header = function() {\n", + " var titlebar = $(\n", + " '
');\n", + " var titletext = $(\n", + " '
');\n", + " titlebar.append(titletext)\n", + " this.root.append(titlebar);\n", + " this.header = titletext[0];\n", + "}\n", + "\n", + "\n", + "\n", + "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", + "\n", + "}\n", + "\n", + "\n", + "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", + "\n", + "}\n", + "\n", + "mpl.figure.prototype._init_canvas = function() {\n", + " var fig = this;\n", + "\n", + " var canvas_div = $('
');\n", + "\n", + " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", + "\n", + " function canvas_keyboard_event(event) {\n", + " return fig.key_event(event, event['data']);\n", + " }\n", + "\n", + " canvas_div.keydown('key_press', canvas_keyboard_event);\n", + " canvas_div.keyup('key_release', canvas_keyboard_event);\n", + " this.canvas_div = canvas_div\n", + " this._canvas_extra_style(canvas_div)\n", + " this.root.append(canvas_div);\n", + "\n", + " var canvas = $('');\n", + " canvas.addClass('mpl-canvas');\n", + " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", + "\n", + " this.canvas = canvas[0];\n", + " this.context = canvas[0].getContext(\"2d\");\n", + "\n", + " var rubberband = $('');\n", + " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", + "\n", + " var pass_mouse_events = true;\n", + "\n", + " canvas_div.resizable({\n", + " start: function(event, ui) {\n", + " pass_mouse_events = false;\n", + " },\n", + " resize: function(event, ui) {\n", + " fig.request_resize(ui.size.width, ui.size.height);\n", + " },\n", + " stop: function(event, ui) {\n", + " pass_mouse_events = true;\n", + " fig.request_resize(ui.size.width, ui.size.height);\n", + " },\n", + " });\n", + "\n", + " function mouse_event_fn(event) {\n", + " if (pass_mouse_events)\n", + " return fig.mouse_event(event, event['data']);\n", + " }\n", + "\n", + " rubberband.mousedown('button_press', mouse_event_fn);\n", + " rubberband.mouseup('button_release', mouse_event_fn);\n", + " // Throttle sequential mouse events to 1 every 20ms.\n", + " rubberband.mousemove('motion_notify', mouse_event_fn);\n", + "\n", + " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", + " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", + "\n", + " canvas_div.on(\"wheel\", function (event) {\n", + " event = event.originalEvent;\n", + " event['data'] = 'scroll'\n", + " if (event.deltaY < 0) {\n", + " event.step = 1;\n", + " } else {\n", + " event.step = -1;\n", + " }\n", + " mouse_event_fn(event);\n", + " });\n", + "\n", + " canvas_div.append(canvas);\n", + " canvas_div.append(rubberband);\n", + "\n", + " this.rubberband = rubberband;\n", + " this.rubberband_canvas = rubberband[0];\n", + " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", + " this.rubberband_context.strokeStyle = \"#000000\";\n", + "\n", + " this._resize_canvas = function(width, height) {\n", + " // Keep the size of the canvas, canvas container, and rubber band\n", + " // canvas in synch.\n", + " canvas_div.css('width', width)\n", + " canvas_div.css('height', height)\n", + "\n", + " canvas.attr('width', width);\n", + " canvas.attr('height', height);\n", + "\n", + " rubberband.attr('width', width);\n", + " rubberband.attr('height', height);\n", + " }\n", + "\n", + " // Set the figure to an initial 600x600px, this will subsequently be updated\n", + " // upon first draw.\n", + " this._resize_canvas(600, 600);\n", + "\n", + " // Disable right mouse context menu.\n", + " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", + " return false;\n", + " });\n", + "\n", + " function set_focus () {\n", + " canvas.focus();\n", + " canvas_div.focus();\n", + " }\n", + "\n", + " window.setTimeout(set_focus, 100);\n", + "}\n", + "\n", + "mpl.figure.prototype._init_toolbar = function() {\n", + " var fig = this;\n", + "\n", + " var nav_element = $('
')\n", + " nav_element.attr('style', 'width: 100%');\n", + " this.root.append(nav_element);\n", + "\n", + " // Define a callback function for later on.\n", + " function toolbar_event(event) {\n", + " return fig.toolbar_button_onclick(event['data']);\n", + " }\n", + " function toolbar_mouse_event(event) {\n", + " return fig.toolbar_button_onmouseover(event['data']);\n", + " }\n", + "\n", + " for(var toolbar_ind in mpl.toolbar_items) {\n", + " var name = mpl.toolbar_items[toolbar_ind][0];\n", + " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", + " var image = mpl.toolbar_items[toolbar_ind][2];\n", + " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", + "\n", + " if (!name) {\n", + " // put a spacer in here.\n", + " continue;\n", + " }\n", + " var button = $('