From 5fb598974ac670ba66cefd73bf0d704ae3c7748c Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 26 Jun 2015 12:28:48 -0500 Subject: [PATCH 01/34] implemented default populate relation (fixed issue #59) --- datajoint/autopopulate.py | 18 ++++++++++++------ datajoint/erd.py | 18 ++++++++---------- datajoint/relation.py | 32 +++++++++++++++++++++++++------- 3 files changed, 45 insertions(+), 23 deletions(-) diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index d07fd1bf0..d8ec6d297 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -1,5 +1,6 @@ from .relational_operand import RelationalOperand from . import DataJointError, Relation +from .relation import FreeRelation import abc import logging @@ -12,17 +13,22 @@ class AutoPopulate(metaclass=abc.ABCMeta): """ AutoPopulate is a mixin class that adds the method populate() to a Relation class. Auto-populated relations must inherit from both Relation and AutoPopulate, - must define the property pop_rel, and must define the callback method make_tuples. + must define the property populate_relation, and must define the callback method _make_tuples. """ - @abc.abstractproperty + @property def populate_relation(self): """ - Derived classes must implement the read-only property populate_relation, which is the - relational expression that defines how keys are generated for the populate call. - By default, populate relation is the join of the primary dependencies of the table. + :return: the relation whose primary key values are passed, sequentially, to the + _make_tuples method when populate() is called. + The default value is the join of the parent relations. Users may override to change + the granularity or the scope of populate() calls. """ - pass + parents = [FreeRelation(self.target.connection, rel) for rel in self.target.parents] + ret = parents.pop(0) + while parents: + ret *= parents.pop(0) + return ret @abc.abstractmethod def _make_tuples(self, key): diff --git a/datajoint/erd.py b/datajoint/erd.py index 73af00c7c..f224279ec 100644 --- a/datajoint/erd.py +++ b/datajoint/erd.py @@ -16,10 +16,10 @@ class ERD: - _parents = defaultdict(set) - _children = defaultdict(set) - _references = defaultdict(set) - _referenced = defaultdict(set) + _parents = defaultdict(list) + _children = defaultdict(list) + _references = defaultdict(list) + _referenced = defaultdict(list) def load_dependencies(self, connection, full_table_name, primary_key): # fetch the CREATE TABLE statement @@ -57,11 +57,11 @@ def load_dependencies(self, connection, full_table_name, primary_key): "%s's foreign key refers to differently named attributes in %s" % (self.__class__.__name__, result.referenced_table)) if all(q in primary_key for q in [s.strip('` ') for s in result.attributes.split(',')]): - self._parents[full_table_name].add(result.referenced_table) - self._children[result.referenced_table].add(full_table_name) + self._parents[full_table_name].append(result.referenced_table) + self._children[result.referenced_table].append(full_table_name) else: - self._referenced[full_table_name].add(result.referenced_table) - self._references[result.referenced_table].add(full_table_name) + self._referenced[full_table_name].append(result.referenced_table) + self._references[result.referenced_table].append(full_table_name) @property def parents(self): @@ -80,8 +80,6 @@ def referenced(self): return self._referenced - - def to_camel_case(s): """ Convert names with under score (_) separation diff --git a/datajoint/relation.py b/datajoint/relation.py index 3cdf8e919..534e564aa 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -151,19 +151,15 @@ def full_table_name(self): def insert(self, tup, ignore_errors=False, replace=False): """ - Insert one data record or one Mapping (like a dictionary). + Insert one data record or one Mapping (like a dict). - :param tup: Data record, or a Mapping (like a dictionary). + :param tup: Data record, or a Mapping (like a dict). :param ignore_errors=False: Ignores errors if True. :param replace=False: Replaces data tuple if True. Example:: - - b = djtest.Subject() - b.insert(dict(subject_id = 7, species="mouse",\\ - real_id = 1007, date_of_birth = "2014-09-01")) + rel.insert(dict(subject_id = 7, species="mouse", date_of_birth = "2014-09-01")) """ - heading = self.heading if isinstance(tup, np.void): for fieldname in tup.dtype.fields: @@ -287,3 +283,25 @@ def _alter(self, alter_statement): sql = 'ALTER TABLE %s %s' % (self.full_table_name, alter_statement) self.connection.query(sql) self.heading.init_from_database(self.connection, self.database, self.table_name) + + +class FreeRelation(Relation): + """ + A relation with no definition. Its table must already exist in the database. + """ + definition = None + + def __init__(self, connection, full_table_name): + [database, table_name] = full_table_name.split('.') + self.database = database.strip('`') + self._table_name = table_name.strip('`') + self._heading = Heading() + self._connection = connection + + @property + def connection(self): + return self._connection + + @property + def table_name(self): + return self._table_name \ No newline at end of file From 194f5c2f8f4b8f7fc70d9e5fb191ebf83923c28d Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 26 Jun 2015 12:31:41 -0500 Subject: [PATCH 02/34] renamed populate_relation into populated_from --- datajoint/autopopulate.py | 18 +++++++++--------- datajoint/user_relations.py | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index d8ec6d297..2721d2475 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -13,11 +13,11 @@ class AutoPopulate(metaclass=abc.ABCMeta): """ AutoPopulate is a mixin class that adds the method populate() to a Relation class. Auto-populated relations must inherit from both Relation and AutoPopulate, - must define the property populate_relation, and must define the callback method _make_tuples. + must define the property populated_from, and must define the callback method _make_tuples. """ @property - def populate_relation(self): + def populated_from(self): """ :return: the relation whose primary key values are passed, sequentially, to the _make_tuples method when populate() is called. @@ -45,18 +45,18 @@ def target(self): def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False): """ - rel.populate() calls rel._make_tuples(key) for every primary key in self.populate_relation + rel.populate() calls rel._make_tuples(key) for every primary key in self.populated_from for which there is not already a tuple in rel. - :param restriction: restriction on rel.populate_relation - target + :param restriction: restriction on rel.populated_from - target :param suppress_errors: suppresses error if true :param reserve_jobs: currently not implemented """ assert not reserve_jobs, NotImplemented # issue #5 error_list = [] if suppress_errors else None - if not isinstance(self.populate_relation, RelationalOperand): - raise DataJointError('Invalid populate_relation value') + if not isinstance(self.populated_from, RelationalOperand): + raise DataJointError('Invalid populated_from value') self.connection.cancel_transaction() # rollback previous transaction, if any @@ -64,7 +64,7 @@ def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False): raise DataJointError( 'AutoPopulate is a mixin for Relation and must therefore subclass Relation') - unpopulated = (self.populate_relation - self.target) & restriction + unpopulated = (self.populated_from - self.target) & restriction for key in unpopulated.project(): self.connection.start_transaction() if key in self.target: # already populated @@ -90,7 +90,7 @@ def progress(self): """ report progress of populating this table """ - total = len(self.populate_relation) - remaining = len(self.populate_relation - self.target) + total = len(self.populated_from) + remaining = len(self.populated_from - self.target) print('Remaining %d of %d (%2.1f%%)' % (remaining, total, 100*remaining/total) if remaining else 'Complete', flush=True) diff --git a/datajoint/user_relations.py b/datajoint/user_relations.py index 1b0352a03..38f149615 100644 --- a/datajoint/user_relations.py +++ b/datajoint/user_relations.py @@ -34,7 +34,7 @@ class Subordinate: Mix-in to make computed tables subordinate """ @property - def populate_relation(self): + def populated_from(self): return None def _make_tuples(self, key): From 0b9add91230a7b930163923d8d01e08673f31a08 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 26 Jun 2015 15:08:18 -0500 Subject: [PATCH 03/34] erd.load_dependencies no longer needs the primary_key argument --- datajoint/__init__.py | 2 +- datajoint/autopopulate.py | 6 +++--- datajoint/erd.py | 35 +++++++++++++++++++++++++---------- datajoint/relation.py | 12 +++++++++--- 4 files changed, 38 insertions(+), 17 deletions(-) diff --git a/datajoint/__init__.py b/datajoint/__init__.py index b562aa8d1..309d443cf 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -5,7 +5,7 @@ __version__ = "0.2" __all__ = ['__author__', '__version__', 'Connection', 'Heading', 'Relation', 'FreeRelation', 'Not', - 'Relation', + 'Relation', 'schema', 'Manual', 'Lookup', 'Imported', 'Computed', 'AutoPopulate', 'conn', 'DataJointError', 'blob'] diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index 2721d2475..d222fbf83 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -1,6 +1,6 @@ from .relational_operand import RelationalOperand -from . import DataJointError, Relation -from .relation import FreeRelation +from . import DataJointError +from .relation import Relation, FreeRelation, schema import abc import logging @@ -15,6 +15,7 @@ class AutoPopulate(metaclass=abc.ABCMeta): Auto-populated relations must inherit from both Relation and AutoPopulate, must define the property populated_from, and must define the callback method _make_tuples. """ + _jobs = None @property def populated_from(self): @@ -85,7 +86,6 @@ def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False): logger.info('Done populating.') return error_list - def progress(self): """ report progress of populating this table diff --git a/datajoint/erd.py b/datajoint/erd.py index f224279ec..6e39e05ab 100644 --- a/datajoint/erd.py +++ b/datajoint/erd.py @@ -21,7 +21,7 @@ class ERD: _references = defaultdict(list) _referenced = defaultdict(list) - def load_dependencies(self, connection, full_table_name, primary_key): + def load_dependencies(self, connection, full_table_name): # fetch the CREATE TABLE statement cur = connection.query('SHOW CREATE TABLE %s' % full_table_name) create_statement = cur.fetchone() @@ -29,29 +29,44 @@ def load_dependencies(self, connection, full_table_name, primary_key): raise DataJointError('Could not load the definition table %s' % full_table_name) create_statement = create_statement[1].split('\n') - # build foreign key parser + # build foreign key fk_parser database = full_table_name.split('.')[0].strip('`') add_database = lambda string, loc, toc: ['`{database}`.`{table}`'.format(database=database, table=toc[0])] - parser = pp.CaselessLiteral('CONSTRAINT').suppress() - parser += pp.QuotedString('`').suppress() - parser += pp.CaselessLiteral('FOREIGN KEY').suppress() - parser += pp.QuotedString('(', endQuoteChar=')').setResultsName('attributes') - parser += pp.CaselessLiteral('REFERENCES') - parser += pp.Or([ + # primary key parser + pk_parser = pp.CaselessLiteral('PRIMARY KEY') + pk_parser += pp.QuotedString('(', endQuoteChar=')').setResultsName('primary_key') + + # foreign key parser + fk_parser = pp.CaselessLiteral('CONSTRAINT').suppress() + fk_parser += pp.QuotedString('`').suppress() + fk_parser += pp.CaselessLiteral('FOREIGN KEY').suppress() + fk_parser += pp.QuotedString('(', endQuoteChar=')').setResultsName('attributes') + fk_parser += pp.CaselessLiteral('REFERENCES') + fk_parser += pp.Or([ pp.QuotedString('`').setParseAction(add_database), pp.Combine(pp.QuotedString('`', unquoteResults=False) + '.' + pp.QuotedString('`', unquoteResults=False)) ]).setResultsName('referenced_table') - parser += pp.QuotedString('(', endQuoteChar=')').setResultsName('referenced_attributes') + fk_parser += pp.QuotedString('(', endQuoteChar=')').setResultsName('referenced_attributes') # parse foreign keys + primary_key = None for line in create_statement: + if primary_key is None: + try: + result = pk_parser.parseString(line) + except pp.ParseException: + pass + else: + primary_key = [s.strip(' `') for s in result.primary_key.split(',')] try: - result = parser.parseString(line) + result = fk_parser.parseString(line) except pp.ParseException: pass else: + if not primary_key: + raise DataJointErrr('No primary key found %s' % full_table_name) if result.referenced_attributes != result.attributes: raise DataJointError( "%s's foreign key refers to differently named attributes in %s" diff --git a/datajoint/relation.py b/datajoint/relation.py index 534e564aa..34692b421 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -53,7 +53,7 @@ def decorator(cls): full_table_name=instance.full_table_name, definition=instance.definition, context=context)) - connection.erd.load_dependencies(connection, instance.full_table_name, instance.primary_key) + connection.erd.load_dependencies(connection, instance.full_table_name) return cls return decorator @@ -192,11 +192,17 @@ def insert(self, tup, ignore_errors=False, replace=False): logger.info(sql) self.connection.query(sql, args=args) - def delete(self): + def delete_quick(self): + """ + delete without cascading and without prompting + """ + self.connection.query('DELETE FROM ' + self.from_clause + self.where_clause) + + def delete(self): # TODO: impelment cascading (issue #15) if not config['safemode'] or user_choice( "You are about to delete data from a table. This operation cannot be undone.\n" "Proceed?", default='no') == 'yes': - self.connection.query('DELETE FROM ' + self.from_clause + self.where_clause) # TODO: make cascading (issue #15) + self.delete_quick() def drop(self): """ From d990b35ec5577970d5d35200aa6003e965ce3df3 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 26 Jun 2015 16:40:28 -0500 Subject: [PATCH 04/34] implemented job reservations for distributed computing (issue #5) --- datajoint/autopopulate.py | 51 +++++++++++----------- datajoint/connection.py | 9 ++-- datajoint/jobs.py | 90 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 121 insertions(+), 29 deletions(-) create mode 100644 datajoint/jobs.py diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index d222fbf83..ca2b11876 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -1,8 +1,9 @@ +import abc +import logging from .relational_operand import RelationalOperand from . import DataJointError from .relation import Relation, FreeRelation, schema -import abc -import logging +from . import jobs # noinspection PyExceptionInherit,PyCallingNonCallable @@ -53,37 +54,39 @@ def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False): :param suppress_errors: suppresses error if true :param reserve_jobs: currently not implemented """ - - assert not reserve_jobs, NotImplemented # issue #5 error_list = [] if suppress_errors else None if not isinstance(self.populated_from, RelationalOperand): raise DataJointError('Invalid populated_from value') - self.connection.cancel_transaction() # rollback previous transaction, if any - - if not isinstance(self, Relation): - raise DataJointError( - 'AutoPopulate is a mixin for Relation and must therefore subclass Relation') + if self.connection.in_transaction: + raise DataJointError('Populate cannot be called during a transaction.') unpopulated = (self.populated_from - self.target) & restriction for key in unpopulated.project(): - self.connection.start_transaction() - if key in self.target: # already populated - self.connection.cancel_transaction() - else: - logger.info('Populating: ' + str(key)) - try: - self._make_tuples(dict(key)) - except Exception as error: + if not reserve_jobs or jobs.reserve(self.target.full_table_name, key): + self.connection.start_transaction() + if key in self.target: # already populated self.connection.cancel_transaction() - if not suppress_errors: - raise - else: - logger.error(error) - error_list.append((key, error)) + if reserve_jobs: + jobs.complete(self.target.full_table_name, key) else: - self.connection.commit_transaction() - logger.info('Done populating.') + logger.info('Populating: ' + str(key)) + try: + self._make_tuples(dict(key)) + except Exception as error: + self.connection.cancel_transaction() + if reserve_jobs: + jobs.log_error(self.target.full_table_name, key, + error_message=str(error)) + if not suppress_errors: + raise + else: + logger.error(error) + error_list.append((key, error)) + else: + self.connection.commit_transaction() + if reserve_jobs: + jobs.complete(self.target.full_table_name, key) return error_list def progress(self): diff --git a/datajoint/connection.py b/datajoint/connection.py index 325d1c021..ac279e8a4 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -14,13 +14,13 @@ def conn_container(): """ _connection = None # persistent connection object used by dj.conn() - def conn_function(host=None, user=None, passwd=None, init_fun=None, reset=False): # TODO: thin wrapping layer to mimic singleton + def conn_function(host=None, user=None, passwd=None, init_fun=None, reset=False): """ Manage a persistent connection object. This is one of several ways to configure and access a datajoint connection. Users may customize their own connection manager. - Set rest=True to reset the persistent connection object + Set reset=True to reset the persistent connection object with new connection parameters """ nonlocal _connection if not _connection or reset: @@ -108,7 +108,7 @@ def query(self, query, args=(), as_dict=False): cur.execute(query, args) return cur - # ---------- transaction processing ------------------ + # ---------- transaction processing @property def in_transaction(self): self._in_transaction = self._in_transaction and self.is_connected @@ -131,8 +131,7 @@ def commit_transaction(self): self._in_transaction = False logger.info("Transaction committed and closed.") - - #-------- context manager for transactions + # -------- context manager for transactions @contextmanager def transaction(self): try: diff --git a/datajoint/jobs.py b/datajoint/jobs.py new file mode 100644 index 000000000..9542bdea5 --- /dev/null +++ b/datajoint/jobs.py @@ -0,0 +1,90 @@ +import hashlib +import os +import pymysql + +from .relation import Relation, schema + + +def get_jobs_table(database): + """ + :return: the base relation of the job reservation table for database + """ + self = get_jobs_table + if not hasattr(self, 'lookup'): + self.lookup = {} + + if database not in self.lookup: + @schema(database, context={}) + class JobsRelation(Relation): + definition = """ + # the job reservation table + table_name : varchar(255) # className of the table + key_hash : char(32) # key hash + --- + status : enum('reserved','error','ignore')# if tuple is missing, the job is available + key=null : blob # structure containing the key + error_message="" : varchar(1023) # error message returned if failed + error_stack=null : blob # error stack if failed + host="" : varchar(255) # system hostname + pid=0 : int unsigned # system process id + timestamp=CURRENT_TIMESTAMP : timestamp # automatic timestamp + """ + + @property + def table_name(self): + return '~jobs' + + self.lookup[database] = JobsRelation() + + return self.lookup[database] + +def split_name(full_table_name): + [database, table_name] = full_table_name.split('.') + return database.strip('` '), table_name.strip('` ') + + +def key_hash(key): + hashed = hashlib.md5() + for k, v in sorted(key.items()): + hashed.update(str(v).encode()) + return hashed.hexdigest() + + +def reserve(full_table_name, key): + """ + Insert a reservation record in the jobs table + :return: True if reserved job successfully + """ + database, table_name = split_name(full_table_name) + jobs = get_jobs_table(database) + job_key = dict(table_name=table_name, key_hash=key_hash(key)) + if jobs & job_key: + return False + try: + jobs.insert(dict(job_key, status="reserved", host=os.uname().nodename, pid=os.getpid())) + except pymysql.err.IntegrityError: + success = False + else: + success = True + return success + + +def complete(full_table_name, key): + """ + upon job completion the job entry is removed + """ + database, table_name = split_name(full_table_name) + job_key = dict(table_name=table_name, key_hash=key_hash(key)) + entry = get_jobs_table(full_table_name) & job_key + entry.delete_quick() + + +def log_error(full_table_name, key, error_message, error_stack=None): + database, table_name = split_name(full_table_name) + job_key = dict(table_name=table_name, key_hash=key_hash(key)) + jobs = get_jobs_table(database) + jobs.insert(dict(job_key, + status="error", + host=os.uname(), + pid=os.getpid(), + error_message=error_message), replace=True) From 13fc228f7ce0ffc81d92b395401a27bc850b466c Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 26 Jun 2015 17:02:47 -0500 Subject: [PATCH 05/34] simplified dj.conn() --- datajoint/autopopulate.py | 3 +-- datajoint/connection.py | 41 +++++++++++++-------------------------- datajoint/jobs.py | 23 ++++++++++++---------- 3 files changed, 27 insertions(+), 40 deletions(-) diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index ca2b11876..a7c08ab74 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -76,8 +76,7 @@ def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False): except Exception as error: self.connection.cancel_transaction() if reserve_jobs: - jobs.log_error(self.target.full_table_name, key, - error_message=str(error)) + jobs.error(self.target.full_table_name, key, error_message=str(error)) if not suppress_errors: raise else: diff --git a/datajoint/connection.py b/datajoint/connection.py index ac279e8a4..393b9862c 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -8,36 +8,21 @@ logger = logging.getLogger(__name__) -def conn_container(): +def conn(host=None, user=None, passwd=None, init_fun=None, reset=False): """ - creates a persistent connections for everyone to use + Returns a persistent connection object to be shared by multiple modules. + If the connection is not yet established or reset=True, a new connection is set up. + If connection information is not provided, it is taken from config. """ - _connection = None # persistent connection object used by dj.conn() - - def conn_function(host=None, user=None, passwd=None, init_fun=None, reset=False): - """ - Manage a persistent connection object. - This is one of several ways to configure and access a datajoint connection. - Users may customize their own connection manager. - - Set reset=True to reset the persistent connection object with new connection parameters - """ - nonlocal _connection - if not _connection or reset: - host = host if host is not None else config['database.host'] - user = user if user is not None else config['database.user'] - passwd = passwd if passwd is not None else config['database.password'] - - if passwd is None: passwd = input("Please enter database password: ") - - init_fun = init_fun if init_fun is not None else config['connection.init_function'] - _connection = Connection(host, user, passwd, init_fun) - return _connection - - return conn_function - -# The function conn is used by others to obtain a connection object -conn = conn_container() + if not hasattr(conn, 'connection') or reset: + host = host if host is not None else config['database.host'] + user = user if user is not None else config['database.user'] + passwd = passwd if passwd is not None else config['database.password'] + if passwd is None: + passwd = input("Please enter database password: ") + init_fun = init_fun if init_fun is not None else config['connection.init_function'] + conn.connection = Connection(host, user, passwd, init_fun) + return conn.connection class Connection: diff --git a/datajoint/jobs.py b/datajoint/jobs.py index 9542bdea5..1dc0c820c 100644 --- a/datajoint/jobs.py +++ b/datajoint/jobs.py @@ -18,16 +18,16 @@ def get_jobs_table(database): class JobsRelation(Relation): definition = """ # the job reservation table - table_name : varchar(255) # className of the table - key_hash : char(32) # key hash + table_name: varchar(255) # className of the table + key_hash: char(32) # key hash --- - status : enum('reserved','error','ignore')# if tuple is missing, the job is available - key=null : blob # structure containing the key - error_message="" : varchar(1023) # error message returned if failed - error_stack=null : blob # error stack if failed - host="" : varchar(255) # system hostname - pid=0 : int unsigned # system process id - timestamp=CURRENT_TIMESTAMP : timestamp # automatic timestamp + status: enum('reserved','error','ignore')# if tuple is missing, the job is available + key=null: blob # structure containing the key + error_message="": varchar(1023) # error message returned if failed + error_stack=null: blob # error stack if failed + host="": varchar(255) # system hostname + pid=0: int unsigned # system process id + timestamp=CURRENT_TIMESTAMP: timestamp # automatic timestamp """ @property @@ -79,7 +79,10 @@ def complete(full_table_name, key): entry.delete_quick() -def log_error(full_table_name, key, error_message, error_stack=None): +def error(full_table_name, key, error_message): + """ + if an error occurs, leave an entry describing the problem + """ database, table_name = split_name(full_table_name) job_key = dict(table_name=table_name, key_hash=key_hash(key)) jobs = get_jobs_table(database) From ad98158b6e2e9a5045d95dd152f7af34541991fb Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 26 Jun 2015 17:14:38 -0500 Subject: [PATCH 06/34] minor bugfix --- datajoint/autopopulate.py | 12 +++++------ datajoint/erd.py | 2 +- datajoint/jobs.py | 45 ++++++++++++++++++++++----------------- 3 files changed, 31 insertions(+), 28 deletions(-) diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index a7c08ab74..bc1c725bc 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -61,22 +61,21 @@ def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False): if self.connection.in_transaction: raise DataJointError('Populate cannot be called during a transaction.') + full_table_name = self.target.full_table_name unpopulated = (self.populated_from - self.target) & restriction for key in unpopulated.project(): - if not reserve_jobs or jobs.reserve(self.target.full_table_name, key): + if jobs.reserve(reserve_jobs, full_table_name, key): self.connection.start_transaction() if key in self.target: # already populated self.connection.cancel_transaction() - if reserve_jobs: - jobs.complete(self.target.full_table_name, key) + jobs.complete(reserve_jobs, full_table_name, key) else: logger.info('Populating: ' + str(key)) try: self._make_tuples(dict(key)) except Exception as error: self.connection.cancel_transaction() - if reserve_jobs: - jobs.error(self.target.full_table_name, key, error_message=str(error)) + jobs.error(reserve_jobs, full_table_name, key, error_message=str(error)) if not suppress_errors: raise else: @@ -84,8 +83,7 @@ def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False): error_list.append((key, error)) else: self.connection.commit_transaction() - if reserve_jobs: - jobs.complete(self.target.full_table_name, key) + jobs.complete(reserve_jobs, full_table_name, key) return error_list def progress(self): diff --git a/datajoint/erd.py b/datajoint/erd.py index 6e39e05ab..23e83d426 100644 --- a/datajoint/erd.py +++ b/datajoint/erd.py @@ -66,7 +66,7 @@ def load_dependencies(self, connection, full_table_name): pass else: if not primary_key: - raise DataJointErrr('No primary key found %s' % full_table_name) + raise DataJointError('No primary key found %s' % full_table_name) if result.referenced_attributes != result.attributes: raise DataJointError( "%s's foreign key refers to differently named attributes in %s" diff --git a/datajoint/jobs.py b/datajoint/jobs.py index 1dc0c820c..07e8a8f8c 100644 --- a/datajoint/jobs.py +++ b/datajoint/jobs.py @@ -14,20 +14,21 @@ def get_jobs_table(database): self.lookup = {} if database not in self.lookup: + @schema(database, context={}) class JobsRelation(Relation): definition = """ # the job reservation table - table_name: varchar(255) # className of the table - key_hash: char(32) # key hash + table_name: varchar(255) # className of the table + key_hash: char(32) # key hash --- - status: enum('reserved','error','ignore')# if tuple is missing, the job is available - key=null: blob # structure containing the key + status: enum('reserved','error','ignore') # if tuple is missing, the job is available + key=null: blob # structure containing the key error_message="": varchar(1023) # error message returned if failed error_stack=null: blob # error stack if failed host="": varchar(255) # system hostname pid=0: int unsigned # system process id - timestamp=CURRENT_TIMESTAMP: timestamp # automatic timestamp + timestamp=CURRENT_TIMESTAMP: timestamp # automatic timestamp """ @property @@ -50,11 +51,13 @@ def key_hash(key): return hashed.hexdigest() -def reserve(full_table_name, key): +def reserve(reserve_jobs, full_table_name, key): """ Insert a reservation record in the jobs table :return: True if reserved job successfully """ + if not reserve_jobs: + return True database, table_name = split_name(full_table_name) jobs = get_jobs_table(database) job_key = dict(table_name=table_name, key_hash=key_hash(key)) @@ -69,25 +72,27 @@ def reserve(full_table_name, key): return success -def complete(full_table_name, key): +def complete(reserve_jobs, full_table_name, key): """ upon job completion the job entry is removed """ - database, table_name = split_name(full_table_name) - job_key = dict(table_name=table_name, key_hash=key_hash(key)) - entry = get_jobs_table(full_table_name) & job_key - entry.delete_quick() + if reserve_jobs: + database, table_name = split_name(full_table_name) + job_key = dict(table_name=table_name, key_hash=key_hash(key)) + entry = get_jobs_table(full_table_name) & job_key + entry.delete_quick() -def error(full_table_name, key, error_message): +def error(reserve_jobs, full_table_name, key, error_message): """ if an error occurs, leave an entry describing the problem """ - database, table_name = split_name(full_table_name) - job_key = dict(table_name=table_name, key_hash=key_hash(key)) - jobs = get_jobs_table(database) - jobs.insert(dict(job_key, - status="error", - host=os.uname(), - pid=os.getpid(), - error_message=error_message), replace=True) + if reserve_jobs: + database, table_name = split_name(full_table_name) + job_key = dict(table_name=table_name, key_hash=key_hash(key)) + jobs = get_jobs_table(database) + jobs.insert(dict(job_key, + status="error", + host=os.uname(), + pid=os.getpid(), + error_message=error_message), replace=True) From e5d854442eb19e22a60ec41f751272b6065c8956 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Sun, 28 Jun 2015 22:00:06 -0500 Subject: [PATCH 07/34] implemented recursive drop: Fixed issue #16. --- datajoint/erd.py | 18 ++++++++++++++++++ datajoint/jobs.py | 2 +- datajoint/relation.py | 40 +++++++++++++++++++++++++++++++--------- 3 files changed, 50 insertions(+), 10 deletions(-) diff --git a/datajoint/erd.py b/datajoint/erd.py index 23e83d426..d3782b574 100644 --- a/datajoint/erd.py +++ b/datajoint/erd.py @@ -94,6 +94,24 @@ def references(self): def referenced(self): return self._referenced + def get_descendants(self, full_table_name): + """ + :param full_table_name: a table name in the format `database`.`table_name` + :return: list of all children and references, in order of dependence. + This is helpful for cascading delete or drop operations. + """ + ret = defaultdict(lambda: 0) + + def recurse(full_table_name, level): + if level > ret[full_table_name]: + ret[full_table_name] = level + for child in self.children[full_table_name] + self.references[full_table_name]: + recurse(child, level+1) + + recurse(full_table_name, 0) + return sorted(ret.keys(), key=ret.__getitem__) + + def to_camel_case(s): """ diff --git a/datajoint/jobs.py b/datajoint/jobs.py index 07e8a8f8c..2774b3488 100644 --- a/datajoint/jobs.py +++ b/datajoint/jobs.py @@ -14,7 +14,7 @@ def get_jobs_table(database): self.lookup = {} if database not in self.lookup: - + @schema(database, context={}) class JobsRelation(Relation): definition = """ diff --git a/datajoint/relation.py b/datajoint/relation.py index 34692b421..c81c22f17 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -1,4 +1,5 @@ from collections.abc import Mapping +from collections import defaultdict import numpy as np import logging import abc @@ -128,6 +129,15 @@ def references(self): def referenced(self): return self.connection.erd.referenced[self.full_table_name] + @property + def descendants(self): + """ + :return: list of relation objects for all children and references, recursively, + in order of dependence. + This is helpful for cascading delete or drop operations. + """ + return [FreeRelation(self.connection, table) + for table in self.connection.erd.get_descendants(self.full_table_name)] # --------- SQL functionality --------- # @property @@ -204,18 +214,30 @@ def delete(self): # TODO: impelment cascading (issue #15) "Proceed?", default='no') == 'yes': self.delete_quick() - def drop(self): + def drop_quick(self): """ - Drops the table associated to this class. + Drops the table associated with this relation without cascading and without user prompt. """ if self.is_declared: - if not config['safemode'] or user_choice( - "You are about to drop an entire table. This operation cannot be undone.\n" - "Proceed?", default='no') == 'yes': - self.connection.query('DROP TABLE %s' % self.full_table_name) # TODO: make cascading (issue #16) - # cls.connection.clear_dependencies(dbname=cls.dbname) #TODO: reimplement because clear_dependencies will be gone - # cls.connection.load_headings(dbname=cls.dbname, force=True) #TODO: reimplement because load_headings is gone - logger.info("Dropped table %s" % self.full_table_name) + self.connection.query('DROP TABLE %s' % self.full_table_name) + logger.info("Dropped table %s" % self.full_table_name) + + def drop(self): + """ + Drop the table and all tables that reference it, recursively. + User is prompted before. + """ + do_drop = True + relations = self.descendants + if not config['safemode']: + print('The following tables are about to be dropped:') + for relation in relations: + print(relation.full_table_name, '(%d tuples)' % len(relation)) + do_drop = user_choice("Proceed?", default='no') == 'yes' + if do_drop: + while relations: + relations.pop().drop_quick() + print('Dropped tables..') def size_on_disk(self): """ From 69a28f3b4eda954b606e4865984e7875b4f5530c Mon Sep 17 00:00:00 2001 From: Fabian Sinz Date: Sun, 28 Jun 2015 22:28:53 -0500 Subject: [PATCH 08/34] group_by --- datajoint/relational_operand.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/datajoint/relational_operand.py b/datajoint/relational_operand.py index 367fd2d1e..31336af90 100644 --- a/datajoint/relational_operand.py +++ b/datajoint/relational_operand.py @@ -95,6 +95,7 @@ def project(self, *attributes, **renamed_attributes): def aggregate(self, _group, *attributes, **renamed_attributes): """ Relational aggregation operator + :param group: relation whose tuples can be used in aggregation operators :param extensions: :return: a relation representing the aggregation/projection operator result @@ -103,6 +104,17 @@ def aggregate(self, _group, *attributes, **renamed_attributes): raise DataJointError('The second argument must be a relation or None') return Projection(self, _group, *attributes, **renamed_attributes) + def group_by(self, *attributes): + r = self.project(*attributes).fetch() + dtype2 = np.dtype({name:r.dtype.fields[name] for name in attributes}) + r2 = np.unique(np.ndarray(r.shape, dtype2, r, 0, r.strides)) + for nk in r2: + restr = ' and '.join(["%s='%s'" % (fn, str(v)) for fn, v in zip(r2.dtype.names, nk)]) + if len(nk) == 1: + yield nk[0], self & restr + else: + yield nk, self & restr + def __iand__(self, restriction): """ in-place relational restriction or semijoin From ed5dd2a70c63f29d77f0ce4fc0fc2fb3712061de Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Sun, 28 Jun 2015 23:34:24 -0500 Subject: [PATCH 09/34] removed alter methods from Relation in preparation for issue #110. --- datajoint/erd.py | 1 - datajoint/jobs.py | 2 +- datajoint/relation.py | 65 +++---------------------------------------- 3 files changed, 5 insertions(+), 63 deletions(-) diff --git a/datajoint/erd.py b/datajoint/erd.py index d3782b574..17f8f630c 100644 --- a/datajoint/erd.py +++ b/datajoint/erd.py @@ -112,7 +112,6 @@ def recurse(full_table_name, level): return sorted(ret.keys(), key=ret.__getitem__) - def to_camel_case(s): """ Convert names with under score (_) separation diff --git a/datajoint/jobs.py b/datajoint/jobs.py index 2774b3488..f73cd1063 100644 --- a/datajoint/jobs.py +++ b/datajoint/jobs.py @@ -14,7 +14,6 @@ def get_jobs_table(database): self.lookup = {} if database not in self.lookup: - @schema(database, context={}) class JobsRelation(Relation): definition = """ @@ -39,6 +38,7 @@ def table_name(self): return self.lookup[database] + def split_name(full_table_name): [database, table_name] = full_table_name.split('.') return database.strip('` '), table_name.strip('` ') diff --git a/datajoint/relation.py b/datajoint/relation.py index c81c22f17..76923c8ff 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -136,8 +136,9 @@ def descendants(self): in order of dependence. This is helpful for cascading delete or drop operations. """ - return [FreeRelation(self.connection, table) - for table in self.connection.erd.get_descendants(self.full_table_name)] + relations = (FreeRelation(self.connection, table) + for table in self.connection.erd.get_descendants(self.full_table_name)) + return [relation for relation in relations if relation.is_declared] # --------- SQL functionality --------- # @property @@ -249,70 +250,12 @@ def size_on_disk(self): ).fetchone() return (ret['Data_length'] + ret['Index_length'])/1024**2 - def set_table_comment(self, comment): - """ - Update the table comment in the table definition. - :param comment: new comment as string - """ - self._alter('COMMENT="%s"' % comment) - - def add_attribute(self, definition, after=None): - """ - Add a new attribute to the table. A full line from the table definition - is passed in as definition. - - The definition can specify where to place the new attribute. Use after=None - to add the attribute as the first attribute or after='attribute' to place it - after an existing attribute. - - :param definition: table definition - :param after=None: After which attribute of the table the new attribute is inserted. - If None, the attribute is inserted in front. - """ - position = ' FIRST' if after is None else ( - ' AFTER %s' % after if after else '') - sql = compile_attribute(definition)[1] - self._alter('ADD COLUMN %s%s' % (sql, position)) - - def drop_attribute(self, attribute_name): - """ - Drops the attribute attrName from this table. - :param attribute_name: Name of the attribute that is dropped. - """ - if not config['safemode'] or user_choice( - "You are about to drop an attribute from a table." - "This operation cannot be undone.\n" - "Proceed?", default='no') == 'yes': - self._alter('DROP COLUMN `%s`' % attribute_name) - - def alter_attribute(self, attribute_name, definition): - """ - Alter attribute definition - - :param attribute_name: field that is redefined - :param definition: new definition of the field - """ - sql = compile_attribute(definition)[1] - self._alter('CHANGE COLUMN `%s` %s' % (attribute_name, sql)) - def erd(self, subset=None): """ Plot the schema's entity relationship diagram (ERD). """ NotImplemented - def _alter(self, alter_statement): - """ - Execute an ALTER TABLE statement. - :param alter_statement: alter statement - """ - if self.connection.in_transaction: - raise DataJointError("Table definition cannot be altered during a transaction.") - sql = 'ALTER TABLE %s %s' % (self.full_table_name, alter_statement) - self.connection.query(sql) - self.heading.init_from_database(self.connection, self.database, self.table_name) - - class FreeRelation(Relation): """ A relation with no definition. Its table must already exist in the database. @@ -332,4 +275,4 @@ def connection(self): @property def table_name(self): - return self._table_name \ No newline at end of file + return self._table_name From 0ef3f8200a5bd72b5aabee0d73343405430c169f Mon Sep 17 00:00:00 2001 From: Fabian Sinz Date: Mon, 29 Jun 2015 07:50:50 -0500 Subject: [PATCH 10/34] added sorted --- datajoint/relational_operand.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/datajoint/relational_operand.py b/datajoint/relational_operand.py index 31336af90..a24d4f7ef 100644 --- a/datajoint/relational_operand.py +++ b/datajoint/relational_operand.py @@ -104,10 +104,11 @@ def aggregate(self, _group, *attributes, **renamed_attributes): raise DataJointError('The second argument must be a relation or None') return Projection(self, _group, *attributes, **renamed_attributes) - def group_by(self, *attributes): + def group_by(self, *attributes, sortby=None): r = self.project(*attributes).fetch() dtype2 = np.dtype({name:r.dtype.fields[name] for name in attributes}) r2 = np.unique(np.ndarray(r.shape, dtype2, r, 0, r.strides)) + r2.sort(order=sortby if sortby is not None else attributes) for nk in r2: restr = ' and '.join(["%s='%s'" % (fn, str(v)) for fn, v in zip(r2.dtype.names, nk)]) if len(nk) == 1: From 36e360fa57ff4e5820f9fdd6514f16ae3128af98 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Tue, 30 Jun 2015 19:21:14 -0500 Subject: [PATCH 11/34] now clearing dependencies when a table is dropped. --- datajoint/erd.py | 6 ++++++ datajoint/relation.py | 12 ++++++------ tests/__init__.py | 4 ++-- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/datajoint/erd.py b/datajoint/erd.py index 17f8f630c..03877ce54 100644 --- a/datajoint/erd.py +++ b/datajoint/erd.py @@ -78,6 +78,12 @@ def load_dependencies(self, connection, full_table_name): self._referenced[full_table_name].append(result.referenced_table) self._references[result.referenced_table].append(full_table_name) + def clear_dependency(self, full_table_name): + for ref in self.children.pop(full_table_name): + self.parents.remove(ref) + for ref in self.references.pop(full_table_name): + self.referenced.remove(ref) + @property def parents(self): return self._parents diff --git a/datajoint/relation.py b/datajoint/relation.py index 76923c8ff..e3b0e66ac 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -175,17 +175,16 @@ def insert(self, tup, ignore_errors=False, replace=False): if isinstance(tup, np.void): for fieldname in tup.dtype.fields: if fieldname not in heading: - raise KeyError(u'{0:s} is not in the attribute list'.format(fieldname, )) + raise KeyError(u'{0:s} is not in the attribute list'.format(fieldname)) value_list = ','.join([repr(tup[name]) if not heading[name].is_blob else '%s' for name in heading if name in tup.dtype.fields]) - args = tuple(pack(tup[name]) for name in heading if name in tup.dtype.fields and heading[name].is_blob) attribute_list = '`' + '`,`'.join(q for q in heading if q in tup.dtype.fields) + '`' elif isinstance(tup, Mapping): for fieldname in tup.keys(): if fieldname not in heading: - raise KeyError(u'{0:s} is not in the attribute list'.format(fieldname, )) + raise KeyError(u'{0:s} is not in the attribute list'.format(fieldname)) value_list = ','.join(repr(tup[name]) if not heading[name].is_blob else '%s' for name in heading if name in tup) args = tuple(pack(tup[name]) for name in heading @@ -205,7 +204,7 @@ def insert(self, tup, ignore_errors=False, replace=False): def delete_quick(self): """ - delete without cascading and without prompting + delete without cascading and without user prompt """ self.connection.query('DELETE FROM ' + self.from_clause + self.where_clause) @@ -221,16 +220,17 @@ def drop_quick(self): """ if self.is_declared: self.connection.query('DROP TABLE %s' % self.full_table_name) + self.connection.erd.clear_dependency(self.full_table_name) logger.info("Dropped table %s" % self.full_table_name) def drop(self): """ Drop the table and all tables that reference it, recursively. - User is prompted before. + User is prompted for confirmation if config['safemode'] """ do_drop = True relations = self.descendants - if not config['safemode']: + if config['safemode']: print('The following tables are about to be dropped:') for relation in relations: print(relation.full_table_name, '(%d tuples)' % len(relation)) diff --git a/tests/__init__.py b/tests/__init__.py index 611f34bc2..143b718da 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -39,10 +39,10 @@ def teardown(): # start a transaction now cur.execute("SHOW DATABASES LIKE '{}\_%'".format(PREFIX)) dbs = [x[0] for x in cur.fetchall()] - cur.execute('SET FOREIGN_KEY_CHECKS=0') # unset foreign key check while deleting + cur.execute('SET FOREIGN_KEY_CHECKS=0') # unset foreign key check while deleting for db in dbs: cur.execute('DROP DATABASE `{}`'.format(db)) - cur.execute('SET FOREIGN_KEY_CHECKS=1') # set foreign key check back on + cur.execute('SET FOREIGN_KEY_CHECKS=1') # set foreign key check back on cur.execute("COMMIT") def cleanup(): From 7c461cedf04ad299831313d2f7c29d3c462d6ddd Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Tue, 30 Jun 2015 19:48:57 -0500 Subject: [PATCH 12/34] fixed nosetests --- datajoint/relation.py | 6 ------ tests/test_blob.py | 7 ------- tests/test_relation.py | 14 ++------------ 3 files changed, 2 insertions(+), 25 deletions(-) diff --git a/datajoint/relation.py b/datajoint/relation.py index e3b0e66ac..2d34e0b90 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -250,12 +250,6 @@ def size_on_disk(self): ).fetchone() return (ret['Data_length'] + ret['Index_length'])/1024**2 - def erd(self, subset=None): - """ - Plot the schema's entity relationship diagram (ERD). - """ - NotImplemented - class FreeRelation(Relation): """ A relation with no definition. Its table must already exist in the database. diff --git a/tests/test_blob.py b/tests/test_blob.py index 322be02cc..8809d74b3 100644 --- a/tests/test_blob.py +++ b/tests/test_blob.py @@ -1,6 +1,3 @@ -from datajoint import DataJointError - -__author__ = 'fabee' import numpy as np from datajoint.blob import pack, unpack from numpy.testing import assert_array_equal, raises @@ -19,10 +16,6 @@ def test_pack(): x = np.int16(np.random.randn(1, 2, 3)) assert_array_equal(x, unpack(pack(x)), "Arrays do not match!") -@raises(DataJointError) -def test_error(): - pack(dict()) - def test_complex(): z = np.random.randn(8, 10) + 1j*np.random.randn(8,10) assert_array_equal(z, unpack(pack(z)), "Arrays do not match!") diff --git a/tests/test_relation.py b/tests/test_relation.py index fd27876ce..9db5ab162 100644 --- a/tests/test_relation.py +++ b/tests/test_relation.py @@ -1,17 +1,9 @@ -# import random -# import string -# -# __author__ = 'fabee' -# -# from .schemata.schema1 import test1, test4 import random import string -import pymysql from datajoint import DataJointError from .schemata.test1 import Subjects, Animals, Matrix, Trials, SquaredScore, SquaredSubtable, WrongImplementation, \ ErrorGenerator, testschema from . import BASE_CONN, CONN_INFO, PREFIX, cleanup -# from datajoint.connection import Connection from nose.tools import assert_raises, assert_equal, assert_regexp_matches, assert_false, assert_true, assert_list_equal, \ assert_tuple_equal, assert_dict_equal, raises # from datajoint import DataJointError, TransactionError, AutoPopulate, Relation @@ -20,8 +12,6 @@ import numpy as np import datajoint as dj -# -# def trial_faker(n=10): def iter(): for s in [1, 2]: @@ -62,7 +52,7 @@ def test_table_name_computed(self): assert_true(self.score.table_name.startswith('__')) def test_population_relation_subordinate(self): - assert_true(self.subtable.populate_relation is None) + assert_true(self.subtable.populated_from is None) @raises(NotImplementedError) def test_make_tubles_not_implemented_subordinate(self): @@ -401,7 +391,7 @@ def test_autopopulate_restriction(self): def test_autopopulate_relation_check(self): @testschema class dummy(dj.Computed): - def populate_relation(self): + def populated_from(self): return None def _make_tuples(self, key): From 94a859de627f48ec521e3a637b267fe7e43f92c1 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Wed, 1 Jul 2015 14:59:43 -0500 Subject: [PATCH 13/34] cleaned up tests, made Connection.transaction a property since it will not take any arguments --- datajoint/connection.py | 3 +- datajoint/relation.py | 23 +++- tests/__init__.py | 100 ++++---------- tests/schemas.py | 90 +++++++++++++ tests/schemata/__init__.py | 1 - tests/schemata/schema1/__init__.py | 5 - tests/schemata/schema1/test2.py | 48 ------- tests/schemata/schema1/test3.py | 21 --- tests/schemata/schema1/test4.py | 17 --- tests/schemata/schema2/__init__.py | 2 - tests/schemata/schema2/test1.py | 16 --- tests/schemata/test1.py | 175 ------------------------ tests/test_blob.py | 3 +- tests/test_connection.py | 127 +++++++++--------- tests/test_free_relation.py | 205 ----------------------------- tests/test_heading.py | 1 - tests/test_relation.py | 169 +++++++++--------------- tests/test_settings.py | 1 - tests/test_utils.py | 6 +- 19 files changed, 259 insertions(+), 754 deletions(-) create mode 100644 tests/schemas.py delete mode 100644 tests/schemata/__init__.py delete mode 100644 tests/schemata/schema1/__init__.py delete mode 100644 tests/schemata/schema1/test2.py delete mode 100644 tests/schemata/schema1/test3.py delete mode 100644 tests/schemata/schema1/test4.py delete mode 100644 tests/schemata/schema2/__init__.py delete mode 100644 tests/schemata/schema2/test1.py delete mode 100644 tests/schemata/test1.py delete mode 100644 tests/test_free_relation.py delete mode 100644 tests/test_heading.py diff --git a/datajoint/connection.py b/datajoint/connection.py index 393b9862c..eb2943c02 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -117,8 +117,9 @@ def commit_transaction(self): logger.info("Transaction committed and closed.") # -------- context manager for transactions + @property @contextmanager - def transaction(self): + def transaction(self): # TODO: make a property since it does not take any arguments try: self.start_transaction() yield self diff --git a/datajoint/relation.py b/datajoint/relation.py index 2d34e0b90..b269180c7 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -208,11 +208,24 @@ def delete_quick(self): """ self.connection.query('DELETE FROM ' + self.from_clause + self.where_clause) - def delete(self): # TODO: impelment cascading (issue #15) - if not config['safemode'] or user_choice( - "You are about to delete data from a table. This operation cannot be undone.\n" - "Proceed?", default='no') == 'yes': - self.delete_quick() + def delete(self): + """ + Delete the contents of the table and its dependent tables, recursively. + User is prompted for confirmation if config['safemode'] + """ + relations = self.descendants + if self.restrictions and len(relations)>1: + raise NotImplementedError('Restricted cascading deletes are not yet implemented') + do_delete = True + if config['safemode']: + print('The contents of the following tables are about to be deleted:') + for relation in relations: + print(relation.full_table_name, '(%d tuples)'' % len(relation)') + do_delete = user_choice("Proceed?", default='no') == 'yes' + if do_delete: + with self.connection.transaction: + while relations: + relations.pop().delete_quick() def drop_quick(self): """ diff --git a/tests/__init__.py b/tests/__init__.py index 143b718da..8bd43ace9 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -5,98 +5,44 @@ after the test. """ -import pymysql +__author__ = 'Edgar Walker, Fabian Sinz, Dimitri Yatsenko' + import logging from os import environ import datajoint as dj -logging.basicConfig(level=logging.DEBUG) +__all__ = ['__author__', 'PREFIX', 'CONN_INFO'] -# Connection information for testing -CONN_INFO = { - 'host': environ.get('DJ_TEST_HOST', 'localhost'), - 'user': environ.get('DJ_TEST_USER', 'datajoint'), - 'passwd': environ.get('DJ_TEST_PASSWORD', 'datajoint') -} +logging.basicConfig(level=logging.DEBUG) -conn = dj.conn(**CONN_INFO) +# Connection for testing +CONN_INFO = dict( + host=environ.get('DJ_TEST_HOST', 'localhost'), + user=environ.get('DJ_TEST_USER', 'datajoint'), + passwd=environ.get('DJ_TEST_PASSWORD', 'datajoint')) # Prefix for all databases used during testing PREFIX = environ.get('DJ_TEST_DB_PREFIX', 'djtest') -# Bare connection used for verification of query results -BASE_CONN = pymysql.connect(**CONN_INFO) -BASE_CONN.autocommit(True) -def setup(): - cleanup() +def setup_package(): + """ + Package-level unit test setup + :return: + """ dj.config['safemode'] = False -def teardown(): - cur = BASE_CONN.cursor() - # cancel any unfinished transactions - cur.execute("ROLLBACK") - # start a transaction now - cur.execute("SHOW DATABASES LIKE '{}\_%'".format(PREFIX)) - dbs = [x[0] for x in cur.fetchall()] - cur.execute('SET FOREIGN_KEY_CHECKS=0') # unset foreign key check while deleting - for db in dbs: - cur.execute('DROP DATABASE `{}`'.format(db)) - cur.execute('SET FOREIGN_KEY_CHECKS=1') # set foreign key check back on - cur.execute("COMMIT") -def cleanup(): +def teardown_package(): """ - Removes all databases with name starting with the prefix. + Package-level unit test teardown. + Removes all databases with name starting with PREFIX. To deal with possible foreign key constraints, it will unset and then later reset FOREIGN_KEY_CHECKS flag """ - - cur = BASE_CONN.cursor() - # cancel any unfinished transactions - cur.execute("ROLLBACK") - # start a transaction now - cur.execute("START TRANSACTION WITH CONSISTENT SNAPSHOT") - cur.execute("SHOW DATABASES LIKE '{}\_%'".format(PREFIX)) - dbs = [x[0] for x in cur.fetchall()] - cur.execute('SET FOREIGN_KEY_CHECKS=0') # unset foreign key check while deleting - for db in dbs: - cur.execute("USE %s" % (db)) - cur.execute("SHOW TABLES") - for table in [x[0] for x in cur.fetchall()]: - cur.execute('DELETE FROM `{}`'.format(table)) - cur.execute('SET FOREIGN_KEY_CHECKS=1') # set foreign key check back on - cur.execute("COMMIT") -# -# def setup_sample_db(): -# """ -# Helper method to setup databases with tables to be used -# during the test -# """ -# cur = BASE_CONN.cursor() -# cur.execute("CREATE DATABASE IF NOT EXISTS `{}_test1`".format(PREFIX)) -# cur.execute("CREATE DATABASE IF NOT EXISTS `{}_test2`".format(PREFIX)) -# query1 = """ -# CREATE TABLE `{prefix}_test1`.`subjects` -# ( -# subject_id SMALLINT COMMENT 'Unique subject ID', -# subject_name VARCHAR(255) COMMENT 'Subject name', -# subject_email VARCHAR(255) COMMENT 'Subject email address', -# PRIMARY KEY (subject_id) -# ) -# """.format(prefix=PREFIX) -# cur.execute(query1) -# query2 = """ -# CREATE TABLE `{prefix}_test2`.`experimenter` -# ( -# experimenter_id SMALLINT COMMENT 'Unique experimenter ID', -# experimenter_name VARCHAR(255) COMMENT 'Experimenter name', -# PRIMARY KEY (experimenter_id) -# )""".format(prefix=PREFIX) -# cur.execute(query2) -# -# -# -# -# -# + conn = dj.conn(**CONN_INFO) + conn.query('SET FOREIGN_KEY_CHECKS=0') + cur = conn.query('SHOW DATABASES LIKE "{}\_%%"'.format(PREFIX)) + for db in cur.fetchone(): + conn.query('DROP DATABASE `{}`'.format(db)) + conn.query('SET FOREIGN_KEY_CHECKS=1') diff --git a/tests/schemas.py b/tests/schemas.py new file mode 100644 index 000000000..378ec3d24 --- /dev/null +++ b/tests/schemas.py @@ -0,0 +1,90 @@ +""" +Test schema definition +""" + +import datajoint as dj +from . import PREFIX, CONN_INFO + +schema = dj.schema(PREFIX + '_test1', locals(), connection=dj.conn(**CONN_INFO)) + +@schema +class Subjects(dj.Manual): + definition = """ + #Basic subject + subject_id : int # unique subject id + --- + real_id : varchar(40) # real-world name + species = "mouse" : enum('mouse', 'monkey', 'human') # species + """ + +@schema +class Animals(dj.Manual): + definition = """ + # information of non-human subjects + -> Subjects + --- + animal_dob :date # date of birth + """ + +@schema +class Trials(dj.Manual): + definition = """ + # info about trials + -> Subjects + trial_id: int + --- + outcome: int # result of experiment + notes="": varchar(4096) # other comments + trial_ts=CURRENT_TIMESTAMP: timestamp # automatic + """ + +@schema +class Matrix(dj.Manual): + definition = """ + # Some numpy array + matrix_id: int # unique matrix id + --- + data: longblob # data + comment: varchar(1000) # comment + """ + +@schema +class SquaredScore(dj.Computed): + definition = """ + # cumulative outcome of trials + -> Subjects + -> Trials + --- + squared: int # squared result of Trials outcome + """ + + def _make_tuples(self, key): + outcome = (Trials() & key).fetch1()['outcome'] + self.insert(dict(key, squared=outcome**2)) + ss = SquaredSubtable() + for i in range(10): + ss.insert(dict(key, dummy=i)) + + +@schema +class ErrorGenerator(dj.Computed): + definition = """ + # ignore + -> Subjects + -> Trials + --- + dummy: int # ignore + """ + + def _make_tuples(self, key): + raise Exception("This is for testing") + + +@schema +class SquaredSubtable(dj.Subordinate, dj.Manual): + definition = """ + # cumulative outcome of trials + -> SquaredScore + dummy: int # dummy primary attribute + --- + """ \ No newline at end of file diff --git a/tests/schemata/__init__.py b/tests/schemata/__init__.py deleted file mode 100644 index 9fa8b9ad1..000000000 --- a/tests/schemata/__init__.py +++ /dev/null @@ -1 +0,0 @@ -__author__ = "eywalker, fabiansinz" \ No newline at end of file diff --git a/tests/schemata/schema1/__init__.py b/tests/schemata/schema1/__init__.py deleted file mode 100644 index cae90cec9..000000000 --- a/tests/schemata/schema1/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# __author__ = 'eywalker' -# import datajoint as dj -# -# print(__name__) -# from .test3 import * \ No newline at end of file diff --git a/tests/schemata/schema1/test2.py b/tests/schemata/schema1/test2.py deleted file mode 100644 index aded2e4fb..000000000 --- a/tests/schemata/schema1/test2.py +++ /dev/null @@ -1,48 +0,0 @@ -# """ -# Test 2 Schema definition -# """ -# __author__ = 'eywalker' -# -# import datajoint as dj -# from . import test1 as alias -# #from ..schema2 import test2 as test1 -# -# -# # references to another schema -# class Experiments(dj.Relation): -# definition = """ -# test2.Experiments (manual) # Basic subject info -# -> test1.Subjects -# experiment_id : int # unique experiment id -# --- -# real_id : varchar(40) # real-world name -# species = "mouse" : enum('mouse', 'monkey', 'human') # species -# """ -# -# -# # references to another schema -# class Conditions(dj.Relation): -# definition = """ -# test2.Conditions (manual) # Subject conditions -# -> alias.Subjects -# condition_name : varchar(255) # description of the condition -# """ -# -# -# class FoodPreference(dj.Relation): -# definition = """ -# test2.FoodPreference (manual) # Food preference of each subject -# -> animals.Subjects -# preferred_food : enum('banana', 'apple', 'oranges') -# """ -# -# -# class Session(dj.Relation): -# definition = """ -# test2.Session (manual) # Experiment sessions -# -> test1.Subjects -# -> test2.Experimenter -# session_id : int # unique session id -# --- -# session_comment : varchar(255) # comment about the session -# """ \ No newline at end of file diff --git a/tests/schemata/schema1/test3.py b/tests/schemata/schema1/test3.py deleted file mode 100644 index e00a01afb..000000000 --- a/tests/schemata/schema1/test3.py +++ /dev/null @@ -1,21 +0,0 @@ -# """ -# Test 3 Schema definition - no binding, no conn -# -# To be bound at the package level -# """ -# __author__ = 'eywalker' -# -# import datajoint as dj -# -# -# class Subjects(dj.Relation): -# definition = """ -# schema1.Subjects (manual) # Basic subject info -# -# subject_id : int # unique subject id -# dob : date # date of birth -# --- -# real_id : varchar(40) # real-world name -# species = "mouse" : enum('mouse', 'monkey', 'human') # species -# """ -# diff --git a/tests/schemata/schema1/test4.py b/tests/schemata/schema1/test4.py deleted file mode 100644 index 9860cb030..000000000 --- a/tests/schemata/schema1/test4.py +++ /dev/null @@ -1,17 +0,0 @@ -# """ -# Test 1 Schema definition - fully bound and has connection object -# """ -# __author__ = 'fabee' -# -# import datajoint as dj -# -# -# class Matrix(dj.Relation): -# definition = """ -# test4.Matrix (manual) # Some numpy array -# -# matrix_id : int # unique matrix id -# --- -# data : longblob # data -# comment : varchar(1000) # comment -# """ diff --git a/tests/schemata/schema2/__init__.py b/tests/schemata/schema2/__init__.py deleted file mode 100644 index d79e02cc3..000000000 --- a/tests/schemata/schema2/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# __author__ = 'eywalker' -# from .test1 import * \ No newline at end of file diff --git a/tests/schemata/schema2/test1.py b/tests/schemata/schema2/test1.py deleted file mode 100644 index 4005aa670..000000000 --- a/tests/schemata/schema2/test1.py +++ /dev/null @@ -1,16 +0,0 @@ -# """ -# Test 2 Schema definition -# """ -# __author__ = 'eywalker' -# -# import datajoint as dj -# -# -# class Subjects(dj.Relation): -# definition = """ -# schema2.Subjects (manual) # Basic subject info -# pop_id : int # unique experiment id -# --- -# real_id : varchar(40) # real-world name -# species = "mouse" : enum('mouse', 'monkey', 'human') # species -# """ \ No newline at end of file diff --git a/tests/schemata/test1.py b/tests/schemata/test1.py deleted file mode 100644 index a9a03be64..000000000 --- a/tests/schemata/test1.py +++ /dev/null @@ -1,175 +0,0 @@ -""" -Test 1 Schema definition -""" -__author__ = 'eywalker' - -import datajoint as dj -# from .. import schema2 -from .. import PREFIX - -testschema = dj.schema(PREFIX + '_test1', locals()) - -@testschema -class Subjects(dj.Manual): - definition = """ - #Basic subject info - - subject_id : int # unique subject id - --- - real_id : varchar(40) # real-world name - species = "mouse" : enum('mouse', 'monkey', 'human') # species - """ - -# test for shorthand -@testschema -class Animals(dj.Manual): - definition = """ - # Listing of all info - - -> Subjects - --- - animal_dob :date # date of birth - """ - -@testschema -class Trials(dj.Manual): - definition = """ - # info about trials - - -> Subjects - trial_id : int - --- - outcome : int # result of experiment - - notes="" : varchar(4096) # other comments - trial_ts=CURRENT_TIMESTAMP : timestamp # automatic - """ - -@testschema -class Matrix(dj.Manual): - definition = """ - # Some numpy array - - matrix_id : int # unique matrix id - --- - data : longblob # data - comment : varchar(1000) # comment - """ - - -@testschema -class SquaredScore(dj.Computed): - definition = """ - # cumulative outcome of trials - - -> Subjects - -> Trials - --- - squared : int # squared result of Trials outcome - """ - - @property - def populate_relation(self): - return Subjects() * Trials() - - def _make_tuples(self, key): - tmp = (Trials() & key).fetch1() - tmp2 = SquaredSubtable() & key - - self.insert(dict(key, squared=tmp['outcome']**2)) - - ss = SquaredSubtable() - - for i in range(10): - key['dummy'] = i - ss.insert(key) - -@testschema -class WrongImplementation(dj.Computed): - definition = """ - # ignore - - -> Subjects - -> Trials - --- - dummy : int # ignore - """ - - @property - def populate_relation(self): - return {'subject_id':2} - - def _make_tuples(self, key): - pass - -class ErrorGenerator(dj.Computed): - definition = """ - # ignore - - -> Subjects - -> Trials - --- - dummy : int # ignore - """ - - @property - def populate_relation(self): - return Subjects() * Trials() - - def _make_tuples(self, key): - raise Exception("This is for testing") - -@testschema -class SquaredSubtable(dj.Subordinate, dj.Manual): - definition = """ - # cumulative outcome of trials - - -> SquaredScore - dummy : int # dummy primary attribute - --- - """ -# -# -# # test reference to another table in same schema -# class Experiments(dj.Relation): -# definition = """ -# test1.Experiments (imported) # Experiment info -# -> test1.Subjects -# exp_id : int # unique id for experiment -# --- -# exp_data_file : varchar(255) # data file -# """ -# -# -# # refers to a table in dj_test2 (bound to test2) but without a class -# class Sessions(dj.Relation): -# definition = """ -# test1.Sessions (manual) # Experiment sessions -# -> test1.Subjects -# -> test2.Experimenter -# session_id : int # unique session id -# --- -# session_comment : varchar(255) # comment about the session -# """ -# -# -# class Match(dj.Relation): -# definition = """ -# test1.Match (manual) # Match between subject and color -# -> schema2.Subjects -# --- -# dob : date # date of birth -# """ -# -# -# # this tries to reference a table in database directly without ORM -# class TrainingSession(dj.Relation): -# definition = """ -# test1.TrainingSession (manual) # training sessions -# -> `dj_test2`.Experimenter -# session_id : int # training session id -# """ -# -# -# class Empty(dj.Relation): -# pass diff --git a/tests/test_blob.py b/tests/test_blob.py index 8809d74b3..2265a1192 100644 --- a/tests/test_blob.py +++ b/tests/test_blob.py @@ -16,11 +16,12 @@ def test_pack(): x = np.int16(np.random.randn(1, 2, 3)) assert_array_equal(x, unpack(pack(x)), "Arrays do not match!") + def test_complex(): z = np.random.randn(8, 10) + 1j*np.random.randn(8,10) assert_array_equal(z, unpack(pack(z)), "Arrays do not match!") - z = np.random.randn(10)+ 1j*np.random.randn(10) + z = np.random.randn(10) + 1j*np.random.randn(10) assert_array_equal(z, unpack(pack(z)), "Arrays do not match!") x = np.float32(np.random.randn(3, 4, 5)) + 1j*np.float32(np.random.randn(3, 4, 5)) diff --git a/tests/test_connection.py b/tests/test_connection.py index b94a07adf..656373621 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,17 +1,12 @@ """ Collection of test cases to test connection module. """ -from tests.schemata.test1 import Subjects -__author__ = 'eywalker, fabee' -from . import CONN_INFO, PREFIX, BASE_CONN, cleanup from nose.tools import assert_true, assert_raises, assert_equal, raises import datajoint as dj -from datajoint import DataJointError import numpy as np - -def setup(): - cleanup() +from datajoint import DataJointError +from . import CONN_INFO, PREFIX def test_dj_conn(): @@ -24,88 +19,86 @@ def test_dj_conn(): def test_persistent_dj_conn(): """ - conn() method should provide persistent connection - across calls. + conn() method should provide persistent connection across calls. + Setting reset=True should create a new persistent connection. """ c1 = dj.conn(**CONN_INFO) c2 = dj.conn() + c3 = dj.conn(**CONN_INFO) + c4 = dj.conn(reset=True, **CONN_INFO) + c5 = dj.conn(**CONN_INFO) assert_true(c1 is c2) - - -def test_dj_conn_reset(): - """ - Passing in reset=True should allow for new persistent - connection to be created. - """ - c1 = dj.conn(**CONN_INFO) - c2 = dj.conn(reset=True, **CONN_INFO) - assert_true(c1 is not c2) + assert_true(c1 is c3) + assert_true(c1 is not c4) + assert_true(c4 is c5) def test_repr(): c1 = dj.conn(**CONN_INFO) - assert_true('disconnected' not in c1.__repr__() and 'connected' in c1.__repr__()) - -def test_del(): - c1 = dj.conn(**CONN_INFO) - assert_true('disconnected' not in c1.__repr__() and 'connected' in c1.__repr__()) - del c1 - + assert_true('disconnected' not in repr(c1) and 'connected' in repr(c1)) -class TestContextManager(object): - def __init__(self): - self.relvar = None - self.setup() - +class TestTransactions: """ - Test cases for FreeRelation objects + test transaction management """ - def setup(self): - """ - Create a connection object and prepare test modules - as follows: - test1 - has conn and bounded + schema = dj.schema(PREFIX + '_transactions', locals(), connection=dj.conn(**CONN_INFO)) + + @schema + class Subjects(dj.Manual): + definition = """ + #Basic subject + subject_id : int # unique subject id + --- + real_id : varchar(40) # real-world name + species = "mouse" : enum('mouse', 'monkey', 'human') # species """ - cleanup() # drop all databases with PREFIX - self.conn = dj.conn() - self.relvar = Subjects() + + def __init__(self): + self.relation = self.Subjects() + self.conn = dj.conn(**CONN_INFO) def teardown(self): - cleanup() + self.relation.delete_quick() def test_active(self): - with self.conn.transaction() as conn: + with self.conn.transaction as conn: assert_true(conn.in_transaction, "Transaction is not active") - def test_rollback(self): - - tmp = np.array([(1,'Peter','mouse'),(2, 'Klara', 'monkey')], - dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) + def test_transaction_rollback(self): + """Test transaction cancellation using a with statement""" + tmp = np.array([ + (1, 'Peter', 'mouse'), + (2, 'Klara', 'monkey') + ], self.relation.heading.as_dtype) - self.relvar.insert(tmp[0]) + self.relation.delete() + with self.conn.transaction: + self.relation.insert(tmp[0]) try: - with self.conn.transaction(): - self.relvar.insert(tmp[1]) - raise DataJointError("Just to test") - except DataJointError as e: + with self.conn.transaction: + self.relation.insert(tmp[1]) + raise DataJointError("Testing rollback") + except DataJointError: pass - testt2 = (self.relvar & 'subject_id = 2').fetch() - assert_equal(len(testt2), 0, "Length is not 0. Expected because rollback should have happened.") + assert_equal(len(self.relation), 1, + "Length is not 1. Expected because rollback should have happened.") + assert_equal(len(self.relation & 'subject_id = 2'), 0, + "Length is not 0. Expected because rollback should have happened.") def test_cancel(self): - """Tests cancelling a transaction""" - tmp = np.array([(1,'Peter','mouse'),(2, 'Klara', 'monkey')], - dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) - - self.relvar.insert(tmp[0]) - with self.conn.transaction() as conn: - self.relvar.insert(tmp[1]) - conn.cancel_transaction() - - testt2 = (self.relvar & 'subject_id = 2').fetch() - assert_equal(len(testt2), 0, "Length is not 0. Expected because rollback should have happened.") - - - + """Tests cancelling a transaction explicitly""" + tmp = np.array([ + (1, 'Peter', 'mouse'), + (2, 'Klara', 'monkey') + ], self.relation.heading.as_dtype) + self.relation.delete_quick() + self.relation.insert(tmp[0]) + self.conn.start_transaction() + self.relation.insert(tmp[1]) + self.conn.cancel_transaction() + assert_equal(len(self.relation), 1, + "Length is not 1. Expected because rollback should have happened.") + assert_equal(len(self.relation & 'subject_id = 2'), 0, + "Length is not 0. Expected because rollback should have happened.") diff --git a/tests/test_free_relation.py b/tests/test_free_relation.py deleted file mode 100644 index 285bf7487..000000000 --- a/tests/test_free_relation.py +++ /dev/null @@ -1,205 +0,0 @@ -# """ -# Collection of test cases for base module. Tests functionalities such as -# creating tables using docstring table declarations -# """ -# from .schemata import schema1, schema2 -# from .schemata.schema1 import test1, test2, test3 -# -# -# __author__ = 'eywalker' -# -# from . import BASE_CONN, CONN_INFO, PREFIX, cleanup, setup_sample_db -# from datajoint.connection import Connection -# from nose.tools import assert_raises, assert_equal, assert_regexp_matches, assert_false, assert_true, raises -# from datajoint import DataJointError -# -# -# def setup(): -# """ -# Setup connections and bindings -# """ -# pass -# -# -# class TestRelationInstantiations(object): -# """ -# Test cases for instantiating Relation objects -# """ -# def __init__(self): -# self.conn = None -# -# def setup(self): -# """ -# Create a connection object and prepare test modules -# as follows: -# test1 - has conn and bounded -# """ -# self.conn = Connection(**CONN_INFO) -# cleanup() # drop all databases with PREFIX -# #test1.conn = self.conn -# #self.conn.bind(test1.__name__, PREFIX+'_test1') -# -# #test2.conn = self.conn -# -# #test3.__dict__.pop('conn', None) # make sure conn is not defined in test3 -# test1.__dict__.pop('conn', None) -# schema1.__dict__.pop('conn', None) # make sure conn is not defined at schema level -# -# -# def teardown(self): -# cleanup() -# -# -# def test_instantiation_from_unbound_module_should_fail(self): -# """ -# Attempting to instantiate a Relation derivative from a module with -# connection defined but not bound to a database should raise error -# """ -# test1.conn = self.conn -# with assert_raises(DataJointError) as e: -# test1.Subjects() -# assert_regexp_matches(e.exception.args[0], r".*not bound.*") -# -# def test_instantiation_from_module_without_conn_should_fail(self): -# """ -# Attempting to instantiate a Relation derivative from a module that lacks -# `conn` object should raise error -# """ -# with assert_raises(DataJointError) as e: -# test1.Subjects() -# assert_regexp_matches(e.exception.args[0], r".*define.*conn.*") -# -# def test_instantiation_of_base_derivatives(self): -# """ -# Test instantiation and initialization of objects derived from -# Relation class -# """ -# test1.conn = self.conn -# self.conn.bind(test1.__name__, PREFIX + '_test1') -# s = test1.Subjects() -# assert_equal(s.dbname, PREFIX + '_test1') -# assert_equal(s.conn, self.conn) -# assert_equal(s.definition, test1.Subjects.definition) -# -# def test_packagelevel_binding(self): -# schema2.conn = self.conn -# self.conn.bind(schema2.__name__, PREFIX + '_test1') -# s = schema2.test1.Subjects() -# -# -# class TestRelationDeclaration(object): -# """ -# Test declaration (creation of table) from -# definition in Relation under various circumstances -# """ -# -# def setup(self): -# cleanup() -# -# self.conn = Connection(**CONN_INFO) -# test1.conn = self.conn -# self.conn.bind(test1.__name__, PREFIX + '_test1') -# test2.conn = self.conn -# self.conn.bind(test2.__name__, PREFIX + '_test2') -# -# def test_is_declared(self): -# """ -# The table should not be created immediately after instantiation, -# but should be created when declare method is called -# :return: -# """ -# s = test1.Subjects() -# assert_false(s.is_declared) -# s.declare() -# assert_true(s.is_declared) -# -# def test_calling_heading_should_trigger_declaration(self): -# s = test1.Subjects() -# assert_false(s.is_declared) -# a = s.heading -# assert_true(s.is_declared) -# -# def test_foreign_key_ref_in_same_schema(self): -# s = test1.Experiments() -# assert_true('subject_id' in s.heading.primary_key) -# -# def test_foreign_key_ref_in_another_schema(self): -# s = test2.Experiments() -# assert_true('subject_id' in s.heading.primary_key) -# -# def test_aliased_module_name_should_resolve(self): -# """ -# Module names that were aliased in the definition should -# be properly resolved. -# """ -# s = test2.Conditions() -# assert_true('subject_id' in s.heading.primary_key) -# -# def test_reference_to_unknown_module_in_definition_should_fail(self): -# """ -# Module names in table definition that is not aliased via import -# results in error -# """ -# s = test2.FoodPreference() -# with assert_raises(DataJointError) as e: -# s.declare() -# -# -# class TestRelationWithExistingTables(object): -# """ -# Test base derivatives behaviors when some of the tables -# already exists in the database -# """ -# def setup(self): -# cleanup() -# self.conn = Connection(**CONN_INFO) -# setup_sample_db() -# test1.conn = self.conn -# self.conn.bind(test1.__name__, PREFIX + '_test1') -# test2.conn = self.conn -# self.conn.bind(test2.__name__, PREFIX + '_test2') -# self.conn.load_headings(force=True) -# -# schema2.conn = self.conn -# self.conn.bind(schema2.__name__, PREFIX + '_package') -# -# def teardown(selfself): -# schema1.__dict__.pop('conn', None) -# cleanup() -# -# def test_detection_of_existing_table(self): -# """ -# The Relation instance should be able to detect if the -# corresponding table already exists in the database -# """ -# s = test1.Subjects() -# assert_true(s.is_declared) -# -# def test_definition_referring_to_existing_table_without_class(self): -# s1 = test1.Sessions() -# assert_true('experimenter_id' in s1.primary_key) -# -# s2 = test2.Session() -# assert_true('experimenter_id' in s2.primary_key) -# -# def test_reference_to_package_level_table(self): -# s = test1.Match() -# s.declare() -# assert_true('pop_id' in s.primary_key) -# -# def test_direct_reference_to_existing_table_should_fail(self): -# """ -# When deriving from Relation, definition should not contain direct reference -# to a database name -# """ -# s = test1.TrainingSession() -# with assert_raises(DataJointError): -# s.declare() -# -# @raises(TypeError) -# def test_instantiation_of_base_derivative_without_definition_should_fail(): -# test1.Empty() -# -# -# -# diff --git a/tests/test_heading.py b/tests/test_heading.py deleted file mode 100644 index 25d6f77b7..000000000 --- a/tests/test_heading.py +++ /dev/null @@ -1 +0,0 @@ -__author__ = 'eywalker' diff --git a/tests/test_relation.py b/tests/test_relation.py index 9db5ab162..c830ee8a2 100644 --- a/tests/test_relation.py +++ b/tests/test_relation.py @@ -1,55 +1,49 @@ import random import string -from datajoint import DataJointError -from .schemata.test1 import Subjects, Animals, Matrix, Trials, SquaredScore, SquaredSubtable, WrongImplementation, \ - ErrorGenerator, testschema -from . import BASE_CONN, CONN_INFO, PREFIX, cleanup -from nose.tools import assert_raises, assert_equal, assert_regexp_matches, assert_false, assert_true, assert_list_equal, \ - assert_tuple_equal, assert_dict_equal, raises -# from datajoint import DataJointError, TransactionError, AutoPopulate, Relation -import numpy as np from numpy.testing import assert_array_equal import numpy as np -import datajoint as dj - -def trial_faker(n=10): - def iter(): - for s in [1, 2]: - for i in range(n): - yield dict(trial_id=i, subject_id=s, outcome=int(np.random.randint(10)), notes='no comment') +from nose.tools import assert_raises, assert_equal, assert_regexp_matches, \ + assert_false, assert_true, assert_list_equal, \ + assert_tuple_equal, assert_dict_equal, raises - return iter() +from datajoint import DataJointError +from .schemas import Subjects, Animals, Matrix, Trials, SquaredScore, SquaredSubtable, \ + ErrorGenerator, schema +import datajoint as dj -class TestTableObject(object): - def __init__(self): - self.subjects = None - self.setup() +class TestRelation: """ - Test cases for FreeRelation objects + Test creation of relations, their dependencies """ - def setup(self): + def __init__(self): """ Create a connection object and prepare test modules as follows: test1 - has conn and bounded """ - cleanup() # delete everything from all tables of databases with PREFIX self.subjects = Subjects() self.animals = Animals() - self.relvar_blob = Matrix() + self.matrix = Matrix() self.trials = Trials() self.score = SquaredScore() self.subtable = SquaredSubtable() + def setup(self): + pass + + def teardown(self): + self.subjects.delete() + def test_table_name_manual(self): assert_true(not self.subjects.table_name.startswith('#') and not self.subjects.table_name.startswith('_') and not self.subjects.table_name.startswith('__')) def test_table_name_computed(self): assert_true(self.score.table_name.startswith('__')) + assert_true(self.subtable.table_name.startswith('__')) def test_population_relation_subordinate(self): assert_true(self.subtable.populated_from is None) @@ -61,16 +55,22 @@ def test_make_tubles_not_implemented_subordinate(self): def test_instantiate_relation(self): s = Subjects() - def teardown(self): - cleanup() - def test_compound_restriction(self): s = self.subjects t = self.trials s.insert(dict(subject_id=1, real_id='M')) s.insert(dict(subject_id=2, real_id='F')) - t.iter_insert(trial_faker(20)) + + # insert trials + n_trials = 20 + for subject_id in [1, 2]: + for trial_id in range(n_trials): + t.insert( + trial_id=trial_id, + subject_id=subject_id, + outcome=int(np.random.randint(10)), + notes='no comment') tM = t & (s & "real_id = 'M'") t1 = t & "subject_id = 1" @@ -196,60 +196,11 @@ def test_iter_insert(self): def test_blob_insert(self): x = np.random.randn(10) t = {'matrix_id': 0, 'data': x, 'comment': 'this is a random image'} - self.relvar_blob.insert(t) - x2 = self.relvar_blob.fetch()[0][1] + self.matrix.insert(t) + x2 = self.matrix.fetch()[0][1] assert_array_equal(x, x2, 'inserted blob does not match') -# -# class TestUnboundTables(object): -# """ -# Test usages of FreeRelation objects not connected to a module. -# """ -# def setup(self): -# cleanup() -# self.conn = Connection(**CONN_INFO) -# -# def test_creation_from_definition(self): -# definition = """ -# `dj_free`.Animals (manual) # my animal table -# animal_id : int # unique id for the animal -# --- -# animal_name : varchar(128) # name of the animal -# """ -# table = FreeRelation(self.conn, 'dj_free', 'Animals', definition) -# table.declare() -# assert_true('animal_id' in table.primary_key) -# -# def test_reference_to_non_existant_table_should_fail(self): -# definition = """ -# `dj_free`.Recordings (manual) # recordings -# -> `dj_free`.Animals -# rec_session_id : int # recording session identifier -# """ -# table = FreeRelation(self.conn, 'dj_free', 'Recordings', definition) -# assert_raises(DataJointError, table.declare) -# -# def test_reference_to_existing_table(self): -# definition1 = """ -# `dj_free`.Animals (manual) # my animal table -# animal_id : int # unique id for the animal -# --- -# animal_name : varchar(128) # name of the animal -# """ -# table1 = FreeRelation(self.conn, 'dj_free', 'Animals', definition1) -# table1.declare() -# -# definition2 = """ -# `dj_free`.Recordings (manual) # recordings -# -> `dj_free`.Animals -# rec_session_id : int # recording session identifier -# """ -# table2 = FreeRelation(self.conn, 'dj_free', 'Recordings', definition2) -# table2.declare() -# assert_true('animal_id' in table2.primary_key) -# -# def id_generator(size=6, chars=string.ascii_uppercase + string.digits): return ''.join(random.choice(chars) for _ in range(size)) @@ -269,33 +220,33 @@ def setup(self): as follows: test1 - has conn and bounded """ - cleanup() # drop all databases with PREFIX - self.relvar_blob = Matrix() + self.matrix = Matrix() def teardown(self): - cleanup() + pass # # def test_blob_iteration(self): - "Tests the basic call of the iterator" - + """Tests the basic call of the iterator""" dicts = [] for i in range(10): c = id_generator() - t = {'matrix_id': i, 'data': np.random.randn(4, 4, 4), 'comment': c} - self.relvar_blob.insert(t) + self.matrix.insert(t) dicts.append(t) - for t, t2 in zip(dicts, self.relvar_blob): + for t, t2 in zip(dicts, self.matrix): assert_true(isinstance(t2, dict), 'iterator does not return dict') - assert_equal(t['matrix_id'], t2['matrix_id'], 'inserted and retrieved tuples do not match') - assert_equal(t['comment'], t2['comment'], 'inserted and retrieved tuples do not match') - assert_true(np.all(t['data'] == t2['data']), 'inserted and retrieved tuples do not match') + assert_equal(t['matrix_id'], t2['matrix_id'], + 'inserted and retrieved tuples do not match') + assert_equal(t['comment'], t2['comment'], + 'inserted and retrieved tuples do not match') + assert_true(np.all(t['data'] == t2['data']), + 'inserted and retrieved tuples do not match') def test_fetch(self): dicts = [] @@ -305,10 +256,10 @@ def test_fetch(self): t = {'matrix_id': i, 'data': np.random.randn(4, 4, 4), 'comment': c} - self.relvar_blob.insert(t) + self.matrix.insert(t) dicts.append(t) - tuples2 = self.relvar_blob.fetch() + tuples2 = self.matrix.fetch() assert_true(isinstance(tuples2, np.ndarray), "Return value of fetch does not have proper type.") assert_true(isinstance(tuples2[0], np.void), "Return value of fetch does not have proper type.") for t, t2 in zip(dicts, tuples2): @@ -324,16 +275,21 @@ def test_fetch_dicts(self): t = {'matrix_id': i, 'data': np.random.randn(4, 4, 4), 'comment': c} - self.relvar_blob.insert(t) + self.matrix.insert(t) dicts.append(t) - tuples2 = self.relvar_blob.fetch(as_dict=True) - assert_true(isinstance(tuples2, list), "Return value of fetch with as_dict=True does not have proper type.") - assert_true(isinstance(tuples2[0], dict), "Return value of fetch with as_dict=True does not have proper type.") + tuples2 = self.matrix.fetch(as_dict=True) + assert_true(isinstance(tuples2, list), + "Return value of fetch with as_dict=True does not have proper type.") + assert_true(isinstance(tuples2[0], dict), + "Return value of fetch with as_dict=True does not have proper type.") for t, t2 in zip(dicts, tuples2): - assert_equal(t['matrix_id'], t2['matrix_id'], 'inserted and retrieved dicts do not match') - assert_equal(t['comment'], t2['comment'], 'inserted and retrieved dicts do not match') - assert_true(np.all(t['data'] == t2['data']), 'inserted and retrieved dicts do not match') + assert_equal(t['matrix_id'], t2['matrix_id'], + 'inserted and retrieved dicts do not match') + assert_equal(t['comment'], t2['comment'], + 'inserted and retrieved dicts do not match') + assert_true(np.all(t['data'] == t2['data']), + 'inserted and retrieved dicts do not match') # @@ -352,16 +308,17 @@ def setup(self): as follows: test1 - has conn and bounded """ - cleanup() # drop all databases with PREFIX - self.subjects = Subjects() self.trials = Trials() self.squared = SquaredScore() self.dummy = SquaredSubtable() - self.dummy1 = WrongImplementation() self.error_generator = ErrorGenerator() self.fill_relation() + def teardown(self): + self.error_generator.delete_quick() + + def fill_relation(self): tmp = np.array([('Klara', 2, 'monkey'), ('Peter', 3, 'mouse')], dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) @@ -371,7 +328,7 @@ def fill_relation(self): self.trials.insert(dict(subject_id=2, trial_id=trial_id, outcome=np.random.randint(0, 10))) def teardown(self): - cleanup() + pass def test_autopopulate(self): self.squared.populate() @@ -389,15 +346,15 @@ def test_autopopulate_restriction(self): @raises(DataJointError) def test_autopopulate_relation_check(self): - @testschema - class dummy(dj.Computed): + @schema + class Dummy(dj.Computed): def populated_from(self): return None def _make_tuples(self, key): pass - du = dummy() + du = Dummy() du.populate() @raises(DataJointError) diff --git a/tests/test_settings.py b/tests/test_settings.py index d10323b10..2e9f18328 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -67,7 +67,6 @@ def test_save(): os.rename(tmpfile, settings.LOCALCONFIG) def test_load_save(): - filename_old = dj.settings.LOCALCONFIG filename = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(50)) + '.json' dj.settings.LOCALCONFIG = filename diff --git a/tests/test_utils.py b/tests/test_utils.py index 6c70150b2..2973c8d81 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,14 +1,10 @@ """ Collection of test cases to test core module. """ -from datajoint.user_relations import from_camel_case - -__author__ = 'eywalker' from nose.tools import assert_true, assert_raises, assert_equal -# from datajoint.utils import to_camel_case, from_camel_case +from datajoint.user_relations import from_camel_case from datajoint import DataJointError - def setup(): pass From 6d0f473066074d514343aab1050e339e66783b62 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Wed, 1 Jul 2015 16:56:00 -0500 Subject: [PATCH 14/34] tables are now declared when relation.heading is requested. --- datajoint/autopopulate.py | 2 ++ datajoint/erd.py | 8 ++++- datajoint/heading.py | 3 ++ datajoint/relation.py | 50 +++++++++++++++++++------------ tests/__init__.py | 4 +-- tests/{schemas.py => schema.py} | 0 tests/test_declare.py | 52 +++++++++++++++++++++++++++++++++ tests/test_relation.py | 2 +- 8 files changed, 99 insertions(+), 22 deletions(-) rename tests/{schemas.py => schema.py} (100%) create mode 100644 tests/test_declare.py diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index bc1c725bc..c3f08bd6e 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -27,6 +27,8 @@ def populated_from(self): the granularity or the scope of populate() calls. """ parents = [FreeRelation(self.target.connection, rel) for rel in self.target.parents] + if not parents: + raise DataJointError('A relation must have parent relations to be able to be populated') ret = parents.pop(0) while parents: ret *= parents.pop(0) diff --git a/datajoint/erd.py b/datajoint/erd.py index 03877ce54..3452648d4 100644 --- a/datajoint/erd.py +++ b/datajoint/erd.py @@ -16,12 +16,17 @@ class ERD: + _checked_dependencies = set() _parents = defaultdict(list) _children = defaultdict(list) _references = defaultdict(list) _referenced = defaultdict(list) def load_dependencies(self, connection, full_table_name): + # check if already loaded. Use clear_dependencies before reloading + if full_table_name in self._checked_dependencies: + return + self._checked_dependencies.add(full_table_name) # fetch the CREATE TABLE statement cur = connection.query('SHOW CREATE TABLE %s' % full_table_name) create_statement = cur.fetchone() @@ -78,7 +83,8 @@ def load_dependencies(self, connection, full_table_name): self._referenced[full_table_name].append(result.referenced_table) self._references[result.referenced_table].append(full_table_name) - def clear_dependency(self, full_table_name): + def clear_dependencies(self, full_table_name): + self._checked_dependencies.remove(full_table_name) for ref in self.children.pop(full_table_name): self.parents.remove(ref) for ref in self.references.pop(full_table_name): diff --git a/datajoint/heading.py b/datajoint/heading.py index 5c332fff6..b84935d0e 100644 --- a/datajoint/heading.py +++ b/datajoint/heading.py @@ -49,6 +49,9 @@ def __init__(self, attributes=None): attributes = OrderedDict([(q['name'], Heading.AttrTuple(**q)) for q in attributes]) self.attributes = attributes + def reset(self): + self.attributes = None + def __bool__(self): return self.attributes is not None diff --git a/datajoint/relation.py b/datajoint/relation.py index b269180c7..c6b49fbe8 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -1,12 +1,11 @@ from collections.abc import Mapping -from collections import defaultdict import numpy as np import logging import abc import pymysql from . import DataJointError, config, conn -from .declare import declare, compile_attribute +from .declare import declare from .relational_operand import RelationalOperand from .blob import pack from .utils import user_choice @@ -42,19 +41,16 @@ def schema(database, context, connection=None): def decorator(cls): """ - The decorator declares the table and binds the class to the database table + The decorator binds its argument class object to a database + :param cls: class to be decorated """ + # class-level attributes cls.database = database cls._connection = connection cls._heading = Heading() - instance = cls() if isinstance(cls, type) else cls - if not instance.heading: - connection.query( - declare( - full_table_name=instance.full_table_name, - definition=instance.definition, - context=context)) - connection.erd.load_dependencies(connection, instance.full_table_name) + cls._context = context + # trigger table declaration by requesting the heading from an instance + cls().heading return cls return decorator @@ -67,6 +63,8 @@ class Relation(RelationalOperand, metaclass=abc.ABCMeta): table name, database, context, and definition. A Relation implements insert and delete methods in addition to inherited relational operators. """ + _heading = None + _context = None # ---------- abstract properties ------------ # @property @@ -92,8 +90,20 @@ def connection(self): @property def heading(self): - if not self._heading and self.is_declared: - self._heading.init_from_database(self.connection, self.database, self.table_name) + """ + Get the table headng. + If the table is not declared, attempts to declare it and return heading. + :return: + """ + if self._heading is None: + self._heading = Heading() # instance-level heading + if not self._heading: + if not self.is_declared: + self.connection.query( + declare(self.full_table_name, self.definition, self._context)) + if self.is_declared: + self.connection.erd.load_dependencies(self.connection, self.full_table_name) + self._heading.init_from_database(self.connection, self.database, self.table_name) return self._heading @property @@ -233,7 +243,8 @@ def drop_quick(self): """ if self.is_declared: self.connection.query('DROP TABLE %s' % self.full_table_name) - self.connection.erd.clear_dependency(self.full_table_name) + self.connection.erd.clear_dependencies(self.full_table_name) + self.heading.reset() logger.info("Dropped table %s" % self.full_table_name) def drop(self): @@ -267,14 +278,17 @@ class FreeRelation(Relation): """ A relation with no definition. Its table must already exist in the database. """ - definition = None - - def __init__(self, connection, full_table_name): + def __init__(self, connection, full_table_name, definition=None, context=None): [database, table_name] = full_table_name.split('.') self.database = database.strip('`') self._table_name = table_name.strip('`') - self._heading = Heading() self._connection = connection + self._definition = definition + self._context = context + + @property + def definition(self): + return self._definition @property def connection(self): diff --git a/tests/__init__.py b/tests/__init__.py index 8bd43ace9..2bf44d4cf 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -43,6 +43,6 @@ def teardown_package(): conn = dj.conn(**CONN_INFO) conn.query('SET FOREIGN_KEY_CHECKS=0') cur = conn.query('SHOW DATABASES LIKE "{}\_%%"'.format(PREFIX)) - for db in cur.fetchone(): - conn.query('DROP DATABASE `{}`'.format(db)) + for db in cur.fetchall(): + conn.query('DROP DATABASE `{}`'.format(db[0])) conn.query('SET FOREIGN_KEY_CHECKS=1') diff --git a/tests/schemas.py b/tests/schema.py similarity index 100% rename from tests/schemas.py rename to tests/schema.py diff --git a/tests/test_declare.py b/tests/test_declare.py new file mode 100644 index 000000000..b7af31bcd --- /dev/null +++ b/tests/test_declare.py @@ -0,0 +1,52 @@ +import datajoint as dj +from . import PREFIX, CONN_INFO + + + +class TestDeclare: + """ + Tests declaration, heading, dependencies, and drop + """ + + schema = dj.schema(PREFIX + '_test1', locals(), connection=dj.conn(**CONN_INFO)) + + @schema + class Subjects(dj.Manual): + definition = """ + #Basic subject + subject_id : int # unique subject id + --- + real_id : varchar(40) # real-world name + species = "mouse" : enum('mouse', 'monkey', 'human') # species + """ + + @schema + class Experiments(dj.Manual): + definition = """ + # information of non-human subjects + -> Subjects + --- + animal_dob :date # date of birth + """ + + @schema + class Trials(dj.Manual): + definition = """ + # info about trials + -> Subjects + trial_id: int + --- + outcome: int # result of experiment + notes="": varchar(4096) # other comments + trial_ts=CURRENT_TIMESTAMP: timestamp # automatic + """ + + @schema + class Matrix(dj.Manual): + definition = """ + # Some numpy array + matrix_id: int # unique matrix id + --- + data: longblob # data + comment: varchar(1000) # comment + """ \ No newline at end of file diff --git a/tests/test_relation.py b/tests/test_relation.py index c830ee8a2..bdc26076d 100644 --- a/tests/test_relation.py +++ b/tests/test_relation.py @@ -7,7 +7,7 @@ assert_tuple_equal, assert_dict_equal, raises from datajoint import DataJointError -from .schemas import Subjects, Animals, Matrix, Trials, SquaredScore, SquaredSubtable, \ +from .schema import Subjects, Animals, Matrix, Trials, SquaredScore, SquaredSubtable, \ ErrorGenerator, schema import datajoint as dj From 96bcc6b53904864b1fca9e1a39aa512a62a05fad Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Wed, 1 Jul 2015 17:34:27 -0500 Subject: [PATCH 15/34] fixed clearing and reloading of dependencies when tables are dropped. --- datajoint/__init__.py | 1 + datajoint/erd.py | 21 ++++++++++++--------- datajoint/relation.py | 3 ++- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/datajoint/__init__.py b/datajoint/__init__.py index 309d443cf..de3d720fa 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -4,6 +4,7 @@ __author__ = "Dimitri Yatsenko, Edgar Walker, and Fabian Sinz at Baylor College of Medicine" __version__ = "0.2" __all__ = ['__author__', '__version__', + 'config', 'Connection', 'Heading', 'Relation', 'FreeRelation', 'Not', 'Relation', 'schema', 'Manual', 'Lookup', 'Imported', 'Computed', diff --git a/datajoint/erd.py b/datajoint/erd.py index 3452648d4..a90f4b017 100644 --- a/datajoint/erd.py +++ b/datajoint/erd.py @@ -17,16 +17,18 @@ class ERD: _checked_dependencies = set() - _parents = defaultdict(list) + _parents = dict() + _referenced = dict() _children = defaultdict(list) _references = defaultdict(list) - _referenced = defaultdict(list) def load_dependencies(self, connection, full_table_name): # check if already loaded. Use clear_dependencies before reloading - if full_table_name in self._checked_dependencies: + if full_table_name in self._parents: return - self._checked_dependencies.add(full_table_name) + self._parents[full_table_name] = list() + self._referenced[full_table_name] = list() + # fetch the CREATE TABLE statement cur = connection.query('SHOW CREATE TABLE %s' % full_table_name) create_statement = cur.fetchone() @@ -84,11 +86,12 @@ def load_dependencies(self, connection, full_table_name): self._references[result.referenced_table].append(full_table_name) def clear_dependencies(self, full_table_name): - self._checked_dependencies.remove(full_table_name) - for ref in self.children.pop(full_table_name): - self.parents.remove(ref) - for ref in self.references.pop(full_table_name): - self.referenced.remove(ref) + for ref in self._parents.pop(full_table_name, []): + if full_table_name in self._children[ref]: + self._children[ref].remove(full_table_name) + for ref in self._referenced.pop(full_table_name, []): + if full_table_name in self._references[ref]: + self._references[ref].remove(full_table_name) @property def parents(self): diff --git a/datajoint/relation.py b/datajoint/relation.py index c6b49fbe8..e87fad25a 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -244,7 +244,8 @@ def drop_quick(self): if self.is_declared: self.connection.query('DROP TABLE %s' % self.full_table_name) self.connection.erd.clear_dependencies(self.full_table_name) - self.heading.reset() + if self._heading: + self._heading.reset() logger.info("Dropped table %s" % self.full_table_name) def drop(self): From 0cdf45617966a9331afa47dddf91cd9b513a9d07 Mon Sep 17 00:00:00 2001 From: Fabian Sinz Date: Sun, 5 Jul 2015 11:24:53 -0500 Subject: [PATCH 16/34] more documentation --- datajoint/__init__.py | 10 +++++++ datajoint/autopopulate.py | 7 +++-- datajoint/blob.py | 15 ++++++++-- datajoint/connection.py | 52 +++++++++++++++++++++++++++++++---- datajoint/declare.py | 9 ++++++ datajoint/jobs.py | 1 + datajoint/user_relations.py | 45 ++++++++++++++++++++++++++++++ datajoint/utils.py | 2 +- doc/source/autopopulate.rst | 5 ++++ doc/source/blob.rst | 5 ++++ doc/source/declare.rst | 6 ++++ doc/source/index.rst | 5 ++-- doc/source/tiers.rst | 9 ++++++ doc/source/user_relations.rst | 5 ++++ 14 files changed, 162 insertions(+), 14 deletions(-) create mode 100644 doc/source/autopopulate.rst create mode 100644 doc/source/blob.rst create mode 100644 doc/source/declare.rst create mode 100644 doc/source/tiers.rst create mode 100644 doc/source/user_relations.rst diff --git a/datajoint/__init__.py b/datajoint/__init__.py index 309d443cf..7bb14b8a8 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -1,3 +1,13 @@ +""" +DataJoint for Python is a high-level programming interface for MySQL databases +to support data processing chains in science labs. DataJoint is built on the +foundation of the relational data model and prescribes a consistent method for +organizing, populating, and querying data. + +DataJoint is free software under the LGPL License. In addition, we request +that any use of DataJoint leading to a publication be acknowledged in the publication. +""" + import logging import os diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index bc1c725bc..abdfa1b4e 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -1,3 +1,4 @@ +"""autopopulate containing the dj.AutoPopulate class. See `dj.AutoPopulate` for more info.""" import abc import logging from .relational_operand import RelationalOperand @@ -22,9 +23,9 @@ class AutoPopulate(metaclass=abc.ABCMeta): def populated_from(self): """ :return: the relation whose primary key values are passed, sequentially, to the - _make_tuples method when populate() is called. - The default value is the join of the parent relations. Users may override to change - the granularity or the scope of populate() calls. + `_make_tuples` method when populate() is called.The default value is the + join of the parent relations. Users may override to change the granularity + or the scope of populate() calls. """ parents = [FreeRelation(self.target.connection, rel) for rel in self.target.parents] ret = parents.pop(0) diff --git a/datajoint/blob.py b/datajoint/blob.py index 7988e82d5..e002be24e 100644 --- a/datajoint/blob.py +++ b/datajoint/blob.py @@ -1,3 +1,7 @@ +""" +Provides serialization methods for numpy.ndarrays that ensure compatibility with Matlab. +""" + import zlib from collections import OrderedDict import numpy as np @@ -29,7 +33,10 @@ def pack(obj): """ - packs an object into a blob to be compatible with mym.mex + Packs an object into a blob to be compatible with mym.mex + + :param obj: object to be packed + :type obj: numpy.ndarray """ if not isinstance(obj, np.ndarray): raise DataJointError("Only numpy arrays can be saved in blobs") @@ -58,7 +65,11 @@ def pack(obj): def unpack(blob): """ - unpack blob into a numpy array + Unpacks blob data into a numpy array. + + :param blob: mysql blob + :returns: unpacked data + :rtype: numpy.ndarray """ # decompress if necessary if blob[0:5] == b'ZL123': diff --git a/datajoint/connection.py b/datajoint/connection.py index 393b9862c..25db62048 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -1,3 +1,8 @@ +""" +This module hosts the Connection class that manages the connection to the mysql database via +`pymysql`, and the `conn` function that provides access to a persistent connection in datajoint. + +""" from contextlib import contextmanager import pymysql from . import DataJointError @@ -12,7 +17,15 @@ def conn(host=None, user=None, passwd=None, init_fun=None, reset=False): """ Returns a persistent connection object to be shared by multiple modules. If the connection is not yet established or reset=True, a new connection is set up. - If connection information is not provided, it is taken from config. + If connection information is not provided, it is taken from config which takes the + information from dj_local_conf.json. If the password is not specified in that file + datajoint prompts for the password. + + :param host: hostname + :param user: mysql user + :param passwd: mysql password + :param init_fun: initialization function + :param reset: whether the connection should be reseted or not """ if not hasattr(conn, 'connection') or reset: host = host if host is not None else config['database.host'] @@ -74,16 +87,15 @@ def __repr__(self): return "DataJoint connection ({connected}) {user}@{host}:{port}".format( connected=connected, **self.conn_info) - def __del__(self): - logger.info('Disconnecting {user}@{host}:{port}'.format(**self.conn_info)) - self._conn.close() def query(self, query, args=(), as_dict=False): """ Execute the specified query and return the tuple generator. - If as_dict is set to True, the returned cursor objects returns - query results as dictionary. + :param query: mysql query + :param args: additional arguments for the pymysql.cursor + :param as_dict: If as_dict is set to True, the returned cursor objects returns + query results as dictionary. """ cursor = pymysql.cursors.DictCursor if as_dict else pymysql.cursors.Cursor cur = self._conn.cursor(cursor=cursor) @@ -96,10 +108,18 @@ def query(self, query, args=(), as_dict=False): # ---------- transaction processing @property def in_transaction(self): + """ + :return: True if there is an open transaction. + """ self._in_transaction = self._in_transaction and self.is_connected return self._in_transaction def start_transaction(self): + """ + Starts a transaction error. + + :raise DataJointError: if there is an ongoing transaction. + """ if self.in_transaction: raise DataJointError("Nested connections are not supported.") self.query('START TRANSACTION WITH CONSISTENT SNAPSHOT') @@ -107,11 +127,19 @@ def start_transaction(self): logger.info("Transaction started") def cancel_transaction(self): + """ + Cancels the current transaction and rolls back all changes made during the transaction. + + """ self.query('ROLLBACK') self._in_transaction = False logger.info("Transaction cancelled. Rolling back ...") def commit_transaction(self): + """ + Commit all changes made during the transaction and close it. + + """ self.query('COMMIT') self._in_transaction = False logger.info("Transaction committed and closed.") @@ -119,6 +147,18 @@ def commit_transaction(self): # -------- context manager for transactions @contextmanager def transaction(self): + """ + Context manager for transactions. Opens an transaction and closes it after the with statement. + If an error is caught during the transaction, the commits are automatically rolled back. All + errors are raised again. + + Example: + >>> import datajoint as dj + >>> with dj.conn().transaction() as conn: + >>> # transaction is open here + + + """ try: self.start_transaction() yield self diff --git a/datajoint/declare.py b/datajoint/declare.py index 18679150a..eceb70b48 100644 --- a/datajoint/declare.py +++ b/datajoint/declare.py @@ -1,3 +1,7 @@ +""" +This module hosts functions to convert DataJoint table definitions into mysql table definitions, and to +declare the corresponding mysql tables. +""" import re import pyparsing as pp import logging @@ -12,6 +16,10 @@ def declare(full_table_name, definition, context): """ Parse declaration and create new SQL table accordingly. + + :param full_table_name: full name of the table + :param definition: DataJoint table definition + :param context: dictionary of objects that might be referred to in the table. Usually this will be locals() """ # split definition into lines definition = re.split(r'\s*\n\s*', definition.strip()) @@ -72,6 +80,7 @@ def declare(full_table_name, definition, context): def compile_attribute(line, in_key=False): """ Convert attribute definition from DataJoint format to SQL + :param line: attribution line :param in_key: set to True if attribute is in primary key set :returns: (name, sql) -- attribute name and sql code for its declaration diff --git a/datajoint/jobs.py b/datajoint/jobs.py index f73cd1063..65344a4a4 100644 --- a/datajoint/jobs.py +++ b/datajoint/jobs.py @@ -7,6 +7,7 @@ def get_jobs_table(database): """ + :return: the base relation of the job reservation table for database """ self = get_jobs_table diff --git a/datajoint/user_relations.py b/datajoint/user_relations.py index 38f149615..4f9fdcbdc 100644 --- a/datajoint/user_relations.py +++ b/datajoint/user_relations.py @@ -1,3 +1,7 @@ +""" +Hosts the table tiers, user relations should be derived from. +""" + import re import abc from datajoint.relation import Relation @@ -6,26 +10,54 @@ class Manual(Relation): + """ + Inherit from this class if the table's values are entered manually. + """ @property def table_name(self): + """ + :returns: the table name of the table formatted for mysql. + """ return from_camel_case(self.__class__.__name__) class Lookup(Relation): + """ + Inherit from this class if the table's values are for lookup. This is + currently equivalent to defining the table as Manual and serves semantic + purposes only. + """ @property def table_name(self): + """ + :returns: the table name of the table formatted for mysql. + """ return '#' + from_camel_case(self.__class__.__name__) class Imported(Relation, AutoPopulate): + """ + Inherit from this class if the table's values are imported from external data sources. + The inherited class must at least provide the function `_make_tuples`. + """ @property def table_name(self): + """ + :returns: the table name of the table formatted for mysql. + """ return "_" + from_camel_case(self.__class__.__name__) class Computed(Relation, AutoPopulate): + """ + Inherit from this class if the table's values are computed from other relations in the schema. + The inherited class must at least provide the function `_make_tuples`. + """ @property def table_name(self): + """ + :returns: the table name of the table formatted for mysql. + """ return "__" + from_camel_case(self.__class__.__name__) @@ -35,9 +67,22 @@ class Subordinate: """ @property def populated_from(self): + """ + Overrides the `populate_from` property because subtables should not be populated + directly. + + :return: None + """ return None def _make_tuples(self, key): + """ + Overrides the `_make_tuples` property because subtables should not be populated + directly. Raises an error if this method is called (usually from populate of the + inheriting object). + + :raises: NotImplementedError + """ raise NotImplementedError('Subtables should not be populated directly.') diff --git a/datajoint/utils.py b/datajoint/utils.py index f4b0edb57..e1e3d2fda 100644 --- a/datajoint/utils.py +++ b/datajoint/utils.py @@ -4,7 +4,7 @@ def user_choice(prompt, choices=("yes", "no"), default=None): Prompts the user for confirmation. The default value, if any, is capitalized. :param prompt: Information to display to the user. :param choices: an iterable of possible choices. - :param default=None: default choice + :param default: default choice :return: the user's choice """ choice_list = ', '.join((choice.title() if choice == default else choice for choice in choices)) diff --git a/doc/source/autopopulate.rst b/doc/source/autopopulate.rst new file mode 100644 index 000000000..75d5a632f --- /dev/null +++ b/doc/source/autopopulate.rst @@ -0,0 +1,5 @@ +AutoPopulate +============ + +.. automodule:: datajoint.autopopulate + :members: diff --git a/doc/source/blob.rst b/doc/source/blob.rst new file mode 100644 index 000000000..bea29b965 --- /dev/null +++ b/doc/source/blob.rst @@ -0,0 +1,5 @@ +Serialization Module +==================== + +.. automodule:: datajoint.blob + :members: diff --git a/doc/source/declare.rst b/doc/source/declare.rst new file mode 100644 index 000000000..3817225d9 --- /dev/null +++ b/doc/source/declare.rst @@ -0,0 +1,6 @@ +Mysql table declaration +======================= + +.. automodule:: datajoint.declare + :members: + :inherited-members: diff --git a/doc/source/index.rst b/doc/source/index.rst index b02acdad1..d70f80705 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -11,10 +11,11 @@ API: .. toctree:: :maxdepth: 2 + tiers.rst relation.rst - relational_operand.rst connection.rst - + blob.rst + declare.rst Indices and tables ================== diff --git a/doc/source/tiers.rst b/doc/source/tiers.rst new file mode 100644 index 000000000..b1ebd1ebc --- /dev/null +++ b/doc/source/tiers.rst @@ -0,0 +1,9 @@ +Datajoint Tiers +=============== + +.. toctree:: + :maxdepth: 2 + + user_relations.rst + autopopulate.rst + diff --git a/doc/source/user_relations.rst b/doc/source/user_relations.rst new file mode 100644 index 000000000..73724209b --- /dev/null +++ b/doc/source/user_relations.rst @@ -0,0 +1,5 @@ +User Tiers +========== + +.. automodule:: datajoint.user_relations + :members: From ba92501871e15177115714f147b1cf93af5b10e3 Mon Sep 17 00:00:00 2001 From: Fabian Sinz Date: Mon, 6 Jul 2015 16:49:31 -0500 Subject: [PATCH 17/34] changed insert to insert1. --- datajoint/relation.py | 20 +++++++------------- tests/schema.py | 4 ++-- tests/test_connection.py | 8 ++++---- tests/test_relation.py | 24 ++++++++++++------------ 4 files changed, 25 insertions(+), 31 deletions(-) diff --git a/datajoint/relation.py b/datajoint/relation.py index e87fad25a..69b363080 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -113,14 +113,6 @@ def from_clause(self): """ return self.full_table_name - def iter_insert(self, rows, **kwargs): - """ - Inserts a collection of tuples. Additional keyword arguments are passed to insert. - - :param iter: Must be an iterator that generates a sequence of valid arguments for insert. - """ - for row in rows: - self.insert(row, **kwargs) # ------------- dependencies ---------- # @property @@ -158,19 +150,21 @@ def is_declared(self): database=self.database, table_name=self.table_name)) return cur.rowcount > 0 - def batch_insert(self, data, **kwargs): + def insert(self, rows, **kwargs): """ - Inserts an entire batch of entries. Additional keyword arguments are passed to insert. + Inserts a collection of tuples. Additional keyword arguments are passed to insert. - :param data: must be iterable, each row must be a valid argument for insert + :param iter: Must be an iterator that generates a sequence of valid arguments for insert. """ - self.iter_insert(data.__iter__(), **kwargs) + for row in rows: + self.insert1(row, **kwargs) + @property def full_table_name(self): return r"`{0:s}`.`{1:s}`".format(self.database, self.table_name) - def insert(self, tup, ignore_errors=False, replace=False): + def insert1(self, tup, ignore_errors=False, replace=False): """ Insert one data record or one Mapping (like a dict). diff --git a/tests/schema.py b/tests/schema.py index 378ec3d24..112bc1596 100644 --- a/tests/schema.py +++ b/tests/schema.py @@ -60,10 +60,10 @@ class SquaredScore(dj.Computed): def _make_tuples(self, key): outcome = (Trials() & key).fetch1()['outcome'] - self.insert(dict(key, squared=outcome**2)) + self.insert1(dict(key, squared=outcome ** 2)) ss = SquaredSubtable() for i in range(10): - ss.insert(dict(key, dummy=i)) + ss.insert1(dict(key, dummy=i)) @schema diff --git a/tests/test_connection.py b/tests/test_connection.py index 656373621..11f50beea 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -75,10 +75,10 @@ def test_transaction_rollback(self): self.relation.delete() with self.conn.transaction: - self.relation.insert(tmp[0]) + self.relation.insert1(tmp[0]) try: with self.conn.transaction: - self.relation.insert(tmp[1]) + self.relation.insert1(tmp[1]) raise DataJointError("Testing rollback") except DataJointError: pass @@ -94,9 +94,9 @@ def test_cancel(self): (2, 'Klara', 'monkey') ], self.relation.heading.as_dtype) self.relation.delete_quick() - self.relation.insert(tmp[0]) + self.relation.insert1(tmp[0]) self.conn.start_transaction() - self.relation.insert(tmp[1]) + self.relation.insert1(tmp[1]) self.conn.cancel_transaction() assert_equal(len(self.relation), 1, "Length is not 1. Expected because rollback should have happened.") diff --git a/tests/test_relation.py b/tests/test_relation.py index bdc26076d..11fe25298 100644 --- a/tests/test_relation.py +++ b/tests/test_relation.py @@ -59,14 +59,14 @@ def test_compound_restriction(self): s = self.subjects t = self.trials - s.insert(dict(subject_id=1, real_id='M')) - s.insert(dict(subject_id=2, real_id='F')) + s.insert1(dict(subject_id=1, real_id='M')) + s.insert1(dict(subject_id=2, real_id='F')) # insert trials n_trials = 20 for subject_id in [1, 2]: for trial_id in range(n_trials): - t.insert( + t.insert1( trial_id=trial_id, subject_id=subject_id, outcome=int(np.random.randint(10)), @@ -87,7 +87,7 @@ def test_record_insert(self): tmp = np.array([(2, 'Klara', 'monkey')], dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) - self.subjects.insert(tmp[0]) + self.subjects.insert1(tmp[0]) testt2 = (self.subjects & 'subject_id = 2').fetch()[0] assert_equal(tuple(tmp[0]), tuple(testt2), "Inserted and fetched record do not match!") @@ -128,7 +128,7 @@ def test_record_insert_different_order(self): tmp = np.array([('Klara', 2, 'monkey')], dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) - self.subjects.insert(tmp[0]) + self.subjects.insert1(tmp[0]) testt2 = (self.subjects & 'subject_id = 2').fetch()[0] assert_equal((2, 'Klara', 'monkey'), tuple(testt2), "Inserted and fetched record do not match!") @@ -139,7 +139,7 @@ def test_wrong_key_insert_records(self): tmp = np.array([('Klara', 2, 'monkey')], dtype=[('real_deal', 'O'), ('subject_id', '>i4'), ('species', 'O')]) - self.subjects.insert(tmp[0]) + self.subjects.insert1(tmp[0]) def test_dict_insert(self): "Test whether record insert works" @@ -158,7 +158,7 @@ def test_wrong_key_insert(self): 'subject_database': 3, 'species': 'human'} - self.subjects.insert(tmp) + self.subjects.insert1(tmp) # def test_batch_insert(self): @@ -196,7 +196,7 @@ def test_iter_insert(self): def test_blob_insert(self): x = np.random.randn(10) t = {'matrix_id': 0, 'data': x, 'comment': 'this is a random image'} - self.matrix.insert(t) + self.matrix.insert1(t) x2 = self.matrix.fetch()[0][1] assert_array_equal(x, x2, 'inserted blob does not match') @@ -235,7 +235,7 @@ def test_blob_iteration(self): t = {'matrix_id': i, 'data': np.random.randn(4, 4, 4), 'comment': c} - self.matrix.insert(t) + self.matrix.insert1(t) dicts.append(t) for t, t2 in zip(dicts, self.matrix): @@ -256,7 +256,7 @@ def test_fetch(self): t = {'matrix_id': i, 'data': np.random.randn(4, 4, 4), 'comment': c} - self.matrix.insert(t) + self.matrix.insert1(t) dicts.append(t) tuples2 = self.matrix.fetch() @@ -275,7 +275,7 @@ def test_fetch_dicts(self): t = {'matrix_id': i, 'data': np.random.randn(4, 4, 4), 'comment': c} - self.matrix.insert(t) + self.matrix.insert1(t) dicts.append(t) tuples2 = self.matrix.fetch(as_dict=True) @@ -325,7 +325,7 @@ def fill_relation(self): self.subjects.batch_insert(tmp) for trial_id in range(1, 11): - self.trials.insert(dict(subject_id=2, trial_id=trial_id, outcome=np.random.randint(0, 10))) + self.trials.insert1(dict(subject_id=2, trial_id=trial_id, outcome=np.random.randint(0, 10))) def teardown(self): pass From 719907acc74cfd433a80bc3d261affdd7c34be95 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Wed, 8 Jul 2015 11:56:35 -0500 Subject: [PATCH 18/34] added positional insert of tuples (attributes are identified by position, not name) --- datajoint/heading.py | 3 + datajoint/jobs.py | 16 +++- datajoint/relation.py | 59 ++++++------ datajoint/relational_operand.py | 2 +- tests/schema.py | 153 ++++++++++++++++++++------------ tests/test_declare.py | 52 ++--------- 6 files changed, 154 insertions(+), 131 deletions(-) diff --git a/datajoint/heading.py b/datajoint/heading.py index b84935d0e..5c1ca0f76 100644 --- a/datajoint/heading.py +++ b/datajoint/heading.py @@ -52,6 +52,9 @@ def __init__(self, attributes=None): def reset(self): self.attributes = None + def __len__(self): + return 0 if self.attributes is None else len(self.attributes) + def __bool__(self): return self.attributes is not None diff --git a/datajoint/jobs.py b/datajoint/jobs.py index f73cd1063..f3ce5efab 100644 --- a/datajoint/jobs.py +++ b/datajoint/jobs.py @@ -53,7 +53,11 @@ def key_hash(key): def reserve(reserve_jobs, full_table_name, key): """ - Insert a reservation record in the jobs table + Reserve a job for computation. When a job is reserved, the job table contains an entry for the + job key, identified by its hash. When jobs are completed, the entry is removed. + :param reserve_jobs: if True, use job reservation + :param full_table_name: `database`.`table_name` + :param key: the dict of the job's primary key :return: True if reserved job successfully """ if not reserve_jobs: @@ -74,7 +78,10 @@ def reserve(reserve_jobs, full_table_name, key): def complete(reserve_jobs, full_table_name, key): """ - upon job completion the job entry is removed + Log a completed job. When a job is completed, its reservation entry is deleted. + :param reserve_jobs: if True, use job reservation + :param full_table_name: `database`.`table_name` + :param key: the dict of the job's primary key """ if reserve_jobs: database, table_name = split_name(full_table_name) @@ -85,7 +92,12 @@ def complete(reserve_jobs, full_table_name, key): def error(reserve_jobs, full_table_name, key, error_message): """ + Log an error message. The job reservation is replaced with an error entry. if an error occurs, leave an entry describing the problem + :param reserve_jobs: if True, use job reservation + :param full_table_name: `database`.`table_name` + :param key: the dict of the job's primary key + :param error_message: string error message """ if reserve_jobs: database, table_name = split_name(full_table_name) diff --git a/datajoint/relation.py b/datajoint/relation.py index e87fad25a..ff453b0e5 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -113,15 +113,6 @@ def from_clause(self): """ return self.full_table_name - def iter_insert(self, rows, **kwargs): - """ - Inserts a collection of tuples. Additional keyword arguments are passed to insert. - - :param iter: Must be an iterator that generates a sequence of valid arguments for insert. - """ - for row in rows: - self.insert(row, **kwargs) - # ------------- dependencies ---------- # @property def parents(self): @@ -158,31 +149,33 @@ def is_declared(self): database=self.database, table_name=self.table_name)) return cur.rowcount > 0 - def batch_insert(self, data, **kwargs): - """ - Inserts an entire batch of entries. Additional keyword arguments are passed to insert. - - :param data: must be iterable, each row must be a valid argument for insert - """ - self.iter_insert(data.__iter__(), **kwargs) - @property def full_table_name(self): return r"`{0:s}`.`{1:s}`".format(self.database, self.table_name) - def insert(self, tup, ignore_errors=False, replace=False): + def insert(self, rows, **kwargs): + """ + Inserts a collection of tuples. Additional keyword arguments are passed to insert1. + + :param iter: Must be an iterator that generates a sequence of valid arguments for insert. + """ + for row in rows: + self.insert1(row, **kwargs) + + def insert1(self, tup, ignore_errors=False, replace=False): """ Insert one data record or one Mapping (like a dict). - :param tup: Data record, or a Mapping (like a dict). + :param tup: Data record, a Mapping (like a dict), or a list or tuple with ordered values. :param ignore_errors=False: Ignores errors if True. :param replace=False: Replaces data tuple if True. Example:: - rel.insert(dict(subject_id = 7, species="mouse", date_of_birth = "2014-09-01")) + relation.insert1(dict(subject_id=7, species="mouse", date_of_birth="2014-09-01")) """ heading = self.heading - if isinstance(tup, np.void): + + if isinstance(tup, np.void): # np.array insert for fieldname in tup.dtype.fields: if fieldname not in heading: raise KeyError(u'{0:s} is not in the attribute list'.format(fieldname)) @@ -191,7 +184,8 @@ def insert(self, tup, ignore_errors=False, replace=False): args = tuple(pack(tup[name]) for name in heading if name in tup.dtype.fields and heading[name].is_blob) attribute_list = '`' + '`,`'.join(q for q in heading if q in tup.dtype.fields) + '`' - elif isinstance(tup, Mapping): + + elif isinstance(tup, Mapping): # dict-based insert for fieldname in tup.keys(): if fieldname not in heading: raise KeyError(u'{0:s} is not in the attribute list'.format(fieldname)) @@ -200,8 +194,19 @@ def insert(self, tup, ignore_errors=False, replace=False): args = tuple(pack(tup[name]) for name in heading if name in tup and heading[name].is_blob) attribute_list = '`' + '`,`'.join(name for name in heading if name in tup) + '`' - else: - raise DataJointError('Datatype %s cannot be inserted' % type(tup)) + + else: # positional insert + try: + if len(tup) != len(self.heading): + raise DataJointError( + 'Tuple size does not match the number of relation attributes') + except TypeError: + raise DataJointError('Datatype %s cannot be inserted' % type(tup)) + else: + pairs = zip(heading, tup) + value_list = ','.join('%s' if heading[name].is_blob else repr(value) for name, value in pairs) + attribute_list = '`' + '`,`'.join(heading.names) + '`' + args = tuple(pack(vallue) for name, value in pairs if heading[name].is_blob) if replace: sql = 'REPLACE' elif ignore_errors: @@ -230,7 +235,9 @@ def delete(self): if config['safemode']: print('The contents of the following tables are about to be deleted:') for relation in relations: - print(relation.full_table_name, '(%d tuples)'' % len(relation)') + count = len(relation) + if count: + print(relation.full_table_name, '(%d tuples)' % count) do_delete = user_choice("Proceed?", default='no') == 'yes' if do_delete: with self.connection.transaction: @@ -263,7 +270,7 @@ def drop(self): if do_drop: while relations: relations.pop().drop_quick() - print('Dropped tables..') + print('Tables dropped.') def size_on_disk(self): """ diff --git a/datajoint/relational_operand.py b/datajoint/relational_operand.py index 367fd2d1e..3a380df79 100644 --- a/datajoint/relational_operand.py +++ b/datajoint/relational_operand.py @@ -151,7 +151,7 @@ def __len__(self): def __contains__(self, item): """ - "item in relation" is equivalient to "len(relation & item)>0" + "item in relation" is equivalent to "len(relation & item)>0" """ return len(self & item) > 0 diff --git a/tests/schema.py b/tests/schema.py index 378ec3d24..3bac2dca0 100644 --- a/tests/schema.py +++ b/tests/schema.py @@ -2,89 +2,128 @@ Test schema definition """ +import random import datajoint as dj from . import PREFIX, CONN_INFO schema = dj.schema(PREFIX + '_test1', locals(), connection=dj.conn(**CONN_INFO)) @schema -class Subjects(dj.Manual): - definition = """ - #Basic subject - subject_id : int # unique subject id - --- - real_id : varchar(40) # real-world name - species = "mouse" : enum('mouse', 'monkey', 'human') # species +class User(dj.Manual): + definition = """ # lab members + username: varchar(12) """ -@schema -class Animals(dj.Manual): - definition = """ - # information of non-human subjects - -> Subjects - --- - animal_dob :date # date of birth - """ + def fill(self): + names = [['Jake'], ['Cathryn'], ['Shan'], ['Fabian'], ['Edgar'], ['George'], ['Dimitri']] + self.insert(names) + return names @schema -class Trials(dj.Manual): - definition = """ - # info about trials - -> Subjects - trial_id: int +class Subject(dj.Manual): + definition = """ # Basic information about animal subjects used in experiments + subject_id :int # unique subject id --- - outcome: int # result of experiment - notes="": varchar(4096) # other comments - trial_ts=CURRENT_TIMESTAMP: timestamp # automatic + real_id :varchar(40) # real-world name. Omit if the same as subject_id + species = "mouse" :enum('mouse', 'monkey', 'human') + date_of_birth :date + subject_notes :varchar(4000) + unique index (real_id, species) """ + def fill(self): + data = [ + [1551, '', 'mouse', '2015-04-01', 'genetically engineered super mouse'], + [1, 'Curious George', 'monkey', '2008-06-30', ''], + [1552, '', 'mouse', '2015-06-15', ''], + [1553, '', 'mouse', '2016-07-01', ''] + ] + self.insert(data) + return data + @schema -class Matrix(dj.Manual): - definition = """ - # Some numpy array - matrix_id: int # unique matrix id +class Experiment(dj.Imported): + definition = """ # information about experiments + -> Subject + experiment_id :smallint # experiment number for this subject --- - data: longblob # data - comment: varchar(1000) # comment + experiment_date :date # date when experiment was started + -> User + data_path="" :varchar(255) # file path to recorded data + notes="" :varchar(2048) # e.g. purpose of experiment + entry_time=CURRENT_TIMESTAMP :timestamp # automatic timestamp """ + def _make_tuples(self, key): + """ + populate with random data + """ + from datetime import date, timedelta + experiments_per_subject = 5 + users = User().fetch()['username'] + for experiment_id in range(experiments_per_subject): + self.insert1( + dict(key, + experiment_id=experiment_id, + experiment_date=(date.today()-timedelta(random.expovariate(1/30))).isoformat(), + username=random.choice(users))) + @schema -class SquaredScore(dj.Computed): - definition = """ - # cumulative outcome of trials - -> Subjects - -> Trials +class Trial(dj.Imported): + definition = """ # a trial within an experiment + -> Experiment + trial_id :smallint # trial number --- - squared: int # squared result of Trials outcome + start_time :double # (s) """ def _make_tuples(self, key): - outcome = (Trials() & key).fetch1()['outcome'] - self.insert(dict(key, squared=outcome**2)) - ss = SquaredSubtable() - for i in range(10): - ss.insert(dict(key, dummy=i)) - + """ + populate with random data (pretend reading from raw files) + """ + for trial_id in range(10): + self.insert1( + dict(key, + trial_id=trial_id, + start_time=random.random()*1e9 + )) @schema -class ErrorGenerator(dj.Computed): - definition = """ - # ignore - -> Subjects - -> Trials - --- - dummy: int # ignore +class Ephys(dj.Imported): + definition = """ # some kind of electrophysiological recording + -> Trial + ---- + sampling_frequency :double # (Hz) + duration :double # (s) """ def _make_tuples(self, key): - raise Exception("This is for testing") - + """ + populate with random data + """ + row = dict(key, + sampling_frequency=16000, + duration=random.expovariate(1/30)) + self.insert1(row) + EphysChannel().fill(number_samples=round(row.duration*row.sampling_frequency)) @schema -class SquaredSubtable(dj.Subordinate, dj.Manual): - definition = """ - # cumulative outcome of trials - -> SquaredScore - dummy: int # dummy primary attribute - --- - """ \ No newline at end of file +class EphysChannel(dj.Subordinate, dj.Imported): + definition = """ # subtable containing individual channels + -> Ephys + channel :tinyint unsigned # channel number within Ephys + ---- + voltage :longblob + """ + + def fill(self, key, number_samples): + """ + populate random trace of specified length + """ + import numpy as np + for channel in range(16): + self.insert1( + dict(key, + channel=channel, + voltage=np.float32(np.random.randn(number_samples)) + )) diff --git a/tests/test_declare.py b/tests/test_declare.py index b7af31bcd..6b080a45f 100644 --- a/tests/test_declare.py +++ b/tests/test_declare.py @@ -1,52 +1,14 @@ -import datajoint as dj -from . import PREFIX, CONN_INFO - - +from nose.tools import assert_true, assert_raises, assert_equal, raises +from . import schema class TestDeclare: """ Tests declaration, heading, dependencies, and drop """ - schema = dj.schema(PREFIX + '_test1', locals(), connection=dj.conn(**CONN_INFO)) - - @schema - class Subjects(dj.Manual): - definition = """ - #Basic subject - subject_id : int # unique subject id - --- - real_id : varchar(40) # real-world name - species = "mouse" : enum('mouse', 'monkey', 'human') # species - """ - - @schema - class Experiments(dj.Manual): - definition = """ - # information of non-human subjects - -> Subjects - --- - animal_dob :date # date of birth - """ - - @schema - class Trials(dj.Manual): - definition = """ - # info about trials - -> Subjects - trial_id: int - --- - outcome: int # result of experiment - notes="": varchar(4096) # other comments - trial_ts=CURRENT_TIMESTAMP: timestamp # automatic - """ + def test_attributes(self): + assert_equal(schema.Subjects().heading.names, ['subject_id', 'real_id', 'species']) + assert_equal(schema.Subjects().primary_key, ['subject_id']) + assert_equal(schema.Experiments().heading.names, ['subject_id', 'animal_doob']) + assert_equal(schema.Subjects().primary_key, ['subject_id']) - @schema - class Matrix(dj.Manual): - definition = """ - # Some numpy array - matrix_id: int # unique matrix id - --- - data: longblob # data - comment: varchar(1000) # comment - """ \ No newline at end of file From ac82e1a4451e9df791ce53a796efd30e9bf9ac77 Mon Sep 17 00:00:00 2001 From: Fabian Sinz Date: Thu, 9 Jul 2015 09:19:14 -0500 Subject: [PATCH 19/34] added prepare and content --- datajoint/relation.py | 13 ++++++++++++- datajoint/relational_operand.py | 12 ------------ datajoint/user_relations.py | 15 ++++++++++++++- datajoint/utils.py | 15 +++++++++++++++ 4 files changed, 41 insertions(+), 14 deletions(-) diff --git a/datajoint/relation.py b/datajoint/relation.py index 69b363080..97054dae6 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -50,7 +50,9 @@ def decorator(cls): cls._heading = Heading() cls._context = context # trigger table declaration by requesting the heading from an instance - cls().heading + inst = cls() + inst.heading + inst.prepare() return cls return decorator @@ -269,6 +271,15 @@ def size_on_disk(self): ).fetchone() return (ret['Data_length'] + ret['Index_length'])/1024**2 + # --------- functionality used by the decorator --------- + def prepare(self): + """ + This function is overriden by the user_relation subclasses. It is called on an instance + once when the class object is created and prepares the table depending on the tier. + """ + pass + + class FreeRelation(Relation): """ A relation with no definition. Its table must already exist in the database. diff --git a/datajoint/relational_operand.py b/datajoint/relational_operand.py index a24d4f7ef..7d8e5ed56 100644 --- a/datajoint/relational_operand.py +++ b/datajoint/relational_operand.py @@ -104,18 +104,6 @@ def aggregate(self, _group, *attributes, **renamed_attributes): raise DataJointError('The second argument must be a relation or None') return Projection(self, _group, *attributes, **renamed_attributes) - def group_by(self, *attributes, sortby=None): - r = self.project(*attributes).fetch() - dtype2 = np.dtype({name:r.dtype.fields[name] for name in attributes}) - r2 = np.unique(np.ndarray(r.shape, dtype2, r, 0, r.strides)) - r2.sort(order=sortby if sortby is not None else attributes) - for nk in r2: - restr = ' and '.join(["%s='%s'" % (fn, str(v)) for fn, v in zip(r2.dtype.names, nk)]) - if len(nk) == 1: - yield nk[0], self & restr - else: - yield nk, self & restr - def __iand__(self, restriction): """ in-place relational restriction or semijoin diff --git a/datajoint/user_relations.py b/datajoint/user_relations.py index 4f9fdcbdc..bf270e119 100644 --- a/datajoint/user_relations.py +++ b/datajoint/user_relations.py @@ -13,6 +13,7 @@ class Manual(Relation): """ Inherit from this class if the table's values are entered manually. """ + @property def table_name(self): """ @@ -27,6 +28,7 @@ class Lookup(Relation): currently equivalent to defining the table as Manual and serves semantic purposes only. """ + @property def table_name(self): """ @@ -34,12 +36,21 @@ def table_name(self): """ return '#' + from_camel_case(self.__class__.__name__) + def prepare(self): + """ + Checks whether the instance has a property called `content` and inserts its elements. + """ + if hasattr(self, 'content'): + for row in self.content: + self.insert1(row, ignore_errors=True) + class Imported(Relation, AutoPopulate): """ Inherit from this class if the table's values are imported from external data sources. The inherited class must at least provide the function `_make_tuples`. """ + @property def table_name(self): """ @@ -53,6 +64,7 @@ class Computed(Relation, AutoPopulate): Inherit from this class if the table's values are computed from other relations in the schema. The inherited class must at least provide the function `_make_tuples`. """ + @property def table_name(self): """ @@ -65,6 +77,7 @@ class Subordinate: """ Mix-in to make computed tables subordinate """ + @property def populated_from(self): """ @@ -95,6 +108,7 @@ def from_camel_case(s): >>>from_camel_case("TableName") "table_name" """ + def convert(match): return ('_' if match.groups()[0] else '') + match.group(0).lower() @@ -102,4 +116,3 @@ def convert(match): raise DataJointError( 'ClassName must be alphanumeric in CamelCase, begin with a capital letter') return re.sub(r'(\B[A-Z])|(\b[A-Z])', convert, s) - diff --git a/datajoint/utils.py b/datajoint/utils.py index e1e3d2fda..a142e2707 100644 --- a/datajoint/utils.py +++ b/datajoint/utils.py @@ -1,3 +1,5 @@ +import numpy as np + def user_choice(prompt, choices=("yes", "no"), default=None): """ @@ -14,3 +16,16 @@ def user_choice(prompt, choices=("yes", "no"), default=None): response = response if response else default valid = response in choices return response + + +def group_by(rel, *attributes, sortby=None): + r = rel.project(*attributes).fetch() + dtype2 = np.dtype({name:r.dtype.fields[name] for name in attributes}) + r2 = np.unique(np.ndarray(r.shape, dtype2, r, 0, r.strides)) + r2.sort(order=sortby if sortby is not None else attributes) + for nk in r2: + restr = ' and '.join(["%s='%s'" % (fn, str(v)) for fn, v in zip(r2.dtype.names, nk)]) + if len(nk) == 1: + yield nk[0], rel & restr + else: + yield nk, rel & restr \ No newline at end of file From 1d96ebfc2d1b68c8b99d73a0893f85ee7d90943e Mon Sep 17 00:00:00 2001 From: Fabian Sinz Date: Thu, 9 Jul 2015 09:23:00 -0500 Subject: [PATCH 20/34] content -> contents --- datajoint/user_relations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datajoint/user_relations.py b/datajoint/user_relations.py index bf270e119..00bcb6fd0 100644 --- a/datajoint/user_relations.py +++ b/datajoint/user_relations.py @@ -40,8 +40,8 @@ def prepare(self): """ Checks whether the instance has a property called `content` and inserts its elements. """ - if hasattr(self, 'content'): - for row in self.content: + if hasattr(self, 'contents'): + for row in self.contents: self.insert1(row, ignore_errors=True) From 2e8e46d3fc55449067ca32cd1f3fab4f2c14de51 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Thu, 9 Jul 2015 09:30:22 -0500 Subject: [PATCH 21/34] converted size_on_disk into a property --- datajoint/relation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/datajoint/relation.py b/datajoint/relation.py index 1a2e06804..40b9b146c 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -273,6 +273,7 @@ def drop(self): relations.pop().drop_quick() print('Tables dropped.') + @property def size_on_disk(self): """ :return: size of data and indices in GiB taken by the table on the storage device From 1dd616d4da004a3702f7422962ddca3e6637f80d Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Thu, 9 Jul 2015 10:04:57 -0500 Subject: [PATCH 22/34] minor cleanup. size_on_disk in now in bytes. --- datajoint/relation.py | 14 +++++++------- datajoint/user_relations.py | 5 ++--- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/datajoint/relation.py b/datajoint/relation.py index d7b2d163e..714711588 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -50,9 +50,9 @@ def decorator(cls): cls._heading = Heading() cls._context = context # trigger table declaration by requesting the heading from an instance - inst = cls() - inst.heading - inst.prepare() + instance = cls() + instance.heading # trigger table declaration + instance.prepare() return cls return decorator @@ -278,19 +278,19 @@ def drop(self): @property def size_on_disk(self): """ - :return: size of data and indices in GiB taken by the table on the storage device + :return: size of data and indices in bytes on the storage device """ ret = self.connection.query( 'SHOW TABLE STATUS FROM `{database}` WHERE NAME="{table}"'.format( database=self.database, table=self.table_name), as_dict=True ).fetchone() - return (ret['Data_length'] + ret['Index_length'])/1024**2 + return ret['Data_length'] + ret['Index_length'] # --------- functionality used by the decorator --------- def prepare(self): """ - This function is overriden by the user_relation subclasses. It is called on an instance - once when the class object is created and prepares the table depending on the tier. + This method is overridden by the user_relations subclasses. It is called on an instance + once when the class is declared. """ pass diff --git a/datajoint/user_relations.py b/datajoint/user_relations.py index 00bcb6fd0..26a21b0db 100644 --- a/datajoint/user_relations.py +++ b/datajoint/user_relations.py @@ -38,11 +38,10 @@ def table_name(self): def prepare(self): """ - Checks whether the instance has a property called `content` and inserts its elements. + Checks whether the instance has a property called `contents` and inserts its elements. """ if hasattr(self, 'contents'): - for row in self.contents: - self.insert1(row, ignore_errors=True) + self.insert(self.contents, ignore_errors=True) class Imported(Relation, AutoPopulate): From e07efe8429b0602ef44ecb10a9a38aa6c1470fd1 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Thu, 9 Jul 2015 12:48:39 -0500 Subject: [PATCH 23/34] bugfix in job reservations --- datajoint/__init__.py | 5 +++-- datajoint/autopopulate.py | 2 +- datajoint/jobs.py | 15 ++++++++------- datajoint/relation.py | 4 +++- 4 files changed, 15 insertions(+), 11 deletions(-) diff --git a/datajoint/__init__.py b/datajoint/__init__.py index 265d61b8d..507ce06b1 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -18,7 +18,8 @@ 'Connection', 'Heading', 'Relation', 'FreeRelation', 'Not', 'Relation', 'schema', 'Manual', 'Lookup', 'Imported', 'Computed', - 'AutoPopulate', 'conn', 'DataJointError', 'blob'] + 'get_jobs', + 'conn', 'DataJointError', 'blob'] class DataJointError(Exception): @@ -50,7 +51,7 @@ class DataJointError(Exception): from .connection import conn, Connection from .relation import Relation from .user_relations import Manual, Lookup, Imported, Computed, Subordinate -from .autopopulate import AutoPopulate +from .jobs import get_jobs from . import blob from .relational_operand import Not from .heading import Heading diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index 1ffe18656..2382caae4 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -95,5 +95,5 @@ def progress(self): """ total = len(self.populated_from) remaining = len(self.populated_from - self.target) - print('Remaining %d of %d (%2.1f%%)' % (remaining, total, 100*remaining/total) + print('Completed %d of %d (%2.1f%%)' % (total-remaining, total, 100-100*remaining/total) if remaining else 'Complete', flush=True) diff --git a/datajoint/jobs.py b/datajoint/jobs.py index 76fd5b5f9..6fbea5c23 100644 --- a/datajoint/jobs.py +++ b/datajoint/jobs.py @@ -5,12 +5,13 @@ from .relation import Relation, schema -def get_jobs_table(database): +def get_jobs(database): """ - :return: the base relation of the job reservation table for database """ - self = get_jobs_table + + # cache for containing an instance + self = get_jobs if not hasattr(self, 'lookup'): self.lookup = {} @@ -64,12 +65,12 @@ def reserve(reserve_jobs, full_table_name, key): if not reserve_jobs: return True database, table_name = split_name(full_table_name) - jobs = get_jobs_table(database) + jobs = get_jobs(database) job_key = dict(table_name=table_name, key_hash=key_hash(key)) if jobs & job_key: return False try: - jobs.insert(dict(job_key, status="reserved", host=os.uname().nodename, pid=os.getpid())) + jobs.insert1(dict(job_key, status="reserved", host=os.uname().nodename, pid=os.getpid())) except pymysql.err.IntegrityError: success = False else: @@ -87,7 +88,7 @@ def complete(reserve_jobs, full_table_name, key): if reserve_jobs: database, table_name = split_name(full_table_name) job_key = dict(table_name=table_name, key_hash=key_hash(key)) - entry = get_jobs_table(full_table_name) & job_key + entry = get_jobs(database) & job_key entry.delete_quick() @@ -103,7 +104,7 @@ def error(reserve_jobs, full_table_name, key, error_message): if reserve_jobs: database, table_name = split_name(full_table_name) job_key = dict(table_name=table_name, key_hash=key_hash(key)) - jobs = get_jobs_table(database) + jobs = get_jobs(database) jobs.insert(dict(job_key, status="error", host=os.uname(), diff --git a/datajoint/relation.py b/datajoint/relation.py index 714711588..74e4e3a1e 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -236,12 +236,14 @@ def delete(self): raise NotImplementedError('Restricted cascading deletes are not yet implemented') do_delete = True if config['safemode']: + do_delete = False print('The contents of the following tables are about to be deleted:') for relation in relations: count = len(relation) if count: + do_delete = True print(relation.full_table_name, '(%d tuples)' % count) - do_delete = user_choice("Proceed?", default='no') == 'yes' + do_delete = do_delete and user_choice("Proceed?", default='no') == 'yes' if do_delete: with self.connection.transaction: while relations: From 7f5d6235123f7bad1a1506f137dc5ba9ce8e8993 Mon Sep 17 00:00:00 2001 From: Fabian Sinz Date: Thu, 9 Jul 2015 13:12:48 -0500 Subject: [PATCH 24/34] bugfix in error of jobs --- datajoint/jobs.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/datajoint/jobs.py b/datajoint/jobs.py index 6fbea5c23..64e52c1b7 100644 --- a/datajoint/jobs.py +++ b/datajoint/jobs.py @@ -105,8 +105,9 @@ def error(reserve_jobs, full_table_name, key, error_message): database, table_name = split_name(full_table_name) job_key = dict(table_name=table_name, key_hash=key_hash(key)) jobs = get_jobs(database) - jobs.insert(dict(job_key, + + jobs.insert1(dict(job_key, status="error", - host=os.uname(), + host=os.uname().nodename, pid=os.getpid(), error_message=error_message), replace=True) From 6aed9b74f9428cb01916cadeaa93107d48b7e746 Mon Sep 17 00:00:00 2001 From: Fabian Sinz Date: Fri, 10 Jul 2015 08:46:39 -0500 Subject: [PATCH 25/34] before batches --- datajoint/autopopulate.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index 2382caae4..1108dd9b5 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -11,6 +11,12 @@ logger = logging.getLogger(__name__) +def batches(iter, n=1): + iter_len = len(iter) + for start in range(0, iter_len, n): + yield iter[start:min(start + n, iter_len)] + + class AutoPopulate(metaclass=abc.ABCMeta): """ AutoPopulate is a mixin class that adds the method populate() to a Relation class. @@ -48,7 +54,7 @@ def _make_tuples(self, key): def target(self): return self - def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False): + def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False, batch=1): """ rel.populate() calls rel._make_tuples(key) for every primary key in self.populated_from for which there is not already a tuple in rel. @@ -56,6 +62,7 @@ def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False): :param restriction: restriction on rel.populated_from - target :param suppress_errors: suppresses error if true :param reserve_jobs: currently not implemented + :param batch: batch size of a single job """ error_list = [] if suppress_errors else None if not isinstance(self.populated_from, RelationalOperand): @@ -95,5 +102,5 @@ def progress(self): """ total = len(self.populated_from) remaining = len(self.populated_from - self.target) - print('Completed %d of %d (%2.1f%%)' % (total-remaining, total, 100-100*remaining/total) + print('Completed %d of %d (%2.1f%%)' % (total - remaining, total, 100 - 100 * remaining / total) if remaining else 'Complete', flush=True) From bf1e670aac3d239c1f1fb1fb05aa482b9c3f8fc4 Mon Sep 17 00:00:00 2001 From: Fabian Sinz Date: Fri, 10 Jul 2015 09:41:29 -0500 Subject: [PATCH 26/34] implemented batch processing in populate --- datajoint/autopopulate.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index 1108dd9b5..54549d46d 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -17,6 +17,19 @@ def batches(iter, n=1): yield iter[start:min(start + n, iter_len)] +def reserved_batches(rel, reserve_jobs, full_table_name, n=1): + ret = [] + for key in rel: + if jobs.reserve(reserve_jobs, full_table_name, key): + ret.append(key) + if len(ret) == n: + yield ret + ret = [] + else: + if len(ret) > 0: + yield ret + + class AutoPopulate(metaclass=abc.ABCMeta): """ AutoPopulate is a mixin class that adds the method populate() to a Relation class. @@ -54,7 +67,7 @@ def _make_tuples(self, key): def target(self): return self - def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False, batch=1): + def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False, batch_size=1): """ rel.populate() calls rel._make_tuples(key) for every primary key in self.populated_from for which there is not already a tuple in rel. @@ -62,7 +75,7 @@ def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False, :param restriction: restriction on rel.populated_from - target :param suppress_errors: suppresses error if true :param reserve_jobs: currently not implemented - :param batch: batch size of a single job + :param batch_size: batch size of a single job """ error_list = [] if suppress_errors else None if not isinstance(self.populated_from, RelationalOperand): @@ -73,8 +86,8 @@ def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False, full_table_name = self.target.full_table_name unpopulated = (self.populated_from - self.target) & restriction - for key in unpopulated.project(): - if jobs.reserve(reserve_jobs, full_table_name, key): + for batch in reserved_batches(unpopulated.project(), reserve_jobs, full_table_name, n=batch_size): + for key in batch: self.connection.start_transaction() if key in self.target: # already populated self.connection.cancel_transaction() From 1f71a2da02796ce62b8e3c2fdc687ac06bdb330a Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 10 Jul 2015 10:42:47 -0500 Subject: [PATCH 27/34] refactored job reservations and the schema decorator. Moved job manager to the Connection class. Converted the schema decorator into a class and extracted into its own module. --- datajoint/__init__.py | 7 +- datajoint/autopopulate.py | 16 ++-- datajoint/connection.py | 10 +- datajoint/jobs.py | 187 +++++++++++++++++++------------------- datajoint/relation.py | 54 +---------- datajoint/schema.py | 62 +++++++++++++ 6 files changed, 174 insertions(+), 162 deletions(-) create mode 100644 datajoint/schema.py diff --git a/datajoint/__init__.py b/datajoint/__init__.py index 507ce06b1..c99d13635 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -18,8 +18,7 @@ 'Connection', 'Heading', 'Relation', 'FreeRelation', 'Not', 'Relation', 'schema', 'Manual', 'Lookup', 'Imported', 'Computed', - 'get_jobs', - 'conn', 'DataJointError', 'blob'] + 'conn', 'DataJointError'] class DataJointError(Exception): @@ -51,8 +50,6 @@ class DataJointError(Exception): from .connection import conn, Connection from .relation import Relation from .user_relations import Manual, Lookup, Imported, Computed, Subordinate -from .jobs import get_jobs -from . import blob from .relational_operand import Not from .heading import Heading -from .relation import schema \ No newline at end of file +from .schema import schema \ No newline at end of file diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index 2382caae4..2a6d96c5e 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -3,7 +3,7 @@ import logging from .relational_operand import RelationalOperand from . import DataJointError -from .relation import Relation, FreeRelation, schema +from .relation import Relation, FreeRelation from . import jobs # noinspection PyExceptionInherit,PyCallingNonCallable @@ -64,21 +64,24 @@ def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False): if self.connection.in_transaction: raise DataJointError('Populate cannot be called during a transaction.') - full_table_name = self.target.full_table_name + jobs = self.connection.jobs[self.target.database] + table_name = self.target.table_name unpopulated = (self.populated_from - self.target) & restriction for key in unpopulated.project(): - if jobs.reserve(reserve_jobs, full_table_name, key): + if not reserve_jobs or jobs.reserve(table_name, key): self.connection.start_transaction() if key in self.target: # already populated self.connection.cancel_transaction() - jobs.complete(reserve_jobs, full_table_name, key) + if reserve_jobs: + jobs.complete(table_name, key) else: logger.info('Populating: ' + str(key)) try: self._make_tuples(dict(key)) except Exception as error: self.connection.cancel_transaction() - jobs.error(reserve_jobs, full_table_name, key, error_message=str(error)) + if reserve_jobs: + jobs.error(table_name, key, error_message=str(error)) if not suppress_errors: raise else: @@ -86,7 +89,8 @@ def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False): error_list.append((key, error)) else: self.connection.commit_transaction() - jobs.complete(reserve_jobs, full_table_name, key) + if reserve_jobs: + jobs.complete(table_name, key) return error_list def progress(self): diff --git a/datajoint/connection.py b/datajoint/connection.py index 4f67d8e6a..42d70e21c 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -5,10 +5,11 @@ """ from contextlib import contextmanager import pymysql -from . import DataJointError import logging -from . import config +from collections import defaultdict +from . import DataJointError, config from .erd import ERD +from .jobs import JobManager logger = logging.getLogger(__name__) @@ -49,7 +50,6 @@ class Connection: :param user: user name :param passwd: password :param init_fun: initialization function - """ def __init__(self, host, user, passwd, init_fun=None): @@ -67,6 +67,7 @@ def __init__(self, host, user, passwd, init_fun=None): raise DataJointError('Connection failed.') self._conn.autocommit(True) self._in_transaction = False + self.jobs = JobManager(self) def __del__(self): logger.info('Disconnecting {user}@{host}:{port}'.format(**self.conn_info)) @@ -167,5 +168,4 @@ def transaction(self): self.cancel_transaction() raise else: - self.commit_transaction() - + self.commit_transaction() \ No newline at end of file diff --git a/datajoint/jobs.py b/datajoint/jobs.py index 6fbea5c23..921a26af1 100644 --- a/datajoint/jobs.py +++ b/datajoint/jobs.py @@ -1,112 +1,109 @@ import hashlib import os import pymysql +from .relation import Relation -from .relation import Relation, schema - -def get_jobs(database): +def key_hash(key): """ - :return: the base relation of the job reservation table for database + 32-byte hash used for lookup of primary keys of jobs """ - - # cache for containing an instance - self = get_jobs - if not hasattr(self, 'lookup'): - self.lookup = {} - - if database not in self.lookup: - @schema(database, context={}) - class JobsRelation(Relation): - definition = """ - # the job reservation table - table_name: varchar(255) # className of the table - key_hash: char(32) # key hash - --- - status: enum('reserved','error','ignore') # if tuple is missing, the job is available - key=null: blob # structure containing the key - error_message="": varchar(1023) # error message returned if failed - error_stack=null: blob # error stack if failed - host="": varchar(255) # system hostname - pid=0: int unsigned # system process id - timestamp=CURRENT_TIMESTAMP: timestamp # automatic timestamp - """ - - @property - def table_name(self): - return '~jobs' - - self.lookup[database] = JobsRelation() - - return self.lookup[database] - - -def split_name(full_table_name): - [database, table_name] = full_table_name.split('.') - return database.strip('` '), table_name.strip('` ') - - -def key_hash(key): hashed = hashlib.md5() for k, v in sorted(key.items()): hashed.update(str(v).encode()) return hashed.hexdigest() -def reserve(reserve_jobs, full_table_name, key): - """ - Reserve a job for computation. When a job is reserved, the job table contains an entry for the - job key, identified by its hash. When jobs are completed, the entry is removed. - :param reserve_jobs: if True, use job reservation - :param full_table_name: `database`.`table_name` - :param key: the dict of the job's primary key - :return: True if reserved job successfully +class JobRelation(Relation): """ - if not reserve_jobs: - return True - database, table_name = split_name(full_table_name) - jobs = get_jobs(database) - job_key = dict(table_name=table_name, key_hash=key_hash(key)) - if jobs & job_key: - return False - try: - jobs.insert1(dict(job_key, status="reserved", host=os.uname().nodename, pid=os.getpid())) - except pymysql.err.IntegrityError: - success = False - else: - success = True - return success - - -def complete(reserve_jobs, full_table_name, key): + A base relation with no definition. Allows reserving jobs """ - Log a completed job. When a job is completed, its reservation entry is deleted. - :param reserve_jobs: if True, use job reservation - :param full_table_name: `database`.`table_name` - :param key: the dict of the job's primary key - """ - if reserve_jobs: - database, table_name = split_name(full_table_name) + def __init__(self, connection, database): + self.database = database + self._table_name = '~jobs' + self._connection = connection + self._definition = """ # job reservation table + table_name :varchar(255) # className of the table + key_hash :char(32) # key hash + --- + status :enum('reserved','error','ignore') # if tuple is missing, the job is available + key=null :blob # structure containing the key + error_message="" :varchar(1023) # error message returned if failed + error_stack=null :blob # error stack if failed + host="" :varchar(255) # system hostname + pid=0 :int unsigned # system process id + timestamp=CURRENT_TIMESTAMP :timestamp # automatic timestamp + """ + + @property + def definition(self): + return self._definition + + @property + def connection(self): + return self._connection + + @property + def table_name(self): + return self._table_name + + + def reserve(self, table_name, key): + """ + Reserve a job for computation. When a job is reserved, the job table contains an entry for the + job key, identified by its hash. When jobs are completed, the entry is removed. + :param full_table_name: `database`.`table_name` + :param key: the dict of the job's primary key + :return: True if reserved job successfully. False = the jobs is already taken + """ job_key = dict(table_name=table_name, key_hash=key_hash(key)) - entry = get_jobs(database) & job_key - entry.delete_quick() - - -def error(reserve_jobs, full_table_name, key, error_message): - """ - Log an error message. The job reservation is replaced with an error entry. - if an error occurs, leave an entry describing the problem - :param reserve_jobs: if True, use job reservation - :param full_table_name: `database`.`table_name` - :param key: the dict of the job's primary key - :param error_message: string error message - """ - if reserve_jobs: - database, table_name = split_name(full_table_name) + try: + self.insert1( + dict(job_key, + status="reserved", + host=os.uname().nodename, + pid=os.getpid())) + except pymysql.err.IntegrityError: + return False + else: + return True + + + def complete(self, table_name, key): + """ + Log a completed job. When a job is completed, its reservation entry is deleted. + :param reserve_jobs: if True, use job reservation + :param full_table_name: `database`.`table_name` + :param key: the dict of the job's primary key + """ + job_key = dict(table_name=table_name, key_hash=key_hash(key)) + (self & job_key).delete_quick() + + + def error(self, table_name, key, error_message): + """ + Log an error message. The job reservation is replaced with an error entry. + if an error occurs, leave an entry describing the problem + :param reserve_jobs: if True, use job reservation + :param full_table_name: `database`.`table_name` + :param key: the dict of the job's primary key + :param error_message: string error message + """ job_key = dict(table_name=table_name, key_hash=key_hash(key)) - jobs = get_jobs(database) - jobs.insert(dict(job_key, - status="error", - host=os.uname(), - pid=os.getpid(), - error_message=error_message), replace=True) + self.insert1( + dict(job_key, + status="error", + host=os.uname().nodename, + pid=os.getpid(), + error_message=error_message), replace=True) + +class JobManager: + + def __init__(self, connection): + self.connection = connection + self._jobs = {} + + def __getitem__(self, database): + if database not in self._jobs: + self._jobs[database] = JobRelation(self.connection, database) + return self._jobs[database] diff --git a/datajoint/relation.py b/datajoint/relation.py index 74e4e3a1e..551e86690 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -2,9 +2,8 @@ import numpy as np import logging import abc -import pymysql -from . import DataJointError, config, conn +from . import DataJointError, config from .declare import declare from .relational_operand import RelationalOperand from .blob import pack @@ -14,50 +13,6 @@ logger = logging.getLogger(__name__) -def schema(database, context, connection=None): - """ - Returns a decorator that can be used to associate a Relation class to a database. - - :param database: name of the database to associate the decorated class with - :param context: dictionary for looking up foreign keys references, usually set to locals() - :param connection: Connection object. Defaults to datajoint.conn() - :return: a decorator for Relation subclasses - """ - if connection is None: - connection = conn() - - # if the database does not exist, create it - cur = connection.query("SHOW DATABASES LIKE '{database}'".format(database=database)) - if cur.rowcount == 0: - logger.info("Database `{database}` could not be found. " - "Attempting to create the database.".format(database=database)) - try: - connection.query("CREATE DATABASE `{database}`".format(database=database)) - logger.info('Created database `{database}`.'.format(database=database)) - except pymysql.OperationalError: - raise DataJointError("Database named `{database}` was not defined, and" - "an attempt to create has failed. Check" - " permissions.".format(database=database)) - - def decorator(cls): - """ - The decorator binds its argument class object to a database - :param cls: class to be decorated - """ - # class-level attributes - cls.database = database - cls._connection = connection - cls._heading = Heading() - cls._context = context - # trigger table declaration by requesting the heading from an instance - instance = cls() - instance.heading # trigger table declaration - instance.prepare() - return cls - - return decorator - - class Relation(RelationalOperand, metaclass=abc.ABCMeta): """ Relation is an abstract class that represents a base relation, i.e. a table in the database. @@ -115,7 +70,6 @@ def from_clause(self): """ return self.full_table_name - # ------------- dependencies ---------- # @property def parents(self): @@ -299,12 +253,10 @@ def prepare(self): class FreeRelation(Relation): """ - A relation with no definition. Its table must already exist in the database. + A base relation without a dedicated class. The table name is explicitly set. """ def __init__(self, connection, full_table_name, definition=None, context=None): - [database, table_name] = full_table_name.split('.') - self.database = database.strip('`') - self._table_name = table_name.strip('`') + self.database, self._table_name = (s.strip('`') for s in full_table_name.split('.')) self._connection = connection self._definition = definition self._context = context diff --git a/datajoint/schema.py b/datajoint/schema.py new file mode 100644 index 000000000..3a6e64025 --- /dev/null +++ b/datajoint/schema.py @@ -0,0 +1,62 @@ +import pymysql +import logging + +from . import DataJointError, conn +from .heading import Heading + + +logger = logging.getLogger(__name__) + + +class schema: + """ + A schema object can be used as a decorator that associates a Relation class to a database as + well as a namespace for looking up foreign key references. + """ + + def __init__(self, database, context, connection=None): + """ + :param database: name of the database to associate the decorated class with + :param context: dictionary for looking up foreign keys references, usually set to locals() + :param connection: Connection object. Defaults to datajoint.conn() + """ + if connection is None: + connection = conn() + self.database = database + self.connection = connection + self.context = context + + # if the database does not exist, create it + cur = connection.query("SHOW DATABASES LIKE '{database}'".format(database=database)) + if cur.rowcount == 0: + logger.info("Database `{database}` could not be found. " + "Attempting to create the database.".format(database=database)) + try: + connection.query("CREATE DATABASE `{database}`".format(database=database)) + logger.info('Created database `{database}`.'.format(database=database)) + except pymysql.OperationalError: + raise DataJointError("Database named `{database}` was not defined, and" + "an attempt to create has failed. Check" + " permissions.".format(database=database)) + + def __call__(self, cls): + """ + The decorator binds its argument class object to a database + :param cls: class to be decorated + """ + # class-level attributes + cls.database = self.database + cls._connection = self.connection + cls._heading = Heading() + cls._context = self.context + + # trigger table declaration by requesting the heading from an instance + instance = cls() + instance.heading # trigger table declaration + instance.prepare() + return cls + + @property + def jobs(self): + return get_jobs(self.connection, self.database) + From a54af9840f842894802d153b80c0f6de9f469c53 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 10 Jul 2015 11:06:22 -0500 Subject: [PATCH 28/34] schema now has the attribute `jobs` containing the job reservation table --- datajoint/schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datajoint/schema.py b/datajoint/schema.py index 3a6e64025..da12ec8eb 100644 --- a/datajoint/schema.py +++ b/datajoint/schema.py @@ -58,5 +58,5 @@ def __call__(self, cls): @property def jobs(self): - return get_jobs(self.connection, self.database) + return self.connection.jobs[self.database] From 9169efc9b6f820de0d70804c301dfb5ea4466184 Mon Sep 17 00:00:00 2001 From: Fabian Sinz Date: Fri, 10 Jul 2015 11:06:11 -0500 Subject: [PATCH 29/34] tuples iterator --- datajoint/autopopulate.py | 27 ++++----------------------- datajoint/relational_operand.py | 14 ++++++++++++++ 2 files changed, 18 insertions(+), 23 deletions(-) diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index 54549d46d..5e2e528f8 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -11,25 +11,6 @@ logger = logging.getLogger(__name__) -def batches(iter, n=1): - iter_len = len(iter) - for start in range(0, iter_len, n): - yield iter[start:min(start + n, iter_len)] - - -def reserved_batches(rel, reserve_jobs, full_table_name, n=1): - ret = [] - for key in rel: - if jobs.reserve(reserve_jobs, full_table_name, key): - ret.append(key) - if len(ret) == n: - yield ret - ret = [] - else: - if len(ret) > 0: - yield ret - - class AutoPopulate(metaclass=abc.ABCMeta): """ AutoPopulate is a mixin class that adds the method populate() to a Relation class. @@ -67,7 +48,7 @@ def _make_tuples(self, key): def target(self): return self - def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False, batch_size=1): + def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False, batch=1): """ rel.populate() calls rel._make_tuples(key) for every primary key in self.populated_from for which there is not already a tuple in rel. @@ -75,7 +56,7 @@ def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False, :param restriction: restriction on rel.populated_from - target :param suppress_errors: suppresses error if true :param reserve_jobs: currently not implemented - :param batch_size: batch size of a single job + :param batch: batch size of a single job """ error_list = [] if suppress_errors else None if not isinstance(self.populated_from, RelationalOperand): @@ -86,8 +67,8 @@ def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False, full_table_name = self.target.full_table_name unpopulated = (self.populated_from - self.target) & restriction - for batch in reserved_batches(unpopulated.project(), reserve_jobs, full_table_name, n=batch_size): - for key in batch: + for key in unpopulated.project(): + if jobs.reserve(reserve_jobs, full_table_name, key): self.connection.start_transaction() if key in self.target: # already populated self.connection.cancel_transaction() diff --git a/datajoint/relational_operand.py b/datajoint/relational_operand.py index 303ffd39d..d634c0691 100644 --- a/datajoint/relational_operand.py +++ b/datajoint/relational_operand.py @@ -246,6 +246,20 @@ def __iter__(self): for field_name, up, value in zip(heading.names, do_unpack, values)} values = cur.fetchone() + @property + def tuples(self): + """ + Iterator that returns the contents of the database as tuples. + """ + cur = self.cursor() + heading = self.heading # construct once for efficiency + do_unpack = tuple(h in heading.blobs for h in heading.names) + values = cur.fetchone() + while values: + yield tuple(unpack(value) if up else value for up, value in zip(do_unpack, values)) + values = cur.fetchone() + + @property def where_clause(self): """ From a8ccb785cd9e84514549fc0e710457466de1b461 Mon Sep 17 00:00:00 2001 From: Fabian Sinz Date: Fri, 10 Jul 2015 12:09:55 -0500 Subject: [PATCH 30/34] introduced new iterators values() and keys() --- datajoint/relational_operand.py | 104 +++++++++++++++++--------------- 1 file changed, 54 insertions(+), 50 deletions(-) diff --git a/datajoint/relational_operand.py b/datajoint/relational_operand.py index d634c0691..38c1d8ea8 100644 --- a/datajoint/relational_operand.py +++ b/datajoint/relational_operand.py @@ -162,6 +162,41 @@ def __call__(self, *args, **kwargs): """ return self.fetch(*args, **kwargs) + def cursor(self, offset=0, limit=None, order_by=None, descending=False, as_dict=False): + """ + Return query cursor. + See Relation.fetch() for input description. + :return: cursor to the query + """ + if offset and limit is None: + raise DataJointError('limit is required when offset is set') + sql = self.make_select() + if order_by is not None: + sql += ' ORDER BY ' + ', '.join(order_by) + if descending: + sql += ' DESC' + if limit is not None: + sql += ' LIMIT %d' % limit + if offset: + sql += ' OFFSET %d' % offset + logger.debug(sql) + return self.connection.query(sql, as_dict=as_dict) + + def __repr__(self): + limit = config['display.limit'] + width = config['display.width'] + rel = self.project(*self.heading.non_blobs) # project out blobs + template = '%%-%d.%ds' % (width, width) + columns = rel.heading.names + repr_string = ' '.join([template % column for column in columns]) + '\n' + repr_string += ' '.join(['+' + '-' * (width - 2) + '+' for _ in columns]) + '\n' + for tup in rel.fetch(limit=limit): + repr_string += ' '.join([template % column for column in tup]) + '\n' + if len(self) > limit: + repr_string += '...\n' + repr_string += ' (%d tuples)\n' % len(self) + return repr_string + def fetch1(self): """ This version of fetch is called when self is expected to contain exactly one tuple. @@ -198,67 +233,36 @@ def fetch(self, offset=0, limit=None, order_by=None, descending=False, as_dict=F ret[blob_name] = list(map(unpack, ret[blob_name])) return ret - def cursor(self, offset=0, limit=None, order_by=None, descending=False, as_dict=False): + def values(self, offset=0, limit=None, order_by=None, descending=False, as_dict=False): """ - Return query cursor. - See Relation.fetch() for input description. - :return: cursor to the query + Iterator that returns the contents of the database. """ - if offset and limit is None: - raise DataJointError('limit is required when offset is set') - sql = self.make_select() - if order_by is not None: - sql += ' ORDER BY ' + ', '.join(order_by) - if descending: - sql += ' DESC' - if limit is not None: - sql += ' LIMIT %d' % limit - if offset: - sql += ' OFFSET %d' % offset - logger.debug(sql) - return self.connection.query(sql, as_dict=as_dict) + cur = self.cursor(offset=offset, limit=limit, order_by=order_by, + descending=descending, as_dict=as_dict) + heading = self.heading + do_unpack = tuple(h in heading.blobs for h in heading.names) + values = cur.fetchone() + while values: + if as_dict: + yield OrderedDict((field_name, unpack(values[field_name])) if up else (field_name, values[field_name]) + for field_name, up in zip(heading.names, do_unpack)) - def __repr__(self): - limit = config['display.limit'] - width = config['display.width'] - rel = self.project(*self.heading.non_blobs) # project out blobs - template = '%%-%d.%ds' % (width, width) - columns = rel.heading.names - repr_string = ' '.join([template % column for column in columns]) + '\n' - repr_string += ' '.join(['+' + '-' * (width - 2) + '+' for _ in columns]) + '\n' - for tup in rel.fetch(limit=limit): - repr_string += ' '.join([template % column for column in tup]) + '\n' - if len(self) > limit: - repr_string += '...\n' - repr_string += ' (%d tuples)\n' % len(self) - return repr_string + else: + yield tuple(unpack(value) if up else value for up, value in zip(do_unpack, values)) + + values = cur.fetchone() def __iter__(self): """ Iterator that yields individual tuples of the current table dictionaries. """ - cur = self.cursor() - heading = self.heading # construct once for efficiency - do_unpack = tuple(h in heading.blobs for h in heading.names) - values = cur.fetchone() - while values: - yield {field_name: unpack(value) if up else value - for field_name, up, value in zip(heading.names, do_unpack, values)} - values = cur.fetchone() + yield from self.values(as_dict=True) - @property - def tuples(self): + def keys(self, *args, **kwargs): """ - Iterator that returns the contents of the database as tuples. + Iterator that returns primary keys. """ - cur = self.cursor() - heading = self.heading # construct once for efficiency - do_unpack = tuple(h in heading.blobs for h in heading.names) - values = cur.fetchone() - while values: - yield tuple(unpack(value) if up else value for up, value in zip(do_unpack, values)) - values = cur.fetchone() - + yield from self.project().values(*args, **kwargs) @property def where_clause(self): From 67f311b38df10075f266bb81e78aac54b10c097e Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 10 Jul 2015 12:56:07 -0500 Subject: [PATCH 31/34] bugfix in relational algebra. Projections with renames can now be restricted. Also, extracted AttrTuple into the standalone class Attribute --- datajoint/heading.py | 67 +++++++++++++++++---------------- datajoint/jobs.py | 8 +--- datajoint/relational_operand.py | 17 ++++++++- datajoint/schema.py | 2 +- 4 files changed, 53 insertions(+), 41 deletions(-) diff --git a/datajoint/heading.py b/datajoint/heading.py index 5c1ca0f76..85d54ee3e 100644 --- a/datajoint/heading.py +++ b/datajoint/heading.py @@ -5,48 +5,49 @@ from datajoint import DataJointError +class Attribute(namedtuple('Attribute', + ('name', 'type', 'in_key', 'nullable', 'default', + 'comment', 'autoincrement', 'numeric', 'string', 'is_blob', + 'computation', 'dtype'))): + def _asdict(self): + """ + for some reason the inherted _asdict does not work after subclassing from namedtuple + """ + return OrderedDict((name, self[i]) for i, name in enumerate(self._fields)) + + def sql(self): + """ + Convert attribute tuple into its SQL CREATE TABLE clause. + :rtype : SQL code + """ + literals = ['CURRENT_TIMESTAMP'] + if self.nullable: + default = 'DEFAULT NULL' + else: + default = 'NOT NULL' + if self.default: + # enclose value in quotes except special SQL values or already enclosed + quote = self.default.upper() not in literals and self.default[0] not in '"\'' + default += ' DEFAULT ' + ('"%s"' if quote else "%s") % self.default + if any((c in r'\"' for c in self.comment)): + raise DataJointError('Illegal characters in attribute comment "%s"' % self.comment) + return '`{name}` {type} {default} COMMENT "{comment}"'.format( + name=self.name, type=self.type, default=default, comment=self.comment) + + class Heading: """ Local class for relations' headings. Heading contains the property attributes, which is an OrderedDict in which the keys are - the attribute names and the values are AttrTuples. + the attribute names and the values are Attributes. """ - class AttrTuple(namedtuple('AttrTuple', - ('name', 'type', 'in_key', 'nullable', 'default', - 'comment', 'autoincrement', 'numeric', 'string', 'is_blob', - 'computation', 'dtype'))): - def _asdict(self): - """ - for some reason the inherted _asdict does not work after subclassing from namedtuple - """ - return OrderedDict((name, self[i]) for i, name in enumerate(self._fields)) - - def sql(self): - """ - Convert attribute tuple into its SQL CREATE TABLE clause. - :rtype : SQL code - """ - literals = ['CURRENT_TIMESTAMP'] - if self.nullable: - default = 'DEFAULT NULL' - else: - default = 'NOT NULL' - if self.default: - # enclose value in quotes except special SQL values or already enclosed - quote = self.default.upper() not in literals and self.default[0] not in '"\'' - default += ' DEFAULT ' + ('"%s"' if quote else "%s") % self.default - if any((c in r'\"' for c in self.comment)): - raise DataJointError('Illegal characters in attribute comment "%s"' % self.comment) - return '`{name}` {type} {default} COMMENT "{comment}"'.format( - name=self.name, type=self.type, default=default, comment=self.comment) - def __init__(self, attributes=None): """ - :param attributes: a list of dicts with the same keys as AttrTuple + :param attributes: a list of dicts with the same keys as Attribute """ if attributes: - attributes = OrderedDict([(q['name'], Heading.AttrTuple(**q)) for q in attributes]) + attributes = OrderedDict([(q['name'], Attribute(**q)) for q in attributes]) self.attributes = attributes def reset(self): @@ -200,7 +201,7 @@ def init_from_database(self, conn, database, table_name): t = re.sub(r' unsigned$', '', t) # remove unsigned assert (t, is_unsigned) in numeric_types, 'dtype not found for type %s' % t attr['dtype'] = numeric_types[(t, is_unsigned)] - self.attributes = OrderedDict([(q['name'], Heading.AttrTuple(**q)) for q in attributes]) + self.attributes = OrderedDict([(q['name'], Attribute(**q)) for q in attributes]) def project(self, *attribute_list, **renamed_attributes): """ diff --git a/datajoint/jobs.py b/datajoint/jobs.py index 921a26af1..ec2f3d797 100644 --- a/datajoint/jobs.py +++ b/datajoint/jobs.py @@ -68,24 +68,20 @@ def reserve(self, table_name, key): else: return True - def complete(self, table_name, key): """ Log a completed job. When a job is completed, its reservation entry is deleted. - :param reserve_jobs: if True, use job reservation - :param full_table_name: `database`.`table_name` + :param table_name: `database`.`table_name` :param key: the dict of the job's primary key """ job_key = dict(table_name=table_name, key_hash=key_hash(key)) (self & job_key).delete_quick() - def error(self, table_name, key, error_message): """ Log an error message. The job reservation is replaced with an error entry. if an error occurs, leave an entry describing the problem - :param reserve_jobs: if True, use job reservation - :param full_table_name: `database`.`table_name` + :param table_name: `database`.`table_name` :param key: the dict of the job's primary key :param error_message: string error message """ diff --git a/datajoint/relational_operand.py b/datajoint/relational_operand.py index 303ffd39d..e94e3c6b2 100644 --- a/datajoint/relational_operand.py +++ b/datajoint/relational_operand.py @@ -104,7 +104,7 @@ def aggregate(self, _group, *attributes, **renamed_attributes): raise DataJointError('The second argument must be a relation or None') return Projection(self, _group, *attributes, **renamed_attributes) - def __iand__(self, restriction): + def _restrict(self, restriction): """ in-place relational restriction or semijoin """ @@ -114,6 +114,12 @@ def __iand__(self, restriction): self._restrictions.append(restriction) return self + def __iand__(self, restriction): + """ + in-place relational restriction or semijoin + """ + return self._restrict(restriction) + def __and__(self, restriction): """ relational restriction or semijoin @@ -373,6 +379,15 @@ def from_clause(self): self._arg.from_clause, self._group.from_clause, '`,`'.join(self.heading.primary_key)) + def _restrict(self, restriction): + """ + Projection is enclosed in Subquery when restricted if it has renamed attributes + """ + if self.heading.computed: + return Subquery(self) & restriction + else: + return super()._restrict(restriction) + class Subquery(RelationalOperand): """ diff --git a/datajoint/schema.py b/datajoint/schema.py index da12ec8eb..11e417fc8 100644 --- a/datajoint/schema.py +++ b/datajoint/schema.py @@ -52,7 +52,7 @@ def __call__(self, cls): # trigger table declaration by requesting the heading from an instance instance = cls() - instance.heading # trigger table declaration + instance.heading instance.prepare() return cls From 2ece673bbd060de80c51fc8df0c880aff7218f6e Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 10 Jul 2015 13:09:52 -0500 Subject: [PATCH 32/34] minor --- datajoint/autopopulate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index 9b44c71ec..62b7754d1 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -48,7 +48,7 @@ def _make_tuples(self, key): def target(self): return self - def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False, batch=1): + def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False): """ rel.populate() calls rel._make_tuples(key) for every primary key in self.populated_from for which there is not already a tuple in rel. From b01bedef8959230e6a334e65bf00f0887b197949 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 10 Jul 2015 16:42:53 -0500 Subject: [PATCH 33/34] bugfix in relational algebra (semijoin) --- datajoint/relational_operand.py | 2 +- tests/schema.py | 24 +++++++++++------------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/datajoint/relational_operand.py b/datajoint/relational_operand.py index 59961c7a9..a8f633c28 100644 --- a/datajoint/relational_operand.py +++ b/datajoint/relational_operand.py @@ -297,7 +297,7 @@ def make_condition(arg): elif isinstance(r, np.ndarray) or isinstance(r, list): r = '(' + ') OR ('.join([make_condition(q) for q in r]) + ')' elif isinstance(r, RelationalOperand): - common_attributes = ','.join([q for q in self.heading.names if r.heading.names]) + common_attributes = ','.join([q for q in self.heading.names if q in r.heading.names]) r = '(%s) in (SELECT %s FROM %s%s)' % ( common_attributes, common_attributes, r.from_clause, r.where_clause) diff --git a/tests/schema.py b/tests/schema.py index 3bac2dca0..f80acdb13 100644 --- a/tests/schema.py +++ b/tests/schema.py @@ -14,10 +14,10 @@ class User(dj.Manual): username: varchar(12) """ - def fill(self): - names = [['Jake'], ['Cathryn'], ['Shan'], ['Fabian'], ['Edgar'], ['George'], ['Dimitri']] - self.insert(names) - return names + def prepare(self): + self.insert( + [['Jake'], ['Cathryn'], ['Shan'], ['Fabian'], ['Edgar'], ['George'], ['Dimitri']], + ignore_errors=True) @schema class Subject(dj.Manual): @@ -31,15 +31,13 @@ class Subject(dj.Manual): unique index (real_id, species) """ - def fill(self): - data = [ - [1551, '', 'mouse', '2015-04-01', 'genetically engineered super mouse'], - [1, 'Curious George', 'monkey', '2008-06-30', ''], - [1552, '', 'mouse', '2015-06-15', ''], - [1553, '', 'mouse', '2016-07-01', ''] - ] - self.insert(data) - return data + def prepare(self): + self.insert([ + [1551, '', 'mouse', '2015-04-01', 'genetically engineered super mouse'], + [1, 'Curious George', 'monkey', '2008-06-30', ''], + [1552, '', 'mouse', '2015-06-15', ''], + [1553, '', 'mouse', '2016-07-01', ''] + ], ignore_errors=True) @schema class Experiment(dj.Imported): From c79fcc1e99f7cfbd760ad6a35972f7adb50409ed Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Wed, 15 Jul 2015 17:35:03 -0500 Subject: [PATCH 34/34] updated nose tests --- datajoint/blob.py | 3 + datajoint/relation.py | 4 +- tests/schema.py | 27 +-- tests/test_declare.py | 62 ++++++- tests/test_relation.py | 388 ++++------------------------------------- 5 files changed, 108 insertions(+), 376 deletions(-) diff --git a/datajoint/blob.py b/datajoint/blob.py index e002be24e..d513532d8 100644 --- a/datajoint/blob.py +++ b/datajoint/blob.py @@ -72,6 +72,9 @@ def unpack(blob): :rtype: numpy.ndarray """ # decompress if necessary + if blob is None: + return None + if blob[0:5] == b'ZL123': blob_length = np.fromstring(blob[6:14], dtype=np.uint64)[0] blob = zlib.decompress(blob[14:]) diff --git a/datajoint/relation.py b/datajoint/relation.py index 551e86690..140b2858d 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -119,13 +119,13 @@ def insert(self, rows, **kwargs): for row in rows: self.insert1(row, **kwargs) - def insert1(self, tup, ignore_errors=False, replace=False): + def insert1(self, tup, replace=False, ignore_errors=False): """ Insert one data record or one Mapping (like a dict). :param tup: Data record, a Mapping (like a dict), or a list or tuple with ordered values. - :param ignore_errors=False: Ignores errors if True. :param replace=False: Replaces data tuple if True. + :param ignore_errors=False: If True, ignore errors: e.g. constraint violations or duplicates Example:: relation.insert1(dict(subject_id=7, species="mouse", date_of_birth="2014-09-01")) diff --git a/tests/schema.py b/tests/schema.py index f80acdb13..981e205eb 100644 --- a/tests/schema.py +++ b/tests/schema.py @@ -8,16 +8,14 @@ schema = dj.schema(PREFIX + '_test1', locals(), connection=dj.conn(**CONN_INFO)) + @schema -class User(dj.Manual): +class User(dj.Lookup): definition = """ # lab members username: varchar(12) """ + contents = [['Jake'], ['Cathryn'], ['Shan'], ['Fabian'], ['Edgar'], ['George'], ['Dimitri']] - def prepare(self): - self.insert( - [['Jake'], ['Cathryn'], ['Shan'], ['Fabian'], ['Edgar'], ['George'], ['Dimitri']], - ignore_errors=True) @schema class Subject(dj.Manual): @@ -31,13 +29,15 @@ class Subject(dj.Manual): unique index (real_id, species) """ + contents = [ + [1551, '1551', 'mouse', '2015-04-01', 'genetically engineered super mouse'], + [10, 'Curious George', 'monkey', '2008-06-30', ''], + [1552, '1552', 'mouse', '2015-06-15', ''], + [1553, '1553', 'mouse', '2016-07-01', '']] + def prepare(self): - self.insert([ - [1551, '', 'mouse', '2015-04-01', 'genetically engineered super mouse'], - [1, 'Curious George', 'monkey', '2008-06-30', ''], - [1552, '', 'mouse', '2015-06-15', ''], - [1553, '', 'mouse', '2016-07-01', ''] - ], ignore_errors=True) + self.insert(self.contents, ignore_errors=True) + @schema class Experiment(dj.Imported): @@ -66,6 +66,7 @@ def _make_tuples(self, key): experiment_date=(date.today()-timedelta(random.expovariate(1/30))).isoformat(), username=random.choice(users))) + @schema class Trial(dj.Imported): definition = """ # a trial within an experiment @@ -86,6 +87,7 @@ def _make_tuples(self, key): start_time=random.random()*1e9 )) + @schema class Ephys(dj.Imported): definition = """ # some kind of electrophysiological recording @@ -103,7 +105,8 @@ def _make_tuples(self, key): sampling_frequency=16000, duration=random.expovariate(1/30)) self.insert1(row) - EphysChannel().fill(number_samples=round(row.duration*row.sampling_frequency)) + EphysChannel().fill(key, number_samples=round(row.duration*row.sampling_frequency)) + @schema class EphysChannel(dj.Subordinate, dj.Imported): diff --git a/tests/test_declare.py b/tests/test_declare.py index 6b080a45f..e15b49180 100644 --- a/tests/test_declare.py +++ b/tests/test_declare.py @@ -1,14 +1,60 @@ -from nose.tools import assert_true, assert_raises, assert_equal, raises +from nose.tools import assert_true, assert_false, assert_equal, assert_list_equal from . import schema + class TestDeclare: - """ - Tests declaration, heading, dependencies, and drop - """ + def __init__(self): + self.user = schema.User() + self.subject = schema.Subject() + self.experiment = schema.Experiment() + self.trial = schema.Trial() + self.ephys = schema.Ephys() + self.channel = schema.EphysChannel() def test_attributes(self): - assert_equal(schema.Subjects().heading.names, ['subject_id', 'real_id', 'species']) - assert_equal(schema.Subjects().primary_key, ['subject_id']) - assert_equal(schema.Experiments().heading.names, ['subject_id', 'animal_doob']) - assert_equal(schema.Subjects().primary_key, ['subject_id']) + assert_list_equal(self.subject.heading.names, + ['subject_id', 'real_id', 'species', 'date_of_birth', 'subject_notes']) + assert_list_equal(self.subject.primary_key, + ['subject_id']) + assert_true(self.subject.heading.attributes['subject_id'].numeric) + assert_false(self.subject.heading.attributes['real_id'].numeric) + + experiment = schema.Experiment() + assert_list_equal(experiment.heading.names, + ['subject_id', 'experiment_id', 'experiment_date', + 'username', 'data_path', + 'notes', 'entry_time']) + assert_list_equal(experiment.primary_key, + ['subject_id', 'experiment_id']) + + assert_list_equal(self.trial.heading.names, + ['subject_id', 'experiment_id', 'trial_id', 'start_time']) + assert_list_equal(self.trial.primary_key, + ['subject_id', 'experiment_id', 'trial_id']) + + assert_list_equal(self.ephys.heading.names, + ['subject_id', 'experiment_id', 'trial_id', 'sampling_frequency', 'duration']) + assert_list_equal(self.ephys.primary_key, + ['subject_id', 'experiment_id', 'trial_id']) + + assert_list_equal(self.channel.heading.names, + ['subject_id', 'experiment_id', 'trial_id', 'channel', 'voltage']) + assert_list_equal(self.channel.primary_key, + ['subject_id', 'experiment_id', 'trial_id', 'channel']) + assert_true(self.channel.heading.attributes['voltage'].is_blob) + + def test_dependencies(self): + assert_equal(self.user.references, [self.experiment.full_table_name]) + assert_equal(self.experiment.referenced, [self.user.full_table_name]) + + assert_equal(self.subject.children, [self.experiment.full_table_name]) + assert_equal(self.experiment.parents, [self.subject.full_table_name]) + + assert_equal(self.experiment.children, [self.trial.full_table_name]) + assert_equal(self.trial.parents, [self.experiment.full_table_name]) + + assert_equal(self.trial.children, [self.ephys.full_table_name]) + assert_equal(self.ephys.parents, [self.trial.full_table_name]) + assert_equal(self.ephys.children, [self.channel.full_table_name]) + assert_equal(self.channel.parents, [self.ephys.full_table_name]) diff --git a/tests/test_relation.py b/tests/test_relation.py index 11fe25298..d4fa37ddb 100644 --- a/tests/test_relation.py +++ b/tests/test_relation.py @@ -2,370 +2,50 @@ import string from numpy.testing import assert_array_equal import numpy as np -from nose.tools import assert_raises, assert_equal, assert_regexp_matches, \ +from nose.tools import assert_raises, assert_equal, \ assert_false, assert_true, assert_list_equal, \ assert_tuple_equal, assert_dict_equal, raises -from datajoint import DataJointError -from .schema import Subjects, Animals, Matrix, Trials, SquaredScore, SquaredSubtable, \ - ErrorGenerator, schema - -import datajoint as dj +from . import schema class TestRelation: """ - Test creation of relations, their dependencies + Test base relations: insert, delete """ def __init__(self): + self.user = schema.User() + self.subject = schema.Subject() + self.experiment = schema.Experiment() + self.trial = schema.Trial() + self.ephys = schema.Ephys() + self.channel = schema.EphysChannel() + + def test_contents(self): """ - Create a connection object and prepare test modules - as follows: - test1 - has conn and bounded + test the ability of tables to self-populate using the contents property """ - self.subjects = Subjects() - self.animals = Animals() - self.matrix = Matrix() - self.trials = Trials() - self.score = SquaredScore() - self.subtable = SquaredSubtable() - - def setup(self): - pass - - def teardown(self): - self.subjects.delete() - - def test_table_name_manual(self): - assert_true(not self.subjects.table_name.startswith('#') and - not self.subjects.table_name.startswith('_') and not self.subjects.table_name.startswith('__')) - - def test_table_name_computed(self): - assert_true(self.score.table_name.startswith('__')) - assert_true(self.subtable.table_name.startswith('__')) - - def test_population_relation_subordinate(self): - assert_true(self.subtable.populated_from is None) - - @raises(NotImplementedError) - def test_make_tubles_not_implemented_subordinate(self): - self.subtable._make_tuples(None) - - def test_instantiate_relation(self): - s = Subjects() - - def test_compound_restriction(self): - s = self.subjects - t = self.trials - - s.insert1(dict(subject_id=1, real_id='M')) - s.insert1(dict(subject_id=2, real_id='F')) - - # insert trials - n_trials = 20 - for subject_id in [1, 2]: - for trial_id in range(n_trials): - t.insert1( - trial_id=trial_id, - subject_id=subject_id, - outcome=int(np.random.randint(10)), - notes='no comment') - - tM = t & (s & "real_id = 'M'") - t1 = t & "subject_id = 1" - - assert_equal(len(tM), len(t1), "Results of compound request does not have same length") - - for t1_item, tM_item in zip(sorted(t1, key=lambda item: item['trial_id']), - sorted(tM, key=lambda item: item['trial_id'])): - assert_dict_equal(t1_item, tM_item, - 'Dictionary elements do not agree in compound statement') - - def test_record_insert(self): - "Test whether record insert works" - tmp = np.array([(2, 'Klara', 'monkey')], - dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) - - self.subjects.insert1(tmp[0]) - testt2 = (self.subjects & 'subject_id = 2').fetch()[0] - assert_equal(tuple(tmp[0]), tuple(testt2), "Inserted and fetched record do not match!") - - def test_delete(self): - "Test whether delete works" - tmp = np.array([(2, 'Klara', 'monkey'), (1, 'Peter', 'mouse')], - dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) - - self.subjects.batch_insert(tmp) - assert_true(len(self.subjects) == 2, 'Length does not match 2.') - self.subjects.delete() - assert_true(len(self.subjects) == 0, 'Length does not match 0.') - - # - # # def test_cascading_delete(self): - # # "Test whether delete works" - # # tmp = np.array([(2, 'Klara', 'monkey'), (1,'Peter', 'mouse')], - # # dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) - # # - # # self.subjects.batch_insert(tmp) - # # - # # self.trials.insert(dict(subject_id=1, trial_id=1, outcome=0)) - # # self.trials.insert(dict(subject_id=1, trial_id=2, outcome=1)) - # # self.trials.insert(dict(subject_id=2, trial_id=3, outcome=2)) - # # assert_true(len(self.subjects) == 2, 'Length does not match 2.') - # # assert_true(len(self.trials) == 3, 'Length does not match 3.') - # # (self.subjects & 'subject_id=1').delete() - # # assert_true(len(self.subjects) == 1, 'Length does not match 1.') - # # assert_true(len(self.trials) == 1, 'Length does not match 1.') - # - # def test_short_hand_foreign_reference(self): - # self.animals.heading - # - # - # - def test_record_insert_different_order(self): - "Test whether record insert works" - tmp = np.array([('Klara', 2, 'monkey')], - dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) - - self.subjects.insert1(tmp[0]) - testt2 = (self.subjects & 'subject_id = 2').fetch()[0] - assert_equal((2, 'Klara', 'monkey'), tuple(testt2), - "Inserted and fetched record do not match!") - - @raises(KeyError) - def test_wrong_key_insert_records(self): - "Test whether record insert works" - tmp = np.array([('Klara', 2, 'monkey')], - dtype=[('real_deal', 'O'), ('subject_id', '>i4'), ('species', 'O')]) - - self.subjects.insert1(tmp[0]) - - def test_dict_insert(self): - "Test whether record insert works" - tmp = {'real_id': 'Brunhilda', - 'subject_id': 3, - 'species': 'human'} - - self.subjects.insert(tmp) - testt2 = (self.subjects & 'subject_id = 3').fetch()[0] - assert_equal((3, 'Brunhilda', 'human'), tuple(testt2), "Inserted and fetched record do not match!") - - @raises(KeyError) - def test_wrong_key_insert(self): - "Test whether a correct error is generated when inserting wrong attribute name" - tmp = {'real_deal': 'Brunhilda', - 'subject_database': 3, - 'species': 'human'} - - self.subjects.insert1(tmp) - - # - def test_batch_insert(self): - "Test whether record insert works" - tmp = np.array([('Klara', 2, 'monkey'), ('Brunhilda', 3, 'mouse'), ('Mickey', 1, 'human')], - dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) - - self.subjects.batch_insert(tmp) - - expected = np.array([(1, 'Mickey', 'human'), (2, 'Klara', 'monkey'), - (3, 'Brunhilda', 'mouse')], - dtype=[('subject_id', 'i4'), ('species', 'O')]) - - self.subjects.iter_insert(tmp.__iter__()) - - expected = np.array([(1, 'Mickey', 'human'), (2, 'Klara', 'monkey'), - (3, 'Brunhilda', 'mouse')], - dtype=[('subject_id', 'i4'), ('species', 'O')]) - self.subjects.batch_insert(tmp) - - for trial_id in range(1, 11): - self.trials.insert1(dict(subject_id=2, trial_id=trial_id, outcome=np.random.randint(0, 10))) - - def teardown(self): - pass - - def test_autopopulate(self): - self.squared.populate() - assert_equal(len(self.squared), 10) - - for trial in self.trials * self.squared: - assert_equal(trial['outcome'] ** 2, trial['squared']) - - def test_autopopulate_restriction(self): - self.squared.populate(restriction='trial_id <= 5') - assert_equal(len(self.squared), 5) - - for trial in self.trials * self.squared: - assert_equal(trial['outcome'] ** 2, trial['squared']) - - @raises(DataJointError) - def test_autopopulate_relation_check(self): - @schema - class Dummy(dj.Computed): - def populated_from(self): - return None - - def _make_tuples(self, key): - pass - - du = Dummy() - du.populate() - - @raises(DataJointError) - def test_autopopulate_relation_check(self): - self.dummy1.populate() - - @raises(Exception) - def test_autopopulate_relation_check(self): - self.error_generator.populate() - @raises(Exception) - def test_autopopulate_relation_check2(self): - tmp = self.dummy2.populate(suppress_errors=True) - assert_equal(len(tmp), 1, 'Error list should have length 1.') + # test contents + assert_true(self.user) + assert_true(len(self.user) == len(self.user.contents)) + u = self.user.fetch(order_by=['username']) + assert_list_equal(list(u['username']), sorted([s[0] for s in self.user.contents])) + + # test prepare + assert_true(self.subject) + assert_true(len(self.subject) == len(self.subject.contents)) + u = self.subject.fetch(order_by=['subject_id']) + assert_list_equal(list(u['subject_id']), sorted([s[0] for s in self.subject.contents])) + + def test_delete_quick(self): + tmp = np.array([ + (2, 'Klara', 'monkey', '2010-01-01', ''), + (1, 'Peter', 'mouse', '2015-01-01', '')], + dtype=self.subject.heading.as_dtype) + self.subject.insert(tmp) + s = self.subject & ('subject_id in (%s)' % ','.join(str(r) for r in tmp['subject_id'])) + assert_true(len(s) == 2, 'insert did not work.') + s.delete_quick() + assert_true(len(s) == 0, 'delete did not work.')