diff --git a/datajoint/__init__.py b/datajoint/__init__.py index f44d86939..9e0142e2b 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -1,22 +1,20 @@ import logging import os -__author__ = "Dimitri Yatsenko, Edgar Walker, and Fabian Sinz at Baylor College of Medicine" +__author__ = "Dimitri Yatsenko, Edgar Walker, and Fabian Sinz at Baylor College of Medicine" __version__ = "0.2" __all__ = ['__author__', '__version__', 'Connection', 'Heading', 'Base', 'Not', - 'AutoPopulate', 'TaskQueue', 'conn', 'DataJointError', 'blob'] + 'AutoPopulate', 'conn', 'DataJointError', 'blob'] + -# ------------ define datajoint error before the import hierarchy is flattened ------------ class DataJointError(Exception): """ - Base class for errors specific to DataJoint internal - operation. + Base class for errors specific to DataJoint internal operation. """ pass - # ----------- loads local configuration from file ---------------- from .settings import Config, logger config = Config() @@ -37,10 +35,6 @@ class DataJointError(Exception): # ------------- flatten import hierarchy ------------------------- from .connection import conn, Connection from .base import Base -from .task import TaskQueue from .autopopulate import AutoPopulate from . import blob -from .relational import Not - - - +from .relational import Not \ No newline at end of file diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index 30996ac06..8e32ef43d 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -1,54 +1,48 @@ -from .relational import _Relational +from .relational import Relation +from . import DataJointError import pprint import abc #noinspection PyExceptionInherit,PyCallingNonCallable + + class AutoPopulate(metaclass=abc.ABCMeta): """ - Class datajoint.AutoPopulate is a mixin that adds the method populate() to a dj.Relvar class. - Auto-populated relvars must inherit from both datajoint.Relvar and datajoint.AutoPopulate, - must define the property popRel, and must define the callback method makeTuples. + AutoPopulate is a mixin class that adds the method populate() to a Base class. + Auto-populated relations must inherit from both Base and AutoPopulate, + must define the property pop_rel, and must define the callback method make_tuples. """ @abc.abstractproperty - def popRel(self): + def pop_rel(self): """ - Derived classes must implement the read-only property popRel (populate relation) which is the relational - expression (a dj.Relvar object) that defines how keys are generated for the populate call. + Derived classes must implement the read-only property pop_rel (populate relation) which is the relational + expression (a Relation object) that defines how keys are generated for the populate call. """ pass - @abc.abstractmethod - def makeTuples(self, key): + def make_tuples(self, key): """ - Derived classes must implement methods makeTuples that fetches data from parent tables, restricting by + Derived classes must implement method make_tuples that fetches data from parent tables, restricting by the given key, computes dependent attributes, and inserts the new tuples into self. """ pass - - def populate(self, catchErrors=False, reserveJobs=False, restrict=None): + def populate(self, catch_errors=False, reserve_jobs=False, restrict=None): """ - rel.populate() will call rel.makeTuples(key) for every primary key in self.popRel + rel.populate() will call rel.make_tuples(key) for every primary key in self.pop_rel for which there is not already a tuple in rel. """ - + if not isinstance(self.pop_rel, Relation): + raise DataJointError('') self.conn.cancel_transaction() - # enumerate unpopulated keys - unpopulated = self.popRel - if ~isinstance(unpopulated, _Relational): - unpopulated = unpopulated() # instantiate - + unpopulated = self.pop_rel - self if not unpopulated.count: - print('Nothing to populate') - else: - unpopulated = unpopulated(*args, **kwargs) # - self # TODO: implement antijoin - - # execute - if catchErrors: - errKeys, errors = [], [] + print('Nothing to populate', flush=True) # TODO: use logging? + if catch_errors: + error_keys, errors = [], [] for key in unpopulated.fetch(): self.conn.start_transaction() n = self(key).count @@ -57,17 +51,16 @@ def populate(self, catchErrors=False, reserveJobs=False, restrict=None): else: print('Populating:') pprint.pprint(key) - try: - self.makeTuples(key) + self.make_tuples(key) except Exception as e: self.conn.cancel_transaction() - if not catchErrors: + if not catch_errors: raise print(e) errors += [e] - errKeys+= [key] + error_keys += [key] else: self.conn.commit_transaction() - if catchErrors: - return errors, errKeys + if catch_errors: + return errors, error_keys \ No newline at end of file diff --git a/datajoint/base.py b/datajoint/base.py index a43c60268..96a34859d 100644 --- a/datajoint/base.py +++ b/datajoint/base.py @@ -1,70 +1,18 @@ import importlib -import re +import abc from types import ModuleType -import numpy as np from enum import Enum -from .utils import from_camel_case from . import DataJointError -from .relational import _Relational -from .heading import Heading +from .table import Table import logging - -# table names have prefixes that designate their roles in the processing chain logger = logging.getLogger(__name__) -# Todo: Shouldn't this go into the settings module? -Role = Enum('Role', 'manual lookup imported computed job') -role_to_prefix = { - Role.manual: '', - Role.lookup: '#', - Role.imported: '_', - Role.computed: '__', - Role.job: '~' -} -prefix_to_role = dict(zip(role_to_prefix.values(), role_to_prefix.keys())) - -mysql_constants = ['CURRENT_TIMESTAMP'] - - -class Base(_Relational): +class Base(Table, metaclass=abc.ABCMeta): """ - Base integrates all data manipulation and data declaration functions. - An instance of the class provides an interface to a single table in the database. - - An instance of the class can be produced in two ways: - - 1. direct instantiation (used mostly for debugging and odd jobs) - - 2. instantiation from a derived class (regular recommended use) - - With direct instantiation, instance parameters must be explicitly specified. - With a derived class, all the instance parameters are taken from the module - of the deriving class. The module must declare the connection object conn. - The name of the deriving class is used as the table's className. - - The table associated with an instance of Base is identified by the className - property, which is a string in CamelCase. The actual table name is obtained - by converting className from CamelCase to underscore_separated_words and - prefixing according to the table's role. - - The table declaration can be specified in the doc string of the inheriting - class, in the DataJoint table declaration syntax. - - Base also implements the methods insert and delete to insert and delete tuples - from the table. It can also be an argument in relational operators: restrict, - join, pro, and aggr. See class :mod:`datajoint.relational`. - - Base instances return their table's heading by looking it up in the connection - object. This ensures that Base instances contain the current table definition - even after tables are modified after the instance is created. - - :param conn=None: :mod:`datajoint.connection.Connection` object. Only used when Base is - instantiated directly. - :param dbname=None: Name of the database. Only used when Base is instantiated directly. - :param class_name=None: Class name. Only used when Base is instantiated directly. - :param table_def=None: Declaration of the table. Only used when Base is instantiated directly. + Base is a Table that implements data definition functions. + It is an abstract class with the abstract property 'definition'. Example for a usage of Base:: @@ -72,9 +20,8 @@ class Base(_Relational): class Subjects(dj.Base): - _table_def = ''' + definition = ''' test1.Subjects (manual) # Basic subject info - subject_id : int # unique subject id --- real_id : varchar(40) # real-world name @@ -83,248 +30,39 @@ class Subjects(dj.Base): """ - def __init__(self, conn=None, dbname=None, class_name=None, table_def=None): - self._use_package = False - if self.__class__ is Base: - # instantiate without subclassing - if not (conn and dbname and class_name): - raise DataJointError( - 'Missing argument: please specify conn, dbname, and class name.') - self.class_name = class_name - self.conn = conn - self.dbname = dbname - self._table_def = table_def - # register with a fake module, enclosed in back quotes + @abc.abstractproperty + def definition(self): + """ + :return: string containing the table declaration using the DataJoint Data Definition Language. + The DataJoint DDL is described at: TODO + """ + pass - if dbname not in self.conn.db_to_mod: - self.conn.bind('`{0}`'.format(dbname), dbname) - else: - # instantiate a derived class - if conn or dbname or class_name or table_def: - raise DataJointError( - 'With derived classes, constructor arguments are ignored') # TODO: consider changing this to a warning instead - self.class_name = self.__class__.__name__ - module = self.__module__ - mod_obj = importlib.import_module(module) + def __init__(self): + self.class_name = self.__class__.__name__ + module = self.__module__ + mod_obj = importlib.import_module(module) + try: + conn = mod_obj.conn + except AttributeError: try: - self.conn = mod_obj.conn + pkg_obj = importlib.import_module(mod_obj.__package__) + conn = pkg_obj.conn + use_package = True except AttributeError: - try: - pkg_obj = importlib.import_module(mod_obj.__package__) - self.conn = pkg_obj.conn - self._use_package = True - except AttributeError: - raise DataJointError( - "Please define object 'conn' in '{}' or in its containing package.".format(self.__module__)) - try: - if (self._use_package): - pkg_name = '.'.join(module.split('.')[:-1]) - self.dbname = self.conn.mod_to_db[pkg_name] - else: - self.dbname = self.conn.mod_to_db[module] - except KeyError: raise DataJointError( - 'Module {} is not bound to a database. See datajoint.connection.bind'.format(self.__module__)) - - if hasattr(self, '_table_def'): - self._table_def = self._table_def + "Please define object 'conn' in '{}' or in its containing package.".format(self.__module__)) + try: + if use_package: + pkg_name = '.'.join(module.split('.')[:-1]) + dbname = self.conn.mod_to_db[pkg_name] else: - self._table_def = None - - # todo: do we support records and named tuples for tup? - def insert(self, tup, ignore_errors=False, replace=False): - """ - Insert one data tuple. - - :param tup: Data tuple. Can be an iterable in matching order, a dict with named fields, or an np.void. - :param ignore_errors=False: Ignores errors if True. - :param replace=False: Replaces data tuple if True. - - Example:: - - b = djtest.Subject() - b.insert( dict(subject_id = 7, species="mouse",\\ - real_id = 1007, date_of_birth = "2014-09-01") ) - """ - - if issubclass(type(tup), tuple) or issubclass(type(tup), list): - valueList = ','.join([repr(q) for q in tup]) - fieldList = '`' + '`,`'.join(self.heading.names[0:len(tup)]) + '`' - elif issubclass(type(tup), dict): - valueList = ','.join([repr(tup[q]) - for q in self.heading.names if q in tup]) - fieldList = '`' + \ - '`,`'.join([q for q in self.heading.names if q in tup]) + '`' - elif issubclass(type(tup), np.void): - valueList = ','.join([repr(tup[q]) - for q in self.heading.names if q in tup]) - fieldList = '`' + '`,`'.join(tup.dtype.fields) + '`' - else: - raise DataJointError('Datatype %s cannot be inserted' % type(tup)) - if replace: - sql = 'REPLACE' - elif ignore_errors: - sql = 'INSERT IGNORE' - else: - sql = 'INSERT' - sql += " INTO %s (%s) VALUES (%s)" % (self.full_table_name, - fieldList, valueList) - logger.info(sql) - self.conn.query(sql) - - def drop(self): - """ - Drops the table associated to this object. - """ - # TODO make cascading (github issue #16) - self.conn.query('DROP TABLE %s' % self.full_table_name) - self.conn.clear_dependencies(dbname=self.dbname) - self.conn.load_headings(dbname=self.dbname, force=True) - logger.debug("Dropped table %s" % self.full_table_name) - - @property - def sql(self): - return self.full_table_name + self._whereClause - - @property - def heading(self): - self.declare() - return self.conn.headings[self.dbname][self.table_name] - - @property - def is_declared(self): - """ - :returns: True if table is found in the database - """ - self.conn.load_headings(self.dbname) - return self.class_name in self.conn.table_names[self.dbname] - - @property - def table_name(self): - """ - :return: name of the associated table - """ - self.declare() - return self.conn.table_names[self.dbname][self.class_name] - - @property - def full_table_name(self): - """ - :return: full name of the associated table - """ - return '`%s`.`%s`' % (self.dbname, self.table_name) - - @property - def full_class_name(self): - """ - :return: full class name - """ - return '{}.{}'.format(self.__module__, self.class_name) - - @property - def primary_key(self): - """ - :return: primary key of the table - """ - return self.heading.primary_key - - def declare(self): - """ - Declare the table in database if it doesn't already exist. - - :raises: DataJointError if the table cannot be declared. - """ - if not self.is_declared: - self._declare() - if not self.is_declared: - raise DataJointError( - 'Table could not be declared for %s' % self.class_name) - - """ - Data definition functionalities - """ - - def set_table_comment(self, newComment): - """ - Update the table comment in the table declaration. - - :param newComment: new comment as string - - """ - # TODO: add verification procedure (github issue #24) - self.alter('COMMENT="%s"' % newComment) - - def add_attribute(self, definition, after=None): - """ - Add a new attribute to the table. A full line from the table definition - is passed in as definition. - - The definition can specify where to place the new attribute. Use after=None - to add the attribute as the first attribute or after='attribute' to place it - after an existing attribute. - - :param definition: table definition - :param after=None: After which attribute of the table the new attribute is inserted. - If None, the attribute is inserted in front. - """ - position = ' FIRST' if after is None else ( - ' AFTER %s' % after if after else '') - sql = self._field_to_SQL(self._parse_attr_def(definition)) - self._alter('ADD COLUMN %s%s' % (sql[:-2], position)) - - def drop_attribute(self, attr_name): - """ - Drops the attribute attrName from this table. - - :param attr_name: Name of the attribute that is dropped. - """ - self._alter('DROP COLUMN `%s`' % attr_name) - - def alter_attribute(self, attr_name, new_definition): - """ - Alter the definition of the field attr_name in this table using the new definition. - - :param attr_name: field that is redefined - :param new_definition: new definition of the field - """ - sql = self._field_to_SQL(self._parse_attr_def(new_definition)) - self._alter('CHANGE COLUMN `%s` %s' % (attr_name, sql[:-2])) - - def erd(self, subset=None, prog='dot'): - """ - Plot the schema's entity relationship diagram (ERD). - The layout programs can be 'dot' (default), 'neato', 'fdp', 'sfdp', 'circo', 'twopi' - """ - if not subset: - g = self.graph - else: - g = self.graph.copy() - # todo: make erd work (github issue #7) - """ - g = self.graph - else: - g = self.graph.copy() - for i in g.nodes(): - if i not in subset: - g.remove_node(i) - def tablelist(tier): - return [i for i in g if self.tables[i].tier==tier] + dbname = self.conn.mod_to_db[module] + except KeyError: + raise DataJointError( + 'Module {} is not bound to a database. See datajoint.connection.bind'.format(self.__module__)) + super().__init__(self, conn=conn, dbname=dbname, class_name=self.__class__.__name__) - pos=nx.graphviz_layout(g,prog=prog,args='') - plt.figure(figsize=(8,8)) - nx.draw_networkx_edges(g, pos, alpha=0.3) - nx.draw_networkx_nodes(g, pos, nodelist=tablelist('manual'), - node_color='g', node_size=200, alpha=0.3) - nx.draw_networkx_nodes(g, pos, nodelist=tablelist('computed'), - node_color='r', node_size=200, alpha=0.3) - nx.draw_networkx_nodes(g, pos, nodelist=tablelist('imported'), - node_color='b', node_size=200, alpha=0.3) - nx.draw_networkx_nodes(g, pos, nodelist=tablelist('lookup'), - node_color='gray', node_size=120, alpha=0.3) - nx.draw_networkx_labels(g, pos, nodelist = subset, font_weight='bold', font_size=9) - nx.draw(g,pos,alpha=0,with_labels=false) - plt.show() - """ @classmethod def get_module(cls, module_name): @@ -355,268 +93,3 @@ def get_module(cls, module_name): return importlib.import_module(module_name) except ImportError: return None - - def get_base(self, module_name, class_name): - """ - Loads the base relation from the module. If the base relation is not defined in - the module, then construct it using Base constructor. - - :param module_name: module name - :param class_name: class name - :returns: the base relation - """ - mod_obj = self.get_module(module_name) - try: - ret = getattr(mod_obj, class_name)() - except KeyError: - ret = self.__class__(conn=self.conn, - dbname=self.conn.schemas[module_name], - class_name=class_name) - return ret - - # //////////////////////////////////////////////////////////// - # Private Methods - # //////////////////////////////////////////////////////////// - - def _field_to_SQL(self, field): - """ - Converts an attribute definition tuple into SQL code. - - :param field: attribute definition - :rtype : SQL code - """ - if field.isNullable: - default = 'DEFAULT NULL' - else: - default = 'NOT NULL' - # if some default specified - if field.default: - # enclose value in quotes (even numeric), except special SQL values - # or values already enclosed by the user - if field.default.upper() in mysql_constants or field.default[:1] in ["'", '"']: - default = '%s DEFAULT %s' % (default, field.default) - else: - default = '%s DEFAULT "%s"' % (default, field.default) - - # TODO: escape instead! - same goes for Matlab side implementation - assert not any((c in r'\"' for c in field.comment)), \ - 'Illegal characters in attribute comment "%s"' % field.comment - - return '`{name}` {type} {default} COMMENT "{comment}",\n'.format( - name=field.name, type=field.type, default=default, comment=field.comment) - - def _alter(self, alter_statement): - """ - Execute ALTER TABLE statement for this table. The schema - will be reloaded within the connection object. - - :param alter_statement: alter statement - """ - sql = 'ALTER TABLE %s %s' % (self.full_table_name, alter_statement) - self.conn.query(sql) - self.conn.load_headings(self.dbname, force=True) - # TODO: place table definition sync mechanism - - def _declare(self): - """ - Declares the table in the data base if no table in the database matches this object. - """ - if not self._table_def: - raise DataJointError('Table declaration is missing!') - table_info, parents, referenced, fieldDefs, indexDefs = self._parse_declaration() - defined_name = table_info['module'] + '.' + table_info['className'] - expected_name = self.__module__.split('.')[-1] + '.' + self.class_name - if not defined_name == expected_name: - raise DataJointError('Table name {} does not match the declared' - 'name {}'.format(expected_name, defined_name)) - - # compile the CREATE TABLE statement - # TODO: support prefix - table_name = role_to_prefix[ - table_info['tier']] + from_camel_case(self.class_name) - sql = 'CREATE TABLE `%s`.`%s` (\n' % (self.dbname, table_name) - - # add inherited primary key fields - primary_key_fields = set() - non_key_fields = set() - for p in parents: - for key in p.primary_key: - field = p.heading[key] - if field.name not in primary_key_fields: - primary_key_fields.add(field.name) - sql += self._field_to_SQL(field) - else: - logger.debug('Field definition of {} in {} ignored'.format( - field.name, p.full_class_name)) - - # add newly defined primary key fields - for field in (f for f in fieldDefs if f.isKey): - if field.isNullable: - raise DataJointError('Primary key {} cannot be nullable'.format( - field.name)) - if field.name in primary_key_fields: - raise DataJointError('Duplicate declaration of the primary key ' - '{key}. Check to make sure that the key ' - 'is not declared already in referenced ' - 'tables'.format(key=field.name)) - primary_key_fields.add(field.name) - sql += self._field_to_SQL(field) - - # add secondary foreign key attributes - for r in referenced: - keys = (x for x in r.heading.attrs.values() if x.isKey) - for field in keys: - if field.name not in primary_key_fields | non_key_fields: - non_key_fields.add(field.name) - sql += self._field_to_SQL(field) - - # add dependent attributes - for field in (f for f in fieldDefs if not f.isKey): - non_key_fields.add(field.name) - sql += self._field_to_SQL(field) - - # add primary key declaration - assert len(primary_key_fields) > 0, 'table must have a primary key' - keys = ', '.join(primary_key_fields) - sql += 'PRIMARY KEY (%s),\n' % keys - - # add foreign key declarations - for ref in parents + referenced: - keys = ', '.join(ref.primary_key) - sql += 'FOREIGN KEY (%s) REFERENCES %s (%s) ON UPDATE CASCADE ON DELETE RESTRICT,\n' % \ - (keys, ref.full_table_name, keys) - - # add secondary index declarations - # gather implicit indexes due to foreign keys first - implicit_indices = [] - for fk_source in parents + referenced: - implicit_indices.append(fk_source.primary_key) - - # for index in indexDefs: - # TODO: finish this up... - - # close the declaration - sql = '%s\n) ENGINE = InnoDB, COMMENT "%s"' % ( - sql[:-2], table_info['comment']) - - # make sure that the table does not alredy exist - self.conn.load_headings(self.dbname, force=True) - if not self.is_declared: - # execute declaration - logger.debug('\n\n' + sql + '\n\n') - self.conn.query(sql) - self.conn.load_headings(self.dbname, force=True) - - def _parse_declaration(self): - """ - Parse declaration and create new SQL table accordingly. - """ - parents = [] - referenced = [] - index_defs = [] - field_defs = [] - declaration = re.split(r'\s*\n\s*', self._table_def.strip()) - - # remove comment lines - declaration = [x for x in declaration if not x.startswith('#')] - ptrn = """ - ^(?P\w+)\.(?P\w+)\s* # module.className - \(\s*(?P\w+)\s*\)\s* # (tier) - \#\s*(?P.*)$ # comment - """ - p = re.compile(ptrn, re.X) - table_info = p.match(declaration[0]).groupdict() - if table_info['tier'] not in Role.__members__: - raise DataJointError('InvalidTableTier: Invalid tier {tier} for table\ - {module}.{cls}'.format(tier=table_info['tier'], - module=table_info[ - 'module'], - cls=table_info['className'])) - table_info['tier'] = Role[table_info['tier']] # convert into enum - - in_key = True # parse primary keys - field_ptrn = """ - ^[a-z][a-z\d_]*\s* # name - (=\s*\S+(\s+\S+)*\s*)? # optional defaults - :\s*\w.*$ # type, comment - """ - fieldP = re.compile(field_ptrn, re.I + re.X) # ignore case and verbose - - for line in declaration[1:]: - if line.startswith('---'): - in_key = False # start parsing non-PK fields - elif line.startswith('->'): - # foreign key - module_name, class_name = line[2:].strip().split('.') - rel = self.get_base(module_name, class_name) - (parents if in_key else referenced).append(rel) - elif re.match(r'^(unique\s+)?index[^:]*$', line): - index_defs.append(self._parse_index_def(line)) - elif fieldP.match(line): - field_defs.append(self._parse_attr_def(line, in_key)) - else: - raise DataJointError( - 'Invalid table declaration line "%s"' % line) - - return table_info, parents, referenced, field_defs, index_defs - - def _parse_attr_def(self, line, in_key=False): # todo add docu for in_key - """ - Parse attribute definition line in the declaration and returns - an attribute tuple. - - :param line: attribution line - :param in_key: - :returns: attribute tuple - """ - line = line.strip() - attr_ptrn = """ - ^(?P[a-z][a-z\d_]*)\s* # field name - (=\s*(?P\S+(\s+\S+)*?)\s*)? # default value - :\s*(?P\w[^\#]*[^\#\s])\s* # datatype - (\#\s*(?P\S*(\s+\S+)*)\s*)?$ # comment - """ - - attrP = re.compile(attr_ptrn, re.I + re.X) - m = attrP.match(line) - assert m, 'Invalid field declaration "%s"' % line - attr_info = m.groupdict() - if not attr_info['comment']: - attr_info['comment'] = '' - if not attr_info['default']: - attr_info['default'] = '' - attr_info['isNullable'] = attr_info['default'].lower() == 'null' - assert (not re.match(r'^bigint', attr_info['type'], re.I) or not attr_info['isNullable']), \ - 'BIGINT attributes cannot be nullable in "%s"' % line - - attr_info['isKey'] = in_key - attr_info['isAutoincrement'] = None - attr_info['isNumeric'] = None - attr_info['isString'] = None - attr_info['isBlob'] = None - attr_info['computation'] = None - attr_info['dtype'] = None - - return Heading.AttrTuple(**attr_info) - - def _parse_index_def(self, line): - """ - Parses index definition. - - :param line: definition line - :return: groupdict with index info - """ - line = line.strip() - index_ptrn = """ - ^(?PUNIQUE)?\s*INDEX\s* # [UNIQUE] INDEX - \((?P[^\)]+)\)$ # (attr1, attr2) - """ - indexP = re.compile(index_ptrn, re.I + re.X) - m = indexP.match(line) - assert m, 'Invalid index declaration "%s"' % line - index_info = m.groupdict() - attributes = re.split(r'\s*,\s*', index_info['attributes'].strip()) - index_info['attributes'] = attributes - assert len(attributes) == len(set(attributes)), \ - 'Duplicate attributes in index declaration "%s"' % line - return index_info diff --git a/datajoint/blob.py b/datajoint/blob.py index 82ac9e338..e66ca11ab 100644 --- a/datajoint/blob.py +++ b/datajoint/blob.py @@ -3,29 +3,27 @@ import numpy as np from . import DataJointError - mxClassID = collections.OrderedDict( # see http://www.mathworks.com/help/techdoc/apiref/mxclassid.html - mxUNKNOWN_CLASS = None, - mxCELL_CLASS = None, # not yet implemented - mxSTRUCT_CLASS = None, # not yet implemented - mxLOGICAL_CLASS = np.dtype('bool'), - mxCHAR_CLASS = np.dtype('c'), - mxVOID_CLASS = None, - mxDOUBLE_CLASS = np.dtype('float64'), - mxSINGLE_CLASS = np.dtype('float32'), - mxINT8_CLASS = np.dtype('int8'), - mxUINT8_CLASS = np.dtype('uint8'), - mxINT16_CLASS = np.dtype('int16'), - mxUINT16_CLASS = np.dtype('uint16'), - mxINT32_CLASS = np.dtype('int32'), - mxUINT32_CLASS = np.dtype('uint32'), - mxINT64_CLASS = np.dtype('int64'), - mxUINT64_CLASS = np.dtype('uint64'), - mxFUNCTION_CLASS= None - ) + mxUNKNOWN_CLASS=None, + mxCELL_CLASS=None, # not implemented + mxSTRUCT_CLASS=None, # not implemented + mxLOGICAL_CLASS=np.dtype('bool'), + mxCHAR_CLASS=np.dtype('c'), + mxVOID_CLASS=None, + mxDOUBLE_CLASS=np.dtype('float64'), + mxSINGLE_CLASS=np.dtype('float32'), + mxINT8_CLASS=np.dtype('int8'), + mxUINT8_CLASS=np.dtype('uint8'), + mxINT16_CLASS=np.dtype('int16'), + mxUINT16_CLASS=np.dtype('uint16'), + mxINT32_CLASS=np.dtype('int32'), + mxUINT32_CLASS=np.dtype('uint32'), + mxINT64_CLASS=np.dtype('int64'), + mxUINT64_CLASS=np.dtype('uint64'), + mxFUNCTION_CLASS=None) -reverseClassID = {v:i for i,v in enumerate(mxClassID.values())} +reverseClassID = {v: i for i, v in enumerate(mxClassID.values())} def pack(obj): @@ -35,55 +33,53 @@ def pack(obj): if not isinstance(obj, np.ndarray): raise DataJointError("Only numpy arrays can be saved in blobs") - blob = b"mYm\0A" # TODO: extend to process other datatypes besides arrays - blob += np.asarray((len(obj.shape),)+obj.shape,dtype=np.uint64).tostring() + blob = b"mYm\0A" # TODO: extend to process other data types besides arrays + blob += np.asarray((len(obj.shape),) + obj.shape, dtype=np.uint64).tostring() - isComplex = np.iscomplexobj(obj) - if isComplex: - obj, objImag = np.real(obj), np.imag(obj) + is_complex = np.iscomplexobj(obj) + if is_complex: + obj, imaginary = np.real(obj), np.imag(obj) - typeNum = reverseClassID[obj.dtype] - blob+= np.asarray(typeNum, dtype=np.uint32).tostring() - blob+= np.int8(isComplex).tostring() + b'\0\0\0' - blob+= obj.tostring() + type_number = reverseClassID[obj.dtype] + blob += np.asarray(type_number, dtype=np.uint32).tostring() + blob += np.int8(is_complex).tostring() + b'\0\0\0' + blob += obj.tostring() - if isComplex: - blob+= objImag.tostring() + if is_complex: + blob += imaginary.tostring() - if len(blob)>1000: - compressed = b'ZL123\0'+np.asarray(len(blob),dtype=np.uint64).tostring() + zlib.compress(blob) + if len(blob) > 1000: + compressed = b'ZL123\0'+np.asarray(len(blob), dtype=np.uint64).tostring() + zlib.compress(blob) if len(compressed) < len(blob): blob = compressed return blob - def unpack(blob): """ unpack blob into a numpy array """ # decompress if necessary - if blob[0:5]==b'ZL123': - blobLen = np.fromstring(blob[6:14],dtype=np.uint64)[0] + if blob[0:5] == b'ZL123': + blob_length = np.fromstring(blob[6:14], dtype=np.uint64)[0] blob = zlib.decompress(blob[14:]) - assert(len(blob)==blobLen) + assert len(blob) == blob_length - blobType = blob[4] - if blobType!=65: # TODO: also process structure arrays, cell arrays, etc. + blob_type = blob[4] + if blob_type != 65: # TODO: also process structure arrays, cell arrays, etc. raise DataJointError('only arrays are currently allowed in blobs') p = 5 - ndims = np.fromstring(blob[p:p+8], dtype=np.uint64) + dimensions = np.fromstring(blob[p:p+8], dtype=np.uint64) p += 8 - arrDims = np.fromstring(blob[p:p+8*ndims], dtype=np.uint64) - p += 8 * ndims - mxType, dtype = [q for q in mxClassID.items()][np.fromstring(blob[p:p+4],dtype=np.uint32)[0]] + array_shape = np.fromstring(blob[p:p+8*dimensions], dtype=np.uint64) + p += 8 * dimensions + mx_type, dtype = [q for q in mxClassID.items()][np.fromstring(blob[p:p+4], dtype=np.uint32)[0]] if dtype is None: - raise DataJointError('Unsupported matlab datatype '+mxType+' in blob') + raise DataJointError('Unsupported MATLAB data type '+mx_type+' in blob') p += 4 - complexity = np.fromstring(blob[p:p+4],dtype=np.uint32)[0] + is_complex = np.fromstring(blob[p:p+4], dtype=np.uint32)[0] p += 4 obj = np.fromstring(blob[p:], dtype=dtype) - if complexity: + if is_complex: obj = obj[:len(obj)/2] + 1j*obj[len(obj)/2:] - - return obj.reshape(arrDims) + return obj.reshape(array_shape) \ No newline at end of file diff --git a/datajoint/connection.py b/datajoint/connection.py index 2653950a3..ca6551cea 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -2,13 +2,9 @@ import re from .utils import to_camel_case from . import DataJointError -import os from .heading import Heading -from .base import prefix_to_role +from .settings import prefix_to_role import logging -import networkx as nx -from networkx import pygraphviz_layout -import matplotlib.pyplot as plt from .erd import DBConnGraph from . import config @@ -26,7 +22,7 @@ def conn_container(): """ _connObj = None # persistent connection object used by dj.conn() - def conn(host=None, user=None, passwd=None, initFun=None, reset=False): + def conn_function(host=None, user=None, passwd=None, init_fun=None, reset=False): """ Manage a persistent connection object. This is one of several ways to configure and access a datajoint connection. @@ -39,31 +35,29 @@ def conn(host=None, user=None, passwd=None, initFun=None, reset=False): host = host if host is not None else config['database.host'] user = user if user is not None else config['database.user'] passwd = passwd if passwd is not None else config['database.password'] - initFun = initFun if initFun is not None else config['connection.init_function'] - _connObj = Connection(host, user, passwd, initFun) + init_fun = init_fun if init_fun is not None else config['connection.init_function'] + _connObj = Connection(host, user, passwd, init_fun) return _connObj - return conn - + return conn_function # The function conn is used by others to obtain the package wide persistent connection object conn = conn_container() - class Connection: """ A dj.Connection object manages a connection to a database server. It also catalogues modules, schemas, tables, and their dependencies (foreign keys) """ - def __init__(self, host, user, passwd, initFun=None): + def __init__(self, host, user, passwd, init_fun=None): if ':' in host: host, port = host.split(':') port = int(port) else: port = config['database.port'] self.conn_info = dict(host=host, port=port, user=user, passwd=passwd) - self._conn = pymysql.connect(init_command=initFun, **self.conn_info) + self._conn = pymysql.connect(init_command=init_fun, **self.conn_info) # TODO Do something if connection cannot be established if self.is_connected: print("Connected", user + '@' + host + ':' + str(port)) @@ -140,11 +134,11 @@ def bind(self, module, dbname): self.mod_to_db[module] = dbname elif count == 0: # Database doesn't exist, attempt to create - logger.warning("Database `{dbname}` could not be found. " - "Attempting to create the database.".format(dbname=dbname)) + logger.info("Database `{dbname}` could not be found. " + "Attempting to create the database.".format(dbname=dbname)) try: - cur = self.query("CREATE DATABASE `{dbname}`".format(dbname=dbname)) - logger.warning('Created database `{dbname}`.'.format(dbname=dbname)) + self.query("CREATE DATABASE `{dbname}`".format(dbname=dbname)) + logger.info('Created database `{dbname}`.'.format(dbname=dbname)) self.db_to_mod[dbname] = module self.mod_to_db[module] = dbname except pymysql.OperationalError: @@ -227,7 +221,8 @@ def load_dependencies(self, dbname): # TODO: Perhaps consider making this "priv self.referenced[full_table_name] = [] for m in re.finditer(ptrn, table_def["Create Table"], re.X): # iterate through foreign key statements - assert m.group('attr1') == m.group('attr2'), 'Foreign keys must link identically named attributes' + assert m.group('attr1') == m.group('attr2'), \ + 'Foreign keys must link identically named attributes' attrs = m.group('attr1') attrs = re.split(r',\s+', re.sub(r'`(.*?)`', r'\1', attrs)) # remove ` around attrs and split into list pk = self.headings[dbname][tabName].primary_key @@ -270,14 +265,14 @@ def parents_of(self, child_table): def children_of(self, parent_table): """ - Returnis a list of tables for which parentTable is a parent (primary foreign key) + Returns a list of tables for which parent_table is a parent (primary foreign key) """ return [child_table for child_table, parents in self.parents.items() if parent_table in parents] def referenced_by(self, referencing_table): """ Returns a list of tables that are referenced by non-primary foreign key - by the referencingTable. + by the referencing_table. """ return self.referenced.get(referencing_table, []).copy() diff --git a/datajoint/declare.py b/datajoint/declare.py new file mode 100644 index 000000000..af09334ae --- /dev/null +++ b/datajoint/declare.py @@ -0,0 +1,239 @@ +import re +import logging +from .heading import Heading +from . import DataJointError +from .utils import from_camel_case +from .settings import Role, role_to_prefix + +mysql_constants = ['CURRENT_TIMESTAMP'] + +logger = logging.getLogger(__name__) + + +def declare(conn, definition, class_name): + """ + Declares the table in the data base if no table in the database matches this object. + """ + table_info, parents, referenced, field_definitions, index_definitions = parse_declaration(definition) + defined_name = table_info['module'] + '.' + table_info['className'] + if not defined_name == class_name: + raise DataJointError('Table name {} does not match the declared' + 'name {}'.format(class_name, defined_name)) + + # compile the CREATE TABLE statement + table_name = role_to_prefix[table_info['tier']] + from_camel_case(class_name) + sql = 'CREATE TABLE `%s`.`%s` (\n' % (self.dbname, table_name) + + # add inherited primary key fields + primary_key_fields = set() + non_key_fields = set() + for p in parents: + for key in p.primary_key: + field = p.heading[key] + if field.name not in primary_key_fields: + primary_key_fields.add(field.name) + sql += field_to_sql(field) + else: + logger.debug('Field definition of {} in {} ignored'.format( + field.name, p.full_class_name)) + + # add newly defined primary key fields + for field in (f for f in field_definitions if f.isKey): + if field.nullable: + raise DataJointError('Primary key {} cannot be nullable'.format( + field.name)) + if field.name in primary_key_fields: + raise DataJointError('Duplicate declaration of the primary key ' + '{key}. Check to make sure that the key ' + 'is not declared already in referenced ' + 'tables'.format(key=field.name)) + primary_key_fields.add(field.name) + sql += field_to_sql(field) + + # add secondary foreign key attributes + for r in referenced: + keys = (x for x in r.heading.attrs.values() if x.isKey) + for field in keys: + if field.name not in primary_key_fields | non_key_fields: + non_key_fields.add(field.name) + sql += field_to_sql(field) + + # add dependent attributes + for field in (f for f in field_definitions if not f.isKey): + non_key_fields.add(field.name) + sql += field_to_sql(field) + + # add primary key declaration + assert len(primary_key_fields) > 0, 'table must have a primary key' + keys = ', '.join(primary_key_fields) + sql += 'PRIMARY KEY (%s),\n' % keys + + # add foreign key declarations + for ref in parents + referenced: + keys = ', '.join(ref.primary_key) + sql += 'FOREIGN KEY (%s) REFERENCES %s (%s) ON UPDATE CASCADE ON DELETE RESTRICT,\n' % \ + (keys, ref.full_table_name, keys) + + # add secondary index declarations + # gather implicit indexes due to foreign keys first + implicit_indices = [] + for fk_source in parents + referenced: + implicit_indices.append(fk_source.primary_key) + + # for index in index_definitions: + # TODO: finish this up... + + # close the declaration + sql = '%s\n) ENGINE = InnoDB, COMMENT "%s"' % ( + sql[:-2], table_info['comment']) + + # make sure that the table does not alredy exist + self.conn.load_headings(self.dbname, force=True) + if not self.is_declared: + # execute declaration + logger.debug('\n\n' + sql + '\n\n') + self.conn.query(sql) + self.conn.load_headings(self.dbname, force=True) + + +def _parse_declaration(self): + """ + Parse declaration and create new SQL table accordingly. + """ + parents = [] + referenced = [] + index_defs = [] + field_defs = [] + declaration = re.split(r'\s*\n\s*', self.definition.strip()) + + # remove comment lines + declaration = [x for x in declaration if not x.startswith('#')] + ptrn = """ + ^(?P\w+)\.(?P\w+)\s* # module.className + \(\s*(?P\w+)\s*\)\s* # (tier) + \#\s*(?P.*)$ # comment + """ + p = re.compile(ptrn, re.X) + table_info = p.match(declaration[0]).groupdict() + if table_info['tier'] not in Role.__members__: + raise DataJointError('InvalidTableTier: Invalid tier {tier} for table\ + {module}.{cls}'.format(tier=table_info['tier'], + module=table_info[ + 'module'], + cls=table_info['className'])) + table_info['tier'] = Role[table_info['tier']] # convert into enum + + in_key = True # parse primary keys + field_ptrn = """ + ^[a-z][a-z\d_]*\s* # name + (=\s*\S+(\s+\S+)*\s*)? # optional defaults + :\s*\w.*$ # type, comment + """ + fieldP = re.compile(field_ptrn, re.I + re.X) # ignore case and verbose + + for line in declaration[1:]: + if line.startswith('---'): + in_key = False # start parsing non-PK fields + elif line.startswith('->'): + # foreign key + module_name, class_name = line[2:].strip().split('.') + rel = self.get_base(module_name, class_name) + (parents if in_key else referenced).append(rel) + elif re.match(r'^(unique\s+)?index[^:]*$', line): + index_defs.append(parse_index_defnition(line)) + elif fieldP.match(line): + field_defs.append(parse_attribute_definition(line, in_key)) + else: + raise DataJointError( + 'Invalid table declaration line "%s"' % line) + + return table_info, parents, referenced, field_defs, index_defs + + +def field_to_sql(field): + """ + Converts an attribute definition tuple into SQL code. + :param field: attribute definition + :rtype : SQL code + """ + if field.nullable: + default = 'DEFAULT NULL' + else: + default = 'NOT NULL' + # if some default specified + if field.default: + # enclose value in quotes (even numeric), except special SQL values + # or values already enclosed by the user + if field.default.upper() in mysql_constants or field.default[:1] in ["'", '"']: + default = '%s DEFAULT %s' % (default, field.default) + else: + default = '%s DEFAULT "%s"' % (default, field.default) + + # TODO: escape instead! - same goes for Matlab side implementation + assert not any((c in r'\"' for c in field.comment)), \ + 'Illegal characters in attribute comment "%s"' % field.comment + + return '`{name}` {type} {default} COMMENT "{comment}",\n'.format( + name=field.name, type=field.type, default=default, comment=field.comment) + + +def parse_attribute_definition(line, in_key=False): # todo add docu for in_key + """ + Parse attribute definition line in the declaration and returns + an attribute tuple. + :param line: attribution line + :param in_key: + :returns: attribute tuple + """ + line = line.strip() + attr_ptrn = """ + ^(?P[a-z][a-z\d_]*)\s* # field name + (=\s*(?P\S+(\s+\S+)*?)\s*)? # default value + :\s*(?P\w[^\#]*[^\#\s])\s* # datatype + (\#\s*(?P\S*(\s+\S+)*)\s*)?$ # comment + """ + + attrP = re.compile(attr_ptrn, re.I + re.X) + m = attrP.match(line) + assert m, 'Invalid field declaration "%s"' % line + attr_info = m.groupdict() + if not attr_info['comment']: + attr_info['comment'] = '' + if not attr_info['default']: + attr_info['default'] = '' + attr_info['nullable'] = attr_info['default'].lower() == 'null' + assert (not re.match(r'^bigint', attr_info['type'], re.I) or not attr_info['nullable']), \ + 'BIGINT attributes cannot be nullable in "%s"' % line + + attr_info['isKey'] = in_key + attr_info['isAutoincrement'] = None + attr_info['isNumeric'] = None + attr_info['isString'] = None + attr_info['isBlob'] = None + attr_info['computation'] = None + attr_info['dtype'] = None + + return Heading.AttrTuple(**attr_info) + + +def parse_index_definition(line): # why is this a method of Base instead of a local function? + """ + Parses index definition. + + :param line: definition line + :return: groupdict with index info + """ + line = line.strip() + index_ptrn = """ + ^(?PUNIQUE)?\s*INDEX\s* # [UNIQUE] INDEX + \((?P[^\)]+)\)$ # (attr1, attr2) + """ + indexP = re.compile(index_ptrn, re.I + re.X) + m = indexP.match(line) + assert m, 'Invalid index declaration "%s"' % line + index_info = m.groupdict() + attributes = re.split(r'\s*,\s*', index_info['attributes'].strip()) + index_info['attributes'] = attributes + assert len(attributes) == len(set(attributes)), \ + 'Duplicate attributes in index declaration "%s"' % line + return index_info diff --git a/datajoint/erd.py b/datajoint/erd.py index a7e1bc024..af6107390 100644 --- a/datajoint/erd.py +++ b/datajoint/erd.py @@ -14,6 +14,7 @@ logger = logging.getLogger(__name__) + class RelGraph(DiGraph): """ Represents relations between tables and databases diff --git a/datajoint/fetch.py b/datajoint/fetch.py deleted file mode 100644 index e658bf1e6..000000000 --- a/datajoint/fetch.py +++ /dev/null @@ -1,62 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Created on Wed Aug 20 22:05:29 2014 - -@author: dimitri -""" - -from .blob import unpack -import numpy as np -import logging - -logger = logging.getLogger(__name__) - - -class Fetch: - """ - Fetch defines callable objects that fetch data from a relation - """ - - def __init__(self, relational): - self.rel = relational - self._orderBy = None - self._offset = 0 - self._limit = None - - def limit(self, n, offset=0): - self._limit = n - self._offset = offset - return self - - def order_by(self, *attrs): - self._orderBy = attrs - return self - - def __call__(self, *attrs, **renames): - """ - fetch relation from database into an np.array - """ - cur = self._cursor(*attrs, **renames) - heading = self.rel.pro(*attrs, **renames).heading - ret = np.array(list(cur), dtype=heading.asdtype) - # unpack blobs - for i in range(len(ret)): - for f in heading.blobs: - ret[i][f] = unpack(ret[i][f]) - return ret - - def _cursor(self, *attrs, **renames): - rel = self.rel.pro(*attrs, **renames) - sql = 'SELECT ' + rel.heading.asSQL + ' FROM ' + rel.sql - # add ORDER BY clause - if self._orderBy: - sql += ' ORDER BY ' + ', '.join(self._orderBy) - - # add LIMIT clause - if self._limit: - sql += ' LIMIT %d' % self._limit - if self._offset: - sql += ' OFFSET %d ' % self._offset - - logger.debug(sql) - return self.rel.conn.query(sql) diff --git a/datajoint/heading.py b/datajoint/heading.py index ca7495db0..18e75df62 100644 --- a/datajoint/heading.py +++ b/datajoint/heading.py @@ -1,9 +1,3 @@ -# -*- coding: utf-8 -*- -""" -Created on Mon Aug 4 01:29:51 2014 - -@author: dimitri, eywalker -""" import re from collections import OrderedDict, namedtuple @@ -13,219 +7,214 @@ class Heading: """ - local class for relationals' headings. + local class for relations' headings. """ + AttrTuple = namedtuple('AttrTuple', + ('name', 'type', 'in_key', 'nullable', 'default', 'comment', 'autoincrement', + 'numeric', 'string', 'is_blob', 'computation', 'dtype')) - AttrTuple = namedtuple('AttrTuple',('name','type','isKey','isNullable', - 'default','comment','isAutoincrement','isNumeric','isString','isBlob', - 'computation','dtype')) - - def __init__(self, attrs): - # Input: attrs -list of dicts with attribute descriptions - self.attrs = OrderedDict([(q['name'], Heading.AttrTuple(**q)) for q in attrs]) + def __init__(self, attributes): + # Input: attributes -list of dicts with attribute descriptions + self.attributes = OrderedDict([(q['name'], Heading.AttrTuple(**q)) for q in attributes]) @property def names(self): - return [k for k in self.attrs] + return [k for k in self.attributes] @property def primary_key(self): - return [k for k,v in self.attrs.items() if v.isKey] + return [k for k, v in self.attributes.items() if v.in_key] @property def dependent_fields(self): - return [k for k,v in self.attrs.items() if not v.isKey] + return [k for k, v in self.attributes.items() if not v.in_key] @property def blobs(self): - return [k for k,v in self.attrs.items() if v.isBlob] + return [k for k, v in self.attributes.items() if v.is_blob] @property def non_blobs(self): - return [k for k,v in self.attrs.items() if not v.isBlob] + return [k for k, v in self.attributes.items() if not v.is_blob] @property def computed(self): - return [k for k,v in self.attrs.items() if v.computation] + return [k for k, v in self.attributes.items() if v.computation] - def __getitem__(self,name): + def __getitem__(self, name): """shortcut to the attribute""" - return self.attrs[name] + return self.attributes[name] def __repr__(self): - autoIncrementString = {False:'', True:' auto_increment'} + autoincrement_string = {False: '', True: ' auto_increment'} return '\n'.join(['%-20s : %-28s # %s' % ( - k if v.default is None else '%s="%s"'%(k,v.default), - '%s%s' % (v.type, autoIncrementString[v.isAutoincrement]), + k if v.default is None else '%s="%s"' % (k, v.default), + '%s%s' % (v.type, autoincrement_string[v.autoincrement]), v.comment) - for k,v in self.attrs.items()]) + for k, v in self.attributes.items()]) @property - def asdtype(self): + def as_dtype(self): """ represent the heading as a numpy dtype """ return np.dtype(dict( names=self.names, - formats=[v.dtype for k,v in self.attrs.items()])) + formats=[v.dtype for k, v in self.attributes.items()])) @property - def asSQL(self): - """represent heading as SQL field list""" - attrNames = ['`%s`' % name if self.attrs[name].computation is None else '%s as `%s`' % (self.attrs[name].computation, name) - for name in self.names] - return ','.join(attrNames) - - # Use heading as a dictionary like object + def as_sql(self): + """ + represent heading as SQL field list + """ + return ','.join(['`%s`' % name + if self.attributes[name].computation is None + else '%s as `%s`' % (self.attributes[name].computation, name) + for name in self.names]) def keys(self): - return self.attrs.keys() + return self.attributes.keys() def values(self): - return self.attrs.values() + return self.attributes.values() def items(self): - return self.attrs.items() - + return self.attributes.items() @classmethod - def init_from_database(cls, conn, dbname, tabname): + def init_from_database(cls, conn, dbname, table_name): """ initialize heading from a database table """ cur = conn.query( - 'SHOW FULL COLUMNS FROM `{tabname}` IN `{dbname}`'.format( - tabname=tabname, dbname=dbname),asDict=True) - attrs = cur.fetchall() + 'SHOW FULL COLUMNS FROM `{table_name}` IN `{dbname}`'.format( + table_name=table_name, dbname=dbname), asDict=True) + attributes = cur.fetchall() rename_map = { - 'Field' : 'name', - 'Type' : 'type', - 'Null' : 'isNullable', + 'Field': 'name', + 'Type': 'type', + 'Null': 'nullable', 'Default': 'default', - 'Key' : 'isKey', + 'Key': 'in_key', 'Comment': 'comment'} - dropFields = ('Privileges', 'Collation') # unncessary + fields_to_drop = ('Privileges', 'Collation') # rename and drop attributes - attrs = [{rename_map[k] if k in rename_map else k: v - for k, v in x.items() if k not in dropFields} - for x in attrs] - numTypes ={ - ('float',False):np.float32, - ('float',True):np.float32, - ('double',False):np.float32, - ('double',True):np.float64, - ('tinyint',False):np.int8, - ('tinyint',True):np.uint8, - ('smallint',False):np.int16, - ('smallint',True):np.uint16, - ('mediumint',False):np.int32, - ('mediumint',True):np.uint32, - ('int',False):np.int32, - ('int',True):np.uint32, - ('bigint',False):np.int64, - ('bigint',True):np.uint64 - # TODO: include decimal and numeric datatypes + attributes = [{rename_map[k] if k in rename_map else k: v + for k, v in x.items() if k not in fields_to_drop} + for x in attributes] + + numeric_types = { + ('float', False): np.float32, + ('float', True): np.float32, + ('double', False): np.float32, + ('double', True): np.float64, + ('tinyint', False): np.int8, + ('tinyint', True): np.uint8, + ('smallint', False): np.int16, + ('smallint', True): np.uint16, + ('mediumint', False): np.int32, + ('mediumint', True): np.uint32, + ('int', False): np.int32, + ('int', True): np.uint32, + ('bigint', False): np.int64, + ('bigint', True): np.uint64 + # TODO: include types DECIMAL and NUMERIC } - # additional attribute properties - for attr in attrs: - attr['isNullable'] = (attr['isNullable'] == 'YES') - attr['isKey'] = (attr['isKey'] == 'PRI') - attr['isAutoincrement'] = bool(re.search(r'auto_increment', attr['Extra'], flags=re.IGNORECASE)) - attr['isNumeric'] = bool(re.match(r'(tiny|small|medium|big)?int|decimal|double|float', attr['type'])) - attr['isString'] = bool(re.match(r'(var)?char|enum|date|time|timestamp', attr['type'])) - attr['isBlob'] = bool(re.match(r'(tiny|medium|long)?blob', attr['type'])) + for attr in attributes: + attr['nullable'] = (attr['nullable'] == 'YES') + attr['in_key'] = (attr['in_key'] == 'PRI') + attr['autoincrement'] = bool(re.search(r'auto_increment', attr['Extra'], flags=re.IGNORECASE)) + attr['numeric'] = bool(re.match(r'(tiny|small|medium|big)?int|decimal|double|float', attr['type'])) + attr['string'] = bool(re.match(r'(var)?char|enum|date|time|timestamp', attr['type'])) + attr['is_blob'] = bool(re.match(r'(tiny|medium|long)?blob', attr['type'])) # strip field lengths off integer types attr['type'] = re.sub(r'((tiny|small|medium|big)?int)\(\d+\)', r'\1', attr['type']) attr['computation'] = None - if not (attr['isNumeric'] or attr['isString'] or attr['isBlob']): - raise DataJointError('Unsupported field type {field} in `{dbname}`.`{tabname}`'.format( - field=attr['type'], dbname=dbname, tabname=tabname)) + if not (attr['numeric'] or attr['string'] or attr['is_blob']): + raise DataJointError('Unsupported field type {field} in `{dbname}`.`{table_name}`'.format( + field=attr['type'], dbname=dbname, table_name=table_name)) attr.pop('Extra') # fill out the dtype. All floats and non-nullable integers are turned into specific dtypes attr['dtype'] = object - if attr['isNumeric'] : - isInteger = bool(re.match(r'(tiny|small|medium|big)?int',attr['type'])) - isFloat = bool(re.match(r'(double|float)',attr['type'])) - if isInteger and not attr['isNullable'] or isFloat: - isUnsigned = bool(re.match('\sunsigned', attr['type'], flags=re.IGNORECASE)) + if attr['numeric']: + is_integer = bool(re.match(r'(tiny|small|medium|big)?int', attr['type'])) + is_float = bool(re.match(r'(double|float)', attr['type'])) + if is_integer and not attr['nullable'] or is_float: + is_unsigned = bool(re.match('\sunsigned', attr['type'], flags=re.IGNORECASE)) t = attr['type'] - t = re.sub(r'\(.*\)','',t) # remove parentheses - t = re.sub(r' unsigned$','',t) # remove unsigned - assert (t,isUnsigned) in numTypes, 'dtype not found for type %s' % t - attr['dtype'] = numTypes[(t,isUnsigned)] - - return cls(attrs) + t = re.sub(r'\(.*\)', '', t) # remove parentheses + t = re.sub(r' unsigned$', '', t) # remove unsigned + assert (t, is_unsigned) in numeric_types, 'dtype not found for type %s' % t + attr['dtype'] = numeric_types[(t, is_unsigned)] + return cls(attributes) - def pro(self, *attrList, **renameDict): + def pro(self, *attribute_list, **rename_dict): """ derive a new heading by selecting, renaming, or computing attributes. In relational algebra these operators are known as project, rename, and expand. The primary key is always included. """ - # include all if '*' is in attrSet, always include primary key - attrSet = set(self.names) if '*' in attrList \ - else set(attrList).union(self.primary_key) + # include all if '*' is in attribute_set, always include primary key + attribute_set = set(self.names) if '*' in attribute_list \ + else set(attribute_list).union(self.primary_key) # report missing attributes - missing = attrSet.difference(self.names) + missing = attribute_set.difference(self.names) if missing: raise DataJointError('Attributes %s are not found' % str(missing)) - # make attrList a list of dicts for initializing a Heading - attrList = [v._asdict() for k,v in self.attrs.items() - if k in attrSet and k not in renameDict.values()] + # make attribute_list a list of dicts for initializing a Heading + attribute_list = [v._asdict() for k, v in self.attributes.items() + if k in attribute_set and k not in rename_dict.values()] # add renamed and computed attributes - for newName, computation in renameDict.items(): + for new_name, computation in rename_dict.items(): if computation in self.names: # renamed attribute - newAttr = self.attrs[computation]._asdict() - newAttr['name'] = newName - newAttr['computation'] = '`' + computation + '`' + new_attr = self.attributes[computation]._asdict() + new_attr['name'] = new_name + new_attr['computation'] = '`' + computation + '`' else: # computed attribute - newAttr = dict( - name = newName, - type = 'computed', - isKey = False, - isNullable = False, - default = None, - comment = 'computed attribute', - isAutoincrement = False, - isNumeric = None, - isString = None, - isBlob = False, - computation = computation, - dtype = object) - attrList.append(newAttr) - - return Heading(attrList) - + new_attr = dict( + name=new_name, + type='computed', + in_key=False, + nullable=False, + default=None, + comment='computed attribute', + autoincrement=False, + numeric=None, + string=None, + is_blob=False, + computation=computation, + dtype=object) + attribute_list.append(new_attr) + + return Heading(attribute_list) def join(self, other): """ join two headings """ - assert isinstance(other,Heading) - attrList = [v._asdict() for v in self.attrs.values()] + assert isinstance(other, Heading) + attribute_list = [v._asdict() for v in self.attributes.values()] for name in other.names: if name not in self.names: - attrList.append(other.attrs[name]._asdict()) - return Heading(attrList) + attribute_list.append(other.attributes[name]._asdict()) + return Heading(attribute_list) - - def resolveComputations(self): + def resolve_computations(self): """ Remove computations. To be done after computations have been resolved in a subquery """ - return Heading( [dict(v._asdict(),computation=None) for v in self.attrs.values()] ) - + return Heading([dict(v._asdict(), computation=None) for v in self.attributes.values()]) \ No newline at end of file diff --git a/datajoint/relational.py b/datajoint/relational.py index f02edbc30..f72bd1369 100644 --- a/datajoint/relational.py +++ b/datajoint/relational.py @@ -1,32 +1,28 @@ -# -*- coding: utf-8 -*- """ -Created on Thu Aug 7 17:00:02 2014 - -@author: dimitri, eywalker +classes for relational algebra """ + import numpy as np import abc from copy import copy from datajoint import DataJointError -from .fetch import Fetch +from .blob import unpack +import logging + +logger = logging.getLogger(__name__) + -class _Relational(metaclass=abc.ABCMeta): +class Relation(metaclass=abc.ABCMeta): """ - Relational implements relational operators. - Relational objects reference other relational objects linked by operators. - The leaves of this tree of objects are base relvars. - When fetching data from the database, this tree of objects is compiled into - and SQL expression. - It is a mixin class that provides relational operators, iteration, and - fetch capability. - Relational operators are: restrict, pro, aggr, and join. + Relation implements relational algebra and fetch methods. + Relation objects reference other relation objects linked by operators. + The leaves of this tree of objects are base relations. + When fetching data from the database, this tree of objects is compiled into an SQL expression. + It is a mixin class that provides relational operators, iteration, and fetch capability. + Relation operators are: restrict, pro, and join. """ _restrictions = [] - _limit = None - _offset = 0 - _order_by = [] - - #### abstract properties that subclasses must define ##### + @abc.abstractproperty def sql(self): return NotImplemented @@ -34,139 +30,217 @@ def sql(self): @abc.abstractproperty def heading(self): return NotImplemented + + @property + def restrictions(self): + return self._restrictions - ###### Relational algebra ############## def __mul__(self, other): - "relational join" - return Join(self,other) + """ + relational join + """ + return Join(self, other) - def pro(self, *arg, _sub=None, **kwarg): - "relational projection abd aggregation" - return Projection(self, _sub=_sub, *arg, **kwarg) + def __mod__(self, attribute_list): + """ + relational projection operator. + :param attribute_list: list of attribute specifications. + The attribute specifications are strings in following forms: + 'name' - specific attribute + 'name->new_name' - rename attribute. The old attribute is kept only if specifically included. + 'sql_expression->new_name' - extend attribute, i.e. a new computed attribute. + :return: a new relation with specified heading + """ + self.project(attribute_list) + + def project(self, *selection, **aliases): + """ + Relational projection operator. + :param attributes: a list of attribute names to be included in the result. + :param renames: a dict of attributes to be renamed + :return: a new relation with selected fields + Primary key attributes are always selected and cannot be excluded. + Therefore obj.project() produces a relation with only the primary key attributes. + If selection includes the string '*', all attributes are selected. + Each attribute can only be used once in attributes or renames. Therefore, the projected + relation cannot have more attributes than the original relation. + """ + return self.aggregate( + group=selection.pop[0] if selection and isinstance(selection[0], Relation) else None, + *selection, **aliases) + + def aggregate(self, group, *selection, **aliases): + """ + Relational aggregation operator + :param grouped_relation: + :param extensions: + :return: + """ + if group is not None and not isinstance(group, Relation): + raise DataJointError('The second argument of aggregate must be a relation') + # convert the string notation for aliases to + + return Projection(group=group, *attriutes, **aliases) def __iand__(self, restriction): - "in-place relational restriction or semijoin" + """ + in-place relational restriction or semijoin + """ if self._restrictions is None: self._restrictions = [] self._restrictions.append(restriction) return self def __and__(self, restriction): - "relational restriction or semijoin" + """ + relational restriction or semijoin + """ if self._restrictions is None: self._restrictions = [] - ret = copy(self) # todo: why not deepcopy it? + ret = copy(self) # todo: why not deepcopy it? ret._restrictions = list(ret._restrictions) # copy restriction ret &= restriction return ret def __isub__(self, restriction): - "in-place inverted restriction aka antijoin" + """ + in-place inverted restriction aka antijoin + """ self &= Not(restriction) return self def __sub__(self, restriction): - "inverted restriction aka antijoin" + """ + inverted restriction aka antijoin + """ return self & Not(restriction) - - ###### Fetching the data ############## @property def count(self): - sql = 'SELECT count(*) FROM ' + self.sql + self._whereClause + sql = 'SELECT count(*) FROM ' + self.sql + self._where_clause cur = self.conn.query(sql) - return cur.fetchone()[0] #todo: should we assert that this only returns one result? + return cur.fetchone()[0] - @property - def fetch(self): - return Fetch(self) + def __call__(self, offset=0, limit=None, order_by=None, descending=False): + """ + fetches the relation from the database table into an np.array and unpacks blob attributes. + :param offset: the number of tuples to skip in the returned result + :param limit: the maximum number of tuples to return + :param order_by: the list of attributes to order the results. No ordering should be assumed if order_by=None. + :param descending: the list of attributes to order the results + :return: the contents of the relation in the form of a structured numpy.array + """ + cur = self.cursor(offset, limit, order_by, descending) + ret = np.array(list(cur), dtype=self.heading.asdtype) + for f in self.heading.blobs: + for i in range(len(ret)): + ret[i][f] = unpack(ret[i][f]) + return ret + + def cursor(self, offset=0, limit=None, order_by=None, descending=False): + """ + :param offset: the number of tuples to skip in the returned result + :param limit: the maximum number of tuples to return + :param order_by: the list of attributes to order the results. No ordering should be assumed if order_by=None. + :param descending: the list of attributes to order the results + :return: cursor to the query + """ + if offset and limit is None: + raise DataJointError('') + sql = 'SELECT ' + self.heading.as_sql + ' FROM ' + self.sql + if order_by is not None: + sql += ' ORDER BY ' + ', '.join(self._orderBy) + if descending: + sql += ' DESC' + if limit is not None: + sql += ' LIMIT %d' % limit + if offset: + sql += ' OFFSET %d' % offset + logger.debug(sql) + return self.conn.query(sql) def __repr__(self): - header = self.heading.non_blobs limit = 13 width = 12 template = '%%-%d.%ds' % (width, width) - ret_val = ' '.join([template % column for column in header]) + '\n' - ret_val += ' '.join(['+' + '-' * (width - 2) + '+' for column in header]) + '\n' - - tuples = self.fetch.limit(limit)(*header) + repr_string = ' '.join([template % column for column in header]) + '\n' + repr_string += ' '.join(['+' + '-' * (width - 2) + '+' for _ in header]) + '\n' + tuples = self.pro(*self.heading.non_blobs).fetch(limit=limit) for tup in tuples: - ret_val += ' '.join([template % column for column in tup]) + '\n' - cnt = self.count - if cnt > limit: - ret_val += '...\n' - ret_val += '%d tuples\n' % self.count - return ret_val + repr_string += ' '.join([template % column for column in tup]) + '\n' + if self.count > limit: + repr_string += '...\n' + repr_string += '%d tuples\n' % self.count + return repr_string - ######## iterator ############### def __iter__(self): """ iterator yields primary key tuples """ - cur, h = self.fetch._cursor() - dtype = h.asdtype + cur, h = self.pro().cursor() q = cur.fetchone() while q: - yield np.array([q,],dtype=dtype) #todo: why convert that into an array? + yield np.array([q, ], dtype=h.asdtype) q = cur.fetchone() - @property - def _whereClause(self): - "make there WHERE clause based on the current restriction" - + def _where_clause(self): + """ + make there WHERE clause based on the current restriction + """ if not self._restrictions: return '' - - def makeCondition(arg): - if isinstance(arg,dict): - conds = ['`%s`=%s'%(k,repr(v)) for k,v in arg.items()] - elif isinstance(arg,np.void): - conds = ['`%s`=%s'%(k, arg[k]) for k in arg.dtype.fields] + + def make_condition(arg): + if isinstance(arg, dict): + conditions = ['`%s`=%s' % (k, repr(v)) for k, v in arg.items()] + elif isinstance(arg, np.void): + conditions = ['`%s`=%s' % (k, arg[k]) for k in arg.dtype.fields] else: raise DataJointError('invalid restriction type') - return ' AND '.join(conds) + return ' AND '.join(conditions) - condStr = [] + condition_string = [] for r in self._restrictions: - negate = isinstance(r,Not) + negate = isinstance(r, Not) if negate: - r = r._restriction - if isinstance(r,dict) or isinstance(r,np.void): - r = makeCondition(r) - elif isinstance(r,np.ndarray) or isinstance(r,list): - r = '('+') OR ('.join([makeCondition(q) for q in r])+')' - elif isinstance(r,_Relational): - commonAttrs = ','.join([q for q in self.heading.names if r.heading.names]) - r = '(%s) in (SELECT %s FROM %s)' % (commonAttrs, commonAttrs, r.sql) + r = r.restrictions + if isinstance(r, dict) or isinstance(r, np.void): + r = make_condition(r) + elif isinstance(r, np.ndarray) or isinstance(r, list): + r = '('+') OR ('.join([make_condition(q) for q in r])+')' + elif isinstance(r, Relation): + common_attributes = ','.join([q for q in self.heading.names if r.heading.names]) + r = '(%s) in (SELECT %s FROM %s)' % (common_attributes, common_attributes, r.sql) - assert isinstance(r,str), 'condition must be converted into a string' + assert isinstance(r, str), 'condition must be converted into a string' r = '('+r+')' if negate: - r = 'NOT '+r; - condStr.append(r) + r = 'NOT '+r + condition_string.append(r) - return ' WHERE ' + ' AND '.join(condStr) + return ' WHERE ' + ' AND '.join(condition_string) class Not: - "inverse of a restriction" - def __init__(self,restriction): + """ + inverse of a restriction + """ + def __init__(self, restriction): self._restriction = restriction -class Join(_Relational): +class Join(Relation): + alias_counter = 0 - aliasCounter = 0 - - def __init__(self,rel1,rel2): - if not isinstance(rel2,_Relational): + def __init__(self, rel1, rel2): + if not isinstance(rel2, Relation): raise DataJointError('relvars can only be joined with other relvars') if rel1.conn is not rel2.conn: raise DataJointError('Cannot join relations with different database connections') self.conn = rel1.conn - self._rel1 = rel1; - self._rel2 = rel2; + self._rel1 = rel1 + self._rel2 = rel2 @property def heading(self): @@ -174,23 +248,22 @@ def heading(self): @property def sql(self): - Join.aliasCounter += 1 - return '%s NATURAL JOIN %s as `j%x`' % (self._rel1.sql, self._rel2.sql, Join.aliasCounter) + Join.alias_counter += 1 + return '%s NATURAL JOIN %s as `j%x`' % (self._rel1.sql, self._rel2.sql, Join.alias_counter) -class Projection(_Relational): +class Projection(Relation): + alias_counter = 0 - aliasCounter = 0 + def __init__(self, relation, *attributes, **renames): + """ + See Relation.project() + """ + self.conn = relation.conn + self._relation = relation + self._projection_attributes = attributes + self._renamed_attributes = renames - def __init__(self, rel, *arg, _sub, **kwarg): - if _sub and isinstance(_sub, _Relational): - raise DataJointError('Relational join must receive two relations') - self.conn = rel.conn - self._rel = rel - self._sub = _sub - self._selection = arg - self._renames = kwarg - @property def sql(self): return self._rel.sql @@ -200,19 +273,18 @@ def heading(self): return self._rel.heading.pro(*self._selection, **self._renames) -class Subquery(_Relational): - - aliasCounter = 0; +class Subquery(Relation): + alias_counter = 0 def __init__(self, rel): - self.conn = rel.conn; - self._rel = rel; + self.conn = rel.conn + self._rel = rel @property def sql(self): - self.aliasCounter = self.aliasCounter + 1; - return '(SELECT ' + self._rel.heading.asSQL + ' FROM ' + self._rel.sql + ') as `s%x`' % self.aliasCounter + self.alias_counter += 1 + return '(SELECT ' + self._rel.heading.as_sql + ' FROM ' + self._rel.sql + ') as `s%x`' % self.alias_counter @property def heading(self): - return self._rel.heading.resolveComputations() + return self._rel.heading.resolve_computations() \ No newline at end of file diff --git a/datajoint/settings.py b/datajoint/settings.py index 9ad4d97c6..22afcadb6 100644 --- a/datajoint/settings.py +++ b/datajoint/settings.py @@ -8,10 +8,22 @@ __author__ = 'eywalker' import logging import collections +from enum import Enum validators = collections.defaultdict(lambda: lambda value: True) +Role = Enum('Role', 'manual lookup imported computed job') +role_to_prefix = { + Role.manual: '', + Role.lookup: '#', + Role.imported: '_', + Role.computed: '__', + Role.job: '~' +} +prefix_to_role = dict(zip(role_to_prefix.values(), role_to_prefix.keys())) + + default = { 'database.host': 'localhost', 'database.password': 'datajoint', @@ -24,6 +36,7 @@ 'config.varname': 'DJ_LOCAL_CONF' } + class Config(collections.MutableMapping): """ Stores datajoint settings. Behaves like a dictionary, but applies validator functions @@ -31,7 +44,6 @@ class Config(collections.MutableMapping): The default parameters are stored in datajoint.settings.default . If a local config file exists, the settings specified in this file override the default settings. - """ def __init__(self, *args, **kwargs): @@ -90,4 +102,4 @@ def load(self, filename): ############################################################################# logger = logging.getLogger() -logger.setLevel(logging.DEBUG) #set package wide logger level TODO:make this respond to environmental variable +logger.setLevel(logging.DEBUG) #set package wide logger level TODO:make this respond to environmental variable \ No newline at end of file diff --git a/datajoint/table.py b/datajoint/table.py new file mode 100644 index 000000000..d1c6c3bb9 --- /dev/null +++ b/datajoint/table.py @@ -0,0 +1,228 @@ +import numpy as np +import logging +from . import DataJointError +from .relational import Relation +from .declare import declare + +logger = logging.getLogger(__name__) + + +class Table(Relation): + """ + A Table object is a relation associated with a table. + A Table object provides insert and delete methods. + Table objects are only used internally and for debugging. + The table must already exist in the schema for its Table object to work. + + The table associated with an instance of Base is identified by its 'class name'. + property, which is a string in CamelCase. The actual table name is obtained + by converting className from CamelCase to underscore_separated_words and + prefixing according to the table's role. + + Base instances obtain their table's heading by looking it up in the connection + object. This ensures that Base instances contain the current table definition + even after tables are modified after the instance is created. + """ + + def __init__(self, conn=None, dbname=None, class_name=None, definition=None): + self._use_package = False + self.class_name = class_name + self.conn = conn + self.dbname = dbname + self.conn.load_headings(self.dbname) + + if dbname not in self.conn.db_to_mod: + # register with a fake module, enclosed in back quotes + self.conn.bind('`{0}`'.format(dbname), dbname) + + #TODO: delay the loading until first use (move out of __init__) + self.conn.load_headings() + if self.class_name not in self.conn.table_names[self.dbname]: + if definition is None: + raise DataJointError('The table is not declared') + else: + declare(conn, definition, class_name) + + @property + def sql(self): + return self.full_table_name + + @property + def heading(self): + return self.conn.headings[self.dbname][self.table_name] + + @property + def full_table_name(self): + """ + :return: full name of the associated table + """ + return '`%s`.`%s`' % (self.dbname, self.table_name) + + @property + def table_name(self): + """ + :return: name of the associated table + """ + return self.conn.table_names[self.dbname][self.class_name] + + + @property + def full_class_name(self): + """ + :return: full class name + """ + return '{}.{}'.format(self.__module__, self.class_name) + + @property + def primary_key(self): + """ + :return: primary key of the table + """ + return self.heading.primary_key + + def insert(self, tup, ignore_errors=False, replace=False): # TODO: in progress (issue #8) + """ + Insert one data tuple. + + :param tup: Data tuple. Can be an iterable in matching order, a dict with named fields, or an np.void. + :param ignore_errors=False: Ignores errors if True. + :param replace=False: Replaces data tuple if True. + + Example:: + + b = djtest.Subject() + b.insert(dict(subject_id = 7, species="mouse",\\ + real_id = 1007, date_of_birth = "2014-09-01")) + """ + # todo: do we support records and named tuples for tup? + + if issubclass(type(tup), tuple) or issubclass(type(tup), list): + value_list = ','.join([repr(q) for q in tup]) + attribute_list = '`'+'`,`'.join(self.heading.names[0:len(tup)]) + '`' + elif issubclass(type(tup), dict): + value_list = ','.join([repr(tup[q]) + for q in self.heading.names if q in tup]) + attribute_list = '`' + '`,`'.join([q for q in self.heading.names if q in tup]) + '`' + elif issubclass(type(tup), np.void): + value_list = ','.join([repr(tup[q]) + for q in self.heading.names if q in tup]) + attribute_list = '`' + '`,`'.join(tup.dtype.fields) + '`' + else: + raise DataJointError('Datatype %s cannot be inserted' % type(tup)) + if replace: + sql = 'REPLACE' + elif ignore_errors: + sql = 'INSERT IGNORE' + else: + sql = 'INSERT' + sql += " INTO %s (%s) VALUES (%s)" % (self.full_table_name, + attribute_list, value_list) + logger.info(sql) + self.conn.query(sql) + + def delete(self): # TODO: (issues #14 and #15) + pass + + def drop(self): + """ + Drops the table associated to this object. + """ + # TODO: make cascading (issue #16) + self.conn.query('DROP TABLE %s' % self.full_table_name) + self.conn.clear_dependencies(dbname=self.dbname) + self.conn.load_headings(dbname=self.dbname, force=True) + logger.debug("Dropped table %s" % self.full_table_name) + + def set_table_comment(self, comment): + """ + Update the table comment in the table definition. + + :param comment: new comment as string + + """ + # TODO: add verification procedure (github issue #24) + self.alter('COMMENT="%s"' % comment) + + def add_attribute(self, definition, after=None): + """ + Add a new attribute to the table. A full line from the table definition + is passed in as definition. + + The definition can specify where to place the new attribute. Use after=None + to add the attribute as the first attribute or after='attribute' to place it + after an existing attribute. + + :param definition: table definition + :param after=None: After which attribute of the table the new attribute is inserted. + If None, the attribute is inserted in front. + """ + position = ' FIRST' if after is None else ( + ' AFTER %s' % after if after else '') + sql = self.field_to_sql(parse_attribute_definition(definition)) + self._alter('ADD COLUMN %s%s' % (sql[:-2], position)) + + def drop_attribute(self, attr_name): + """ + Drops the attribute attrName from this table. + + :param attr_name: Name of the attribute that is dropped. + """ + self._alter('DROP COLUMN `%s`' % attr_name) + + def alter_attribute(self, attr_name, new_definition): + """ + Alter the definition of the field attr_name in this table using the new definition. + + :param attr_name: field that is redefined + :param new_definition: new definition of the field + """ + sql = self.field_to_sql(parse_attribute_definition(new_definition)) + self._alter('CHANGE COLUMN `%s` %s' % (attr_name, sql[:-2])) + + def erd(self, subset=None, prog='dot'): + """ + Plot the schema's entity relationship diagram (ERD). + The layout programs can be 'dot' (default), 'neato', 'fdp', 'sfdp', 'circo', 'twopi' + """ + if not subset: + g = self.graph + else: + g = self.graph.copy() + # todo: make erd work (github issue #7) + """ + g = self.graph + else: + g = self.graph.copy() + for i in g.nodes(): + if i not in subset: + g.remove_node(i) + def tablelist(tier): + return [i for i in g if self.tables[i].tier==tier] + + pos=nx.graphviz_layout(g,prog=prog,args='') + plt.figure(figsize=(8,8)) + nx.draw_networkx_edges(g, pos, alpha=0.3) + nx.draw_networkx_nodes(g, pos, nodelist=tablelist('manual'), + node_color='g', node_size=200, alpha=0.3) + nx.draw_networkx_nodes(g, pos, nodelist=tablelist('computed'), + node_color='r', node_size=200, alpha=0.3) + nx.draw_networkx_nodes(g, pos, nodelist=tablelist('imported'), + node_color='b', node_size=200, alpha=0.3) + nx.draw_networkx_nodes(g, pos, nodelist=tablelist('lookup'), + node_color='gray', node_size=120, alpha=0.3) + nx.draw_networkx_labels(g, pos, nodelist = subset, font_weight='bold', font_size=9) + nx.draw(g,pos,alpha=0,with_labels=false) + plt.show() + """ + + def _alter(self, alter_statement): + """ + Execute ALTER TABLE statement for this table. The schema + will be reloaded within the connection object. + + :param alter_statement: alter statement + """ + sql = 'ALTER TABLE %s %s' % (self.full_table_name, alter_statement) + self.conn.query(sql) + self.conn.load_headings(self.dbname, force=True) + # TODO: place table definition sync mechanism \ No newline at end of file diff --git a/datajoint/task.py b/datajoint/task.py deleted file mode 100644 index 156c46281..000000000 --- a/datajoint/task.py +++ /dev/null @@ -1,53 +0,0 @@ -import queue -import threading - - -def _ping(): - print("The task thread is running") - - -class TaskQueue: - """ - Executes tasks in a single parallel thread in FIFO sequence. - Example: - queue = TaskQueue() - queue.submit(func1, arg1, arg2, arg3) - queue.submit(func2) - queue.quit() # wait until the last task is done and stop thread - - Datajoint applications may use a task queue for delayed inserts. - """ - def __init__(self): - self.queue = queue.Queue() - self.thread = threading.Thread(target=self._worker) - self.thread.daemon = True - self.thread.start() - - def empty(self): - return self.queue.empty() - - def submit(self, func=_ping, *args): - """Submit task for execution""" - self.queue.put((func, args)) - - def quit(self, timeout=3.0): - """Wait until all tasks finish""" - self.queue.put('quit') - self.thread.join(timeout) - if self.thread.isAlive(): - raise Exception('Task thread is still executing. Try quitting again.') - - def _worker(self): - while True: - msg = self.queue.get() - if msg=='quit': - self.queue.task_done() - break - fun, args = msg - try: - fun(*args) - except Exception as e: - print("Exception in the task thread:") - print(e) - self.queue.task_done() - diff --git a/datajoint/utils.py b/datajoint/utils.py index 9d1bf85de..47aacdeeb 100644 --- a/datajoint/utils.py +++ b/datajoint/utils.py @@ -1,13 +1,7 @@ import re -# package-wide settings that control execution - -# setup root logger from . import DataJointError - - - def to_camel_case(s): """ Convert names with under score (_) separation @@ -24,7 +18,7 @@ def to_upper(matchobj): def from_camel_case(s): """ - Conver names in camel case into underscore + Convert names in camel case into underscore (_) separated names Example: @@ -37,7 +31,8 @@ def from_camel_case(s): raise DataJointError('String cannot begin with a digit') if not re.match(r'^[a-zA-Z0-9]*$', s): raise DataJointError('String can only contain alphanumeric characters') + def conv(matchobj): return ('_' if matchobj.groups()[0] else '') + matchobj.group(0).lower() - return re.sub(r'(\B[A-Z])|(\b[A-Z])', conv, s) + return re.sub(r'(\B[A-Z])|(\b[A-Z])', conv, s) \ No newline at end of file diff --git a/demos/demo1.py b/demos/demo1.py index 18e54a4bf..689905730 100644 --- a/demos/demo1.py +++ b/demos/demo1.py @@ -6,6 +6,7 @@ """ import datajoint as dj +import os print("Welcome to the database 'demo1'") @@ -13,11 +14,9 @@ conn.bind(module=__name__, dbname='dj_test') # bind this module to the database - class Subject(dj.Base): - _table_def = """ + definition = """ demo1.Subject (manual) # Basic subject info - subject_id : int # internal subject id --- real_id : varchar(40) # real-world name @@ -28,41 +27,49 @@ class Subject(dj.Base): animal_notes="" : varchar(4096) # strain, genetic manipulations, etc """ -class Exp2(dj.Base): - pass class Experiment(dj.Base): - _table_def = """ + definition = """ demo1.Experiment (manual) # Basic subject info - -> demo1.Subject experiment : smallint # experiment number for this subject --- + experiment_folder : varchar(255) # folder path experiment_date : date # experiment start date experiment_notes="" : varchar(4096) experiment_ts=CURRENT_TIMESTAMP : timestamp # automatic timestamp """ -class TwoPhotonSession(dj.Base): - _table_def = """ - demo1.TwoPhotonSession (manual) # a two-photon imaging session - +class Session(dj.Base): + definition = """ + demo1.Session (manual) # a two-photon imaging session -> demo1.Experiment - tp_session : tinyint # two-photon session within this experiment - ---- + session_id : tinyint # two-photon session within this experiment + ----------- setup : tinyint # experimental setup - lens : tinyint # lens e.g.: 10x, 20x. 25x, 60x + lens : tinyint # lens e.g.: 10x, 20x, 25x, 60x """ -class EphysSetup(dj.Base): - _table_def = """ - demo1.EphysSetup (manual) # Ephys setup - setup_id : tinyint # unique seutp id + + +class Scan(dj.Base): + definition = """ + demo1.Scan (manual) # a two-photon imaging session + -> demo1.Session + scan_id : tinyint # two-photon session within this experiment + ---- + depth : float # depth from surface + wavelength : smallint # (nm) laser wavelength + mwatts: numeric(4,1) # (mW) laser power to brain """ -class EphysExperiment(dj.Base): - _table_def = """ - demo1.EphysExperiment (manual) # Ephys experiment - -> demo1.Subject - -> demo1.EphysSetup - """ \ No newline at end of file + +class ScanInfo(dj.Base, dj.AutoPopulate): + definition = None + pop_rel = Session + + def make_tuples(self, key): + info = (Session()*Scan() & key).pro('experiment_folder').fetch() + filename = os.path.join(info.experiment_folder, 'scan_%03', ) + + diff --git a/demos/rundemo1.py b/demos/rundemo1.py index ece931611..88a8f4502 100644 --- a/demos/rundemo1.py +++ b/demos/rundemo1.py @@ -8,35 +8,35 @@ import demo1 s = demo1.Subject() -# insert as dict +e = demo1.Experiment() + s.insert(dict(subject_id=1, - real_id='George', + real_id="George", species="monkey", date_of_birth="2011-01-01", sex="M", caretaker="Arthur", animal_notes="this is a test")) + s.insert(dict(subject_id=2, real_id='1373', date_of_birth="2014-08-01", caretaker="Joe")) -# insert as tuple. Attributes must be in the same order as in table declaration -s.insert((3,'Dennis','monkey','2012-09-01')) - -# TODO: insert as ndarray +s.insert((3, 'Dennis', 'monkey', '2012-09-01')) +s.insert((12430, 'C0430', 'mouse', '2012-09-01', 'M')) +s.insert((12431, 'C0431', 'mouse', '2012-09-01', 'F')) print('inserted keys into Subject:') for key in s: print(key) +e.insert(dict(subject_id=1, + experiment=1, + experiment_date="2014-08-28", + experiment_notes="my first experiment")) -# -e = demo1.Experiment() -e.insert(dict(subject_id=1,experiment=1,experiment_date="2014-08-28",experiment_notes="my first experiment")) -e.insert(dict(subject_id=1,experiment=2,experiment_date="2014-08-28",experiment_notes="my second experiment")) - - -# drop the tables -#s.drop -#e.drop +e.insert(dict(subject_id=1, + experiment=2, + experiment_date="2014-08-28", + experiment_notes="my second experiment"))