Skip to content
9 changes: 5 additions & 4 deletions datajoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
'config',
'Connection', 'Heading', 'Relation', 'FreeRelation', 'Not',
'Relation', 'schema',
'Manual', 'Lookup', 'Imported', 'Computed',
'Manual', 'Lookup', 'Imported', 'Computed', 'Part',
'conn', 'kill']

# define an object that identifies the primary key in RelationalOperand.__getitem__
class PrimaryKey: pass

# define an object that identifies the primary key in RelationalOperand.__getitem__
class PrimaryKey:
pass

key = PrimaryKey

Expand Down Expand Up @@ -54,7 +55,7 @@ class DataJointError(Exception):
# ------------- flatten import hierarchy -------------------------
from .connection import conn, Connection
from .relation import Relation
from .user_relations import Manual, Lookup, Imported, Computed, Subordinate
from .user_relations import Manual, Lookup, Imported, Computed, Part
from .relational_operand import Not
from .heading import Heading
from .schema import schema
Expand Down
51 changes: 38 additions & 13 deletions datajoint/schema.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import pymysql
import logging

from . import conn
from . import DataJointError
from . import conn, DataJointError
from .heading import Heading

from .relation import Relation
from .user_relations import Part
logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -44,18 +44,43 @@ def __call__(self, cls):
The decorator binds its argument class object to a database
:param cls: class to be decorated
"""
# class-level attributes
cls.database = self.database
cls._connection = self.connection
cls._heading = Heading()
cls._context = self.context

# trigger table declaration by requesting the heading from an instance
instance = cls()
instance.heading
instance._prepare()

def process_relation_class(class_object, context):
"""
assign schema properties to the relation class and declare the table
"""
class_object.database = self.database
class_object._connection = self.connection
class_object._heading = Heading()
class_object._context = context
instance = class_object()
instance.heading # trigger table declaration
instance._prepare()

if issubclass(cls, Part):
raise DataJointError('The schema decorator should not apply to part relations')

process_relation_class(cls, context=self.context)

# Process subordinate relations
for name in (name for name in dir(cls) if not name.startswith('_')):
part = getattr(cls, name)
try:
is_sub = issubclass(part, Part)
except TypeError:
pass
else:
if is_sub:
part._master = cls
process_relation_class(part, context=dict(self.context, **{cls.__name__: cls}))
elif issubclass(part, Relation):
raise DataJointError('Part relations must subclass from datajoint.Part')
return cls

@property
def jobs(self):
"""
schema.jobs provides a view of the job reservation table for the schema
:return: jobs relation
"""
return self.connection.jobs[self.database]
58 changes: 15 additions & 43 deletions datajoint/user_relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,21 @@
from datajoint.relation import Relation
from .autopopulate import AutoPopulate
from .utils import from_camel_case
from . import DataJointError


class Part(Relation):

@property
def master(self):
if not hasattr(self, '_master'):
raise DataJointError(
'Part relations must be declared inside a base relation class')
return self._master

@property
def table_name(self):
return self.master().table_name + '__' + from_camel_case(self.__class__.__name__)


class Manual(Relation):
Expand Down Expand Up @@ -68,46 +83,3 @@ def table_name(self):
:returns: the table name of the table formatted for mysql.
"""
return "__" + from_camel_case(self.__class__.__name__)


class Subordinate:
"""
Mix-in to make computed tables subordinate
"""

@property
def populated_from(self):
"""
Overrides the `populate_from` property because subtables should not be populated
directly.

:return: None
"""
return None

def _make_tuples(self, key):
"""
Overrides the `_make_tuples` property because subtables should not be populated
directly. Raises an error if this method is called (usually from populate of the
inheriting object).

:raises: NotImplementedError
"""
raise NotImplementedError(
'This table is subordinate: it cannot be populated directly. Refer to its parent table.')

def progress(self):
"""
Overrides the `progress` method because subtables should not be populated directly.
"""
raise NotImplementedError(
'This table is subordinate: it cannot be populated directly. Refer to its parent table.')

def populate(self, *args, **kwargs):
raise NotImplementedError(
'This table is subordinate: it cannot be populated directly. Refer to its parent table.')





48 changes: 20 additions & 28 deletions tests/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,12 @@ class Language(dj.Lookup):
"""

contents = [
('Fabian', 'English'),
('Edgar', 'English'),
('Dimitri', 'English'),
('Dimitri', 'Ukrainian'),
('Fabian', 'German'),
('Edgar', 'Japanese'),
]
('Fabian', 'English'),
('Edgar', 'English'),
('Dimitri', 'English'),
('Dimitri', 'Ukrainian'),
('Fabian', 'German'),
('Edgar', 'Japanese')]


@schema
Expand Down Expand Up @@ -136,36 +135,29 @@ class Ephys(dj.Imported):
duration :double # (s)
"""

class Channel(dj.Part):
definition = """ # subtable containing individual channels
-> Ephys
channel :tinyint unsigned # channel number within Ephys
----
voltage :longblob
"""

def _make_tuples(self, key):
"""
populate with random data
"""
random.seed('Amazing seed')
random.seed(str(key))
row = dict(key,
sampling_frequency=6000,
duration=np.minimum(2, random.expovariate(1)))
self.insert1(row)
number_samples = round(row['duration'] * row['sampling_frequency'])
EphysChannel().fill(key, number_samples=number_samples)


@schema
class EphysChannel(dj.Subordinate, dj.Imported):
definition = """ # subtable containing individual channels
-> Ephys
channel :tinyint unsigned # channel number within Ephys
----
voltage :longblob
"""

def fill(self, key, number_samples):
"""
populate random trace of specified length
"""
random.seed('Amazing seed')
sub = self.Channel()
for channel in range(2):
self.insert1(
sub.insert1(
dict(key,
channel=channel,
voltage=np.float32(np.random.randn(number_samples))
))
voltage=np.float32(np.random.randn(number_samples))))


104 changes: 104 additions & 0 deletions tests/schema_simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""
A simple, abstract schema to test relational algebra
"""
import random
import datajoint as dj
from . import PREFIX, CONN_INFO

schema = dj.schema(PREFIX + '_relational', locals(), connection=dj.conn(**CONN_INFO))


@schema
class A(dj.Lookup):
definition = """
id_a :int
---
cond_in_a :tinyint
"""
contents = [(i, i % 4 > i % 3) for i in range(10)]


@schema
class B(dj.Computed):
definition = """
-> A
id_b :int
---
mu :float # mean value
sigma :float # standard deviation
n :smallint # number samples
"""

class C(dj.Part):
definition = """
-> B
id_c :int
---
value :float # normally distributed variables according to parameters in B
"""

def _make_tuples(self, key):
random.seed(str(key))
sub = B.C()
for i in range(4):
key['id_b'] = i
mu = random.normalvariate(0, 10)
sigma = random.lognormvariate(0, 4)
n = random.randint(0, 10)
self.insert1(dict(key, mu=mu, sigma=sigma, n=n))
sub.insert((dict(key, id_c=j, value=random.normalvariate(mu, sigma)) for j in range(n)))


@schema
class L(dj.Lookup):
definition = """
id_l: int
---
cond_in_l :tinyint
"""
contents = [(i, i % 3 >= i % 5) for i in range(30)]


@schema
class D(dj.Computed):
definition = """
-> A
id_d :int
---
-> L
"""

def _make_tuples(self, key):
# make reference to a random tuple from L
random.seed(str(key))
lookup = list(L().fetch.keys())
for i in range(4):
self.insert1(dict(key, id_d=i, **random.choice(lookup)))


@schema
class E(dj.Computed):
definition = """
-> B
-> D
---
-> L
"""

class F(dj.Part):
definition = """
-> E
id_f :int
---
-> B.C
"""

def _make_tuples(self, key):
random.seed(str(key))
self.insert1(dict(key, **random.choice(list(L().fetch.keys()))))
sub = E.F()
references = list((B.C() & key).fetch.keys())
random.shuffle(references)
for i, ref in enumerate(references):
if random.getrandbits(1):
sub.insert1(dict(key, id_f=i, **ref))
2 changes: 1 addition & 1 deletion tests/test_autopopulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self):
self.experiment = schema.Experiment()
self.trial = schema.Trial()
self.ephys = schema.Ephys()
self.channel = schema.EphysChannel()
self.channel = schema.Ephys.Channel()

# delete automatic tables just in case
self.channel.delete_quick()
Expand Down
Loading