Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions datajoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
__author__ = "Dimitri Yatsenko, Edgar Walker, and Fabian Sinz at Baylor College of Medicine"
__version__ = "0.2"
__all__ = ['__author__', '__version__',
'Connection', 'Heading', 'Relation', 'Not',
'Connection', 'Heading', 'Relation', 'FreeRelation', 'Not',
'AutoPopulate', 'conn', 'DataJointError', 'blob']


Expand All @@ -20,11 +20,8 @@ class DataJointError(Exception):
config = Config()
local_config_file = os.environ.get(CONFIGVAR, None)
if local_config_file is None:
local_config_file = os.path.expanduser(LOCALCONFIG)
else:
local_config_file = os.path.expanduser(local_config_file)


local_config_file = LOCALCONFIG
local_config_file = os.path.expanduser(local_config_file)

try:
logger.log(logging.INFO, "Loading local settings from {0:s}".format(local_config_file))
Expand All @@ -43,6 +40,4 @@ class DataJointError(Exception):
from . import blob
from .relational_operand import Not
from .free_relation import FreeRelation


#############################################################################
from .heading import Heading
11 changes: 5 additions & 6 deletions datajoint/blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
mxClassID = collections.OrderedDict(
# see http://www.mathworks.com/help/techdoc/apiref/mxclassid.html
mxUNKNOWN_CLASS=None,
mxCELL_CLASS=None, # not implemented
mxSTRUCT_CLASS=None, # not implemented
mxCELL_CLASS=None, # TODO: implement
mxSTRUCT_CLASS=None, # TODO: implement
mxLOGICAL_CLASS=np.dtype('bool'),
mxCHAR_CLASS=np.dtype('c'),
mxVOID_CLASS=None,
Expand Down Expand Up @@ -48,10 +48,9 @@ def pack(obj):
if is_complex:
blob += imaginary.tostring()

if len(blob) > 1000:
compressed = b'ZL123\0'+np.asarray(len(blob), dtype=np.uint64).tostring() + zlib.compress(blob)
if len(compressed) < len(blob):
blob = compressed
compressed = b'ZL123\0' + np.uint64(len(blob)).tostring() + zlib.compress(blob)
if len(compressed) < len(blob):
blob = compressed
return blob


Expand Down
7 changes: 3 additions & 4 deletions datajoint/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .utils import to_camel_case
from . import DataJointError
from .heading import Heading
from .settings import prefix_to_role, DEFAULT_PORT
from .settings import prefix_to_role
import logging
from .erd import DBConnGraph
from . import config
Expand Down Expand Up @@ -74,7 +74,6 @@ def is_active(self):
"""
return self.conn.is_connected and self.conn.in_transaction


def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is None and exc_val is None and exc_tb is None:
self.conn._commit_transaction()
Expand All @@ -85,7 +84,6 @@ def __exit__(self, exc_type, exc_val, exc_tb):
logger.debug("Transaction cancled because of an error.", exc_info=(exc_type, exc_val, exc_tb))



class Connection(object):
"""
A dj.Connection object manages a connection to a database server.
Expand All @@ -108,9 +106,10 @@ def __init__(self, host, user, passwd, init_fun=None):
port = config['database.port']
self.conn_info = dict(host=host, port=port, user=user, passwd=passwd)
self._conn = pymysql.connect(init_command=init_fun, **self.conn_info)
# TODO Do something if connection cannot be established
if self.is_connected:
print("Connected", user + '@' + host + ':' + str(port))
else:
raise DataJointError('Connection failed.')
self._conn.autocommit(True)

self.db_to_mod = {} # modules indexed by dbnames
Expand Down
36 changes: 18 additions & 18 deletions datajoint/free_relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@ def drop(self):
Drops the table associated to this object.
"""
# TODO: make cascading (issue #16)

if self.is_declared:
self.conn.query('DROP TABLE %s' % self.full_table_name)
self.conn.clear_dependencies(dbname=self.dbname)
Expand Down Expand Up @@ -443,24 +442,23 @@ def _parse_declaration(self):
table_info['tier'] = Role[table_info['tier']] # convert into enum

in_key = True # parse primary keys
field_ptrn = """
attribute_regexp = re.compile("""
^[a-z][a-z\d_]*\s* # name
(=\s*\S+(\s+\S+)*\s*)? # optional defaults
:\s*\w.*$ # type, comment
"""
fieldP = re.compile(field_ptrn, re.I + re.X) # ignore case and verbose
""", re.I + re.X) # ignore case and verbose

for line in declaration[1:]:
if line.startswith('---'):
in_key = False # start parsing non-PK fields
elif line.startswith('->'):
# foreign key
module_name, class_name = line[2:].strip().split('.')
rel = self.get_base(module_name, class_name)
(parents if in_key else referenced).append(rel)
elif re.match(r'^(unique\s+)?index[^:]*$', line):
ref = parents if in_key else referenced
ref.append(self.get_base(module_name, class_name))
elif re.match(r'^(unique\s+)?index[^:]*$', line, re.I):
index_defs.append(self._parse_index_def(line))
elif fieldP.match(line):
elif attribute_regexp.match(line):
field_defs.append(parse_attribute_definition(line, in_key))
else:
raise DataJointError(
Expand All @@ -486,7 +484,8 @@ def parse_attribute_definition(line, in_key=False):
(\#\s*(?P<comment>\S*(\s+\S+)*)\s*)?$ # comment
""", re.X)
m = attribute_regexp.match(line)
assert m, 'Invalid field declaration "%s"' % line
if not m:
raise DataJointError('Invalid field declaration "%s"' % line)
attr_info = m.groupdict()
if not attr_info['comment']:
attr_info['comment'] = ''
Expand All @@ -496,12 +495,13 @@ def parse_attribute_definition(line, in_key=False):
assert (not re.match(r'^bigint', attr_info['type'], re.I) or not attr_info['nullable']), \
'BIGINT attributes cannot be nullable in "%s"' % line

attr_info['in_key'] = in_key
attr_info['autoincrement'] = None
attr_info['numeric'] = None
attr_info['string'] = None
attr_info['is_blob'] = None
attr_info['computation'] = None
attr_info['dtype'] = None

return Heading.AttrTuple(**attr_info)
return Heading.AttrTuple(
in_key=in_key,
autoincrement=None,
numeric=None,
string=None,
is_blob=None,
computation=None,
dtype=None,
**attr_info
)
35 changes: 13 additions & 22 deletions datajoint/relational_operand.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,20 +147,18 @@ def fetch(self, offset=0, limit=None, order_by=None, descending=False, as_dict=F
for n, e in zip(attr_names, t)) for t in cur.fetchall()]
else:
ret = np.array(list(cur.fetchall()), dtype=self.heading.as_dtype)
for bname in self.heading.blobs:
ret[bname] = list(map(unpack, ret[bname]))
for blob_name in self.heading.blobs:
ret[blob_name] = list(map(unpack, ret[blob_name]))
return ret

def cursor(self, offset=0, limit=None, order_by=None, descending=False):
"""
:param offset: the number of tuples to skip in the returned result
:param limit: the maximum number of tuples to return
:param order_by: the list of attributes to order the results. No ordering should be assumed if order_by=None.
:param descending: the list of attributes to order the results
Return query cursor.
See Relation.fetch() for input description.
:return: cursor to the query
"""
if offset and limit is None:
raise DataJointError('offset cannot be set without setting a limit')
raise DataJointError('limit is required when offset is set')
sql = self.make_select()
if order_by is not None:
sql += ' ORDER BY ' + ', '.join(order_by)
Expand All @@ -173,16 +171,14 @@ def cursor(self, offset=0, limit=None, order_by=None, descending=False):
logger.debug(sql)
return self.conn.query(sql)



def __repr__(self):
limit = 7 #TODO: move some of these display settings into the config
width = 14
rel = self.project(*self.heading.non_blobs)
template = '%%-%d.%ds' % (width, width)
columns = rel.heading.names
repr_string = ' '.join([template % column for column in columns]) + '\n'
repr_string += ' '.join(['+' + '-' * (width - 2) + '+' for _ in columns]) + '\n'
repr_string += ' '.join(['+' + '-'*(width-2) + '+' for _ in columns]) + '\n'
for tup in rel.fetch(limit=limit):
repr_string += ' '.join([template % column for column in tup]) + '\n'
if self.count > limit:
Expand All @@ -193,20 +189,15 @@ def __repr__(self):
def __iter__(self):
"""
Iterator that yields individual tuples of the current table dictionaries.


:param offset: parameter passed to the :func:`cursor`
:param limit: parameter passed to the :func:`cursor`
:param order_by: parameter passed to the :func:`cursor`
:param descending: parameter passed to the :func:`cursor`
"""
cur = self.cursor()
do_unpack = tuple(h in self.heading.blobs for h in self.heading.names)
q = cur.fetchone()
while q:
yield dict( (fieldname,unpack(field)) if up else (fieldname,field)
for fieldname, up, field in zip(self.heading.names, do_unpack, q))
q = cur.fetchone()
heading = self.heading # construct once for efficiency
do_unpack = tuple(h in heading.blobs for h in self.heading.names)
values = cur.fetchone()
while values:
yield {field_name: unpack(value) if up else value
for field_name, up, value in zip(heading.names, do_unpack, values)}
values = cur.fetchone()

@property
def where_clause(self):
Expand Down
4 changes: 2 additions & 2 deletions datajoint/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,12 @@
import pprint
from collections import OrderedDict

__author__ = 'eywalker'
import logging
import collections
from enum import Enum

LOCALCONFIG = 'dj_local_conf.json'
CONFIGVAR = 'DJ_LOCAL_CONF'
DEFAULT_PORT = 3306

validators = collections.defaultdict(lambda: lambda value: True)
validators['database.port'] = lambda a: isinstance(a, int)
Expand Down Expand Up @@ -52,9 +50,11 @@

class Borg:
_shared_state = {}

def __init__(self):
self.__dict__ = self._shared_state


class Config(Borg, collections.MutableMapping):
"""
Stores datajoint settings. Behaves like a dictionary, but applies validator functions
Expand Down
15 changes: 5 additions & 10 deletions datajoint/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,16 @@ def to_upper(match):

def from_camel_case(s):
"""
Convert names in camel case into underscore
(_) separated names
Convert names in camel case into underscore (_) separated names

Example:
>>>from_camel_case("TableName")
"table_name"
"""
if re.search(r'\s', s):
raise DataJointError('Input cannot contain white space')
if re.match(r'\d.*', s):
raise DataJointError('Input cannot begin with a digit')
if not re.match(r'^[a-zA-Z0-9]*$', s):
raise DataJointError('String can only contain alphanumeric characters')

def convert(match):
return ('_' if match.groups()[0] else '') + match.group(0).lower()

return re.sub(r'(\B[A-Z])|(\b[A-Z])', convert, s)
if not re.match(r'[A-Z][a-zA-Z0-9]*', s):
raise DataJointError(
'ClassName must be alphanumeric in CamelCase, begin with a capital letter')
return re.sub(r'(\B[A-Z])|(\b[A-Z])', convert, s)
3 changes: 2 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def test_to_camel_case():

def test_from_camel_case():
assert_equal(from_camel_case('AllGroups'), 'all_groups')
assert_equal(from_camel_case('repNames'), 'rep_names')
with assert_raises(DataJointError):
from_camel_case('repNames')
with assert_raises(DataJointError):
from_camel_case('10_all')
with assert_raises(DataJointError):
Expand Down