diff --git a/datajoint/__init__.py b/datajoint/__init__.py index 38cef2aa6..f95c04803 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -15,6 +15,24 @@ class DataJointError(Exception): pass +class TransactionError(DataJointError): + """ + Base class for errors specific to DataJoint internal operation. + """ + def __init__(self, msg, f, args, kwargs): + super(TransactionError, self).__init__(msg) + self.operations = (f, args, kwargs) + + def resolve(self): + f, args, kwargs = self.operations + return f(*args, **kwargs) + + @property + def culprit(self): + return self.operations[0].__name__ + + + # ----------- loads local configuration from file ---------------- from .settings import Config, CONFIGVAR, LOCALCONFIG, logger, log_levels config = Config() diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index 64fb4f06b..28d5625c2 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -1,9 +1,9 @@ from .relational_operand import RelationalOperand -from . import DataJointError +from . import DataJointError, TransactionError import abc import logging -#noinspection PyExceptionInherit,PyCallingNonCallable +# noinspection PyExceptionInherit,PyCallingNonCallable logger = logging.getLogger(__name__) @@ -39,31 +39,41 @@ def target(self): def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False): """ - rel.populate() calls rel._make_tuples(key) for every primary key in self.pop_rel + rel.populate() calls rel._make_tuples(key) for every primary key in self.populate_relation for which there is not already a tuple in rel. + + :param restriction: restriction on rel.populate_relation - target + :param suppress_errors: suppresses error if true + :param reserve_jobs: currently not implemented """ - assert not reserve_jobs, NotImplemented # issue #5 + assert not reserve_jobs, NotImplemented # issue #5 + error_list = [] if suppress_errors else None + if not isinstance(self.populate_relation, RelationalOperand): raise DataJointError('Invalid populate_relation value') - self.conn._cancel_transaction() # rollback previous transaction, if any + unpopulated = (self.populate_relation - self.target) & restriction + for key in unpopulated.project(): - self.conn._start_transaction() - if key in self.target: # already populated - self.conn._cancel_transaction() - else: - logger.info('Populating: ' + str(key)) - try: - self._make_tuples(key) - except Exception as error: - self.conn._cancel_transaction() - if not suppress_errors: - raise - else: - print(error) - error_list.append((key, error)) + try: + while True: + try: + with self.conn.transaction(): + if not key in self.target: # already populated + logger.info('Populating: ' + str(key)) + self._make_tuples(dict(key)) + break + except TransactionError as tr_err: + if suppress_errors: + error_list.append((key,tr_err)) + tr_err.resolve() + logger.info('Resolved transaction error raised by {0:s}.'.format(tr_err.culprit)) + except Exception as error: + if not suppress_errors: + raise else: - self.conn._commit_transaction() + print(error) + error_list.append((key, error)) logger.info('Done populating.') - return error_list \ No newline at end of file + return error_list diff --git a/datajoint/connection.py b/datajoint/connection.py index 51ed2d1bd..fc00ad951 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -41,6 +41,7 @@ def conn_function(host=None, user=None, passwd=None, init_fun=None, reset=False) init_fun = init_fun if init_fun is not None else config['connection.init_function'] _connObj = Connection(host, user, passwd, init_fun) return _connObj + return conn_function # The function conn is used by others to obtain the package wide persistent connection object @@ -51,11 +52,16 @@ class Transaction(object): """ Class that defines a transaction. Mainly for use in a with statement. + :param ignore_errors=False: if True, all errors are not passed on. However, the transaction is still + rolled back if an error is raised. + :param conn: connection object that opens the transaction. """ - def __init__(self, conn): + def __init__(self, conn, ignore_errors=False): self.conn = conn + self._do_not_raise_error_again = False + self.ignore_errors = ignore_errors def __enter__(self): assert self.conn.is_connected, "Connection is not connected" @@ -74,14 +80,27 @@ def is_active(self): """ return self.conn.is_connected and self.conn.in_transaction - def __exit__(self, exc_type, exc_val, exc_tb): - if exc_type is None and exc_val is None and exc_tb is None: + def cancel(self): + """ + Cancels an ongoing transaction and rolls back. + + """ + self._do_not_raise_error_again = True + raise DataJointError("Transaction cancelled by user.") + + def __exit__(self, exc_type, exc_val, exc_tb): # TODO: assert XOR and only exc_type is None + + if exc_type is None: + assert exc_type is None and exc_val is None and exc_tb is None, \ + "Either all of exc_type, exc_val, exc_tb should be None, or neither of them" self.conn._commit_transaction() self.conn._in_transaction = False + return True else: self.conn._cancel_transaction() self.conn._in_transaction = False logger.debug("Transaction cancled because of an error.", exc_info=(exc_type, exc_val, exc_tb)) + return self._do_not_raise_error_again or self.ignore_errors # if True is returned, errors are not raised again class Connection(object): @@ -387,10 +406,12 @@ def query(self, query, args=(), as_dict=False): cur.execute(query, args) return cur - def transaction(self): + def transaction(self, ignore_errors=False): """ Context manager to be used with python's with statement. + :param ignore_errors=False: if True, all errors are not passed on. However, the transaction is still + rolled back if an error is raised. :return: a :class:`Transaction` object :Example: @@ -399,7 +420,7 @@ def transaction(self): >>> with conn.transaction() as tr: ... # do magic """ - return Transaction(self) + return Transaction(self, ignore_errors) @property def in_transaction(self): diff --git a/datajoint/decorators.py b/datajoint/decorators.py new file mode 100644 index 000000000..a333bd1cc --- /dev/null +++ b/datajoint/decorators.py @@ -0,0 +1,22 @@ +from decorator import decorator +from . import DataJointError, TransactionError + + +def _not_in_transaction(f, *args, **kwargs): + if not hasattr(args[0], '_conn'): + raise DataJointError(u"{0:s} does not have a member called _conn".format(args[0].__class__.__name__, )) + if not hasattr(args[0]._conn, 'in_transaction'): + raise DataJointError( + u"{0:s}._conn does not have a property in_transaction".format(args[0].__class__.__name__, )) + if args[0]._conn.in_transaction: + raise TransactionError( + u"{0:s} is currently in transaction. Operation not allowed to avoid implicit commits.".format( + args[0].__class__.__name__), f, args, kwargs) + return f(*args, **kwargs) + + +def not_in_transaction(f): + """ + This decorator raises an error if the function is called during a transaction. + """ + return decorator(_not_in_transaction, f) diff --git a/datajoint/free_relation.py b/datajoint/free_relation.py index 3484e68d5..9e5b49203 100644 --- a/datajoint/free_relation.py +++ b/datajoint/free_relation.py @@ -2,6 +2,7 @@ import numpy as np import logging from . import DataJointError, config +from .decorators import not_in_transaction from .relational_operand import RelationalOperand from .blob import pack from .heading import Heading @@ -265,7 +266,7 @@ def erd(self, subset=None): Plot the schema's entity relationship diagram (ERD). """ - + @not_in_transaction def _alter(self, alter_statement): """ Execute ALTER TABLE statement for this table. The schema @@ -313,6 +314,7 @@ def ref_name(self): """ return '`{0}`'.format(self.dbname) + '.' + self.class_name + @not_in_transaction def _declare(self): """ Declares the table in the database if no table in the database matches this object. diff --git a/requirements.txt b/requirements.txt index 82abfa7ea..1e85cd6e3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ networkx matplotlib sphinx_rtd_theme mock +decorator diff --git a/setup.py b/setup.py index 5c56b483c..5320801b0 100644 --- a/setup.py +++ b/setup.py @@ -10,6 +10,6 @@ description='An object-relational mapping and relational algebra to facilitate data definition and data manipulation in MySQL databases.', url='https://github.com/datajoint/datajoint-python', packages=['datajoint'], - requires=['numpy', 'pymysql', 'networkx', 'matplotlib', 'sphinx_rtd_theme', 'mock', 'json'], + requires=['numpy', 'pymysql', 'networkx', 'matplotlib', 'sphinx_rtd_theme', 'mock', 'json', 'decorator'], license = "MIT", ) diff --git a/tests/schemata/schema1/test1.py b/tests/schemata/schema1/test1.py index c29eb7240..7a33a4372 100644 --- a/tests/schemata/schema1/test1.py +++ b/tests/schemata/schema1/test1.py @@ -42,6 +42,43 @@ class Trials(dj.Relation): """ + +class SquaredScore(dj.Relation, dj.AutoPopulate): + definition = """ + test1.SquaredScore (computed) # cumulative outcome of trials + + -> test1.Subjects + -> test1.Trials + --- + squared : int # squared result of Trials outcome + """ + + @property + def populate_relation(self): + return Subjects() * Trials() + + def _make_tuples(self, key): + tmp = (Trials() & key).fetch1() + tmp2 = SquaredSubtable() & key + + self.insert(dict(key, squared=tmp['outcome']**2)) + + ss = SquaredSubtable() + + for i in range(10): + key['dummy'] = i + ss.insert(key) + +class SquaredSubtable(dj.Relation): + definition = """ + test1.SquaredSubtable (computed) # cumulative outcome of trials + + -> test1.SquaredScore + dummy : int # dummy primary attribute + --- + """ + + # test reference to another table in same schema class Experiments(dj.Relation): definition = """ diff --git a/tests/test_connection.py b/tests/test_connection.py index 266b9f3fb..838ddd2df 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -224,6 +224,19 @@ def test_rollback(self): testt2 = (self.relvar & 'subject_id = 2').fetch() assert_equal(len(testt2), 0, "Length is not 0. Expected because rollback should have happened.") + def test_cancel(self): + """Tests cancelling a transaction""" + tmp = np.array([(1,'Peter','mouse'),(2, 'Klara', 'monkey')], + dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) + + self.relvar.insert(tmp[0]) + with self.conn.transaction() as transaction: + self.relvar.insert(tmp[1]) + transaction.cancel() + + testt2 = (self.relvar & 'subject_id = 2').fetch() + assert_equal(len(testt2), 0, "Length is not 0. Expected because rollback should have happened.") + # class TestConnectionWithBindings(object): diff --git a/tests/test_relation.py b/tests/test_relation.py index a2564a80c..7c1458d4d 100644 --- a/tests/test_relation.py +++ b/tests/test_relation.py @@ -9,7 +9,7 @@ from datajoint.connection import Connection from nose.tools import assert_raises, assert_equal, assert_regexp_matches, assert_false, assert_true, assert_list_equal,\ assert_tuple_equal, assert_dict_equal, raises -from datajoint import DataJointError +from datajoint import DataJointError, TransactionError import numpy as np from numpy.testing import assert_array_equal from datajoint.free_relation import FreeRelation @@ -117,7 +117,7 @@ def test_delete(self): # assert_true(len(self.trials) == 1, 'Length does not match 1.') def test_short_hand_foreign_reference(self): - self.animals.heading; + self.animals.heading @@ -131,6 +131,64 @@ def test_record_insert_different_order(self): assert_equal((2, 'Klara', 'monkey'), tuple(testt2), "Inserted and fetched record do not match!") + @raises(TransactionError) + def test_transaction_error(self): + "Test whether declaration in transaction is prohibited" + + tmp = np.array([('Klara', 2, 'monkey')], + dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) + with self.conn.transaction() as tr: + self.subjects.insert(tmp[0]) + + def test_transaction_suppress_error(self): + "Test whether ignore_errors ignores the errors." + + tmp = np.array([('Klara', 2, 'monkey')], + dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) + with self.conn.transaction(ignore_errors=True) as tr: + self.subjects.insert(tmp[0]) + + + @raises(TransactionError) + def test_transaction_error_not_resolve(self): + "Test whether declaration in transaction is prohibited" + + tmp = np.array([('Klara', 2, 'monkey'), ('Klara', 3, 'monkey')], + dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) + try: + with self.conn.transaction() as tr: + self.subjects.insert(tmp[0]) + except TransactionError as te: + pass + with self.conn.transaction() as tr: + self.subjects.insert(tmp[0]) + + def test_transaction_error_resolve(self): + "Test whether declaration in transaction is prohibited" + + tmp = np.array([('Klara', 2, 'monkey'), ('Klara', 3, 'monkey')], + dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) + try: + with self.conn.transaction() as tr: + self.subjects.insert(tmp[0]) + except TransactionError as te: + te.resolve() + + with self.conn.transaction() as tr: + self.subjects.insert(tmp[0]) + + def test_transaction_error2(self): + "If table is declared, we are allowed to insert within a transaction" + + tmp = np.array([('Klara', 2, 'monkey'), ('Klara', 3, 'monkey')], + dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) + self.subjects.insert(tmp[0]) + + with self.conn.transaction() as tr: + self.subjects.insert(tmp[1]) + + + @raises(KeyError) def test_wrong_key_insert_records(self): "Test whether record insert works" @@ -336,3 +394,65 @@ def test_fetch_dicts(self): assert_equal(t['comment'], t2['comment'], 'inserted and retrieved dicts do not match') assert_true(np.all(t['data'] == t2['data']), 'inserted and retrieved dicts do not match') + + +class TestAutopopulate(object): + def __init__(self): + self.relvar = None + self.setup() + + """ + Test cases for Iterators in Relations objects + """ + + def setup(self): + """ + Create a connection object and prepare test modules + as follows: + test1 - has conn and bounded + """ + cleanup() # drop all databases with PREFIX + test1.__dict__.pop('conn', None) # make sure conn is not defined at schema level + + self.conn = Connection(**CONN_INFO) + test1.conn = self.conn + self.conn.bind(test1.__name__, PREFIX + '_test1') + + self.subjects = test1.Subjects() + self.trials = test1.Trials() + self.squared = test1.SquaredScore() + self.dummy = test1.SquaredSubtable() + self.fill_relation() + + def fill_relation(self): + tmp = np.array([('Klara', 2, 'monkey'), ('Peter', 3, 'mouse')], + dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) + self.subjects.batch_insert(tmp) + + for trial_id in range(1,11): + self.trials.insert(dict(subject_id=2, trial_id=trial_id, outcome=np.random.randint(0,10))) + + + def teardown(self): + cleanup() + + + def test_autopopulate(self): + self.squared.populate() + assert_equal(len(self.squared), 10) + + for trial in self.trials*self.squared: + assert_equal(trial['outcome']**2, trial['squared']) + + def test_autopopulate_restriction(self): + self.squared.populate(restriction='trial_id <= 5') + assert_equal(len(self.squared), 5) + + for trial in self.trials*self.squared: + assert_equal(trial['outcome']**2, trial['squared']) + + + def test_autopopulate_transaction_error(self): + errors = self.squared.populate(suppress_errors=True) + assert_equal(len(errors), 1) + assert_true(isinstance(errors[0][1], TransactionError))