diff --git a/datajoint/__init__.py b/datajoint/__init__.py index 34bac55f8..38cef2aa6 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -4,7 +4,7 @@ __author__ = "Dimitri Yatsenko, Edgar Walker, and Fabian Sinz at Baylor College of Medicine" __version__ = "0.2" __all__ = ['__author__', '__version__', - 'Connection', 'Heading', 'Relation', 'Not', + 'Connection', 'Heading', 'Relation', 'FreeRelation', 'Not', 'AutoPopulate', 'conn', 'DataJointError', 'blob'] @@ -20,11 +20,8 @@ class DataJointError(Exception): config = Config() local_config_file = os.environ.get(CONFIGVAR, None) if local_config_file is None: - local_config_file = os.path.expanduser(LOCALCONFIG) -else: - local_config_file = os.path.expanduser(local_config_file) - - + local_config_file = LOCALCONFIG +local_config_file = os.path.expanduser(local_config_file) try: logger.log(logging.INFO, "Loading local settings from {0:s}".format(local_config_file)) @@ -43,6 +40,4 @@ class DataJointError(Exception): from . import blob from .relational_operand import Not from .free_relation import FreeRelation - - -############################################################################# +from .heading import Heading \ No newline at end of file diff --git a/datajoint/blob.py b/datajoint/blob.py index e66ca11ab..b77547320 100644 --- a/datajoint/blob.py +++ b/datajoint/blob.py @@ -6,8 +6,8 @@ mxClassID = collections.OrderedDict( # see http://www.mathworks.com/help/techdoc/apiref/mxclassid.html mxUNKNOWN_CLASS=None, - mxCELL_CLASS=None, # not implemented - mxSTRUCT_CLASS=None, # not implemented + mxCELL_CLASS=None, # TODO: implement + mxSTRUCT_CLASS=None, # TODO: implement mxLOGICAL_CLASS=np.dtype('bool'), mxCHAR_CLASS=np.dtype('c'), mxVOID_CLASS=None, @@ -48,10 +48,9 @@ def pack(obj): if is_complex: blob += imaginary.tostring() - if len(blob) > 1000: - compressed = b'ZL123\0'+np.asarray(len(blob), dtype=np.uint64).tostring() + zlib.compress(blob) - if len(compressed) < len(blob): - blob = compressed + compressed = b'ZL123\0' + np.uint64(len(blob)).tostring() + zlib.compress(blob) + if len(compressed) < len(blob): + blob = compressed return blob diff --git a/datajoint/connection.py b/datajoint/connection.py index 935cab86d..64607de5b 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -3,7 +3,7 @@ from .utils import to_camel_case from . import DataJointError from .heading import Heading -from .settings import prefix_to_role, DEFAULT_PORT +from .settings import prefix_to_role import logging from .erd import DBConnGraph from . import config @@ -74,7 +74,6 @@ def is_active(self): """ return self.conn.is_connected and self.conn.in_transaction - def __exit__(self, exc_type, exc_val, exc_tb): if exc_type is None and exc_val is None and exc_tb is None: self.conn._commit_transaction() @@ -85,7 +84,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): logger.debug("Transaction cancled because of an error.", exc_info=(exc_type, exc_val, exc_tb)) - class Connection(object): """ A dj.Connection object manages a connection to a database server. @@ -108,9 +106,10 @@ def __init__(self, host, user, passwd, init_fun=None): port = config['database.port'] self.conn_info = dict(host=host, port=port, user=user, passwd=passwd) self._conn = pymysql.connect(init_command=init_fun, **self.conn_info) - # TODO Do something if connection cannot be established if self.is_connected: print("Connected", user + '@' + host + ':' + str(port)) + else: + raise DataJointError('Connection failed.') self._conn.autocommit(True) self.db_to_mod = {} # modules indexed by dbnames diff --git a/datajoint/free_relation.py b/datajoint/free_relation.py index e45af6ad2..318738131 100644 --- a/datajoint/free_relation.py +++ b/datajoint/free_relation.py @@ -195,7 +195,6 @@ def drop(self): Drops the table associated to this object. """ # TODO: make cascading (issue #16) - if self.is_declared: self.conn.query('DROP TABLE %s' % self.full_table_name) self.conn.clear_dependencies(dbname=self.dbname) @@ -443,12 +442,11 @@ def _parse_declaration(self): table_info['tier'] = Role[table_info['tier']] # convert into enum in_key = True # parse primary keys - field_ptrn = """ + attribute_regexp = re.compile(""" ^[a-z][a-z\d_]*\s* # name (=\s*\S+(\s+\S+)*\s*)? # optional defaults :\s*\w.*$ # type, comment - """ - fieldP = re.compile(field_ptrn, re.I + re.X) # ignore case and verbose + """, re.I + re.X) # ignore case and verbose for line in declaration[1:]: if line.startswith('---'): @@ -456,11 +454,11 @@ def _parse_declaration(self): elif line.startswith('->'): # foreign key module_name, class_name = line[2:].strip().split('.') - rel = self.get_base(module_name, class_name) - (parents if in_key else referenced).append(rel) - elif re.match(r'^(unique\s+)?index[^:]*$', line): + ref = parents if in_key else referenced + ref.append(self.get_base(module_name, class_name)) + elif re.match(r'^(unique\s+)?index[^:]*$', line, re.I): index_defs.append(self._parse_index_def(line)) - elif fieldP.match(line): + elif attribute_regexp.match(line): field_defs.append(parse_attribute_definition(line, in_key)) else: raise DataJointError( @@ -486,7 +484,8 @@ def parse_attribute_definition(line, in_key=False): (\#\s*(?P\S*(\s+\S+)*)\s*)?$ # comment """, re.X) m = attribute_regexp.match(line) - assert m, 'Invalid field declaration "%s"' % line + if not m: + raise DataJointError('Invalid field declaration "%s"' % line) attr_info = m.groupdict() if not attr_info['comment']: attr_info['comment'] = '' @@ -496,12 +495,13 @@ def parse_attribute_definition(line, in_key=False): assert (not re.match(r'^bigint', attr_info['type'], re.I) or not attr_info['nullable']), \ 'BIGINT attributes cannot be nullable in "%s"' % line - attr_info['in_key'] = in_key - attr_info['autoincrement'] = None - attr_info['numeric'] = None - attr_info['string'] = None - attr_info['is_blob'] = None - attr_info['computation'] = None - attr_info['dtype'] = None - - return Heading.AttrTuple(**attr_info) \ No newline at end of file + return Heading.AttrTuple( + in_key=in_key, + autoincrement=None, + numeric=None, + string=None, + is_blob=None, + computation=None, + dtype=None, + **attr_info + ) diff --git a/datajoint/relational_operand.py b/datajoint/relational_operand.py index 71e9c459f..a36b84502 100644 --- a/datajoint/relational_operand.py +++ b/datajoint/relational_operand.py @@ -147,20 +147,18 @@ def fetch(self, offset=0, limit=None, order_by=None, descending=False, as_dict=F for n, e in zip(attr_names, t)) for t in cur.fetchall()] else: ret = np.array(list(cur.fetchall()), dtype=self.heading.as_dtype) - for bname in self.heading.blobs: - ret[bname] = list(map(unpack, ret[bname])) + for blob_name in self.heading.blobs: + ret[blob_name] = list(map(unpack, ret[blob_name])) return ret def cursor(self, offset=0, limit=None, order_by=None, descending=False): """ - :param offset: the number of tuples to skip in the returned result - :param limit: the maximum number of tuples to return - :param order_by: the list of attributes to order the results. No ordering should be assumed if order_by=None. - :param descending: the list of attributes to order the results + Return query cursor. + See Relation.fetch() for input description. :return: cursor to the query """ if offset and limit is None: - raise DataJointError('offset cannot be set without setting a limit') + raise DataJointError('limit is required when offset is set') sql = self.make_select() if order_by is not None: sql += ' ORDER BY ' + ', '.join(order_by) @@ -173,8 +171,6 @@ def cursor(self, offset=0, limit=None, order_by=None, descending=False): logger.debug(sql) return self.conn.query(sql) - - def __repr__(self): limit = 7 #TODO: move some of these display settings into the config width = 14 @@ -182,7 +178,7 @@ def __repr__(self): template = '%%-%d.%ds' % (width, width) columns = rel.heading.names repr_string = ' '.join([template % column for column in columns]) + '\n' - repr_string += ' '.join(['+' + '-' * (width - 2) + '+' for _ in columns]) + '\n' + repr_string += ' '.join(['+' + '-'*(width-2) + '+' for _ in columns]) + '\n' for tup in rel.fetch(limit=limit): repr_string += ' '.join([template % column for column in tup]) + '\n' if self.count > limit: @@ -193,20 +189,15 @@ def __repr__(self): def __iter__(self): """ Iterator that yields individual tuples of the current table dictionaries. - - - :param offset: parameter passed to the :func:`cursor` - :param limit: parameter passed to the :func:`cursor` - :param order_by: parameter passed to the :func:`cursor` - :param descending: parameter passed to the :func:`cursor` """ cur = self.cursor() - do_unpack = tuple(h in self.heading.blobs for h in self.heading.names) - q = cur.fetchone() - while q: - yield dict( (fieldname,unpack(field)) if up else (fieldname,field) - for fieldname, up, field in zip(self.heading.names, do_unpack, q)) - q = cur.fetchone() + heading = self.heading # construct once for efficiency + do_unpack = tuple(h in heading.blobs for h in self.heading.names) + values = cur.fetchone() + while values: + yield {field_name: unpack(value) if up else value + for field_name, up, value in zip(heading.names, do_unpack, values)} + values = cur.fetchone() @property def where_clause(self): diff --git a/datajoint/settings.py b/datajoint/settings.py index 345245733..1ab7ea0a0 100644 --- a/datajoint/settings.py +++ b/datajoint/settings.py @@ -6,14 +6,12 @@ import pprint from collections import OrderedDict -__author__ = 'eywalker' import logging import collections from enum import Enum LOCALCONFIG = 'dj_local_conf.json' CONFIGVAR = 'DJ_LOCAL_CONF' -DEFAULT_PORT = 3306 validators = collections.defaultdict(lambda: lambda value: True) validators['database.port'] = lambda a: isinstance(a, int) @@ -52,9 +50,11 @@ class Borg: _shared_state = {} + def __init__(self): self.__dict__ = self._shared_state + class Config(Borg, collections.MutableMapping): """ Stores datajoint settings. Behaves like a dictionary, but applies validator functions diff --git a/datajoint/utils.py b/datajoint/utils.py index 1189a9307..252fa1187 100644 --- a/datajoint/utils.py +++ b/datajoint/utils.py @@ -19,21 +19,16 @@ def to_upper(match): def from_camel_case(s): """ - Convert names in camel case into underscore - (_) separated names + Convert names in camel case into underscore (_) separated names Example: >>>from_camel_case("TableName") "table_name" """ - if re.search(r'\s', s): - raise DataJointError('Input cannot contain white space') - if re.match(r'\d.*', s): - raise DataJointError('Input cannot begin with a digit') - if not re.match(r'^[a-zA-Z0-9]*$', s): - raise DataJointError('String can only contain alphanumeric characters') - def convert(match): return ('_' if match.groups()[0] else '') + match.group(0).lower() - return re.sub(r'(\B[A-Z])|(\b[A-Z])', convert, s) + if not re.match(r'[A-Z][a-zA-Z0-9]*', s): + raise DataJointError( + 'ClassName must be alphanumeric in CamelCase, begin with a capital letter') + return re.sub(r'(\B[A-Z])|(\b[A-Z])', convert, s) \ No newline at end of file diff --git a/tests/test_utils.py b/tests/test_utils.py index 9bce1de8a..655884ce0 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -23,7 +23,8 @@ def test_to_camel_case(): def test_from_camel_case(): assert_equal(from_camel_case('AllGroups'), 'all_groups') - assert_equal(from_camel_case('repNames'), 'rep_names') + with assert_raises(DataJointError): + from_camel_case('repNames') with assert_raises(DataJointError): from_camel_case('10_all') with assert_raises(DataJointError):