From 2d3208c29a46be196092898915ae161787356bb6 Mon Sep 17 00:00:00 2001 From: Edgar Walker Date: Wed, 6 May 2015 19:02:12 -0500 Subject: [PATCH 1/2] Fix all parts that was broken from #39 --- datajoint/base.py | 278 ++++++++++++++++++++++++++++++++++++++++++++- datajoint/table.py | 1 - tests/test_base.py | 4 +- 3 files changed, 275 insertions(+), 8 deletions(-) diff --git a/datajoint/base.py b/datajoint/base.py index d9ca80581..37cd4116d 100644 --- a/datajoint/base.py +++ b/datajoint/base.py @@ -5,6 +5,10 @@ from . import DataJointError from .table import Table import logging +import re +from .settings import Role, role_to_prefix +from .utils import from_camel_case +from .heading import Heading from .declare import declare logger = logging.getLogger(__name__) @@ -43,7 +47,7 @@ def __init__(self): self.class_name = self.__class__.__name__ module = self.__module__ mod_obj = importlib.import_module(module) - use_package = False + self._use_package = False try: conn = mod_obj.conn except AttributeError: @@ -51,13 +55,13 @@ def __init__(self): # check if database bound at the package level instead pkg_obj = importlib.import_module(mod_obj.__package__) conn = pkg_obj.conn - use_package = True + self._use_package = True except AttributeError: raise DataJointError( "Please define object 'conn' in '{}' or in its containing package.".format(self.__module__)) self.conn = conn try: - if use_package: + if self._use_package: pkg_name = '.'.join(module.split('.')[:-1]) dbname = self.conn.mod_to_db[pkg_name] else: @@ -65,9 +69,275 @@ def __init__(self): except KeyError: raise DataJointError( 'Module {} is not bound to a database. See datajoint.connection.bind'.format(self.__module__)) - declare(self.conn, self.definition, self.full_class_name) + self.dbname = dbname + self.declare() super().__init__(conn=conn, dbname=dbname, class_name=self.__class__.__name__) + @property + def is_declared(self): + self.conn.load_headings(self.dbname) + return self.class_name in self.conn.table_names[self.dbname] + + def declare(self): + """ + Declare the table in database if it doesn't already exist. + + :raises: DataJointError if the table cannot be declared. + """ + if not self.is_declared: + self._declare() + if not self.is_declared: + raise DataJointError( + 'Table could not be declared for %s' % self.class_name) + + def _field_to_sql(self, field): + """ + Converts an attribute definition tuple into SQL code. + :param field: attribute definition + :rtype : SQL code + """ + mysql_constants = ['CURRENT_TIMESTAMP'] + if field.nullable: + default = 'DEFAULT NULL' + else: + default = 'NOT NULL' + # if some default specified + if field.default: + # enclose value in quotes (even numeric), except special SQL values + # or values already enclosed by the user + if field.default.upper() in mysql_constants or field.default[:1] in ["'", '"']: + default = '%s DEFAULT %s' % (default, field.default) + else: + default = '%s DEFAULT "%s"' % (default, field.default) + + # TODO: escape instead! - same goes for Matlab side implementation + assert not any((c in r'\"' for c in field.comment)), \ + 'Illegal characters in attribute comment "%s"' % field.comment + + return '`{name}` {type} {default} COMMENT "{comment}",\n'.format( + name=field.name, type=field.type, default=default, comment=field.comment) + + def _declare(self): + """ + Declares the table in the data base if no table in the database matches this object. + """ + if not self.definition: + raise DataJointError('Table declaration is missing!') + table_info, parents, referenced, fieldDefs, indexDefs = self._parse_declaration() + defined_name = table_info['module'] + '.' + table_info['className'] + expected_name = self.__module__.split('.')[-1] + '.' + self.class_name + if not defined_name == expected_name: + raise DataJointError('Table name {} does not match the declared' + 'name {}'.format(expected_name, defined_name)) + + # compile the CREATE TABLE statement + # TODO: support prefix + table_name = role_to_prefix[ + table_info['tier']] + from_camel_case(self.class_name) + sql = 'CREATE TABLE `%s`.`%s` (\n' % (self.dbname, table_name) + + # add inherited primary key fields + primary_key_fields = set() + non_key_fields = set() + for p in parents: + for key in p.primary_key: + field = p.heading[key] + if field.name not in primary_key_fields: + primary_key_fields.add(field.name) + sql += self._field_to_SQL(field) + else: + logger.debug('Field definition of {} in {} ignored'.format( + field.name, p.full_class_name)) + + # add newly defined primary key fields + for field in (f for f in fieldDefs if f.in_key): + if field.nullable: + raise DataJointError('Primary key {} cannot be nullable'.format( + field.name)) + if field.name in primary_key_fields: + raise DataJointError('Duplicate declaration of the primary key ' + '{key}. Check to make sure that the key ' + 'is not declared already in referenced ' + 'tables'.format(key=field.name)) + primary_key_fields.add(field.name) + sql += self._field_to_sql(field) + + # add secondary foreign key attributes + for r in referenced: + keys = (x for x in r.heading.attrs.values() if x.in_key) + for field in keys: + if field.name not in primary_key_fields | non_key_fields: + non_key_fields.add(field.name) + sql += self._field_to_sql(field) + + # add dependent attributes + for field in (f for f in fieldDefs if not f.in_key): + non_key_fields.add(field.name) + sql += self._field_to_sql(field) + + # add primary key declaration + assert len(primary_key_fields) > 0, 'table must have a primary key' + keys = ', '.join(primary_key_fields) + sql += 'PRIMARY KEY (%s),\n' % keys + + # add foreign key declarations + for ref in parents + referenced: + keys = ', '.join(ref.primary_key) + sql += 'FOREIGN KEY (%s) REFERENCES %s (%s) ON UPDATE CASCADE ON DELETE RESTRICT,\n' % \ + (keys, ref.full_table_name, keys) + + # add secondary index declarations + # gather implicit indexes due to foreign keys first + implicit_indices = [] + for fk_source in parents + referenced: + implicit_indices.append(fk_source.primary_key) + + # for index in indexDefs: + # TODO: finish this up... + + # close the declaration + sql = '%s\n) ENGINE = InnoDB, COMMENT "%s"' % ( + sql[:-2], table_info['comment']) + + # make sure that the table does not alredy exist + self.conn.load_headings(self.dbname, force=True) + if not self.is_declared: + # execute declaration + logger.debug('\n\n' + sql + '\n\n') + self.conn.query(sql) + self.conn.load_headings(self.dbname, force=True) + + def _parse_declaration(self): + """ + Parse declaration and create new SQL table accordingly. + """ + parents = [] + referenced = [] + index_defs = [] + field_defs = [] + declaration = re.split(r'\s*\n\s*', self.definition.strip()) + + # remove comment lines + declaration = [x for x in declaration if not x.startswith('#')] + ptrn = """ + ^(?P\w+)\.(?P\w+)\s* # module.className + \(\s*(?P\w+)\s*\)\s* # (tier) + \#\s*(?P.*)$ # comment + """ + p = re.compile(ptrn, re.X) + table_info = p.match(declaration[0]).groupdict() + if table_info['tier'] not in Role.__members__: + raise DataJointError('InvalidTableTier: Invalid tier {tier} for table\ + {module}.{cls}'.format(tier=table_info['tier'], + module=table_info[ + 'module'], + cls=table_info['className'])) + table_info['tier'] = Role[table_info['tier']] # convert into enum + + in_key = True # parse primary keys + field_ptrn = """ + ^[a-z][a-z\d_]*\s* # name + (=\s*\S+(\s+\S+)*\s*)? # optional defaults + :\s*\w.*$ # type, comment + """ + fieldP = re.compile(field_ptrn, re.I + re.X) # ignore case and verbose + + for line in declaration[1:]: + if line.startswith('---'): + in_key = False # start parsing non-PK fields + elif line.startswith('->'): + # foreign key + module_name, class_name = line[2:].strip().split('.') + rel = self.get_base(module_name, class_name) + (parents if in_key else referenced).append(rel) + elif re.match(r'^(unique\s+)?index[^:]*$', line): + index_defs.append(self._parse_index_def(line)) + elif fieldP.match(line): + field_defs.append(self._parse_attr_def(line, in_key)) + else: + raise DataJointError( + 'Invalid table declaration line "%s"' % line) + + return table_info, parents, referenced, field_defs, index_defs + + def _parse_attr_def(self, line, in_key=False): # todo add docu for in_key + """ + Parse attribute definition line in the declaration and returns + an attribute tuple. + + :param line: attribution line + :param in_key: + :returns: attribute tuple + """ + line = line.strip() + attr_ptrn = """ + ^(?P[a-z][a-z\d_]*)\s* # field name + (=\s*(?P\S+(\s+\S+)*?)\s*)? # default value + :\s*(?P\w[^\#]*[^\#\s])\s* # datatype + (\#\s*(?P\S*(\s+\S+)*)\s*)?$ # comment + """ + + attrP = re.compile(attr_ptrn, re.I + re.X) + m = attrP.match(line) + assert m, 'Invalid field declaration "%s"' % line + attr_info = m.groupdict() + if not attr_info['comment']: + attr_info['comment'] = '' + if not attr_info['default']: + attr_info['default'] = '' + attr_info['nullable'] = attr_info['default'].lower() == 'null' + assert (not re.match(r'^bigint', attr_info['type'], re.I) or not attr_info['nullable']), \ + 'BIGINT attributes cannot be nullable in "%s"' % line + + attr_info['in_key'] = in_key + attr_info['autoincrement'] = None + attr_info['numeric'] = None + attr_info['string'] = None + attr_info['is_blob'] = None + attr_info['computation'] = None + attr_info['dtype'] = None + + return Heading.AttrTuple(**attr_info) + + def _parse_index_def(self, line): + """ + Parses index definition. + + :param line: definition line + :return: groupdict with index info + """ + line = line.strip() + index_ptrn = """ + ^(?PUNIQUE)?\s*INDEX\s* # [UNIQUE] INDEX + \((?P[^\)]+)\)$ # (attr1, attr2) + """ + indexP = re.compile(index_ptrn, re.I + re.X) + m = indexP.match(line) + assert m, 'Invalid index declaration "%s"' % line + index_info = m.groupdict() + attributes = re.split(r'\s*,\s*', index_info['attributes'].strip()) + index_info['attributes'] = attributes + assert len(attributes) == len(set(attributes)), \ + 'Duplicate attributes in index declaration "%s"' % line + return index_info + + def get_base(self, module_name, class_name): + """ + Loads the base relation from the module. If the base relation is not defined in + the module, then construct it using Base constructor. + + :param module_name: module name + :param class_name: class name + :returns: the base relation + """ + mod_obj = self.get_module(module_name) + try: + ret = getattr(mod_obj, class_name)() + except KeyError: + ret = self.__class__(conn=self.conn, + dbname=self.conn.schemas[module_name], + class_name=class_name) + return ret @classmethod def get_module(cls, module_name): diff --git a/datajoint/table.py b/datajoint/table.py index 14971e868..7cb6d6b97 100644 --- a/datajoint/table.py +++ b/datajoint/table.py @@ -25,7 +25,6 @@ class Table(Relation): """ def __init__(self, conn=None, dbname=None, class_name=None, definition=None): - self._use_package = False self.class_name = class_name self.conn = conn self.dbname = dbname diff --git a/tests/test_base.py b/tests/test_base.py index 42e11bacf..61ad99e8d 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -76,12 +76,10 @@ def test_instantiation_of_base_derivatives(self): s = test1.Subjects() assert_equal(s.dbname, PREFIX + '_test1') assert_equal(s.conn, self.conn) - assert_equal(s._table_def, test1.Subjects._table_def) + assert_equal(s.definition, test1.Subjects.definition) def test_declaration_status(self): b = test1.Subjects() - assert_false(b.is_declared) - b.declare() assert_true(b.is_declared) def test_declaration_from_doc_string(self): From a17cf837b52ca1cb7ce88acbd4ffc836bceef19f Mon Sep 17 00:00:00 2001 From: Edgar Walker Date: Wed, 6 May 2015 19:05:22 -0500 Subject: [PATCH 2/2] Delete declare.py --- datajoint/base.py | 1 - datajoint/declare.py | 241 ------------------------------------------- datajoint/table.py | 1 - 3 files changed, 243 deletions(-) delete mode 100644 datajoint/declare.py diff --git a/datajoint/base.py b/datajoint/base.py index 37cd4116d..ce30adfe9 100644 --- a/datajoint/base.py +++ b/datajoint/base.py @@ -9,7 +9,6 @@ from .settings import Role, role_to_prefix from .utils import from_camel_case from .heading import Heading -from .declare import declare logger = logging.getLogger(__name__) diff --git a/datajoint/declare.py b/datajoint/declare.py deleted file mode 100644 index d95468e8e..000000000 --- a/datajoint/declare.py +++ /dev/null @@ -1,241 +0,0 @@ -import re -import logging -from .heading import Heading -from . import DataJointError -from .utils import from_camel_case -from .settings import Role, role_to_prefix - -mysql_constants = ['CURRENT_TIMESTAMP'] - -logger = logging.getLogger(__name__) - - -def declare(conn, definition, class_name): - """ - Declares the table in the data base if no table in the database matches this object. - """ - table_info, parents, referenced, field_definitions, index_definitions = _parse_declaration(conn, definition) - defined_name = table_info['module'] + '.' + table_info['className'] - # TODO: clean up this mess... currently just ignoring the name used to define the table - #if not defined_name == class_name: - # raise DataJointError('Table name {} does not match the declared' - # 'name {}'.format(class_name, defined_name)) - - # compile the CREATE TABLE statement - table_name = role_to_prefix[table_info['tier']] + from_camel_case(defined_name) - sql = 'CREATE TABLE `%s`.`%s` (\n' % (self.dbname, table_name) - - # add inherited primary key fields - primary_key_fields = set() - non_key_fields = set() - for p in parents: - for key in p.primary_key: - field = p.heading[key] - if field.name not in primary_key_fields: - primary_key_fields.add(field.name) - sql += field_to_sql(field) - else: - logger.debug('Field definition of {} in {} ignored'.format( - field.name, p.full_class_name)) - - # add newly defined primary key fields - for field in (f for f in field_definitions if f.isKey): - if field.nullable: - raise DataJointError('Primary key {} cannot be nullable'.format( - field.name)) - if field.name in primary_key_fields: - raise DataJointError('Duplicate declaration of the primary key ' - '{key}. Check to make sure that the key ' - 'is not declared already in referenced ' - 'tables'.format(key=field.name)) - primary_key_fields.add(field.name) - sql += field_to_sql(field) - - # add secondary foreign key attributes - for r in referenced: - keys = (x for x in r.heading.attrs.values() if x.isKey) - for field in keys: - if field.name not in primary_key_fields | non_key_fields: - non_key_fields.add(field.name) - sql += field_to_sql(field) - - # add dependent attributes - for field in (f for f in field_definitions if not f.isKey): - non_key_fields.add(field.name) - sql += field_to_sql(field) - - # add primary key declaration - assert len(primary_key_fields) > 0, 'table must have a primary key' - keys = ', '.join(primary_key_fields) - sql += 'PRIMARY KEY (%s),\n' % keys - - # add foreign key declarations - for ref in parents + referenced: - keys = ', '.join(ref.primary_key) - sql += 'FOREIGN KEY (%s) REFERENCES %s (%s) ON UPDATE CASCADE ON DELETE RESTRICT,\n' % \ - (keys, ref.full_table_name, keys) - - # add secondary index declarations - # gather implicit indexes due to foreign keys first - implicit_indices = [] - for fk_source in parents + referenced: - implicit_indices.append(fk_source.primary_key) - - # for index in index_definitions: - # TODO: finish this up... - - # close the declaration - sql = '%s\n) ENGINE = InnoDB, COMMENT "%s"' % ( - sql[:-2], table_info['comment']) - - # make sure that the table does not alredy exist - # TODO: there will be a problem with resolving the module here... - conn.load_headings(self.dbname, force=True) - if not self.is_declared: - # execute declaration - logger.debug('\n\n' + sql + '\n\n') - self.conn.query(sql) - self.conn.load_headings(self.dbname, force=True) - - -def _parse_declaration(conn, definition): - """ - Parse declaration and create new SQL table accordingly. - """ - parents = [] - referenced = [] - index_defs = [] - field_defs = [] - declaration = re.split(r'\s*\n\s*', definition.strip()) - - # remove comment lines - declaration = [x for x in declaration if not x.startswith('#')] - ptrn = """ - ^(?P\w+)\.(?P\w+)\s* # module.className - \(\s*(?P\w+)\s*\)\s* # (tier) - \#\s*(?P.*)$ # comment - """ - p = re.compile(ptrn, re.X) - table_info = p.match(declaration[0]).groupdict() - if table_info['tier'] not in Role.__members__: - raise DataJointError('InvalidTableTier: Invalid tier {tier} for table\ - {module}.{cls}'.format(tier=table_info['tier'], - module=table_info[ - 'module'], - cls=table_info['className'])) - table_info['tier'] = Role[table_info['tier']] # convert into enum - - in_key = True # parse primary keys - field_ptrn = """ - ^[a-z][a-z\d_]*\s* # name - (=\s*\S+(\s+\S+)*\s*)? # optional defaults - :\s*\w.*$ # type, comment - """ - fieldP = re.compile(field_ptrn, re.I + re.X) # ignore case and verbose - - for line in declaration[1:]: - if line.startswith('---'): - in_key = False # start parsing non-PK fields - elif line.startswith('->'): - # foreign key - module_name, class_name = line[2:].strip().split('.') - rel = self.get_base(module_name, class_name) - (parents if in_key else referenced).append(rel) - elif re.match(r'^(unique\s+)?index[^:]*$', line): - index_defs.append(parse_index_defnition(line)) - elif fieldP.match(line): - field_defs.append(parse_attribute_definition(line, in_key)) - else: - raise DataJointError( - 'Invalid table declaration line "%s"' % line) - - return table_info, parents, referenced, field_defs, index_defs - - -def field_to_sql(field): - """ - Converts an attribute definition tuple into SQL code. - :param field: attribute definition - :rtype : SQL code - """ - if field.nullable: - default = 'DEFAULT NULL' - else: - default = 'NOT NULL' - # if some default specified - if field.default: - # enclose value in quotes (even numeric), except special SQL values - # or values already enclosed by the user - if field.default.upper() in mysql_constants or field.default[:1] in ["'", '"']: - default = '%s DEFAULT %s' % (default, field.default) - else: - default = '%s DEFAULT "%s"' % (default, field.default) - - # TODO: escape instead! - same goes for Matlab side implementation - assert not any((c in r'\"' for c in field.comment)), \ - 'Illegal characters in attribute comment "%s"' % field.comment - - return '`{name}` {type} {default} COMMENT "{comment}",\n'.format( - name=field.name, type=field.type, default=default, comment=field.comment) - - -def parse_attribute_definition(line, in_key=False): # todo add docu for in_key - """ - Parse attribute definition line in the declaration and returns - an attribute tuple. - :param line: attribution line - :param in_key: - :returns: attribute tuple - """ - line = line.strip() - attr_ptrn = """ - ^(?P[a-z][a-z\d_]*)\s* # field name - (=\s*(?P\S+(\s+\S+)*?)\s*)? # default value - :\s*(?P\w[^\#]*[^\#\s])\s* # datatype - (\#\s*(?P\S*(\s+\S+)*)\s*)?$ # comment - """ - - attrP = re.compile(attr_ptrn, re.I + re.X) - m = attrP.match(line) - assert m, 'Invalid field declaration "%s"' % line - attr_info = m.groupdict() - if not attr_info['comment']: - attr_info['comment'] = '' - if not attr_info['default']: - attr_info['default'] = '' - attr_info['nullable'] = attr_info['default'].lower() == 'null' - assert (not re.match(r'^bigint', attr_info['type'], re.I) or not attr_info['nullable']), \ - 'BIGINT attributes cannot be nullable in "%s"' % line - - attr_info['in_key'] = in_key - attr_info['autoincrement'] = None - attr_info['numeric'] = None - attr_info['string'] = None - attr_info['is_blob'] = None - attr_info['computation'] = None - attr_info['dtype'] = None - - return Heading.AttrTuple(**attr_info) - - -def parse_index_definition(line): # why is this a method of Base instead of a local function? - """ - Parses index definition. - - :param line: definition line - :return: groupdict with index info - """ - line = line.strip() - index_ptrn = """ - ^(?PUNIQUE)?\s*INDEX\s* # [UNIQUE] INDEX - \((?P[^\)]+)\)$ # (attr1, attr2) - """ - indexP = re.compile(index_ptrn, re.I + re.X) - m = indexP.match(line) - assert m, 'Invalid index declaration "%s"' % line - index_info = m.groupdict() - attributes = re.split(r'\s*,\s*', index_info['attributes'].strip()) - index_info['attributes'] = attributes - assert len(attributes) == len(set(attributes)), \ - 'Duplicate attributes in index declaration "%s"' % line - return index_info diff --git a/datajoint/table.py b/datajoint/table.py index 7cb6d6b97..77164e395 100644 --- a/datajoint/table.py +++ b/datajoint/table.py @@ -2,7 +2,6 @@ import logging from . import DataJointError from .relational import Relation -from .declare import (declare, parse_attribute_definition) logger = logging.getLogger(__name__)