Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion datajoint/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from . import DataJointError
from .table import Table
import logging
from .declare import declare

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -42,16 +43,19 @@ def __init__(self):
self.class_name = self.__class__.__name__
module = self.__module__
mod_obj = importlib.import_module(module)
use_package = False
try:
conn = mod_obj.conn
except AttributeError:
try:
# check if database bound at the package level instead
pkg_obj = importlib.import_module(mod_obj.__package__)
conn = pkg_obj.conn
use_package = True
except AttributeError:
raise DataJointError(
"Please define object 'conn' in '{}' or in its containing package.".format(self.__module__))
self.conn = conn
try:
if use_package:
pkg_name = '.'.join(module.split('.')[:-1])
Expand All @@ -61,7 +65,8 @@ def __init__(self):
except KeyError:
raise DataJointError(
'Module {} is not bound to a database. See datajoint.connection.bind'.format(self.__module__))
super().__init__(self, conn=conn, dbname=dbname, class_name=self.__class__.__name__)
declare(self.conn, self.definition, self.full_class_name)
super().__init__(conn=conn, dbname=dbname, class_name=self.__class__.__name__)


@classmethod
Expand Down
28 changes: 15 additions & 13 deletions datajoint/declare.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@ def declare(conn, definition, class_name):
"""
Declares the table in the data base if no table in the database matches this object.
"""
table_info, parents, referenced, field_definitions, index_definitions = parse_declaration(definition)
table_info, parents, referenced, field_definitions, index_definitions = _parse_declaration(conn, definition)
defined_name = table_info['module'] + '.' + table_info['className']
if not defined_name == class_name:
raise DataJointError('Table name {} does not match the declared'
'name {}'.format(class_name, defined_name))
# TODO: clean up this mess... currently just ignoring the name used to define the table
#if not defined_name == class_name:
# raise DataJointError('Table name {} does not match the declared'
# 'name {}'.format(class_name, defined_name))

# compile the CREATE TABLE statement
table_name = role_to_prefix[table_info['tier']] + from_camel_case(class_name)
table_name = role_to_prefix[table_info['tier']] + from_camel_case(defined_name)
sql = 'CREATE TABLE `%s`.`%s` (\n' % (self.dbname, table_name)

# add inherited primary key fields
Expand Down Expand Up @@ -88,23 +89,24 @@ def declare(conn, definition, class_name):
sql[:-2], table_info['comment'])

# make sure that the table does not alredy exist
self.conn.load_headings(self.dbname, force=True)
# TODO: there will be a problem with resolving the module here...
conn.load_headings(self.dbname, force=True)
if not self.is_declared:
# execute declaration
logger.debug('\n<SQL>\n' + sql + '</SQL>\n\n')
self.conn.query(sql)
self.conn.load_headings(self.dbname, force=True)


def _parse_declaration(self):
def _parse_declaration(conn, definition):
"""
Parse declaration and create new SQL table accordingly.
"""
parents = []
referenced = []
index_defs = []
field_defs = []
declaration = re.split(r'\s*\n\s*', self.definition.strip())
declaration = re.split(r'\s*\n\s*', definition.strip())

# remove comment lines
declaration = [x for x in declaration if not x.startswith('#')]
Expand Down Expand Up @@ -205,11 +207,11 @@ def parse_attribute_definition(line, in_key=False): # todo add docu for in_key
assert (not re.match(r'^bigint', attr_info['type'], re.I) or not attr_info['nullable']), \
'BIGINT attributes cannot be nullable in "%s"' % line

attr_info['isKey'] = in_key
attr_info['isAutoincrement'] = None
attr_info['isNumeric'] = None
attr_info['isString'] = None
attr_info['isBlob'] = None
attr_info['in_key'] = in_key
attr_info['autoincrement'] = None
attr_info['numeric'] = None
attr_info['string'] = None
attr_info['is_blob'] = None
attr_info['computation'] = None
attr_info['dtype'] = None

Expand Down
3 changes: 3 additions & 0 deletions datajoint/heading.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ def values(self):
def items(self):
return self.attributes.items()

def __iter__(self):
return iter(self.attributes)

@classmethod
def init_from_database(cls, conn, dbname, table_name):
"""
Expand Down
29 changes: 16 additions & 13 deletions datajoint/relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,8 @@ def project(self, *selection, **aliases):
Each attribute can only be used once in attributes or renames. Therefore, the projected
relation cannot have more attributes than the original relation.
"""
return self.aggregate(
group=selection.pop[0] if selection and isinstance(selection[0], Relation) else None,
*selection, **aliases)
group = selection.pop[0] if selection and isinstance(selection[0], Relation) else None
return self.aggregate(group, *selection, **aliases)

def aggregate(self, group, *selection, **aliases):
"""
Expand All @@ -79,8 +78,9 @@ def aggregate(self, group, *selection, **aliases):
if group is not None and not isinstance(group, Relation):
raise DataJointError('The second argument of aggregate must be a relation')
# convert the string notation for aliases to

return Projection(group=group, *attriutes, **aliases)
# handling of the variable group is unclear here
# and thus ommitted
return Projection(self, *selection, **aliases)

def __iand__(self, restriction):
"""
Expand Down Expand Up @@ -121,6 +121,9 @@ def count(self):
cur = self.conn.query(sql)
return cur.fetchone()[0]

def fetch(self, *args, **kwargs):
return self(*args, **kwargs)

def __call__(self, offset=0, limit=None, order_by=None, descending=False):
"""
fetches the relation from the database table into an np.array and unpacks blob attributes.
Expand All @@ -131,7 +134,7 @@ def __call__(self, offset=0, limit=None, order_by=None, descending=False):
:return: the contents of the relation in the form of a structured numpy.array
"""
cur = self.cursor(offset, limit, order_by, descending)
ret = np.array(list(cur), dtype=self.heading.asdtype)
ret = np.array(list(cur), dtype=self.heading.as_dtype)
for f in self.heading.blobs:
for i in range(len(ret)):
ret[i][f] = unpack(ret[i][f])
Expand Down Expand Up @@ -160,12 +163,12 @@ def cursor(self, offset=0, limit=None, order_by=None, descending=False):
return self.conn.query(sql)

def __repr__(self):
limit = 13
limit = 13 #TODO: move some of these display settings into the config
width = 12
template = '%%-%d.%ds' % (width, width)
repr_string = ' '.join([template % column for column in header]) + '\n'
repr_string += ' '.join(['+' + '-' * (width - 2) + '+' for _ in header]) + '\n'
tuples = self.pro(*self.heading.non_blobs).fetch(limit=limit)
repr_string = ' '.join([template % column for column in self.heading]) + '\n'
repr_string += ' '.join(['+' + '-' * (width - 2) + '+' for _ in self.heading]) + '\n'
tuples = self.project(*self.heading.non_blobs).fetch(limit=limit)
for tup in tuples:
repr_string += ' '.join([template % column for column in tup]) + '\n'
if self.count > limit:
Expand All @@ -177,7 +180,7 @@ def __iter__(self):
"""
iterator yields primary key tuples
"""
cur, h = self.pro().cursor()
cur, h = self.project().cursor()
q = cur.fetchone()
while q:
yield np.array([q, ], dtype=h.asdtype)
Expand Down Expand Up @@ -266,11 +269,11 @@ def __init__(self, relation, *attributes, **renames):

@property
def sql(self):
return self._rel.sql
return self._relation.sql

@property
def heading(self):
return self._rel.heading.pro(*self._selection, **self._renames)
return self._relation.heading.pro(*self._projection_attributes, **self._renamed_attributes)


class Subquery(Relation):
Expand Down
2 changes: 1 addition & 1 deletion datajoint/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from . import DataJointError
from .relational import Relation
from .declare import declare
from .declare import (declare, parse_attribute_definition)

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion tests/schemata/schema1/test1.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


class Subjects(dj.Base):
_table_def = """
definition = """
test1.Subjects (manual) # Basic subject info

subject_id : int # unique subject id
Expand Down
2 changes: 1 addition & 1 deletion tests/schemata/schema1/test2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


class Experiments(dj.Base):
_table_def = """
definition = """
test2.Experiments (manual) # Basic subject info
-> test1.Subjects
experiment_id : int # unique experiment id
Expand Down
2 changes: 1 addition & 1 deletion tests/schemata/schema1/test3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


class Subjects(dj.Base):
_table_def = """
definition = """
test3.Subjects (manual) # Basic subject info

subject_id : int # unique subject id
Expand Down