From 69b5f422c5d087df662c38d743840382a8c4ceb5 Mon Sep 17 00:00:00 2001 From: Fabian Sinz Date: Mon, 8 Jun 2015 17:24:34 -0500 Subject: [PATCH 01/20] * merged Relation and ClassLevelRelation * converted almost everything to @classproperty or @classmethod --- datajoint/__init__.py | 2 +- .../{abstract_relation.py => relation.py} | 267 ++++++++++++------ datajoint/relations.py | 94 ------ datajoint/user_relations.py | 2 +- tests/test_relation.py | 2 +- 5 files changed, 183 insertions(+), 184 deletions(-) rename datajoint/{abstract_relation.py => relation.py} (65%) delete mode 100644 datajoint/relations.py diff --git a/datajoint/__init__.py b/datajoint/__init__.py index e3a8007a6..4eb1eb6d1 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -56,7 +56,7 @@ def culprit(self): # ------------- flatten import hierarchy ------------------------- from .connection import conn, Connection from .user_relations import Manual, Lookup, Imported, Computed -from .abstract_relation import Relation +from .relation import Relation from .autopopulate import AutoPopulate from . import blob from .relational_operand import Not diff --git a/datajoint/abstract_relation.py b/datajoint/relation.py similarity index 65% rename from datajoint/abstract_relation.py rename to datajoint/relation.py index da191fef6..ef9e43938 100644 --- a/datajoint/abstract_relation.py +++ b/datajoint/relation.py @@ -1,3 +1,4 @@ +from collections import namedtuple from collections.abc import MutableMapping, Mapping import numpy as np import logging @@ -5,6 +6,8 @@ import abc from . import DataJointError, config, TransactionError +import pymysql +from datajoint import DataJointError, conn from .relational_operand import RelationalOperand from .blob import pack from .utils import user_choice @@ -13,6 +16,65 @@ logger = logging.getLogger(__name__) +SharedInfo = namedtuple( + 'SharedInfo', + ('database', 'context', 'connection', 'heading', 'parents', 'children', 'references', 'referenced')) + + +class classproperty: + def __init__(self, getf): + self._getf = getf + + def __get__(self, instance, owner): + return self._getf(owner) + + +def schema(database, context, connection=None): + """ + Returns a schema decorator that can be used to associate a Relation class to a + specific database with :param name:. Name reference to other tables in the table definition + will be resolved by looking up the corresponding key entry in the passed in context dictionary. + It is most common to set context equal to the return value of call to locals() in the module. + For more details, please refer to the tutorial online. + + :param database: name of the database to associate the decorated class with + :param context: dictionary used to resolve (any) name references within the table definition string + :param connection: connection object to the database server. If ommited, will try to establish connection according to + config values + :return: a decorator function to be used on Relation derivative classes + """ + if connection is None: + connection = conn() + + # if the database does not exist, create it + cur = connection.query("SHOW DATABASES LIKE '{database}'".format(database=database)) + if cur.rowcount == 0: + logger.info("Database `{database}` could not be found. " + "Attempting to create the database.".format(database=database)) + try: + connection.query("CREATE DATABASE `{database}`".format(database=database)) + logger.info('Created database `{database}`.'.format(database=database)) + except pymysql.OperationalError: + raise DataJointError("Database named `{database}` was not defined, and" + "an attempt to create has failed. Check" + " permissions.".format(database=database)) + + def decorator(cls): + cls._shared_info = SharedInfo( + database=database, + context=context, + connection=connection, + heading=None, + parents=[], + children=[], + references=[], + referenced=[] + ) + return cls + + return decorator + + class Relation(RelationalOperand, metaclass=abc.ABCMeta): """ @@ -26,103 +88,109 @@ class Relation(RelationalOperand, metaclass=abc.ABCMeta): __heading = None + _shared_info = None + + def __init__(self): # TODO: Think about it + if self._shared_info is None: + raise DataJointError('The class must define _shared_info') + # ---------- abstract properties ------------ # - @property + @classproperty @abc.abstractmethod - def table_name(self): + def table_name(cls): """ :return: the name of the table in the database """ pass - @property + @classproperty @abc.abstractmethod - def database(self): + def database(cls): """ :return: string containing the database name on the server """ pass - @property + @classproperty @abc.abstractmethod - def definition(self): + def definition(cls): """ :return: a string containing the table definition using the DataJoint DDL """ pass - @property + @classproperty @abc.abstractmethod - def context(self): + def context(cls): """ :return: a dict with other relations that can be referenced by foreign keys """ pass # --------- base relation functionality --------- # - @property - def is_declared(self): - if self.__heading is not None: + @classproperty + def is_declared(cls): + if cls.__heading is not None: return True - cur = self.query( + cur = cls._shared_info.connection.query( 'SHOW TABLE STATUS FROM `{database}` WHERE name="{table_name}"'.format( - table_name=self.table_name)) + table_name=cls.table_name)) return cur.rowcount == 1 - @property - def heading(self): + @classproperty + def heading(cls): """ Required by relational operand :return: a datajoint.Heading object """ - if self.__heading is None: - if not self.is_declared and self.definition: - self.declare() - if self.is_declared: - self.__heading = Heading.init_from_database( - self.connection, self.database, self.table_name) - return self.__heading + if cls.__heading is None: + cls.__heading = Heading.init_from_database(cls.connection, cls.database, cls.table_name) + return cls.__heading - @property - def from_clause(self): + @classproperty + def from_clause(cls): """ Required by the Relational class, this property specifies the contents of the FROM clause for the SQL SELECT statements. :return: """ - return '`%s`.`%s`' % (self.database, self.table_name) + return '`%s`.`%s`' % (cls.database, cls.table_name) - def declare(self): + @classmethod + def declare(cls): """ Declare the table in database if it doesn't already exist. :raises: DataJointError if the table cannot be declared. """ - if not self.is_declared: - self._declare() - # verify that declaration completed successfully - if not self.is_declared: - raise DataJointError( - 'Relation could not be declared for %s' % self.class_name) + if not cls.is_declared: + cls._declare() - def iter_insert(self, rows, **kwargs): + @classmethod + def iter_insert(cls, rows, **kwargs): """ Inserts an entire batch of entries. Additional keyword arguments are passed to insert. :param iter: Must be an iterator that generates a sequence of valid arguments for insert. """ for row in rows: - self.insert(row, **kwargs) + cls.insert(row, **kwargs) - def batch_insert(self, data, **kwargs): + @classmethod + def batch_insert(cls, data, **kwargs): """ Inserts an entire batch of entries. Additional keyword arguments are passed to insert. :param data: must be iterable, each row must be a valid argument for insert """ - self.iter_insert(data.__iter__(), **kwargs) + cls.iter_insert(data.__iter__(), **kwargs) - def insert(self, tup, ignore_errors=False, replace=False): + @classproperty + def full_table_name(cls): + return r"`{0:s}`.`{1:s}`".format(cls.database, cls.table_name) + + @classmethod + def insert(cls, tup, ignore_errors=False, replace=False): """ Insert one data record or one Mapping (like a dictionary). @@ -137,7 +205,7 @@ def insert(self, tup, ignore_errors=False, replace=False): real_id = 1007, date_of_birth = "2014-09-01")) """ - heading = self.heading + heading = cls.heading if isinstance(tup, np.void): for fieldname in tup.dtype.fields: if fieldname not in heading: @@ -167,10 +235,10 @@ def insert(self, tup, ignore_errors=False, replace=False): sql = 'INSERT IGNORE' else: sql = 'INSERT' - sql += " INTO %s (%s) VALUES (%s)" % (self.full_table_name, + sql += " INTO %s (%s) VALUES (%s)" % (cls.full_table_name, attribute_list, value_list) logger.info(sql) - self.connection.query(sql, args=args) + cls.connection.query(sql, args=args) def delete(self): if not config['safemode'] or user_choice( @@ -178,39 +246,41 @@ def delete(self): "Proceed?", default='no') == 'yes': self.connection.query('DELETE FROM ' + self.from_clause + self.where_clause) # TODO: make cascading (issue #15) - def drop(self): + @classmethod + def drop(cls): """ - Drops the table associated to this object. + Drops the table associated to this class. """ - if self.is_declared: + if cls.is_declared: if not config['safemode'] or user_choice( "You are about to drop an entire table. This operation cannot be undone.\n" "Proceed?", default='no') == 'yes': - self.connection.query('DROP TABLE %s' % self.full_table_name) # TODO: make cascading (issue #16) - self.connection.clear_dependencies(dbname=self.dbname) - self.connection.load_headings(dbname=self.dbname, force=True) - logger.info("Dropped table %s" % self.full_table_name) + cls.connection.query('DROP TABLE %s' % cls.full_table_name) # TODO: make cascading (issue #16) + # cls.connection.clear_dependencies(dbname=cls.dbname) #TODO: reimplement because clear_dependencies will be gone + # cls.connection.load_headings(dbname=cls.dbname, force=True) #TODO: reimplement because load_headings is gone + logger.info("Dropped table %s" % cls.full_table_name) - @property - def size_on_disk(self): + @classproperty + def size_on_disk(cls): """ :return: size of data and indices in MiB taken by the table on the storage device """ - cur = self.connection.query( + cur = cls.connection.query( 'SHOW TABLE STATUS FROM `{dbname}` WHERE NAME="{table}"'.format( - dbname=self.dbname, table=self.table_name), as_dict=True) + dbname=cls.dbname, table=cls.table_name), as_dict=True) ret = cur.fetchone() return (ret['Data_length'] + ret['Index_length'])/1024**2 - def set_table_comment(self, comment): + @classmethod + def set_table_comment(cls, comment): """ Update the table comment in the table definition. :param comment: new comment as string """ - # TODO: add verification procedure (github issue #24) - self.alter('COMMENT="%s"' % comment) + cls.alter('COMMENT="%s"' % comment) - def add_attribute(self, definition, after=None): + @classmethod + def add_attribute(cls, definition, after=None): """ Add a new attribute to the table. A full line from the table definition is passed in as definition. @@ -226,9 +296,10 @@ def add_attribute(self, definition, after=None): position = ' FIRST' if after is None else ( ' AFTER %s' % after if after else '') sql = field_to_sql(parse_attribute_definition(definition)) - self._alter('ADD COLUMN %s%s' % (sql[:-2], position)) + cls._alter('ADD COLUMN %s%s' % (sql[:-2], position)) - def drop_attribute(self, attr_name): + @classmethod + def drop_attribute(cls, attr_name): """ Drops the attribute attrName from this table. @@ -238,9 +309,10 @@ def drop_attribute(self, attr_name): "You are about to drop an attribute from a table." "This operation cannot be undone.\n" "Proceed?", default='no') == 'yes': - self._alter('DROP COLUMN `%s`' % attr_name) + cls._alter('DROP COLUMN `%s`' % attr_name) - def alter_attribute(self, attr_name, new_definition): + @classmethod + def alter_attribute(cls, attr_name, new_definition): """ Alter the definition of the field attr_name in this table using the new definition. @@ -248,44 +320,46 @@ def alter_attribute(self, attr_name, new_definition): :param new_definition: new definition of the field """ sql = field_to_sql(parse_attribute_definition(new_definition)) - self._alter('CHANGE COLUMN `%s` %s' % (attr_name, sql[:-2])) + cls._alter('CHANGE COLUMN `%s` %s' % (attr_name, sql[:-2])) - def erd(self, subset=None): + @classmethod + def erd(cls, subset=None): """ Plot the schema's entity relationship diagram (ERD). """ - def _alter(self, alter_statement): + @classmethod + def _alter(cls, alter_statement): """ Execute ALTER TABLE statement for this table. The schema will be reloaded within the connection object. :param alter_statement: alter statement """ - if self._conn.in_transaction: + if cls.connection.in_transaction: raise TransactionError( u"_alter is currently in transaction. Operation not allowed to avoid implicit commits.", - self._alter, args=(alter_statement,)) + cls._alter, args=(alter_statement,)) - sql = 'ALTER TABLE %s %s' % (self.full_table_name, alter_statement) - self.connection.query(sql) - self.connection.load_headings(self.dbname, force=True) + sql = 'ALTER TABLE %s %s' % (cls.full_table_name, alter_statement) + cls.connection.query(sql) + cls.connection.load_headings(cls.dbname, force=True) # TODO: place table definition sync mechanism - @staticmethod - def _declare(self): + @classmethod + def _declare(cls): """ Declares the table in the database if no table in the database matches this object. """ - if self.connection.in_transaction: + if cls.connection.in_transaction: raise TransactionError( - u"_declare is currently in transaction. Operation not allowed to avoid implicit commits.", self._declare) + u"_declare is currently in transaction. Operation not allowed to avoid implicit commits.", cls._declare) - if not self.definition: # if empty definition was supplied + if not cls.definition: # if empty definition was supplied raise DataJointError('Table definition is missing!') - table_info, parents, referenced, field_defs, index_defs = self._parse_declaration() + table_info, parents, referenced, field_defs, index_defs = cls._parse_declaration() - sql = 'CREATE TABLE %s (\n' % self.full_table_name + sql = 'CREATE TABLE %s (\n' % cls.full_table_name # add inherited primary key fields primary_key_fields = set() @@ -349,15 +423,16 @@ def _declare(self): sql = '%s\n) ENGINE = InnoDB, COMMENT "%s"' % ( sql[:-2], table_info['comment']) - # make sure that the table does not alredy exist - self.load_heading() - if not self.is_declared: - # execute declaration - logger.debug('\n\n' + sql + '\n\n') - self.connection.query(sql) - self.load_heading() + # # make sure that the table does not alredy exist + # cls.load_heading() + # if not cls.is_declared: + # # execute declaration + # logger.debug('\n\n' + sql + '\n\n') + # cls.connection.query(sql) + # cls.load_heading() - def _parse_declaration(self): + @classmethod + def _parse_declaration(cls): """ Parse declaration and create new SQL table accordingly. """ @@ -365,7 +440,7 @@ def _parse_declaration(self): referenced = [] index_defs = [] field_defs = [] - declaration = re.split(r'\s*\n\s*', self.definition.strip()) + declaration = re.split(r'\s*\n\s*', cls.definition.strip()) # remove comment lines declaration = [x for x in declaration if not x.startswith('#')] @@ -391,7 +466,7 @@ def _parse_declaration(self): # foreign key ref_name = line[2:].strip() ref_list = parents if in_key else referenced - ref_list.append(self.lookup_name(ref_name)) + ref_list.append(cls.lookup_name(ref_name)) elif re.match(r'^(unique\s+)?index[^:]*$', line, re.I): index_defs.append(parse_index_definition(line)) elif attribute_regexp.match(line): @@ -402,7 +477,8 @@ def _parse_declaration(self): return table_info, parents, referenced, field_defs, index_defs - def lookup_name(self, name): + @classmethod + def lookup_name(cls, name): """ Lookup the referenced name in the context dictionary @@ -411,7 +487,7 @@ def lookup_name(self, name): """ parts = name.strip().split('.') try: - ref = self.context.get(parts[0]) + ref = cls.context.get(parts[0]) for attr in parts[1:]: ref = getattr(ref, attr) except (KeyError, AttributeError): @@ -419,4 +495,21 @@ def lookup_name(self, name): 'Foreign key reference to %s could not be resolved.' 'Please make sure the name exists' 'in the context of the class' % name) - return ref \ No newline at end of file + return ref + + @classproperty + def connection(cls): + """ + Returns the connection object of the class + + :return: the connection object + """ + return cls.connection + + @classproperty + def database(cls): + return cls._shared_info.database + + @classproperty + def context(cls): + return cls._shared_info.context diff --git a/datajoint/relations.py b/datajoint/relations.py deleted file mode 100644 index 08f3f8900..000000000 --- a/datajoint/relations.py +++ /dev/null @@ -1,94 +0,0 @@ -import abc -import logging -from collections import namedtuple -import pymysql -from .connection import conn -from .abstract_relation import Relation -from . import DataJointError - - -logger = logging.getLogger(__name__) - - -SharedInfo = namedtuple( - 'SharedInfo', - ('database', 'context', 'connection', 'heading', 'parents', 'children', 'references', 'referenced')) - - -def schema(database, context, connection=None): - """ - Returns a schema decorator that can be used to associate a Relation class to a - specific database with :param name:. Name reference to other tables in the table definition - will be resolved by looking up the corresponding key entry in the passed in context dictionary. - It is most common to set context equal to the return value of call to locals() in the module. - For more details, please refer to the tutorial online. - - :param database: name of the database to associate the decorated class with - :param context: dictionary used to resolve (any) name references within the table definition string - :param connection: connection object to the database server. If ommited, will try to establish connection according to - config values - :return: a decorator function to be used on Relation derivative classes - """ - if connection is None: - connection = conn() - - # if the database does not exist, create it - cur = connection.query("SHOW DATABASES LIKE '{database}'".format(database=database)) - if cur.rowcount == 0: - logger.info("Database `{database}` could not be found. " - "Attempting to create the database.".format(database=database)) - try: - connection.query("CREATE DATABASE `{database}`".format(database=database)) - logger.info('Created database `{database}`.'.format(database=database)) - except pymysql.OperationalError: - raise DataJointError("Database named `{database}` was not defined, and" - "an attempt to create has failed. Check" - " permissions.".format(database=database)) - - def decorator(cls): - cls._shared_info = SharedInfo( - database=database, - context=context, - connection=connection, - heading=None, - parents=[], - children=[], - references=[], - referenced=[] - ) - return cls - - return decorator - - -class ClassBoundRelation(Relation): - """ - Abstract class for dedicated table classes. - Subclasses of ClassBoundRelation are dedicated interfaces to a single table. - The main purpose of ClassBoundRelation is to encapsulated sharedInfo containing the table heading - and dependency information shared by all instances of - """ - - _shared_info = None - - def __init__(self): - if self._shared_info is None: - raise DataJointError('The class must define _shared_info') - - @property - def database(self): - return self._shared_info.database - - @property - def connection(self): - return self._shared_info.connection - - @property - def context(self): - return self._shared_info.context - - @property - def heading(self): - if self._shared_info.heading is None: - self._shared_info.heading = super().heading - return self._shared_info.heading diff --git a/datajoint/user_relations.py b/datajoint/user_relations.py index 7fb5aba1a..56556508d 100644 --- a/datajoint/user_relations.py +++ b/datajoint/user_relations.py @@ -1,4 +1,4 @@ -from .relations import ClassBoundRelation +from datajoint.relation import ClassBoundRelation from .autopopulate import AutoPopulate from .utils import from_camel_case diff --git a/tests/test_relation.py b/tests/test_relation.py index 3facc6721..15f52f34a 100644 --- a/tests/test_relation.py +++ b/tests/test_relation.py @@ -12,7 +12,7 @@ from datajoint import DataJointError, TransactionError, AutoPopulate, Relation import numpy as np from numpy.testing import assert_array_equal -from datajoint.abstract_relation import FreeRelation +from datajoint.relation import FreeRelation import numpy as np From 2593680d8449bf290e0672dff775b9ccaba95787 Mon Sep 17 00:00:00 2001 From: Fabian Sinz Date: Mon, 8 Jun 2015 17:26:37 -0500 Subject: [PATCH 02/20] * decorator declares table --- datajoint/relation.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/datajoint/relation.py b/datajoint/relation.py index ef9e43938..729c92651 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -65,11 +65,8 @@ def decorator(cls): context=context, connection=connection, heading=None, - parents=[], - children=[], - references=[], - referenced=[] ) + cls.declare() return cls return decorator From 4097adbfea786fa78d0fea5f173712c683ace01f Mon Sep 17 00:00:00 2001 From: Fabian Sinz Date: Mon, 8 Jun 2015 17:32:35 -0500 Subject: [PATCH 03/20] * removed implicit commit handling --- datajoint/__init__.py | 18 ------------------ datajoint/autopopulate.py | 13 +------------ datajoint/relation.py | 9 ++++----- 3 files changed, 5 insertions(+), 35 deletions(-) diff --git a/datajoint/__init__.py b/datajoint/__init__.py index 4eb1eb6d1..528e8d74d 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -17,24 +17,6 @@ class DataJointError(Exception): pass -class TransactionError(DataJointError): - """ - Base class for errors specific to DataJoint internal operation. - """ - def __init__(self, msg, f, args=None, kwargs=None): - super(TransactionError, self).__init__(msg) - self.operations = (f, args if args is not None else tuple(), - kwargs if kwargs is not None else {}) - - def resolve(self): - f, args, kwargs = self.operations - return f(*args, **kwargs) - - @property - def culprit(self): - return self.operations[0].__name__ - - # ----------- loads local configuration from file ---------------- from .settings import Config, CONFIGVAR, LOCALCONFIG, logger, log_levels config = Config() diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index a266b62f9..57de7ae8e 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -66,18 +66,7 @@ def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False, else: logger.info('Populating: ' + str(key)) try: - for attempts in range(max_attempts): - try: - self._make_tuples(dict(key)) - break - except TransactionError as tr_err: - self.conn.cancel_transaction() - tr_err.resolve() - self.conn.start_transaction() - logger.info('Transaction error in {0:s}.'.format(tr_err.culprit)) - else: - raise DataJointError( - '%s._make_tuples failed after %i attempts, giving up' % (self.__class__,max_attempts)) + self._make_tuples(dict(key)) except Exception as error: self.conn.cancel_transaction() if not suppress_errors: diff --git a/datajoint/relation.py b/datajoint/relation.py index 729c92651..448feb526 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -334,9 +334,8 @@ def _alter(cls, alter_statement): :param alter_statement: alter statement """ if cls.connection.in_transaction: - raise TransactionError( - u"_alter is currently in transaction. Operation not allowed to avoid implicit commits.", - cls._alter, args=(alter_statement,)) + raise DataJointError( + u"_alter is currently in transaction. Operation not allowed to avoid implicit commits.") sql = 'ALTER TABLE %s %s' % (cls.full_table_name, alter_statement) cls.connection.query(sql) @@ -349,8 +348,8 @@ def _declare(cls): Declares the table in the database if no table in the database matches this object. """ if cls.connection.in_transaction: - raise TransactionError( - u"_declare is currently in transaction. Operation not allowed to avoid implicit commits.", cls._declare) + raise DataJointError( + u"_declare is currently in transaction. Operation not allowed to avoid implicit commits.") if not cls.definition: # if empty definition was supplied raise DataJointError('Table definition is missing!') From b6c3f96d872f933ed5132b8b5bbe3c0ff61351e7 Mon Sep 17 00:00:00 2001 From: Fabian Sinz Date: Mon, 8 Jun 2015 18:33:56 -0500 Subject: [PATCH 04/20] first test runs without failing --- datajoint/__init__.py | 2 +- datajoint/autopopulate.py | 2 +- datajoint/connection.py | 2 +- datajoint/relation.py | 5 +- datajoint/user_relations.py | 10 +- tests/__init__.py | 70 ++- tests/schemata/schema1/__init__.py | 10 +- tests/schemata/schema1/test1.py | 332 +++++----- tests/schemata/schema1/test2.py | 96 +-- tests/schemata/schema1/test3.py | 42 +- tests/schemata/schema1/test4.py | 34 +- tests/schemata/schema2/__init__.py | 4 +- tests/schemata/schema2/test1.py | 32 +- tests/test_connection.py | 492 +++++++-------- tests/test_free_relation.py | 410 ++++++------ tests/test_relation.py | 980 ++++++++++++++--------------- tests/test_relational_operand.py | 94 +-- tests/test_settings.py | 134 ++-- tests/test_utils.py | 66 +- 19 files changed, 1410 insertions(+), 1407 deletions(-) diff --git a/datajoint/__init__.py b/datajoint/__init__.py index 528e8d74d..8ebe93992 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -37,8 +37,8 @@ class DataJointError(Exception): # ------------- flatten import hierarchy ------------------------- from .connection import conn, Connection -from .user_relations import Manual, Lookup, Imported, Computed from .relation import Relation +from .user_relations import Manual, Lookup, Imported, Computed from .autopopulate import AutoPopulate from . import blob from .relational_operand import Not diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index 57de7ae8e..ba349b250 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -1,5 +1,5 @@ from .relational_operand import RelationalOperand -from . import DataJointError, TransactionError, Relation +from . import DataJointError, Relation import abc import logging diff --git a/datajoint/connection.py b/datajoint/connection.py index 52dae7598..3ca01b6bd 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -17,7 +17,7 @@ def conn_container(): """ _connection = None # persistent connection object used by dj.conn() - def conn_function(host=None, user=None, passwd=None, init_fun=None, reset=False): + def conn_function(host=None, user=None, passwd=None, init_fun=None, reset=False): # TODO: thin wrapping layer to mimic singleton """ Manage a persistent connection object. This is one of several ways to configure and access a datajoint connection. diff --git a/datajoint/relation.py b/datajoint/relation.py index 448feb526..19c2ebac9 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -5,7 +5,7 @@ import re import abc -from . import DataJointError, config, TransactionError +from . import DataJointError, config import pymysql from datajoint import DataJointError, conn from .relational_operand import RelationalOperand @@ -21,6 +21,7 @@ ('database', 'context', 'connection', 'heading', 'parents', 'children', 'references', 'referenced')) + class classproperty: def __init__(self, getf): self._getf = getf @@ -500,7 +501,7 @@ def connection(cls): :return: the connection object """ - return cls.connection + return cls._shared_info.connection @classproperty def database(cls): diff --git a/datajoint/user_relations.py b/datajoint/user_relations.py index 56556508d..60e013e6f 100644 --- a/datajoint/user_relations.py +++ b/datajoint/user_relations.py @@ -1,30 +1,30 @@ -from datajoint.relation import ClassBoundRelation +from datajoint.relation import Relation from .autopopulate import AutoPopulate from .utils import from_camel_case -class Manual(ClassBoundRelation): +class Manual(Relation): @property @classmethod def table_name(cls): return from_camel_case(cls.__name__) -class Lookup(ClassBoundRelation): +class Lookup(Relation): @property @classmethod def table_name(cls): return '#' + from_camel_case(cls.__name__) -class Imported(ClassBoundRelation, AutoPopulate): +class Imported(Relation, AutoPopulate): @property @classmethod def table_name(cls): return "_" + from_camel_case(cls.__name__) -class Computed(ClassBoundRelation, AutoPopulate): +class Computed(Relation, AutoPopulate): @property @classmethod def table_name(cls): diff --git a/tests/__init__.py b/tests/__init__.py index 09e358e98..4d4101116 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -18,6 +18,9 @@ 'user': environ.get('DJ_TEST_USER', 'datajoint'), 'passwd': environ.get('DJ_TEST_PASSWORD', 'datajoint') } + +conn = dj.conn(**CONN_INFO) + # Prefix for all databases used during testing PREFIX = environ.get('DJ_TEST_DB_PREFIX', 'dj') # Bare connection used for verification of query results @@ -32,7 +35,6 @@ def setup(): def teardown(): cleanup() - def cleanup(): """ Removes all databases with name starting with the prefix. @@ -52,36 +54,36 @@ def cleanup(): cur.execute('DROP DATABASE `{}`'.format(db)) cur.execute('SET FOREIGN_KEY_CHECKS=1') # set foreign key check back on cur.execute("COMMIT") - -def setup_sample_db(): - """ - Helper method to setup databases with tables to be used - during the test - """ - cur = BASE_CONN.cursor() - cur.execute("CREATE DATABASE IF NOT EXISTS `{}_test1`".format(PREFIX)) - cur.execute("CREATE DATABASE IF NOT EXISTS `{}_test2`".format(PREFIX)) - query1 = """ - CREATE TABLE `{prefix}_test1`.`subjects` - ( - subject_id SMALLINT COMMENT 'Unique subject ID', - subject_name VARCHAR(255) COMMENT 'Subject name', - subject_email VARCHAR(255) COMMENT 'Subject email address', - PRIMARY KEY (subject_id) - ) - """.format(prefix=PREFIX) - cur.execute(query1) - query2 = """ - CREATE TABLE `{prefix}_test2`.`experimenter` - ( - experimenter_id SMALLINT COMMENT 'Unique experimenter ID', - experimenter_name VARCHAR(255) COMMENT 'Experimenter name', - PRIMARY KEY (experimenter_id) - )""".format(prefix=PREFIX) - cur.execute(query2) - - - - - - +# +# def setup_sample_db(): +# """ +# Helper method to setup databases with tables to be used +# during the test +# """ +# cur = BASE_CONN.cursor() +# cur.execute("CREATE DATABASE IF NOT EXISTS `{}_test1`".format(PREFIX)) +# cur.execute("CREATE DATABASE IF NOT EXISTS `{}_test2`".format(PREFIX)) +# query1 = """ +# CREATE TABLE `{prefix}_test1`.`subjects` +# ( +# subject_id SMALLINT COMMENT 'Unique subject ID', +# subject_name VARCHAR(255) COMMENT 'Subject name', +# subject_email VARCHAR(255) COMMENT 'Subject email address', +# PRIMARY KEY (subject_id) +# ) +# """.format(prefix=PREFIX) +# cur.execute(query1) +# query2 = """ +# CREATE TABLE `{prefix}_test2`.`experimenter` +# ( +# experimenter_id SMALLINT COMMENT 'Unique experimenter ID', +# experimenter_name VARCHAR(255) COMMENT 'Experimenter name', +# PRIMARY KEY (experimenter_id) +# )""".format(prefix=PREFIX) +# cur.execute(query2) +# +# +# +# +# +# diff --git a/tests/schemata/schema1/__init__.py b/tests/schemata/schema1/__init__.py index 6032e7bd6..cae90cec9 100644 --- a/tests/schemata/schema1/__init__.py +++ b/tests/schemata/schema1/__init__.py @@ -1,5 +1,5 @@ -__author__ = 'eywalker' -import datajoint as dj - -print(__name__) -from .test3 import * \ No newline at end of file +# __author__ = 'eywalker' +# import datajoint as dj +# +# print(__name__) +# from .test3 import * \ No newline at end of file diff --git a/tests/schemata/schema1/test1.py b/tests/schemata/schema1/test1.py index 4c8df082f..d0d91707f 100644 --- a/tests/schemata/schema1/test1.py +++ b/tests/schemata/schema1/test1.py @@ -1,166 +1,166 @@ -""" -Test 1 Schema definition -""" -__author__ = 'eywalker' - -import datajoint as dj -from .. import schema2 - - -class Subjects(dj.Relation): - definition = """ - test1.Subjects (manual) # Basic subject info - - subject_id : int # unique subject id - --- - real_id : varchar(40) # real-world name - species = "mouse" : enum('mouse', 'monkey', 'human') # species - """ - -# test for shorthand -class Animals(dj.Relation): - definition = """ - test1.Animals (manual) # Listing of all info - - -> Subjects - --- - animal_dob :date # date of birth - """ - - -class Trials(dj.Relation): - definition = """ - test1.Trials (manual) # info about trials - - -> test1.Subjects - trial_id : int - --- - outcome : int # result of experiment - - notes="" : varchar(4096) # other comments - trial_ts=CURRENT_TIMESTAMP : timestamp # automatic - """ - - - -class SquaredScore(dj.Relation, dj.AutoPopulate): - definition = """ - test1.SquaredScore (computed) # cumulative outcome of trials - - -> test1.Subjects - -> test1.Trials - --- - squared : int # squared result of Trials outcome - """ - - @property - def populate_relation(self): - return Subjects() * Trials() - - def _make_tuples(self, key): - tmp = (Trials() & key).fetch1() - tmp2 = SquaredSubtable() & key - - self.insert(dict(key, squared=tmp['outcome']**2)) - - ss = SquaredSubtable() - - for i in range(10): - key['dummy'] = i - ss.insert(key) - - -class WrongImplementation(dj.Relation, dj.AutoPopulate): - definition = """ - test1.WrongImplementation (computed) # ignore - - -> test1.Subjects - -> test1.Trials - --- - dummy : int # ignore - """ - - @property - def populate_relation(self): - return {'subject_id':2} - - def _make_tuples(self, key): - pass - - - -class ErrorGenerator(dj.Relation, dj.AutoPopulate): - definition = """ - test1.ErrorGenerator (computed) # ignore - - -> test1.Subjects - -> test1.Trials - --- - dummy : int # ignore - """ - - @property - def populate_relation(self): - return Subjects() * Trials() - - def _make_tuples(self, key): - raise Exception("This is for testing") - - - - - - -class SquaredSubtable(dj.Relation): - definition = """ - test1.SquaredSubtable (computed) # cumulative outcome of trials - - -> test1.SquaredScore - dummy : int # dummy primary attribute - --- - """ - - -# test reference to another table in same schema -class Experiments(dj.Relation): - definition = """ - test1.Experiments (imported) # Experiment info - -> test1.Subjects - exp_id : int # unique id for experiment - --- - exp_data_file : varchar(255) # data file - """ - - -# refers to a table in dj_test2 (bound to test2) but without a class -class Sessions(dj.Relation): - definition = """ - test1.Sessions (manual) # Experiment sessions - -> test1.Subjects - -> test2.Experimenter - session_id : int # unique session id - --- - session_comment : varchar(255) # comment about the session - """ - - -class Match(dj.Relation): - definition = """ - test1.Match (manual) # Match between subject and color - -> schema2.Subjects - --- - dob : date # date of birth - """ - - -# this tries to reference a table in database directly without ORM -class TrainingSession(dj.Relation): - definition = """ - test1.TrainingSession (manual) # training sessions - -> `dj_test2`.Experimenter - session_id : int # training session id - """ - - -class Empty(dj.Relation): - pass +# """ +# Test 1 Schema definition +# """ +# __author__ = 'eywalker' +# +# import datajoint as dj +# from .. import schema2 +# +# +# class Subjects(dj.Relation): +# definition = """ +# test1.Subjects (manual) # Basic subject info +# +# subject_id : int # unique subject id +# --- +# real_id : varchar(40) # real-world name +# species = "mouse" : enum('mouse', 'monkey', 'human') # species +# """ +# +# # test for shorthand +# class Animals(dj.Relation): +# definition = """ +# test1.Animals (manual) # Listing of all info +# +# -> Subjects +# --- +# animal_dob :date # date of birth +# """ +# +# +# class Trials(dj.Relation): +# definition = """ +# test1.Trials (manual) # info about trials +# +# -> test1.Subjects +# trial_id : int +# --- +# outcome : int # result of experiment +# +# notes="" : varchar(4096) # other comments +# trial_ts=CURRENT_TIMESTAMP : timestamp # automatic +# """ +# +# +# +# class SquaredScore(dj.Relation, dj.AutoPopulate): +# definition = """ +# test1.SquaredScore (computed) # cumulative outcome of trials +# +# -> test1.Subjects +# -> test1.Trials +# --- +# squared : int # squared result of Trials outcome +# """ +# +# @property +# def populate_relation(self): +# return Subjects() * Trials() +# +# def _make_tuples(self, key): +# tmp = (Trials() & key).fetch1() +# tmp2 = SquaredSubtable() & key +# +# self.insert(dict(key, squared=tmp['outcome']**2)) +# +# ss = SquaredSubtable() +# +# for i in range(10): +# key['dummy'] = i +# ss.insert(key) +# +# +# class WrongImplementation(dj.Relation, dj.AutoPopulate): +# definition = """ +# test1.WrongImplementation (computed) # ignore +# +# -> test1.Subjects +# -> test1.Trials +# --- +# dummy : int # ignore +# """ +# +# @property +# def populate_relation(self): +# return {'subject_id':2} +# +# def _make_tuples(self, key): +# pass +# +# +# +# class ErrorGenerator(dj.Relation, dj.AutoPopulate): +# definition = """ +# test1.ErrorGenerator (computed) # ignore +# +# -> test1.Subjects +# -> test1.Trials +# --- +# dummy : int # ignore +# """ +# +# @property +# def populate_relation(self): +# return Subjects() * Trials() +# +# def _make_tuples(self, key): +# raise Exception("This is for testing") +# +# +# +# +# +# +# class SquaredSubtable(dj.Relation): +# definition = """ +# test1.SquaredSubtable (computed) # cumulative outcome of trials +# +# -> test1.SquaredScore +# dummy : int # dummy primary attribute +# --- +# """ +# +# +# # test reference to another table in same schema +# class Experiments(dj.Relation): +# definition = """ +# test1.Experiments (imported) # Experiment info +# -> test1.Subjects +# exp_id : int # unique id for experiment +# --- +# exp_data_file : varchar(255) # data file +# """ +# +# +# # refers to a table in dj_test2 (bound to test2) but without a class +# class Sessions(dj.Relation): +# definition = """ +# test1.Sessions (manual) # Experiment sessions +# -> test1.Subjects +# -> test2.Experimenter +# session_id : int # unique session id +# --- +# session_comment : varchar(255) # comment about the session +# """ +# +# +# class Match(dj.Relation): +# definition = """ +# test1.Match (manual) # Match between subject and color +# -> schema2.Subjects +# --- +# dob : date # date of birth +# """ +# +# +# # this tries to reference a table in database directly without ORM +# class TrainingSession(dj.Relation): +# definition = """ +# test1.TrainingSession (manual) # training sessions +# -> `dj_test2`.Experimenter +# session_id : int # training session id +# """ +# +# +# class Empty(dj.Relation): +# pass diff --git a/tests/schemata/schema1/test2.py b/tests/schemata/schema1/test2.py index 563fe6b52..aded2e4fb 100644 --- a/tests/schemata/schema1/test2.py +++ b/tests/schemata/schema1/test2.py @@ -1,48 +1,48 @@ -""" -Test 2 Schema definition -""" -__author__ = 'eywalker' - -import datajoint as dj -from . import test1 as alias -#from ..schema2 import test2 as test1 - - -# references to another schema -class Experiments(dj.Relation): - definition = """ - test2.Experiments (manual) # Basic subject info - -> test1.Subjects - experiment_id : int # unique experiment id - --- - real_id : varchar(40) # real-world name - species = "mouse" : enum('mouse', 'monkey', 'human') # species - """ - - -# references to another schema -class Conditions(dj.Relation): - definition = """ - test2.Conditions (manual) # Subject conditions - -> alias.Subjects - condition_name : varchar(255) # description of the condition - """ - - -class FoodPreference(dj.Relation): - definition = """ - test2.FoodPreference (manual) # Food preference of each subject - -> animals.Subjects - preferred_food : enum('banana', 'apple', 'oranges') - """ - - -class Session(dj.Relation): - definition = """ - test2.Session (manual) # Experiment sessions - -> test1.Subjects - -> test2.Experimenter - session_id : int # unique session id - --- - session_comment : varchar(255) # comment about the session - """ \ No newline at end of file +# """ +# Test 2 Schema definition +# """ +# __author__ = 'eywalker' +# +# import datajoint as dj +# from . import test1 as alias +# #from ..schema2 import test2 as test1 +# +# +# # references to another schema +# class Experiments(dj.Relation): +# definition = """ +# test2.Experiments (manual) # Basic subject info +# -> test1.Subjects +# experiment_id : int # unique experiment id +# --- +# real_id : varchar(40) # real-world name +# species = "mouse" : enum('mouse', 'monkey', 'human') # species +# """ +# +# +# # references to another schema +# class Conditions(dj.Relation): +# definition = """ +# test2.Conditions (manual) # Subject conditions +# -> alias.Subjects +# condition_name : varchar(255) # description of the condition +# """ +# +# +# class FoodPreference(dj.Relation): +# definition = """ +# test2.FoodPreference (manual) # Food preference of each subject +# -> animals.Subjects +# preferred_food : enum('banana', 'apple', 'oranges') +# """ +# +# +# class Session(dj.Relation): +# definition = """ +# test2.Session (manual) # Experiment sessions +# -> test1.Subjects +# -> test2.Experimenter +# session_id : int # unique session id +# --- +# session_comment : varchar(255) # comment about the session +# """ \ No newline at end of file diff --git a/tests/schemata/schema1/test3.py b/tests/schemata/schema1/test3.py index 2004a8736..e00a01afb 100644 --- a/tests/schemata/schema1/test3.py +++ b/tests/schemata/schema1/test3.py @@ -1,21 +1,21 @@ -""" -Test 3 Schema definition - no binding, no conn - -To be bound at the package level -""" -__author__ = 'eywalker' - -import datajoint as dj - - -class Subjects(dj.Relation): - definition = """ - schema1.Subjects (manual) # Basic subject info - - subject_id : int # unique subject id - dob : date # date of birth - --- - real_id : varchar(40) # real-world name - species = "mouse" : enum('mouse', 'monkey', 'human') # species - """ - +# """ +# Test 3 Schema definition - no binding, no conn +# +# To be bound at the package level +# """ +# __author__ = 'eywalker' +# +# import datajoint as dj +# +# +# class Subjects(dj.Relation): +# definition = """ +# schema1.Subjects (manual) # Basic subject info +# +# subject_id : int # unique subject id +# dob : date # date of birth +# --- +# real_id : varchar(40) # real-world name +# species = "mouse" : enum('mouse', 'monkey', 'human') # species +# """ +# diff --git a/tests/schemata/schema1/test4.py b/tests/schemata/schema1/test4.py index a2004affd..9860cb030 100644 --- a/tests/schemata/schema1/test4.py +++ b/tests/schemata/schema1/test4.py @@ -1,17 +1,17 @@ -""" -Test 1 Schema definition - fully bound and has connection object -""" -__author__ = 'fabee' - -import datajoint as dj - - -class Matrix(dj.Relation): - definition = """ - test4.Matrix (manual) # Some numpy array - - matrix_id : int # unique matrix id - --- - data : longblob # data - comment : varchar(1000) # comment - """ +# """ +# Test 1 Schema definition - fully bound and has connection object +# """ +# __author__ = 'fabee' +# +# import datajoint as dj +# +# +# class Matrix(dj.Relation): +# definition = """ +# test4.Matrix (manual) # Some numpy array +# +# matrix_id : int # unique matrix id +# --- +# data : longblob # data +# comment : varchar(1000) # comment +# """ diff --git a/tests/schemata/schema2/__init__.py b/tests/schemata/schema2/__init__.py index e6b482590..d79e02cc3 100644 --- a/tests/schemata/schema2/__init__.py +++ b/tests/schemata/schema2/__init__.py @@ -1,2 +1,2 @@ -__author__ = 'eywalker' -from .test1 import * \ No newline at end of file +# __author__ = 'eywalker' +# from .test1 import * \ No newline at end of file diff --git a/tests/schemata/schema2/test1.py b/tests/schemata/schema2/test1.py index 83bb3a19e..4005aa670 100644 --- a/tests/schemata/schema2/test1.py +++ b/tests/schemata/schema2/test1.py @@ -1,16 +1,16 @@ -""" -Test 2 Schema definition -""" -__author__ = 'eywalker' - -import datajoint as dj - - -class Subjects(dj.Relation): - definition = """ - schema2.Subjects (manual) # Basic subject info - pop_id : int # unique experiment id - --- - real_id : varchar(40) # real-world name - species = "mouse" : enum('mouse', 'monkey', 'human') # species - """ \ No newline at end of file +# """ +# Test 2 Schema definition +# """ +# __author__ = 'eywalker' +# +# import datajoint as dj +# +# +# class Subjects(dj.Relation): +# definition = """ +# schema2.Subjects (manual) # Basic subject info +# pop_id : int # unique experiment id +# --- +# real_id : varchar(40) # real-world name +# species = "mouse" : enum('mouse', 'monkey', 'human') # species +# """ \ No newline at end of file diff --git a/tests/test_connection.py b/tests/test_connection.py index 1fb581468..29fee4f64 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,257 +1,257 @@ -""" -Collection of test cases to test connection module. -""" -from .schemata import schema1 -from .schemata.schema1 import test1 -import numpy as np - -__author__ = 'eywalker, fabee' -from . import (CONN_INFO, PREFIX, BASE_CONN, cleanup) -from nose.tools import assert_true, assert_raises, assert_equal, raises -import datajoint as dj -from datajoint.utils import DataJointError - - -def setup(): - cleanup() - - -def test_dj_conn(): - """ - Should be able to establish a connection - """ - c = dj.conn(**CONN_INFO) - assert c.is_connected - - -def test_persistent_dj_conn(): - """ - conn() method should provide persistent connection - across calls. - """ - c1 = dj.conn(**CONN_INFO) - c2 = dj.conn() - assert_true(c1 is c2) - - -def test_dj_conn_reset(): - """ - Passing in reset=True should allow for new persistent - connection to be created. - """ - c1 = dj.conn(**CONN_INFO) - c2 = dj.conn(reset=True, **CONN_INFO) - assert_true(c1 is not c2) - - - -def setup_sample_db(): - """ - Helper method to setup databases with tables to be used - during the test - """ - cur = BASE_CONN.cursor() - cur.execute("CREATE DATABASE `{}_test1`".format(PREFIX)) - cur.execute("CREATE DATABASE `{}_test2`".format(PREFIX)) - query1 = """ - CREATE TABLE `{prefix}_test1`.`subjects` - ( - subject_id SMALLINT COMMENT 'Unique subject ID', - subject_name VARCHAR(255) COMMENT 'Subject name', - subject_email VARCHAR(255) COMMENT 'Subject email address', - PRIMARY KEY (subject_id) - ) - """.format(prefix=PREFIX) - cur.execute(query1) - # query2 = """ - # CREATE TABLE `{prefix}_test2`.`experiments` - # ( - # experiment_id SMALLINT COMMENT 'Unique experiment ID', - # experiment_name VARCHAR(255) COMMENT 'Experiment name', - # subject_id SMALLINT, - # CONSTRAINT FOREIGN KEY (`subject_id`) REFERENCES `dj_test1`.`subjects` (`subject_id`) ON UPDATE CASCADE ON DELETE RESTRICT, - # PRIMARY KEY (subject_id, experiment_id) - # )""".format(prefix=PREFIX) - # cur.execute(query2) - - -class TestConnectionWithoutBindings(object): - """ - Test methods from Connection that does not - depend on presence of module to database bindings. - This includes tests for `bind` method itself. - """ - def setup(self): - self.conn = dj.Connection(**CONN_INFO) - test1.__dict__.pop('conn', None) - schema1.__dict__.pop('conn', None) - setup_sample_db() - - def teardown(self): - cleanup() - - def check_binding(self, db_name, module): - """ - Helper method to check if the specified database-module pairing exists - """ - assert_equal(self.conn.db_to_mod[db_name], module) - assert_equal(self.conn.mod_to_db[module], db_name) - - def test_bind_to_existing_database(self): - """ - Should be able to bind a module to an existing database - """ - db_name = PREFIX + '_test1' - module = test1.__name__ - self.conn.bind(module, db_name) - self.check_binding(db_name, module) - - def test_bind_at_package_level(self): - db_name = PREFIX + '_test1' - package = schema1.__name__ - self.conn.bind(package, db_name) - self.check_binding(db_name, package) - - def test_bind_to_non_existing_database(self): - """ - Should be able to bind a module to a non-existing database by creating target - """ - db_name = PREFIX + '_test3' - module = test1.__name__ - cur = BASE_CONN.cursor() - - # Ensure target database doesn't exist - if cur.execute("SHOW DATABASES LIKE '{}'".format(db_name)): - cur.execute("DROP DATABASE IF EXISTS `{}`".format(db_name)) - # Bind module to non-existing database - self.conn.bind(module, db_name) - # Check that target database was created - assert_equal(cur.execute("SHOW DATABASES LIKE '{}'".format(db_name)), 1) - self.check_binding(db_name, module) - # Remove the target database - cur.execute("DROP DATABASE IF EXISTS `{}`".format(db_name)) - - def test_cannot_bind_to_multiple_databases(self): - """ - Bind will fail when db_name is a pattern that - matches multiple databases - """ - db_name = PREFIX + "_test%%" - module = test1.__name__ - with assert_raises(DataJointError): - self.conn.bind(module, db_name) - - def test_basic_sql_query(self): - """ - Test execution of basic SQL query using connection - object. - """ - cur = self.conn.query('SHOW DATABASES') - results1 = cur.fetchall() - cur2 = BASE_CONN.cursor() - cur2.execute('SHOW DATABASES') - results2 = cur2.fetchall() - assert_equal(results1, results2) - - def test_transaction_commit(self): - """ - Test transaction commit - """ - table_name = PREFIX + '_test1.subjects' - self.conn.start_transaction() - self.conn.query("INSERT INTO {table} VALUES (0, 'dj_user', 'dj_user@example.com')".format(table=table_name)) - cur = BASE_CONN.cursor() - assert_equal(cur.execute("SELECT * FROM {}".format(table_name)), 0) - self.conn.commit_transaction() - assert_equal(cur.execute("SELECT * FROM {}".format(table_name)), 1) - - def test_transaction_rollback(self): - """ - Test transaction rollback - """ - table_name = PREFIX + '_test1.subjects' - self.conn.start_transaction() - self.conn.query("INSERT INTO {table} VALUES (0, 'dj_user', 'dj_user@example.com')".format(table=table_name)) - cur = BASE_CONN.cursor() - assert_equal(cur.execute("SELECT * FROM {}".format(table_name)), 0) - self.conn.cancel_transaction() - assert_equal(cur.execute("SELECT * FROM {}".format(table_name)), 0) - -# class TestContextManager(object): -# def __init__(self): -# self.relvar = None -# self.setup() +# """ +# Collection of test cases to test connection module. +# """ +# from .schemata import schema1 +# from .schemata.schema1 import test1 +# import numpy as np # +# __author__ = 'eywalker, fabee' +# from . import (CONN_INFO, PREFIX, BASE_CONN, cleanup) +# from nose.tools import assert_true, assert_raises, assert_equal, raises +# import datajoint as dj +# from datajoint.utils import DataJointError +# +# +# def setup(): +# cleanup() +# +# +# def test_dj_conn(): # """ -# Test cases for FreeRelation objects +# Should be able to establish a connection # """ +# c = dj.conn(**CONN_INFO) +# assert c.is_connected # -# def setup(self): -# """ -# Create a connection object and prepare test modules -# as follows: -# test1 - has conn and bounded -# """ -# cleanup() # drop all databases with PREFIX -# test1.__dict__.pop('conn', None) # -# self.conn = dj.Connection(**CONN_INFO) -# test1.conn = self.conn -# self.conn.bind(test1.__name__, PREFIX + '_test1') -# self.relvar = test1.Subjects() +# def test_persistent_dj_conn(): +# """ +# conn() method should provide persistent connection +# across calls. +# """ +# c1 = dj.conn(**CONN_INFO) +# c2 = dj.conn() +# assert_true(c1 is c2) # -# def teardown(self): -# cleanup() # -# # def test_active(self): -# # with self.conn.transaction() as tr: -# # assert_true(tr.is_active, "Transaction is not active") -# -# # def test_rollback(self): -# # -# # tmp = np.array([(1,'Peter','mouse'),(2, 'Klara', 'monkey')], -# # dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) -# # -# # self.relvar.insert(tmp[0]) -# # try: -# # with self.conn.transaction(): -# # self.relvar.insert(tmp[1]) -# # raise DataJointError("Just to test") -# # except DataJointError as e: -# # pass -# # -# # testt2 = (self.relvar & 'subject_id = 2').fetch() -# # assert_equal(len(testt2), 0, "Length is not 0. Expected because rollback should have happened.") -# -# # def test_cancel(self): -# # """Tests cancelling a transaction""" -# # tmp = np.array([(1,'Peter','mouse'),(2, 'Klara', 'monkey')], -# # dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) -# # -# # self.relvar.insert(tmp[0]) -# # with self.conn.transaction() as transaction: -# # self.relvar.insert(tmp[1]) -# # transaction.cancel() -# # -# # testt2 = (self.relvar & 'subject_id = 2').fetch() -# # assert_equal(len(testt2), 0, "Length is not 0. Expected because rollback should have happened.") - - - -# class TestConnectionWithBindings(object): +# def test_dj_conn_reset(): +# """ +# Passing in reset=True should allow for new persistent +# connection to be created. +# """ +# c1 = dj.conn(**CONN_INFO) +# c2 = dj.conn(reset=True, **CONN_INFO) +# assert_true(c1 is not c2) +# +# +# +# def setup_sample_db(): +# """ +# Helper method to setup databases with tables to be used +# during the test +# """ +# cur = BASE_CONN.cursor() +# cur.execute("CREATE DATABASE `{}_test1`".format(PREFIX)) +# cur.execute("CREATE DATABASE `{}_test2`".format(PREFIX)) +# query1 = """ +# CREATE TABLE `{prefix}_test1`.`subjects` +# ( +# subject_id SMALLINT COMMENT 'Unique subject ID', +# subject_name VARCHAR(255) COMMENT 'Subject name', +# subject_email VARCHAR(255) COMMENT 'Subject email address', +# PRIMARY KEY (subject_id) +# ) +# """.format(prefix=PREFIX) +# cur.execute(query1) +# # query2 = """ +# # CREATE TABLE `{prefix}_test2`.`experiments` +# # ( +# # experiment_id SMALLINT COMMENT 'Unique experiment ID', +# # experiment_name VARCHAR(255) COMMENT 'Experiment name', +# # subject_id SMALLINT, +# # CONSTRAINT FOREIGN KEY (`subject_id`) REFERENCES `dj_test1`.`subjects` (`subject_id`) ON UPDATE CASCADE ON DELETE RESTRICT, +# # PRIMARY KEY (subject_id, experiment_id) +# # )""".format(prefix=PREFIX) +# # cur.execute(query2) +# +# +# class TestConnectionWithoutBindings(object): # """ -# Tests heading and dependency loadings +# Test methods from Connection that does not +# depend on presence of module to database bindings. +# This includes tests for `bind` method itself. # """ # def setup(self): # self.conn = dj.Connection(**CONN_INFO) -# cur.execute(query) - - - - - - - - - - +# test1.__dict__.pop('conn', None) +# schema1.__dict__.pop('conn', None) +# setup_sample_db() +# +# def teardown(self): +# cleanup() +# +# def check_binding(self, db_name, module): +# """ +# Helper method to check if the specified database-module pairing exists +# """ +# assert_equal(self.conn.db_to_mod[db_name], module) +# assert_equal(self.conn.mod_to_db[module], db_name) +# +# def test_bind_to_existing_database(self): +# """ +# Should be able to bind a module to an existing database +# """ +# db_name = PREFIX + '_test1' +# module = test1.__name__ +# self.conn.bind(module, db_name) +# self.check_binding(db_name, module) +# +# def test_bind_at_package_level(self): +# db_name = PREFIX + '_test1' +# package = schema1.__name__ +# self.conn.bind(package, db_name) +# self.check_binding(db_name, package) +# +# def test_bind_to_non_existing_database(self): +# """ +# Should be able to bind a module to a non-existing database by creating target +# """ +# db_name = PREFIX + '_test3' +# module = test1.__name__ +# cur = BASE_CONN.cursor() +# +# # Ensure target database doesn't exist +# if cur.execute("SHOW DATABASES LIKE '{}'".format(db_name)): +# cur.execute("DROP DATABASE IF EXISTS `{}`".format(db_name)) +# # Bind module to non-existing database +# self.conn.bind(module, db_name) +# # Check that target database was created +# assert_equal(cur.execute("SHOW DATABASES LIKE '{}'".format(db_name)), 1) +# self.check_binding(db_name, module) +# # Remove the target database +# cur.execute("DROP DATABASE IF EXISTS `{}`".format(db_name)) +# +# def test_cannot_bind_to_multiple_databases(self): +# """ +# Bind will fail when db_name is a pattern that +# matches multiple databases +# """ +# db_name = PREFIX + "_test%%" +# module = test1.__name__ +# with assert_raises(DataJointError): +# self.conn.bind(module, db_name) +# +# def test_basic_sql_query(self): +# """ +# Test execution of basic SQL query using connection +# object. +# """ +# cur = self.conn.query('SHOW DATABASES') +# results1 = cur.fetchall() +# cur2 = BASE_CONN.cursor() +# cur2.execute('SHOW DATABASES') +# results2 = cur2.fetchall() +# assert_equal(results1, results2) +# +# def test_transaction_commit(self): +# """ +# Test transaction commit +# """ +# table_name = PREFIX + '_test1.subjects' +# self.conn.start_transaction() +# self.conn.query("INSERT INTO {table} VALUES (0, 'dj_user', 'dj_user@example.com')".format(table=table_name)) +# cur = BASE_CONN.cursor() +# assert_equal(cur.execute("SELECT * FROM {}".format(table_name)), 0) +# self.conn.commit_transaction() +# assert_equal(cur.execute("SELECT * FROM {}".format(table_name)), 1) +# +# def test_transaction_rollback(self): +# """ +# Test transaction rollback +# """ +# table_name = PREFIX + '_test1.subjects' +# self.conn.start_transaction() +# self.conn.query("INSERT INTO {table} VALUES (0, 'dj_user', 'dj_user@example.com')".format(table=table_name)) +# cur = BASE_CONN.cursor() +# assert_equal(cur.execute("SELECT * FROM {}".format(table_name)), 0) +# self.conn.cancel_transaction() +# assert_equal(cur.execute("SELECT * FROM {}".format(table_name)), 0) +# +# # class TestContextManager(object): +# # def __init__(self): +# # self.relvar = None +# # self.setup() +# # +# # """ +# # Test cases for FreeRelation objects +# # """ +# # +# # def setup(self): +# # """ +# # Create a connection object and prepare test modules +# # as follows: +# # test1 - has conn and bounded +# # """ +# # cleanup() # drop all databases with PREFIX +# # test1.__dict__.pop('conn', None) +# # +# # self.conn = dj.Connection(**CONN_INFO) +# # test1.conn = self.conn +# # self.conn.bind(test1.__name__, PREFIX + '_test1') +# # self.relvar = test1.Subjects() +# # +# # def teardown(self): +# # cleanup() +# # +# # # def test_active(self): +# # # with self.conn.transaction() as tr: +# # # assert_true(tr.is_active, "Transaction is not active") +# # +# # # def test_rollback(self): +# # # +# # # tmp = np.array([(1,'Peter','mouse'),(2, 'Klara', 'monkey')], +# # # dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) +# # # +# # # self.relvar.insert(tmp[0]) +# # # try: +# # # with self.conn.transaction(): +# # # self.relvar.insert(tmp[1]) +# # # raise DataJointError("Just to test") +# # # except DataJointError as e: +# # # pass +# # # +# # # testt2 = (self.relvar & 'subject_id = 2').fetch() +# # # assert_equal(len(testt2), 0, "Length is not 0. Expected because rollback should have happened.") +# # +# # # def test_cancel(self): +# # # """Tests cancelling a transaction""" +# # # tmp = np.array([(1,'Peter','mouse'),(2, 'Klara', 'monkey')], +# # # dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) +# # # +# # # self.relvar.insert(tmp[0]) +# # # with self.conn.transaction() as transaction: +# # # self.relvar.insert(tmp[1]) +# # # transaction.cancel() +# # # +# # # testt2 = (self.relvar & 'subject_id = 2').fetch() +# # # assert_equal(len(testt2), 0, "Length is not 0. Expected because rollback should have happened.") +# +# +# +# # class TestConnectionWithBindings(object): +# # """ +# # Tests heading and dependency loadings +# # """ +# # def setup(self): +# # self.conn = dj.Connection(**CONN_INFO) +# # cur.execute(query) +# +# +# +# +# +# +# +# +# +# diff --git a/tests/test_free_relation.py b/tests/test_free_relation.py index e4ebaa872..285bf7487 100644 --- a/tests/test_free_relation.py +++ b/tests/test_free_relation.py @@ -1,205 +1,205 @@ -""" -Collection of test cases for base module. Tests functionalities such as -creating tables using docstring table declarations -""" -from .schemata import schema1, schema2 -from .schemata.schema1 import test1, test2, test3 - - -__author__ = 'eywalker' - -from . import BASE_CONN, CONN_INFO, PREFIX, cleanup, setup_sample_db -from datajoint.connection import Connection -from nose.tools import assert_raises, assert_equal, assert_regexp_matches, assert_false, assert_true, raises -from datajoint import DataJointError - - -def setup(): - """ - Setup connections and bindings - """ - pass - - -class TestRelationInstantiations(object): - """ - Test cases for instantiating Relation objects - """ - def __init__(self): - self.conn = None - - def setup(self): - """ - Create a connection object and prepare test modules - as follows: - test1 - has conn and bounded - """ - self.conn = Connection(**CONN_INFO) - cleanup() # drop all databases with PREFIX - #test1.conn = self.conn - #self.conn.bind(test1.__name__, PREFIX+'_test1') - - #test2.conn = self.conn - - #test3.__dict__.pop('conn', None) # make sure conn is not defined in test3 - test1.__dict__.pop('conn', None) - schema1.__dict__.pop('conn', None) # make sure conn is not defined at schema level - - - def teardown(self): - cleanup() - - - def test_instantiation_from_unbound_module_should_fail(self): - """ - Attempting to instantiate a Relation derivative from a module with - connection defined but not bound to a database should raise error - """ - test1.conn = self.conn - with assert_raises(DataJointError) as e: - test1.Subjects() - assert_regexp_matches(e.exception.args[0], r".*not bound.*") - - def test_instantiation_from_module_without_conn_should_fail(self): - """ - Attempting to instantiate a Relation derivative from a module that lacks - `conn` object should raise error - """ - with assert_raises(DataJointError) as e: - test1.Subjects() - assert_regexp_matches(e.exception.args[0], r".*define.*conn.*") - - def test_instantiation_of_base_derivatives(self): - """ - Test instantiation and initialization of objects derived from - Relation class - """ - test1.conn = self.conn - self.conn.bind(test1.__name__, PREFIX + '_test1') - s = test1.Subjects() - assert_equal(s.dbname, PREFIX + '_test1') - assert_equal(s.conn, self.conn) - assert_equal(s.definition, test1.Subjects.definition) - - def test_packagelevel_binding(self): - schema2.conn = self.conn - self.conn.bind(schema2.__name__, PREFIX + '_test1') - s = schema2.test1.Subjects() - - -class TestRelationDeclaration(object): - """ - Test declaration (creation of table) from - definition in Relation under various circumstances - """ - - def setup(self): - cleanup() - - self.conn = Connection(**CONN_INFO) - test1.conn = self.conn - self.conn.bind(test1.__name__, PREFIX + '_test1') - test2.conn = self.conn - self.conn.bind(test2.__name__, PREFIX + '_test2') - - def test_is_declared(self): - """ - The table should not be created immediately after instantiation, - but should be created when declare method is called - :return: - """ - s = test1.Subjects() - assert_false(s.is_declared) - s.declare() - assert_true(s.is_declared) - - def test_calling_heading_should_trigger_declaration(self): - s = test1.Subjects() - assert_false(s.is_declared) - a = s.heading - assert_true(s.is_declared) - - def test_foreign_key_ref_in_same_schema(self): - s = test1.Experiments() - assert_true('subject_id' in s.heading.primary_key) - - def test_foreign_key_ref_in_another_schema(self): - s = test2.Experiments() - assert_true('subject_id' in s.heading.primary_key) - - def test_aliased_module_name_should_resolve(self): - """ - Module names that were aliased in the definition should - be properly resolved. - """ - s = test2.Conditions() - assert_true('subject_id' in s.heading.primary_key) - - def test_reference_to_unknown_module_in_definition_should_fail(self): - """ - Module names in table definition that is not aliased via import - results in error - """ - s = test2.FoodPreference() - with assert_raises(DataJointError) as e: - s.declare() - - -class TestRelationWithExistingTables(object): - """ - Test base derivatives behaviors when some of the tables - already exists in the database - """ - def setup(self): - cleanup() - self.conn = Connection(**CONN_INFO) - setup_sample_db() - test1.conn = self.conn - self.conn.bind(test1.__name__, PREFIX + '_test1') - test2.conn = self.conn - self.conn.bind(test2.__name__, PREFIX + '_test2') - self.conn.load_headings(force=True) - - schema2.conn = self.conn - self.conn.bind(schema2.__name__, PREFIX + '_package') - - def teardown(selfself): - schema1.__dict__.pop('conn', None) - cleanup() - - def test_detection_of_existing_table(self): - """ - The Relation instance should be able to detect if the - corresponding table already exists in the database - """ - s = test1.Subjects() - assert_true(s.is_declared) - - def test_definition_referring_to_existing_table_without_class(self): - s1 = test1.Sessions() - assert_true('experimenter_id' in s1.primary_key) - - s2 = test2.Session() - assert_true('experimenter_id' in s2.primary_key) - - def test_reference_to_package_level_table(self): - s = test1.Match() - s.declare() - assert_true('pop_id' in s.primary_key) - - def test_direct_reference_to_existing_table_should_fail(self): - """ - When deriving from Relation, definition should not contain direct reference - to a database name - """ - s = test1.TrainingSession() - with assert_raises(DataJointError): - s.declare() - -@raises(TypeError) -def test_instantiation_of_base_derivative_without_definition_should_fail(): - test1.Empty() - - - - +# """ +# Collection of test cases for base module. Tests functionalities such as +# creating tables using docstring table declarations +# """ +# from .schemata import schema1, schema2 +# from .schemata.schema1 import test1, test2, test3 +# +# +# __author__ = 'eywalker' +# +# from . import BASE_CONN, CONN_INFO, PREFIX, cleanup, setup_sample_db +# from datajoint.connection import Connection +# from nose.tools import assert_raises, assert_equal, assert_regexp_matches, assert_false, assert_true, raises +# from datajoint import DataJointError +# +# +# def setup(): +# """ +# Setup connections and bindings +# """ +# pass +# +# +# class TestRelationInstantiations(object): +# """ +# Test cases for instantiating Relation objects +# """ +# def __init__(self): +# self.conn = None +# +# def setup(self): +# """ +# Create a connection object and prepare test modules +# as follows: +# test1 - has conn and bounded +# """ +# self.conn = Connection(**CONN_INFO) +# cleanup() # drop all databases with PREFIX +# #test1.conn = self.conn +# #self.conn.bind(test1.__name__, PREFIX+'_test1') +# +# #test2.conn = self.conn +# +# #test3.__dict__.pop('conn', None) # make sure conn is not defined in test3 +# test1.__dict__.pop('conn', None) +# schema1.__dict__.pop('conn', None) # make sure conn is not defined at schema level +# +# +# def teardown(self): +# cleanup() +# +# +# def test_instantiation_from_unbound_module_should_fail(self): +# """ +# Attempting to instantiate a Relation derivative from a module with +# connection defined but not bound to a database should raise error +# """ +# test1.conn = self.conn +# with assert_raises(DataJointError) as e: +# test1.Subjects() +# assert_regexp_matches(e.exception.args[0], r".*not bound.*") +# +# def test_instantiation_from_module_without_conn_should_fail(self): +# """ +# Attempting to instantiate a Relation derivative from a module that lacks +# `conn` object should raise error +# """ +# with assert_raises(DataJointError) as e: +# test1.Subjects() +# assert_regexp_matches(e.exception.args[0], r".*define.*conn.*") +# +# def test_instantiation_of_base_derivatives(self): +# """ +# Test instantiation and initialization of objects derived from +# Relation class +# """ +# test1.conn = self.conn +# self.conn.bind(test1.__name__, PREFIX + '_test1') +# s = test1.Subjects() +# assert_equal(s.dbname, PREFIX + '_test1') +# assert_equal(s.conn, self.conn) +# assert_equal(s.definition, test1.Subjects.definition) +# +# def test_packagelevel_binding(self): +# schema2.conn = self.conn +# self.conn.bind(schema2.__name__, PREFIX + '_test1') +# s = schema2.test1.Subjects() +# +# +# class TestRelationDeclaration(object): +# """ +# Test declaration (creation of table) from +# definition in Relation under various circumstances +# """ +# +# def setup(self): +# cleanup() +# +# self.conn = Connection(**CONN_INFO) +# test1.conn = self.conn +# self.conn.bind(test1.__name__, PREFIX + '_test1') +# test2.conn = self.conn +# self.conn.bind(test2.__name__, PREFIX + '_test2') +# +# def test_is_declared(self): +# """ +# The table should not be created immediately after instantiation, +# but should be created when declare method is called +# :return: +# """ +# s = test1.Subjects() +# assert_false(s.is_declared) +# s.declare() +# assert_true(s.is_declared) +# +# def test_calling_heading_should_trigger_declaration(self): +# s = test1.Subjects() +# assert_false(s.is_declared) +# a = s.heading +# assert_true(s.is_declared) +# +# def test_foreign_key_ref_in_same_schema(self): +# s = test1.Experiments() +# assert_true('subject_id' in s.heading.primary_key) +# +# def test_foreign_key_ref_in_another_schema(self): +# s = test2.Experiments() +# assert_true('subject_id' in s.heading.primary_key) +# +# def test_aliased_module_name_should_resolve(self): +# """ +# Module names that were aliased in the definition should +# be properly resolved. +# """ +# s = test2.Conditions() +# assert_true('subject_id' in s.heading.primary_key) +# +# def test_reference_to_unknown_module_in_definition_should_fail(self): +# """ +# Module names in table definition that is not aliased via import +# results in error +# """ +# s = test2.FoodPreference() +# with assert_raises(DataJointError) as e: +# s.declare() +# +# +# class TestRelationWithExistingTables(object): +# """ +# Test base derivatives behaviors when some of the tables +# already exists in the database +# """ +# def setup(self): +# cleanup() +# self.conn = Connection(**CONN_INFO) +# setup_sample_db() +# test1.conn = self.conn +# self.conn.bind(test1.__name__, PREFIX + '_test1') +# test2.conn = self.conn +# self.conn.bind(test2.__name__, PREFIX + '_test2') +# self.conn.load_headings(force=True) +# +# schema2.conn = self.conn +# self.conn.bind(schema2.__name__, PREFIX + '_package') +# +# def teardown(selfself): +# schema1.__dict__.pop('conn', None) +# cleanup() +# +# def test_detection_of_existing_table(self): +# """ +# The Relation instance should be able to detect if the +# corresponding table already exists in the database +# """ +# s = test1.Subjects() +# assert_true(s.is_declared) +# +# def test_definition_referring_to_existing_table_without_class(self): +# s1 = test1.Sessions() +# assert_true('experimenter_id' in s1.primary_key) +# +# s2 = test2.Session() +# assert_true('experimenter_id' in s2.primary_key) +# +# def test_reference_to_package_level_table(self): +# s = test1.Match() +# s.declare() +# assert_true('pop_id' in s.primary_key) +# +# def test_direct_reference_to_existing_table_should_fail(self): +# """ +# When deriving from Relation, definition should not contain direct reference +# to a database name +# """ +# s = test1.TrainingSession() +# with assert_raises(DataJointError): +# s.declare() +# +# @raises(TypeError) +# def test_instantiation_of_base_derivative_without_definition_should_fail(): +# test1.Empty() +# +# +# +# diff --git a/tests/test_relation.py b/tests/test_relation.py index 15f52f34a..94adef010 100644 --- a/tests/test_relation.py +++ b/tests/test_relation.py @@ -1,490 +1,490 @@ -import random -import string - -__author__ = 'fabee' - -from .schemata.schema1 import test1, test4 - -from . import BASE_CONN, CONN_INFO, PREFIX, cleanup -from datajoint.connection import Connection -from nose.tools import assert_raises, assert_equal, assert_regexp_matches, assert_false, assert_true, assert_list_equal,\ - assert_tuple_equal, assert_dict_equal, raises -from datajoint import DataJointError, TransactionError, AutoPopulate, Relation -import numpy as np -from numpy.testing import assert_array_equal -from datajoint.relation import FreeRelation -import numpy as np - - -def trial_faker(n=10): - def iter(): - for s in [1, 2]: - for i in range(n): - yield dict(trial_id=i, subject_id=s, outcome=int(np.random.randint(10)), notes= 'no comment') - return iter() - - -def setup(): - """ - Setup connections and bindings - """ - pass - - -class TestTableObject(object): - def __init__(self): - self.subjects = None - self.setup() - - """ - Test cases for FreeRelation objects - """ - - def setup(self): - """ - Create a connection object and prepare test modules - as follows: - test1 - has conn and bounded - """ - cleanup() # drop all databases with PREFIX - test1.__dict__.pop('conn', None) - test4.__dict__.pop('conn', None) # make sure conn is not defined at schema level - - self.conn = Connection(**CONN_INFO) - test1.conn = self.conn - test4.conn = self.conn - self.conn.bind(test1.__name__, PREFIX + '_test1') - self.conn.bind(test4.__name__, PREFIX + '_test4') - self.subjects = test1.Subjects() - self.animals = test1.Animals() - self.relvar_blob = test4.Matrix() - self.trials = test1.Trials() - - def teardown(self): - cleanup() - - def test_compound_restriction(self): - s = self.subjects - t = self.trials - - s.insert(dict(subject_id=1, real_id='M')) - s.insert(dict(subject_id=2, real_id='F')) - t.iter_insert(trial_faker(20)) - - tM = t & (s & "real_id = 'M'") - t1 = t & "subject_id = 1" - - assert_equal(len(tM), len(t1), "Results of compound request does not have same length") - - for t1_item, tM_item in zip(sorted(t1, key=lambda item: item['trial_id']), - sorted(tM, key=lambda item: item['trial_id'])): - assert_dict_equal(t1_item, tM_item, - 'Dictionary elements do not agree in compound statement') - - def test_record_insert(self): - "Test whether record insert works" - tmp = np.array([(2, 'Klara', 'monkey')], - dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) - - self.subjects.insert(tmp[0]) - testt2 = (self.subjects & 'subject_id = 2').fetch()[0] - assert_equal(tuple(tmp[0]), tuple(testt2), "Inserted and fetched record do not match!") - - def test_delete(self): - "Test whether delete works" - tmp = np.array([(2, 'Klara', 'monkey'), (1,'Peter', 'mouse')], - dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) - - self.subjects.batch_insert(tmp) - assert_true(len(self.subjects) == 2, 'Length does not match 2.') - self.subjects.delete() - assert_true(len(self.subjects) == 0, 'Length does not match 0.') - - # def test_cascading_delete(self): - # "Test whether delete works" - # tmp = np.array([(2, 'Klara', 'monkey'), (1,'Peter', 'mouse')], - # dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) - # - # self.subjects.batch_insert(tmp) - # - # self.trials.insert(dict(subject_id=1, trial_id=1, outcome=0)) - # self.trials.insert(dict(subject_id=1, trial_id=2, outcome=1)) - # self.trials.insert(dict(subject_id=2, trial_id=3, outcome=2)) - # assert_true(len(self.subjects) == 2, 'Length does not match 2.') - # assert_true(len(self.trials) == 3, 'Length does not match 3.') - # (self.subjects & 'subject_id=1').delete() - # assert_true(len(self.subjects) == 1, 'Length does not match 1.') - # assert_true(len(self.trials) == 1, 'Length does not match 1.') - - def test_short_hand_foreign_reference(self): - self.animals.heading - - - - def test_record_insert_different_order(self): - "Test whether record insert works" - tmp = np.array([('Klara', 2, 'monkey')], - dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) - - self.subjects.insert(tmp[0]) - testt2 = (self.subjects & 'subject_id = 2').fetch()[0] - assert_equal((2, 'Klara', 'monkey'), tuple(testt2), - "Inserted and fetched record do not match!") - - @raises(TransactionError) - def test_transaction_error(self): - "Test whether declaration in transaction is prohibited" - - tmp = np.array([('Klara', 2, 'monkey')], - dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) - self.conn.start_transaction() - self.subjects.insert(tmp[0]) - - # def test_transaction_suppress_error(self): - # "Test whether ignore_errors ignores the errors." - # - # tmp = np.array([('Klara', 2, 'monkey')], - # dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) - # with self.conn.transaction(ignore_errors=True) as tr: - # self.subjects.insert(tmp[0]) - - - @raises(TransactionError) - def test_transaction_error_not_resolve(self): - "Test whether declaration in transaction is prohibited" - - tmp = np.array([('Klara', 2, 'monkey'), ('Klara', 3, 'monkey')], - dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) - try: - self.conn.start_transaction() - self.subjects.insert(tmp[0]) - except TransactionError as te: - self.conn.cancel_transaction() - - self.conn.start_transaction() - self.subjects.insert(tmp[0]) - - def test_transaction_error_resolve(self): - "Test whether declaration in transaction is prohibited" - - tmp = np.array([('Klara', 2, 'monkey'), ('Klara', 3, 'monkey')], - dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) - try: - self.conn.start_transaction() - self.subjects.insert(tmp[0]) - except TransactionError as te: - self.conn.cancel_transaction() - te.resolve() - - self.conn.start_transaction() - self.subjects.insert(tmp[0]) - self.conn.commit_transaction() - - def test_transaction_error2(self): - "If table is declared, we are allowed to insert within a transaction" - - tmp = np.array([('Klara', 2, 'monkey'), ('Klara', 3, 'monkey')], - dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) - self.subjects.insert(tmp[0]) - - self.conn.start_transaction() - self.subjects.insert(tmp[1]) - self.conn.commit_transaction() - - - @raises(KeyError) - def test_wrong_key_insert_records(self): - "Test whether record insert works" - tmp = np.array([('Klara', 2, 'monkey')], - dtype=[('real_deal', 'O'), ('subject_id', '>i4'), ('species', 'O')]) - - self.subjects.insert(tmp[0]) - - - def test_dict_insert(self): - "Test whether record insert works" - tmp = {'real_id': 'Brunhilda', - 'subject_id': 3, - 'species': 'human'} - - self.subjects.insert(tmp) - testt2 = (self.subjects & 'subject_id = 3').fetch()[0] - assert_equal((3, 'Brunhilda', 'human'), tuple(testt2), "Inserted and fetched record do not match!") - - @raises(KeyError) - def test_wrong_key_insert(self): - "Test whether a correct error is generated when inserting wrong attribute name" - tmp = {'real_deal': 'Brunhilda', - 'subject_database': 3, - 'species': 'human'} - - self.subjects.insert(tmp) - - def test_batch_insert(self): - "Test whether record insert works" - tmp = np.array([('Klara', 2, 'monkey'), ('Brunhilda', 3, 'mouse'), ('Mickey', 1, 'human')], - dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) - - self.subjects.batch_insert(tmp) - - expected = np.array([(1, 'Mickey', 'human'), (2, 'Klara', 'monkey'), - (3, 'Brunhilda', 'mouse')], - dtype=[('subject_id', 'i4'), ('species', 'O')]) - - self.subjects.iter_insert(tmp.__iter__()) - - expected = np.array([(1, 'Mickey', 'human'), (2, 'Klara', 'monkey'), - (3, 'Brunhilda', 'mouse')], - dtype=[('subject_id', ' `dj_free`.Animals - rec_session_id : int # recording session identifier - """ - table = FreeRelation(self.conn, 'dj_free', 'Recordings', definition) - assert_raises(DataJointError, table.declare) - - def test_reference_to_existing_table(self): - definition1 = """ - `dj_free`.Animals (manual) # my animal table - animal_id : int # unique id for the animal - --- - animal_name : varchar(128) # name of the animal - """ - table1 = FreeRelation(self.conn, 'dj_free', 'Animals', definition1) - table1.declare() - - definition2 = """ - `dj_free`.Recordings (manual) # recordings - -> `dj_free`.Animals - rec_session_id : int # recording session identifier - """ - table2 = FreeRelation(self.conn, 'dj_free', 'Recordings', definition2) - table2.declare() - assert_true('animal_id' in table2.primary_key) - - -def id_generator(size=6, chars=string.ascii_uppercase + string.digits): - return ''.join(random.choice(chars) for _ in range(size)) - -class TestIterator(object): - def __init__(self): - self.relvar = None - self.setup() - - """ - Test cases for Iterators in Relations objects - """ - - def setup(self): - """ - Create a connection object and prepare test modules - as follows: - test1 - has conn and bounded - """ - cleanup() # drop all databases with PREFIX - test4.__dict__.pop('conn', None) # make sure conn is not defined at schema level - - self.conn = Connection(**CONN_INFO) - test4.conn = self.conn - self.conn.bind(test4.__name__, PREFIX + '_test4') - self.relvar_blob = test4.Matrix() - - def teardown(self): - cleanup() - - - def test_blob_iteration(self): - "Tests the basic call of the iterator" - - dicts = [] - for i in range(10): - - c = id_generator() - - t = {'matrix_id':i, - 'data': np.random.randn(4,4,4), - 'comment': c} - self.relvar_blob.insert(t) - dicts.append(t) - - for t, t2 in zip(dicts, self.relvar_blob): - assert_true(isinstance(t2, dict), 'iterator does not return dict') - - assert_equal(t['matrix_id'], t2['matrix_id'], 'inserted and retrieved tuples do not match') - assert_equal(t['comment'], t2['comment'], 'inserted and retrieved tuples do not match') - assert_true(np.all(t['data'] == t2['data']), 'inserted and retrieved tuples do not match') - - def test_fetch(self): - dicts = [] - for i in range(10): - - c = id_generator() - - t = {'matrix_id':i, - 'data': np.random.randn(4,4,4), - 'comment': c} - self.relvar_blob.insert(t) - dicts.append(t) - - tuples2 = self.relvar_blob.fetch() - assert_true(isinstance(tuples2, np.ndarray), "Return value of fetch does not have proper type.") - assert_true(isinstance(tuples2[0], np.void), "Return value of fetch does not have proper type.") - for t, t2 in zip(dicts, tuples2): - - assert_equal(t['matrix_id'], t2['matrix_id'], 'inserted and retrieved tuples do not match') - assert_equal(t['comment'], t2['comment'], 'inserted and retrieved tuples do not match') - assert_true(np.all(t['data'] == t2['data']), 'inserted and retrieved tuples do not match') - - def test_fetch_dicts(self): - dicts = [] - for i in range(10): - - c = id_generator() - - t = {'matrix_id':i, - 'data': np.random.randn(4,4,4), - 'comment': c} - self.relvar_blob.insert(t) - dicts.append(t) - - tuples2 = self.relvar_blob.fetch(as_dict=True) - assert_true(isinstance(tuples2, list), "Return value of fetch with as_dict=True does not have proper type.") - assert_true(isinstance(tuples2[0], dict), "Return value of fetch with as_dict=True does not have proper type.") - for t, t2 in zip(dicts, tuples2): - assert_equal(t['matrix_id'], t2['matrix_id'], 'inserted and retrieved dicts do not match') - assert_equal(t['comment'], t2['comment'], 'inserted and retrieved dicts do not match') - assert_true(np.all(t['data'] == t2['data']), 'inserted and retrieved dicts do not match') - - - -class TestAutopopulate(object): - def __init__(self): - self.relvar = None - self.setup() - - """ - Test cases for Iterators in Relations objects - """ - - def setup(self): - """ - Create a connection object and prepare test modules - as follows: - test1 - has conn and bounded - """ - cleanup() # drop all databases with PREFIX - test1.__dict__.pop('conn', None) # make sure conn is not defined at schema level - - self.conn = Connection(**CONN_INFO) - test1.conn = self.conn - self.conn.bind(test1.__name__, PREFIX + '_test1') - - self.subjects = test1.Subjects() - self.trials = test1.Trials() - self.squared = test1.SquaredScore() - self.dummy = test1.SquaredSubtable() - self.dummy1 = test1.WrongImplementation() - self.error_generator = test1.ErrorGenerator() - self.fill_relation() - - - - def fill_relation(self): - tmp = np.array([('Klara', 2, 'monkey'), ('Peter', 3, 'mouse')], - dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) - self.subjects.batch_insert(tmp) - - for trial_id in range(1,11): - self.trials.insert(dict(subject_id=2, trial_id=trial_id, outcome=np.random.randint(0,10))) - - def teardown(self): - cleanup() - - def test_autopopulate(self): - self.squared.populate() - assert_equal(len(self.squared), 10) - - for trial in self.trials*self.squared: - assert_equal(trial['outcome']**2, trial['squared']) - - def test_autopopulate_restriction(self): - self.squared.populate(restriction='trial_id <= 5') - assert_equal(len(self.squared), 5) - - for trial in self.trials*self.squared: - assert_equal(trial['outcome']**2, trial['squared']) - - - # def test_autopopulate_transaction_error(self): - # errors = self.squared.populate(suppress_errors=True) - # assert_equal(len(errors), 1) - # assert_true(isinstance(errors[0][1], TransactionError)) - - @raises(DataJointError) - def test_autopopulate_relation_check(self): - - class dummy(AutoPopulate): - - def populate_relation(self): - return None - - def _make_tuples(self, key): - pass - - du = dummy() - du.populate() \ - - @raises(DataJointError) - def test_autopopulate_relation_check(self): - self.dummy1.populate() - - @raises(Exception) - def test_autopopulate_relation_check(self): - self.error_generator.populate()\ - - @raises(Exception) - def test_autopopulate_relation_check2(self): - tmp = self.dummy2.populate(suppress_errors=True) - assert_equal(len(tmp), 1, 'Error list should have length 1.') +# import random +# import string +# +# __author__ = 'fabee' +# +# from .schemata.schema1 import test1, test4 +# +# from . import BASE_CONN, CONN_INFO, PREFIX, cleanup +# from datajoint.connection import Connection +# from nose.tools import assert_raises, assert_equal, assert_regexp_matches, assert_false, assert_true, assert_list_equal,\ +# assert_tuple_equal, assert_dict_equal, raises +# from datajoint import DataJointError, TransactionError, AutoPopulate, Relation +# import numpy as np +# from numpy.testing import assert_array_equal +# from datajoint.relation import FreeRelation +# import numpy as np +# +# +# def trial_faker(n=10): +# def iter(): +# for s in [1, 2]: +# for i in range(n): +# yield dict(trial_id=i, subject_id=s, outcome=int(np.random.randint(10)), notes= 'no comment') +# return iter() +# +# +# def setup(): +# """ +# Setup connections and bindings +# """ +# pass +# +# +# class TestTableObject(object): +# def __init__(self): +# self.subjects = None +# self.setup() +# +# """ +# Test cases for FreeRelation objects +# """ +# +# def setup(self): +# """ +# Create a connection object and prepare test modules +# as follows: +# test1 - has conn and bounded +# """ +# cleanup() # drop all databases with PREFIX +# test1.__dict__.pop('conn', None) +# test4.__dict__.pop('conn', None) # make sure conn is not defined at schema level +# +# self.conn = Connection(**CONN_INFO) +# test1.conn = self.conn +# test4.conn = self.conn +# self.conn.bind(test1.__name__, PREFIX + '_test1') +# self.conn.bind(test4.__name__, PREFIX + '_test4') +# self.subjects = test1.Subjects() +# self.animals = test1.Animals() +# self.relvar_blob = test4.Matrix() +# self.trials = test1.Trials() +# +# def teardown(self): +# cleanup() +# +# def test_compound_restriction(self): +# s = self.subjects +# t = self.trials +# +# s.insert(dict(subject_id=1, real_id='M')) +# s.insert(dict(subject_id=2, real_id='F')) +# t.iter_insert(trial_faker(20)) +# +# tM = t & (s & "real_id = 'M'") +# t1 = t & "subject_id = 1" +# +# assert_equal(len(tM), len(t1), "Results of compound request does not have same length") +# +# for t1_item, tM_item in zip(sorted(t1, key=lambda item: item['trial_id']), +# sorted(tM, key=lambda item: item['trial_id'])): +# assert_dict_equal(t1_item, tM_item, +# 'Dictionary elements do not agree in compound statement') +# +# def test_record_insert(self): +# "Test whether record insert works" +# tmp = np.array([(2, 'Klara', 'monkey')], +# dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) +# +# self.subjects.insert(tmp[0]) +# testt2 = (self.subjects & 'subject_id = 2').fetch()[0] +# assert_equal(tuple(tmp[0]), tuple(testt2), "Inserted and fetched record do not match!") +# +# def test_delete(self): +# "Test whether delete works" +# tmp = np.array([(2, 'Klara', 'monkey'), (1,'Peter', 'mouse')], +# dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) +# +# self.subjects.batch_insert(tmp) +# assert_true(len(self.subjects) == 2, 'Length does not match 2.') +# self.subjects.delete() +# assert_true(len(self.subjects) == 0, 'Length does not match 0.') +# +# # def test_cascading_delete(self): +# # "Test whether delete works" +# # tmp = np.array([(2, 'Klara', 'monkey'), (1,'Peter', 'mouse')], +# # dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) +# # +# # self.subjects.batch_insert(tmp) +# # +# # self.trials.insert(dict(subject_id=1, trial_id=1, outcome=0)) +# # self.trials.insert(dict(subject_id=1, trial_id=2, outcome=1)) +# # self.trials.insert(dict(subject_id=2, trial_id=3, outcome=2)) +# # assert_true(len(self.subjects) == 2, 'Length does not match 2.') +# # assert_true(len(self.trials) == 3, 'Length does not match 3.') +# # (self.subjects & 'subject_id=1').delete() +# # assert_true(len(self.subjects) == 1, 'Length does not match 1.') +# # assert_true(len(self.trials) == 1, 'Length does not match 1.') +# +# def test_short_hand_foreign_reference(self): +# self.animals.heading +# +# +# +# def test_record_insert_different_order(self): +# "Test whether record insert works" +# tmp = np.array([('Klara', 2, 'monkey')], +# dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) +# +# self.subjects.insert(tmp[0]) +# testt2 = (self.subjects & 'subject_id = 2').fetch()[0] +# assert_equal((2, 'Klara', 'monkey'), tuple(testt2), +# "Inserted and fetched record do not match!") +# +# @raises(TransactionError) +# def test_transaction_error(self): +# "Test whether declaration in transaction is prohibited" +# +# tmp = np.array([('Klara', 2, 'monkey')], +# dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) +# self.conn.start_transaction() +# self.subjects.insert(tmp[0]) +# +# # def test_transaction_suppress_error(self): +# # "Test whether ignore_errors ignores the errors." +# # +# # tmp = np.array([('Klara', 2, 'monkey')], +# # dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) +# # with self.conn.transaction(ignore_errors=True) as tr: +# # self.subjects.insert(tmp[0]) +# +# +# @raises(TransactionError) +# def test_transaction_error_not_resolve(self): +# "Test whether declaration in transaction is prohibited" +# +# tmp = np.array([('Klara', 2, 'monkey'), ('Klara', 3, 'monkey')], +# dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) +# try: +# self.conn.start_transaction() +# self.subjects.insert(tmp[0]) +# except TransactionError as te: +# self.conn.cancel_transaction() +# +# self.conn.start_transaction() +# self.subjects.insert(tmp[0]) +# +# def test_transaction_error_resolve(self): +# "Test whether declaration in transaction is prohibited" +# +# tmp = np.array([('Klara', 2, 'monkey'), ('Klara', 3, 'monkey')], +# dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) +# try: +# self.conn.start_transaction() +# self.subjects.insert(tmp[0]) +# except TransactionError as te: +# self.conn.cancel_transaction() +# te.resolve() +# +# self.conn.start_transaction() +# self.subjects.insert(tmp[0]) +# self.conn.commit_transaction() +# +# def test_transaction_error2(self): +# "If table is declared, we are allowed to insert within a transaction" +# +# tmp = np.array([('Klara', 2, 'monkey'), ('Klara', 3, 'monkey')], +# dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) +# self.subjects.insert(tmp[0]) +# +# self.conn.start_transaction() +# self.subjects.insert(tmp[1]) +# self.conn.commit_transaction() +# +# +# @raises(KeyError) +# def test_wrong_key_insert_records(self): +# "Test whether record insert works" +# tmp = np.array([('Klara', 2, 'monkey')], +# dtype=[('real_deal', 'O'), ('subject_id', '>i4'), ('species', 'O')]) +# +# self.subjects.insert(tmp[0]) +# +# +# def test_dict_insert(self): +# "Test whether record insert works" +# tmp = {'real_id': 'Brunhilda', +# 'subject_id': 3, +# 'species': 'human'} +# +# self.subjects.insert(tmp) +# testt2 = (self.subjects & 'subject_id = 3').fetch()[0] +# assert_equal((3, 'Brunhilda', 'human'), tuple(testt2), "Inserted and fetched record do not match!") +# +# @raises(KeyError) +# def test_wrong_key_insert(self): +# "Test whether a correct error is generated when inserting wrong attribute name" +# tmp = {'real_deal': 'Brunhilda', +# 'subject_database': 3, +# 'species': 'human'} +# +# self.subjects.insert(tmp) +# +# def test_batch_insert(self): +# "Test whether record insert works" +# tmp = np.array([('Klara', 2, 'monkey'), ('Brunhilda', 3, 'mouse'), ('Mickey', 1, 'human')], +# dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) +# +# self.subjects.batch_insert(tmp) +# +# expected = np.array([(1, 'Mickey', 'human'), (2, 'Klara', 'monkey'), +# (3, 'Brunhilda', 'mouse')], +# dtype=[('subject_id', 'i4'), ('species', 'O')]) +# +# self.subjects.iter_insert(tmp.__iter__()) +# +# expected = np.array([(1, 'Mickey', 'human'), (2, 'Klara', 'monkey'), +# (3, 'Brunhilda', 'mouse')], +# dtype=[('subject_id', ' `dj_free`.Animals +# rec_session_id : int # recording session identifier +# """ +# table = FreeRelation(self.conn, 'dj_free', 'Recordings', definition) +# assert_raises(DataJointError, table.declare) +# +# def test_reference_to_existing_table(self): +# definition1 = """ +# `dj_free`.Animals (manual) # my animal table +# animal_id : int # unique id for the animal +# --- +# animal_name : varchar(128) # name of the animal +# """ +# table1 = FreeRelation(self.conn, 'dj_free', 'Animals', definition1) +# table1.declare() +# +# definition2 = """ +# `dj_free`.Recordings (manual) # recordings +# -> `dj_free`.Animals +# rec_session_id : int # recording session identifier +# """ +# table2 = FreeRelation(self.conn, 'dj_free', 'Recordings', definition2) +# table2.declare() +# assert_true('animal_id' in table2.primary_key) +# +# +# def id_generator(size=6, chars=string.ascii_uppercase + string.digits): +# return ''.join(random.choice(chars) for _ in range(size)) +# +# class TestIterator(object): +# def __init__(self): +# self.relvar = None +# self.setup() +# +# """ +# Test cases for Iterators in Relations objects +# """ +# +# def setup(self): +# """ +# Create a connection object and prepare test modules +# as follows: +# test1 - has conn and bounded +# """ +# cleanup() # drop all databases with PREFIX +# test4.__dict__.pop('conn', None) # make sure conn is not defined at schema level +# +# self.conn = Connection(**CONN_INFO) +# test4.conn = self.conn +# self.conn.bind(test4.__name__, PREFIX + '_test4') +# self.relvar_blob = test4.Matrix() +# +# def teardown(self): +# cleanup() +# +# +# def test_blob_iteration(self): +# "Tests the basic call of the iterator" +# +# dicts = [] +# for i in range(10): +# +# c = id_generator() +# +# t = {'matrix_id':i, +# 'data': np.random.randn(4,4,4), +# 'comment': c} +# self.relvar_blob.insert(t) +# dicts.append(t) +# +# for t, t2 in zip(dicts, self.relvar_blob): +# assert_true(isinstance(t2, dict), 'iterator does not return dict') +# +# assert_equal(t['matrix_id'], t2['matrix_id'], 'inserted and retrieved tuples do not match') +# assert_equal(t['comment'], t2['comment'], 'inserted and retrieved tuples do not match') +# assert_true(np.all(t['data'] == t2['data']), 'inserted and retrieved tuples do not match') +# +# def test_fetch(self): +# dicts = [] +# for i in range(10): +# +# c = id_generator() +# +# t = {'matrix_id':i, +# 'data': np.random.randn(4,4,4), +# 'comment': c} +# self.relvar_blob.insert(t) +# dicts.append(t) +# +# tuples2 = self.relvar_blob.fetch() +# assert_true(isinstance(tuples2, np.ndarray), "Return value of fetch does not have proper type.") +# assert_true(isinstance(tuples2[0], np.void), "Return value of fetch does not have proper type.") +# for t, t2 in zip(dicts, tuples2): +# +# assert_equal(t['matrix_id'], t2['matrix_id'], 'inserted and retrieved tuples do not match') +# assert_equal(t['comment'], t2['comment'], 'inserted and retrieved tuples do not match') +# assert_true(np.all(t['data'] == t2['data']), 'inserted and retrieved tuples do not match') +# +# def test_fetch_dicts(self): +# dicts = [] +# for i in range(10): +# +# c = id_generator() +# +# t = {'matrix_id':i, +# 'data': np.random.randn(4,4,4), +# 'comment': c} +# self.relvar_blob.insert(t) +# dicts.append(t) +# +# tuples2 = self.relvar_blob.fetch(as_dict=True) +# assert_true(isinstance(tuples2, list), "Return value of fetch with as_dict=True does not have proper type.") +# assert_true(isinstance(tuples2[0], dict), "Return value of fetch with as_dict=True does not have proper type.") +# for t, t2 in zip(dicts, tuples2): +# assert_equal(t['matrix_id'], t2['matrix_id'], 'inserted and retrieved dicts do not match') +# assert_equal(t['comment'], t2['comment'], 'inserted and retrieved dicts do not match') +# assert_true(np.all(t['data'] == t2['data']), 'inserted and retrieved dicts do not match') +# +# +# +# class TestAutopopulate(object): +# def __init__(self): +# self.relvar = None +# self.setup() +# +# """ +# Test cases for Iterators in Relations objects +# """ +# +# def setup(self): +# """ +# Create a connection object and prepare test modules +# as follows: +# test1 - has conn and bounded +# """ +# cleanup() # drop all databases with PREFIX +# test1.__dict__.pop('conn', None) # make sure conn is not defined at schema level +# +# self.conn = Connection(**CONN_INFO) +# test1.conn = self.conn +# self.conn.bind(test1.__name__, PREFIX + '_test1') +# +# self.subjects = test1.Subjects() +# self.trials = test1.Trials() +# self.squared = test1.SquaredScore() +# self.dummy = test1.SquaredSubtable() +# self.dummy1 = test1.WrongImplementation() +# self.error_generator = test1.ErrorGenerator() +# self.fill_relation() +# +# +# +# def fill_relation(self): +# tmp = np.array([('Klara', 2, 'monkey'), ('Peter', 3, 'mouse')], +# dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) +# self.subjects.batch_insert(tmp) +# +# for trial_id in range(1,11): +# self.trials.insert(dict(subject_id=2, trial_id=trial_id, outcome=np.random.randint(0,10))) +# +# def teardown(self): +# cleanup() +# +# def test_autopopulate(self): +# self.squared.populate() +# assert_equal(len(self.squared), 10) +# +# for trial in self.trials*self.squared: +# assert_equal(trial['outcome']**2, trial['squared']) +# +# def test_autopopulate_restriction(self): +# self.squared.populate(restriction='trial_id <= 5') +# assert_equal(len(self.squared), 5) +# +# for trial in self.trials*self.squared: +# assert_equal(trial['outcome']**2, trial['squared']) +# +# +# # def test_autopopulate_transaction_error(self): +# # errors = self.squared.populate(suppress_errors=True) +# # assert_equal(len(errors), 1) +# # assert_true(isinstance(errors[0][1], TransactionError)) +# +# @raises(DataJointError) +# def test_autopopulate_relation_check(self): +# +# class dummy(AutoPopulate): +# +# def populate_relation(self): +# return None +# +# def _make_tuples(self, key): +# pass +# +# du = dummy() +# du.populate() \ +# +# @raises(DataJointError) +# def test_autopopulate_relation_check(self): +# self.dummy1.populate() +# +# @raises(Exception) +# def test_autopopulate_relation_check(self): +# self.error_generator.populate()\ +# +# @raises(Exception) +# def test_autopopulate_relation_check2(self): +# tmp = self.dummy2.populate(suppress_errors=True) +# assert_equal(len(tmp), 1, 'Error list should have length 1.') diff --git a/tests/test_relational_operand.py b/tests/test_relational_operand.py index 3a25a87d0..3ceeb4964 100644 --- a/tests/test_relational_operand.py +++ b/tests/test_relational_operand.py @@ -1,47 +1,47 @@ -""" -Collection of test cases to test relational methods -""" - -__author__ = 'eywalker' - - -def setup(): - """ - Setup - :return: - """ - -class TestRelationalAlgebra(object): - - def setup(self): - pass - - def test_mul(self): - pass - - def test_project(self): - pass - - def test_iand(self): - pass - - def test_isub(self): - pass - - def test_sub(self): - pass - - def test_len(self): - pass - - def test_fetch(self): - pass - - def test_repr(self): - pass - - def test_iter(self): - pass - - def test_not(self): - pass \ No newline at end of file +# """ +# Collection of test cases to test relational methods +# """ +# +# __author__ = 'eywalker' +# +# +# def setup(): +# """ +# Setup +# :return: +# """ +# +# class TestRelationalAlgebra(object): +# +# def setup(self): +# pass +# +# def test_mul(self): +# pass +# +# def test_project(self): +# pass +# +# def test_iand(self): +# pass +# +# def test_isub(self): +# pass +# +# def test_sub(self): +# pass +# +# def test_len(self): +# pass +# +# def test_fetch(self): +# pass +# +# def test_repr(self): +# pass +# +# def test_iter(self): +# pass +# +# def test_not(self): +# pass \ No newline at end of file diff --git a/tests/test_settings.py b/tests/test_settings.py index 6b8100806..24b05d5d4 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -1,67 +1,67 @@ -import os -import pprint -import random -import string -from datajoint import settings - -__author__ = 'Fabian Sinz' - -from nose.tools import assert_true, assert_raises, assert_equal, raises, assert_dict_equal -import datajoint as dj - - -def test_load_save(): - dj.config.save('tmp.json') - conf = dj.Config() - conf.load('tmp.json') - assert_true(conf == dj.config, 'Two config files do not match.') - os.remove('tmp.json') - -def test_singleton(): - dj.config.save('tmp.json') - conf = dj.Config() - conf.load('tmp.json') - conf['dummy.val'] = 2 - - assert_true(conf == dj.config, 'Config does not behave like a singleton.') - os.remove('tmp.json') - - -@raises(ValueError) -def test_nested_check(): - dummy = {'dummy.testval': {'notallowed': 2}} - dj.config.update(dummy) - -@raises(dj.DataJointError) -def test_validator(): - dj.config['database.port'] = 'harbor' - -def test_del(): - dj.config['peter'] = 2 - assert_true('peter' in dj.config) - del dj.config['peter'] - assert_true('peter' not in dj.config) - -def test_len(): - assert_equal(len(dj.config), len(dj.config._conf)) - -def test_str(): - assert_equal(str(dj.config), pprint.pformat(dj.config._conf, indent=4)) - -def test_repr(): - assert_equal(repr(dj.config), pprint.pformat(dj.config._conf, indent=4)) - -@raises(ValueError) -def test_nested_check2(): - dj.config['dummy'] = {'dummy2':2} - -def test_save(): - tmpfile = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(20)) - moved = False - if os.path.isfile(settings.LOCALCONFIG): - os.rename(settings.LOCALCONFIG, tmpfile) - moved = True - dj.config.save() - assert_true(os.path.isfile(settings.LOCALCONFIG)) - if moved: - os.rename(tmpfile, settings.LOCALCONFIG) +# import os +# import pprint +# import random +# import string +# from datajoint import settings +# +# __author__ = 'Fabian Sinz' +# +# from nose.tools import assert_true, assert_raises, assert_equal, raises, assert_dict_equal +# import datajoint as dj +# +# +# def test_load_save(): +# dj.config.save('tmp.json') +# conf = dj.Config() +# conf.load('tmp.json') +# assert_true(conf == dj.config, 'Two config files do not match.') +# os.remove('tmp.json') +# +# def test_singleton(): +# dj.config.save('tmp.json') +# conf = dj.Config() +# conf.load('tmp.json') +# conf['dummy.val'] = 2 +# +# assert_true(conf == dj.config, 'Config does not behave like a singleton.') +# os.remove('tmp.json') +# +# +# @raises(ValueError) +# def test_nested_check(): +# dummy = {'dummy.testval': {'notallowed': 2}} +# dj.config.update(dummy) +# +# @raises(dj.DataJointError) +# def test_validator(): +# dj.config['database.port'] = 'harbor' +# +# def test_del(): +# dj.config['peter'] = 2 +# assert_true('peter' in dj.config) +# del dj.config['peter'] +# assert_true('peter' not in dj.config) +# +# def test_len(): +# assert_equal(len(dj.config), len(dj.config._conf)) +# +# def test_str(): +# assert_equal(str(dj.config), pprint.pformat(dj.config._conf, indent=4)) +# +# def test_repr(): +# assert_equal(repr(dj.config), pprint.pformat(dj.config._conf, indent=4)) +# +# @raises(ValueError) +# def test_nested_check2(): +# dj.config['dummy'] = {'dummy2':2} +# +# def test_save(): +# tmpfile = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(20)) +# moved = False +# if os.path.isfile(settings.LOCALCONFIG): +# os.rename(settings.LOCALCONFIG, tmpfile) +# moved = True +# dj.config.save() +# assert_true(os.path.isfile(settings.LOCALCONFIG)) +# if moved: +# os.rename(tmpfile, settings.LOCALCONFIG) diff --git a/tests/test_utils.py b/tests/test_utils.py index 655884ce0..2322cf18a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,33 +1,33 @@ -""" -Collection of test cases to test core module. -""" - -__author__ = 'eywalker' -from nose.tools import assert_true, assert_raises, assert_equal -from datajoint.utils import to_camel_case, from_camel_case -from datajoint import DataJointError - - -def setup(): - pass - - -def teardown(): - pass - - -def test_to_camel_case(): - assert_equal(to_camel_case('basic_sessions'), 'BasicSessions') - assert_equal(to_camel_case('_another_table'), 'AnotherTable') - - -def test_from_camel_case(): - assert_equal(from_camel_case('AllGroups'), 'all_groups') - with assert_raises(DataJointError): - from_camel_case('repNames') - with assert_raises(DataJointError): - from_camel_case('10_all') - with assert_raises(DataJointError): - from_camel_case('hello world') - with assert_raises(DataJointError): - from_camel_case('#baisc_names') +# """ +# Collection of test cases to test core module. +# """ +# +# __author__ = 'eywalker' +# from nose.tools import assert_true, assert_raises, assert_equal +# from datajoint.utils import to_camel_case, from_camel_case +# from datajoint import DataJointError +# +# +# def setup(): +# pass +# +# +# def teardown(): +# pass +# +# +# def test_to_camel_case(): +# assert_equal(to_camel_case('basic_sessions'), 'BasicSessions') +# assert_equal(to_camel_case('_another_table'), 'AnotherTable') +# +# +# def test_from_camel_case(): +# assert_equal(from_camel_case('AllGroups'), 'all_groups') +# with assert_raises(DataJointError): +# from_camel_case('repNames') +# with assert_raises(DataJointError): +# from_camel_case('10_all') +# with assert_raises(DataJointError): +# from_camel_case('hello world') +# with assert_raises(DataJointError): +# from_camel_case('#baisc_names') From 0c52d6cab08946c8ab4e2233b6d471e2e2ec95fe Mon Sep 17 00:00:00 2001 From: Fabian Sinz Date: Mon, 8 Jun 2015 18:35:46 -0500 Subject: [PATCH 05/20] first test runs without failing --- tests/test_utils.py | 66 ++++++++++++++++++++++----------------------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 2322cf18a..655884ce0 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,33 +1,33 @@ -# """ -# Collection of test cases to test core module. -# """ -# -# __author__ = 'eywalker' -# from nose.tools import assert_true, assert_raises, assert_equal -# from datajoint.utils import to_camel_case, from_camel_case -# from datajoint import DataJointError -# -# -# def setup(): -# pass -# -# -# def teardown(): -# pass -# -# -# def test_to_camel_case(): -# assert_equal(to_camel_case('basic_sessions'), 'BasicSessions') -# assert_equal(to_camel_case('_another_table'), 'AnotherTable') -# -# -# def test_from_camel_case(): -# assert_equal(from_camel_case('AllGroups'), 'all_groups') -# with assert_raises(DataJointError): -# from_camel_case('repNames') -# with assert_raises(DataJointError): -# from_camel_case('10_all') -# with assert_raises(DataJointError): -# from_camel_case('hello world') -# with assert_raises(DataJointError): -# from_camel_case('#baisc_names') +""" +Collection of test cases to test core module. +""" + +__author__ = 'eywalker' +from nose.tools import assert_true, assert_raises, assert_equal +from datajoint.utils import to_camel_case, from_camel_case +from datajoint import DataJointError + + +def setup(): + pass + + +def teardown(): + pass + + +def test_to_camel_case(): + assert_equal(to_camel_case('basic_sessions'), 'BasicSessions') + assert_equal(to_camel_case('_another_table'), 'AnotherTable') + + +def test_from_camel_case(): + assert_equal(from_camel_case('AllGroups'), 'all_groups') + with assert_raises(DataJointError): + from_camel_case('repNames') + with assert_raises(DataJointError): + from_camel_case('10_all') + with assert_raises(DataJointError): + from_camel_case('hello world') + with assert_raises(DataJointError): + from_camel_case('#baisc_names') From eda035bd74c12596562d5184153654a0ce2b4543 Mon Sep 17 00:00:00 2001 From: Fabian Sinz Date: Mon, 8 Jun 2015 18:36:07 -0500 Subject: [PATCH 06/20] 13 tests pass --- tests/test_settings.py | 134 ++++++++++++++++++++--------------------- 1 file changed, 67 insertions(+), 67 deletions(-) diff --git a/tests/test_settings.py b/tests/test_settings.py index 24b05d5d4..6b8100806 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -1,67 +1,67 @@ -# import os -# import pprint -# import random -# import string -# from datajoint import settings -# -# __author__ = 'Fabian Sinz' -# -# from nose.tools import assert_true, assert_raises, assert_equal, raises, assert_dict_equal -# import datajoint as dj -# -# -# def test_load_save(): -# dj.config.save('tmp.json') -# conf = dj.Config() -# conf.load('tmp.json') -# assert_true(conf == dj.config, 'Two config files do not match.') -# os.remove('tmp.json') -# -# def test_singleton(): -# dj.config.save('tmp.json') -# conf = dj.Config() -# conf.load('tmp.json') -# conf['dummy.val'] = 2 -# -# assert_true(conf == dj.config, 'Config does not behave like a singleton.') -# os.remove('tmp.json') -# -# -# @raises(ValueError) -# def test_nested_check(): -# dummy = {'dummy.testval': {'notallowed': 2}} -# dj.config.update(dummy) -# -# @raises(dj.DataJointError) -# def test_validator(): -# dj.config['database.port'] = 'harbor' -# -# def test_del(): -# dj.config['peter'] = 2 -# assert_true('peter' in dj.config) -# del dj.config['peter'] -# assert_true('peter' not in dj.config) -# -# def test_len(): -# assert_equal(len(dj.config), len(dj.config._conf)) -# -# def test_str(): -# assert_equal(str(dj.config), pprint.pformat(dj.config._conf, indent=4)) -# -# def test_repr(): -# assert_equal(repr(dj.config), pprint.pformat(dj.config._conf, indent=4)) -# -# @raises(ValueError) -# def test_nested_check2(): -# dj.config['dummy'] = {'dummy2':2} -# -# def test_save(): -# tmpfile = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(20)) -# moved = False -# if os.path.isfile(settings.LOCALCONFIG): -# os.rename(settings.LOCALCONFIG, tmpfile) -# moved = True -# dj.config.save() -# assert_true(os.path.isfile(settings.LOCALCONFIG)) -# if moved: -# os.rename(tmpfile, settings.LOCALCONFIG) +import os +import pprint +import random +import string +from datajoint import settings + +__author__ = 'Fabian Sinz' + +from nose.tools import assert_true, assert_raises, assert_equal, raises, assert_dict_equal +import datajoint as dj + + +def test_load_save(): + dj.config.save('tmp.json') + conf = dj.Config() + conf.load('tmp.json') + assert_true(conf == dj.config, 'Two config files do not match.') + os.remove('tmp.json') + +def test_singleton(): + dj.config.save('tmp.json') + conf = dj.Config() + conf.load('tmp.json') + conf['dummy.val'] = 2 + + assert_true(conf == dj.config, 'Config does not behave like a singleton.') + os.remove('tmp.json') + + +@raises(ValueError) +def test_nested_check(): + dummy = {'dummy.testval': {'notallowed': 2}} + dj.config.update(dummy) + +@raises(dj.DataJointError) +def test_validator(): + dj.config['database.port'] = 'harbor' + +def test_del(): + dj.config['peter'] = 2 + assert_true('peter' in dj.config) + del dj.config['peter'] + assert_true('peter' not in dj.config) + +def test_len(): + assert_equal(len(dj.config), len(dj.config._conf)) + +def test_str(): + assert_equal(str(dj.config), pprint.pformat(dj.config._conf, indent=4)) + +def test_repr(): + assert_equal(repr(dj.config), pprint.pformat(dj.config._conf, indent=4)) + +@raises(ValueError) +def test_nested_check2(): + dj.config['dummy'] = {'dummy2':2} + +def test_save(): + tmpfile = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(20)) + moved = False + if os.path.isfile(settings.LOCALCONFIG): + os.rename(settings.LOCALCONFIG, tmpfile) + moved = True + dj.config.save() + assert_true(os.path.isfile(settings.LOCALCONFIG)) + if moved: + os.rename(tmpfile, settings.LOCALCONFIG) From 52d4100eca6ca2e15592bb1ee8e6023a8fb7cdb9 Mon Sep 17 00:00:00 2001 From: Fabian Sinz Date: Mon, 8 Jun 2015 18:58:37 -0500 Subject: [PATCH 07/20] instantiating Subject fails --- datajoint/__init__.py | 3 ++- datajoint/relation.py | 32 +++++++++++----------- datajoint/user_relations.py | 14 ++++------ tests/schemata/__init__.py | 2 +- tests/schemata/{schema1 => }/test1.py | 39 ++++++++++++++------------- tests/test_relation.py | 6 +++++ 6 files changed, 51 insertions(+), 45 deletions(-) rename tests/schemata/{schema1 => }/test1.py (89%) diff --git a/datajoint/__init__.py b/datajoint/__init__.py index 8ebe93992..796ded488 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -42,4 +42,5 @@ class DataJointError(Exception): from .autopopulate import AutoPopulate from . import blob from .relational_operand import Not -from .heading import Heading \ No newline at end of file +from .heading import Heading +from .relation import schema \ No newline at end of file diff --git a/datajoint/relation.py b/datajoint/relation.py index 19c2ebac9..3c48690d9 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -18,7 +18,7 @@ SharedInfo = namedtuple( 'SharedInfo', - ('database', 'context', 'connection', 'heading', 'parents', 'children', 'references', 'referenced')) + ('database', 'context', 'connection', 'heading')) @@ -101,13 +101,13 @@ def table_name(cls): """ pass - @classproperty - @abc.abstractmethod - def database(cls): - """ - :return: string containing the database name on the server - """ - pass + # @classproperty + # @abc.abstractmethod + # def database(cls): + # """ + # :return: string containing the database name on the server + # """ + # pass @classproperty @abc.abstractmethod @@ -117,13 +117,13 @@ def definition(cls): """ pass - @classproperty - @abc.abstractmethod - def context(cls): - """ - :return: a dict with other relations that can be referenced by foreign keys - """ - pass + # @classproperty + # @abc.abstractmethod + # def context(cls): + # """ + # :return: a dict with other relations that can be referenced by foreign keys + # """ + # pass # --------- base relation functionality --------- # @classproperty @@ -132,7 +132,7 @@ def is_declared(cls): return True cur = cls._shared_info.connection.query( 'SHOW TABLE STATUS FROM `{database}` WHERE name="{table_name}"'.format( - table_name=cls.table_name)) + database=cls.database , table_name=cls.table_name)) return cur.rowcount == 1 @classproperty diff --git a/datajoint/user_relations.py b/datajoint/user_relations.py index 60e013e6f..62c9951bc 100644 --- a/datajoint/user_relations.py +++ b/datajoint/user_relations.py @@ -1,31 +1,27 @@ -from datajoint.relation import Relation +from datajoint.relation import Relation, classproperty from .autopopulate import AutoPopulate from .utils import from_camel_case class Manual(Relation): - @property - @classmethod + @classproperty def table_name(cls): return from_camel_case(cls.__name__) class Lookup(Relation): - @property - @classmethod + @classproperty def table_name(cls): return '#' + from_camel_case(cls.__name__) class Imported(Relation, AutoPopulate): - @property - @classmethod + @classproperty def table_name(cls): return "_" + from_camel_case(cls.__name__) class Computed(Relation, AutoPopulate): - @property - @classmethod + @classproperty def table_name(cls): return "__" + from_camel_case(cls.__name__) \ No newline at end of file diff --git a/tests/schemata/__init__.py b/tests/schemata/__init__.py index 6f391d065..9fa8b9ad1 100644 --- a/tests/schemata/__init__.py +++ b/tests/schemata/__init__.py @@ -1 +1 @@ -__author__ = "eywalker" \ No newline at end of file +__author__ = "eywalker, fabiansinz" \ No newline at end of file diff --git a/tests/schemata/schema1/test1.py b/tests/schemata/test1.py similarity index 89% rename from tests/schemata/schema1/test1.py rename to tests/schemata/test1.py index d0d91707f..5b4ac723a 100644 --- a/tests/schemata/schema1/test1.py +++ b/tests/schemata/test1.py @@ -1,22 +1,25 @@ -# """ -# Test 1 Schema definition -# """ -# __author__ = 'eywalker' -# -# import datajoint as dj +""" +Test 1 Schema definition +""" +__author__ = 'eywalker' + +import datajoint as dj # from .. import schema2 -# -# -# class Subjects(dj.Relation): -# definition = """ -# test1.Subjects (manual) # Basic subject info -# -# subject_id : int # unique subject id -# --- -# real_id : varchar(40) # real-world name -# species = "mouse" : enum('mouse', 'monkey', 'human') # species -# """ -# +from .. import PREFIX + +testschema = dj.schema(PREFIX + '_test1', locals()) + +@testschema +class Subjects(dj.Manual): + definition = """ + # Basic subject info + + subject_id : int # unique subject id + --- + real_id : varchar(40) # real-world name + species = "mouse" : enum('mouse', 'monkey', 'human') # species + """ + # # test for shorthand # class Animals(dj.Relation): # definition = """ diff --git a/tests/test_relation.py b/tests/test_relation.py index 94adef010..76f87f707 100644 --- a/tests/test_relation.py +++ b/tests/test_relation.py @@ -4,6 +4,12 @@ # __author__ = 'fabee' # # from .schemata.schema1 import test1, test4 +from .schemata.test1 import Subjects + + +def test_instantiate_relation(): + s = Subjects() + # # from . import BASE_CONN, CONN_INFO, PREFIX, cleanup # from datajoint.connection import Connection From 2f0971ad4c892d722441528597df98b18b7fe286 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Tue, 9 Jun 2015 09:49:24 -0500 Subject: [PATCH 08/20] implemented lookup_name using eval --- datajoint/autopopulate.py | 10 +++++----- datajoint/relation.py | 18 ++---------------- demos/demo1.py | 5 ----- 3 files changed, 7 insertions(+), 26 deletions(-) diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index ba349b250..5db012eff 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -53,29 +53,29 @@ def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False, if not isinstance(self.populate_relation, RelationalOperand): raise DataJointError('Invalid populate_relation value') - self.conn.cancel_transaction() # rollback previous transaction, if any + self.connection.cancel_transaction() # rollback previous transaction, if any if not isinstance(self, Relation): raise DataJointError('Autopopulate is a mixin for Relation and must therefore subclass Relation') unpopulated = (self.populate_relation - self.target) & restriction for key in unpopulated.project(): - self.conn.start_transaction() + self.connection.start_transaction() if key in self.target: # already populated - self.conn.cancel_transaction() + self.connection.cancel_transaction() else: logger.info('Populating: ' + str(key)) try: self._make_tuples(dict(key)) except Exception as error: - self.conn.cancel_transaction() + self.connection.cancel_transaction() if not suppress_errors: raise else: logger.error(error) error_list.append((key, error)) else: - self.conn.commit_transaction() + self.connection.commit_transaction() logger.info('Done populating.') return error_list diff --git a/datajoint/relation.py b/datajoint/relation.py index 3c48690d9..4588040a1 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -21,7 +21,6 @@ ('database', 'context', 'connection', 'heading')) - class classproperty: def __init__(self, getf): self._getf = getf @@ -88,7 +87,7 @@ class Relation(RelationalOperand, metaclass=abc.ABCMeta): _shared_info = None - def __init__(self): # TODO: Think about it + def __init__(self): if self._shared_info is None: raise DataJointError('The class must define _shared_info') @@ -478,21 +477,8 @@ def _parse_declaration(cls): def lookup_name(cls, name): """ Lookup the referenced name in the context dictionary - - e.g. for reference `common.Animals`, it will first check if `context` dictionary contains key - `common`. If found, it then checks for attribute `Animals` in `common`, and returns the result. """ - parts = name.strip().split('.') - try: - ref = cls.context.get(parts[0]) - for attr in parts[1:]: - ref = getattr(ref, attr) - except (KeyError, AttributeError): - raise DataJointError( - 'Foreign key reference to %s could not be resolved.' - 'Please make sure the name exists' - 'in the context of the class' % name) - return ref + return eval(name, locals=cls.context) @classproperty def connection(cls): diff --git a/demos/demo1.py b/demos/demo1.py index e85d6ead3..4376a4982 100644 --- a/demos/demo1.py +++ b/demos/demo1.py @@ -1,9 +1,4 @@ # -*- coding: utf-8 -*- -""" -Created on Tue Aug 26 17:42:52 2014 - -@author: dimitri -""" import datajoint as dj From 6265e6877e87415abbabf146e63cc0436ef54811 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Tue, 9 Jun 2015 09:55:22 -0500 Subject: [PATCH 09/20] moved from_camel_case to user_relations.py --- datajoint/user_relations.py | 24 ++++++++++++++++++++++-- datajoint/utils.py | 34 ---------------------------------- 2 files changed, 22 insertions(+), 36 deletions(-) diff --git a/datajoint/user_relations.py b/datajoint/user_relations.py index 62c9951bc..8348b6ccc 100644 --- a/datajoint/user_relations.py +++ b/datajoint/user_relations.py @@ -1,7 +1,8 @@ +import re from datajoint.relation import Relation, classproperty from .autopopulate import AutoPopulate from .utils import from_camel_case - +from . import DataJointError class Manual(Relation): @classproperty @@ -24,4 +25,23 @@ def table_name(cls): class Computed(Relation, AutoPopulate): @classproperty def table_name(cls): - return "__" + from_camel_case(cls.__name__) \ No newline at end of file + return "__" + from_camel_case(cls.__name__) + + + +def from_camel_case(s): + """ + Convert names in camel case into underscore (_) separated names + + Example: + >>>from_camel_case("TableName") + "table_name" + """ + def convert(match): + return ('_' if match.groups()[0] else '') + match.group(0).lower() + + if not re.match(r'[A-Z][a-zA-Z0-9]*', s): + raise DataJointError( + 'ClassName must be alphanumeric in CamelCase, begin with a capital letter') + return re.sub(r'(\B[A-Z])|(\b[A-Z])', convert, s) + diff --git a/datajoint/utils.py b/datajoint/utils.py index ec506472e..f4b0edb57 100644 --- a/datajoint/utils.py +++ b/datajoint/utils.py @@ -1,37 +1,3 @@ -import re -from . import DataJointError - - -def to_camel_case(s): - """ - Convert names with under score (_) separation - into camel case names. - - Example: - >>>to_camel_case("table_name") - "TableName" - """ - def to_upper(match): - return match.group(0)[-1].upper() - return re.sub('(^|[_\W])+[a-zA-Z]', to_upper, s) - - -def from_camel_case(s): - """ - Convert names in camel case into underscore (_) separated names - - Example: - >>>from_camel_case("TableName") - "table_name" - """ - def convert(match): - return ('_' if match.groups()[0] else '') + match.group(0).lower() - - if not re.match(r'[A-Z][a-zA-Z0-9]*', s): - raise DataJointError( - 'ClassName must be alphanumeric in CamelCase, begin with a capital letter') - return re.sub(r'(\B[A-Z])|(\b[A-Z])', convert, s) - def user_choice(prompt, choices=("yes", "no"), default=None): """ From a48e66fe5650ee2f15a5812fad6c8eca141e6671 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Tue, 9 Jun 2015 10:22:16 -0500 Subject: [PATCH 10/20] minor cleanup --- datajoint/relation.py | 86 +++++++++++++------------------------ datajoint/user_relations.py | 2 - 2 files changed, 31 insertions(+), 57 deletions(-) diff --git a/datajoint/relation.py b/datajoint/relation.py index 4588040a1..8c2ebdc2b 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -16,8 +16,8 @@ logger = logging.getLogger(__name__) -SharedInfo = namedtuple( - 'SharedInfo', +TableInfo = namedtuple( + 'TableInfo', ('database', 'context', 'connection', 'heading')) @@ -60,11 +60,11 @@ def schema(database, context, connection=None): " permissions.".format(database=database)) def decorator(cls): - cls._shared_info = SharedInfo( + cls._table_info = TableInfo( database=database, context=context, connection=connection, - heading=None, + heading=None ) cls.declare() return cls @@ -85,11 +85,28 @@ class Relation(RelationalOperand, metaclass=abc.ABCMeta): __heading = None - _shared_info = None + _table_info = None def __init__(self): - if self._shared_info is None: - raise DataJointError('The class must define _shared_info') + if self._table_info is None: + raise DataJointError('The class must define _table_info') + + @classproperty + def connection(cls): + """ + Returns the connection object of the class + + :return: the connection object + """ + return cls._table_info.connection + + @classproperty + def database(cls): + return cls._table_info.database + + @classproperty + def context(cls): + return cls._table_info.context # ---------- abstract properties ------------ # @classproperty @@ -100,14 +117,6 @@ def table_name(cls): """ pass - # @classproperty - # @abc.abstractmethod - # def database(cls): - # """ - # :return: string containing the database name on the server - # """ - # pass - @classproperty @abc.abstractmethod def definition(cls): @@ -116,20 +125,12 @@ def definition(cls): """ pass - # @classproperty - # @abc.abstractmethod - # def context(cls): - # """ - # :return: a dict with other relations that can be referenced by foreign keys - # """ - # pass - - # --------- base relation functionality --------- # + # --------- SQL functionality --------- # @classproperty def is_declared(cls): if cls.__heading is not None: return True - cur = cls._shared_info.connection.query( + cur = cls._table_info.connection.query( 'SHOW TABLE STATUS FROM `{database}` WHERE name="{table_name}"'.format( database=cls.database , table_name=cls.table_name)) return cur.rowcount == 1 @@ -151,7 +152,7 @@ def from_clause(cls): for the SQL SELECT statements. :return: """ - return '`%s`.`%s`' % (cls.database, cls.table_name) + return cls.full_table_name @classmethod def declare(cls): @@ -413,13 +414,13 @@ def _declare(cls): implicit_indices.append(fk_source.primary_key) # for index in indexDefs: - # TODO: finish this up... + # TODO: add index declaration # close the declaration sql = '%s\n) ENGINE = InnoDB, COMMENT "%s"' % ( sql[:-2], table_info['comment']) - # # make sure that the table does not alredy exist + # # make sure that the table does not already exist # cls.load_heading() # if not cls.is_declared: # # execute declaration @@ -462,37 +463,12 @@ def _parse_declaration(cls): # foreign key ref_name = line[2:].strip() ref_list = parents if in_key else referenced - ref_list.append(cls.lookup_name(ref_name)) + ref_list.append(eval(ref_name, locals=cls.context)) elif re.match(r'^(unique\s+)?index[^:]*$', line, re.I): index_defs.append(parse_index_definition(line)) elif attribute_regexp.match(line): field_defs.append(parse_attribute_definition(line, in_key)) else: - raise DataJointError( - 'Invalid table declaration line "%s"' % line) + raise DataJointError('Invalid table declaration line "%s"' % line) return table_info, parents, referenced, field_defs, index_defs - - @classmethod - def lookup_name(cls, name): - """ - Lookup the referenced name in the context dictionary - """ - return eval(name, locals=cls.context) - - @classproperty - def connection(cls): - """ - Returns the connection object of the class - - :return: the connection object - """ - return cls._shared_info.connection - - @classproperty - def database(cls): - return cls._shared_info.database - - @classproperty - def context(cls): - return cls._shared_info.context diff --git a/datajoint/user_relations.py b/datajoint/user_relations.py index 8348b6ccc..a615247ee 100644 --- a/datajoint/user_relations.py +++ b/datajoint/user_relations.py @@ -1,7 +1,6 @@ import re from datajoint.relation import Relation, classproperty from .autopopulate import AutoPopulate -from .utils import from_camel_case from . import DataJointError class Manual(Relation): @@ -28,7 +27,6 @@ def table_name(cls): return "__" + from_camel_case(cls.__name__) - def from_camel_case(s): """ Convert names in camel case into underscore (_) separated names From 622fd5607978aa5b20917143693c7e63fd7b3ddb Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Tue, 9 Jun 2015 11:15:24 -0500 Subject: [PATCH 11/20] intermediate: removed classmethods from Relation. Moved all declaration functionality into declare.py --- datajoint/declare.py | 211 ++++++++++++++++++++++++++ datajoint/parsing.py | 88 ----------- datajoint/relation.py | 295 ++++++++---------------------------- datajoint/user_relations.py | 6 +- requirements.txt | 1 + 5 files changed, 273 insertions(+), 328 deletions(-) create mode 100644 datajoint/declare.py delete mode 100644 datajoint/parsing.py diff --git a/datajoint/declare.py b/datajoint/declare.py new file mode 100644 index 000000000..39d375701 --- /dev/null +++ b/datajoint/declare.py @@ -0,0 +1,211 @@ +import re +import pyparsing as pp +import logging + +from . import DataJointError + + +logger = logging.getLogger(__name__) + + +def compile_attribute(line, in_key=False): + """ + Convert attribute definition from DataJoint format to SQL + :param line: attribution line + :param in_key: set to True if attribute is in primary key set + :returns: attribute name and sql code for its declaration + """ + quoted = pp.Or(pp.QuotedString('"'), pp.QuotedString("'")) + colon = pp.Literal(':').suppress() + attribute_name = pp.Word(pp.srange('[a-z]'), pp.srange('[a-z0-9_]')).setResultsName('name') + + data_type = pp.Combine(pp.Word(pp.alphas)+pp.SkipTo("#", ignore=quoted)).setResultsName('type') + default = pp.Literal('=').suppress() + pp.SkipTo(colon, ignore=quoted).setResultsName('default') + comment = pp.Literal('#').suppress() + pp.restOfLine.setResultsName('comment') + + attribute_parser = attribute_name + pp.Optional(default) + colon + data_type + comment + + match = attribute_parser.parseString(line+'#', parseAll=True) + match['comment'] = match['comment'].rstrip('#') + if 'default' not in match: + match['default'] = '' + match = {k: v.strip() for k, v in match.items()} + match['nullable'] = match['default'].lower() == 'null' + + sql_literals = ['CURRENT_TIMESTAMP'] # not to be enclosed in quotes + assert not re.match(r'^bigint', match['type'], re.I) or not match['nullable'], \ + 'BIGINT attributes cannot be nullable in "%s"' % line # TODO: This was a MATLAB limitation. Handle this correctly. + if match['nullable']: + if in_key: + raise DataJointError('Primary key attributes cannot be nullable in line %s' % line) + match['default'] = 'DEFAULT NULL' # nullable attributes default to null + else: + if match['default']: + quote = match['default'].upper() not in sql_literals and match['default'][0] not in '"\'' + match['default'] = ('NOT NULL DEFAULT ' + + ('"%s"' if quote else "%s") % match['default']) + else: + match['default'] = 'NOT NULL' + match['comment'] = match['comment'].replace('"','\\"') # escape double quotes in comment + sql = ('`{name}` {type} {default}' + (' COMMENT "{comment}"' if match['comment'] else '') + ).format(**match) + return match['name'], sql + + +def parse_index(line): + """ + Parses index definition. + + :param line: definition line + :return: groupdict with index info + """ + line = line.strip() + index_regexp = re.compile(""" + ^(?PUNIQUE)?\s*INDEX\s* # [UNIQUE] INDEX + \((?P[^\)]+)\)$ # (attr1, attr2) + """, re.I + re.X) + m = index_regexp.match(line) + assert m, 'Invalid index declaration "%s"' % line + index_info = m.groupdict() + attributes = re.split(r'\s*,\s*', index_info['attributes'].strip()) + index_info['attributes'] = attributes + assert len(attributes) == len(set(attributes)), \ + 'Duplicate attributes in index declaration "%s"' % line + return index_info + + +def parse_declaration(cls): + """ + Parse declaration and create new SQL table accordingly. + """ + parents = [] + referenced = [] + index_defs = [] + field_defs = [] + declaration = re.split(r'\s*\n\s*', cls.definition.strip()) + + # remove comment lines + declaration = [x for x in declaration if not x.startswith('#')] + ptrn = """ + \#\s*(?P.*)$ # comment + """ + p = re.compile(ptrn, re.X) + table_info = p.search(declaration[0]).groupdict() + + #table_info['tier'] = Role[table_info['tier']] # convert into enum + + in_key = True # parse primary keys + attribute_regexp = re.compile(""" + ^[a-z][a-z\d_]*\s* # name + (=\s*\S+(\s+\S+)*\s*)? # optional defaults + :\s*\w.*$ # type, comment + """, re.I + re.X) # ignore case and verbose + + for line in declaration[1:]: + if line.startswith('---'): + in_key = False # start parsing non-PK fields + elif line.startswith('->'): + # foreign key + ref_name = line[2:].strip() + ref_list = parents if in_key else referenced + ref_list.append(eval(ref_name, locals=cls.context)) + elif re.match(r'^(unique\s+)?index[^:]*$', line, re.I): + index_defs.append(parse_index_definition(line)) + elif attribute_regexp.match(line): + field_defs.append(parse_attribute_definition(line, in_key)) + else: + raise DataJointError('Invalid table declaration line "%s"' % line) + + return table_info, parents, referenced, field_defs, index_defs + + +def declare(base_relation): + """ + Declares the table in the database if it does not exist already + """ + cur = base_relation.connection.query( + 'SHOW TABLE STATUS FROM `{database}` WHERE name="{table_name}"'.format( + database=base_relation.database, table_name=base_relation.table_name)) + if cur.rowcount: + return + + if base_relation.connection.in_transaction: + raise DataJointError("Tables cannot be declared during a transaction.") + + if not base_relation.definition: + raise DataJointError('Table definition is missing.') + + table_info, parents, referenced, field_defs, index_defs = cls._parse_declaration() + + sql = 'CREATE TABLE %s (\n' % cls.full_table_name + + # add inherited primary key fields + primary_key_fields = set() + non_key_fields = set() + for p in parents: + for key in p.primary_key: + field = p.heading[key] + if field.name not in primary_key_fields: + primary_key_fields.add(field.name) + sql += field_to_sql(field) + else: + logger.debug('Field definition of {} in {} ignored'.format( + field.name, p.full_class_name)) + + # add newly defined primary key fields + for field in (f for f in field_defs if f.in_key): + if field.nullable: + raise DataJointError('Primary key attribute {} cannot be nullable'.format( + field.name)) + if field.name in primary_key_fields: + raise DataJointError('Duplicate declaration of the primary attribute {key}. ' + 'Ensure that the attribute is not already declared ' + 'in referenced tables'.format(key=field.name)) + primary_key_fields.add(field.name) + sql += field_to_sql(field) + + # add secondary foreign key attributes + for r in referenced: + for key in r.primary_key: + field = r.heading[key] + if field.name not in primary_key_fields | non_key_fields: + non_key_fields.add(field.name) + sql += field_to_sql(field) + + # add dependent attributes + for field in (f for f in field_defs if not f.in_key): + non_key_fields.add(field.name) + sql += field_to_sql(field) + + # add primary key declaration + assert len(primary_key_fields) > 0, 'table must have a primary key' + keys = ', '.join(primary_key_fields) + sql += 'PRIMARY KEY (%s),\n' % keys + + # add foreign key declarations + for ref in parents + referenced: + keys = ', '.join(ref.primary_key) + sql += 'FOREIGN KEY (%s) REFERENCES %s (%s) ON UPDATE CASCADE ON DELETE RESTRICT,\n' % \ + (keys, ref.full_table_name, keys) + + # add secondary index declarations + # gather implicit indexes due to foreign keys first + implicit_indices = [] + for fk_source in parents + referenced: + implicit_indices.append(fk_source.primary_key) + + # for index in indexDefs: + # TODO: add index declaration + + # close the declaration + sql = '%s\n) ENGINE = InnoDB, COMMENT "%s"' % ( + sql[:-2], table_info['comment']) + + # # make sure that the table does not already exist + # cls.load_heading() + # if not cls.is_declared: + # # execute declaration + # logger.debug('\n\n' + sql + '\n\n') + # cls.connection.query(sql) + # cls.load_heading() + diff --git a/datajoint/parsing.py b/datajoint/parsing.py deleted file mode 100644 index 85e367c96..000000000 --- a/datajoint/parsing.py +++ /dev/null @@ -1,88 +0,0 @@ -import re -from . import DataJointError -from .heading import Heading - - -def parse_attribute_definition(line, in_key=False): - """ - Parse attribute definition line in the declaration and returns - an attribute tuple. - - :param line: attribution line - :param in_key: set to True if attribute is in primary key set - :returns: attribute tuple - """ - line = line.strip() - attribute_regexp = re.compile(""" - ^(?P[a-z][a-z\d_]*)\s* # field name - (=\s*(?P\S+(\s+\S+)*?)\s*)? # default value - :\s*(?P\w[^\#]*[^\#\s])\s* # datatype - (\#\s*(?P\S*(\s+\S+)*)\s*)?$ # comment - """, re.X) - m = attribute_regexp.match(line) - if not m: - raise DataJointError('Invalid field declaration "%s"' % line) - attr_info = m.groupdict() - if not attr_info['comment']: - attr_info['comment'] = '' - if not attr_info['default']: - attr_info['default'] = '' - attr_info['nullable'] = attr_info['default'].lower() == 'null' - assert (not re.match(r'^bigint', attr_info['type'], re.I) or not attr_info['nullable']), \ - 'BIGINT attributes cannot be nullable in "%s"' % line - - return Heading.AttrTuple( - in_key=in_key, - autoincrement=None, - numeric=None, - string=None, - is_blob=None, - computation=None, - dtype=None, - **attr_info - ) - - -def field_to_sql(field): # TODO move this into Attribute Tuple - """ - Converts an attribute definition tuple into SQL code. - :param field: attribute definition - :rtype : SQL code - """ - mysql_constants = ['CURRENT_TIMESTAMP'] - if field.nullable: - default = 'DEFAULT NULL' - else: - default = 'NOT NULL' - # if some default specified - if field.default: - # enclose value in quotes except special SQL values or already enclosed - quote = field.default.upper() not in mysql_constants and field.default[0] not in '"\'' - default += ' DEFAULT ' + ('"%s"' if quote else "%s") % field.default - if any((c in r'\"' for c in field.comment)): - raise DataJointError('Illegal characters in attribute comment "%s"' % field.comment) - - return '`{name}` {type} {default} COMMENT "{comment}",\n'.format( - name=field.name, type=field.type, default=default, comment=field.comment) - - -def parse_index_definition(line): - """ - Parses index definition. - - :param line: definition line - :return: groupdict with index info - """ - line = line.strip() - index_regexp = re.compile(""" - ^(?PUNIQUE)?\s*INDEX\s* # [UNIQUE] INDEX - \((?P[^\)]+)\)$ # (attr1, attr2) - """, re.I + re.X) - m = index_regexp.match(line) - assert m, 'Invalid index declaration "%s"' % line - index_info = m.groupdict() - attributes = re.split(r'\s*,\s*', index_info['attributes'].strip()) - index_info['attributes'] = attributes - assert len(attributes) == len(set(attributes)), \ - 'Duplicate attributes in index declaration "%s"' % line - return index_info diff --git a/datajoint/relation.py b/datajoint/relation.py index 8c2ebdc2b..7b7acd5b6 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -1,47 +1,33 @@ from collections import namedtuple -from collections.abc import MutableMapping, Mapping +from collections.abc import Mapping import numpy as np import logging -import re import abc - -from . import DataJointError, config import pymysql -from datajoint import DataJointError, conn + + +from . import DataJointError, config, conn +from .declare import declare from .relational_operand import RelationalOperand from .blob import pack from .utils import user_choice -from .parsing import parse_attribute_definition, field_to_sql, parse_index_definition from .heading import Heading logger = logging.getLogger(__name__) -TableInfo = namedtuple( - 'TableInfo', +TableLink = namedtuple( + 'TableLink', ('database', 'context', 'connection', 'heading')) -class classproperty: - def __init__(self, getf): - self._getf = getf - - def __get__(self, instance, owner): - return self._getf(owner) - - def schema(database, context, connection=None): """ - Returns a schema decorator that can be used to associate a Relation class to a - specific database with :param name:. Name reference to other tables in the table definition - will be resolved by looking up the corresponding key entry in the passed in context dictionary. - It is most common to set context equal to the return value of call to locals() in the module. - For more details, please refer to the tutorial online. + Returns a decorator that can be used to associate a Relation class to a database. :param database: name of the database to associate the decorated class with - :param context: dictionary used to resolve (any) name references within the table definition string - :param connection: connection object to the database server. If ommited, will try to establish connection according to - config values - :return: a decorator function to be used on Relation derivative classes + :param context: dictionary for looking up foreign keys references, usually set to locals() + :param connection: Connection object. Defaults to datajoint.conn() + :return: a decorator for Relation subclasses """ if connection is None: connection = conn() @@ -60,13 +46,16 @@ def schema(database, context, connection=None): " permissions.".format(database=database)) def decorator(cls): - cls._table_info = TableInfo( + """ + The decorator declares the table and binds the class to the database table + """ + cls._table_info = TableLink( database=database, context=context, connection=connection, heading=None ) - cls.declare() + declare(cls()) return cls return decorator @@ -83,112 +72,82 @@ class Relation(RelationalOperand, metaclass=abc.ABCMeta): It also handles the table declaration based on its definition property """ - __heading = None - _table_info = None def __init__(self): if self._table_info is None: raise DataJointError('The class must define _table_info') - @classproperty - def connection(cls): - """ - Returns the connection object of the class - - :return: the connection object - """ - return cls._table_info.connection - - @classproperty - def database(cls): - return cls._table_info.database - - @classproperty - def context(cls): - return cls._table_info.context - # ---------- abstract properties ------------ # - @classproperty + @property @abc.abstractmethod - def table_name(cls): + def table_name(self): """ :return: the name of the table in the database """ pass - @classproperty + @property @abc.abstractmethod - def definition(cls): + def definition(self): """ :return: a string containing the table definition using the DataJoint DDL """ pass - # --------- SQL functionality --------- # - @classproperty - def is_declared(cls): - if cls.__heading is not None: - return True - cur = cls._table_info.connection.query( - 'SHOW TABLE STATUS FROM `{database}` WHERE name="{table_name}"'.format( - database=cls.database , table_name=cls.table_name)) - return cur.rowcount == 1 - - @classproperty - def heading(cls): - """ - Required by relational operand - :return: a datajoint.Heading object - """ - if cls.__heading is None: - cls.__heading = Heading.init_from_database(cls.connection, cls.database, cls.table_name) - return cls.__heading + # -------------- table info ----------------- # + @property + def connection(self): + return self._table_info.connection + + @property + def database(self): + return self._table_info.database + + @property + def context(self): + return self._table_info.context + + @property + def heading(self): + if self._table_info.heading is None: + self._table_info.heading = Heading.init_from_database( + self.connection, self.database, self.table_name) + return self._table_info.heading - @classproperty - def from_clause(cls): + + # --------- SQL functionality --------- # + @property + def from_clause(self): """ Required by the Relational class, this property specifies the contents of the FROM clause for the SQL SELECT statements. :return: """ - return cls.full_table_name - - @classmethod - def declare(cls): - """ - Declare the table in database if it doesn't already exist. + return self.full_table_name - :raises: DataJointError if the table cannot be declared. - """ - if not cls.is_declared: - cls._declare() - - @classmethod - def iter_insert(cls, rows, **kwargs): + def iter_insert(self, rows, **kwargs): """ Inserts an entire batch of entries. Additional keyword arguments are passed to insert. :param iter: Must be an iterator that generates a sequence of valid arguments for insert. """ for row in rows: - cls.insert(row, **kwargs) + self.insert(row, **kwargs) - @classmethod - def batch_insert(cls, data, **kwargs): + def batch_insert(self, data, **kwargs): """ Inserts an entire batch of entries. Additional keyword arguments are passed to insert. :param data: must be iterable, each row must be a valid argument for insert """ - cls.iter_insert(data.__iter__(), **kwargs) + self.iter_insert(data.__iter__(), **kwargs) - @classproperty - def full_table_name(cls): - return r"`{0:s}`.`{1:s}`".format(cls.database, cls.table_name) + def full_table_name(self): + return r"`{0:s}`.`{1:s}`".format(self.database, self.table_name) @classmethod - def insert(cls, tup, ignore_errors=False, replace=False): + def insert(self, tup, ignore_errors=False, replace=False): """ Insert one data record or one Mapping (like a dictionary). @@ -203,7 +162,7 @@ def insert(cls, tup, ignore_errors=False, replace=False): real_id = 1007, date_of_birth = "2014-09-01")) """ - heading = cls.heading + heading = self.heading if isinstance(tup, np.void): for fieldname in tup.dtype.fields: if fieldname not in heading: @@ -233,10 +192,9 @@ def insert(cls, tup, ignore_errors=False, replace=False): sql = 'INSERT IGNORE' else: sql = 'INSERT' - sql += " INTO %s (%s) VALUES (%s)" % (cls.full_table_name, - attribute_list, value_list) + sql += " INTO %s (%s) VALUES (%s)" % (self.from_caluse, attribute_list, value_list) logger.info(sql) - cls.connection.query(sql, args=args) + self.connection.query(sql, args=args) def delete(self): if not config['safemode'] or user_choice( @@ -244,22 +202,20 @@ def delete(self): "Proceed?", default='no') == 'yes': self.connection.query('DELETE FROM ' + self.from_clause + self.where_clause) # TODO: make cascading (issue #15) - @classmethod - def drop(cls): + def drop(self): """ Drops the table associated to this class. """ - if cls.is_declared: + if self.is_declared: if not config['safemode'] or user_choice( "You are about to drop an entire table. This operation cannot be undone.\n" "Proceed?", default='no') == 'yes': - cls.connection.query('DROP TABLE %s' % cls.full_table_name) # TODO: make cascading (issue #16) + self.connection.query('DROP TABLE %s' % cls.full_table_name) # TODO: make cascading (issue #16) # cls.connection.clear_dependencies(dbname=cls.dbname) #TODO: reimplement because clear_dependencies will be gone # cls.connection.load_headings(dbname=cls.dbname, force=True) #TODO: reimplement because load_headings is gone logger.info("Dropped table %s" % cls.full_table_name) - @classproperty - def size_on_disk(cls): + def size_on_disk(self): """ :return: size of data and indices in MiB taken by the table on the storage device """ @@ -335,140 +291,9 @@ def _alter(cls, alter_statement): :param alter_statement: alter statement """ if cls.connection.in_transaction: - raise DataJointError( - u"_alter is currently in transaction. Operation not allowed to avoid implicit commits.") + raise DataJointError("Table definition cannot be altered during a transaction.") sql = 'ALTER TABLE %s %s' % (cls.full_table_name, alter_statement) cls.connection.query(sql) cls.connection.load_headings(cls.dbname, force=True) - # TODO: place table definition sync mechanism - - @classmethod - def _declare(cls): - """ - Declares the table in the database if no table in the database matches this object. - """ - if cls.connection.in_transaction: - raise DataJointError( - u"_declare is currently in transaction. Operation not allowed to avoid implicit commits.") - - if not cls.definition: # if empty definition was supplied - raise DataJointError('Table definition is missing!') - table_info, parents, referenced, field_defs, index_defs = cls._parse_declaration() - - sql = 'CREATE TABLE %s (\n' % cls.full_table_name - - # add inherited primary key fields - primary_key_fields = set() - non_key_fields = set() - for p in parents: - for key in p.primary_key: - field = p.heading[key] - if field.name not in primary_key_fields: - primary_key_fields.add(field.name) - sql += field_to_sql(field) - else: - logger.debug('Field definition of {} in {} ignored'.format( - field.name, p.full_class_name)) - - # add newly defined primary key fields - for field in (f for f in field_defs if f.in_key): - if field.nullable: - raise DataJointError('Primary key attribute {} cannot be nullable'.format( - field.name)) - if field.name in primary_key_fields: - raise DataJointError('Duplicate declaration of the primary attribute {key}. ' - 'Ensure that the attribute is not already declared ' - 'in referenced tables'.format(key=field.name)) - primary_key_fields.add(field.name) - sql += field_to_sql(field) - - # add secondary foreign key attributes - for r in referenced: - for key in r.primary_key: - field = r.heading[key] - if field.name not in primary_key_fields | non_key_fields: - non_key_fields.add(field.name) - sql += field_to_sql(field) - - # add dependent attributes - for field in (f for f in field_defs if not f.in_key): - non_key_fields.add(field.name) - sql += field_to_sql(field) - - # add primary key declaration - assert len(primary_key_fields) > 0, 'table must have a primary key' - keys = ', '.join(primary_key_fields) - sql += 'PRIMARY KEY (%s),\n' % keys - - # add foreign key declarations - for ref in parents + referenced: - keys = ', '.join(ref.primary_key) - sql += 'FOREIGN KEY (%s) REFERENCES %s (%s) ON UPDATE CASCADE ON DELETE RESTRICT,\n' % \ - (keys, ref.full_table_name, keys) - - # add secondary index declarations - # gather implicit indexes due to foreign keys first - implicit_indices = [] - for fk_source in parents + referenced: - implicit_indices.append(fk_source.primary_key) - - # for index in indexDefs: - # TODO: add index declaration - - # close the declaration - sql = '%s\n) ENGINE = InnoDB, COMMENT "%s"' % ( - sql[:-2], table_info['comment']) - - # # make sure that the table does not already exist - # cls.load_heading() - # if not cls.is_declared: - # # execute declaration - # logger.debug('\n\n' + sql + '\n\n') - # cls.connection.query(sql) - # cls.load_heading() - - @classmethod - def _parse_declaration(cls): - """ - Parse declaration and create new SQL table accordingly. - """ - parents = [] - referenced = [] - index_defs = [] - field_defs = [] - declaration = re.split(r'\s*\n\s*', cls.definition.strip()) - - # remove comment lines - declaration = [x for x in declaration if not x.startswith('#')] - ptrn = """ - \#\s*(?P.*)$ # comment - """ - p = re.compile(ptrn, re.X) - table_info = p.search(declaration[0]).groupdict() - - #table_info['tier'] = Role[table_info['tier']] # convert into enum - - in_key = True # parse primary keys - attribute_regexp = re.compile(""" - ^[a-z][a-z\d_]*\s* # name - (=\s*\S+(\s+\S+)*\s*)? # optional defaults - :\s*\w.*$ # type, comment - """, re.I + re.X) # ignore case and verbose - - for line in declaration[1:]: - if line.startswith('---'): - in_key = False # start parsing non-PK fields - elif line.startswith('->'): - # foreign key - ref_name = line[2:].strip() - ref_list = parents if in_key else referenced - ref_list.append(eval(ref_name, locals=cls.context)) - elif re.match(r'^(unique\s+)?index[^:]*$', line, re.I): - index_defs.append(parse_index_definition(line)) - elif attribute_regexp.match(line): - field_defs.append(parse_attribute_definition(line, in_key)) - else: - raise DataJointError('Invalid table declaration line "%s"' % line) - - return table_info, parents, referenced, field_defs, index_defs + # TODO: place table definition sync mechanism \ No newline at end of file diff --git a/datajoint/user_relations.py b/datajoint/user_relations.py index a615247ee..a3429f457 100644 --- a/datajoint/user_relations.py +++ b/datajoint/user_relations.py @@ -1,28 +1,24 @@ import re -from datajoint.relation import Relation, classproperty +from datajoint.relation import Relation from .autopopulate import AutoPopulate from . import DataJointError class Manual(Relation): - @classproperty def table_name(cls): return from_camel_case(cls.__name__) class Lookup(Relation): - @classproperty def table_name(cls): return '#' + from_camel_case(cls.__name__) class Imported(Relation, AutoPopulate): - @classproperty def table_name(cls): return "_" + from_camel_case(cls.__name__) class Computed(Relation, AutoPopulate): - @classproperty def table_name(cls): return "__" + from_camel_case(cls.__name__) diff --git a/requirements.txt b/requirements.txt index 82abfa7ea..af4d48f13 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ numpy pymysql +pyparsing networkx matplotlib sphinx_rtd_theme From 3b8420e922919165e5fd3081cefec5241d389dcf Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Tue, 9 Jun 2015 11:31:58 -0500 Subject: [PATCH 12/20] amending, removed remaining classmethods from Relation --- datajoint/declare.py | 2 +- datajoint/relation.py | 60 ++++++++++++++++--------------------- datajoint/user_relations.py | 16 +++++----- 3 files changed, 35 insertions(+), 43 deletions(-) diff --git a/datajoint/declare.py b/datajoint/declare.py index 39d375701..bde62a7e4 100644 --- a/datajoint/declare.py +++ b/datajoint/declare.py @@ -13,7 +13,7 @@ def compile_attribute(line, in_key=False): Convert attribute definition from DataJoint format to SQL :param line: attribution line :param in_key: set to True if attribute is in primary key set - :returns: attribute name and sql code for its declaration + :returns: (name, sql) -- attribute name and sql code for its declaration """ quoted = pp.Or(pp.QuotedString('"'), pp.QuotedString("'")) colon = pp.Literal(':').suppress() diff --git a/datajoint/relation.py b/datajoint/relation.py index 7b7acd5b6..5533853ee 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -5,13 +5,13 @@ import abc import pymysql - from . import DataJointError, config, conn from .declare import declare from .relational_operand import RelationalOperand from .blob import pack from .utils import user_choice from .heading import Heading +from .declare import compile_attribute logger = logging.getLogger(__name__) @@ -210,7 +210,7 @@ def drop(self): if not config['safemode'] or user_choice( "You are about to drop an entire table. This operation cannot be undone.\n" "Proceed?", default='no') == 'yes': - self.connection.query('DROP TABLE %s' % cls.full_table_name) # TODO: make cascading (issue #16) + self.connection.query('DROP TABLE %s' % self.full_table_name) # TODO: make cascading (issue #16) # cls.connection.clear_dependencies(dbname=cls.dbname) #TODO: reimplement because clear_dependencies will be gone # cls.connection.load_headings(dbname=cls.dbname, force=True) #TODO: reimplement because load_headings is gone logger.info("Dropped table %s" % cls.full_table_name) @@ -219,22 +219,20 @@ def size_on_disk(self): """ :return: size of data and indices in MiB taken by the table on the storage device """ - cur = cls.connection.query( - 'SHOW TABLE STATUS FROM `{dbname}` WHERE NAME="{table}"'.format( - dbname=cls.dbname, table=cls.table_name), as_dict=True) - ret = cur.fetchone() + ret = self.connection.query( + 'SHOW TABLE STATUS FROM `(database}` WHERE NAME="{table}"'.format( + database=self.database, table=self.table_name), as_dict=True + ).fetchone() return (ret['Data_length'] + ret['Index_length'])/1024**2 - @classmethod - def set_table_comment(cls, comment): + def set_table_comment(self, comment): """ Update the table comment in the table definition. :param comment: new comment as string """ - cls.alter('COMMENT="%s"' % comment) + self._alter('COMMENT="%s"' % comment) - @classmethod - def add_attribute(cls, definition, after=None): + def add_attribute(self, definition, after=None): """ Add a new attribute to the table. A full line from the table definition is passed in as definition. @@ -249,51 +247,45 @@ def add_attribute(cls, definition, after=None): """ position = ' FIRST' if after is None else ( ' AFTER %s' % after if after else '') - sql = field_to_sql(parse_attribute_definition(definition)) - cls._alter('ADD COLUMN %s%s' % (sql[:-2], position)) + sql = compile_attribute(definition)[1] + self._alter('ADD COLUMN %s%s' % (sql, position)) - @classmethod - def drop_attribute(cls, attr_name): + def drop_attribute(self, attribute_name): """ Drops the attribute attrName from this table. - - :param attr_name: Name of the attribute that is dropped. + :param attribute_name: Name of the attribute that is dropped. """ if not config['safemode'] or user_choice( "You are about to drop an attribute from a table." "This operation cannot be undone.\n" "Proceed?", default='no') == 'yes': - cls._alter('DROP COLUMN `%s`' % attr_name) + self._alter('DROP COLUMN `%s`' % attribute_name) - @classmethod - def alter_attribute(cls, attr_name, new_definition): + def alter_attribute(self, attribute_name, definition): """ Alter the definition of the field attr_name in this table using the new definition. - :param attr_name: field that is redefined - :param new_definition: new definition of the field + :param attribute_name: field that is redefined + :param definition: new definition of the field """ - sql = field_to_sql(parse_attribute_definition(new_definition)) - cls._alter('CHANGE COLUMN `%s` %s' % (attr_name, sql[:-2])) + sql = compile_attribute(definition)[1] + self._alter('CHANGE COLUMN `%s` %s' % (attribute_name, sql)) - @classmethod - def erd(cls, subset=None): + def erd(self, subset=None): """ Plot the schema's entity relationship diagram (ERD). """ + NotImplemented - @classmethod - def _alter(cls, alter_statement): + def _alter(self, alter_statement): """ Execute ALTER TABLE statement for this table. The schema will be reloaded within the connection object. :param alter_statement: alter statement """ - if cls.connection.in_transaction: + if self.connection.in_transaction: raise DataJointError("Table definition cannot be altered during a transaction.") - - sql = 'ALTER TABLE %s %s' % (cls.full_table_name, alter_statement) - cls.connection.query(sql) - cls.connection.load_headings(cls.dbname, force=True) - # TODO: place table definition sync mechanism \ No newline at end of file + sql = 'ALTER TABLE %s %s' % (self.full_table_name, alter_statement) + self.connection.query(sql) + self._table_info.heading = None diff --git a/datajoint/user_relations.py b/datajoint/user_relations.py index a3429f457..9431e28c2 100644 --- a/datajoint/user_relations.py +++ b/datajoint/user_relations.py @@ -4,23 +4,23 @@ from . import DataJointError class Manual(Relation): - def table_name(cls): - return from_camel_case(cls.__name__) + def table_name(self): + return from_camel_case(self.__class__.__name__) class Lookup(Relation): - def table_name(cls): - return '#' + from_camel_case(cls.__name__) + def table_name(self): + return '#' + from_camel_case(self.__class__.__name__) class Imported(Relation, AutoPopulate): - def table_name(cls): - return "_" + from_camel_case(cls.__name__) + def table_name(self): + return "_" + from_camel_case(self.__class__.__name__) class Computed(Relation, AutoPopulate): - def table_name(cls): - return "__" + from_camel_case(cls.__name__) + def table_name(self): + return "__" + from_camel_case(self.__class__.__name__) def from_camel_case(s): From dafcbbc98d1309e37f284f11c468bcc993a5619c Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Tue, 9 Jun 2015 14:26:19 -0500 Subject: [PATCH 13/20] converted Heading.init_from_database into an instance method --- datajoint/connection.py | 5 --- datajoint/declare.py | 22 ++++++------- datajoint/heading.py | 12 +++---- datajoint/relation.py | 63 ++++++++++++++++--------------------- datajoint/user_relations.py | 6 ++++ 5 files changed, 50 insertions(+), 58 deletions(-) diff --git a/datajoint/connection.py b/datajoint/connection.py index 3ca01b6bd..0ba65b532 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -1,11 +1,6 @@ import pymysql -import re -from .utils import to_camel_case from . import DataJointError -from .heading import Heading -from .settings import prefix_to_role import logging -from .erd import DBConnGraph from . import config logger = logging.getLogger(__name__) diff --git a/datajoint/declare.py b/datajoint/declare.py index bde62a7e4..4551cd72a 100644 --- a/datajoint/declare.py +++ b/datajoint/declare.py @@ -32,7 +32,7 @@ def compile_attribute(line, in_key=False): match = {k: v.strip() for k, v in match.items()} match['nullable'] = match['default'].lower() == 'null' - sql_literals = ['CURRENT_TIMESTAMP'] # not to be enclosed in quotes + literals = ['CURRENT_TIMESTAMP'] # not to be enclosed in quotes assert not re.match(r'^bigint', match['type'], re.I) or not match['nullable'], \ 'BIGINT attributes cannot be nullable in "%s"' % line # TODO: This was a MATLAB limitation. Handle this correctly. if match['nullable']: @@ -41,12 +41,12 @@ def compile_attribute(line, in_key=False): match['default'] = 'DEFAULT NULL' # nullable attributes default to null else: if match['default']: - quote = match['default'].upper() not in sql_literals and match['default'][0] not in '"\'' + quote = match['default'].upper() not in literals and match['default'][0] not in '"\'' match['default'] = ('NOT NULL DEFAULT ' + ('"%s"' if quote else "%s") % match['default']) else: match['default'] = 'NOT NULL' - match['comment'] = match['comment'].replace('"','\\"') # escape double quotes in comment + match['comment'] = match['comment'].replace('"', '\\"') # escape double quotes in comment sql = ('`{name}` {type} {default}' + (' COMMENT "{comment}"' if match['comment'] else '') ).format(**match) return match['name'], sql @@ -110,7 +110,7 @@ def parse_declaration(cls): ref_list = parents if in_key else referenced ref_list.append(eval(ref_name, locals=cls.context)) elif re.match(r'^(unique\s+)?index[^:]*$', line, re.I): - index_defs.append(parse_index_definition(line)) + index_defs.append(parse_index(line)) elif attribute_regexp.match(line): field_defs.append(parse_attribute_definition(line, in_key)) else: @@ -119,23 +119,23 @@ def parse_declaration(cls): return table_info, parents, referenced, field_defs, index_defs -def declare(base_relation): +def declare(relation): """ Declares the table in the database if it does not exist already """ - cur = base_relation.connection.query( + cur = relation.connection.query( 'SHOW TABLE STATUS FROM `{database}` WHERE name="{table_name}"'.format( - database=base_relation.database, table_name=base_relation.table_name)) + database=relation.database, table_name=relation.table_name)) if cur.rowcount: return - if base_relation.connection.in_transaction: + if relation.connection.in_transaction: raise DataJointError("Tables cannot be declared during a transaction.") - if not base_relation.definition: - raise DataJointError('Table definition is missing.') + if not relation.definition: + raise DataJointError('Missing table definition.') - table_info, parents, referenced, field_defs, index_defs = cls._parse_declaration() + table_info, parents, referenced, field_defs, index_defs = parse_declaration() sql = 'CREATE TABLE %s (\n' % cls.full_table_name diff --git a/datajoint/heading.py b/datajoint/heading.py index 73d6c30f2..e37caf85e 100644 --- a/datajoint/heading.py +++ b/datajoint/heading.py @@ -17,11 +17,13 @@ class Heading: 'computation', 'dtype')) AttrTuple.as_dict = AttrTuple._asdict # renaming to make public - def __init__(self, attributes): + def __init__(self, attributes=None): """ :param attributes: a list of dicts with the same keys as AttrTuple """ - self.attributes = OrderedDict([(q['name'], Heading.AttrTuple(**q)) for q in attributes]) + if attributes: + attributes = OrderedDict([(q['name'], Heading.AttrTuple(**q)) for q in attributes]) + self.attributes = attributes @property def names(self): @@ -90,8 +92,7 @@ def items(self): def __iter__(self): return iter(self.attributes) - @classmethod - def init_from_database(cls, conn, database, table_name): + def init_from_database(self, conn, database, table_name): """ initialize heading from a database table """ @@ -163,8 +164,7 @@ def init_from_database(cls, conn, database, table_name): t = re.sub(r' unsigned$', '', t) # remove unsigned assert (t, is_unsigned) in numeric_types, 'dtype not found for type %s' % t attr['dtype'] = numeric_types[(t, is_unsigned)] - - return cls(attributes) + self.attributes = OrderedDict([(q['name'], Heading.AttrTuple(**q)) for q in attributes]) def project(self, *attribute_list, **renamed_attributes): """ diff --git a/datajoint/relation.py b/datajoint/relation.py index 5533853ee..56c4eb11a 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -1,4 +1,4 @@ -from collections import namedtuple +from types import SimpleNamespace from collections.abc import Mapping import numpy as np import logging @@ -6,18 +6,16 @@ import pymysql from . import DataJointError, config, conn -from .declare import declare +from .declare import declare, compile_attribute from .relational_operand import RelationalOperand from .blob import pack from .utils import user_choice from .heading import Heading -from .declare import compile_attribute logger = logging.getLogger(__name__) -TableLink = namedtuple( - 'TableLink', - ('database', 'context', 'connection', 'heading')) +TableLink = namedtuple('TableLink', + ('database', 'context', 'connection', 'heading')) def schema(database, context, connection=None): @@ -49,11 +47,11 @@ def decorator(cls): """ The decorator declares the table and binds the class to the database table """ - cls._table_info = TableLink( + cls._table_link = TableLink( database=database, context=context, connection=connection, - heading=None + heading=Heading() ) declare(cls()) return cls @@ -72,11 +70,11 @@ class Relation(RelationalOperand, metaclass=abc.ABCMeta): It also handles the table declaration based on its definition property """ - _table_info = None + _table_link = None def __init__(self): - if self._table_info is None: - raise DataJointError('The class must define _table_info') + if self._table_link is None: + raise DataJointError('The class must define _table_link') # ---------- abstract properties ------------ # @property @@ -85,7 +83,7 @@ def table_name(self): """ :return: the name of the table in the database """ - pass + raise NotImplementedError('Relation subclasses must define property table_name') @property @abc.abstractmethod @@ -98,37 +96,34 @@ def definition(self): # -------------- table info ----------------- # @property def connection(self): - return self._table_info.connection + return self._table_link.connection @property def database(self): - return self._table_info.database + return self._table_link.database @property def context(self): - return self._table_info.context + return self._table_link.context @property def heading(self): - if self._table_info.heading is None: - self._table_info.heading = Heading.init_from_database( - self.connection, self.database, self.table_name) - return self._table_info.heading - + heading = self._table_link.heading + if not heading: + heading.init_from_database(self.connection, self.database, self.table_name) + return heading # --------- SQL functionality --------- # @property def from_clause(self): """ - Required by the Relational class, this property specifies the contents of the FROM clause - for the SQL SELECT statements. - :return: + :return: the FROM clause of SQL SELECT statements. """ return self.full_table_name def iter_insert(self, rows, **kwargs): """ - Inserts an entire batch of entries. Additional keyword arguments are passed to insert. + Inserts a collection of tuples. Additional keyword arguments are passed to insert. :param iter: Must be an iterator that generates a sequence of valid arguments for insert. """ @@ -172,18 +167,16 @@ def insert(self, tup, ignore_errors=False, replace=False): args = tuple(pack(tup[name]) for name in heading if name in tup.dtype.fields and heading[name].is_blob) - attribute_list = '`' + '`,`'.join( - [q for q in heading if q in tup.dtype.fields]) + '`' + attribute_list = '`' + '`,`'.join(q for q in heading if q in tup.dtype.fields) + '`' elif isinstance(tup, Mapping): for fieldname in tup.keys(): if fieldname not in heading: raise KeyError(u'{0:s} is not in the attribute list'.format(fieldname, )) - value_list = ','.join([repr(tup[name]) if not heading[name].is_blob else '%s' - for name in heading if name in tup]) + value_list = ','.join(repr(tup[name]) if not heading[name].is_blob else '%s' + for name in heading if name in tup) args = tuple(pack(tup[name]) for name in heading if name in tup and heading[name].is_blob) - attribute_list = '`' + '`,`'.join( - [name for name in heading if name in tup]) + '`' + attribute_list = '`' + '`,`'.join(name for name in heading if name in tup) + '`' else: raise DataJointError('Datatype %s cannot be inserted' % type(tup)) if replace: @@ -213,7 +206,7 @@ def drop(self): self.connection.query('DROP TABLE %s' % self.full_table_name) # TODO: make cascading (issue #16) # cls.connection.clear_dependencies(dbname=cls.dbname) #TODO: reimplement because clear_dependencies will be gone # cls.connection.load_headings(dbname=cls.dbname, force=True) #TODO: reimplement because load_headings is gone - logger.info("Dropped table %s" % cls.full_table_name) + logger.info("Dropped table %s" % self.full_table_name) def size_on_disk(self): """ @@ -263,7 +256,7 @@ def drop_attribute(self, attribute_name): def alter_attribute(self, attribute_name, definition): """ - Alter the definition of the field attr_name in this table using the new definition. + Alter attribute definition :param attribute_name: field that is redefined :param definition: new definition of the field @@ -279,13 +272,11 @@ def erd(self, subset=None): def _alter(self, alter_statement): """ - Execute ALTER TABLE statement for this table. The schema - will be reloaded within the connection object. - + Execute an ALTER TABLE statement. :param alter_statement: alter statement """ if self.connection.in_transaction: raise DataJointError("Table definition cannot be altered during a transaction.") sql = 'ALTER TABLE %s %s' % (self.full_table_name, alter_statement) self.connection.query(sql) - self._table_info.heading = None + self._table_link.heading = None \ No newline at end of file diff --git a/datajoint/user_relations.py b/datajoint/user_relations.py index 9431e28c2..767f87b33 100644 --- a/datajoint/user_relations.py +++ b/datajoint/user_relations.py @@ -3,26 +3,32 @@ from .autopopulate import AutoPopulate from . import DataJointError + class Manual(Relation): + @property def table_name(self): return from_camel_case(self.__class__.__name__) class Lookup(Relation): + @property def table_name(self): return '#' + from_camel_case(self.__class__.__name__) class Imported(Relation, AutoPopulate): + @property def table_name(self): return "_" + from_camel_case(self.__class__.__name__) class Computed(Relation, AutoPopulate): + @property def table_name(self): return "__" + from_camel_case(self.__class__.__name__) +# ---------------- utilities -------------------- def from_camel_case(s): """ Convert names in camel case into underscore (_) separated names From 76a94b461a5ab51b0c39843569551dcd3419d4b5 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Tue, 9 Jun 2015 14:48:56 -0500 Subject: [PATCH 14/20] fixed heading initialization --- datajoint/declare.py | 4 ++-- datajoint/heading.py | 3 +++ datajoint/relation.py | 4 ++-- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/datajoint/declare.py b/datajoint/declare.py index 4551cd72a..444c5e4af 100644 --- a/datajoint/declare.py +++ b/datajoint/declare.py @@ -149,8 +149,8 @@ def declare(relation): primary_key_fields.add(field.name) sql += field_to_sql(field) else: - logger.debug('Field definition of {} in {} ignored'.format( - field.name, p.full_class_name)) + logger.debug( + 'Field definition of {} in {} ignored'.format(field.name, p.full_class_name)) # add newly defined primary key fields for field in (f for f in field_defs if f.in_key): diff --git a/datajoint/heading.py b/datajoint/heading.py index e37caf85e..6596a0673 100644 --- a/datajoint/heading.py +++ b/datajoint/heading.py @@ -25,6 +25,9 @@ def __init__(self, attributes=None): attributes = OrderedDict([(q['name'], Heading.AttrTuple(**q)) for q in attributes]) self.attributes = attributes + def __bool__(self): + return self.attributes is not None + @property def names(self): return [k for k in self.attributes] diff --git a/datajoint/relation.py b/datajoint/relation.py index 56c4eb11a..c38f0ed87 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -1,4 +1,4 @@ -from types import SimpleNamespace +from collections import namedtuple from collections.abc import Mapping import numpy as np import logging @@ -138,10 +138,10 @@ def batch_insert(self, data, **kwargs): """ self.iter_insert(data.__iter__(), **kwargs) + @property def full_table_name(self): return r"`{0:s}`.`{1:s}`".format(self.database, self.table_name) - @classmethod def insert(self, tup, ignore_errors=False, replace=False): """ Insert one data record or one Mapping (like a dictionary). From bab7460d12c78c84f4954782bbd35be0f4fc91ce Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Tue, 9 Jun 2015 15:13:04 -0500 Subject: [PATCH 15/20] implemented dj.Subordinate to support subtables --- datajoint/__init__.py | 2 +- datajoint/autopopulate.py | 6 +++--- datajoint/user_relations.py | 21 +++++++++++++++++---- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/datajoint/__init__.py b/datajoint/__init__.py index 796ded488..b562aa8d1 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -38,7 +38,7 @@ class DataJointError(Exception): # ------------- flatten import hierarchy ------------------------- from .connection import conn, Connection from .relation import Relation -from .user_relations import Manual, Lookup, Imported, Computed +from .user_relations import Manual, Lookup, Imported, Computed, Subordinate from .autopopulate import AutoPopulate from . import blob from .relational_operand import Not diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index 5db012eff..cb4c0d673 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -37,7 +37,7 @@ def _make_tuples(self, key): def target(self): return self - def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False, max_attempts=10): + def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False): """ rel.populate() calls rel._make_tuples(key) for every primary key in self.populate_relation for which there is not already a tuple in rel. @@ -45,7 +45,6 @@ def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False, :param restriction: restriction on rel.populate_relation - target :param suppress_errors: suppresses error if true :param reserve_jobs: currently not implemented - :param max_attempts: maximal number of times a TransactionError is caught before populate gives up """ assert not reserve_jobs, NotImplemented # issue #5 @@ -56,7 +55,8 @@ def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False, self.connection.cancel_transaction() # rollback previous transaction, if any if not isinstance(self, Relation): - raise DataJointError('Autopopulate is a mixin for Relation and must therefore subclass Relation') + raise DataJointError( + 'AutoPopulate is a mixin for Relation and must therefore subclass Relation') unpopulated = (self.populate_relation - self.target) & restriction for key in unpopulated.project(): diff --git a/datajoint/user_relations.py b/datajoint/user_relations.py index 767f87b33..8fa469504 100644 --- a/datajoint/user_relations.py +++ b/datajoint/user_relations.py @@ -1,33 +1,46 @@ import re +import abc from datajoint.relation import Relation from .autopopulate import AutoPopulate from . import DataJointError -class Manual(Relation): +class Manual(Relation, metaclass=abc.ABCMeta): @property def table_name(self): return from_camel_case(self.__class__.__name__) -class Lookup(Relation): +class Lookup(Relation, metaclass=abc.ABCMeta): @property def table_name(self): return '#' + from_camel_case(self.__class__.__name__) -class Imported(Relation, AutoPopulate): +class Imported(Relation, AutoPopulate, metaclass=abc.ABCMeta): @property def table_name(self): return "_" + from_camel_case(self.__class__.__name__) -class Computed(Relation, AutoPopulate): +class Computed(Relation, AutoPopulate, metaclass=abc.ABCMeta): @property def table_name(self): return "__" + from_camel_case(self.__class__.__name__) +class Subordinate: + """ + Mix-in to make computed tables subordinate + """ + @property + def populate_relation(self): + return None + + def _make_tuples(self, key): + raise NotImplementedError('_make_tuples not defined.') + + # ---------------- utilities -------------------- def from_camel_case(s): """ From 706dff0cf5174171ab93f2310c0cbc01f1034e6a Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Tue, 9 Jun 2015 16:58:17 -0500 Subject: [PATCH 16/20] In Relation, split table_info attribute into its parts; made declare() independent --- datajoint/declare.py | 4 +-- datajoint/relation.py | 51 +++++++++++++------------------------ datajoint/user_relations.py | 10 ++++---- 3 files changed, 23 insertions(+), 42 deletions(-) diff --git a/datajoint/declare.py b/datajoint/declare.py index 444c5e4af..99d189464 100644 --- a/datajoint/declare.py +++ b/datajoint/declare.py @@ -2,8 +2,6 @@ import pyparsing as pp import logging -from . import DataJointError - logger = logging.getLogger(__name__) @@ -119,7 +117,7 @@ def parse_declaration(cls): return table_info, parents, referenced, field_defs, index_defs -def declare(relation): +def declare(full_table_name, definition, context): """ Declares the table in the database if it does not exist already """ diff --git a/datajoint/relation.py b/datajoint/relation.py index c38f0ed87..463ed0424 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -14,9 +14,6 @@ logger = logging.getLogger(__name__) -TableLink = namedtuple('TableLink', - ('database', 'context', 'connection', 'heading')) - def schema(database, context, connection=None): """ @@ -47,13 +44,16 @@ def decorator(cls): """ The decorator declares the table and binds the class to the database table """ - cls._table_link = TableLink( - database=database, - context=context, - connection=connection, - heading=Heading() - ) - declare(cls()) + cls.database = database + cls._connection = connection + cls._heading = Heading() + instance = cls() if isinstance(cls, type) else cls + if not cls.heading: + cls.connection.query( + declare( + table_name=instance.full_table_name, + definition=instance.definition, + context=context)) return cls return decorator @@ -66,16 +66,8 @@ class Relation(RelationalOperand, metaclass=abc.ABCMeta): To make it a concrete class, override the abstract properties specifying the connection, table name, database, context, and definition. A Relation implements insert and delete methods in addition to inherited relational operators. - It also loads table heading and dependencies from the database. - It also handles the table declaration based on its definition property """ - _table_link = None - - def __init__(self): - if self._table_link is None: - raise DataJointError('The class must define _table_link') - # ---------- abstract properties ------------ # @property @abc.abstractmethod @@ -93,27 +85,17 @@ def definition(self): """ pass - # -------------- table info ----------------- # + # -------------- required by RelationalOperand ----------------- # @property def connection(self): - return self._table_link.connection - - @property - def database(self): - return self._table_link.database - - @property - def context(self): - return self._table_link.context + return self._connection @property def heading(self): - heading = self._table_link.heading - if not heading: - heading.init_from_database(self.connection, self.database, self.table_name) - return heading + if not self._heading: + self._heading.init_from_database(self.connection, self.database, self.table_name) + return self._heading - # --------- SQL functionality --------- # @property def from_clause(self): """ @@ -130,6 +112,7 @@ def iter_insert(self, rows, **kwargs): for row in rows: self.insert(row, **kwargs) + # --------- SQL functionality --------- # def batch_insert(self, data, **kwargs): """ Inserts an entire batch of entries. Additional keyword arguments are passed to insert. @@ -279,4 +262,4 @@ def _alter(self, alter_statement): raise DataJointError("Table definition cannot be altered during a transaction.") sql = 'ALTER TABLE %s %s' % (self.full_table_name, alter_statement) self.connection.query(sql) - self._table_link.heading = None \ No newline at end of file + self.heading.init_from_database(self.connection, self.database, self.table_name) \ No newline at end of file diff --git a/datajoint/user_relations.py b/datajoint/user_relations.py index 8fa469504..1b0352a03 100644 --- a/datajoint/user_relations.py +++ b/datajoint/user_relations.py @@ -5,25 +5,25 @@ from . import DataJointError -class Manual(Relation, metaclass=abc.ABCMeta): +class Manual(Relation): @property def table_name(self): return from_camel_case(self.__class__.__name__) -class Lookup(Relation, metaclass=abc.ABCMeta): +class Lookup(Relation): @property def table_name(self): return '#' + from_camel_case(self.__class__.__name__) -class Imported(Relation, AutoPopulate, metaclass=abc.ABCMeta): +class Imported(Relation, AutoPopulate): @property def table_name(self): return "_" + from_camel_case(self.__class__.__name__) -class Computed(Relation, AutoPopulate, metaclass=abc.ABCMeta): +class Computed(Relation, AutoPopulate): @property def table_name(self): return "__" + from_camel_case(self.__class__.__name__) @@ -38,7 +38,7 @@ def populate_relation(self): return None def _make_tuples(self, key): - raise NotImplementedError('_make_tuples not defined.') + raise NotImplementedError('Subtables should not be populated directly.') # ---------------- utilities -------------------- From 2b261fcbffe54506a388928caa860eacfca3069a Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Tue, 9 Jun 2015 17:20:47 -0500 Subject: [PATCH 17/20] typo --- datajoint/relation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datajoint/relation.py b/datajoint/relation.py index 463ed0424..98fa24309 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -168,7 +168,7 @@ def insert(self, tup, ignore_errors=False, replace=False): sql = 'INSERT IGNORE' else: sql = 'INSERT' - sql += " INTO %s (%s) VALUES (%s)" % (self.from_caluse, attribute_list, value_list) + sql += " INTO %s (%s) VALUES (%s)" % (self.from_clause, attribute_list, value_list) logger.info(sql) self.connection.query(sql, args=args) From de8f0098e211bbdeb9bafeb9e312046b06c21fca Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Tue, 9 Jun 2015 17:26:29 -0500 Subject: [PATCH 18/20] typo --- datajoint/relation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datajoint/relation.py b/datajoint/relation.py index 98fa24309..958e6ed27 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -196,7 +196,7 @@ def size_on_disk(self): :return: size of data and indices in MiB taken by the table on the storage device """ ret = self.connection.query( - 'SHOW TABLE STATUS FROM `(database}` WHERE NAME="{table}"'.format( + 'SHOW TABLE STATUS FROM `{database}` WHERE NAME="{table}"'.format( database=self.database, table=self.table_name), as_dict=True ).fetchone() return (ret['Data_length'] + ret['Index_length'])/1024**2 From c8b9400b313d69196529dc95fe572f4b8fa051b8 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 12 Jun 2015 15:56:51 -0500 Subject: [PATCH 19/20] debugged and optimized table declaration in declare.py --- datajoint/declare.py | 221 ++++++++++++------------------------------ datajoint/heading.py | 60 +++++++++--- datajoint/relation.py | 17 +++- demos/demo1.py | 38 +++++--- 4 files changed, 140 insertions(+), 196 deletions(-) diff --git a/datajoint/declare.py b/datajoint/declare.py index 99d189464..b3c2968ca 100644 --- a/datajoint/declare.py +++ b/datajoint/declare.py @@ -6,6 +6,66 @@ logger = logging.getLogger(__name__) + +def declare(full_table_name, definition, context): + """ + Parse declaration and create new SQL table accordingly. + """ + # split definition into lines + definition = re.split(r'\s*\n\s*', definition.strip()) + + table_comment = definition.pop(0)[1:] if definition[0].startswith('#') else '' + + in_key = True # parse primary keys + primary_key = [] + attributes = [] + attribute_sql = [] + foreign_key_sql = [] + index_sql = [] + + for line in definition: + if line.startswith('#'): # additional comments are ignored + pass + elif line.startswith('---'): + in_key = False # start parsing dependent attributes + elif line.startswith('->'): + # foreign key + ref = eval(line[2:], context)() + foreign_key_sql.append( + 'FOREIGN KEY ({primary_key})' + ' REFERENCES {ref} ({primary_key})' + ' ON UPDATE CASCADE ON DELETE RESTRICT'.format( + primary_key='`' + '`,`'.join(primary_key) + '`', ref=ref.full_table_name) + ) + for name in ref.primary_key: + if in_key and name not in primary_key: + primary_key.append(name) + if name not in attributes: + attributes.append(name) + attribute_sql.append(ref.heading[name].sql()) + elif re.match(r'^(unique\s+)?index[^:]*$', line, re.I): # index + index_sql.append(line) # the SQL syntax is identical to DataJoint's + else: + name, sql = compile_attribute(line, in_key) + if in_key and name not in primary_key: + primary_key.append(name) + if name not in attributes: + attributes.append(name) + attribute_sql.append(sql) + + # compile SQL + sql = 'CREATE TABLE %s (\n ' % full_table_name + sql += ', \n'.join(attribute_sql) + if foreign_key_sql: + sql += ', \n' + ', \n'.join(foreign_key_sql) + if index_sql: + sql += ', \n' + ', \n'.join(index_sql) + sql += '\n) ENGINE = InnoDB, COMMENT "%s"' % table_comment + return sql + + + + def compile_attribute(line, in_key=False): """ Convert attribute definition from DataJoint format to SQL @@ -31,8 +91,6 @@ def compile_attribute(line, in_key=False): match['nullable'] = match['default'].lower() == 'null' literals = ['CURRENT_TIMESTAMP'] # not to be enclosed in quotes - assert not re.match(r'^bigint', match['type'], re.I) or not match['nullable'], \ - 'BIGINT attributes cannot be nullable in "%s"' % line # TODO: This was a MATLAB limitation. Handle this correctly. if match['nullable']: if in_key: raise DataJointError('Primary key attributes cannot be nullable in line %s' % line) @@ -48,162 +106,3 @@ def compile_attribute(line, in_key=False): sql = ('`{name}` {type} {default}' + (' COMMENT "{comment}"' if match['comment'] else '') ).format(**match) return match['name'], sql - - -def parse_index(line): - """ - Parses index definition. - - :param line: definition line - :return: groupdict with index info - """ - line = line.strip() - index_regexp = re.compile(""" - ^(?PUNIQUE)?\s*INDEX\s* # [UNIQUE] INDEX - \((?P[^\)]+)\)$ # (attr1, attr2) - """, re.I + re.X) - m = index_regexp.match(line) - assert m, 'Invalid index declaration "%s"' % line - index_info = m.groupdict() - attributes = re.split(r'\s*,\s*', index_info['attributes'].strip()) - index_info['attributes'] = attributes - assert len(attributes) == len(set(attributes)), \ - 'Duplicate attributes in index declaration "%s"' % line - return index_info - - -def parse_declaration(cls): - """ - Parse declaration and create new SQL table accordingly. - """ - parents = [] - referenced = [] - index_defs = [] - field_defs = [] - declaration = re.split(r'\s*\n\s*', cls.definition.strip()) - - # remove comment lines - declaration = [x for x in declaration if not x.startswith('#')] - ptrn = """ - \#\s*(?P.*)$ # comment - """ - p = re.compile(ptrn, re.X) - table_info = p.search(declaration[0]).groupdict() - - #table_info['tier'] = Role[table_info['tier']] # convert into enum - - in_key = True # parse primary keys - attribute_regexp = re.compile(""" - ^[a-z][a-z\d_]*\s* # name - (=\s*\S+(\s+\S+)*\s*)? # optional defaults - :\s*\w.*$ # type, comment - """, re.I + re.X) # ignore case and verbose - - for line in declaration[1:]: - if line.startswith('---'): - in_key = False # start parsing non-PK fields - elif line.startswith('->'): - # foreign key - ref_name = line[2:].strip() - ref_list = parents if in_key else referenced - ref_list.append(eval(ref_name, locals=cls.context)) - elif re.match(r'^(unique\s+)?index[^:]*$', line, re.I): - index_defs.append(parse_index(line)) - elif attribute_regexp.match(line): - field_defs.append(parse_attribute_definition(line, in_key)) - else: - raise DataJointError('Invalid table declaration line "%s"' % line) - - return table_info, parents, referenced, field_defs, index_defs - - -def declare(full_table_name, definition, context): - """ - Declares the table in the database if it does not exist already - """ - cur = relation.connection.query( - 'SHOW TABLE STATUS FROM `{database}` WHERE name="{table_name}"'.format( - database=relation.database, table_name=relation.table_name)) - if cur.rowcount: - return - - if relation.connection.in_transaction: - raise DataJointError("Tables cannot be declared during a transaction.") - - if not relation.definition: - raise DataJointError('Missing table definition.') - - table_info, parents, referenced, field_defs, index_defs = parse_declaration() - - sql = 'CREATE TABLE %s (\n' % cls.full_table_name - - # add inherited primary key fields - primary_key_fields = set() - non_key_fields = set() - for p in parents: - for key in p.primary_key: - field = p.heading[key] - if field.name not in primary_key_fields: - primary_key_fields.add(field.name) - sql += field_to_sql(field) - else: - logger.debug( - 'Field definition of {} in {} ignored'.format(field.name, p.full_class_name)) - - # add newly defined primary key fields - for field in (f for f in field_defs if f.in_key): - if field.nullable: - raise DataJointError('Primary key attribute {} cannot be nullable'.format( - field.name)) - if field.name in primary_key_fields: - raise DataJointError('Duplicate declaration of the primary attribute {key}. ' - 'Ensure that the attribute is not already declared ' - 'in referenced tables'.format(key=field.name)) - primary_key_fields.add(field.name) - sql += field_to_sql(field) - - # add secondary foreign key attributes - for r in referenced: - for key in r.primary_key: - field = r.heading[key] - if field.name not in primary_key_fields | non_key_fields: - non_key_fields.add(field.name) - sql += field_to_sql(field) - - # add dependent attributes - for field in (f for f in field_defs if not f.in_key): - non_key_fields.add(field.name) - sql += field_to_sql(field) - - # add primary key declaration - assert len(primary_key_fields) > 0, 'table must have a primary key' - keys = ', '.join(primary_key_fields) - sql += 'PRIMARY KEY (%s),\n' % keys - - # add foreign key declarations - for ref in parents + referenced: - keys = ', '.join(ref.primary_key) - sql += 'FOREIGN KEY (%s) REFERENCES %s (%s) ON UPDATE CASCADE ON DELETE RESTRICT,\n' % \ - (keys, ref.full_table_name, keys) - - # add secondary index declarations - # gather implicit indexes due to foreign keys first - implicit_indices = [] - for fk_source in parents + referenced: - implicit_indices.append(fk_source.primary_key) - - # for index in indexDefs: - # TODO: add index declaration - - # close the declaration - sql = '%s\n) ENGINE = InnoDB, COMMENT "%s"' % ( - sql[:-2], table_info['comment']) - - # # make sure that the table does not already exist - # cls.load_heading() - # if not cls.is_declared: - # # execute declaration - # logger.debug('\n\n' + sql + '\n\n') - # cls.connection.query(sql) - # cls.load_heading() - diff --git a/datajoint/heading.py b/datajoint/heading.py index 6596a0673..71ba2b8cf 100644 --- a/datajoint/heading.py +++ b/datajoint/heading.py @@ -11,11 +11,38 @@ class Heading: Heading contains the property attributes, which is an OrderedDict in which the keys are the attribute names and the values are AttrTuples. """ - AttrTuple = namedtuple('AttrTuple', + + class AttrTuple(namedtuple('AttrTuple', ('name', 'type', 'in_key', 'nullable', 'default', 'comment', 'autoincrement', 'numeric', 'string', 'is_blob', - 'computation', 'dtype')) - AttrTuple.as_dict = AttrTuple._asdict # renaming to make public + 'computation', 'dtype'))): + def _asdict(self): + """ + for some reason the inherted _asdict does not work after subclassing from namedtuple + """ + return OrderedDict((name, self[i]) for i, name in enumerate(self._fields)) + + + def sql(self): + """ + Convert attribute tuple into its SQL CREATE TABLE clause. + :rtype : SQL code + """ + literals = ['CURRENT_TIMESTAMP'] + if self.nullable: + default = 'DEFAULT NULL' + else: + default = 'NOT NULL' + if self.default: + # enclose value in quotes except special SQL values or already enclosed + quote = self.default.upper() not in literals and self.default[0] not in '"\'' + default += ' DEFAULT ' + ('"%s"' if quote else "%s") % self.default + if any((c in r'\"' for c in self.comment)): + raise DataJointError('Illegal characters in attribute comment "%s"' % self.comment) + return '`{name}` {type} {default} COMMENT "{comment}"'.format( + name=self.name, type=self.type, default=default, comment=self.comment) + + def __init__(self, attributes=None): """ @@ -57,12 +84,14 @@ def __getitem__(self, name): return self.attributes[name] def __repr__(self): - autoincrement_string = {False: '', True: ' auto_increment'} - return '\n'.join(['%-20s : %-28s # %s' % ( - k if v.default is None else '%s="%s"' % (k, v.default), - '%s%s' % (v.type, autoincrement_string[v.autoincrement]), - v.comment) - for k, v in self.attributes.items()]) + if self.attributes is None: + return 'Empty heading' + else: + return '\n'.join(['%-20s : %-28s # %s' % ( + k if v.default is None else '%s="%s"' % (k, v.default), + '%s%s' % (v.type, 'auto_increment' if v.autoincrement else ''), + v.comment) + for k, v in self.attributes.items()]) @property def as_dtype(self): @@ -97,11 +126,12 @@ def __iter__(self): def init_from_database(self, conn, database, table_name): """ - initialize heading from a database table + initialize heading from a database table. The table must exist already. """ cur = conn.query( 'SHOW FULL COLUMNS FROM `{table_name}` IN `{database}`'.format( table_name=table_name, database=database), as_dict=True) + attributes = cur.fetchall() rename_map = { @@ -184,14 +214,14 @@ def project(self, *attribute_list, **renamed_attributes): attribute_list = self.primary_key + [a for a in attribute_list if a not in self.primary_key] # convert attribute_list into a list of dicts but exclude renamed attributes - attribute_list = [v.as_dict() for k, v in self.attributes.items() + attribute_list = [v._asdict() for k, v in self.attributes.items() if k in attribute_list and k not in renamed_attributes.values()] # add renamed and computed attributes for new_name, computation in renamed_attributes.items(): if computation in self.names: # renamed attribute - new_attr = self.attributes[computation].as_dict() + new_attr = self.attributes[computation]._asdict() new_attr['name'] = new_name new_attr['computation'] = '`' + computation + '`' else: @@ -218,14 +248,14 @@ def __add__(self, other): join two headings """ assert isinstance(other, Heading) - attribute_list = [v.as_dict() for v in self.attributes.values()] + attribute_list = [v._asdict() for v in self.attributes.values()] for name in other.names: if name not in self.names: - attribute_list.append(other.attributes[name].as_dict()) + attribute_list.append(other.attributes[name]._asdict()) return Heading(attribute_list) def resolve(self): """ Remove attribute computations after they have been resolved in a subquery """ - return Heading([dict(v.as_dict(), computation=None) for v in self.attributes.values()]) \ No newline at end of file + return Heading([dict(v._asdict(), computation=None) for v in self.attributes.values()]) \ No newline at end of file diff --git a/datajoint/relation.py b/datajoint/relation.py index 958e6ed27..8f558831f 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -48,10 +48,10 @@ def decorator(cls): cls._connection = connection cls._heading = Heading() instance = cls() if isinstance(cls, type) else cls - if not cls.heading: - cls.connection.query( + if not instance.heading: + connection.query( declare( - table_name=instance.full_table_name, + full_table_name=instance.full_table_name, definition=instance.definition, context=context)) return cls @@ -92,7 +92,7 @@ def connection(self): @property def heading(self): - if not self._heading: + if not self._heading and self.is_declared: self._heading.init_from_database(self.connection, self.database, self.table_name) return self._heading @@ -113,6 +113,13 @@ def iter_insert(self, rows, **kwargs): self.insert(row, **kwargs) # --------- SQL functionality --------- # + @property + def is_declared(self): + cur = self.connection.query( + 'SHOW TABLES in `{database}`LIKE "{table_name}"'.format( + database=self.database, table_name=self.table_name)) + return cur.rowcount>0 + def batch_insert(self, data, **kwargs): """ Inserts an entire batch of entries. Additional keyword arguments are passed to insert. @@ -193,7 +200,7 @@ def drop(self): def size_on_disk(self): """ - :return: size of data and indices in MiB taken by the table on the storage device + :return: size of data and indices in GiB taken by the table on the storage device """ ret = self.connection.query( 'SHOW TABLE STATUS FROM `{database}` WHERE NAME="{table}"'.format( diff --git a/demos/demo1.py b/demos/demo1.py index 4376a4982..d10dc30e3 100644 --- a/demos/demo1.py +++ b/demos/demo1.py @@ -4,13 +4,12 @@ print("Welcome to the database 'demo1'") -conn = dj.conn() # connect to database; conn must be defined in module namespace -conn.bind(module=__name__, dbname='dj_test') # bind this module to the database +schema = dj.schema('dj_test', locals()) - -class Subject(dj.Relation): +@schema +class Subject(dj.Manual): definition = """ - demo1.Subject (manual) # Basic subject info + # Basic subject info subject_id : int # internal subject id --- real_id : varchar(40) # real-world name @@ -22,9 +21,10 @@ class Subject(dj.Relation): """ -class Experiment(dj.Relation): +@schema +class Experiment(dj.Manual): definition = """ - demo1.Experiment (manual) # Basic subject info + # Basic subject info -> demo1.Subject experiment : smallint # experiment number for this subject --- @@ -35,9 +35,11 @@ class Experiment(dj.Relation): """ -class Session(dj.Relation): +@schema +class Session(dj.Manual): definition = """ - demo1.Session (manual) # a two-photon imaging session + # a two-photon imaging session + -> demo1.Experiment session_id : tinyint # two-photon session within this experiment ----------- @@ -46,9 +48,11 @@ class Session(dj.Relation): """ -class Scan(dj.Relation): +@schema +class Scan(dj.Manual): definition = """ - demo1.Scan (manual) # a two-photon imaging session + # a two-photon imaging session + -> demo1.Session -> Config scan_id : tinyint # two-photon session within this experiment @@ -58,16 +62,20 @@ class Scan(dj.Relation): mwatts: numeric(4,1) # (mW) laser power to brain """ -class Config(dj.Relation): +@schema +class Config(dj.Manual): definition = """ - demo1.Config (manual) # configuration for scanner + # configuration for scanner + config_id : tinyint # unique id for config setup --- ->ConfigParam """ -class ConfigParam(dj.Relation): +@schema +class ConfigParam(dj.Manual): definition = """ - demo1.ConfigParam (lookup) # params for configurations + # params for configurations + param_set_id : tinyint # id for params """ \ No newline at end of file From d0690f29e6b99a4e48b4bbb6d5e1ed72d0c1c8d6 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 12 Jun 2015 16:40:24 -0500 Subject: [PATCH 20/20] fixed primary key declaration --- datajoint/declare.py | 11 ++++++++--- datajoint/heading.py | 3 --- demos/demo1.py | 28 +++++++--------------------- 3 files changed, 15 insertions(+), 27 deletions(-) diff --git a/datajoint/declare.py b/datajoint/declare.py index b3c2968ca..c0c8e44a9 100644 --- a/datajoint/declare.py +++ b/datajoint/declare.py @@ -2,6 +2,8 @@ import pyparsing as pp import logging +from . import DataJointError + logger = logging.getLogger(__name__) @@ -14,7 +16,7 @@ def declare(full_table_name, definition, context): # split definition into lines definition = re.split(r'\s*\n\s*', definition.strip()) - table_comment = definition.pop(0)[1:] if definition[0].startswith('#') else '' + table_comment = definition.pop(0)[1:].strip() if definition[0].startswith('#') else '' in_key = True # parse primary keys primary_key = [] @@ -35,7 +37,7 @@ def declare(full_table_name, definition, context): 'FOREIGN KEY ({primary_key})' ' REFERENCES {ref} ({primary_key})' ' ON UPDATE CASCADE ON DELETE RESTRICT'.format( - primary_key='`' + '`,`'.join(primary_key) + '`', ref=ref.full_table_name) + primary_key='`' + '`,`'.join(ref.primary_key) + '`', ref=ref.full_table_name) ) for name in ref.primary_key: if in_key and name not in primary_key: @@ -54,8 +56,11 @@ def declare(full_table_name, definition, context): attribute_sql.append(sql) # compile SQL + if not primary_key: + raise DataJointError('Table must have a primary key') sql = 'CREATE TABLE %s (\n ' % full_table_name - sql += ', \n'.join(attribute_sql) + sql += ',\n '.join(attribute_sql) + sql += ',\n PRIMARY KEY (`' + '`,`'.join(primary_key) + '`)' if foreign_key_sql: sql += ', \n' + ', \n'.join(foreign_key_sql) if index_sql: diff --git a/datajoint/heading.py b/datajoint/heading.py index 71ba2b8cf..5c332fff6 100644 --- a/datajoint/heading.py +++ b/datajoint/heading.py @@ -22,7 +22,6 @@ def _asdict(self): """ return OrderedDict((name, self[i]) for i, name in enumerate(self._fields)) - def sql(self): """ Convert attribute tuple into its SQL CREATE TABLE clause. @@ -42,8 +41,6 @@ def sql(self): return '`{name}` {type} {default} COMMENT "{comment}"'.format( name=self.name, type=self.type, default=default, comment=self.comment) - - def __init__(self, attributes=None): """ :param attributes: a list of dicts with the same keys as AttrTuple diff --git a/demos/demo1.py b/demos/demo1.py index d10dc30e3..46ab53fa6 100644 --- a/demos/demo1.py +++ b/demos/demo1.py @@ -20,12 +20,16 @@ class Subject(dj.Manual): animal_notes="" : varchar(4096) # strain, genetic manipulations, etc """ +s = Subject() +p = s.primary_key + @schema class Experiment(dj.Manual): definition = """ # Basic subject info - -> demo1.Subject + + -> Subject experiment : smallint # experiment number for this subject --- experiment_folder : varchar(255) # folder path @@ -40,7 +44,7 @@ class Session(dj.Manual): definition = """ # a two-photon imaging session - -> demo1.Experiment + -> Experiment session_id : tinyint # two-photon session within this experiment ----------- setup : tinyint # experimental setup @@ -53,8 +57,7 @@ class Scan(dj.Manual): definition = """ # a two-photon imaging session - -> demo1.Session - -> Config + -> Session scan_id : tinyint # two-photon session within this experiment ---- depth : float # depth from surface @@ -62,20 +65,3 @@ class Scan(dj.Manual): mwatts: numeric(4,1) # (mW) laser power to brain """ -@schema -class Config(dj.Manual): - definition = """ - # configuration for scanner - - config_id : tinyint # unique id for config setup - --- - ->ConfigParam - """ - -@schema -class ConfigParam(dj.Manual): - definition = """ - # params for configurations - - param_set_id : tinyint # id for params - """ \ No newline at end of file