diff --git a/datajoint/__init__.py b/datajoint/__init__.py index e59a008cc..e193f88a8 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -5,6 +5,8 @@ __version__ = "0.2" __all__ = ['__author__', '__version__', 'Connection', 'Heading', 'Relation', 'FreeRelation', 'Not', + 'BaseRelation', + 'ManualRelation', 'LookupRelation', 'ImportedRelation', 'ComputedRelation', 'AutoPopulate', 'conn', 'DataJointError', 'blob'] @@ -28,13 +30,11 @@ def resolve(self): f, args, kwargs = self.operations return f(*args, **kwargs) - @property def culprit(self): return self.operations[0].__name__ - # ----------- loads local configuration from file ---------------- from .settings import Config, CONFIGVAR, LOCALCONFIG, logger, log_levels config = Config() @@ -55,9 +55,9 @@ def culprit(self): # ------------- flatten import hierarchy ------------------------- from .connection import conn, Connection -from .relation import Relation +from .user_relations import ManualRelation, LookupRelation, ImportedRelation, ComputedRelation +from .base_relation import BaseRelation from .autopopulate import AutoPopulate from . import blob from .relational_operand import Not -from .relation import FreeRelation from .heading import Heading \ No newline at end of file diff --git a/datajoint/relation.py b/datajoint/base_relation.py similarity index 57% rename from datajoint/relation.py rename to datajoint/base_relation.py index 4b3f67146..fb30ae959 100644 --- a/datajoint/relation.py +++ b/datajoint/base_relation.py @@ -1,124 +1,96 @@ -from _collections_abc import MutableMapping, Mapping +from collections.abc import MutableMapping, Mapping import numpy as np import logging +import re +import abc + from . import DataJointError, config, TransactionError from .relational_operand import RelationalOperand from .blob import pack +from .utils import user_choice +from .parsing import parse_attribute_definition, field_to_sql, parse_index_definition from .heading import Heading -import re -from .settings import Role, role_to_prefix -from .utils import from_camel_case, user_choice -from .connection import conn -import abc logger = logging.getLogger(__name__) -class Relation(RelationalOperand): +class BaseRelation(RelationalOperand, metaclass=abc.ABCMeta): """ - A FreeRelation object is a relation associated with a table. - A FreeRelation object provides insert and delete methods. - FreeRelation objects are only used internally and for debugging. - The table must already exist in the schema for its FreeRelation object to work. - - The table associated with an instance of Relation is identified by its 'class name'. - property, which is a string in CamelCase. The actual table name is obtained - by converting className from CamelCase to underscore_separated_words and - prefixing according to the table's role. - - Relation instances obtain their table's heading by looking it up in the connection - object. This ensures that Relation instances contain the current table definition - even after tables are modified after the instance is created. + BaseRelation is an abstract class that represents a base relation, i.e. a table in the database. + To make it a concrete class, override the abstract properties specifying the connection, + table name, database, context, and definition. + A BaseRelation implements insert and delete methods in addition to inherited relational operators. + It also loads table heading and dependencies from the database. + It also handles the table declaration based on its definition property """ - # defines class properties - - - def __init__(self, table_name, schema_name=None, connection=None, definition=None, context=None): - self._table_name = table_name - self._schema_name = schema_name - if connection is None: - connection = conn() - self._connection = connection - self._definition = definition - if context is None: - context = {} - self._context = context - self._heading = None + __heading = None + # ---------- abstract properties ------------ # @property - def schema_name(self): - return self._schema_name + @abc.abstractmethod + def table_name(self): + """ + :return: the name of the table in the database + """ + pass @property - def connection(self): - return self._connection + @abc.abstractmethod + def database(self): + """ + :return: string containing the database name on the server + """ + pass @property + @abc.abstractmethod def definition(self): - return self._definition + """ + :return: a string containing the table definition using the DataJoint DDL + """ + pass @property + @abc.abstractmethod def context(self): - return self._context - - @property - def heading(self): - return self._heading - - @heading.setter - def heading(self, new_heading): - self._heading = new_heading - - @property - def table_prefix(self): - return '' - - @property - def table_name(self): """ - TODO: allow table kind to be specified - :return: name of the table. This is equal to table_prefix + class name with underscores + :return: a dict with other relations that can be referenced by foreign keys """ - return self._table_name + pass + # --------- base relation functionality --------- # @property - def definition(self): - return self._definition - - - # ============================== Shared implementations ============================== + def is_declared(self): + if self.__heading is not None: + return True + cur = self.query( + 'SHOW TABLE STATUS FROM `{database}` WHERE name="{table_name}"'.format( + table_name=self.table_name)) + return cur.rowcount == 1 @property - def full_table_name(self): + def heading(self): """ - :return: full name of the associated table + Required by relational operand + :return: a datajoint.Heading object """ - return '`%s`.`%s`' % (self.schema_name, self.table_name) + if self.__heading is None: + if not self.is_declared and self.definition: + self.declare() + if self.is_declared: + self.__heading = Heading.init_from_database( + self.connection, self.database, self.table_name) + return self.__heading @property def from_clause(self): - return self.full_table_name - - # TODO: consider if this should be a class method for derived classes - def load_heading(self, forced=False): """ - Load the heading information for this table. If the table does not exist in the database server, Heading will be - set to None if the table is not yet defined in the database. + Required by the Relational class, this property specifies the contents of the FROM clause + for the SQL SELECT statements. + :return: """ - pass - # TODO: I want to be able to tell whether load_heading has already been attempted in the past... `self.heading is None` is not informative - # TODO: make sure to assign new heading to self.heading, not to self._heading or any other direct variables - - @property - def is_declared(self): - #TODO: this implementation is rather expensive and stupid - # - if table is not declared yet, repeated call to this method causes loading attempt each time - - if self.heading is None: - self.load_heading() - return self.heading is not None - + return '`%s`.`%s`' % (self.database, self.table_name) def declare(self): """ @@ -131,47 +103,15 @@ def declare(self): # verify that declaration completed successfully if not self.is_declared: raise DataJointError( - 'FreeRelation could not be declared for %s' % self.class_name) - - @staticmethod - def _field_to_sql(field): # TODO move this into Attribute Tuple - """ - Converts an attribute definition tuple into SQL code. - :param field: attribute definition - :rtype : SQL code - """ - mysql_constants = ['CURRENT_TIMESTAMP'] - if field.nullable: - default = 'DEFAULT NULL' - else: - default = 'NOT NULL' - # if some default specified - if field.default: - # enclose value in quotes except special SQL values or already enclosed - quote = field.default.upper() not in mysql_constants and field.default[0] not in '"\'' - default += ' DEFAULT ' + ('"%s"' if quote else "%s") % field.default - if any((c in r'\"' for c in field.comment)): - raise DataJointError('Illegal characters in attribute comment "%s"' % field.comment) - - return '`{name}` {type} {default} COMMENT "{comment}",\n'.format( - name=field.name, type=field.type, default=default, comment=field.comment) - - - - @property - def primary_key(self): - """ - :return: primary key of the table - """ - return self.heading.primary_key + 'BaseRelation could not be declared for %s' % self.class_name) - def iter_insert(self, iter, **kwargs): + def iter_insert(self, rows, **kwargs): """ Inserts an entire batch of entries. Additional keyword arguments are passed to insert. :param iter: Must be an iterator that generates a sequence of valid arguments for insert. """ - for row in iter: + for row in rows: self.insert(row, **kwargs) def batch_insert(self, data, **kwargs): @@ -285,7 +225,7 @@ def add_attribute(self, definition, after=None): """ position = ' FIRST' if after is None else ( ' AFTER %s' % after if after else '') - sql = self.field_to_sql(parse_attribute_definition(definition)) + sql = field_to_sql(parse_attribute_definition(definition)) self._alter('ADD COLUMN %s%s' % (sql[:-2], position)) def drop_attribute(self, attr_name): @@ -307,7 +247,7 @@ def alter_attribute(self, attr_name, new_definition): :param attr_name: field that is redefined :param new_definition: new definition of the field """ - sql = self.field_to_sql(parse_attribute_definition(new_definition)) + sql = field_to_sql(parse_attribute_definition(new_definition)) self._alter('CHANGE COLUMN `%s` %s' % (attr_name, sql[:-2])) def erd(self, subset=None): @@ -333,28 +273,6 @@ def _alter(self, alter_statement): # TODO: place table definition sync mechanism @staticmethod - def _parse_index_def(line): - """ - Parses index definition. - - :param line: definition line - :return: groupdict with index info - """ - line = line.strip() - index_regexp = re.compile(""" - ^(?PUNIQUE)?\s*INDEX\s* # [UNIQUE] INDEX - \((?P[^\)]+)\)$ # (attr1, attr2) - """, re.I + re.X) - m = index_regexp.match(line) - assert m, 'Invalid index declaration "%s"' % line - index_info = m.groupdict() - attributes = re.split(r'\s*,\s*', index_info['attributes'].strip()) - index_info['attributes'] = attributes - assert len(attributes) == len(set(attributes)), \ - 'Duplicate attributes in index declaration "%s"' % line - return index_info - - def _declare(self): """ Declares the table in the database if no table in the database matches this object. @@ -377,7 +295,7 @@ def _declare(self): field = p.heading[key] if field.name not in primary_key_fields: primary_key_fields.add(field.name) - sql += self._field_to_sql(field) + sql += field_to_sql(field) else: logger.debug('Field definition of {} in {} ignored'.format( field.name, p.full_class_name)) @@ -392,7 +310,7 @@ def _declare(self): 'Ensure that the attribute is not already declared ' 'in referenced tables'.format(key=field.name)) primary_key_fields.add(field.name) - sql += self._field_to_sql(field) + sql += field_to_sql(field) # add secondary foreign key attributes for r in referenced: @@ -400,12 +318,12 @@ def _declare(self): field = r.heading[key] if field.name not in primary_key_fields | non_key_fields: non_key_fields.add(field.name) - sql += self._field_to_sql(field) + sql += field_to_sql(field) # add dependent attributes for field in (f for f in field_defs if not f.in_key): non_key_fields.add(field.name) - sql += self._field_to_sql(field) + sql += field_to_sql(field) # add primary key declaration assert len(primary_key_fields) > 0, 'table must have a primary key' @@ -475,7 +393,7 @@ def _parse_declaration(self): ref_list = parents if in_key else referenced ref_list.append(self.lookup_name(ref_name)) elif re.match(r'^(unique\s+)?index[^:]*$', line, re.I): - index_defs.append(self._parse_index_def(line)) + index_defs.append(parse_index_definition(line)) elif attribute_regexp.match(line): field_defs.append(parse_attribute_definition(line, in_key)) else: @@ -497,147 +415,8 @@ def lookup_name(self, name): for attr in parts[1:]: ref = getattr(ref, attr) except (KeyError, AttributeError): - raise DataJointError('Foreign reference %s could not be resolved. Please make sure the name exists' - 'in the context of the class' % name) - return ref - -class ClassRelation(Relation, metaclass=abc.ABCMeta): - """ - A relation object that is handled at class level. All instances of the derived classes - share common connection and schema binding - """ - - _connection = None # connection information - _schema_name = None # name of schema this relation belongs to - _heading = None # heading information for this relation - _context = None # name reference lookup context - - def __init__(self, schema_name=None, connection=None, context=None): - """ - Use this constructor to specify class level - """ - if schema_name is not None: - self.schema_name = schema_name - - # TODO: Think about this implementation carefully - if connection is not None: - self.connection = connection - elif self.connection is None: - self.connection = conn() - - if context is not None: - self.context = context - elif self.context is None: - self.context = {} # initialize with an empty dictionary - - @property - def schema_name(self): - return self.__class__._schema_name - - @schema_name.setter - def schema_name(self, new_schema_name): - if self.schema_name is not None: - logger.warn('Overriding associated schema for class %s' - '- this will affect all existing instances!' % self.__class__.__name__) - self.__class__._schema_name = new_schema_name - - @property - def connection(self): - return self.__class__._connection - - @connection.setter - def connection(self, new_connection): - if self.connection is not None: - logger.warn('Overriding associated connection for class %s' - '- this will affect all existing instances!' % self.__class__.__name__) - self.__class__._connection = new_connection - - @property - def context(self): - # TODO: should this be a copy or the original? - return self.__class__._context.copy() - - @context.setter - def context(self, new_context): - if self.context is not None: - logger.warn('Overriding associated reference context for class %s' - '- this will affect all existing instances!' % self.__class__.__name__) - self.__class__._context = new_context - - @property - def heading(self): - return self.__class__._heading - - @heading.setter - def heading(self, new_heading): - self.__class__._heading = new_heading - - @abc.abstractproperty - def definition(self): - """ - Inheriting class must override this property with a valid table definition string - """ - pass - - @abc.abstractproperty - def table_prefix(self): - pass - - -class ManualRelation(ClassRelation): - @property - def table_prefix(self): - return "" - - -class AutoRelation(ClassRelation): - pass - - -class ComputedRelation(AutoRelation): - @property - def table_prefix(self): - return "_" - - - - - -def parse_attribute_definition(line, in_key=False): - """ - Parse attribute definition line in the declaration and returns - an attribute tuple. - - :param line: attribution line - :param in_key: set to True if attribute is in primary key set - :returns: attribute tuple - """ - line = line.strip() - attribute_regexp = re.compile(""" - ^(?P[a-z][a-z\d_]*)\s* # field name - (=\s*(?P\S+(\s+\S+)*?)\s*)? # default value - :\s*(?P\w[^\#]*[^\#\s])\s* # datatype - (\#\s*(?P\S*(\s+\S+)*)\s*)?$ # comment - """, re.X) - m = attribute_regexp.match(line) - if not m: - raise DataJointError('Invalid field declaration "%s"' % line) - attr_info = m.groupdict() - if not attr_info['comment']: - attr_info['comment'] = '' - if not attr_info['default']: - attr_info['default'] = '' - attr_info['nullable'] = attr_info['default'].lower() == 'null' - assert (not re.match(r'^bigint', attr_info['type'], re.I) or not attr_info['nullable']), \ - 'BIGINT attributes cannot be nullable in "%s"' % line - - return Heading.AttrTuple( - in_key=in_key, - autoincrement=None, - numeric=None, - string=None, - is_blob=None, - computation=None, - dtype=None, - **attr_info - ) + raise DataJointError( + 'Foreign key reference to %s could not be resolved.' + 'Please make sure the name exists' + 'in the context of the class' % name) + return ref \ No newline at end of file diff --git a/datajoint/connection.py b/datajoint/connection.py index 572cede34..52dae7598 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -10,17 +10,12 @@ logger = logging.getLogger(__name__) -# The following two regular expression are equivalent but one works in python -# and the other works in MySQL -table_name_regexp_sql = re.compile('^(#|_|__|~)?[a-z][a-z0-9_]*$') -table_name_regexp = re.compile('^(|#|_|__|~)[a-z][a-z0-9_]*$') # MySQL does not accept this but MariaDB does - def conn_container(): """ creates a persistent connections for everyone to use """ - _connObj = None # persistent connection object used by dj.conn() + _connection = None # persistent connection object used by dj.conn() def conn_function(host=None, user=None, passwd=None, init_fun=None, reset=False): """ @@ -30,8 +25,8 @@ def conn_function(host=None, user=None, passwd=None, init_fun=None, reset=False) Set rest=True to reset the persistent connection object """ - nonlocal _connObj - if not _connObj or reset: + nonlocal _connection + if not _connection or reset: host = host if host is not None else config['database.host'] user = user if user is not None else config['database.user'] passwd = passwd if passwd is not None else config['database.password'] @@ -39,18 +34,16 @@ def conn_function(host=None, user=None, passwd=None, init_fun=None, reset=False) if passwd is None: passwd = input("Please enter database password: ") init_fun = init_fun if init_fun is not None else config['connection.init_function'] - _connObj = Connection(host, user, passwd, init_fun) - return _connObj + _connection = Connection(host, user, passwd, init_fun) + return _connection return conn_function -# The function conn is used by others to obtain the package wide persistent connection object +# The function conn is used by others to obtain a connection object conn = conn_container() - - -class Connection(object): +class Connection: """ A dj.Connection object manages a connection to a database server. It also catalogues modules, schemas, tables, and their dependencies (foreign keys). @@ -73,23 +66,16 @@ def __init__(self, host, user, passwd, init_fun=None): self.conn_info = dict(host=host, port=port, user=user, passwd=passwd) self._conn = pymysql.connect(init_command=init_fun, **self.conn_info) if self.is_connected: - print("Connected", user + '@' + host + ':' + str(port)) + logger.info("Connected " + user + '@' + host + ':' + str(port)) else: raise DataJointError('Connection failed.') self._conn.autocommit(True) - - self.db_to_mod = {} # modules indexed by dbnames - self.mod_to_db = {} # database names indexed by modules - self.table_names = {} # tables names indexed by [dbname][class_name] - self.headings = {} # contains headings indexed by [dbname][table_name] - self.tableInfo = {} # table information indexed by [dbname][table_name] - - # dependencies from foreign keys - self.parents = {} # maps table names to their parent table names (primary foreign key) - self.referenced = {} # maps table names to table names they reference (non-primary foreign key - self._graph = DBConnGraph(self) # initialize an empty connection graph self._in_transaction = False + def __del__(self): + logger.info('Disconnecting {user}@{host}:{port}'.format(**self.conn_info)) + self._conn.close() + def __eq__(self, other): return self.conn_info == other.conn_info @@ -100,216 +86,6 @@ def is_connected(self): """ return self._conn.ping() - def get_full_module_name(self, module): - """ - Returns full module name of the module. - - :param module: module for which the name is requested. - :return: full module name - """ - return '.'.join(self.root_package, module) - - def bind(self, module, dbname): - """ - Binds the `module` name to the database named `dbname`. - Throws an error if `dbname` is already bound to another module. - - If the database `dbname` does not exist in the server, attempts - to create the database and then bind the module. - - - :param module: module name. - :param dbname: database name. It should be a valid database identifier and not a match pattern. - """ - - if dbname in self.db_to_mod: - raise DataJointError('Database `%s` is already bound to module `%s`' - % (dbname, self.db_to_mod[dbname])) - - cur = self.query("SHOW DATABASES LIKE '{dbname}'".format(dbname=dbname)) - count = cur.rowcount - - if count == 1: - # Database exists - self.db_to_mod[dbname] = module - self.mod_to_db[module] = dbname - elif count == 0: - # Database doesn't exist, attempt to create - logger.info("Database `{dbname}` could not be found. " - "Attempting to create the database.".format(dbname=dbname)) - try: - self.query("CREATE DATABASE `{dbname}`".format(dbname=dbname)) - logger.info('Created database `{dbname}`.'.format(dbname=dbname)) - self.db_to_mod[dbname] = module - self.mod_to_db[module] = dbname - except pymysql.OperationalError: - raise DataJointError("Database named `{dbname}` was not defined, and" - " an attempt to create has failed. Check" - " permissions.".format(dbname=dbname)) - else: - raise DataJointError("Database name {dbname} matched more than one " - "existing databases. Database name should not be " - "a pattern.".format(dbname=dbname)) - - def load_headings(self, dbname=None, force=False): - """ - Load table information including roles and list of attributes for all - tables within dbname by examining respective table status. - - If dbname is not specified or None, will load headings for all - databases that are bound to a module. - - By default, the heading is not loaded again if it already exists. - Setting force=True will result in reloading of the heading even if one - already exists. - - :param dbname=None: database name - :param force=False: force reloading the heading - """ - if dbname: - self._load_headings(dbname, force) - return - - for dbname in self.db_to_mod: - self._load_headings(dbname, force) - - def _load_headings(self, dbname, force=False): - """ - Load table information including roles and list of attributes for all - tables within dbname by examining respective table status. - - By default, the heading is not loaded again if it already exists. - Setting force=True will result in reloading of the heading even if one - already exists. - - :param dbname: database name - :param force: force reloading the heading - """ - if dbname not in self.headings or force: - logger.info('Loading table definitions from `{dbname}`...'.format(dbname=dbname)) - self.table_names[dbname] = {} - self.headings[dbname] = {} - self.tableInfo[dbname] = {} - - cur = self.query('SHOW TABLE STATUS FROM `{dbname}` WHERE name REGEXP "{sqlPtrn}"'.format( - dbname=dbname, sqlPtrn=table_name_regexp_sql.pattern), as_dict=True) - - for info in cur: - info = {k.lower(): v for k, v in info.items()} # lowercase it - table_name = info.pop('name') - # look up role by table name prefix - role = prefix_to_role[table_name_regexp.match(table_name).group(1)] - class_name = to_camel_case(table_name) - self.table_names[dbname][class_name] = table_name - self.tableInfo[dbname][table_name] = dict(info, role=role) - self.headings[dbname][table_name] = Heading.init_from_database(self, dbname, table_name) - self.load_dependencies(dbname) - - def load_dependencies(self, dbname): # TODO: Perhaps consider making this "private" by preceding with underscore? - """ - Load dependencies (foreign keys) between tables by examining their - respective CREATE TABLE statements. - - :param dbname: database name - """ - - foreign_key_regexp = re.compile(r""" - FOREIGN\ KEY\s+\((?P[`\w ,]+)\)\s+ # list of keys in this table - REFERENCES\s+(?P[^\s]+)\s+ # table referenced - \((?P[`\w ,]+)\) # list of keys in the referenced table - """, re.X) - - logger.info('Loading dependencies for `{dbname}`'.format(dbname=dbname)) - - for tabName in self.tableInfo[dbname]: - cur = self.query('SHOW CREATE TABLE `{dbname}`.`{tabName}`'.format(dbname=dbname, tabName=tabName), - as_dict=True) - table_def = cur.fetchone() - full_table_name = '`%s`.`%s`' % (dbname, tabName) - self.parents[full_table_name] = [] - self.referenced[full_table_name] = [] - - for m in foreign_key_regexp.finditer(table_def["Create Table"]): # iterate through foreign key statements - assert m.group('attr1') == m.group('attr2'), \ - 'Foreign keys must link identically named attributes' - attrs = m.group('attr1') - attrs = re.split(r',\s+', re.sub(r'`(.*?)`', r'\1', attrs)) # remove ` around attrs and split into list - pk = self.headings[dbname][tabName].primary_key - is_primary = all([k in pk for k in attrs]) - ref = m.group('ref') # referenced table - - if not re.search(r'`\.`', ref): # if referencing other table in same schema - ref = '`%s`.%s' % (dbname, ref) # convert to full-table name - - (self.parents if is_primary else self.referenced)[full_table_name].append(ref) - self.parents.setdefault(ref, []) - self.referenced.setdefault(ref, []) - - def clear_dependencies(self, dbname=None): - """ - Clears dependency mapping originating from `dbname`. If `dbname` is not - specified, dependencies for all databases will be cleared. - - - :param dbname: database name - """ - if dbname is None: # clear out all dependencies - self.parents.clear() - self.referenced.clear() - else: - table_keys = ('`%s`.`%s`' % (dbname, tblName) for tblName in self.tableInfo[dbname]) - for key in table_keys: - if key in self.parents: - self.parents.pop(key) - if key in self.referenced: - self.referenced.pop(key) - - def parents_of(self, child_table): - """ - Returns a list of tables that are parents of the specified child_table. Parent-child relationship is defined - based on the presence of primary-key foreign reference: table that holds a foreign key relation to another table - is the child table. - - :param child_table: the child table - :return: list of parent tables - """ - return self.parents.get(child_table, []).copy() - - def children_of(self, parent_table): - """ - Returns a list of tables for which parent_table is a parent (primary foreign key). Parent-child relationship - is defined based on the presence of primary-key foreign reference: table that holds a foreign key relation to - another table is the child table. - - :param parent_table: parent table - :return: list of child tables - """ - return [child_table for child_table, parents in self.parents.items() if parent_table in parents] - - def referenced_by(self, referencing_table): - """ - Returns a list of tables that are referenced by non-primary foreign key relation - by the referencing_table. - - :param referencing_table: referencing table - :return: list of tables that are referenced by the target table - """ - return self.referenced.get(referencing_table, []).copy() - - def referencing(self, referenced_table): - """ - Returns a list of tables that references referenced_table as non-primary foreign key - - :param referenced_table: referenced table - :return: list of tables that refers to the target table - """ - return [referencing for referencing, referenced in self.referenced.items() - if referenced_table in referenced] - - # TODO: Reimplement __str__ - def __str__(self): - return self.__repr__() # placeholder until more suitable __str__ is implemented - def __repr__(self): connected = "connected" if self.is_connected else "disconnected" return "DataJoint connection ({connected}) {user}@{host}:{port}".format( @@ -319,25 +95,6 @@ def __del__(self): logger.info('Disconnecting {user}@{host}:{port}'.format(**self.conn_info)) self._conn.close() - def erd(self, databases=None, tables=None, fill=True, reload=False): - """ - Creates Entity Relation Diagram for the database or specified subset of - tables. - - Set `fill` to False to only display specified tables. (By default - connection tables are automatically included) - """ - self._graph.update_graph(reload=reload) # update the graph - - graph = self._graph.copy_graph() - if databases: - graph = graph.restrict_by_modules(databases, fill) - - if tables: - graph = graph.restrict_by_tables(tables, fill) - - return graph - def query(self, query, args=(), as_dict=False): """ Execute the specified query and return the tuple generator. @@ -353,7 +110,7 @@ def query(self, query, args=(), as_dict=False): cur.execute(query, args) return cur - + # ---------- transaction processing ------------------ @property def in_transaction(self): self._in_transaction = self._in_transaction and self.is_connected @@ -374,5 +131,4 @@ def cancel_transaction(self): def commit_transaction(self): self.query('COMMIT') self._in_transaction = False - logger.info("Transaction commited and closed.") - + logger.info("Transaction committed and closed.") diff --git a/datajoint/decorators.py b/datajoint/decorators.py deleted file mode 100644 index 9ed650dd7..000000000 --- a/datajoint/decorators.py +++ /dev/null @@ -1,29 +0,0 @@ -__author__ = 'eywalker' -from .connection import conn - -def schema(name, context, connection=None): #TODO consider moving this into relation module - """ - Returns a schema decorator that can be used to associate a Relation class to a - specific database with :param name:. Name reference to other tables in the table definition - will be resolved by looking up the corresponding key entry in the passed in context dictionary. - It is most common to set context equal to the return value of call to locals() in the module. - For more details, please refer to the tutorial online. - - :param name: name of the database to associate the decorated class with - :param context: dictionary used to resolve (any) name references within the table definition string - :param connection: connection object to the database server. If ommited, will try to establish connection according to - config values - :return: a decorator function to be used on Relation derivative classes - """ - if connection is None: - connection = conn() - - def _dec(cls): - cls._schema_name = name - cls._context = context - cls._connection = connection - return cls - - return _dec - - diff --git a/datajoint/heading.py b/datajoint/heading.py index 620c93751..73d6c30f2 100644 --- a/datajoint/heading.py +++ b/datajoint/heading.py @@ -91,13 +91,13 @@ def __iter__(self): return iter(self.attributes) @classmethod - def init_from_database(cls, conn, dbname, table_name): + def init_from_database(cls, conn, database, table_name): """ initialize heading from a database table """ cur = conn.query( - 'SHOW FULL COLUMNS FROM `{table_name}` IN `{dbname}`'.format( - table_name=table_name, dbname=dbname), as_dict=True) + 'SHOW FULL COLUMNS FROM `{table_name}` IN `{database}`'.format( + table_name=table_name, database=database), as_dict=True) attributes = cur.fetchall() rename_map = { @@ -147,8 +147,8 @@ def init_from_database(cls, conn, dbname, table_name): attr['computation'] = None if not (attr['numeric'] or attr['string'] or attr['is_blob']): - raise DataJointError('Unsupported field type {field} in `{dbname}`.`{table_name}`'.format( - field=attr['type'], dbname=dbname, table_name=table_name)) + raise DataJointError('Unsupported field type {field} in `{database}`.`{table_name}`'.format( + field=attr['type'], database=database, table_name=table_name)) attr.pop('Extra') # fill out dtype. All floats and non-nullable integers are turned into specific dtypes diff --git a/datajoint/parsing.py b/datajoint/parsing.py new file mode 100644 index 000000000..85e367c96 --- /dev/null +++ b/datajoint/parsing.py @@ -0,0 +1,88 @@ +import re +from . import DataJointError +from .heading import Heading + + +def parse_attribute_definition(line, in_key=False): + """ + Parse attribute definition line in the declaration and returns + an attribute tuple. + + :param line: attribution line + :param in_key: set to True if attribute is in primary key set + :returns: attribute tuple + """ + line = line.strip() + attribute_regexp = re.compile(""" + ^(?P[a-z][a-z\d_]*)\s* # field name + (=\s*(?P\S+(\s+\S+)*?)\s*)? # default value + :\s*(?P\w[^\#]*[^\#\s])\s* # datatype + (\#\s*(?P\S*(\s+\S+)*)\s*)?$ # comment + """, re.X) + m = attribute_regexp.match(line) + if not m: + raise DataJointError('Invalid field declaration "%s"' % line) + attr_info = m.groupdict() + if not attr_info['comment']: + attr_info['comment'] = '' + if not attr_info['default']: + attr_info['default'] = '' + attr_info['nullable'] = attr_info['default'].lower() == 'null' + assert (not re.match(r'^bigint', attr_info['type'], re.I) or not attr_info['nullable']), \ + 'BIGINT attributes cannot be nullable in "%s"' % line + + return Heading.AttrTuple( + in_key=in_key, + autoincrement=None, + numeric=None, + string=None, + is_blob=None, + computation=None, + dtype=None, + **attr_info + ) + + +def field_to_sql(field): # TODO move this into Attribute Tuple + """ + Converts an attribute definition tuple into SQL code. + :param field: attribute definition + :rtype : SQL code + """ + mysql_constants = ['CURRENT_TIMESTAMP'] + if field.nullable: + default = 'DEFAULT NULL' + else: + default = 'NOT NULL' + # if some default specified + if field.default: + # enclose value in quotes except special SQL values or already enclosed + quote = field.default.upper() not in mysql_constants and field.default[0] not in '"\'' + default += ' DEFAULT ' + ('"%s"' if quote else "%s") % field.default + if any((c in r'\"' for c in field.comment)): + raise DataJointError('Illegal characters in attribute comment "%s"' % field.comment) + + return '`{name}` {type} {default} COMMENT "{comment}",\n'.format( + name=field.name, type=field.type, default=default, comment=field.comment) + + +def parse_index_definition(line): + """ + Parses index definition. + + :param line: definition line + :return: groupdict with index info + """ + line = line.strip() + index_regexp = re.compile(""" + ^(?PUNIQUE)?\s*INDEX\s* # [UNIQUE] INDEX + \((?P[^\)]+)\)$ # (attr1, attr2) + """, re.I + re.X) + m = index_regexp.match(line) + assert m, 'Invalid index declaration "%s"' % line + index_info = m.groupdict() + attributes = re.split(r'\s*,\s*', index_info['attributes'].strip()) + index_info['attributes'] = attributes + assert len(attributes) == len(set(attributes)), \ + 'Duplicate attributes in index declaration "%s"' % line + return index_info diff --git a/datajoint/relation_class.py b/datajoint/relation_class.py new file mode 100644 index 000000000..7fab544b7 --- /dev/null +++ b/datajoint/relation_class.py @@ -0,0 +1,94 @@ +import abc +import logging +from collections import namedtuple +import pymysql +from .connection import conn +from .base_relation import BaseRelation +from . import DataJointError + + +logger = logging.getLogger(__name__) + + +SharedInfo = namedtuple( + 'SharedInfo', + ('database', 'context', 'connection', 'heading', 'parents', 'children', 'references', 'referenced')) + + +def schema(database, context, connection=None): + """ + Returns a schema decorator that can be used to associate a BaseRelation class to a + specific database with :param name:. Name reference to other tables in the table definition + will be resolved by looking up the corresponding key entry in the passed in context dictionary. + It is most common to set context equal to the return value of call to locals() in the module. + For more details, please refer to the tutorial online. + + :param database: name of the database to associate the decorated class with + :param context: dictionary used to resolve (any) name references within the table definition string + :param connection: connection object to the database server. If ommited, will try to establish connection according to + config values + :return: a decorator function to be used on BaseRelation derivative classes + """ + if connection is None: + connection = conn() + + # if the database does not exist, create it + cur = connection.query("SHOW DATABASES LIKE '{database}'".format(database=database)) + if cur.rowcount == 0: + logger.info("Database `{database}` could not be found. " + "Attempting to create the database.".format(database=database)) + try: + connection.query("CREATE DATABASE `{database}`".format(database=database)) + logger.info('Created database `{database}`.'.format(database=database)) + except pymysql.OperationalError: + raise DataJointError("Database named `{database}` was not defined, and" + "an attempt to create has failed. Check" + " permissions.".format(database=database)) + + def decorator(cls): + cls._shared_info = SharedInfo( + database=database, + context=context, + connection=connection, + heading=None, + parents=[], + children=[], + references=[], + referenced=[] + ) + return cls + + return decorator + + +class RelationClass(BaseRelation): + """ + Abstract class for dedicated table classes. + Subclasses of RelationClass are dedicated interfaces to a single table. + The main purpose of RelationClass is to encapsulated sharedInfo containing the table heading + and dependency information shared by all instances of + """ + + _shared_info = None + + def __init__(self): + if self._shared_info is None: + raise DataJointError('The class must define _shared_info') + + @property + def database(self): + return self._shared_info.database + + @property + def connection(self): + return self._shared_info.connection + + @property + def context(self): + return self._shared_info.context + + @property + def heading(self): + if self._shared_info.heading is None: + self._shared_info.heading = super().heading + return self._shared_info.heading diff --git a/datajoint/relational_operand.py b/datajoint/relational_operand.py index 2a1ff588a..ac330a005 100644 --- a/datajoint/relational_operand.py +++ b/datajoint/relational_operand.py @@ -9,7 +9,6 @@ from datajoint import DataJointError, config from .blob import unpack import logging -import numpy.lib.recfunctions as rfn logger = logging.getLogger(__name__) @@ -24,26 +23,44 @@ class RelationalOperand(metaclass=abc.ABCMeta): RelationalOperand operators are: restrict, pro, and join. """ - def __init__(self, conn, restrictions=None): - self._conn = conn - self._restrictions = [] if restrictions is None else restrictions - - @property - def connection(self): - return self._conn + _restrictions = None @property def restrictions(self): return self._restrictions - @abc.abstractproperty + @property + def primary_key(self): + return self.heading.primary_key + + # --------- abstract properties ----------- + + @property + @abc.abstractmethod + def connection(self): + """ + :return: a datajoint.Connection object + """ + pass + + @property + @abc.abstractmethod def from_clause(self): + """ + :return: a string containing the FROM clause of the SQL SELECT statement + """ pass - @abc.abstractproperty + @property + @abc.abstractmethod def heading(self): + """ + :return: a valid datajoint.Heading object + """ pass + # --------- relational operators ----------- + def __mul__(self, other): """ relational join @@ -118,6 +135,8 @@ def __sub__(self, restriction): """ return self & Not(restriction) + # ------ data retrieval methods ----------- + def make_select(self, attribute_spec=None): if attribute_spec is None: attribute_spec = self.heading.as_sql @@ -285,7 +304,11 @@ def __init__(self, arg1, arg2): raise DataJointError('Cannot join relations with different database connections') self._arg1 = Subquery(arg1) if arg1.heading.computed else arg1 self._arg2 = Subquery(arg1) if arg2.heading.computed else arg2 - super().__init__(arg1.connection, self._arg1.restrictions + self._arg2.restrictions) + self._restrictions = self._arg1.restrictions + self._arg2.restrictions + + @property + def connection(self): + return self._arg1.connection @property def counter(self): @@ -318,7 +341,6 @@ def __init__(self, arg, group=None, *attributes, **renamed_attributes): self._renamed_attributes.update({d['alias']: d['sql_expression']}) else: self._attributes.append(attribute) - super().__init__(arg.connection) if group: if arg.connection != group.connection: raise DataJointError('Cannot join relations with different database connections') @@ -333,6 +355,10 @@ def __init__(self, arg, group=None, *attributes, **renamed_attributes): self._arg = arg self._restrictions = self._arg.restrictions + @property + def connection(self): + return self._arg.connection + @property def heading(self): return self._arg.heading.project(*self._attributes, **self._renamed_attributes) @@ -356,7 +382,10 @@ class Subquery(RelationalOperand): def __init__(self, arg): self._arg = arg - super().__init__(arg.connection) + + @property + def connection(self): + return self._arg.connection @property def counter(self): diff --git a/datajoint/user_relations.py b/datajoint/user_relations.py new file mode 100644 index 000000000..7de6567c1 --- /dev/null +++ b/datajoint/user_relations.py @@ -0,0 +1,31 @@ +from .relation_class import RelationClass +from .autopopulate import AutoPopulate +from .utils import from_camel_case + + +class ManualRelation(RelationClass): + @property + @classmethod + def table_name(cls): + return from_camel_case(cls.__name__) + + +class LookupRelation(RelationClass): + @property + @classmethod + def table_name(cls): + return '#' + from_camel_case(cls.__name__) + + +class ImportedRelation(RelationClass, AutoPopulate): + @property + @classmethod + def table_name(cls): + return "_" + from_camel_case(cls.__name__) + + +class ComputedRelation(RelationClass, AutoPopulate): + @property + @classmethod + def table_name(cls): + return "__" + from_camel_case(cls.__name__) \ No newline at end of file