diff --git a/datajoint/__init__.py b/datajoint/__init__.py
index f44d86939..9e0142e2b 100644
--- a/datajoint/__init__.py
+++ b/datajoint/__init__.py
@@ -1,22 +1,20 @@
import logging
import os
-__author__ = "Dimitri Yatsenko, Edgar Walker, and Fabian Sinz at Baylor College of Medicine"
+__author__ = "Dimitri Yatsenko, Edgar Walker, and Fabian Sinz at Baylor College of Medicine"
__version__ = "0.2"
__all__ = ['__author__', '__version__',
'Connection', 'Heading', 'Base', 'Not',
- 'AutoPopulate', 'TaskQueue', 'conn', 'DataJointError', 'blob']
+ 'AutoPopulate', 'conn', 'DataJointError', 'blob']
+
-# ------------ define datajoint error before the import hierarchy is flattened ------------
class DataJointError(Exception):
"""
- Base class for errors specific to DataJoint internal
- operation.
+ Base class for errors specific to DataJoint internal operation.
"""
pass
-
# ----------- loads local configuration from file ----------------
from .settings import Config, logger
config = Config()
@@ -37,10 +35,6 @@ class DataJointError(Exception):
# ------------- flatten import hierarchy -------------------------
from .connection import conn, Connection
from .base import Base
-from .task import TaskQueue
from .autopopulate import AutoPopulate
from . import blob
-from .relational import Not
-
-
-
+from .relational import Not
\ No newline at end of file
diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py
index 30996ac06..8e32ef43d 100644
--- a/datajoint/autopopulate.py
+++ b/datajoint/autopopulate.py
@@ -1,54 +1,48 @@
-from .relational import _Relational
+from .relational import Relation
+from . import DataJointError
import pprint
import abc
#noinspection PyExceptionInherit,PyCallingNonCallable
+
+
class AutoPopulate(metaclass=abc.ABCMeta):
"""
- Class datajoint.AutoPopulate is a mixin that adds the method populate() to a dj.Relvar class.
- Auto-populated relvars must inherit from both datajoint.Relvar and datajoint.AutoPopulate,
- must define the property popRel, and must define the callback method makeTuples.
+ AutoPopulate is a mixin class that adds the method populate() to a Base class.
+ Auto-populated relations must inherit from both Base and AutoPopulate,
+ must define the property pop_rel, and must define the callback method make_tuples.
"""
@abc.abstractproperty
- def popRel(self):
+ def pop_rel(self):
"""
- Derived classes must implement the read-only property popRel (populate relation) which is the relational
- expression (a dj.Relvar object) that defines how keys are generated for the populate call.
+ Derived classes must implement the read-only property pop_rel (populate relation) which is the relational
+ expression (a Relation object) that defines how keys are generated for the populate call.
"""
pass
-
@abc.abstractmethod
- def makeTuples(self, key):
+ def make_tuples(self, key):
"""
- Derived classes must implement methods makeTuples that fetches data from parent tables, restricting by
+ Derived classes must implement method make_tuples that fetches data from parent tables, restricting by
the given key, computes dependent attributes, and inserts the new tuples into self.
"""
pass
-
- def populate(self, catchErrors=False, reserveJobs=False, restrict=None):
+ def populate(self, catch_errors=False, reserve_jobs=False, restrict=None):
"""
- rel.populate() will call rel.makeTuples(key) for every primary key in self.popRel
+ rel.populate() will call rel.make_tuples(key) for every primary key in self.pop_rel
for which there is not already a tuple in rel.
"""
-
+ if not isinstance(self.pop_rel, Relation):
+ raise DataJointError('')
self.conn.cancel_transaction()
- # enumerate unpopulated keys
- unpopulated = self.popRel
- if ~isinstance(unpopulated, _Relational):
- unpopulated = unpopulated() # instantiate
-
+ unpopulated = self.pop_rel - self
if not unpopulated.count:
- print('Nothing to populate')
- else:
- unpopulated = unpopulated(*args, **kwargs) # - self # TODO: implement antijoin
-
- # execute
- if catchErrors:
- errKeys, errors = [], []
+ print('Nothing to populate', flush=True) # TODO: use logging?
+ if catch_errors:
+ error_keys, errors = [], []
for key in unpopulated.fetch():
self.conn.start_transaction()
n = self(key).count
@@ -57,17 +51,16 @@ def populate(self, catchErrors=False, reserveJobs=False, restrict=None):
else:
print('Populating:')
pprint.pprint(key)
-
try:
- self.makeTuples(key)
+ self.make_tuples(key)
except Exception as e:
self.conn.cancel_transaction()
- if not catchErrors:
+ if not catch_errors:
raise
print(e)
errors += [e]
- errKeys+= [key]
+ error_keys += [key]
else:
self.conn.commit_transaction()
- if catchErrors:
- return errors, errKeys
+ if catch_errors:
+ return errors, error_keys
\ No newline at end of file
diff --git a/datajoint/base.py b/datajoint/base.py
index a43c60268..96a34859d 100644
--- a/datajoint/base.py
+++ b/datajoint/base.py
@@ -1,70 +1,18 @@
import importlib
-import re
+import abc
from types import ModuleType
-import numpy as np
from enum import Enum
-from .utils import from_camel_case
from . import DataJointError
-from .relational import _Relational
-from .heading import Heading
+from .table import Table
import logging
-
-# table names have prefixes that designate their roles in the processing chain
logger = logging.getLogger(__name__)
-# Todo: Shouldn't this go into the settings module?
-Role = Enum('Role', 'manual lookup imported computed job')
-role_to_prefix = {
- Role.manual: '',
- Role.lookup: '#',
- Role.imported: '_',
- Role.computed: '__',
- Role.job: '~'
-}
-prefix_to_role = dict(zip(role_to_prefix.values(), role_to_prefix.keys()))
-
-mysql_constants = ['CURRENT_TIMESTAMP']
-
-
-class Base(_Relational):
+class Base(Table, metaclass=abc.ABCMeta):
"""
- Base integrates all data manipulation and data declaration functions.
- An instance of the class provides an interface to a single table in the database.
-
- An instance of the class can be produced in two ways:
-
- 1. direct instantiation (used mostly for debugging and odd jobs)
-
- 2. instantiation from a derived class (regular recommended use)
-
- With direct instantiation, instance parameters must be explicitly specified.
- With a derived class, all the instance parameters are taken from the module
- of the deriving class. The module must declare the connection object conn.
- The name of the deriving class is used as the table's className.
-
- The table associated with an instance of Base is identified by the className
- property, which is a string in CamelCase. The actual table name is obtained
- by converting className from CamelCase to underscore_separated_words and
- prefixing according to the table's role.
-
- The table declaration can be specified in the doc string of the inheriting
- class, in the DataJoint table declaration syntax.
-
- Base also implements the methods insert and delete to insert and delete tuples
- from the table. It can also be an argument in relational operators: restrict,
- join, pro, and aggr. See class :mod:`datajoint.relational`.
-
- Base instances return their table's heading by looking it up in the connection
- object. This ensures that Base instances contain the current table definition
- even after tables are modified after the instance is created.
-
- :param conn=None: :mod:`datajoint.connection.Connection` object. Only used when Base is
- instantiated directly.
- :param dbname=None: Name of the database. Only used when Base is instantiated directly.
- :param class_name=None: Class name. Only used when Base is instantiated directly.
- :param table_def=None: Declaration of the table. Only used when Base is instantiated directly.
+ Base is a Table that implements data definition functions.
+ It is an abstract class with the abstract property 'definition'.
Example for a usage of Base::
@@ -72,9 +20,8 @@ class Base(_Relational):
class Subjects(dj.Base):
- _table_def = '''
+ definition = '''
test1.Subjects (manual) # Basic subject info
-
subject_id : int # unique subject id
---
real_id : varchar(40) # real-world name
@@ -83,248 +30,39 @@ class Subjects(dj.Base):
"""
- def __init__(self, conn=None, dbname=None, class_name=None, table_def=None):
- self._use_package = False
- if self.__class__ is Base:
- # instantiate without subclassing
- if not (conn and dbname and class_name):
- raise DataJointError(
- 'Missing argument: please specify conn, dbname, and class name.')
- self.class_name = class_name
- self.conn = conn
- self.dbname = dbname
- self._table_def = table_def
- # register with a fake module, enclosed in back quotes
+ @abc.abstractproperty
+ def definition(self):
+ """
+ :return: string containing the table declaration using the DataJoint Data Definition Language.
+ The DataJoint DDL is described at: TODO
+ """
+ pass
- if dbname not in self.conn.db_to_mod:
- self.conn.bind('`{0}`'.format(dbname), dbname)
- else:
- # instantiate a derived class
- if conn or dbname or class_name or table_def:
- raise DataJointError(
- 'With derived classes, constructor arguments are ignored') # TODO: consider changing this to a warning instead
- self.class_name = self.__class__.__name__
- module = self.__module__
- mod_obj = importlib.import_module(module)
+ def __init__(self):
+ self.class_name = self.__class__.__name__
+ module = self.__module__
+ mod_obj = importlib.import_module(module)
+ try:
+ conn = mod_obj.conn
+ except AttributeError:
try:
- self.conn = mod_obj.conn
+ pkg_obj = importlib.import_module(mod_obj.__package__)
+ conn = pkg_obj.conn
+ use_package = True
except AttributeError:
- try:
- pkg_obj = importlib.import_module(mod_obj.__package__)
- self.conn = pkg_obj.conn
- self._use_package = True
- except AttributeError:
- raise DataJointError(
- "Please define object 'conn' in '{}' or in its containing package.".format(self.__module__))
- try:
- if (self._use_package):
- pkg_name = '.'.join(module.split('.')[:-1])
- self.dbname = self.conn.mod_to_db[pkg_name]
- else:
- self.dbname = self.conn.mod_to_db[module]
- except KeyError:
raise DataJointError(
- 'Module {} is not bound to a database. See datajoint.connection.bind'.format(self.__module__))
-
- if hasattr(self, '_table_def'):
- self._table_def = self._table_def
+ "Please define object 'conn' in '{}' or in its containing package.".format(self.__module__))
+ try:
+ if use_package:
+ pkg_name = '.'.join(module.split('.')[:-1])
+ dbname = self.conn.mod_to_db[pkg_name]
else:
- self._table_def = None
-
- # todo: do we support records and named tuples for tup?
- def insert(self, tup, ignore_errors=False, replace=False):
- """
- Insert one data tuple.
-
- :param tup: Data tuple. Can be an iterable in matching order, a dict with named fields, or an np.void.
- :param ignore_errors=False: Ignores errors if True.
- :param replace=False: Replaces data tuple if True.
-
- Example::
-
- b = djtest.Subject()
- b.insert( dict(subject_id = 7, species="mouse",\\
- real_id = 1007, date_of_birth = "2014-09-01") )
- """
-
- if issubclass(type(tup), tuple) or issubclass(type(tup), list):
- valueList = ','.join([repr(q) for q in tup])
- fieldList = '`' + '`,`'.join(self.heading.names[0:len(tup)]) + '`'
- elif issubclass(type(tup), dict):
- valueList = ','.join([repr(tup[q])
- for q in self.heading.names if q in tup])
- fieldList = '`' + \
- '`,`'.join([q for q in self.heading.names if q in tup]) + '`'
- elif issubclass(type(tup), np.void):
- valueList = ','.join([repr(tup[q])
- for q in self.heading.names if q in tup])
- fieldList = '`' + '`,`'.join(tup.dtype.fields) + '`'
- else:
- raise DataJointError('Datatype %s cannot be inserted' % type(tup))
- if replace:
- sql = 'REPLACE'
- elif ignore_errors:
- sql = 'INSERT IGNORE'
- else:
- sql = 'INSERT'
- sql += " INTO %s (%s) VALUES (%s)" % (self.full_table_name,
- fieldList, valueList)
- logger.info(sql)
- self.conn.query(sql)
-
- def drop(self):
- """
- Drops the table associated to this object.
- """
- # TODO make cascading (github issue #16)
- self.conn.query('DROP TABLE %s' % self.full_table_name)
- self.conn.clear_dependencies(dbname=self.dbname)
- self.conn.load_headings(dbname=self.dbname, force=True)
- logger.debug("Dropped table %s" % self.full_table_name)
-
- @property
- def sql(self):
- return self.full_table_name + self._whereClause
-
- @property
- def heading(self):
- self.declare()
- return self.conn.headings[self.dbname][self.table_name]
-
- @property
- def is_declared(self):
- """
- :returns: True if table is found in the database
- """
- self.conn.load_headings(self.dbname)
- return self.class_name in self.conn.table_names[self.dbname]
-
- @property
- def table_name(self):
- """
- :return: name of the associated table
- """
- self.declare()
- return self.conn.table_names[self.dbname][self.class_name]
-
- @property
- def full_table_name(self):
- """
- :return: full name of the associated table
- """
- return '`%s`.`%s`' % (self.dbname, self.table_name)
-
- @property
- def full_class_name(self):
- """
- :return: full class name
- """
- return '{}.{}'.format(self.__module__, self.class_name)
-
- @property
- def primary_key(self):
- """
- :return: primary key of the table
- """
- return self.heading.primary_key
-
- def declare(self):
- """
- Declare the table in database if it doesn't already exist.
-
- :raises: DataJointError if the table cannot be declared.
- """
- if not self.is_declared:
- self._declare()
- if not self.is_declared:
- raise DataJointError(
- 'Table could not be declared for %s' % self.class_name)
-
- """
- Data definition functionalities
- """
-
- def set_table_comment(self, newComment):
- """
- Update the table comment in the table declaration.
-
- :param newComment: new comment as string
-
- """
- # TODO: add verification procedure (github issue #24)
- self.alter('COMMENT="%s"' % newComment)
-
- def add_attribute(self, definition, after=None):
- """
- Add a new attribute to the table. A full line from the table definition
- is passed in as definition.
-
- The definition can specify where to place the new attribute. Use after=None
- to add the attribute as the first attribute or after='attribute' to place it
- after an existing attribute.
-
- :param definition: table definition
- :param after=None: After which attribute of the table the new attribute is inserted.
- If None, the attribute is inserted in front.
- """
- position = ' FIRST' if after is None else (
- ' AFTER %s' % after if after else '')
- sql = self._field_to_SQL(self._parse_attr_def(definition))
- self._alter('ADD COLUMN %s%s' % (sql[:-2], position))
-
- def drop_attribute(self, attr_name):
- """
- Drops the attribute attrName from this table.
-
- :param attr_name: Name of the attribute that is dropped.
- """
- self._alter('DROP COLUMN `%s`' % attr_name)
-
- def alter_attribute(self, attr_name, new_definition):
- """
- Alter the definition of the field attr_name in this table using the new definition.
-
- :param attr_name: field that is redefined
- :param new_definition: new definition of the field
- """
- sql = self._field_to_SQL(self._parse_attr_def(new_definition))
- self._alter('CHANGE COLUMN `%s` %s' % (attr_name, sql[:-2]))
-
- def erd(self, subset=None, prog='dot'):
- """
- Plot the schema's entity relationship diagram (ERD).
- The layout programs can be 'dot' (default), 'neato', 'fdp', 'sfdp', 'circo', 'twopi'
- """
- if not subset:
- g = self.graph
- else:
- g = self.graph.copy()
- # todo: make erd work (github issue #7)
- """
- g = self.graph
- else:
- g = self.graph.copy()
- for i in g.nodes():
- if i not in subset:
- g.remove_node(i)
- def tablelist(tier):
- return [i for i in g if self.tables[i].tier==tier]
+ dbname = self.conn.mod_to_db[module]
+ except KeyError:
+ raise DataJointError(
+ 'Module {} is not bound to a database. See datajoint.connection.bind'.format(self.__module__))
+ super().__init__(self, conn=conn, dbname=dbname, class_name=self.__class__.__name__)
- pos=nx.graphviz_layout(g,prog=prog,args='')
- plt.figure(figsize=(8,8))
- nx.draw_networkx_edges(g, pos, alpha=0.3)
- nx.draw_networkx_nodes(g, pos, nodelist=tablelist('manual'),
- node_color='g', node_size=200, alpha=0.3)
- nx.draw_networkx_nodes(g, pos, nodelist=tablelist('computed'),
- node_color='r', node_size=200, alpha=0.3)
- nx.draw_networkx_nodes(g, pos, nodelist=tablelist('imported'),
- node_color='b', node_size=200, alpha=0.3)
- nx.draw_networkx_nodes(g, pos, nodelist=tablelist('lookup'),
- node_color='gray', node_size=120, alpha=0.3)
- nx.draw_networkx_labels(g, pos, nodelist = subset, font_weight='bold', font_size=9)
- nx.draw(g,pos,alpha=0,with_labels=false)
- plt.show()
- """
@classmethod
def get_module(cls, module_name):
@@ -355,268 +93,3 @@ def get_module(cls, module_name):
return importlib.import_module(module_name)
except ImportError:
return None
-
- def get_base(self, module_name, class_name):
- """
- Loads the base relation from the module. If the base relation is not defined in
- the module, then construct it using Base constructor.
-
- :param module_name: module name
- :param class_name: class name
- :returns: the base relation
- """
- mod_obj = self.get_module(module_name)
- try:
- ret = getattr(mod_obj, class_name)()
- except KeyError:
- ret = self.__class__(conn=self.conn,
- dbname=self.conn.schemas[module_name],
- class_name=class_name)
- return ret
-
- # ////////////////////////////////////////////////////////////
- # Private Methods
- # ////////////////////////////////////////////////////////////
-
- def _field_to_SQL(self, field):
- """
- Converts an attribute definition tuple into SQL code.
-
- :param field: attribute definition
- :rtype : SQL code
- """
- if field.isNullable:
- default = 'DEFAULT NULL'
- else:
- default = 'NOT NULL'
- # if some default specified
- if field.default:
- # enclose value in quotes (even numeric), except special SQL values
- # or values already enclosed by the user
- if field.default.upper() in mysql_constants or field.default[:1] in ["'", '"']:
- default = '%s DEFAULT %s' % (default, field.default)
- else:
- default = '%s DEFAULT "%s"' % (default, field.default)
-
- # TODO: escape instead! - same goes for Matlab side implementation
- assert not any((c in r'\"' for c in field.comment)), \
- 'Illegal characters in attribute comment "%s"' % field.comment
-
- return '`{name}` {type} {default} COMMENT "{comment}",\n'.format(
- name=field.name, type=field.type, default=default, comment=field.comment)
-
- def _alter(self, alter_statement):
- """
- Execute ALTER TABLE statement for this table. The schema
- will be reloaded within the connection object.
-
- :param alter_statement: alter statement
- """
- sql = 'ALTER TABLE %s %s' % (self.full_table_name, alter_statement)
- self.conn.query(sql)
- self.conn.load_headings(self.dbname, force=True)
- # TODO: place table definition sync mechanism
-
- def _declare(self):
- """
- Declares the table in the data base if no table in the database matches this object.
- """
- if not self._table_def:
- raise DataJointError('Table declaration is missing!')
- table_info, parents, referenced, fieldDefs, indexDefs = self._parse_declaration()
- defined_name = table_info['module'] + '.' + table_info['className']
- expected_name = self.__module__.split('.')[-1] + '.' + self.class_name
- if not defined_name == expected_name:
- raise DataJointError('Table name {} does not match the declared'
- 'name {}'.format(expected_name, defined_name))
-
- # compile the CREATE TABLE statement
- # TODO: support prefix
- table_name = role_to_prefix[
- table_info['tier']] + from_camel_case(self.class_name)
- sql = 'CREATE TABLE `%s`.`%s` (\n' % (self.dbname, table_name)
-
- # add inherited primary key fields
- primary_key_fields = set()
- non_key_fields = set()
- for p in parents:
- for key in p.primary_key:
- field = p.heading[key]
- if field.name not in primary_key_fields:
- primary_key_fields.add(field.name)
- sql += self._field_to_SQL(field)
- else:
- logger.debug('Field definition of {} in {} ignored'.format(
- field.name, p.full_class_name))
-
- # add newly defined primary key fields
- for field in (f for f in fieldDefs if f.isKey):
- if field.isNullable:
- raise DataJointError('Primary key {} cannot be nullable'.format(
- field.name))
- if field.name in primary_key_fields:
- raise DataJointError('Duplicate declaration of the primary key '
- '{key}. Check to make sure that the key '
- 'is not declared already in referenced '
- 'tables'.format(key=field.name))
- primary_key_fields.add(field.name)
- sql += self._field_to_SQL(field)
-
- # add secondary foreign key attributes
- for r in referenced:
- keys = (x for x in r.heading.attrs.values() if x.isKey)
- for field in keys:
- if field.name not in primary_key_fields | non_key_fields:
- non_key_fields.add(field.name)
- sql += self._field_to_SQL(field)
-
- # add dependent attributes
- for field in (f for f in fieldDefs if not f.isKey):
- non_key_fields.add(field.name)
- sql += self._field_to_SQL(field)
-
- # add primary key declaration
- assert len(primary_key_fields) > 0, 'table must have a primary key'
- keys = ', '.join(primary_key_fields)
- sql += 'PRIMARY KEY (%s),\n' % keys
-
- # add foreign key declarations
- for ref in parents + referenced:
- keys = ', '.join(ref.primary_key)
- sql += 'FOREIGN KEY (%s) REFERENCES %s (%s) ON UPDATE CASCADE ON DELETE RESTRICT,\n' % \
- (keys, ref.full_table_name, keys)
-
- # add secondary index declarations
- # gather implicit indexes due to foreign keys first
- implicit_indices = []
- for fk_source in parents + referenced:
- implicit_indices.append(fk_source.primary_key)
-
- # for index in indexDefs:
- # TODO: finish this up...
-
- # close the declaration
- sql = '%s\n) ENGINE = InnoDB, COMMENT "%s"' % (
- sql[:-2], table_info['comment'])
-
- # make sure that the table does not alredy exist
- self.conn.load_headings(self.dbname, force=True)
- if not self.is_declared:
- # execute declaration
- logger.debug('\n\n' + sql + '\n\n')
- self.conn.query(sql)
- self.conn.load_headings(self.dbname, force=True)
-
- def _parse_declaration(self):
- """
- Parse declaration and create new SQL table accordingly.
- """
- parents = []
- referenced = []
- index_defs = []
- field_defs = []
- declaration = re.split(r'\s*\n\s*', self._table_def.strip())
-
- # remove comment lines
- declaration = [x for x in declaration if not x.startswith('#')]
- ptrn = """
- ^(?P\w+)\.(?P\w+)\s* # module.className
- \(\s*(?P\w+)\s*\)\s* # (tier)
- \#\s*(?P.*)$ # comment
- """
- p = re.compile(ptrn, re.X)
- table_info = p.match(declaration[0]).groupdict()
- if table_info['tier'] not in Role.__members__:
- raise DataJointError('InvalidTableTier: Invalid tier {tier} for table\
- {module}.{cls}'.format(tier=table_info['tier'],
- module=table_info[
- 'module'],
- cls=table_info['className']))
- table_info['tier'] = Role[table_info['tier']] # convert into enum
-
- in_key = True # parse primary keys
- field_ptrn = """
- ^[a-z][a-z\d_]*\s* # name
- (=\s*\S+(\s+\S+)*\s*)? # optional defaults
- :\s*\w.*$ # type, comment
- """
- fieldP = re.compile(field_ptrn, re.I + re.X) # ignore case and verbose
-
- for line in declaration[1:]:
- if line.startswith('---'):
- in_key = False # start parsing non-PK fields
- elif line.startswith('->'):
- # foreign key
- module_name, class_name = line[2:].strip().split('.')
- rel = self.get_base(module_name, class_name)
- (parents if in_key else referenced).append(rel)
- elif re.match(r'^(unique\s+)?index[^:]*$', line):
- index_defs.append(self._parse_index_def(line))
- elif fieldP.match(line):
- field_defs.append(self._parse_attr_def(line, in_key))
- else:
- raise DataJointError(
- 'Invalid table declaration line "%s"' % line)
-
- return table_info, parents, referenced, field_defs, index_defs
-
- def _parse_attr_def(self, line, in_key=False): # todo add docu for in_key
- """
- Parse attribute definition line in the declaration and returns
- an attribute tuple.
-
- :param line: attribution line
- :param in_key:
- :returns: attribute tuple
- """
- line = line.strip()
- attr_ptrn = """
- ^(?P[a-z][a-z\d_]*)\s* # field name
- (=\s*(?P\S+(\s+\S+)*?)\s*)? # default value
- :\s*(?P\w[^\#]*[^\#\s])\s* # datatype
- (\#\s*(?P\S*(\s+\S+)*)\s*)?$ # comment
- """
-
- attrP = re.compile(attr_ptrn, re.I + re.X)
- m = attrP.match(line)
- assert m, 'Invalid field declaration "%s"' % line
- attr_info = m.groupdict()
- if not attr_info['comment']:
- attr_info['comment'] = ''
- if not attr_info['default']:
- attr_info['default'] = ''
- attr_info['isNullable'] = attr_info['default'].lower() == 'null'
- assert (not re.match(r'^bigint', attr_info['type'], re.I) or not attr_info['isNullable']), \
- 'BIGINT attributes cannot be nullable in "%s"' % line
-
- attr_info['isKey'] = in_key
- attr_info['isAutoincrement'] = None
- attr_info['isNumeric'] = None
- attr_info['isString'] = None
- attr_info['isBlob'] = None
- attr_info['computation'] = None
- attr_info['dtype'] = None
-
- return Heading.AttrTuple(**attr_info)
-
- def _parse_index_def(self, line):
- """
- Parses index definition.
-
- :param line: definition line
- :return: groupdict with index info
- """
- line = line.strip()
- index_ptrn = """
- ^(?PUNIQUE)?\s*INDEX\s* # [UNIQUE] INDEX
- \((?P[^\)]+)\)$ # (attr1, attr2)
- """
- indexP = re.compile(index_ptrn, re.I + re.X)
- m = indexP.match(line)
- assert m, 'Invalid index declaration "%s"' % line
- index_info = m.groupdict()
- attributes = re.split(r'\s*,\s*', index_info['attributes'].strip())
- index_info['attributes'] = attributes
- assert len(attributes) == len(set(attributes)), \
- 'Duplicate attributes in index declaration "%s"' % line
- return index_info
diff --git a/datajoint/blob.py b/datajoint/blob.py
index 82ac9e338..e66ca11ab 100644
--- a/datajoint/blob.py
+++ b/datajoint/blob.py
@@ -3,29 +3,27 @@
import numpy as np
from . import DataJointError
-
mxClassID = collections.OrderedDict(
# see http://www.mathworks.com/help/techdoc/apiref/mxclassid.html
- mxUNKNOWN_CLASS = None,
- mxCELL_CLASS = None, # not yet implemented
- mxSTRUCT_CLASS = None, # not yet implemented
- mxLOGICAL_CLASS = np.dtype('bool'),
- mxCHAR_CLASS = np.dtype('c'),
- mxVOID_CLASS = None,
- mxDOUBLE_CLASS = np.dtype('float64'),
- mxSINGLE_CLASS = np.dtype('float32'),
- mxINT8_CLASS = np.dtype('int8'),
- mxUINT8_CLASS = np.dtype('uint8'),
- mxINT16_CLASS = np.dtype('int16'),
- mxUINT16_CLASS = np.dtype('uint16'),
- mxINT32_CLASS = np.dtype('int32'),
- mxUINT32_CLASS = np.dtype('uint32'),
- mxINT64_CLASS = np.dtype('int64'),
- mxUINT64_CLASS = np.dtype('uint64'),
- mxFUNCTION_CLASS= None
- )
+ mxUNKNOWN_CLASS=None,
+ mxCELL_CLASS=None, # not implemented
+ mxSTRUCT_CLASS=None, # not implemented
+ mxLOGICAL_CLASS=np.dtype('bool'),
+ mxCHAR_CLASS=np.dtype('c'),
+ mxVOID_CLASS=None,
+ mxDOUBLE_CLASS=np.dtype('float64'),
+ mxSINGLE_CLASS=np.dtype('float32'),
+ mxINT8_CLASS=np.dtype('int8'),
+ mxUINT8_CLASS=np.dtype('uint8'),
+ mxINT16_CLASS=np.dtype('int16'),
+ mxUINT16_CLASS=np.dtype('uint16'),
+ mxINT32_CLASS=np.dtype('int32'),
+ mxUINT32_CLASS=np.dtype('uint32'),
+ mxINT64_CLASS=np.dtype('int64'),
+ mxUINT64_CLASS=np.dtype('uint64'),
+ mxFUNCTION_CLASS=None)
-reverseClassID = {v:i for i,v in enumerate(mxClassID.values())}
+reverseClassID = {v: i for i, v in enumerate(mxClassID.values())}
def pack(obj):
@@ -35,55 +33,53 @@ def pack(obj):
if not isinstance(obj, np.ndarray):
raise DataJointError("Only numpy arrays can be saved in blobs")
- blob = b"mYm\0A" # TODO: extend to process other datatypes besides arrays
- blob += np.asarray((len(obj.shape),)+obj.shape,dtype=np.uint64).tostring()
+ blob = b"mYm\0A" # TODO: extend to process other data types besides arrays
+ blob += np.asarray((len(obj.shape),) + obj.shape, dtype=np.uint64).tostring()
- isComplex = np.iscomplexobj(obj)
- if isComplex:
- obj, objImag = np.real(obj), np.imag(obj)
+ is_complex = np.iscomplexobj(obj)
+ if is_complex:
+ obj, imaginary = np.real(obj), np.imag(obj)
- typeNum = reverseClassID[obj.dtype]
- blob+= np.asarray(typeNum, dtype=np.uint32).tostring()
- blob+= np.int8(isComplex).tostring() + b'\0\0\0'
- blob+= obj.tostring()
+ type_number = reverseClassID[obj.dtype]
+ blob += np.asarray(type_number, dtype=np.uint32).tostring()
+ blob += np.int8(is_complex).tostring() + b'\0\0\0'
+ blob += obj.tostring()
- if isComplex:
- blob+= objImag.tostring()
+ if is_complex:
+ blob += imaginary.tostring()
- if len(blob)>1000:
- compressed = b'ZL123\0'+np.asarray(len(blob),dtype=np.uint64).tostring() + zlib.compress(blob)
+ if len(blob) > 1000:
+ compressed = b'ZL123\0'+np.asarray(len(blob), dtype=np.uint64).tostring() + zlib.compress(blob)
if len(compressed) < len(blob):
blob = compressed
return blob
-
def unpack(blob):
"""
unpack blob into a numpy array
"""
# decompress if necessary
- if blob[0:5]==b'ZL123':
- blobLen = np.fromstring(blob[6:14],dtype=np.uint64)[0]
+ if blob[0:5] == b'ZL123':
+ blob_length = np.fromstring(blob[6:14], dtype=np.uint64)[0]
blob = zlib.decompress(blob[14:])
- assert(len(blob)==blobLen)
+ assert len(blob) == blob_length
- blobType = blob[4]
- if blobType!=65: # TODO: also process structure arrays, cell arrays, etc.
+ blob_type = blob[4]
+ if blob_type != 65: # TODO: also process structure arrays, cell arrays, etc.
raise DataJointError('only arrays are currently allowed in blobs')
p = 5
- ndims = np.fromstring(blob[p:p+8], dtype=np.uint64)
+ dimensions = np.fromstring(blob[p:p+8], dtype=np.uint64)
p += 8
- arrDims = np.fromstring(blob[p:p+8*ndims], dtype=np.uint64)
- p += 8 * ndims
- mxType, dtype = [q for q in mxClassID.items()][np.fromstring(blob[p:p+4],dtype=np.uint32)[0]]
+ array_shape = np.fromstring(blob[p:p+8*dimensions], dtype=np.uint64)
+ p += 8 * dimensions
+ mx_type, dtype = [q for q in mxClassID.items()][np.fromstring(blob[p:p+4], dtype=np.uint32)[0]]
if dtype is None:
- raise DataJointError('Unsupported matlab datatype '+mxType+' in blob')
+ raise DataJointError('Unsupported MATLAB data type '+mx_type+' in blob')
p += 4
- complexity = np.fromstring(blob[p:p+4],dtype=np.uint32)[0]
+ is_complex = np.fromstring(blob[p:p+4], dtype=np.uint32)[0]
p += 4
obj = np.fromstring(blob[p:], dtype=dtype)
- if complexity:
+ if is_complex:
obj = obj[:len(obj)/2] + 1j*obj[len(obj)/2:]
-
- return obj.reshape(arrDims)
+ return obj.reshape(array_shape)
\ No newline at end of file
diff --git a/datajoint/connection.py b/datajoint/connection.py
index 2653950a3..ca6551cea 100644
--- a/datajoint/connection.py
+++ b/datajoint/connection.py
@@ -2,13 +2,9 @@
import re
from .utils import to_camel_case
from . import DataJointError
-import os
from .heading import Heading
-from .base import prefix_to_role
+from .settings import prefix_to_role
import logging
-import networkx as nx
-from networkx import pygraphviz_layout
-import matplotlib.pyplot as plt
from .erd import DBConnGraph
from . import config
@@ -26,7 +22,7 @@ def conn_container():
"""
_connObj = None # persistent connection object used by dj.conn()
- def conn(host=None, user=None, passwd=None, initFun=None, reset=False):
+ def conn_function(host=None, user=None, passwd=None, init_fun=None, reset=False):
"""
Manage a persistent connection object.
This is one of several ways to configure and access a datajoint connection.
@@ -39,31 +35,29 @@ def conn(host=None, user=None, passwd=None, initFun=None, reset=False):
host = host if host is not None else config['database.host']
user = user if user is not None else config['database.user']
passwd = passwd if passwd is not None else config['database.password']
- initFun = initFun if initFun is not None else config['connection.init_function']
- _connObj = Connection(host, user, passwd, initFun)
+ init_fun = init_fun if init_fun is not None else config['connection.init_function']
+ _connObj = Connection(host, user, passwd, init_fun)
return _connObj
- return conn
-
+ return conn_function
# The function conn is used by others to obtain the package wide persistent connection object
conn = conn_container()
-
class Connection:
"""
A dj.Connection object manages a connection to a database server.
It also catalogues modules, schemas, tables, and their dependencies (foreign keys)
"""
- def __init__(self, host, user, passwd, initFun=None):
+ def __init__(self, host, user, passwd, init_fun=None):
if ':' in host:
host, port = host.split(':')
port = int(port)
else:
port = config['database.port']
self.conn_info = dict(host=host, port=port, user=user, passwd=passwd)
- self._conn = pymysql.connect(init_command=initFun, **self.conn_info)
+ self._conn = pymysql.connect(init_command=init_fun, **self.conn_info)
# TODO Do something if connection cannot be established
if self.is_connected:
print("Connected", user + '@' + host + ':' + str(port))
@@ -140,11 +134,11 @@ def bind(self, module, dbname):
self.mod_to_db[module] = dbname
elif count == 0:
# Database doesn't exist, attempt to create
- logger.warning("Database `{dbname}` could not be found. "
- "Attempting to create the database.".format(dbname=dbname))
+ logger.info("Database `{dbname}` could not be found. "
+ "Attempting to create the database.".format(dbname=dbname))
try:
- cur = self.query("CREATE DATABASE `{dbname}`".format(dbname=dbname))
- logger.warning('Created database `{dbname}`.'.format(dbname=dbname))
+ self.query("CREATE DATABASE `{dbname}`".format(dbname=dbname))
+ logger.info('Created database `{dbname}`.'.format(dbname=dbname))
self.db_to_mod[dbname] = module
self.mod_to_db[module] = dbname
except pymysql.OperationalError:
@@ -227,7 +221,8 @@ def load_dependencies(self, dbname): # TODO: Perhaps consider making this "priv
self.referenced[full_table_name] = []
for m in re.finditer(ptrn, table_def["Create Table"], re.X): # iterate through foreign key statements
- assert m.group('attr1') == m.group('attr2'), 'Foreign keys must link identically named attributes'
+ assert m.group('attr1') == m.group('attr2'), \
+ 'Foreign keys must link identically named attributes'
attrs = m.group('attr1')
attrs = re.split(r',\s+', re.sub(r'`(.*?)`', r'\1', attrs)) # remove ` around attrs and split into list
pk = self.headings[dbname][tabName].primary_key
@@ -270,14 +265,14 @@ def parents_of(self, child_table):
def children_of(self, parent_table):
"""
- Returnis a list of tables for which parentTable is a parent (primary foreign key)
+ Returns a list of tables for which parent_table is a parent (primary foreign key)
"""
return [child_table for child_table, parents in self.parents.items() if parent_table in parents]
def referenced_by(self, referencing_table):
"""
Returns a list of tables that are referenced by non-primary foreign key
- by the referencingTable.
+ by the referencing_table.
"""
return self.referenced.get(referencing_table, []).copy()
diff --git a/datajoint/declare.py b/datajoint/declare.py
new file mode 100644
index 000000000..af09334ae
--- /dev/null
+++ b/datajoint/declare.py
@@ -0,0 +1,239 @@
+import re
+import logging
+from .heading import Heading
+from . import DataJointError
+from .utils import from_camel_case
+from .settings import Role, role_to_prefix
+
+mysql_constants = ['CURRENT_TIMESTAMP']
+
+logger = logging.getLogger(__name__)
+
+
+def declare(conn, definition, class_name):
+ """
+ Declares the table in the data base if no table in the database matches this object.
+ """
+ table_info, parents, referenced, field_definitions, index_definitions = parse_declaration(definition)
+ defined_name = table_info['module'] + '.' + table_info['className']
+ if not defined_name == class_name:
+ raise DataJointError('Table name {} does not match the declared'
+ 'name {}'.format(class_name, defined_name))
+
+ # compile the CREATE TABLE statement
+ table_name = role_to_prefix[table_info['tier']] + from_camel_case(class_name)
+ sql = 'CREATE TABLE `%s`.`%s` (\n' % (self.dbname, table_name)
+
+ # add inherited primary key fields
+ primary_key_fields = set()
+ non_key_fields = set()
+ for p in parents:
+ for key in p.primary_key:
+ field = p.heading[key]
+ if field.name not in primary_key_fields:
+ primary_key_fields.add(field.name)
+ sql += field_to_sql(field)
+ else:
+ logger.debug('Field definition of {} in {} ignored'.format(
+ field.name, p.full_class_name))
+
+ # add newly defined primary key fields
+ for field in (f for f in field_definitions if f.isKey):
+ if field.nullable:
+ raise DataJointError('Primary key {} cannot be nullable'.format(
+ field.name))
+ if field.name in primary_key_fields:
+ raise DataJointError('Duplicate declaration of the primary key '
+ '{key}. Check to make sure that the key '
+ 'is not declared already in referenced '
+ 'tables'.format(key=field.name))
+ primary_key_fields.add(field.name)
+ sql += field_to_sql(field)
+
+ # add secondary foreign key attributes
+ for r in referenced:
+ keys = (x for x in r.heading.attrs.values() if x.isKey)
+ for field in keys:
+ if field.name not in primary_key_fields | non_key_fields:
+ non_key_fields.add(field.name)
+ sql += field_to_sql(field)
+
+ # add dependent attributes
+ for field in (f for f in field_definitions if not f.isKey):
+ non_key_fields.add(field.name)
+ sql += field_to_sql(field)
+
+ # add primary key declaration
+ assert len(primary_key_fields) > 0, 'table must have a primary key'
+ keys = ', '.join(primary_key_fields)
+ sql += 'PRIMARY KEY (%s),\n' % keys
+
+ # add foreign key declarations
+ for ref in parents + referenced:
+ keys = ', '.join(ref.primary_key)
+ sql += 'FOREIGN KEY (%s) REFERENCES %s (%s) ON UPDATE CASCADE ON DELETE RESTRICT,\n' % \
+ (keys, ref.full_table_name, keys)
+
+ # add secondary index declarations
+ # gather implicit indexes due to foreign keys first
+ implicit_indices = []
+ for fk_source in parents + referenced:
+ implicit_indices.append(fk_source.primary_key)
+
+ # for index in index_definitions:
+ # TODO: finish this up...
+
+ # close the declaration
+ sql = '%s\n) ENGINE = InnoDB, COMMENT "%s"' % (
+ sql[:-2], table_info['comment'])
+
+ # make sure that the table does not alredy exist
+ self.conn.load_headings(self.dbname, force=True)
+ if not self.is_declared:
+ # execute declaration
+ logger.debug('\n\n' + sql + '\n\n')
+ self.conn.query(sql)
+ self.conn.load_headings(self.dbname, force=True)
+
+
+def _parse_declaration(self):
+ """
+ Parse declaration and create new SQL table accordingly.
+ """
+ parents = []
+ referenced = []
+ index_defs = []
+ field_defs = []
+ declaration = re.split(r'\s*\n\s*', self.definition.strip())
+
+ # remove comment lines
+ declaration = [x for x in declaration if not x.startswith('#')]
+ ptrn = """
+ ^(?P\w+)\.(?P\w+)\s* # module.className
+ \(\s*(?P\w+)\s*\)\s* # (tier)
+ \#\s*(?P.*)$ # comment
+ """
+ p = re.compile(ptrn, re.X)
+ table_info = p.match(declaration[0]).groupdict()
+ if table_info['tier'] not in Role.__members__:
+ raise DataJointError('InvalidTableTier: Invalid tier {tier} for table\
+ {module}.{cls}'.format(tier=table_info['tier'],
+ module=table_info[
+ 'module'],
+ cls=table_info['className']))
+ table_info['tier'] = Role[table_info['tier']] # convert into enum
+
+ in_key = True # parse primary keys
+ field_ptrn = """
+ ^[a-z][a-z\d_]*\s* # name
+ (=\s*\S+(\s+\S+)*\s*)? # optional defaults
+ :\s*\w.*$ # type, comment
+ """
+ fieldP = re.compile(field_ptrn, re.I + re.X) # ignore case and verbose
+
+ for line in declaration[1:]:
+ if line.startswith('---'):
+ in_key = False # start parsing non-PK fields
+ elif line.startswith('->'):
+ # foreign key
+ module_name, class_name = line[2:].strip().split('.')
+ rel = self.get_base(module_name, class_name)
+ (parents if in_key else referenced).append(rel)
+ elif re.match(r'^(unique\s+)?index[^:]*$', line):
+ index_defs.append(parse_index_defnition(line))
+ elif fieldP.match(line):
+ field_defs.append(parse_attribute_definition(line, in_key))
+ else:
+ raise DataJointError(
+ 'Invalid table declaration line "%s"' % line)
+
+ return table_info, parents, referenced, field_defs, index_defs
+
+
+def field_to_sql(field):
+ """
+ Converts an attribute definition tuple into SQL code.
+ :param field: attribute definition
+ :rtype : SQL code
+ """
+ if field.nullable:
+ default = 'DEFAULT NULL'
+ else:
+ default = 'NOT NULL'
+ # if some default specified
+ if field.default:
+ # enclose value in quotes (even numeric), except special SQL values
+ # or values already enclosed by the user
+ if field.default.upper() in mysql_constants or field.default[:1] in ["'", '"']:
+ default = '%s DEFAULT %s' % (default, field.default)
+ else:
+ default = '%s DEFAULT "%s"' % (default, field.default)
+
+ # TODO: escape instead! - same goes for Matlab side implementation
+ assert not any((c in r'\"' for c in field.comment)), \
+ 'Illegal characters in attribute comment "%s"' % field.comment
+
+ return '`{name}` {type} {default} COMMENT "{comment}",\n'.format(
+ name=field.name, type=field.type, default=default, comment=field.comment)
+
+
+def parse_attribute_definition(line, in_key=False): # todo add docu for in_key
+ """
+ Parse attribute definition line in the declaration and returns
+ an attribute tuple.
+ :param line: attribution line
+ :param in_key:
+ :returns: attribute tuple
+ """
+ line = line.strip()
+ attr_ptrn = """
+ ^(?P[a-z][a-z\d_]*)\s* # field name
+ (=\s*(?P\S+(\s+\S+)*?)\s*)? # default value
+ :\s*(?P\w[^\#]*[^\#\s])\s* # datatype
+ (\#\s*(?P\S*(\s+\S+)*)\s*)?$ # comment
+ """
+
+ attrP = re.compile(attr_ptrn, re.I + re.X)
+ m = attrP.match(line)
+ assert m, 'Invalid field declaration "%s"' % line
+ attr_info = m.groupdict()
+ if not attr_info['comment']:
+ attr_info['comment'] = ''
+ if not attr_info['default']:
+ attr_info['default'] = ''
+ attr_info['nullable'] = attr_info['default'].lower() == 'null'
+ assert (not re.match(r'^bigint', attr_info['type'], re.I) or not attr_info['nullable']), \
+ 'BIGINT attributes cannot be nullable in "%s"' % line
+
+ attr_info['isKey'] = in_key
+ attr_info['isAutoincrement'] = None
+ attr_info['isNumeric'] = None
+ attr_info['isString'] = None
+ attr_info['isBlob'] = None
+ attr_info['computation'] = None
+ attr_info['dtype'] = None
+
+ return Heading.AttrTuple(**attr_info)
+
+
+def parse_index_definition(line): # why is this a method of Base instead of a local function?
+ """
+ Parses index definition.
+
+ :param line: definition line
+ :return: groupdict with index info
+ """
+ line = line.strip()
+ index_ptrn = """
+ ^(?PUNIQUE)?\s*INDEX\s* # [UNIQUE] INDEX
+ \((?P[^\)]+)\)$ # (attr1, attr2)
+ """
+ indexP = re.compile(index_ptrn, re.I + re.X)
+ m = indexP.match(line)
+ assert m, 'Invalid index declaration "%s"' % line
+ index_info = m.groupdict()
+ attributes = re.split(r'\s*,\s*', index_info['attributes'].strip())
+ index_info['attributes'] = attributes
+ assert len(attributes) == len(set(attributes)), \
+ 'Duplicate attributes in index declaration "%s"' % line
+ return index_info
diff --git a/datajoint/erd.py b/datajoint/erd.py
index a7e1bc024..af6107390 100644
--- a/datajoint/erd.py
+++ b/datajoint/erd.py
@@ -14,6 +14,7 @@
logger = logging.getLogger(__name__)
+
class RelGraph(DiGraph):
"""
Represents relations between tables and databases
diff --git a/datajoint/fetch.py b/datajoint/fetch.py
deleted file mode 100644
index e658bf1e6..000000000
--- a/datajoint/fetch.py
+++ /dev/null
@@ -1,62 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-Created on Wed Aug 20 22:05:29 2014
-
-@author: dimitri
-"""
-
-from .blob import unpack
-import numpy as np
-import logging
-
-logger = logging.getLogger(__name__)
-
-
-class Fetch:
- """
- Fetch defines callable objects that fetch data from a relation
- """
-
- def __init__(self, relational):
- self.rel = relational
- self._orderBy = None
- self._offset = 0
- self._limit = None
-
- def limit(self, n, offset=0):
- self._limit = n
- self._offset = offset
- return self
-
- def order_by(self, *attrs):
- self._orderBy = attrs
- return self
-
- def __call__(self, *attrs, **renames):
- """
- fetch relation from database into an np.array
- """
- cur = self._cursor(*attrs, **renames)
- heading = self.rel.pro(*attrs, **renames).heading
- ret = np.array(list(cur), dtype=heading.asdtype)
- # unpack blobs
- for i in range(len(ret)):
- for f in heading.blobs:
- ret[i][f] = unpack(ret[i][f])
- return ret
-
- def _cursor(self, *attrs, **renames):
- rel = self.rel.pro(*attrs, **renames)
- sql = 'SELECT ' + rel.heading.asSQL + ' FROM ' + rel.sql
- # add ORDER BY clause
- if self._orderBy:
- sql += ' ORDER BY ' + ', '.join(self._orderBy)
-
- # add LIMIT clause
- if self._limit:
- sql += ' LIMIT %d' % self._limit
- if self._offset:
- sql += ' OFFSET %d ' % self._offset
-
- logger.debug(sql)
- return self.rel.conn.query(sql)
diff --git a/datajoint/heading.py b/datajoint/heading.py
index ca7495db0..18e75df62 100644
--- a/datajoint/heading.py
+++ b/datajoint/heading.py
@@ -1,9 +1,3 @@
-# -*- coding: utf-8 -*-
-"""
-Created on Mon Aug 4 01:29:51 2014
-
-@author: dimitri, eywalker
-"""
import re
from collections import OrderedDict, namedtuple
@@ -13,219 +7,214 @@
class Heading:
"""
- local class for relationals' headings.
+ local class for relations' headings.
"""
+ AttrTuple = namedtuple('AttrTuple',
+ ('name', 'type', 'in_key', 'nullable', 'default', 'comment', 'autoincrement',
+ 'numeric', 'string', 'is_blob', 'computation', 'dtype'))
- AttrTuple = namedtuple('AttrTuple',('name','type','isKey','isNullable',
- 'default','comment','isAutoincrement','isNumeric','isString','isBlob',
- 'computation','dtype'))
-
- def __init__(self, attrs):
- # Input: attrs -list of dicts with attribute descriptions
- self.attrs = OrderedDict([(q['name'], Heading.AttrTuple(**q)) for q in attrs])
+ def __init__(self, attributes):
+ # Input: attributes -list of dicts with attribute descriptions
+ self.attributes = OrderedDict([(q['name'], Heading.AttrTuple(**q)) for q in attributes])
@property
def names(self):
- return [k for k in self.attrs]
+ return [k for k in self.attributes]
@property
def primary_key(self):
- return [k for k,v in self.attrs.items() if v.isKey]
+ return [k for k, v in self.attributes.items() if v.in_key]
@property
def dependent_fields(self):
- return [k for k,v in self.attrs.items() if not v.isKey]
+ return [k for k, v in self.attributes.items() if not v.in_key]
@property
def blobs(self):
- return [k for k,v in self.attrs.items() if v.isBlob]
+ return [k for k, v in self.attributes.items() if v.is_blob]
@property
def non_blobs(self):
- return [k for k,v in self.attrs.items() if not v.isBlob]
+ return [k for k, v in self.attributes.items() if not v.is_blob]
@property
def computed(self):
- return [k for k,v in self.attrs.items() if v.computation]
+ return [k for k, v in self.attributes.items() if v.computation]
- def __getitem__(self,name):
+ def __getitem__(self, name):
"""shortcut to the attribute"""
- return self.attrs[name]
+ return self.attributes[name]
def __repr__(self):
- autoIncrementString = {False:'', True:' auto_increment'}
+ autoincrement_string = {False: '', True: ' auto_increment'}
return '\n'.join(['%-20s : %-28s # %s' % (
- k if v.default is None else '%s="%s"'%(k,v.default),
- '%s%s' % (v.type, autoIncrementString[v.isAutoincrement]),
+ k if v.default is None else '%s="%s"' % (k, v.default),
+ '%s%s' % (v.type, autoincrement_string[v.autoincrement]),
v.comment)
- for k,v in self.attrs.items()])
+ for k, v in self.attributes.items()])
@property
- def asdtype(self):
+ def as_dtype(self):
"""
represent the heading as a numpy dtype
"""
return np.dtype(dict(
names=self.names,
- formats=[v.dtype for k,v in self.attrs.items()]))
+ formats=[v.dtype for k, v in self.attributes.items()]))
@property
- def asSQL(self):
- """represent heading as SQL field list"""
- attrNames = ['`%s`' % name if self.attrs[name].computation is None else '%s as `%s`' % (self.attrs[name].computation, name)
- for name in self.names]
- return ','.join(attrNames)
-
- # Use heading as a dictionary like object
+ def as_sql(self):
+ """
+ represent heading as SQL field list
+ """
+ return ','.join(['`%s`' % name
+ if self.attributes[name].computation is None
+ else '%s as `%s`' % (self.attributes[name].computation, name)
+ for name in self.names])
def keys(self):
- return self.attrs.keys()
+ return self.attributes.keys()
def values(self):
- return self.attrs.values()
+ return self.attributes.values()
def items(self):
- return self.attrs.items()
-
+ return self.attributes.items()
@classmethod
- def init_from_database(cls, conn, dbname, tabname):
+ def init_from_database(cls, conn, dbname, table_name):
"""
initialize heading from a database table
"""
cur = conn.query(
- 'SHOW FULL COLUMNS FROM `{tabname}` IN `{dbname}`'.format(
- tabname=tabname, dbname=dbname),asDict=True)
- attrs = cur.fetchall()
+ 'SHOW FULL COLUMNS FROM `{table_name}` IN `{dbname}`'.format(
+ table_name=table_name, dbname=dbname), asDict=True)
+ attributes = cur.fetchall()
rename_map = {
- 'Field' : 'name',
- 'Type' : 'type',
- 'Null' : 'isNullable',
+ 'Field': 'name',
+ 'Type': 'type',
+ 'Null': 'nullable',
'Default': 'default',
- 'Key' : 'isKey',
+ 'Key': 'in_key',
'Comment': 'comment'}
- dropFields = ('Privileges', 'Collation') # unncessary
+ fields_to_drop = ('Privileges', 'Collation')
# rename and drop attributes
- attrs = [{rename_map[k] if k in rename_map else k: v
- for k, v in x.items() if k not in dropFields}
- for x in attrs]
- numTypes ={
- ('float',False):np.float32,
- ('float',True):np.float32,
- ('double',False):np.float32,
- ('double',True):np.float64,
- ('tinyint',False):np.int8,
- ('tinyint',True):np.uint8,
- ('smallint',False):np.int16,
- ('smallint',True):np.uint16,
- ('mediumint',False):np.int32,
- ('mediumint',True):np.uint32,
- ('int',False):np.int32,
- ('int',True):np.uint32,
- ('bigint',False):np.int64,
- ('bigint',True):np.uint64
- # TODO: include decimal and numeric datatypes
+ attributes = [{rename_map[k] if k in rename_map else k: v
+ for k, v in x.items() if k not in fields_to_drop}
+ for x in attributes]
+
+ numeric_types = {
+ ('float', False): np.float32,
+ ('float', True): np.float32,
+ ('double', False): np.float32,
+ ('double', True): np.float64,
+ ('tinyint', False): np.int8,
+ ('tinyint', True): np.uint8,
+ ('smallint', False): np.int16,
+ ('smallint', True): np.uint16,
+ ('mediumint', False): np.int32,
+ ('mediumint', True): np.uint32,
+ ('int', False): np.int32,
+ ('int', True): np.uint32,
+ ('bigint', False): np.int64,
+ ('bigint', True): np.uint64
+ # TODO: include types DECIMAL and NUMERIC
}
-
# additional attribute properties
- for attr in attrs:
- attr['isNullable'] = (attr['isNullable'] == 'YES')
- attr['isKey'] = (attr['isKey'] == 'PRI')
- attr['isAutoincrement'] = bool(re.search(r'auto_increment', attr['Extra'], flags=re.IGNORECASE))
- attr['isNumeric'] = bool(re.match(r'(tiny|small|medium|big)?int|decimal|double|float', attr['type']))
- attr['isString'] = bool(re.match(r'(var)?char|enum|date|time|timestamp', attr['type']))
- attr['isBlob'] = bool(re.match(r'(tiny|medium|long)?blob', attr['type']))
+ for attr in attributes:
+ attr['nullable'] = (attr['nullable'] == 'YES')
+ attr['in_key'] = (attr['in_key'] == 'PRI')
+ attr['autoincrement'] = bool(re.search(r'auto_increment', attr['Extra'], flags=re.IGNORECASE))
+ attr['numeric'] = bool(re.match(r'(tiny|small|medium|big)?int|decimal|double|float', attr['type']))
+ attr['string'] = bool(re.match(r'(var)?char|enum|date|time|timestamp', attr['type']))
+ attr['is_blob'] = bool(re.match(r'(tiny|medium|long)?blob', attr['type']))
# strip field lengths off integer types
attr['type'] = re.sub(r'((tiny|small|medium|big)?int)\(\d+\)', r'\1', attr['type'])
attr['computation'] = None
- if not (attr['isNumeric'] or attr['isString'] or attr['isBlob']):
- raise DataJointError('Unsupported field type {field} in `{dbname}`.`{tabname}`'.format(
- field=attr['type'], dbname=dbname, tabname=tabname))
+ if not (attr['numeric'] or attr['string'] or attr['is_blob']):
+ raise DataJointError('Unsupported field type {field} in `{dbname}`.`{table_name}`'.format(
+ field=attr['type'], dbname=dbname, table_name=table_name))
attr.pop('Extra')
# fill out the dtype. All floats and non-nullable integers are turned into specific dtypes
attr['dtype'] = object
- if attr['isNumeric'] :
- isInteger = bool(re.match(r'(tiny|small|medium|big)?int',attr['type']))
- isFloat = bool(re.match(r'(double|float)',attr['type']))
- if isInteger and not attr['isNullable'] or isFloat:
- isUnsigned = bool(re.match('\sunsigned', attr['type'], flags=re.IGNORECASE))
+ if attr['numeric']:
+ is_integer = bool(re.match(r'(tiny|small|medium|big)?int', attr['type']))
+ is_float = bool(re.match(r'(double|float)', attr['type']))
+ if is_integer and not attr['nullable'] or is_float:
+ is_unsigned = bool(re.match('\sunsigned', attr['type'], flags=re.IGNORECASE))
t = attr['type']
- t = re.sub(r'\(.*\)','',t) # remove parentheses
- t = re.sub(r' unsigned$','',t) # remove unsigned
- assert (t,isUnsigned) in numTypes, 'dtype not found for type %s' % t
- attr['dtype'] = numTypes[(t,isUnsigned)]
-
- return cls(attrs)
+ t = re.sub(r'\(.*\)', '', t) # remove parentheses
+ t = re.sub(r' unsigned$', '', t) # remove unsigned
+ assert (t, is_unsigned) in numeric_types, 'dtype not found for type %s' % t
+ attr['dtype'] = numeric_types[(t, is_unsigned)]
+ return cls(attributes)
- def pro(self, *attrList, **renameDict):
+ def pro(self, *attribute_list, **rename_dict):
"""
derive a new heading by selecting, renaming, or computing attributes.
In relational algebra these operators are known as project, rename, and expand.
The primary key is always included.
"""
- # include all if '*' is in attrSet, always include primary key
- attrSet = set(self.names) if '*' in attrList \
- else set(attrList).union(self.primary_key)
+ # include all if '*' is in attribute_set, always include primary key
+ attribute_set = set(self.names) if '*' in attribute_list \
+ else set(attribute_list).union(self.primary_key)
# report missing attributes
- missing = attrSet.difference(self.names)
+ missing = attribute_set.difference(self.names)
if missing:
raise DataJointError('Attributes %s are not found' % str(missing))
- # make attrList a list of dicts for initializing a Heading
- attrList = [v._asdict() for k,v in self.attrs.items()
- if k in attrSet and k not in renameDict.values()]
+ # make attribute_list a list of dicts for initializing a Heading
+ attribute_list = [v._asdict() for k, v in self.attributes.items()
+ if k in attribute_set and k not in rename_dict.values()]
# add renamed and computed attributes
- for newName, computation in renameDict.items():
+ for new_name, computation in rename_dict.items():
if computation in self.names:
# renamed attribute
- newAttr = self.attrs[computation]._asdict()
- newAttr['name'] = newName
- newAttr['computation'] = '`' + computation + '`'
+ new_attr = self.attributes[computation]._asdict()
+ new_attr['name'] = new_name
+ new_attr['computation'] = '`' + computation + '`'
else:
# computed attribute
- newAttr = dict(
- name = newName,
- type = 'computed',
- isKey = False,
- isNullable = False,
- default = None,
- comment = 'computed attribute',
- isAutoincrement = False,
- isNumeric = None,
- isString = None,
- isBlob = False,
- computation = computation,
- dtype = object)
- attrList.append(newAttr)
-
- return Heading(attrList)
-
+ new_attr = dict(
+ name=new_name,
+ type='computed',
+ in_key=False,
+ nullable=False,
+ default=None,
+ comment='computed attribute',
+ autoincrement=False,
+ numeric=None,
+ string=None,
+ is_blob=False,
+ computation=computation,
+ dtype=object)
+ attribute_list.append(new_attr)
+
+ return Heading(attribute_list)
def join(self, other):
"""
join two headings
"""
- assert isinstance(other,Heading)
- attrList = [v._asdict() for v in self.attrs.values()]
+ assert isinstance(other, Heading)
+ attribute_list = [v._asdict() for v in self.attributes.values()]
for name in other.names:
if name not in self.names:
- attrList.append(other.attrs[name]._asdict())
- return Heading(attrList)
+ attribute_list.append(other.attributes[name]._asdict())
+ return Heading(attribute_list)
-
- def resolveComputations(self):
+ def resolve_computations(self):
"""
Remove computations. To be done after computations have been resolved in a subquery
"""
- return Heading( [dict(v._asdict(),computation=None) for v in self.attrs.values()] )
-
+ return Heading([dict(v._asdict(), computation=None) for v in self.attributes.values()])
\ No newline at end of file
diff --git a/datajoint/relational.py b/datajoint/relational.py
index f02edbc30..f72bd1369 100644
--- a/datajoint/relational.py
+++ b/datajoint/relational.py
@@ -1,32 +1,28 @@
-# -*- coding: utf-8 -*-
"""
-Created on Thu Aug 7 17:00:02 2014
-
-@author: dimitri, eywalker
+classes for relational algebra
"""
+
import numpy as np
import abc
from copy import copy
from datajoint import DataJointError
-from .fetch import Fetch
+from .blob import unpack
+import logging
+
+logger = logging.getLogger(__name__)
+
-class _Relational(metaclass=abc.ABCMeta):
+class Relation(metaclass=abc.ABCMeta):
"""
- Relational implements relational operators.
- Relational objects reference other relational objects linked by operators.
- The leaves of this tree of objects are base relvars.
- When fetching data from the database, this tree of objects is compiled into
- and SQL expression.
- It is a mixin class that provides relational operators, iteration, and
- fetch capability.
- Relational operators are: restrict, pro, aggr, and join.
+ Relation implements relational algebra and fetch methods.
+ Relation objects reference other relation objects linked by operators.
+ The leaves of this tree of objects are base relations.
+ When fetching data from the database, this tree of objects is compiled into an SQL expression.
+ It is a mixin class that provides relational operators, iteration, and fetch capability.
+ Relation operators are: restrict, pro, and join.
"""
_restrictions = []
- _limit = None
- _offset = 0
- _order_by = []
-
- #### abstract properties that subclasses must define #####
+
@abc.abstractproperty
def sql(self):
return NotImplemented
@@ -34,139 +30,217 @@ def sql(self):
@abc.abstractproperty
def heading(self):
return NotImplemented
+
+ @property
+ def restrictions(self):
+ return self._restrictions
- ###### Relational algebra ##############
def __mul__(self, other):
- "relational join"
- return Join(self,other)
+ """
+ relational join
+ """
+ return Join(self, other)
- def pro(self, *arg, _sub=None, **kwarg):
- "relational projection abd aggregation"
- return Projection(self, _sub=_sub, *arg, **kwarg)
+ def __mod__(self, attribute_list):
+ """
+ relational projection operator.
+ :param attribute_list: list of attribute specifications.
+ The attribute specifications are strings in following forms:
+ 'name' - specific attribute
+ 'name->new_name' - rename attribute. The old attribute is kept only if specifically included.
+ 'sql_expression->new_name' - extend attribute, i.e. a new computed attribute.
+ :return: a new relation with specified heading
+ """
+ self.project(attribute_list)
+
+ def project(self, *selection, **aliases):
+ """
+ Relational projection operator.
+ :param attributes: a list of attribute names to be included in the result.
+ :param renames: a dict of attributes to be renamed
+ :return: a new relation with selected fields
+ Primary key attributes are always selected and cannot be excluded.
+ Therefore obj.project() produces a relation with only the primary key attributes.
+ If selection includes the string '*', all attributes are selected.
+ Each attribute can only be used once in attributes or renames. Therefore, the projected
+ relation cannot have more attributes than the original relation.
+ """
+ return self.aggregate(
+ group=selection.pop[0] if selection and isinstance(selection[0], Relation) else None,
+ *selection, **aliases)
+
+ def aggregate(self, group, *selection, **aliases):
+ """
+ Relational aggregation operator
+ :param grouped_relation:
+ :param extensions:
+ :return:
+ """
+ if group is not None and not isinstance(group, Relation):
+ raise DataJointError('The second argument of aggregate must be a relation')
+ # convert the string notation for aliases to
+
+ return Projection(group=group, *attriutes, **aliases)
def __iand__(self, restriction):
- "in-place relational restriction or semijoin"
+ """
+ in-place relational restriction or semijoin
+ """
if self._restrictions is None:
self._restrictions = []
self._restrictions.append(restriction)
return self
def __and__(self, restriction):
- "relational restriction or semijoin"
+ """
+ relational restriction or semijoin
+ """
if self._restrictions is None:
self._restrictions = []
- ret = copy(self) # todo: why not deepcopy it?
+ ret = copy(self) # todo: why not deepcopy it?
ret._restrictions = list(ret._restrictions) # copy restriction
ret &= restriction
return ret
def __isub__(self, restriction):
- "in-place inverted restriction aka antijoin"
+ """
+ in-place inverted restriction aka antijoin
+ """
self &= Not(restriction)
return self
def __sub__(self, restriction):
- "inverted restriction aka antijoin"
+ """
+ inverted restriction aka antijoin
+ """
return self & Not(restriction)
-
- ###### Fetching the data ##############
@property
def count(self):
- sql = 'SELECT count(*) FROM ' + self.sql + self._whereClause
+ sql = 'SELECT count(*) FROM ' + self.sql + self._where_clause
cur = self.conn.query(sql)
- return cur.fetchone()[0] #todo: should we assert that this only returns one result?
+ return cur.fetchone()[0]
- @property
- def fetch(self):
- return Fetch(self)
+ def __call__(self, offset=0, limit=None, order_by=None, descending=False):
+ """
+ fetches the relation from the database table into an np.array and unpacks blob attributes.
+ :param offset: the number of tuples to skip in the returned result
+ :param limit: the maximum number of tuples to return
+ :param order_by: the list of attributes to order the results. No ordering should be assumed if order_by=None.
+ :param descending: the list of attributes to order the results
+ :return: the contents of the relation in the form of a structured numpy.array
+ """
+ cur = self.cursor(offset, limit, order_by, descending)
+ ret = np.array(list(cur), dtype=self.heading.asdtype)
+ for f in self.heading.blobs:
+ for i in range(len(ret)):
+ ret[i][f] = unpack(ret[i][f])
+ return ret
+
+ def cursor(self, offset=0, limit=None, order_by=None, descending=False):
+ """
+ :param offset: the number of tuples to skip in the returned result
+ :param limit: the maximum number of tuples to return
+ :param order_by: the list of attributes to order the results. No ordering should be assumed if order_by=None.
+ :param descending: the list of attributes to order the results
+ :return: cursor to the query
+ """
+ if offset and limit is None:
+ raise DataJointError('')
+ sql = 'SELECT ' + self.heading.as_sql + ' FROM ' + self.sql
+ if order_by is not None:
+ sql += ' ORDER BY ' + ', '.join(self._orderBy)
+ if descending:
+ sql += ' DESC'
+ if limit is not None:
+ sql += ' LIMIT %d' % limit
+ if offset:
+ sql += ' OFFSET %d' % offset
+ logger.debug(sql)
+ return self.conn.query(sql)
def __repr__(self):
- header = self.heading.non_blobs
limit = 13
width = 12
template = '%%-%d.%ds' % (width, width)
- ret_val = ' '.join([template % column for column in header]) + '\n'
- ret_val += ' '.join(['+' + '-' * (width - 2) + '+' for column in header]) + '\n'
-
- tuples = self.fetch.limit(limit)(*header)
+ repr_string = ' '.join([template % column for column in header]) + '\n'
+ repr_string += ' '.join(['+' + '-' * (width - 2) + '+' for _ in header]) + '\n'
+ tuples = self.pro(*self.heading.non_blobs).fetch(limit=limit)
for tup in tuples:
- ret_val += ' '.join([template % column for column in tup]) + '\n'
- cnt = self.count
- if cnt > limit:
- ret_val += '...\n'
- ret_val += '%d tuples\n' % self.count
- return ret_val
+ repr_string += ' '.join([template % column for column in tup]) + '\n'
+ if self.count > limit:
+ repr_string += '...\n'
+ repr_string += '%d tuples\n' % self.count
+ return repr_string
- ######## iterator ###############
def __iter__(self):
"""
iterator yields primary key tuples
"""
- cur, h = self.fetch._cursor()
- dtype = h.asdtype
+ cur, h = self.pro().cursor()
q = cur.fetchone()
while q:
- yield np.array([q,],dtype=dtype) #todo: why convert that into an array?
+ yield np.array([q, ], dtype=h.asdtype)
q = cur.fetchone()
-
@property
- def _whereClause(self):
- "make there WHERE clause based on the current restriction"
-
+ def _where_clause(self):
+ """
+ make there WHERE clause based on the current restriction
+ """
if not self._restrictions:
return ''
-
- def makeCondition(arg):
- if isinstance(arg,dict):
- conds = ['`%s`=%s'%(k,repr(v)) for k,v in arg.items()]
- elif isinstance(arg,np.void):
- conds = ['`%s`=%s'%(k, arg[k]) for k in arg.dtype.fields]
+
+ def make_condition(arg):
+ if isinstance(arg, dict):
+ conditions = ['`%s`=%s' % (k, repr(v)) for k, v in arg.items()]
+ elif isinstance(arg, np.void):
+ conditions = ['`%s`=%s' % (k, arg[k]) for k in arg.dtype.fields]
else:
raise DataJointError('invalid restriction type')
- return ' AND '.join(conds)
+ return ' AND '.join(conditions)
- condStr = []
+ condition_string = []
for r in self._restrictions:
- negate = isinstance(r,Not)
+ negate = isinstance(r, Not)
if negate:
- r = r._restriction
- if isinstance(r,dict) or isinstance(r,np.void):
- r = makeCondition(r)
- elif isinstance(r,np.ndarray) or isinstance(r,list):
- r = '('+') OR ('.join([makeCondition(q) for q in r])+')'
- elif isinstance(r,_Relational):
- commonAttrs = ','.join([q for q in self.heading.names if r.heading.names])
- r = '(%s) in (SELECT %s FROM %s)' % (commonAttrs, commonAttrs, r.sql)
+ r = r.restrictions
+ if isinstance(r, dict) or isinstance(r, np.void):
+ r = make_condition(r)
+ elif isinstance(r, np.ndarray) or isinstance(r, list):
+ r = '('+') OR ('.join([make_condition(q) for q in r])+')'
+ elif isinstance(r, Relation):
+ common_attributes = ','.join([q for q in self.heading.names if r.heading.names])
+ r = '(%s) in (SELECT %s FROM %s)' % (common_attributes, common_attributes, r.sql)
- assert isinstance(r,str), 'condition must be converted into a string'
+ assert isinstance(r, str), 'condition must be converted into a string'
r = '('+r+')'
if negate:
- r = 'NOT '+r;
- condStr.append(r)
+ r = 'NOT '+r
+ condition_string.append(r)
- return ' WHERE ' + ' AND '.join(condStr)
+ return ' WHERE ' + ' AND '.join(condition_string)
class Not:
- "inverse of a restriction"
- def __init__(self,restriction):
+ """
+ inverse of a restriction
+ """
+ def __init__(self, restriction):
self._restriction = restriction
-class Join(_Relational):
+class Join(Relation):
+ alias_counter = 0
- aliasCounter = 0
-
- def __init__(self,rel1,rel2):
- if not isinstance(rel2,_Relational):
+ def __init__(self, rel1, rel2):
+ if not isinstance(rel2, Relation):
raise DataJointError('relvars can only be joined with other relvars')
if rel1.conn is not rel2.conn:
raise DataJointError('Cannot join relations with different database connections')
self.conn = rel1.conn
- self._rel1 = rel1;
- self._rel2 = rel2;
+ self._rel1 = rel1
+ self._rel2 = rel2
@property
def heading(self):
@@ -174,23 +248,22 @@ def heading(self):
@property
def sql(self):
- Join.aliasCounter += 1
- return '%s NATURAL JOIN %s as `j%x`' % (self._rel1.sql, self._rel2.sql, Join.aliasCounter)
+ Join.alias_counter += 1
+ return '%s NATURAL JOIN %s as `j%x`' % (self._rel1.sql, self._rel2.sql, Join.alias_counter)
-class Projection(_Relational):
+class Projection(Relation):
+ alias_counter = 0
- aliasCounter = 0
+ def __init__(self, relation, *attributes, **renames):
+ """
+ See Relation.project()
+ """
+ self.conn = relation.conn
+ self._relation = relation
+ self._projection_attributes = attributes
+ self._renamed_attributes = renames
- def __init__(self, rel, *arg, _sub, **kwarg):
- if _sub and isinstance(_sub, _Relational):
- raise DataJointError('Relational join must receive two relations')
- self.conn = rel.conn
- self._rel = rel
- self._sub = _sub
- self._selection = arg
- self._renames = kwarg
-
@property
def sql(self):
return self._rel.sql
@@ -200,19 +273,18 @@ def heading(self):
return self._rel.heading.pro(*self._selection, **self._renames)
-class Subquery(_Relational):
-
- aliasCounter = 0;
+class Subquery(Relation):
+ alias_counter = 0
def __init__(self, rel):
- self.conn = rel.conn;
- self._rel = rel;
+ self.conn = rel.conn
+ self._rel = rel
@property
def sql(self):
- self.aliasCounter = self.aliasCounter + 1;
- return '(SELECT ' + self._rel.heading.asSQL + ' FROM ' + self._rel.sql + ') as `s%x`' % self.aliasCounter
+ self.alias_counter += 1
+ return '(SELECT ' + self._rel.heading.as_sql + ' FROM ' + self._rel.sql + ') as `s%x`' % self.alias_counter
@property
def heading(self):
- return self._rel.heading.resolveComputations()
+ return self._rel.heading.resolve_computations()
\ No newline at end of file
diff --git a/datajoint/settings.py b/datajoint/settings.py
index 9ad4d97c6..22afcadb6 100644
--- a/datajoint/settings.py
+++ b/datajoint/settings.py
@@ -8,10 +8,22 @@
__author__ = 'eywalker'
import logging
import collections
+from enum import Enum
validators = collections.defaultdict(lambda: lambda value: True)
+Role = Enum('Role', 'manual lookup imported computed job')
+role_to_prefix = {
+ Role.manual: '',
+ Role.lookup: '#',
+ Role.imported: '_',
+ Role.computed: '__',
+ Role.job: '~'
+}
+prefix_to_role = dict(zip(role_to_prefix.values(), role_to_prefix.keys()))
+
+
default = {
'database.host': 'localhost',
'database.password': 'datajoint',
@@ -24,6 +36,7 @@
'config.varname': 'DJ_LOCAL_CONF'
}
+
class Config(collections.MutableMapping):
"""
Stores datajoint settings. Behaves like a dictionary, but applies validator functions
@@ -31,7 +44,6 @@ class Config(collections.MutableMapping):
The default parameters are stored in datajoint.settings.default . If a local config file
exists, the settings specified in this file override the default settings.
-
"""
def __init__(self, *args, **kwargs):
@@ -90,4 +102,4 @@ def load(self, filename):
#############################################################################
logger = logging.getLogger()
-logger.setLevel(logging.DEBUG) #set package wide logger level TODO:make this respond to environmental variable
+logger.setLevel(logging.DEBUG) #set package wide logger level TODO:make this respond to environmental variable
\ No newline at end of file
diff --git a/datajoint/table.py b/datajoint/table.py
new file mode 100644
index 000000000..d1c6c3bb9
--- /dev/null
+++ b/datajoint/table.py
@@ -0,0 +1,228 @@
+import numpy as np
+import logging
+from . import DataJointError
+from .relational import Relation
+from .declare import declare
+
+logger = logging.getLogger(__name__)
+
+
+class Table(Relation):
+ """
+ A Table object is a relation associated with a table.
+ A Table object provides insert and delete methods.
+ Table objects are only used internally and for debugging.
+ The table must already exist in the schema for its Table object to work.
+
+ The table associated with an instance of Base is identified by its 'class name'.
+ property, which is a string in CamelCase. The actual table name is obtained
+ by converting className from CamelCase to underscore_separated_words and
+ prefixing according to the table's role.
+
+ Base instances obtain their table's heading by looking it up in the connection
+ object. This ensures that Base instances contain the current table definition
+ even after tables are modified after the instance is created.
+ """
+
+ def __init__(self, conn=None, dbname=None, class_name=None, definition=None):
+ self._use_package = False
+ self.class_name = class_name
+ self.conn = conn
+ self.dbname = dbname
+ self.conn.load_headings(self.dbname)
+
+ if dbname not in self.conn.db_to_mod:
+ # register with a fake module, enclosed in back quotes
+ self.conn.bind('`{0}`'.format(dbname), dbname)
+
+ #TODO: delay the loading until first use (move out of __init__)
+ self.conn.load_headings()
+ if self.class_name not in self.conn.table_names[self.dbname]:
+ if definition is None:
+ raise DataJointError('The table is not declared')
+ else:
+ declare(conn, definition, class_name)
+
+ @property
+ def sql(self):
+ return self.full_table_name
+
+ @property
+ def heading(self):
+ return self.conn.headings[self.dbname][self.table_name]
+
+ @property
+ def full_table_name(self):
+ """
+ :return: full name of the associated table
+ """
+ return '`%s`.`%s`' % (self.dbname, self.table_name)
+
+ @property
+ def table_name(self):
+ """
+ :return: name of the associated table
+ """
+ return self.conn.table_names[self.dbname][self.class_name]
+
+
+ @property
+ def full_class_name(self):
+ """
+ :return: full class name
+ """
+ return '{}.{}'.format(self.__module__, self.class_name)
+
+ @property
+ def primary_key(self):
+ """
+ :return: primary key of the table
+ """
+ return self.heading.primary_key
+
+ def insert(self, tup, ignore_errors=False, replace=False): # TODO: in progress (issue #8)
+ """
+ Insert one data tuple.
+
+ :param tup: Data tuple. Can be an iterable in matching order, a dict with named fields, or an np.void.
+ :param ignore_errors=False: Ignores errors if True.
+ :param replace=False: Replaces data tuple if True.
+
+ Example::
+
+ b = djtest.Subject()
+ b.insert(dict(subject_id = 7, species="mouse",\\
+ real_id = 1007, date_of_birth = "2014-09-01"))
+ """
+ # todo: do we support records and named tuples for tup?
+
+ if issubclass(type(tup), tuple) or issubclass(type(tup), list):
+ value_list = ','.join([repr(q) for q in tup])
+ attribute_list = '`'+'`,`'.join(self.heading.names[0:len(tup)]) + '`'
+ elif issubclass(type(tup), dict):
+ value_list = ','.join([repr(tup[q])
+ for q in self.heading.names if q in tup])
+ attribute_list = '`' + '`,`'.join([q for q in self.heading.names if q in tup]) + '`'
+ elif issubclass(type(tup), np.void):
+ value_list = ','.join([repr(tup[q])
+ for q in self.heading.names if q in tup])
+ attribute_list = '`' + '`,`'.join(tup.dtype.fields) + '`'
+ else:
+ raise DataJointError('Datatype %s cannot be inserted' % type(tup))
+ if replace:
+ sql = 'REPLACE'
+ elif ignore_errors:
+ sql = 'INSERT IGNORE'
+ else:
+ sql = 'INSERT'
+ sql += " INTO %s (%s) VALUES (%s)" % (self.full_table_name,
+ attribute_list, value_list)
+ logger.info(sql)
+ self.conn.query(sql)
+
+ def delete(self): # TODO: (issues #14 and #15)
+ pass
+
+ def drop(self):
+ """
+ Drops the table associated to this object.
+ """
+ # TODO: make cascading (issue #16)
+ self.conn.query('DROP TABLE %s' % self.full_table_name)
+ self.conn.clear_dependencies(dbname=self.dbname)
+ self.conn.load_headings(dbname=self.dbname, force=True)
+ logger.debug("Dropped table %s" % self.full_table_name)
+
+ def set_table_comment(self, comment):
+ """
+ Update the table comment in the table definition.
+
+ :param comment: new comment as string
+
+ """
+ # TODO: add verification procedure (github issue #24)
+ self.alter('COMMENT="%s"' % comment)
+
+ def add_attribute(self, definition, after=None):
+ """
+ Add a new attribute to the table. A full line from the table definition
+ is passed in as definition.
+
+ The definition can specify where to place the new attribute. Use after=None
+ to add the attribute as the first attribute or after='attribute' to place it
+ after an existing attribute.
+
+ :param definition: table definition
+ :param after=None: After which attribute of the table the new attribute is inserted.
+ If None, the attribute is inserted in front.
+ """
+ position = ' FIRST' if after is None else (
+ ' AFTER %s' % after if after else '')
+ sql = self.field_to_sql(parse_attribute_definition(definition))
+ self._alter('ADD COLUMN %s%s' % (sql[:-2], position))
+
+ def drop_attribute(self, attr_name):
+ """
+ Drops the attribute attrName from this table.
+
+ :param attr_name: Name of the attribute that is dropped.
+ """
+ self._alter('DROP COLUMN `%s`' % attr_name)
+
+ def alter_attribute(self, attr_name, new_definition):
+ """
+ Alter the definition of the field attr_name in this table using the new definition.
+
+ :param attr_name: field that is redefined
+ :param new_definition: new definition of the field
+ """
+ sql = self.field_to_sql(parse_attribute_definition(new_definition))
+ self._alter('CHANGE COLUMN `%s` %s' % (attr_name, sql[:-2]))
+
+ def erd(self, subset=None, prog='dot'):
+ """
+ Plot the schema's entity relationship diagram (ERD).
+ The layout programs can be 'dot' (default), 'neato', 'fdp', 'sfdp', 'circo', 'twopi'
+ """
+ if not subset:
+ g = self.graph
+ else:
+ g = self.graph.copy()
+ # todo: make erd work (github issue #7)
+ """
+ g = self.graph
+ else:
+ g = self.graph.copy()
+ for i in g.nodes():
+ if i not in subset:
+ g.remove_node(i)
+ def tablelist(tier):
+ return [i for i in g if self.tables[i].tier==tier]
+
+ pos=nx.graphviz_layout(g,prog=prog,args='')
+ plt.figure(figsize=(8,8))
+ nx.draw_networkx_edges(g, pos, alpha=0.3)
+ nx.draw_networkx_nodes(g, pos, nodelist=tablelist('manual'),
+ node_color='g', node_size=200, alpha=0.3)
+ nx.draw_networkx_nodes(g, pos, nodelist=tablelist('computed'),
+ node_color='r', node_size=200, alpha=0.3)
+ nx.draw_networkx_nodes(g, pos, nodelist=tablelist('imported'),
+ node_color='b', node_size=200, alpha=0.3)
+ nx.draw_networkx_nodes(g, pos, nodelist=tablelist('lookup'),
+ node_color='gray', node_size=120, alpha=0.3)
+ nx.draw_networkx_labels(g, pos, nodelist = subset, font_weight='bold', font_size=9)
+ nx.draw(g,pos,alpha=0,with_labels=false)
+ plt.show()
+ """
+
+ def _alter(self, alter_statement):
+ """
+ Execute ALTER TABLE statement for this table. The schema
+ will be reloaded within the connection object.
+
+ :param alter_statement: alter statement
+ """
+ sql = 'ALTER TABLE %s %s' % (self.full_table_name, alter_statement)
+ self.conn.query(sql)
+ self.conn.load_headings(self.dbname, force=True)
+ # TODO: place table definition sync mechanism
\ No newline at end of file
diff --git a/datajoint/task.py b/datajoint/task.py
deleted file mode 100644
index 156c46281..000000000
--- a/datajoint/task.py
+++ /dev/null
@@ -1,53 +0,0 @@
-import queue
-import threading
-
-
-def _ping():
- print("The task thread is running")
-
-
-class TaskQueue:
- """
- Executes tasks in a single parallel thread in FIFO sequence.
- Example:
- queue = TaskQueue()
- queue.submit(func1, arg1, arg2, arg3)
- queue.submit(func2)
- queue.quit() # wait until the last task is done and stop thread
-
- Datajoint applications may use a task queue for delayed inserts.
- """
- def __init__(self):
- self.queue = queue.Queue()
- self.thread = threading.Thread(target=self._worker)
- self.thread.daemon = True
- self.thread.start()
-
- def empty(self):
- return self.queue.empty()
-
- def submit(self, func=_ping, *args):
- """Submit task for execution"""
- self.queue.put((func, args))
-
- def quit(self, timeout=3.0):
- """Wait until all tasks finish"""
- self.queue.put('quit')
- self.thread.join(timeout)
- if self.thread.isAlive():
- raise Exception('Task thread is still executing. Try quitting again.')
-
- def _worker(self):
- while True:
- msg = self.queue.get()
- if msg=='quit':
- self.queue.task_done()
- break
- fun, args = msg
- try:
- fun(*args)
- except Exception as e:
- print("Exception in the task thread:")
- print(e)
- self.queue.task_done()
-
diff --git a/datajoint/utils.py b/datajoint/utils.py
index 9d1bf85de..47aacdeeb 100644
--- a/datajoint/utils.py
+++ b/datajoint/utils.py
@@ -1,13 +1,7 @@
import re
-# package-wide settings that control execution
-
-# setup root logger
from . import DataJointError
-
-
-
def to_camel_case(s):
"""
Convert names with under score (_) separation
@@ -24,7 +18,7 @@ def to_upper(matchobj):
def from_camel_case(s):
"""
- Conver names in camel case into underscore
+ Convert names in camel case into underscore
(_) separated names
Example:
@@ -37,7 +31,8 @@ def from_camel_case(s):
raise DataJointError('String cannot begin with a digit')
if not re.match(r'^[a-zA-Z0-9]*$', s):
raise DataJointError('String can only contain alphanumeric characters')
+
def conv(matchobj):
return ('_' if matchobj.groups()[0] else '') + matchobj.group(0).lower()
- return re.sub(r'(\B[A-Z])|(\b[A-Z])', conv, s)
+ return re.sub(r'(\B[A-Z])|(\b[A-Z])', conv, s)
\ No newline at end of file
diff --git a/demos/demo1.py b/demos/demo1.py
index 18e54a4bf..689905730 100644
--- a/demos/demo1.py
+++ b/demos/demo1.py
@@ -6,6 +6,7 @@
"""
import datajoint as dj
+import os
print("Welcome to the database 'demo1'")
@@ -13,11 +14,9 @@
conn.bind(module=__name__, dbname='dj_test') # bind this module to the database
-
class Subject(dj.Base):
- _table_def = """
+ definition = """
demo1.Subject (manual) # Basic subject info
-
subject_id : int # internal subject id
---
real_id : varchar(40) # real-world name
@@ -28,41 +27,49 @@ class Subject(dj.Base):
animal_notes="" : varchar(4096) # strain, genetic manipulations, etc
"""
-class Exp2(dj.Base):
- pass
class Experiment(dj.Base):
- _table_def = """
+ definition = """
demo1.Experiment (manual) # Basic subject info
-
-> demo1.Subject
experiment : smallint # experiment number for this subject
---
+ experiment_folder : varchar(255) # folder path
experiment_date : date # experiment start date
experiment_notes="" : varchar(4096)
experiment_ts=CURRENT_TIMESTAMP : timestamp # automatic timestamp
"""
-class TwoPhotonSession(dj.Base):
- _table_def = """
- demo1.TwoPhotonSession (manual) # a two-photon imaging session
-
+class Session(dj.Base):
+ definition = """
+ demo1.Session (manual) # a two-photon imaging session
-> demo1.Experiment
- tp_session : tinyint # two-photon session within this experiment
- ----
+ session_id : tinyint # two-photon session within this experiment
+ -----------
setup : tinyint # experimental setup
- lens : tinyint # lens e.g.: 10x, 20x. 25x, 60x
+ lens : tinyint # lens e.g.: 10x, 20x, 25x, 60x
"""
-class EphysSetup(dj.Base):
- _table_def = """
- demo1.EphysSetup (manual) # Ephys setup
- setup_id : tinyint # unique seutp id
+
+
+class Scan(dj.Base):
+ definition = """
+ demo1.Scan (manual) # a two-photon imaging session
+ -> demo1.Session
+ scan_id : tinyint # two-photon session within this experiment
+ ----
+ depth : float # depth from surface
+ wavelength : smallint # (nm) laser wavelength
+ mwatts: numeric(4,1) # (mW) laser power to brain
"""
-class EphysExperiment(dj.Base):
- _table_def = """
- demo1.EphysExperiment (manual) # Ephys experiment
- -> demo1.Subject
- -> demo1.EphysSetup
- """
\ No newline at end of file
+
+class ScanInfo(dj.Base, dj.AutoPopulate):
+ definition = None
+ pop_rel = Session
+
+ def make_tuples(self, key):
+ info = (Session()*Scan() & key).pro('experiment_folder').fetch()
+ filename = os.path.join(info.experiment_folder, 'scan_%03', )
+
+
diff --git a/demos/rundemo1.py b/demos/rundemo1.py
index ece931611..88a8f4502 100644
--- a/demos/rundemo1.py
+++ b/demos/rundemo1.py
@@ -8,35 +8,35 @@
import demo1
s = demo1.Subject()
-# insert as dict
+e = demo1.Experiment()
+
s.insert(dict(subject_id=1,
- real_id='George',
+ real_id="George",
species="monkey",
date_of_birth="2011-01-01",
sex="M",
caretaker="Arthur",
animal_notes="this is a test"))
+
s.insert(dict(subject_id=2,
real_id='1373',
date_of_birth="2014-08-01",
caretaker="Joe"))
-# insert as tuple. Attributes must be in the same order as in table declaration
-s.insert((3,'Dennis','monkey','2012-09-01'))
-
-# TODO: insert as ndarray
+s.insert((3, 'Dennis', 'monkey', '2012-09-01'))
+s.insert((12430, 'C0430', 'mouse', '2012-09-01', 'M'))
+s.insert((12431, 'C0431', 'mouse', '2012-09-01', 'F'))
print('inserted keys into Subject:')
for key in s:
print(key)
+e.insert(dict(subject_id=1,
+ experiment=1,
+ experiment_date="2014-08-28",
+ experiment_notes="my first experiment"))
-#
-e = demo1.Experiment()
-e.insert(dict(subject_id=1,experiment=1,experiment_date="2014-08-28",experiment_notes="my first experiment"))
-e.insert(dict(subject_id=1,experiment=2,experiment_date="2014-08-28",experiment_notes="my second experiment"))
-
-
-# drop the tables
-#s.drop
-#e.drop
+e.insert(dict(subject_id=1,
+ experiment=2,
+ experiment_date="2014-08-28",
+ experiment_notes="my second experiment"))