diff --git a/datajoint/free_relation.py b/datajoint/free_relation.py index ff27236e8..ffa54a3fa 100644 --- a/datajoint/free_relation.py +++ b/datajoint/free_relation.py @@ -1,13 +1,13 @@ from _collections_abc import MutableMapping, Mapping import numpy as np import logging -from . import DataJointError +from . import DataJointError, config from .relational_operand import RelationalOperand from .blob import pack from .heading import Heading import re from .settings import Role, role_to_prefix -from .utils import from_camel_case +from .utils import from_camel_case, user_confirmation logger = logging.getLogger(__name__) @@ -50,7 +50,6 @@ def heading(self): self.declare() return self.conn.headings[self.dbname][self.table_name] - @property def definition(self): return self._definition @@ -73,7 +72,7 @@ def declare(self): 'FreeRelation could not be declared for %s' % self.class_name) @staticmethod - def _field_to_sql(field): #TODO move this into Attribute Tuple + def _field_to_sql(field): # TODO move this into Attribute Tuple """ Converts an attribute definition tuple into SQL code. :param field: attribute definition @@ -183,6 +182,11 @@ def insert(self, tup, ignore_errors=False, replace=False): self.conn.query(sql, args=args) def delete(self): + if config['safemode'] and \ + user_confirmation( + """You are about to delete data from a table. This operation cannot be undone. + Do you want to proceed?""", ['y', 'n'], 'n') == 'n': + return # TODO: make cascading (issue #15) self.conn.query('DELETE FROM ' + self.from_clause + self.where_clause) @@ -190,6 +194,11 @@ def drop(self): """ Drops the table associated to this object. """ + if config['safemode'] and \ + user_confirmation( + """You are about to drop an entire table. This operation cannot be undone. + Do you want to proceed?""", ['y', 'n'], 'n') == 'n': + return # TODO: make cascading (issue #16) if self.is_declared: self.conn.query('DROP TABLE %s' % self.full_table_name) @@ -229,6 +238,11 @@ def drop_attribute(self, attr_name): :param attr_name: Name of the attribute that is dropped. """ + if config['safemode'] and \ + user_confirmation( + """You are about to drop an attribute from a table. This operation cannot be undone. + Do you want to proceed?""", ['y', 'n'], 'n') == 'n': + return self._alter('DROP COLUMN `%s`' % attr_name) def alter_attribute(self, attr_name, new_definition): diff --git a/datajoint/relational_operand.py b/datajoint/relational_operand.py index 396048d4d..fdc15c88a 100644 --- a/datajoint/relational_operand.py +++ b/datajoint/relational_operand.py @@ -6,7 +6,7 @@ import abc import re from copy import copy -from datajoint import DataJointError +from datajoint import DataJointError, config from .blob import unpack import logging import numpy.lib.recfunctions as rfn @@ -184,8 +184,8 @@ def cursor(self, offset=0, limit=None, order_by=None, descending=False): return self.conn.query(sql) def __repr__(self): - limit = 7 #TODO: move some of these display settings into the config - width = 14 + limit = config['display.limit'] + width = config['display.width'] rel = self.project(*self.heading.non_blobs) template = '%%-%d.%ds' % (width, width) columns = rel.heading.names diff --git a/datajoint/settings.py b/datajoint/settings.py index 1ab7ea0a0..767171b17 100644 --- a/datajoint/settings.py +++ b/datajoint/settings.py @@ -34,7 +34,12 @@ # 'connection.init_function': None, # - 'loglevel': 'DEBUG' + 'loglevel': 'DEBUG', + # + 'safemode': False, + # + 'display.limit': 7, + 'display.width': 14 }) logger = logging.getLogger() diff --git a/datajoint/utils.py b/datajoint/utils.py index 252fa1187..d104a9cc8 100644 --- a/datajoint/utils.py +++ b/datajoint/utils.py @@ -31,4 +31,26 @@ def convert(match): if not re.match(r'[A-Z][a-zA-Z0-9]*', s): raise DataJointError( 'ClassName must be alphanumeric in CamelCase, begin with a capital letter') - return re.sub(r'(\B[A-Z])|(\b[A-Z])', convert, s) \ No newline at end of file + return re.sub(r'(\B[A-Z])|(\b[A-Z])', convert, s) + +def user_confirmation(infostring, choices, default=None): + """ + Prompts the user for confirmation. + + :param infostring: Information to display to the user. + :param choices: an iterable of possible choices. + :param default=None: default choice + :return: the user's choice + """ + print(infostring) + cho = list(choices) + if default is not None: + cho[cho.index(default)] += ' (default)' + cho = ', '.join(cho) + + response = input('Please answer ' + cho) + while not ((response in choices) or (default is not None and len(response.strip())==0)): + response = input('Please answer (' + cho + '):') + if default is not None and len(response.strip())==0: + response = choices[choices.index(default)] + return response diff --git a/tests/test_relation.py b/tests/test_relation.py index 830dea26e..8d6ed3a6e 100644 --- a/tests/test_relation.py +++ b/tests/test_relation.py @@ -89,6 +89,36 @@ def test_record_insert(self): testt2 = (self.subjects & 'subject_id = 2').fetch()[0] assert_equal(tuple(tmp[0]), tuple(testt2), "Inserted and fetched record do not match!") + def test_delete(self): + "Test whether delete works" + tmp = np.array([(2, 'Klara', 'monkey'), (1,'Peter', 'mouse')], + dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) + + self.subjects.batch_insert(tmp) + assert_true(len(self.subjects) == 2, 'Length does not match 2.') + self.subjects.delete() + assert_true(len(self.subjects) == 0, 'Length does not match 0.') + + # def test_cascading_delete(self): + # "Test whether delete works" + # tmp = np.array([(2, 'Klara', 'monkey'), (1,'Peter', 'mouse')], + # dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) + # + # self.subjects.batch_insert(tmp) + # + # self.trials.insert(dict(subject_id=1, trial_id=1, outcome=0)) + # self.trials.insert(dict(subject_id=1, trial_id=2, outcome=1)) + # self.trials.insert(dict(subject_id=2, trial_id=3, outcome=2)) + # assert_true(len(self.subjects) == 2, 'Length does not match 2.') + # assert_true(len(self.trials) == 3, 'Length does not match 3.') + # (self.subjects & 'subject_id=1').delete() + # assert_true(len(self.subjects) == 1, 'Length does not match 1.') + # assert_true(len(self.trials) == 1, 'Length does not match 1.') + + + + + def test_record_insert_different_order(self): "Test whether record insert works" tmp = np.array([('Klara', 2, 'monkey')],