diff --git a/datajoint/__init__.py b/datajoint/__init__.py index b745ab686..aea688382 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -17,12 +17,13 @@ 'config', 'Connection', 'Heading', 'Relation', 'FreeRelation', 'Not', 'Relation', 'schema', - 'Manual', 'Lookup', 'Imported', 'Computed', + 'Manual', 'Lookup', 'Imported', 'Computed', 'Part', 'conn', 'kill'] -# define an object that identifies the primary key in RelationalOperand.__getitem__ -class PrimaryKey: pass +# define an object that identifies the primary key in RelationalOperand.__getitem__ +class PrimaryKey: + pass key = PrimaryKey @@ -54,7 +55,7 @@ class DataJointError(Exception): # ------------- flatten import hierarchy ------------------------- from .connection import conn, Connection from .relation import Relation -from .user_relations import Manual, Lookup, Imported, Computed, Subordinate +from .user_relations import Manual, Lookup, Imported, Computed, Part from .relational_operand import Not from .heading import Heading from .schema import schema diff --git a/datajoint/schema.py b/datajoint/schema.py index 2f538a2f0..bdf985c61 100644 --- a/datajoint/schema.py +++ b/datajoint/schema.py @@ -1,10 +1,10 @@ import pymysql import logging -from . import conn -from . import DataJointError +from . import conn, DataJointError from .heading import Heading - +from .relation import Relation +from .user_relations import Part logger = logging.getLogger(__name__) @@ -44,18 +44,43 @@ def __call__(self, cls): The decorator binds its argument class object to a database :param cls: class to be decorated """ - # class-level attributes - cls.database = self.database - cls._connection = self.connection - cls._heading = Heading() - cls._context = self.context - - # trigger table declaration by requesting the heading from an instance - instance = cls() - instance.heading - instance._prepare() + + def process_relation_class(class_object, context): + """ + assign schema properties to the relation class and declare the table + """ + class_object.database = self.database + class_object._connection = self.connection + class_object._heading = Heading() + class_object._context = context + instance = class_object() + instance.heading # trigger table declaration + instance._prepare() + + if issubclass(cls, Part): + raise DataJointError('The schema decorator should not apply to part relations') + + process_relation_class(cls, context=self.context) + + # Process subordinate relations + for name in (name for name in dir(cls) if not name.startswith('_')): + part = getattr(cls, name) + try: + is_sub = issubclass(part, Part) + except TypeError: + pass + else: + if is_sub: + part._master = cls + process_relation_class(part, context=dict(self.context, **{cls.__name__: cls})) + elif issubclass(part, Relation): + raise DataJointError('Part relations must subclass from datajoint.Part') return cls @property def jobs(self): + """ + schema.jobs provides a view of the job reservation table for the schema + :return: jobs relation + """ return self.connection.jobs[self.database] diff --git a/datajoint/user_relations.py b/datajoint/user_relations.py index 636e86b10..2e140f62f 100644 --- a/datajoint/user_relations.py +++ b/datajoint/user_relations.py @@ -5,6 +5,21 @@ from datajoint.relation import Relation from .autopopulate import AutoPopulate from .utils import from_camel_case +from . import DataJointError + + +class Part(Relation): + + @property + def master(self): + if not hasattr(self, '_master'): + raise DataJointError( + 'Part relations must be declared inside a base relation class') + return self._master + + @property + def table_name(self): + return self.master().table_name + '__' + from_camel_case(self.__class__.__name__) class Manual(Relation): @@ -68,46 +83,3 @@ def table_name(self): :returns: the table name of the table formatted for mysql. """ return "__" + from_camel_case(self.__class__.__name__) - - -class Subordinate: - """ - Mix-in to make computed tables subordinate - """ - - @property - def populated_from(self): - """ - Overrides the `populate_from` property because subtables should not be populated - directly. - - :return: None - """ - return None - - def _make_tuples(self, key): - """ - Overrides the `_make_tuples` property because subtables should not be populated - directly. Raises an error if this method is called (usually from populate of the - inheriting object). - - :raises: NotImplementedError - """ - raise NotImplementedError( - 'This table is subordinate: it cannot be populated directly. Refer to its parent table.') - - def progress(self): - """ - Overrides the `progress` method because subtables should not be populated directly. - """ - raise NotImplementedError( - 'This table is subordinate: it cannot be populated directly. Refer to its parent table.') - - def populate(self, *args, **kwargs): - raise NotImplementedError( - 'This table is subordinate: it cannot be populated directly. Refer to its parent table.') - - - - - diff --git a/tests/schema.py b/tests/schema.py index 5ae94480f..564b0fcd4 100644 --- a/tests/schema.py +++ b/tests/schema.py @@ -66,13 +66,12 @@ class Language(dj.Lookup): """ contents = [ - ('Fabian', 'English'), - ('Edgar', 'English'), - ('Dimitri', 'English'), - ('Dimitri', 'Ukrainian'), - ('Fabian', 'German'), - ('Edgar', 'Japanese'), - ] + ('Fabian', 'English'), + ('Edgar', 'English'), + ('Dimitri', 'English'), + ('Dimitri', 'Ukrainian'), + ('Fabian', 'German'), + ('Edgar', 'Japanese')] @schema @@ -136,36 +135,29 @@ class Ephys(dj.Imported): duration :double # (s) """ + class Channel(dj.Part): + definition = """ # subtable containing individual channels + -> Ephys + channel :tinyint unsigned # channel number within Ephys + ---- + voltage :longblob + """ + def _make_tuples(self, key): """ populate with random data """ - random.seed('Amazing seed') + random.seed(str(key)) row = dict(key, sampling_frequency=6000, duration=np.minimum(2, random.expovariate(1))) self.insert1(row) number_samples = round(row['duration'] * row['sampling_frequency']) - EphysChannel().fill(key, number_samples=number_samples) - - -@schema -class EphysChannel(dj.Subordinate, dj.Imported): - definition = """ # subtable containing individual channels - -> Ephys - channel :tinyint unsigned # channel number within Ephys - ---- - voltage :longblob - """ - - def fill(self, key, number_samples): - """ - populate random trace of specified length - """ - random.seed('Amazing seed') + sub = self.Channel() for channel in range(2): - self.insert1( + sub.insert1( dict(key, channel=channel, - voltage=np.float32(np.random.randn(number_samples)) - )) + voltage=np.float32(np.random.randn(number_samples)))) + + diff --git a/tests/schema_simple.py b/tests/schema_simple.py new file mode 100644 index 000000000..61fedbb66 --- /dev/null +++ b/tests/schema_simple.py @@ -0,0 +1,104 @@ +""" +A simple, abstract schema to test relational algebra +""" +import random +import datajoint as dj +from . import PREFIX, CONN_INFO + +schema = dj.schema(PREFIX + '_relational', locals(), connection=dj.conn(**CONN_INFO)) + + +@schema +class A(dj.Lookup): + definition = """ + id_a :int + --- + cond_in_a :tinyint + """ + contents = [(i, i % 4 > i % 3) for i in range(10)] + + +@schema +class B(dj.Computed): + definition = """ + -> A + id_b :int + --- + mu :float # mean value + sigma :float # standard deviation + n :smallint # number samples + """ + + class C(dj.Part): + definition = """ + -> B + id_c :int + --- + value :float # normally distributed variables according to parameters in B + """ + + def _make_tuples(self, key): + random.seed(str(key)) + sub = B.C() + for i in range(4): + key['id_b'] = i + mu = random.normalvariate(0, 10) + sigma = random.lognormvariate(0, 4) + n = random.randint(0, 10) + self.insert1(dict(key, mu=mu, sigma=sigma, n=n)) + sub.insert((dict(key, id_c=j, value=random.normalvariate(mu, sigma)) for j in range(n))) + + +@schema +class L(dj.Lookup): + definition = """ + id_l: int + --- + cond_in_l :tinyint + """ + contents = [(i, i % 3 >= i % 5) for i in range(30)] + + +@schema +class D(dj.Computed): + definition = """ + -> A + id_d :int + --- + -> L + """ + + def _make_tuples(self, key): + # make reference to a random tuple from L + random.seed(str(key)) + lookup = list(L().fetch.keys()) + for i in range(4): + self.insert1(dict(key, id_d=i, **random.choice(lookup))) + + +@schema +class E(dj.Computed): + definition = """ + -> B + -> D + --- + -> L + """ + + class F(dj.Part): + definition = """ + -> E + id_f :int + --- + -> B.C + """ + + def _make_tuples(self, key): + random.seed(str(key)) + self.insert1(dict(key, **random.choice(list(L().fetch.keys())))) + sub = E.F() + references = list((B.C() & key).fetch.keys()) + random.shuffle(references) + for i, ref in enumerate(references): + if random.getrandbits(1): + sub.insert1(dict(key, id_f=i, **ref)) diff --git a/tests/test_autopopulate.py b/tests/test_autopopulate.py index cbb5142a4..31bc6496f 100644 --- a/tests/test_autopopulate.py +++ b/tests/test_autopopulate.py @@ -16,7 +16,7 @@ def __init__(self): self.experiment = schema.Experiment() self.trial = schema.Trial() self.ephys = schema.Ephys() - self.channel = schema.EphysChannel() + self.channel = schema.Ephys.Channel() # delete automatic tables just in case self.channel.delete_quick() diff --git a/tests/test_cascading_delete.py b/tests/test_cascading_delete.py new file mode 100644 index 000000000..8182ad6c1 --- /dev/null +++ b/tests/test_cascading_delete.py @@ -0,0 +1,74 @@ +from nose.tools import assert_false, assert_true +import datajoint as dj +from .schema_simple import A, B, D, E, L + + +class TestDelete: + + @staticmethod + def setup(): + """ + class-level test setup. Executes before each test method. + """ + A()._prepare() + L()._prepare() + B().populate() + D().populate() + E().populate() + + @staticmethod + def test_delete_tree(): + assert_false(dj.config['safemode'], 'safemode must be off for testing') + assert_true(L() and A() and B() and B.C() and D() and E() and E.F(), + 'schema is not populated') + A().delete() + assert_false(A() or B() or B.C() or D() or E() or E.F(), 'incomplete delete') + + @staticmethod + def test_delete_tree_restricted(): + assert_false(dj.config['safemode'], 'safemode must be off for testing') + assert_true(L() and A() and B() and B.C() and D() and E() and E.F(), + 'schema is not populated') + cond = 'cond_in_a' + rel = A() & cond + rest = dict( + A=len(A())-len(rel), + B=len(B()-rel), + C=len(B.C()-rel), + D=len(D()-rel), + E=len(E()-rel), + F=len(E.F()-rel) + ) + rel.delete() + assert_false(rel or + (B() & rel) or + (B.C() & rel) or + (D() & rel) or + (E() & rel) or + (E.F() & rel), + 'incomplete delete') + assert_true( + len(A()) == rest['A'] and + len(B()) == rest['B'] and + len(B.C()) == rest['C'] and + len(D()) == rest['D'] and + len(E()) == rest['E'] and + len(E.F()) == rest['F'], + 'incorrect restricted delete') + + @staticmethod + def test_delete_lookup(): + assert_false(dj.config['safemode'], 'safemode must be off for testing') + assert_true(bool(L() and A() and B() and B.C() and D() and E() and E.F()), + 'schema is not populated') + L().delete() + assert_false(bool(L() or D() or E() or E.F()), 'incomplete delete') + A().delete() # delete all is necessary because delete L deletes from subtables. TODO: submit this as an issue + + # @staticmethod + # def test_delete_lookup_restricted(): + # assert_false(dj.config['safemode'], 'safemode must be off for testing') + # assert_true(L() and A() and B() and C() and D() and E() and F(), + # 'schema is not populated') + # rel = L() & 'cond_in_l' + # L().delete() diff --git a/tests/test_declare.py b/tests/test_declare.py index d856c750c..b00545599 100644 --- a/tests/test_declare.py +++ b/tests/test_declare.py @@ -8,7 +8,7 @@ experiment = schema.Experiment() trial = schema.Trial() ephys = schema.Ephys() -channel = schema.EphysChannel() +channel = schema.Ephys.Channel() class TestDeclare: diff --git a/tests/test_relation.py b/tests/test_relation.py index d025775d0..a2de3749b 100644 --- a/tests/test_relation.py +++ b/tests/test_relation.py @@ -18,7 +18,7 @@ def __init__(self): self.experiment = schema.Experiment() self.trial = schema.Trial() self.ephys = schema.Ephys() - self.channel = schema.EphysChannel() + self.channel = schema.Ephys.Channel() def test_contents(self): """ diff --git a/tests/test_relational_operand.py b/tests/test_relational_operand.py index a6c19a46e..5e0e68b67 100644 --- a/tests/test_relational_operand.py +++ b/tests/test_relational_operand.py @@ -1,124 +1,20 @@ -import datajoint as dj -from . import PREFIX, CONN_INFO -import random import numpy as np from nose.tools import assert_raises, assert_equal, \ assert_false, assert_true, assert_list_equal, \ assert_tuple_equal, assert_dict_equal, raises - - -schema = dj.schema(PREFIX + '_relational', locals(), connection=dj.conn(**CONN_INFO)) - - -@schema -class A(dj.Lookup): - definition = """ - id_a :int - --- - cond_in_a :tinyint - """ - contents = [(i, i % 4 > i % 3) for i in range(10)] - - -@schema -class B(dj.Computed): - definition = """ - -> A - id_b :int - --- - mu :float # mean value - sigma :float # standard deviation - n :smallint # number samples - """ - - def _make_tuples(self, key): - random.seed(str(key)) - sub = C() - for i in range(4): - key['id_b'] = i - mu = random.normalvariate(0, 10) - sigma = random.lognormvariate(0, 4) - n = random.randint(0, 10) - self.insert1(dict(key, mu=mu, sigma=sigma, n=n)) - for j in range(n): - sub.insert1(dict(key, id_c=j, value=random.normalvariate(mu, sigma))) - - -@schema -class C(dj.Subordinate, dj.Computed): - definition = """ - -> B - id_c :int - --- - value :float # normally distributed variables according to parameters in B - """ - - -@schema -class L(dj.Lookup): - definition = """ - id_l: int - --- - cond_in_l :tinyint - """ - contents = ((i, i % 3 >= i % 5) for i in range(30)) - - -@schema -class D(dj.Computed): - definition = """ - -> A - id_d :int - --- - -> L - """ - - def _make_tuples(self, key): - # connect to random L - random.seed(str(key)) - lookup = list(L().fetch.keys()) - for i in range(4): - self.insert1(dict(key, id_d=i, **random.choice(lookup))) - - -@schema -class E(dj.Computed): - definition = """ - -> B - -> D - --- - -> L - """ - - def _make_tuples(self, key): - random.seed(str(key)) - self.insert1(dict(key, **random.choice(list(L().fetch.keys())))) - sub = F() - references = list((C() & key).fetch.keys()) - random.shuffle(references) - for i, ref in enumerate(references): - if random.getrandbits(1): - sub.insert1(dict(key, id_f=i, **ref)) - - -@schema -class F(dj.Subordinate, dj.Computed): - definition = """ - -> E - id_f :int - --- - -> C - """ +import datajoint as dj +from .schema_simple import A, B, D, E, L def setup(): """ module-level test setup """ + A()._prepare() + L()._prepare() B().populate() D().populate() E().populate() - pass class TestRelational: @@ -130,10 +26,10 @@ def test_populate(): assert_false(E().progress(display=False)[0], 'E incompletely populated') assert_true(len(B()) == 40, 'B populated incorrectly') - assert_true(len(C()) > 0, 'C populated incorrectly') + assert_true(len(B.C()) > 0, 'C populated incorrectly') assert_true(len(D()) == 40, 'D populated incorrectly') assert_true(len(E()) == len(B())*len(D())/len(A()), 'E populated incorrectly') - assert_true(len(F()) > 0, 'F populated incorrectly') + assert_true(len(E.F()) > 0, 'F populated incorrectly') @staticmethod def test_join(): @@ -217,28 +113,23 @@ def test_project(): 'extend does not work') # projection after restriction - assert_equal( - len(D() & (L() & 'cond_in_l')) + len(D() - (L() & 'cond_in_l')), - len(D()), - 'failed semijoin or antijoin' - ) - assert_equal( - len((D() - (L() & 'cond_in_l')).project()), - len(D() - (L() & 'cond_in_l')), - 'projection altered the cardinality of a restricted relation' - ) + cond = L() & 'cond_in_l' + assert_equal(len(D() & cond) + len(D() - cond), len(D()), + 'failed semijoin or antijoin') + assert_equal(len((D() & cond).project()), len((D() & cond)), + 'projection failed: altered its argument''s cardinality') @staticmethod def test_aggregate(): - x = B().aggregate(C(), 'n', count='count(id_c)', mean='avg(value)', max='max(value)') + x = B().aggregate(B.C(), 'n', count='count(id_c)', mean='avg(value)', max='max(value)') assert_equal(len(x), len(B())) - for n, count, mean, max, key in zip(*x.fetch['n', 'count', 'mean', 'max', dj.key]): + for n, count, mean, max_, key in zip(*x.fetch['n', 'count', 'mean', 'max', dj.key]): assert_equal(n, count, 'aggregation failed (count)') - values = (C() & key).fetch['value'] + values = (B.C() & key).fetch['value'] assert_true(bool(len(values)) == bool(n), 'aggregation failed (restriction)') if n: assert_true(np.isclose(mean, values.mean(), rtol=1e-4, atol=1e-5), "aggregation failed (mean)") - assert_true(np.isclose(max, values.max(), rtol=1e-4, atol=1e-5), + assert_true(np.isclose(max_, values.max(), rtol=1e-4, atol=1e-5), "aggregation failed (max)")