diff --git a/datajoint/__init__.py b/datajoint/__init__.py index b562aa8d1..584995a6d 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -1,13 +1,24 @@ +""" +DataJoint for Python is a high-level programming interface for MySQL databases +to support data processing chains in science labs. DataJoint is built on the +foundation of the relational data model and prescribes a consistent method for +organizing, populating, and querying data. + +DataJoint is free software under the LGPL License. In addition, we request +that any use of DataJoint leading to a publication be acknowledged in the publication. +""" + import logging import os __author__ = "Dimitri Yatsenko, Edgar Walker, and Fabian Sinz at Baylor College of Medicine" __version__ = "0.2" __all__ = ['__author__', '__version__', + 'config', 'Connection', 'Heading', 'Relation', 'FreeRelation', 'Not', - 'Relation', + 'Relation', 'schema', 'Manual', 'Lookup', 'Imported', 'Computed', - 'AutoPopulate', 'conn', 'DataJointError', 'blob'] + 'conn', 'DataJointError'] class DataJointError(Exception): @@ -39,8 +50,6 @@ class DataJointError(Exception): from .connection import conn, Connection from .relation import Relation from .user_relations import Manual, Lookup, Imported, Computed, Subordinate -from .autopopulate import AutoPopulate -from . import blob from .relational_operand import Not from .heading import Heading -from .relation import schema \ No newline at end of file +from .schema import schema diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index d07fd1bf0..62b7754d1 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -1,7 +1,10 @@ -from .relational_operand import RelationalOperand -from . import DataJointError, Relation +"""autopopulate containing the dj.AutoPopulate class. See `dj.AutoPopulate` for more info.""" import abc import logging +from .relational_operand import RelationalOperand +from . import DataJointError +from .relation import Relation, FreeRelation +from . import jobs # noinspection PyExceptionInherit,PyCallingNonCallable @@ -12,17 +15,25 @@ class AutoPopulate(metaclass=abc.ABCMeta): """ AutoPopulate is a mixin class that adds the method populate() to a Relation class. Auto-populated relations must inherit from both Relation and AutoPopulate, - must define the property pop_rel, and must define the callback method make_tuples. + must define the property populated_from, and must define the callback method _make_tuples. """ + _jobs = None - @abc.abstractproperty - def populate_relation(self): + @property + def populated_from(self): """ - Derived classes must implement the read-only property populate_relation, which is the - relational expression that defines how keys are generated for the populate call. - By default, populate relation is the join of the primary dependencies of the table. + :return: the relation whose primary key values are passed, sequentially, to the + `_make_tuples` method when populate() is called.The default value is the + join of the parent relations. Users may override to change the granularity + or the scope of populate() calls. """ - pass + parents = [FreeRelation(self.target.connection, rel) for rel in self.target.parents] + if not parents: + raise DataJointError('A relation must have parent relations to be able to be populated') + ret = parents.pop(0) + while parents: + ret *= parents.pop(0) + return ret @abc.abstractmethod def _make_tuples(self, key): @@ -39,52 +50,55 @@ def target(self): def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False): """ - rel.populate() calls rel._make_tuples(key) for every primary key in self.populate_relation + rel.populate() calls rel._make_tuples(key) for every primary key in self.populated_from for which there is not already a tuple in rel. - :param restriction: restriction on rel.populate_relation - target + :param restriction: restriction on rel.populated_from - target :param suppress_errors: suppresses error if true :param reserve_jobs: currently not implemented + :param batch: batch size of a single job """ - - assert not reserve_jobs, NotImplemented # issue #5 error_list = [] if suppress_errors else None - if not isinstance(self.populate_relation, RelationalOperand): - raise DataJointError('Invalid populate_relation value') + if not isinstance(self.populated_from, RelationalOperand): + raise DataJointError('Invalid populated_from value') - self.connection.cancel_transaction() # rollback previous transaction, if any + if self.connection.in_transaction: + raise DataJointError('Populate cannot be called during a transaction.') - if not isinstance(self, Relation): - raise DataJointError( - 'AutoPopulate is a mixin for Relation and must therefore subclass Relation') - - unpopulated = (self.populate_relation - self.target) & restriction + jobs = self.connection.jobs[self.target.database] + table_name = self.target.table_name + unpopulated = (self.populated_from - self.target) & restriction for key in unpopulated.project(): - self.connection.start_transaction() - if key in self.target: # already populated - self.connection.cancel_transaction() - else: - logger.info('Populating: ' + str(key)) - try: - self._make_tuples(dict(key)) - except Exception as error: + if not reserve_jobs or jobs.reserve(table_name, key): + self.connection.start_transaction() + if key in self.target: # already populated self.connection.cancel_transaction() - if not suppress_errors: - raise - else: - logger.error(error) - error_list.append((key, error)) + if reserve_jobs: + jobs.complete(table_name, key) else: - self.connection.commit_transaction() - logger.info('Done populating.') + logger.info('Populating: ' + str(key)) + try: + self._make_tuples(dict(key)) + except Exception as error: + self.connection.cancel_transaction() + if reserve_jobs: + jobs.error(table_name, key, error_message=str(error)) + if not suppress_errors: + raise + else: + logger.error(error) + error_list.append((key, error)) + else: + self.connection.commit_transaction() + if reserve_jobs: + jobs.complete(table_name, key) return error_list - def progress(self): """ report progress of populating this table """ - total = len(self.populate_relation) - remaining = len(self.populate_relation - self.target) - print('Remaining %d of %d (%2.1f%%)' % (remaining, total, 100*remaining/total) + total = len(self.populated_from) + remaining = len(self.populated_from - self.target) + print('Completed %d of %d (%2.1f%%)' % (total - remaining, total, 100 - 100 * remaining / total) if remaining else 'Complete', flush=True) diff --git a/datajoint/blob.py b/datajoint/blob.py index 7988e82d5..d513532d8 100644 --- a/datajoint/blob.py +++ b/datajoint/blob.py @@ -1,3 +1,7 @@ +""" +Provides serialization methods for numpy.ndarrays that ensure compatibility with Matlab. +""" + import zlib from collections import OrderedDict import numpy as np @@ -29,7 +33,10 @@ def pack(obj): """ - packs an object into a blob to be compatible with mym.mex + Packs an object into a blob to be compatible with mym.mex + + :param obj: object to be packed + :type obj: numpy.ndarray """ if not isinstance(obj, np.ndarray): raise DataJointError("Only numpy arrays can be saved in blobs") @@ -58,9 +65,16 @@ def pack(obj): def unpack(blob): """ - unpack blob into a numpy array + Unpacks blob data into a numpy array. + + :param blob: mysql blob + :returns: unpacked data + :rtype: numpy.ndarray """ # decompress if necessary + if blob is None: + return None + if blob[0:5] == b'ZL123': blob_length = np.fromstring(blob[6:14], dtype=np.uint64)[0] blob = zlib.decompress(blob[14:]) diff --git a/datajoint/connection.py b/datajoint/connection.py index 325d1c021..42d70e21c 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -1,43 +1,42 @@ +""" +This module hosts the Connection class that manages the connection to the mysql database via +`pymysql`, and the `conn` function that provides access to a persistent connection in datajoint. + +""" from contextlib import contextmanager import pymysql -from . import DataJointError import logging -from . import config +from collections import defaultdict +from . import DataJointError, config from .erd import ERD +from .jobs import JobManager logger = logging.getLogger(__name__) -def conn_container(): +def conn(host=None, user=None, passwd=None, init_fun=None, reset=False): """ - creates a persistent connections for everyone to use + Returns a persistent connection object to be shared by multiple modules. + If the connection is not yet established or reset=True, a new connection is set up. + If connection information is not provided, it is taken from config which takes the + information from dj_local_conf.json. If the password is not specified in that file + datajoint prompts for the password. + + :param host: hostname + :param user: mysql user + :param passwd: mysql password + :param init_fun: initialization function + :param reset: whether the connection should be reseted or not """ - _connection = None # persistent connection object used by dj.conn() - - def conn_function(host=None, user=None, passwd=None, init_fun=None, reset=False): # TODO: thin wrapping layer to mimic singleton - """ - Manage a persistent connection object. - This is one of several ways to configure and access a datajoint connection. - Users may customize their own connection manager. - - Set rest=True to reset the persistent connection object - """ - nonlocal _connection - if not _connection or reset: - host = host if host is not None else config['database.host'] - user = user if user is not None else config['database.user'] - passwd = passwd if passwd is not None else config['database.password'] - - if passwd is None: passwd = input("Please enter database password: ") - - init_fun = init_fun if init_fun is not None else config['connection.init_function'] - _connection = Connection(host, user, passwd, init_fun) - return _connection - - return conn_function - -# The function conn is used by others to obtain a connection object -conn = conn_container() + if not hasattr(conn, 'connection') or reset: + host = host if host is not None else config['database.host'] + user = user if user is not None else config['database.user'] + passwd = passwd if passwd is not None else config['database.password'] + if passwd is None: + passwd = input("Please enter database password: ") + init_fun = init_fun if init_fun is not None else config['connection.init_function'] + conn.connection = Connection(host, user, passwd, init_fun) + return conn.connection class Connection: @@ -51,7 +50,6 @@ class Connection: :param user: user name :param passwd: password :param init_fun: initialization function - """ def __init__(self, host, user, passwd, init_fun=None): @@ -69,6 +67,7 @@ def __init__(self, host, user, passwd, init_fun=None): raise DataJointError('Connection failed.') self._conn.autocommit(True) self._in_transaction = False + self.jobs = JobManager(self) def __del__(self): logger.info('Disconnecting {user}@{host}:{port}'.format(**self.conn_info)) @@ -89,16 +88,15 @@ def __repr__(self): return "DataJoint connection ({connected}) {user}@{host}:{port}".format( connected=connected, **self.conn_info) - def __del__(self): - logger.info('Disconnecting {user}@{host}:{port}'.format(**self.conn_info)) - self._conn.close() def query(self, query, args=(), as_dict=False): """ Execute the specified query and return the tuple generator. - If as_dict is set to True, the returned cursor objects returns - query results as dictionary. + :param query: mysql query + :param args: additional arguments for the pymysql.cursor + :param as_dict: If as_dict is set to True, the returned cursor objects returns + query results as dictionary. """ cursor = pymysql.cursors.DictCursor if as_dict else pymysql.cursors.Cursor cur = self._conn.cursor(cursor=cursor) @@ -108,13 +106,21 @@ def query(self, query, args=(), as_dict=False): cur.execute(query, args) return cur - # ---------- transaction processing ------------------ + # ---------- transaction processing @property def in_transaction(self): + """ + :return: True if there is an open transaction. + """ self._in_transaction = self._in_transaction and self.is_connected return self._in_transaction def start_transaction(self): + """ + Starts a transaction error. + + :raise DataJointError: if there is an ongoing transaction. + """ if self.in_transaction: raise DataJointError("Nested connections are not supported.") self.query('START TRANSACTION WITH CONSISTENT SNAPSHOT') @@ -122,19 +128,39 @@ def start_transaction(self): logger.info("Transaction started") def cancel_transaction(self): + """ + Cancels the current transaction and rolls back all changes made during the transaction. + + """ self.query('ROLLBACK') self._in_transaction = False logger.info("Transaction cancelled. Rolling back ...") def commit_transaction(self): + """ + Commit all changes made during the transaction and close it. + + """ self.query('COMMIT') self._in_transaction = False logger.info("Transaction committed and closed.") - - #-------- context manager for transactions + # -------- context manager for transactions + @property @contextmanager def transaction(self): + """ + Context manager for transactions. Opens an transaction and closes it after the with statement. + If an error is caught during the transaction, the commits are automatically rolled back. All + errors are raised again. + + Example: + >>> import datajoint as dj + >>> with dj.conn().transaction as conn: + >>> # transaction is open here + + + """ try: self.start_transaction() yield self @@ -142,5 +168,4 @@ def transaction(self): self.cancel_transaction() raise else: - self.commit_transaction() - + self.commit_transaction() \ No newline at end of file diff --git a/datajoint/declare.py b/datajoint/declare.py index 18679150a..eceb70b48 100644 --- a/datajoint/declare.py +++ b/datajoint/declare.py @@ -1,3 +1,7 @@ +""" +This module hosts functions to convert DataJoint table definitions into mysql table definitions, and to +declare the corresponding mysql tables. +""" import re import pyparsing as pp import logging @@ -12,6 +16,10 @@ def declare(full_table_name, definition, context): """ Parse declaration and create new SQL table accordingly. + + :param full_table_name: full name of the table + :param definition: DataJoint table definition + :param context: dictionary of objects that might be referred to in the table. Usually this will be locals() """ # split definition into lines definition = re.split(r'\s*\n\s*', definition.strip()) @@ -72,6 +80,7 @@ def declare(full_table_name, definition, context): def compile_attribute(line, in_key=False): """ Convert attribute definition from DataJoint format to SQL + :param line: attribution line :param in_key: set to True if attribute is in primary key set :returns: (name, sql) -- attribute name and sql code for its declaration diff --git a/datajoint/erd.py b/datajoint/erd.py index 73af00c7c..a90f4b017 100644 --- a/datajoint/erd.py +++ b/datajoint/erd.py @@ -16,12 +16,19 @@ class ERD: - _parents = defaultdict(set) - _children = defaultdict(set) - _references = defaultdict(set) - _referenced = defaultdict(set) + _checked_dependencies = set() + _parents = dict() + _referenced = dict() + _children = defaultdict(list) + _references = defaultdict(list) + + def load_dependencies(self, connection, full_table_name): + # check if already loaded. Use clear_dependencies before reloading + if full_table_name in self._parents: + return + self._parents[full_table_name] = list() + self._referenced[full_table_name] = list() - def load_dependencies(self, connection, full_table_name, primary_key): # fetch the CREATE TABLE statement cur = connection.query('SHOW CREATE TABLE %s' % full_table_name) create_statement = cur.fetchone() @@ -29,39 +36,62 @@ def load_dependencies(self, connection, full_table_name, primary_key): raise DataJointError('Could not load the definition table %s' % full_table_name) create_statement = create_statement[1].split('\n') - # build foreign key parser + # build foreign key fk_parser database = full_table_name.split('.')[0].strip('`') add_database = lambda string, loc, toc: ['`{database}`.`{table}`'.format(database=database, table=toc[0])] - parser = pp.CaselessLiteral('CONSTRAINT').suppress() - parser += pp.QuotedString('`').suppress() - parser += pp.CaselessLiteral('FOREIGN KEY').suppress() - parser += pp.QuotedString('(', endQuoteChar=')').setResultsName('attributes') - parser += pp.CaselessLiteral('REFERENCES') - parser += pp.Or([ + # primary key parser + pk_parser = pp.CaselessLiteral('PRIMARY KEY') + pk_parser += pp.QuotedString('(', endQuoteChar=')').setResultsName('primary_key') + + # foreign key parser + fk_parser = pp.CaselessLiteral('CONSTRAINT').suppress() + fk_parser += pp.QuotedString('`').suppress() + fk_parser += pp.CaselessLiteral('FOREIGN KEY').suppress() + fk_parser += pp.QuotedString('(', endQuoteChar=')').setResultsName('attributes') + fk_parser += pp.CaselessLiteral('REFERENCES') + fk_parser += pp.Or([ pp.QuotedString('`').setParseAction(add_database), pp.Combine(pp.QuotedString('`', unquoteResults=False) + '.' + pp.QuotedString('`', unquoteResults=False)) ]).setResultsName('referenced_table') - parser += pp.QuotedString('(', endQuoteChar=')').setResultsName('referenced_attributes') + fk_parser += pp.QuotedString('(', endQuoteChar=')').setResultsName('referenced_attributes') # parse foreign keys + primary_key = None for line in create_statement: + if primary_key is None: + try: + result = pk_parser.parseString(line) + except pp.ParseException: + pass + else: + primary_key = [s.strip(' `') for s in result.primary_key.split(',')] try: - result = parser.parseString(line) + result = fk_parser.parseString(line) except pp.ParseException: pass else: + if not primary_key: + raise DataJointError('No primary key found %s' % full_table_name) if result.referenced_attributes != result.attributes: raise DataJointError( "%s's foreign key refers to differently named attributes in %s" % (self.__class__.__name__, result.referenced_table)) if all(q in primary_key for q in [s.strip('` ') for s in result.attributes.split(',')]): - self._parents[full_table_name].add(result.referenced_table) - self._children[result.referenced_table].add(full_table_name) + self._parents[full_table_name].append(result.referenced_table) + self._children[result.referenced_table].append(full_table_name) else: - self._referenced[full_table_name].add(result.referenced_table) - self._references[result.referenced_table].add(full_table_name) + self._referenced[full_table_name].append(result.referenced_table) + self._references[result.referenced_table].append(full_table_name) + + def clear_dependencies(self, full_table_name): + for ref in self._parents.pop(full_table_name, []): + if full_table_name in self._children[ref]: + self._children[ref].remove(full_table_name) + for ref in self._referenced.pop(full_table_name, []): + if full_table_name in self._references[ref]: + self._references[ref].remove(full_table_name) @property def parents(self): @@ -79,7 +109,22 @@ def references(self): def referenced(self): return self._referenced + def get_descendants(self, full_table_name): + """ + :param full_table_name: a table name in the format `database`.`table_name` + :return: list of all children and references, in order of dependence. + This is helpful for cascading delete or drop operations. + """ + ret = defaultdict(lambda: 0) + + def recurse(full_table_name, level): + if level > ret[full_table_name]: + ret[full_table_name] = level + for child in self.children[full_table_name] + self.references[full_table_name]: + recurse(child, level+1) + recurse(full_table_name, 0) + return sorted(ret.keys(), key=ret.__getitem__) def to_camel_case(s): diff --git a/datajoint/heading.py b/datajoint/heading.py index 5c332fff6..85d54ee3e 100644 --- a/datajoint/heading.py +++ b/datajoint/heading.py @@ -5,50 +5,57 @@ from datajoint import DataJointError +class Attribute(namedtuple('Attribute', + ('name', 'type', 'in_key', 'nullable', 'default', + 'comment', 'autoincrement', 'numeric', 'string', 'is_blob', + 'computation', 'dtype'))): + def _asdict(self): + """ + for some reason the inherted _asdict does not work after subclassing from namedtuple + """ + return OrderedDict((name, self[i]) for i, name in enumerate(self._fields)) + + def sql(self): + """ + Convert attribute tuple into its SQL CREATE TABLE clause. + :rtype : SQL code + """ + literals = ['CURRENT_TIMESTAMP'] + if self.nullable: + default = 'DEFAULT NULL' + else: + default = 'NOT NULL' + if self.default: + # enclose value in quotes except special SQL values or already enclosed + quote = self.default.upper() not in literals and self.default[0] not in '"\'' + default += ' DEFAULT ' + ('"%s"' if quote else "%s") % self.default + if any((c in r'\"' for c in self.comment)): + raise DataJointError('Illegal characters in attribute comment "%s"' % self.comment) + return '`{name}` {type} {default} COMMENT "{comment}"'.format( + name=self.name, type=self.type, default=default, comment=self.comment) + + class Heading: """ Local class for relations' headings. Heading contains the property attributes, which is an OrderedDict in which the keys are - the attribute names and the values are AttrTuples. + the attribute names and the values are Attributes. """ - class AttrTuple(namedtuple('AttrTuple', - ('name', 'type', 'in_key', 'nullable', 'default', - 'comment', 'autoincrement', 'numeric', 'string', 'is_blob', - 'computation', 'dtype'))): - def _asdict(self): - """ - for some reason the inherted _asdict does not work after subclassing from namedtuple - """ - return OrderedDict((name, self[i]) for i, name in enumerate(self._fields)) - - def sql(self): - """ - Convert attribute tuple into its SQL CREATE TABLE clause. - :rtype : SQL code - """ - literals = ['CURRENT_TIMESTAMP'] - if self.nullable: - default = 'DEFAULT NULL' - else: - default = 'NOT NULL' - if self.default: - # enclose value in quotes except special SQL values or already enclosed - quote = self.default.upper() not in literals and self.default[0] not in '"\'' - default += ' DEFAULT ' + ('"%s"' if quote else "%s") % self.default - if any((c in r'\"' for c in self.comment)): - raise DataJointError('Illegal characters in attribute comment "%s"' % self.comment) - return '`{name}` {type} {default} COMMENT "{comment}"'.format( - name=self.name, type=self.type, default=default, comment=self.comment) - def __init__(self, attributes=None): """ - :param attributes: a list of dicts with the same keys as AttrTuple + :param attributes: a list of dicts with the same keys as Attribute """ if attributes: - attributes = OrderedDict([(q['name'], Heading.AttrTuple(**q)) for q in attributes]) + attributes = OrderedDict([(q['name'], Attribute(**q)) for q in attributes]) self.attributes = attributes + def reset(self): + self.attributes = None + + def __len__(self): + return 0 if self.attributes is None else len(self.attributes) + def __bool__(self): return self.attributes is not None @@ -194,7 +201,7 @@ def init_from_database(self, conn, database, table_name): t = re.sub(r' unsigned$', '', t) # remove unsigned assert (t, is_unsigned) in numeric_types, 'dtype not found for type %s' % t attr['dtype'] = numeric_types[(t, is_unsigned)] - self.attributes = OrderedDict([(q['name'], Heading.AttrTuple(**q)) for q in attributes]) + self.attributes = OrderedDict([(q['name'], Attribute(**q)) for q in attributes]) def project(self, *attribute_list, **renamed_attributes): """ diff --git a/datajoint/jobs.py b/datajoint/jobs.py new file mode 100644 index 000000000..db41e0eac --- /dev/null +++ b/datajoint/jobs.py @@ -0,0 +1,105 @@ +import hashlib +import os +import pymysql +from .relation import Relation + + +def key_hash(key): + """ + 32-byte hash used for lookup of primary keys of jobs + """ + hashed = hashlib.md5() + for k, v in sorted(key.items()): + hashed.update(str(v).encode()) + return hashed.hexdigest() + + +class JobRelation(Relation): + """ + A base relation with no definition. Allows reserving jobs + """ + + def __init__(self, connection, database): + self.database = database + self._table_name = '~jobs' + self._connection = connection + self._definition = """ # job reservation table + table_name :varchar(255) # className of the table + key_hash :char(32) # key hash + --- + status :enum('reserved','error','ignore') # if tuple is missing, the job is available + key=null :blob # structure containing the key + error_message="" :varchar(1023) # error message returned if failed + error_stack=null :blob # error stack if failed + host="" :varchar(255) # system hostname + pid=0 :int unsigned # system process id + timestamp=CURRENT_TIMESTAMP :timestamp # automatic timestamp + """ + + @property + def definition(self): + return self._definition + + @property + def connection(self): + return self._connection + + @property + def table_name(self): + return self._table_name + + def reserve(self, table_name, key): + """ + Reserve a job for computation. When a job is reserved, the job table contains an entry for the + job key, identified by its hash. When jobs are completed, the entry is removed. + :param full_table_name: `database`.`table_name` + :param key: the dict of the job's primary key + :return: True if reserved job successfully. False = the jobs is already taken + """ + job_key = dict(table_name=table_name, key_hash=key_hash(key)) + try: + self.insert1( + dict(job_key, + status="reserved", + host=os.uname().nodename, + pid=os.getpid())) + except pymysql.err.IntegrityError: + return False + else: + return True + + def complete(self, table_name, key): + """ + Log a completed job. When a job is completed, its reservation entry is deleted. + :param table_name: `database`.`table_name` + :param key: the dict of the job's primary key + """ + job_key = dict(table_name=table_name, key_hash=key_hash(key)) + (self & job_key).delete_quick() + + def error(self, table_name, key, error_message): + """ + Log an error message. The job reservation is replaced with an error entry. + if an error occurs, leave an entry describing the problem + :param table_name: `database`.`table_name` + :param key: the dict of the job's primary key + :param error_message: string error message + """ + job_key = dict(table_name=table_name, key_hash=key_hash(key)) + self.insert1( + dict(job_key, + status="error", + host=os.uname().nodename, + pid=os.getpid(), + error_message=error_message), replace=True) + + +class JobManager: + def __init__(self, connection): + self.connection = connection + self._jobs = {} + + def __getitem__(self, database): + if database not in self._jobs: + self._jobs[database] = JobRelation(self.connection, database) + return self._jobs[database] diff --git a/datajoint/relation.py b/datajoint/relation.py index 3cdf8e919..140b2858d 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -2,10 +2,9 @@ import numpy as np import logging import abc -import pymysql -from . import DataJointError, config, conn -from .declare import declare, compile_attribute +from . import DataJointError, config +from .declare import declare from .relational_operand import RelationalOperand from .blob import pack from .utils import user_choice @@ -14,51 +13,6 @@ logger = logging.getLogger(__name__) -def schema(database, context, connection=None): - """ - Returns a decorator that can be used to associate a Relation class to a database. - - :param database: name of the database to associate the decorated class with - :param context: dictionary for looking up foreign keys references, usually set to locals() - :param connection: Connection object. Defaults to datajoint.conn() - :return: a decorator for Relation subclasses - """ - if connection is None: - connection = conn() - - # if the database does not exist, create it - cur = connection.query("SHOW DATABASES LIKE '{database}'".format(database=database)) - if cur.rowcount == 0: - logger.info("Database `{database}` could not be found. " - "Attempting to create the database.".format(database=database)) - try: - connection.query("CREATE DATABASE `{database}`".format(database=database)) - logger.info('Created database `{database}`.'.format(database=database)) - except pymysql.OperationalError: - raise DataJointError("Database named `{database}` was not defined, and" - "an attempt to create has failed. Check" - " permissions.".format(database=database)) - - def decorator(cls): - """ - The decorator declares the table and binds the class to the database table - """ - cls.database = database - cls._connection = connection - cls._heading = Heading() - instance = cls() if isinstance(cls, type) else cls - if not instance.heading: - connection.query( - declare( - full_table_name=instance.full_table_name, - definition=instance.definition, - context=context)) - connection.erd.load_dependencies(connection, instance.full_table_name, instance.primary_key) - return cls - - return decorator - - class Relation(RelationalOperand, metaclass=abc.ABCMeta): """ Relation is an abstract class that represents a base relation, i.e. a table in the database. @@ -66,6 +20,8 @@ class Relation(RelationalOperand, metaclass=abc.ABCMeta): table name, database, context, and definition. A Relation implements insert and delete methods in addition to inherited relational operators. """ + _heading = None + _context = None # ---------- abstract properties ------------ # @property @@ -91,8 +47,20 @@ def connection(self): @property def heading(self): - if not self._heading and self.is_declared: - self._heading.init_from_database(self.connection, self.database, self.table_name) + """ + Get the table headng. + If the table is not declared, attempts to declare it and return heading. + :return: + """ + if self._heading is None: + self._heading = Heading() # instance-level heading + if not self._heading: + if not self.is_declared: + self.connection.query( + declare(self.full_table_name, self.definition, self._context)) + if self.is_declared: + self.connection.erd.load_dependencies(self.connection, self.full_table_name) + self._heading.init_from_database(self.connection, self.database, self.table_name) return self._heading @property @@ -102,15 +70,6 @@ def from_clause(self): """ return self.full_table_name - def iter_insert(self, rows, **kwargs): - """ - Inserts a collection of tuples. Additional keyword arguments are passed to insert. - - :param iter: Must be an iterator that generates a sequence of valid arguments for insert. - """ - for row in rows: - self.insert(row, **kwargs) - # ------------- dependencies ---------- # @property def parents(self): @@ -128,6 +87,16 @@ def references(self): def referenced(self): return self.connection.erd.referenced[self.full_table_name] + @property + def descendants(self): + """ + :return: list of relation objects for all children and references, recursively, + in order of dependence. + This is helpful for cascading delete or drop operations. + """ + relations = (FreeRelation(self.connection, table) + for table in self.connection.erd.get_descendants(self.full_table_name)) + return [relation for relation in relations if relation.is_declared] # --------- SQL functionality --------- # @property @@ -137,55 +106,64 @@ def is_declared(self): database=self.database, table_name=self.table_name)) return cur.rowcount > 0 - def batch_insert(self, data, **kwargs): - """ - Inserts an entire batch of entries. Additional keyword arguments are passed to insert. - - :param data: must be iterable, each row must be a valid argument for insert - """ - self.iter_insert(data.__iter__(), **kwargs) - @property def full_table_name(self): return r"`{0:s}`.`{1:s}`".format(self.database, self.table_name) - def insert(self, tup, ignore_errors=False, replace=False): + def insert(self, rows, **kwargs): """ - Insert one data record or one Mapping (like a dictionary). + Inserts a collection of tuples. Additional keyword arguments are passed to insert1. + + :param iter: Must be an iterator that generates a sequence of valid arguments for insert. + """ + for row in rows: + self.insert1(row, **kwargs) - :param tup: Data record, or a Mapping (like a dictionary). - :param ignore_errors=False: Ignores errors if True. + def insert1(self, tup, replace=False, ignore_errors=False): + """ + Insert one data record or one Mapping (like a dict). + + :param tup: Data record, a Mapping (like a dict), or a list or tuple with ordered values. :param replace=False: Replaces data tuple if True. + :param ignore_errors=False: If True, ignore errors: e.g. constraint violations or duplicates Example:: - - b = djtest.Subject() - b.insert(dict(subject_id = 7, species="mouse",\\ - real_id = 1007, date_of_birth = "2014-09-01")) + relation.insert1(dict(subject_id=7, species="mouse", date_of_birth="2014-09-01")) """ - heading = self.heading - if isinstance(tup, np.void): + + if isinstance(tup, np.void): # np.array insert for fieldname in tup.dtype.fields: if fieldname not in heading: - raise KeyError(u'{0:s} is not in the attribute list'.format(fieldname, )) + raise KeyError(u'{0:s} is not in the attribute list'.format(fieldname)) value_list = ','.join([repr(tup[name]) if not heading[name].is_blob else '%s' for name in heading if name in tup.dtype.fields]) - args = tuple(pack(tup[name]) for name in heading if name in tup.dtype.fields and heading[name].is_blob) attribute_list = '`' + '`,`'.join(q for q in heading if q in tup.dtype.fields) + '`' - elif isinstance(tup, Mapping): + + elif isinstance(tup, Mapping): # dict-based insert for fieldname in tup.keys(): if fieldname not in heading: - raise KeyError(u'{0:s} is not in the attribute list'.format(fieldname, )) + raise KeyError(u'{0:s} is not in the attribute list'.format(fieldname)) value_list = ','.join(repr(tup[name]) if not heading[name].is_blob else '%s' for name in heading if name in tup) args = tuple(pack(tup[name]) for name in heading if name in tup and heading[name].is_blob) attribute_list = '`' + '`,`'.join(name for name in heading if name in tup) + '`' - else: - raise DataJointError('Datatype %s cannot be inserted' % type(tup)) + + else: # positional insert + try: + if len(tup) != len(self.heading): + raise DataJointError( + 'Tuple size does not match the number of relation attributes') + except TypeError: + raise DataJointError('Datatype %s cannot be inserted' % type(tup)) + else: + pairs = zip(heading, tup) + value_list = ','.join('%s' if heading[name].is_blob else repr(value) for name, value in pairs) + attribute_list = '`' + '`,`'.join(heading.names) + '`' + args = tuple(pack(value) for name, value in pairs if heading[name].is_blob) if replace: sql = 'REPLACE' elif ignore_errors: @@ -196,94 +174,101 @@ def insert(self, tup, ignore_errors=False, replace=False): logger.info(sql) self.connection.query(sql, args=args) + def delete_quick(self): + """ + delete without cascading and without user prompt + """ + self.connection.query('DELETE FROM ' + self.from_clause + self.where_clause) + def delete(self): - if not config['safemode'] or user_choice( - "You are about to delete data from a table. This operation cannot be undone.\n" - "Proceed?", default='no') == 'yes': - self.connection.query('DELETE FROM ' + self.from_clause + self.where_clause) # TODO: make cascading (issue #15) + """ + Delete the contents of the table and its dependent tables, recursively. + User is prompted for confirmation if config['safemode'] + """ + relations = self.descendants + if self.restrictions and len(relations)>1: + raise NotImplementedError('Restricted cascading deletes are not yet implemented') + do_delete = True + if config['safemode']: + do_delete = False + print('The contents of the following tables are about to be deleted:') + for relation in relations: + count = len(relation) + if count: + do_delete = True + print(relation.full_table_name, '(%d tuples)' % count) + do_delete = do_delete and user_choice("Proceed?", default='no') == 'yes' + if do_delete: + with self.connection.transaction: + while relations: + relations.pop().delete_quick() + + def drop_quick(self): + """ + Drops the table associated with this relation without cascading and without user prompt. + """ + if self.is_declared: + self.connection.query('DROP TABLE %s' % self.full_table_name) + self.connection.erd.clear_dependencies(self.full_table_name) + if self._heading: + self._heading.reset() + logger.info("Dropped table %s" % self.full_table_name) def drop(self): """ - Drops the table associated to this class. + Drop the table and all tables that reference it, recursively. + User is prompted for confirmation if config['safemode'] """ - if self.is_declared: - if not config['safemode'] or user_choice( - "You are about to drop an entire table. This operation cannot be undone.\n" - "Proceed?", default='no') == 'yes': - self.connection.query('DROP TABLE %s' % self.full_table_name) # TODO: make cascading (issue #16) - # cls.connection.clear_dependencies(dbname=cls.dbname) #TODO: reimplement because clear_dependencies will be gone - # cls.connection.load_headings(dbname=cls.dbname, force=True) #TODO: reimplement because load_headings is gone - logger.info("Dropped table %s" % self.full_table_name) + do_drop = True + relations = self.descendants + if config['safemode']: + print('The following tables are about to be dropped:') + for relation in relations: + print(relation.full_table_name, '(%d tuples)' % len(relation)) + do_drop = user_choice("Proceed?", default='no') == 'yes' + if do_drop: + while relations: + relations.pop().drop_quick() + print('Tables dropped.') + @property def size_on_disk(self): """ - :return: size of data and indices in GiB taken by the table on the storage device + :return: size of data and indices in bytes on the storage device """ ret = self.connection.query( 'SHOW TABLE STATUS FROM `{database}` WHERE NAME="{table}"'.format( database=self.database, table=self.table_name), as_dict=True ).fetchone() - return (ret['Data_length'] + ret['Index_length'])/1024**2 - - def set_table_comment(self, comment): - """ - Update the table comment in the table definition. - :param comment: new comment as string - """ - self._alter('COMMENT="%s"' % comment) + return ret['Data_length'] + ret['Index_length'] - def add_attribute(self, definition, after=None): + # --------- functionality used by the decorator --------- + def prepare(self): """ - Add a new attribute to the table. A full line from the table definition - is passed in as definition. - - The definition can specify where to place the new attribute. Use after=None - to add the attribute as the first attribute or after='attribute' to place it - after an existing attribute. - - :param definition: table definition - :param after=None: After which attribute of the table the new attribute is inserted. - If None, the attribute is inserted in front. + This method is overridden by the user_relations subclasses. It is called on an instance + once when the class is declared. """ - position = ' FIRST' if after is None else ( - ' AFTER %s' % after if after else '') - sql = compile_attribute(definition)[1] - self._alter('ADD COLUMN %s%s' % (sql, position)) + pass - def drop_attribute(self, attribute_name): - """ - Drops the attribute attrName from this table. - :param attribute_name: Name of the attribute that is dropped. - """ - if not config['safemode'] or user_choice( - "You are about to drop an attribute from a table." - "This operation cannot be undone.\n" - "Proceed?", default='no') == 'yes': - self._alter('DROP COLUMN `%s`' % attribute_name) - def alter_attribute(self, attribute_name, definition): - """ - Alter attribute definition +class FreeRelation(Relation): + """ + A base relation without a dedicated class. The table name is explicitly set. + """ + def __init__(self, connection, full_table_name, definition=None, context=None): + self.database, self._table_name = (s.strip('`') for s in full_table_name.split('.')) + self._connection = connection + self._definition = definition + self._context = context - :param attribute_name: field that is redefined - :param definition: new definition of the field - """ - sql = compile_attribute(definition)[1] - self._alter('CHANGE COLUMN `%s` %s' % (attribute_name, sql)) + @property + def definition(self): + return self._definition - def erd(self, subset=None): - """ - Plot the schema's entity relationship diagram (ERD). - """ - NotImplemented + @property + def connection(self): + return self._connection - def _alter(self, alter_statement): - """ - Execute an ALTER TABLE statement. - :param alter_statement: alter statement - """ - if self.connection.in_transaction: - raise DataJointError("Table definition cannot be altered during a transaction.") - sql = 'ALTER TABLE %s %s' % (self.full_table_name, alter_statement) - self.connection.query(sql) - self.heading.init_from_database(self.connection, self.database, self.table_name) + @property + def table_name(self): + return self._table_name diff --git a/datajoint/relational_operand.py b/datajoint/relational_operand.py index 367fd2d1e..a8f633c28 100644 --- a/datajoint/relational_operand.py +++ b/datajoint/relational_operand.py @@ -95,6 +95,7 @@ def project(self, *attributes, **renamed_attributes): def aggregate(self, _group, *attributes, **renamed_attributes): """ Relational aggregation operator + :param group: relation whose tuples can be used in aggregation operators :param extensions: :return: a relation representing the aggregation/projection operator result @@ -103,7 +104,7 @@ def aggregate(self, _group, *attributes, **renamed_attributes): raise DataJointError('The second argument must be a relation or None') return Projection(self, _group, *attributes, **renamed_attributes) - def __iand__(self, restriction): + def _restrict(self, restriction): """ in-place relational restriction or semijoin """ @@ -113,6 +114,12 @@ def __iand__(self, restriction): self._restrictions.append(restriction) return self + def __iand__(self, restriction): + """ + in-place relational restriction or semijoin + """ + return self._restrict(restriction) + def __and__(self, restriction): """ relational restriction or semijoin @@ -151,7 +158,7 @@ def __len__(self): def __contains__(self, item): """ - "item in relation" is equivalient to "len(relation & item)>0" + "item in relation" is equivalent to "len(relation & item)>0" """ return len(self & item) > 0 @@ -161,6 +168,41 @@ def __call__(self, *args, **kwargs): """ return self.fetch(*args, **kwargs) + def cursor(self, offset=0, limit=None, order_by=None, descending=False, as_dict=False): + """ + Return query cursor. + See Relation.fetch() for input description. + :return: cursor to the query + """ + if offset and limit is None: + raise DataJointError('limit is required when offset is set') + sql = self.make_select() + if order_by is not None: + sql += ' ORDER BY ' + ', '.join(order_by) + if descending: + sql += ' DESC' + if limit is not None: + sql += ' LIMIT %d' % limit + if offset: + sql += ' OFFSET %d' % offset + logger.debug(sql) + return self.connection.query(sql, as_dict=as_dict) + + def __repr__(self): + limit = config['display.limit'] + width = config['display.width'] + rel = self.project(*self.heading.non_blobs) # project out blobs + template = '%%-%d.%ds' % (width, width) + columns = rel.heading.names + repr_string = ' '.join([template % column for column in columns]) + '\n' + repr_string += ' '.join(['+' + '-' * (width - 2) + '+' for _ in columns]) + '\n' + for tup in rel.fetch(limit=limit): + repr_string += ' '.join([template % column for column in tup]) + '\n' + if len(self) > limit: + repr_string += '...\n' + repr_string += ' (%d tuples)\n' % len(self) + return repr_string + def fetch1(self): """ This version of fetch is called when self is expected to contain exactly one tuple. @@ -197,53 +239,36 @@ def fetch(self, offset=0, limit=None, order_by=None, descending=False, as_dict=F ret[blob_name] = list(map(unpack, ret[blob_name])) return ret - def cursor(self, offset=0, limit=None, order_by=None, descending=False, as_dict=False): + def values(self, offset=0, limit=None, order_by=None, descending=False, as_dict=False): """ - Return query cursor. - See Relation.fetch() for input description. - :return: cursor to the query + Iterator that returns the contents of the database. """ - if offset and limit is None: - raise DataJointError('limit is required when offset is set') - sql = self.make_select() - if order_by is not None: - sql += ' ORDER BY ' + ', '.join(order_by) - if descending: - sql += ' DESC' - if limit is not None: - sql += ' LIMIT %d' % limit - if offset: - sql += ' OFFSET %d' % offset - logger.debug(sql) - return self.connection.query(sql, as_dict=as_dict) + cur = self.cursor(offset=offset, limit=limit, order_by=order_by, + descending=descending, as_dict=as_dict) + heading = self.heading + do_unpack = tuple(h in heading.blobs for h in heading.names) + values = cur.fetchone() + while values: + if as_dict: + yield OrderedDict((field_name, unpack(values[field_name])) if up else (field_name, values[field_name]) + for field_name, up in zip(heading.names, do_unpack)) - def __repr__(self): - limit = config['display.limit'] - width = config['display.width'] - rel = self.project(*self.heading.non_blobs) # project out blobs - template = '%%-%d.%ds' % (width, width) - columns = rel.heading.names - repr_string = ' '.join([template % column for column in columns]) + '\n' - repr_string += ' '.join(['+' + '-' * (width - 2) + '+' for _ in columns]) + '\n' - for tup in rel.fetch(limit=limit): - repr_string += ' '.join([template % column for column in tup]) + '\n' - if len(self) > limit: - repr_string += '...\n' - repr_string += ' (%d tuples)\n' % len(self) - return repr_string + else: + yield tuple(unpack(value) if up else value for up, value in zip(do_unpack, values)) + + values = cur.fetchone() def __iter__(self): """ Iterator that yields individual tuples of the current table dictionaries. """ - cur = self.cursor() - heading = self.heading # construct once for efficiency - do_unpack = tuple(h in heading.blobs for h in heading.names) - values = cur.fetchone() - while values: - yield {field_name: unpack(value) if up else value - for field_name, up, value in zip(heading.names, do_unpack, values)} - values = cur.fetchone() + yield from self.values(as_dict=True) + + def keys(self, *args, **kwargs): + """ + Iterator that returns primary keys. + """ + yield from self.project().values(*args, **kwargs) @property def where_clause(self): @@ -272,7 +297,7 @@ def make_condition(arg): elif isinstance(r, np.ndarray) or isinstance(r, list): r = '(' + ') OR ('.join([make_condition(q) for q in r]) + ')' elif isinstance(r, RelationalOperand): - common_attributes = ','.join([q for q in self.heading.names if r.heading.names]) + common_attributes = ','.join([q for q in self.heading.names if q in r.heading.names]) r = '(%s) in (SELECT %s FROM %s%s)' % ( common_attributes, common_attributes, r.from_clause, r.where_clause) @@ -372,6 +397,15 @@ def from_clause(self): self._arg.from_clause, self._group.from_clause, '`,`'.join(self.heading.primary_key)) + def _restrict(self, restriction): + """ + Projection is enclosed in Subquery when restricted if it has renamed attributes + """ + if self.heading.computed: + return Subquery(self) & restriction + else: + return super()._restrict(restriction) + class Subquery(RelationalOperand): """ diff --git a/datajoint/schema.py b/datajoint/schema.py new file mode 100644 index 000000000..f1c82b351 --- /dev/null +++ b/datajoint/schema.py @@ -0,0 +1,60 @@ +import pymysql +import logging + +from . import DataJointError, conn +from .heading import Heading + +logger = logging.getLogger(__name__) + + +class schema: + """ + A schema object can be used as a decorator that associates a Relation class to a database as + well as a namespace for looking up foreign key references. + """ + + def __init__(self, database, context, connection=None): + """ + :param database: name of the database to associate the decorated class with + :param context: dictionary for looking up foreign keys references, usually set to locals() + :param connection: Connection object. Defaults to datajoint.conn() + """ + if connection is None: + connection = conn() + self.database = database + self.connection = connection + self.context = context + + # if the database does not exist, create it + cur = connection.query("SHOW DATABASES LIKE '{database}'".format(database=database)) + if cur.rowcount == 0: + logger.info("Database `{database}` could not be found. " + "Attempting to create the database.".format(database=database)) + try: + connection.query("CREATE DATABASE `{database}`".format(database=database)) + logger.info('Created database `{database}`.'.format(database=database)) + except pymysql.OperationalError: + raise DataJointError("Database named `{database}` was not defined, and" + "an attempt to create has failed. Check" + " permissions.".format(database=database)) + + def __call__(self, cls): + """ + The decorator binds its argument class object to a database + :param cls: class to be decorated + """ + # class-level attributes + cls.database = self.database + cls._connection = self.connection + cls._heading = Heading() + cls._context = self.context + + # trigger table declaration by requesting the heading from an instance + instance = cls() + instance.heading + instance.prepare() + return cls + + @property + def jobs(self): + return self.connection.jobs[self.database] diff --git a/datajoint/user_relations.py b/datajoint/user_relations.py index 1b0352a03..26a21b0db 100644 --- a/datajoint/user_relations.py +++ b/datajoint/user_relations.py @@ -1,3 +1,7 @@ +""" +Hosts the table tiers, user relations should be derived from. +""" + import re import abc from datajoint.relation import Relation @@ -6,26 +10,65 @@ class Manual(Relation): + """ + Inherit from this class if the table's values are entered manually. + """ + @property def table_name(self): + """ + :returns: the table name of the table formatted for mysql. + """ return from_camel_case(self.__class__.__name__) class Lookup(Relation): + """ + Inherit from this class if the table's values are for lookup. This is + currently equivalent to defining the table as Manual and serves semantic + purposes only. + """ + @property def table_name(self): + """ + :returns: the table name of the table formatted for mysql. + """ return '#' + from_camel_case(self.__class__.__name__) + def prepare(self): + """ + Checks whether the instance has a property called `contents` and inserts its elements. + """ + if hasattr(self, 'contents'): + self.insert(self.contents, ignore_errors=True) + class Imported(Relation, AutoPopulate): + """ + Inherit from this class if the table's values are imported from external data sources. + The inherited class must at least provide the function `_make_tuples`. + """ + @property def table_name(self): + """ + :returns: the table name of the table formatted for mysql. + """ return "_" + from_camel_case(self.__class__.__name__) class Computed(Relation, AutoPopulate): + """ + Inherit from this class if the table's values are computed from other relations in the schema. + The inherited class must at least provide the function `_make_tuples`. + """ + @property def table_name(self): + """ + :returns: the table name of the table formatted for mysql. + """ return "__" + from_camel_case(self.__class__.__name__) @@ -33,11 +76,25 @@ class Subordinate: """ Mix-in to make computed tables subordinate """ + @property - def populate_relation(self): + def populated_from(self): + """ + Overrides the `populate_from` property because subtables should not be populated + directly. + + :return: None + """ return None def _make_tuples(self, key): + """ + Overrides the `_make_tuples` property because subtables should not be populated + directly. Raises an error if this method is called (usually from populate of the + inheriting object). + + :raises: NotImplementedError + """ raise NotImplementedError('Subtables should not be populated directly.') @@ -50,6 +107,7 @@ def from_camel_case(s): >>>from_camel_case("TableName") "table_name" """ + def convert(match): return ('_' if match.groups()[0] else '') + match.group(0).lower() @@ -57,4 +115,3 @@ def convert(match): raise DataJointError( 'ClassName must be alphanumeric in CamelCase, begin with a capital letter') return re.sub(r'(\B[A-Z])|(\b[A-Z])', convert, s) - diff --git a/datajoint/utils.py b/datajoint/utils.py index f4b0edb57..a142e2707 100644 --- a/datajoint/utils.py +++ b/datajoint/utils.py @@ -1,10 +1,12 @@ +import numpy as np + def user_choice(prompt, choices=("yes", "no"), default=None): """ Prompts the user for confirmation. The default value, if any, is capitalized. :param prompt: Information to display to the user. :param choices: an iterable of possible choices. - :param default=None: default choice + :param default: default choice :return: the user's choice """ choice_list = ', '.join((choice.title() if choice == default else choice for choice in choices)) @@ -14,3 +16,16 @@ def user_choice(prompt, choices=("yes", "no"), default=None): response = response if response else default valid = response in choices return response + + +def group_by(rel, *attributes, sortby=None): + r = rel.project(*attributes).fetch() + dtype2 = np.dtype({name:r.dtype.fields[name] for name in attributes}) + r2 = np.unique(np.ndarray(r.shape, dtype2, r, 0, r.strides)) + r2.sort(order=sortby if sortby is not None else attributes) + for nk in r2: + restr = ' and '.join(["%s='%s'" % (fn, str(v)) for fn, v in zip(r2.dtype.names, nk)]) + if len(nk) == 1: + yield nk[0], rel & restr + else: + yield nk, rel & restr \ No newline at end of file diff --git a/doc/source/autopopulate.rst b/doc/source/autopopulate.rst new file mode 100644 index 000000000..75d5a632f --- /dev/null +++ b/doc/source/autopopulate.rst @@ -0,0 +1,5 @@ +AutoPopulate +============ + +.. automodule:: datajoint.autopopulate + :members: diff --git a/doc/source/blob.rst b/doc/source/blob.rst new file mode 100644 index 000000000..bea29b965 --- /dev/null +++ b/doc/source/blob.rst @@ -0,0 +1,5 @@ +Serialization Module +==================== + +.. automodule:: datajoint.blob + :members: diff --git a/doc/source/declare.rst b/doc/source/declare.rst new file mode 100644 index 000000000..3817225d9 --- /dev/null +++ b/doc/source/declare.rst @@ -0,0 +1,6 @@ +Mysql table declaration +======================= + +.. automodule:: datajoint.declare + :members: + :inherited-members: diff --git a/doc/source/index.rst b/doc/source/index.rst index b02acdad1..d70f80705 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -11,10 +11,11 @@ API: .. toctree:: :maxdepth: 2 + tiers.rst relation.rst - relational_operand.rst connection.rst - + blob.rst + declare.rst Indices and tables ================== diff --git a/doc/source/tiers.rst b/doc/source/tiers.rst new file mode 100644 index 000000000..b1ebd1ebc --- /dev/null +++ b/doc/source/tiers.rst @@ -0,0 +1,9 @@ +Datajoint Tiers +=============== + +.. toctree:: + :maxdepth: 2 + + user_relations.rst + autopopulate.rst + diff --git a/doc/source/user_relations.rst b/doc/source/user_relations.rst new file mode 100644 index 000000000..73724209b --- /dev/null +++ b/doc/source/user_relations.rst @@ -0,0 +1,5 @@ +User Tiers +========== + +.. automodule:: datajoint.user_relations + :members: diff --git a/tests/__init__.py b/tests/__init__.py index 611f34bc2..2bf44d4cf 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -5,98 +5,44 @@ after the test. """ -import pymysql +__author__ = 'Edgar Walker, Fabian Sinz, Dimitri Yatsenko' + import logging from os import environ import datajoint as dj -logging.basicConfig(level=logging.DEBUG) +__all__ = ['__author__', 'PREFIX', 'CONN_INFO'] -# Connection information for testing -CONN_INFO = { - 'host': environ.get('DJ_TEST_HOST', 'localhost'), - 'user': environ.get('DJ_TEST_USER', 'datajoint'), - 'passwd': environ.get('DJ_TEST_PASSWORD', 'datajoint') -} +logging.basicConfig(level=logging.DEBUG) -conn = dj.conn(**CONN_INFO) +# Connection for testing +CONN_INFO = dict( + host=environ.get('DJ_TEST_HOST', 'localhost'), + user=environ.get('DJ_TEST_USER', 'datajoint'), + passwd=environ.get('DJ_TEST_PASSWORD', 'datajoint')) # Prefix for all databases used during testing PREFIX = environ.get('DJ_TEST_DB_PREFIX', 'djtest') -# Bare connection used for verification of query results -BASE_CONN = pymysql.connect(**CONN_INFO) -BASE_CONN.autocommit(True) -def setup(): - cleanup() +def setup_package(): + """ + Package-level unit test setup + :return: + """ dj.config['safemode'] = False -def teardown(): - cur = BASE_CONN.cursor() - # cancel any unfinished transactions - cur.execute("ROLLBACK") - # start a transaction now - cur.execute("SHOW DATABASES LIKE '{}\_%'".format(PREFIX)) - dbs = [x[0] for x in cur.fetchall()] - cur.execute('SET FOREIGN_KEY_CHECKS=0') # unset foreign key check while deleting - for db in dbs: - cur.execute('DROP DATABASE `{}`'.format(db)) - cur.execute('SET FOREIGN_KEY_CHECKS=1') # set foreign key check back on - cur.execute("COMMIT") -def cleanup(): +def teardown_package(): """ - Removes all databases with name starting with the prefix. + Package-level unit test teardown. + Removes all databases with name starting with PREFIX. To deal with possible foreign key constraints, it will unset and then later reset FOREIGN_KEY_CHECKS flag """ - - cur = BASE_CONN.cursor() - # cancel any unfinished transactions - cur.execute("ROLLBACK") - # start a transaction now - cur.execute("START TRANSACTION WITH CONSISTENT SNAPSHOT") - cur.execute("SHOW DATABASES LIKE '{}\_%'".format(PREFIX)) - dbs = [x[0] for x in cur.fetchall()] - cur.execute('SET FOREIGN_KEY_CHECKS=0') # unset foreign key check while deleting - for db in dbs: - cur.execute("USE %s" % (db)) - cur.execute("SHOW TABLES") - for table in [x[0] for x in cur.fetchall()]: - cur.execute('DELETE FROM `{}`'.format(table)) - cur.execute('SET FOREIGN_KEY_CHECKS=1') # set foreign key check back on - cur.execute("COMMIT") -# -# def setup_sample_db(): -# """ -# Helper method to setup databases with tables to be used -# during the test -# """ -# cur = BASE_CONN.cursor() -# cur.execute("CREATE DATABASE IF NOT EXISTS `{}_test1`".format(PREFIX)) -# cur.execute("CREATE DATABASE IF NOT EXISTS `{}_test2`".format(PREFIX)) -# query1 = """ -# CREATE TABLE `{prefix}_test1`.`subjects` -# ( -# subject_id SMALLINT COMMENT 'Unique subject ID', -# subject_name VARCHAR(255) COMMENT 'Subject name', -# subject_email VARCHAR(255) COMMENT 'Subject email address', -# PRIMARY KEY (subject_id) -# ) -# """.format(prefix=PREFIX) -# cur.execute(query1) -# query2 = """ -# CREATE TABLE `{prefix}_test2`.`experimenter` -# ( -# experimenter_id SMALLINT COMMENT 'Unique experimenter ID', -# experimenter_name VARCHAR(255) COMMENT 'Experimenter name', -# PRIMARY KEY (experimenter_id) -# )""".format(prefix=PREFIX) -# cur.execute(query2) -# -# -# -# -# -# + conn = dj.conn(**CONN_INFO) + conn.query('SET FOREIGN_KEY_CHECKS=0') + cur = conn.query('SHOW DATABASES LIKE "{}\_%%"'.format(PREFIX)) + for db in cur.fetchall(): + conn.query('DROP DATABASE `{}`'.format(db[0])) + conn.query('SET FOREIGN_KEY_CHECKS=1') diff --git a/tests/schema.py b/tests/schema.py new file mode 100644 index 000000000..981e205eb --- /dev/null +++ b/tests/schema.py @@ -0,0 +1,130 @@ +""" +Test schema definition +""" + +import random +import datajoint as dj +from . import PREFIX, CONN_INFO + +schema = dj.schema(PREFIX + '_test1', locals(), connection=dj.conn(**CONN_INFO)) + + +@schema +class User(dj.Lookup): + definition = """ # lab members + username: varchar(12) + """ + contents = [['Jake'], ['Cathryn'], ['Shan'], ['Fabian'], ['Edgar'], ['George'], ['Dimitri']] + + +@schema +class Subject(dj.Manual): + definition = """ # Basic information about animal subjects used in experiments + subject_id :int # unique subject id + --- + real_id :varchar(40) # real-world name. Omit if the same as subject_id + species = "mouse" :enum('mouse', 'monkey', 'human') + date_of_birth :date + subject_notes :varchar(4000) + unique index (real_id, species) + """ + + contents = [ + [1551, '1551', 'mouse', '2015-04-01', 'genetically engineered super mouse'], + [10, 'Curious George', 'monkey', '2008-06-30', ''], + [1552, '1552', 'mouse', '2015-06-15', ''], + [1553, '1553', 'mouse', '2016-07-01', '']] + + def prepare(self): + self.insert(self.contents, ignore_errors=True) + + +@schema +class Experiment(dj.Imported): + definition = """ # information about experiments + -> Subject + experiment_id :smallint # experiment number for this subject + --- + experiment_date :date # date when experiment was started + -> User + data_path="" :varchar(255) # file path to recorded data + notes="" :varchar(2048) # e.g. purpose of experiment + entry_time=CURRENT_TIMESTAMP :timestamp # automatic timestamp + """ + + def _make_tuples(self, key): + """ + populate with random data + """ + from datetime import date, timedelta + experiments_per_subject = 5 + users = User().fetch()['username'] + for experiment_id in range(experiments_per_subject): + self.insert1( + dict(key, + experiment_id=experiment_id, + experiment_date=(date.today()-timedelta(random.expovariate(1/30))).isoformat(), + username=random.choice(users))) + + +@schema +class Trial(dj.Imported): + definition = """ # a trial within an experiment + -> Experiment + trial_id :smallint # trial number + --- + start_time :double # (s) + """ + + def _make_tuples(self, key): + """ + populate with random data (pretend reading from raw files) + """ + for trial_id in range(10): + self.insert1( + dict(key, + trial_id=trial_id, + start_time=random.random()*1e9 + )) + + +@schema +class Ephys(dj.Imported): + definition = """ # some kind of electrophysiological recording + -> Trial + ---- + sampling_frequency :double # (Hz) + duration :double # (s) + """ + + def _make_tuples(self, key): + """ + populate with random data + """ + row = dict(key, + sampling_frequency=16000, + duration=random.expovariate(1/30)) + self.insert1(row) + EphysChannel().fill(key, number_samples=round(row.duration*row.sampling_frequency)) + + +@schema +class EphysChannel(dj.Subordinate, dj.Imported): + definition = """ # subtable containing individual channels + -> Ephys + channel :tinyint unsigned # channel number within Ephys + ---- + voltage :longblob + """ + + def fill(self, key, number_samples): + """ + populate random trace of specified length + """ + import numpy as np + for channel in range(16): + self.insert1( + dict(key, + channel=channel, + voltage=np.float32(np.random.randn(number_samples)) + )) diff --git a/tests/schemata/__init__.py b/tests/schemata/__init__.py deleted file mode 100644 index 9fa8b9ad1..000000000 --- a/tests/schemata/__init__.py +++ /dev/null @@ -1 +0,0 @@ -__author__ = "eywalker, fabiansinz" \ No newline at end of file diff --git a/tests/schemata/schema1/__init__.py b/tests/schemata/schema1/__init__.py deleted file mode 100644 index cae90cec9..000000000 --- a/tests/schemata/schema1/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# __author__ = 'eywalker' -# import datajoint as dj -# -# print(__name__) -# from .test3 import * \ No newline at end of file diff --git a/tests/schemata/schema1/test2.py b/tests/schemata/schema1/test2.py deleted file mode 100644 index aded2e4fb..000000000 --- a/tests/schemata/schema1/test2.py +++ /dev/null @@ -1,48 +0,0 @@ -# """ -# Test 2 Schema definition -# """ -# __author__ = 'eywalker' -# -# import datajoint as dj -# from . import test1 as alias -# #from ..schema2 import test2 as test1 -# -# -# # references to another schema -# class Experiments(dj.Relation): -# definition = """ -# test2.Experiments (manual) # Basic subject info -# -> test1.Subjects -# experiment_id : int # unique experiment id -# --- -# real_id : varchar(40) # real-world name -# species = "mouse" : enum('mouse', 'monkey', 'human') # species -# """ -# -# -# # references to another schema -# class Conditions(dj.Relation): -# definition = """ -# test2.Conditions (manual) # Subject conditions -# -> alias.Subjects -# condition_name : varchar(255) # description of the condition -# """ -# -# -# class FoodPreference(dj.Relation): -# definition = """ -# test2.FoodPreference (manual) # Food preference of each subject -# -> animals.Subjects -# preferred_food : enum('banana', 'apple', 'oranges') -# """ -# -# -# class Session(dj.Relation): -# definition = """ -# test2.Session (manual) # Experiment sessions -# -> test1.Subjects -# -> test2.Experimenter -# session_id : int # unique session id -# --- -# session_comment : varchar(255) # comment about the session -# """ \ No newline at end of file diff --git a/tests/schemata/schema1/test3.py b/tests/schemata/schema1/test3.py deleted file mode 100644 index e00a01afb..000000000 --- a/tests/schemata/schema1/test3.py +++ /dev/null @@ -1,21 +0,0 @@ -# """ -# Test 3 Schema definition - no binding, no conn -# -# To be bound at the package level -# """ -# __author__ = 'eywalker' -# -# import datajoint as dj -# -# -# class Subjects(dj.Relation): -# definition = """ -# schema1.Subjects (manual) # Basic subject info -# -# subject_id : int # unique subject id -# dob : date # date of birth -# --- -# real_id : varchar(40) # real-world name -# species = "mouse" : enum('mouse', 'monkey', 'human') # species -# """ -# diff --git a/tests/schemata/schema1/test4.py b/tests/schemata/schema1/test4.py deleted file mode 100644 index 9860cb030..000000000 --- a/tests/schemata/schema1/test4.py +++ /dev/null @@ -1,17 +0,0 @@ -# """ -# Test 1 Schema definition - fully bound and has connection object -# """ -# __author__ = 'fabee' -# -# import datajoint as dj -# -# -# class Matrix(dj.Relation): -# definition = """ -# test4.Matrix (manual) # Some numpy array -# -# matrix_id : int # unique matrix id -# --- -# data : longblob # data -# comment : varchar(1000) # comment -# """ diff --git a/tests/schemata/schema2/__init__.py b/tests/schemata/schema2/__init__.py deleted file mode 100644 index d79e02cc3..000000000 --- a/tests/schemata/schema2/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# __author__ = 'eywalker' -# from .test1 import * \ No newline at end of file diff --git a/tests/schemata/schema2/test1.py b/tests/schemata/schema2/test1.py deleted file mode 100644 index 4005aa670..000000000 --- a/tests/schemata/schema2/test1.py +++ /dev/null @@ -1,16 +0,0 @@ -# """ -# Test 2 Schema definition -# """ -# __author__ = 'eywalker' -# -# import datajoint as dj -# -# -# class Subjects(dj.Relation): -# definition = """ -# schema2.Subjects (manual) # Basic subject info -# pop_id : int # unique experiment id -# --- -# real_id : varchar(40) # real-world name -# species = "mouse" : enum('mouse', 'monkey', 'human') # species -# """ \ No newline at end of file diff --git a/tests/schemata/test1.py b/tests/schemata/test1.py deleted file mode 100644 index a9a03be64..000000000 --- a/tests/schemata/test1.py +++ /dev/null @@ -1,175 +0,0 @@ -""" -Test 1 Schema definition -""" -__author__ = 'eywalker' - -import datajoint as dj -# from .. import schema2 -from .. import PREFIX - -testschema = dj.schema(PREFIX + '_test1', locals()) - -@testschema -class Subjects(dj.Manual): - definition = """ - #Basic subject info - - subject_id : int # unique subject id - --- - real_id : varchar(40) # real-world name - species = "mouse" : enum('mouse', 'monkey', 'human') # species - """ - -# test for shorthand -@testschema -class Animals(dj.Manual): - definition = """ - # Listing of all info - - -> Subjects - --- - animal_dob :date # date of birth - """ - -@testschema -class Trials(dj.Manual): - definition = """ - # info about trials - - -> Subjects - trial_id : int - --- - outcome : int # result of experiment - - notes="" : varchar(4096) # other comments - trial_ts=CURRENT_TIMESTAMP : timestamp # automatic - """ - -@testschema -class Matrix(dj.Manual): - definition = """ - # Some numpy array - - matrix_id : int # unique matrix id - --- - data : longblob # data - comment : varchar(1000) # comment - """ - - -@testschema -class SquaredScore(dj.Computed): - definition = """ - # cumulative outcome of trials - - -> Subjects - -> Trials - --- - squared : int # squared result of Trials outcome - """ - - @property - def populate_relation(self): - return Subjects() * Trials() - - def _make_tuples(self, key): - tmp = (Trials() & key).fetch1() - tmp2 = SquaredSubtable() & key - - self.insert(dict(key, squared=tmp['outcome']**2)) - - ss = SquaredSubtable() - - for i in range(10): - key['dummy'] = i - ss.insert(key) - -@testschema -class WrongImplementation(dj.Computed): - definition = """ - # ignore - - -> Subjects - -> Trials - --- - dummy : int # ignore - """ - - @property - def populate_relation(self): - return {'subject_id':2} - - def _make_tuples(self, key): - pass - -class ErrorGenerator(dj.Computed): - definition = """ - # ignore - - -> Subjects - -> Trials - --- - dummy : int # ignore - """ - - @property - def populate_relation(self): - return Subjects() * Trials() - - def _make_tuples(self, key): - raise Exception("This is for testing") - -@testschema -class SquaredSubtable(dj.Subordinate, dj.Manual): - definition = """ - # cumulative outcome of trials - - -> SquaredScore - dummy : int # dummy primary attribute - --- - """ -# -# -# # test reference to another table in same schema -# class Experiments(dj.Relation): -# definition = """ -# test1.Experiments (imported) # Experiment info -# -> test1.Subjects -# exp_id : int # unique id for experiment -# --- -# exp_data_file : varchar(255) # data file -# """ -# -# -# # refers to a table in dj_test2 (bound to test2) but without a class -# class Sessions(dj.Relation): -# definition = """ -# test1.Sessions (manual) # Experiment sessions -# -> test1.Subjects -# -> test2.Experimenter -# session_id : int # unique session id -# --- -# session_comment : varchar(255) # comment about the session -# """ -# -# -# class Match(dj.Relation): -# definition = """ -# test1.Match (manual) # Match between subject and color -# -> schema2.Subjects -# --- -# dob : date # date of birth -# """ -# -# -# # this tries to reference a table in database directly without ORM -# class TrainingSession(dj.Relation): -# definition = """ -# test1.TrainingSession (manual) # training sessions -# -> `dj_test2`.Experimenter -# session_id : int # training session id -# """ -# -# -# class Empty(dj.Relation): -# pass diff --git a/tests/test_blob.py b/tests/test_blob.py index 322be02cc..2265a1192 100644 --- a/tests/test_blob.py +++ b/tests/test_blob.py @@ -1,6 +1,3 @@ -from datajoint import DataJointError - -__author__ = 'fabee' import numpy as np from datajoint.blob import pack, unpack from numpy.testing import assert_array_equal, raises @@ -19,15 +16,12 @@ def test_pack(): x = np.int16(np.random.randn(1, 2, 3)) assert_array_equal(x, unpack(pack(x)), "Arrays do not match!") -@raises(DataJointError) -def test_error(): - pack(dict()) def test_complex(): z = np.random.randn(8, 10) + 1j*np.random.randn(8,10) assert_array_equal(z, unpack(pack(z)), "Arrays do not match!") - z = np.random.randn(10)+ 1j*np.random.randn(10) + z = np.random.randn(10) + 1j*np.random.randn(10) assert_array_equal(z, unpack(pack(z)), "Arrays do not match!") x = np.float32(np.random.randn(3, 4, 5)) + 1j*np.float32(np.random.randn(3, 4, 5)) diff --git a/tests/test_connection.py b/tests/test_connection.py index b94a07adf..11f50beea 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,17 +1,12 @@ """ Collection of test cases to test connection module. """ -from tests.schemata.test1 import Subjects -__author__ = 'eywalker, fabee' -from . import CONN_INFO, PREFIX, BASE_CONN, cleanup from nose.tools import assert_true, assert_raises, assert_equal, raises import datajoint as dj -from datajoint import DataJointError import numpy as np - -def setup(): - cleanup() +from datajoint import DataJointError +from . import CONN_INFO, PREFIX def test_dj_conn(): @@ -24,88 +19,86 @@ def test_dj_conn(): def test_persistent_dj_conn(): """ - conn() method should provide persistent connection - across calls. + conn() method should provide persistent connection across calls. + Setting reset=True should create a new persistent connection. """ c1 = dj.conn(**CONN_INFO) c2 = dj.conn() + c3 = dj.conn(**CONN_INFO) + c4 = dj.conn(reset=True, **CONN_INFO) + c5 = dj.conn(**CONN_INFO) assert_true(c1 is c2) - - -def test_dj_conn_reset(): - """ - Passing in reset=True should allow for new persistent - connection to be created. - """ - c1 = dj.conn(**CONN_INFO) - c2 = dj.conn(reset=True, **CONN_INFO) - assert_true(c1 is not c2) + assert_true(c1 is c3) + assert_true(c1 is not c4) + assert_true(c4 is c5) def test_repr(): c1 = dj.conn(**CONN_INFO) - assert_true('disconnected' not in c1.__repr__() and 'connected' in c1.__repr__()) - -def test_del(): - c1 = dj.conn(**CONN_INFO) - assert_true('disconnected' not in c1.__repr__() and 'connected' in c1.__repr__()) - del c1 - + assert_true('disconnected' not in repr(c1) and 'connected' in repr(c1)) -class TestContextManager(object): - def __init__(self): - self.relvar = None - self.setup() - +class TestTransactions: """ - Test cases for FreeRelation objects + test transaction management """ - def setup(self): - """ - Create a connection object and prepare test modules - as follows: - test1 - has conn and bounded + schema = dj.schema(PREFIX + '_transactions', locals(), connection=dj.conn(**CONN_INFO)) + + @schema + class Subjects(dj.Manual): + definition = """ + #Basic subject + subject_id : int # unique subject id + --- + real_id : varchar(40) # real-world name + species = "mouse" : enum('mouse', 'monkey', 'human') # species """ - cleanup() # drop all databases with PREFIX - self.conn = dj.conn() - self.relvar = Subjects() + + def __init__(self): + self.relation = self.Subjects() + self.conn = dj.conn(**CONN_INFO) def teardown(self): - cleanup() + self.relation.delete_quick() def test_active(self): - with self.conn.transaction() as conn: + with self.conn.transaction as conn: assert_true(conn.in_transaction, "Transaction is not active") - def test_rollback(self): - - tmp = np.array([(1,'Peter','mouse'),(2, 'Klara', 'monkey')], - dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) + def test_transaction_rollback(self): + """Test transaction cancellation using a with statement""" + tmp = np.array([ + (1, 'Peter', 'mouse'), + (2, 'Klara', 'monkey') + ], self.relation.heading.as_dtype) - self.relvar.insert(tmp[0]) + self.relation.delete() + with self.conn.transaction: + self.relation.insert1(tmp[0]) try: - with self.conn.transaction(): - self.relvar.insert(tmp[1]) - raise DataJointError("Just to test") - except DataJointError as e: + with self.conn.transaction: + self.relation.insert1(tmp[1]) + raise DataJointError("Testing rollback") + except DataJointError: pass - testt2 = (self.relvar & 'subject_id = 2').fetch() - assert_equal(len(testt2), 0, "Length is not 0. Expected because rollback should have happened.") + assert_equal(len(self.relation), 1, + "Length is not 1. Expected because rollback should have happened.") + assert_equal(len(self.relation & 'subject_id = 2'), 0, + "Length is not 0. Expected because rollback should have happened.") def test_cancel(self): - """Tests cancelling a transaction""" - tmp = np.array([(1,'Peter','mouse'),(2, 'Klara', 'monkey')], - dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) - - self.relvar.insert(tmp[0]) - with self.conn.transaction() as conn: - self.relvar.insert(tmp[1]) - conn.cancel_transaction() - - testt2 = (self.relvar & 'subject_id = 2').fetch() - assert_equal(len(testt2), 0, "Length is not 0. Expected because rollback should have happened.") - - - + """Tests cancelling a transaction explicitly""" + tmp = np.array([ + (1, 'Peter', 'mouse'), + (2, 'Klara', 'monkey') + ], self.relation.heading.as_dtype) + self.relation.delete_quick() + self.relation.insert1(tmp[0]) + self.conn.start_transaction() + self.relation.insert1(tmp[1]) + self.conn.cancel_transaction() + assert_equal(len(self.relation), 1, + "Length is not 1. Expected because rollback should have happened.") + assert_equal(len(self.relation & 'subject_id = 2'), 0, + "Length is not 0. Expected because rollback should have happened.") diff --git a/tests/test_declare.py b/tests/test_declare.py new file mode 100644 index 000000000..e15b49180 --- /dev/null +++ b/tests/test_declare.py @@ -0,0 +1,60 @@ +from nose.tools import assert_true, assert_false, assert_equal, assert_list_equal +from . import schema + + +class TestDeclare: + def __init__(self): + self.user = schema.User() + self.subject = schema.Subject() + self.experiment = schema.Experiment() + self.trial = schema.Trial() + self.ephys = schema.Ephys() + self.channel = schema.EphysChannel() + + def test_attributes(self): + assert_list_equal(self.subject.heading.names, + ['subject_id', 'real_id', 'species', 'date_of_birth', 'subject_notes']) + assert_list_equal(self.subject.primary_key, + ['subject_id']) + assert_true(self.subject.heading.attributes['subject_id'].numeric) + assert_false(self.subject.heading.attributes['real_id'].numeric) + + experiment = schema.Experiment() + assert_list_equal(experiment.heading.names, + ['subject_id', 'experiment_id', 'experiment_date', + 'username', 'data_path', + 'notes', 'entry_time']) + assert_list_equal(experiment.primary_key, + ['subject_id', 'experiment_id']) + + assert_list_equal(self.trial.heading.names, + ['subject_id', 'experiment_id', 'trial_id', 'start_time']) + assert_list_equal(self.trial.primary_key, + ['subject_id', 'experiment_id', 'trial_id']) + + assert_list_equal(self.ephys.heading.names, + ['subject_id', 'experiment_id', 'trial_id', 'sampling_frequency', 'duration']) + assert_list_equal(self.ephys.primary_key, + ['subject_id', 'experiment_id', 'trial_id']) + + assert_list_equal(self.channel.heading.names, + ['subject_id', 'experiment_id', 'trial_id', 'channel', 'voltage']) + assert_list_equal(self.channel.primary_key, + ['subject_id', 'experiment_id', 'trial_id', 'channel']) + assert_true(self.channel.heading.attributes['voltage'].is_blob) + + def test_dependencies(self): + assert_equal(self.user.references, [self.experiment.full_table_name]) + assert_equal(self.experiment.referenced, [self.user.full_table_name]) + + assert_equal(self.subject.children, [self.experiment.full_table_name]) + assert_equal(self.experiment.parents, [self.subject.full_table_name]) + + assert_equal(self.experiment.children, [self.trial.full_table_name]) + assert_equal(self.trial.parents, [self.experiment.full_table_name]) + + assert_equal(self.trial.children, [self.ephys.full_table_name]) + assert_equal(self.ephys.parents, [self.trial.full_table_name]) + + assert_equal(self.ephys.children, [self.channel.full_table_name]) + assert_equal(self.channel.parents, [self.ephys.full_table_name]) diff --git a/tests/test_free_relation.py b/tests/test_free_relation.py deleted file mode 100644 index 285bf7487..000000000 --- a/tests/test_free_relation.py +++ /dev/null @@ -1,205 +0,0 @@ -# """ -# Collection of test cases for base module. Tests functionalities such as -# creating tables using docstring table declarations -# """ -# from .schemata import schema1, schema2 -# from .schemata.schema1 import test1, test2, test3 -# -# -# __author__ = 'eywalker' -# -# from . import BASE_CONN, CONN_INFO, PREFIX, cleanup, setup_sample_db -# from datajoint.connection import Connection -# from nose.tools import assert_raises, assert_equal, assert_regexp_matches, assert_false, assert_true, raises -# from datajoint import DataJointError -# -# -# def setup(): -# """ -# Setup connections and bindings -# """ -# pass -# -# -# class TestRelationInstantiations(object): -# """ -# Test cases for instantiating Relation objects -# """ -# def __init__(self): -# self.conn = None -# -# def setup(self): -# """ -# Create a connection object and prepare test modules -# as follows: -# test1 - has conn and bounded -# """ -# self.conn = Connection(**CONN_INFO) -# cleanup() # drop all databases with PREFIX -# #test1.conn = self.conn -# #self.conn.bind(test1.__name__, PREFIX+'_test1') -# -# #test2.conn = self.conn -# -# #test3.__dict__.pop('conn', None) # make sure conn is not defined in test3 -# test1.__dict__.pop('conn', None) -# schema1.__dict__.pop('conn', None) # make sure conn is not defined at schema level -# -# -# def teardown(self): -# cleanup() -# -# -# def test_instantiation_from_unbound_module_should_fail(self): -# """ -# Attempting to instantiate a Relation derivative from a module with -# connection defined but not bound to a database should raise error -# """ -# test1.conn = self.conn -# with assert_raises(DataJointError) as e: -# test1.Subjects() -# assert_regexp_matches(e.exception.args[0], r".*not bound.*") -# -# def test_instantiation_from_module_without_conn_should_fail(self): -# """ -# Attempting to instantiate a Relation derivative from a module that lacks -# `conn` object should raise error -# """ -# with assert_raises(DataJointError) as e: -# test1.Subjects() -# assert_regexp_matches(e.exception.args[0], r".*define.*conn.*") -# -# def test_instantiation_of_base_derivatives(self): -# """ -# Test instantiation and initialization of objects derived from -# Relation class -# """ -# test1.conn = self.conn -# self.conn.bind(test1.__name__, PREFIX + '_test1') -# s = test1.Subjects() -# assert_equal(s.dbname, PREFIX + '_test1') -# assert_equal(s.conn, self.conn) -# assert_equal(s.definition, test1.Subjects.definition) -# -# def test_packagelevel_binding(self): -# schema2.conn = self.conn -# self.conn.bind(schema2.__name__, PREFIX + '_test1') -# s = schema2.test1.Subjects() -# -# -# class TestRelationDeclaration(object): -# """ -# Test declaration (creation of table) from -# definition in Relation under various circumstances -# """ -# -# def setup(self): -# cleanup() -# -# self.conn = Connection(**CONN_INFO) -# test1.conn = self.conn -# self.conn.bind(test1.__name__, PREFIX + '_test1') -# test2.conn = self.conn -# self.conn.bind(test2.__name__, PREFIX + '_test2') -# -# def test_is_declared(self): -# """ -# The table should not be created immediately after instantiation, -# but should be created when declare method is called -# :return: -# """ -# s = test1.Subjects() -# assert_false(s.is_declared) -# s.declare() -# assert_true(s.is_declared) -# -# def test_calling_heading_should_trigger_declaration(self): -# s = test1.Subjects() -# assert_false(s.is_declared) -# a = s.heading -# assert_true(s.is_declared) -# -# def test_foreign_key_ref_in_same_schema(self): -# s = test1.Experiments() -# assert_true('subject_id' in s.heading.primary_key) -# -# def test_foreign_key_ref_in_another_schema(self): -# s = test2.Experiments() -# assert_true('subject_id' in s.heading.primary_key) -# -# def test_aliased_module_name_should_resolve(self): -# """ -# Module names that were aliased in the definition should -# be properly resolved. -# """ -# s = test2.Conditions() -# assert_true('subject_id' in s.heading.primary_key) -# -# def test_reference_to_unknown_module_in_definition_should_fail(self): -# """ -# Module names in table definition that is not aliased via import -# results in error -# """ -# s = test2.FoodPreference() -# with assert_raises(DataJointError) as e: -# s.declare() -# -# -# class TestRelationWithExistingTables(object): -# """ -# Test base derivatives behaviors when some of the tables -# already exists in the database -# """ -# def setup(self): -# cleanup() -# self.conn = Connection(**CONN_INFO) -# setup_sample_db() -# test1.conn = self.conn -# self.conn.bind(test1.__name__, PREFIX + '_test1') -# test2.conn = self.conn -# self.conn.bind(test2.__name__, PREFIX + '_test2') -# self.conn.load_headings(force=True) -# -# schema2.conn = self.conn -# self.conn.bind(schema2.__name__, PREFIX + '_package') -# -# def teardown(selfself): -# schema1.__dict__.pop('conn', None) -# cleanup() -# -# def test_detection_of_existing_table(self): -# """ -# The Relation instance should be able to detect if the -# corresponding table already exists in the database -# """ -# s = test1.Subjects() -# assert_true(s.is_declared) -# -# def test_definition_referring_to_existing_table_without_class(self): -# s1 = test1.Sessions() -# assert_true('experimenter_id' in s1.primary_key) -# -# s2 = test2.Session() -# assert_true('experimenter_id' in s2.primary_key) -# -# def test_reference_to_package_level_table(self): -# s = test1.Match() -# s.declare() -# assert_true('pop_id' in s.primary_key) -# -# def test_direct_reference_to_existing_table_should_fail(self): -# """ -# When deriving from Relation, definition should not contain direct reference -# to a database name -# """ -# s = test1.TrainingSession() -# with assert_raises(DataJointError): -# s.declare() -# -# @raises(TypeError) -# def test_instantiation_of_base_derivative_without_definition_should_fail(): -# test1.Empty() -# -# -# -# diff --git a/tests/test_heading.py b/tests/test_heading.py deleted file mode 100644 index 25d6f77b7..000000000 --- a/tests/test_heading.py +++ /dev/null @@ -1 +0,0 @@ -__author__ = 'eywalker' diff --git a/tests/test_relation.py b/tests/test_relation.py index fd27876ce..d4fa37ddb 100644 --- a/tests/test_relation.py +++ b/tests/test_relation.py @@ -1,424 +1,51 @@ -# import random -# import string -# -# __author__ = 'fabee' -# -# from .schemata.schema1 import test1, test4 import random import string -import pymysql -from datajoint import DataJointError -from .schemata.test1 import Subjects, Animals, Matrix, Trials, SquaredScore, SquaredSubtable, WrongImplementation, \ - ErrorGenerator, testschema -from . import BASE_CONN, CONN_INFO, PREFIX, cleanup -# from datajoint.connection import Connection -from nose.tools import assert_raises, assert_equal, assert_regexp_matches, assert_false, assert_true, assert_list_equal, \ - assert_tuple_equal, assert_dict_equal, raises -# from datajoint import DataJointError, TransactionError, AutoPopulate, Relation -import numpy as np from numpy.testing import assert_array_equal import numpy as np -import datajoint as dj - -# -# -def trial_faker(n=10): - def iter(): - for s in [1, 2]: - for i in range(n): - yield dict(trial_id=i, subject_id=s, outcome=int(np.random.randint(10)), notes='no comment') - - return iter() - - -class TestTableObject(object): - def __init__(self): - self.subjects = None - self.setup() - - """ - Test cases for FreeRelation objects - """ - - def setup(self): - """ - Create a connection object and prepare test modules - as follows: - test1 - has conn and bounded - """ - cleanup() # delete everything from all tables of databases with PREFIX - self.subjects = Subjects() - self.animals = Animals() - self.relvar_blob = Matrix() - self.trials = Trials() - self.score = SquaredScore() - self.subtable = SquaredSubtable() - - def test_table_name_manual(self): - assert_true(not self.subjects.table_name.startswith('#') and - not self.subjects.table_name.startswith('_') and not self.subjects.table_name.startswith('__')) - - def test_table_name_computed(self): - assert_true(self.score.table_name.startswith('__')) - - def test_population_relation_subordinate(self): - assert_true(self.subtable.populate_relation is None) - - @raises(NotImplementedError) - def test_make_tubles_not_implemented_subordinate(self): - self.subtable._make_tuples(None) - - def test_instantiate_relation(self): - s = Subjects() - - def teardown(self): - cleanup() - - def test_compound_restriction(self): - s = self.subjects - t = self.trials - - s.insert(dict(subject_id=1, real_id='M')) - s.insert(dict(subject_id=2, real_id='F')) - t.iter_insert(trial_faker(20)) - - tM = t & (s & "real_id = 'M'") - t1 = t & "subject_id = 1" - - assert_equal(len(tM), len(t1), "Results of compound request does not have same length") - - for t1_item, tM_item in zip(sorted(t1, key=lambda item: item['trial_id']), - sorted(tM, key=lambda item: item['trial_id'])): - assert_dict_equal(t1_item, tM_item, - 'Dictionary elements do not agree in compound statement') - - def test_record_insert(self): - "Test whether record insert works" - tmp = np.array([(2, 'Klara', 'monkey')], - dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) - - self.subjects.insert(tmp[0]) - testt2 = (self.subjects & 'subject_id = 2').fetch()[0] - assert_equal(tuple(tmp[0]), tuple(testt2), "Inserted and fetched record do not match!") - - def test_delete(self): - "Test whether delete works" - tmp = np.array([(2, 'Klara', 'monkey'), (1, 'Peter', 'mouse')], - dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) - - self.subjects.batch_insert(tmp) - assert_true(len(self.subjects) == 2, 'Length does not match 2.') - self.subjects.delete() - assert_true(len(self.subjects) == 0, 'Length does not match 0.') - - # - # # def test_cascading_delete(self): - # # "Test whether delete works" - # # tmp = np.array([(2, 'Klara', 'monkey'), (1,'Peter', 'mouse')], - # # dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) - # # - # # self.subjects.batch_insert(tmp) - # # - # # self.trials.insert(dict(subject_id=1, trial_id=1, outcome=0)) - # # self.trials.insert(dict(subject_id=1, trial_id=2, outcome=1)) - # # self.trials.insert(dict(subject_id=2, trial_id=3, outcome=2)) - # # assert_true(len(self.subjects) == 2, 'Length does not match 2.') - # # assert_true(len(self.trials) == 3, 'Length does not match 3.') - # # (self.subjects & 'subject_id=1').delete() - # # assert_true(len(self.subjects) == 1, 'Length does not match 1.') - # # assert_true(len(self.trials) == 1, 'Length does not match 1.') - # - # def test_short_hand_foreign_reference(self): - # self.animals.heading - # - # - # - def test_record_insert_different_order(self): - "Test whether record insert works" - tmp = np.array([('Klara', 2, 'monkey')], - dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) - - self.subjects.insert(tmp[0]) - testt2 = (self.subjects & 'subject_id = 2').fetch()[0] - assert_equal((2, 'Klara', 'monkey'), tuple(testt2), - "Inserted and fetched record do not match!") - - @raises(KeyError) - def test_wrong_key_insert_records(self): - "Test whether record insert works" - tmp = np.array([('Klara', 2, 'monkey')], - dtype=[('real_deal', 'O'), ('subject_id', '>i4'), ('species', 'O')]) - - self.subjects.insert(tmp[0]) - - def test_dict_insert(self): - "Test whether record insert works" - tmp = {'real_id': 'Brunhilda', - 'subject_id': 3, - 'species': 'human'} - - self.subjects.insert(tmp) - testt2 = (self.subjects & 'subject_id = 3').fetch()[0] - assert_equal((3, 'Brunhilda', 'human'), tuple(testt2), "Inserted and fetched record do not match!") - - @raises(KeyError) - def test_wrong_key_insert(self): - "Test whether a correct error is generated when inserting wrong attribute name" - tmp = {'real_deal': 'Brunhilda', - 'subject_database': 3, - 'species': 'human'} - - self.subjects.insert(tmp) - - # - def test_batch_insert(self): - "Test whether record insert works" - tmp = np.array([('Klara', 2, 'monkey'), ('Brunhilda', 3, 'mouse'), ('Mickey', 1, 'human')], - dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) - - self.subjects.batch_insert(tmp) - - expected = np.array([(1, 'Mickey', 'human'), (2, 'Klara', 'monkey'), - (3, 'Brunhilda', 'mouse')], - dtype=[('subject_id', 'i4'), ('species', 'O')]) - - self.subjects.iter_insert(tmp.__iter__()) - - expected = np.array([(1, 'Mickey', 'human'), (2, 'Klara', 'monkey'), - (3, 'Brunhilda', 'mouse')], - dtype=[('subject_id', ' `dj_free`.Animals -# rec_session_id : int # recording session identifier -# """ -# table = FreeRelation(self.conn, 'dj_free', 'Recordings', definition) -# assert_raises(DataJointError, table.declare) -# -# def test_reference_to_existing_table(self): -# definition1 = """ -# `dj_free`.Animals (manual) # my animal table -# animal_id : int # unique id for the animal -# --- -# animal_name : varchar(128) # name of the animal -# """ -# table1 = FreeRelation(self.conn, 'dj_free', 'Animals', definition1) -# table1.declare() -# -# definition2 = """ -# `dj_free`.Recordings (manual) # recordings -# -> `dj_free`.Animals -# rec_session_id : int # recording session identifier -# """ -# table2 = FreeRelation(self.conn, 'dj_free', 'Recordings', definition2) -# table2.declare() -# assert_true('animal_id' in table2.primary_key) -# -# -def id_generator(size=6, chars=string.ascii_uppercase + string.digits): - return ''.join(random.choice(chars) for _ in range(size)) +from nose.tools import assert_raises, assert_equal, \ + assert_false, assert_true, assert_list_equal, \ + assert_tuple_equal, assert_dict_equal, raises +from . import schema -class TestIterator(object): - def __init__(self): - self.relvar = None - self.setup() +class TestRelation: """ - Test cases for Iterators in Relations objects + Test base relations: insert, delete """ - def setup(self): - """ - Create a connection object and prepare test modules - as follows: - test1 - has conn and bounded - """ - cleanup() # drop all databases with PREFIX - self.relvar_blob = Matrix() - - def teardown(self): - cleanup() - - # - # - def test_blob_iteration(self): - "Tests the basic call of the iterator" - - dicts = [] - for i in range(10): - c = id_generator() - - t = {'matrix_id': i, - 'data': np.random.randn(4, 4, 4), - 'comment': c} - self.relvar_blob.insert(t) - dicts.append(t) - - for t, t2 in zip(dicts, self.relvar_blob): - assert_true(isinstance(t2, dict), 'iterator does not return dict') - - assert_equal(t['matrix_id'], t2['matrix_id'], 'inserted and retrieved tuples do not match') - assert_equal(t['comment'], t2['comment'], 'inserted and retrieved tuples do not match') - assert_true(np.all(t['data'] == t2['data']), 'inserted and retrieved tuples do not match') - - def test_fetch(self): - dicts = [] - for i in range(10): - c = id_generator() - - t = {'matrix_id': i, - 'data': np.random.randn(4, 4, 4), - 'comment': c} - self.relvar_blob.insert(t) - dicts.append(t) - - tuples2 = self.relvar_blob.fetch() - assert_true(isinstance(tuples2, np.ndarray), "Return value of fetch does not have proper type.") - assert_true(isinstance(tuples2[0], np.void), "Return value of fetch does not have proper type.") - for t, t2 in zip(dicts, tuples2): - assert_equal(t['matrix_id'], t2['matrix_id'], 'inserted and retrieved tuples do not match') - assert_equal(t['comment'], t2['comment'], 'inserted and retrieved tuples do not match') - assert_true(np.all(t['data'] == t2['data']), 'inserted and retrieved tuples do not match') - - def test_fetch_dicts(self): - dicts = [] - for i in range(10): - c = id_generator() - - t = {'matrix_id': i, - 'data': np.random.randn(4, 4, 4), - 'comment': c} - self.relvar_blob.insert(t) - dicts.append(t) - - tuples2 = self.relvar_blob.fetch(as_dict=True) - assert_true(isinstance(tuples2, list), "Return value of fetch with as_dict=True does not have proper type.") - assert_true(isinstance(tuples2[0], dict), "Return value of fetch with as_dict=True does not have proper type.") - for t, t2 in zip(dicts, tuples2): - assert_equal(t['matrix_id'], t2['matrix_id'], 'inserted and retrieved dicts do not match') - assert_equal(t['comment'], t2['comment'], 'inserted and retrieved dicts do not match') - assert_true(np.all(t['data'] == t2['data']), 'inserted and retrieved dicts do not match') - - -# -class TestAutopopulate: def __init__(self): - self.relvar = None - self.setup() - - """ - Test cases for Iterators in Relations objects - """ - - def setup(self): + self.user = schema.User() + self.subject = schema.Subject() + self.experiment = schema.Experiment() + self.trial = schema.Trial() + self.ephys = schema.Ephys() + self.channel = schema.EphysChannel() + + def test_contents(self): """ - Create a connection object and prepare test modules - as follows: - test1 - has conn and bounded + test the ability of tables to self-populate using the contents property """ - cleanup() # drop all databases with PREFIX - - self.subjects = Subjects() - self.trials = Trials() - self.squared = SquaredScore() - self.dummy = SquaredSubtable() - self.dummy1 = WrongImplementation() - self.error_generator = ErrorGenerator() - self.fill_relation() - - def fill_relation(self): - tmp = np.array([('Klara', 2, 'monkey'), ('Peter', 3, 'mouse')], - dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) - self.subjects.batch_insert(tmp) - - for trial_id in range(1, 11): - self.trials.insert(dict(subject_id=2, trial_id=trial_id, outcome=np.random.randint(0, 10))) - - def teardown(self): - cleanup() - - def test_autopopulate(self): - self.squared.populate() - assert_equal(len(self.squared), 10) - - for trial in self.trials * self.squared: - assert_equal(trial['outcome'] ** 2, trial['squared']) - - def test_autopopulate_restriction(self): - self.squared.populate(restriction='trial_id <= 5') - assert_equal(len(self.squared), 5) - - for trial in self.trials * self.squared: - assert_equal(trial['outcome'] ** 2, trial['squared']) - - @raises(DataJointError) - def test_autopopulate_relation_check(self): - @testschema - class dummy(dj.Computed): - def populate_relation(self): - return None - - def _make_tuples(self, key): - pass - - du = dummy() - du.populate() - - @raises(DataJointError) - def test_autopopulate_relation_check(self): - self.dummy1.populate() - - @raises(Exception) - def test_autopopulate_relation_check(self): - self.error_generator.populate() - @raises(Exception) - def test_autopopulate_relation_check2(self): - tmp = self.dummy2.populate(suppress_errors=True) - assert_equal(len(tmp), 1, 'Error list should have length 1.') + # test contents + assert_true(self.user) + assert_true(len(self.user) == len(self.user.contents)) + u = self.user.fetch(order_by=['username']) + assert_list_equal(list(u['username']), sorted([s[0] for s in self.user.contents])) + + # test prepare + assert_true(self.subject) + assert_true(len(self.subject) == len(self.subject.contents)) + u = self.subject.fetch(order_by=['subject_id']) + assert_list_equal(list(u['subject_id']), sorted([s[0] for s in self.subject.contents])) + + def test_delete_quick(self): + tmp = np.array([ + (2, 'Klara', 'monkey', '2010-01-01', ''), + (1, 'Peter', 'mouse', '2015-01-01', '')], + dtype=self.subject.heading.as_dtype) + self.subject.insert(tmp) + s = self.subject & ('subject_id in (%s)' % ','.join(str(r) for r in tmp['subject_id'])) + assert_true(len(s) == 2, 'insert did not work.') + s.delete_quick() + assert_true(len(s) == 0, 'delete did not work.') diff --git a/tests/test_settings.py b/tests/test_settings.py index d10323b10..2e9f18328 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -67,7 +67,6 @@ def test_save(): os.rename(tmpfile, settings.LOCALCONFIG) def test_load_save(): - filename_old = dj.settings.LOCALCONFIG filename = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(50)) + '.json' dj.settings.LOCALCONFIG = filename diff --git a/tests/test_utils.py b/tests/test_utils.py index 6c70150b2..2973c8d81 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,14 +1,10 @@ """ Collection of test cases to test core module. """ -from datajoint.user_relations import from_camel_case - -__author__ = 'eywalker' from nose.tools import assert_true, assert_raises, assert_equal -# from datajoint.utils import to_camel_case, from_camel_case +from datajoint.user_relations import from_camel_case from datajoint import DataJointError - def setup(): pass