From 4f1b12b289bf0ee24066443321f172184050e04e Mon Sep 17 00:00:00 2001 From: Fabian Sinz Date: Wed, 20 May 2015 09:00:27 -0500 Subject: [PATCH 1/2] inserted decorator to avoid implicit commits in transaction --- datajoint/__init__.py | 7 +++++++ datajoint/connection.py | 2 ++ datajoint/decorators.py | 22 ++++++++++++++++++++++ datajoint/free_relation.py | 3 +++ requirements.txt | 1 + setup.py | 2 +- tests/test_relation.py | 25 +++++++++++++++++++++++-- 7 files changed, 59 insertions(+), 3 deletions(-) create mode 100644 datajoint/decorators.py diff --git a/datajoint/__init__.py b/datajoint/__init__.py index 38cef2aa6..cddc141bb 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -15,6 +15,13 @@ class DataJointError(Exception): pass +class TransactionError(DataJointError): + """ + Base class for errors specific to DataJoint internal operation. + """ + pass + + # ----------- loads local configuration from file ---------------- from .settings import Config, CONFIGVAR, LOCALCONFIG, logger, log_levels config = Config() diff --git a/datajoint/connection.py b/datajoint/connection.py index 64607de5b..b27bc6287 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -78,10 +78,12 @@ def __exit__(self, exc_type, exc_val, exc_tb): if exc_type is None and exc_val is None and exc_tb is None: self.conn._commit_transaction() self.conn._in_transaction = False + return True else: self.conn._cancel_transaction() self.conn._in_transaction = False logger.debug("Transaction cancled because of an error.", exc_info=(exc_type, exc_val, exc_tb)) + return False class Connection(object): diff --git a/datajoint/decorators.py b/datajoint/decorators.py new file mode 100644 index 000000000..44837baad --- /dev/null +++ b/datajoint/decorators.py @@ -0,0 +1,22 @@ +from decorator import decorator +from . import DataJointError, TransactionError + + +def _not_in_transaction(f, *args, **kwargs): + if not hasattr(args[0], '_conn'): + raise DataJointError(u"{0:s} does not have a member called _conn".format(args[0].__class__.__name__, )) + if not hasattr(args[0]._conn, 'in_transaction'): + raise DataJointError( + u"{0:s}._conn does not have a property in_transaction".format(args[0].__class__.__name__, )) + if args[0]._conn.in_transaction: + raise TransactionError( + u"{0:s} is currently in transaction. Operation not allowed to avoid implicit commits.".format( + args[0].__class__.__name__)) + return f(*args, **kwargs) + + +def not_in_transaction(f): + """ + This decorator raises an error if the function is called during a transaction. + """ + return decorator(_not_in_transaction, f) diff --git a/datajoint/free_relation.py b/datajoint/free_relation.py index fe08b3f51..b6c9d04e2 100644 --- a/datajoint/free_relation.py +++ b/datajoint/free_relation.py @@ -2,6 +2,7 @@ import numpy as np import logging from . import DataJointError, config +from .decorators import not_in_transaction from .relational_operand import RelationalOperand from .blob import pack from .heading import Heading @@ -296,6 +297,7 @@ def tablelist(tier): plt.show() """ + @not_in_transaction def _alter(self, alter_statement): """ Execute ALTER TABLE statement for this table. The schema @@ -343,6 +345,7 @@ def ref_name(self): """ return '`{0}`'.format(self.dbname) + '.' + self.class_name + @not_in_transaction def _declare(self): """ Declares the table in the database if no table in the database matches this object. diff --git a/requirements.txt b/requirements.txt index 82abfa7ea..1e85cd6e3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ networkx matplotlib sphinx_rtd_theme mock +decorator diff --git a/setup.py b/setup.py index 5c56b483c..5320801b0 100644 --- a/setup.py +++ b/setup.py @@ -10,6 +10,6 @@ description='An object-relational mapping and relational algebra to facilitate data definition and data manipulation in MySQL databases.', url='https://github.com/datajoint/datajoint-python', packages=['datajoint'], - requires=['numpy', 'pymysql', 'networkx', 'matplotlib', 'sphinx_rtd_theme', 'mock', 'json'], + requires=['numpy', 'pymysql', 'networkx', 'matplotlib', 'sphinx_rtd_theme', 'mock', 'json', 'decorator'], license = "MIT", ) diff --git a/tests/test_relation.py b/tests/test_relation.py index a2564a80c..07ccada1d 100644 --- a/tests/test_relation.py +++ b/tests/test_relation.py @@ -9,7 +9,7 @@ from datajoint.connection import Connection from nose.tools import assert_raises, assert_equal, assert_regexp_matches, assert_false, assert_true, assert_list_equal,\ assert_tuple_equal, assert_dict_equal, raises -from datajoint import DataJointError +from datajoint import DataJointError, TransactionError import numpy as np from numpy.testing import assert_array_equal from datajoint.free_relation import FreeRelation @@ -117,7 +117,7 @@ def test_delete(self): # assert_true(len(self.trials) == 1, 'Length does not match 1.') def test_short_hand_foreign_reference(self): - self.animals.heading; + self.animals.heading @@ -131,6 +131,27 @@ def test_record_insert_different_order(self): assert_equal((2, 'Klara', 'monkey'), tuple(testt2), "Inserted and fetched record do not match!") + @raises(TransactionError) + def test_transaction_error(self): + "Test whether declaration in transaction is prohibited" + + tmp = np.array([('Klara', 2, 'monkey')], + dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) + with self.conn.transaction() as tr: + self.subjects.insert(tmp[0]) + + def test_transaction_error2(self): + "If table is declared, we are allow to insert within a transaction" + + tmp = np.array([('Klara', 2, 'monkey'), ('Klara', 3, 'monkey')], + dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) + self.subjects.insert(tmp[0]) + print(self.subjects) + with self.conn.transaction() as tr: + self.subjects.insert(tmp[1]) + + + @raises(KeyError) def test_wrong_key_insert_records(self): "Test whether record insert works" From 74350da302b010b21199ecd12f3ab73cbc23c788 Mon Sep 17 00:00:00 2001 From: Fabian Sinz Date: Wed, 20 May 2015 14:02:14 -0500 Subject: [PATCH 2/2] Fixed implicit commits --- datajoint/__init__.py | 13 +++- datajoint/autopopulate.py | 52 +++++++++------- datajoint/connection.py | 35 ++++++++--- datajoint/decorators.py | 2 +- tests/schemata/schema1/test1.py | 37 ++++++++++++ tests/test_connection.py | 13 ++++ tests/test_relation.py | 103 +++++++++++++++++++++++++++++++- 7 files changed, 222 insertions(+), 33 deletions(-) diff --git a/datajoint/__init__.py b/datajoint/__init__.py index cddc141bb..f95c04803 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -19,7 +19,18 @@ class TransactionError(DataJointError): """ Base class for errors specific to DataJoint internal operation. """ - pass + def __init__(self, msg, f, args, kwargs): + super(TransactionError, self).__init__(msg) + self.operations = (f, args, kwargs) + + def resolve(self): + f, args, kwargs = self.operations + return f(*args, **kwargs) + + @property + def culprit(self): + return self.operations[0].__name__ + # ----------- loads local configuration from file ---------------- diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index 64fb4f06b..28d5625c2 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -1,9 +1,9 @@ from .relational_operand import RelationalOperand -from . import DataJointError +from . import DataJointError, TransactionError import abc import logging -#noinspection PyExceptionInherit,PyCallingNonCallable +# noinspection PyExceptionInherit,PyCallingNonCallable logger = logging.getLogger(__name__) @@ -39,31 +39,41 @@ def target(self): def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False): """ - rel.populate() calls rel._make_tuples(key) for every primary key in self.pop_rel + rel.populate() calls rel._make_tuples(key) for every primary key in self.populate_relation for which there is not already a tuple in rel. + + :param restriction: restriction on rel.populate_relation - target + :param suppress_errors: suppresses error if true + :param reserve_jobs: currently not implemented """ - assert not reserve_jobs, NotImplemented # issue #5 + assert not reserve_jobs, NotImplemented # issue #5 + error_list = [] if suppress_errors else None + if not isinstance(self.populate_relation, RelationalOperand): raise DataJointError('Invalid populate_relation value') - self.conn._cancel_transaction() # rollback previous transaction, if any + unpopulated = (self.populate_relation - self.target) & restriction + for key in unpopulated.project(): - self.conn._start_transaction() - if key in self.target: # already populated - self.conn._cancel_transaction() - else: - logger.info('Populating: ' + str(key)) - try: - self._make_tuples(key) - except Exception as error: - self.conn._cancel_transaction() - if not suppress_errors: - raise - else: - print(error) - error_list.append((key, error)) + try: + while True: + try: + with self.conn.transaction(): + if not key in self.target: # already populated + logger.info('Populating: ' + str(key)) + self._make_tuples(dict(key)) + break + except TransactionError as tr_err: + if suppress_errors: + error_list.append((key,tr_err)) + tr_err.resolve() + logger.info('Resolved transaction error raised by {0:s}.'.format(tr_err.culprit)) + except Exception as error: + if not suppress_errors: + raise else: - self.conn._commit_transaction() + print(error) + error_list.append((key, error)) logger.info('Done populating.') - return error_list \ No newline at end of file + return error_list diff --git a/datajoint/connection.py b/datajoint/connection.py index b27bc6287..a1eb81247 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -41,6 +41,7 @@ def conn_function(host=None, user=None, passwd=None, init_fun=None, reset=False) init_fun = init_fun if init_fun is not None else config['connection.init_function'] _connObj = Connection(host, user, passwd, init_fun) return _connObj + return conn_function # The function conn is used by others to obtain the package wide persistent connection object @@ -51,11 +52,16 @@ class Transaction(object): """ Class that defines a transaction. Mainly for use in a with statement. + :param ignore_errors=False: if True, all errors are not passed on. However, the transaction is still + rolled back if an error is raised. + :param conn: connection object that opens the transaction. """ - def __init__(self, conn): + def __init__(self, conn, ignore_errors=False): self.conn = conn + self._do_not_raise_error_again = False + self.ignore_errors = ignore_errors def __enter__(self): assert self.conn.is_connected, "Connection is not connected" @@ -74,8 +80,19 @@ def is_active(self): """ return self.conn.is_connected and self.conn.in_transaction - def __exit__(self, exc_type, exc_val, exc_tb): - if exc_type is None and exc_val is None and exc_tb is None: + def cancel(self): + """ + Cancels an ongoing transaction and rolls back. + + """ + self._do_not_raise_error_again = True + raise DataJointError("Transaction cancelled by user.") + + def __exit__(self, exc_type, exc_val, exc_tb): # TODO: assert XOR and only exc_type is None + + if exc_type is None: + assert exc_type is None and exc_val is None and exc_tb is None, \ + "Either all of exc_type, exc_val, exc_tb should be None, or neither of them" self.conn._commit_transaction() self.conn._in_transaction = False return True @@ -83,7 +100,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.conn._cancel_transaction() self.conn._in_transaction = False logger.debug("Transaction cancled because of an error.", exc_info=(exc_type, exc_val, exc_tb)) - return False + return self._do_not_raise_error_again or self.ignore_errors # if True is returned, errors are not raised again class Connection(object): @@ -300,7 +317,7 @@ def clear_dependencies(self, dbname=None): if key in self.referenced: self.referenced.pop(key) - def parents_of(self, child_table): #TODO: this function is not clear to me after reading the docu + def parents_of(self, child_table): # TODO: this function is not clear to me after reading the docu """ Returns a list of tables that are parents for the childTable based on primary foreign keys. @@ -309,7 +326,7 @@ def parents_of(self, child_table): #TODO: this function is not clear to me after """ return self.parents.get(child_table, []).copy() - def children_of(self, parent_table):#TODO: this function is not clear to me after reading the docu + def children_of(self, parent_table): # TODO: this function is not clear to me after reading the docu """ Returns a list of tables for which parent_table is a parent (primary foreign key) @@ -385,10 +402,12 @@ def query(self, query, args=(), as_dict=False): cur.execute(query, args) return cur - def transaction(self): + def transaction(self, ignore_errors=False): """ Context manager to be used with python's with statement. + :param ignore_errors=False: if True, all errors are not passed on. However, the transaction is still + rolled back if an error is raised. :return: a :class:`Transaction` object :Example: @@ -397,7 +416,7 @@ def transaction(self): >>> with conn.transaction() as tr: ... # do magic """ - return Transaction(self) + return Transaction(self, ignore_errors) @property def in_transaction(self): diff --git a/datajoint/decorators.py b/datajoint/decorators.py index 44837baad..a333bd1cc 100644 --- a/datajoint/decorators.py +++ b/datajoint/decorators.py @@ -11,7 +11,7 @@ def _not_in_transaction(f, *args, **kwargs): if args[0]._conn.in_transaction: raise TransactionError( u"{0:s} is currently in transaction. Operation not allowed to avoid implicit commits.".format( - args[0].__class__.__name__)) + args[0].__class__.__name__), f, args, kwargs) return f(*args, **kwargs) diff --git a/tests/schemata/schema1/test1.py b/tests/schemata/schema1/test1.py index c29eb7240..7a33a4372 100644 --- a/tests/schemata/schema1/test1.py +++ b/tests/schemata/schema1/test1.py @@ -42,6 +42,43 @@ class Trials(dj.Relation): """ + +class SquaredScore(dj.Relation, dj.AutoPopulate): + definition = """ + test1.SquaredScore (computed) # cumulative outcome of trials + + -> test1.Subjects + -> test1.Trials + --- + squared : int # squared result of Trials outcome + """ + + @property + def populate_relation(self): + return Subjects() * Trials() + + def _make_tuples(self, key): + tmp = (Trials() & key).fetch1() + tmp2 = SquaredSubtable() & key + + self.insert(dict(key, squared=tmp['outcome']**2)) + + ss = SquaredSubtable() + + for i in range(10): + key['dummy'] = i + ss.insert(key) + +class SquaredSubtable(dj.Relation): + definition = """ + test1.SquaredSubtable (computed) # cumulative outcome of trials + + -> test1.SquaredScore + dummy : int # dummy primary attribute + --- + """ + + # test reference to another table in same schema class Experiments(dj.Relation): definition = """ diff --git a/tests/test_connection.py b/tests/test_connection.py index 429e06304..43fbadc99 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -223,6 +223,19 @@ def test_rollback(self): testt2 = (self.relvar & 'subject_id = 2').fetch() assert_equal(len(testt2), 0, "Length is not 0. Expected because rollback should have happened.") + def test_cancel(self): + """Tests cancelling a transaction""" + tmp = np.array([(1,'Peter','mouse'),(2, 'Klara', 'monkey')], + dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) + + self.relvar.insert(tmp[0]) + with self.conn.transaction() as transaction: + self.relvar.insert(tmp[1]) + transaction.cancel() + + testt2 = (self.relvar & 'subject_id = 2').fetch() + assert_equal(len(testt2), 0, "Length is not 0. Expected because rollback should have happened.") + # class TestConnectionWithBindings(object): diff --git a/tests/test_relation.py b/tests/test_relation.py index 07ccada1d..7c1458d4d 100644 --- a/tests/test_relation.py +++ b/tests/test_relation.py @@ -140,13 +140,50 @@ def test_transaction_error(self): with self.conn.transaction() as tr: self.subjects.insert(tmp[0]) + def test_transaction_suppress_error(self): + "Test whether ignore_errors ignores the errors." + + tmp = np.array([('Klara', 2, 'monkey')], + dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) + with self.conn.transaction(ignore_errors=True) as tr: + self.subjects.insert(tmp[0]) + + + @raises(TransactionError) + def test_transaction_error_not_resolve(self): + "Test whether declaration in transaction is prohibited" + + tmp = np.array([('Klara', 2, 'monkey'), ('Klara', 3, 'monkey')], + dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) + try: + with self.conn.transaction() as tr: + self.subjects.insert(tmp[0]) + except TransactionError as te: + pass + with self.conn.transaction() as tr: + self.subjects.insert(tmp[0]) + + def test_transaction_error_resolve(self): + "Test whether declaration in transaction is prohibited" + + tmp = np.array([('Klara', 2, 'monkey'), ('Klara', 3, 'monkey')], + dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) + try: + with self.conn.transaction() as tr: + self.subjects.insert(tmp[0]) + except TransactionError as te: + te.resolve() + + with self.conn.transaction() as tr: + self.subjects.insert(tmp[0]) + def test_transaction_error2(self): - "If table is declared, we are allow to insert within a transaction" + "If table is declared, we are allowed to insert within a transaction" tmp = np.array([('Klara', 2, 'monkey'), ('Klara', 3, 'monkey')], dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) self.subjects.insert(tmp[0]) - print(self.subjects) + with self.conn.transaction() as tr: self.subjects.insert(tmp[1]) @@ -357,3 +394,65 @@ def test_fetch_dicts(self): assert_equal(t['comment'], t2['comment'], 'inserted and retrieved dicts do not match') assert_true(np.all(t['data'] == t2['data']), 'inserted and retrieved dicts do not match') + + +class TestAutopopulate(object): + def __init__(self): + self.relvar = None + self.setup() + + """ + Test cases for Iterators in Relations objects + """ + + def setup(self): + """ + Create a connection object and prepare test modules + as follows: + test1 - has conn and bounded + """ + cleanup() # drop all databases with PREFIX + test1.__dict__.pop('conn', None) # make sure conn is not defined at schema level + + self.conn = Connection(**CONN_INFO) + test1.conn = self.conn + self.conn.bind(test1.__name__, PREFIX + '_test1') + + self.subjects = test1.Subjects() + self.trials = test1.Trials() + self.squared = test1.SquaredScore() + self.dummy = test1.SquaredSubtable() + self.fill_relation() + + def fill_relation(self): + tmp = np.array([('Klara', 2, 'monkey'), ('Peter', 3, 'mouse')], + dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) + self.subjects.batch_insert(tmp) + + for trial_id in range(1,11): + self.trials.insert(dict(subject_id=2, trial_id=trial_id, outcome=np.random.randint(0,10))) + + + def teardown(self): + cleanup() + + + def test_autopopulate(self): + self.squared.populate() + assert_equal(len(self.squared), 10) + + for trial in self.trials*self.squared: + assert_equal(trial['outcome']**2, trial['squared']) + + def test_autopopulate_restriction(self): + self.squared.populate(restriction='trial_id <= 5') + assert_equal(len(self.squared), 5) + + for trial in self.trials*self.squared: + assert_equal(trial['outcome']**2, trial['squared']) + + + def test_autopopulate_transaction_error(self): + errors = self.squared.populate(suppress_errors=True) + assert_equal(len(errors), 1) + assert_true(isinstance(errors[0][1], TransactionError))