diff --git a/tests/schema.py b/tests/schema.py index 14061e6ff..5ae94480f 100644 --- a/tests/schema.py +++ b/tests/schema.py @@ -10,6 +10,20 @@ schema = dj.schema(PREFIX + '_test1', locals(), connection=dj.conn(**CONN_INFO)) +@schema +class Auto(dj.Lookup): + definition = """ + id :int auto_increment + --- + name :varchar(12) + """ + contents = ( + dict(name="Godel"), + dict(name="Escher"), + dict(name="Bach") + ) + + @schema class User(dj.Lookup): definition = """ # lab members diff --git a/tests/test_declare.py b/tests/test_declare.py index ef760d2ce..d856c750c 100644 --- a/tests/test_declare.py +++ b/tests/test_declare.py @@ -2,24 +2,31 @@ from . import schema +auto = schema.Auto() +user = schema.User() +subject = schema.Subject() +experiment = schema.Experiment() +trial = schema.Trial() +ephys = schema.Ephys() +channel = schema.EphysChannel() + + class TestDeclare: - def __init__(self): - self.user = schema.User() - self.subject = schema.Subject() - self.experiment = schema.Experiment() - self.trial = schema.Trial() - self.ephys = schema.Ephys() - self.channel = schema.EphysChannel() - - def test_attributes(self): - assert_list_equal(self.subject.heading.names, + + @staticmethod + def test_attributes(): + # test autoincrement declaration + assert_list_equal(auto.heading.names, ['id', 'name']) + assert_true(auto.heading.attributes['id'].autoincrement) + + # test attribute declarations + assert_list_equal(subject.heading.names, ['subject_id', 'real_id', 'species', 'date_of_birth', 'subject_notes']) - assert_list_equal(self.subject.primary_key, + assert_list_equal(subject.primary_key, ['subject_id']) - assert_true(self.subject.heading.attributes['subject_id'].numeric) - assert_false(self.subject.heading.attributes['real_id'].numeric) + assert_true(subject.heading.attributes['subject_id'].numeric) + assert_false(subject.heading.attributes['real_id'].numeric) - experiment = schema.Experiment() assert_list_equal(experiment.heading.names, ['subject_id', 'experiment_id', 'experiment_date', 'username', 'data_path', @@ -27,34 +34,35 @@ def test_attributes(self): assert_list_equal(experiment.primary_key, ['subject_id', 'experiment_id']) - assert_list_equal(self.trial.heading.names, + assert_list_equal(trial.heading.names, ['subject_id', 'experiment_id', 'trial_id', 'start_time']) - assert_list_equal(self.trial.primary_key, + assert_list_equal(trial.primary_key, ['subject_id', 'experiment_id', 'trial_id']) - assert_list_equal(self.ephys.heading.names, + assert_list_equal(ephys.heading.names, ['subject_id', 'experiment_id', 'trial_id', 'sampling_frequency', 'duration']) - assert_list_equal(self.ephys.primary_key, + assert_list_equal(ephys.primary_key, ['subject_id', 'experiment_id', 'trial_id']) - assert_list_equal(self.channel.heading.names, + assert_list_equal(channel.heading.names, ['subject_id', 'experiment_id', 'trial_id', 'channel', 'voltage']) - assert_list_equal(self.channel.primary_key, + assert_list_equal(channel.primary_key, ['subject_id', 'experiment_id', 'trial_id', 'channel']) - assert_true(self.channel.heading.attributes['voltage'].is_blob) + assert_true(channel.heading.attributes['voltage'].is_blob) + def test_dependencies(self): - assert_equal(self.user.references, [self.experiment.full_table_name]) - assert_equal(self.experiment.referenced, [self.user.full_table_name]) + assert_equal(user.references, [experiment.full_table_name]) + assert_equal(experiment.referenced, [user.full_table_name]) - assert_equal(self.subject.children, [self.experiment.full_table_name]) - assert_equal(self.experiment.parents, [self.subject.full_table_name]) + assert_equal(subject.children, [experiment.full_table_name]) + assert_equal(experiment.parents, [subject.full_table_name]) - assert_equal(self.experiment.children, [self.trial.full_table_name]) - assert_equal(self.trial.parents, [self.experiment.full_table_name]) + assert_equal(experiment.children, [trial.full_table_name]) + assert_equal(trial.parents, [experiment.full_table_name]) - assert_equal(self.trial.children, [self.ephys.full_table_name]) - assert_equal(self.ephys.parents, [self.trial.full_table_name]) + assert_equal(trial.children, [ephys.full_table_name]) + assert_equal(ephys.parents, [trial.full_table_name]) - assert_equal(self.ephys.children, [self.channel.full_table_name]) - assert_equal(self.channel.parents, [self.ephys.full_table_name]) + assert_equal(ephys.children, [channel.full_table_name]) + assert_equal(channel.parents, [ephys.full_table_name])