Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions datajoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,17 @@
__author__ = "Dimitri Yatsenko, Edgar Walker, and Fabian Sinz at Baylor College of Medicine"
__version__ = "0.2"
__all__ = ['__author__', '__version__',
'config',
'Connection', 'Heading', 'Relation', 'FreeRelation', 'Not',
'Relation', 'schema',
'Manual', 'Lookup', 'Imported', 'Computed', 'Part',
'conn', 'kill']
'config', 'conn', 'kill',
'Connection', 'Heading', 'Relation', 'FreeRelation', 'Not', 'schema',
'Manual', 'Lookup', 'Imported', 'Computed', 'Part']


# define an object that identifies the primary key in RelationalOperand.__getitem__
class PrimaryKey:
class key:
"""
object that allows requesting the primary key in Fetch.__getitem__
"""
pass

key = PrimaryKey


class DataJointError(Exception):
"""
Expand Down
21 changes: 18 additions & 3 deletions datajoint/autopopulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import abc
import logging
import datetime
import random
from .relational_operand import RelationalOperand
from . import DataJointError
from .relation import FreeRelation
Expand Down Expand Up @@ -52,7 +53,8 @@ def target(self):
"""
return self

def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False):
def populate(self, restriction=None, suppress_errors=False,
reserve_jobs=False, order="original"):
"""
rel.populate() calls rel._make_tuples(key) for every primary key in self.populated_from
for which there is not already a tuple in rel.
Expand All @@ -61,18 +63,31 @@ def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False):
:param suppress_errors: suppresses error if true
:param reserve_jobs: currently not implemented
:param batch: batch size of a single job
:param order: "original"|"reverse"|"random" - the order of execution
"""
error_list = [] if suppress_errors else None
if not isinstance(self.populated_from, RelationalOperand):
raise DataJointError('Invalid populated_from value')

if self.connection.in_transaction:
raise DataJointError('Populate cannot be called during a transaction.')

valid_order = ['original', 'reverse', 'random']
if order not in valid_order:
raise DataJointError('The order argument must be one of %s' % str(valid_order))

error_list = [] if suppress_errors else None

jobs = self.connection.jobs[self.target.database]
table_name = self.target.table_name
unpopulated = (self.populated_from & restriction) - self.target.project()
for key in unpopulated.fetch.keys():
keys = unpopulated.fetch.keys()
if order == "reverse":
keys = list(keys).reverse()
elif order == "random":
keys = list(keys)
random.shuffle(keys)

for key in keys:
if not reserve_jobs or jobs.reserve(table_name, key):
self.connection.start_transaction()
if key in self.target: # already populated
Expand Down
2 changes: 1 addition & 1 deletion datajoint/kill.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def kill(restriction=None, connection=None):
except TypeError as err:
print(process)

response = input('process to kill or "q" to quit)')
response = input('process to kill or "q" to quit > ')
if response == 'q':
break
if response:
Expand Down
46 changes: 29 additions & 17 deletions datajoint/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,20 @@ def heading(self):
if self._heading is None:
self._heading = Heading() # instance-level heading
if not self._heading:
if not self.is_declared:
self.connection.query(
declare(self.full_table_name, self.definition, self._context))
if self.is_declared:
self.connection.erm.load_dependencies(self.full_table_name)
self._heading.init_from_database(self.connection, self.database, self.table_name)
self.declare()
return self._heading

def declare(self):
"""
load the table heading. If the table is not declared, use self.definition to declare
"""
if not self.is_declared:
self.connection.query(
declare(self.full_table_name, self.definition, self._context))
if self.is_declared:
self.connection.erm.load_dependencies(self.full_table_name)
self._heading.init_from_database(self.connection, self.database, self.table_name)

@property
def from_clause(self):
"""
Expand Down Expand Up @@ -115,7 +121,6 @@ def descendants(self):
for table in self.connection.erm.get_descendants(self.full_table_name))
return [relation for relation in relations if relation.is_declared]


def _repr_helper(self):
return "%s.%s()" % (self.__module__, self.__class__.__name__)

Expand All @@ -133,7 +138,7 @@ def full_table_name(self):

def insert(self, rows, **kwargs):
"""
Inserts a collection of tuples. Additional keyword arguments are passed to insert1.
Insert a collection of tuples. Additional keyword arguments are passed to insert1.

:param iter: Must be an iterator that generates a sequence of valid arguments for insert.
"""
Expand Down Expand Up @@ -172,9 +177,11 @@ def insert1(self, tup, replace=False, ignore_errors=False, skip_duplicates=False
if name in tup and heading[name].is_blob)
else: # positional insert
try:
if len(tup) != len(self.heading):
if len(tup) != len(heading):
raise DataJointError(
'Tuple size does not match the number of relation attributes')
'Incorrect number of attributes: '
'{given} given; {expected} expected'.format(
given=len(tup), expected=len(heading)))
except TypeError:
raise DataJointError('Datatype %s cannot be inserted' % type(tup))
else:
Expand Down Expand Up @@ -225,7 +232,8 @@ def delete(self):
relations[dep] &= r.project() if name in restrict_by_me else r.restrictions

do_delete = False # indicate if there is anything to delete
print('The contents of the following tables are about to be deleted:')
if config['safemode']:
print('The contents of the following tables are about to be deleted:')
for relation in relations.values():
count = len(relation)
if count:
Expand All @@ -234,10 +242,15 @@ def delete(self):
print(relation.full_table_name, '(%d tuples)' % count)
else:
relations.pop(relation.full_table_name)
if do_delete and (not config['safemode'] or user_choice("Proceed?", default='no') == 'yes'):
with self.connection.transaction:
for r in reversed(list(relations.values())):
r.delete_quick()
if not do_delete:
if config['safemode']:
print('Nothing to delete')
else:
if not config['safemode'] or user_choice("Proceed?", default='no') == 'yes':
with self.connection.transaction:
for r in reversed(list(relations.values())):
r.delete_quick()
print('Done')

def drop_quick(self):
"""
Expand Down Expand Up @@ -274,8 +287,7 @@ def size_on_disk(self):
"""
ret = self.connection.query(
'SHOW TABLE STATUS FROM `{database}` WHERE NAME="{table}"'.format(
database=self.database, table=self.table_name), as_dict=True
).fetchone()
database=self.database, table=self.table_name), as_dict=True).fetchone()
return ret['Data_length'] + ret['Index_length']

# --------- functionality used by the decorator ---------
Expand Down
22 changes: 14 additions & 8 deletions datajoint/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,24 +45,23 @@ def __call__(self, cls):
:param cls: class to be decorated
"""

def process_relation_class(class_object, context):
def process_relation_class(relation_class, context):
"""
assign schema properties to the relation class and declare the table
"""
class_object.database = self.database
class_object._connection = self.connection
class_object._heading = Heading()
class_object._context = context
instance = class_object()
instance.heading # trigger table declaration
instance._prepare()
relation_class.database = self.database
relation_class._connection = self.connection
relation_class._heading = Heading()
relation_class._context = context
relation_class().declare()

if issubclass(cls, Part):
raise DataJointError('The schema decorator should not apply to part relations')

process_relation_class(cls, context=self.context)

# Process subordinate relations
parts = list()
for name in (name for name in dir(cls) if not name.startswith('_')):
part = getattr(cls, name)
try:
Expand All @@ -71,10 +70,17 @@ def process_relation_class(class_object, context):
pass
else:
if is_sub:
parts.append(part)
part._master = cls
process_relation_class(part, context=dict(self.context, **{cls.__name__: cls}))
elif issubclass(part, Relation):
raise DataJointError('Part relations must subclass from datajoint.Part')

# invoke Relation._prepare() on class and its part relations.
cls()._prepare()
for part in parts:
part()._prepare()

return cls

@property
Expand Down
15 changes: 8 additions & 7 deletions datajoint/user_relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
Hosts the table tiers, user relations should be derived from.
"""

from datajoint.relation import Relation
import abc
from .relation import Relation
from .autopopulate import AutoPopulate
from .utils import from_camel_case
from . import DataJointError


class Part(Relation):
class Part(Relation, metaclass=abc.ABCMeta):

@property
def master(self):
Expand All @@ -22,7 +23,7 @@ def table_name(self):
return self.master().table_name + '__' + from_camel_case(self.__class__.__name__)


class Manual(Relation):
class Manual(Relation, metaclass=abc.ABCMeta):
"""
Inherit from this class if the table's values are entered manually.
"""
Expand All @@ -35,7 +36,7 @@ def table_name(self):
return from_camel_case(self.__class__.__name__)


class Lookup(Relation):
class Lookup(Relation, metaclass=abc.ABCMeta):
"""
Inherit from this class if the table's values are for lookup. This is
currently equivalent to defining the table as Manual and serves semantic
Expand All @@ -54,10 +55,10 @@ def _prepare(self):
Checks whether the instance has a property called `contents` and inserts its elements.
"""
if hasattr(self, 'contents'):
self.insert(self.contents, ignore_errors=False, skip_duplicates=True)
self.insert(self.contents, skip_duplicates=True)


class Imported(Relation, AutoPopulate):
class Imported(Relation, AutoPopulate, metaclass=abc.ABCMeta):
"""
Inherit from this class if the table's values are imported from external data sources.
The inherited class must at least provide the function `_make_tuples`.
Expand All @@ -71,7 +72,7 @@ def table_name(self):
return "_" + from_camel_case(self.__class__.__name__)


class Computed(Relation, AutoPopulate):
class Computed(Relation, AutoPopulate, metaclass=abc.ABCMeta):
"""
Inherit from this class if the table's values are computed from other relations in the schema.
The inherited class must at least provide the function `_make_tuples`.
Expand Down