Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
.ipynb_checkpoints/
*.json
*/.*.swp
*/.*.swo
*/*.pyc
Expand Down
10 changes: 8 additions & 2 deletions datajoint/autopopulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
from . import DataJointError
import pprint
import abc
import logging

#noinspection PyExceptionInherit,PyCallingNonCallable

logger = logging.getLogger(__name__)

class AutoPopulate(metaclass=abc.ABCMeta):
"""
Expand All @@ -29,6 +31,10 @@ def make_tuples(self, key):
"""
pass

@property
def target(self):
return self

def populate(self, catch_errors=False, reserve_jobs=False, restrict=None):
"""
rel.populate() will call rel.make_tuples(key) for every primary key in self.pop_rel
Expand All @@ -38,9 +44,9 @@ def populate(self, catch_errors=False, reserve_jobs=False, restrict=None):
raise DataJointError('')
self.conn.cancel_transaction()

unpopulated = self.pop_rel - self
unpopulated = self.pop_rel - self.target
if not unpopulated.count:
print('Nothing to populate', flush=True) # TODO: use logging?
logger.info('Nothing to populate', flush=True)
if catch_errors:
error_keys, errors = [], []
for key in unpopulated.fetch():
Expand Down
31 changes: 14 additions & 17 deletions datajoint/base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
import importlib
import abc
from types import ModuleType
from enum import Enum
from . import DataJointError
from .table import Table
import logging
import re
from .settings import Role, role_to_prefix
from .utils import from_camel_case
from .heading import Heading


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -49,7 +45,7 @@ def full_class_name(self):
return '{}.{}'.format(self.__module__, self.class_name)

@property
def access_name(self):
def ref_name(self):
"""
:return: name by which this class should be accessible as
"""
Expand All @@ -60,12 +56,12 @@ def access_name(self):
return parent + '.' + self.class_name



def __init__(self): #TODO: support taking in conn obj
self.class_name = self.__class__.__name__
module = self.__module__
mod_obj = importlib.import_module(module)
class_name = self.__class__.__name__
module_name = self.__module__
mod_obj = importlib.import_module(module_name)
self._use_package = False
# first, find the conn object
try:
conn = mod_obj.conn
except AttributeError:
Expand All @@ -76,19 +72,20 @@ def __init__(self): #TODO: support taking in conn obj
self._use_package = True
except AttributeError:
raise DataJointError(
"Please define object 'conn' in '{}' or in its containing package.".format(self.__module__))
self.conn = conn
"Please define object 'conn' in '{}' or in its containing package.".format(module_name))
# now use the conn object to determine the dbname this belongs to
try:
if self._use_package:
# the database is bound to the package
pkg_name = '.'.join(module.split('.')[:-1])
dbname = self.conn.mod_to_db[pkg_name]
pkg_name = '.'.join(module_name.split('.')[:-1])
dbname = conn.mod_to_db[pkg_name]
else:
dbname = self.conn.mod_to_db[module]
dbname = conn.mod_to_db[module_name]
except KeyError:
raise DataJointError(
'Module {} is not bound to a database. See datajoint.connection.bind'.format(self.__module__))
self.dbname = dbname
'Module {} is not bound to a database. See datajoint.connection.bind'.format(module_name))
# initialize using super class's constructor
super().__init__(conn, dbname, class_name)


def get_base(self, module_name, class_name):
Expand Down
47 changes: 26 additions & 21 deletions datajoint/heading.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,19 @@
class Heading:
"""
local class for relations' headings.
Heading contains the property attributes, which is an OrderedDict in which the keys are
the attribute names and the values are AttrTuples.
"""
AttrTuple = namedtuple('AttrTuple',
('name', 'type', 'in_key', 'nullable', 'default', 'comment', 'autoincrement',
'numeric', 'string', 'is_blob', 'computation', 'dtype'))
('name', 'type', 'in_key', 'nullable', 'default',
'comment', 'autoincrement', 'numeric', 'string', 'is_blob',
'computation', 'dtype'))
AttrTuple.as_dict = AttrTuple._asdict # rename the method into a nicer name

def __init__(self, attributes):
# Input: attributes -list of dicts with attribute descriptions
"""
:param attributes: a list of dicts with the same keys as AttrTuple
"""
self.attributes = OrderedDict([(q['name'], Heading.AttrTuple(**q)) for q in attributes])

@property
Expand Down Expand Up @@ -91,7 +97,7 @@ def init_from_database(cls, conn, dbname, table_name):
"""
cur = conn.query(
'SHOW FULL COLUMNS FROM `{table_name}` IN `{dbname}`'.format(
table_name=table_name, dbname=dbname), asDict=True)
table_name=table_name, dbname=dbname), asDict=True)
attributes = cur.fetchall()

rename_map = {
Expand Down Expand Up @@ -145,7 +151,7 @@ def init_from_database(cls, conn, dbname, table_name):
field=attr['type'], dbname=dbname, table_name=table_name))
attr.pop('Extra')

# fill out the dtype. All floats and non-nullable integers are turned into specific dtypes
# fill out dtype. All floats and non-nullable integers are turned into specific dtypes
attr['dtype'] = object
if attr['numeric']:
is_integer = bool(re.match(r'(tiny|small|medium|big)?int', attr['type']))
Expand All @@ -160,30 +166,29 @@ def init_from_database(cls, conn, dbname, table_name):

return cls(attributes)

def pro(self, *attribute_list, **rename_dict):
def project(self, *attribute_list, **renamed_attributes):
"""
derive a new heading by selecting, renaming, or computing attributes.
In relational algebra these operators are known as project, rename, and expand.
The primary key is always included.
"""
# include all if '*' is in attribute_set, always include primary key
attribute_set = set(self.names) if '*' in attribute_list \
else set(attribute_list).union(self.primary_key)

# report missing attributes
missing = attribute_set.difference(self.names)
# check missing attributes
missing = [a for a in attribute_list if a not in self.names]
if missing:
raise DataJointError('Attributes %s are not found' % str(missing))
raise DataJointError('Attributes `%s` are not found' % '`, `'.join(missing))

# always add primary key attributes
attribute_list = self.primary_key + [a for a in attribute_list if a not in self.primary_key]

# make attribute_list a list of dicts for initializing a Heading
attribute_list = [v._asdict() for k, v in self.attributes.items()
if k in attribute_set and k not in rename_dict.values()]
# convert attribute_list into a list of dicts but exclude renamed attributes
attribute_list = [v.as_dict() for k, v in self.attributes.items()
if k in attribute_list and k not in renamed_attributes.values()]

# add renamed and computed attributes
for new_name, computation in rename_dict.items():
for new_name, computation in renamed_attributes.items():
if computation in self.names:
# renamed attribute
new_attr = self.attributes[computation]._asdict()
new_attr = self.attributes[computation].as_dict()
new_attr['name'] = new_name
new_attr['computation'] = '`' + computation + '`'
else:
Expand All @@ -210,14 +215,14 @@ def join(self, other):
join two headings
"""
assert isinstance(other, Heading)
attribute_list = [v._asdict() for v in self.attributes.values()]
attribute_list = [v.as_dict() for v in self.attributes.values()]
for name in other.names:
if name not in self.names:
attribute_list.append(other.attributes[name]._asdict())
attribute_list.append(other.attributes[name].as_dict())
return Heading(attribute_list)

def resolve_computations(self):
"""
Remove computations. To be done after computations have been resolved in a subquery
"""
return Heading([dict(v._asdict(), computation=None) for v in self.attributes.values()])
return Heading([dict(v.as_dict(), computation=None) for v in self.attributes.values()])
Loading