diff --git a/datajoint/__init__.py b/datajoint/__init__.py index ac832b322..b562aa8d1 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -5,6 +5,8 @@ __version__ = "0.2" __all__ = ['__author__', '__version__', 'Connection', 'Heading', 'Relation', 'FreeRelation', 'Not', + 'Relation', + 'Manual', 'Lookup', 'Imported', 'Computed', 'AutoPopulate', 'conn', 'DataJointError', 'blob'] @@ -15,26 +17,6 @@ class DataJointError(Exception): pass -class TransactionError(DataJointError): - """ - Base class for errors specific to DataJoint internal operation. - """ - def __init__(self, msg, f, args=None, kwargs=None): - super(TransactionError, self).__init__(msg) - self.operations = (f, args if args is not None else tuple(), - kwargs if kwargs is not None else {}) - - def resolve(self): - f, args, kwargs = self.operations - return f(*args, **kwargs) - - - @property - def culprit(self): - return self.operations[0].__name__ - - - # ----------- loads local configuration from file ---------------- from .settings import Config, CONFIGVAR, LOCALCONFIG, logger, log_levels config = Config() @@ -56,8 +38,9 @@ def culprit(self): # ------------- flatten import hierarchy ------------------------- from .connection import conn, Connection from .relation import Relation +from .user_relations import Manual, Lookup, Imported, Computed, Subordinate from .autopopulate import AutoPopulate from . import blob from .relational_operand import Not -from .free_relation import FreeRelation -from .heading import Heading \ No newline at end of file +from .heading import Heading +from .relation import schema \ No newline at end of file diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index a266b62f9..d07fd1bf0 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -1,5 +1,5 @@ from .relational_operand import RelationalOperand -from . import DataJointError, TransactionError, Relation +from . import DataJointError, Relation import abc import logging @@ -37,7 +37,7 @@ def _make_tuples(self, key): def target(self): return self - def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False, max_attempts=10): + def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False): """ rel.populate() calls rel._make_tuples(key) for every primary key in self.populate_relation for which there is not already a tuple in rel. @@ -45,7 +45,6 @@ def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False, :param restriction: restriction on rel.populate_relation - target :param suppress_errors: suppresses error if true :param reserve_jobs: currently not implemented - :param max_attempts: maximal number of times a TransactionError is caught before populate gives up """ assert not reserve_jobs, NotImplemented # issue #5 @@ -53,40 +52,39 @@ def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False, if not isinstance(self.populate_relation, RelationalOperand): raise DataJointError('Invalid populate_relation value') - self.conn.cancel_transaction() # rollback previous transaction, if any + self.connection.cancel_transaction() # rollback previous transaction, if any if not isinstance(self, Relation): - raise DataJointError('Autopopulate is a mixin for Relation and must therefore subclass Relation') + raise DataJointError( + 'AutoPopulate is a mixin for Relation and must therefore subclass Relation') unpopulated = (self.populate_relation - self.target) & restriction for key in unpopulated.project(): - self.conn.start_transaction() + self.connection.start_transaction() if key in self.target: # already populated - self.conn.cancel_transaction() + self.connection.cancel_transaction() else: logger.info('Populating: ' + str(key)) try: - for attempts in range(max_attempts): - try: - self._make_tuples(dict(key)) - break - except TransactionError as tr_err: - self.conn.cancel_transaction() - tr_err.resolve() - self.conn.start_transaction() - logger.info('Transaction error in {0:s}.'.format(tr_err.culprit)) - else: - raise DataJointError( - '%s._make_tuples failed after %i attempts, giving up' % (self.__class__,max_attempts)) + self._make_tuples(dict(key)) except Exception as error: - self.conn.cancel_transaction() + self.connection.cancel_transaction() if not suppress_errors: raise else: logger.error(error) error_list.append((key, error)) else: - self.conn.commit_transaction() + self.connection.commit_transaction() logger.info('Done populating.') return error_list + + def progress(self): + """ + report progress of populating this table + """ + total = len(self.populate_relation) + remaining = len(self.populate_relation - self.target) + print('Remaining %d of %d (%2.1f%%)' % (remaining, total, 100*remaining/total) + if remaining else 'Complete', flush=True) diff --git a/datajoint/blob.py b/datajoint/blob.py index b77547320..7dad919cf 100644 --- a/datajoint/blob.py +++ b/datajoint/blob.py @@ -1,29 +1,30 @@ import zlib -import collections +from collections import OrderedDict import numpy as np from . import DataJointError -mxClassID = collections.OrderedDict( +mxClassID = OrderedDict(( # see http://www.mathworks.com/help/techdoc/apiref/mxclassid.html - mxUNKNOWN_CLASS=None, - mxCELL_CLASS=None, # TODO: implement - mxSTRUCT_CLASS=None, # TODO: implement - mxLOGICAL_CLASS=np.dtype('bool'), - mxCHAR_CLASS=np.dtype('c'), - mxVOID_CLASS=None, - mxDOUBLE_CLASS=np.dtype('float64'), - mxSINGLE_CLASS=np.dtype('float32'), - mxINT8_CLASS=np.dtype('int8'), - mxUINT8_CLASS=np.dtype('uint8'), - mxINT16_CLASS=np.dtype('int16'), - mxUINT16_CLASS=np.dtype('uint16'), - mxINT32_CLASS=np.dtype('int32'), - mxUINT32_CLASS=np.dtype('uint32'), - mxINT64_CLASS=np.dtype('int64'), - mxUINT64_CLASS=np.dtype('uint64'), - mxFUNCTION_CLASS=None) + ('mxUNKNOWN_CLASS', None), + ('mxCELL_CLASS', None), # TODO: implement + ('mxSTRUCT_CLASS', None), # TODO: implement + ('mxLOGICAL_CLASS', np.dtype('bool')), + ('mxCHAR_CLASS', np.dtype('c')), + ('mxVOID_CLASS', None), + ('mxDOUBLE_CLASS', np.dtype('float64')), + ('mxSINGLE_CLASS', np.dtype('float32')), + ('mxINT8_CLASS', np.dtype('int8')), + ('mxUINT8_CLASS', np.dtype('uint8')), + ('mxINT16_CLASS', np.dtype('int16')), + ('mxUINT16_CLASS', np.dtype('uint16')), + ('mxINT32_CLASS', np.dtype('int32')), + ('mxUINT32_CLASS', np.dtype('uint32')), + ('mxINT64_CLASS', np.dtype('int64')), + ('mxUINT64_CLASS', np.dtype('uint64')), + ('mxFUNCTION_CLASS', None))) reverseClassID = {v: i for i, v in enumerate(mxClassID.values())} +dtypeList = list(mxClassID.values()) def pack(obj): @@ -41,6 +42,7 @@ def pack(obj): obj, imaginary = np.real(obj), np.imag(obj) type_number = reverseClassID[obj.dtype] + assert dtypeList[type_number] is obj.dtype, 'ambigous or unknown array type' blob += np.asarray(type_number, dtype=np.uint32).tostring() blob += np.int8(is_complex).tostring() + b'\0\0\0' blob += obj.tostring() @@ -72,9 +74,10 @@ def unpack(blob): p += 8 array_shape = np.fromstring(blob[p:p+8*dimensions], dtype=np.uint64) p += 8 * dimensions - mx_type, dtype = [q for q in mxClassID.items()][np.fromstring(blob[p:p+4], dtype=np.uint32)[0]] + type_number = np.fromstring(blob[p:p+4], dtype=np.uint32)[0] + dtype = dtypeList[type_number] if dtype is None: - raise DataJointError('Unsupported MATLAB data type '+mx_type+' in blob') + raise DataJointError('Unsupported MATLAB data type '+type_number+' in blob') p += 4 is_complex = np.fromstring(blob[p:p+4], dtype=np.uint32)[0] p += 4 diff --git a/datajoint/connection.py b/datajoint/connection.py index 572cede34..0c668d9fe 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -1,28 +1,19 @@ +from contextlib import contextmanager import pymysql -import re -from .utils import to_camel_case from . import DataJointError -from .heading import Heading -from .settings import prefix_to_role import logging -from .erd import DBConnGraph from . import config logger = logging.getLogger(__name__) -# The following two regular expression are equivalent but one works in python -# and the other works in MySQL -table_name_regexp_sql = re.compile('^(#|_|__|~)?[a-z][a-z0-9_]*$') -table_name_regexp = re.compile('^(|#|_|__|~)[a-z][a-z0-9_]*$') # MySQL does not accept this but MariaDB does - def conn_container(): """ creates a persistent connections for everyone to use """ - _connObj = None # persistent connection object used by dj.conn() + _connection = None # persistent connection object used by dj.conn() - def conn_function(host=None, user=None, passwd=None, init_fun=None, reset=False): + def conn_function(host=None, user=None, passwd=None, init_fun=None, reset=False): # TODO: thin wrapping layer to mimic singleton """ Manage a persistent connection object. This is one of several ways to configure and access a datajoint connection. @@ -30,8 +21,8 @@ def conn_function(host=None, user=None, passwd=None, init_fun=None, reset=False) Set rest=True to reset the persistent connection object """ - nonlocal _connObj - if not _connObj or reset: + nonlocal _connection + if not _connection or reset: host = host if host is not None else config['database.host'] user = user if user is not None else config['database.user'] passwd = passwd if passwd is not None else config['database.password'] @@ -39,18 +30,16 @@ def conn_function(host=None, user=None, passwd=None, init_fun=None, reset=False) if passwd is None: passwd = input("Please enter database password: ") init_fun = init_fun if init_fun is not None else config['connection.init_function'] - _connObj = Connection(host, user, passwd, init_fun) - return _connObj + _connection = Connection(host, user, passwd, init_fun) + return _connection return conn_function -# The function conn is used by others to obtain the package wide persistent connection object +# The function conn is used by others to obtain a connection object conn = conn_container() - - -class Connection(object): +class Connection: """ A dj.Connection object manages a connection to a database server. It also catalogues modules, schemas, tables, and their dependencies (foreign keys). @@ -73,23 +62,16 @@ def __init__(self, host, user, passwd, init_fun=None): self.conn_info = dict(host=host, port=port, user=user, passwd=passwd) self._conn = pymysql.connect(init_command=init_fun, **self.conn_info) if self.is_connected: - print("Connected", user + '@' + host + ':' + str(port)) + logger.info("Connected " + user + '@' + host + ':' + str(port)) else: raise DataJointError('Connection failed.') self._conn.autocommit(True) - - self.db_to_mod = {} # modules indexed by dbnames - self.mod_to_db = {} # database names indexed by modules - self.table_names = {} # tables names indexed by [dbname][class_name] - self.headings = {} # contains headings indexed by [dbname][table_name] - self.tableInfo = {} # table information indexed by [dbname][table_name] - - # dependencies from foreign keys - self.parents = {} # maps table names to their parent table names (primary foreign key) - self.referenced = {} # maps table names to table names they reference (non-primary foreign key - self._graph = DBConnGraph(self) # initialize an empty connection graph self._in_transaction = False + def __del__(self): + logger.info('Disconnecting {user}@{host}:{port}'.format(**self.conn_info)) + self._conn.close() + def __eq__(self, other): return self.conn_info == other.conn_info @@ -100,216 +82,6 @@ def is_connected(self): """ return self._conn.ping() - def get_full_module_name(self, module): - """ - Returns full module name of the module. - - :param module: module for which the name is requested. - :return: full module name - """ - return '.'.join(self.root_package, module) - - def bind(self, module, dbname): - """ - Binds the `module` name to the database named `dbname`. - Throws an error if `dbname` is already bound to another module. - - If the database `dbname` does not exist in the server, attempts - to create the database and then bind the module. - - - :param module: module name. - :param dbname: database name. It should be a valid database identifier and not a match pattern. - """ - - if dbname in self.db_to_mod: - raise DataJointError('Database `%s` is already bound to module `%s`' - % (dbname, self.db_to_mod[dbname])) - - cur = self.query("SHOW DATABASES LIKE '{dbname}'".format(dbname=dbname)) - count = cur.rowcount - - if count == 1: - # Database exists - self.db_to_mod[dbname] = module - self.mod_to_db[module] = dbname - elif count == 0: - # Database doesn't exist, attempt to create - logger.info("Database `{dbname}` could not be found. " - "Attempting to create the database.".format(dbname=dbname)) - try: - self.query("CREATE DATABASE `{dbname}`".format(dbname=dbname)) - logger.info('Created database `{dbname}`.'.format(dbname=dbname)) - self.db_to_mod[dbname] = module - self.mod_to_db[module] = dbname - except pymysql.OperationalError: - raise DataJointError("Database named `{dbname}` was not defined, and" - " an attempt to create has failed. Check" - " permissions.".format(dbname=dbname)) - else: - raise DataJointError("Database name {dbname} matched more than one " - "existing databases. Database name should not be " - "a pattern.".format(dbname=dbname)) - - def load_headings(self, dbname=None, force=False): - """ - Load table information including roles and list of attributes for all - tables within dbname by examining respective table status. - - If dbname is not specified or None, will load headings for all - databases that are bound to a module. - - By default, the heading is not loaded again if it already exists. - Setting force=True will result in reloading of the heading even if one - already exists. - - :param dbname=None: database name - :param force=False: force reloading the heading - """ - if dbname: - self._load_headings(dbname, force) - return - - for dbname in self.db_to_mod: - self._load_headings(dbname, force) - - def _load_headings(self, dbname, force=False): - """ - Load table information including roles and list of attributes for all - tables within dbname by examining respective table status. - - By default, the heading is not loaded again if it already exists. - Setting force=True will result in reloading of the heading even if one - already exists. - - :param dbname: database name - :param force: force reloading the heading - """ - if dbname not in self.headings or force: - logger.info('Loading table definitions from `{dbname}`...'.format(dbname=dbname)) - self.table_names[dbname] = {} - self.headings[dbname] = {} - self.tableInfo[dbname] = {} - - cur = self.query('SHOW TABLE STATUS FROM `{dbname}` WHERE name REGEXP "{sqlPtrn}"'.format( - dbname=dbname, sqlPtrn=table_name_regexp_sql.pattern), as_dict=True) - - for info in cur: - info = {k.lower(): v for k, v in info.items()} # lowercase it - table_name = info.pop('name') - # look up role by table name prefix - role = prefix_to_role[table_name_regexp.match(table_name).group(1)] - class_name = to_camel_case(table_name) - self.table_names[dbname][class_name] = table_name - self.tableInfo[dbname][table_name] = dict(info, role=role) - self.headings[dbname][table_name] = Heading.init_from_database(self, dbname, table_name) - self.load_dependencies(dbname) - - def load_dependencies(self, dbname): # TODO: Perhaps consider making this "private" by preceding with underscore? - """ - Load dependencies (foreign keys) between tables by examining their - respective CREATE TABLE statements. - - :param dbname: database name - """ - - foreign_key_regexp = re.compile(r""" - FOREIGN\ KEY\s+\((?P[`\w ,]+)\)\s+ # list of keys in this table - REFERENCES\s+(?P[^\s]+)\s+ # table referenced - \((?P[`\w ,]+)\) # list of keys in the referenced table - """, re.X) - - logger.info('Loading dependencies for `{dbname}`'.format(dbname=dbname)) - - for tabName in self.tableInfo[dbname]: - cur = self.query('SHOW CREATE TABLE `{dbname}`.`{tabName}`'.format(dbname=dbname, tabName=tabName), - as_dict=True) - table_def = cur.fetchone() - full_table_name = '`%s`.`%s`' % (dbname, tabName) - self.parents[full_table_name] = [] - self.referenced[full_table_name] = [] - - for m in foreign_key_regexp.finditer(table_def["Create Table"]): # iterate through foreign key statements - assert m.group('attr1') == m.group('attr2'), \ - 'Foreign keys must link identically named attributes' - attrs = m.group('attr1') - attrs = re.split(r',\s+', re.sub(r'`(.*?)`', r'\1', attrs)) # remove ` around attrs and split into list - pk = self.headings[dbname][tabName].primary_key - is_primary = all([k in pk for k in attrs]) - ref = m.group('ref') # referenced table - - if not re.search(r'`\.`', ref): # if referencing other table in same schema - ref = '`%s`.%s' % (dbname, ref) # convert to full-table name - - (self.parents if is_primary else self.referenced)[full_table_name].append(ref) - self.parents.setdefault(ref, []) - self.referenced.setdefault(ref, []) - - def clear_dependencies(self, dbname=None): - """ - Clears dependency mapping originating from `dbname`. If `dbname` is not - specified, dependencies for all databases will be cleared. - - - :param dbname: database name - """ - if dbname is None: # clear out all dependencies - self.parents.clear() - self.referenced.clear() - else: - table_keys = ('`%s`.`%s`' % (dbname, tblName) for tblName in self.tableInfo[dbname]) - for key in table_keys: - if key in self.parents: - self.parents.pop(key) - if key in self.referenced: - self.referenced.pop(key) - - def parents_of(self, child_table): - """ - Returns a list of tables that are parents of the specified child_table. Parent-child relationship is defined - based on the presence of primary-key foreign reference: table that holds a foreign key relation to another table - is the child table. - - :param child_table: the child table - :return: list of parent tables - """ - return self.parents.get(child_table, []).copy() - - def children_of(self, parent_table): - """ - Returns a list of tables for which parent_table is a parent (primary foreign key). Parent-child relationship - is defined based on the presence of primary-key foreign reference: table that holds a foreign key relation to - another table is the child table. - - :param parent_table: parent table - :return: list of child tables - """ - return [child_table for child_table, parents in self.parents.items() if parent_table in parents] - - def referenced_by(self, referencing_table): - """ - Returns a list of tables that are referenced by non-primary foreign key relation - by the referencing_table. - - :param referencing_table: referencing table - :return: list of tables that are referenced by the target table - """ - return self.referenced.get(referencing_table, []).copy() - - def referencing(self, referenced_table): - """ - Returns a list of tables that references referenced_table as non-primary foreign key - - :param referenced_table: referenced table - :return: list of tables that refers to the target table - """ - return [referencing for referencing, referenced in self.referenced.items() - if referenced_table in referenced] - - # TODO: Reimplement __str__ - def __str__(self): - return self.__repr__() # placeholder until more suitable __str__ is implemented - def __repr__(self): connected = "connected" if self.is_connected else "disconnected" return "DataJoint connection ({connected}) {user}@{host}:{port}".format( @@ -319,25 +91,6 @@ def __del__(self): logger.info('Disconnecting {user}@{host}:{port}'.format(**self.conn_info)) self._conn.close() - def erd(self, databases=None, tables=None, fill=True, reload=False): - """ - Creates Entity Relation Diagram for the database or specified subset of - tables. - - Set `fill` to False to only display specified tables. (By default - connection tables are automatically included) - """ - self._graph.update_graph(reload=reload) # update the graph - - graph = self._graph.copy_graph() - if databases: - graph = graph.restrict_by_modules(databases, fill) - - if tables: - graph = graph.restrict_by_tables(tables, fill) - - return graph - def query(self, query, args=(), as_dict=False): """ Execute the specified query and return the tuple generator. @@ -353,7 +106,7 @@ def query(self, query, args=(), as_dict=False): cur.execute(query, args) return cur - + # ---------- transaction processing ------------------ @property def in_transaction(self): self._in_transaction = self._in_transaction and self.is_connected @@ -374,5 +127,18 @@ def cancel_transaction(self): def commit_transaction(self): self.query('COMMIT') self._in_transaction = False - logger.info("Transaction commited and closed.") + logger.info("Transaction committed and closed.") + + + #-------- context manager for transactions + @contextmanager + def transaction(self): + try: + self.start_transaction() + yield self + except: + self.cancel_transaction() + raise + else: + self.commit_transaction() diff --git a/datajoint/declare.py b/datajoint/declare.py new file mode 100644 index 000000000..c0c8e44a9 --- /dev/null +++ b/datajoint/declare.py @@ -0,0 +1,113 @@ +import re +import pyparsing as pp +import logging + +from . import DataJointError + + +logger = logging.getLogger(__name__) + + + +def declare(full_table_name, definition, context): + """ + Parse declaration and create new SQL table accordingly. + """ + # split definition into lines + definition = re.split(r'\s*\n\s*', definition.strip()) + + table_comment = definition.pop(0)[1:].strip() if definition[0].startswith('#') else '' + + in_key = True # parse primary keys + primary_key = [] + attributes = [] + attribute_sql = [] + foreign_key_sql = [] + index_sql = [] + + for line in definition: + if line.startswith('#'): # additional comments are ignored + pass + elif line.startswith('---'): + in_key = False # start parsing dependent attributes + elif line.startswith('->'): + # foreign key + ref = eval(line[2:], context)() + foreign_key_sql.append( + 'FOREIGN KEY ({primary_key})' + ' REFERENCES {ref} ({primary_key})' + ' ON UPDATE CASCADE ON DELETE RESTRICT'.format( + primary_key='`' + '`,`'.join(ref.primary_key) + '`', ref=ref.full_table_name) + ) + for name in ref.primary_key: + if in_key and name not in primary_key: + primary_key.append(name) + if name not in attributes: + attributes.append(name) + attribute_sql.append(ref.heading[name].sql()) + elif re.match(r'^(unique\s+)?index[^:]*$', line, re.I): # index + index_sql.append(line) # the SQL syntax is identical to DataJoint's + else: + name, sql = compile_attribute(line, in_key) + if in_key and name not in primary_key: + primary_key.append(name) + if name not in attributes: + attributes.append(name) + attribute_sql.append(sql) + + # compile SQL + if not primary_key: + raise DataJointError('Table must have a primary key') + sql = 'CREATE TABLE %s (\n ' % full_table_name + sql += ',\n '.join(attribute_sql) + sql += ',\n PRIMARY KEY (`' + '`,`'.join(primary_key) + '`)' + if foreign_key_sql: + sql += ', \n' + ', \n'.join(foreign_key_sql) + if index_sql: + sql += ', \n' + ', \n'.join(index_sql) + sql += '\n) ENGINE = InnoDB, COMMENT "%s"' % table_comment + return sql + + + + +def compile_attribute(line, in_key=False): + """ + Convert attribute definition from DataJoint format to SQL + :param line: attribution line + :param in_key: set to True if attribute is in primary key set + :returns: (name, sql) -- attribute name and sql code for its declaration + """ + quoted = pp.Or(pp.QuotedString('"'), pp.QuotedString("'")) + colon = pp.Literal(':').suppress() + attribute_name = pp.Word(pp.srange('[a-z]'), pp.srange('[a-z0-9_]')).setResultsName('name') + + data_type = pp.Combine(pp.Word(pp.alphas)+pp.SkipTo("#", ignore=quoted)).setResultsName('type') + default = pp.Literal('=').suppress() + pp.SkipTo(colon, ignore=quoted).setResultsName('default') + comment = pp.Literal('#').suppress() + pp.restOfLine.setResultsName('comment') + + attribute_parser = attribute_name + pp.Optional(default) + colon + data_type + comment + + match = attribute_parser.parseString(line+'#', parseAll=True) + match['comment'] = match['comment'].rstrip('#') + if 'default' not in match: + match['default'] = '' + match = {k: v.strip() for k, v in match.items()} + match['nullable'] = match['default'].lower() == 'null' + + literals = ['CURRENT_TIMESTAMP'] # not to be enclosed in quotes + if match['nullable']: + if in_key: + raise DataJointError('Primary key attributes cannot be nullable in line %s' % line) + match['default'] = 'DEFAULT NULL' # nullable attributes default to null + else: + if match['default']: + quote = match['default'].upper() not in literals and match['default'][0] not in '"\'' + match['default'] = ('NOT NULL DEFAULT ' + + ('"%s"' if quote else "%s") % match['default']) + else: + match['default'] = 'NOT NULL' + match['comment'] = match['comment'].replace('"', '\\"') # escape double quotes in comment + sql = ('`{name}` {type} {default}' + (' COMMENT "{comment}"' if match['comment'] else '') + ).format(**match) + return match['name'], sql diff --git a/datajoint/free_relation.py b/datajoint/free_relation.py deleted file mode 100644 index 11ff6fe0c..000000000 --- a/datajoint/free_relation.py +++ /dev/null @@ -1,506 +0,0 @@ -from _collections_abc import MutableMapping, Mapping -import numpy as np -import logging -from . import DataJointError, config, TransactionError -from .relational_operand import RelationalOperand -from .blob import pack -from .heading import Heading -import re -from .settings import Role, role_to_prefix -from .utils import from_camel_case, user_choice - -logger = logging.getLogger(__name__) - - -class FreeRelation(RelationalOperand): - """ - A FreeRelation object is a relation associated with a table. - A FreeRelation object provides insert and delete methods. - FreeRelation objects are only used internally and for debugging. - The table must already exist in the schema for its FreeRelation object to work. - - The table associated with an instance of Relation is identified by its 'class name'. - property, which is a string in CamelCase. The actual table name is obtained - by converting className from CamelCase to underscore_separated_words and - prefixing according to the table's role. - - Relation instances obtain their table's heading by looking it up in the connection - object. This ensures that Relation instances contain the current table definition - even after tables are modified after the instance is created. - """ - - def __init__(self, conn, dbname, class_name=None, definition=None): - self.class_name = class_name - self._conn = conn - self.dbname = dbname - self._definition = definition - - if dbname not in self.conn.db_to_mod: - # register with a fake module, enclosed in back quotes - # necessary for loading mechanism - self.conn.bind('`{0}`'.format(dbname), dbname) - super().__init__(conn) - - @property - def from_clause(self): - return self.full_table_name - - @property - def heading(self): - self.declare() - return self.conn.headings[self.dbname][self.table_name] - - @property - def definition(self): - return self._definition - - @property - def is_declared(self): - self.conn.load_headings(self.dbname) - return self.class_name in self.conn.table_names[self.dbname] - - def declare(self): - """ - Declare the table in database if it doesn't already exist. - - :raises: DataJointError if the table cannot be declared. - """ - if not self.is_declared: - self._declare() - if not self.is_declared: - raise DataJointError( - 'FreeRelation could not be declared for %s' % self.class_name) - - @staticmethod - def _field_to_sql(field): # TODO move this into Attribute Tuple - """ - Converts an attribute definition tuple into SQL code. - :param field: attribute definition - :rtype : SQL code - """ - mysql_constants = ['CURRENT_TIMESTAMP'] - if field.nullable: - default = 'DEFAULT NULL' - else: - default = 'NOT NULL' - # if some default specified - if field.default: - # enclose value in quotes except special SQL values or already enclosed - quote = field.default.upper() not in mysql_constants and field.default[0] not in '"\'' - default += ' DEFAULT ' + ('"%s"' if quote else "%s") % field.default - if any((c in r'\"' for c in field.comment)): - raise DataJointError('Illegal characters in attribute comment "%s"' % field.comment) - - return '`{name}` {type} {default} COMMENT "{comment}",\n'.format( - name=field.name, type=field.type, default=default, comment=field.comment) - - @property - def full_table_name(self): - """ - :return: full name of the associated table - """ - return '`%s`.`%s`' % (self.dbname, self.table_name) - - @property - def table_name(self): - """ - :return: name of the associated table - """ - return self.conn.table_names[self.dbname][self.class_name] if self.is_declared else None - - @property - def primary_key(self): - """ - :return: primary key of the table - """ - return self.heading.primary_key - - def iter_insert(self, iter, **kwargs): - """ - Inserts an entire batch of entries. Additional keyword arguments are passed to insert. - - :param iter: Must be an iterator that generates a sequence of valid arguments for insert. - """ - for row in iter: - self.insert(row, **kwargs) - - def batch_insert(self, data, **kwargs): - """ - Inserts an entire batch of entries. Additional keyword arguments are passed to insert. - - :param data: must be iterable, each row must be a valid argument for insert - """ - self.iter_insert(data.__iter__(), **kwargs) - - def insert(self, tup, ignore_errors=False, replace=False): - """ - Insert one data record or one Mapping (like a dictionary). - - :param tup: Data record, or a Mapping (like a dictionary). - :param ignore_errors=False: Ignores errors if True. - :param replace=False: Replaces data tuple if True. - - Example:: - - b = djtest.Subject() - b.insert(dict(subject_id = 7, species="mouse",\\ - real_id = 1007, date_of_birth = "2014-09-01")) - """ - - heading = self.heading - if isinstance(tup, np.void): - for fieldname in tup.dtype.fields: - if fieldname not in heading: - raise KeyError(u'{0:s} is not in the attribute list'.format(fieldname, )) - value_list = ','.join([repr(tup[name]) if not heading[name].is_blob else '%s' - for name in heading if name in tup.dtype.fields]) - - args = tuple(pack(tup[name]) for name in heading - if name in tup.dtype.fields and heading[name].is_blob) - attribute_list = '`' + '`,`'.join( - [q for q in heading if q in tup.dtype.fields]) + '`' - elif isinstance(tup, Mapping): - for fieldname in tup.keys(): - if fieldname not in heading: - raise KeyError(u'{0:s} is not in the attribute list'.format(fieldname, )) - value_list = ','.join([repr(tup[name]) if not heading[name].is_blob else '%s' - for name in heading if name in tup]) - args = tuple(pack(tup[name]) for name in heading - if name in tup and heading[name].is_blob) - attribute_list = '`' + '`,`'.join( - [name for name in heading if name in tup]) + '`' - else: - raise DataJointError('Datatype %s cannot be inserted' % type(tup)) - if replace: - sql = 'REPLACE' - elif ignore_errors: - sql = 'INSERT IGNORE' - else: - sql = 'INSERT' - sql += " INTO %s (%s) VALUES (%s)" % (self.full_table_name, - attribute_list, value_list) - logger.info(sql) - self.conn.query(sql, args=args) - - def delete(self): - if not config['safemode'] or user_choice( - "You are about to delete data from a table. This operation cannot be undone.\n" - "Proceed?", default='no') == 'yes': - self.conn.query('DELETE FROM ' + self.from_clause + self.where_clause) # TODO: make cascading (issue #15) - - def drop(self): - """ - Drops the table associated to this object. - """ - if self.is_declared: - if not config['safemode'] or user_choice( - "You are about to drop an entire table. This operation cannot be undone.\n" - "Proceed?", default='no') == 'yes': - self.conn.query('DROP TABLE %s' % self.full_table_name) # TODO: make cascading (issue #16) - self.conn.clear_dependencies(dbname=self.dbname) - self.conn.load_headings(dbname=self.dbname, force=True) - logger.info("Dropped table %s" % self.full_table_name) - - @property - def size_on_disk(self): - """ - :return: size of data and indices in MiB taken by the table on the storage device - """ - cur = self.conn.query( - 'SHOW TABLE STATUS FROM `{dbname}` WHERE NAME="{table}"'.format( - dbname=self.dbname, table=self.table_name), as_dict=True) - ret = cur.fetchone() - return (ret['Data_length'] + ret['Index_length'])/1024**2 - - def set_table_comment(self, comment): - """ - Update the table comment in the table definition. - :param comment: new comment as string - """ - # TODO: add verification procedure (github issue #24) - self.alter('COMMENT="%s"' % comment) - - def add_attribute(self, definition, after=None): - """ - Add a new attribute to the table. A full line from the table definition - is passed in as definition. - - The definition can specify where to place the new attribute. Use after=None - to add the attribute as the first attribute or after='attribute' to place it - after an existing attribute. - - :param definition: table definition - :param after=None: After which attribute of the table the new attribute is inserted. - If None, the attribute is inserted in front. - """ - position = ' FIRST' if after is None else ( - ' AFTER %s' % after if after else '') - sql = self.field_to_sql(parse_attribute_definition(definition)) - self._alter('ADD COLUMN %s%s' % (sql[:-2], position)) - - def drop_attribute(self, attr_name): - """ - Drops the attribute attrName from this table. - - :param attr_name: Name of the attribute that is dropped. - """ - if not config['safemode'] or user_choice( - "You are about to drop an attribute from a table." - "This operation cannot be undone.\n" - "Proceed?", default='no') == 'yes': - self._alter('DROP COLUMN `%s`' % attr_name) - - def alter_attribute(self, attr_name, new_definition): - """ - Alter the definition of the field attr_name in this table using the new definition. - - :param attr_name: field that is redefined - :param new_definition: new definition of the field - """ - sql = self.field_to_sql(parse_attribute_definition(new_definition)) - self._alter('CHANGE COLUMN `%s` %s' % (attr_name, sql[:-2])) - - def erd(self, subset=None): - """ - Plot the schema's entity relationship diagram (ERD). - """ - - def _alter(self, alter_statement): - """ - Execute ALTER TABLE statement for this table. The schema - will be reloaded within the connection object. - - :param alter_statement: alter statement - """ - if self._conn.in_transaction: - raise TransactionError( - u"_alter is currently in transaction. Operation not allowed to avoid implicit commits.", - self._alter, args=(alter_statement,)) - - sql = 'ALTER TABLE %s %s' % (self.full_table_name, alter_statement) - self.conn.query(sql) - self.conn.load_headings(self.dbname, force=True) - # TODO: place table definition sync mechanism - - @staticmethod - def _parse_index_def(line): - """ - Parses index definition. - - :param line: definition line - :return: groupdict with index info - """ - line = line.strip() - index_regexp = re.compile(""" - ^(?PUNIQUE)?\s*INDEX\s* # [UNIQUE] INDEX - \((?P[^\)]+)\)$ # (attr1, attr2) - """, re.I + re.X) - m = index_regexp.match(line) - assert m, 'Invalid index declaration "%s"' % line - index_info = m.groupdict() - attributes = re.split(r'\s*,\s*', index_info['attributes'].strip()) - index_info['attributes'] = attributes - assert len(attributes) == len(set(attributes)), \ - 'Duplicate attributes in index declaration "%s"' % line - return index_info - - def get_base(self, module_name, class_name): - if not module_name: - module_name = r'`{dbname}`'.format(self.dbname) - m = re.match(r'`(\w+)`', module_name) - return FreeRelation(self.conn, m.group(1), class_name) if m else None - - @property - def ref_name(self): - """ - :return: the name to refer to this class, taking form module.class or `database`.class - """ - return '`{0}`'.format(self.dbname) + '.' + self.class_name - - def _declare(self): - """ - Declares the table in the database if no table in the database matches this object. - """ - if self._conn.in_transaction: - raise TransactionError( - u"_alter is currently in transaction. Operation not allowed to avoid implicit commits.", self._declare) - - if not self.definition: - raise DataJointError('Table definition is missing!') - table_info, parents, referenced, field_defs, index_defs = self._parse_declaration() - defined_name = table_info['module'] + '.' + table_info['className'] - - if not defined_name == self.ref_name: - raise DataJointError('Table name {} does not match the declared' - 'name {}'.format(self.ref_name, defined_name)) - - # compile the CREATE TABLE statement - table_name = role_to_prefix[table_info['tier']] + from_camel_case(self.class_name) - sql = 'CREATE TABLE `%s`.`%s` (\n' % (self.dbname, table_name) - - # add inherited primary key fields - primary_key_fields = set() - non_key_fields = set() - for p in parents: - for key in p.primary_key: - field = p.heading[key] - if field.name not in primary_key_fields: - primary_key_fields.add(field.name) - sql += self._field_to_sql(field) - else: - logger.debug('Field definition of {} in {} ignored'.format( - field.name, p.full_class_name)) - - # add newly defined primary key fields - for field in (f for f in field_defs if f.in_key): - if field.nullable: - raise DataJointError('Primary key attribute {} cannot be nullable'.format( - field.name)) - if field.name in primary_key_fields: - raise DataJointError('Duplicate declaration of the primary attribute {key}. ' - 'Ensure that the attribute is not already declared ' - 'in referenced tables'.format(key=field.name)) - primary_key_fields.add(field.name) - sql += self._field_to_sql(field) - - # add secondary foreign key attributes - for r in referenced: - for key in r.primary_key: - field = r.heading[key] - if field.name not in primary_key_fields | non_key_fields: - non_key_fields.add(field.name) - sql += self._field_to_sql(field) - - # add dependent attributes - for field in (f for f in field_defs if not f.in_key): - non_key_fields.add(field.name) - sql += self._field_to_sql(field) - - # add primary key declaration - assert len(primary_key_fields) > 0, 'table must have a primary key' - keys = ', '.join(primary_key_fields) - sql += 'PRIMARY KEY (%s),\n' % keys - - # add foreign key declarations - for ref in parents + referenced: - keys = ', '.join(ref.primary_key) - sql += 'FOREIGN KEY (%s) REFERENCES %s (%s) ON UPDATE CASCADE ON DELETE RESTRICT,\n' % \ - (keys, ref.full_table_name, keys) - - # add secondary index declarations - # gather implicit indexes due to foreign keys first - implicit_indices = [] - for fk_source in parents + referenced: - implicit_indices.append(fk_source.primary_key) - - # for index in indexDefs: - # TODO: finish this up... - - # close the declaration - sql = '%s\n) ENGINE = InnoDB, COMMENT "%s"' % ( - sql[:-2], table_info['comment']) - - # make sure that the table does not alredy exist - self.conn.load_headings(self.dbname, force=True) - if not self.is_declared: - # execute declaration - logger.debug('\n\n' + sql + '\n\n') - self.conn.query(sql) - self.conn.load_headings(self.dbname, force=True) - - def _parse_declaration(self): - """ - Parse declaration and create new SQL table accordingly. - """ - parents = [] - referenced = [] - index_defs = [] - field_defs = [] - declaration = re.split(r'\s*\n\s*', self.definition.strip()) - - # remove comment lines - declaration = [x for x in declaration if not x.startswith('#')] - ptrn = """ - ^(?P[\w\`]+)\.(?P\w+)\s* # module.className - \(\s*(?P\w+)\s*\)\s* # (tier) - \#\s*(?P.*)$ # comment - """ - p = re.compile(ptrn, re.X) - table_info = p.match(declaration[0]).groupdict() - if table_info['tier'] not in Role.__members__: - raise DataJointError('InvalidTableTier: Invalid tier {tier} for table\ - {module}.{cls}'.format(tier=table_info['tier'], - module=table_info[ - 'module'], - cls=table_info['className'])) - table_info['tier'] = Role[table_info['tier']] # convert into enum - - in_key = True # parse primary keys - attribute_regexp = re.compile(""" - ^[a-z][a-z\d_]*\s* # name - (=\s*\S+(\s+\S+)*\s*)? # optional defaults - :\s*\w.*$ # type, comment - """, re.I + re.X) # ignore case and verbose - - for line in declaration[1:]: - if line.startswith('---'): - in_key = False # start parsing non-PK fields - elif line.startswith('->'): - # foreign key - if '.' in line[2:]: - module_name, class_name = line[2:].strip().split('.') - else: - # assume it's a shorthand - module_name = '' - class_name = line[2:].strip() - ref = parents if in_key else referenced - ref.append(self.get_base(module_name, class_name)) - elif re.match(r'^(unique\s+)?index[^:]*$', line, re.I): - index_defs.append(self._parse_index_def(line)) - elif attribute_regexp.match(line): - field_defs.append(parse_attribute_definition(line, in_key)) - else: - raise DataJointError( - 'Invalid table declaration line "%s"' % line) - - return table_info, parents, referenced, field_defs, index_defs - - -def parse_attribute_definition(line, in_key=False): - """ - Parse attribute definition line in the declaration and returns - an attribute tuple. - - :param line: attribution line - :param in_key: set to True if attribute is in primary key set - :returns: attribute tuple - """ - line = line.strip() - attribute_regexp = re.compile(""" - ^(?P[a-z][a-z\d_]*)\s* # field name - (=\s*(?P\S+(\s+\S+)*?)\s*)? # default value - :\s*(?P\w[^\#]*[^\#\s])\s* # datatype - (\#\s*(?P\S*(\s+\S+)*)\s*)?$ # comment - """, re.X) - m = attribute_regexp.match(line) - if not m: - raise DataJointError('Invalid field declaration "%s"' % line) - attr_info = m.groupdict() - if not attr_info['comment']: - attr_info['comment'] = '' - if not attr_info['default']: - attr_info['default'] = '' - attr_info['nullable'] = attr_info['default'].lower() == 'null' - assert (not re.match(r'^bigint', attr_info['type'], re.I) or not attr_info['nullable']), \ - 'BIGINT attributes cannot be nullable in "%s"' % line - - return Heading.AttrTuple( - in_key=in_key, - autoincrement=None, - numeric=None, - string=None, - is_blob=None, - computation=None, - dtype=None, - **attr_info - ) diff --git a/datajoint/heading.py b/datajoint/heading.py index 620c93751..5c332fff6 100644 --- a/datajoint/heading.py +++ b/datajoint/heading.py @@ -11,17 +11,46 @@ class Heading: Heading contains the property attributes, which is an OrderedDict in which the keys are the attribute names and the values are AttrTuples. """ - AttrTuple = namedtuple('AttrTuple', + + class AttrTuple(namedtuple('AttrTuple', ('name', 'type', 'in_key', 'nullable', 'default', 'comment', 'autoincrement', 'numeric', 'string', 'is_blob', - 'computation', 'dtype')) - AttrTuple.as_dict = AttrTuple._asdict # renaming to make public - - def __init__(self, attributes): + 'computation', 'dtype'))): + def _asdict(self): + """ + for some reason the inherted _asdict does not work after subclassing from namedtuple + """ + return OrderedDict((name, self[i]) for i, name in enumerate(self._fields)) + + def sql(self): + """ + Convert attribute tuple into its SQL CREATE TABLE clause. + :rtype : SQL code + """ + literals = ['CURRENT_TIMESTAMP'] + if self.nullable: + default = 'DEFAULT NULL' + else: + default = 'NOT NULL' + if self.default: + # enclose value in quotes except special SQL values or already enclosed + quote = self.default.upper() not in literals and self.default[0] not in '"\'' + default += ' DEFAULT ' + ('"%s"' if quote else "%s") % self.default + if any((c in r'\"' for c in self.comment)): + raise DataJointError('Illegal characters in attribute comment "%s"' % self.comment) + return '`{name}` {type} {default} COMMENT "{comment}"'.format( + name=self.name, type=self.type, default=default, comment=self.comment) + + def __init__(self, attributes=None): """ :param attributes: a list of dicts with the same keys as AttrTuple """ - self.attributes = OrderedDict([(q['name'], Heading.AttrTuple(**q)) for q in attributes]) + if attributes: + attributes = OrderedDict([(q['name'], Heading.AttrTuple(**q)) for q in attributes]) + self.attributes = attributes + + def __bool__(self): + return self.attributes is not None @property def names(self): @@ -52,12 +81,14 @@ def __getitem__(self, name): return self.attributes[name] def __repr__(self): - autoincrement_string = {False: '', True: ' auto_increment'} - return '\n'.join(['%-20s : %-28s # %s' % ( - k if v.default is None else '%s="%s"' % (k, v.default), - '%s%s' % (v.type, autoincrement_string[v.autoincrement]), - v.comment) - for k, v in self.attributes.items()]) + if self.attributes is None: + return 'Empty heading' + else: + return '\n'.join(['%-20s : %-28s # %s' % ( + k if v.default is None else '%s="%s"' % (k, v.default), + '%s%s' % (v.type, 'auto_increment' if v.autoincrement else ''), + v.comment) + for k, v in self.attributes.items()]) @property def as_dtype(self): @@ -90,14 +121,14 @@ def items(self): def __iter__(self): return iter(self.attributes) - @classmethod - def init_from_database(cls, conn, dbname, table_name): + def init_from_database(self, conn, database, table_name): """ - initialize heading from a database table + initialize heading from a database table. The table must exist already. """ cur = conn.query( - 'SHOW FULL COLUMNS FROM `{table_name}` IN `{dbname}`'.format( - table_name=table_name, dbname=dbname), as_dict=True) + 'SHOW FULL COLUMNS FROM `{table_name}` IN `{database}`'.format( + table_name=table_name, database=database), as_dict=True) + attributes = cur.fetchall() rename_map = { @@ -147,8 +178,8 @@ def init_from_database(cls, conn, dbname, table_name): attr['computation'] = None if not (attr['numeric'] or attr['string'] or attr['is_blob']): - raise DataJointError('Unsupported field type {field} in `{dbname}`.`{table_name}`'.format( - field=attr['type'], dbname=dbname, table_name=table_name)) + raise DataJointError('Unsupported field type {field} in `{database}`.`{table_name}`'.format( + field=attr['type'], database=database, table_name=table_name)) attr.pop('Extra') # fill out dtype. All floats and non-nullable integers are turned into specific dtypes @@ -163,8 +194,7 @@ def init_from_database(cls, conn, dbname, table_name): t = re.sub(r' unsigned$', '', t) # remove unsigned assert (t, is_unsigned) in numeric_types, 'dtype not found for type %s' % t attr['dtype'] = numeric_types[(t, is_unsigned)] - - return cls(attributes) + self.attributes = OrderedDict([(q['name'], Heading.AttrTuple(**q)) for q in attributes]) def project(self, *attribute_list, **renamed_attributes): """ @@ -181,14 +211,14 @@ def project(self, *attribute_list, **renamed_attributes): attribute_list = self.primary_key + [a for a in attribute_list if a not in self.primary_key] # convert attribute_list into a list of dicts but exclude renamed attributes - attribute_list = [v.as_dict() for k, v in self.attributes.items() + attribute_list = [v._asdict() for k, v in self.attributes.items() if k in attribute_list and k not in renamed_attributes.values()] # add renamed and computed attributes for new_name, computation in renamed_attributes.items(): if computation in self.names: # renamed attribute - new_attr = self.attributes[computation].as_dict() + new_attr = self.attributes[computation]._asdict() new_attr['name'] = new_name new_attr['computation'] = '`' + computation + '`' else: @@ -215,14 +245,14 @@ def __add__(self, other): join two headings """ assert isinstance(other, Heading) - attribute_list = [v.as_dict() for v in self.attributes.values()] + attribute_list = [v._asdict() for v in self.attributes.values()] for name in other.names: if name not in self.names: - attribute_list.append(other.attributes[name].as_dict()) + attribute_list.append(other.attributes[name]._asdict()) return Heading(attribute_list) def resolve(self): """ Remove attribute computations after they have been resolved in a subquery """ - return Heading([dict(v.as_dict(), computation=None) for v in self.attributes.values()]) \ No newline at end of file + return Heading([dict(v._asdict(), computation=None) for v in self.attributes.values()]) \ No newline at end of file diff --git a/datajoint/relation.py b/datajoint/relation.py index 7112087ca..3c9ce03dc 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -1,144 +1,271 @@ -import importlib -import abc -from types import ModuleType -from . import DataJointError -from .free_relation import FreeRelation +from collections.abc import Mapping +import numpy as np import logging +import abc +import pymysql +from . import DataJointError, config, conn +from .declare import declare, compile_attribute +from .relational_operand import RelationalOperand +from .blob import pack +from .utils import user_choice +from .heading import Heading logger = logging.getLogger(__name__) -class Relation(FreeRelation, metaclass=abc.ABCMeta): +def schema(database, context, connection=None): """ - Relation is a Table that implements data definition functions. - It is an abstract class with the abstract property 'definition'. + Returns a decorator that can be used to associate a Relation class to a database. - Example for a usage of Relation:: + :param database: name of the database to associate the decorated class with + :param context: dictionary for looking up foreign keys references, usually set to locals() + :param connection: Connection object. Defaults to datajoint.conn() + :return: a decorator for Relation subclasses + """ + if connection is None: + connection = conn() - import datajoint as dj + # if the database does not exist, create it + cur = connection.query("SHOW DATABASES LIKE '{database}'".format(database=database)) + if cur.rowcount == 0: + logger.info("Database `{database}` could not be found. " + "Attempting to create the database.".format(database=database)) + try: + connection.query("CREATE DATABASE `{database}`".format(database=database)) + logger.info('Created database `{database}`.'.format(database=database)) + except pymysql.OperationalError: + raise DataJointError("Database named `{database}` was not defined, and" + "an attempt to create has failed. Check" + " permissions.".format(database=database)) + def decorator(cls): + """ + The decorator declares the table and binds the class to the database table + """ + cls.database = database + cls._connection = connection + cls._heading = Heading() + instance = cls() if isinstance(cls, type) else cls + if not instance.heading: + connection.query( + declare( + full_table_name=instance.full_table_name, + definition=instance.definition, + context=context)) + return cls + + return decorator - class Subjects(dj.Relation): - definition = ''' - test1.Subjects (manual) # Basic subject info - subject_id : int # unique subject id - --- - real_id : varchar(40) # real-world name - species = "mouse" : enum('mouse', 'monkey', 'human') # species - ''' + +class Relation(RelationalOperand, metaclass=abc.ABCMeta): """ - @abc.abstractproperty - def definition(self): + Relation is an abstract class that represents a base relation, i.e. a table in the database. + To make it a concrete class, override the abstract properties specifying the connection, + table name, database, context, and definition. + A Relation implements insert and delete methods in addition to inherited relational operators. + """ + + # ---------- abstract properties ------------ # + @property + @abc.abstractmethod + def table_name(self): + """ + :return: the name of the table in the database """ - :return: string containing the table declaration using the DataJoint Data Definition Language. + raise NotImplementedError('Relation subclasses must define property table_name') - The DataJoint DDL is described at: http://datajoint.github.com + @property + @abc.abstractmethod + def definition(self): + """ + :return: a string containing the table definition using the DataJoint DDL """ pass + # -------------- required by RelationalOperand ----------------- # + @property + def connection(self): + return self._connection + + @property + def heading(self): + if not self._heading and self.is_declared: + self._heading.init_from_database(self.connection, self.database, self.table_name) + return self._heading + @property - def full_class_name(self): + def from_clause(self): """ - :return: full class name including the entire package hierarchy + :return: the FROM clause of SQL SELECT statements. + """ + return self.full_table_name + + def iter_insert(self, rows, **kwargs): """ - return '{}.{}'.format(self.__module__, self.class_name) + Inserts a collection of tuples. Additional keyword arguments are passed to insert. + :param iter: Must be an iterator that generates a sequence of valid arguments for insert. + """ + for row in rows: + self.insert(row, **kwargs) + + # --------- SQL functionality --------- # @property - def ref_name(self): + def is_declared(self): + cur = self.connection.query( + 'SHOW TABLES in `{database}`LIKE "{table_name}"'.format( + database=self.database, table_name=self.table_name)) + return cur.rowcount>0 + + def batch_insert(self, data, **kwargs): """ - :return: name by which this class should be accessible + Inserts an entire batch of entries. Additional keyword arguments are passed to insert. + + :param data: must be iterable, each row must be a valid argument for insert """ - parent = self.__module__.split('.')[-2 if self._use_package else -1] - return parent + '.' + self.class_name + self.iter_insert(data.__iter__(), **kwargs) - def __init__(self): #TODO: support taking in conn obj - class_name = self.__class__.__name__ - module_name = self.__module__ - mod_obj = importlib.import_module(module_name) - self._use_package = False - # first, find the conn object - try: - conn = mod_obj.conn - except AttributeError: - try: - # check if database bound at the package level instead - pkg_obj = importlib.import_module(mod_obj.__package__) - conn = pkg_obj.conn - self._use_package = True - except AttributeError: - raise DataJointError( - "Please define object 'conn' in '{}' or in its containing package.".format(module_name)) - # now use the conn object to determine the dbname this belongs to - try: - if self._use_package: - # the database is bound to the package - pkg_name = '.'.join(module_name.split('.')[:-1]) - dbname = conn.mod_to_db[pkg_name] - else: - dbname = conn.mod_to_db[module_name] - except KeyError: - raise DataJointError( - 'Module {} is not bound to a database. See datajoint.connection.bind'.format(module_name)) - # initialize using super class's constructor - super().__init__(conn, dbname, class_name) - - def get_base(self, module_name, class_name): - """ - Loads the base relation from the module. If the base relation is not defined in - the module, then construct it using Relation constructor. - - :param module_name: module name - :param class_name: class name - :returns: the base relation - """ - if not module_name: - module_name = self.__module__.split('.')[-1] - - mod_obj = self.get_module(module_name) - if not mod_obj: - raise DataJointError('Module named {mod_name} was not found. Please make' - ' sure that it is in the path or you import the module.'.format(mod_name=module_name)) - try: - ret = getattr(mod_obj, class_name)() - except AttributeError: - ret = FreeRelation(conn=self.conn, - dbname=self.conn.mod_to_db[mod_obj.__name__], - class_name=class_name) - return ret - - @classmethod - def get_module(cls, module_name): - """ - Resolve short name reference to a module and return the corresponding module object - - :param module_name: short name for the module, whose reference is to be resolved - :return: resolved module object. If no module matches the short name, `None` will be returned - - The module_name resolution steps in the following order: - - 1. Global reference to a module of the same name defined in the module that contains this Relation derivative. - This is the recommended use case. - 2. Module of the same name defined in the package containing this Relation derivative. This will only look for the - most immediate containing package (e.g. if this class is contained in package.subpackage.module, it will - check within `package.subpackage` but not inside `package`). - 3. Globally accessible module with the same name. - """ - # from IPython import embed - # embed() - mod_obj = importlib.import_module(cls.__module__) - if cls.__module__.split('.')[-1] == module_name: - return mod_obj - attr = getattr(mod_obj, module_name, None) - if isinstance(attr, ModuleType): - return attr - if mod_obj.__package__: - try: - return importlib.import_module('.' + module_name, mod_obj.__package__) - except ImportError: - pass - try: - return importlib.import_module(module_name) - except ImportError: - return None + @property + def full_table_name(self): + return r"`{0:s}`.`{1:s}`".format(self.database, self.table_name) + + def insert(self, tup, ignore_errors=False, replace=False): + """ + Insert one data record or one Mapping (like a dictionary). + + :param tup: Data record, or a Mapping (like a dictionary). + :param ignore_errors=False: Ignores errors if True. + :param replace=False: Replaces data tuple if True. + + Example:: + + b = djtest.Subject() + b.insert(dict(subject_id = 7, species="mouse",\\ + real_id = 1007, date_of_birth = "2014-09-01")) + """ + + heading = self.heading + if isinstance(tup, np.void): + for fieldname in tup.dtype.fields: + if fieldname not in heading: + raise KeyError(u'{0:s} is not in the attribute list'.format(fieldname, )) + value_list = ','.join([repr(tup[name]) if not heading[name].is_blob else '%s' + for name in heading if name in tup.dtype.fields]) + + args = tuple(pack(tup[name]) for name in heading + if name in tup.dtype.fields and heading[name].is_blob) + attribute_list = '`' + '`,`'.join(q for q in heading if q in tup.dtype.fields) + '`' + elif isinstance(tup, Mapping): + for fieldname in tup.keys(): + if fieldname not in heading: + raise KeyError(u'{0:s} is not in the attribute list'.format(fieldname, )) + value_list = ','.join(repr(tup[name]) if not heading[name].is_blob else '%s' + for name in heading if name in tup) + args = tuple(pack(tup[name]) for name in heading + if name in tup and heading[name].is_blob) + attribute_list = '`' + '`,`'.join(name for name in heading if name in tup) + '`' + else: + raise DataJointError('Datatype %s cannot be inserted' % type(tup)) + if replace: + sql = 'REPLACE' + elif ignore_errors: + sql = 'INSERT IGNORE' + else: + sql = 'INSERT' + sql += " INTO %s (%s) VALUES (%s)" % (self.from_clause, attribute_list, value_list) + logger.info(sql) + self.connection.query(sql, args=args) + + def delete(self): + if not config['safemode'] or user_choice( + "You are about to delete data from a table. This operation cannot be undone.\n" + "Proceed?", default='no') == 'yes': + self.connection.query('DELETE FROM ' + self.from_clause + self.where_clause) # TODO: make cascading (issue #15) + + def drop(self): + """ + Drops the table associated to this class. + """ + if self.is_declared: + if not config['safemode'] or user_choice( + "You are about to drop an entire table. This operation cannot be undone.\n" + "Proceed?", default='no') == 'yes': + self.connection.query('DROP TABLE %s' % self.full_table_name) # TODO: make cascading (issue #16) + # cls.connection.clear_dependencies(dbname=cls.dbname) #TODO: reimplement because clear_dependencies will be gone + # cls.connection.load_headings(dbname=cls.dbname, force=True) #TODO: reimplement because load_headings is gone + logger.info("Dropped table %s" % self.full_table_name) + + def size_on_disk(self): + """ + :return: size of data and indices in GiB taken by the table on the storage device + """ + ret = self.connection.query( + 'SHOW TABLE STATUS FROM `{database}` WHERE NAME="{table}"'.format( + database=self.database, table=self.table_name), as_dict=True + ).fetchone() + return (ret['Data_length'] + ret['Index_length'])/1024**2 + + def set_table_comment(self, comment): + """ + Update the table comment in the table definition. + :param comment: new comment as string + """ + self._alter('COMMENT="%s"' % comment) + + def add_attribute(self, definition, after=None): + """ + Add a new attribute to the table. A full line from the table definition + is passed in as definition. + + The definition can specify where to place the new attribute. Use after=None + to add the attribute as the first attribute or after='attribute' to place it + after an existing attribute. + + :param definition: table definition + :param after=None: After which attribute of the table the new attribute is inserted. + If None, the attribute is inserted in front. + """ + position = ' FIRST' if after is None else ( + ' AFTER %s' % after if after else '') + sql = compile_attribute(definition)[1] + self._alter('ADD COLUMN %s%s' % (sql, position)) + + def drop_attribute(self, attribute_name): + """ + Drops the attribute attrName from this table. + :param attribute_name: Name of the attribute that is dropped. + """ + if not config['safemode'] or user_choice( + "You are about to drop an attribute from a table." + "This operation cannot be undone.\n" + "Proceed?", default='no') == 'yes': + self._alter('DROP COLUMN `%s`' % attribute_name) + + def alter_attribute(self, attribute_name, definition): + """ + Alter attribute definition + + :param attribute_name: field that is redefined + :param definition: new definition of the field + """ + sql = compile_attribute(definition)[1] + self._alter('CHANGE COLUMN `%s` %s' % (attribute_name, sql)) + + def erd(self, subset=None): + """ + Plot the schema's entity relationship diagram (ERD). + """ + NotImplemented + + def _alter(self, alter_statement): + """ + Execute an ALTER TABLE statement. + :param alter_statement: alter statement + """ + if self.connection.in_transaction: + raise DataJointError("Table definition cannot be altered during a transaction.") + sql = 'ALTER TABLE %s %s' % (self.full_table_name, alter_statement) + self.connection.query(sql) + self.heading.init_from_database(self.connection, self.database, self.table_name) diff --git a/datajoint/relational_operand.py b/datajoint/relational_operand.py index 51f81c43a..367fd2d1e 100644 --- a/datajoint/relational_operand.py +++ b/datajoint/relational_operand.py @@ -5,11 +5,12 @@ import numpy as np import abc import re +from collections import OrderedDict from copy import copy from datajoint import DataJointError, config -from .blob import unpack import logging -import numpy.lib.recfunctions as rfn + +from .blob import unpack logger = logging.getLogger(__name__) @@ -24,26 +25,44 @@ class RelationalOperand(metaclass=abc.ABCMeta): RelationalOperand operators are: restrict, pro, and join. """ - def __init__(self, conn, restrictions=None): - self._conn = conn - self._restrictions = [] if restrictions is None else restrictions + _restrictions = None @property - def conn(self): - return self._conn + def restrictions(self): + return [] if self._restrictions is None else self._restrictions @property - def restrictions(self): - return self._restrictions + def primary_key(self): + return self.heading.primary_key + + # --------- abstract properties ----------- - @abc.abstractproperty + @property + @abc.abstractmethod + def connection(self): + """ + :return: a datajoint.Connection object + """ + pass + + @property + @abc.abstractmethod def from_clause(self): + """ + :return: a string containing the FROM clause of the SQL SELECT statement + """ pass - @abc.abstractproperty + @property + @abc.abstractmethod def heading(self): + """ + :return: a valid datajoint.Heading object + """ pass + # --------- relational operators ----------- + def __mul__(self, other): """ relational join @@ -98,10 +117,8 @@ def __and__(self, restriction): """ relational restriction or semijoin """ - if self._restrictions is None: - self._restrictions = [] - ret = copy(self) # todo: why not deepcopy it? - ret._restrictions = list(ret._restrictions) # copy restriction + ret = copy(self) + ret._restrictions = list(ret.restrictions) # copy restriction list ret &= restriction return ret @@ -118,6 +135,8 @@ def __sub__(self, restriction): """ return self & Not(restriction) + # ------ data retrieval methods ----------- + def make_select(self, attribute_spec=None): if attribute_spec is None: attribute_spec = self.heading.as_sql @@ -127,7 +146,7 @@ def __len__(self): """ number of tuples in the relation. This also takes care of the truth value """ - cur = self.conn.query(self.make_select('count(*)')) + cur = self.connection.query(self.make_select('count(*)')) return cur.fetchone()[0] def __contains__(self, item): @@ -152,8 +171,8 @@ def fetch1(self): ret = cur.fetchone() if not ret or cur.fetchone(): raise DataJointError('fetch1 should only be used for relations with exactly one tuple') - ret = {k: unpack(v) if heading[k].is_blob else v for k, v in ret.items()} - return ret + return OrderedDict((name, unpack(ret[name]) if heading[name].is_blob else ret[name]) + for name in self.heading.names) def fetch(self, offset=0, limit=None, order_by=None, descending=False, as_dict=False): """ @@ -169,8 +188,8 @@ def fetch(self, offset=0, limit=None, order_by=None, descending=False, as_dict=F descending=descending, as_dict=as_dict) heading = self.heading if as_dict: - ret = [{k: unpack(v) if heading[k].is_blob else v - for k, v in d.items()} + ret = [OrderedDict((name, unpack(d[name]) if heading[name].is_blob else d[name]) + for name in self.heading.names) for d in cur.fetchall()] else: ret = np.array(list(cur.fetchall()), dtype=heading.as_dtype) @@ -196,7 +215,7 @@ def cursor(self, offset=0, limit=None, order_by=None, descending=False, as_dict= if offset: sql += ' OFFSET %d' % offset logger.debug(sql) - return self.conn.query(sql, as_dict=as_dict) + return self.connection.query(sql, as_dict=as_dict) def __repr__(self): limit = config['display.limit'] @@ -236,7 +255,7 @@ def where_clause(self): def make_condition(arg): if isinstance(arg, dict): - conditions = ['`%s`=%s' % (k, repr(v)) for k, v in arg.items()] + conditions = ['`%s`=%s' % (k, repr(v)) for k, v in arg.items() if k in self.heading] elif isinstance(arg, np.void): conditions = ['`%s`=%s' % (k, arg[k]) for k in arg.dtype.fields] else: @@ -281,11 +300,15 @@ class Join(RelationalOperand): def __init__(self, arg1, arg2): if not isinstance(arg2, RelationalOperand): raise DataJointError('a relation can only be joined with another relation') - if arg1.conn != arg2.conn: + if arg1.connection != arg2.connection: raise DataJointError('Cannot join relations with different database connections') self._arg1 = Subquery(arg1) if arg1.heading.computed else arg1 self._arg2 = Subquery(arg1) if arg2.heading.computed else arg2 - super().__init__(arg1.conn, self._arg1.restrictions + self._arg2.restrictions) + self._restrictions = self._arg1.restrictions + self._arg2.restrictions + + @property + def connection(self): + return self._arg1.connection @property def counter(self): @@ -318,9 +341,8 @@ def __init__(self, arg, group=None, *attributes, **renamed_attributes): self._renamed_attributes.update({d['alias']: d['sql_expression']}) else: self._attributes.append(attribute) - super().__init__(arg.conn) if group: - if arg.conn != group.conn: + if arg.connection != group.connection: raise DataJointError('Cannot join relations with different database connections') self._group = Subquery(group) self._arg = Subquery(arg) @@ -333,6 +355,10 @@ def __init__(self, arg, group=None, *attributes, **renamed_attributes): self._arg = arg self._restrictions = self._arg.restrictions + @property + def connection(self): + return self._arg.connection + @property def heading(self): return self._arg.heading.project(*self._attributes, **self._renamed_attributes) @@ -356,7 +382,10 @@ class Subquery(RelationalOperand): def __init__(self, arg): self._arg = arg - super().__init__(arg.conn) + + @property + def connection(self): + return self._arg.connection @property def counter(self): diff --git a/datajoint/settings.py b/datajoint/settings.py index 7e5bfe04a..7dac8186a 100644 --- a/datajoint/settings.py +++ b/datajoint/settings.py @@ -107,7 +107,6 @@ def save(self, filename=None): :param filename: filename of the local JSON settings file. If None, the local config file is used. """ if filename is None: - import datajoint as dj filename = LOCALCONFIG with open(filename, 'w') as fid: json.dump(self._conf, fid, indent=4) @@ -119,8 +118,7 @@ def load(self, filename): :param filename=None: filename of the local JSON settings file. If None, the local config file is used. """ if filename is None: - import datajoint as dj - filename = dj.config['config.file'] + filename = LOCALCONFIG with open(filename, 'r') as fid: self.update(json.load(fid)) diff --git a/datajoint/user_relations.py b/datajoint/user_relations.py new file mode 100644 index 000000000..1b0352a03 --- /dev/null +++ b/datajoint/user_relations.py @@ -0,0 +1,60 @@ +import re +import abc +from datajoint.relation import Relation +from .autopopulate import AutoPopulate +from . import DataJointError + + +class Manual(Relation): + @property + def table_name(self): + return from_camel_case(self.__class__.__name__) + + +class Lookup(Relation): + @property + def table_name(self): + return '#' + from_camel_case(self.__class__.__name__) + + +class Imported(Relation, AutoPopulate): + @property + def table_name(self): + return "_" + from_camel_case(self.__class__.__name__) + + +class Computed(Relation, AutoPopulate): + @property + def table_name(self): + return "__" + from_camel_case(self.__class__.__name__) + + +class Subordinate: + """ + Mix-in to make computed tables subordinate + """ + @property + def populate_relation(self): + return None + + def _make_tuples(self, key): + raise NotImplementedError('Subtables should not be populated directly.') + + +# ---------------- utilities -------------------- +def from_camel_case(s): + """ + Convert names in camel case into underscore (_) separated names + + Example: + >>>from_camel_case("TableName") + "table_name" + """ + def convert(match): + return ('_' if match.groups()[0] else '') + match.group(0).lower() + + if not re.match(r'[A-Z][a-zA-Z0-9]*', s): + raise DataJointError( + 'ClassName must be alphanumeric in CamelCase, begin with a capital letter') + return re.sub(r'(\B[A-Z])|(\b[A-Z])', convert, s) + diff --git a/datajoint/utils.py b/datajoint/utils.py index ec506472e..f4b0edb57 100644 --- a/datajoint/utils.py +++ b/datajoint/utils.py @@ -1,37 +1,3 @@ -import re -from . import DataJointError - - -def to_camel_case(s): - """ - Convert names with under score (_) separation - into camel case names. - - Example: - >>>to_camel_case("table_name") - "TableName" - """ - def to_upper(match): - return match.group(0)[-1].upper() - return re.sub('(^|[_\W])+[a-zA-Z]', to_upper, s) - - -def from_camel_case(s): - """ - Convert names in camel case into underscore (_) separated names - - Example: - >>>from_camel_case("TableName") - "table_name" - """ - def convert(match): - return ('_' if match.groups()[0] else '') + match.group(0).lower() - - if not re.match(r'[A-Z][a-zA-Z0-9]*', s): - raise DataJointError( - 'ClassName must be alphanumeric in CamelCase, begin with a capital letter') - return re.sub(r'(\B[A-Z])|(\b[A-Z])', convert, s) - def user_choice(prompt, choices=("yes", "no"), default=None): """ diff --git a/demos/demo1.py b/demos/demo1.py index e85d6ead3..d30894617 100644 --- a/demos/demo1.py +++ b/demos/demo1.py @@ -1,21 +1,15 @@ # -*- coding: utf-8 -*- -""" -Created on Tue Aug 26 17:42:52 2014 - -@author: dimitri -""" import datajoint as dj print("Welcome to the database 'demo1'") -conn = dj.conn() # connect to database; conn must be defined in module namespace -conn.bind(module=__name__, dbname='dj_test') # bind this module to the database - +schema = dj.schema('dj_test', locals()) -class Subject(dj.Relation): +@schema +class Subject(dj.Manual): definition = """ - demo1.Subject (manual) # Basic subject info + # Basic subject info subject_id : int # internal subject id --- real_id : varchar(40) # real-world name @@ -26,11 +20,16 @@ class Subject(dj.Relation): animal_notes="" : varchar(4096) # strain, genetic manipulations, etc """ +s = Subject() +p = s.primary_key + -class Experiment(dj.Relation): +@schema +class Experiment(dj.Manual): definition = """ - demo1.Experiment (manual) # Basic subject info - -> demo1.Subject + # Basic subject info + + -> Subject experiment : smallint # experiment number for this subject --- experiment_folder : varchar(255) # folder path @@ -40,10 +39,12 @@ class Experiment(dj.Relation): """ -class Session(dj.Relation): +@schema +class Session(dj.Manual): definition = """ - demo1.Session (manual) # a two-photon imaging session - -> demo1.Experiment + # a two-photon imaging session + + -> Experiment session_id : tinyint # two-photon session within this experiment ----------- setup : tinyint # experimental setup @@ -51,11 +52,12 @@ class Session(dj.Relation): """ -class Scan(dj.Relation): +@schema +class Scan(dj.Manual): definition = """ - demo1.Scan (manual) # a two-photon imaging session - -> demo1.Session - -> Config + # a two-photon imaging session + + -> Session scan_id : tinyint # two-photon session within this experiment ---- depth : float # depth from surface @@ -63,16 +65,3 @@ class Scan(dj.Relation): mwatts: numeric(4,1) # (mW) laser power to brain """ -class Config(dj.Relation): - definition = """ - demo1.Config (manual) # configuration for scanner - config_id : tinyint # unique id for config setup - --- - ->ConfigParam - """ - -class ConfigParam(dj.Relation): - definition = """ - demo1.ConfigParam (lookup) # params for configurations - param_set_id : tinyint # id for params - """ \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 82abfa7ea..af4d48f13 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ numpy pymysql +pyparsing networkx matplotlib sphinx_rtd_theme diff --git a/tests/__init__.py b/tests/__init__.py index 09e358e98..611f34bc2 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -18,8 +18,11 @@ 'user': environ.get('DJ_TEST_USER', 'datajoint'), 'passwd': environ.get('DJ_TEST_PASSWORD', 'datajoint') } + +conn = dj.conn(**CONN_INFO) + # Prefix for all databases used during testing -PREFIX = environ.get('DJ_TEST_DB_PREFIX', 'dj') +PREFIX = environ.get('DJ_TEST_DB_PREFIX', 'djtest') # Bare connection used for verification of query results BASE_CONN = pymysql.connect(**CONN_INFO) BASE_CONN.autocommit(True) @@ -30,8 +33,17 @@ def setup(): dj.config['safemode'] = False def teardown(): - cleanup() - + cur = BASE_CONN.cursor() + # cancel any unfinished transactions + cur.execute("ROLLBACK") + # start a transaction now + cur.execute("SHOW DATABASES LIKE '{}\_%'".format(PREFIX)) + dbs = [x[0] for x in cur.fetchall()] + cur.execute('SET FOREIGN_KEY_CHECKS=0') # unset foreign key check while deleting + for db in dbs: + cur.execute('DROP DATABASE `{}`'.format(db)) + cur.execute('SET FOREIGN_KEY_CHECKS=1') # set foreign key check back on + cur.execute("COMMIT") def cleanup(): """ @@ -49,39 +61,42 @@ def cleanup(): dbs = [x[0] for x in cur.fetchall()] cur.execute('SET FOREIGN_KEY_CHECKS=0') # unset foreign key check while deleting for db in dbs: - cur.execute('DROP DATABASE `{}`'.format(db)) + cur.execute("USE %s" % (db)) + cur.execute("SHOW TABLES") + for table in [x[0] for x in cur.fetchall()]: + cur.execute('DELETE FROM `{}`'.format(table)) cur.execute('SET FOREIGN_KEY_CHECKS=1') # set foreign key check back on cur.execute("COMMIT") - -def setup_sample_db(): - """ - Helper method to setup databases with tables to be used - during the test - """ - cur = BASE_CONN.cursor() - cur.execute("CREATE DATABASE IF NOT EXISTS `{}_test1`".format(PREFIX)) - cur.execute("CREATE DATABASE IF NOT EXISTS `{}_test2`".format(PREFIX)) - query1 = """ - CREATE TABLE `{prefix}_test1`.`subjects` - ( - subject_id SMALLINT COMMENT 'Unique subject ID', - subject_name VARCHAR(255) COMMENT 'Subject name', - subject_email VARCHAR(255) COMMENT 'Subject email address', - PRIMARY KEY (subject_id) - ) - """.format(prefix=PREFIX) - cur.execute(query1) - query2 = """ - CREATE TABLE `{prefix}_test2`.`experimenter` - ( - experimenter_id SMALLINT COMMENT 'Unique experimenter ID', - experimenter_name VARCHAR(255) COMMENT 'Experimenter name', - PRIMARY KEY (experimenter_id) - )""".format(prefix=PREFIX) - cur.execute(query2) - - - - - - +# +# def setup_sample_db(): +# """ +# Helper method to setup databases with tables to be used +# during the test +# """ +# cur = BASE_CONN.cursor() +# cur.execute("CREATE DATABASE IF NOT EXISTS `{}_test1`".format(PREFIX)) +# cur.execute("CREATE DATABASE IF NOT EXISTS `{}_test2`".format(PREFIX)) +# query1 = """ +# CREATE TABLE `{prefix}_test1`.`subjects` +# ( +# subject_id SMALLINT COMMENT 'Unique subject ID', +# subject_name VARCHAR(255) COMMENT 'Subject name', +# subject_email VARCHAR(255) COMMENT 'Subject email address', +# PRIMARY KEY (subject_id) +# ) +# """.format(prefix=PREFIX) +# cur.execute(query1) +# query2 = """ +# CREATE TABLE `{prefix}_test2`.`experimenter` +# ( +# experimenter_id SMALLINT COMMENT 'Unique experimenter ID', +# experimenter_name VARCHAR(255) COMMENT 'Experimenter name', +# PRIMARY KEY (experimenter_id) +# )""".format(prefix=PREFIX) +# cur.execute(query2) +# +# +# +# +# +# diff --git a/tests/schemata/__init__.py b/tests/schemata/__init__.py index 6f391d065..9fa8b9ad1 100644 --- a/tests/schemata/__init__.py +++ b/tests/schemata/__init__.py @@ -1 +1 @@ -__author__ = "eywalker" \ No newline at end of file +__author__ = "eywalker, fabiansinz" \ No newline at end of file diff --git a/tests/schemata/schema1/__init__.py b/tests/schemata/schema1/__init__.py index 6032e7bd6..cae90cec9 100644 --- a/tests/schemata/schema1/__init__.py +++ b/tests/schemata/schema1/__init__.py @@ -1,5 +1,5 @@ -__author__ = 'eywalker' -import datajoint as dj - -print(__name__) -from .test3 import * \ No newline at end of file +# __author__ = 'eywalker' +# import datajoint as dj +# +# print(__name__) +# from .test3 import * \ No newline at end of file diff --git a/tests/schemata/schema1/test1.py b/tests/schemata/schema1/test1.py deleted file mode 100644 index 4c8df082f..000000000 --- a/tests/schemata/schema1/test1.py +++ /dev/null @@ -1,166 +0,0 @@ -""" -Test 1 Schema definition -""" -__author__ = 'eywalker' - -import datajoint as dj -from .. import schema2 - - -class Subjects(dj.Relation): - definition = """ - test1.Subjects (manual) # Basic subject info - - subject_id : int # unique subject id - --- - real_id : varchar(40) # real-world name - species = "mouse" : enum('mouse', 'monkey', 'human') # species - """ - -# test for shorthand -class Animals(dj.Relation): - definition = """ - test1.Animals (manual) # Listing of all info - - -> Subjects - --- - animal_dob :date # date of birth - """ - - -class Trials(dj.Relation): - definition = """ - test1.Trials (manual) # info about trials - - -> test1.Subjects - trial_id : int - --- - outcome : int # result of experiment - - notes="" : varchar(4096) # other comments - trial_ts=CURRENT_TIMESTAMP : timestamp # automatic - """ - - - -class SquaredScore(dj.Relation, dj.AutoPopulate): - definition = """ - test1.SquaredScore (computed) # cumulative outcome of trials - - -> test1.Subjects - -> test1.Trials - --- - squared : int # squared result of Trials outcome - """ - - @property - def populate_relation(self): - return Subjects() * Trials() - - def _make_tuples(self, key): - tmp = (Trials() & key).fetch1() - tmp2 = SquaredSubtable() & key - - self.insert(dict(key, squared=tmp['outcome']**2)) - - ss = SquaredSubtable() - - for i in range(10): - key['dummy'] = i - ss.insert(key) - - -class WrongImplementation(dj.Relation, dj.AutoPopulate): - definition = """ - test1.WrongImplementation (computed) # ignore - - -> test1.Subjects - -> test1.Trials - --- - dummy : int # ignore - """ - - @property - def populate_relation(self): - return {'subject_id':2} - - def _make_tuples(self, key): - pass - - - -class ErrorGenerator(dj.Relation, dj.AutoPopulate): - definition = """ - test1.ErrorGenerator (computed) # ignore - - -> test1.Subjects - -> test1.Trials - --- - dummy : int # ignore - """ - - @property - def populate_relation(self): - return Subjects() * Trials() - - def _make_tuples(self, key): - raise Exception("This is for testing") - - - - - - -class SquaredSubtable(dj.Relation): - definition = """ - test1.SquaredSubtable (computed) # cumulative outcome of trials - - -> test1.SquaredScore - dummy : int # dummy primary attribute - --- - """ - - -# test reference to another table in same schema -class Experiments(dj.Relation): - definition = """ - test1.Experiments (imported) # Experiment info - -> test1.Subjects - exp_id : int # unique id for experiment - --- - exp_data_file : varchar(255) # data file - """ - - -# refers to a table in dj_test2 (bound to test2) but without a class -class Sessions(dj.Relation): - definition = """ - test1.Sessions (manual) # Experiment sessions - -> test1.Subjects - -> test2.Experimenter - session_id : int # unique session id - --- - session_comment : varchar(255) # comment about the session - """ - - -class Match(dj.Relation): - definition = """ - test1.Match (manual) # Match between subject and color - -> schema2.Subjects - --- - dob : date # date of birth - """ - - -# this tries to reference a table in database directly without ORM -class TrainingSession(dj.Relation): - definition = """ - test1.TrainingSession (manual) # training sessions - -> `dj_test2`.Experimenter - session_id : int # training session id - """ - - -class Empty(dj.Relation): - pass diff --git a/tests/schemata/schema1/test2.py b/tests/schemata/schema1/test2.py index 563fe6b52..aded2e4fb 100644 --- a/tests/schemata/schema1/test2.py +++ b/tests/schemata/schema1/test2.py @@ -1,48 +1,48 @@ -""" -Test 2 Schema definition -""" -__author__ = 'eywalker' - -import datajoint as dj -from . import test1 as alias -#from ..schema2 import test2 as test1 - - -# references to another schema -class Experiments(dj.Relation): - definition = """ - test2.Experiments (manual) # Basic subject info - -> test1.Subjects - experiment_id : int # unique experiment id - --- - real_id : varchar(40) # real-world name - species = "mouse" : enum('mouse', 'monkey', 'human') # species - """ - - -# references to another schema -class Conditions(dj.Relation): - definition = """ - test2.Conditions (manual) # Subject conditions - -> alias.Subjects - condition_name : varchar(255) # description of the condition - """ - - -class FoodPreference(dj.Relation): - definition = """ - test2.FoodPreference (manual) # Food preference of each subject - -> animals.Subjects - preferred_food : enum('banana', 'apple', 'oranges') - """ - - -class Session(dj.Relation): - definition = """ - test2.Session (manual) # Experiment sessions - -> test1.Subjects - -> test2.Experimenter - session_id : int # unique session id - --- - session_comment : varchar(255) # comment about the session - """ \ No newline at end of file +# """ +# Test 2 Schema definition +# """ +# __author__ = 'eywalker' +# +# import datajoint as dj +# from . import test1 as alias +# #from ..schema2 import test2 as test1 +# +# +# # references to another schema +# class Experiments(dj.Relation): +# definition = """ +# test2.Experiments (manual) # Basic subject info +# -> test1.Subjects +# experiment_id : int # unique experiment id +# --- +# real_id : varchar(40) # real-world name +# species = "mouse" : enum('mouse', 'monkey', 'human') # species +# """ +# +# +# # references to another schema +# class Conditions(dj.Relation): +# definition = """ +# test2.Conditions (manual) # Subject conditions +# -> alias.Subjects +# condition_name : varchar(255) # description of the condition +# """ +# +# +# class FoodPreference(dj.Relation): +# definition = """ +# test2.FoodPreference (manual) # Food preference of each subject +# -> animals.Subjects +# preferred_food : enum('banana', 'apple', 'oranges') +# """ +# +# +# class Session(dj.Relation): +# definition = """ +# test2.Session (manual) # Experiment sessions +# -> test1.Subjects +# -> test2.Experimenter +# session_id : int # unique session id +# --- +# session_comment : varchar(255) # comment about the session +# """ \ No newline at end of file diff --git a/tests/schemata/schema1/test3.py b/tests/schemata/schema1/test3.py index 2004a8736..e00a01afb 100644 --- a/tests/schemata/schema1/test3.py +++ b/tests/schemata/schema1/test3.py @@ -1,21 +1,21 @@ -""" -Test 3 Schema definition - no binding, no conn - -To be bound at the package level -""" -__author__ = 'eywalker' - -import datajoint as dj - - -class Subjects(dj.Relation): - definition = """ - schema1.Subjects (manual) # Basic subject info - - subject_id : int # unique subject id - dob : date # date of birth - --- - real_id : varchar(40) # real-world name - species = "mouse" : enum('mouse', 'monkey', 'human') # species - """ - +# """ +# Test 3 Schema definition - no binding, no conn +# +# To be bound at the package level +# """ +# __author__ = 'eywalker' +# +# import datajoint as dj +# +# +# class Subjects(dj.Relation): +# definition = """ +# schema1.Subjects (manual) # Basic subject info +# +# subject_id : int # unique subject id +# dob : date # date of birth +# --- +# real_id : varchar(40) # real-world name +# species = "mouse" : enum('mouse', 'monkey', 'human') # species +# """ +# diff --git a/tests/schemata/schema1/test4.py b/tests/schemata/schema1/test4.py index a2004affd..9860cb030 100644 --- a/tests/schemata/schema1/test4.py +++ b/tests/schemata/schema1/test4.py @@ -1,17 +1,17 @@ -""" -Test 1 Schema definition - fully bound and has connection object -""" -__author__ = 'fabee' - -import datajoint as dj - - -class Matrix(dj.Relation): - definition = """ - test4.Matrix (manual) # Some numpy array - - matrix_id : int # unique matrix id - --- - data : longblob # data - comment : varchar(1000) # comment - """ +# """ +# Test 1 Schema definition - fully bound and has connection object +# """ +# __author__ = 'fabee' +# +# import datajoint as dj +# +# +# class Matrix(dj.Relation): +# definition = """ +# test4.Matrix (manual) # Some numpy array +# +# matrix_id : int # unique matrix id +# --- +# data : longblob # data +# comment : varchar(1000) # comment +# """ diff --git a/tests/schemata/schema2/__init__.py b/tests/schemata/schema2/__init__.py index e6b482590..d79e02cc3 100644 --- a/tests/schemata/schema2/__init__.py +++ b/tests/schemata/schema2/__init__.py @@ -1,2 +1,2 @@ -__author__ = 'eywalker' -from .test1 import * \ No newline at end of file +# __author__ = 'eywalker' +# from .test1 import * \ No newline at end of file diff --git a/tests/schemata/schema2/test1.py b/tests/schemata/schema2/test1.py index 83bb3a19e..4005aa670 100644 --- a/tests/schemata/schema2/test1.py +++ b/tests/schemata/schema2/test1.py @@ -1,16 +1,16 @@ -""" -Test 2 Schema definition -""" -__author__ = 'eywalker' - -import datajoint as dj - - -class Subjects(dj.Relation): - definition = """ - schema2.Subjects (manual) # Basic subject info - pop_id : int # unique experiment id - --- - real_id : varchar(40) # real-world name - species = "mouse" : enum('mouse', 'monkey', 'human') # species - """ \ No newline at end of file +# """ +# Test 2 Schema definition +# """ +# __author__ = 'eywalker' +# +# import datajoint as dj +# +# +# class Subjects(dj.Relation): +# definition = """ +# schema2.Subjects (manual) # Basic subject info +# pop_id : int # unique experiment id +# --- +# real_id : varchar(40) # real-world name +# species = "mouse" : enum('mouse', 'monkey', 'human') # species +# """ \ No newline at end of file diff --git a/tests/schemata/test1.py b/tests/schemata/test1.py new file mode 100644 index 000000000..a9a03be64 --- /dev/null +++ b/tests/schemata/test1.py @@ -0,0 +1,175 @@ +""" +Test 1 Schema definition +""" +__author__ = 'eywalker' + +import datajoint as dj +# from .. import schema2 +from .. import PREFIX + +testschema = dj.schema(PREFIX + '_test1', locals()) + +@testschema +class Subjects(dj.Manual): + definition = """ + #Basic subject info + + subject_id : int # unique subject id + --- + real_id : varchar(40) # real-world name + species = "mouse" : enum('mouse', 'monkey', 'human') # species + """ + +# test for shorthand +@testschema +class Animals(dj.Manual): + definition = """ + # Listing of all info + + -> Subjects + --- + animal_dob :date # date of birth + """ + +@testschema +class Trials(dj.Manual): + definition = """ + # info about trials + + -> Subjects + trial_id : int + --- + outcome : int # result of experiment + + notes="" : varchar(4096) # other comments + trial_ts=CURRENT_TIMESTAMP : timestamp # automatic + """ + +@testschema +class Matrix(dj.Manual): + definition = """ + # Some numpy array + + matrix_id : int # unique matrix id + --- + data : longblob # data + comment : varchar(1000) # comment + """ + + +@testschema +class SquaredScore(dj.Computed): + definition = """ + # cumulative outcome of trials + + -> Subjects + -> Trials + --- + squared : int # squared result of Trials outcome + """ + + @property + def populate_relation(self): + return Subjects() * Trials() + + def _make_tuples(self, key): + tmp = (Trials() & key).fetch1() + tmp2 = SquaredSubtable() & key + + self.insert(dict(key, squared=tmp['outcome']**2)) + + ss = SquaredSubtable() + + for i in range(10): + key['dummy'] = i + ss.insert(key) + +@testschema +class WrongImplementation(dj.Computed): + definition = """ + # ignore + + -> Subjects + -> Trials + --- + dummy : int # ignore + """ + + @property + def populate_relation(self): + return {'subject_id':2} + + def _make_tuples(self, key): + pass + +class ErrorGenerator(dj.Computed): + definition = """ + # ignore + + -> Subjects + -> Trials + --- + dummy : int # ignore + """ + + @property + def populate_relation(self): + return Subjects() * Trials() + + def _make_tuples(self, key): + raise Exception("This is for testing") + +@testschema +class SquaredSubtable(dj.Subordinate, dj.Manual): + definition = """ + # cumulative outcome of trials + + -> SquaredScore + dummy : int # dummy primary attribute + --- + """ +# +# +# # test reference to another table in same schema +# class Experiments(dj.Relation): +# definition = """ +# test1.Experiments (imported) # Experiment info +# -> test1.Subjects +# exp_id : int # unique id for experiment +# --- +# exp_data_file : varchar(255) # data file +# """ +# +# +# # refers to a table in dj_test2 (bound to test2) but without a class +# class Sessions(dj.Relation): +# definition = """ +# test1.Sessions (manual) # Experiment sessions +# -> test1.Subjects +# -> test2.Experimenter +# session_id : int # unique session id +# --- +# session_comment : varchar(255) # comment about the session +# """ +# +# +# class Match(dj.Relation): +# definition = """ +# test1.Match (manual) # Match between subject and color +# -> schema2.Subjects +# --- +# dob : date # date of birth +# """ +# +# +# # this tries to reference a table in database directly without ORM +# class TrainingSession(dj.Relation): +# definition = """ +# test1.TrainingSession (manual) # training sessions +# -> `dj_test2`.Experimenter +# session_id : int # training session id +# """ +# +# +# class Empty(dj.Relation): +# pass diff --git a/tests/test_blob.py b/tests/test_blob.py index 489707f31..322be02cc 100644 --- a/tests/test_blob.py +++ b/tests/test_blob.py @@ -1,7 +1,9 @@ +from datajoint import DataJointError + __author__ = 'fabee' import numpy as np from datajoint.blob import pack, unpack -from numpy.testing import assert_array_equal +from numpy.testing import assert_array_equal, raises def test_pack(): @@ -16,3 +18,20 @@ def test_pack(): x = np.int16(np.random.randn(1, 2, 3)) assert_array_equal(x, unpack(pack(x)), "Arrays do not match!") + +@raises(DataJointError) +def test_error(): + pack(dict()) + +def test_complex(): + z = np.random.randn(8, 10) + 1j*np.random.randn(8,10) + assert_array_equal(z, unpack(pack(z)), "Arrays do not match!") + + z = np.random.randn(10)+ 1j*np.random.randn(10) + assert_array_equal(z, unpack(pack(z)), "Arrays do not match!") + + x = np.float32(np.random.randn(3, 4, 5)) + 1j*np.float32(np.random.randn(3, 4, 5)) + assert_array_equal(x, unpack(pack(x)), "Arrays do not match!") + + x = np.int16(np.random.randn(1, 2, 3)) + 1j*np.int16(np.random.randn(1, 2, 3)) + assert_array_equal(x, unpack(pack(x)), "Arrays do not match!") diff --git a/tests/test_connection.py b/tests/test_connection.py index 1fb581468..b94a07adf 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,16 +1,14 @@ """ Collection of test cases to test connection module. """ -from .schemata import schema1 -from .schemata.schema1 import test1 -import numpy as np +from tests.schemata.test1 import Subjects __author__ = 'eywalker, fabee' -from . import (CONN_INFO, PREFIX, BASE_CONN, cleanup) +from . import CONN_INFO, PREFIX, BASE_CONN, cleanup from nose.tools import assert_true, assert_raises, assert_equal, raises import datajoint as dj -from datajoint.utils import DataJointError - +from datajoint import DataJointError +import numpy as np def setup(): cleanup() @@ -44,214 +42,70 @@ def test_dj_conn_reset(): assert_true(c1 is not c2) +def test_repr(): + c1 = dj.conn(**CONN_INFO) + assert_true('disconnected' not in c1.__repr__() and 'connected' in c1.__repr__()) -def setup_sample_db(): - """ - Helper method to setup databases with tables to be used - during the test - """ - cur = BASE_CONN.cursor() - cur.execute("CREATE DATABASE `{}_test1`".format(PREFIX)) - cur.execute("CREATE DATABASE `{}_test2`".format(PREFIX)) - query1 = """ - CREATE TABLE `{prefix}_test1`.`subjects` - ( - subject_id SMALLINT COMMENT 'Unique subject ID', - subject_name VARCHAR(255) COMMENT 'Subject name', - subject_email VARCHAR(255) COMMENT 'Subject email address', - PRIMARY KEY (subject_id) - ) - """.format(prefix=PREFIX) - cur.execute(query1) - # query2 = """ - # CREATE TABLE `{prefix}_test2`.`experiments` - # ( - # experiment_id SMALLINT COMMENT 'Unique experiment ID', - # experiment_name VARCHAR(255) COMMENT 'Experiment name', - # subject_id SMALLINT, - # CONSTRAINT FOREIGN KEY (`subject_id`) REFERENCES `dj_test1`.`subjects` (`subject_id`) ON UPDATE CASCADE ON DELETE RESTRICT, - # PRIMARY KEY (subject_id, experiment_id) - # )""".format(prefix=PREFIX) - # cur.execute(query2) - - -class TestConnectionWithoutBindings(object): - """ - Test methods from Connection that does not - depend on presence of module to database bindings. - This includes tests for `bind` method itself. - """ - def setup(self): - self.conn = dj.Connection(**CONN_INFO) - test1.__dict__.pop('conn', None) - schema1.__dict__.pop('conn', None) - setup_sample_db() - - def teardown(self): - cleanup() - - def check_binding(self, db_name, module): - """ - Helper method to check if the specified database-module pairing exists - """ - assert_equal(self.conn.db_to_mod[db_name], module) - assert_equal(self.conn.mod_to_db[module], db_name) - - def test_bind_to_existing_database(self): - """ - Should be able to bind a module to an existing database - """ - db_name = PREFIX + '_test1' - module = test1.__name__ - self.conn.bind(module, db_name) - self.check_binding(db_name, module) - - def test_bind_at_package_level(self): - db_name = PREFIX + '_test1' - package = schema1.__name__ - self.conn.bind(package, db_name) - self.check_binding(db_name, package) - - def test_bind_to_non_existing_database(self): - """ - Should be able to bind a module to a non-existing database by creating target - """ - db_name = PREFIX + '_test3' - module = test1.__name__ - cur = BASE_CONN.cursor() +def test_del(): + c1 = dj.conn(**CONN_INFO) + assert_true('disconnected' not in c1.__repr__() and 'connected' in c1.__repr__()) + del c1 - # Ensure target database doesn't exist - if cur.execute("SHOW DATABASES LIKE '{}'".format(db_name)): - cur.execute("DROP DATABASE IF EXISTS `{}`".format(db_name)) - # Bind module to non-existing database - self.conn.bind(module, db_name) - # Check that target database was created - assert_equal(cur.execute("SHOW DATABASES LIKE '{}'".format(db_name)), 1) - self.check_binding(db_name, module) - # Remove the target database - cur.execute("DROP DATABASE IF EXISTS `{}`".format(db_name)) - def test_cannot_bind_to_multiple_databases(self): - """ - Bind will fail when db_name is a pattern that - matches multiple databases - """ - db_name = PREFIX + "_test%%" - module = test1.__name__ - with assert_raises(DataJointError): - self.conn.bind(module, db_name) - def test_basic_sql_query(self): - """ - Test execution of basic SQL query using connection - object. - """ - cur = self.conn.query('SHOW DATABASES') - results1 = cur.fetchall() - cur2 = BASE_CONN.cursor() - cur2.execute('SHOW DATABASES') - results2 = cur2.fetchall() - assert_equal(results1, results2) +class TestContextManager(object): + def __init__(self): + self.relvar = None + self.setup() - def test_transaction_commit(self): - """ - Test transaction commit - """ - table_name = PREFIX + '_test1.subjects' - self.conn.start_transaction() - self.conn.query("INSERT INTO {table} VALUES (0, 'dj_user', 'dj_user@example.com')".format(table=table_name)) - cur = BASE_CONN.cursor() - assert_equal(cur.execute("SELECT * FROM {}".format(table_name)), 0) - self.conn.commit_transaction() - assert_equal(cur.execute("SELECT * FROM {}".format(table_name)), 1) + """ + Test cases for FreeRelation objects + """ - def test_transaction_rollback(self): + def setup(self): """ - Test transaction rollback + Create a connection object and prepare test modules + as follows: + test1 - has conn and bounded """ - table_name = PREFIX + '_test1.subjects' - self.conn.start_transaction() - self.conn.query("INSERT INTO {table} VALUES (0, 'dj_user', 'dj_user@example.com')".format(table=table_name)) - cur = BASE_CONN.cursor() - assert_equal(cur.execute("SELECT * FROM {}".format(table_name)), 0) - self.conn.cancel_transaction() - assert_equal(cur.execute("SELECT * FROM {}".format(table_name)), 0) - -# class TestContextManager(object): -# def __init__(self): -# self.relvar = None -# self.setup() -# -# """ -# Test cases for FreeRelation objects -# """ -# -# def setup(self): -# """ -# Create a connection object and prepare test modules -# as follows: -# test1 - has conn and bounded -# """ -# cleanup() # drop all databases with PREFIX -# test1.__dict__.pop('conn', None) -# -# self.conn = dj.Connection(**CONN_INFO) -# test1.conn = self.conn -# self.conn.bind(test1.__name__, PREFIX + '_test1') -# self.relvar = test1.Subjects() -# -# def teardown(self): -# cleanup() -# -# # def test_active(self): -# # with self.conn.transaction() as tr: -# # assert_true(tr.is_active, "Transaction is not active") -# -# # def test_rollback(self): -# # -# # tmp = np.array([(1,'Peter','mouse'),(2, 'Klara', 'monkey')], -# # dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) -# # -# # self.relvar.insert(tmp[0]) -# # try: -# # with self.conn.transaction(): -# # self.relvar.insert(tmp[1]) -# # raise DataJointError("Just to test") -# # except DataJointError as e: -# # pass -# # -# # testt2 = (self.relvar & 'subject_id = 2').fetch() -# # assert_equal(len(testt2), 0, "Length is not 0. Expected because rollback should have happened.") -# -# # def test_cancel(self): -# # """Tests cancelling a transaction""" -# # tmp = np.array([(1,'Peter','mouse'),(2, 'Klara', 'monkey')], -# # dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) -# # -# # self.relvar.insert(tmp[0]) -# # with self.conn.transaction() as transaction: -# # self.relvar.insert(tmp[1]) -# # transaction.cancel() -# # -# # testt2 = (self.relvar & 'subject_id = 2').fetch() -# # assert_equal(len(testt2), 0, "Length is not 0. Expected because rollback should have happened.") - - - -# class TestConnectionWithBindings(object): -# """ -# Tests heading and dependency loadings -# """ -# def setup(self): -# self.conn = dj.Connection(**CONN_INFO) -# cur.execute(query) - - - - - + cleanup() # drop all databases with PREFIX + self.conn = dj.conn() + self.relvar = Subjects() + def teardown(self): + cleanup() + def test_active(self): + with self.conn.transaction() as conn: + assert_true(conn.in_transaction, "Transaction is not active") + + def test_rollback(self): + + tmp = np.array([(1,'Peter','mouse'),(2, 'Klara', 'monkey')], + dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) + + self.relvar.insert(tmp[0]) + try: + with self.conn.transaction(): + self.relvar.insert(tmp[1]) + raise DataJointError("Just to test") + except DataJointError as e: + pass + testt2 = (self.relvar & 'subject_id = 2').fetch() + assert_equal(len(testt2), 0, "Length is not 0. Expected because rollback should have happened.") + + def test_cancel(self): + """Tests cancelling a transaction""" + tmp = np.array([(1,'Peter','mouse'),(2, 'Klara', 'monkey')], + dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) + + self.relvar.insert(tmp[0]) + with self.conn.transaction() as conn: + self.relvar.insert(tmp[1]) + conn.cancel_transaction() + + testt2 = (self.relvar & 'subject_id = 2').fetch() + assert_equal(len(testt2), 0, "Length is not 0. Expected because rollback should have happened.") diff --git a/tests/test_free_relation.py b/tests/test_free_relation.py index e4ebaa872..285bf7487 100644 --- a/tests/test_free_relation.py +++ b/tests/test_free_relation.py @@ -1,205 +1,205 @@ -""" -Collection of test cases for base module. Tests functionalities such as -creating tables using docstring table declarations -""" -from .schemata import schema1, schema2 -from .schemata.schema1 import test1, test2, test3 - - -__author__ = 'eywalker' - -from . import BASE_CONN, CONN_INFO, PREFIX, cleanup, setup_sample_db -from datajoint.connection import Connection -from nose.tools import assert_raises, assert_equal, assert_regexp_matches, assert_false, assert_true, raises -from datajoint import DataJointError - - -def setup(): - """ - Setup connections and bindings - """ - pass - - -class TestRelationInstantiations(object): - """ - Test cases for instantiating Relation objects - """ - def __init__(self): - self.conn = None - - def setup(self): - """ - Create a connection object and prepare test modules - as follows: - test1 - has conn and bounded - """ - self.conn = Connection(**CONN_INFO) - cleanup() # drop all databases with PREFIX - #test1.conn = self.conn - #self.conn.bind(test1.__name__, PREFIX+'_test1') - - #test2.conn = self.conn - - #test3.__dict__.pop('conn', None) # make sure conn is not defined in test3 - test1.__dict__.pop('conn', None) - schema1.__dict__.pop('conn', None) # make sure conn is not defined at schema level - - - def teardown(self): - cleanup() - - - def test_instantiation_from_unbound_module_should_fail(self): - """ - Attempting to instantiate a Relation derivative from a module with - connection defined but not bound to a database should raise error - """ - test1.conn = self.conn - with assert_raises(DataJointError) as e: - test1.Subjects() - assert_regexp_matches(e.exception.args[0], r".*not bound.*") - - def test_instantiation_from_module_without_conn_should_fail(self): - """ - Attempting to instantiate a Relation derivative from a module that lacks - `conn` object should raise error - """ - with assert_raises(DataJointError) as e: - test1.Subjects() - assert_regexp_matches(e.exception.args[0], r".*define.*conn.*") - - def test_instantiation_of_base_derivatives(self): - """ - Test instantiation and initialization of objects derived from - Relation class - """ - test1.conn = self.conn - self.conn.bind(test1.__name__, PREFIX + '_test1') - s = test1.Subjects() - assert_equal(s.dbname, PREFIX + '_test1') - assert_equal(s.conn, self.conn) - assert_equal(s.definition, test1.Subjects.definition) - - def test_packagelevel_binding(self): - schema2.conn = self.conn - self.conn.bind(schema2.__name__, PREFIX + '_test1') - s = schema2.test1.Subjects() - - -class TestRelationDeclaration(object): - """ - Test declaration (creation of table) from - definition in Relation under various circumstances - """ - - def setup(self): - cleanup() - - self.conn = Connection(**CONN_INFO) - test1.conn = self.conn - self.conn.bind(test1.__name__, PREFIX + '_test1') - test2.conn = self.conn - self.conn.bind(test2.__name__, PREFIX + '_test2') - - def test_is_declared(self): - """ - The table should not be created immediately after instantiation, - but should be created when declare method is called - :return: - """ - s = test1.Subjects() - assert_false(s.is_declared) - s.declare() - assert_true(s.is_declared) - - def test_calling_heading_should_trigger_declaration(self): - s = test1.Subjects() - assert_false(s.is_declared) - a = s.heading - assert_true(s.is_declared) - - def test_foreign_key_ref_in_same_schema(self): - s = test1.Experiments() - assert_true('subject_id' in s.heading.primary_key) - - def test_foreign_key_ref_in_another_schema(self): - s = test2.Experiments() - assert_true('subject_id' in s.heading.primary_key) - - def test_aliased_module_name_should_resolve(self): - """ - Module names that were aliased in the definition should - be properly resolved. - """ - s = test2.Conditions() - assert_true('subject_id' in s.heading.primary_key) - - def test_reference_to_unknown_module_in_definition_should_fail(self): - """ - Module names in table definition that is not aliased via import - results in error - """ - s = test2.FoodPreference() - with assert_raises(DataJointError) as e: - s.declare() - - -class TestRelationWithExistingTables(object): - """ - Test base derivatives behaviors when some of the tables - already exists in the database - """ - def setup(self): - cleanup() - self.conn = Connection(**CONN_INFO) - setup_sample_db() - test1.conn = self.conn - self.conn.bind(test1.__name__, PREFIX + '_test1') - test2.conn = self.conn - self.conn.bind(test2.__name__, PREFIX + '_test2') - self.conn.load_headings(force=True) - - schema2.conn = self.conn - self.conn.bind(schema2.__name__, PREFIX + '_package') - - def teardown(selfself): - schema1.__dict__.pop('conn', None) - cleanup() - - def test_detection_of_existing_table(self): - """ - The Relation instance should be able to detect if the - corresponding table already exists in the database - """ - s = test1.Subjects() - assert_true(s.is_declared) - - def test_definition_referring_to_existing_table_without_class(self): - s1 = test1.Sessions() - assert_true('experimenter_id' in s1.primary_key) - - s2 = test2.Session() - assert_true('experimenter_id' in s2.primary_key) - - def test_reference_to_package_level_table(self): - s = test1.Match() - s.declare() - assert_true('pop_id' in s.primary_key) - - def test_direct_reference_to_existing_table_should_fail(self): - """ - When deriving from Relation, definition should not contain direct reference - to a database name - """ - s = test1.TrainingSession() - with assert_raises(DataJointError): - s.declare() - -@raises(TypeError) -def test_instantiation_of_base_derivative_without_definition_should_fail(): - test1.Empty() - - - - +# """ +# Collection of test cases for base module. Tests functionalities such as +# creating tables using docstring table declarations +# """ +# from .schemata import schema1, schema2 +# from .schemata.schema1 import test1, test2, test3 +# +# +# __author__ = 'eywalker' +# +# from . import BASE_CONN, CONN_INFO, PREFIX, cleanup, setup_sample_db +# from datajoint.connection import Connection +# from nose.tools import assert_raises, assert_equal, assert_regexp_matches, assert_false, assert_true, raises +# from datajoint import DataJointError +# +# +# def setup(): +# """ +# Setup connections and bindings +# """ +# pass +# +# +# class TestRelationInstantiations(object): +# """ +# Test cases for instantiating Relation objects +# """ +# def __init__(self): +# self.conn = None +# +# def setup(self): +# """ +# Create a connection object and prepare test modules +# as follows: +# test1 - has conn and bounded +# """ +# self.conn = Connection(**CONN_INFO) +# cleanup() # drop all databases with PREFIX +# #test1.conn = self.conn +# #self.conn.bind(test1.__name__, PREFIX+'_test1') +# +# #test2.conn = self.conn +# +# #test3.__dict__.pop('conn', None) # make sure conn is not defined in test3 +# test1.__dict__.pop('conn', None) +# schema1.__dict__.pop('conn', None) # make sure conn is not defined at schema level +# +# +# def teardown(self): +# cleanup() +# +# +# def test_instantiation_from_unbound_module_should_fail(self): +# """ +# Attempting to instantiate a Relation derivative from a module with +# connection defined but not bound to a database should raise error +# """ +# test1.conn = self.conn +# with assert_raises(DataJointError) as e: +# test1.Subjects() +# assert_regexp_matches(e.exception.args[0], r".*not bound.*") +# +# def test_instantiation_from_module_without_conn_should_fail(self): +# """ +# Attempting to instantiate a Relation derivative from a module that lacks +# `conn` object should raise error +# """ +# with assert_raises(DataJointError) as e: +# test1.Subjects() +# assert_regexp_matches(e.exception.args[0], r".*define.*conn.*") +# +# def test_instantiation_of_base_derivatives(self): +# """ +# Test instantiation and initialization of objects derived from +# Relation class +# """ +# test1.conn = self.conn +# self.conn.bind(test1.__name__, PREFIX + '_test1') +# s = test1.Subjects() +# assert_equal(s.dbname, PREFIX + '_test1') +# assert_equal(s.conn, self.conn) +# assert_equal(s.definition, test1.Subjects.definition) +# +# def test_packagelevel_binding(self): +# schema2.conn = self.conn +# self.conn.bind(schema2.__name__, PREFIX + '_test1') +# s = schema2.test1.Subjects() +# +# +# class TestRelationDeclaration(object): +# """ +# Test declaration (creation of table) from +# definition in Relation under various circumstances +# """ +# +# def setup(self): +# cleanup() +# +# self.conn = Connection(**CONN_INFO) +# test1.conn = self.conn +# self.conn.bind(test1.__name__, PREFIX + '_test1') +# test2.conn = self.conn +# self.conn.bind(test2.__name__, PREFIX + '_test2') +# +# def test_is_declared(self): +# """ +# The table should not be created immediately after instantiation, +# but should be created when declare method is called +# :return: +# """ +# s = test1.Subjects() +# assert_false(s.is_declared) +# s.declare() +# assert_true(s.is_declared) +# +# def test_calling_heading_should_trigger_declaration(self): +# s = test1.Subjects() +# assert_false(s.is_declared) +# a = s.heading +# assert_true(s.is_declared) +# +# def test_foreign_key_ref_in_same_schema(self): +# s = test1.Experiments() +# assert_true('subject_id' in s.heading.primary_key) +# +# def test_foreign_key_ref_in_another_schema(self): +# s = test2.Experiments() +# assert_true('subject_id' in s.heading.primary_key) +# +# def test_aliased_module_name_should_resolve(self): +# """ +# Module names that were aliased in the definition should +# be properly resolved. +# """ +# s = test2.Conditions() +# assert_true('subject_id' in s.heading.primary_key) +# +# def test_reference_to_unknown_module_in_definition_should_fail(self): +# """ +# Module names in table definition that is not aliased via import +# results in error +# """ +# s = test2.FoodPreference() +# with assert_raises(DataJointError) as e: +# s.declare() +# +# +# class TestRelationWithExistingTables(object): +# """ +# Test base derivatives behaviors when some of the tables +# already exists in the database +# """ +# def setup(self): +# cleanup() +# self.conn = Connection(**CONN_INFO) +# setup_sample_db() +# test1.conn = self.conn +# self.conn.bind(test1.__name__, PREFIX + '_test1') +# test2.conn = self.conn +# self.conn.bind(test2.__name__, PREFIX + '_test2') +# self.conn.load_headings(force=True) +# +# schema2.conn = self.conn +# self.conn.bind(schema2.__name__, PREFIX + '_package') +# +# def teardown(selfself): +# schema1.__dict__.pop('conn', None) +# cleanup() +# +# def test_detection_of_existing_table(self): +# """ +# The Relation instance should be able to detect if the +# corresponding table already exists in the database +# """ +# s = test1.Subjects() +# assert_true(s.is_declared) +# +# def test_definition_referring_to_existing_table_without_class(self): +# s1 = test1.Sessions() +# assert_true('experimenter_id' in s1.primary_key) +# +# s2 = test2.Session() +# assert_true('experimenter_id' in s2.primary_key) +# +# def test_reference_to_package_level_table(self): +# s = test1.Match() +# s.declare() +# assert_true('pop_id' in s.primary_key) +# +# def test_direct_reference_to_existing_table_should_fail(self): +# """ +# When deriving from Relation, definition should not contain direct reference +# to a database name +# """ +# s = test1.TrainingSession() +# with assert_raises(DataJointError): +# s.declare() +# +# @raises(TypeError) +# def test_instantiation_of_base_derivative_without_definition_should_fail(): +# test1.Empty() +# +# +# +# diff --git a/tests/test_relation.py b/tests/test_relation.py index ea7ef21d9..fd27876ce 100644 --- a/tests/test_relation.py +++ b/tests/test_relation.py @@ -1,34 +1,34 @@ +# import random +# import string +# +# __author__ = 'fabee' +# +# from .schemata.schema1 import test1, test4 import random import string - -__author__ = 'fabee' - -from .schemata.schema1 import test1, test4 - +import pymysql +from datajoint import DataJointError +from .schemata.test1 import Subjects, Animals, Matrix, Trials, SquaredScore, SquaredSubtable, WrongImplementation, \ + ErrorGenerator, testschema from . import BASE_CONN, CONN_INFO, PREFIX, cleanup -from datajoint.connection import Connection -from nose.tools import assert_raises, assert_equal, assert_regexp_matches, assert_false, assert_true, assert_list_equal,\ +# from datajoint.connection import Connection +from nose.tools import assert_raises, assert_equal, assert_regexp_matches, assert_false, assert_true, assert_list_equal, \ assert_tuple_equal, assert_dict_equal, raises -from datajoint import DataJointError, TransactionError, AutoPopulate, Relation +# from datajoint import DataJointError, TransactionError, AutoPopulate, Relation import numpy as np from numpy.testing import assert_array_equal -from datajoint.free_relation import FreeRelation import numpy as np +import datajoint as dj - +# +# def trial_faker(n=10): def iter(): for s in [1, 2]: for i in range(n): - yield dict(trial_id=i, subject_id=s, outcome=int(np.random.randint(10)), notes= 'no comment') - return iter() + yield dict(trial_id=i, subject_id=s, outcome=int(np.random.randint(10)), notes='no comment') - -def setup(): - """ - Setup connections and bindings - """ - pass + return iter() class TestTableObject(object): @@ -46,19 +46,30 @@ def setup(self): as follows: test1 - has conn and bounded """ - cleanup() # drop all databases with PREFIX - test1.__dict__.pop('conn', None) - test4.__dict__.pop('conn', None) # make sure conn is not defined at schema level - - self.conn = Connection(**CONN_INFO) - test1.conn = self.conn - test4.conn = self.conn - self.conn.bind(test1.__name__, PREFIX + '_test1') - self.conn.bind(test4.__name__, PREFIX + '_test4') - self.subjects = test1.Subjects() - self.animals = test1.Animals() - self.relvar_blob = test4.Matrix() - self.trials = test1.Trials() + cleanup() # delete everything from all tables of databases with PREFIX + self.subjects = Subjects() + self.animals = Animals() + self.relvar_blob = Matrix() + self.trials = Trials() + self.score = SquaredScore() + self.subtable = SquaredSubtable() + + def test_table_name_manual(self): + assert_true(not self.subjects.table_name.startswith('#') and + not self.subjects.table_name.startswith('_') and not self.subjects.table_name.startswith('__')) + + def test_table_name_computed(self): + assert_true(self.score.table_name.startswith('__')) + + def test_population_relation_subordinate(self): + assert_true(self.subtable.populate_relation is None) + + @raises(NotImplementedError) + def test_make_tubles_not_implemented_subordinate(self): + self.subtable._make_tuples(None) + + def test_instantiate_relation(self): + s = Subjects() def teardown(self): cleanup() @@ -92,7 +103,7 @@ def test_record_insert(self): def test_delete(self): "Test whether delete works" - tmp = np.array([(2, 'Klara', 'monkey'), (1,'Peter', 'mouse')], + tmp = np.array([(2, 'Klara', 'monkey'), (1, 'Peter', 'mouse')], dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) self.subjects.batch_insert(tmp) @@ -100,27 +111,28 @@ def test_delete(self): self.subjects.delete() assert_true(len(self.subjects) == 0, 'Length does not match 0.') - # def test_cascading_delete(self): - # "Test whether delete works" - # tmp = np.array([(2, 'Klara', 'monkey'), (1,'Peter', 'mouse')], - # dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) # - # self.subjects.batch_insert(tmp) + # # def test_cascading_delete(self): + # # "Test whether delete works" + # # tmp = np.array([(2, 'Klara', 'monkey'), (1,'Peter', 'mouse')], + # # dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) + # # + # # self.subjects.batch_insert(tmp) + # # + # # self.trials.insert(dict(subject_id=1, trial_id=1, outcome=0)) + # # self.trials.insert(dict(subject_id=1, trial_id=2, outcome=1)) + # # self.trials.insert(dict(subject_id=2, trial_id=3, outcome=2)) + # # assert_true(len(self.subjects) == 2, 'Length does not match 2.') + # # assert_true(len(self.trials) == 3, 'Length does not match 3.') + # # (self.subjects & 'subject_id=1').delete() + # # assert_true(len(self.subjects) == 1, 'Length does not match 1.') + # # assert_true(len(self.trials) == 1, 'Length does not match 1.') + # + # def test_short_hand_foreign_reference(self): + # self.animals.heading + # + # # - # self.trials.insert(dict(subject_id=1, trial_id=1, outcome=0)) - # self.trials.insert(dict(subject_id=1, trial_id=2, outcome=1)) - # self.trials.insert(dict(subject_id=2, trial_id=3, outcome=2)) - # assert_true(len(self.subjects) == 2, 'Length does not match 2.') - # assert_true(len(self.trials) == 3, 'Length does not match 3.') - # (self.subjects & 'subject_id=1').delete() - # assert_true(len(self.subjects) == 1, 'Length does not match 1.') - # assert_true(len(self.trials) == 1, 'Length does not match 1.') - - def test_short_hand_foreign_reference(self): - self.animals.heading - - - def test_record_insert_different_order(self): "Test whether record insert works" tmp = np.array([('Klara', 2, 'monkey')], @@ -131,67 +143,6 @@ def test_record_insert_different_order(self): assert_equal((2, 'Klara', 'monkey'), tuple(testt2), "Inserted and fetched record do not match!") - @raises(TransactionError) - def test_transaction_error(self): - "Test whether declaration in transaction is prohibited" - - tmp = np.array([('Klara', 2, 'monkey')], - dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) - self.conn.start_transaction() - self.subjects.insert(tmp[0]) - - # def test_transaction_suppress_error(self): - # "Test whether ignore_errors ignores the errors." - # - # tmp = np.array([('Klara', 2, 'monkey')], - # dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) - # with self.conn.transaction(ignore_errors=True) as tr: - # self.subjects.insert(tmp[0]) - - - @raises(TransactionError) - def test_transaction_error_not_resolve(self): - "Test whether declaration in transaction is prohibited" - - tmp = np.array([('Klara', 2, 'monkey'), ('Klara', 3, 'monkey')], - dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) - try: - self.conn.start_transaction() - self.subjects.insert(tmp[0]) - except TransactionError as te: - self.conn.cancel_transaction() - - self.conn.start_transaction() - self.subjects.insert(tmp[0]) - - def test_transaction_error_resolve(self): - "Test whether declaration in transaction is prohibited" - - tmp = np.array([('Klara', 2, 'monkey'), ('Klara', 3, 'monkey')], - dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) - try: - self.conn.start_transaction() - self.subjects.insert(tmp[0]) - except TransactionError as te: - self.conn.cancel_transaction() - te.resolve() - - self.conn.start_transaction() - self.subjects.insert(tmp[0]) - self.conn.commit_transaction() - - def test_transaction_error2(self): - "If table is declared, we are allowed to insert within a transaction" - - tmp = np.array([('Klara', 2, 'monkey'), ('Klara', 3, 'monkey')], - dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) - self.subjects.insert(tmp[0]) - - self.conn.start_transaction() - self.subjects.insert(tmp[1]) - self.conn.commit_transaction() - - @raises(KeyError) def test_wrong_key_insert_records(self): "Test whether record insert works" @@ -200,16 +151,15 @@ def test_wrong_key_insert_records(self): self.subjects.insert(tmp[0]) + def test_dict_insert(self): + "Test whether record insert works" + tmp = {'real_id': 'Brunhilda', + 'subject_id': 3, + 'species': 'human'} - def test_dict_insert(self): - "Test whether record insert works" - tmp = {'real_id': 'Brunhilda', - 'subject_id': 3, - 'species': 'human'} - - self.subjects.insert(tmp) - testt2 = (self.subjects & 'subject_id = 3').fetch()[0] - assert_equal((3, 'Brunhilda', 'human'), tuple(testt2), "Inserted and fetched record do not match!") + self.subjects.insert(tmp) + testt2 = (self.subjects & 'subject_id = 3').fetch()[0] + assert_equal((3, 'Brunhilda', 'human'), tuple(testt2), "Inserted and fetched record do not match!") @raises(KeyError) def test_wrong_key_insert(self): @@ -220,6 +170,7 @@ def test_wrong_key_insert(self): self.subjects.insert(tmp) + # def test_batch_insert(self): "Test whether record insert works" tmp = np.array([('Klara', 2, 'monkey'), ('Brunhilda', 3, 'mouse'), ('Mickey', 1, 'human')], @@ -232,8 +183,9 @@ def test_batch_insert(self): dtype=[('subject_id', ' `dj_free`.Animals - rec_session_id : int # recording session identifier - """ - table = FreeRelation(self.conn, 'dj_free', 'Recordings', definition) - assert_raises(DataJointError, table.declare) - - def test_reference_to_existing_table(self): - definition1 = """ - `dj_free`.Animals (manual) # my animal table - animal_id : int # unique id for the animal - --- - animal_name : varchar(128) # name of the animal - """ - table1 = FreeRelation(self.conn, 'dj_free', 'Animals', definition1) - table1.declare() - - definition2 = """ - `dj_free`.Recordings (manual) # recordings - -> `dj_free`.Animals - rec_session_id : int # recording session identifier - """ - table2 = FreeRelation(self.conn, 'dj_free', 'Recordings', definition2) - table2.declare() - assert_true('animal_id' in table2.primary_key) - - + assert_array_equal(x, x2, 'inserted blob does not match') + + +# +# class TestUnboundTables(object): +# """ +# Test usages of FreeRelation objects not connected to a module. +# """ +# def setup(self): +# cleanup() +# self.conn = Connection(**CONN_INFO) +# +# def test_creation_from_definition(self): +# definition = """ +# `dj_free`.Animals (manual) # my animal table +# animal_id : int # unique id for the animal +# --- +# animal_name : varchar(128) # name of the animal +# """ +# table = FreeRelation(self.conn, 'dj_free', 'Animals', definition) +# table.declare() +# assert_true('animal_id' in table.primary_key) +# +# def test_reference_to_non_existant_table_should_fail(self): +# definition = """ +# `dj_free`.Recordings (manual) # recordings +# -> `dj_free`.Animals +# rec_session_id : int # recording session identifier +# """ +# table = FreeRelation(self.conn, 'dj_free', 'Recordings', definition) +# assert_raises(DataJointError, table.declare) +# +# def test_reference_to_existing_table(self): +# definition1 = """ +# `dj_free`.Animals (manual) # my animal table +# animal_id : int # unique id for the animal +# --- +# animal_name : varchar(128) # name of the animal +# """ +# table1 = FreeRelation(self.conn, 'dj_free', 'Animals', definition1) +# table1.declare() +# +# definition2 = """ +# `dj_free`.Recordings (manual) # recordings +# -> `dj_free`.Animals +# rec_session_id : int # recording session identifier +# """ +# table2 = FreeRelation(self.conn, 'dj_free', 'Recordings', definition2) +# table2.declare() +# assert_true('animal_id' in table2.primary_key) +# +# def id_generator(size=6, chars=string.ascii_uppercase + string.digits): return ''.join(random.choice(chars) for _ in range(size)) + class TestIterator(object): def __init__(self): self.relvar = None @@ -324,27 +280,22 @@ def setup(self): test1 - has conn and bounded """ cleanup() # drop all databases with PREFIX - test4.__dict__.pop('conn', None) # make sure conn is not defined at schema level - - self.conn = Connection(**CONN_INFO) - test4.conn = self.conn - self.conn.bind(test4.__name__, PREFIX + '_test4') - self.relvar_blob = test4.Matrix() + self.relvar_blob = Matrix() def teardown(self): cleanup() - + # + # def test_blob_iteration(self): "Tests the basic call of the iterator" dicts = [] for i in range(10): - c = id_generator() - t = {'matrix_id':i, - 'data': np.random.randn(4,4,4), + t = {'matrix_id': i, + 'data': np.random.randn(4, 4, 4), 'comment': c} self.relvar_blob.insert(t) dicts.append(t) @@ -359,11 +310,10 @@ def test_blob_iteration(self): def test_fetch(self): dicts = [] for i in range(10): - c = id_generator() - t = {'matrix_id':i, - 'data': np.random.randn(4,4,4), + t = {'matrix_id': i, + 'data': np.random.randn(4, 4, 4), 'comment': c} self.relvar_blob.insert(t) dicts.append(t) @@ -372,7 +322,6 @@ def test_fetch(self): assert_true(isinstance(tuples2, np.ndarray), "Return value of fetch does not have proper type.") assert_true(isinstance(tuples2[0], np.void), "Return value of fetch does not have proper type.") for t, t2 in zip(dicts, tuples2): - assert_equal(t['matrix_id'], t2['matrix_id'], 'inserted and retrieved tuples do not match') assert_equal(t['comment'], t2['comment'], 'inserted and retrieved tuples do not match') assert_true(np.all(t['data'] == t2['data']), 'inserted and retrieved tuples do not match') @@ -380,11 +329,10 @@ def test_fetch(self): def test_fetch_dicts(self): dicts = [] for i in range(10): - c = id_generator() - t = {'matrix_id':i, - 'data': np.random.randn(4,4,4), + t = {'matrix_id': i, + 'data': np.random.randn(4, 4, 4), 'comment': c} self.relvar_blob.insert(t) dicts.append(t) @@ -398,8 +346,8 @@ def test_fetch_dicts(self): assert_true(np.all(t['data'] == t2['data']), 'inserted and retrieved dicts do not match') - -class TestAutopopulate(object): +# +class TestAutopopulate: def __init__(self): self.relvar = None self.setup() @@ -415,29 +363,22 @@ def setup(self): test1 - has conn and bounded """ cleanup() # drop all databases with PREFIX - test1.__dict__.pop('conn', None) # make sure conn is not defined at schema level - - self.conn = Connection(**CONN_INFO) - test1.conn = self.conn - self.conn.bind(test1.__name__, PREFIX + '_test1') - - self.subjects = test1.Subjects() - self.trials = test1.Trials() - self.squared = test1.SquaredScore() - self.dummy = test1.SquaredSubtable() - self.dummy1 = test1.WrongImplementation() - self.error_generator = test1.ErrorGenerator() - self.fill_relation() - + self.subjects = Subjects() + self.trials = Trials() + self.squared = SquaredScore() + self.dummy = SquaredSubtable() + self.dummy1 = WrongImplementation() + self.error_generator = ErrorGenerator() + self.fill_relation() def fill_relation(self): tmp = np.array([('Klara', 2, 'monkey'), ('Peter', 3, 'mouse')], dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) self.subjects.batch_insert(tmp) - for trial_id in range(1,11): - self.trials.insert(dict(subject_id=2, trial_id=trial_id, outcome=np.random.randint(0,10))) + for trial_id in range(1, 11): + self.trials.insert(dict(subject_id=2, trial_id=trial_id, outcome=np.random.randint(0, 10))) def teardown(self): cleanup() @@ -446,27 +387,20 @@ def test_autopopulate(self): self.squared.populate() assert_equal(len(self.squared), 10) - for trial in self.trials*self.squared: - assert_equal(trial['outcome']**2, trial['squared']) + for trial in self.trials * self.squared: + assert_equal(trial['outcome'] ** 2, trial['squared']) def test_autopopulate_restriction(self): self.squared.populate(restriction='trial_id <= 5') assert_equal(len(self.squared), 5) - for trial in self.trials*self.squared: - assert_equal(trial['outcome']**2, trial['squared']) - - - # def test_autopopulate_transaction_error(self): - # errors = self.squared.populate(suppress_errors=True) - # assert_equal(len(errors), 1) - # assert_true(isinstance(errors[0][1], TransactionError)) + for trial in self.trials * self.squared: + assert_equal(trial['outcome'] ** 2, trial['squared']) @raises(DataJointError) def test_autopopulate_relation_check(self): - - class dummy(AutoPopulate): - + @testschema + class dummy(dj.Computed): def populate_relation(self): return None @@ -474,7 +408,7 @@ def _make_tuples(self, key): pass du = dummy() - du.populate() \ + du.populate() @raises(DataJointError) def test_autopopulate_relation_check(self): @@ -482,7 +416,7 @@ def test_autopopulate_relation_check(self): @raises(Exception) def test_autopopulate_relation_check(self): - self.error_generator.populate()\ + self.error_generator.populate() @raises(Exception) def test_autopopulate_relation_check2(self): diff --git a/tests/test_relational_operand.py b/tests/test_relational_operand.py index 3a25a87d0..3ceeb4964 100644 --- a/tests/test_relational_operand.py +++ b/tests/test_relational_operand.py @@ -1,47 +1,47 @@ -""" -Collection of test cases to test relational methods -""" - -__author__ = 'eywalker' - - -def setup(): - """ - Setup - :return: - """ - -class TestRelationalAlgebra(object): - - def setup(self): - pass - - def test_mul(self): - pass - - def test_project(self): - pass - - def test_iand(self): - pass - - def test_isub(self): - pass - - def test_sub(self): - pass - - def test_len(self): - pass - - def test_fetch(self): - pass - - def test_repr(self): - pass - - def test_iter(self): - pass - - def test_not(self): - pass \ No newline at end of file +# """ +# Collection of test cases to test relational methods +# """ +# +# __author__ = 'eywalker' +# +# +# def setup(): +# """ +# Setup +# :return: +# """ +# +# class TestRelationalAlgebra(object): +# +# def setup(self): +# pass +# +# def test_mul(self): +# pass +# +# def test_project(self): +# pass +# +# def test_iand(self): +# pass +# +# def test_isub(self): +# pass +# +# def test_sub(self): +# pass +# +# def test_len(self): +# pass +# +# def test_fetch(self): +# pass +# +# def test_repr(self): +# pass +# +# def test_iter(self): +# pass +# +# def test_not(self): +# pass \ No newline at end of file diff --git a/tests/test_settings.py b/tests/test_settings.py index 6b8100806..d10323b10 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -8,7 +8,7 @@ from nose.tools import assert_true, assert_raises, assert_equal, raises, assert_dict_equal import datajoint as dj - +import os def test_load_save(): dj.config.save('tmp.json') @@ -65,3 +65,13 @@ def test_save(): assert_true(os.path.isfile(settings.LOCALCONFIG)) if moved: os.rename(tmpfile, settings.LOCALCONFIG) + +def test_load_save(): + + filename_old = dj.settings.LOCALCONFIG + filename = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(50)) + '.json' + dj.settings.LOCALCONFIG = filename + dj.config.save() + dj.config.load(filename=filename) + dj.settings.LOCALCONFIG = filename_old + os.remove(filename) diff --git a/tests/test_utils.py b/tests/test_utils.py index 655884ce0..6c70150b2 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,10 +1,11 @@ """ Collection of test cases to test core module. """ +from datajoint.user_relations import from_camel_case __author__ = 'eywalker' from nose.tools import assert_true, assert_raises, assert_equal -from datajoint.utils import to_camel_case, from_camel_case +# from datajoint.utils import to_camel_case, from_camel_case from datajoint import DataJointError @@ -16,11 +17,6 @@ def teardown(): pass -def test_to_camel_case(): - assert_equal(to_camel_case('basic_sessions'), 'BasicSessions') - assert_equal(to_camel_case('_another_table'), 'AnotherTable') - - def test_from_camel_case(): assert_equal(from_camel_case('AllGroups'), 'all_groups') with assert_raises(DataJointError):