From 70ca4ff39d0701952601e2a10d97417f08a0c0f1 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 22 May 2015 14:52:46 -0500 Subject: [PATCH 01/42] added Connection.__del__ for closing the connection --- datajoint/connection.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/datajoint/connection.py b/datajoint/connection.py index 572cede34..64bbc9dce 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -90,6 +90,10 @@ def __init__(self, host, user, passwd, init_fun=None): self._graph = DBConnGraph(self) # initialize an empty connection graph self._in_transaction = False + def __del__(self): + logger.info('Disconnecting {user}@{host}:{port}'.format(**self.conn_info)) + self._conn.close() + def __eq__(self, other): return self.conn_info == other.conn_info From 8156593b782e5e3739fcd462b78b47759adc0f77 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Sat, 30 May 2015 03:48:22 -0500 Subject: [PATCH 02/42] Add schema decorator factory function and start changing relation --- datajoint/__init__.py | 2 +- datajoint/decorators.py | 29 ++ datajoint/free_relation.py | 506 --------------------------- datajoint/relation.py | 601 ++++++++++++++++++++++++++------ datajoint/relational_operand.py | 16 +- tests/test_relation.py | 2 +- 6 files changed, 528 insertions(+), 628 deletions(-) create mode 100644 datajoint/decorators.py delete mode 100644 datajoint/free_relation.py diff --git a/datajoint/__init__.py b/datajoint/__init__.py index ac832b322..e59a008cc 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -59,5 +59,5 @@ def culprit(self): from .autopopulate import AutoPopulate from . import blob from .relational_operand import Not -from .free_relation import FreeRelation +from .relation import FreeRelation from .heading import Heading \ No newline at end of file diff --git a/datajoint/decorators.py b/datajoint/decorators.py new file mode 100644 index 000000000..9ed650dd7 --- /dev/null +++ b/datajoint/decorators.py @@ -0,0 +1,29 @@ +__author__ = 'eywalker' +from .connection import conn + +def schema(name, context, connection=None): #TODO consider moving this into relation module + """ + Returns a schema decorator that can be used to associate a Relation class to a + specific database with :param name:. Name reference to other tables in the table definition + will be resolved by looking up the corresponding key entry in the passed in context dictionary. + It is most common to set context equal to the return value of call to locals() in the module. + For more details, please refer to the tutorial online. + + :param name: name of the database to associate the decorated class with + :param context: dictionary used to resolve (any) name references within the table definition string + :param connection: connection object to the database server. If ommited, will try to establish connection according to + config values + :return: a decorator function to be used on Relation derivative classes + """ + if connection is None: + connection = conn() + + def _dec(cls): + cls._schema_name = name + cls._context = context + cls._connection = connection + return cls + + return _dec + + diff --git a/datajoint/free_relation.py b/datajoint/free_relation.py deleted file mode 100644 index 11ff6fe0c..000000000 --- a/datajoint/free_relation.py +++ /dev/null @@ -1,506 +0,0 @@ -from _collections_abc import MutableMapping, Mapping -import numpy as np -import logging -from . import DataJointError, config, TransactionError -from .relational_operand import RelationalOperand -from .blob import pack -from .heading import Heading -import re -from .settings import Role, role_to_prefix -from .utils import from_camel_case, user_choice - -logger = logging.getLogger(__name__) - - -class FreeRelation(RelationalOperand): - """ - A FreeRelation object is a relation associated with a table. - A FreeRelation object provides insert and delete methods. - FreeRelation objects are only used internally and for debugging. - The table must already exist in the schema for its FreeRelation object to work. - - The table associated with an instance of Relation is identified by its 'class name'. - property, which is a string in CamelCase. The actual table name is obtained - by converting className from CamelCase to underscore_separated_words and - prefixing according to the table's role. - - Relation instances obtain their table's heading by looking it up in the connection - object. This ensures that Relation instances contain the current table definition - even after tables are modified after the instance is created. - """ - - def __init__(self, conn, dbname, class_name=None, definition=None): - self.class_name = class_name - self._conn = conn - self.dbname = dbname - self._definition = definition - - if dbname not in self.conn.db_to_mod: - # register with a fake module, enclosed in back quotes - # necessary for loading mechanism - self.conn.bind('`{0}`'.format(dbname), dbname) - super().__init__(conn) - - @property - def from_clause(self): - return self.full_table_name - - @property - def heading(self): - self.declare() - return self.conn.headings[self.dbname][self.table_name] - - @property - def definition(self): - return self._definition - - @property - def is_declared(self): - self.conn.load_headings(self.dbname) - return self.class_name in self.conn.table_names[self.dbname] - - def declare(self): - """ - Declare the table in database if it doesn't already exist. - - :raises: DataJointError if the table cannot be declared. - """ - if not self.is_declared: - self._declare() - if not self.is_declared: - raise DataJointError( - 'FreeRelation could not be declared for %s' % self.class_name) - - @staticmethod - def _field_to_sql(field): # TODO move this into Attribute Tuple - """ - Converts an attribute definition tuple into SQL code. - :param field: attribute definition - :rtype : SQL code - """ - mysql_constants = ['CURRENT_TIMESTAMP'] - if field.nullable: - default = 'DEFAULT NULL' - else: - default = 'NOT NULL' - # if some default specified - if field.default: - # enclose value in quotes except special SQL values or already enclosed - quote = field.default.upper() not in mysql_constants and field.default[0] not in '"\'' - default += ' DEFAULT ' + ('"%s"' if quote else "%s") % field.default - if any((c in r'\"' for c in field.comment)): - raise DataJointError('Illegal characters in attribute comment "%s"' % field.comment) - - return '`{name}` {type} {default} COMMENT "{comment}",\n'.format( - name=field.name, type=field.type, default=default, comment=field.comment) - - @property - def full_table_name(self): - """ - :return: full name of the associated table - """ - return '`%s`.`%s`' % (self.dbname, self.table_name) - - @property - def table_name(self): - """ - :return: name of the associated table - """ - return self.conn.table_names[self.dbname][self.class_name] if self.is_declared else None - - @property - def primary_key(self): - """ - :return: primary key of the table - """ - return self.heading.primary_key - - def iter_insert(self, iter, **kwargs): - """ - Inserts an entire batch of entries. Additional keyword arguments are passed to insert. - - :param iter: Must be an iterator that generates a sequence of valid arguments for insert. - """ - for row in iter: - self.insert(row, **kwargs) - - def batch_insert(self, data, **kwargs): - """ - Inserts an entire batch of entries. Additional keyword arguments are passed to insert. - - :param data: must be iterable, each row must be a valid argument for insert - """ - self.iter_insert(data.__iter__(), **kwargs) - - def insert(self, tup, ignore_errors=False, replace=False): - """ - Insert one data record or one Mapping (like a dictionary). - - :param tup: Data record, or a Mapping (like a dictionary). - :param ignore_errors=False: Ignores errors if True. - :param replace=False: Replaces data tuple if True. - - Example:: - - b = djtest.Subject() - b.insert(dict(subject_id = 7, species="mouse",\\ - real_id = 1007, date_of_birth = "2014-09-01")) - """ - - heading = self.heading - if isinstance(tup, np.void): - for fieldname in tup.dtype.fields: - if fieldname not in heading: - raise KeyError(u'{0:s} is not in the attribute list'.format(fieldname, )) - value_list = ','.join([repr(tup[name]) if not heading[name].is_blob else '%s' - for name in heading if name in tup.dtype.fields]) - - args = tuple(pack(tup[name]) for name in heading - if name in tup.dtype.fields and heading[name].is_blob) - attribute_list = '`' + '`,`'.join( - [q for q in heading if q in tup.dtype.fields]) + '`' - elif isinstance(tup, Mapping): - for fieldname in tup.keys(): - if fieldname not in heading: - raise KeyError(u'{0:s} is not in the attribute list'.format(fieldname, )) - value_list = ','.join([repr(tup[name]) if not heading[name].is_blob else '%s' - for name in heading if name in tup]) - args = tuple(pack(tup[name]) for name in heading - if name in tup and heading[name].is_blob) - attribute_list = '`' + '`,`'.join( - [name for name in heading if name in tup]) + '`' - else: - raise DataJointError('Datatype %s cannot be inserted' % type(tup)) - if replace: - sql = 'REPLACE' - elif ignore_errors: - sql = 'INSERT IGNORE' - else: - sql = 'INSERT' - sql += " INTO %s (%s) VALUES (%s)" % (self.full_table_name, - attribute_list, value_list) - logger.info(sql) - self.conn.query(sql, args=args) - - def delete(self): - if not config['safemode'] or user_choice( - "You are about to delete data from a table. This operation cannot be undone.\n" - "Proceed?", default='no') == 'yes': - self.conn.query('DELETE FROM ' + self.from_clause + self.where_clause) # TODO: make cascading (issue #15) - - def drop(self): - """ - Drops the table associated to this object. - """ - if self.is_declared: - if not config['safemode'] or user_choice( - "You are about to drop an entire table. This operation cannot be undone.\n" - "Proceed?", default='no') == 'yes': - self.conn.query('DROP TABLE %s' % self.full_table_name) # TODO: make cascading (issue #16) - self.conn.clear_dependencies(dbname=self.dbname) - self.conn.load_headings(dbname=self.dbname, force=True) - logger.info("Dropped table %s" % self.full_table_name) - - @property - def size_on_disk(self): - """ - :return: size of data and indices in MiB taken by the table on the storage device - """ - cur = self.conn.query( - 'SHOW TABLE STATUS FROM `{dbname}` WHERE NAME="{table}"'.format( - dbname=self.dbname, table=self.table_name), as_dict=True) - ret = cur.fetchone() - return (ret['Data_length'] + ret['Index_length'])/1024**2 - - def set_table_comment(self, comment): - """ - Update the table comment in the table definition. - :param comment: new comment as string - """ - # TODO: add verification procedure (github issue #24) - self.alter('COMMENT="%s"' % comment) - - def add_attribute(self, definition, after=None): - """ - Add a new attribute to the table. A full line from the table definition - is passed in as definition. - - The definition can specify where to place the new attribute. Use after=None - to add the attribute as the first attribute or after='attribute' to place it - after an existing attribute. - - :param definition: table definition - :param after=None: After which attribute of the table the new attribute is inserted. - If None, the attribute is inserted in front. - """ - position = ' FIRST' if after is None else ( - ' AFTER %s' % after if after else '') - sql = self.field_to_sql(parse_attribute_definition(definition)) - self._alter('ADD COLUMN %s%s' % (sql[:-2], position)) - - def drop_attribute(self, attr_name): - """ - Drops the attribute attrName from this table. - - :param attr_name: Name of the attribute that is dropped. - """ - if not config['safemode'] or user_choice( - "You are about to drop an attribute from a table." - "This operation cannot be undone.\n" - "Proceed?", default='no') == 'yes': - self._alter('DROP COLUMN `%s`' % attr_name) - - def alter_attribute(self, attr_name, new_definition): - """ - Alter the definition of the field attr_name in this table using the new definition. - - :param attr_name: field that is redefined - :param new_definition: new definition of the field - """ - sql = self.field_to_sql(parse_attribute_definition(new_definition)) - self._alter('CHANGE COLUMN `%s` %s' % (attr_name, sql[:-2])) - - def erd(self, subset=None): - """ - Plot the schema's entity relationship diagram (ERD). - """ - - def _alter(self, alter_statement): - """ - Execute ALTER TABLE statement for this table. The schema - will be reloaded within the connection object. - - :param alter_statement: alter statement - """ - if self._conn.in_transaction: - raise TransactionError( - u"_alter is currently in transaction. Operation not allowed to avoid implicit commits.", - self._alter, args=(alter_statement,)) - - sql = 'ALTER TABLE %s %s' % (self.full_table_name, alter_statement) - self.conn.query(sql) - self.conn.load_headings(self.dbname, force=True) - # TODO: place table definition sync mechanism - - @staticmethod - def _parse_index_def(line): - """ - Parses index definition. - - :param line: definition line - :return: groupdict with index info - """ - line = line.strip() - index_regexp = re.compile(""" - ^(?PUNIQUE)?\s*INDEX\s* # [UNIQUE] INDEX - \((?P[^\)]+)\)$ # (attr1, attr2) - """, re.I + re.X) - m = index_regexp.match(line) - assert m, 'Invalid index declaration "%s"' % line - index_info = m.groupdict() - attributes = re.split(r'\s*,\s*', index_info['attributes'].strip()) - index_info['attributes'] = attributes - assert len(attributes) == len(set(attributes)), \ - 'Duplicate attributes in index declaration "%s"' % line - return index_info - - def get_base(self, module_name, class_name): - if not module_name: - module_name = r'`{dbname}`'.format(self.dbname) - m = re.match(r'`(\w+)`', module_name) - return FreeRelation(self.conn, m.group(1), class_name) if m else None - - @property - def ref_name(self): - """ - :return: the name to refer to this class, taking form module.class or `database`.class - """ - return '`{0}`'.format(self.dbname) + '.' + self.class_name - - def _declare(self): - """ - Declares the table in the database if no table in the database matches this object. - """ - if self._conn.in_transaction: - raise TransactionError( - u"_alter is currently in transaction. Operation not allowed to avoid implicit commits.", self._declare) - - if not self.definition: - raise DataJointError('Table definition is missing!') - table_info, parents, referenced, field_defs, index_defs = self._parse_declaration() - defined_name = table_info['module'] + '.' + table_info['className'] - - if not defined_name == self.ref_name: - raise DataJointError('Table name {} does not match the declared' - 'name {}'.format(self.ref_name, defined_name)) - - # compile the CREATE TABLE statement - table_name = role_to_prefix[table_info['tier']] + from_camel_case(self.class_name) - sql = 'CREATE TABLE `%s`.`%s` (\n' % (self.dbname, table_name) - - # add inherited primary key fields - primary_key_fields = set() - non_key_fields = set() - for p in parents: - for key in p.primary_key: - field = p.heading[key] - if field.name not in primary_key_fields: - primary_key_fields.add(field.name) - sql += self._field_to_sql(field) - else: - logger.debug('Field definition of {} in {} ignored'.format( - field.name, p.full_class_name)) - - # add newly defined primary key fields - for field in (f for f in field_defs if f.in_key): - if field.nullable: - raise DataJointError('Primary key attribute {} cannot be nullable'.format( - field.name)) - if field.name in primary_key_fields: - raise DataJointError('Duplicate declaration of the primary attribute {key}. ' - 'Ensure that the attribute is not already declared ' - 'in referenced tables'.format(key=field.name)) - primary_key_fields.add(field.name) - sql += self._field_to_sql(field) - - # add secondary foreign key attributes - for r in referenced: - for key in r.primary_key: - field = r.heading[key] - if field.name not in primary_key_fields | non_key_fields: - non_key_fields.add(field.name) - sql += self._field_to_sql(field) - - # add dependent attributes - for field in (f for f in field_defs if not f.in_key): - non_key_fields.add(field.name) - sql += self._field_to_sql(field) - - # add primary key declaration - assert len(primary_key_fields) > 0, 'table must have a primary key' - keys = ', '.join(primary_key_fields) - sql += 'PRIMARY KEY (%s),\n' % keys - - # add foreign key declarations - for ref in parents + referenced: - keys = ', '.join(ref.primary_key) - sql += 'FOREIGN KEY (%s) REFERENCES %s (%s) ON UPDATE CASCADE ON DELETE RESTRICT,\n' % \ - (keys, ref.full_table_name, keys) - - # add secondary index declarations - # gather implicit indexes due to foreign keys first - implicit_indices = [] - for fk_source in parents + referenced: - implicit_indices.append(fk_source.primary_key) - - # for index in indexDefs: - # TODO: finish this up... - - # close the declaration - sql = '%s\n) ENGINE = InnoDB, COMMENT "%s"' % ( - sql[:-2], table_info['comment']) - - # make sure that the table does not alredy exist - self.conn.load_headings(self.dbname, force=True) - if not self.is_declared: - # execute declaration - logger.debug('\n\n' + sql + '\n\n') - self.conn.query(sql) - self.conn.load_headings(self.dbname, force=True) - - def _parse_declaration(self): - """ - Parse declaration and create new SQL table accordingly. - """ - parents = [] - referenced = [] - index_defs = [] - field_defs = [] - declaration = re.split(r'\s*\n\s*', self.definition.strip()) - - # remove comment lines - declaration = [x for x in declaration if not x.startswith('#')] - ptrn = """ - ^(?P[\w\`]+)\.(?P\w+)\s* # module.className - \(\s*(?P\w+)\s*\)\s* # (tier) - \#\s*(?P.*)$ # comment - """ - p = re.compile(ptrn, re.X) - table_info = p.match(declaration[0]).groupdict() - if table_info['tier'] not in Role.__members__: - raise DataJointError('InvalidTableTier: Invalid tier {tier} for table\ - {module}.{cls}'.format(tier=table_info['tier'], - module=table_info[ - 'module'], - cls=table_info['className'])) - table_info['tier'] = Role[table_info['tier']] # convert into enum - - in_key = True # parse primary keys - attribute_regexp = re.compile(""" - ^[a-z][a-z\d_]*\s* # name - (=\s*\S+(\s+\S+)*\s*)? # optional defaults - :\s*\w.*$ # type, comment - """, re.I + re.X) # ignore case and verbose - - for line in declaration[1:]: - if line.startswith('---'): - in_key = False # start parsing non-PK fields - elif line.startswith('->'): - # foreign key - if '.' in line[2:]: - module_name, class_name = line[2:].strip().split('.') - else: - # assume it's a shorthand - module_name = '' - class_name = line[2:].strip() - ref = parents if in_key else referenced - ref.append(self.get_base(module_name, class_name)) - elif re.match(r'^(unique\s+)?index[^:]*$', line, re.I): - index_defs.append(self._parse_index_def(line)) - elif attribute_regexp.match(line): - field_defs.append(parse_attribute_definition(line, in_key)) - else: - raise DataJointError( - 'Invalid table declaration line "%s"' % line) - - return table_info, parents, referenced, field_defs, index_defs - - -def parse_attribute_definition(line, in_key=False): - """ - Parse attribute definition line in the declaration and returns - an attribute tuple. - - :param line: attribution line - :param in_key: set to True if attribute is in primary key set - :returns: attribute tuple - """ - line = line.strip() - attribute_regexp = re.compile(""" - ^(?P[a-z][a-z\d_]*)\s* # field name - (=\s*(?P\S+(\s+\S+)*?)\s*)? # default value - :\s*(?P\w[^\#]*[^\#\s])\s* # datatype - (\#\s*(?P\S*(\s+\S+)*)\s*)?$ # comment - """, re.X) - m = attribute_regexp.match(line) - if not m: - raise DataJointError('Invalid field declaration "%s"' % line) - attr_info = m.groupdict() - if not attr_info['comment']: - attr_info['comment'] = '' - if not attr_info['default']: - attr_info['default'] = '' - attr_info['nullable'] = attr_info['default'].lower() == 'null' - assert (not re.match(r'^bigint', attr_info['type'], re.I) or not attr_info['nullable']), \ - 'BIGINT attributes cannot be nullable in "%s"' % line - - return Heading.AttrTuple( - in_key=in_key, - autoincrement=None, - numeric=None, - string=None, - is_blob=None, - computation=None, - dtype=None, - **attr_info - ) diff --git a/datajoint/relation.py b/datajoint/relation.py index 7112087ca..6f89fe486 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -1,144 +1,521 @@ -import importlib -import abc -from types import ModuleType -from . import DataJointError -from .free_relation import FreeRelation +from _collections_abc import MutableMapping, Mapping +import numpy as np import logging - +from . import DataJointError, config, TransactionError +from .relational_operand import RelationalOperand +from .blob import pack +from .heading import Heading +import re +from .settings import Role, role_to_prefix +from .utils import from_camel_case, user_choice +import abc logger = logging.getLogger(__name__) -class Relation(FreeRelation, metaclass=abc.ABCMeta): +class Relation(RelationalOperand, metaclass=abc.ABCMeta): """ - Relation is a Table that implements data definition functions. - It is an abstract class with the abstract property 'definition'. + A FreeRelation object is a relation associated with a table. + A FreeRelation object provides insert and delete methods. + FreeRelation objects are only used internally and for debugging. + The table must already exist in the schema for its FreeRelation object to work. - Example for a usage of Relation:: + The table associated with an instance of Relation is identified by its 'class name'. + property, which is a string in CamelCase. The actual table name is obtained + by converting className from CamelCase to underscore_separated_words and + prefixing according to the table's role. - import datajoint as dj + Relation instances obtain their table's heading by looking it up in the connection + object. This ensures that Relation instances contain the current table definition + even after tables are modified after the instance is created. + """ + _connection = None # connection information + _schema_name = None # name of schema this relation belongs to + _heading = None # heading information for this relation + _context = {} # name reference lookup context - class Subjects(dj.Relation): - definition = ''' - test1.Subjects (manual) # Basic subject info - subject_id : int # unique subject id - --- - real_id : varchar(40) # real-world name - species = "mouse" : enum('mouse', 'monkey', 'human') # species - ''' + # defines class properties + + @property + def connection(self): + return self.__class__._connection + + @property + def schema_name(self): + return self.__class__._schema_name + + @property + def heading(self): + return self.__class__._heading + + @heading.setter + def heading(self, new_heading): + self.__class__._heading = new_heading + + @property + def context(self): + return self.__class__._context + + # object properties + + @property + def table_prefix(self): + return '' + + @property + def table_name(self): + """ + :return: name of the table. This is equal to table_prefix + class name with underscores + """ + return self.table_prefix + from_camel_case(self.__class__.__name__) + + @property + def full_table_name(self): + """ + :return: full name of the associated table + """ + return '`%s`.`%s`' % (self.schema_name, self.table_name) + + @property + def from_clause(self): + return self.full_table_name - """ @abc.abstractproperty def definition(self): - """ - :return: string containing the table declaration using the DataJoint Data Definition Language. + pass - The DataJoint DDL is described at: http://datajoint.github.com + @classmethod + def load_heading(cls): + """ + Load the heading information for this table. If the table does not exist in the database server, Heading will be + set to None if the table is not yet defined in the database. """ pass @property - def full_class_name(self): + def is_declared(self): + if self.heading is None: + self.load_heading() + return self.heading is not None + + + def __init__(self): + self.load_heading() + + def declare(self): + """ + Declare the table in database if it doesn't already exist. + + :raises: DataJointError if the table cannot be declared. + """ + if not self.is_declared: + self._declare() + # verify that declaration completed successfully + if not self.is_declared: + raise DataJointError( + 'FreeRelation could not be declared for %s' % self.class_name) + + @staticmethod + def _field_to_sql(field): # TODO move this into Attribute Tuple """ - :return: full class name including the entire package hierarchy + Converts an attribute definition tuple into SQL code. + :param field: attribute definition + :rtype : SQL code """ - return '{}.{}'.format(self.__module__, self.class_name) + mysql_constants = ['CURRENT_TIMESTAMP'] + if field.nullable: + default = 'DEFAULT NULL' + else: + default = 'NOT NULL' + # if some default specified + if field.default: + # enclose value in quotes except special SQL values or already enclosed + quote = field.default.upper() not in mysql_constants and field.default[0] not in '"\'' + default += ' DEFAULT ' + ('"%s"' if quote else "%s") % field.default + if any((c in r'\"' for c in field.comment)): + raise DataJointError('Illegal characters in attribute comment "%s"' % field.comment) + + return '`{name}` {type} {default} COMMENT "{comment}",\n'.format( + name=field.name, type=field.type, default=default, comment=field.comment) + + @property - def ref_name(self): + def primary_key(self): """ - :return: name by which this class should be accessible + :return: primary key of the table """ - parent = self.__module__.split('.')[-2 if self._use_package else -1] - return parent + '.' + self.class_name + return self.heading.primary_key - def __init__(self): #TODO: support taking in conn obj - class_name = self.__class__.__name__ - module_name = self.__module__ - mod_obj = importlib.import_module(module_name) - self._use_package = False - # first, find the conn object - try: - conn = mod_obj.conn - except AttributeError: - try: - # check if database bound at the package level instead - pkg_obj = importlib.import_module(mod_obj.__package__) - conn = pkg_obj.conn - self._use_package = True - except AttributeError: - raise DataJointError( - "Please define object 'conn' in '{}' or in its containing package.".format(module_name)) - # now use the conn object to determine the dbname this belongs to - try: - if self._use_package: - # the database is bound to the package - pkg_name = '.'.join(module_name.split('.')[:-1]) - dbname = conn.mod_to_db[pkg_name] - else: - dbname = conn.mod_to_db[module_name] - except KeyError: - raise DataJointError( - 'Module {} is not bound to a database. See datajoint.connection.bind'.format(module_name)) - # initialize using super class's constructor - super().__init__(conn, dbname, class_name) + def iter_insert(self, iter, **kwargs): + """ + Inserts an entire batch of entries. Additional keyword arguments are passed to insert. - def get_base(self, module_name, class_name): + :param iter: Must be an iterator that generates a sequence of valid arguments for insert. """ - Loads the base relation from the module. If the base relation is not defined in - the module, then construct it using Relation constructor. + for row in iter: + self.insert(row, **kwargs) - :param module_name: module name - :param class_name: class name - :returns: the base relation + def batch_insert(self, data, **kwargs): """ - if not module_name: - module_name = self.__module__.split('.')[-1] + Inserts an entire batch of entries. Additional keyword arguments are passed to insert. - mod_obj = self.get_module(module_name) - if not mod_obj: - raise DataJointError('Module named {mod_name} was not found. Please make' - ' sure that it is in the path or you import the module.'.format(mod_name=module_name)) - try: - ret = getattr(mod_obj, class_name)() - except AttributeError: - ret = FreeRelation(conn=self.conn, - dbname=self.conn.mod_to_db[mod_obj.__name__], - class_name=class_name) - return ret + :param data: must be iterable, each row must be a valid argument for insert + """ + self.iter_insert(data.__iter__(), **kwargs) - @classmethod - def get_module(cls, module_name): - """ - Resolve short name reference to a module and return the corresponding module object - - :param module_name: short name for the module, whose reference is to be resolved - :return: resolved module object. If no module matches the short name, `None` will be returned - - The module_name resolution steps in the following order: - - 1. Global reference to a module of the same name defined in the module that contains this Relation derivative. - This is the recommended use case. - 2. Module of the same name defined in the package containing this Relation derivative. This will only look for the - most immediate containing package (e.g. if this class is contained in package.subpackage.module, it will - check within `package.subpackage` but not inside `package`). - 3. Globally accessible module with the same name. - """ - # from IPython import embed - # embed() - mod_obj = importlib.import_module(cls.__module__) - if cls.__module__.split('.')[-1] == module_name: - return mod_obj - attr = getattr(mod_obj, module_name, None) - if isinstance(attr, ModuleType): - return attr - if mod_obj.__package__: - try: - return importlib.import_module('.' + module_name, mod_obj.__package__) - except ImportError: - pass + def insert(self, tup, ignore_errors=False, replace=False): + """ + Insert one data record or one Mapping (like a dictionary). + + :param tup: Data record, or a Mapping (like a dictionary). + :param ignore_errors=False: Ignores errors if True. + :param replace=False: Replaces data tuple if True. + + Example:: + + b = djtest.Subject() + b.insert(dict(subject_id = 7, species="mouse",\\ + real_id = 1007, date_of_birth = "2014-09-01")) + """ + + heading = self.heading + if isinstance(tup, np.void): + for fieldname in tup.dtype.fields: + if fieldname not in heading: + raise KeyError(u'{0:s} is not in the attribute list'.format(fieldname, )) + value_list = ','.join([repr(tup[name]) if not heading[name].is_blob else '%s' + for name in heading if name in tup.dtype.fields]) + + args = tuple(pack(tup[name]) for name in heading + if name in tup.dtype.fields and heading[name].is_blob) + attribute_list = '`' + '`,`'.join( + [q for q in heading if q in tup.dtype.fields]) + '`' + elif isinstance(tup, Mapping): + for fieldname in tup.keys(): + if fieldname not in heading: + raise KeyError(u'{0:s} is not in the attribute list'.format(fieldname, )) + value_list = ','.join([repr(tup[name]) if not heading[name].is_blob else '%s' + for name in heading if name in tup]) + args = tuple(pack(tup[name]) for name in heading + if name in tup and heading[name].is_blob) + attribute_list = '`' + '`,`'.join( + [name for name in heading if name in tup]) + '`' + else: + raise DataJointError('Datatype %s cannot be inserted' % type(tup)) + if replace: + sql = 'REPLACE' + elif ignore_errors: + sql = 'INSERT IGNORE' + else: + sql = 'INSERT' + sql += " INTO %s (%s) VALUES (%s)" % (self.full_table_name, + attribute_list, value_list) + logger.info(sql) + self.connection.query(sql, args=args) + + def delete(self): + if not config['safemode'] or user_choice( + "You are about to delete data from a table. This operation cannot be undone.\n" + "Proceed?", default='no') == 'yes': + self.connection.query('DELETE FROM ' + self.from_clause + self.where_clause) # TODO: make cascading (issue #15) + + def drop(self): + """ + Drops the table associated to this object. + """ + if self.is_declared: + if not config['safemode'] or user_choice( + "You are about to drop an entire table. This operation cannot be undone.\n" + "Proceed?", default='no') == 'yes': + self.connection.query('DROP TABLE %s' % self.full_table_name) # TODO: make cascading (issue #16) + self.connection.clear_dependencies(dbname=self.dbname) + self.connection.load_headings(dbname=self.dbname, force=True) + logger.info("Dropped table %s" % self.full_table_name) + + @property + def size_on_disk(self): + """ + :return: size of data and indices in MiB taken by the table on the storage device + """ + cur = self.connection.query( + 'SHOW TABLE STATUS FROM `{dbname}` WHERE NAME="{table}"'.format( + dbname=self.dbname, table=self.table_name), as_dict=True) + ret = cur.fetchone() + return (ret['Data_length'] + ret['Index_length'])/1024**2 + + def set_table_comment(self, comment): + """ + Update the table comment in the table definition. + :param comment: new comment as string + """ + # TODO: add verification procedure (github issue #24) + self.alter('COMMENT="%s"' % comment) + + def add_attribute(self, definition, after=None): + """ + Add a new attribute to the table. A full line from the table definition + is passed in as definition. + + The definition can specify where to place the new attribute. Use after=None + to add the attribute as the first attribute or after='attribute' to place it + after an existing attribute. + + :param definition: table definition + :param after=None: After which attribute of the table the new attribute is inserted. + If None, the attribute is inserted in front. + """ + position = ' FIRST' if after is None else ( + ' AFTER %s' % after if after else '') + sql = self.field_to_sql(parse_attribute_definition(definition)) + self._alter('ADD COLUMN %s%s' % (sql[:-2], position)) + + def drop_attribute(self, attr_name): + """ + Drops the attribute attrName from this table. + + :param attr_name: Name of the attribute that is dropped. + """ + if not config['safemode'] or user_choice( + "You are about to drop an attribute from a table." + "This operation cannot be undone.\n" + "Proceed?", default='no') == 'yes': + self._alter('DROP COLUMN `%s`' % attr_name) + + def alter_attribute(self, attr_name, new_definition): + """ + Alter the definition of the field attr_name in this table using the new definition. + + :param attr_name: field that is redefined + :param new_definition: new definition of the field + """ + sql = self.field_to_sql(parse_attribute_definition(new_definition)) + self._alter('CHANGE COLUMN `%s` %s' % (attr_name, sql[:-2])) + + def erd(self, subset=None): + """ + Plot the schema's entity relationship diagram (ERD). + """ + + def _alter(self, alter_statement): + """ + Execute ALTER TABLE statement for this table. The schema + will be reloaded within the connection object. + + :param alter_statement: alter statement + """ + if self._conn.in_transaction: + raise TransactionError( + u"_alter is currently in transaction. Operation not allowed to avoid implicit commits.", + self._alter, args=(alter_statement,)) + + sql = 'ALTER TABLE %s %s' % (self.full_table_name, alter_statement) + self.connection.query(sql) + self.connection.load_headings(self.dbname, force=True) + # TODO: place table definition sync mechanism + + @staticmethod + def _parse_index_def(line): + """ + Parses index definition. + + :param line: definition line + :return: groupdict with index info + """ + line = line.strip() + index_regexp = re.compile(""" + ^(?PUNIQUE)?\s*INDEX\s* # [UNIQUE] INDEX + \((?P[^\)]+)\)$ # (attr1, attr2) + """, re.I + re.X) + m = index_regexp.match(line) + assert m, 'Invalid index declaration "%s"' % line + index_info = m.groupdict() + attributes = re.split(r'\s*,\s*', index_info['attributes'].strip()) + index_info['attributes'] = attributes + assert len(attributes) == len(set(attributes)), \ + 'Duplicate attributes in index declaration "%s"' % line + return index_info + + + def _declare(self): + """ + Declares the table in the database if no table in the database matches this object. + """ + if self.connection.in_transaction: + raise TransactionError( + u"_declare is currently in transaction. Operation not allowed to avoid implicit commits.", self._declare) + + if not self.definition: # if empty definition was supplied + raise DataJointError('Table definition is missing!') + table_info, parents, referenced, field_defs, index_defs = self._parse_declaration() + + sql = 'CREATE TABLE %s (\n' % self.full_table_name + + # add inherited primary key fields + primary_key_fields = set() + non_key_fields = set() + for p in parents: + for key in p.primary_key: + field = p.heading[key] + if field.name not in primary_key_fields: + primary_key_fields.add(field.name) + sql += self._field_to_sql(field) + else: + logger.debug('Field definition of {} in {} ignored'.format( + field.name, p.full_class_name)) + + # add newly defined primary key fields + for field in (f for f in field_defs if f.in_key): + if field.nullable: + raise DataJointError('Primary key attribute {} cannot be nullable'.format( + field.name)) + if field.name in primary_key_fields: + raise DataJointError('Duplicate declaration of the primary attribute {key}. ' + 'Ensure that the attribute is not already declared ' + 'in referenced tables'.format(key=field.name)) + primary_key_fields.add(field.name) + sql += self._field_to_sql(field) + + # add secondary foreign key attributes + for r in referenced: + for key in r.primary_key: + field = r.heading[key] + if field.name not in primary_key_fields | non_key_fields: + non_key_fields.add(field.name) + sql += self._field_to_sql(field) + + # add dependent attributes + for field in (f for f in field_defs if not f.in_key): + non_key_fields.add(field.name) + sql += self._field_to_sql(field) + + # add primary key declaration + assert len(primary_key_fields) > 0, 'table must have a primary key' + keys = ', '.join(primary_key_fields) + sql += 'PRIMARY KEY (%s),\n' % keys + + # add foreign key declarations + for ref in parents + referenced: + keys = ', '.join(ref.primary_key) + sql += 'FOREIGN KEY (%s) REFERENCES %s (%s) ON UPDATE CASCADE ON DELETE RESTRICT,\n' % \ + (keys, ref.full_table_name, keys) + + # add secondary index declarations + # gather implicit indexes due to foreign keys first + implicit_indices = [] + for fk_source in parents + referenced: + implicit_indices.append(fk_source.primary_key) + + # for index in indexDefs: + # TODO: finish this up... + + # close the declaration + sql = '%s\n) ENGINE = InnoDB, COMMENT "%s"' % ( + sql[:-2], table_info['comment']) + + # make sure that the table does not alredy exist + self.load_heading() + if not self.is_declared: + # execute declaration + logger.debug('\n\n' + sql + '\n\n') + self.connection.query(sql) + self.load_heading() + + def _parse_declaration(self): + """ + Parse declaration and create new SQL table accordingly. + """ + parents = [] + referenced = [] + index_defs = [] + field_defs = [] + declaration = re.split(r'\s*\n\s*', self.definition.strip()) + + # remove comment lines + declaration = [x for x in declaration if not x.startswith('#')] + ptrn = """ + \#\s*(?P.*)$ # comment + """ + p = re.compile(ptrn, re.X) + table_info = p.search(declaration[0]).groupdict() + + #table_info['tier'] = Role[table_info['tier']] # convert into enum + + in_key = True # parse primary keys + attribute_regexp = re.compile(""" + ^[a-z][a-z\d_]*\s* # name + (=\s*\S+(\s+\S+)*\s*)? # optional defaults + :\s*\w.*$ # type, comment + """, re.I + re.X) # ignore case and verbose + + for line in declaration[1:]: + if line.startswith('---'): + in_key = False # start parsing non-PK fields + elif line.startswith('->'): + # foreign key + ref_name = line[2:].strip() + ref_list = parents if in_key else referenced + ref_list.append(self.lookup_name(ref_name)) + elif re.match(r'^(unique\s+)?index[^:]*$', line, re.I): + index_defs.append(self._parse_index_def(line)) + elif attribute_regexp.match(line): + field_defs.append(parse_attribute_definition(line, in_key)) + else: + raise DataJointError( + 'Invalid table declaration line "%s"' % line) + + return table_info, parents, referenced, field_defs, index_defs + + + def lookup_name(self, name): + parts = name.strip().split('.') try: - return importlib.import_module(module_name) - except ImportError: - return None + ref = self.context.get(parts[0]) + for attr in parts[1:]: + ref = getattr(ref, attr) + except (KeyError, AttributeError): + raise DataJointError('Foreign reference %s could not be resolved. Please make sure the name exists' + 'in the context of the class' % name) + return ref + + + +def parse_attribute_definition(line, in_key=False): + """ + Parse attribute definition line in the declaration and returns + an attribute tuple. + + :param line: attribution line + :param in_key: set to True if attribute is in primary key set + :returns: attribute tuple + """ + line = line.strip() + attribute_regexp = re.compile(""" + ^(?P[a-z][a-z\d_]*)\s* # field name + (=\s*(?P\S+(\s+\S+)*?)\s*)? # default value + :\s*(?P\w[^\#]*[^\#\s])\s* # datatype + (\#\s*(?P\S*(\s+\S+)*)\s*)?$ # comment + """, re.X) + m = attribute_regexp.match(line) + if not m: + raise DataJointError('Invalid field declaration "%s"' % line) + attr_info = m.groupdict() + if not attr_info['comment']: + attr_info['comment'] = '' + if not attr_info['default']: + attr_info['default'] = '' + attr_info['nullable'] = attr_info['default'].lower() == 'null' + assert (not re.match(r'^bigint', attr_info['type'], re.I) or not attr_info['nullable']), \ + 'BIGINT attributes cannot be nullable in "%s"' % line + + return Heading.AttrTuple( + in_key=in_key, + autoincrement=None, + numeric=None, + string=None, + is_blob=None, + computation=None, + dtype=None, + **attr_info + ) diff --git a/datajoint/relational_operand.py b/datajoint/relational_operand.py index 51f81c43a..2a1ff588a 100644 --- a/datajoint/relational_operand.py +++ b/datajoint/relational_operand.py @@ -29,7 +29,7 @@ def __init__(self, conn, restrictions=None): self._restrictions = [] if restrictions is None else restrictions @property - def conn(self): + def connection(self): return self._conn @property @@ -127,7 +127,7 @@ def __len__(self): """ number of tuples in the relation. This also takes care of the truth value """ - cur = self.conn.query(self.make_select('count(*)')) + cur = self.connection.query(self.make_select('count(*)')) return cur.fetchone()[0] def __contains__(self, item): @@ -196,7 +196,7 @@ def cursor(self, offset=0, limit=None, order_by=None, descending=False, as_dict= if offset: sql += ' OFFSET %d' % offset logger.debug(sql) - return self.conn.query(sql, as_dict=as_dict) + return self.connection.query(sql, as_dict=as_dict) def __repr__(self): limit = config['display.limit'] @@ -281,11 +281,11 @@ class Join(RelationalOperand): def __init__(self, arg1, arg2): if not isinstance(arg2, RelationalOperand): raise DataJointError('a relation can only be joined with another relation') - if arg1.conn != arg2.conn: + if arg1.connection != arg2.connection: raise DataJointError('Cannot join relations with different database connections') self._arg1 = Subquery(arg1) if arg1.heading.computed else arg1 self._arg2 = Subquery(arg1) if arg2.heading.computed else arg2 - super().__init__(arg1.conn, self._arg1.restrictions + self._arg2.restrictions) + super().__init__(arg1.connection, self._arg1.restrictions + self._arg2.restrictions) @property def counter(self): @@ -318,9 +318,9 @@ def __init__(self, arg, group=None, *attributes, **renamed_attributes): self._renamed_attributes.update({d['alias']: d['sql_expression']}) else: self._attributes.append(attribute) - super().__init__(arg.conn) + super().__init__(arg.connection) if group: - if arg.conn != group.conn: + if arg.connection != group.connection: raise DataJointError('Cannot join relations with different database connections') self._group = Subquery(group) self._arg = Subquery(arg) @@ -356,7 +356,7 @@ class Subquery(RelationalOperand): def __init__(self, arg): self._arg = arg - super().__init__(arg.conn) + super().__init__(arg.connection) @property def counter(self): diff --git a/tests/test_relation.py b/tests/test_relation.py index ea7ef21d9..15f52f34a 100644 --- a/tests/test_relation.py +++ b/tests/test_relation.py @@ -12,7 +12,7 @@ from datajoint import DataJointError, TransactionError, AutoPopulate, Relation import numpy as np from numpy.testing import assert_array_equal -from datajoint.free_relation import FreeRelation +from datajoint.relation import FreeRelation import numpy as np From 6a41328e09c5232b2b6454a057ca8c0188715cb1 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Wed, 3 Jun 2015 18:51:26 -0500 Subject: [PATCH 03/42] Add ClassRelation --- datajoint/relation.py | 178 +++++++++++++++++++++++++++++++++++------- 1 file changed, 150 insertions(+), 28 deletions(-) diff --git a/datajoint/relation.py b/datajoint/relation.py index 6f89fe486..4b3f67146 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -8,12 +8,13 @@ import re from .settings import Role, role_to_prefix from .utils import from_camel_case, user_choice +from .connection import conn import abc logger = logging.getLogger(__name__) -class Relation(RelationalOperand, metaclass=abc.ABCMeta): +class Relation(RelationalOperand): """ A FreeRelation object is a relation associated with a table. A FreeRelation object provides insert and delete methods. @@ -30,34 +31,44 @@ class Relation(RelationalOperand, metaclass=abc.ABCMeta): even after tables are modified after the instance is created. """ - _connection = None # connection information - _schema_name = None # name of schema this relation belongs to - _heading = None # heading information for this relation - _context = {} # name reference lookup context - # defines class properties - @property - def connection(self): - return self.__class__._connection + + def __init__(self, table_name, schema_name=None, connection=None, definition=None, context=None): + self._table_name = table_name + self._schema_name = schema_name + if connection is None: + connection = conn() + self._connection = connection + self._definition = definition + if context is None: + context = {} + self._context = context + self._heading = None @property def schema_name(self): - return self.__class__._schema_name + return self._schema_name @property - def heading(self): - return self.__class__._heading + def connection(self): + return self._connection - @heading.setter - def heading(self, new_heading): - self.__class__._heading = new_heading + @property + def definition(self): + return self._definition @property def context(self): - return self.__class__._context + return self._context + + @property + def heading(self): + return self._heading - # object properties + @heading.setter + def heading(self, new_heading): + self._heading = new_heading @property def table_prefix(self): @@ -66,9 +77,17 @@ def table_prefix(self): @property def table_name(self): """ + TODO: allow table kind to be specified :return: name of the table. This is equal to table_prefix + class name with underscores """ - return self.table_prefix + from_camel_case(self.__class__.__name__) + return self._table_name + + @property + def definition(self): + return self._definition + + + # ============================== Shared implementations ============================== @property def full_table_name(self): @@ -81,28 +100,26 @@ def full_table_name(self): def from_clause(self): return self.full_table_name - @abc.abstractproperty - def definition(self): - pass - - @classmethod - def load_heading(cls): + # TODO: consider if this should be a class method for derived classes + def load_heading(self, forced=False): """ Load the heading information for this table. If the table does not exist in the database server, Heading will be set to None if the table is not yet defined in the database. """ pass + # TODO: I want to be able to tell whether load_heading has already been attempted in the past... `self.heading is None` is not informative + # TODO: make sure to assign new heading to self.heading, not to self._heading or any other direct variables @property def is_declared(self): + #TODO: this implementation is rather expensive and stupid + # - if table is not declared yet, repeated call to this method causes loading attempt each time + if self.heading is None: self.load_heading() return self.heading is not None - def __init__(self): - self.load_heading() - def declare(self): """ Declare the table in database if it doesn't already exist. @@ -467,8 +484,13 @@ def _parse_declaration(self): return table_info, parents, referenced, field_defs, index_defs - def lookup_name(self, name): + """ + Lookup the referenced name in the context dictionary + + e.g. for reference `common.Animals`, it will first check if `context` dictionary contains key + `common`. If found, it then checks for attribute `Animals` in `common`, and returns the result. + """ parts = name.strip().split('.') try: ref = self.context.get(parts[0]) @@ -479,6 +501,106 @@ def lookup_name(self, name): 'in the context of the class' % name) return ref +class ClassRelation(Relation, metaclass=abc.ABCMeta): + """ + A relation object that is handled at class level. All instances of the derived classes + share common connection and schema binding + """ + + _connection = None # connection information + _schema_name = None # name of schema this relation belongs to + _heading = None # heading information for this relation + _context = None # name reference lookup context + + def __init__(self, schema_name=None, connection=None, context=None): + """ + Use this constructor to specify class level + """ + if schema_name is not None: + self.schema_name = schema_name + + # TODO: Think about this implementation carefully + if connection is not None: + self.connection = connection + elif self.connection is None: + self.connection = conn() + + if context is not None: + self.context = context + elif self.context is None: + self.context = {} # initialize with an empty dictionary + + @property + def schema_name(self): + return self.__class__._schema_name + + @schema_name.setter + def schema_name(self, new_schema_name): + if self.schema_name is not None: + logger.warn('Overriding associated schema for class %s' + '- this will affect all existing instances!' % self.__class__.__name__) + self.__class__._schema_name = new_schema_name + + @property + def connection(self): + return self.__class__._connection + + @connection.setter + def connection(self, new_connection): + if self.connection is not None: + logger.warn('Overriding associated connection for class %s' + '- this will affect all existing instances!' % self.__class__.__name__) + self.__class__._connection = new_connection + + @property + def context(self): + # TODO: should this be a copy or the original? + return self.__class__._context.copy() + + @context.setter + def context(self, new_context): + if self.context is not None: + logger.warn('Overriding associated reference context for class %s' + '- this will affect all existing instances!' % self.__class__.__name__) + self.__class__._context = new_context + + @property + def heading(self): + return self.__class__._heading + + @heading.setter + def heading(self, new_heading): + self.__class__._heading = new_heading + + @abc.abstractproperty + def definition(self): + """ + Inheriting class must override this property with a valid table definition string + """ + pass + + @abc.abstractproperty + def table_prefix(self): + pass + + +class ManualRelation(ClassRelation): + @property + def table_prefix(self): + return "" + + +class AutoRelation(ClassRelation): + pass + + +class ComputedRelation(AutoRelation): + @property + def table_prefix(self): + return "_" + + + def parse_attribute_definition(line, in_key=False): From a0eb8857b5f499b568b595c6367db16ada1c9761 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Thu, 4 Jun 2015 13:09:56 -0500 Subject: [PATCH 04/42] moved the schema management functionality out of the Connection class --- datajoint/{relation.py => base_relation.py} | 362 ++++---------------- datajoint/connection.py | 268 +-------------- datajoint/decorators.py | 29 -- datajoint/heading.py | 10 +- datajoint/parsing.py | 88 +++++ datajoint/relation_class.py | 94 +++++ datajoint/relational_operand.py | 55 ++- datajoint/user_relations.py | 27 ++ 8 files changed, 335 insertions(+), 598 deletions(-) rename datajoint/{relation.py => base_relation.py} (57%) delete mode 100644 datajoint/decorators.py create mode 100644 datajoint/parsing.py create mode 100644 datajoint/relation_class.py create mode 100644 datajoint/user_relations.py diff --git a/datajoint/relation.py b/datajoint/base_relation.py similarity index 57% rename from datajoint/relation.py rename to datajoint/base_relation.py index 4b3f67146..aa5d35bb7 100644 --- a/datajoint/relation.py +++ b/datajoint/base_relation.py @@ -1,124 +1,93 @@ -from _collections_abc import MutableMapping, Mapping +from collections.abc import MutableMapping, Mapping import numpy as np import logging +import re +import abc + from . import DataJointError, config, TransactionError from .relational_operand import RelationalOperand from .blob import pack +from .utils import user_choice +from .parsing import parse_attribute_definition, field_to_sql, parse_index_definition from .heading import Heading -import re -from .settings import Role, role_to_prefix -from .utils import from_camel_case, user_choice -from .connection import conn -import abc logger = logging.getLogger(__name__) -class Relation(RelationalOperand): +class BaseRelation(RelationalOperand, metaclass=abc.ABCMeta): """ - A FreeRelation object is a relation associated with a table. - A FreeRelation object provides insert and delete methods. - FreeRelation objects are only used internally and for debugging. - The table must already exist in the schema for its FreeRelation object to work. - - The table associated with an instance of Relation is identified by its 'class name'. - property, which is a string in CamelCase. The actual table name is obtained - by converting className from CamelCase to underscore_separated_words and - prefixing according to the table's role. - - Relation instances obtain their table's heading by looking it up in the connection - object. This ensures that Relation instances contain the current table definition - even after tables are modified after the instance is created. + BaseRelation is an abstract class that represents a base relation, i.e. a table in the database. + To make it a concrete class, override the abstract properties specifying the connection, + table name, database, context, and definition. + A BaseRelation implements insert and delete methods in addition to inherited relational operators. + It also loads table heading and dependencies from the database. + It also handles the table declaration based on its definition property """ - # defines class properties - - - def __init__(self, table_name, schema_name=None, connection=None, definition=None, context=None): - self._table_name = table_name - self._schema_name = schema_name - if connection is None: - connection = conn() - self._connection = connection - self._definition = definition - if context is None: - context = {} - self._context = context - self._heading = None + _heading = None + # ---------- abstract properties ------------ # @property - def schema_name(self): - return self._schema_name + @abc.abstractmethod + def table_name(self): + """ + :return: the name of the table in the database + """ + pass @property - def connection(self): - return self._connection + @abc.abstractmethod + def database(self): + """ + :return: string containing the database name on the server + """ + pass @property + @abc.abstractmethod def definition(self): - return self._definition + """ + :return: a string containing the table definition using the DataJoint DDL + """ + pass @property + @abc.abstractmethod def context(self): - return self._context - - @property - def heading(self): - return self._heading - - @heading.setter - def heading(self, new_heading): - self._heading = new_heading - - @property - def table_prefix(self): - return '' - - @property - def table_name(self): """ - TODO: allow table kind to be specified - :return: name of the table. This is equal to table_prefix + class name with underscores + :return: a dict with other relations that can be referenced by foreign keys """ - return self._table_name + pass + # --------- base relation functionality --------- # @property - def definition(self): - return self._definition - - - # ============================== Shared implementations ============================== + def is_declared(self): + cur = self.query("SHOW DATABASES LIKE '{database}'".format(database=self.database)) + return cur.rowcount == 1 @property - def full_table_name(self): + def heading(self): """ - :return: full name of the associated table + Required by relational operand + :return: a datajoint.Heading object """ - return '`%s`.`%s`' % (self.schema_name, self.table_name) + if self._heading is None: + if not self.is_declared and self.definition: + self.declare() + if self.is_declared: + self._heading = Heading.init_from_database( + self.connection, self.database, self.table_name) + + return self._heading @property def from_clause(self): - return self.full_table_name - - # TODO: consider if this should be a class method for derived classes - def load_heading(self, forced=False): """ - Load the heading information for this table. If the table does not exist in the database server, Heading will be - set to None if the table is not yet defined in the database. + Required by the Relational class, this property specifies the contents of the FROM clause + for the SQL SELECT statements. + :return: """ - pass - # TODO: I want to be able to tell whether load_heading has already been attempted in the past... `self.heading is None` is not informative - # TODO: make sure to assign new heading to self.heading, not to self._heading or any other direct variables - - @property - def is_declared(self): - #TODO: this implementation is rather expensive and stupid - # - if table is not declared yet, repeated call to this method causes loading attempt each time - - if self.heading is None: - self.load_heading() - return self.heading is not None - + return '`%s`.`%s`' % (self.database, self.table_name) def declare(self): """ @@ -131,47 +100,15 @@ def declare(self): # verify that declaration completed successfully if not self.is_declared: raise DataJointError( - 'FreeRelation could not be declared for %s' % self.class_name) - - @staticmethod - def _field_to_sql(field): # TODO move this into Attribute Tuple - """ - Converts an attribute definition tuple into SQL code. - :param field: attribute definition - :rtype : SQL code - """ - mysql_constants = ['CURRENT_TIMESTAMP'] - if field.nullable: - default = 'DEFAULT NULL' - else: - default = 'NOT NULL' - # if some default specified - if field.default: - # enclose value in quotes except special SQL values or already enclosed - quote = field.default.upper() not in mysql_constants and field.default[0] not in '"\'' - default += ' DEFAULT ' + ('"%s"' if quote else "%s") % field.default - if any((c in r'\"' for c in field.comment)): - raise DataJointError('Illegal characters in attribute comment "%s"' % field.comment) - - return '`{name}` {type} {default} COMMENT "{comment}",\n'.format( - name=field.name, type=field.type, default=default, comment=field.comment) - - - - @property - def primary_key(self): - """ - :return: primary key of the table - """ - return self.heading.primary_key + 'BaseRelation could not be declared for %s' % self.class_name) - def iter_insert(self, iter, **kwargs): + def iter_insert(self, rows, **kwargs): """ Inserts an entire batch of entries. Additional keyword arguments are passed to insert. :param iter: Must be an iterator that generates a sequence of valid arguments for insert. """ - for row in iter: + for row in rows: self.insert(row, **kwargs) def batch_insert(self, data, **kwargs): @@ -285,7 +222,7 @@ def add_attribute(self, definition, after=None): """ position = ' FIRST' if after is None else ( ' AFTER %s' % after if after else '') - sql = self.field_to_sql(parse_attribute_definition(definition)) + sql = field_to_sql(parse_attribute_definition(definition)) self._alter('ADD COLUMN %s%s' % (sql[:-2], position)) def drop_attribute(self, attr_name): @@ -307,7 +244,7 @@ def alter_attribute(self, attr_name, new_definition): :param attr_name: field that is redefined :param new_definition: new definition of the field """ - sql = self.field_to_sql(parse_attribute_definition(new_definition)) + sql = field_to_sql(parse_attribute_definition(new_definition)) self._alter('CHANGE COLUMN `%s` %s' % (attr_name, sql[:-2])) def erd(self, subset=None): @@ -333,28 +270,6 @@ def _alter(self, alter_statement): # TODO: place table definition sync mechanism @staticmethod - def _parse_index_def(line): - """ - Parses index definition. - - :param line: definition line - :return: groupdict with index info - """ - line = line.strip() - index_regexp = re.compile(""" - ^(?PUNIQUE)?\s*INDEX\s* # [UNIQUE] INDEX - \((?P[^\)]+)\)$ # (attr1, attr2) - """, re.I + re.X) - m = index_regexp.match(line) - assert m, 'Invalid index declaration "%s"' % line - index_info = m.groupdict() - attributes = re.split(r'\s*,\s*', index_info['attributes'].strip()) - index_info['attributes'] = attributes - assert len(attributes) == len(set(attributes)), \ - 'Duplicate attributes in index declaration "%s"' % line - return index_info - - def _declare(self): """ Declares the table in the database if no table in the database matches this object. @@ -377,7 +292,7 @@ def _declare(self): field = p.heading[key] if field.name not in primary_key_fields: primary_key_fields.add(field.name) - sql += self._field_to_sql(field) + sql += field_to_sql(field) else: logger.debug('Field definition of {} in {} ignored'.format( field.name, p.full_class_name)) @@ -392,7 +307,7 @@ def _declare(self): 'Ensure that the attribute is not already declared ' 'in referenced tables'.format(key=field.name)) primary_key_fields.add(field.name) - sql += self._field_to_sql(field) + sql += field_to_sql(field) # add secondary foreign key attributes for r in referenced: @@ -400,12 +315,12 @@ def _declare(self): field = r.heading[key] if field.name not in primary_key_fields | non_key_fields: non_key_fields.add(field.name) - sql += self._field_to_sql(field) + sql += field_to_sql(field) # add dependent attributes for field in (f for f in field_defs if not f.in_key): non_key_fields.add(field.name) - sql += self._field_to_sql(field) + sql += field_to_sql(field) # add primary key declaration assert len(primary_key_fields) > 0, 'table must have a primary key' @@ -475,7 +390,7 @@ def _parse_declaration(self): ref_list = parents if in_key else referenced ref_list.append(self.lookup_name(ref_name)) elif re.match(r'^(unique\s+)?index[^:]*$', line, re.I): - index_defs.append(self._parse_index_def(line)) + index_defs.append(parse_index_definition(line)) elif attribute_regexp.match(line): field_defs.append(parse_attribute_definition(line, in_key)) else: @@ -497,147 +412,8 @@ def lookup_name(self, name): for attr in parts[1:]: ref = getattr(ref, attr) except (KeyError, AttributeError): - raise DataJointError('Foreign reference %s could not be resolved. Please make sure the name exists' - 'in the context of the class' % name) - return ref - -class ClassRelation(Relation, metaclass=abc.ABCMeta): - """ - A relation object that is handled at class level. All instances of the derived classes - share common connection and schema binding - """ - - _connection = None # connection information - _schema_name = None # name of schema this relation belongs to - _heading = None # heading information for this relation - _context = None # name reference lookup context - - def __init__(self, schema_name=None, connection=None, context=None): - """ - Use this constructor to specify class level - """ - if schema_name is not None: - self.schema_name = schema_name - - # TODO: Think about this implementation carefully - if connection is not None: - self.connection = connection - elif self.connection is None: - self.connection = conn() - - if context is not None: - self.context = context - elif self.context is None: - self.context = {} # initialize with an empty dictionary - - @property - def schema_name(self): - return self.__class__._schema_name - - @schema_name.setter - def schema_name(self, new_schema_name): - if self.schema_name is not None: - logger.warn('Overriding associated schema for class %s' - '- this will affect all existing instances!' % self.__class__.__name__) - self.__class__._schema_name = new_schema_name - - @property - def connection(self): - return self.__class__._connection - - @connection.setter - def connection(self, new_connection): - if self.connection is not None: - logger.warn('Overriding associated connection for class %s' - '- this will affect all existing instances!' % self.__class__.__name__) - self.__class__._connection = new_connection - - @property - def context(self): - # TODO: should this be a copy or the original? - return self.__class__._context.copy() - - @context.setter - def context(self, new_context): - if self.context is not None: - logger.warn('Overriding associated reference context for class %s' - '- this will affect all existing instances!' % self.__class__.__name__) - self.__class__._context = new_context - - @property - def heading(self): - return self.__class__._heading - - @heading.setter - def heading(self, new_heading): - self.__class__._heading = new_heading - - @abc.abstractproperty - def definition(self): - """ - Inheriting class must override this property with a valid table definition string - """ - pass - - @abc.abstractproperty - def table_prefix(self): - pass - - -class ManualRelation(ClassRelation): - @property - def table_prefix(self): - return "" - - -class AutoRelation(ClassRelation): - pass - - -class ComputedRelation(AutoRelation): - @property - def table_prefix(self): - return "_" - - - - - -def parse_attribute_definition(line, in_key=False): - """ - Parse attribute definition line in the declaration and returns - an attribute tuple. - - :param line: attribution line - :param in_key: set to True if attribute is in primary key set - :returns: attribute tuple - """ - line = line.strip() - attribute_regexp = re.compile(""" - ^(?P[a-z][a-z\d_]*)\s* # field name - (=\s*(?P\S+(\s+\S+)*?)\s*)? # default value - :\s*(?P\w[^\#]*[^\#\s])\s* # datatype - (\#\s*(?P\S*(\s+\S+)*)\s*)?$ # comment - """, re.X) - m = attribute_regexp.match(line) - if not m: - raise DataJointError('Invalid field declaration "%s"' % line) - attr_info = m.groupdict() - if not attr_info['comment']: - attr_info['comment'] = '' - if not attr_info['default']: - attr_info['default'] = '' - attr_info['nullable'] = attr_info['default'].lower() == 'null' - assert (not re.match(r'^bigint', attr_info['type'], re.I) or not attr_info['nullable']), \ - 'BIGINT attributes cannot be nullable in "%s"' % line - - return Heading.AttrTuple( - in_key=in_key, - autoincrement=None, - numeric=None, - string=None, - is_blob=None, - computation=None, - dtype=None, - **attr_info - ) + raise DataJointError( + 'Foreign key reference to %s could not be resolved.' + 'Please make sure the name exists' + 'in the context of the class' % name) + return ref \ No newline at end of file diff --git a/datajoint/connection.py b/datajoint/connection.py index 64bbc9dce..52dae7598 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -10,17 +10,12 @@ logger = logging.getLogger(__name__) -# The following two regular expression are equivalent but one works in python -# and the other works in MySQL -table_name_regexp_sql = re.compile('^(#|_|__|~)?[a-z][a-z0-9_]*$') -table_name_regexp = re.compile('^(|#|_|__|~)[a-z][a-z0-9_]*$') # MySQL does not accept this but MariaDB does - def conn_container(): """ creates a persistent connections for everyone to use """ - _connObj = None # persistent connection object used by dj.conn() + _connection = None # persistent connection object used by dj.conn() def conn_function(host=None, user=None, passwd=None, init_fun=None, reset=False): """ @@ -30,8 +25,8 @@ def conn_function(host=None, user=None, passwd=None, init_fun=None, reset=False) Set rest=True to reset the persistent connection object """ - nonlocal _connObj - if not _connObj or reset: + nonlocal _connection + if not _connection or reset: host = host if host is not None else config['database.host'] user = user if user is not None else config['database.user'] passwd = passwd if passwd is not None else config['database.password'] @@ -39,18 +34,16 @@ def conn_function(host=None, user=None, passwd=None, init_fun=None, reset=False) if passwd is None: passwd = input("Please enter database password: ") init_fun = init_fun if init_fun is not None else config['connection.init_function'] - _connObj = Connection(host, user, passwd, init_fun) - return _connObj + _connection = Connection(host, user, passwd, init_fun) + return _connection return conn_function -# The function conn is used by others to obtain the package wide persistent connection object +# The function conn is used by others to obtain a connection object conn = conn_container() - - -class Connection(object): +class Connection: """ A dj.Connection object manages a connection to a database server. It also catalogues modules, schemas, tables, and their dependencies (foreign keys). @@ -73,21 +66,10 @@ def __init__(self, host, user, passwd, init_fun=None): self.conn_info = dict(host=host, port=port, user=user, passwd=passwd) self._conn = pymysql.connect(init_command=init_fun, **self.conn_info) if self.is_connected: - print("Connected", user + '@' + host + ':' + str(port)) + logger.info("Connected " + user + '@' + host + ':' + str(port)) else: raise DataJointError('Connection failed.') self._conn.autocommit(True) - - self.db_to_mod = {} # modules indexed by dbnames - self.mod_to_db = {} # database names indexed by modules - self.table_names = {} # tables names indexed by [dbname][class_name] - self.headings = {} # contains headings indexed by [dbname][table_name] - self.tableInfo = {} # table information indexed by [dbname][table_name] - - # dependencies from foreign keys - self.parents = {} # maps table names to their parent table names (primary foreign key) - self.referenced = {} # maps table names to table names they reference (non-primary foreign key - self._graph = DBConnGraph(self) # initialize an empty connection graph self._in_transaction = False def __del__(self): @@ -104,216 +86,6 @@ def is_connected(self): """ return self._conn.ping() - def get_full_module_name(self, module): - """ - Returns full module name of the module. - - :param module: module for which the name is requested. - :return: full module name - """ - return '.'.join(self.root_package, module) - - def bind(self, module, dbname): - """ - Binds the `module` name to the database named `dbname`. - Throws an error if `dbname` is already bound to another module. - - If the database `dbname` does not exist in the server, attempts - to create the database and then bind the module. - - - :param module: module name. - :param dbname: database name. It should be a valid database identifier and not a match pattern. - """ - - if dbname in self.db_to_mod: - raise DataJointError('Database `%s` is already bound to module `%s`' - % (dbname, self.db_to_mod[dbname])) - - cur = self.query("SHOW DATABASES LIKE '{dbname}'".format(dbname=dbname)) - count = cur.rowcount - - if count == 1: - # Database exists - self.db_to_mod[dbname] = module - self.mod_to_db[module] = dbname - elif count == 0: - # Database doesn't exist, attempt to create - logger.info("Database `{dbname}` could not be found. " - "Attempting to create the database.".format(dbname=dbname)) - try: - self.query("CREATE DATABASE `{dbname}`".format(dbname=dbname)) - logger.info('Created database `{dbname}`.'.format(dbname=dbname)) - self.db_to_mod[dbname] = module - self.mod_to_db[module] = dbname - except pymysql.OperationalError: - raise DataJointError("Database named `{dbname}` was not defined, and" - " an attempt to create has failed. Check" - " permissions.".format(dbname=dbname)) - else: - raise DataJointError("Database name {dbname} matched more than one " - "existing databases. Database name should not be " - "a pattern.".format(dbname=dbname)) - - def load_headings(self, dbname=None, force=False): - """ - Load table information including roles and list of attributes for all - tables within dbname by examining respective table status. - - If dbname is not specified or None, will load headings for all - databases that are bound to a module. - - By default, the heading is not loaded again if it already exists. - Setting force=True will result in reloading of the heading even if one - already exists. - - :param dbname=None: database name - :param force=False: force reloading the heading - """ - if dbname: - self._load_headings(dbname, force) - return - - for dbname in self.db_to_mod: - self._load_headings(dbname, force) - - def _load_headings(self, dbname, force=False): - """ - Load table information including roles and list of attributes for all - tables within dbname by examining respective table status. - - By default, the heading is not loaded again if it already exists. - Setting force=True will result in reloading of the heading even if one - already exists. - - :param dbname: database name - :param force: force reloading the heading - """ - if dbname not in self.headings or force: - logger.info('Loading table definitions from `{dbname}`...'.format(dbname=dbname)) - self.table_names[dbname] = {} - self.headings[dbname] = {} - self.tableInfo[dbname] = {} - - cur = self.query('SHOW TABLE STATUS FROM `{dbname}` WHERE name REGEXP "{sqlPtrn}"'.format( - dbname=dbname, sqlPtrn=table_name_regexp_sql.pattern), as_dict=True) - - for info in cur: - info = {k.lower(): v for k, v in info.items()} # lowercase it - table_name = info.pop('name') - # look up role by table name prefix - role = prefix_to_role[table_name_regexp.match(table_name).group(1)] - class_name = to_camel_case(table_name) - self.table_names[dbname][class_name] = table_name - self.tableInfo[dbname][table_name] = dict(info, role=role) - self.headings[dbname][table_name] = Heading.init_from_database(self, dbname, table_name) - self.load_dependencies(dbname) - - def load_dependencies(self, dbname): # TODO: Perhaps consider making this "private" by preceding with underscore? - """ - Load dependencies (foreign keys) between tables by examining their - respective CREATE TABLE statements. - - :param dbname: database name - """ - - foreign_key_regexp = re.compile(r""" - FOREIGN\ KEY\s+\((?P[`\w ,]+)\)\s+ # list of keys in this table - REFERENCES\s+(?P[^\s]+)\s+ # table referenced - \((?P[`\w ,]+)\) # list of keys in the referenced table - """, re.X) - - logger.info('Loading dependencies for `{dbname}`'.format(dbname=dbname)) - - for tabName in self.tableInfo[dbname]: - cur = self.query('SHOW CREATE TABLE `{dbname}`.`{tabName}`'.format(dbname=dbname, tabName=tabName), - as_dict=True) - table_def = cur.fetchone() - full_table_name = '`%s`.`%s`' % (dbname, tabName) - self.parents[full_table_name] = [] - self.referenced[full_table_name] = [] - - for m in foreign_key_regexp.finditer(table_def["Create Table"]): # iterate through foreign key statements - assert m.group('attr1') == m.group('attr2'), \ - 'Foreign keys must link identically named attributes' - attrs = m.group('attr1') - attrs = re.split(r',\s+', re.sub(r'`(.*?)`', r'\1', attrs)) # remove ` around attrs and split into list - pk = self.headings[dbname][tabName].primary_key - is_primary = all([k in pk for k in attrs]) - ref = m.group('ref') # referenced table - - if not re.search(r'`\.`', ref): # if referencing other table in same schema - ref = '`%s`.%s' % (dbname, ref) # convert to full-table name - - (self.parents if is_primary else self.referenced)[full_table_name].append(ref) - self.parents.setdefault(ref, []) - self.referenced.setdefault(ref, []) - - def clear_dependencies(self, dbname=None): - """ - Clears dependency mapping originating from `dbname`. If `dbname` is not - specified, dependencies for all databases will be cleared. - - - :param dbname: database name - """ - if dbname is None: # clear out all dependencies - self.parents.clear() - self.referenced.clear() - else: - table_keys = ('`%s`.`%s`' % (dbname, tblName) for tblName in self.tableInfo[dbname]) - for key in table_keys: - if key in self.parents: - self.parents.pop(key) - if key in self.referenced: - self.referenced.pop(key) - - def parents_of(self, child_table): - """ - Returns a list of tables that are parents of the specified child_table. Parent-child relationship is defined - based on the presence of primary-key foreign reference: table that holds a foreign key relation to another table - is the child table. - - :param child_table: the child table - :return: list of parent tables - """ - return self.parents.get(child_table, []).copy() - - def children_of(self, parent_table): - """ - Returns a list of tables for which parent_table is a parent (primary foreign key). Parent-child relationship - is defined based on the presence of primary-key foreign reference: table that holds a foreign key relation to - another table is the child table. - - :param parent_table: parent table - :return: list of child tables - """ - return [child_table for child_table, parents in self.parents.items() if parent_table in parents] - - def referenced_by(self, referencing_table): - """ - Returns a list of tables that are referenced by non-primary foreign key relation - by the referencing_table. - - :param referencing_table: referencing table - :return: list of tables that are referenced by the target table - """ - return self.referenced.get(referencing_table, []).copy() - - def referencing(self, referenced_table): - """ - Returns a list of tables that references referenced_table as non-primary foreign key - - :param referenced_table: referenced table - :return: list of tables that refers to the target table - """ - return [referencing for referencing, referenced in self.referenced.items() - if referenced_table in referenced] - - # TODO: Reimplement __str__ - def __str__(self): - return self.__repr__() # placeholder until more suitable __str__ is implemented - def __repr__(self): connected = "connected" if self.is_connected else "disconnected" return "DataJoint connection ({connected}) {user}@{host}:{port}".format( @@ -323,25 +95,6 @@ def __del__(self): logger.info('Disconnecting {user}@{host}:{port}'.format(**self.conn_info)) self._conn.close() - def erd(self, databases=None, tables=None, fill=True, reload=False): - """ - Creates Entity Relation Diagram for the database or specified subset of - tables. - - Set `fill` to False to only display specified tables. (By default - connection tables are automatically included) - """ - self._graph.update_graph(reload=reload) # update the graph - - graph = self._graph.copy_graph() - if databases: - graph = graph.restrict_by_modules(databases, fill) - - if tables: - graph = graph.restrict_by_tables(tables, fill) - - return graph - def query(self, query, args=(), as_dict=False): """ Execute the specified query and return the tuple generator. @@ -357,7 +110,7 @@ def query(self, query, args=(), as_dict=False): cur.execute(query, args) return cur - + # ---------- transaction processing ------------------ @property def in_transaction(self): self._in_transaction = self._in_transaction and self.is_connected @@ -378,5 +131,4 @@ def cancel_transaction(self): def commit_transaction(self): self.query('COMMIT') self._in_transaction = False - logger.info("Transaction commited and closed.") - + logger.info("Transaction committed and closed.") diff --git a/datajoint/decorators.py b/datajoint/decorators.py deleted file mode 100644 index 9ed650dd7..000000000 --- a/datajoint/decorators.py +++ /dev/null @@ -1,29 +0,0 @@ -__author__ = 'eywalker' -from .connection import conn - -def schema(name, context, connection=None): #TODO consider moving this into relation module - """ - Returns a schema decorator that can be used to associate a Relation class to a - specific database with :param name:. Name reference to other tables in the table definition - will be resolved by looking up the corresponding key entry in the passed in context dictionary. - It is most common to set context equal to the return value of call to locals() in the module. - For more details, please refer to the tutorial online. - - :param name: name of the database to associate the decorated class with - :param context: dictionary used to resolve (any) name references within the table definition string - :param connection: connection object to the database server. If ommited, will try to establish connection according to - config values - :return: a decorator function to be used on Relation derivative classes - """ - if connection is None: - connection = conn() - - def _dec(cls): - cls._schema_name = name - cls._context = context - cls._connection = connection - return cls - - return _dec - - diff --git a/datajoint/heading.py b/datajoint/heading.py index 620c93751..73d6c30f2 100644 --- a/datajoint/heading.py +++ b/datajoint/heading.py @@ -91,13 +91,13 @@ def __iter__(self): return iter(self.attributes) @classmethod - def init_from_database(cls, conn, dbname, table_name): + def init_from_database(cls, conn, database, table_name): """ initialize heading from a database table """ cur = conn.query( - 'SHOW FULL COLUMNS FROM `{table_name}` IN `{dbname}`'.format( - table_name=table_name, dbname=dbname), as_dict=True) + 'SHOW FULL COLUMNS FROM `{table_name}` IN `{database}`'.format( + table_name=table_name, database=database), as_dict=True) attributes = cur.fetchall() rename_map = { @@ -147,8 +147,8 @@ def init_from_database(cls, conn, dbname, table_name): attr['computation'] = None if not (attr['numeric'] or attr['string'] or attr['is_blob']): - raise DataJointError('Unsupported field type {field} in `{dbname}`.`{table_name}`'.format( - field=attr['type'], dbname=dbname, table_name=table_name)) + raise DataJointError('Unsupported field type {field} in `{database}`.`{table_name}`'.format( + field=attr['type'], database=database, table_name=table_name)) attr.pop('Extra') # fill out dtype. All floats and non-nullable integers are turned into specific dtypes diff --git a/datajoint/parsing.py b/datajoint/parsing.py new file mode 100644 index 000000000..85e367c96 --- /dev/null +++ b/datajoint/parsing.py @@ -0,0 +1,88 @@ +import re +from . import DataJointError +from .heading import Heading + + +def parse_attribute_definition(line, in_key=False): + """ + Parse attribute definition line in the declaration and returns + an attribute tuple. + + :param line: attribution line + :param in_key: set to True if attribute is in primary key set + :returns: attribute tuple + """ + line = line.strip() + attribute_regexp = re.compile(""" + ^(?P[a-z][a-z\d_]*)\s* # field name + (=\s*(?P\S+(\s+\S+)*?)\s*)? # default value + :\s*(?P\w[^\#]*[^\#\s])\s* # datatype + (\#\s*(?P\S*(\s+\S+)*)\s*)?$ # comment + """, re.X) + m = attribute_regexp.match(line) + if not m: + raise DataJointError('Invalid field declaration "%s"' % line) + attr_info = m.groupdict() + if not attr_info['comment']: + attr_info['comment'] = '' + if not attr_info['default']: + attr_info['default'] = '' + attr_info['nullable'] = attr_info['default'].lower() == 'null' + assert (not re.match(r'^bigint', attr_info['type'], re.I) or not attr_info['nullable']), \ + 'BIGINT attributes cannot be nullable in "%s"' % line + + return Heading.AttrTuple( + in_key=in_key, + autoincrement=None, + numeric=None, + string=None, + is_blob=None, + computation=None, + dtype=None, + **attr_info + ) + + +def field_to_sql(field): # TODO move this into Attribute Tuple + """ + Converts an attribute definition tuple into SQL code. + :param field: attribute definition + :rtype : SQL code + """ + mysql_constants = ['CURRENT_TIMESTAMP'] + if field.nullable: + default = 'DEFAULT NULL' + else: + default = 'NOT NULL' + # if some default specified + if field.default: + # enclose value in quotes except special SQL values or already enclosed + quote = field.default.upper() not in mysql_constants and field.default[0] not in '"\'' + default += ' DEFAULT ' + ('"%s"' if quote else "%s") % field.default + if any((c in r'\"' for c in field.comment)): + raise DataJointError('Illegal characters in attribute comment "%s"' % field.comment) + + return '`{name}` {type} {default} COMMENT "{comment}",\n'.format( + name=field.name, type=field.type, default=default, comment=field.comment) + + +def parse_index_definition(line): + """ + Parses index definition. + + :param line: definition line + :return: groupdict with index info + """ + line = line.strip() + index_regexp = re.compile(""" + ^(?PUNIQUE)?\s*INDEX\s* # [UNIQUE] INDEX + \((?P[^\)]+)\)$ # (attr1, attr2) + """, re.I + re.X) + m = index_regexp.match(line) + assert m, 'Invalid index declaration "%s"' % line + index_info = m.groupdict() + attributes = re.split(r'\s*,\s*', index_info['attributes'].strip()) + index_info['attributes'] = attributes + assert len(attributes) == len(set(attributes)), \ + 'Duplicate attributes in index declaration "%s"' % line + return index_info diff --git a/datajoint/relation_class.py b/datajoint/relation_class.py new file mode 100644 index 000000000..7fab544b7 --- /dev/null +++ b/datajoint/relation_class.py @@ -0,0 +1,94 @@ +import abc +import logging +from collections import namedtuple +import pymysql +from .connection import conn +from .base_relation import BaseRelation +from . import DataJointError + + +logger = logging.getLogger(__name__) + + +SharedInfo = namedtuple( + 'SharedInfo', + ('database', 'context', 'connection', 'heading', 'parents', 'children', 'references', 'referenced')) + + +def schema(database, context, connection=None): + """ + Returns a schema decorator that can be used to associate a BaseRelation class to a + specific database with :param name:. Name reference to other tables in the table definition + will be resolved by looking up the corresponding key entry in the passed in context dictionary. + It is most common to set context equal to the return value of call to locals() in the module. + For more details, please refer to the tutorial online. + + :param database: name of the database to associate the decorated class with + :param context: dictionary used to resolve (any) name references within the table definition string + :param connection: connection object to the database server. If ommited, will try to establish connection according to + config values + :return: a decorator function to be used on BaseRelation derivative classes + """ + if connection is None: + connection = conn() + + # if the database does not exist, create it + cur = connection.query("SHOW DATABASES LIKE '{database}'".format(database=database)) + if cur.rowcount == 0: + logger.info("Database `{database}` could not be found. " + "Attempting to create the database.".format(database=database)) + try: + connection.query("CREATE DATABASE `{database}`".format(database=database)) + logger.info('Created database `{database}`.'.format(database=database)) + except pymysql.OperationalError: + raise DataJointError("Database named `{database}` was not defined, and" + "an attempt to create has failed. Check" + " permissions.".format(database=database)) + + def decorator(cls): + cls._shared_info = SharedInfo( + database=database, + context=context, + connection=connection, + heading=None, + parents=[], + children=[], + references=[], + referenced=[] + ) + return cls + + return decorator + + +class RelationClass(BaseRelation): + """ + Abstract class for dedicated table classes. + Subclasses of RelationClass are dedicated interfaces to a single table. + The main purpose of RelationClass is to encapsulated sharedInfo containing the table heading + and dependency information shared by all instances of + """ + + _shared_info = None + + def __init__(self): + if self._shared_info is None: + raise DataJointError('The class must define _shared_info') + + @property + def database(self): + return self._shared_info.database + + @property + def connection(self): + return self._shared_info.connection + + @property + def context(self): + return self._shared_info.context + + @property + def heading(self): + if self._shared_info.heading is None: + self._shared_info.heading = super().heading + return self._shared_info.heading diff --git a/datajoint/relational_operand.py b/datajoint/relational_operand.py index 2a1ff588a..ac330a005 100644 --- a/datajoint/relational_operand.py +++ b/datajoint/relational_operand.py @@ -9,7 +9,6 @@ from datajoint import DataJointError, config from .blob import unpack import logging -import numpy.lib.recfunctions as rfn logger = logging.getLogger(__name__) @@ -24,26 +23,44 @@ class RelationalOperand(metaclass=abc.ABCMeta): RelationalOperand operators are: restrict, pro, and join. """ - def __init__(self, conn, restrictions=None): - self._conn = conn - self._restrictions = [] if restrictions is None else restrictions - - @property - def connection(self): - return self._conn + _restrictions = None @property def restrictions(self): return self._restrictions - @abc.abstractproperty + @property + def primary_key(self): + return self.heading.primary_key + + # --------- abstract properties ----------- + + @property + @abc.abstractmethod + def connection(self): + """ + :return: a datajoint.Connection object + """ + pass + + @property + @abc.abstractmethod def from_clause(self): + """ + :return: a string containing the FROM clause of the SQL SELECT statement + """ pass - @abc.abstractproperty + @property + @abc.abstractmethod def heading(self): + """ + :return: a valid datajoint.Heading object + """ pass + # --------- relational operators ----------- + def __mul__(self, other): """ relational join @@ -118,6 +135,8 @@ def __sub__(self, restriction): """ return self & Not(restriction) + # ------ data retrieval methods ----------- + def make_select(self, attribute_spec=None): if attribute_spec is None: attribute_spec = self.heading.as_sql @@ -285,7 +304,11 @@ def __init__(self, arg1, arg2): raise DataJointError('Cannot join relations with different database connections') self._arg1 = Subquery(arg1) if arg1.heading.computed else arg1 self._arg2 = Subquery(arg1) if arg2.heading.computed else arg2 - super().__init__(arg1.connection, self._arg1.restrictions + self._arg2.restrictions) + self._restrictions = self._arg1.restrictions + self._arg2.restrictions + + @property + def connection(self): + return self._arg1.connection @property def counter(self): @@ -318,7 +341,6 @@ def __init__(self, arg, group=None, *attributes, **renamed_attributes): self._renamed_attributes.update({d['alias']: d['sql_expression']}) else: self._attributes.append(attribute) - super().__init__(arg.connection) if group: if arg.connection != group.connection: raise DataJointError('Cannot join relations with different database connections') @@ -333,6 +355,10 @@ def __init__(self, arg, group=None, *attributes, **renamed_attributes): self._arg = arg self._restrictions = self._arg.restrictions + @property + def connection(self): + return self._arg.connection + @property def heading(self): return self._arg.heading.project(*self._attributes, **self._renamed_attributes) @@ -356,7 +382,10 @@ class Subquery(RelationalOperand): def __init__(self, arg): self._arg = arg - super().__init__(arg.connection) + + @property + def connection(self): + return self._arg.connection @property def counter(self): diff --git a/datajoint/user_relations.py b/datajoint/user_relations.py new file mode 100644 index 000000000..8b436f85b --- /dev/null +++ b/datajoint/user_relations.py @@ -0,0 +1,27 @@ +from .relation_class import RelationClass +from .autopopulate import AutoPopulate +from .utils import from_camel_case + + +class ManualRelation(RelationClass): + @property + def table_name(self): + return from_camel_case(self.__class__.__name__) + + +class LookupRelation(RelationClass): + @property + def table_name(self): + return '#' + from_camel_case(self.__class__.__name__) + + +class ImportedRelation(RelationClass, AutoPopulate): + @property + def table_name(self): + return "_" + from_camel_case(self.__class__.__name__) + + +class ComputedRelation(RelationClass, AutoPopulate): + @property + def table_name(self): + return "__" + from_camel_case(self.__class__.__name__) \ No newline at end of file From 50df1a384a47a4d5780ac4fc0ac9063bca784e57 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Thu, 4 Jun 2015 17:51:20 -0500 Subject: [PATCH 05/42] fixed BaseRelation.is_declared --- datajoint/__init__.py | 8 ++++---- datajoint/base_relation.py | 15 +++++++++------ 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/datajoint/__init__.py b/datajoint/__init__.py index e59a008cc..e193f88a8 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -5,6 +5,8 @@ __version__ = "0.2" __all__ = ['__author__', '__version__', 'Connection', 'Heading', 'Relation', 'FreeRelation', 'Not', + 'BaseRelation', + 'ManualRelation', 'LookupRelation', 'ImportedRelation', 'ComputedRelation', 'AutoPopulate', 'conn', 'DataJointError', 'blob'] @@ -28,13 +30,11 @@ def resolve(self): f, args, kwargs = self.operations return f(*args, **kwargs) - @property def culprit(self): return self.operations[0].__name__ - # ----------- loads local configuration from file ---------------- from .settings import Config, CONFIGVAR, LOCALCONFIG, logger, log_levels config = Config() @@ -55,9 +55,9 @@ def culprit(self): # ------------- flatten import hierarchy ------------------------- from .connection import conn, Connection -from .relation import Relation +from .user_relations import ManualRelation, LookupRelation, ImportedRelation, ComputedRelation +from .base_relation import BaseRelation from .autopopulate import AutoPopulate from . import blob from .relational_operand import Not -from .relation import FreeRelation from .heading import Heading \ No newline at end of file diff --git a/datajoint/base_relation.py b/datajoint/base_relation.py index aa5d35bb7..fb30ae959 100644 --- a/datajoint/base_relation.py +++ b/datajoint/base_relation.py @@ -24,7 +24,7 @@ class BaseRelation(RelationalOperand, metaclass=abc.ABCMeta): It also handles the table declaration based on its definition property """ - _heading = None + __heading = None # ---------- abstract properties ------------ # @property @@ -62,7 +62,11 @@ def context(self): # --------- base relation functionality --------- # @property def is_declared(self): - cur = self.query("SHOW DATABASES LIKE '{database}'".format(database=self.database)) + if self.__heading is not None: + return True + cur = self.query( + 'SHOW TABLE STATUS FROM `{database}` WHERE name="{table_name}"'.format( + table_name=self.table_name)) return cur.rowcount == 1 @property @@ -71,14 +75,13 @@ def heading(self): Required by relational operand :return: a datajoint.Heading object """ - if self._heading is None: + if self.__heading is None: if not self.is_declared and self.definition: self.declare() if self.is_declared: - self._heading = Heading.init_from_database( + self.__heading = Heading.init_from_database( self.connection, self.database, self.table_name) - - return self._heading + return self.__heading @property def from_clause(self): From 9b70f5706d2283edbbe2470b10c6c3dc09a85a4a Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Thu, 4 Jun 2015 18:07:01 -0500 Subject: [PATCH 06/42] made table_name a class property in the user relation classes --- datajoint/user_relations.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/datajoint/user_relations.py b/datajoint/user_relations.py index 8b436f85b..7de6567c1 100644 --- a/datajoint/user_relations.py +++ b/datajoint/user_relations.py @@ -5,23 +5,27 @@ class ManualRelation(RelationClass): @property - def table_name(self): - return from_camel_case(self.__class__.__name__) + @classmethod + def table_name(cls): + return from_camel_case(cls.__name__) class LookupRelation(RelationClass): @property - def table_name(self): - return '#' + from_camel_case(self.__class__.__name__) + @classmethod + def table_name(cls): + return '#' + from_camel_case(cls.__name__) class ImportedRelation(RelationClass, AutoPopulate): @property - def table_name(self): - return "_" + from_camel_case(self.__class__.__name__) + @classmethod + def table_name(cls): + return "_" + from_camel_case(cls.__name__) class ComputedRelation(RelationClass, AutoPopulate): @property - def table_name(self): - return "__" + from_camel_case(self.__class__.__name__) \ No newline at end of file + @classmethod + def table_name(cls): + return "__" + from_camel_case(cls.__name__) \ No newline at end of file From c2f7ce3cf0c73683230d80676aa2123fd0ad5e07 Mon Sep 17 00:00:00 2001 From: Fabian Sinz Date: Mon, 8 Jun 2015 14:56:33 -0500 Subject: [PATCH 07/42] intermediate --- datajoint/__init__.py | 8 ++++---- datajoint/{base_relation.py => abstract_relation.py} | 8 ++++---- datajoint/{relation_class.py => relations.py} | 12 ++++++------ datajoint/user_relations.py | 10 +++++----- tests/test_relation.py | 2 +- 5 files changed, 20 insertions(+), 20 deletions(-) rename datajoint/{base_relation.py => abstract_relation.py} (97%) rename datajoint/{relation_class.py => relations.py} (87%) diff --git a/datajoint/__init__.py b/datajoint/__init__.py index e193f88a8..e3a8007a6 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -5,8 +5,8 @@ __version__ = "0.2" __all__ = ['__author__', '__version__', 'Connection', 'Heading', 'Relation', 'FreeRelation', 'Not', - 'BaseRelation', - 'ManualRelation', 'LookupRelation', 'ImportedRelation', 'ComputedRelation', + 'Relation', + 'Manual', 'Lookup', 'Imported', 'Computed', 'AutoPopulate', 'conn', 'DataJointError', 'blob'] @@ -55,8 +55,8 @@ def culprit(self): # ------------- flatten import hierarchy ------------------------- from .connection import conn, Connection -from .user_relations import ManualRelation, LookupRelation, ImportedRelation, ComputedRelation -from .base_relation import BaseRelation +from .user_relations import Manual, Lookup, Imported, Computed +from .abstract_relation import Relation from .autopopulate import AutoPopulate from . import blob from .relational_operand import Not diff --git a/datajoint/base_relation.py b/datajoint/abstract_relation.py similarity index 97% rename from datajoint/base_relation.py rename to datajoint/abstract_relation.py index fb30ae959..da191fef6 100644 --- a/datajoint/base_relation.py +++ b/datajoint/abstract_relation.py @@ -14,12 +14,12 @@ logger = logging.getLogger(__name__) -class BaseRelation(RelationalOperand, metaclass=abc.ABCMeta): +class Relation(RelationalOperand, metaclass=abc.ABCMeta): """ - BaseRelation is an abstract class that represents a base relation, i.e. a table in the database. + Relation is an abstract class that represents a base relation, i.e. a table in the database. To make it a concrete class, override the abstract properties specifying the connection, table name, database, context, and definition. - A BaseRelation implements insert and delete methods in addition to inherited relational operators. + A Relation implements insert and delete methods in addition to inherited relational operators. It also loads table heading and dependencies from the database. It also handles the table declaration based on its definition property """ @@ -103,7 +103,7 @@ def declare(self): # verify that declaration completed successfully if not self.is_declared: raise DataJointError( - 'BaseRelation could not be declared for %s' % self.class_name) + 'Relation could not be declared for %s' % self.class_name) def iter_insert(self, rows, **kwargs): """ diff --git a/datajoint/relation_class.py b/datajoint/relations.py similarity index 87% rename from datajoint/relation_class.py rename to datajoint/relations.py index 7fab544b7..08f3f8900 100644 --- a/datajoint/relation_class.py +++ b/datajoint/relations.py @@ -3,7 +3,7 @@ from collections import namedtuple import pymysql from .connection import conn -from .base_relation import BaseRelation +from .abstract_relation import Relation from . import DataJointError @@ -17,7 +17,7 @@ def schema(database, context, connection=None): """ - Returns a schema decorator that can be used to associate a BaseRelation class to a + Returns a schema decorator that can be used to associate a Relation class to a specific database with :param name:. Name reference to other tables in the table definition will be resolved by looking up the corresponding key entry in the passed in context dictionary. It is most common to set context equal to the return value of call to locals() in the module. @@ -27,7 +27,7 @@ def schema(database, context, connection=None): :param context: dictionary used to resolve (any) name references within the table definition string :param connection: connection object to the database server. If ommited, will try to establish connection according to config values - :return: a decorator function to be used on BaseRelation derivative classes + :return: a decorator function to be used on Relation derivative classes """ if connection is None: connection = conn() @@ -61,11 +61,11 @@ def decorator(cls): return decorator -class RelationClass(BaseRelation): +class ClassBoundRelation(Relation): """ Abstract class for dedicated table classes. - Subclasses of RelationClass are dedicated interfaces to a single table. - The main purpose of RelationClass is to encapsulated sharedInfo containing the table heading + Subclasses of ClassBoundRelation are dedicated interfaces to a single table. + The main purpose of ClassBoundRelation is to encapsulated sharedInfo containing the table heading and dependency information shared by all instances of """ diff --git a/datajoint/user_relations.py b/datajoint/user_relations.py index 7de6567c1..7fb5aba1a 100644 --- a/datajoint/user_relations.py +++ b/datajoint/user_relations.py @@ -1,30 +1,30 @@ -from .relation_class import RelationClass +from .relations import ClassBoundRelation from .autopopulate import AutoPopulate from .utils import from_camel_case -class ManualRelation(RelationClass): +class Manual(ClassBoundRelation): @property @classmethod def table_name(cls): return from_camel_case(cls.__name__) -class LookupRelation(RelationClass): +class Lookup(ClassBoundRelation): @property @classmethod def table_name(cls): return '#' + from_camel_case(cls.__name__) -class ImportedRelation(RelationClass, AutoPopulate): +class Imported(ClassBoundRelation, AutoPopulate): @property @classmethod def table_name(cls): return "_" + from_camel_case(cls.__name__) -class ComputedRelation(RelationClass, AutoPopulate): +class Computed(ClassBoundRelation, AutoPopulate): @property @classmethod def table_name(cls): diff --git a/tests/test_relation.py b/tests/test_relation.py index 15f52f34a..3facc6721 100644 --- a/tests/test_relation.py +++ b/tests/test_relation.py @@ -12,7 +12,7 @@ from datajoint import DataJointError, TransactionError, AutoPopulate, Relation import numpy as np from numpy.testing import assert_array_equal -from datajoint.relation import FreeRelation +from datajoint.abstract_relation import FreeRelation import numpy as np From 69b5f422c5d087df662c38d743840382a8c4ceb5 Mon Sep 17 00:00:00 2001 From: Fabian Sinz Date: Mon, 8 Jun 2015 17:24:34 -0500 Subject: [PATCH 08/42] * merged Relation and ClassLevelRelation * converted almost everything to @classproperty or @classmethod --- datajoint/__init__.py | 2 +- .../{abstract_relation.py => relation.py} | 267 ++++++++++++------ datajoint/relations.py | 94 ------ datajoint/user_relations.py | 2 +- tests/test_relation.py | 2 +- 5 files changed, 183 insertions(+), 184 deletions(-) rename datajoint/{abstract_relation.py => relation.py} (65%) delete mode 100644 datajoint/relations.py diff --git a/datajoint/__init__.py b/datajoint/__init__.py index e3a8007a6..4eb1eb6d1 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -56,7 +56,7 @@ def culprit(self): # ------------- flatten import hierarchy ------------------------- from .connection import conn, Connection from .user_relations import Manual, Lookup, Imported, Computed -from .abstract_relation import Relation +from .relation import Relation from .autopopulate import AutoPopulate from . import blob from .relational_operand import Not diff --git a/datajoint/abstract_relation.py b/datajoint/relation.py similarity index 65% rename from datajoint/abstract_relation.py rename to datajoint/relation.py index da191fef6..ef9e43938 100644 --- a/datajoint/abstract_relation.py +++ b/datajoint/relation.py @@ -1,3 +1,4 @@ +from collections import namedtuple from collections.abc import MutableMapping, Mapping import numpy as np import logging @@ -5,6 +6,8 @@ import abc from . import DataJointError, config, TransactionError +import pymysql +from datajoint import DataJointError, conn from .relational_operand import RelationalOperand from .blob import pack from .utils import user_choice @@ -13,6 +16,65 @@ logger = logging.getLogger(__name__) +SharedInfo = namedtuple( + 'SharedInfo', + ('database', 'context', 'connection', 'heading', 'parents', 'children', 'references', 'referenced')) + + +class classproperty: + def __init__(self, getf): + self._getf = getf + + def __get__(self, instance, owner): + return self._getf(owner) + + +def schema(database, context, connection=None): + """ + Returns a schema decorator that can be used to associate a Relation class to a + specific database with :param name:. Name reference to other tables in the table definition + will be resolved by looking up the corresponding key entry in the passed in context dictionary. + It is most common to set context equal to the return value of call to locals() in the module. + For more details, please refer to the tutorial online. + + :param database: name of the database to associate the decorated class with + :param context: dictionary used to resolve (any) name references within the table definition string + :param connection: connection object to the database server. If ommited, will try to establish connection according to + config values + :return: a decorator function to be used on Relation derivative classes + """ + if connection is None: + connection = conn() + + # if the database does not exist, create it + cur = connection.query("SHOW DATABASES LIKE '{database}'".format(database=database)) + if cur.rowcount == 0: + logger.info("Database `{database}` could not be found. " + "Attempting to create the database.".format(database=database)) + try: + connection.query("CREATE DATABASE `{database}`".format(database=database)) + logger.info('Created database `{database}`.'.format(database=database)) + except pymysql.OperationalError: + raise DataJointError("Database named `{database}` was not defined, and" + "an attempt to create has failed. Check" + " permissions.".format(database=database)) + + def decorator(cls): + cls._shared_info = SharedInfo( + database=database, + context=context, + connection=connection, + heading=None, + parents=[], + children=[], + references=[], + referenced=[] + ) + return cls + + return decorator + + class Relation(RelationalOperand, metaclass=abc.ABCMeta): """ @@ -26,103 +88,109 @@ class Relation(RelationalOperand, metaclass=abc.ABCMeta): __heading = None + _shared_info = None + + def __init__(self): # TODO: Think about it + if self._shared_info is None: + raise DataJointError('The class must define _shared_info') + # ---------- abstract properties ------------ # - @property + @classproperty @abc.abstractmethod - def table_name(self): + def table_name(cls): """ :return: the name of the table in the database """ pass - @property + @classproperty @abc.abstractmethod - def database(self): + def database(cls): """ :return: string containing the database name on the server """ pass - @property + @classproperty @abc.abstractmethod - def definition(self): + def definition(cls): """ :return: a string containing the table definition using the DataJoint DDL """ pass - @property + @classproperty @abc.abstractmethod - def context(self): + def context(cls): """ :return: a dict with other relations that can be referenced by foreign keys """ pass # --------- base relation functionality --------- # - @property - def is_declared(self): - if self.__heading is not None: + @classproperty + def is_declared(cls): + if cls.__heading is not None: return True - cur = self.query( + cur = cls._shared_info.connection.query( 'SHOW TABLE STATUS FROM `{database}` WHERE name="{table_name}"'.format( - table_name=self.table_name)) + table_name=cls.table_name)) return cur.rowcount == 1 - @property - def heading(self): + @classproperty + def heading(cls): """ Required by relational operand :return: a datajoint.Heading object """ - if self.__heading is None: - if not self.is_declared and self.definition: - self.declare() - if self.is_declared: - self.__heading = Heading.init_from_database( - self.connection, self.database, self.table_name) - return self.__heading + if cls.__heading is None: + cls.__heading = Heading.init_from_database(cls.connection, cls.database, cls.table_name) + return cls.__heading - @property - def from_clause(self): + @classproperty + def from_clause(cls): """ Required by the Relational class, this property specifies the contents of the FROM clause for the SQL SELECT statements. :return: """ - return '`%s`.`%s`' % (self.database, self.table_name) + return '`%s`.`%s`' % (cls.database, cls.table_name) - def declare(self): + @classmethod + def declare(cls): """ Declare the table in database if it doesn't already exist. :raises: DataJointError if the table cannot be declared. """ - if not self.is_declared: - self._declare() - # verify that declaration completed successfully - if not self.is_declared: - raise DataJointError( - 'Relation could not be declared for %s' % self.class_name) + if not cls.is_declared: + cls._declare() - def iter_insert(self, rows, **kwargs): + @classmethod + def iter_insert(cls, rows, **kwargs): """ Inserts an entire batch of entries. Additional keyword arguments are passed to insert. :param iter: Must be an iterator that generates a sequence of valid arguments for insert. """ for row in rows: - self.insert(row, **kwargs) + cls.insert(row, **kwargs) - def batch_insert(self, data, **kwargs): + @classmethod + def batch_insert(cls, data, **kwargs): """ Inserts an entire batch of entries. Additional keyword arguments are passed to insert. :param data: must be iterable, each row must be a valid argument for insert """ - self.iter_insert(data.__iter__(), **kwargs) + cls.iter_insert(data.__iter__(), **kwargs) - def insert(self, tup, ignore_errors=False, replace=False): + @classproperty + def full_table_name(cls): + return r"`{0:s}`.`{1:s}`".format(cls.database, cls.table_name) + + @classmethod + def insert(cls, tup, ignore_errors=False, replace=False): """ Insert one data record or one Mapping (like a dictionary). @@ -137,7 +205,7 @@ def insert(self, tup, ignore_errors=False, replace=False): real_id = 1007, date_of_birth = "2014-09-01")) """ - heading = self.heading + heading = cls.heading if isinstance(tup, np.void): for fieldname in tup.dtype.fields: if fieldname not in heading: @@ -167,10 +235,10 @@ def insert(self, tup, ignore_errors=False, replace=False): sql = 'INSERT IGNORE' else: sql = 'INSERT' - sql += " INTO %s (%s) VALUES (%s)" % (self.full_table_name, + sql += " INTO %s (%s) VALUES (%s)" % (cls.full_table_name, attribute_list, value_list) logger.info(sql) - self.connection.query(sql, args=args) + cls.connection.query(sql, args=args) def delete(self): if not config['safemode'] or user_choice( @@ -178,39 +246,41 @@ def delete(self): "Proceed?", default='no') == 'yes': self.connection.query('DELETE FROM ' + self.from_clause + self.where_clause) # TODO: make cascading (issue #15) - def drop(self): + @classmethod + def drop(cls): """ - Drops the table associated to this object. + Drops the table associated to this class. """ - if self.is_declared: + if cls.is_declared: if not config['safemode'] or user_choice( "You are about to drop an entire table. This operation cannot be undone.\n" "Proceed?", default='no') == 'yes': - self.connection.query('DROP TABLE %s' % self.full_table_name) # TODO: make cascading (issue #16) - self.connection.clear_dependencies(dbname=self.dbname) - self.connection.load_headings(dbname=self.dbname, force=True) - logger.info("Dropped table %s" % self.full_table_name) + cls.connection.query('DROP TABLE %s' % cls.full_table_name) # TODO: make cascading (issue #16) + # cls.connection.clear_dependencies(dbname=cls.dbname) #TODO: reimplement because clear_dependencies will be gone + # cls.connection.load_headings(dbname=cls.dbname, force=True) #TODO: reimplement because load_headings is gone + logger.info("Dropped table %s" % cls.full_table_name) - @property - def size_on_disk(self): + @classproperty + def size_on_disk(cls): """ :return: size of data and indices in MiB taken by the table on the storage device """ - cur = self.connection.query( + cur = cls.connection.query( 'SHOW TABLE STATUS FROM `{dbname}` WHERE NAME="{table}"'.format( - dbname=self.dbname, table=self.table_name), as_dict=True) + dbname=cls.dbname, table=cls.table_name), as_dict=True) ret = cur.fetchone() return (ret['Data_length'] + ret['Index_length'])/1024**2 - def set_table_comment(self, comment): + @classmethod + def set_table_comment(cls, comment): """ Update the table comment in the table definition. :param comment: new comment as string """ - # TODO: add verification procedure (github issue #24) - self.alter('COMMENT="%s"' % comment) + cls.alter('COMMENT="%s"' % comment) - def add_attribute(self, definition, after=None): + @classmethod + def add_attribute(cls, definition, after=None): """ Add a new attribute to the table. A full line from the table definition is passed in as definition. @@ -226,9 +296,10 @@ def add_attribute(self, definition, after=None): position = ' FIRST' if after is None else ( ' AFTER %s' % after if after else '') sql = field_to_sql(parse_attribute_definition(definition)) - self._alter('ADD COLUMN %s%s' % (sql[:-2], position)) + cls._alter('ADD COLUMN %s%s' % (sql[:-2], position)) - def drop_attribute(self, attr_name): + @classmethod + def drop_attribute(cls, attr_name): """ Drops the attribute attrName from this table. @@ -238,9 +309,10 @@ def drop_attribute(self, attr_name): "You are about to drop an attribute from a table." "This operation cannot be undone.\n" "Proceed?", default='no') == 'yes': - self._alter('DROP COLUMN `%s`' % attr_name) + cls._alter('DROP COLUMN `%s`' % attr_name) - def alter_attribute(self, attr_name, new_definition): + @classmethod + def alter_attribute(cls, attr_name, new_definition): """ Alter the definition of the field attr_name in this table using the new definition. @@ -248,44 +320,46 @@ def alter_attribute(self, attr_name, new_definition): :param new_definition: new definition of the field """ sql = field_to_sql(parse_attribute_definition(new_definition)) - self._alter('CHANGE COLUMN `%s` %s' % (attr_name, sql[:-2])) + cls._alter('CHANGE COLUMN `%s` %s' % (attr_name, sql[:-2])) - def erd(self, subset=None): + @classmethod + def erd(cls, subset=None): """ Plot the schema's entity relationship diagram (ERD). """ - def _alter(self, alter_statement): + @classmethod + def _alter(cls, alter_statement): """ Execute ALTER TABLE statement for this table. The schema will be reloaded within the connection object. :param alter_statement: alter statement """ - if self._conn.in_transaction: + if cls.connection.in_transaction: raise TransactionError( u"_alter is currently in transaction. Operation not allowed to avoid implicit commits.", - self._alter, args=(alter_statement,)) + cls._alter, args=(alter_statement,)) - sql = 'ALTER TABLE %s %s' % (self.full_table_name, alter_statement) - self.connection.query(sql) - self.connection.load_headings(self.dbname, force=True) + sql = 'ALTER TABLE %s %s' % (cls.full_table_name, alter_statement) + cls.connection.query(sql) + cls.connection.load_headings(cls.dbname, force=True) # TODO: place table definition sync mechanism - @staticmethod - def _declare(self): + @classmethod + def _declare(cls): """ Declares the table in the database if no table in the database matches this object. """ - if self.connection.in_transaction: + if cls.connection.in_transaction: raise TransactionError( - u"_declare is currently in transaction. Operation not allowed to avoid implicit commits.", self._declare) + u"_declare is currently in transaction. Operation not allowed to avoid implicit commits.", cls._declare) - if not self.definition: # if empty definition was supplied + if not cls.definition: # if empty definition was supplied raise DataJointError('Table definition is missing!') - table_info, parents, referenced, field_defs, index_defs = self._parse_declaration() + table_info, parents, referenced, field_defs, index_defs = cls._parse_declaration() - sql = 'CREATE TABLE %s (\n' % self.full_table_name + sql = 'CREATE TABLE %s (\n' % cls.full_table_name # add inherited primary key fields primary_key_fields = set() @@ -349,15 +423,16 @@ def _declare(self): sql = '%s\n) ENGINE = InnoDB, COMMENT "%s"' % ( sql[:-2], table_info['comment']) - # make sure that the table does not alredy exist - self.load_heading() - if not self.is_declared: - # execute declaration - logger.debug('\n\n' + sql + '\n\n') - self.connection.query(sql) - self.load_heading() + # # make sure that the table does not alredy exist + # cls.load_heading() + # if not cls.is_declared: + # # execute declaration + # logger.debug('\n\n' + sql + '\n\n') + # cls.connection.query(sql) + # cls.load_heading() - def _parse_declaration(self): + @classmethod + def _parse_declaration(cls): """ Parse declaration and create new SQL table accordingly. """ @@ -365,7 +440,7 @@ def _parse_declaration(self): referenced = [] index_defs = [] field_defs = [] - declaration = re.split(r'\s*\n\s*', self.definition.strip()) + declaration = re.split(r'\s*\n\s*', cls.definition.strip()) # remove comment lines declaration = [x for x in declaration if not x.startswith('#')] @@ -391,7 +466,7 @@ def _parse_declaration(self): # foreign key ref_name = line[2:].strip() ref_list = parents if in_key else referenced - ref_list.append(self.lookup_name(ref_name)) + ref_list.append(cls.lookup_name(ref_name)) elif re.match(r'^(unique\s+)?index[^:]*$', line, re.I): index_defs.append(parse_index_definition(line)) elif attribute_regexp.match(line): @@ -402,7 +477,8 @@ def _parse_declaration(self): return table_info, parents, referenced, field_defs, index_defs - def lookup_name(self, name): + @classmethod + def lookup_name(cls, name): """ Lookup the referenced name in the context dictionary @@ -411,7 +487,7 @@ def lookup_name(self, name): """ parts = name.strip().split('.') try: - ref = self.context.get(parts[0]) + ref = cls.context.get(parts[0]) for attr in parts[1:]: ref = getattr(ref, attr) except (KeyError, AttributeError): @@ -419,4 +495,21 @@ def lookup_name(self, name): 'Foreign key reference to %s could not be resolved.' 'Please make sure the name exists' 'in the context of the class' % name) - return ref \ No newline at end of file + return ref + + @classproperty + def connection(cls): + """ + Returns the connection object of the class + + :return: the connection object + """ + return cls.connection + + @classproperty + def database(cls): + return cls._shared_info.database + + @classproperty + def context(cls): + return cls._shared_info.context diff --git a/datajoint/relations.py b/datajoint/relations.py deleted file mode 100644 index 08f3f8900..000000000 --- a/datajoint/relations.py +++ /dev/null @@ -1,94 +0,0 @@ -import abc -import logging -from collections import namedtuple -import pymysql -from .connection import conn -from .abstract_relation import Relation -from . import DataJointError - - -logger = logging.getLogger(__name__) - - -SharedInfo = namedtuple( - 'SharedInfo', - ('database', 'context', 'connection', 'heading', 'parents', 'children', 'references', 'referenced')) - - -def schema(database, context, connection=None): - """ - Returns a schema decorator that can be used to associate a Relation class to a - specific database with :param name:. Name reference to other tables in the table definition - will be resolved by looking up the corresponding key entry in the passed in context dictionary. - It is most common to set context equal to the return value of call to locals() in the module. - For more details, please refer to the tutorial online. - - :param database: name of the database to associate the decorated class with - :param context: dictionary used to resolve (any) name references within the table definition string - :param connection: connection object to the database server. If ommited, will try to establish connection according to - config values - :return: a decorator function to be used on Relation derivative classes - """ - if connection is None: - connection = conn() - - # if the database does not exist, create it - cur = connection.query("SHOW DATABASES LIKE '{database}'".format(database=database)) - if cur.rowcount == 0: - logger.info("Database `{database}` could not be found. " - "Attempting to create the database.".format(database=database)) - try: - connection.query("CREATE DATABASE `{database}`".format(database=database)) - logger.info('Created database `{database}`.'.format(database=database)) - except pymysql.OperationalError: - raise DataJointError("Database named `{database}` was not defined, and" - "an attempt to create has failed. Check" - " permissions.".format(database=database)) - - def decorator(cls): - cls._shared_info = SharedInfo( - database=database, - context=context, - connection=connection, - heading=None, - parents=[], - children=[], - references=[], - referenced=[] - ) - return cls - - return decorator - - -class ClassBoundRelation(Relation): - """ - Abstract class for dedicated table classes. - Subclasses of ClassBoundRelation are dedicated interfaces to a single table. - The main purpose of ClassBoundRelation is to encapsulated sharedInfo containing the table heading - and dependency information shared by all instances of - """ - - _shared_info = None - - def __init__(self): - if self._shared_info is None: - raise DataJointError('The class must define _shared_info') - - @property - def database(self): - return self._shared_info.database - - @property - def connection(self): - return self._shared_info.connection - - @property - def context(self): - return self._shared_info.context - - @property - def heading(self): - if self._shared_info.heading is None: - self._shared_info.heading = super().heading - return self._shared_info.heading diff --git a/datajoint/user_relations.py b/datajoint/user_relations.py index 7fb5aba1a..56556508d 100644 --- a/datajoint/user_relations.py +++ b/datajoint/user_relations.py @@ -1,4 +1,4 @@ -from .relations import ClassBoundRelation +from datajoint.relation import ClassBoundRelation from .autopopulate import AutoPopulate from .utils import from_camel_case diff --git a/tests/test_relation.py b/tests/test_relation.py index 3facc6721..15f52f34a 100644 --- a/tests/test_relation.py +++ b/tests/test_relation.py @@ -12,7 +12,7 @@ from datajoint import DataJointError, TransactionError, AutoPopulate, Relation import numpy as np from numpy.testing import assert_array_equal -from datajoint.abstract_relation import FreeRelation +from datajoint.relation import FreeRelation import numpy as np From 2593680d8449bf290e0672dff775b9ccaba95787 Mon Sep 17 00:00:00 2001 From: Fabian Sinz Date: Mon, 8 Jun 2015 17:26:37 -0500 Subject: [PATCH 09/42] * decorator declares table --- datajoint/relation.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/datajoint/relation.py b/datajoint/relation.py index ef9e43938..729c92651 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -65,11 +65,8 @@ def decorator(cls): context=context, connection=connection, heading=None, - parents=[], - children=[], - references=[], - referenced=[] ) + cls.declare() return cls return decorator From 4097adbfea786fa78d0fea5f173712c683ace01f Mon Sep 17 00:00:00 2001 From: Fabian Sinz Date: Mon, 8 Jun 2015 17:32:35 -0500 Subject: [PATCH 10/42] * removed implicit commit handling --- datajoint/__init__.py | 18 ------------------ datajoint/autopopulate.py | 13 +------------ datajoint/relation.py | 9 ++++----- 3 files changed, 5 insertions(+), 35 deletions(-) diff --git a/datajoint/__init__.py b/datajoint/__init__.py index 4eb1eb6d1..528e8d74d 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -17,24 +17,6 @@ class DataJointError(Exception): pass -class TransactionError(DataJointError): - """ - Base class for errors specific to DataJoint internal operation. - """ - def __init__(self, msg, f, args=None, kwargs=None): - super(TransactionError, self).__init__(msg) - self.operations = (f, args if args is not None else tuple(), - kwargs if kwargs is not None else {}) - - def resolve(self): - f, args, kwargs = self.operations - return f(*args, **kwargs) - - @property - def culprit(self): - return self.operations[0].__name__ - - # ----------- loads local configuration from file ---------------- from .settings import Config, CONFIGVAR, LOCALCONFIG, logger, log_levels config = Config() diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index a266b62f9..57de7ae8e 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -66,18 +66,7 @@ def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False, else: logger.info('Populating: ' + str(key)) try: - for attempts in range(max_attempts): - try: - self._make_tuples(dict(key)) - break - except TransactionError as tr_err: - self.conn.cancel_transaction() - tr_err.resolve() - self.conn.start_transaction() - logger.info('Transaction error in {0:s}.'.format(tr_err.culprit)) - else: - raise DataJointError( - '%s._make_tuples failed after %i attempts, giving up' % (self.__class__,max_attempts)) + self._make_tuples(dict(key)) except Exception as error: self.conn.cancel_transaction() if not suppress_errors: diff --git a/datajoint/relation.py b/datajoint/relation.py index 729c92651..448feb526 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -334,9 +334,8 @@ def _alter(cls, alter_statement): :param alter_statement: alter statement """ if cls.connection.in_transaction: - raise TransactionError( - u"_alter is currently in transaction. Operation not allowed to avoid implicit commits.", - cls._alter, args=(alter_statement,)) + raise DataJointError( + u"_alter is currently in transaction. Operation not allowed to avoid implicit commits.") sql = 'ALTER TABLE %s %s' % (cls.full_table_name, alter_statement) cls.connection.query(sql) @@ -349,8 +348,8 @@ def _declare(cls): Declares the table in the database if no table in the database matches this object. """ if cls.connection.in_transaction: - raise TransactionError( - u"_declare is currently in transaction. Operation not allowed to avoid implicit commits.", cls._declare) + raise DataJointError( + u"_declare is currently in transaction. Operation not allowed to avoid implicit commits.") if not cls.definition: # if empty definition was supplied raise DataJointError('Table definition is missing!') From b6c3f96d872f933ed5132b8b5bbe3c0ff61351e7 Mon Sep 17 00:00:00 2001 From: Fabian Sinz Date: Mon, 8 Jun 2015 18:33:56 -0500 Subject: [PATCH 11/42] first test runs without failing --- datajoint/__init__.py | 2 +- datajoint/autopopulate.py | 2 +- datajoint/connection.py | 2 +- datajoint/relation.py | 5 +- datajoint/user_relations.py | 10 +- tests/__init__.py | 70 ++- tests/schemata/schema1/__init__.py | 10 +- tests/schemata/schema1/test1.py | 332 +++++----- tests/schemata/schema1/test2.py | 96 +-- tests/schemata/schema1/test3.py | 42 +- tests/schemata/schema1/test4.py | 34 +- tests/schemata/schema2/__init__.py | 4 +- tests/schemata/schema2/test1.py | 32 +- tests/test_connection.py | 492 +++++++-------- tests/test_free_relation.py | 410 ++++++------ tests/test_relation.py | 980 ++++++++++++++--------------- tests/test_relational_operand.py | 94 +-- tests/test_settings.py | 134 ++-- tests/test_utils.py | 66 +- 19 files changed, 1410 insertions(+), 1407 deletions(-) diff --git a/datajoint/__init__.py b/datajoint/__init__.py index 528e8d74d..8ebe93992 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -37,8 +37,8 @@ class DataJointError(Exception): # ------------- flatten import hierarchy ------------------------- from .connection import conn, Connection -from .user_relations import Manual, Lookup, Imported, Computed from .relation import Relation +from .user_relations import Manual, Lookup, Imported, Computed from .autopopulate import AutoPopulate from . import blob from .relational_operand import Not diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index 57de7ae8e..ba349b250 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -1,5 +1,5 @@ from .relational_operand import RelationalOperand -from . import DataJointError, TransactionError, Relation +from . import DataJointError, Relation import abc import logging diff --git a/datajoint/connection.py b/datajoint/connection.py index 52dae7598..3ca01b6bd 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -17,7 +17,7 @@ def conn_container(): """ _connection = None # persistent connection object used by dj.conn() - def conn_function(host=None, user=None, passwd=None, init_fun=None, reset=False): + def conn_function(host=None, user=None, passwd=None, init_fun=None, reset=False): # TODO: thin wrapping layer to mimic singleton """ Manage a persistent connection object. This is one of several ways to configure and access a datajoint connection. diff --git a/datajoint/relation.py b/datajoint/relation.py index 448feb526..19c2ebac9 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -5,7 +5,7 @@ import re import abc -from . import DataJointError, config, TransactionError +from . import DataJointError, config import pymysql from datajoint import DataJointError, conn from .relational_operand import RelationalOperand @@ -21,6 +21,7 @@ ('database', 'context', 'connection', 'heading', 'parents', 'children', 'references', 'referenced')) + class classproperty: def __init__(self, getf): self._getf = getf @@ -500,7 +501,7 @@ def connection(cls): :return: the connection object """ - return cls.connection + return cls._shared_info.connection @classproperty def database(cls): diff --git a/datajoint/user_relations.py b/datajoint/user_relations.py index 56556508d..60e013e6f 100644 --- a/datajoint/user_relations.py +++ b/datajoint/user_relations.py @@ -1,30 +1,30 @@ -from datajoint.relation import ClassBoundRelation +from datajoint.relation import Relation from .autopopulate import AutoPopulate from .utils import from_camel_case -class Manual(ClassBoundRelation): +class Manual(Relation): @property @classmethod def table_name(cls): return from_camel_case(cls.__name__) -class Lookup(ClassBoundRelation): +class Lookup(Relation): @property @classmethod def table_name(cls): return '#' + from_camel_case(cls.__name__) -class Imported(ClassBoundRelation, AutoPopulate): +class Imported(Relation, AutoPopulate): @property @classmethod def table_name(cls): return "_" + from_camel_case(cls.__name__) -class Computed(ClassBoundRelation, AutoPopulate): +class Computed(Relation, AutoPopulate): @property @classmethod def table_name(cls): diff --git a/tests/__init__.py b/tests/__init__.py index 09e358e98..4d4101116 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -18,6 +18,9 @@ 'user': environ.get('DJ_TEST_USER', 'datajoint'), 'passwd': environ.get('DJ_TEST_PASSWORD', 'datajoint') } + +conn = dj.conn(**CONN_INFO) + # Prefix for all databases used during testing PREFIX = environ.get('DJ_TEST_DB_PREFIX', 'dj') # Bare connection used for verification of query results @@ -32,7 +35,6 @@ def setup(): def teardown(): cleanup() - def cleanup(): """ Removes all databases with name starting with the prefix. @@ -52,36 +54,36 @@ def cleanup(): cur.execute('DROP DATABASE `{}`'.format(db)) cur.execute('SET FOREIGN_KEY_CHECKS=1') # set foreign key check back on cur.execute("COMMIT") - -def setup_sample_db(): - """ - Helper method to setup databases with tables to be used - during the test - """ - cur = BASE_CONN.cursor() - cur.execute("CREATE DATABASE IF NOT EXISTS `{}_test1`".format(PREFIX)) - cur.execute("CREATE DATABASE IF NOT EXISTS `{}_test2`".format(PREFIX)) - query1 = """ - CREATE TABLE `{prefix}_test1`.`subjects` - ( - subject_id SMALLINT COMMENT 'Unique subject ID', - subject_name VARCHAR(255) COMMENT 'Subject name', - subject_email VARCHAR(255) COMMENT 'Subject email address', - PRIMARY KEY (subject_id) - ) - """.format(prefix=PREFIX) - cur.execute(query1) - query2 = """ - CREATE TABLE `{prefix}_test2`.`experimenter` - ( - experimenter_id SMALLINT COMMENT 'Unique experimenter ID', - experimenter_name VARCHAR(255) COMMENT 'Experimenter name', - PRIMARY KEY (experimenter_id) - )""".format(prefix=PREFIX) - cur.execute(query2) - - - - - - +# +# def setup_sample_db(): +# """ +# Helper method to setup databases with tables to be used +# during the test +# """ +# cur = BASE_CONN.cursor() +# cur.execute("CREATE DATABASE IF NOT EXISTS `{}_test1`".format(PREFIX)) +# cur.execute("CREATE DATABASE IF NOT EXISTS `{}_test2`".format(PREFIX)) +# query1 = """ +# CREATE TABLE `{prefix}_test1`.`subjects` +# ( +# subject_id SMALLINT COMMENT 'Unique subject ID', +# subject_name VARCHAR(255) COMMENT 'Subject name', +# subject_email VARCHAR(255) COMMENT 'Subject email address', +# PRIMARY KEY (subject_id) +# ) +# """.format(prefix=PREFIX) +# cur.execute(query1) +# query2 = """ +# CREATE TABLE `{prefix}_test2`.`experimenter` +# ( +# experimenter_id SMALLINT COMMENT 'Unique experimenter ID', +# experimenter_name VARCHAR(255) COMMENT 'Experimenter name', +# PRIMARY KEY (experimenter_id) +# )""".format(prefix=PREFIX) +# cur.execute(query2) +# +# +# +# +# +# diff --git a/tests/schemata/schema1/__init__.py b/tests/schemata/schema1/__init__.py index 6032e7bd6..cae90cec9 100644 --- a/tests/schemata/schema1/__init__.py +++ b/tests/schemata/schema1/__init__.py @@ -1,5 +1,5 @@ -__author__ = 'eywalker' -import datajoint as dj - -print(__name__) -from .test3 import * \ No newline at end of file +# __author__ = 'eywalker' +# import datajoint as dj +# +# print(__name__) +# from .test3 import * \ No newline at end of file diff --git a/tests/schemata/schema1/test1.py b/tests/schemata/schema1/test1.py index 4c8df082f..d0d91707f 100644 --- a/tests/schemata/schema1/test1.py +++ b/tests/schemata/schema1/test1.py @@ -1,166 +1,166 @@ -""" -Test 1 Schema definition -""" -__author__ = 'eywalker' - -import datajoint as dj -from .. import schema2 - - -class Subjects(dj.Relation): - definition = """ - test1.Subjects (manual) # Basic subject info - - subject_id : int # unique subject id - --- - real_id : varchar(40) # real-world name - species = "mouse" : enum('mouse', 'monkey', 'human') # species - """ - -# test for shorthand -class Animals(dj.Relation): - definition = """ - test1.Animals (manual) # Listing of all info - - -> Subjects - --- - animal_dob :date # date of birth - """ - - -class Trials(dj.Relation): - definition = """ - test1.Trials (manual) # info about trials - - -> test1.Subjects - trial_id : int - --- - outcome : int # result of experiment - - notes="" : varchar(4096) # other comments - trial_ts=CURRENT_TIMESTAMP : timestamp # automatic - """ - - - -class SquaredScore(dj.Relation, dj.AutoPopulate): - definition = """ - test1.SquaredScore (computed) # cumulative outcome of trials - - -> test1.Subjects - -> test1.Trials - --- - squared : int # squared result of Trials outcome - """ - - @property - def populate_relation(self): - return Subjects() * Trials() - - def _make_tuples(self, key): - tmp = (Trials() & key).fetch1() - tmp2 = SquaredSubtable() & key - - self.insert(dict(key, squared=tmp['outcome']**2)) - - ss = SquaredSubtable() - - for i in range(10): - key['dummy'] = i - ss.insert(key) - - -class WrongImplementation(dj.Relation, dj.AutoPopulate): - definition = """ - test1.WrongImplementation (computed) # ignore - - -> test1.Subjects - -> test1.Trials - --- - dummy : int # ignore - """ - - @property - def populate_relation(self): - return {'subject_id':2} - - def _make_tuples(self, key): - pass - - - -class ErrorGenerator(dj.Relation, dj.AutoPopulate): - definition = """ - test1.ErrorGenerator (computed) # ignore - - -> test1.Subjects - -> test1.Trials - --- - dummy : int # ignore - """ - - @property - def populate_relation(self): - return Subjects() * Trials() - - def _make_tuples(self, key): - raise Exception("This is for testing") - - - - - - -class SquaredSubtable(dj.Relation): - definition = """ - test1.SquaredSubtable (computed) # cumulative outcome of trials - - -> test1.SquaredScore - dummy : int # dummy primary attribute - --- - """ - - -# test reference to another table in same schema -class Experiments(dj.Relation): - definition = """ - test1.Experiments (imported) # Experiment info - -> test1.Subjects - exp_id : int # unique id for experiment - --- - exp_data_file : varchar(255) # data file - """ - - -# refers to a table in dj_test2 (bound to test2) but without a class -class Sessions(dj.Relation): - definition = """ - test1.Sessions (manual) # Experiment sessions - -> test1.Subjects - -> test2.Experimenter - session_id : int # unique session id - --- - session_comment : varchar(255) # comment about the session - """ - - -class Match(dj.Relation): - definition = """ - test1.Match (manual) # Match between subject and color - -> schema2.Subjects - --- - dob : date # date of birth - """ - - -# this tries to reference a table in database directly without ORM -class TrainingSession(dj.Relation): - definition = """ - test1.TrainingSession (manual) # training sessions - -> `dj_test2`.Experimenter - session_id : int # training session id - """ - - -class Empty(dj.Relation): - pass +# """ +# Test 1 Schema definition +# """ +# __author__ = 'eywalker' +# +# import datajoint as dj +# from .. import schema2 +# +# +# class Subjects(dj.Relation): +# definition = """ +# test1.Subjects (manual) # Basic subject info +# +# subject_id : int # unique subject id +# --- +# real_id : varchar(40) # real-world name +# species = "mouse" : enum('mouse', 'monkey', 'human') # species +# """ +# +# # test for shorthand +# class Animals(dj.Relation): +# definition = """ +# test1.Animals (manual) # Listing of all info +# +# -> Subjects +# --- +# animal_dob :date # date of birth +# """ +# +# +# class Trials(dj.Relation): +# definition = """ +# test1.Trials (manual) # info about trials +# +# -> test1.Subjects +# trial_id : int +# --- +# outcome : int # result of experiment +# +# notes="" : varchar(4096) # other comments +# trial_ts=CURRENT_TIMESTAMP : timestamp # automatic +# """ +# +# +# +# class SquaredScore(dj.Relation, dj.AutoPopulate): +# definition = """ +# test1.SquaredScore (computed) # cumulative outcome of trials +# +# -> test1.Subjects +# -> test1.Trials +# --- +# squared : int # squared result of Trials outcome +# """ +# +# @property +# def populate_relation(self): +# return Subjects() * Trials() +# +# def _make_tuples(self, key): +# tmp = (Trials() & key).fetch1() +# tmp2 = SquaredSubtable() & key +# +# self.insert(dict(key, squared=tmp['outcome']**2)) +# +# ss = SquaredSubtable() +# +# for i in range(10): +# key['dummy'] = i +# ss.insert(key) +# +# +# class WrongImplementation(dj.Relation, dj.AutoPopulate): +# definition = """ +# test1.WrongImplementation (computed) # ignore +# +# -> test1.Subjects +# -> test1.Trials +# --- +# dummy : int # ignore +# """ +# +# @property +# def populate_relation(self): +# return {'subject_id':2} +# +# def _make_tuples(self, key): +# pass +# +# +# +# class ErrorGenerator(dj.Relation, dj.AutoPopulate): +# definition = """ +# test1.ErrorGenerator (computed) # ignore +# +# -> test1.Subjects +# -> test1.Trials +# --- +# dummy : int # ignore +# """ +# +# @property +# def populate_relation(self): +# return Subjects() * Trials() +# +# def _make_tuples(self, key): +# raise Exception("This is for testing") +# +# +# +# +# +# +# class SquaredSubtable(dj.Relation): +# definition = """ +# test1.SquaredSubtable (computed) # cumulative outcome of trials +# +# -> test1.SquaredScore +# dummy : int # dummy primary attribute +# --- +# """ +# +# +# # test reference to another table in same schema +# class Experiments(dj.Relation): +# definition = """ +# test1.Experiments (imported) # Experiment info +# -> test1.Subjects +# exp_id : int # unique id for experiment +# --- +# exp_data_file : varchar(255) # data file +# """ +# +# +# # refers to a table in dj_test2 (bound to test2) but without a class +# class Sessions(dj.Relation): +# definition = """ +# test1.Sessions (manual) # Experiment sessions +# -> test1.Subjects +# -> test2.Experimenter +# session_id : int # unique session id +# --- +# session_comment : varchar(255) # comment about the session +# """ +# +# +# class Match(dj.Relation): +# definition = """ +# test1.Match (manual) # Match between subject and color +# -> schema2.Subjects +# --- +# dob : date # date of birth +# """ +# +# +# # this tries to reference a table in database directly without ORM +# class TrainingSession(dj.Relation): +# definition = """ +# test1.TrainingSession (manual) # training sessions +# -> `dj_test2`.Experimenter +# session_id : int # training session id +# """ +# +# +# class Empty(dj.Relation): +# pass diff --git a/tests/schemata/schema1/test2.py b/tests/schemata/schema1/test2.py index 563fe6b52..aded2e4fb 100644 --- a/tests/schemata/schema1/test2.py +++ b/tests/schemata/schema1/test2.py @@ -1,48 +1,48 @@ -""" -Test 2 Schema definition -""" -__author__ = 'eywalker' - -import datajoint as dj -from . import test1 as alias -#from ..schema2 import test2 as test1 - - -# references to another schema -class Experiments(dj.Relation): - definition = """ - test2.Experiments (manual) # Basic subject info - -> test1.Subjects - experiment_id : int # unique experiment id - --- - real_id : varchar(40) # real-world name - species = "mouse" : enum('mouse', 'monkey', 'human') # species - """ - - -# references to another schema -class Conditions(dj.Relation): - definition = """ - test2.Conditions (manual) # Subject conditions - -> alias.Subjects - condition_name : varchar(255) # description of the condition - """ - - -class FoodPreference(dj.Relation): - definition = """ - test2.FoodPreference (manual) # Food preference of each subject - -> animals.Subjects - preferred_food : enum('banana', 'apple', 'oranges') - """ - - -class Session(dj.Relation): - definition = """ - test2.Session (manual) # Experiment sessions - -> test1.Subjects - -> test2.Experimenter - session_id : int # unique session id - --- - session_comment : varchar(255) # comment about the session - """ \ No newline at end of file +# """ +# Test 2 Schema definition +# """ +# __author__ = 'eywalker' +# +# import datajoint as dj +# from . import test1 as alias +# #from ..schema2 import test2 as test1 +# +# +# # references to another schema +# class Experiments(dj.Relation): +# definition = """ +# test2.Experiments (manual) # Basic subject info +# -> test1.Subjects +# experiment_id : int # unique experiment id +# --- +# real_id : varchar(40) # real-world name +# species = "mouse" : enum('mouse', 'monkey', 'human') # species +# """ +# +# +# # references to another schema +# class Conditions(dj.Relation): +# definition = """ +# test2.Conditions (manual) # Subject conditions +# -> alias.Subjects +# condition_name : varchar(255) # description of the condition +# """ +# +# +# class FoodPreference(dj.Relation): +# definition = """ +# test2.FoodPreference (manual) # Food preference of each subject +# -> animals.Subjects +# preferred_food : enum('banana', 'apple', 'oranges') +# """ +# +# +# class Session(dj.Relation): +# definition = """ +# test2.Session (manual) # Experiment sessions +# -> test1.Subjects +# -> test2.Experimenter +# session_id : int # unique session id +# --- +# session_comment : varchar(255) # comment about the session +# """ \ No newline at end of file diff --git a/tests/schemata/schema1/test3.py b/tests/schemata/schema1/test3.py index 2004a8736..e00a01afb 100644 --- a/tests/schemata/schema1/test3.py +++ b/tests/schemata/schema1/test3.py @@ -1,21 +1,21 @@ -""" -Test 3 Schema definition - no binding, no conn - -To be bound at the package level -""" -__author__ = 'eywalker' - -import datajoint as dj - - -class Subjects(dj.Relation): - definition = """ - schema1.Subjects (manual) # Basic subject info - - subject_id : int # unique subject id - dob : date # date of birth - --- - real_id : varchar(40) # real-world name - species = "mouse" : enum('mouse', 'monkey', 'human') # species - """ - +# """ +# Test 3 Schema definition - no binding, no conn +# +# To be bound at the package level +# """ +# __author__ = 'eywalker' +# +# import datajoint as dj +# +# +# class Subjects(dj.Relation): +# definition = """ +# schema1.Subjects (manual) # Basic subject info +# +# subject_id : int # unique subject id +# dob : date # date of birth +# --- +# real_id : varchar(40) # real-world name +# species = "mouse" : enum('mouse', 'monkey', 'human') # species +# """ +# diff --git a/tests/schemata/schema1/test4.py b/tests/schemata/schema1/test4.py index a2004affd..9860cb030 100644 --- a/tests/schemata/schema1/test4.py +++ b/tests/schemata/schema1/test4.py @@ -1,17 +1,17 @@ -""" -Test 1 Schema definition - fully bound and has connection object -""" -__author__ = 'fabee' - -import datajoint as dj - - -class Matrix(dj.Relation): - definition = """ - test4.Matrix (manual) # Some numpy array - - matrix_id : int # unique matrix id - --- - data : longblob # data - comment : varchar(1000) # comment - """ +# """ +# Test 1 Schema definition - fully bound and has connection object +# """ +# __author__ = 'fabee' +# +# import datajoint as dj +# +# +# class Matrix(dj.Relation): +# definition = """ +# test4.Matrix (manual) # Some numpy array +# +# matrix_id : int # unique matrix id +# --- +# data : longblob # data +# comment : varchar(1000) # comment +# """ diff --git a/tests/schemata/schema2/__init__.py b/tests/schemata/schema2/__init__.py index e6b482590..d79e02cc3 100644 --- a/tests/schemata/schema2/__init__.py +++ b/tests/schemata/schema2/__init__.py @@ -1,2 +1,2 @@ -__author__ = 'eywalker' -from .test1 import * \ No newline at end of file +# __author__ = 'eywalker' +# from .test1 import * \ No newline at end of file diff --git a/tests/schemata/schema2/test1.py b/tests/schemata/schema2/test1.py index 83bb3a19e..4005aa670 100644 --- a/tests/schemata/schema2/test1.py +++ b/tests/schemata/schema2/test1.py @@ -1,16 +1,16 @@ -""" -Test 2 Schema definition -""" -__author__ = 'eywalker' - -import datajoint as dj - - -class Subjects(dj.Relation): - definition = """ - schema2.Subjects (manual) # Basic subject info - pop_id : int # unique experiment id - --- - real_id : varchar(40) # real-world name - species = "mouse" : enum('mouse', 'monkey', 'human') # species - """ \ No newline at end of file +# """ +# Test 2 Schema definition +# """ +# __author__ = 'eywalker' +# +# import datajoint as dj +# +# +# class Subjects(dj.Relation): +# definition = """ +# schema2.Subjects (manual) # Basic subject info +# pop_id : int # unique experiment id +# --- +# real_id : varchar(40) # real-world name +# species = "mouse" : enum('mouse', 'monkey', 'human') # species +# """ \ No newline at end of file diff --git a/tests/test_connection.py b/tests/test_connection.py index 1fb581468..29fee4f64 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,257 +1,257 @@ -""" -Collection of test cases to test connection module. -""" -from .schemata import schema1 -from .schemata.schema1 import test1 -import numpy as np - -__author__ = 'eywalker, fabee' -from . import (CONN_INFO, PREFIX, BASE_CONN, cleanup) -from nose.tools import assert_true, assert_raises, assert_equal, raises -import datajoint as dj -from datajoint.utils import DataJointError - - -def setup(): - cleanup() - - -def test_dj_conn(): - """ - Should be able to establish a connection - """ - c = dj.conn(**CONN_INFO) - assert c.is_connected - - -def test_persistent_dj_conn(): - """ - conn() method should provide persistent connection - across calls. - """ - c1 = dj.conn(**CONN_INFO) - c2 = dj.conn() - assert_true(c1 is c2) - - -def test_dj_conn_reset(): - """ - Passing in reset=True should allow for new persistent - connection to be created. - """ - c1 = dj.conn(**CONN_INFO) - c2 = dj.conn(reset=True, **CONN_INFO) - assert_true(c1 is not c2) - - - -def setup_sample_db(): - """ - Helper method to setup databases with tables to be used - during the test - """ - cur = BASE_CONN.cursor() - cur.execute("CREATE DATABASE `{}_test1`".format(PREFIX)) - cur.execute("CREATE DATABASE `{}_test2`".format(PREFIX)) - query1 = """ - CREATE TABLE `{prefix}_test1`.`subjects` - ( - subject_id SMALLINT COMMENT 'Unique subject ID', - subject_name VARCHAR(255) COMMENT 'Subject name', - subject_email VARCHAR(255) COMMENT 'Subject email address', - PRIMARY KEY (subject_id) - ) - """.format(prefix=PREFIX) - cur.execute(query1) - # query2 = """ - # CREATE TABLE `{prefix}_test2`.`experiments` - # ( - # experiment_id SMALLINT COMMENT 'Unique experiment ID', - # experiment_name VARCHAR(255) COMMENT 'Experiment name', - # subject_id SMALLINT, - # CONSTRAINT FOREIGN KEY (`subject_id`) REFERENCES `dj_test1`.`subjects` (`subject_id`) ON UPDATE CASCADE ON DELETE RESTRICT, - # PRIMARY KEY (subject_id, experiment_id) - # )""".format(prefix=PREFIX) - # cur.execute(query2) - - -class TestConnectionWithoutBindings(object): - """ - Test methods from Connection that does not - depend on presence of module to database bindings. - This includes tests for `bind` method itself. - """ - def setup(self): - self.conn = dj.Connection(**CONN_INFO) - test1.__dict__.pop('conn', None) - schema1.__dict__.pop('conn', None) - setup_sample_db() - - def teardown(self): - cleanup() - - def check_binding(self, db_name, module): - """ - Helper method to check if the specified database-module pairing exists - """ - assert_equal(self.conn.db_to_mod[db_name], module) - assert_equal(self.conn.mod_to_db[module], db_name) - - def test_bind_to_existing_database(self): - """ - Should be able to bind a module to an existing database - """ - db_name = PREFIX + '_test1' - module = test1.__name__ - self.conn.bind(module, db_name) - self.check_binding(db_name, module) - - def test_bind_at_package_level(self): - db_name = PREFIX + '_test1' - package = schema1.__name__ - self.conn.bind(package, db_name) - self.check_binding(db_name, package) - - def test_bind_to_non_existing_database(self): - """ - Should be able to bind a module to a non-existing database by creating target - """ - db_name = PREFIX + '_test3' - module = test1.__name__ - cur = BASE_CONN.cursor() - - # Ensure target database doesn't exist - if cur.execute("SHOW DATABASES LIKE '{}'".format(db_name)): - cur.execute("DROP DATABASE IF EXISTS `{}`".format(db_name)) - # Bind module to non-existing database - self.conn.bind(module, db_name) - # Check that target database was created - assert_equal(cur.execute("SHOW DATABASES LIKE '{}'".format(db_name)), 1) - self.check_binding(db_name, module) - # Remove the target database - cur.execute("DROP DATABASE IF EXISTS `{}`".format(db_name)) - - def test_cannot_bind_to_multiple_databases(self): - """ - Bind will fail when db_name is a pattern that - matches multiple databases - """ - db_name = PREFIX + "_test%%" - module = test1.__name__ - with assert_raises(DataJointError): - self.conn.bind(module, db_name) - - def test_basic_sql_query(self): - """ - Test execution of basic SQL query using connection - object. - """ - cur = self.conn.query('SHOW DATABASES') - results1 = cur.fetchall() - cur2 = BASE_CONN.cursor() - cur2.execute('SHOW DATABASES') - results2 = cur2.fetchall() - assert_equal(results1, results2) - - def test_transaction_commit(self): - """ - Test transaction commit - """ - table_name = PREFIX + '_test1.subjects' - self.conn.start_transaction() - self.conn.query("INSERT INTO {table} VALUES (0, 'dj_user', 'dj_user@example.com')".format(table=table_name)) - cur = BASE_CONN.cursor() - assert_equal(cur.execute("SELECT * FROM {}".format(table_name)), 0) - self.conn.commit_transaction() - assert_equal(cur.execute("SELECT * FROM {}".format(table_name)), 1) - - def test_transaction_rollback(self): - """ - Test transaction rollback - """ - table_name = PREFIX + '_test1.subjects' - self.conn.start_transaction() - self.conn.query("INSERT INTO {table} VALUES (0, 'dj_user', 'dj_user@example.com')".format(table=table_name)) - cur = BASE_CONN.cursor() - assert_equal(cur.execute("SELECT * FROM {}".format(table_name)), 0) - self.conn.cancel_transaction() - assert_equal(cur.execute("SELECT * FROM {}".format(table_name)), 0) - -# class TestContextManager(object): -# def __init__(self): -# self.relvar = None -# self.setup() +# """ +# Collection of test cases to test connection module. +# """ +# from .schemata import schema1 +# from .schemata.schema1 import test1 +# import numpy as np # +# __author__ = 'eywalker, fabee' +# from . import (CONN_INFO, PREFIX, BASE_CONN, cleanup) +# from nose.tools import assert_true, assert_raises, assert_equal, raises +# import datajoint as dj +# from datajoint.utils import DataJointError +# +# +# def setup(): +# cleanup() +# +# +# def test_dj_conn(): # """ -# Test cases for FreeRelation objects +# Should be able to establish a connection # """ +# c = dj.conn(**CONN_INFO) +# assert c.is_connected # -# def setup(self): -# """ -# Create a connection object and prepare test modules -# as follows: -# test1 - has conn and bounded -# """ -# cleanup() # drop all databases with PREFIX -# test1.__dict__.pop('conn', None) # -# self.conn = dj.Connection(**CONN_INFO) -# test1.conn = self.conn -# self.conn.bind(test1.__name__, PREFIX + '_test1') -# self.relvar = test1.Subjects() +# def test_persistent_dj_conn(): +# """ +# conn() method should provide persistent connection +# across calls. +# """ +# c1 = dj.conn(**CONN_INFO) +# c2 = dj.conn() +# assert_true(c1 is c2) # -# def teardown(self): -# cleanup() # -# # def test_active(self): -# # with self.conn.transaction() as tr: -# # assert_true(tr.is_active, "Transaction is not active") -# -# # def test_rollback(self): -# # -# # tmp = np.array([(1,'Peter','mouse'),(2, 'Klara', 'monkey')], -# # dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) -# # -# # self.relvar.insert(tmp[0]) -# # try: -# # with self.conn.transaction(): -# # self.relvar.insert(tmp[1]) -# # raise DataJointError("Just to test") -# # except DataJointError as e: -# # pass -# # -# # testt2 = (self.relvar & 'subject_id = 2').fetch() -# # assert_equal(len(testt2), 0, "Length is not 0. Expected because rollback should have happened.") -# -# # def test_cancel(self): -# # """Tests cancelling a transaction""" -# # tmp = np.array([(1,'Peter','mouse'),(2, 'Klara', 'monkey')], -# # dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) -# # -# # self.relvar.insert(tmp[0]) -# # with self.conn.transaction() as transaction: -# # self.relvar.insert(tmp[1]) -# # transaction.cancel() -# # -# # testt2 = (self.relvar & 'subject_id = 2').fetch() -# # assert_equal(len(testt2), 0, "Length is not 0. Expected because rollback should have happened.") - - - -# class TestConnectionWithBindings(object): +# def test_dj_conn_reset(): +# """ +# Passing in reset=True should allow for new persistent +# connection to be created. +# """ +# c1 = dj.conn(**CONN_INFO) +# c2 = dj.conn(reset=True, **CONN_INFO) +# assert_true(c1 is not c2) +# +# +# +# def setup_sample_db(): +# """ +# Helper method to setup databases with tables to be used +# during the test +# """ +# cur = BASE_CONN.cursor() +# cur.execute("CREATE DATABASE `{}_test1`".format(PREFIX)) +# cur.execute("CREATE DATABASE `{}_test2`".format(PREFIX)) +# query1 = """ +# CREATE TABLE `{prefix}_test1`.`subjects` +# ( +# subject_id SMALLINT COMMENT 'Unique subject ID', +# subject_name VARCHAR(255) COMMENT 'Subject name', +# subject_email VARCHAR(255) COMMENT 'Subject email address', +# PRIMARY KEY (subject_id) +# ) +# """.format(prefix=PREFIX) +# cur.execute(query1) +# # query2 = """ +# # CREATE TABLE `{prefix}_test2`.`experiments` +# # ( +# # experiment_id SMALLINT COMMENT 'Unique experiment ID', +# # experiment_name VARCHAR(255) COMMENT 'Experiment name', +# # subject_id SMALLINT, +# # CONSTRAINT FOREIGN KEY (`subject_id`) REFERENCES `dj_test1`.`subjects` (`subject_id`) ON UPDATE CASCADE ON DELETE RESTRICT, +# # PRIMARY KEY (subject_id, experiment_id) +# # )""".format(prefix=PREFIX) +# # cur.execute(query2) +# +# +# class TestConnectionWithoutBindings(object): # """ -# Tests heading and dependency loadings +# Test methods from Connection that does not +# depend on presence of module to database bindings. +# This includes tests for `bind` method itself. # """ # def setup(self): # self.conn = dj.Connection(**CONN_INFO) -# cur.execute(query) - - - - - - - - - - +# test1.__dict__.pop('conn', None) +# schema1.__dict__.pop('conn', None) +# setup_sample_db() +# +# def teardown(self): +# cleanup() +# +# def check_binding(self, db_name, module): +# """ +# Helper method to check if the specified database-module pairing exists +# """ +# assert_equal(self.conn.db_to_mod[db_name], module) +# assert_equal(self.conn.mod_to_db[module], db_name) +# +# def test_bind_to_existing_database(self): +# """ +# Should be able to bind a module to an existing database +# """ +# db_name = PREFIX + '_test1' +# module = test1.__name__ +# self.conn.bind(module, db_name) +# self.check_binding(db_name, module) +# +# def test_bind_at_package_level(self): +# db_name = PREFIX + '_test1' +# package = schema1.__name__ +# self.conn.bind(package, db_name) +# self.check_binding(db_name, package) +# +# def test_bind_to_non_existing_database(self): +# """ +# Should be able to bind a module to a non-existing database by creating target +# """ +# db_name = PREFIX + '_test3' +# module = test1.__name__ +# cur = BASE_CONN.cursor() +# +# # Ensure target database doesn't exist +# if cur.execute("SHOW DATABASES LIKE '{}'".format(db_name)): +# cur.execute("DROP DATABASE IF EXISTS `{}`".format(db_name)) +# # Bind module to non-existing database +# self.conn.bind(module, db_name) +# # Check that target database was created +# assert_equal(cur.execute("SHOW DATABASES LIKE '{}'".format(db_name)), 1) +# self.check_binding(db_name, module) +# # Remove the target database +# cur.execute("DROP DATABASE IF EXISTS `{}`".format(db_name)) +# +# def test_cannot_bind_to_multiple_databases(self): +# """ +# Bind will fail when db_name is a pattern that +# matches multiple databases +# """ +# db_name = PREFIX + "_test%%" +# module = test1.__name__ +# with assert_raises(DataJointError): +# self.conn.bind(module, db_name) +# +# def test_basic_sql_query(self): +# """ +# Test execution of basic SQL query using connection +# object. +# """ +# cur = self.conn.query('SHOW DATABASES') +# results1 = cur.fetchall() +# cur2 = BASE_CONN.cursor() +# cur2.execute('SHOW DATABASES') +# results2 = cur2.fetchall() +# assert_equal(results1, results2) +# +# def test_transaction_commit(self): +# """ +# Test transaction commit +# """ +# table_name = PREFIX + '_test1.subjects' +# self.conn.start_transaction() +# self.conn.query("INSERT INTO {table} VALUES (0, 'dj_user', 'dj_user@example.com')".format(table=table_name)) +# cur = BASE_CONN.cursor() +# assert_equal(cur.execute("SELECT * FROM {}".format(table_name)), 0) +# self.conn.commit_transaction() +# assert_equal(cur.execute("SELECT * FROM {}".format(table_name)), 1) +# +# def test_transaction_rollback(self): +# """ +# Test transaction rollback +# """ +# table_name = PREFIX + '_test1.subjects' +# self.conn.start_transaction() +# self.conn.query("INSERT INTO {table} VALUES (0, 'dj_user', 'dj_user@example.com')".format(table=table_name)) +# cur = BASE_CONN.cursor() +# assert_equal(cur.execute("SELECT * FROM {}".format(table_name)), 0) +# self.conn.cancel_transaction() +# assert_equal(cur.execute("SELECT * FROM {}".format(table_name)), 0) +# +# # class TestContextManager(object): +# # def __init__(self): +# # self.relvar = None +# # self.setup() +# # +# # """ +# # Test cases for FreeRelation objects +# # """ +# # +# # def setup(self): +# # """ +# # Create a connection object and prepare test modules +# # as follows: +# # test1 - has conn and bounded +# # """ +# # cleanup() # drop all databases with PREFIX +# # test1.__dict__.pop('conn', None) +# # +# # self.conn = dj.Connection(**CONN_INFO) +# # test1.conn = self.conn +# # self.conn.bind(test1.__name__, PREFIX + '_test1') +# # self.relvar = test1.Subjects() +# # +# # def teardown(self): +# # cleanup() +# # +# # # def test_active(self): +# # # with self.conn.transaction() as tr: +# # # assert_true(tr.is_active, "Transaction is not active") +# # +# # # def test_rollback(self): +# # # +# # # tmp = np.array([(1,'Peter','mouse'),(2, 'Klara', 'monkey')], +# # # dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) +# # # +# # # self.relvar.insert(tmp[0]) +# # # try: +# # # with self.conn.transaction(): +# # # self.relvar.insert(tmp[1]) +# # # raise DataJointError("Just to test") +# # # except DataJointError as e: +# # # pass +# # # +# # # testt2 = (self.relvar & 'subject_id = 2').fetch() +# # # assert_equal(len(testt2), 0, "Length is not 0. Expected because rollback should have happened.") +# # +# # # def test_cancel(self): +# # # """Tests cancelling a transaction""" +# # # tmp = np.array([(1,'Peter','mouse'),(2, 'Klara', 'monkey')], +# # # dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) +# # # +# # # self.relvar.insert(tmp[0]) +# # # with self.conn.transaction() as transaction: +# # # self.relvar.insert(tmp[1]) +# # # transaction.cancel() +# # # +# # # testt2 = (self.relvar & 'subject_id = 2').fetch() +# # # assert_equal(len(testt2), 0, "Length is not 0. Expected because rollback should have happened.") +# +# +# +# # class TestConnectionWithBindings(object): +# # """ +# # Tests heading and dependency loadings +# # """ +# # def setup(self): +# # self.conn = dj.Connection(**CONN_INFO) +# # cur.execute(query) +# +# +# +# +# +# +# +# +# +# diff --git a/tests/test_free_relation.py b/tests/test_free_relation.py index e4ebaa872..285bf7487 100644 --- a/tests/test_free_relation.py +++ b/tests/test_free_relation.py @@ -1,205 +1,205 @@ -""" -Collection of test cases for base module. Tests functionalities such as -creating tables using docstring table declarations -""" -from .schemata import schema1, schema2 -from .schemata.schema1 import test1, test2, test3 - - -__author__ = 'eywalker' - -from . import BASE_CONN, CONN_INFO, PREFIX, cleanup, setup_sample_db -from datajoint.connection import Connection -from nose.tools import assert_raises, assert_equal, assert_regexp_matches, assert_false, assert_true, raises -from datajoint import DataJointError - - -def setup(): - """ - Setup connections and bindings - """ - pass - - -class TestRelationInstantiations(object): - """ - Test cases for instantiating Relation objects - """ - def __init__(self): - self.conn = None - - def setup(self): - """ - Create a connection object and prepare test modules - as follows: - test1 - has conn and bounded - """ - self.conn = Connection(**CONN_INFO) - cleanup() # drop all databases with PREFIX - #test1.conn = self.conn - #self.conn.bind(test1.__name__, PREFIX+'_test1') - - #test2.conn = self.conn - - #test3.__dict__.pop('conn', None) # make sure conn is not defined in test3 - test1.__dict__.pop('conn', None) - schema1.__dict__.pop('conn', None) # make sure conn is not defined at schema level - - - def teardown(self): - cleanup() - - - def test_instantiation_from_unbound_module_should_fail(self): - """ - Attempting to instantiate a Relation derivative from a module with - connection defined but not bound to a database should raise error - """ - test1.conn = self.conn - with assert_raises(DataJointError) as e: - test1.Subjects() - assert_regexp_matches(e.exception.args[0], r".*not bound.*") - - def test_instantiation_from_module_without_conn_should_fail(self): - """ - Attempting to instantiate a Relation derivative from a module that lacks - `conn` object should raise error - """ - with assert_raises(DataJointError) as e: - test1.Subjects() - assert_regexp_matches(e.exception.args[0], r".*define.*conn.*") - - def test_instantiation_of_base_derivatives(self): - """ - Test instantiation and initialization of objects derived from - Relation class - """ - test1.conn = self.conn - self.conn.bind(test1.__name__, PREFIX + '_test1') - s = test1.Subjects() - assert_equal(s.dbname, PREFIX + '_test1') - assert_equal(s.conn, self.conn) - assert_equal(s.definition, test1.Subjects.definition) - - def test_packagelevel_binding(self): - schema2.conn = self.conn - self.conn.bind(schema2.__name__, PREFIX + '_test1') - s = schema2.test1.Subjects() - - -class TestRelationDeclaration(object): - """ - Test declaration (creation of table) from - definition in Relation under various circumstances - """ - - def setup(self): - cleanup() - - self.conn = Connection(**CONN_INFO) - test1.conn = self.conn - self.conn.bind(test1.__name__, PREFIX + '_test1') - test2.conn = self.conn - self.conn.bind(test2.__name__, PREFIX + '_test2') - - def test_is_declared(self): - """ - The table should not be created immediately after instantiation, - but should be created when declare method is called - :return: - """ - s = test1.Subjects() - assert_false(s.is_declared) - s.declare() - assert_true(s.is_declared) - - def test_calling_heading_should_trigger_declaration(self): - s = test1.Subjects() - assert_false(s.is_declared) - a = s.heading - assert_true(s.is_declared) - - def test_foreign_key_ref_in_same_schema(self): - s = test1.Experiments() - assert_true('subject_id' in s.heading.primary_key) - - def test_foreign_key_ref_in_another_schema(self): - s = test2.Experiments() - assert_true('subject_id' in s.heading.primary_key) - - def test_aliased_module_name_should_resolve(self): - """ - Module names that were aliased in the definition should - be properly resolved. - """ - s = test2.Conditions() - assert_true('subject_id' in s.heading.primary_key) - - def test_reference_to_unknown_module_in_definition_should_fail(self): - """ - Module names in table definition that is not aliased via import - results in error - """ - s = test2.FoodPreference() - with assert_raises(DataJointError) as e: - s.declare() - - -class TestRelationWithExistingTables(object): - """ - Test base derivatives behaviors when some of the tables - already exists in the database - """ - def setup(self): - cleanup() - self.conn = Connection(**CONN_INFO) - setup_sample_db() - test1.conn = self.conn - self.conn.bind(test1.__name__, PREFIX + '_test1') - test2.conn = self.conn - self.conn.bind(test2.__name__, PREFIX + '_test2') - self.conn.load_headings(force=True) - - schema2.conn = self.conn - self.conn.bind(schema2.__name__, PREFIX + '_package') - - def teardown(selfself): - schema1.__dict__.pop('conn', None) - cleanup() - - def test_detection_of_existing_table(self): - """ - The Relation instance should be able to detect if the - corresponding table already exists in the database - """ - s = test1.Subjects() - assert_true(s.is_declared) - - def test_definition_referring_to_existing_table_without_class(self): - s1 = test1.Sessions() - assert_true('experimenter_id' in s1.primary_key) - - s2 = test2.Session() - assert_true('experimenter_id' in s2.primary_key) - - def test_reference_to_package_level_table(self): - s = test1.Match() - s.declare() - assert_true('pop_id' in s.primary_key) - - def test_direct_reference_to_existing_table_should_fail(self): - """ - When deriving from Relation, definition should not contain direct reference - to a database name - """ - s = test1.TrainingSession() - with assert_raises(DataJointError): - s.declare() - -@raises(TypeError) -def test_instantiation_of_base_derivative_without_definition_should_fail(): - test1.Empty() - - - - +# """ +# Collection of test cases for base module. Tests functionalities such as +# creating tables using docstring table declarations +# """ +# from .schemata import schema1, schema2 +# from .schemata.schema1 import test1, test2, test3 +# +# +# __author__ = 'eywalker' +# +# from . import BASE_CONN, CONN_INFO, PREFIX, cleanup, setup_sample_db +# from datajoint.connection import Connection +# from nose.tools import assert_raises, assert_equal, assert_regexp_matches, assert_false, assert_true, raises +# from datajoint import DataJointError +# +# +# def setup(): +# """ +# Setup connections and bindings +# """ +# pass +# +# +# class TestRelationInstantiations(object): +# """ +# Test cases for instantiating Relation objects +# """ +# def __init__(self): +# self.conn = None +# +# def setup(self): +# """ +# Create a connection object and prepare test modules +# as follows: +# test1 - has conn and bounded +# """ +# self.conn = Connection(**CONN_INFO) +# cleanup() # drop all databases with PREFIX +# #test1.conn = self.conn +# #self.conn.bind(test1.__name__, PREFIX+'_test1') +# +# #test2.conn = self.conn +# +# #test3.__dict__.pop('conn', None) # make sure conn is not defined in test3 +# test1.__dict__.pop('conn', None) +# schema1.__dict__.pop('conn', None) # make sure conn is not defined at schema level +# +# +# def teardown(self): +# cleanup() +# +# +# def test_instantiation_from_unbound_module_should_fail(self): +# """ +# Attempting to instantiate a Relation derivative from a module with +# connection defined but not bound to a database should raise error +# """ +# test1.conn = self.conn +# with assert_raises(DataJointError) as e: +# test1.Subjects() +# assert_regexp_matches(e.exception.args[0], r".*not bound.*") +# +# def test_instantiation_from_module_without_conn_should_fail(self): +# """ +# Attempting to instantiate a Relation derivative from a module that lacks +# `conn` object should raise error +# """ +# with assert_raises(DataJointError) as e: +# test1.Subjects() +# assert_regexp_matches(e.exception.args[0], r".*define.*conn.*") +# +# def test_instantiation_of_base_derivatives(self): +# """ +# Test instantiation and initialization of objects derived from +# Relation class +# """ +# test1.conn = self.conn +# self.conn.bind(test1.__name__, PREFIX + '_test1') +# s = test1.Subjects() +# assert_equal(s.dbname, PREFIX + '_test1') +# assert_equal(s.conn, self.conn) +# assert_equal(s.definition, test1.Subjects.definition) +# +# def test_packagelevel_binding(self): +# schema2.conn = self.conn +# self.conn.bind(schema2.__name__, PREFIX + '_test1') +# s = schema2.test1.Subjects() +# +# +# class TestRelationDeclaration(object): +# """ +# Test declaration (creation of table) from +# definition in Relation under various circumstances +# """ +# +# def setup(self): +# cleanup() +# +# self.conn = Connection(**CONN_INFO) +# test1.conn = self.conn +# self.conn.bind(test1.__name__, PREFIX + '_test1') +# test2.conn = self.conn +# self.conn.bind(test2.__name__, PREFIX + '_test2') +# +# def test_is_declared(self): +# """ +# The table should not be created immediately after instantiation, +# but should be created when declare method is called +# :return: +# """ +# s = test1.Subjects() +# assert_false(s.is_declared) +# s.declare() +# assert_true(s.is_declared) +# +# def test_calling_heading_should_trigger_declaration(self): +# s = test1.Subjects() +# assert_false(s.is_declared) +# a = s.heading +# assert_true(s.is_declared) +# +# def test_foreign_key_ref_in_same_schema(self): +# s = test1.Experiments() +# assert_true('subject_id' in s.heading.primary_key) +# +# def test_foreign_key_ref_in_another_schema(self): +# s = test2.Experiments() +# assert_true('subject_id' in s.heading.primary_key) +# +# def test_aliased_module_name_should_resolve(self): +# """ +# Module names that were aliased in the definition should +# be properly resolved. +# """ +# s = test2.Conditions() +# assert_true('subject_id' in s.heading.primary_key) +# +# def test_reference_to_unknown_module_in_definition_should_fail(self): +# """ +# Module names in table definition that is not aliased via import +# results in error +# """ +# s = test2.FoodPreference() +# with assert_raises(DataJointError) as e: +# s.declare() +# +# +# class TestRelationWithExistingTables(object): +# """ +# Test base derivatives behaviors when some of the tables +# already exists in the database +# """ +# def setup(self): +# cleanup() +# self.conn = Connection(**CONN_INFO) +# setup_sample_db() +# test1.conn = self.conn +# self.conn.bind(test1.__name__, PREFIX + '_test1') +# test2.conn = self.conn +# self.conn.bind(test2.__name__, PREFIX + '_test2') +# self.conn.load_headings(force=True) +# +# schema2.conn = self.conn +# self.conn.bind(schema2.__name__, PREFIX + '_package') +# +# def teardown(selfself): +# schema1.__dict__.pop('conn', None) +# cleanup() +# +# def test_detection_of_existing_table(self): +# """ +# The Relation instance should be able to detect if the +# corresponding table already exists in the database +# """ +# s = test1.Subjects() +# assert_true(s.is_declared) +# +# def test_definition_referring_to_existing_table_without_class(self): +# s1 = test1.Sessions() +# assert_true('experimenter_id' in s1.primary_key) +# +# s2 = test2.Session() +# assert_true('experimenter_id' in s2.primary_key) +# +# def test_reference_to_package_level_table(self): +# s = test1.Match() +# s.declare() +# assert_true('pop_id' in s.primary_key) +# +# def test_direct_reference_to_existing_table_should_fail(self): +# """ +# When deriving from Relation, definition should not contain direct reference +# to a database name +# """ +# s = test1.TrainingSession() +# with assert_raises(DataJointError): +# s.declare() +# +# @raises(TypeError) +# def test_instantiation_of_base_derivative_without_definition_should_fail(): +# test1.Empty() +# +# +# +# diff --git a/tests/test_relation.py b/tests/test_relation.py index 15f52f34a..94adef010 100644 --- a/tests/test_relation.py +++ b/tests/test_relation.py @@ -1,490 +1,490 @@ -import random -import string - -__author__ = 'fabee' - -from .schemata.schema1 import test1, test4 - -from . import BASE_CONN, CONN_INFO, PREFIX, cleanup -from datajoint.connection import Connection -from nose.tools import assert_raises, assert_equal, assert_regexp_matches, assert_false, assert_true, assert_list_equal,\ - assert_tuple_equal, assert_dict_equal, raises -from datajoint import DataJointError, TransactionError, AutoPopulate, Relation -import numpy as np -from numpy.testing import assert_array_equal -from datajoint.relation import FreeRelation -import numpy as np - - -def trial_faker(n=10): - def iter(): - for s in [1, 2]: - for i in range(n): - yield dict(trial_id=i, subject_id=s, outcome=int(np.random.randint(10)), notes= 'no comment') - return iter() - - -def setup(): - """ - Setup connections and bindings - """ - pass - - -class TestTableObject(object): - def __init__(self): - self.subjects = None - self.setup() - - """ - Test cases for FreeRelation objects - """ - - def setup(self): - """ - Create a connection object and prepare test modules - as follows: - test1 - has conn and bounded - """ - cleanup() # drop all databases with PREFIX - test1.__dict__.pop('conn', None) - test4.__dict__.pop('conn', None) # make sure conn is not defined at schema level - - self.conn = Connection(**CONN_INFO) - test1.conn = self.conn - test4.conn = self.conn - self.conn.bind(test1.__name__, PREFIX + '_test1') - self.conn.bind(test4.__name__, PREFIX + '_test4') - self.subjects = test1.Subjects() - self.animals = test1.Animals() - self.relvar_blob = test4.Matrix() - self.trials = test1.Trials() - - def teardown(self): - cleanup() - - def test_compound_restriction(self): - s = self.subjects - t = self.trials - - s.insert(dict(subject_id=1, real_id='M')) - s.insert(dict(subject_id=2, real_id='F')) - t.iter_insert(trial_faker(20)) - - tM = t & (s & "real_id = 'M'") - t1 = t & "subject_id = 1" - - assert_equal(len(tM), len(t1), "Results of compound request does not have same length") - - for t1_item, tM_item in zip(sorted(t1, key=lambda item: item['trial_id']), - sorted(tM, key=lambda item: item['trial_id'])): - assert_dict_equal(t1_item, tM_item, - 'Dictionary elements do not agree in compound statement') - - def test_record_insert(self): - "Test whether record insert works" - tmp = np.array([(2, 'Klara', 'monkey')], - dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) - - self.subjects.insert(tmp[0]) - testt2 = (self.subjects & 'subject_id = 2').fetch()[0] - assert_equal(tuple(tmp[0]), tuple(testt2), "Inserted and fetched record do not match!") - - def test_delete(self): - "Test whether delete works" - tmp = np.array([(2, 'Klara', 'monkey'), (1,'Peter', 'mouse')], - dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) - - self.subjects.batch_insert(tmp) - assert_true(len(self.subjects) == 2, 'Length does not match 2.') - self.subjects.delete() - assert_true(len(self.subjects) == 0, 'Length does not match 0.') - - # def test_cascading_delete(self): - # "Test whether delete works" - # tmp = np.array([(2, 'Klara', 'monkey'), (1,'Peter', 'mouse')], - # dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) - # - # self.subjects.batch_insert(tmp) - # - # self.trials.insert(dict(subject_id=1, trial_id=1, outcome=0)) - # self.trials.insert(dict(subject_id=1, trial_id=2, outcome=1)) - # self.trials.insert(dict(subject_id=2, trial_id=3, outcome=2)) - # assert_true(len(self.subjects) == 2, 'Length does not match 2.') - # assert_true(len(self.trials) == 3, 'Length does not match 3.') - # (self.subjects & 'subject_id=1').delete() - # assert_true(len(self.subjects) == 1, 'Length does not match 1.') - # assert_true(len(self.trials) == 1, 'Length does not match 1.') - - def test_short_hand_foreign_reference(self): - self.animals.heading - - - - def test_record_insert_different_order(self): - "Test whether record insert works" - tmp = np.array([('Klara', 2, 'monkey')], - dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) - - self.subjects.insert(tmp[0]) - testt2 = (self.subjects & 'subject_id = 2').fetch()[0] - assert_equal((2, 'Klara', 'monkey'), tuple(testt2), - "Inserted and fetched record do not match!") - - @raises(TransactionError) - def test_transaction_error(self): - "Test whether declaration in transaction is prohibited" - - tmp = np.array([('Klara', 2, 'monkey')], - dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) - self.conn.start_transaction() - self.subjects.insert(tmp[0]) - - # def test_transaction_suppress_error(self): - # "Test whether ignore_errors ignores the errors." - # - # tmp = np.array([('Klara', 2, 'monkey')], - # dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) - # with self.conn.transaction(ignore_errors=True) as tr: - # self.subjects.insert(tmp[0]) - - - @raises(TransactionError) - def test_transaction_error_not_resolve(self): - "Test whether declaration in transaction is prohibited" - - tmp = np.array([('Klara', 2, 'monkey'), ('Klara', 3, 'monkey')], - dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) - try: - self.conn.start_transaction() - self.subjects.insert(tmp[0]) - except TransactionError as te: - self.conn.cancel_transaction() - - self.conn.start_transaction() - self.subjects.insert(tmp[0]) - - def test_transaction_error_resolve(self): - "Test whether declaration in transaction is prohibited" - - tmp = np.array([('Klara', 2, 'monkey'), ('Klara', 3, 'monkey')], - dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) - try: - self.conn.start_transaction() - self.subjects.insert(tmp[0]) - except TransactionError as te: - self.conn.cancel_transaction() - te.resolve() - - self.conn.start_transaction() - self.subjects.insert(tmp[0]) - self.conn.commit_transaction() - - def test_transaction_error2(self): - "If table is declared, we are allowed to insert within a transaction" - - tmp = np.array([('Klara', 2, 'monkey'), ('Klara', 3, 'monkey')], - dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) - self.subjects.insert(tmp[0]) - - self.conn.start_transaction() - self.subjects.insert(tmp[1]) - self.conn.commit_transaction() - - - @raises(KeyError) - def test_wrong_key_insert_records(self): - "Test whether record insert works" - tmp = np.array([('Klara', 2, 'monkey')], - dtype=[('real_deal', 'O'), ('subject_id', '>i4'), ('species', 'O')]) - - self.subjects.insert(tmp[0]) - - - def test_dict_insert(self): - "Test whether record insert works" - tmp = {'real_id': 'Brunhilda', - 'subject_id': 3, - 'species': 'human'} - - self.subjects.insert(tmp) - testt2 = (self.subjects & 'subject_id = 3').fetch()[0] - assert_equal((3, 'Brunhilda', 'human'), tuple(testt2), "Inserted and fetched record do not match!") - - @raises(KeyError) - def test_wrong_key_insert(self): - "Test whether a correct error is generated when inserting wrong attribute name" - tmp = {'real_deal': 'Brunhilda', - 'subject_database': 3, - 'species': 'human'} - - self.subjects.insert(tmp) - - def test_batch_insert(self): - "Test whether record insert works" - tmp = np.array([('Klara', 2, 'monkey'), ('Brunhilda', 3, 'mouse'), ('Mickey', 1, 'human')], - dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) - - self.subjects.batch_insert(tmp) - - expected = np.array([(1, 'Mickey', 'human'), (2, 'Klara', 'monkey'), - (3, 'Brunhilda', 'mouse')], - dtype=[('subject_id', 'i4'), ('species', 'O')]) - - self.subjects.iter_insert(tmp.__iter__()) - - expected = np.array([(1, 'Mickey', 'human'), (2, 'Klara', 'monkey'), - (3, 'Brunhilda', 'mouse')], - dtype=[('subject_id', ' `dj_free`.Animals - rec_session_id : int # recording session identifier - """ - table = FreeRelation(self.conn, 'dj_free', 'Recordings', definition) - assert_raises(DataJointError, table.declare) - - def test_reference_to_existing_table(self): - definition1 = """ - `dj_free`.Animals (manual) # my animal table - animal_id : int # unique id for the animal - --- - animal_name : varchar(128) # name of the animal - """ - table1 = FreeRelation(self.conn, 'dj_free', 'Animals', definition1) - table1.declare() - - definition2 = """ - `dj_free`.Recordings (manual) # recordings - -> `dj_free`.Animals - rec_session_id : int # recording session identifier - """ - table2 = FreeRelation(self.conn, 'dj_free', 'Recordings', definition2) - table2.declare() - assert_true('animal_id' in table2.primary_key) - - -def id_generator(size=6, chars=string.ascii_uppercase + string.digits): - return ''.join(random.choice(chars) for _ in range(size)) - -class TestIterator(object): - def __init__(self): - self.relvar = None - self.setup() - - """ - Test cases for Iterators in Relations objects - """ - - def setup(self): - """ - Create a connection object and prepare test modules - as follows: - test1 - has conn and bounded - """ - cleanup() # drop all databases with PREFIX - test4.__dict__.pop('conn', None) # make sure conn is not defined at schema level - - self.conn = Connection(**CONN_INFO) - test4.conn = self.conn - self.conn.bind(test4.__name__, PREFIX + '_test4') - self.relvar_blob = test4.Matrix() - - def teardown(self): - cleanup() - - - def test_blob_iteration(self): - "Tests the basic call of the iterator" - - dicts = [] - for i in range(10): - - c = id_generator() - - t = {'matrix_id':i, - 'data': np.random.randn(4,4,4), - 'comment': c} - self.relvar_blob.insert(t) - dicts.append(t) - - for t, t2 in zip(dicts, self.relvar_blob): - assert_true(isinstance(t2, dict), 'iterator does not return dict') - - assert_equal(t['matrix_id'], t2['matrix_id'], 'inserted and retrieved tuples do not match') - assert_equal(t['comment'], t2['comment'], 'inserted and retrieved tuples do not match') - assert_true(np.all(t['data'] == t2['data']), 'inserted and retrieved tuples do not match') - - def test_fetch(self): - dicts = [] - for i in range(10): - - c = id_generator() - - t = {'matrix_id':i, - 'data': np.random.randn(4,4,4), - 'comment': c} - self.relvar_blob.insert(t) - dicts.append(t) - - tuples2 = self.relvar_blob.fetch() - assert_true(isinstance(tuples2, np.ndarray), "Return value of fetch does not have proper type.") - assert_true(isinstance(tuples2[0], np.void), "Return value of fetch does not have proper type.") - for t, t2 in zip(dicts, tuples2): - - assert_equal(t['matrix_id'], t2['matrix_id'], 'inserted and retrieved tuples do not match') - assert_equal(t['comment'], t2['comment'], 'inserted and retrieved tuples do not match') - assert_true(np.all(t['data'] == t2['data']), 'inserted and retrieved tuples do not match') - - def test_fetch_dicts(self): - dicts = [] - for i in range(10): - - c = id_generator() - - t = {'matrix_id':i, - 'data': np.random.randn(4,4,4), - 'comment': c} - self.relvar_blob.insert(t) - dicts.append(t) - - tuples2 = self.relvar_blob.fetch(as_dict=True) - assert_true(isinstance(tuples2, list), "Return value of fetch with as_dict=True does not have proper type.") - assert_true(isinstance(tuples2[0], dict), "Return value of fetch with as_dict=True does not have proper type.") - for t, t2 in zip(dicts, tuples2): - assert_equal(t['matrix_id'], t2['matrix_id'], 'inserted and retrieved dicts do not match') - assert_equal(t['comment'], t2['comment'], 'inserted and retrieved dicts do not match') - assert_true(np.all(t['data'] == t2['data']), 'inserted and retrieved dicts do not match') - - - -class TestAutopopulate(object): - def __init__(self): - self.relvar = None - self.setup() - - """ - Test cases for Iterators in Relations objects - """ - - def setup(self): - """ - Create a connection object and prepare test modules - as follows: - test1 - has conn and bounded - """ - cleanup() # drop all databases with PREFIX - test1.__dict__.pop('conn', None) # make sure conn is not defined at schema level - - self.conn = Connection(**CONN_INFO) - test1.conn = self.conn - self.conn.bind(test1.__name__, PREFIX + '_test1') - - self.subjects = test1.Subjects() - self.trials = test1.Trials() - self.squared = test1.SquaredScore() - self.dummy = test1.SquaredSubtable() - self.dummy1 = test1.WrongImplementation() - self.error_generator = test1.ErrorGenerator() - self.fill_relation() - - - - def fill_relation(self): - tmp = np.array([('Klara', 2, 'monkey'), ('Peter', 3, 'mouse')], - dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) - self.subjects.batch_insert(tmp) - - for trial_id in range(1,11): - self.trials.insert(dict(subject_id=2, trial_id=trial_id, outcome=np.random.randint(0,10))) - - def teardown(self): - cleanup() - - def test_autopopulate(self): - self.squared.populate() - assert_equal(len(self.squared), 10) - - for trial in self.trials*self.squared: - assert_equal(trial['outcome']**2, trial['squared']) - - def test_autopopulate_restriction(self): - self.squared.populate(restriction='trial_id <= 5') - assert_equal(len(self.squared), 5) - - for trial in self.trials*self.squared: - assert_equal(trial['outcome']**2, trial['squared']) - - - # def test_autopopulate_transaction_error(self): - # errors = self.squared.populate(suppress_errors=True) - # assert_equal(len(errors), 1) - # assert_true(isinstance(errors[0][1], TransactionError)) - - @raises(DataJointError) - def test_autopopulate_relation_check(self): - - class dummy(AutoPopulate): - - def populate_relation(self): - return None - - def _make_tuples(self, key): - pass - - du = dummy() - du.populate() \ - - @raises(DataJointError) - def test_autopopulate_relation_check(self): - self.dummy1.populate() - - @raises(Exception) - def test_autopopulate_relation_check(self): - self.error_generator.populate()\ - - @raises(Exception) - def test_autopopulate_relation_check2(self): - tmp = self.dummy2.populate(suppress_errors=True) - assert_equal(len(tmp), 1, 'Error list should have length 1.') +# import random +# import string +# +# __author__ = 'fabee' +# +# from .schemata.schema1 import test1, test4 +# +# from . import BASE_CONN, CONN_INFO, PREFIX, cleanup +# from datajoint.connection import Connection +# from nose.tools import assert_raises, assert_equal, assert_regexp_matches, assert_false, assert_true, assert_list_equal,\ +# assert_tuple_equal, assert_dict_equal, raises +# from datajoint import DataJointError, TransactionError, AutoPopulate, Relation +# import numpy as np +# from numpy.testing import assert_array_equal +# from datajoint.relation import FreeRelation +# import numpy as np +# +# +# def trial_faker(n=10): +# def iter(): +# for s in [1, 2]: +# for i in range(n): +# yield dict(trial_id=i, subject_id=s, outcome=int(np.random.randint(10)), notes= 'no comment') +# return iter() +# +# +# def setup(): +# """ +# Setup connections and bindings +# """ +# pass +# +# +# class TestTableObject(object): +# def __init__(self): +# self.subjects = None +# self.setup() +# +# """ +# Test cases for FreeRelation objects +# """ +# +# def setup(self): +# """ +# Create a connection object and prepare test modules +# as follows: +# test1 - has conn and bounded +# """ +# cleanup() # drop all databases with PREFIX +# test1.__dict__.pop('conn', None) +# test4.__dict__.pop('conn', None) # make sure conn is not defined at schema level +# +# self.conn = Connection(**CONN_INFO) +# test1.conn = self.conn +# test4.conn = self.conn +# self.conn.bind(test1.__name__, PREFIX + '_test1') +# self.conn.bind(test4.__name__, PREFIX + '_test4') +# self.subjects = test1.Subjects() +# self.animals = test1.Animals() +# self.relvar_blob = test4.Matrix() +# self.trials = test1.Trials() +# +# def teardown(self): +# cleanup() +# +# def test_compound_restriction(self): +# s = self.subjects +# t = self.trials +# +# s.insert(dict(subject_id=1, real_id='M')) +# s.insert(dict(subject_id=2, real_id='F')) +# t.iter_insert(trial_faker(20)) +# +# tM = t & (s & "real_id = 'M'") +# t1 = t & "subject_id = 1" +# +# assert_equal(len(tM), len(t1), "Results of compound request does not have same length") +# +# for t1_item, tM_item in zip(sorted(t1, key=lambda item: item['trial_id']), +# sorted(tM, key=lambda item: item['trial_id'])): +# assert_dict_equal(t1_item, tM_item, +# 'Dictionary elements do not agree in compound statement') +# +# def test_record_insert(self): +# "Test whether record insert works" +# tmp = np.array([(2, 'Klara', 'monkey')], +# dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) +# +# self.subjects.insert(tmp[0]) +# testt2 = (self.subjects & 'subject_id = 2').fetch()[0] +# assert_equal(tuple(tmp[0]), tuple(testt2), "Inserted and fetched record do not match!") +# +# def test_delete(self): +# "Test whether delete works" +# tmp = np.array([(2, 'Klara', 'monkey'), (1,'Peter', 'mouse')], +# dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) +# +# self.subjects.batch_insert(tmp) +# assert_true(len(self.subjects) == 2, 'Length does not match 2.') +# self.subjects.delete() +# assert_true(len(self.subjects) == 0, 'Length does not match 0.') +# +# # def test_cascading_delete(self): +# # "Test whether delete works" +# # tmp = np.array([(2, 'Klara', 'monkey'), (1,'Peter', 'mouse')], +# # dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) +# # +# # self.subjects.batch_insert(tmp) +# # +# # self.trials.insert(dict(subject_id=1, trial_id=1, outcome=0)) +# # self.trials.insert(dict(subject_id=1, trial_id=2, outcome=1)) +# # self.trials.insert(dict(subject_id=2, trial_id=3, outcome=2)) +# # assert_true(len(self.subjects) == 2, 'Length does not match 2.') +# # assert_true(len(self.trials) == 3, 'Length does not match 3.') +# # (self.subjects & 'subject_id=1').delete() +# # assert_true(len(self.subjects) == 1, 'Length does not match 1.') +# # assert_true(len(self.trials) == 1, 'Length does not match 1.') +# +# def test_short_hand_foreign_reference(self): +# self.animals.heading +# +# +# +# def test_record_insert_different_order(self): +# "Test whether record insert works" +# tmp = np.array([('Klara', 2, 'monkey')], +# dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) +# +# self.subjects.insert(tmp[0]) +# testt2 = (self.subjects & 'subject_id = 2').fetch()[0] +# assert_equal((2, 'Klara', 'monkey'), tuple(testt2), +# "Inserted and fetched record do not match!") +# +# @raises(TransactionError) +# def test_transaction_error(self): +# "Test whether declaration in transaction is prohibited" +# +# tmp = np.array([('Klara', 2, 'monkey')], +# dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) +# self.conn.start_transaction() +# self.subjects.insert(tmp[0]) +# +# # def test_transaction_suppress_error(self): +# # "Test whether ignore_errors ignores the errors." +# # +# # tmp = np.array([('Klara', 2, 'monkey')], +# # dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) +# # with self.conn.transaction(ignore_errors=True) as tr: +# # self.subjects.insert(tmp[0]) +# +# +# @raises(TransactionError) +# def test_transaction_error_not_resolve(self): +# "Test whether declaration in transaction is prohibited" +# +# tmp = np.array([('Klara', 2, 'monkey'), ('Klara', 3, 'monkey')], +# dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) +# try: +# self.conn.start_transaction() +# self.subjects.insert(tmp[0]) +# except TransactionError as te: +# self.conn.cancel_transaction() +# +# self.conn.start_transaction() +# self.subjects.insert(tmp[0]) +# +# def test_transaction_error_resolve(self): +# "Test whether declaration in transaction is prohibited" +# +# tmp = np.array([('Klara', 2, 'monkey'), ('Klara', 3, 'monkey')], +# dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) +# try: +# self.conn.start_transaction() +# self.subjects.insert(tmp[0]) +# except TransactionError as te: +# self.conn.cancel_transaction() +# te.resolve() +# +# self.conn.start_transaction() +# self.subjects.insert(tmp[0]) +# self.conn.commit_transaction() +# +# def test_transaction_error2(self): +# "If table is declared, we are allowed to insert within a transaction" +# +# tmp = np.array([('Klara', 2, 'monkey'), ('Klara', 3, 'monkey')], +# dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) +# self.subjects.insert(tmp[0]) +# +# self.conn.start_transaction() +# self.subjects.insert(tmp[1]) +# self.conn.commit_transaction() +# +# +# @raises(KeyError) +# def test_wrong_key_insert_records(self): +# "Test whether record insert works" +# tmp = np.array([('Klara', 2, 'monkey')], +# dtype=[('real_deal', 'O'), ('subject_id', '>i4'), ('species', 'O')]) +# +# self.subjects.insert(tmp[0]) +# +# +# def test_dict_insert(self): +# "Test whether record insert works" +# tmp = {'real_id': 'Brunhilda', +# 'subject_id': 3, +# 'species': 'human'} +# +# self.subjects.insert(tmp) +# testt2 = (self.subjects & 'subject_id = 3').fetch()[0] +# assert_equal((3, 'Brunhilda', 'human'), tuple(testt2), "Inserted and fetched record do not match!") +# +# @raises(KeyError) +# def test_wrong_key_insert(self): +# "Test whether a correct error is generated when inserting wrong attribute name" +# tmp = {'real_deal': 'Brunhilda', +# 'subject_database': 3, +# 'species': 'human'} +# +# self.subjects.insert(tmp) +# +# def test_batch_insert(self): +# "Test whether record insert works" +# tmp = np.array([('Klara', 2, 'monkey'), ('Brunhilda', 3, 'mouse'), ('Mickey', 1, 'human')], +# dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) +# +# self.subjects.batch_insert(tmp) +# +# expected = np.array([(1, 'Mickey', 'human'), (2, 'Klara', 'monkey'), +# (3, 'Brunhilda', 'mouse')], +# dtype=[('subject_id', 'i4'), ('species', 'O')]) +# +# self.subjects.iter_insert(tmp.__iter__()) +# +# expected = np.array([(1, 'Mickey', 'human'), (2, 'Klara', 'monkey'), +# (3, 'Brunhilda', 'mouse')], +# dtype=[('subject_id', ' `dj_free`.Animals +# rec_session_id : int # recording session identifier +# """ +# table = FreeRelation(self.conn, 'dj_free', 'Recordings', definition) +# assert_raises(DataJointError, table.declare) +# +# def test_reference_to_existing_table(self): +# definition1 = """ +# `dj_free`.Animals (manual) # my animal table +# animal_id : int # unique id for the animal +# --- +# animal_name : varchar(128) # name of the animal +# """ +# table1 = FreeRelation(self.conn, 'dj_free', 'Animals', definition1) +# table1.declare() +# +# definition2 = """ +# `dj_free`.Recordings (manual) # recordings +# -> `dj_free`.Animals +# rec_session_id : int # recording session identifier +# """ +# table2 = FreeRelation(self.conn, 'dj_free', 'Recordings', definition2) +# table2.declare() +# assert_true('animal_id' in table2.primary_key) +# +# +# def id_generator(size=6, chars=string.ascii_uppercase + string.digits): +# return ''.join(random.choice(chars) for _ in range(size)) +# +# class TestIterator(object): +# def __init__(self): +# self.relvar = None +# self.setup() +# +# """ +# Test cases for Iterators in Relations objects +# """ +# +# def setup(self): +# """ +# Create a connection object and prepare test modules +# as follows: +# test1 - has conn and bounded +# """ +# cleanup() # drop all databases with PREFIX +# test4.__dict__.pop('conn', None) # make sure conn is not defined at schema level +# +# self.conn = Connection(**CONN_INFO) +# test4.conn = self.conn +# self.conn.bind(test4.__name__, PREFIX + '_test4') +# self.relvar_blob = test4.Matrix() +# +# def teardown(self): +# cleanup() +# +# +# def test_blob_iteration(self): +# "Tests the basic call of the iterator" +# +# dicts = [] +# for i in range(10): +# +# c = id_generator() +# +# t = {'matrix_id':i, +# 'data': np.random.randn(4,4,4), +# 'comment': c} +# self.relvar_blob.insert(t) +# dicts.append(t) +# +# for t, t2 in zip(dicts, self.relvar_blob): +# assert_true(isinstance(t2, dict), 'iterator does not return dict') +# +# assert_equal(t['matrix_id'], t2['matrix_id'], 'inserted and retrieved tuples do not match') +# assert_equal(t['comment'], t2['comment'], 'inserted and retrieved tuples do not match') +# assert_true(np.all(t['data'] == t2['data']), 'inserted and retrieved tuples do not match') +# +# def test_fetch(self): +# dicts = [] +# for i in range(10): +# +# c = id_generator() +# +# t = {'matrix_id':i, +# 'data': np.random.randn(4,4,4), +# 'comment': c} +# self.relvar_blob.insert(t) +# dicts.append(t) +# +# tuples2 = self.relvar_blob.fetch() +# assert_true(isinstance(tuples2, np.ndarray), "Return value of fetch does not have proper type.") +# assert_true(isinstance(tuples2[0], np.void), "Return value of fetch does not have proper type.") +# for t, t2 in zip(dicts, tuples2): +# +# assert_equal(t['matrix_id'], t2['matrix_id'], 'inserted and retrieved tuples do not match') +# assert_equal(t['comment'], t2['comment'], 'inserted and retrieved tuples do not match') +# assert_true(np.all(t['data'] == t2['data']), 'inserted and retrieved tuples do not match') +# +# def test_fetch_dicts(self): +# dicts = [] +# for i in range(10): +# +# c = id_generator() +# +# t = {'matrix_id':i, +# 'data': np.random.randn(4,4,4), +# 'comment': c} +# self.relvar_blob.insert(t) +# dicts.append(t) +# +# tuples2 = self.relvar_blob.fetch(as_dict=True) +# assert_true(isinstance(tuples2, list), "Return value of fetch with as_dict=True does not have proper type.") +# assert_true(isinstance(tuples2[0], dict), "Return value of fetch with as_dict=True does not have proper type.") +# for t, t2 in zip(dicts, tuples2): +# assert_equal(t['matrix_id'], t2['matrix_id'], 'inserted and retrieved dicts do not match') +# assert_equal(t['comment'], t2['comment'], 'inserted and retrieved dicts do not match') +# assert_true(np.all(t['data'] == t2['data']), 'inserted and retrieved dicts do not match') +# +# +# +# class TestAutopopulate(object): +# def __init__(self): +# self.relvar = None +# self.setup() +# +# """ +# Test cases for Iterators in Relations objects +# """ +# +# def setup(self): +# """ +# Create a connection object and prepare test modules +# as follows: +# test1 - has conn and bounded +# """ +# cleanup() # drop all databases with PREFIX +# test1.__dict__.pop('conn', None) # make sure conn is not defined at schema level +# +# self.conn = Connection(**CONN_INFO) +# test1.conn = self.conn +# self.conn.bind(test1.__name__, PREFIX + '_test1') +# +# self.subjects = test1.Subjects() +# self.trials = test1.Trials() +# self.squared = test1.SquaredScore() +# self.dummy = test1.SquaredSubtable() +# self.dummy1 = test1.WrongImplementation() +# self.error_generator = test1.ErrorGenerator() +# self.fill_relation() +# +# +# +# def fill_relation(self): +# tmp = np.array([('Klara', 2, 'monkey'), ('Peter', 3, 'mouse')], +# dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) +# self.subjects.batch_insert(tmp) +# +# for trial_id in range(1,11): +# self.trials.insert(dict(subject_id=2, trial_id=trial_id, outcome=np.random.randint(0,10))) +# +# def teardown(self): +# cleanup() +# +# def test_autopopulate(self): +# self.squared.populate() +# assert_equal(len(self.squared), 10) +# +# for trial in self.trials*self.squared: +# assert_equal(trial['outcome']**2, trial['squared']) +# +# def test_autopopulate_restriction(self): +# self.squared.populate(restriction='trial_id <= 5') +# assert_equal(len(self.squared), 5) +# +# for trial in self.trials*self.squared: +# assert_equal(trial['outcome']**2, trial['squared']) +# +# +# # def test_autopopulate_transaction_error(self): +# # errors = self.squared.populate(suppress_errors=True) +# # assert_equal(len(errors), 1) +# # assert_true(isinstance(errors[0][1], TransactionError)) +# +# @raises(DataJointError) +# def test_autopopulate_relation_check(self): +# +# class dummy(AutoPopulate): +# +# def populate_relation(self): +# return None +# +# def _make_tuples(self, key): +# pass +# +# du = dummy() +# du.populate() \ +# +# @raises(DataJointError) +# def test_autopopulate_relation_check(self): +# self.dummy1.populate() +# +# @raises(Exception) +# def test_autopopulate_relation_check(self): +# self.error_generator.populate()\ +# +# @raises(Exception) +# def test_autopopulate_relation_check2(self): +# tmp = self.dummy2.populate(suppress_errors=True) +# assert_equal(len(tmp), 1, 'Error list should have length 1.') diff --git a/tests/test_relational_operand.py b/tests/test_relational_operand.py index 3a25a87d0..3ceeb4964 100644 --- a/tests/test_relational_operand.py +++ b/tests/test_relational_operand.py @@ -1,47 +1,47 @@ -""" -Collection of test cases to test relational methods -""" - -__author__ = 'eywalker' - - -def setup(): - """ - Setup - :return: - """ - -class TestRelationalAlgebra(object): - - def setup(self): - pass - - def test_mul(self): - pass - - def test_project(self): - pass - - def test_iand(self): - pass - - def test_isub(self): - pass - - def test_sub(self): - pass - - def test_len(self): - pass - - def test_fetch(self): - pass - - def test_repr(self): - pass - - def test_iter(self): - pass - - def test_not(self): - pass \ No newline at end of file +# """ +# Collection of test cases to test relational methods +# """ +# +# __author__ = 'eywalker' +# +# +# def setup(): +# """ +# Setup +# :return: +# """ +# +# class TestRelationalAlgebra(object): +# +# def setup(self): +# pass +# +# def test_mul(self): +# pass +# +# def test_project(self): +# pass +# +# def test_iand(self): +# pass +# +# def test_isub(self): +# pass +# +# def test_sub(self): +# pass +# +# def test_len(self): +# pass +# +# def test_fetch(self): +# pass +# +# def test_repr(self): +# pass +# +# def test_iter(self): +# pass +# +# def test_not(self): +# pass \ No newline at end of file diff --git a/tests/test_settings.py b/tests/test_settings.py index 6b8100806..24b05d5d4 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -1,67 +1,67 @@ -import os -import pprint -import random -import string -from datajoint import settings - -__author__ = 'Fabian Sinz' - -from nose.tools import assert_true, assert_raises, assert_equal, raises, assert_dict_equal -import datajoint as dj - - -def test_load_save(): - dj.config.save('tmp.json') - conf = dj.Config() - conf.load('tmp.json') - assert_true(conf == dj.config, 'Two config files do not match.') - os.remove('tmp.json') - -def test_singleton(): - dj.config.save('tmp.json') - conf = dj.Config() - conf.load('tmp.json') - conf['dummy.val'] = 2 - - assert_true(conf == dj.config, 'Config does not behave like a singleton.') - os.remove('tmp.json') - - -@raises(ValueError) -def test_nested_check(): - dummy = {'dummy.testval': {'notallowed': 2}} - dj.config.update(dummy) - -@raises(dj.DataJointError) -def test_validator(): - dj.config['database.port'] = 'harbor' - -def test_del(): - dj.config['peter'] = 2 - assert_true('peter' in dj.config) - del dj.config['peter'] - assert_true('peter' not in dj.config) - -def test_len(): - assert_equal(len(dj.config), len(dj.config._conf)) - -def test_str(): - assert_equal(str(dj.config), pprint.pformat(dj.config._conf, indent=4)) - -def test_repr(): - assert_equal(repr(dj.config), pprint.pformat(dj.config._conf, indent=4)) - -@raises(ValueError) -def test_nested_check2(): - dj.config['dummy'] = {'dummy2':2} - -def test_save(): - tmpfile = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(20)) - moved = False - if os.path.isfile(settings.LOCALCONFIG): - os.rename(settings.LOCALCONFIG, tmpfile) - moved = True - dj.config.save() - assert_true(os.path.isfile(settings.LOCALCONFIG)) - if moved: - os.rename(tmpfile, settings.LOCALCONFIG) +# import os +# import pprint +# import random +# import string +# from datajoint import settings +# +# __author__ = 'Fabian Sinz' +# +# from nose.tools import assert_true, assert_raises, assert_equal, raises, assert_dict_equal +# import datajoint as dj +# +# +# def test_load_save(): +# dj.config.save('tmp.json') +# conf = dj.Config() +# conf.load('tmp.json') +# assert_true(conf == dj.config, 'Two config files do not match.') +# os.remove('tmp.json') +# +# def test_singleton(): +# dj.config.save('tmp.json') +# conf = dj.Config() +# conf.load('tmp.json') +# conf['dummy.val'] = 2 +# +# assert_true(conf == dj.config, 'Config does not behave like a singleton.') +# os.remove('tmp.json') +# +# +# @raises(ValueError) +# def test_nested_check(): +# dummy = {'dummy.testval': {'notallowed': 2}} +# dj.config.update(dummy) +# +# @raises(dj.DataJointError) +# def test_validator(): +# dj.config['database.port'] = 'harbor' +# +# def test_del(): +# dj.config['peter'] = 2 +# assert_true('peter' in dj.config) +# del dj.config['peter'] +# assert_true('peter' not in dj.config) +# +# def test_len(): +# assert_equal(len(dj.config), len(dj.config._conf)) +# +# def test_str(): +# assert_equal(str(dj.config), pprint.pformat(dj.config._conf, indent=4)) +# +# def test_repr(): +# assert_equal(repr(dj.config), pprint.pformat(dj.config._conf, indent=4)) +# +# @raises(ValueError) +# def test_nested_check2(): +# dj.config['dummy'] = {'dummy2':2} +# +# def test_save(): +# tmpfile = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(20)) +# moved = False +# if os.path.isfile(settings.LOCALCONFIG): +# os.rename(settings.LOCALCONFIG, tmpfile) +# moved = True +# dj.config.save() +# assert_true(os.path.isfile(settings.LOCALCONFIG)) +# if moved: +# os.rename(tmpfile, settings.LOCALCONFIG) diff --git a/tests/test_utils.py b/tests/test_utils.py index 655884ce0..2322cf18a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,33 +1,33 @@ -""" -Collection of test cases to test core module. -""" - -__author__ = 'eywalker' -from nose.tools import assert_true, assert_raises, assert_equal -from datajoint.utils import to_camel_case, from_camel_case -from datajoint import DataJointError - - -def setup(): - pass - - -def teardown(): - pass - - -def test_to_camel_case(): - assert_equal(to_camel_case('basic_sessions'), 'BasicSessions') - assert_equal(to_camel_case('_another_table'), 'AnotherTable') - - -def test_from_camel_case(): - assert_equal(from_camel_case('AllGroups'), 'all_groups') - with assert_raises(DataJointError): - from_camel_case('repNames') - with assert_raises(DataJointError): - from_camel_case('10_all') - with assert_raises(DataJointError): - from_camel_case('hello world') - with assert_raises(DataJointError): - from_camel_case('#baisc_names') +# """ +# Collection of test cases to test core module. +# """ +# +# __author__ = 'eywalker' +# from nose.tools import assert_true, assert_raises, assert_equal +# from datajoint.utils import to_camel_case, from_camel_case +# from datajoint import DataJointError +# +# +# def setup(): +# pass +# +# +# def teardown(): +# pass +# +# +# def test_to_camel_case(): +# assert_equal(to_camel_case('basic_sessions'), 'BasicSessions') +# assert_equal(to_camel_case('_another_table'), 'AnotherTable') +# +# +# def test_from_camel_case(): +# assert_equal(from_camel_case('AllGroups'), 'all_groups') +# with assert_raises(DataJointError): +# from_camel_case('repNames') +# with assert_raises(DataJointError): +# from_camel_case('10_all') +# with assert_raises(DataJointError): +# from_camel_case('hello world') +# with assert_raises(DataJointError): +# from_camel_case('#baisc_names') From 0c52d6cab08946c8ab4e2233b6d471e2e2ec95fe Mon Sep 17 00:00:00 2001 From: Fabian Sinz Date: Mon, 8 Jun 2015 18:35:46 -0500 Subject: [PATCH 12/42] first test runs without failing --- tests/test_utils.py | 66 ++++++++++++++++++++++----------------------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 2322cf18a..655884ce0 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,33 +1,33 @@ -# """ -# Collection of test cases to test core module. -# """ -# -# __author__ = 'eywalker' -# from nose.tools import assert_true, assert_raises, assert_equal -# from datajoint.utils import to_camel_case, from_camel_case -# from datajoint import DataJointError -# -# -# def setup(): -# pass -# -# -# def teardown(): -# pass -# -# -# def test_to_camel_case(): -# assert_equal(to_camel_case('basic_sessions'), 'BasicSessions') -# assert_equal(to_camel_case('_another_table'), 'AnotherTable') -# -# -# def test_from_camel_case(): -# assert_equal(from_camel_case('AllGroups'), 'all_groups') -# with assert_raises(DataJointError): -# from_camel_case('repNames') -# with assert_raises(DataJointError): -# from_camel_case('10_all') -# with assert_raises(DataJointError): -# from_camel_case('hello world') -# with assert_raises(DataJointError): -# from_camel_case('#baisc_names') +""" +Collection of test cases to test core module. +""" + +__author__ = 'eywalker' +from nose.tools import assert_true, assert_raises, assert_equal +from datajoint.utils import to_camel_case, from_camel_case +from datajoint import DataJointError + + +def setup(): + pass + + +def teardown(): + pass + + +def test_to_camel_case(): + assert_equal(to_camel_case('basic_sessions'), 'BasicSessions') + assert_equal(to_camel_case('_another_table'), 'AnotherTable') + + +def test_from_camel_case(): + assert_equal(from_camel_case('AllGroups'), 'all_groups') + with assert_raises(DataJointError): + from_camel_case('repNames') + with assert_raises(DataJointError): + from_camel_case('10_all') + with assert_raises(DataJointError): + from_camel_case('hello world') + with assert_raises(DataJointError): + from_camel_case('#baisc_names') From eda035bd74c12596562d5184153654a0ce2b4543 Mon Sep 17 00:00:00 2001 From: Fabian Sinz Date: Mon, 8 Jun 2015 18:36:07 -0500 Subject: [PATCH 13/42] 13 tests pass --- tests/test_settings.py | 134 ++++++++++++++++++++--------------------- 1 file changed, 67 insertions(+), 67 deletions(-) diff --git a/tests/test_settings.py b/tests/test_settings.py index 24b05d5d4..6b8100806 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -1,67 +1,67 @@ -# import os -# import pprint -# import random -# import string -# from datajoint import settings -# -# __author__ = 'Fabian Sinz' -# -# from nose.tools import assert_true, assert_raises, assert_equal, raises, assert_dict_equal -# import datajoint as dj -# -# -# def test_load_save(): -# dj.config.save('tmp.json') -# conf = dj.Config() -# conf.load('tmp.json') -# assert_true(conf == dj.config, 'Two config files do not match.') -# os.remove('tmp.json') -# -# def test_singleton(): -# dj.config.save('tmp.json') -# conf = dj.Config() -# conf.load('tmp.json') -# conf['dummy.val'] = 2 -# -# assert_true(conf == dj.config, 'Config does not behave like a singleton.') -# os.remove('tmp.json') -# -# -# @raises(ValueError) -# def test_nested_check(): -# dummy = {'dummy.testval': {'notallowed': 2}} -# dj.config.update(dummy) -# -# @raises(dj.DataJointError) -# def test_validator(): -# dj.config['database.port'] = 'harbor' -# -# def test_del(): -# dj.config['peter'] = 2 -# assert_true('peter' in dj.config) -# del dj.config['peter'] -# assert_true('peter' not in dj.config) -# -# def test_len(): -# assert_equal(len(dj.config), len(dj.config._conf)) -# -# def test_str(): -# assert_equal(str(dj.config), pprint.pformat(dj.config._conf, indent=4)) -# -# def test_repr(): -# assert_equal(repr(dj.config), pprint.pformat(dj.config._conf, indent=4)) -# -# @raises(ValueError) -# def test_nested_check2(): -# dj.config['dummy'] = {'dummy2':2} -# -# def test_save(): -# tmpfile = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(20)) -# moved = False -# if os.path.isfile(settings.LOCALCONFIG): -# os.rename(settings.LOCALCONFIG, tmpfile) -# moved = True -# dj.config.save() -# assert_true(os.path.isfile(settings.LOCALCONFIG)) -# if moved: -# os.rename(tmpfile, settings.LOCALCONFIG) +import os +import pprint +import random +import string +from datajoint import settings + +__author__ = 'Fabian Sinz' + +from nose.tools import assert_true, assert_raises, assert_equal, raises, assert_dict_equal +import datajoint as dj + + +def test_load_save(): + dj.config.save('tmp.json') + conf = dj.Config() + conf.load('tmp.json') + assert_true(conf == dj.config, 'Two config files do not match.') + os.remove('tmp.json') + +def test_singleton(): + dj.config.save('tmp.json') + conf = dj.Config() + conf.load('tmp.json') + conf['dummy.val'] = 2 + + assert_true(conf == dj.config, 'Config does not behave like a singleton.') + os.remove('tmp.json') + + +@raises(ValueError) +def test_nested_check(): + dummy = {'dummy.testval': {'notallowed': 2}} + dj.config.update(dummy) + +@raises(dj.DataJointError) +def test_validator(): + dj.config['database.port'] = 'harbor' + +def test_del(): + dj.config['peter'] = 2 + assert_true('peter' in dj.config) + del dj.config['peter'] + assert_true('peter' not in dj.config) + +def test_len(): + assert_equal(len(dj.config), len(dj.config._conf)) + +def test_str(): + assert_equal(str(dj.config), pprint.pformat(dj.config._conf, indent=4)) + +def test_repr(): + assert_equal(repr(dj.config), pprint.pformat(dj.config._conf, indent=4)) + +@raises(ValueError) +def test_nested_check2(): + dj.config['dummy'] = {'dummy2':2} + +def test_save(): + tmpfile = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(20)) + moved = False + if os.path.isfile(settings.LOCALCONFIG): + os.rename(settings.LOCALCONFIG, tmpfile) + moved = True + dj.config.save() + assert_true(os.path.isfile(settings.LOCALCONFIG)) + if moved: + os.rename(tmpfile, settings.LOCALCONFIG) From 52d4100eca6ca2e15592bb1ee8e6023a8fb7cdb9 Mon Sep 17 00:00:00 2001 From: Fabian Sinz Date: Mon, 8 Jun 2015 18:58:37 -0500 Subject: [PATCH 14/42] instantiating Subject fails --- datajoint/__init__.py | 3 ++- datajoint/relation.py | 32 +++++++++++----------- datajoint/user_relations.py | 14 ++++------ tests/schemata/__init__.py | 2 +- tests/schemata/{schema1 => }/test1.py | 39 ++++++++++++++------------- tests/test_relation.py | 6 +++++ 6 files changed, 51 insertions(+), 45 deletions(-) rename tests/schemata/{schema1 => }/test1.py (89%) diff --git a/datajoint/__init__.py b/datajoint/__init__.py index 8ebe93992..796ded488 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -42,4 +42,5 @@ class DataJointError(Exception): from .autopopulate import AutoPopulate from . import blob from .relational_operand import Not -from .heading import Heading \ No newline at end of file +from .heading import Heading +from .relation import schema \ No newline at end of file diff --git a/datajoint/relation.py b/datajoint/relation.py index 19c2ebac9..3c48690d9 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -18,7 +18,7 @@ SharedInfo = namedtuple( 'SharedInfo', - ('database', 'context', 'connection', 'heading', 'parents', 'children', 'references', 'referenced')) + ('database', 'context', 'connection', 'heading')) @@ -101,13 +101,13 @@ def table_name(cls): """ pass - @classproperty - @abc.abstractmethod - def database(cls): - """ - :return: string containing the database name on the server - """ - pass + # @classproperty + # @abc.abstractmethod + # def database(cls): + # """ + # :return: string containing the database name on the server + # """ + # pass @classproperty @abc.abstractmethod @@ -117,13 +117,13 @@ def definition(cls): """ pass - @classproperty - @abc.abstractmethod - def context(cls): - """ - :return: a dict with other relations that can be referenced by foreign keys - """ - pass + # @classproperty + # @abc.abstractmethod + # def context(cls): + # """ + # :return: a dict with other relations that can be referenced by foreign keys + # """ + # pass # --------- base relation functionality --------- # @classproperty @@ -132,7 +132,7 @@ def is_declared(cls): return True cur = cls._shared_info.connection.query( 'SHOW TABLE STATUS FROM `{database}` WHERE name="{table_name}"'.format( - table_name=cls.table_name)) + database=cls.database , table_name=cls.table_name)) return cur.rowcount == 1 @classproperty diff --git a/datajoint/user_relations.py b/datajoint/user_relations.py index 60e013e6f..62c9951bc 100644 --- a/datajoint/user_relations.py +++ b/datajoint/user_relations.py @@ -1,31 +1,27 @@ -from datajoint.relation import Relation +from datajoint.relation import Relation, classproperty from .autopopulate import AutoPopulate from .utils import from_camel_case class Manual(Relation): - @property - @classmethod + @classproperty def table_name(cls): return from_camel_case(cls.__name__) class Lookup(Relation): - @property - @classmethod + @classproperty def table_name(cls): return '#' + from_camel_case(cls.__name__) class Imported(Relation, AutoPopulate): - @property - @classmethod + @classproperty def table_name(cls): return "_" + from_camel_case(cls.__name__) class Computed(Relation, AutoPopulate): - @property - @classmethod + @classproperty def table_name(cls): return "__" + from_camel_case(cls.__name__) \ No newline at end of file diff --git a/tests/schemata/__init__.py b/tests/schemata/__init__.py index 6f391d065..9fa8b9ad1 100644 --- a/tests/schemata/__init__.py +++ b/tests/schemata/__init__.py @@ -1 +1 @@ -__author__ = "eywalker" \ No newline at end of file +__author__ = "eywalker, fabiansinz" \ No newline at end of file diff --git a/tests/schemata/schema1/test1.py b/tests/schemata/test1.py similarity index 89% rename from tests/schemata/schema1/test1.py rename to tests/schemata/test1.py index d0d91707f..5b4ac723a 100644 --- a/tests/schemata/schema1/test1.py +++ b/tests/schemata/test1.py @@ -1,22 +1,25 @@ -# """ -# Test 1 Schema definition -# """ -# __author__ = 'eywalker' -# -# import datajoint as dj +""" +Test 1 Schema definition +""" +__author__ = 'eywalker' + +import datajoint as dj # from .. import schema2 -# -# -# class Subjects(dj.Relation): -# definition = """ -# test1.Subjects (manual) # Basic subject info -# -# subject_id : int # unique subject id -# --- -# real_id : varchar(40) # real-world name -# species = "mouse" : enum('mouse', 'monkey', 'human') # species -# """ -# +from .. import PREFIX + +testschema = dj.schema(PREFIX + '_test1', locals()) + +@testschema +class Subjects(dj.Manual): + definition = """ + # Basic subject info + + subject_id : int # unique subject id + --- + real_id : varchar(40) # real-world name + species = "mouse" : enum('mouse', 'monkey', 'human') # species + """ + # # test for shorthand # class Animals(dj.Relation): # definition = """ diff --git a/tests/test_relation.py b/tests/test_relation.py index 94adef010..76f87f707 100644 --- a/tests/test_relation.py +++ b/tests/test_relation.py @@ -4,6 +4,12 @@ # __author__ = 'fabee' # # from .schemata.schema1 import test1, test4 +from .schemata.test1 import Subjects + + +def test_instantiate_relation(): + s = Subjects() + # # from . import BASE_CONN, CONN_INFO, PREFIX, cleanup # from datajoint.connection import Connection From 2f0971ad4c892d722441528597df98b18b7fe286 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Tue, 9 Jun 2015 09:49:24 -0500 Subject: [PATCH 15/42] implemented lookup_name using eval --- datajoint/autopopulate.py | 10 +++++----- datajoint/relation.py | 18 ++---------------- demos/demo1.py | 5 ----- 3 files changed, 7 insertions(+), 26 deletions(-) diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index ba349b250..5db012eff 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -53,29 +53,29 @@ def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False, if not isinstance(self.populate_relation, RelationalOperand): raise DataJointError('Invalid populate_relation value') - self.conn.cancel_transaction() # rollback previous transaction, if any + self.connection.cancel_transaction() # rollback previous transaction, if any if not isinstance(self, Relation): raise DataJointError('Autopopulate is a mixin for Relation and must therefore subclass Relation') unpopulated = (self.populate_relation - self.target) & restriction for key in unpopulated.project(): - self.conn.start_transaction() + self.connection.start_transaction() if key in self.target: # already populated - self.conn.cancel_transaction() + self.connection.cancel_transaction() else: logger.info('Populating: ' + str(key)) try: self._make_tuples(dict(key)) except Exception as error: - self.conn.cancel_transaction() + self.connection.cancel_transaction() if not suppress_errors: raise else: logger.error(error) error_list.append((key, error)) else: - self.conn.commit_transaction() + self.connection.commit_transaction() logger.info('Done populating.') return error_list diff --git a/datajoint/relation.py b/datajoint/relation.py index 3c48690d9..4588040a1 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -21,7 +21,6 @@ ('database', 'context', 'connection', 'heading')) - class classproperty: def __init__(self, getf): self._getf = getf @@ -88,7 +87,7 @@ class Relation(RelationalOperand, metaclass=abc.ABCMeta): _shared_info = None - def __init__(self): # TODO: Think about it + def __init__(self): if self._shared_info is None: raise DataJointError('The class must define _shared_info') @@ -478,21 +477,8 @@ def _parse_declaration(cls): def lookup_name(cls, name): """ Lookup the referenced name in the context dictionary - - e.g. for reference `common.Animals`, it will first check if `context` dictionary contains key - `common`. If found, it then checks for attribute `Animals` in `common`, and returns the result. """ - parts = name.strip().split('.') - try: - ref = cls.context.get(parts[0]) - for attr in parts[1:]: - ref = getattr(ref, attr) - except (KeyError, AttributeError): - raise DataJointError( - 'Foreign key reference to %s could not be resolved.' - 'Please make sure the name exists' - 'in the context of the class' % name) - return ref + return eval(name, locals=cls.context) @classproperty def connection(cls): diff --git a/demos/demo1.py b/demos/demo1.py index e85d6ead3..4376a4982 100644 --- a/demos/demo1.py +++ b/demos/demo1.py @@ -1,9 +1,4 @@ # -*- coding: utf-8 -*- -""" -Created on Tue Aug 26 17:42:52 2014 - -@author: dimitri -""" import datajoint as dj From 6265e6877e87415abbabf146e63cc0436ef54811 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Tue, 9 Jun 2015 09:55:22 -0500 Subject: [PATCH 16/42] moved from_camel_case to user_relations.py --- datajoint/user_relations.py | 24 ++++++++++++++++++++++-- datajoint/utils.py | 34 ---------------------------------- 2 files changed, 22 insertions(+), 36 deletions(-) diff --git a/datajoint/user_relations.py b/datajoint/user_relations.py index 62c9951bc..8348b6ccc 100644 --- a/datajoint/user_relations.py +++ b/datajoint/user_relations.py @@ -1,7 +1,8 @@ +import re from datajoint.relation import Relation, classproperty from .autopopulate import AutoPopulate from .utils import from_camel_case - +from . import DataJointError class Manual(Relation): @classproperty @@ -24,4 +25,23 @@ def table_name(cls): class Computed(Relation, AutoPopulate): @classproperty def table_name(cls): - return "__" + from_camel_case(cls.__name__) \ No newline at end of file + return "__" + from_camel_case(cls.__name__) + + + +def from_camel_case(s): + """ + Convert names in camel case into underscore (_) separated names + + Example: + >>>from_camel_case("TableName") + "table_name" + """ + def convert(match): + return ('_' if match.groups()[0] else '') + match.group(0).lower() + + if not re.match(r'[A-Z][a-zA-Z0-9]*', s): + raise DataJointError( + 'ClassName must be alphanumeric in CamelCase, begin with a capital letter') + return re.sub(r'(\B[A-Z])|(\b[A-Z])', convert, s) + diff --git a/datajoint/utils.py b/datajoint/utils.py index ec506472e..f4b0edb57 100644 --- a/datajoint/utils.py +++ b/datajoint/utils.py @@ -1,37 +1,3 @@ -import re -from . import DataJointError - - -def to_camel_case(s): - """ - Convert names with under score (_) separation - into camel case names. - - Example: - >>>to_camel_case("table_name") - "TableName" - """ - def to_upper(match): - return match.group(0)[-1].upper() - return re.sub('(^|[_\W])+[a-zA-Z]', to_upper, s) - - -def from_camel_case(s): - """ - Convert names in camel case into underscore (_) separated names - - Example: - >>>from_camel_case("TableName") - "table_name" - """ - def convert(match): - return ('_' if match.groups()[0] else '') + match.group(0).lower() - - if not re.match(r'[A-Z][a-zA-Z0-9]*', s): - raise DataJointError( - 'ClassName must be alphanumeric in CamelCase, begin with a capital letter') - return re.sub(r'(\B[A-Z])|(\b[A-Z])', convert, s) - def user_choice(prompt, choices=("yes", "no"), default=None): """ From a48e66fe5650ee2f15a5812fad6c8eca141e6671 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Tue, 9 Jun 2015 10:22:16 -0500 Subject: [PATCH 17/42] minor cleanup --- datajoint/relation.py | 86 +++++++++++++------------------------ datajoint/user_relations.py | 2 - 2 files changed, 31 insertions(+), 57 deletions(-) diff --git a/datajoint/relation.py b/datajoint/relation.py index 4588040a1..8c2ebdc2b 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -16,8 +16,8 @@ logger = logging.getLogger(__name__) -SharedInfo = namedtuple( - 'SharedInfo', +TableInfo = namedtuple( + 'TableInfo', ('database', 'context', 'connection', 'heading')) @@ -60,11 +60,11 @@ def schema(database, context, connection=None): " permissions.".format(database=database)) def decorator(cls): - cls._shared_info = SharedInfo( + cls._table_info = TableInfo( database=database, context=context, connection=connection, - heading=None, + heading=None ) cls.declare() return cls @@ -85,11 +85,28 @@ class Relation(RelationalOperand, metaclass=abc.ABCMeta): __heading = None - _shared_info = None + _table_info = None def __init__(self): - if self._shared_info is None: - raise DataJointError('The class must define _shared_info') + if self._table_info is None: + raise DataJointError('The class must define _table_info') + + @classproperty + def connection(cls): + """ + Returns the connection object of the class + + :return: the connection object + """ + return cls._table_info.connection + + @classproperty + def database(cls): + return cls._table_info.database + + @classproperty + def context(cls): + return cls._table_info.context # ---------- abstract properties ------------ # @classproperty @@ -100,14 +117,6 @@ def table_name(cls): """ pass - # @classproperty - # @abc.abstractmethod - # def database(cls): - # """ - # :return: string containing the database name on the server - # """ - # pass - @classproperty @abc.abstractmethod def definition(cls): @@ -116,20 +125,12 @@ def definition(cls): """ pass - # @classproperty - # @abc.abstractmethod - # def context(cls): - # """ - # :return: a dict with other relations that can be referenced by foreign keys - # """ - # pass - - # --------- base relation functionality --------- # + # --------- SQL functionality --------- # @classproperty def is_declared(cls): if cls.__heading is not None: return True - cur = cls._shared_info.connection.query( + cur = cls._table_info.connection.query( 'SHOW TABLE STATUS FROM `{database}` WHERE name="{table_name}"'.format( database=cls.database , table_name=cls.table_name)) return cur.rowcount == 1 @@ -151,7 +152,7 @@ def from_clause(cls): for the SQL SELECT statements. :return: """ - return '`%s`.`%s`' % (cls.database, cls.table_name) + return cls.full_table_name @classmethod def declare(cls): @@ -413,13 +414,13 @@ def _declare(cls): implicit_indices.append(fk_source.primary_key) # for index in indexDefs: - # TODO: finish this up... + # TODO: add index declaration # close the declaration sql = '%s\n) ENGINE = InnoDB, COMMENT "%s"' % ( sql[:-2], table_info['comment']) - # # make sure that the table does not alredy exist + # # make sure that the table does not already exist # cls.load_heading() # if not cls.is_declared: # # execute declaration @@ -462,37 +463,12 @@ def _parse_declaration(cls): # foreign key ref_name = line[2:].strip() ref_list = parents if in_key else referenced - ref_list.append(cls.lookup_name(ref_name)) + ref_list.append(eval(ref_name, locals=cls.context)) elif re.match(r'^(unique\s+)?index[^:]*$', line, re.I): index_defs.append(parse_index_definition(line)) elif attribute_regexp.match(line): field_defs.append(parse_attribute_definition(line, in_key)) else: - raise DataJointError( - 'Invalid table declaration line "%s"' % line) + raise DataJointError('Invalid table declaration line "%s"' % line) return table_info, parents, referenced, field_defs, index_defs - - @classmethod - def lookup_name(cls, name): - """ - Lookup the referenced name in the context dictionary - """ - return eval(name, locals=cls.context) - - @classproperty - def connection(cls): - """ - Returns the connection object of the class - - :return: the connection object - """ - return cls._shared_info.connection - - @classproperty - def database(cls): - return cls._shared_info.database - - @classproperty - def context(cls): - return cls._shared_info.context diff --git a/datajoint/user_relations.py b/datajoint/user_relations.py index 8348b6ccc..a615247ee 100644 --- a/datajoint/user_relations.py +++ b/datajoint/user_relations.py @@ -1,7 +1,6 @@ import re from datajoint.relation import Relation, classproperty from .autopopulate import AutoPopulate -from .utils import from_camel_case from . import DataJointError class Manual(Relation): @@ -28,7 +27,6 @@ def table_name(cls): return "__" + from_camel_case(cls.__name__) - def from_camel_case(s): """ Convert names in camel case into underscore (_) separated names From 622fd5607978aa5b20917143693c7e63fd7b3ddb Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Tue, 9 Jun 2015 11:15:24 -0500 Subject: [PATCH 18/42] intermediate: removed classmethods from Relation. Moved all declaration functionality into declare.py --- datajoint/declare.py | 211 ++++++++++++++++++++++++++ datajoint/parsing.py | 88 ----------- datajoint/relation.py | 295 ++++++++---------------------------- datajoint/user_relations.py | 6 +- requirements.txt | 1 + 5 files changed, 273 insertions(+), 328 deletions(-) create mode 100644 datajoint/declare.py delete mode 100644 datajoint/parsing.py diff --git a/datajoint/declare.py b/datajoint/declare.py new file mode 100644 index 000000000..39d375701 --- /dev/null +++ b/datajoint/declare.py @@ -0,0 +1,211 @@ +import re +import pyparsing as pp +import logging + +from . import DataJointError + + +logger = logging.getLogger(__name__) + + +def compile_attribute(line, in_key=False): + """ + Convert attribute definition from DataJoint format to SQL + :param line: attribution line + :param in_key: set to True if attribute is in primary key set + :returns: attribute name and sql code for its declaration + """ + quoted = pp.Or(pp.QuotedString('"'), pp.QuotedString("'")) + colon = pp.Literal(':').suppress() + attribute_name = pp.Word(pp.srange('[a-z]'), pp.srange('[a-z0-9_]')).setResultsName('name') + + data_type = pp.Combine(pp.Word(pp.alphas)+pp.SkipTo("#", ignore=quoted)).setResultsName('type') + default = pp.Literal('=').suppress() + pp.SkipTo(colon, ignore=quoted).setResultsName('default') + comment = pp.Literal('#').suppress() + pp.restOfLine.setResultsName('comment') + + attribute_parser = attribute_name + pp.Optional(default) + colon + data_type + comment + + match = attribute_parser.parseString(line+'#', parseAll=True) + match['comment'] = match['comment'].rstrip('#') + if 'default' not in match: + match['default'] = '' + match = {k: v.strip() for k, v in match.items()} + match['nullable'] = match['default'].lower() == 'null' + + sql_literals = ['CURRENT_TIMESTAMP'] # not to be enclosed in quotes + assert not re.match(r'^bigint', match['type'], re.I) or not match['nullable'], \ + 'BIGINT attributes cannot be nullable in "%s"' % line # TODO: This was a MATLAB limitation. Handle this correctly. + if match['nullable']: + if in_key: + raise DataJointError('Primary key attributes cannot be nullable in line %s' % line) + match['default'] = 'DEFAULT NULL' # nullable attributes default to null + else: + if match['default']: + quote = match['default'].upper() not in sql_literals and match['default'][0] not in '"\'' + match['default'] = ('NOT NULL DEFAULT ' + + ('"%s"' if quote else "%s") % match['default']) + else: + match['default'] = 'NOT NULL' + match['comment'] = match['comment'].replace('"','\\"') # escape double quotes in comment + sql = ('`{name}` {type} {default}' + (' COMMENT "{comment}"' if match['comment'] else '') + ).format(**match) + return match['name'], sql + + +def parse_index(line): + """ + Parses index definition. + + :param line: definition line + :return: groupdict with index info + """ + line = line.strip() + index_regexp = re.compile(""" + ^(?PUNIQUE)?\s*INDEX\s* # [UNIQUE] INDEX + \((?P[^\)]+)\)$ # (attr1, attr2) + """, re.I + re.X) + m = index_regexp.match(line) + assert m, 'Invalid index declaration "%s"' % line + index_info = m.groupdict() + attributes = re.split(r'\s*,\s*', index_info['attributes'].strip()) + index_info['attributes'] = attributes + assert len(attributes) == len(set(attributes)), \ + 'Duplicate attributes in index declaration "%s"' % line + return index_info + + +def parse_declaration(cls): + """ + Parse declaration and create new SQL table accordingly. + """ + parents = [] + referenced = [] + index_defs = [] + field_defs = [] + declaration = re.split(r'\s*\n\s*', cls.definition.strip()) + + # remove comment lines + declaration = [x for x in declaration if not x.startswith('#')] + ptrn = """ + \#\s*(?P.*)$ # comment + """ + p = re.compile(ptrn, re.X) + table_info = p.search(declaration[0]).groupdict() + + #table_info['tier'] = Role[table_info['tier']] # convert into enum + + in_key = True # parse primary keys + attribute_regexp = re.compile(""" + ^[a-z][a-z\d_]*\s* # name + (=\s*\S+(\s+\S+)*\s*)? # optional defaults + :\s*\w.*$ # type, comment + """, re.I + re.X) # ignore case and verbose + + for line in declaration[1:]: + if line.startswith('---'): + in_key = False # start parsing non-PK fields + elif line.startswith('->'): + # foreign key + ref_name = line[2:].strip() + ref_list = parents if in_key else referenced + ref_list.append(eval(ref_name, locals=cls.context)) + elif re.match(r'^(unique\s+)?index[^:]*$', line, re.I): + index_defs.append(parse_index_definition(line)) + elif attribute_regexp.match(line): + field_defs.append(parse_attribute_definition(line, in_key)) + else: + raise DataJointError('Invalid table declaration line "%s"' % line) + + return table_info, parents, referenced, field_defs, index_defs + + +def declare(base_relation): + """ + Declares the table in the database if it does not exist already + """ + cur = base_relation.connection.query( + 'SHOW TABLE STATUS FROM `{database}` WHERE name="{table_name}"'.format( + database=base_relation.database, table_name=base_relation.table_name)) + if cur.rowcount: + return + + if base_relation.connection.in_transaction: + raise DataJointError("Tables cannot be declared during a transaction.") + + if not base_relation.definition: + raise DataJointError('Table definition is missing.') + + table_info, parents, referenced, field_defs, index_defs = cls._parse_declaration() + + sql = 'CREATE TABLE %s (\n' % cls.full_table_name + + # add inherited primary key fields + primary_key_fields = set() + non_key_fields = set() + for p in parents: + for key in p.primary_key: + field = p.heading[key] + if field.name not in primary_key_fields: + primary_key_fields.add(field.name) + sql += field_to_sql(field) + else: + logger.debug('Field definition of {} in {} ignored'.format( + field.name, p.full_class_name)) + + # add newly defined primary key fields + for field in (f for f in field_defs if f.in_key): + if field.nullable: + raise DataJointError('Primary key attribute {} cannot be nullable'.format( + field.name)) + if field.name in primary_key_fields: + raise DataJointError('Duplicate declaration of the primary attribute {key}. ' + 'Ensure that the attribute is not already declared ' + 'in referenced tables'.format(key=field.name)) + primary_key_fields.add(field.name) + sql += field_to_sql(field) + + # add secondary foreign key attributes + for r in referenced: + for key in r.primary_key: + field = r.heading[key] + if field.name not in primary_key_fields | non_key_fields: + non_key_fields.add(field.name) + sql += field_to_sql(field) + + # add dependent attributes + for field in (f for f in field_defs if not f.in_key): + non_key_fields.add(field.name) + sql += field_to_sql(field) + + # add primary key declaration + assert len(primary_key_fields) > 0, 'table must have a primary key' + keys = ', '.join(primary_key_fields) + sql += 'PRIMARY KEY (%s),\n' % keys + + # add foreign key declarations + for ref in parents + referenced: + keys = ', '.join(ref.primary_key) + sql += 'FOREIGN KEY (%s) REFERENCES %s (%s) ON UPDATE CASCADE ON DELETE RESTRICT,\n' % \ + (keys, ref.full_table_name, keys) + + # add secondary index declarations + # gather implicit indexes due to foreign keys first + implicit_indices = [] + for fk_source in parents + referenced: + implicit_indices.append(fk_source.primary_key) + + # for index in indexDefs: + # TODO: add index declaration + + # close the declaration + sql = '%s\n) ENGINE = InnoDB, COMMENT "%s"' % ( + sql[:-2], table_info['comment']) + + # # make sure that the table does not already exist + # cls.load_heading() + # if not cls.is_declared: + # # execute declaration + # logger.debug('\n\n' + sql + '\n\n') + # cls.connection.query(sql) + # cls.load_heading() + diff --git a/datajoint/parsing.py b/datajoint/parsing.py deleted file mode 100644 index 85e367c96..000000000 --- a/datajoint/parsing.py +++ /dev/null @@ -1,88 +0,0 @@ -import re -from . import DataJointError -from .heading import Heading - - -def parse_attribute_definition(line, in_key=False): - """ - Parse attribute definition line in the declaration and returns - an attribute tuple. - - :param line: attribution line - :param in_key: set to True if attribute is in primary key set - :returns: attribute tuple - """ - line = line.strip() - attribute_regexp = re.compile(""" - ^(?P[a-z][a-z\d_]*)\s* # field name - (=\s*(?P\S+(\s+\S+)*?)\s*)? # default value - :\s*(?P\w[^\#]*[^\#\s])\s* # datatype - (\#\s*(?P\S*(\s+\S+)*)\s*)?$ # comment - """, re.X) - m = attribute_regexp.match(line) - if not m: - raise DataJointError('Invalid field declaration "%s"' % line) - attr_info = m.groupdict() - if not attr_info['comment']: - attr_info['comment'] = '' - if not attr_info['default']: - attr_info['default'] = '' - attr_info['nullable'] = attr_info['default'].lower() == 'null' - assert (not re.match(r'^bigint', attr_info['type'], re.I) or not attr_info['nullable']), \ - 'BIGINT attributes cannot be nullable in "%s"' % line - - return Heading.AttrTuple( - in_key=in_key, - autoincrement=None, - numeric=None, - string=None, - is_blob=None, - computation=None, - dtype=None, - **attr_info - ) - - -def field_to_sql(field): # TODO move this into Attribute Tuple - """ - Converts an attribute definition tuple into SQL code. - :param field: attribute definition - :rtype : SQL code - """ - mysql_constants = ['CURRENT_TIMESTAMP'] - if field.nullable: - default = 'DEFAULT NULL' - else: - default = 'NOT NULL' - # if some default specified - if field.default: - # enclose value in quotes except special SQL values or already enclosed - quote = field.default.upper() not in mysql_constants and field.default[0] not in '"\'' - default += ' DEFAULT ' + ('"%s"' if quote else "%s") % field.default - if any((c in r'\"' for c in field.comment)): - raise DataJointError('Illegal characters in attribute comment "%s"' % field.comment) - - return '`{name}` {type} {default} COMMENT "{comment}",\n'.format( - name=field.name, type=field.type, default=default, comment=field.comment) - - -def parse_index_definition(line): - """ - Parses index definition. - - :param line: definition line - :return: groupdict with index info - """ - line = line.strip() - index_regexp = re.compile(""" - ^(?PUNIQUE)?\s*INDEX\s* # [UNIQUE] INDEX - \((?P[^\)]+)\)$ # (attr1, attr2) - """, re.I + re.X) - m = index_regexp.match(line) - assert m, 'Invalid index declaration "%s"' % line - index_info = m.groupdict() - attributes = re.split(r'\s*,\s*', index_info['attributes'].strip()) - index_info['attributes'] = attributes - assert len(attributes) == len(set(attributes)), \ - 'Duplicate attributes in index declaration "%s"' % line - return index_info diff --git a/datajoint/relation.py b/datajoint/relation.py index 8c2ebdc2b..7b7acd5b6 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -1,47 +1,33 @@ from collections import namedtuple -from collections.abc import MutableMapping, Mapping +from collections.abc import Mapping import numpy as np import logging -import re import abc - -from . import DataJointError, config import pymysql -from datajoint import DataJointError, conn + + +from . import DataJointError, config, conn +from .declare import declare from .relational_operand import RelationalOperand from .blob import pack from .utils import user_choice -from .parsing import parse_attribute_definition, field_to_sql, parse_index_definition from .heading import Heading logger = logging.getLogger(__name__) -TableInfo = namedtuple( - 'TableInfo', +TableLink = namedtuple( + 'TableLink', ('database', 'context', 'connection', 'heading')) -class classproperty: - def __init__(self, getf): - self._getf = getf - - def __get__(self, instance, owner): - return self._getf(owner) - - def schema(database, context, connection=None): """ - Returns a schema decorator that can be used to associate a Relation class to a - specific database with :param name:. Name reference to other tables in the table definition - will be resolved by looking up the corresponding key entry in the passed in context dictionary. - It is most common to set context equal to the return value of call to locals() in the module. - For more details, please refer to the tutorial online. + Returns a decorator that can be used to associate a Relation class to a database. :param database: name of the database to associate the decorated class with - :param context: dictionary used to resolve (any) name references within the table definition string - :param connection: connection object to the database server. If ommited, will try to establish connection according to - config values - :return: a decorator function to be used on Relation derivative classes + :param context: dictionary for looking up foreign keys references, usually set to locals() + :param connection: Connection object. Defaults to datajoint.conn() + :return: a decorator for Relation subclasses """ if connection is None: connection = conn() @@ -60,13 +46,16 @@ def schema(database, context, connection=None): " permissions.".format(database=database)) def decorator(cls): - cls._table_info = TableInfo( + """ + The decorator declares the table and binds the class to the database table + """ + cls._table_info = TableLink( database=database, context=context, connection=connection, heading=None ) - cls.declare() + declare(cls()) return cls return decorator @@ -83,112 +72,82 @@ class Relation(RelationalOperand, metaclass=abc.ABCMeta): It also handles the table declaration based on its definition property """ - __heading = None - _table_info = None def __init__(self): if self._table_info is None: raise DataJointError('The class must define _table_info') - @classproperty - def connection(cls): - """ - Returns the connection object of the class - - :return: the connection object - """ - return cls._table_info.connection - - @classproperty - def database(cls): - return cls._table_info.database - - @classproperty - def context(cls): - return cls._table_info.context - # ---------- abstract properties ------------ # - @classproperty + @property @abc.abstractmethod - def table_name(cls): + def table_name(self): """ :return: the name of the table in the database """ pass - @classproperty + @property @abc.abstractmethod - def definition(cls): + def definition(self): """ :return: a string containing the table definition using the DataJoint DDL """ pass - # --------- SQL functionality --------- # - @classproperty - def is_declared(cls): - if cls.__heading is not None: - return True - cur = cls._table_info.connection.query( - 'SHOW TABLE STATUS FROM `{database}` WHERE name="{table_name}"'.format( - database=cls.database , table_name=cls.table_name)) - return cur.rowcount == 1 - - @classproperty - def heading(cls): - """ - Required by relational operand - :return: a datajoint.Heading object - """ - if cls.__heading is None: - cls.__heading = Heading.init_from_database(cls.connection, cls.database, cls.table_name) - return cls.__heading + # -------------- table info ----------------- # + @property + def connection(self): + return self._table_info.connection + + @property + def database(self): + return self._table_info.database + + @property + def context(self): + return self._table_info.context + + @property + def heading(self): + if self._table_info.heading is None: + self._table_info.heading = Heading.init_from_database( + self.connection, self.database, self.table_name) + return self._table_info.heading - @classproperty - def from_clause(cls): + + # --------- SQL functionality --------- # + @property + def from_clause(self): """ Required by the Relational class, this property specifies the contents of the FROM clause for the SQL SELECT statements. :return: """ - return cls.full_table_name - - @classmethod - def declare(cls): - """ - Declare the table in database if it doesn't already exist. + return self.full_table_name - :raises: DataJointError if the table cannot be declared. - """ - if not cls.is_declared: - cls._declare() - - @classmethod - def iter_insert(cls, rows, **kwargs): + def iter_insert(self, rows, **kwargs): """ Inserts an entire batch of entries. Additional keyword arguments are passed to insert. :param iter: Must be an iterator that generates a sequence of valid arguments for insert. """ for row in rows: - cls.insert(row, **kwargs) + self.insert(row, **kwargs) - @classmethod - def batch_insert(cls, data, **kwargs): + def batch_insert(self, data, **kwargs): """ Inserts an entire batch of entries. Additional keyword arguments are passed to insert. :param data: must be iterable, each row must be a valid argument for insert """ - cls.iter_insert(data.__iter__(), **kwargs) + self.iter_insert(data.__iter__(), **kwargs) - @classproperty - def full_table_name(cls): - return r"`{0:s}`.`{1:s}`".format(cls.database, cls.table_name) + def full_table_name(self): + return r"`{0:s}`.`{1:s}`".format(self.database, self.table_name) @classmethod - def insert(cls, tup, ignore_errors=False, replace=False): + def insert(self, tup, ignore_errors=False, replace=False): """ Insert one data record or one Mapping (like a dictionary). @@ -203,7 +162,7 @@ def insert(cls, tup, ignore_errors=False, replace=False): real_id = 1007, date_of_birth = "2014-09-01")) """ - heading = cls.heading + heading = self.heading if isinstance(tup, np.void): for fieldname in tup.dtype.fields: if fieldname not in heading: @@ -233,10 +192,9 @@ def insert(cls, tup, ignore_errors=False, replace=False): sql = 'INSERT IGNORE' else: sql = 'INSERT' - sql += " INTO %s (%s) VALUES (%s)" % (cls.full_table_name, - attribute_list, value_list) + sql += " INTO %s (%s) VALUES (%s)" % (self.from_caluse, attribute_list, value_list) logger.info(sql) - cls.connection.query(sql, args=args) + self.connection.query(sql, args=args) def delete(self): if not config['safemode'] or user_choice( @@ -244,22 +202,20 @@ def delete(self): "Proceed?", default='no') == 'yes': self.connection.query('DELETE FROM ' + self.from_clause + self.where_clause) # TODO: make cascading (issue #15) - @classmethod - def drop(cls): + def drop(self): """ Drops the table associated to this class. """ - if cls.is_declared: + if self.is_declared: if not config['safemode'] or user_choice( "You are about to drop an entire table. This operation cannot be undone.\n" "Proceed?", default='no') == 'yes': - cls.connection.query('DROP TABLE %s' % cls.full_table_name) # TODO: make cascading (issue #16) + self.connection.query('DROP TABLE %s' % cls.full_table_name) # TODO: make cascading (issue #16) # cls.connection.clear_dependencies(dbname=cls.dbname) #TODO: reimplement because clear_dependencies will be gone # cls.connection.load_headings(dbname=cls.dbname, force=True) #TODO: reimplement because load_headings is gone logger.info("Dropped table %s" % cls.full_table_name) - @classproperty - def size_on_disk(cls): + def size_on_disk(self): """ :return: size of data and indices in MiB taken by the table on the storage device """ @@ -335,140 +291,9 @@ def _alter(cls, alter_statement): :param alter_statement: alter statement """ if cls.connection.in_transaction: - raise DataJointError( - u"_alter is currently in transaction. Operation not allowed to avoid implicit commits.") + raise DataJointError("Table definition cannot be altered during a transaction.") sql = 'ALTER TABLE %s %s' % (cls.full_table_name, alter_statement) cls.connection.query(sql) cls.connection.load_headings(cls.dbname, force=True) - # TODO: place table definition sync mechanism - - @classmethod - def _declare(cls): - """ - Declares the table in the database if no table in the database matches this object. - """ - if cls.connection.in_transaction: - raise DataJointError( - u"_declare is currently in transaction. Operation not allowed to avoid implicit commits.") - - if not cls.definition: # if empty definition was supplied - raise DataJointError('Table definition is missing!') - table_info, parents, referenced, field_defs, index_defs = cls._parse_declaration() - - sql = 'CREATE TABLE %s (\n' % cls.full_table_name - - # add inherited primary key fields - primary_key_fields = set() - non_key_fields = set() - for p in parents: - for key in p.primary_key: - field = p.heading[key] - if field.name not in primary_key_fields: - primary_key_fields.add(field.name) - sql += field_to_sql(field) - else: - logger.debug('Field definition of {} in {} ignored'.format( - field.name, p.full_class_name)) - - # add newly defined primary key fields - for field in (f for f in field_defs if f.in_key): - if field.nullable: - raise DataJointError('Primary key attribute {} cannot be nullable'.format( - field.name)) - if field.name in primary_key_fields: - raise DataJointError('Duplicate declaration of the primary attribute {key}. ' - 'Ensure that the attribute is not already declared ' - 'in referenced tables'.format(key=field.name)) - primary_key_fields.add(field.name) - sql += field_to_sql(field) - - # add secondary foreign key attributes - for r in referenced: - for key in r.primary_key: - field = r.heading[key] - if field.name not in primary_key_fields | non_key_fields: - non_key_fields.add(field.name) - sql += field_to_sql(field) - - # add dependent attributes - for field in (f for f in field_defs if not f.in_key): - non_key_fields.add(field.name) - sql += field_to_sql(field) - - # add primary key declaration - assert len(primary_key_fields) > 0, 'table must have a primary key' - keys = ', '.join(primary_key_fields) - sql += 'PRIMARY KEY (%s),\n' % keys - - # add foreign key declarations - for ref in parents + referenced: - keys = ', '.join(ref.primary_key) - sql += 'FOREIGN KEY (%s) REFERENCES %s (%s) ON UPDATE CASCADE ON DELETE RESTRICT,\n' % \ - (keys, ref.full_table_name, keys) - - # add secondary index declarations - # gather implicit indexes due to foreign keys first - implicit_indices = [] - for fk_source in parents + referenced: - implicit_indices.append(fk_source.primary_key) - - # for index in indexDefs: - # TODO: add index declaration - - # close the declaration - sql = '%s\n) ENGINE = InnoDB, COMMENT "%s"' % ( - sql[:-2], table_info['comment']) - - # # make sure that the table does not already exist - # cls.load_heading() - # if not cls.is_declared: - # # execute declaration - # logger.debug('\n\n' + sql + '\n\n') - # cls.connection.query(sql) - # cls.load_heading() - - @classmethod - def _parse_declaration(cls): - """ - Parse declaration and create new SQL table accordingly. - """ - parents = [] - referenced = [] - index_defs = [] - field_defs = [] - declaration = re.split(r'\s*\n\s*', cls.definition.strip()) - - # remove comment lines - declaration = [x for x in declaration if not x.startswith('#')] - ptrn = """ - \#\s*(?P.*)$ # comment - """ - p = re.compile(ptrn, re.X) - table_info = p.search(declaration[0]).groupdict() - - #table_info['tier'] = Role[table_info['tier']] # convert into enum - - in_key = True # parse primary keys - attribute_regexp = re.compile(""" - ^[a-z][a-z\d_]*\s* # name - (=\s*\S+(\s+\S+)*\s*)? # optional defaults - :\s*\w.*$ # type, comment - """, re.I + re.X) # ignore case and verbose - - for line in declaration[1:]: - if line.startswith('---'): - in_key = False # start parsing non-PK fields - elif line.startswith('->'): - # foreign key - ref_name = line[2:].strip() - ref_list = parents if in_key else referenced - ref_list.append(eval(ref_name, locals=cls.context)) - elif re.match(r'^(unique\s+)?index[^:]*$', line, re.I): - index_defs.append(parse_index_definition(line)) - elif attribute_regexp.match(line): - field_defs.append(parse_attribute_definition(line, in_key)) - else: - raise DataJointError('Invalid table declaration line "%s"' % line) - - return table_info, parents, referenced, field_defs, index_defs + # TODO: place table definition sync mechanism \ No newline at end of file diff --git a/datajoint/user_relations.py b/datajoint/user_relations.py index a615247ee..a3429f457 100644 --- a/datajoint/user_relations.py +++ b/datajoint/user_relations.py @@ -1,28 +1,24 @@ import re -from datajoint.relation import Relation, classproperty +from datajoint.relation import Relation from .autopopulate import AutoPopulate from . import DataJointError class Manual(Relation): - @classproperty def table_name(cls): return from_camel_case(cls.__name__) class Lookup(Relation): - @classproperty def table_name(cls): return '#' + from_camel_case(cls.__name__) class Imported(Relation, AutoPopulate): - @classproperty def table_name(cls): return "_" + from_camel_case(cls.__name__) class Computed(Relation, AutoPopulate): - @classproperty def table_name(cls): return "__" + from_camel_case(cls.__name__) diff --git a/requirements.txt b/requirements.txt index 82abfa7ea..af4d48f13 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ numpy pymysql +pyparsing networkx matplotlib sphinx_rtd_theme From 3b8420e922919165e5fd3081cefec5241d389dcf Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Tue, 9 Jun 2015 11:31:58 -0500 Subject: [PATCH 19/42] amending, removed remaining classmethods from Relation --- datajoint/declare.py | 2 +- datajoint/relation.py | 60 ++++++++++++++++--------------------- datajoint/user_relations.py | 16 +++++----- 3 files changed, 35 insertions(+), 43 deletions(-) diff --git a/datajoint/declare.py b/datajoint/declare.py index 39d375701..bde62a7e4 100644 --- a/datajoint/declare.py +++ b/datajoint/declare.py @@ -13,7 +13,7 @@ def compile_attribute(line, in_key=False): Convert attribute definition from DataJoint format to SQL :param line: attribution line :param in_key: set to True if attribute is in primary key set - :returns: attribute name and sql code for its declaration + :returns: (name, sql) -- attribute name and sql code for its declaration """ quoted = pp.Or(pp.QuotedString('"'), pp.QuotedString("'")) colon = pp.Literal(':').suppress() diff --git a/datajoint/relation.py b/datajoint/relation.py index 7b7acd5b6..5533853ee 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -5,13 +5,13 @@ import abc import pymysql - from . import DataJointError, config, conn from .declare import declare from .relational_operand import RelationalOperand from .blob import pack from .utils import user_choice from .heading import Heading +from .declare import compile_attribute logger = logging.getLogger(__name__) @@ -210,7 +210,7 @@ def drop(self): if not config['safemode'] or user_choice( "You are about to drop an entire table. This operation cannot be undone.\n" "Proceed?", default='no') == 'yes': - self.connection.query('DROP TABLE %s' % cls.full_table_name) # TODO: make cascading (issue #16) + self.connection.query('DROP TABLE %s' % self.full_table_name) # TODO: make cascading (issue #16) # cls.connection.clear_dependencies(dbname=cls.dbname) #TODO: reimplement because clear_dependencies will be gone # cls.connection.load_headings(dbname=cls.dbname, force=True) #TODO: reimplement because load_headings is gone logger.info("Dropped table %s" % cls.full_table_name) @@ -219,22 +219,20 @@ def size_on_disk(self): """ :return: size of data and indices in MiB taken by the table on the storage device """ - cur = cls.connection.query( - 'SHOW TABLE STATUS FROM `{dbname}` WHERE NAME="{table}"'.format( - dbname=cls.dbname, table=cls.table_name), as_dict=True) - ret = cur.fetchone() + ret = self.connection.query( + 'SHOW TABLE STATUS FROM `(database}` WHERE NAME="{table}"'.format( + database=self.database, table=self.table_name), as_dict=True + ).fetchone() return (ret['Data_length'] + ret['Index_length'])/1024**2 - @classmethod - def set_table_comment(cls, comment): + def set_table_comment(self, comment): """ Update the table comment in the table definition. :param comment: new comment as string """ - cls.alter('COMMENT="%s"' % comment) + self._alter('COMMENT="%s"' % comment) - @classmethod - def add_attribute(cls, definition, after=None): + def add_attribute(self, definition, after=None): """ Add a new attribute to the table. A full line from the table definition is passed in as definition. @@ -249,51 +247,45 @@ def add_attribute(cls, definition, after=None): """ position = ' FIRST' if after is None else ( ' AFTER %s' % after if after else '') - sql = field_to_sql(parse_attribute_definition(definition)) - cls._alter('ADD COLUMN %s%s' % (sql[:-2], position)) + sql = compile_attribute(definition)[1] + self._alter('ADD COLUMN %s%s' % (sql, position)) - @classmethod - def drop_attribute(cls, attr_name): + def drop_attribute(self, attribute_name): """ Drops the attribute attrName from this table. - - :param attr_name: Name of the attribute that is dropped. + :param attribute_name: Name of the attribute that is dropped. """ if not config['safemode'] or user_choice( "You are about to drop an attribute from a table." "This operation cannot be undone.\n" "Proceed?", default='no') == 'yes': - cls._alter('DROP COLUMN `%s`' % attr_name) + self._alter('DROP COLUMN `%s`' % attribute_name) - @classmethod - def alter_attribute(cls, attr_name, new_definition): + def alter_attribute(self, attribute_name, definition): """ Alter the definition of the field attr_name in this table using the new definition. - :param attr_name: field that is redefined - :param new_definition: new definition of the field + :param attribute_name: field that is redefined + :param definition: new definition of the field """ - sql = field_to_sql(parse_attribute_definition(new_definition)) - cls._alter('CHANGE COLUMN `%s` %s' % (attr_name, sql[:-2])) + sql = compile_attribute(definition)[1] + self._alter('CHANGE COLUMN `%s` %s' % (attribute_name, sql)) - @classmethod - def erd(cls, subset=None): + def erd(self, subset=None): """ Plot the schema's entity relationship diagram (ERD). """ + NotImplemented - @classmethod - def _alter(cls, alter_statement): + def _alter(self, alter_statement): """ Execute ALTER TABLE statement for this table. The schema will be reloaded within the connection object. :param alter_statement: alter statement """ - if cls.connection.in_transaction: + if self.connection.in_transaction: raise DataJointError("Table definition cannot be altered during a transaction.") - - sql = 'ALTER TABLE %s %s' % (cls.full_table_name, alter_statement) - cls.connection.query(sql) - cls.connection.load_headings(cls.dbname, force=True) - # TODO: place table definition sync mechanism \ No newline at end of file + sql = 'ALTER TABLE %s %s' % (self.full_table_name, alter_statement) + self.connection.query(sql) + self._table_info.heading = None diff --git a/datajoint/user_relations.py b/datajoint/user_relations.py index a3429f457..9431e28c2 100644 --- a/datajoint/user_relations.py +++ b/datajoint/user_relations.py @@ -4,23 +4,23 @@ from . import DataJointError class Manual(Relation): - def table_name(cls): - return from_camel_case(cls.__name__) + def table_name(self): + return from_camel_case(self.__class__.__name__) class Lookup(Relation): - def table_name(cls): - return '#' + from_camel_case(cls.__name__) + def table_name(self): + return '#' + from_camel_case(self.__class__.__name__) class Imported(Relation, AutoPopulate): - def table_name(cls): - return "_" + from_camel_case(cls.__name__) + def table_name(self): + return "_" + from_camel_case(self.__class__.__name__) class Computed(Relation, AutoPopulate): - def table_name(cls): - return "__" + from_camel_case(cls.__name__) + def table_name(self): + return "__" + from_camel_case(self.__class__.__name__) def from_camel_case(s): From dafcbbc98d1309e37f284f11c468bcc993a5619c Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Tue, 9 Jun 2015 14:26:19 -0500 Subject: [PATCH 20/42] converted Heading.init_from_database into an instance method --- datajoint/connection.py | 5 --- datajoint/declare.py | 22 ++++++------- datajoint/heading.py | 12 +++---- datajoint/relation.py | 63 ++++++++++++++++--------------------- datajoint/user_relations.py | 6 ++++ 5 files changed, 50 insertions(+), 58 deletions(-) diff --git a/datajoint/connection.py b/datajoint/connection.py index 3ca01b6bd..0ba65b532 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -1,11 +1,6 @@ import pymysql -import re -from .utils import to_camel_case from . import DataJointError -from .heading import Heading -from .settings import prefix_to_role import logging -from .erd import DBConnGraph from . import config logger = logging.getLogger(__name__) diff --git a/datajoint/declare.py b/datajoint/declare.py index bde62a7e4..4551cd72a 100644 --- a/datajoint/declare.py +++ b/datajoint/declare.py @@ -32,7 +32,7 @@ def compile_attribute(line, in_key=False): match = {k: v.strip() for k, v in match.items()} match['nullable'] = match['default'].lower() == 'null' - sql_literals = ['CURRENT_TIMESTAMP'] # not to be enclosed in quotes + literals = ['CURRENT_TIMESTAMP'] # not to be enclosed in quotes assert not re.match(r'^bigint', match['type'], re.I) or not match['nullable'], \ 'BIGINT attributes cannot be nullable in "%s"' % line # TODO: This was a MATLAB limitation. Handle this correctly. if match['nullable']: @@ -41,12 +41,12 @@ def compile_attribute(line, in_key=False): match['default'] = 'DEFAULT NULL' # nullable attributes default to null else: if match['default']: - quote = match['default'].upper() not in sql_literals and match['default'][0] not in '"\'' + quote = match['default'].upper() not in literals and match['default'][0] not in '"\'' match['default'] = ('NOT NULL DEFAULT ' + ('"%s"' if quote else "%s") % match['default']) else: match['default'] = 'NOT NULL' - match['comment'] = match['comment'].replace('"','\\"') # escape double quotes in comment + match['comment'] = match['comment'].replace('"', '\\"') # escape double quotes in comment sql = ('`{name}` {type} {default}' + (' COMMENT "{comment}"' if match['comment'] else '') ).format(**match) return match['name'], sql @@ -110,7 +110,7 @@ def parse_declaration(cls): ref_list = parents if in_key else referenced ref_list.append(eval(ref_name, locals=cls.context)) elif re.match(r'^(unique\s+)?index[^:]*$', line, re.I): - index_defs.append(parse_index_definition(line)) + index_defs.append(parse_index(line)) elif attribute_regexp.match(line): field_defs.append(parse_attribute_definition(line, in_key)) else: @@ -119,23 +119,23 @@ def parse_declaration(cls): return table_info, parents, referenced, field_defs, index_defs -def declare(base_relation): +def declare(relation): """ Declares the table in the database if it does not exist already """ - cur = base_relation.connection.query( + cur = relation.connection.query( 'SHOW TABLE STATUS FROM `{database}` WHERE name="{table_name}"'.format( - database=base_relation.database, table_name=base_relation.table_name)) + database=relation.database, table_name=relation.table_name)) if cur.rowcount: return - if base_relation.connection.in_transaction: + if relation.connection.in_transaction: raise DataJointError("Tables cannot be declared during a transaction.") - if not base_relation.definition: - raise DataJointError('Table definition is missing.') + if not relation.definition: + raise DataJointError('Missing table definition.') - table_info, parents, referenced, field_defs, index_defs = cls._parse_declaration() + table_info, parents, referenced, field_defs, index_defs = parse_declaration() sql = 'CREATE TABLE %s (\n' % cls.full_table_name diff --git a/datajoint/heading.py b/datajoint/heading.py index 73d6c30f2..e37caf85e 100644 --- a/datajoint/heading.py +++ b/datajoint/heading.py @@ -17,11 +17,13 @@ class Heading: 'computation', 'dtype')) AttrTuple.as_dict = AttrTuple._asdict # renaming to make public - def __init__(self, attributes): + def __init__(self, attributes=None): """ :param attributes: a list of dicts with the same keys as AttrTuple """ - self.attributes = OrderedDict([(q['name'], Heading.AttrTuple(**q)) for q in attributes]) + if attributes: + attributes = OrderedDict([(q['name'], Heading.AttrTuple(**q)) for q in attributes]) + self.attributes = attributes @property def names(self): @@ -90,8 +92,7 @@ def items(self): def __iter__(self): return iter(self.attributes) - @classmethod - def init_from_database(cls, conn, database, table_name): + def init_from_database(self, conn, database, table_name): """ initialize heading from a database table """ @@ -163,8 +164,7 @@ def init_from_database(cls, conn, database, table_name): t = re.sub(r' unsigned$', '', t) # remove unsigned assert (t, is_unsigned) in numeric_types, 'dtype not found for type %s' % t attr['dtype'] = numeric_types[(t, is_unsigned)] - - return cls(attributes) + self.attributes = OrderedDict([(q['name'], Heading.AttrTuple(**q)) for q in attributes]) def project(self, *attribute_list, **renamed_attributes): """ diff --git a/datajoint/relation.py b/datajoint/relation.py index 5533853ee..56c4eb11a 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -1,4 +1,4 @@ -from collections import namedtuple +from types import SimpleNamespace from collections.abc import Mapping import numpy as np import logging @@ -6,18 +6,16 @@ import pymysql from . import DataJointError, config, conn -from .declare import declare +from .declare import declare, compile_attribute from .relational_operand import RelationalOperand from .blob import pack from .utils import user_choice from .heading import Heading -from .declare import compile_attribute logger = logging.getLogger(__name__) -TableLink = namedtuple( - 'TableLink', - ('database', 'context', 'connection', 'heading')) +TableLink = namedtuple('TableLink', + ('database', 'context', 'connection', 'heading')) def schema(database, context, connection=None): @@ -49,11 +47,11 @@ def decorator(cls): """ The decorator declares the table and binds the class to the database table """ - cls._table_info = TableLink( + cls._table_link = TableLink( database=database, context=context, connection=connection, - heading=None + heading=Heading() ) declare(cls()) return cls @@ -72,11 +70,11 @@ class Relation(RelationalOperand, metaclass=abc.ABCMeta): It also handles the table declaration based on its definition property """ - _table_info = None + _table_link = None def __init__(self): - if self._table_info is None: - raise DataJointError('The class must define _table_info') + if self._table_link is None: + raise DataJointError('The class must define _table_link') # ---------- abstract properties ------------ # @property @@ -85,7 +83,7 @@ def table_name(self): """ :return: the name of the table in the database """ - pass + raise NotImplementedError('Relation subclasses must define property table_name') @property @abc.abstractmethod @@ -98,37 +96,34 @@ def definition(self): # -------------- table info ----------------- # @property def connection(self): - return self._table_info.connection + return self._table_link.connection @property def database(self): - return self._table_info.database + return self._table_link.database @property def context(self): - return self._table_info.context + return self._table_link.context @property def heading(self): - if self._table_info.heading is None: - self._table_info.heading = Heading.init_from_database( - self.connection, self.database, self.table_name) - return self._table_info.heading - + heading = self._table_link.heading + if not heading: + heading.init_from_database(self.connection, self.database, self.table_name) + return heading # --------- SQL functionality --------- # @property def from_clause(self): """ - Required by the Relational class, this property specifies the contents of the FROM clause - for the SQL SELECT statements. - :return: + :return: the FROM clause of SQL SELECT statements. """ return self.full_table_name def iter_insert(self, rows, **kwargs): """ - Inserts an entire batch of entries. Additional keyword arguments are passed to insert. + Inserts a collection of tuples. Additional keyword arguments are passed to insert. :param iter: Must be an iterator that generates a sequence of valid arguments for insert. """ @@ -172,18 +167,16 @@ def insert(self, tup, ignore_errors=False, replace=False): args = tuple(pack(tup[name]) for name in heading if name in tup.dtype.fields and heading[name].is_blob) - attribute_list = '`' + '`,`'.join( - [q for q in heading if q in tup.dtype.fields]) + '`' + attribute_list = '`' + '`,`'.join(q for q in heading if q in tup.dtype.fields) + '`' elif isinstance(tup, Mapping): for fieldname in tup.keys(): if fieldname not in heading: raise KeyError(u'{0:s} is not in the attribute list'.format(fieldname, )) - value_list = ','.join([repr(tup[name]) if not heading[name].is_blob else '%s' - for name in heading if name in tup]) + value_list = ','.join(repr(tup[name]) if not heading[name].is_blob else '%s' + for name in heading if name in tup) args = tuple(pack(tup[name]) for name in heading if name in tup and heading[name].is_blob) - attribute_list = '`' + '`,`'.join( - [name for name in heading if name in tup]) + '`' + attribute_list = '`' + '`,`'.join(name for name in heading if name in tup) + '`' else: raise DataJointError('Datatype %s cannot be inserted' % type(tup)) if replace: @@ -213,7 +206,7 @@ def drop(self): self.connection.query('DROP TABLE %s' % self.full_table_name) # TODO: make cascading (issue #16) # cls.connection.clear_dependencies(dbname=cls.dbname) #TODO: reimplement because clear_dependencies will be gone # cls.connection.load_headings(dbname=cls.dbname, force=True) #TODO: reimplement because load_headings is gone - logger.info("Dropped table %s" % cls.full_table_name) + logger.info("Dropped table %s" % self.full_table_name) def size_on_disk(self): """ @@ -263,7 +256,7 @@ def drop_attribute(self, attribute_name): def alter_attribute(self, attribute_name, definition): """ - Alter the definition of the field attr_name in this table using the new definition. + Alter attribute definition :param attribute_name: field that is redefined :param definition: new definition of the field @@ -279,13 +272,11 @@ def erd(self, subset=None): def _alter(self, alter_statement): """ - Execute ALTER TABLE statement for this table. The schema - will be reloaded within the connection object. - + Execute an ALTER TABLE statement. :param alter_statement: alter statement """ if self.connection.in_transaction: raise DataJointError("Table definition cannot be altered during a transaction.") sql = 'ALTER TABLE %s %s' % (self.full_table_name, alter_statement) self.connection.query(sql) - self._table_info.heading = None + self._table_link.heading = None \ No newline at end of file diff --git a/datajoint/user_relations.py b/datajoint/user_relations.py index 9431e28c2..767f87b33 100644 --- a/datajoint/user_relations.py +++ b/datajoint/user_relations.py @@ -3,26 +3,32 @@ from .autopopulate import AutoPopulate from . import DataJointError + class Manual(Relation): + @property def table_name(self): return from_camel_case(self.__class__.__name__) class Lookup(Relation): + @property def table_name(self): return '#' + from_camel_case(self.__class__.__name__) class Imported(Relation, AutoPopulate): + @property def table_name(self): return "_" + from_camel_case(self.__class__.__name__) class Computed(Relation, AutoPopulate): + @property def table_name(self): return "__" + from_camel_case(self.__class__.__name__) +# ---------------- utilities -------------------- def from_camel_case(s): """ Convert names in camel case into underscore (_) separated names From 76a94b461a5ab51b0c39843569551dcd3419d4b5 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Tue, 9 Jun 2015 14:48:56 -0500 Subject: [PATCH 21/42] fixed heading initialization --- datajoint/declare.py | 4 ++-- datajoint/heading.py | 3 +++ datajoint/relation.py | 4 ++-- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/datajoint/declare.py b/datajoint/declare.py index 4551cd72a..444c5e4af 100644 --- a/datajoint/declare.py +++ b/datajoint/declare.py @@ -149,8 +149,8 @@ def declare(relation): primary_key_fields.add(field.name) sql += field_to_sql(field) else: - logger.debug('Field definition of {} in {} ignored'.format( - field.name, p.full_class_name)) + logger.debug( + 'Field definition of {} in {} ignored'.format(field.name, p.full_class_name)) # add newly defined primary key fields for field in (f for f in field_defs if f.in_key): diff --git a/datajoint/heading.py b/datajoint/heading.py index e37caf85e..6596a0673 100644 --- a/datajoint/heading.py +++ b/datajoint/heading.py @@ -25,6 +25,9 @@ def __init__(self, attributes=None): attributes = OrderedDict([(q['name'], Heading.AttrTuple(**q)) for q in attributes]) self.attributes = attributes + def __bool__(self): + return self.attributes is not None + @property def names(self): return [k for k in self.attributes] diff --git a/datajoint/relation.py b/datajoint/relation.py index 56c4eb11a..c38f0ed87 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -1,4 +1,4 @@ -from types import SimpleNamespace +from collections import namedtuple from collections.abc import Mapping import numpy as np import logging @@ -138,10 +138,10 @@ def batch_insert(self, data, **kwargs): """ self.iter_insert(data.__iter__(), **kwargs) + @property def full_table_name(self): return r"`{0:s}`.`{1:s}`".format(self.database, self.table_name) - @classmethod def insert(self, tup, ignore_errors=False, replace=False): """ Insert one data record or one Mapping (like a dictionary). From bab7460d12c78c84f4954782bbd35be0f4fc91ce Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Tue, 9 Jun 2015 15:13:04 -0500 Subject: [PATCH 22/42] implemented dj.Subordinate to support subtables --- datajoint/__init__.py | 2 +- datajoint/autopopulate.py | 6 +++--- datajoint/user_relations.py | 21 +++++++++++++++++---- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/datajoint/__init__.py b/datajoint/__init__.py index 796ded488..b562aa8d1 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -38,7 +38,7 @@ class DataJointError(Exception): # ------------- flatten import hierarchy ------------------------- from .connection import conn, Connection from .relation import Relation -from .user_relations import Manual, Lookup, Imported, Computed +from .user_relations import Manual, Lookup, Imported, Computed, Subordinate from .autopopulate import AutoPopulate from . import blob from .relational_operand import Not diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index 5db012eff..cb4c0d673 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -37,7 +37,7 @@ def _make_tuples(self, key): def target(self): return self - def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False, max_attempts=10): + def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False): """ rel.populate() calls rel._make_tuples(key) for every primary key in self.populate_relation for which there is not already a tuple in rel. @@ -45,7 +45,6 @@ def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False, :param restriction: restriction on rel.populate_relation - target :param suppress_errors: suppresses error if true :param reserve_jobs: currently not implemented - :param max_attempts: maximal number of times a TransactionError is caught before populate gives up """ assert not reserve_jobs, NotImplemented # issue #5 @@ -56,7 +55,8 @@ def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False, self.connection.cancel_transaction() # rollback previous transaction, if any if not isinstance(self, Relation): - raise DataJointError('Autopopulate is a mixin for Relation and must therefore subclass Relation') + raise DataJointError( + 'AutoPopulate is a mixin for Relation and must therefore subclass Relation') unpopulated = (self.populate_relation - self.target) & restriction for key in unpopulated.project(): diff --git a/datajoint/user_relations.py b/datajoint/user_relations.py index 767f87b33..8fa469504 100644 --- a/datajoint/user_relations.py +++ b/datajoint/user_relations.py @@ -1,33 +1,46 @@ import re +import abc from datajoint.relation import Relation from .autopopulate import AutoPopulate from . import DataJointError -class Manual(Relation): +class Manual(Relation, metaclass=abc.ABCMeta): @property def table_name(self): return from_camel_case(self.__class__.__name__) -class Lookup(Relation): +class Lookup(Relation, metaclass=abc.ABCMeta): @property def table_name(self): return '#' + from_camel_case(self.__class__.__name__) -class Imported(Relation, AutoPopulate): +class Imported(Relation, AutoPopulate, metaclass=abc.ABCMeta): @property def table_name(self): return "_" + from_camel_case(self.__class__.__name__) -class Computed(Relation, AutoPopulate): +class Computed(Relation, AutoPopulate, metaclass=abc.ABCMeta): @property def table_name(self): return "__" + from_camel_case(self.__class__.__name__) +class Subordinate: + """ + Mix-in to make computed tables subordinate + """ + @property + def populate_relation(self): + return None + + def _make_tuples(self, key): + raise NotImplementedError('_make_tuples not defined.') + + # ---------------- utilities -------------------- def from_camel_case(s): """ From 706dff0cf5174171ab93f2310c0cbc01f1034e6a Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Tue, 9 Jun 2015 16:58:17 -0500 Subject: [PATCH 23/42] In Relation, split table_info attribute into its parts; made declare() independent --- datajoint/declare.py | 4 +-- datajoint/relation.py | 51 +++++++++++++------------------------ datajoint/user_relations.py | 10 ++++---- 3 files changed, 23 insertions(+), 42 deletions(-) diff --git a/datajoint/declare.py b/datajoint/declare.py index 444c5e4af..99d189464 100644 --- a/datajoint/declare.py +++ b/datajoint/declare.py @@ -2,8 +2,6 @@ import pyparsing as pp import logging -from . import DataJointError - logger = logging.getLogger(__name__) @@ -119,7 +117,7 @@ def parse_declaration(cls): return table_info, parents, referenced, field_defs, index_defs -def declare(relation): +def declare(full_table_name, definition, context): """ Declares the table in the database if it does not exist already """ diff --git a/datajoint/relation.py b/datajoint/relation.py index c38f0ed87..463ed0424 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -14,9 +14,6 @@ logger = logging.getLogger(__name__) -TableLink = namedtuple('TableLink', - ('database', 'context', 'connection', 'heading')) - def schema(database, context, connection=None): """ @@ -47,13 +44,16 @@ def decorator(cls): """ The decorator declares the table and binds the class to the database table """ - cls._table_link = TableLink( - database=database, - context=context, - connection=connection, - heading=Heading() - ) - declare(cls()) + cls.database = database + cls._connection = connection + cls._heading = Heading() + instance = cls() if isinstance(cls, type) else cls + if not cls.heading: + cls.connection.query( + declare( + table_name=instance.full_table_name, + definition=instance.definition, + context=context)) return cls return decorator @@ -66,16 +66,8 @@ class Relation(RelationalOperand, metaclass=abc.ABCMeta): To make it a concrete class, override the abstract properties specifying the connection, table name, database, context, and definition. A Relation implements insert and delete methods in addition to inherited relational operators. - It also loads table heading and dependencies from the database. - It also handles the table declaration based on its definition property """ - _table_link = None - - def __init__(self): - if self._table_link is None: - raise DataJointError('The class must define _table_link') - # ---------- abstract properties ------------ # @property @abc.abstractmethod @@ -93,27 +85,17 @@ def definition(self): """ pass - # -------------- table info ----------------- # + # -------------- required by RelationalOperand ----------------- # @property def connection(self): - return self._table_link.connection - - @property - def database(self): - return self._table_link.database - - @property - def context(self): - return self._table_link.context + return self._connection @property def heading(self): - heading = self._table_link.heading - if not heading: - heading.init_from_database(self.connection, self.database, self.table_name) - return heading + if not self._heading: + self._heading.init_from_database(self.connection, self.database, self.table_name) + return self._heading - # --------- SQL functionality --------- # @property def from_clause(self): """ @@ -130,6 +112,7 @@ def iter_insert(self, rows, **kwargs): for row in rows: self.insert(row, **kwargs) + # --------- SQL functionality --------- # def batch_insert(self, data, **kwargs): """ Inserts an entire batch of entries. Additional keyword arguments are passed to insert. @@ -279,4 +262,4 @@ def _alter(self, alter_statement): raise DataJointError("Table definition cannot be altered during a transaction.") sql = 'ALTER TABLE %s %s' % (self.full_table_name, alter_statement) self.connection.query(sql) - self._table_link.heading = None \ No newline at end of file + self.heading.init_from_database(self.connection, self.database, self.table_name) \ No newline at end of file diff --git a/datajoint/user_relations.py b/datajoint/user_relations.py index 8fa469504..1b0352a03 100644 --- a/datajoint/user_relations.py +++ b/datajoint/user_relations.py @@ -5,25 +5,25 @@ from . import DataJointError -class Manual(Relation, metaclass=abc.ABCMeta): +class Manual(Relation): @property def table_name(self): return from_camel_case(self.__class__.__name__) -class Lookup(Relation, metaclass=abc.ABCMeta): +class Lookup(Relation): @property def table_name(self): return '#' + from_camel_case(self.__class__.__name__) -class Imported(Relation, AutoPopulate, metaclass=abc.ABCMeta): +class Imported(Relation, AutoPopulate): @property def table_name(self): return "_" + from_camel_case(self.__class__.__name__) -class Computed(Relation, AutoPopulate, metaclass=abc.ABCMeta): +class Computed(Relation, AutoPopulate): @property def table_name(self): return "__" + from_camel_case(self.__class__.__name__) @@ -38,7 +38,7 @@ def populate_relation(self): return None def _make_tuples(self, key): - raise NotImplementedError('_make_tuples not defined.') + raise NotImplementedError('Subtables should not be populated directly.') # ---------------- utilities -------------------- From 2b261fcbffe54506a388928caa860eacfca3069a Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Tue, 9 Jun 2015 17:20:47 -0500 Subject: [PATCH 24/42] typo --- datajoint/relation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datajoint/relation.py b/datajoint/relation.py index 463ed0424..98fa24309 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -168,7 +168,7 @@ def insert(self, tup, ignore_errors=False, replace=False): sql = 'INSERT IGNORE' else: sql = 'INSERT' - sql += " INTO %s (%s) VALUES (%s)" % (self.from_caluse, attribute_list, value_list) + sql += " INTO %s (%s) VALUES (%s)" % (self.from_clause, attribute_list, value_list) logger.info(sql) self.connection.query(sql, args=args) From de8f0098e211bbdeb9bafeb9e312046b06c21fca Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Tue, 9 Jun 2015 17:26:29 -0500 Subject: [PATCH 25/42] typo --- datajoint/relation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datajoint/relation.py b/datajoint/relation.py index 98fa24309..958e6ed27 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -196,7 +196,7 @@ def size_on_disk(self): :return: size of data and indices in MiB taken by the table on the storage device """ ret = self.connection.query( - 'SHOW TABLE STATUS FROM `(database}` WHERE NAME="{table}"'.format( + 'SHOW TABLE STATUS FROM `{database}` WHERE NAME="{table}"'.format( database=self.database, table=self.table_name), as_dict=True ).fetchone() return (ret['Data_length'] + ret['Index_length'])/1024**2 From c8b9400b313d69196529dc95fe572f4b8fa051b8 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 12 Jun 2015 15:56:51 -0500 Subject: [PATCH 26/42] debugged and optimized table declaration in declare.py --- datajoint/declare.py | 221 ++++++++++++------------------------------ datajoint/heading.py | 60 +++++++++--- datajoint/relation.py | 17 +++- demos/demo1.py | 38 +++++--- 4 files changed, 140 insertions(+), 196 deletions(-) diff --git a/datajoint/declare.py b/datajoint/declare.py index 99d189464..b3c2968ca 100644 --- a/datajoint/declare.py +++ b/datajoint/declare.py @@ -6,6 +6,66 @@ logger = logging.getLogger(__name__) + +def declare(full_table_name, definition, context): + """ + Parse declaration and create new SQL table accordingly. + """ + # split definition into lines + definition = re.split(r'\s*\n\s*', definition.strip()) + + table_comment = definition.pop(0)[1:] if definition[0].startswith('#') else '' + + in_key = True # parse primary keys + primary_key = [] + attributes = [] + attribute_sql = [] + foreign_key_sql = [] + index_sql = [] + + for line in definition: + if line.startswith('#'): # additional comments are ignored + pass + elif line.startswith('---'): + in_key = False # start parsing dependent attributes + elif line.startswith('->'): + # foreign key + ref = eval(line[2:], context)() + foreign_key_sql.append( + 'FOREIGN KEY ({primary_key})' + ' REFERENCES {ref} ({primary_key})' + ' ON UPDATE CASCADE ON DELETE RESTRICT'.format( + primary_key='`' + '`,`'.join(primary_key) + '`', ref=ref.full_table_name) + ) + for name in ref.primary_key: + if in_key and name not in primary_key: + primary_key.append(name) + if name not in attributes: + attributes.append(name) + attribute_sql.append(ref.heading[name].sql()) + elif re.match(r'^(unique\s+)?index[^:]*$', line, re.I): # index + index_sql.append(line) # the SQL syntax is identical to DataJoint's + else: + name, sql = compile_attribute(line, in_key) + if in_key and name not in primary_key: + primary_key.append(name) + if name not in attributes: + attributes.append(name) + attribute_sql.append(sql) + + # compile SQL + sql = 'CREATE TABLE %s (\n ' % full_table_name + sql += ', \n'.join(attribute_sql) + if foreign_key_sql: + sql += ', \n' + ', \n'.join(foreign_key_sql) + if index_sql: + sql += ', \n' + ', \n'.join(index_sql) + sql += '\n) ENGINE = InnoDB, COMMENT "%s"' % table_comment + return sql + + + + def compile_attribute(line, in_key=False): """ Convert attribute definition from DataJoint format to SQL @@ -31,8 +91,6 @@ def compile_attribute(line, in_key=False): match['nullable'] = match['default'].lower() == 'null' literals = ['CURRENT_TIMESTAMP'] # not to be enclosed in quotes - assert not re.match(r'^bigint', match['type'], re.I) or not match['nullable'], \ - 'BIGINT attributes cannot be nullable in "%s"' % line # TODO: This was a MATLAB limitation. Handle this correctly. if match['nullable']: if in_key: raise DataJointError('Primary key attributes cannot be nullable in line %s' % line) @@ -48,162 +106,3 @@ def compile_attribute(line, in_key=False): sql = ('`{name}` {type} {default}' + (' COMMENT "{comment}"' if match['comment'] else '') ).format(**match) return match['name'], sql - - -def parse_index(line): - """ - Parses index definition. - - :param line: definition line - :return: groupdict with index info - """ - line = line.strip() - index_regexp = re.compile(""" - ^(?PUNIQUE)?\s*INDEX\s* # [UNIQUE] INDEX - \((?P[^\)]+)\)$ # (attr1, attr2) - """, re.I + re.X) - m = index_regexp.match(line) - assert m, 'Invalid index declaration "%s"' % line - index_info = m.groupdict() - attributes = re.split(r'\s*,\s*', index_info['attributes'].strip()) - index_info['attributes'] = attributes - assert len(attributes) == len(set(attributes)), \ - 'Duplicate attributes in index declaration "%s"' % line - return index_info - - -def parse_declaration(cls): - """ - Parse declaration and create new SQL table accordingly. - """ - parents = [] - referenced = [] - index_defs = [] - field_defs = [] - declaration = re.split(r'\s*\n\s*', cls.definition.strip()) - - # remove comment lines - declaration = [x for x in declaration if not x.startswith('#')] - ptrn = """ - \#\s*(?P.*)$ # comment - """ - p = re.compile(ptrn, re.X) - table_info = p.search(declaration[0]).groupdict() - - #table_info['tier'] = Role[table_info['tier']] # convert into enum - - in_key = True # parse primary keys - attribute_regexp = re.compile(""" - ^[a-z][a-z\d_]*\s* # name - (=\s*\S+(\s+\S+)*\s*)? # optional defaults - :\s*\w.*$ # type, comment - """, re.I + re.X) # ignore case and verbose - - for line in declaration[1:]: - if line.startswith('---'): - in_key = False # start parsing non-PK fields - elif line.startswith('->'): - # foreign key - ref_name = line[2:].strip() - ref_list = parents if in_key else referenced - ref_list.append(eval(ref_name, locals=cls.context)) - elif re.match(r'^(unique\s+)?index[^:]*$', line, re.I): - index_defs.append(parse_index(line)) - elif attribute_regexp.match(line): - field_defs.append(parse_attribute_definition(line, in_key)) - else: - raise DataJointError('Invalid table declaration line "%s"' % line) - - return table_info, parents, referenced, field_defs, index_defs - - -def declare(full_table_name, definition, context): - """ - Declares the table in the database if it does not exist already - """ - cur = relation.connection.query( - 'SHOW TABLE STATUS FROM `{database}` WHERE name="{table_name}"'.format( - database=relation.database, table_name=relation.table_name)) - if cur.rowcount: - return - - if relation.connection.in_transaction: - raise DataJointError("Tables cannot be declared during a transaction.") - - if not relation.definition: - raise DataJointError('Missing table definition.') - - table_info, parents, referenced, field_defs, index_defs = parse_declaration() - - sql = 'CREATE TABLE %s (\n' % cls.full_table_name - - # add inherited primary key fields - primary_key_fields = set() - non_key_fields = set() - for p in parents: - for key in p.primary_key: - field = p.heading[key] - if field.name not in primary_key_fields: - primary_key_fields.add(field.name) - sql += field_to_sql(field) - else: - logger.debug( - 'Field definition of {} in {} ignored'.format(field.name, p.full_class_name)) - - # add newly defined primary key fields - for field in (f for f in field_defs if f.in_key): - if field.nullable: - raise DataJointError('Primary key attribute {} cannot be nullable'.format( - field.name)) - if field.name in primary_key_fields: - raise DataJointError('Duplicate declaration of the primary attribute {key}. ' - 'Ensure that the attribute is not already declared ' - 'in referenced tables'.format(key=field.name)) - primary_key_fields.add(field.name) - sql += field_to_sql(field) - - # add secondary foreign key attributes - for r in referenced: - for key in r.primary_key: - field = r.heading[key] - if field.name not in primary_key_fields | non_key_fields: - non_key_fields.add(field.name) - sql += field_to_sql(field) - - # add dependent attributes - for field in (f for f in field_defs if not f.in_key): - non_key_fields.add(field.name) - sql += field_to_sql(field) - - # add primary key declaration - assert len(primary_key_fields) > 0, 'table must have a primary key' - keys = ', '.join(primary_key_fields) - sql += 'PRIMARY KEY (%s),\n' % keys - - # add foreign key declarations - for ref in parents + referenced: - keys = ', '.join(ref.primary_key) - sql += 'FOREIGN KEY (%s) REFERENCES %s (%s) ON UPDATE CASCADE ON DELETE RESTRICT,\n' % \ - (keys, ref.full_table_name, keys) - - # add secondary index declarations - # gather implicit indexes due to foreign keys first - implicit_indices = [] - for fk_source in parents + referenced: - implicit_indices.append(fk_source.primary_key) - - # for index in indexDefs: - # TODO: add index declaration - - # close the declaration - sql = '%s\n) ENGINE = InnoDB, COMMENT "%s"' % ( - sql[:-2], table_info['comment']) - - # # make sure that the table does not already exist - # cls.load_heading() - # if not cls.is_declared: - # # execute declaration - # logger.debug('\n\n' + sql + '\n\n') - # cls.connection.query(sql) - # cls.load_heading() - diff --git a/datajoint/heading.py b/datajoint/heading.py index 6596a0673..71ba2b8cf 100644 --- a/datajoint/heading.py +++ b/datajoint/heading.py @@ -11,11 +11,38 @@ class Heading: Heading contains the property attributes, which is an OrderedDict in which the keys are the attribute names and the values are AttrTuples. """ - AttrTuple = namedtuple('AttrTuple', + + class AttrTuple(namedtuple('AttrTuple', ('name', 'type', 'in_key', 'nullable', 'default', 'comment', 'autoincrement', 'numeric', 'string', 'is_blob', - 'computation', 'dtype')) - AttrTuple.as_dict = AttrTuple._asdict # renaming to make public + 'computation', 'dtype'))): + def _asdict(self): + """ + for some reason the inherted _asdict does not work after subclassing from namedtuple + """ + return OrderedDict((name, self[i]) for i, name in enumerate(self._fields)) + + + def sql(self): + """ + Convert attribute tuple into its SQL CREATE TABLE clause. + :rtype : SQL code + """ + literals = ['CURRENT_TIMESTAMP'] + if self.nullable: + default = 'DEFAULT NULL' + else: + default = 'NOT NULL' + if self.default: + # enclose value in quotes except special SQL values or already enclosed + quote = self.default.upper() not in literals and self.default[0] not in '"\'' + default += ' DEFAULT ' + ('"%s"' if quote else "%s") % self.default + if any((c in r'\"' for c in self.comment)): + raise DataJointError('Illegal characters in attribute comment "%s"' % self.comment) + return '`{name}` {type} {default} COMMENT "{comment}"'.format( + name=self.name, type=self.type, default=default, comment=self.comment) + + def __init__(self, attributes=None): """ @@ -57,12 +84,14 @@ def __getitem__(self, name): return self.attributes[name] def __repr__(self): - autoincrement_string = {False: '', True: ' auto_increment'} - return '\n'.join(['%-20s : %-28s # %s' % ( - k if v.default is None else '%s="%s"' % (k, v.default), - '%s%s' % (v.type, autoincrement_string[v.autoincrement]), - v.comment) - for k, v in self.attributes.items()]) + if self.attributes is None: + return 'Empty heading' + else: + return '\n'.join(['%-20s : %-28s # %s' % ( + k if v.default is None else '%s="%s"' % (k, v.default), + '%s%s' % (v.type, 'auto_increment' if v.autoincrement else ''), + v.comment) + for k, v in self.attributes.items()]) @property def as_dtype(self): @@ -97,11 +126,12 @@ def __iter__(self): def init_from_database(self, conn, database, table_name): """ - initialize heading from a database table + initialize heading from a database table. The table must exist already. """ cur = conn.query( 'SHOW FULL COLUMNS FROM `{table_name}` IN `{database}`'.format( table_name=table_name, database=database), as_dict=True) + attributes = cur.fetchall() rename_map = { @@ -184,14 +214,14 @@ def project(self, *attribute_list, **renamed_attributes): attribute_list = self.primary_key + [a for a in attribute_list if a not in self.primary_key] # convert attribute_list into a list of dicts but exclude renamed attributes - attribute_list = [v.as_dict() for k, v in self.attributes.items() + attribute_list = [v._asdict() for k, v in self.attributes.items() if k in attribute_list and k not in renamed_attributes.values()] # add renamed and computed attributes for new_name, computation in renamed_attributes.items(): if computation in self.names: # renamed attribute - new_attr = self.attributes[computation].as_dict() + new_attr = self.attributes[computation]._asdict() new_attr['name'] = new_name new_attr['computation'] = '`' + computation + '`' else: @@ -218,14 +248,14 @@ def __add__(self, other): join two headings """ assert isinstance(other, Heading) - attribute_list = [v.as_dict() for v in self.attributes.values()] + attribute_list = [v._asdict() for v in self.attributes.values()] for name in other.names: if name not in self.names: - attribute_list.append(other.attributes[name].as_dict()) + attribute_list.append(other.attributes[name]._asdict()) return Heading(attribute_list) def resolve(self): """ Remove attribute computations after they have been resolved in a subquery """ - return Heading([dict(v.as_dict(), computation=None) for v in self.attributes.values()]) \ No newline at end of file + return Heading([dict(v._asdict(), computation=None) for v in self.attributes.values()]) \ No newline at end of file diff --git a/datajoint/relation.py b/datajoint/relation.py index 958e6ed27..8f558831f 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -48,10 +48,10 @@ def decorator(cls): cls._connection = connection cls._heading = Heading() instance = cls() if isinstance(cls, type) else cls - if not cls.heading: - cls.connection.query( + if not instance.heading: + connection.query( declare( - table_name=instance.full_table_name, + full_table_name=instance.full_table_name, definition=instance.definition, context=context)) return cls @@ -92,7 +92,7 @@ def connection(self): @property def heading(self): - if not self._heading: + if not self._heading and self.is_declared: self._heading.init_from_database(self.connection, self.database, self.table_name) return self._heading @@ -113,6 +113,13 @@ def iter_insert(self, rows, **kwargs): self.insert(row, **kwargs) # --------- SQL functionality --------- # + @property + def is_declared(self): + cur = self.connection.query( + 'SHOW TABLES in `{database}`LIKE "{table_name}"'.format( + database=self.database, table_name=self.table_name)) + return cur.rowcount>0 + def batch_insert(self, data, **kwargs): """ Inserts an entire batch of entries. Additional keyword arguments are passed to insert. @@ -193,7 +200,7 @@ def drop(self): def size_on_disk(self): """ - :return: size of data and indices in MiB taken by the table on the storage device + :return: size of data and indices in GiB taken by the table on the storage device """ ret = self.connection.query( 'SHOW TABLE STATUS FROM `{database}` WHERE NAME="{table}"'.format( diff --git a/demos/demo1.py b/demos/demo1.py index 4376a4982..d10dc30e3 100644 --- a/demos/demo1.py +++ b/demos/demo1.py @@ -4,13 +4,12 @@ print("Welcome to the database 'demo1'") -conn = dj.conn() # connect to database; conn must be defined in module namespace -conn.bind(module=__name__, dbname='dj_test') # bind this module to the database +schema = dj.schema('dj_test', locals()) - -class Subject(dj.Relation): +@schema +class Subject(dj.Manual): definition = """ - demo1.Subject (manual) # Basic subject info + # Basic subject info subject_id : int # internal subject id --- real_id : varchar(40) # real-world name @@ -22,9 +21,10 @@ class Subject(dj.Relation): """ -class Experiment(dj.Relation): +@schema +class Experiment(dj.Manual): definition = """ - demo1.Experiment (manual) # Basic subject info + # Basic subject info -> demo1.Subject experiment : smallint # experiment number for this subject --- @@ -35,9 +35,11 @@ class Experiment(dj.Relation): """ -class Session(dj.Relation): +@schema +class Session(dj.Manual): definition = """ - demo1.Session (manual) # a two-photon imaging session + # a two-photon imaging session + -> demo1.Experiment session_id : tinyint # two-photon session within this experiment ----------- @@ -46,9 +48,11 @@ class Session(dj.Relation): """ -class Scan(dj.Relation): +@schema +class Scan(dj.Manual): definition = """ - demo1.Scan (manual) # a two-photon imaging session + # a two-photon imaging session + -> demo1.Session -> Config scan_id : tinyint # two-photon session within this experiment @@ -58,16 +62,20 @@ class Scan(dj.Relation): mwatts: numeric(4,1) # (mW) laser power to brain """ -class Config(dj.Relation): +@schema +class Config(dj.Manual): definition = """ - demo1.Config (manual) # configuration for scanner + # configuration for scanner + config_id : tinyint # unique id for config setup --- ->ConfigParam """ -class ConfigParam(dj.Relation): +@schema +class ConfigParam(dj.Manual): definition = """ - demo1.ConfigParam (lookup) # params for configurations + # params for configurations + param_set_id : tinyint # id for params """ \ No newline at end of file From d0690f29e6b99a4e48b4bbb6d5e1ed72d0c1c8d6 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 12 Jun 2015 16:40:24 -0500 Subject: [PATCH 27/42] fixed primary key declaration --- datajoint/declare.py | 11 ++++++++--- datajoint/heading.py | 3 --- demos/demo1.py | 28 +++++++--------------------- 3 files changed, 15 insertions(+), 27 deletions(-) diff --git a/datajoint/declare.py b/datajoint/declare.py index b3c2968ca..c0c8e44a9 100644 --- a/datajoint/declare.py +++ b/datajoint/declare.py @@ -2,6 +2,8 @@ import pyparsing as pp import logging +from . import DataJointError + logger = logging.getLogger(__name__) @@ -14,7 +16,7 @@ def declare(full_table_name, definition, context): # split definition into lines definition = re.split(r'\s*\n\s*', definition.strip()) - table_comment = definition.pop(0)[1:] if definition[0].startswith('#') else '' + table_comment = definition.pop(0)[1:].strip() if definition[0].startswith('#') else '' in_key = True # parse primary keys primary_key = [] @@ -35,7 +37,7 @@ def declare(full_table_name, definition, context): 'FOREIGN KEY ({primary_key})' ' REFERENCES {ref} ({primary_key})' ' ON UPDATE CASCADE ON DELETE RESTRICT'.format( - primary_key='`' + '`,`'.join(primary_key) + '`', ref=ref.full_table_name) + primary_key='`' + '`,`'.join(ref.primary_key) + '`', ref=ref.full_table_name) ) for name in ref.primary_key: if in_key and name not in primary_key: @@ -54,8 +56,11 @@ def declare(full_table_name, definition, context): attribute_sql.append(sql) # compile SQL + if not primary_key: + raise DataJointError('Table must have a primary key') sql = 'CREATE TABLE %s (\n ' % full_table_name - sql += ', \n'.join(attribute_sql) + sql += ',\n '.join(attribute_sql) + sql += ',\n PRIMARY KEY (`' + '`,`'.join(primary_key) + '`)' if foreign_key_sql: sql += ', \n' + ', \n'.join(foreign_key_sql) if index_sql: diff --git a/datajoint/heading.py b/datajoint/heading.py index 71ba2b8cf..5c332fff6 100644 --- a/datajoint/heading.py +++ b/datajoint/heading.py @@ -22,7 +22,6 @@ def _asdict(self): """ return OrderedDict((name, self[i]) for i, name in enumerate(self._fields)) - def sql(self): """ Convert attribute tuple into its SQL CREATE TABLE clause. @@ -42,8 +41,6 @@ def sql(self): return '`{name}` {type} {default} COMMENT "{comment}"'.format( name=self.name, type=self.type, default=default, comment=self.comment) - - def __init__(self, attributes=None): """ :param attributes: a list of dicts with the same keys as AttrTuple diff --git a/demos/demo1.py b/demos/demo1.py index d10dc30e3..46ab53fa6 100644 --- a/demos/demo1.py +++ b/demos/demo1.py @@ -20,12 +20,16 @@ class Subject(dj.Manual): animal_notes="" : varchar(4096) # strain, genetic manipulations, etc """ +s = Subject() +p = s.primary_key + @schema class Experiment(dj.Manual): definition = """ # Basic subject info - -> demo1.Subject + + -> Subject experiment : smallint # experiment number for this subject --- experiment_folder : varchar(255) # folder path @@ -40,7 +44,7 @@ class Session(dj.Manual): definition = """ # a two-photon imaging session - -> demo1.Experiment + -> Experiment session_id : tinyint # two-photon session within this experiment ----------- setup : tinyint # experimental setup @@ -53,8 +57,7 @@ class Scan(dj.Manual): definition = """ # a two-photon imaging session - -> demo1.Session - -> Config + -> Session scan_id : tinyint # two-photon session within this experiment ---- depth : float # depth from surface @@ -62,20 +65,3 @@ class Scan(dj.Manual): mwatts: numeric(4,1) # (mW) laser power to brain """ -@schema -class Config(dj.Manual): - definition = """ - # configuration for scanner - - config_id : tinyint # unique id for config setup - --- - ->ConfigParam - """ - -@schema -class ConfigParam(dj.Manual): - definition = """ - # params for configurations - - param_set_id : tinyint # id for params - """ \ No newline at end of file From 86abe762c78638008b4bb39103a7c4cce52403fc Mon Sep 17 00:00:00 2001 From: Fabian Sinz Date: Fri, 12 Jun 2015 21:33:24 -0500 Subject: [PATCH 28/42] fixed camel_case_test --- tests/test_utils.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 655884ce0..6c70150b2 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,10 +1,11 @@ """ Collection of test cases to test core module. """ +from datajoint.user_relations import from_camel_case __author__ = 'eywalker' from nose.tools import assert_true, assert_raises, assert_equal -from datajoint.utils import to_camel_case, from_camel_case +# from datajoint.utils import to_camel_case, from_camel_case from datajoint import DataJointError @@ -16,11 +17,6 @@ def teardown(): pass -def test_to_camel_case(): - assert_equal(to_camel_case('basic_sessions'), 'BasicSessions') - assert_equal(to_camel_case('_another_table'), 'AnotherTable') - - def test_from_camel_case(): assert_equal(from_camel_case('AllGroups'), 'all_groups') with assert_raises(DataJointError): From ec458a0bb794a26bcfe1565a4b9a2a1df0f9b9b0 Mon Sep 17 00:00:00 2001 From: Fabian Sinz Date: Sat, 13 Jun 2015 09:34:31 -0500 Subject: [PATCH 29/42] many tests in test_relation work again --- datajoint/relational_operand.py | 2 + tests/__init__.py | 20 +- tests/schemata/test1.py | 210 +++++----- tests/test_relation.py | 665 +++++++++++++++----------------- 4 files changed, 446 insertions(+), 451 deletions(-) diff --git a/datajoint/relational_operand.py b/datajoint/relational_operand.py index ac330a005..4964e851b 100644 --- a/datajoint/relational_operand.py +++ b/datajoint/relational_operand.py @@ -27,6 +27,8 @@ class RelationalOperand(metaclass=abc.ABCMeta): @property def restrictions(self): + if self._restrictions is None: + self._restrictions = [] return self._restrictions @property diff --git a/tests/__init__.py b/tests/__init__.py index 4d4101116..389cc4fa3 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -22,7 +22,7 @@ conn = dj.conn(**CONN_INFO) # Prefix for all databases used during testing -PREFIX = environ.get('DJ_TEST_DB_PREFIX', 'dj') +PREFIX = environ.get('DJ_TEST_DB_PREFIX', 'djtest') # Bare connection used for verification of query results BASE_CONN = pymysql.connect(**CONN_INFO) BASE_CONN.autocommit(True) @@ -33,7 +33,18 @@ def setup(): dj.config['safemode'] = False def teardown(): - cleanup() + cur = BASE_CONN.cursor() + # cancel any unfinished transactions + cur.execute("ROLLBACK") + # start a transaction now + cur.execute("START TRANSACTION WITH CONSISTENT SNAPSHOT") + cur.execute("SHOW DATABASES LIKE '{}\_%'".format(PREFIX)) + dbs = [x[0] for x in cur.fetchall()] + cur.execute('SET FOREIGN_KEY_CHECKS=0') # unset foreign key check while deleting + for db in dbs: + cur.execute('DROP DATABASE `{}`'.format(db)) + cur.execute('SET FOREIGN_KEY_CHECKS=1') # set foreign key check back on + cur.execute("COMMIT") def cleanup(): """ @@ -51,7 +62,10 @@ def cleanup(): dbs = [x[0] for x in cur.fetchall()] cur.execute('SET FOREIGN_KEY_CHECKS=0') # unset foreign key check while deleting for db in dbs: - cur.execute('DROP DATABASE `{}`'.format(db)) + cur.execute("USE %s" % (db)) + cur.execute("SHOW TABLES") + for table in [x[0] for x in cur.fetchall()]: + cur.execute('DELETE FROM `{}`'.format(table)) cur.execute('SET FOREIGN_KEY_CHECKS=1') # set foreign key check back on cur.execute("COMMIT") # diff --git a/tests/schemata/test1.py b/tests/schemata/test1.py index 5b4ac723a..4862d5de7 100644 --- a/tests/schemata/test1.py +++ b/tests/schemata/test1.py @@ -20,108 +20,114 @@ class Subjects(dj.Manual): species = "mouse" : enum('mouse', 'monkey', 'human') # species """ -# # test for shorthand -# class Animals(dj.Relation): -# definition = """ -# test1.Animals (manual) # Listing of all info -# -# -> Subjects -# --- -# animal_dob :date # date of birth -# """ -# -# -# class Trials(dj.Relation): -# definition = """ -# test1.Trials (manual) # info about trials -# -# -> test1.Subjects -# trial_id : int -# --- -# outcome : int # result of experiment -# -# notes="" : varchar(4096) # other comments -# trial_ts=CURRENT_TIMESTAMP : timestamp # automatic -# """ -# -# -# -# class SquaredScore(dj.Relation, dj.AutoPopulate): -# definition = """ -# test1.SquaredScore (computed) # cumulative outcome of trials -# -# -> test1.Subjects -# -> test1.Trials -# --- -# squared : int # squared result of Trials outcome -# """ -# -# @property -# def populate_relation(self): -# return Subjects() * Trials() -# -# def _make_tuples(self, key): -# tmp = (Trials() & key).fetch1() -# tmp2 = SquaredSubtable() & key -# -# self.insert(dict(key, squared=tmp['outcome']**2)) -# -# ss = SquaredSubtable() -# -# for i in range(10): -# key['dummy'] = i -# ss.insert(key) -# -# -# class WrongImplementation(dj.Relation, dj.AutoPopulate): -# definition = """ -# test1.WrongImplementation (computed) # ignore -# -# -> test1.Subjects -# -> test1.Trials -# --- -# dummy : int # ignore -# """ -# -# @property -# def populate_relation(self): -# return {'subject_id':2} -# -# def _make_tuples(self, key): -# pass -# -# -# -# class ErrorGenerator(dj.Relation, dj.AutoPopulate): -# definition = """ -# test1.ErrorGenerator (computed) # ignore -# -# -> test1.Subjects -# -> test1.Trials -# --- -# dummy : int # ignore -# """ -# -# @property -# def populate_relation(self): -# return Subjects() * Trials() -# -# def _make_tuples(self, key): -# raise Exception("This is for testing") -# -# -# -# -# -# -# class SquaredSubtable(dj.Relation): -# definition = """ -# test1.SquaredSubtable (computed) # cumulative outcome of trials -# -# -> test1.SquaredScore -# dummy : int # dummy primary attribute -# --- -# """ +# test for shorthand +@testschema +class Animals(dj.Manual): + definition = """ + # Listing of all info + + -> Subjects + --- + animal_dob :date # date of birth + """ + +@testschema +class Trials(dj.Manual): + definition = """ + # info about trials + + -> Subjects + trial_id : int + --- + outcome : int # result of experiment + + notes="" : varchar(4096) # other comments + trial_ts=CURRENT_TIMESTAMP : timestamp # automatic + """ + +@testschema +class Matrix(dj.Manual): + definition = """ + # Some numpy array + + matrix_id : int # unique matrix id + --- + data : longblob # data + comment : varchar(1000) # comment + """ + + +@testschema +class SquaredScore(dj.Computed): + definition = """ + # cumulative outcome of trials + + -> Subjects + -> Trials + --- + squared : int # squared result of Trials outcome + """ + + @property + def populate_relation(self): + return Subjects() * Trials() + + def _make_tuples(self, key): + tmp = (Trials() & key).fetch1() + tmp2 = SquaredSubtable() & key + + self.insert(dict(key, squared=tmp['outcome']**2)) + + ss = SquaredSubtable() + + for i in range(10): + key['dummy'] = i + ss.insert(key) + +@testschema +class WrongImplementation(dj.Computed): + definition = """ + # ignore + + -> Subjects + -> Trials + --- + dummy : int # ignore + """ + + @property + def populate_relation(self): + return {'subject_id':2} + + def _make_tuples(self, key): + pass + +class ErrorGenerator(dj.Computed): + definition = """ + # ignore + + -> Subjects + -> Trials + --- + dummy : int # ignore + """ + + @property + def populate_relation(self): + return Subjects() * Trials() + + def _make_tuples(self, key): + raise Exception("This is for testing") + +@testschema +class SquaredSubtable(dj.Subordinate, dj.Manual): + definition = """ + # cumulative outcome of trials + + -> SquaredScore + dummy : int # dummy primary attribute + --- + """ # # # # test reference to another table in same schema diff --git a/tests/test_relation.py b/tests/test_relation.py index 76f87f707..ff2a5c735 100644 --- a/tests/test_relation.py +++ b/tests/test_relation.py @@ -4,107 +4,98 @@ # __author__ = 'fabee' # # from .schemata.schema1 import test1, test4 -from .schemata.test1 import Subjects - - -def test_instantiate_relation(): - s = Subjects() - -# -# from . import BASE_CONN, CONN_INFO, PREFIX, cleanup +import random +import string +import pymysql +from datajoint import DataJointError +from .schemata.test1 import Subjects, Animals, Matrix, Trials, SquaredScore, SquaredSubtable, WrongImplementation, \ + ErrorGenerator, testschema +from . import BASE_CONN, CONN_INFO, PREFIX, cleanup # from datajoint.connection import Connection -# from nose.tools import assert_raises, assert_equal, assert_regexp_matches, assert_false, assert_true, assert_list_equal,\ -# assert_tuple_equal, assert_dict_equal, raises +from nose.tools import assert_raises, assert_equal, assert_regexp_matches, assert_false, assert_true, assert_list_equal,\ + assert_tuple_equal, assert_dict_equal, raises # from datajoint import DataJointError, TransactionError, AutoPopulate, Relation -# import numpy as np -# from numpy.testing import assert_array_equal -# from datajoint.relation import FreeRelation -# import numpy as np -# -# -# def trial_faker(n=10): -# def iter(): -# for s in [1, 2]: -# for i in range(n): -# yield dict(trial_id=i, subject_id=s, outcome=int(np.random.randint(10)), notes= 'no comment') -# return iter() -# -# -# def setup(): -# """ -# Setup connections and bindings -# """ -# pass -# -# -# class TestTableObject(object): -# def __init__(self): -# self.subjects = None -# self.setup() -# -# """ -# Test cases for FreeRelation objects -# """ -# -# def setup(self): -# """ -# Create a connection object and prepare test modules -# as follows: -# test1 - has conn and bounded -# """ -# cleanup() # drop all databases with PREFIX -# test1.__dict__.pop('conn', None) -# test4.__dict__.pop('conn', None) # make sure conn is not defined at schema level -# -# self.conn = Connection(**CONN_INFO) -# test1.conn = self.conn -# test4.conn = self.conn -# self.conn.bind(test1.__name__, PREFIX + '_test1') -# self.conn.bind(test4.__name__, PREFIX + '_test4') -# self.subjects = test1.Subjects() -# self.animals = test1.Animals() -# self.relvar_blob = test4.Matrix() -# self.trials = test1.Trials() -# -# def teardown(self): -# cleanup() -# -# def test_compound_restriction(self): -# s = self.subjects -# t = self.trials -# -# s.insert(dict(subject_id=1, real_id='M')) -# s.insert(dict(subject_id=2, real_id='F')) -# t.iter_insert(trial_faker(20)) -# -# tM = t & (s & "real_id = 'M'") -# t1 = t & "subject_id = 1" -# -# assert_equal(len(tM), len(t1), "Results of compound request does not have same length") -# -# for t1_item, tM_item in zip(sorted(t1, key=lambda item: item['trial_id']), -# sorted(tM, key=lambda item: item['trial_id'])): -# assert_dict_equal(t1_item, tM_item, -# 'Dictionary elements do not agree in compound statement') -# -# def test_record_insert(self): -# "Test whether record insert works" -# tmp = np.array([(2, 'Klara', 'monkey')], -# dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) -# -# self.subjects.insert(tmp[0]) -# testt2 = (self.subjects & 'subject_id = 2').fetch()[0] -# assert_equal(tuple(tmp[0]), tuple(testt2), "Inserted and fetched record do not match!") +import numpy as np +from numpy.testing import assert_array_equal +import numpy as np +import datajoint as dj + # -# def test_delete(self): -# "Test whether delete works" -# tmp = np.array([(2, 'Klara', 'monkey'), (1,'Peter', 'mouse')], -# dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) # -# self.subjects.batch_insert(tmp) -# assert_true(len(self.subjects) == 2, 'Length does not match 2.') -# self.subjects.delete() -# assert_true(len(self.subjects) == 0, 'Length does not match 0.') +def trial_faker(n=10): + def iter(): + for s in [1, 2]: + for i in range(n): + yield dict(trial_id=i, subject_id=s, outcome=int(np.random.randint(10)), notes= 'no comment') + return iter() + + +class TestTableObject(object): + def __init__(self): + self.subjects = None + self.setup() + + """ + Test cases for FreeRelation objects + """ + + def setup(self): + """ + Create a connection object and prepare test modules + as follows: + test1 - has conn and bounded + """ + cleanup() # delete everything from all tables of databases with PREFIX + self.subjects = Subjects() + self.animals = Animals() + self.relvar_blob = Matrix() + self.trials = Trials() + + + + def test_instantiate_relation(self): + s = Subjects() + + + def teardown(self): + cleanup() + + def test_compound_restriction(self): + s = self.subjects + t = self.trials + + s.insert(dict(subject_id=1, real_id='M')) + s.insert(dict(subject_id=2, real_id='F')) + t.iter_insert(trial_faker(20)) + + tM = t & (s & "real_id = 'M'") + t1 = t & "subject_id = 1" + + assert_equal(len(tM), len(t1), "Results of compound request does not have same length") + + for t1_item, tM_item in zip(sorted(t1, key=lambda item: item['trial_id']), + sorted(tM, key=lambda item: item['trial_id'])): + assert_dict_equal(t1_item, tM_item, + 'Dictionary elements do not agree in compound statement') + + def test_record_insert(self): + "Test whether record insert works" + tmp = np.array([(2, 'Klara', 'monkey')], + dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) + + self.subjects.insert(tmp[0]) + testt2 = (self.subjects & 'subject_id = 2').fetch()[0] + assert_equal(tuple(tmp[0]), tuple(testt2), "Inserted and fetched record do not match!") + + def test_delete(self): + "Test whether delete works" + tmp = np.array([(2, 'Klara', 'monkey'), (1,'Peter', 'mouse')], + dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) + + self.subjects.batch_insert(tmp) + assert_true(len(self.subjects) == 2, 'Length does not match 2.') + self.subjects.delete() + assert_true(len(self.subjects) == 0, 'Length does not match 0.') # # # def test_cascading_delete(self): # # "Test whether delete works" @@ -127,15 +118,15 @@ def test_instantiate_relation(): # # # -# def test_record_insert_different_order(self): -# "Test whether record insert works" -# tmp = np.array([('Klara', 2, 'monkey')], -# dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) -# -# self.subjects.insert(tmp[0]) -# testt2 = (self.subjects & 'subject_id = 2').fetch()[0] -# assert_equal((2, 'Klara', 'monkey'), tuple(testt2), -# "Inserted and fetched record do not match!") + def test_record_insert_different_order(self): + "Test whether record insert works" + tmp = np.array([('Klara', 2, 'monkey')], + dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) + + self.subjects.insert(tmp[0]) + testt2 = (self.subjects & 'subject_id = 2').fetch()[0] + assert_equal((2, 'Klara', 'monkey'), tuple(testt2), + "Inserted and fetched record do not match!") # # @raises(TransactionError) # def test_transaction_error(self): @@ -198,70 +189,70 @@ def test_instantiate_relation(): # self.conn.commit_transaction() # # -# @raises(KeyError) -# def test_wrong_key_insert_records(self): -# "Test whether record insert works" -# tmp = np.array([('Klara', 2, 'monkey')], -# dtype=[('real_deal', 'O'), ('subject_id', '>i4'), ('species', 'O')]) -# -# self.subjects.insert(tmp[0]) -# -# -# def test_dict_insert(self): -# "Test whether record insert works" -# tmp = {'real_id': 'Brunhilda', -# 'subject_id': 3, -# 'species': 'human'} -# -# self.subjects.insert(tmp) -# testt2 = (self.subjects & 'subject_id = 3').fetch()[0] -# assert_equal((3, 'Brunhilda', 'human'), tuple(testt2), "Inserted and fetched record do not match!") -# -# @raises(KeyError) -# def test_wrong_key_insert(self): -# "Test whether a correct error is generated when inserting wrong attribute name" -# tmp = {'real_deal': 'Brunhilda', -# 'subject_database': 3, -# 'species': 'human'} -# -# self.subjects.insert(tmp) -# -# def test_batch_insert(self): -# "Test whether record insert works" -# tmp = np.array([('Klara', 2, 'monkey'), ('Brunhilda', 3, 'mouse'), ('Mickey', 1, 'human')], -# dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) -# -# self.subjects.batch_insert(tmp) -# -# expected = np.array([(1, 'Mickey', 'human'), (2, 'Klara', 'monkey'), -# (3, 'Brunhilda', 'mouse')], -# dtype=[('subject_id', 'i4'), ('species', 'O')]) + @raises(KeyError) + def test_wrong_key_insert_records(self): + "Test whether record insert works" + tmp = np.array([('Klara', 2, 'monkey')], + dtype=[('real_deal', 'O'), ('subject_id', '>i4'), ('species', 'O')]) + + self.subjects.insert(tmp[0]) + # -# self.subjects.iter_insert(tmp.__iter__()) + def test_dict_insert(self): + "Test whether record insert works" + tmp = {'real_id': 'Brunhilda', + 'subject_id': 3, + 'species': 'human'} + + self.subjects.insert(tmp) + testt2 = (self.subjects & 'subject_id = 3').fetch()[0] + assert_equal((3, 'Brunhilda', 'human'), tuple(testt2), "Inserted and fetched record do not match!") + + @raises(KeyError) + def test_wrong_key_insert(self): + "Test whether a correct error is generated when inserting wrong attribute name" + tmp = {'real_deal': 'Brunhilda', + 'subject_database': 3, + 'species': 'human'} + + self.subjects.insert(tmp) # -# expected = np.array([(1, 'Mickey', 'human'), (2, 'Klara', 'monkey'), -# (3, 'Brunhilda', 'mouse')], -# dtype=[('subject_id', 'i4'), ('species', 'O')]) + + self.subjects.batch_insert(tmp) + + expected = np.array([(1, 'Mickey', 'human'), (2, 'Klara', 'monkey'), + (3, 'Brunhilda', 'mouse')], + dtype=[('subject_id', 'i4'), ('species', 'O')]) + + self.subjects.iter_insert(tmp.__iter__()) + + expected = np.array([(1, 'Mickey', 'human'), (2, 'Klara', 'monkey'), + (3, 'Brunhilda', 'mouse')], + dtype=[('subject_id', 'i4'), ('species', 'O')]) -# self.subjects.batch_insert(tmp) -# -# for trial_id in range(1,11): -# self.trials.insert(dict(subject_id=2, trial_id=trial_id, outcome=np.random.randint(0,10))) -# -# def teardown(self): -# cleanup() -# -# def test_autopopulate(self): -# self.squared.populate() -# assert_equal(len(self.squared), 10) -# -# for trial in self.trials*self.squared: -# assert_equal(trial['outcome']**2, trial['squared']) -# -# def test_autopopulate_restriction(self): -# self.squared.populate(restriction='trial_id <= 5') -# assert_equal(len(self.squared), 5) -# -# for trial in self.trials*self.squared: -# assert_equal(trial['outcome']**2, trial['squared']) -# -# -# # def test_autopopulate_transaction_error(self): -# # errors = self.squared.populate(suppress_errors=True) -# # assert_equal(len(errors), 1) -# # assert_true(isinstance(errors[0][1], TransactionError)) -# -# @raises(DataJointError) -# def test_autopopulate_relation_check(self): -# -# class dummy(AutoPopulate): -# -# def populate_relation(self): -# return None -# -# def _make_tuples(self, key): -# pass -# -# du = dummy() -# du.populate() \ +def id_generator(size=6, chars=string.ascii_uppercase + string.digits): + return ''.join(random.choice(chars) for _ in range(size)) + +class TestIterator(object): + def __init__(self): + self.relvar = None + self.setup() + + """ + Test cases for Iterators in Relations objects + """ + + def setup(self): + """ + Create a connection object and prepare test modules + as follows: + test1 - has conn and bounded + """ + cleanup() # drop all databases with PREFIX + self.relvar_blob = Matrix() + + def teardown(self): + cleanup() # -# @raises(DataJointError) -# def test_autopopulate_relation_check(self): -# self.dummy1.populate() # -# @raises(Exception) -# def test_autopopulate_relation_check(self): -# self.error_generator.populate()\ + def test_blob_iteration(self): + "Tests the basic call of the iterator" + + dicts = [] + for i in range(10): + + c = id_generator() + + t = {'matrix_id':i, + 'data': np.random.randn(4,4,4), + 'comment': c} + self.relvar_blob.insert(t) + dicts.append(t) + + for t, t2 in zip(dicts, self.relvar_blob): + assert_true(isinstance(t2, dict), 'iterator does not return dict') + + assert_equal(t['matrix_id'], t2['matrix_id'], 'inserted and retrieved tuples do not match') + assert_equal(t['comment'], t2['comment'], 'inserted and retrieved tuples do not match') + assert_true(np.all(t['data'] == t2['data']), 'inserted and retrieved tuples do not match') + + def test_fetch(self): + dicts = [] + for i in range(10): + + c = id_generator() + + t = {'matrix_id':i, + 'data': np.random.randn(4,4,4), + 'comment': c} + self.relvar_blob.insert(t) + dicts.append(t) + + tuples2 = self.relvar_blob.fetch() + assert_true(isinstance(tuples2, np.ndarray), "Return value of fetch does not have proper type.") + assert_true(isinstance(tuples2[0], np.void), "Return value of fetch does not have proper type.") + for t, t2 in zip(dicts, tuples2): + + assert_equal(t['matrix_id'], t2['matrix_id'], 'inserted and retrieved tuples do not match') + assert_equal(t['comment'], t2['comment'], 'inserted and retrieved tuples do not match') + assert_true(np.all(t['data'] == t2['data']), 'inserted and retrieved tuples do not match') + + def test_fetch_dicts(self): + dicts = [] + for i in range(10): + + c = id_generator() + + t = {'matrix_id':i, + 'data': np.random.randn(4,4,4), + 'comment': c} + self.relvar_blob.insert(t) + dicts.append(t) + + tuples2 = self.relvar_blob.fetch(as_dict=True) + assert_true(isinstance(tuples2, list), "Return value of fetch with as_dict=True does not have proper type.") + assert_true(isinstance(tuples2[0], dict), "Return value of fetch with as_dict=True does not have proper type.") + for t, t2 in zip(dicts, tuples2): + assert_equal(t['matrix_id'], t2['matrix_id'], 'inserted and retrieved dicts do not match') + assert_equal(t['comment'], t2['comment'], 'inserted and retrieved dicts do not match') + assert_true(np.all(t['data'] == t2['data']), 'inserted and retrieved dicts do not match') + + # -# @raises(Exception) -# def test_autopopulate_relation_check2(self): -# tmp = self.dummy2.populate(suppress_errors=True) -# assert_equal(len(tmp), 1, 'Error list should have length 1.') +class TestAutopopulate: + def __init__(self): + self.relvar = None + self.setup() + + """ + Test cases for Iterators in Relations objects + """ + + def setup(self): + """ + Create a connection object and prepare test modules + as follows: + test1 - has conn and bounded + """ + cleanup() # drop all databases with PREFIX + + self.subjects = Subjects() + self.trials = Trials() + self.squared = SquaredScore() + self.dummy = SquaredSubtable() + self.dummy1 = WrongImplementation() + self.error_generator = ErrorGenerator() + self.fill_relation() + + def fill_relation(self): + tmp = np.array([('Klara', 2, 'monkey'), ('Peter', 3, 'mouse')], + dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) + self.subjects.batch_insert(tmp) + + for trial_id in range(1,11): + self.trials.insert(dict(subject_id=2, trial_id=trial_id, outcome=np.random.randint(0,10))) + + def teardown(self): + cleanup() + + def test_autopopulate(self): + self.squared.populate() + assert_equal(len(self.squared), 10) + + for trial in self.trials*self.squared: + assert_equal(trial['outcome']**2, trial['squared']) + + def test_autopopulate_restriction(self): + self.squared.populate(restriction='trial_id <= 5') + assert_equal(len(self.squared), 5) + + for trial in self.trials*self.squared: + assert_equal(trial['outcome']**2, trial['squared']) + + @raises(DataJointError) + def test_autopopulate_relation_check(self): + @testschema + class dummy(dj.Computed): + + def populate_relation(self): + return None + + def _make_tuples(self, key): + pass + + du = dummy() + du.populate() \ + + @raises(DataJointError) + def test_autopopulate_relation_check(self): + self.dummy1.populate() + + @raises(Exception) + def test_autopopulate_relation_check(self): + self.error_generator.populate()\ + + @raises(Exception) + def test_autopopulate_relation_check2(self): + tmp = self.dummy2.populate(suppress_errors=True) + assert_equal(len(tmp), 1, 'Error list should have length 1.') From 083704a842b0586dea3a34775e7a345942ed069b Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Mon, 15 Jun 2015 11:03:49 -0500 Subject: [PATCH 30/42] RelationalOperand.restriction now returns [] if empty --- datajoint/relational_operand.py | 2 +- demos/demo1.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/datajoint/relational_operand.py b/datajoint/relational_operand.py index ac330a005..1a6a95378 100644 --- a/datajoint/relational_operand.py +++ b/datajoint/relational_operand.py @@ -27,7 +27,7 @@ class RelationalOperand(metaclass=abc.ABCMeta): @property def restrictions(self): - return self._restrictions + return [] if self._restrictions is None else self._restrictions @property def primary_key(self): diff --git a/demos/demo1.py b/demos/demo1.py index 46ab53fa6..d30894617 100644 --- a/demos/demo1.py +++ b/demos/demo1.py @@ -28,7 +28,7 @@ class Subject(dj.Manual): class Experiment(dj.Manual): definition = """ # Basic subject info - + -> Subject experiment : smallint # experiment number for this subject --- From 7ed1ba8fce48b116344e55fa3a4f4825a80bfc54 Mon Sep 17 00:00:00 2001 From: Fabian Sinz Date: Mon, 15 Jun 2015 11:43:47 -0500 Subject: [PATCH 31/42] test --- tests/test_relation.py | 66 ++---------------------------------------- 1 file changed, 2 insertions(+), 64 deletions(-) diff --git a/tests/test_relation.py b/tests/test_relation.py index ff2a5c735..5a35db943 100644 --- a/tests/test_relation.py +++ b/tests/test_relation.py @@ -127,68 +127,7 @@ def test_record_insert_different_order(self): testt2 = (self.subjects & 'subject_id = 2').fetch()[0] assert_equal((2, 'Klara', 'monkey'), tuple(testt2), "Inserted and fetched record do not match!") -# -# @raises(TransactionError) -# def test_transaction_error(self): -# "Test whether declaration in transaction is prohibited" -# -# tmp = np.array([('Klara', 2, 'monkey')], -# dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) -# self.conn.start_transaction() -# self.subjects.insert(tmp[0]) -# -# # def test_transaction_suppress_error(self): -# # "Test whether ignore_errors ignores the errors." -# # -# # tmp = np.array([('Klara', 2, 'monkey')], -# # dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) -# # with self.conn.transaction(ignore_errors=True) as tr: -# # self.subjects.insert(tmp[0]) -# -# -# @raises(TransactionError) -# def test_transaction_error_not_resolve(self): -# "Test whether declaration in transaction is prohibited" -# -# tmp = np.array([('Klara', 2, 'monkey'), ('Klara', 3, 'monkey')], -# dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) -# try: -# self.conn.start_transaction() -# self.subjects.insert(tmp[0]) -# except TransactionError as te: -# self.conn.cancel_transaction() -# -# self.conn.start_transaction() -# self.subjects.insert(tmp[0]) -# -# def test_transaction_error_resolve(self): -# "Test whether declaration in transaction is prohibited" -# -# tmp = np.array([('Klara', 2, 'monkey'), ('Klara', 3, 'monkey')], -# dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) -# try: -# self.conn.start_transaction() -# self.subjects.insert(tmp[0]) -# except TransactionError as te: -# self.conn.cancel_transaction() -# te.resolve() -# -# self.conn.start_transaction() -# self.subjects.insert(tmp[0]) -# self.conn.commit_transaction() -# -# def test_transaction_error2(self): -# "If table is declared, we are allowed to insert within a transaction" -# -# tmp = np.array([('Klara', 2, 'monkey'), ('Klara', 3, 'monkey')], -# dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) -# self.subjects.insert(tmp[0]) -# -# self.conn.start_transaction() -# self.subjects.insert(tmp[1]) -# self.conn.commit_transaction() -# -# + @raises(KeyError) def test_wrong_key_insert_records(self): "Test whether record insert works" @@ -197,8 +136,7 @@ def test_wrong_key_insert_records(self): self.subjects.insert(tmp[0]) -# - def test_dict_insert(self): + def test_dict_insert(self): "Test whether record insert works" tmp = {'real_id': 'Brunhilda', 'subject_id': 3, From 237f61e581b0f0886437261e0e9b3173190b2ee1 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Mon, 15 Jun 2015 20:22:06 -0500 Subject: [PATCH 32/42] fixed bug in restriction by dict --- datajoint/relation.py | 1 - datajoint/relational_operand.py | 13 +++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/datajoint/relation.py b/datajoint/relation.py index 8f558831f..a73e67b66 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -1,4 +1,3 @@ -from collections import namedtuple from collections.abc import Mapping import numpy as np import logging diff --git a/datajoint/relational_operand.py b/datajoint/relational_operand.py index 1a6a95378..5ea38135c 100644 --- a/datajoint/relational_operand.py +++ b/datajoint/relational_operand.py @@ -5,11 +5,13 @@ import numpy as np import abc import re +from collections import OrderedDict from copy import copy from datajoint import DataJointError, config -from .blob import unpack import logging +from .blob import unpack + logger = logging.getLogger(__name__) @@ -171,8 +173,7 @@ def fetch1(self): ret = cur.fetchone() if not ret or cur.fetchone(): raise DataJointError('fetch1 should only be used for relations with exactly one tuple') - ret = {k: unpack(v) if heading[k].is_blob else v for k, v in ret.items()} - return ret + return OrderedDict((k, unpack(v) if heading[k].is_blob else v) for k, v in ret.items()) def fetch(self, offset=0, limit=None, order_by=None, descending=False, as_dict=False): """ @@ -188,8 +189,8 @@ def fetch(self, offset=0, limit=None, order_by=None, descending=False, as_dict=F descending=descending, as_dict=as_dict) heading = self.heading if as_dict: - ret = [{k: unpack(v) if heading[k].is_blob else v - for k, v in d.items()} + ret = [OrderedDict((k, unpack(v) if heading[k].is_blob else v) + for k, v in d.items()) for d in cur.fetchall()] else: ret = np.array(list(cur.fetchall()), dtype=heading.as_dtype) @@ -255,7 +256,7 @@ def where_clause(self): def make_condition(arg): if isinstance(arg, dict): - conditions = ['`%s`=%s' % (k, repr(v)) for k, v in arg.items()] + conditions = ['`%s`=%s' % (k, repr(v)) for k, v in arg.items() if k in self.heading] elif isinstance(arg, np.void): conditions = ['`%s`=%s' % (k, arg[k]) for k in arg.dtype.fields] else: From 05d84fa542213225df3b748b63d76f14d8027418 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Mon, 15 Jun 2015 21:15:28 -0500 Subject: [PATCH 33/42] fetch1 now returns OrderedDict with correct order of attributes. --- datajoint/relational_operand.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/datajoint/relational_operand.py b/datajoint/relational_operand.py index 5ea38135c..c4e982972 100644 --- a/datajoint/relational_operand.py +++ b/datajoint/relational_operand.py @@ -173,7 +173,8 @@ def fetch1(self): ret = cur.fetchone() if not ret or cur.fetchone(): raise DataJointError('fetch1 should only be used for relations with exactly one tuple') - return OrderedDict((k, unpack(v) if heading[k].is_blob else v) for k, v in ret.items()) + return OrderedDict((name, unpack(ret[name]) if heading[name].is_blob else ret[name]) + for name in self.heading.names) def fetch(self, offset=0, limit=None, order_by=None, descending=False, as_dict=False): """ @@ -189,8 +190,8 @@ def fetch(self, offset=0, limit=None, order_by=None, descending=False, as_dict=F descending=descending, as_dict=as_dict) heading = self.heading if as_dict: - ret = [OrderedDict((k, unpack(v) if heading[k].is_blob else v) - for k, v in d.items()) + ret = [OrderedDict((name, unpack(ret[name]) if heading[name].is_blob else ret[name]) + for name in self.heading.names) for d in cur.fetchall()] else: ret = np.array(list(cur.fetchall()), dtype=heading.as_dtype) From 5ffacfe9b1ded0782a9ae243614a02cc0d37c4e9 Mon Sep 17 00:00:00 2001 From: Fabian Sinz Date: Tue, 16 Jun 2015 09:33:01 -0500 Subject: [PATCH 34/42] merged --- datajoint/relation.py | 2 +- tests/schemata/test1.py | 4 ++-- tests/test_relation.py | 12 ++++++++++++ 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/datajoint/relation.py b/datajoint/relation.py index 8f558831f..dd6528d12 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -269,4 +269,4 @@ def _alter(self, alter_statement): raise DataJointError("Table definition cannot be altered during a transaction.") sql = 'ALTER TABLE %s %s' % (self.full_table_name, alter_statement) self.connection.query(sql) - self.heading.init_from_database(self.connection, self.database, self.table_name) \ No newline at end of file + self.heading.init_from_database(self.connection, self.database, self.table_name) diff --git a/tests/schemata/test1.py b/tests/schemata/test1.py index 4862d5de7..a9a03be64 100644 --- a/tests/schemata/test1.py +++ b/tests/schemata/test1.py @@ -12,9 +12,9 @@ @testschema class Subjects(dj.Manual): definition = """ - # Basic subject info + #Basic subject info - subject_id : int # unique subject id + subject_id : int # unique subject id --- real_id : varchar(40) # real-world name species = "mouse" : enum('mouse', 'monkey', 'human') # species diff --git a/tests/test_relation.py b/tests/test_relation.py index 5a35db943..dce720b63 100644 --- a/tests/test_relation.py +++ b/tests/test_relation.py @@ -4,6 +4,7 @@ # __author__ = 'fabee' # # from .schemata.schema1 import test1, test4 +<<<<<<< Updated upstream import random import string import pymysql @@ -11,6 +12,17 @@ from .schemata.test1 import Subjects, Animals, Matrix, Trials, SquaredScore, SquaredSubtable, WrongImplementation, \ ErrorGenerator, testschema from . import BASE_CONN, CONN_INFO, PREFIX, cleanup +======= +from .schemata.test1 import Subjects + + +def test_instantiate_relation(): + s = Subjects() + print(s) + +# +# from . import BASE_CONN, CONN_INFO, PREFIX, cleanup +>>>>>>> Stashed changes # from datajoint.connection import Connection from nose.tools import assert_raises, assert_equal, assert_regexp_matches, assert_false, assert_true, assert_list_equal,\ assert_tuple_equal, assert_dict_equal, raises From ae2669bd5be17995312285598b42c43ab69e949a Mon Sep 17 00:00:00 2001 From: Fabian Sinz Date: Tue, 16 Jun 2015 09:35:12 -0500 Subject: [PATCH 35/42] dropped transaction from teardown in tests --- tests/__init__.py | 1 - tests/test_relation.py | 28 ++++++++-------------------- 2 files changed, 8 insertions(+), 21 deletions(-) diff --git a/tests/__init__.py b/tests/__init__.py index 389cc4fa3..611f34bc2 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -37,7 +37,6 @@ def teardown(): # cancel any unfinished transactions cur.execute("ROLLBACK") # start a transaction now - cur.execute("START TRANSACTION WITH CONSISTENT SNAPSHOT") cur.execute("SHOW DATABASES LIKE '{}\_%'".format(PREFIX)) dbs = [x[0] for x in cur.fetchall()] cur.execute('SET FOREIGN_KEY_CHECKS=0') # unset foreign key check while deleting diff --git a/tests/test_relation.py b/tests/test_relation.py index dce720b63..635124130 100644 --- a/tests/test_relation.py +++ b/tests/test_relation.py @@ -4,7 +4,6 @@ # __author__ = 'fabee' # # from .schemata.schema1 import test1, test4 -<<<<<<< Updated upstream import random import string import pymysql @@ -12,17 +11,6 @@ from .schemata.test1 import Subjects, Animals, Matrix, Trials, SquaredScore, SquaredSubtable, WrongImplementation, \ ErrorGenerator, testschema from . import BASE_CONN, CONN_INFO, PREFIX, cleanup -======= -from .schemata.test1 import Subjects - - -def test_instantiate_relation(): - s = Subjects() - print(s) - -# -# from . import BASE_CONN, CONN_INFO, PREFIX, cleanup ->>>>>>> Stashed changes # from datajoint.connection import Connection from nose.tools import assert_raises, assert_equal, assert_regexp_matches, assert_false, assert_true, assert_list_equal,\ assert_tuple_equal, assert_dict_equal, raises @@ -149,14 +137,14 @@ def test_wrong_key_insert_records(self): self.subjects.insert(tmp[0]) def test_dict_insert(self): - "Test whether record insert works" - tmp = {'real_id': 'Brunhilda', - 'subject_id': 3, - 'species': 'human'} - - self.subjects.insert(tmp) - testt2 = (self.subjects & 'subject_id = 3').fetch()[0] - assert_equal((3, 'Brunhilda', 'human'), tuple(testt2), "Inserted and fetched record do not match!") + "Test whether record insert works" + tmp = {'real_id': 'Brunhilda', + 'subject_id': 3, + 'species': 'human'} + + self.subjects.insert(tmp) + testt2 = (self.subjects & 'subject_id = 3').fetch()[0] + assert_equal((3, 'Brunhilda', 'human'), tuple(testt2), "Inserted and fetched record do not match!") @raises(KeyError) def test_wrong_key_insert(self): From d547f83a6e034cad0b09c5df6a20c7cefeb5ad06 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Tue, 16 Jun 2015 16:54:15 -0500 Subject: [PATCH 36/42] added AutoPopulate.progress --- datajoint/autopopulate.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index cb4c0d673..d64f9035a 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -79,3 +79,11 @@ def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False): logger.info('Done populating.') return error_list + + def progress(self): + total = len(self.populate_relation) + remaining = len(self.populate_relation - self.target) + if remaning: + print('Remaining %d of %d (%2.1f%%)' % (remaining, total, 100*remaining/total), flush=True) + else: + print('Complete') From d3060421a0ba64cb42360ec53a90676e0d72e4ed Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Tue, 16 Jun 2015 17:04:41 -0500 Subject: [PATCH 37/42] added doc string to AutoPopulate.progress() --- datajoint/autopopulate.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index d64f9035a..d07fd1bf0 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -81,9 +81,10 @@ def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False): def progress(self): + """ + report progress of populating this table + """ total = len(self.populate_relation) remaining = len(self.populate_relation - self.target) - if remaning: - print('Remaining %d of %d (%2.1f%%)' % (remaining, total, 100*remaining/total), flush=True) - else: - print('Complete') + print('Remaining %d of %d (%2.1f%%)' % (remaining, total, 100*remaining/total) + if remaining else 'Complete', flush=True) From 77ad63d272fc83a80f75e34ba9f25d6c929c25e8 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Tue, 16 Jun 2015 18:09:23 -0500 Subject: [PATCH 38/42] fixed bug in blob packing/unpacking --- datajoint/blob.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/datajoint/blob.py b/datajoint/blob.py index b77547320..e6fe63386 100644 --- a/datajoint/blob.py +++ b/datajoint/blob.py @@ -24,6 +24,7 @@ mxFUNCTION_CLASS=None) reverseClassID = {v: i for i, v in enumerate(mxClassID.values())} +dtypeList = list(mxClassID.values()) def pack(obj): @@ -41,6 +42,7 @@ def pack(obj): obj, imaginary = np.real(obj), np.imag(obj) type_number = reverseClassID[obj.dtype] + assert dtypeList[type_number] is obj.dtype, 'ambigous or unknown array type' blob += np.asarray(type_number, dtype=np.uint32).tostring() blob += np.int8(is_complex).tostring() + b'\0\0\0' blob += obj.tostring() @@ -72,7 +74,8 @@ def unpack(blob): p += 8 array_shape = np.fromstring(blob[p:p+8*dimensions], dtype=np.uint64) p += 8 * dimensions - mx_type, dtype = [q for q in mxClassID.items()][np.fromstring(blob[p:p+4], dtype=np.uint32)[0]] + type_number = np.fromstring(blob[p:p+4], dtype=np.uint32)[0] + dtype = dtypeList[type_number] if dtype is None: raise DataJointError('Unsupported MATLAB data type '+mx_type+' in blob') p += 4 From cf3f30d439c2d1e1db774057f89aec4ed5fc92a9 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Tue, 16 Jun 2015 18:43:31 -0500 Subject: [PATCH 39/42] bugfix in blob pack/unpack --- datajoint/blob.py | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/datajoint/blob.py b/datajoint/blob.py index e6fe63386..7dad919cf 100644 --- a/datajoint/blob.py +++ b/datajoint/blob.py @@ -1,27 +1,27 @@ import zlib -import collections +from collections import OrderedDict import numpy as np from . import DataJointError -mxClassID = collections.OrderedDict( +mxClassID = OrderedDict(( # see http://www.mathworks.com/help/techdoc/apiref/mxclassid.html - mxUNKNOWN_CLASS=None, - mxCELL_CLASS=None, # TODO: implement - mxSTRUCT_CLASS=None, # TODO: implement - mxLOGICAL_CLASS=np.dtype('bool'), - mxCHAR_CLASS=np.dtype('c'), - mxVOID_CLASS=None, - mxDOUBLE_CLASS=np.dtype('float64'), - mxSINGLE_CLASS=np.dtype('float32'), - mxINT8_CLASS=np.dtype('int8'), - mxUINT8_CLASS=np.dtype('uint8'), - mxINT16_CLASS=np.dtype('int16'), - mxUINT16_CLASS=np.dtype('uint16'), - mxINT32_CLASS=np.dtype('int32'), - mxUINT32_CLASS=np.dtype('uint32'), - mxINT64_CLASS=np.dtype('int64'), - mxUINT64_CLASS=np.dtype('uint64'), - mxFUNCTION_CLASS=None) + ('mxUNKNOWN_CLASS', None), + ('mxCELL_CLASS', None), # TODO: implement + ('mxSTRUCT_CLASS', None), # TODO: implement + ('mxLOGICAL_CLASS', np.dtype('bool')), + ('mxCHAR_CLASS', np.dtype('c')), + ('mxVOID_CLASS', None), + ('mxDOUBLE_CLASS', np.dtype('float64')), + ('mxSINGLE_CLASS', np.dtype('float32')), + ('mxINT8_CLASS', np.dtype('int8')), + ('mxUINT8_CLASS', np.dtype('uint8')), + ('mxINT16_CLASS', np.dtype('int16')), + ('mxUINT16_CLASS', np.dtype('uint16')), + ('mxINT32_CLASS', np.dtype('int32')), + ('mxUINT32_CLASS', np.dtype('uint32')), + ('mxINT64_CLASS', np.dtype('int64')), + ('mxUINT64_CLASS', np.dtype('uint64')), + ('mxFUNCTION_CLASS', None))) reverseClassID = {v: i for i, v in enumerate(mxClassID.values())} dtypeList = list(mxClassID.values()) @@ -77,7 +77,7 @@ def unpack(blob): type_number = np.fromstring(blob[p:p+4], dtype=np.uint32)[0] dtype = dtypeList[type_number] if dtype is None: - raise DataJointError('Unsupported MATLAB data type '+mx_type+' in blob') + raise DataJointError('Unsupported MATLAB data type '+type_number+' in blob') p += 4 is_complex = np.fromstring(blob[p:p+4], dtype=np.uint32)[0] p += 4 From 066ca4c6b0adaa8c83f7ad1a03ca0e320ab6b41a Mon Sep 17 00:00:00 2001 From: Fabian Sinz Date: Sat, 20 Jun 2015 09:31:46 -0500 Subject: [PATCH 40/42] fixed Dimitri's pull request --- datajoint/relational_operand.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datajoint/relational_operand.py b/datajoint/relational_operand.py index bf9b4921d..367fd2d1e 100644 --- a/datajoint/relational_operand.py +++ b/datajoint/relational_operand.py @@ -188,8 +188,8 @@ def fetch(self, offset=0, limit=None, order_by=None, descending=False, as_dict=F descending=descending, as_dict=as_dict) heading = self.heading if as_dict: - ret = [OrderedDict((name, unpack(ret[name]) if heading[name].is_blob else ret[name]) - for name in self.heading.names) + ret = [OrderedDict((name, unpack(d[name]) if heading[name].is_blob else d[name]) + for name in self.heading.names) for d in cur.fetchall()] else: ret = np.array(list(cur.fetchall()), dtype=heading.as_dtype) From f6fe321ff408171df05a348c00029622a8140114 Mon Sep 17 00:00:00 2001 From: Fabian Sinz Date: Sat, 20 Jun 2015 12:10:44 -0500 Subject: [PATCH 41/42] more tests and context manager --- datajoint/connection.py | 15 ++ tests/test_blob.py | 21 ++- tests/test_connection.py | 368 ++++++++++++--------------------------- 3 files changed, 146 insertions(+), 258 deletions(-) diff --git a/datajoint/connection.py b/datajoint/connection.py index 0ba65b532..0c668d9fe 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -1,3 +1,4 @@ +from contextlib import contextmanager import pymysql from . import DataJointError import logging @@ -127,3 +128,17 @@ def commit_transaction(self): self.query('COMMIT') self._in_transaction = False logger.info("Transaction committed and closed.") + + + #-------- context manager for transactions + @contextmanager + def transaction(self): + try: + self.start_transaction() + yield self + except: + self.cancel_transaction() + raise + else: + self.commit_transaction() + diff --git a/tests/test_blob.py b/tests/test_blob.py index 489707f31..322be02cc 100644 --- a/tests/test_blob.py +++ b/tests/test_blob.py @@ -1,7 +1,9 @@ +from datajoint import DataJointError + __author__ = 'fabee' import numpy as np from datajoint.blob import pack, unpack -from numpy.testing import assert_array_equal +from numpy.testing import assert_array_equal, raises def test_pack(): @@ -16,3 +18,20 @@ def test_pack(): x = np.int16(np.random.randn(1, 2, 3)) assert_array_equal(x, unpack(pack(x)), "Arrays do not match!") + +@raises(DataJointError) +def test_error(): + pack(dict()) + +def test_complex(): + z = np.random.randn(8, 10) + 1j*np.random.randn(8,10) + assert_array_equal(z, unpack(pack(z)), "Arrays do not match!") + + z = np.random.randn(10)+ 1j*np.random.randn(10) + assert_array_equal(z, unpack(pack(z)), "Arrays do not match!") + + x = np.float32(np.random.randn(3, 4, 5)) + 1j*np.float32(np.random.randn(3, 4, 5)) + assert_array_equal(x, unpack(pack(x)), "Arrays do not match!") + + x = np.int16(np.random.randn(1, 2, 3)) + 1j*np.int16(np.random.randn(1, 2, 3)) + assert_array_equal(x, unpack(pack(x)), "Arrays do not match!") diff --git a/tests/test_connection.py b/tests/test_connection.py index 29fee4f64..b94a07adf 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,257 +1,111 @@ -# """ -# Collection of test cases to test connection module. -# """ -# from .schemata import schema1 -# from .schemata.schema1 import test1 -# import numpy as np -# -# __author__ = 'eywalker, fabee' -# from . import (CONN_INFO, PREFIX, BASE_CONN, cleanup) -# from nose.tools import assert_true, assert_raises, assert_equal, raises -# import datajoint as dj -# from datajoint.utils import DataJointError -# -# -# def setup(): -# cleanup() -# -# -# def test_dj_conn(): -# """ -# Should be able to establish a connection -# """ -# c = dj.conn(**CONN_INFO) -# assert c.is_connected -# -# -# def test_persistent_dj_conn(): -# """ -# conn() method should provide persistent connection -# across calls. -# """ -# c1 = dj.conn(**CONN_INFO) -# c2 = dj.conn() -# assert_true(c1 is c2) -# -# -# def test_dj_conn_reset(): -# """ -# Passing in reset=True should allow for new persistent -# connection to be created. -# """ -# c1 = dj.conn(**CONN_INFO) -# c2 = dj.conn(reset=True, **CONN_INFO) -# assert_true(c1 is not c2) -# -# -# -# def setup_sample_db(): -# """ -# Helper method to setup databases with tables to be used -# during the test -# """ -# cur = BASE_CONN.cursor() -# cur.execute("CREATE DATABASE `{}_test1`".format(PREFIX)) -# cur.execute("CREATE DATABASE `{}_test2`".format(PREFIX)) -# query1 = """ -# CREATE TABLE `{prefix}_test1`.`subjects` -# ( -# subject_id SMALLINT COMMENT 'Unique subject ID', -# subject_name VARCHAR(255) COMMENT 'Subject name', -# subject_email VARCHAR(255) COMMENT 'Subject email address', -# PRIMARY KEY (subject_id) -# ) -# """.format(prefix=PREFIX) -# cur.execute(query1) -# # query2 = """ -# # CREATE TABLE `{prefix}_test2`.`experiments` -# # ( -# # experiment_id SMALLINT COMMENT 'Unique experiment ID', -# # experiment_name VARCHAR(255) COMMENT 'Experiment name', -# # subject_id SMALLINT, -# # CONSTRAINT FOREIGN KEY (`subject_id`) REFERENCES `dj_test1`.`subjects` (`subject_id`) ON UPDATE CASCADE ON DELETE RESTRICT, -# # PRIMARY KEY (subject_id, experiment_id) -# # )""".format(prefix=PREFIX) -# # cur.execute(query2) -# -# -# class TestConnectionWithoutBindings(object): -# """ -# Test methods from Connection that does not -# depend on presence of module to database bindings. -# This includes tests for `bind` method itself. -# """ -# def setup(self): -# self.conn = dj.Connection(**CONN_INFO) -# test1.__dict__.pop('conn', None) -# schema1.__dict__.pop('conn', None) -# setup_sample_db() -# -# def teardown(self): -# cleanup() -# -# def check_binding(self, db_name, module): -# """ -# Helper method to check if the specified database-module pairing exists -# """ -# assert_equal(self.conn.db_to_mod[db_name], module) -# assert_equal(self.conn.mod_to_db[module], db_name) -# -# def test_bind_to_existing_database(self): -# """ -# Should be able to bind a module to an existing database -# """ -# db_name = PREFIX + '_test1' -# module = test1.__name__ -# self.conn.bind(module, db_name) -# self.check_binding(db_name, module) -# -# def test_bind_at_package_level(self): -# db_name = PREFIX + '_test1' -# package = schema1.__name__ -# self.conn.bind(package, db_name) -# self.check_binding(db_name, package) -# -# def test_bind_to_non_existing_database(self): -# """ -# Should be able to bind a module to a non-existing database by creating target -# """ -# db_name = PREFIX + '_test3' -# module = test1.__name__ -# cur = BASE_CONN.cursor() -# -# # Ensure target database doesn't exist -# if cur.execute("SHOW DATABASES LIKE '{}'".format(db_name)): -# cur.execute("DROP DATABASE IF EXISTS `{}`".format(db_name)) -# # Bind module to non-existing database -# self.conn.bind(module, db_name) -# # Check that target database was created -# assert_equal(cur.execute("SHOW DATABASES LIKE '{}'".format(db_name)), 1) -# self.check_binding(db_name, module) -# # Remove the target database -# cur.execute("DROP DATABASE IF EXISTS `{}`".format(db_name)) -# -# def test_cannot_bind_to_multiple_databases(self): -# """ -# Bind will fail when db_name is a pattern that -# matches multiple databases -# """ -# db_name = PREFIX + "_test%%" -# module = test1.__name__ -# with assert_raises(DataJointError): -# self.conn.bind(module, db_name) -# -# def test_basic_sql_query(self): -# """ -# Test execution of basic SQL query using connection -# object. -# """ -# cur = self.conn.query('SHOW DATABASES') -# results1 = cur.fetchall() -# cur2 = BASE_CONN.cursor() -# cur2.execute('SHOW DATABASES') -# results2 = cur2.fetchall() -# assert_equal(results1, results2) -# -# def test_transaction_commit(self): -# """ -# Test transaction commit -# """ -# table_name = PREFIX + '_test1.subjects' -# self.conn.start_transaction() -# self.conn.query("INSERT INTO {table} VALUES (0, 'dj_user', 'dj_user@example.com')".format(table=table_name)) -# cur = BASE_CONN.cursor() -# assert_equal(cur.execute("SELECT * FROM {}".format(table_name)), 0) -# self.conn.commit_transaction() -# assert_equal(cur.execute("SELECT * FROM {}".format(table_name)), 1) -# -# def test_transaction_rollback(self): -# """ -# Test transaction rollback -# """ -# table_name = PREFIX + '_test1.subjects' -# self.conn.start_transaction() -# self.conn.query("INSERT INTO {table} VALUES (0, 'dj_user', 'dj_user@example.com')".format(table=table_name)) -# cur = BASE_CONN.cursor() -# assert_equal(cur.execute("SELECT * FROM {}".format(table_name)), 0) -# self.conn.cancel_transaction() -# assert_equal(cur.execute("SELECT * FROM {}".format(table_name)), 0) -# -# # class TestContextManager(object): -# # def __init__(self): -# # self.relvar = None -# # self.setup() -# # -# # """ -# # Test cases for FreeRelation objects -# # """ -# # -# # def setup(self): -# # """ -# # Create a connection object and prepare test modules -# # as follows: -# # test1 - has conn and bounded -# # """ -# # cleanup() # drop all databases with PREFIX -# # test1.__dict__.pop('conn', None) -# # -# # self.conn = dj.Connection(**CONN_INFO) -# # test1.conn = self.conn -# # self.conn.bind(test1.__name__, PREFIX + '_test1') -# # self.relvar = test1.Subjects() -# # -# # def teardown(self): -# # cleanup() -# # -# # # def test_active(self): -# # # with self.conn.transaction() as tr: -# # # assert_true(tr.is_active, "Transaction is not active") -# # -# # # def test_rollback(self): -# # # -# # # tmp = np.array([(1,'Peter','mouse'),(2, 'Klara', 'monkey')], -# # # dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) -# # # -# # # self.relvar.insert(tmp[0]) -# # # try: -# # # with self.conn.transaction(): -# # # self.relvar.insert(tmp[1]) -# # # raise DataJointError("Just to test") -# # # except DataJointError as e: -# # # pass -# # # -# # # testt2 = (self.relvar & 'subject_id = 2').fetch() -# # # assert_equal(len(testt2), 0, "Length is not 0. Expected because rollback should have happened.") -# # -# # # def test_cancel(self): -# # # """Tests cancelling a transaction""" -# # # tmp = np.array([(1,'Peter','mouse'),(2, 'Klara', 'monkey')], -# # # dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) -# # # -# # # self.relvar.insert(tmp[0]) -# # # with self.conn.transaction() as transaction: -# # # self.relvar.insert(tmp[1]) -# # # transaction.cancel() -# # # -# # # testt2 = (self.relvar & 'subject_id = 2').fetch() -# # # assert_equal(len(testt2), 0, "Length is not 0. Expected because rollback should have happened.") -# -# -# -# # class TestConnectionWithBindings(object): -# # """ -# # Tests heading and dependency loadings -# # """ -# # def setup(self): -# # self.conn = dj.Connection(**CONN_INFO) -# # cur.execute(query) -# -# -# -# -# -# -# -# -# -# +""" +Collection of test cases to test connection module. +""" +from tests.schemata.test1 import Subjects + +__author__ = 'eywalker, fabee' +from . import CONN_INFO, PREFIX, BASE_CONN, cleanup +from nose.tools import assert_true, assert_raises, assert_equal, raises +import datajoint as dj +from datajoint import DataJointError +import numpy as np + +def setup(): + cleanup() + + +def test_dj_conn(): + """ + Should be able to establish a connection + """ + c = dj.conn(**CONN_INFO) + assert c.is_connected + + +def test_persistent_dj_conn(): + """ + conn() method should provide persistent connection + across calls. + """ + c1 = dj.conn(**CONN_INFO) + c2 = dj.conn() + assert_true(c1 is c2) + + +def test_dj_conn_reset(): + """ + Passing in reset=True should allow for new persistent + connection to be created. + """ + c1 = dj.conn(**CONN_INFO) + c2 = dj.conn(reset=True, **CONN_INFO) + assert_true(c1 is not c2) + + +def test_repr(): + c1 = dj.conn(**CONN_INFO) + assert_true('disconnected' not in c1.__repr__() and 'connected' in c1.__repr__()) + +def test_del(): + c1 = dj.conn(**CONN_INFO) + assert_true('disconnected' not in c1.__repr__() and 'connected' in c1.__repr__()) + del c1 + + + +class TestContextManager(object): + def __init__(self): + self.relvar = None + self.setup() + + """ + Test cases for FreeRelation objects + """ + + def setup(self): + """ + Create a connection object and prepare test modules + as follows: + test1 - has conn and bounded + """ + cleanup() # drop all databases with PREFIX + self.conn = dj.conn() + self.relvar = Subjects() + + def teardown(self): + cleanup() + + def test_active(self): + with self.conn.transaction() as conn: + assert_true(conn.in_transaction, "Transaction is not active") + + def test_rollback(self): + + tmp = np.array([(1,'Peter','mouse'),(2, 'Klara', 'monkey')], + dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) + + self.relvar.insert(tmp[0]) + try: + with self.conn.transaction(): + self.relvar.insert(tmp[1]) + raise DataJointError("Just to test") + except DataJointError as e: + pass + testt2 = (self.relvar & 'subject_id = 2').fetch() + assert_equal(len(testt2), 0, "Length is not 0. Expected because rollback should have happened.") + + def test_cancel(self): + """Tests cancelling a transaction""" + tmp = np.array([(1,'Peter','mouse'),(2, 'Klara', 'monkey')], + dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) + + self.relvar.insert(tmp[0]) + with self.conn.transaction() as conn: + self.relvar.insert(tmp[1]) + conn.cancel_transaction() + + testt2 = (self.relvar & 'subject_id = 2').fetch() + assert_equal(len(testt2), 0, "Length is not 0. Expected because rollback should have happened.") + + + From 7cb1173fc6591dbf0bce8e3e94ded303b3189d58 Mon Sep 17 00:00:00 2001 From: Fabian Sinz Date: Sat, 20 Jun 2015 14:06:55 -0500 Subject: [PATCH 42/42] more tests --- datajoint/settings.py | 4 +- tests/test_relation.py | 129 +++++++++++++++++++++++------------------ tests/test_settings.py | 12 +++- 3 files changed, 85 insertions(+), 60 deletions(-) diff --git a/datajoint/settings.py b/datajoint/settings.py index 7e5bfe04a..7dac8186a 100644 --- a/datajoint/settings.py +++ b/datajoint/settings.py @@ -107,7 +107,6 @@ def save(self, filename=None): :param filename: filename of the local JSON settings file. If None, the local config file is used. """ if filename is None: - import datajoint as dj filename = LOCALCONFIG with open(filename, 'w') as fid: json.dump(self._conf, fid, indent=4) @@ -119,8 +118,7 @@ def load(self, filename): :param filename=None: filename of the local JSON settings file. If None, the local config file is used. """ if filename is None: - import datajoint as dj - filename = dj.config['config.file'] + filename = LOCALCONFIG with open(filename, 'r') as fid: self.update(json.load(fid)) diff --git a/tests/test_relation.py b/tests/test_relation.py index 635124130..fd27876ce 100644 --- a/tests/test_relation.py +++ b/tests/test_relation.py @@ -12,7 +12,7 @@ ErrorGenerator, testschema from . import BASE_CONN, CONN_INFO, PREFIX, cleanup # from datajoint.connection import Connection -from nose.tools import assert_raises, assert_equal, assert_regexp_matches, assert_false, assert_true, assert_list_equal,\ +from nose.tools import assert_raises, assert_equal, assert_regexp_matches, assert_false, assert_true, assert_list_equal, \ assert_tuple_equal, assert_dict_equal, raises # from datajoint import DataJointError, TransactionError, AutoPopulate, Relation import numpy as np @@ -26,7 +26,8 @@ def trial_faker(n=10): def iter(): for s in [1, 2]: for i in range(n): - yield dict(trial_id=i, subject_id=s, outcome=int(np.random.randint(10)), notes= 'no comment') + yield dict(trial_id=i, subject_id=s, outcome=int(np.random.randint(10)), notes='no comment') + return iter() @@ -50,13 +51,26 @@ def setup(self): self.animals = Animals() self.relvar_blob = Matrix() self.trials = Trials() + self.score = SquaredScore() + self.subtable = SquaredSubtable() + + def test_table_name_manual(self): + assert_true(not self.subjects.table_name.startswith('#') and + not self.subjects.table_name.startswith('_') and not self.subjects.table_name.startswith('__')) + def test_table_name_computed(self): + assert_true(self.score.table_name.startswith('__')) + def test_population_relation_subordinate(self): + assert_true(self.subtable.populate_relation is None) + + @raises(NotImplementedError) + def test_make_tubles_not_implemented_subordinate(self): + self.subtable._make_tuples(None) def test_instantiate_relation(self): s = Subjects() - def teardown(self): cleanup() @@ -89,35 +103,36 @@ def test_record_insert(self): def test_delete(self): "Test whether delete works" - tmp = np.array([(2, 'Klara', 'monkey'), (1,'Peter', 'mouse')], + tmp = np.array([(2, 'Klara', 'monkey'), (1, 'Peter', 'mouse')], dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) self.subjects.batch_insert(tmp) assert_true(len(self.subjects) == 2, 'Length does not match 2.') self.subjects.delete() assert_true(len(self.subjects) == 0, 'Length does not match 0.') -# -# # def test_cascading_delete(self): -# # "Test whether delete works" -# # tmp = np.array([(2, 'Klara', 'monkey'), (1,'Peter', 'mouse')], -# # dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) -# # -# # self.subjects.batch_insert(tmp) -# # -# # self.trials.insert(dict(subject_id=1, trial_id=1, outcome=0)) -# # self.trials.insert(dict(subject_id=1, trial_id=2, outcome=1)) -# # self.trials.insert(dict(subject_id=2, trial_id=3, outcome=2)) -# # assert_true(len(self.subjects) == 2, 'Length does not match 2.') -# # assert_true(len(self.trials) == 3, 'Length does not match 3.') -# # (self.subjects & 'subject_id=1').delete() -# # assert_true(len(self.subjects) == 1, 'Length does not match 1.') -# # assert_true(len(self.trials) == 1, 'Length does not match 1.') -# -# def test_short_hand_foreign_reference(self): -# self.animals.heading -# -# -# + + # + # # def test_cascading_delete(self): + # # "Test whether delete works" + # # tmp = np.array([(2, 'Klara', 'monkey'), (1,'Peter', 'mouse')], + # # dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) + # # + # # self.subjects.batch_insert(tmp) + # # + # # self.trials.insert(dict(subject_id=1, trial_id=1, outcome=0)) + # # self.trials.insert(dict(subject_id=1, trial_id=2, outcome=1)) + # # self.trials.insert(dict(subject_id=2, trial_id=3, outcome=2)) + # # assert_true(len(self.subjects) == 2, 'Length does not match 2.') + # # assert_true(len(self.trials) == 3, 'Length does not match 3.') + # # (self.subjects & 'subject_id=1').delete() + # # assert_true(len(self.subjects) == 1, 'Length does not match 1.') + # # assert_true(len(self.trials) == 1, 'Length does not match 1.') + # + # def test_short_hand_foreign_reference(self): + # self.animals.heading + # + # + # def test_record_insert_different_order(self): "Test whether record insert works" tmp = np.array([('Klara', 2, 'monkey')], @@ -154,7 +169,8 @@ def test_wrong_key_insert(self): 'species': 'human'} self.subjects.insert(tmp) -# + + # def test_batch_insert(self): "Test whether record insert works" tmp = np.array([('Klara', 2, 'monkey'), ('Brunhilda', 3, 'mouse'), ('Mickey', 1, 'human')], @@ -167,9 +183,10 @@ def test_batch_insert(self): dtype=[('subject_id', 'i4'), ('species', 'O')]) self.subjects.batch_insert(tmp) - for trial_id in range(1,11): - self.trials.insert(dict(subject_id=2, trial_id=trial_id, outcome=np.random.randint(0,10))) + for trial_id in range(1, 11): + self.trials.insert(dict(subject_id=2, trial_id=trial_id, outcome=np.random.randint(0, 10))) def teardown(self): cleanup() @@ -369,21 +387,20 @@ def test_autopopulate(self): self.squared.populate() assert_equal(len(self.squared), 10) - for trial in self.trials*self.squared: - assert_equal(trial['outcome']**2, trial['squared']) + for trial in self.trials * self.squared: + assert_equal(trial['outcome'] ** 2, trial['squared']) def test_autopopulate_restriction(self): self.squared.populate(restriction='trial_id <= 5') assert_equal(len(self.squared), 5) - for trial in self.trials*self.squared: - assert_equal(trial['outcome']**2, trial['squared']) + for trial in self.trials * self.squared: + assert_equal(trial['outcome'] ** 2, trial['squared']) @raises(DataJointError) def test_autopopulate_relation_check(self): @testschema class dummy(dj.Computed): - def populate_relation(self): return None @@ -391,7 +408,7 @@ def _make_tuples(self, key): pass du = dummy() - du.populate() \ + du.populate() @raises(DataJointError) def test_autopopulate_relation_check(self): @@ -399,7 +416,7 @@ def test_autopopulate_relation_check(self): @raises(Exception) def test_autopopulate_relation_check(self): - self.error_generator.populate()\ + self.error_generator.populate() @raises(Exception) def test_autopopulate_relation_check2(self): diff --git a/tests/test_settings.py b/tests/test_settings.py index 6b8100806..d10323b10 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -8,7 +8,7 @@ from nose.tools import assert_true, assert_raises, assert_equal, raises, assert_dict_equal import datajoint as dj - +import os def test_load_save(): dj.config.save('tmp.json') @@ -65,3 +65,13 @@ def test_save(): assert_true(os.path.isfile(settings.LOCALCONFIG)) if moved: os.rename(tmpfile, settings.LOCALCONFIG) + +def test_load_save(): + + filename_old = dj.settings.LOCALCONFIG + filename = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(50)) + '.json' + dj.settings.LOCALCONFIG = filename + dj.config.save() + dj.config.load(filename=filename) + dj.settings.LOCALCONFIG = filename_old + os.remove(filename)