From 6afb2319c4c5288a941d39fcb9fc74e2035ebec2 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Wed, 13 May 2015 12:14:19 -0500 Subject: [PATCH 1/5] minor cosmetic changes --- datajoint/blob.py | 11 +++++------ datajoint/connection.py | 7 +++---- datajoint/settings.py | 3 ++- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/datajoint/blob.py b/datajoint/blob.py index e66ca11ab..b77547320 100644 --- a/datajoint/blob.py +++ b/datajoint/blob.py @@ -6,8 +6,8 @@ mxClassID = collections.OrderedDict( # see http://www.mathworks.com/help/techdoc/apiref/mxclassid.html mxUNKNOWN_CLASS=None, - mxCELL_CLASS=None, # not implemented - mxSTRUCT_CLASS=None, # not implemented + mxCELL_CLASS=None, # TODO: implement + mxSTRUCT_CLASS=None, # TODO: implement mxLOGICAL_CLASS=np.dtype('bool'), mxCHAR_CLASS=np.dtype('c'), mxVOID_CLASS=None, @@ -48,10 +48,9 @@ def pack(obj): if is_complex: blob += imaginary.tostring() - if len(blob) > 1000: - compressed = b'ZL123\0'+np.asarray(len(blob), dtype=np.uint64).tostring() + zlib.compress(blob) - if len(compressed) < len(blob): - blob = compressed + compressed = b'ZL123\0' + np.uint64(len(blob)).tostring() + zlib.compress(blob) + if len(compressed) < len(blob): + blob = compressed return blob diff --git a/datajoint/connection.py b/datajoint/connection.py index 935cab86d..64607de5b 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -3,7 +3,7 @@ from .utils import to_camel_case from . import DataJointError from .heading import Heading -from .settings import prefix_to_role, DEFAULT_PORT +from .settings import prefix_to_role import logging from .erd import DBConnGraph from . import config @@ -74,7 +74,6 @@ def is_active(self): """ return self.conn.is_connected and self.conn.in_transaction - def __exit__(self, exc_type, exc_val, exc_tb): if exc_type is None and exc_val is None and exc_tb is None: self.conn._commit_transaction() @@ -85,7 +84,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): logger.debug("Transaction cancled because of an error.", exc_info=(exc_type, exc_val, exc_tb)) - class Connection(object): """ A dj.Connection object manages a connection to a database server. @@ -108,9 +106,10 @@ def __init__(self, host, user, passwd, init_fun=None): port = config['database.port'] self.conn_info = dict(host=host, port=port, user=user, passwd=passwd) self._conn = pymysql.connect(init_command=init_fun, **self.conn_info) - # TODO Do something if connection cannot be established if self.is_connected: print("Connected", user + '@' + host + ':' + str(port)) + else: + raise DataJointError('Connection failed.') self._conn.autocommit(True) self.db_to_mod = {} # modules indexed by dbnames diff --git a/datajoint/settings.py b/datajoint/settings.py index 345245733..ad1ec4f82 100644 --- a/datajoint/settings.py +++ b/datajoint/settings.py @@ -13,7 +13,6 @@ LOCALCONFIG = 'dj_local_conf.json' CONFIGVAR = 'DJ_LOCAL_CONF' -DEFAULT_PORT = 3306 validators = collections.defaultdict(lambda: lambda value: True) validators['database.port'] = lambda a: isinstance(a, int) @@ -52,9 +51,11 @@ class Borg: _shared_state = {} + def __init__(self): self.__dict__ = self._shared_state + class Config(Borg, collections.MutableMapping): """ Stores datajoint settings. Behaves like a dictionary, but applies validator functions From f5359b2ff0d08ef500ccc1f73fff62ba4d589ffe Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Wed, 13 May 2015 14:28:46 -0500 Subject: [PATCH 2/5] minor improvements --- datajoint/__init__.py | 3 ++- datajoint/free_relation.py | 36 ++++++++++++++++++------------------ 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/datajoint/__init__.py b/datajoint/__init__.py index 34bac55f8..f24fae271 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -4,7 +4,7 @@ __author__ = "Dimitri Yatsenko, Edgar Walker, and Fabian Sinz at Baylor College of Medicine" __version__ = "0.2" __all__ = ['__author__', '__version__', - 'Connection', 'Heading', 'Relation', 'Not', + 'Connection', 'Heading', 'Relation', 'FreeRelation', 'Not', 'AutoPopulate', 'conn', 'DataJointError', 'blob'] @@ -43,6 +43,7 @@ class DataJointError(Exception): from . import blob from .relational_operand import Not from .free_relation import FreeRelation +from .heading import Heading ############################################################################# diff --git a/datajoint/free_relation.py b/datajoint/free_relation.py index d21e17580..fcba3b43d 100644 --- a/datajoint/free_relation.py +++ b/datajoint/free_relation.py @@ -194,7 +194,6 @@ def drop(self): Drops the table associated to this object. """ # TODO: make cascading (issue #16) - if self.is_declared: self.conn.query('DROP TABLE %s' % self.full_table_name) self.conn.clear_dependencies(dbname=self.dbname) @@ -442,12 +441,11 @@ def _parse_declaration(self): table_info['tier'] = Role[table_info['tier']] # convert into enum in_key = True # parse primary keys - field_ptrn = """ + attribute_regexp = re.compile(""" ^[a-z][a-z\d_]*\s* # name (=\s*\S+(\s+\S+)*\s*)? # optional defaults :\s*\w.*$ # type, comment - """ - fieldP = re.compile(field_ptrn, re.I + re.X) # ignore case and verbose + """, re.I + re.X) # ignore case and verbose for line in declaration[1:]: if line.startswith('---'): @@ -455,11 +453,11 @@ def _parse_declaration(self): elif line.startswith('->'): # foreign key module_name, class_name = line[2:].strip().split('.') - rel = self.get_base(module_name, class_name) - (parents if in_key else referenced).append(rel) - elif re.match(r'^(unique\s+)?index[^:]*$', line): + ref = parents if in_key else referenced + ref.append(self.get_base(module_name, class_name)) + elif re.match(r'^(unique\s+)?index[^:]*$', line, re.I): index_defs.append(self._parse_index_def(line)) - elif fieldP.match(line): + elif attribute_regexp.match(line): field_defs.append(parse_attribute_definition(line, in_key)) else: raise DataJointError( @@ -485,7 +483,8 @@ def parse_attribute_definition(line, in_key=False): (\#\s*(?P\S*(\s+\S+)*)\s*)?$ # comment """, re.X) m = attribute_regexp.match(line) - assert m, 'Invalid field declaration "%s"' % line + if not m: + raise DataJointError('Invalid field declaration "%s"' % line) attr_info = m.groupdict() if not attr_info['comment']: attr_info['comment'] = '' @@ -495,12 +494,13 @@ def parse_attribute_definition(line, in_key=False): assert (not re.match(r'^bigint', attr_info['type'], re.I) or not attr_info['nullable']), \ 'BIGINT attributes cannot be nullable in "%s"' % line - attr_info['in_key'] = in_key - attr_info['autoincrement'] = None - attr_info['numeric'] = None - attr_info['string'] = None - attr_info['is_blob'] = None - attr_info['computation'] = None - attr_info['dtype'] = None - - return Heading.AttrTuple(**attr_info) \ No newline at end of file + return Heading.AttrTuple( + in_key=in_key, + autoincrement=None, + numeric=None, + string=None, + is_blob=None, + computation=None, + dtype=None, + **attr_info + ) \ No newline at end of file From 01bc7a7d562b9fcfb1e5b2eaa602f0f07a701a8f Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Wed, 13 May 2015 15:34:47 -0500 Subject: [PATCH 3/5] improved error messages for parsing table definitions --- datajoint/__init__.py | 12 +++--------- datajoint/settings.py | 1 - datajoint/utils.py | 13 ++++--------- 3 files changed, 7 insertions(+), 19 deletions(-) diff --git a/datajoint/__init__.py b/datajoint/__init__.py index f24fae271..38cef2aa6 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -20,11 +20,8 @@ class DataJointError(Exception): config = Config() local_config_file = os.environ.get(CONFIGVAR, None) if local_config_file is None: - local_config_file = os.path.expanduser(LOCALCONFIG) -else: - local_config_file = os.path.expanduser(local_config_file) - - + local_config_file = LOCALCONFIG +local_config_file = os.path.expanduser(local_config_file) try: logger.log(logging.INFO, "Loading local settings from {0:s}".format(local_config_file)) @@ -43,7 +40,4 @@ class DataJointError(Exception): from . import blob from .relational_operand import Not from .free_relation import FreeRelation -from .heading import Heading - - -############################################################################# +from .heading import Heading \ No newline at end of file diff --git a/datajoint/settings.py b/datajoint/settings.py index ad1ec4f82..1ab7ea0a0 100644 --- a/datajoint/settings.py +++ b/datajoint/settings.py @@ -6,7 +6,6 @@ import pprint from collections import OrderedDict -__author__ = 'eywalker' import logging import collections from enum import Enum diff --git a/datajoint/utils.py b/datajoint/utils.py index af2e8e310..7593a3aae 100644 --- a/datajoint/utils.py +++ b/datajoint/utils.py @@ -18,21 +18,16 @@ def to_upper(match): def from_camel_case(s): """ - Convert names in camel case into underscore - (_) separated names + Convert names in camel case into underscore (_) separated names Example: >>>from_camel_case("TableName") "table_name" """ - if re.search(r'\s', s): - raise DataJointError('Input cannot contain white space') - if re.match(r'\d.*', s): - raise DataJointError('Input cannot begin with a digit') - if not re.match(r'^[a-zA-Z0-9]*$', s): - raise DataJointError('String can only contain alphanumeric characters') - def convert(match): return ('_' if match.groups()[0] else '') + match.group(0).lower() + if not re.match(r'[A-Z][a-zA-Z0-9]*', s): + raise DataJointError( + 'ClassName must be alphanumeric in CamelCase, begin with a capital letter') return re.sub(r'(\B[A-Z])|(\b[A-Z])', convert, s) \ No newline at end of file From 6947d680df75b0d2e4752d633120a29d89a61ab9 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Wed, 13 May 2015 16:04:06 -0500 Subject: [PATCH 4/5] cleaned up RelationalOperand.fetch() --- datajoint/relational_operand.py | 16 +++++----------- tests/test_utils.py | 3 ++- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/datajoint/relational_operand.py b/datajoint/relational_operand.py index d98311cef..d2bd57e17 100644 --- a/datajoint/relational_operand.py +++ b/datajoint/relational_operand.py @@ -189,20 +189,14 @@ def __repr__(self): def __iter__(self): """ Iterator that yields individual tuples of the current table dictionaries. - - - :param offset: parameter passed to the :func:`cursor` - :param limit: parameter passed to the :func:`cursor` - :param order_by: parameter passed to the :func:`cursor` - :param descending: parameter passed to the :func:`cursor` """ cur = self.cursor() do_unpack = tuple(h in self.heading.blobs for h in self.heading.names) - q = cur.fetchone() - while q: - yield dict( (fieldname,unpack(field)) if up else (fieldname,field) - for fieldname, up, field in zip(self.heading.names, do_unpack, q)) - q = cur.fetchone() + values = cur.fetchone() + while values: + yield {field_name: unpack(value) if up else value + for field_name, up, value in zip(self.heading.names, do_unpack, values)} + values = cur.fetchone() @property def where_clause(self): diff --git a/tests/test_utils.py b/tests/test_utils.py index 9bce1de8a..655884ce0 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -23,7 +23,8 @@ def test_to_camel_case(): def test_from_camel_case(): assert_equal(from_camel_case('AllGroups'), 'all_groups') - assert_equal(from_camel_case('repNames'), 'rep_names') + with assert_raises(DataJointError): + from_camel_case('repNames') with assert_raises(DataJointError): from_camel_case('10_all') with assert_raises(DataJointError): From 7e9bba7e1ab4d3ccaff41e2c7cc4816477d2c4de Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Wed, 13 May 2015 16:18:32 -0500 Subject: [PATCH 5/5] cleanup --- datajoint/relational_operand.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/datajoint/relational_operand.py b/datajoint/relational_operand.py index d2bd57e17..a36b84502 100644 --- a/datajoint/relational_operand.py +++ b/datajoint/relational_operand.py @@ -191,11 +191,12 @@ def __iter__(self): Iterator that yields individual tuples of the current table dictionaries. """ cur = self.cursor() - do_unpack = tuple(h in self.heading.blobs for h in self.heading.names) + heading = self.heading # construct once for efficiency + do_unpack = tuple(h in heading.blobs for h in self.heading.names) values = cur.fetchone() while values: yield {field_name: unpack(value) if up else value - for field_name, up, value in zip(self.heading.names, do_unpack, values)} + for field_name, up, value in zip(heading.names, do_unpack, values)} values = cur.fetchone() @property