From 9f2297db5f214c8c5c0c5e5091fc5c719eac5928 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Wed, 29 Nov 2023 14:52:34 -0600 Subject: [PATCH 1/6] test: :white_check_mark: convert simpler tests to pytest syntax --- tests/__init__.py | 7 + tests/schema.py | 489 +++++++++++++++++++ tests/schema_advanced.py | 147 ++++++ tests/schema_simple.py | 279 +++++++++++ {tests_old => tests}/test_blob.py | 100 ++-- {tests_old => tests}/test_blob_matlab.py | 58 ++- {tests_old => tests}/test_dependencies.py | 75 ++- tests/test_erd.py | 76 +++ {tests_old => tests}/test_foreign_keys.py | 23 +- {tests_old => tests}/test_groupby.py | 0 tests/test_hash.py | 6 + {tests_old => tests}/test_json.py | 6 +- {tests_old => tests}/test_log.py | 3 +- {tests_old => tests}/test_nan.py | 14 +- {tests_old => tests}/test_plugin.py | 0 {tests_old => tests}/test_relation_u.py | 59 ++- {tests_old => tests}/test_schema_keywords.py | 12 +- {tests_old => tests}/test_settings.py | 30 +- tests/test_utils.py | 33 ++ tests/test_virtual_module.py | 10 + tests_old/test_erd.py | 87 ---- tests_old/test_hash.py | 7 - tests_old/test_utils.py | 33 -- tests_old/test_virtual_module.py | 12 - 24 files changed, 1219 insertions(+), 347 deletions(-) create mode 100644 tests/schema.py create mode 100644 tests/schema_advanced.py create mode 100644 tests/schema_simple.py rename {tests_old => tests}/test_blob.py (73%) rename {tests_old => tests}/test_blob_matlab.py (83%) rename {tests_old => tests}/test_dependencies.py (64%) create mode 100644 tests/test_erd.py rename {tests_old => tests}/test_foreign_keys.py (72%) rename {tests_old => tests}/test_groupby.py (100%) create mode 100644 tests/test_hash.py rename {tests_old => tests}/test_json.py (98%) rename {tests_old => tests}/test_log.py (69%) rename {tests_old => tests}/test_nan.py (73%) rename {tests_old => tests}/test_plugin.py (100%) rename {tests_old => tests}/test_relation_u.py (52%) rename {tests_old => tests}/test_schema_keywords.py (67%) rename {tests_old => tests}/test_settings.py (69%) create mode 100644 tests/test_utils.py create mode 100644 tests/test_virtual_module.py delete mode 100644 tests_old/test_erd.py delete mode 100644 tests_old/test_hash.py delete mode 100644 tests_old/test_utils.py delete mode 100644 tests_old/test_virtual_module.py diff --git a/tests/__init__.py b/tests/__init__.py index 8b825a042..0fd907166 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -5,6 +5,13 @@ PREFIX = "djtest" +# Connection for testing +CONN_INFO = dict( + host=os.getenv("DJ_HOST"), + user=os.getenv("DJ_USER"), + password=os.getenv("DJ_PASS"), +) + CONN_INFO_ROOT = dict( host=os.getenv("DJ_HOST"), user=os.getenv("DJ_USER"), diff --git a/tests/schema.py b/tests/schema.py new file mode 100644 index 000000000..dafd481da --- /dev/null +++ b/tests/schema.py @@ -0,0 +1,489 @@ +""" +Sample schema with realistic tables for testing +""" + +import random +import numpy as np +import datajoint as dj +import inspect +from . import PREFIX, CONN_INFO + +schema = dj.Schema(PREFIX + "_test1", connection=dj.conn(**CONN_INFO)) + + +@schema +class TTest(dj.Lookup): + """ + doc string + """ + + definition = """ + key : int # key + --- + value : int # value + """ + contents = [(k, 2 * k) for k in range(10)] + + +@schema +class TTest2(dj.Manual): + definition = """ + key : int # key + --- + value : int # value + """ + + +@schema +class TTest3(dj.Manual): + definition = """ + key : int + --- + value : varchar(300) + """ + + +@schema +class NullableNumbers(dj.Manual): + definition = """ + key : int + --- + fvalue = null : float + dvalue = null : double + ivalue = null : int + """ + + +@schema +class TTestExtra(dj.Manual): + """ + clone of Test but with an extra field + """ + + definition = TTest.definition + "\nextra : int # extra int\n" + + +@schema +class TTestNoExtra(dj.Manual): + """ + clone of Test but with no extra fields + """ + + definition = TTest.definition + + +@schema +class Auto(dj.Lookup): + definition = """ + id :int auto_increment + --- + name :varchar(12) + """ + + def fill(self): + if not self: + self.insert([dict(name="Godel"), dict(name="Escher"), dict(name="Bach")]) + + +@schema +class User(dj.Lookup): + definition = """ # lab members + username: varchar(12) + """ + contents = [ + ["Jake"], + ["Cathryn"], + ["Shan"], + ["Fabian"], + ["Edgar"], + ["George"], + ["Dimitri"], + ] + + +@schema +class Subject(dj.Lookup): + definition = """ # Basic information about animal subjects used in experiments + subject_id :int # unique subject id + --- + real_id :varchar(40) # real-world name. Omit if the same as subject_id + species = "mouse" :enum('mouse', 'monkey', 'human') + date_of_birth :date + subject_notes :varchar(4000) + unique index (real_id, species) + """ + + contents = [ + [1551, "1551", "mouse", "2015-04-01", "genetically engineered super mouse"], + [10, "Curious George", "monkey", "2008-06-30", ""], + [1552, "1552", "mouse", "2015-06-15", ""], + [1553, "1553", "mouse", "2016-07-01", ""], + ] + + +@schema +class Language(dj.Lookup): + definition = """ + # languages spoken by some of the developers + # additional comments are ignored + name : varchar(40) # name of the developer + language : varchar(40) # language + """ + contents = [ + ("Fabian", "English"), + ("Edgar", "English"), + ("Dimitri", "English"), + ("Dimitri", "Ukrainian"), + ("Fabian", "German"), + ("Edgar", "Japanese"), + ] + + +@schema +class Experiment(dj.Imported): + definition = """ # information about experiments + -> Subject + experiment_id :smallint # experiment number for this subject + --- + experiment_date :date # date when experiment was started + -> [nullable] User + data_path="" :varchar(255) # file path to recorded data + notes="" :varchar(2048) # e.g. purpose of experiment + entry_time=CURRENT_TIMESTAMP :timestamp # automatic timestamp + """ + + fake_experiments_per_subject = 5 + + def make(self, key): + """ + populate with random data + """ + from datetime import date, timedelta + + users = [None, None] + list(User().fetch()["username"]) + random.seed("Amazing Seed") + self.insert( + dict( + key, + experiment_id=experiment_id, + experiment_date=( + date.today() - timedelta(random.expovariate(1 / 30)) + ).isoformat(), + username=random.choice(users), + ) + for experiment_id in range(self.fake_experiments_per_subject) + ) + + +@schema +class Trial(dj.Imported): + definition = """ # a trial within an experiment + -> Experiment.proj(animal='subject_id') + trial_id :smallint # trial number + --- + start_time :double # (s) + """ + + class Condition(dj.Part): + definition = """ # trial conditions + -> Trial + cond_idx : smallint # condition number + ---- + orientation : float # degrees + """ + + def make(self, key): + """populate with random data (pretend reading from raw files)""" + random.seed("Amazing Seed") + trial = self.Condition() + for trial_id in range(10): + key["trial_id"] = trial_id + self.insert1(dict(key, start_time=random.random() * 1e9)) + trial.insert( + dict(key, cond_idx=cond_idx, orientation=random.random() * 360) + for cond_idx in range(30) + ) + + +@schema +class Ephys(dj.Imported): + definition = """ # some kind of electrophysiological recording + -> Trial + ---- + sampling_frequency :double # (Hz) + duration :decimal(7,3) # (s) + """ + + class Channel(dj.Part): + definition = """ # subtable containing individual channels + -> master + channel :tinyint unsigned # channel number within Ephys + ---- + voltage : longblob + current = null : longblob # optional current to test null handling + """ + + def _make_tuples(self, key): + """ + populate with random data + """ + random.seed(str(key)) + row = dict( + key, sampling_frequency=6000, duration=np.minimum(2, random.expovariate(1)) + ) + self.insert1(row) + number_samples = int(row["duration"] * row["sampling_frequency"] + 0.5) + sub = self.Channel() + sub.insert( + dict( + key, + channel=channel, + voltage=np.float32(np.random.randn(number_samples)), + ) + for channel in range(2) + ) + + +@schema +class Image(dj.Manual): + definition = """ + # table for testing blob inserts + id : int # image identifier + --- + img : longblob # image + """ + + +@schema +class UberTrash(dj.Lookup): + definition = """ + id : int + --- + """ + contents = [(1,)] + + +@schema +class UnterTrash(dj.Lookup): + definition = """ + -> UberTrash + my_id : int + --- + """ + contents = [(1, 1), (1, 2)] + + +@schema +class SimpleSource(dj.Lookup): + definition = """ + id : int # id + """ + contents = ((x,) for x in range(10)) + + +@schema +class SigIntTable(dj.Computed): + definition = """ + -> SimpleSource + """ + + def _make_tuples(self, key): + raise KeyboardInterrupt + + +@schema +class SigTermTable(dj.Computed): + definition = """ + -> SimpleSource + """ + + def make(self, key): + raise SystemExit("SIGTERM received") + + +@schema +class DjExceptionName(dj.Lookup): + definition = """ + dj_exception_name: char(64) + """ + + @property + def contents(self): + return [ + [member_name] + for member_name, member_type in inspect.getmembers(dj.errors) + if inspect.isclass(member_type) and issubclass(member_type, Exception) + ] + + +@schema +class ErrorClass(dj.Computed): + definition = """ + -> DjExceptionName + """ + + def make(self, key): + exception_name = key["dj_exception_name"] + raise getattr(dj.errors, exception_name) + + +@schema +class DecimalPrimaryKey(dj.Lookup): + definition = """ + id : decimal(4,3) + """ + contents = zip((0.1, 0.25, 3.99)) + + +@schema +class IndexRich(dj.Manual): + definition = """ + -> Subject + --- + -> [unique, nullable] User.proj(first="username") + first_date : date + value : int + index (first_date, value) + """ + + +# Schema for issue 656 +@schema +class ThingA(dj.Manual): + definition = """ + a: int + """ + + +@schema +class ThingB(dj.Manual): + definition = """ + b1: int + b2: int + --- + b3: int + """ + + +@schema +class ThingC(dj.Manual): + definition = """ + -> ThingA + --- + -> [unique, nullable] ThingB + """ + + +@schema +class Parent(dj.Lookup): + definition = """ + parent_id: int + --- + name: varchar(30) + """ + contents = [(1, "Joe")] + + +@schema +class Child(dj.Lookup): + definition = """ + -> Parent + child_id: int + --- + name: varchar(30) + """ + contents = [(1, 12, "Dan")] + + +# Related to issue #886 (8), #883 (5) +@schema +class ComplexParent(dj.Lookup): + definition = "\n".join(["parent_id_{}: int".format(i + 1) for i in range(8)]) + contents = [tuple(i for i in range(8))] + + +@schema +class ComplexChild(dj.Lookup): + definition = "\n".join( + ["-> ComplexParent"] + ["child_id_{}: int".format(i + 1) for i in range(1)] + ) + contents = [tuple(i for i in range(9))] + + +@schema +class SubjectA(dj.Lookup): + definition = """ + subject_id: varchar(32) + --- + dob : date + sex : enum('M', 'F', 'U') + """ + contents = [ + ("mouse1", "2020-09-01", "M"), + ("mouse2", "2020-03-19", "F"), + ("mouse3", "2020-08-23", "F"), + ] + + +@schema +class SessionA(dj.Lookup): + definition = """ + -> SubjectA + session_start_time: datetime + --- + session_dir='' : varchar(32) + """ + contents = [ + ("mouse1", "2020-12-01 12:32:34", ""), + ("mouse1", "2020-12-02 12:32:34", ""), + ("mouse1", "2020-12-03 12:32:34", ""), + ("mouse1", "2020-12-04 12:32:34", ""), + ] + + +@schema +class SessionStatusA(dj.Lookup): + definition = """ + -> SessionA + --- + status: enum('in_training', 'trained_1a', 'trained_1b', 'ready4ephys') + """ + contents = [ + ("mouse1", "2020-12-01 12:32:34", "in_training"), + ("mouse1", "2020-12-02 12:32:34", "trained_1a"), + ("mouse1", "2020-12-03 12:32:34", "trained_1b"), + ("mouse1", "2020-12-04 12:32:34", "ready4ephys"), + ] + + +@schema +class SessionDateA(dj.Lookup): + definition = """ + -> SubjectA + session_date: date + """ + contents = [ + ("mouse1", "2020-12-01"), + ("mouse1", "2020-12-02"), + ("mouse1", "2020-12-03"), + ("mouse1", "2020-12-04"), + ] + + +@schema +class Stimulus(dj.Lookup): + definition = """ + id: int + --- + contrast: int + brightness: int + """ + + +@schema +class Longblob(dj.Manual): + definition = """ + id: int + --- + data: longblob + """ diff --git a/tests/schema_advanced.py b/tests/schema_advanced.py new file mode 100644 index 000000000..7580611e2 --- /dev/null +++ b/tests/schema_advanced.py @@ -0,0 +1,147 @@ +import datajoint as dj +from . import PREFIX, CONN_INFO + +schema = dj.Schema(PREFIX + "_advanced", locals(), connection=dj.conn(**CONN_INFO)) + + +@schema +class Person(dj.Manual): + definition = """ + person_id : int + ---- + full_name : varchar(60) + sex : enum('M','F') + """ + + def fill(self): + """ + fill fake names from www.fakenamegenerator.com + """ + self.insert( + ( + (0, "May K. Hall", "F"), + (1, "Jeffrey E. Gillen", "M"), + (2, "Hanna R. Walters", "F"), + (3, "Russel S. James", "M"), + (4, "Robbin J. Fletcher", "F"), + (5, "Wade J. Sullivan", "M"), + (6, "Dorothy J. Chen", "F"), + (7, "Michael L. Kowalewski", "M"), + (8, "Kimberly J. Stringer", "F"), + (9, "Mark G. Hair", "M"), + (10, "Mary R. Thompson", "F"), + (11, "Graham C. Gilpin", "M"), + (12, "Nelda T. Ruggeri", "F"), + (13, "Bryan M. Cummings", "M"), + (14, "Sara C. Le", "F"), + (15, "Myron S. Jaramillo", "M"), + ) + ) + + +@schema +class Parent(dj.Manual): + definition = """ + -> Person + parent_sex : enum('M','F') + --- + -> Person.proj(parent='person_id') + """ + + def fill(self): + def make_parent(pid, parent): + return dict( + person_id=pid, + parent=parent, + parent_sex=(Person & {"person_id": parent}).fetch1("sex"), + ) + + self.insert( + make_parent(*r) + for r in ( + (0, 2), + (0, 3), + (1, 4), + (1, 5), + (2, 4), + (2, 5), + (3, 4), + (3, 7), + (4, 7), + (4, 8), + (5, 9), + (5, 10), + (6, 9), + (6, 10), + (7, 11), + (7, 12), + (8, 11), + (8, 14), + (9, 11), + (9, 12), + (10, 13), + (10, 14), + (11, 14), + (11, 15), + (12, 14), + (12, 15), + ) + ) + + +@schema +class Subject(dj.Manual): + definition = """ + subject : int + --- + -> [unique, nullable] Person + """ + + +@schema +class Prep(dj.Manual): + definition = """ + prep : int + """ + + +@schema +class Slice(dj.Manual): + definition = """ + -> Prep + slice : int + """ + + +@schema +class Cell(dj.Manual): + definition = """ + -> Slice + cell : int + """ + + +@schema +class InputCell(dj.Manual): + definition = """ # a synapse within the slice + -> Cell + -> Cell.proj(input="cell") + """ + + +@schema +class LocalSynapse(dj.Manual): + definition = """ # a synapse within the slice + -> Cell.proj(presynaptic='cell') + -> Cell.proj(postsynaptic='cell') + """ + + +@schema +class GlobalSynapse(dj.Manual): + # Mix old-style and new-style projected foreign keys + definition = """ + # a synapse within the slice + -> Cell.proj(pre_slice="slice", pre_cell="cell") + -> Cell.proj(post_slice="slice", post_cell="cell") + """ diff --git a/tests/schema_simple.py b/tests/schema_simple.py new file mode 100644 index 000000000..78f64d036 --- /dev/null +++ b/tests/schema_simple.py @@ -0,0 +1,279 @@ +""" +A simple, abstract schema to test relational algebra +""" +import random +import datajoint as dj +import itertools +import hashlib +import uuid +import faker +from . import PREFIX, CONN_INFO +import numpy as np +from datetime import date, timedelta + +schema = dj.Schema(PREFIX + "_relational", locals(), connection=dj.conn(**CONN_INFO)) + + +@schema +class IJ(dj.Lookup): + definition = """ # tests restrictions + i : int + j : int + """ + contents = list(dict(i=i, j=j + 2) for i in range(3) for j in range(3)) + + +@schema +class JI(dj.Lookup): + definition = """ # tests restrictions by relations when attributes are reordered + j : int + i : int + """ + contents = list(dict(i=i + 1, j=j) for i in range(3) for j in range(3)) + + +@schema +class A(dj.Lookup): + definition = """ + id_a :int + --- + cond_in_a :tinyint + """ + contents = [(i, i % 4 > i % 3) for i in range(10)] + + +@schema +class B(dj.Computed): + definition = """ + -> A + id_b :int + --- + mu :float # mean value + sigma :float # standard deviation + n :smallint # number samples + """ + + class C(dj.Part): + definition = """ + -> B + id_c :int + --- + value :float # normally distributed variables according to parameters in B + """ + + def make(self, key): + random.seed(str(key)) + sub = B.C() + for i in range(4): + key["id_b"] = i + mu = random.normalvariate(0, 10) + sigma = random.lognormvariate(0, 4) + n = random.randint(0, 10) + self.insert1(dict(key, mu=mu, sigma=sigma, n=n)) + sub.insert( + dict(key, id_c=j, value=random.normalvariate(mu, sigma)) + for j in range(n) + ) + + +@schema +class L(dj.Lookup): + definition = """ + id_l: int + --- + cond_in_l :tinyint + """ + contents = [(i, i % 3 >= i % 5) for i in range(30)] + + +@schema +class D(dj.Computed): + definition = """ + -> A + id_d :int + --- + -> L + """ + + def _make_tuples(self, key): + # make reference to a random tuple from L + random.seed(str(key)) + lookup = list(L().fetch("KEY")) + self.insert(dict(key, id_d=i, **random.choice(lookup)) for i in range(4)) + + +@schema +class E(dj.Computed): + definition = """ + -> B + -> D + --- + -> L + """ + + class F(dj.Part): + definition = """ + -> E + id_f :int + --- + -> B.C + """ + + def make(self, key): + random.seed(str(key)) + self.insert1(dict(key, **random.choice(list(L().fetch("KEY"))))) + sub = E.F() + references = list((B.C() & key).fetch("KEY")) + random.shuffle(references) + sub.insert( + dict(key, id_f=i, **ref) + for i, ref in enumerate(references) + if random.getrandbits(1) + ) + + +@schema +class F(dj.Manual): + definition = """ + id: int + ---- + date=null: date + """ + + +@schema +class DataA(dj.Lookup): + definition = """ + idx : int + --- + a : int + """ + contents = list(zip(range(5), range(5))) + + +@schema +class DataB(dj.Lookup): + definition = """ + idx : int + --- + a : int + """ + contents = list(zip(range(5), range(5, 10))) + + +@schema +class Website(dj.Lookup): + definition = """ + url_hash : uuid + --- + url : varchar(1000) + """ + + def insert1_url(self, url): + hashed = hashlib.sha1() + hashed.update(url.encode()) + url_hash = uuid.UUID(bytes=hashed.digest()[:16]) + self.insert1(dict(url=url, url_hash=url_hash), skip_duplicates=True) + return url_hash + + +@schema +class Profile(dj.Manual): + definition = """ + ssn : char(11) + --- + name : varchar(70) + residence : varchar(255) + blood_group : enum('A+', 'A-', 'AB+', 'AB-', 'B+', 'B-', 'O+', 'O-') + username : varchar(120) + birthdate : date + job : varchar(120) + sex : enum('M', 'F') + """ + + class Website(dj.Part): + definition = """ + -> master + -> Website + """ + + def populate_random(self, n=10): + fake = faker.Faker() + faker.Faker.seed(0) # make test deterministic + for _ in range(n): + profile = fake.profile() + with self.connection.transaction: + self.insert1(profile, ignore_extra_fields=True) + for url in profile["website"]: + self.Website().insert1( + dict(ssn=profile["ssn"], url_hash=Website().insert1_url(url)) + ) + + +@schema +class TTestUpdate(dj.Lookup): + definition = """ + primary_key : int + --- + string_attr : varchar(255) + num_attr=null : float + blob_attr : longblob + """ + + contents = [ + (0, "my_string", 0.0, np.random.randn(10, 2)), + (1, "my_other_string", 1.0, np.random.randn(20, 1)), + ] + + +@schema +class ArgmaxTest(dj.Lookup): + definition = """ + primary_key : int + --- + secondary_key : char(2) + val : float + """ + + n = 10 + + @property + def contents(self): + n = self.n + yield from zip( + range(n**2), + itertools.chain(*itertools.repeat(tuple(map(chr, range(100, 100 + n))), n)), + np.random.rand(n**2), + ) + + +@schema +class ReservedWord(dj.Manual): + definition = """ + # Test of SQL reserved words + key : int + --- + in : varchar(25) + from : varchar(25) + int : int + select : varchar(25) + """ + + +@schema +class OutfitLaunch(dj.Lookup): + definition = """ + # Monthly released designer outfits + release_id: int + --- + day: date + """ + contents = [(0, date.today() - timedelta(days=15))] + + class OutfitPiece(dj.Part, dj.Lookup): + definition = """ + # Outfit piece associated with outfit + -> OutfitLaunch + piece: varchar(20) + """ + contents = [(0, "jeans"), (0, "sneakers"), (0, "polo")] diff --git a/tests_old/test_blob.py b/tests/test_blob.py similarity index 73% rename from tests_old/test_blob.py rename to tests/test_blob.py index 3765edc57..562d78f2b 100644 --- a/tests_old/test_blob.py +++ b/tests/test_blob.py @@ -7,15 +7,7 @@ from datetime import datetime from datajoint.blob import pack, unpack from numpy.testing import assert_array_equal -from nose.tools import ( - assert_equal, - assert_true, - assert_false, - assert_list_equal, - assert_set_equal, - assert_tuple_equal, - assert_dict_equal, -) +from pytest import approx def test_pack(): @@ -24,19 +16,19 @@ def test_pack(): -3.7e-2, np.float64(3e31), -np.inf, - np.int8(-3), - np.uint8(-1), + np.array(-3).astype(np.uint8), + np.array(-1).astype(np.uint8), np.int16(-33), - np.uint16(-33), + np.array(-33).astype(np.uint16), np.int32(-3), - np.uint32(-1), + np.array(-1).astype(np.uint32), np.int64(373), - np.uint64(-3), + np.array(-3).astype(np.uint64), ): - assert_equal(x, unpack(pack(x)), "Scalars don't match!") + assert x == approx(unpack(pack(x)), rel=1e-6), "Scalars don't match!" x = np.nan - assert_true(np.isnan(unpack(pack(x))), "nan scalar did not match!") + assert np.isnan(unpack(pack(x))), "nan scalar did not match!" x = np.random.randn(8, 10) assert_array_equal(x, unpack(pack(x)), "Arrays do not match!") @@ -45,7 +37,7 @@ def test_pack(): assert_array_equal(x, unpack(pack(x)), "Arrays do not match!") x = 7j - assert_equal(x, unpack(pack(x)), "Complex scalar does not match") + assert x == unpack(pack(x)), "Complex scalar does not match" x = np.float32(np.random.randn(3, 4, 5)) assert_array_equal(x, unpack(pack(x)), "Arrays do not match!") @@ -54,41 +46,37 @@ def test_pack(): assert_array_equal(x, unpack(pack(x)), "Arrays do not match!") x = None - assert_true(unpack(pack(x)) is None, "None did not match") + assert unpack(pack(x)) is None, "None did not match" x = -255 y = unpack(pack(x)) - assert_true( - x == y and isinstance(y, int) and not isinstance(y, np.ndarray), - "Scalar int did not match", - ) + assert ( + x == y and isinstance(y, int) and not isinstance(y, np.ndarray) + ), "Scalar int did not match" x = -25523987234234287910987234987098245697129798713407812347 y = unpack(pack(x)) - assert_true( - x == y and isinstance(y, int) and not isinstance(y, np.ndarray), - "Unbounded int did not match", - ) + assert ( + x == y and isinstance(y, int) and not isinstance(y, np.ndarray) + ), "Unbounded int did not match" x = 7.0 y = unpack(pack(x)) - assert_true( - x == y and isinstance(y, float) and not isinstance(y, np.ndarray), - "Scalar float did not match", - ) + assert ( + x == y and isinstance(y, float) and not isinstance(y, np.ndarray) + ), "Scalar float did not match" x = 7j y = unpack(pack(x)) - assert_true( - x == y and isinstance(y, complex) and not isinstance(y, np.ndarray), - "Complex scalar did not match", - ) + assert ( + x == y and isinstance(y, complex) and not isinstance(y, np.ndarray) + ), "Complex scalar did not match" x = True - assert_true(unpack(pack(x)) is True, "Scalar bool did not match") + assert unpack(pack(x)) is True, "Scalar bool did not match" x = [None] - assert_list_equal(x, unpack(pack(x))) + assert [None] == unpack(pack(x)) x = { "name": "Anonymous", @@ -98,22 +86,22 @@ def test_pack(): (11, 12): None, } y = unpack(pack(x)) - assert_dict_equal(x, y, "Dict do not match!") - assert_false( - isinstance(["range"][0], np.ndarray), "Scalar int was coerced into array." - ) + assert x == y, "Dict do not match!" + assert not isinstance( + ["range"][0], np.ndarray + ), "Scalar int was coerced into array." x = uuid.uuid4() - assert_equal(x, unpack(pack(x)), "UUID did not match") + assert x == unpack(pack(x)), "UUID did not match" x = Decimal("-112122121.000003000") - assert_equal(x, unpack(pack(x)), "Decimal did not pack/unpack correctly") + assert x == unpack(pack(x)), "Decimal did not pack/unpack correctly" x = [1, datetime.now(), {1: "one", "two": 2}, (1, 2)] - assert_list_equal(x, unpack(pack(x)), "List did not pack/unpack correctly") + assert x == unpack(pack(x)), "List did not pack/unpack correctly" x = (1, datetime.now(), {1: "one", "two": 2}, (uuid.uuid4(), 2)) - assert_tuple_equal(x, unpack(pack(x)), "Tuple did not pack/unpack correctly") + assert x == unpack(pack(x)), "Tuple did not pack/unpack correctly" x = ( 1, @@ -121,36 +109,34 @@ def test_pack(): {"yes!": [1, 2, np.array((3, 4))]}, ) y = unpack(pack(x)) - assert_dict_equal(x[1], y[1]) + assert x[1] == y[1] assert_array_equal(x[2]["yes!"][2], y[2]["yes!"][2]) x = {"elephant"} - assert_set_equal(x, unpack(pack(x)), "Set did not pack/unpack correctly") + assert x == unpack(pack(x)), "Set did not pack/unpack correctly" x = tuple(range(10)) - assert_tuple_equal( - x, unpack(pack(range(10))), "Iterator did not pack/unpack correctly" - ) + assert x == unpack(pack(range(10))), "Iterator did not pack/unpack correctly" x = Decimal("1.24") - assert_true(x == unpack(pack(x)), "Decimal object did not pack/unpack correctly") + assert x == approx(unpack(pack(x))), "Decimal object did not pack/unpack correctly" x = datetime.now() - assert_true(x == unpack(pack(x)), "Datetime object did not pack/unpack correctly") + assert x == unpack(pack(x)), "Datetime object did not pack/unpack correctly" x = np.bool_(True) - assert_true(x == unpack(pack(x)), "Numpy bool object did not pack/unpack correctly") + assert x == unpack(pack(x)), "Numpy bool object did not pack/unpack correctly" x = "test" - assert_true(x == unpack(pack(x)), "String object did not pack/unpack correctly") + assert x == unpack(pack(x)), "String object did not pack/unpack correctly" x = np.array(["yes"]) - assert_true( - x == unpack(pack(x)), "Numpy string array object did not pack/unpack correctly" - ) + assert x == unpack( + pack(x) + ), "Numpy string array object did not pack/unpack correctly" x = np.datetime64("1998").astype("datetime64[us]") - assert_true(x == unpack(pack(x))) + assert x == unpack(pack(x)) def test_recarrays(): diff --git a/tests_old/test_blob_matlab.py b/tests/test_blob_matlab.py similarity index 83% rename from tests_old/test_blob_matlab.py rename to tests/test_blob_matlab.py index 6104c9291..ecb698fec 100644 --- a/tests_old/test_blob_matlab.py +++ b/tests/test_blob_matlab.py @@ -1,8 +1,6 @@ import numpy as np import datajoint as dj from datajoint.blob import pack, unpack - -from nose.tools import assert_equal, assert_true, assert_tuple_equal, assert_false from numpy.testing import assert_array_equal from . import PREFIX, CONN_INFO @@ -58,7 +56,8 @@ def insert_blobs(): class TestFetch: @classmethod def setup_class(cls): - assert_false(dj.config["safemode"], "safemode must be disabled") + dj.config["safemode"] = False # temp + assert not dj.config["safemode"], "safemode must be disabled" Blob().delete() insert_blobs() @@ -70,43 +69,43 @@ def test_complex_matlab_blobs(): blobs = Blob().fetch("blob", order_by="KEY") blob = blobs[0] # 'simple string' 'character string' - assert_equal(blob[0], "character string") + assert blob[0] == "character string" blob = blobs[1] # '1D vector' 1:15:180 assert_array_equal(blob, np.r_[1:180:15][None, :]) assert_array_equal(blob, unpack(pack(blob))) blob = blobs[2] # 'string array' {'string1' 'string2'} - assert_true(isinstance(blob, dj.MatCell)) + assert isinstance(blob, dj.MatCell) assert_array_equal(blob, np.array([["string1", "string2"]])) assert_array_equal(blob, unpack(pack(blob))) blob = blobs[ 3 ] # 'struct array' struct('a', {1,2}, 'b', {struct('c', magic(3)), struct('C', magic(5))}) - assert_true(isinstance(blob, dj.MatStruct)) - assert_tuple_equal(blob.dtype.names, ("a", "b")) + assert isinstance(blob, dj.MatStruct) + assert tuple(blob.dtype.names) == ("a", "b") assert_array_equal(blob.a[0, 0], np.array([[1.0]])) assert_array_equal(blob.a[0, 1], np.array([[2.0]])) - assert_true(isinstance(blob.b[0, 1], dj.MatStruct)) - assert_tuple_equal(blob.b[0, 1].C[0, 0].shape, (5, 5)) + assert isinstance(blob.b[0, 1], dj.MatStruct) + assert tuple(blob.b[0, 1].C[0, 0].shape) == (5, 5) b = unpack(pack(blob)) assert_array_equal(b[0, 0].b[0, 0].c, blob[0, 0].b[0, 0].c) assert_array_equal(b[0, 1].b[0, 0].C, blob[0, 1].b[0, 0].C) blob = blobs[4] # '3D double array' reshape(1:24, [2,3,4]) assert_array_equal(blob, np.r_[1:25].reshape((2, 3, 4), order="F")) - assert_true(blob.dtype == "float64") + assert blob.dtype == "float64" assert_array_equal(blob, unpack(pack(blob))) blob = blobs[5] # reshape(uint8(1:24), [2,3,4]) - assert_true(np.array_equal(blob, np.r_[1:25].reshape((2, 3, 4), order="F"))) - assert_true(blob.dtype == "uint8") + assert np.array_equal(blob, np.r_[1:25].reshape((2, 3, 4), order="F")) + assert blob.dtype == "uint8" assert_array_equal(blob, unpack(pack(blob))) blob = blobs[6] # fftn(reshape(1:24, [2,3,4])) - assert_tuple_equal(blob.shape, (2, 3, 4)) - assert_true(blob.dtype == "complex128") + assert tuple(blob.shape) == (2, 3, 4) + assert blob.dtype == "complex128" assert_array_equal(blob, unpack(pack(blob))) @staticmethod @@ -117,7 +116,7 @@ def test_complex_matlab_squeeze(): blob = (Blob & "id=1").fetch1( "blob", squeeze=True ) # 'simple string' 'character string' - assert_equal(blob, "character string") + assert blob == "character string" blob = (Blob & "id=2").fetch1( "blob", squeeze=True @@ -127,14 +126,14 @@ def test_complex_matlab_squeeze(): blob = (Blob & "id=3").fetch1( "blob", squeeze=True ) # 'string array' {'string1' 'string2'} - assert_true(isinstance(blob, dj.MatCell)) + assert isinstance(blob, dj.MatCell) assert_array_equal(blob, np.array(["string1", "string2"])) blob = (Blob & "id=4").fetch1( "blob", squeeze=True ) # 'struct array' struct('a', {1,2}, 'b', {struct('c', magic(3)), struct('C', magic(5))}) - assert_true(isinstance(blob, dj.MatStruct)) - assert_tuple_equal(blob.dtype.names, ("a", "b")) + assert isinstance(blob, dj.MatStruct) + assert tuple(blob.dtype.names) == ("a", "b") assert_array_equal( blob.a, np.array( @@ -144,32 +143,31 @@ def test_complex_matlab_squeeze(): ] ), ) - assert_true(isinstance(blob[1].b, dj.MatStruct)) - assert_tuple_equal(blob[1].b.C.item().shape, (5, 5)) + assert isinstance(blob[1].b, dj.MatStruct) + assert tuple(blob[1].b.C.item().shape) == (5, 5) blob = (Blob & "id=5").fetch1( "blob", squeeze=True ) # '3D double array' reshape(1:24, [2,3,4]) - assert_true(np.array_equal(blob, np.r_[1:25].reshape((2, 3, 4), order="F"))) - assert_true(blob.dtype == "float64") + assert np.array_equal(blob, np.r_[1:25].reshape((2, 3, 4), order="F")) + assert blob.dtype == "float64" blob = (Blob & "id=6").fetch1( "blob", squeeze=True ) # reshape(uint8(1:24), [2,3,4]) - assert_true(np.array_equal(blob, np.r_[1:25].reshape((2, 3, 4), order="F"))) - assert_true(blob.dtype == "uint8") + assert np.array_equal(blob, np.r_[1:25].reshape((2, 3, 4), order="F")) + assert blob.dtype == "uint8" blob = (Blob & "id=7").fetch1( "blob", squeeze=True ) # fftn(reshape(1:24, [2,3,4])) - assert_tuple_equal(blob.shape, (2, 3, 4)) - assert_true(blob.dtype == "complex128") + assert tuple(blob.shape) == (2, 3, 4) + assert blob.dtype == "complex128" - @staticmethod - def test_iter(): + def test_iter(self): """ test iterator over the entity set """ from_iter = {d["id"]: d for d in Blob()} - assert_equal(len(from_iter), len(Blob())) - assert_equal(from_iter[1]["blob"], "character string") + assert len(from_iter) == len(Blob()) + assert from_iter[1]["blob"] == "character string" diff --git a/tests_old/test_dependencies.py b/tests/test_dependencies.py similarity index 64% rename from tests_old/test_dependencies.py rename to tests/test_dependencies.py index c359b602a..1e8b1da41 100644 --- a/tests_old/test_dependencies.py +++ b/tests/test_dependencies.py @@ -1,57 +1,54 @@ -from nose.tools import assert_true, raises, assert_list_equal +import datajoint as dj +from datajoint import errors +from pytest import raises + from .schema import * from datajoint.dependencies import unite_master_parts def test_unite_master_parts(): - assert_list_equal( - unite_master_parts( - [ - "`s`.`a`", - "`s`.`a__q`", - "`s`.`b`", - "`s`.`c`", - "`s`.`c__q`", - "`s`.`b__q`", - "`s`.`d`", - "`s`.`a__r`", - ] - ), + assert unite_master_parts( [ "`s`.`a`", "`s`.`a__q`", - "`s`.`a__r`", "`s`.`b`", - "`s`.`b__q`", "`s`.`c`", "`s`.`c__q`", + "`s`.`b__q`", "`s`.`d`", - ], - ) - assert_list_equal( - unite_master_parts( - [ - "`lab`.`#equipment`", - "`cells`.`cell_analysis_method`", - "`cells`.`cell_analysis_method_task_type`", - "`cells`.`cell_analysis_method_users`", - "`cells`.`favorite_selection`", - "`cells`.`cell_analysis_method__cell_selection_params`", - "`lab`.`#equipment__config`", - "`cells`.`cell_analysis_method__field_detect_params`", - ] - ), + "`s`.`a__r`", + ] + ) == [ + "`s`.`a`", + "`s`.`a__q`", + "`s`.`a__r`", + "`s`.`b`", + "`s`.`b__q`", + "`s`.`c`", + "`s`.`c__q`", + "`s`.`d`", + ] + assert unite_master_parts( [ "`lab`.`#equipment`", - "`lab`.`#equipment__config`", "`cells`.`cell_analysis_method`", - "`cells`.`cell_analysis_method__cell_selection_params`", - "`cells`.`cell_analysis_method__field_detect_params`", "`cells`.`cell_analysis_method_task_type`", "`cells`.`cell_analysis_method_users`", "`cells`.`favorite_selection`", - ], - ) + "`cells`.`cell_analysis_method__cell_selection_params`", + "`lab`.`#equipment__config`", + "`cells`.`cell_analysis_method__field_detect_params`", + ] + ) == [ + "`lab`.`#equipment`", + "`lab`.`#equipment__config`", + "`cells`.`cell_analysis_method`", + "`cells`.`cell_analysis_method__cell_selection_params`", + "`cells`.`cell_analysis_method__field_detect_params`", + "`cells`.`cell_analysis_method_task_type`", + "`cells`.`cell_analysis_method_users`", + "`cells`.`favorite_selection`", + ] def test_nullable_dependency(): @@ -80,10 +77,9 @@ def test_nullable_dependency(): c.insert1(dict(a=3, b1=1, b2=1)) c.insert1(dict(a=4, b1=1, b2=2)) - assert_true(len(c) == len(c.fetch()) == 5) + assert len(c) == len(c.fetch()) == 5 -@raises(dj.errors.DuplicateError) def test_unique_dependency(): """test nullable unique foreign key""" @@ -104,4 +100,5 @@ def test_unique_dependency(): c.insert1(dict(a=0, b1=1, b2=1)) # duplicate foreign key attributes = not ok - c.insert1(dict(a=1, b1=1, b2=1)) + with raises(errors.DuplicateError): + c.insert1(dict(a=1, b1=1, b2=1)) diff --git a/tests/test_erd.py b/tests/test_erd.py new file mode 100644 index 000000000..991410995 --- /dev/null +++ b/tests/test_erd.py @@ -0,0 +1,76 @@ +import datajoint as dj +from .schema_simple import A, B, D, E, L, schema, OutfitLaunch +from . import schema_advanced + +namespace = locals() + + +class TestERD: + @staticmethod + def setup_method(): + """ + class-level test setup. Executes before each test method. + """ + + @staticmethod + def test_decorator(): + assert issubclass(A, dj.Lookup) + assert not issubclass(A, dj.Part) + assert B.database == schema.database + assert issubclass(B.C, dj.Part) + assert B.C.database == schema.database + assert B.C.master is B and E.F.master is E + + @staticmethod + def test_dependencies(): + deps = schema.connection.dependencies + deps.load() + assert all(cls.full_table_name in deps for cls in (A, B, B.C, D, E, E.F, L)) + assert set(A().children()) == set([B.full_table_name, D.full_table_name]) + assert set(D().parents(primary=True)) == set([A.full_table_name]) + assert set(D().parents(primary=False)) == set([L.full_table_name]) + assert set(deps.descendants(L.full_table_name)).issubset( + cls.full_table_name for cls in (L, D, E, E.F) + ) + + @staticmethod + def test_erd(): + assert dj.diagram.diagram_active, "Failed to import networkx and pydot" + erd = dj.ERD(schema, context=namespace) + graph = erd._make_graph() + assert set(cls.__name__ for cls in (A, B, D, E, L)).issubset(graph.nodes()) + + @staticmethod + def test_erd_algebra(): + erd0 = dj.ERD(B) + erd1 = erd0 + 3 + erd2 = dj.Di(E) - 3 + erd3 = erd1 * erd2 + erd4 = (erd0 + E).add_parts() - B - E + assert erd0.nodes_to_show == set(cls.full_table_name for cls in [B]) + assert erd1.nodes_to_show == set( + cls.full_table_name for cls in (B, B.C, E, E.F) + ) + assert erd2.nodes_to_show == set(cls.full_table_name for cls in (A, B, D, E, L)) + assert erd3.nodes_to_show == set(cls.full_table_name for cls in (B, E)) + assert erd4.nodes_to_show == set(cls.full_table_name for cls in (B.C, E.F)) + + @staticmethod + def test_repr_svg(): + erd = dj.ERD(schema_advanced, context=namespace) + svg = erd._repr_svg_() + assert svg.startswith("") + + @staticmethod + def test_make_image(): + erd = dj.ERD(schema, context=namespace) + img = erd.make_image() + assert img.ndim == 3 and img.shape[2] in (3, 4) + + @staticmethod + def test_part_table_parsing(): + # https://github.com/datajoint/datajoint-python/issues/882 + erd = dj.Di(schema) + graph = erd._make_graph() + assert "OutfitLaunch" in graph.nodes() + assert "OutfitLaunch.OutfitPiece" in graph.nodes() diff --git a/tests_old/test_foreign_keys.py b/tests/test_foreign_keys.py similarity index 72% rename from tests_old/test_foreign_keys.py rename to tests/test_foreign_keys.py index d082960e4..05d87c041 100644 --- a/tests_old/test_foreign_keys.py +++ b/tests/test_foreign_keys.py @@ -1,4 +1,3 @@ -from nose.tools import assert_equal, assert_false, assert_true from datajoint.declare import declare from . import schema_advanced @@ -8,18 +7,16 @@ def test_aliased_fk(): person = schema_advanced.Person() parent = schema_advanced.Parent() person.delete() - assert_false(person) - assert_false(parent) + assert not person + assert not parent person.fill() parent.fill() - assert_true(person) - assert_true(parent) + assert person + assert parent link = person.proj(parent_name="full_name", parent="person_id") parents = person * parent * link parents &= dict(full_name="May K. Hall") - assert_equal( - set(parents.fetch("parent_name")), {"Hanna R. Walters", "Russel S. James"} - ) + assert set(parents.fetch("parent_name")) == {"Hanna R. Walters", "Russel S. James"} delete_count = person.delete() assert delete_count == 16 @@ -33,19 +30,19 @@ def test_describe(): )[0].split("\n") s2 = declare(rel.full_table_name, describe, globals())[0].split("\n") for c1, c2 in zip(s1, s2): - assert_equal(c1, c2) + assert c1 == c2 def test_delete(): person = schema_advanced.Person() parent = schema_advanced.Parent() person.delete() - assert_false(person) - assert_false(parent) + assert not person + assert not parent person.fill() parent.fill() - assert_true(parent) + assert parent original_len = len(parent) to_delete = len(parent & "11 in (person_id, parent)") (person & "person_id=11").delete() - assert_true(to_delete and len(parent) == original_len - to_delete) + assert to_delete and len(parent) == original_len - to_delete diff --git a/tests_old/test_groupby.py b/tests/test_groupby.py similarity index 100% rename from tests_old/test_groupby.py rename to tests/test_groupby.py diff --git a/tests/test_hash.py b/tests/test_hash.py new file mode 100644 index 000000000..a88c45316 --- /dev/null +++ b/tests/test_hash.py @@ -0,0 +1,6 @@ +from datajoint import hash + + +def test_hash(): + assert hash.uuid_from_buffer(b"abc").hex == "900150983cd24fb0d6963f7d28e17f72" + assert hash.uuid_from_buffer(b"").hex == "d41d8cd98f00b204e9800998ecf8427e" diff --git a/tests_old/test_json.py b/tests/test_json.py similarity index 98% rename from tests_old/test_json.py rename to tests/test_json.py index b9b13e4ee..760475a1a 100644 --- a/tests_old/test_json.py +++ b/tests/test_json.py @@ -2,12 +2,10 @@ from datajoint.declare import declare import datajoint as dj import numpy as np -from distutils.version import LooseVersion +from packaging.version import Version from . import PREFIX -if LooseVersion(dj.conn().query("select @@version;").fetchone()[0]) >= LooseVersion( - "8.0.0" -): +if Version(dj.conn().query("select @@version;").fetchone()[0]) >= Version("8.0.0"): schema = dj.Schema(PREFIX + "_json") Team = None diff --git a/tests_old/test_log.py b/tests/test_log.py similarity index 69% rename from tests_old/test_log.py rename to tests/test_log.py index 86a48bc37..a3aafa992 100644 --- a/tests_old/test_log.py +++ b/tests/test_log.py @@ -1,4 +1,3 @@ -from nose.tools import assert_true from . import schema @@ -6,4 +5,4 @@ def test_log(): ts, events = (schema.schema.log & 'event like "Declared%%"').fetch( "timestamp", "event" ) - assert_true(len(ts) >= 2) + assert len(ts) >= 2 diff --git a/tests_old/test_nan.py b/tests/test_nan.py similarity index 73% rename from tests_old/test_nan.py rename to tests/test_nan.py index b06848fdf..ad4e6239e 100644 --- a/tests_old/test_nan.py +++ b/tests/test_nan.py @@ -1,5 +1,4 @@ import numpy as np -from nose.tools import assert_true import datajoint as dj from . import PREFIX, CONN_INFO @@ -28,15 +27,10 @@ def setup_class(cls): def test_insert_nan(self): """Test fetching of null values""" b = self.rel.fetch("value", order_by="id") - assert_true( - (np.isnan(self.a) == np.isnan(b)).all(), "incorrect handling of Nans" - ) - assert_true( - np.allclose( - self.a[np.logical_not(np.isnan(self.a))], b[np.logical_not(np.isnan(b))] - ), - "incorrect storage of floats", - ) + (np.isnan(self.a) == np.isnan(b)).all(), "incorrect handling of Nans" + np.allclose( + self.a[np.logical_not(np.isnan(self.a))], b[np.logical_not(np.isnan(b))] + ), "incorrect storage of floats" def test_nulls_do_not_affect_primary_keys(self): """Test against a case that previously caused a bug when skipping existing entries.""" diff --git a/tests_old/test_plugin.py b/tests/test_plugin.py similarity index 100% rename from tests_old/test_plugin.py rename to tests/test_plugin.py diff --git a/tests_old/test_relation_u.py b/tests/test_relation_u.py similarity index 52% rename from tests_old/test_relation_u.py rename to tests/test_relation_u.py index ff30711b3..44033708d 100644 --- a/tests_old/test_relation_u.py +++ b/tests/test_relation_u.py @@ -1,6 +1,6 @@ -from nose.tools import assert_equal, assert_true, raises, assert_list_equal -from . import schema, schema_simple import datajoint as dj +from pytest import raises +from . import schema, schema_simple class TestU: @@ -23,37 +23,35 @@ def setup_class(cls): def test_restriction(self): language_set = {s[1] for s in self.language.contents} rel = dj.U("language") & self.language - assert_list_equal(rel.heading.names, ["language"]) - assert_true(len(rel) == len(language_set)) - assert_true(set(rel.fetch("language")) == language_set) + assert list(rel.heading.names) == ["language"] + assert len(rel) == len(language_set) + assert set(rel.fetch("language")) == language_set # Test for issue #342 rel = self.trial * dj.U("start_time") - assert_list_equal(rel.primary_key, self.trial.primary_key + ["start_time"]) - assert_list_equal(rel.primary_key, (rel & "trial_id>3").primary_key) - assert_list_equal((dj.U("start_time") & self.trial).primary_key, ["start_time"]) + assert list(rel.primary_key) == self.trial.primary_key + ["start_time"] + assert list(rel.primary_key) == list((rel & "trial_id>3").primary_key) + assert list((dj.U("start_time") & self.trial).primary_key) == ["start_time"] - @staticmethod - @raises(dj.DataJointError) - def test_invalid_restriction(): - result = dj.U("color") & dict(color="red") + def test_invalid_restriction(self): + with raises(dj.DataJointError): + result = dj.U("color") & dict(color="red") def test_ineffective_restriction(self): rel = self.language & dj.U("language") - assert_true(rel.make_sql() == self.language.make_sql()) + assert rel.make_sql() == self.language.make_sql() def test_join(self): rel = self.experiment * dj.U("experiment_date") - assert_equal(self.experiment.primary_key, ["subject_id", "experiment_id"]) - assert_equal(rel.primary_key, self.experiment.primary_key + ["experiment_date"]) + assert self.experiment.primary_key == ["subject_id", "experiment_id"] + assert rel.primary_key == self.experiment.primary_key + ["experiment_date"] rel = dj.U("experiment_date") * self.experiment - assert_equal(self.experiment.primary_key, ["subject_id", "experiment_id"]) - assert_equal(rel.primary_key, self.experiment.primary_key + ["experiment_date"]) + assert self.experiment.primary_key == ["subject_id", "experiment_id"] + assert rel.primary_key == self.experiment.primary_key + ["experiment_date"] - @staticmethod - @raises(dj.DataJointError) - def test_invalid_join(): - rel = dj.U("language") * dict(language="English") + def test_invalid_join(self): + with raises(dj.DataJointError): + rel = dj.U("language") * dict(language="English") def test_repr_without_attrs(self): """test dj.U() display""" @@ -64,25 +62,24 @@ def test_aggregations(self): lang = schema.Language() # test total aggregation on expression object n1 = dj.U().aggr(lang, n="count(*)").fetch1("n") - assert_equal(n1, len(lang.fetch())) + assert n1 == len(lang.fetch()) # test total aggregation on expression class n2 = dj.U().aggr(schema.Language, n="count(*)").fetch1("n") - assert_equal(n1, n2) + assert n1 == n2 rel = dj.U("language").aggr(schema.Language, number_of_speakers="count(*)") - assert_equal(len(rel), len(set(l[1] for l in schema.Language.contents))) - assert_equal((rel & 'language="English"').fetch1("number_of_speakers"), 3) + assert len(rel) == len(set(l[1] for l in schema.Language.contents)) + assert (rel & 'language="English"').fetch1("number_of_speakers") == 3 def test_argmax(self): rel = schema.TTest() - # get the tuples corresponding to maximum value + # get the tuples corresponding to the maximum value mx = (rel * dj.U().aggr(rel, mx="max(value)")) & "mx=value" - assert_equal(mx.fetch("value")[0], max(rel.fetch("value"))) + assert mx.fetch("value")[0] == max(rel.fetch("value")) def test_aggr(self): rel = schema_simple.ArgmaxTest() amax1 = (dj.U("val") * rel) & dj.U("secondary_key").aggr(rel, val="min(val)") amax2 = (dj.U("val") * rel) * dj.U("secondary_key").aggr(rel, val="min(val)") - assert_true( - len(amax1) == len(amax2) == rel.n, - "Aggregated argmax with join and restriction does not yield same length.", - ) + assert ( + len(amax1) == len(amax2) == rel.n + ), "Aggregated argmax with join and restriction does not yield the same length." diff --git a/tests_old/test_schema_keywords.py b/tests/test_schema_keywords.py similarity index 67% rename from tests_old/test_schema_keywords.py rename to tests/test_schema_keywords.py index 49f380f57..1853852ed 100644 --- a/tests_old/test_schema_keywords.py +++ b/tests/test_schema_keywords.py @@ -1,7 +1,5 @@ from . import PREFIX, CONN_INFO import datajoint as dj -from nose.tools import assert_true - schema = dj.Schema(PREFIX + "_keywords", connection=dj.conn(**CONN_INFO)) @@ -39,8 +37,8 @@ class D(B): def test_inherited_part_table(): - assert_true("a_id" in D().heading.attributes) - assert_true("b_id" in D().heading.attributes) - assert_true("a_id" in D.C().heading.attributes) - assert_true("b_id" in D.C().heading.attributes) - assert_true("name" in D.C().heading.attributes) + assert "a_id" in D().heading.attributes + assert "b_id" in D().heading.attributes + assert "a_id" in D.C().heading.attributes + assert "b_id" in D.C().heading.attributes + assert "name" in D.C().heading.attributes diff --git a/tests_old/test_settings.py b/tests/test_settings.py similarity index 69% rename from tests_old/test_settings.py rename to tests/test_settings.py index 63c3dad36..b937d5ad3 100644 --- a/tests_old/test_settings.py +++ b/tests/test_settings.py @@ -1,8 +1,8 @@ import pprint import random import string -from datajoint import settings -from nose.tools import assert_true, assert_equal, raises +import pytest +from datajoint import DataJointError, settings import datajoint as dj import os @@ -14,7 +14,7 @@ def test_load_save(): dj.config.save("tmp.json") conf = settings.Config() conf.load("tmp.json") - assert_true(conf == dj.config, "Two config files do not match.") + assert conf == dj.config os.remove("tmp.json") @@ -25,7 +25,7 @@ def test_singleton(): conf.load("tmp.json") conf["dummy.val"] = 2 - assert_true(conf == dj.config, "Config does not behave like a singleton.") + assert conf == dj.config os.remove("tmp.json") @@ -34,36 +34,36 @@ def test_singleton2(): conf = settings.Config() conf["dummy.val"] = 2 _ = settings.Config() # a new instance should not delete dummy.val - assert_true(conf["dummy.val"] == 2, "Config does not behave like a singleton.") + assert conf["dummy.val"] == 2 -@raises(dj.DataJointError) def test_validator(): """Testing validator""" - dj.config["database.port"] = "harbor" + with pytest.raises(DataJointError): + dj.config["database.port"] = "harbor" def test_del(): """Testing del""" dj.config["peter"] = 2 - assert_true("peter" in dj.config) + assert "peter" in dj.config del dj.config["peter"] - assert_true("peter" not in dj.config) + assert "peter" not in dj.config def test_len(): """Testing len""" - assert_equal(len(dj.config), len(dj.config._conf)) + len(dj.config) == len(dj.config._conf) def test_str(): """Testing str""" - assert_equal(str(dj.config), pprint.pformat(dj.config._conf, indent=4)) + str(dj.config) == pprint.pformat(dj.config._conf, indent=4) def test_repr(): """Testing repr""" - assert_equal(repr(dj.config), pprint.pformat(dj.config._conf, indent=4)) + repr(dj.config) == pprint.pformat(dj.config._conf, indent=4) def test_save(): @@ -76,7 +76,7 @@ def test_save(): os.rename(settings.LOCALCONFIG, tmpfile) moved = True dj.config.save_local() - assert_true(os.path.isfile(settings.LOCALCONFIG)) + assert os.path.isfile(settings.LOCALCONFIG) if moved: os.rename(tmpfile, settings.LOCALCONFIG) @@ -101,5 +101,5 @@ def test_contextmanager(): """Testing context manager""" dj.config["arbitrary.stuff"] = 7 with dj.config(arbitrary__stuff=10): - assert_true(dj.config["arbitrary.stuff"] == 10) - assert_true(dj.config["arbitrary.stuff"] == 7) + assert dj.config["arbitrary.stuff"] == 10 + assert dj.config["arbitrary.stuff"] == 7 diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 000000000..936badb1c --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,33 @@ +""" +Collection of test cases to test core module. +""" +from datajoint import DataJointError +from datajoint.utils import from_camel_case, to_camel_case +import pytest + + +def setup(): + pass + + +def teardown(): + pass + + +def test_from_camel_case(): + assert from_camel_case("AllGroups") == "all_groups" + with pytest.raises(DataJointError): + from_camel_case("repNames") + with pytest.raises(DataJointError): + from_camel_case("10_all") + with pytest.raises(DataJointError): + from_camel_case("hello world") + with pytest.raises(DataJointError): + from_camel_case("#baisc_names") + + +def test_to_camel_case(): + assert to_camel_case("all_groups") == "AllGroups" + assert to_camel_case("hello") == "Hello" + assert to_camel_case("this_is_a_sample_case") == "ThisIsASampleCase" + assert to_camel_case("This_is_Mixed") == "ThisIsMixed" diff --git a/tests/test_virtual_module.py b/tests/test_virtual_module.py new file mode 100644 index 000000000..d3546c488 --- /dev/null +++ b/tests/test_virtual_module.py @@ -0,0 +1,10 @@ +import datajoint as dj +from datajoint.user_tables import UserTable +from . import CONN_INFO + + +def test_virtual_module(schema_obj): + module = dj.VirtualModule( + "module", schema_obj.schema.database, connection=dj.conn(**CONN_INFO) + ) + assert issubclass(module.Experiment, UserTable) diff --git a/tests_old/test_erd.py b/tests_old/test_erd.py deleted file mode 100644 index 1a6293431..000000000 --- a/tests_old/test_erd.py +++ /dev/null @@ -1,87 +0,0 @@ -from nose.tools import assert_false, assert_true -import datajoint as dj -from .schema_simple import A, B, D, E, L, schema, OutfitLaunch -from . import schema_advanced - -namespace = locals() - - -class TestERD: - @staticmethod - def setup(): - """ - class-level test setup. Executes before each test method. - """ - - @staticmethod - def test_decorator(): - assert_true(issubclass(A, dj.Lookup)) - assert_false(issubclass(A, dj.Part)) - assert_true(B.database == schema.database) - assert_true(issubclass(B.C, dj.Part)) - assert_true(B.C.database == schema.database) - assert_true(B.C.master is B and E.F.master is E) - - @staticmethod - def test_dependencies(): - deps = schema.connection.dependencies - deps.load() - assert_true( - all(cls.full_table_name in deps for cls in (A, B, B.C, D, E, E.F, L)) - ) - assert_true(set(A().children()) == set([B.full_table_name, D.full_table_name])) - assert_true(set(D().parents(primary=True)) == set([A.full_table_name])) - assert_true(set(D().parents(primary=False)) == set([L.full_table_name])) - assert_true( - set(deps.descendants(L.full_table_name)).issubset( - cls.full_table_name for cls in (L, D, E, E.F) - ) - ) - - @staticmethod - def test_erd(): - assert_true(dj.diagram.diagram_active, "Failed to import networkx and pydot") - erd = dj.ERD(schema, context=namespace) - graph = erd._make_graph() - assert_true( - set(cls.__name__ for cls in (A, B, D, E, L)).issubset(graph.nodes()) - ) - - @staticmethod - def test_erd_algebra(): - erd0 = dj.ERD(B) - erd1 = erd0 + 3 - erd2 = dj.Di(E) - 3 - erd3 = erd1 * erd2 - erd4 = (erd0 + E).add_parts() - B - E - assert_true(erd0.nodes_to_show == set(cls.full_table_name for cls in [B])) - assert_true( - erd1.nodes_to_show == set(cls.full_table_name for cls in (B, B.C, E, E.F)) - ) - assert_true( - erd2.nodes_to_show == set(cls.full_table_name for cls in (A, B, D, E, L)) - ) - assert_true(erd3.nodes_to_show == set(cls.full_table_name for cls in (B, E))) - assert_true( - erd4.nodes_to_show == set(cls.full_table_name for cls in (B.C, E.F)) - ) - - @staticmethod - def test_repr_svg(): - erd = dj.ERD(schema_advanced, context=namespace) - svg = erd._repr_svg_() - assert_true(svg.startswith("")) - - @staticmethod - def test_make_image(): - erd = dj.ERD(schema, context=namespace) - img = erd.make_image() - assert_true(img.ndim == 3 and img.shape[2] in (3, 4)) - - @staticmethod - def test_part_table_parsing(): - # https://github.com/datajoint/datajoint-python/issues/882 - erd = dj.Di(schema) - graph = erd._make_graph() - assert "OutfitLaunch" in graph.nodes() - assert "OutfitLaunch.OutfitPiece" in graph.nodes() diff --git a/tests_old/test_hash.py b/tests_old/test_hash.py deleted file mode 100644 index dc88290eb..000000000 --- a/tests_old/test_hash.py +++ /dev/null @@ -1,7 +0,0 @@ -from nose.tools import assert_equal -from datajoint import hash - - -def test_hash(): - assert_equal(hash.uuid_from_buffer(b"abc").hex, "900150983cd24fb0d6963f7d28e17f72") - assert_equal(hash.uuid_from_buffer(b"").hex, "d41d8cd98f00b204e9800998ecf8427e") diff --git a/tests_old/test_utils.py b/tests_old/test_utils.py deleted file mode 100644 index b5ed96af3..000000000 --- a/tests_old/test_utils.py +++ /dev/null @@ -1,33 +0,0 @@ -""" -Collection of test cases to test core module. -""" -from nose.tools import assert_true, assert_raises, assert_equal -from datajoint import DataJointError -from datajoint.utils import from_camel_case, to_camel_case - - -def setup(): - pass - - -def teardown(): - pass - - -def test_from_camel_case(): - assert_equal(from_camel_case("AllGroups"), "all_groups") - with assert_raises(DataJointError): - from_camel_case("repNames") - with assert_raises(DataJointError): - from_camel_case("10_all") - with assert_raises(DataJointError): - from_camel_case("hello world") - with assert_raises(DataJointError): - from_camel_case("#baisc_names") - - -def test_to_camel_case(): - assert_equal(to_camel_case("all_groups"), "AllGroups") - assert_equal(to_camel_case("hello"), "Hello") - assert_equal(to_camel_case("this_is_a_sample_case"), "ThisIsASampleCase") - assert_equal(to_camel_case("This_is_Mixed"), "ThisIsMixed") diff --git a/tests_old/test_virtual_module.py b/tests_old/test_virtual_module.py deleted file mode 100644 index 58180916f..000000000 --- a/tests_old/test_virtual_module.py +++ /dev/null @@ -1,12 +0,0 @@ -from nose.tools import assert_true -import datajoint as dj -from datajoint.user_tables import UserTable -from . import schema -from . import CONN_INFO - - -def test_virtual_module(): - module = dj.VirtualModule( - "module", schema.schema.database, connection=dj.conn(**CONN_INFO) - ) - assert_true(issubclass(module.Experiment, UserTable)) From da85e97f13659575347523c0819f5275a22f232a Mon Sep 17 00:00:00 2001 From: A-Baji Date: Thu, 30 Nov 2023 13:14:28 -0600 Subject: [PATCH 2/6] feat: :sparkles: implement schema fixtures --- tests/__init__.py | 57 -------------------- tests/conftest.py | 61 ++++++++++++++++++++++ tests/test_blob_matlab.py | 98 +++++++++++++++++++---------------- tests/test_connection.py | 2 +- tests/test_nan.py | 40 +++++++------- tests/test_schema_keywords.py | 18 ++++--- tests/test_virtual_module.py | 5 +- 7 files changed, 152 insertions(+), 129 deletions(-) create mode 100644 tests/conftest.py diff --git a/tests/__init__.py b/tests/__init__.py index 0fd907166..70381c090 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -17,60 +17,3 @@ user=os.getenv("DJ_USER"), password=os.getenv("DJ_PASS"), ) - - -@pytest.fixture -def connection_root(): - """Root user database connection.""" - dj.config["safemode"] = False - connection = dj.Connection( - host=os.getenv("DJ_HOST"), - user=os.getenv("DJ_USER"), - password=os.getenv("DJ_PASS"), - ) - yield connection - dj.config["safemode"] = True - connection.close() - - -@pytest.fixture -def connection_test(connection_root): - """Test user database connection.""" - database = f"{PREFIX}%%" - credentials = dict( - host=os.getenv("DJ_HOST"), user="datajoint", password="datajoint" - ) - permission = "ALL PRIVILEGES" - - # Create MySQL users - if version.parse( - connection_root.query("select @@version;").fetchone()[0] - ) >= version.parse("8.0.0"): - # create user if necessary on mysql8 - connection_root.query( - f""" - CREATE USER IF NOT EXISTS '{credentials["user"]}'@'%%' - IDENTIFIED BY '{credentials["password"]}'; - """ - ) - connection_root.query( - f""" - GRANT {permission} ON `{database}`.* - TO '{credentials["user"]}'@'%%'; - """ - ) - else: - # grant permissions. For MySQL 5.7 this also automatically creates user - # if not exists - connection_root.query( - f""" - GRANT {permission} ON `{database}`.* - TO '{credentials["user"]}'@'%%' - IDENTIFIED BY '{credentials["password"]}'; - """ - ) - - connection = dj.Connection(**credentials) - yield connection - connection_root.query(f"""DROP USER `{credentials["user"]}`""") - connection.close() diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..49c1bb5b4 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,61 @@ +import datajoint as dj +from packaging import version +import os +import pytest +from . import schema, PREFIX + +@pytest.fixture(scope="session") +def connection_root(): + """Root user database connection.""" + dj.config["safemode"] = False + connection = dj.Connection( + host=os.getenv("DJ_HOST"), + user=os.getenv("DJ_USER"), + password=os.getenv("DJ_PASS"), + ) + yield connection + dj.config["safemode"] = True + connection.close() + + +@pytest.fixture(scope="session") +def connection_test(connection_root): + """Test user database connection.""" + database = f"{PREFIX}%%" + credentials = dict( + host=os.getenv("DJ_HOST"), user="datajoint", password="datajoint" + ) + permission = "ALL PRIVILEGES" + + # Create MySQL users + if version.parse( + connection_root.query("select @@version;").fetchone()[0] + ) >= version.parse("8.0.0"): + # create user if necessary on mysql8 + connection_root.query( + f""" + CREATE USER IF NOT EXISTS '{credentials["user"]}'@'%%' + IDENTIFIED BY '{credentials["password"]}'; + """ + ) + connection_root.query( + f""" + GRANT {permission} ON `{database}`.* + TO '{credentials["user"]}'@'%%'; + """ + ) + else: + # grant permissions. For MySQL 5.7 this also automatically creates user + # if not exists + connection_root.query( + f""" + GRANT {permission} ON `{database}`.* + TO '{credentials["user"]}'@'%%' + IDENTIFIED BY '{credentials["password"]}'; + """ + ) + + connection = dj.Connection(**credentials) + yield connection + connection_root.query(f"""DROP USER `{credentials["user"]}`""") + connection.close() diff --git a/tests/test_blob_matlab.py b/tests/test_blob_matlab.py index ecb698fec..504a4c52e 100644 --- a/tests/test_blob_matlab.py +++ b/tests/test_blob_matlab.py @@ -1,14 +1,12 @@ import numpy as np +import pytest import datajoint as dj from datajoint.blob import pack, unpack from numpy.testing import assert_array_equal -from . import PREFIX, CONN_INFO +from . import PREFIX -schema = dj.Schema(PREFIX + "_test1", locals(), connection=dj.conn(**CONN_INFO)) - -@schema class Blob(dj.Manual): definition = """ # diverse types of blobs id : int @@ -18,51 +16,63 @@ class Blob(dj.Manual): """ -def insert_blobs(): - """ - This function inserts blobs resulting from the following datajoint-matlab code: - - self.insert({ - 1 'simple string' 'character string' - 2 '1D vector' 1:15:180 - 3 'string array' {'string1' 'string2'} - 4 'struct array' struct('a', {1,2}, 'b', {struct('c', magic(3)), struct('C', magic(5))}) - 5 '3D double array' reshape(1:24, [2,3,4]) - 6 '3D uint8 array' reshape(uint8(1:24), [2,3,4]) - 7 '3D complex array' fftn(reshape(1:24, [2,3,4])) - }) - - and then dumped using the command - mysqldump -u username -p --hex-blob test_schema blob_table > blob.sql - """ +@pytest.fixture(scope="module") +def schema(connection_test): + schema = dj.Schema(PREFIX + "_test1", locals(), connection=dj.conn(connection_test)) + schema(Blob) + yield schema + schema.drop() - schema.connection.query( + +@pytest.fixture(scope="module") +def insert_blobs_func(schema): + def insert_blobs(): + """ + This function inserts blobs resulting from the following datajoint-matlab code: + + self.insert({ + 1 'simple string' 'character string' + 2 '1D vector' 1:15:180 + 3 'string array' {'string1' 'string2'} + 4 'struct array' struct('a', {1,2}, 'b', {struct('c', magic(3)), struct('C', magic(5))}) + 5 '3D double array' reshape(1:24, [2,3,4]) + 6 '3D uint8 array' reshape(uint8(1:24), [2,3,4]) + 7 '3D complex array' fftn(reshape(1:24, [2,3,4])) + }) + + and then dumped using the command + mysqldump -u username -p --hex-blob test_schema blob_table > blob.sql """ - INSERT INTO {table_name} VALUES - (1,'simple string',0x6D596D00410200000000000000010000000000000010000000000000000400000000000000630068006100720061006300740065007200200073007400720069006E006700), - (2,'1D vector',0x6D596D0041020000000000000001000000000000000C000000000000000600000000000000000000000000F03F00000000000030400000000000003F4000000000000047400000000000804E4000000000000053400000000000C056400000000000805A400000000000405E4000000000000061400000000000E062400000000000C06440), - (3,'string array',0x6D596D00430200000000000000010000000000000002000000000000002F0000000000000041020000000000000001000000000000000700000000000000040000000000000073007400720069006E00670031002F0000000000000041020000000000000001000000000000000700000000000000040000000000000073007400720069006E0067003200), - (4,'struct array',0x6D596D005302000000000000000100000000000000020000000000000002000000610062002900000000000000410200000000000000010000000000000001000000000000000600000000000000000000000000F03F9000000000000000530200000000000000010000000000000001000000000000000100000063006900000000000000410200000000000000030000000000000003000000000000000600000000000000000000000000204000000000000008400000000000001040000000000000F03F0000000000001440000000000000224000000000000018400000000000001C40000000000000004029000000000000004102000000000000000100000000000000010000000000000006000000000000000000000000000040100100000000000053020000000000000001000000000000000100000000000000010000004300E9000000000000004102000000000000000500000000000000050000000000000006000000000000000000000000003140000000000000374000000000000010400000000000002440000000000000264000000000000038400000000000001440000000000000184000000000000028400000000000003240000000000000F03F0000000000001C400000000000002A400000000000003340000000000000394000000000000020400000000000002C400000000000003440000000000000354000000000000000400000000000002E400000000000003040000000000000364000000000000008400000000000002240), - (5,'3D double array',0x6D596D004103000000000000000200000000000000030000000000000004000000000000000600000000000000000000000000F03F000000000000004000000000000008400000000000001040000000000000144000000000000018400000000000001C40000000000000204000000000000022400000000000002440000000000000264000000000000028400000000000002A400000000000002C400000000000002E40000000000000304000000000000031400000000000003240000000000000334000000000000034400000000000003540000000000000364000000000000037400000000000003840), - (6,'3D uint8 array',0x6D596D0041030000000000000002000000000000000300000000000000040000000000000009000000000000000102030405060708090A0B0C0D0E0F101112131415161718), - (7,'3D complex array',0x6D596D0041030000000000000002000000000000000300000000000000040000000000000006000000010000000000000000C0724000000000000028C000000000000038C0000000000000000000000000000038C0000000000000000000000000000052C00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000052C00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000052C00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000AA4C58E87AB62B400000000000000000AA4C58E87AB62BC0000000000000008000000000000052400000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000080000000000000008000000000000052C000000000000000800000000000000080000000000000008000000000000000800000000000000080 - ); - """.format( - table_name=Blob.full_table_name + + schema.connection.query( + """ + INSERT INTO {table_name} VALUES + (1,'simple string',0x6D596D00410200000000000000010000000000000010000000000000000400000000000000630068006100720061006300740065007200200073007400720069006E006700), + (2,'1D vector',0x6D596D0041020000000000000001000000000000000C000000000000000600000000000000000000000000F03F00000000000030400000000000003F4000000000000047400000000000804E4000000000000053400000000000C056400000000000805A400000000000405E4000000000000061400000000000E062400000000000C06440), + (3,'string array',0x6D596D00430200000000000000010000000000000002000000000000002F0000000000000041020000000000000001000000000000000700000000000000040000000000000073007400720069006E00670031002F0000000000000041020000000000000001000000000000000700000000000000040000000000000073007400720069006E0067003200), + (4,'struct array',0x6D596D005302000000000000000100000000000000020000000000000002000000610062002900000000000000410200000000000000010000000000000001000000000000000600000000000000000000000000F03F9000000000000000530200000000000000010000000000000001000000000000000100000063006900000000000000410200000000000000030000000000000003000000000000000600000000000000000000000000204000000000000008400000000000001040000000000000F03F0000000000001440000000000000224000000000000018400000000000001C40000000000000004029000000000000004102000000000000000100000000000000010000000000000006000000000000000000000000000040100100000000000053020000000000000001000000000000000100000000000000010000004300E9000000000000004102000000000000000500000000000000050000000000000006000000000000000000000000003140000000000000374000000000000010400000000000002440000000000000264000000000000038400000000000001440000000000000184000000000000028400000000000003240000000000000F03F0000000000001C400000000000002A400000000000003340000000000000394000000000000020400000000000002C400000000000003440000000000000354000000000000000400000000000002E400000000000003040000000000000364000000000000008400000000000002240), + (5,'3D double array',0x6D596D004103000000000000000200000000000000030000000000000004000000000000000600000000000000000000000000F03F000000000000004000000000000008400000000000001040000000000000144000000000000018400000000000001C40000000000000204000000000000022400000000000002440000000000000264000000000000028400000000000002A400000000000002C400000000000002E40000000000000304000000000000031400000000000003240000000000000334000000000000034400000000000003540000000000000364000000000000037400000000000003840), + (6,'3D uint8 array',0x6D596D0041030000000000000002000000000000000300000000000000040000000000000009000000000000000102030405060708090A0B0C0D0E0F101112131415161718), + (7,'3D complex array',0x6D596D0041030000000000000002000000000000000300000000000000040000000000000006000000010000000000000000C0724000000000000028C000000000000038C0000000000000000000000000000038C0000000000000000000000000000052C00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000052C00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000052C00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000AA4C58E87AB62B400000000000000000AA4C58E87AB62BC0000000000000008000000000000052400000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000080000000000000008000000000000052C000000000000000800000000000000080000000000000008000000000000000800000000000000080 + ); + """.format( + table_name=Blob.full_table_name + ) ) - ) + yield insert_blobs -class TestFetch: - @classmethod - def setup_class(cls): - dj.config["safemode"] = False # temp - assert not dj.config["safemode"], "safemode must be disabled" - Blob().delete() - insert_blobs() +@pytest.fixture(scope="class") +def setup_class(schema, insert_blobs_func): + assert not dj.config["safemode"], "safemode must be disabled" + Blob().delete() + insert_blobs_func() + + +class TestFetch: @staticmethod - def test_complex_matlab_blobs(): + def test_complex_matlab_blobs(setup_class): """ test correct de-serialization of various blob types """ @@ -109,7 +119,7 @@ def test_complex_matlab_blobs(): assert_array_equal(blob, unpack(pack(blob))) @staticmethod - def test_complex_matlab_squeeze(): + def test_complex_matlab_squeeze(setup_class): """ test correct de-serialization of various blob types """ @@ -164,7 +174,7 @@ def test_complex_matlab_squeeze(): assert tuple(blob.shape) == (2, 3, 4) assert blob.dtype == "complex128" - def test_iter(self): + def test_iter(self, setup_class): """ test iterator over the entity set """ diff --git a/tests/test_connection.py b/tests/test_connection.py index 1916da951..76b6d2389 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -5,7 +5,7 @@ import datajoint as dj from datajoint import DataJointError import numpy as np -from . import CONN_INFO_ROOT, connection_root, connection_test +from . import CONN_INFO_ROOT from . import PREFIX import pytest diff --git a/tests/test_nan.py b/tests/test_nan.py index ad4e6239e..1b3fb9f00 100644 --- a/tests/test_nan.py +++ b/tests/test_nan.py @@ -1,11 +1,8 @@ import numpy as np import datajoint as dj -from . import PREFIX, CONN_INFO +from . import PREFIX +import pytest -schema = dj.Schema(PREFIX + "_nantest", locals(), connection=dj.conn(**CONN_INFO)) - - -@schema class NanTest(dj.Manual): definition = """ id :int @@ -13,26 +10,33 @@ class NanTest(dj.Manual): value=null :double """ +@pytest.fixture(scope="module") +def schema(connection_test): + schema = dj.Schema(PREFIX + "_nantest", locals(), connection=dj.conn(connection_test)) + schema(NanTest) + yield schema + schema.drop() -class TestNaNInsert: - @classmethod - def setup_class(cls): - cls.rel = NanTest() - with dj.config(safemode=False): - cls.rel.delete() - a = np.array([0, 1 / 3, np.nan, np.pi, np.nan]) - cls.rel.insert(((i, value) for i, value in enumerate(a))) - cls.a = a +@pytest.fixture(scope="class") +def setup_class(request, schema): + rel = NanTest() + with dj.config(safemode=False): + rel.delete() + a = np.array([0, 1 / 3, np.nan, np.pi, np.nan]) + rel.insert(((i, value) for i, value in enumerate(a))) + request.cls.rel = rel + request.cls.a = a - def test_insert_nan(self): +class TestNaNInsert: + def test_insert_nan(self, setup_class): """Test fetching of null values""" b = self.rel.fetch("value", order_by="id") - (np.isnan(self.a) == np.isnan(b)).all(), "incorrect handling of Nans" - np.allclose( + assert (np.isnan(self.a) == np.isnan(b)).all(), "incorrect handling of Nans" + assert np.allclose( self.a[np.logical_not(np.isnan(self.a))], b[np.logical_not(np.isnan(b))] ), "incorrect storage of floats" - def test_nulls_do_not_affect_primary_keys(self): + def test_nulls_do_not_affect_primary_keys(self, setup_class): """Test against a case that previously caused a bug when skipping existing entries.""" self.rel.insert( ((i, value) for i, value in enumerate(self.a)), skip_duplicates=True diff --git a/tests/test_schema_keywords.py b/tests/test_schema_keywords.py index 1853852ed..e8354ec26 100644 --- a/tests/test_schema_keywords.py +++ b/tests/test_schema_keywords.py @@ -1,10 +1,8 @@ -from . import PREFIX, CONN_INFO +from . import PREFIX import datajoint as dj +import pytest -schema = dj.Schema(PREFIX + "_keywords", connection=dj.conn(**CONN_INFO)) - -@schema class A(dj.Manual): definition = """ a_id: int # a id @@ -31,12 +29,20 @@ class C(dj.Part): """ -@schema class D(B): source = A -def test_inherited_part_table(): +@pytest.fixture(scope="module") +def schema(connection_test): + schema = dj.Schema(PREFIX + "_keywords", connection=dj.conn(connection_test)) + schema(A) + schema(D) + yield schema + schema.drop() + + +def test_inherited_part_table(schema): assert "a_id" in D().heading.attributes assert "b_id" in D().heading.attributes assert "a_id" in D.C().heading.attributes diff --git a/tests/test_virtual_module.py b/tests/test_virtual_module.py index d3546c488..fbb05002c 100644 --- a/tests/test_virtual_module.py +++ b/tests/test_virtual_module.py @@ -1,10 +1,9 @@ import datajoint as dj from datajoint.user_tables import UserTable -from . import CONN_INFO -def test_virtual_module(schema_obj): +def test_virtual_module(schema_obj, connection_test): module = dj.VirtualModule( - "module", schema_obj.schema.database, connection=dj.conn(**CONN_INFO) + "module", schema_obj.schema.database, connection=dj.conn(connection_test) ) assert issubclass(module.Experiment, UserTable) From 5b53e156d25741aa78498a18eac7ceb7f2d28cd3 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Thu, 30 Nov 2023 14:13:21 -0600 Subject: [PATCH 3/6] convert schema.py to fixture [WIP] --- tests/conftest.py | 52 ++++++++++++++++++++++++++++++++++- tests/schema.py | 39 -------------------------- tests/test_blob.py | 2 +- tests/test_blob_matlab.py | 2 +- tests/test_nan.py | 2 +- tests/test_schema_keywords.py | 2 +- tests/test_virtual_module.py | 2 +- 7 files changed, 56 insertions(+), 45 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 49c1bb5b4..bea480b85 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,15 @@ +import sys import datajoint as dj from packaging import version import os import pytest -from . import schema, PREFIX +import inspect +from . import PREFIX +from .schema import * + +# all_classes = [] +# for _, obj in inspect.getmembers(sys.modules[__name__], inspect.isclass): +# all_classes.append(obj) @pytest.fixture(scope="session") def connection_root(): @@ -59,3 +66,46 @@ def connection_test(connection_root): yield connection connection_root.query(f"""DROP USER `{credentials["user"]}`""") connection.close() + +@pytest.fixture +def schema_fixture(connection_test): + schema = dj.Schema(PREFIX + "_test1", connection=connection_test) + schema(TTest) + schema(TTest) + schema(TTest2) + schema(TTest3) + schema(NullableNumbers) + schema(TTestExtra) + schema(TTestNoExtra) + schema(Auto) + schema(User) + schema(Subject) + schema(Language) + schema(Experiment) + schema(Trial) + schema(Ephys) + schema(Image) + schema(UberTrash) + schema(UnterTrash) + schema(SimpleSource) + schema(SigIntTable) + schema(SigTermTable) + schema(DjExceptionName) + schema(ErrorClass) + schema(DecimalPrimaryKey) + schema(IndexRich) + schema(ThingA) + schema(ThingB) + schema(ThingC) + schema(Parent) + schema(Child) + schema(ComplexParent) + schema(ComplexChild) + schema(SubjectA) + schema(SessionA) + schema(SessionStatusA) + schema(SessionDateA) + schema(Stimulus) + schema(Longblob) + yield schema + schema.drop() \ No newline at end of file diff --git a/tests/schema.py b/tests/schema.py index dafd481da..4128ddd30 100644 --- a/tests/schema.py +++ b/tests/schema.py @@ -6,12 +6,8 @@ import numpy as np import datajoint as dj import inspect -from . import PREFIX, CONN_INFO -schema = dj.Schema(PREFIX + "_test1", connection=dj.conn(**CONN_INFO)) - -@schema class TTest(dj.Lookup): """ doc string @@ -25,7 +21,6 @@ class TTest(dj.Lookup): contents = [(k, 2 * k) for k in range(10)] -@schema class TTest2(dj.Manual): definition = """ key : int # key @@ -34,7 +29,6 @@ class TTest2(dj.Manual): """ -@schema class TTest3(dj.Manual): definition = """ key : int @@ -43,7 +37,6 @@ class TTest3(dj.Manual): """ -@schema class NullableNumbers(dj.Manual): definition = """ key : int @@ -54,7 +47,6 @@ class NullableNumbers(dj.Manual): """ -@schema class TTestExtra(dj.Manual): """ clone of Test but with an extra field @@ -63,7 +55,6 @@ class TTestExtra(dj.Manual): definition = TTest.definition + "\nextra : int # extra int\n" -@schema class TTestNoExtra(dj.Manual): """ clone of Test but with no extra fields @@ -72,7 +63,6 @@ class TTestNoExtra(dj.Manual): definition = TTest.definition -@schema class Auto(dj.Lookup): definition = """ id :int auto_increment @@ -85,7 +75,6 @@ def fill(self): self.insert([dict(name="Godel"), dict(name="Escher"), dict(name="Bach")]) -@schema class User(dj.Lookup): definition = """ # lab members username: varchar(12) @@ -101,7 +90,6 @@ class User(dj.Lookup): ] -@schema class Subject(dj.Lookup): definition = """ # Basic information about animal subjects used in experiments subject_id :int # unique subject id @@ -121,7 +109,6 @@ class Subject(dj.Lookup): ] -@schema class Language(dj.Lookup): definition = """ # languages spoken by some of the developers @@ -139,7 +126,6 @@ class Language(dj.Lookup): ] -@schema class Experiment(dj.Imported): definition = """ # information about experiments -> Subject @@ -175,7 +161,6 @@ def make(self, key): ) -@schema class Trial(dj.Imported): definition = """ # a trial within an experiment -> Experiment.proj(animal='subject_id') @@ -205,7 +190,6 @@ def make(self, key): ) -@schema class Ephys(dj.Imported): definition = """ # some kind of electrophysiological recording -> Trial @@ -244,7 +228,6 @@ def _make_tuples(self, key): ) -@schema class Image(dj.Manual): definition = """ # table for testing blob inserts @@ -254,7 +237,6 @@ class Image(dj.Manual): """ -@schema class UberTrash(dj.Lookup): definition = """ id : int @@ -263,7 +245,6 @@ class UberTrash(dj.Lookup): contents = [(1,)] -@schema class UnterTrash(dj.Lookup): definition = """ -> UberTrash @@ -273,7 +254,6 @@ class UnterTrash(dj.Lookup): contents = [(1, 1), (1, 2)] -@schema class SimpleSource(dj.Lookup): definition = """ id : int # id @@ -281,7 +261,6 @@ class SimpleSource(dj.Lookup): contents = ((x,) for x in range(10)) -@schema class SigIntTable(dj.Computed): definition = """ -> SimpleSource @@ -291,7 +270,6 @@ def _make_tuples(self, key): raise KeyboardInterrupt -@schema class SigTermTable(dj.Computed): definition = """ -> SimpleSource @@ -301,7 +279,6 @@ def make(self, key): raise SystemExit("SIGTERM received") -@schema class DjExceptionName(dj.Lookup): definition = """ dj_exception_name: char(64) @@ -316,7 +293,6 @@ def contents(self): ] -@schema class ErrorClass(dj.Computed): definition = """ -> DjExceptionName @@ -327,7 +303,6 @@ def make(self, key): raise getattr(dj.errors, exception_name) -@schema class DecimalPrimaryKey(dj.Lookup): definition = """ id : decimal(4,3) @@ -335,7 +310,6 @@ class DecimalPrimaryKey(dj.Lookup): contents = zip((0.1, 0.25, 3.99)) -@schema class IndexRich(dj.Manual): definition = """ -> Subject @@ -348,14 +322,12 @@ class IndexRich(dj.Manual): # Schema for issue 656 -@schema class ThingA(dj.Manual): definition = """ a: int """ -@schema class ThingB(dj.Manual): definition = """ b1: int @@ -365,7 +337,6 @@ class ThingB(dj.Manual): """ -@schema class ThingC(dj.Manual): definition = """ -> ThingA @@ -374,7 +345,6 @@ class ThingC(dj.Manual): """ -@schema class Parent(dj.Lookup): definition = """ parent_id: int @@ -384,7 +354,6 @@ class Parent(dj.Lookup): contents = [(1, "Joe")] -@schema class Child(dj.Lookup): definition = """ -> Parent @@ -396,13 +365,11 @@ class Child(dj.Lookup): # Related to issue #886 (8), #883 (5) -@schema class ComplexParent(dj.Lookup): definition = "\n".join(["parent_id_{}: int".format(i + 1) for i in range(8)]) contents = [tuple(i for i in range(8))] -@schema class ComplexChild(dj.Lookup): definition = "\n".join( ["-> ComplexParent"] + ["child_id_{}: int".format(i + 1) for i in range(1)] @@ -410,7 +377,6 @@ class ComplexChild(dj.Lookup): contents = [tuple(i for i in range(9))] -@schema class SubjectA(dj.Lookup): definition = """ subject_id: varchar(32) @@ -425,7 +391,6 @@ class SubjectA(dj.Lookup): ] -@schema class SessionA(dj.Lookup): definition = """ -> SubjectA @@ -441,7 +406,6 @@ class SessionA(dj.Lookup): ] -@schema class SessionStatusA(dj.Lookup): definition = """ -> SessionA @@ -456,7 +420,6 @@ class SessionStatusA(dj.Lookup): ] -@schema class SessionDateA(dj.Lookup): definition = """ -> SubjectA @@ -470,7 +433,6 @@ class SessionDateA(dj.Lookup): ] -@schema class Stimulus(dj.Lookup): definition = """ id: int @@ -480,7 +442,6 @@ class Stimulus(dj.Lookup): """ -@schema class Longblob(dj.Manual): definition = """ id: int diff --git a/tests/test_blob.py b/tests/test_blob.py index 562d78f2b..761b02cf5 100644 --- a/tests/test_blob.py +++ b/tests/test_blob.py @@ -169,7 +169,7 @@ def test_complex(): assert_array_equal(x, unpack(pack(x)), "Arrays do not match!") -def test_insert_longblob(): +def test_insert_longblob(schema_fixture): insert_dj_blob = {"id": 1, "data": [1, 2, 3]} schema.Longblob.insert1(insert_dj_blob) assert (schema.Longblob & "id=1").fetch1() == insert_dj_blob diff --git a/tests/test_blob_matlab.py b/tests/test_blob_matlab.py index 504a4c52e..06154b1fc 100644 --- a/tests/test_blob_matlab.py +++ b/tests/test_blob_matlab.py @@ -18,7 +18,7 @@ class Blob(dj.Manual): @pytest.fixture(scope="module") def schema(connection_test): - schema = dj.Schema(PREFIX + "_test1", locals(), connection=dj.conn(connection_test)) + schema = dj.Schema(PREFIX + "_test1", locals(), connection=connection_test) schema(Blob) yield schema schema.drop() diff --git a/tests/test_nan.py b/tests/test_nan.py index 1b3fb9f00..38dd5036f 100644 --- a/tests/test_nan.py +++ b/tests/test_nan.py @@ -12,7 +12,7 @@ class NanTest(dj.Manual): @pytest.fixture(scope="module") def schema(connection_test): - schema = dj.Schema(PREFIX + "_nantest", locals(), connection=dj.conn(connection_test)) + schema = dj.Schema(PREFIX + "_nantest", locals(), connection=connection_test) schema(NanTest) yield schema schema.drop() diff --git a/tests/test_schema_keywords.py b/tests/test_schema_keywords.py index e8354ec26..c8b7d5a24 100644 --- a/tests/test_schema_keywords.py +++ b/tests/test_schema_keywords.py @@ -35,7 +35,7 @@ class D(B): @pytest.fixture(scope="module") def schema(connection_test): - schema = dj.Schema(PREFIX + "_keywords", connection=dj.conn(connection_test)) + schema = dj.Schema(PREFIX + "_keywords", connection=connection_test) schema(A) schema(D) yield schema diff --git a/tests/test_virtual_module.py b/tests/test_virtual_module.py index fbb05002c..b7c3f23bb 100644 --- a/tests/test_virtual_module.py +++ b/tests/test_virtual_module.py @@ -4,6 +4,6 @@ def test_virtual_module(schema_obj, connection_test): module = dj.VirtualModule( - "module", schema_obj.schema.database, connection=dj.conn(connection_test) + "module", schema_obj.schema.database, connection=connection_test ) assert issubclass(module.Experiment, UserTable) From aaee0a1af8761d01ddb7332d76482971411fc4c3 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Fri, 1 Dec 2023 13:59:38 -0600 Subject: [PATCH 4/6] convert schema files to fixtures --- tests/conftest.py | 137 +++++++++++++++++++++++------------ tests/schema.py | 2 + tests/schema_advanced.py | 12 +-- tests/schema_simple.py | 19 +---- tests/test_blob.py | 23 +++--- tests/test_dependencies.py | 7 +- tests/test_erd.py | 118 ++++++++++++++---------------- tests/test_foreign_keys.py | 25 +++---- tests/test_groupby.py | 2 +- tests/test_log.py | 7 +- tests/test_nan.py | 6 +- tests/test_relation_u.py | 61 ++++++++-------- tests/test_virtual_module.py | 6 +- 13 files changed, 216 insertions(+), 209 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index bea480b85..8335b1c11 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,12 +4,10 @@ import os import pytest import inspect -from . import PREFIX -from .schema import * +from . import PREFIX, schema, schema_simple, schema_advanced + +namespace = locals() -# all_classes = [] -# for _, obj in inspect.getmembers(sys.modules[__name__], inspect.isclass): -# all_classes.append(obj) @pytest.fixture(scope="session") def connection_root(): @@ -67,45 +65,92 @@ def connection_test(connection_root): connection_root.query(f"""DROP USER `{credentials["user"]}`""") connection.close() -@pytest.fixture -def schema_fixture(connection_test): - schema = dj.Schema(PREFIX + "_test1", connection=connection_test) - schema(TTest) - schema(TTest) - schema(TTest2) - schema(TTest3) - schema(NullableNumbers) - schema(TTestExtra) - schema(TTestNoExtra) - schema(Auto) - schema(User) - schema(Subject) - schema(Language) - schema(Experiment) - schema(Trial) - schema(Ephys) - schema(Image) - schema(UberTrash) - schema(UnterTrash) - schema(SimpleSource) - schema(SigIntTable) - schema(SigTermTable) - schema(DjExceptionName) - schema(ErrorClass) - schema(DecimalPrimaryKey) - schema(IndexRich) - schema(ThingA) - schema(ThingB) - schema(ThingC) - schema(Parent) - schema(Child) - schema(ComplexParent) - schema(ComplexChild) - schema(SubjectA) - schema(SessionA) - schema(SessionStatusA) - schema(SessionDateA) - schema(Stimulus) - schema(Longblob) + +@pytest.fixture(scope="module") +def schema_any(connection_test): + schema_any = dj.Schema( + PREFIX + "_test1", schema.__dict__, connection=connection_test + ) + schema_any(schema.TTest) + schema_any(schema.TTest2) + schema_any(schema.TTest3) + schema_any(schema.NullableNumbers) + schema_any(schema.TTestExtra) + schema_any(schema.TTestNoExtra) + schema_any(schema.Auto) + schema_any(schema.User) + schema_any(schema.Subject) + schema_any(schema.Language) + schema_any(schema.Experiment) + schema_any(schema.Trial) + schema_any(schema.Ephys) + schema_any(schema.Image) + schema_any(schema.UberTrash) + schema_any(schema.UnterTrash) + schema_any(schema.SimpleSource) + schema_any(schema.SigIntTable) + schema_any(schema.SigTermTable) + schema_any(schema.DjExceptionName) + schema_any(schema.ErrorClass) + schema_any(schema.DecimalPrimaryKey) + schema_any(schema.IndexRich) + schema_any(schema.ThingA) + schema_any(schema.ThingB) + schema_any(schema.ThingC) + schema_any(schema.Parent) + schema_any(schema.Child) + schema_any(schema.ComplexParent) + schema_any(schema.ComplexChild) + schema_any(schema.SubjectA) + schema_any(schema.SessionA) + schema_any(schema.SessionStatusA) + schema_any(schema.SessionDateA) + schema_any(schema.Stimulus) + schema_any(schema.Longblob) + yield schema_any + schema_any.drop() + + +@pytest.fixture(scope="module") +def schema_simp(connection_test): + schema = dj.Schema( + PREFIX + "_relational", schema_simple.__dict__, connection=connection_test + ) + schema(schema_simple.IJ) + schema(schema_simple.JI) + schema(schema_simple.A) + schema(schema_simple.B) + schema(schema_simple.L) + schema(schema_simple.D) + schema(schema_simple.E) + schema(schema_simple.F) + schema(schema_simple.F) + schema(schema_simple.DataA) + schema(schema_simple.DataB) + schema(schema_simple.Website) + schema(schema_simple.Profile) + schema(schema_simple.Website) + schema(schema_simple.TTestUpdate) + schema(schema_simple.ArgmaxTest) + schema(schema_simple.ReservedWord) + schema(schema_simple.OutfitLaunch) + yield schema + schema.drop() + + +@pytest.fixture(scope="module") +def schema_adv(connection_test): + schema = dj.Schema( + PREFIX + "_advanced", schema_advanced.__dict__, connection=connection_test + ) + schema(schema_advanced.Person) + schema(schema_advanced.Parent) + schema(schema_advanced.Subject) + schema(schema_advanced.Prep) + schema(schema_advanced.Slice) + schema(schema_advanced.Cell) + schema(schema_advanced.InputCell) + schema(schema_advanced.LocalSynapse) + schema(schema_advanced.GlobalSynapse) yield schema - schema.drop() \ No newline at end of file + schema.drop() diff --git a/tests/schema.py b/tests/schema.py index 4128ddd30..864c5efe4 100644 --- a/tests/schema.py +++ b/tests/schema.py @@ -7,6 +7,8 @@ import datajoint as dj import inspect +LOCALS_ANY = locals() + class TTest(dj.Lookup): """ diff --git a/tests/schema_advanced.py b/tests/schema_advanced.py index 7580611e2..104e4d1e4 100644 --- a/tests/schema_advanced.py +++ b/tests/schema_advanced.py @@ -1,10 +1,8 @@ import datajoint as dj -from . import PREFIX, CONN_INFO -schema = dj.Schema(PREFIX + "_advanced", locals(), connection=dj.conn(**CONN_INFO)) +LOCALS_ADVANCED = locals() -@schema class Person(dj.Manual): definition = """ person_id : int @@ -39,7 +37,6 @@ def fill(self): ) -@schema class Parent(dj.Manual): definition = """ -> Person @@ -89,7 +86,6 @@ def make_parent(pid, parent): ) -@schema class Subject(dj.Manual): definition = """ subject : int @@ -98,14 +94,12 @@ class Subject(dj.Manual): """ -@schema class Prep(dj.Manual): definition = """ prep : int """ -@schema class Slice(dj.Manual): definition = """ -> Prep @@ -113,7 +107,6 @@ class Slice(dj.Manual): """ -@schema class Cell(dj.Manual): definition = """ -> Slice @@ -121,7 +114,6 @@ class Cell(dj.Manual): """ -@schema class InputCell(dj.Manual): definition = """ # a synapse within the slice -> Cell @@ -129,7 +121,6 @@ class InputCell(dj.Manual): """ -@schema class LocalSynapse(dj.Manual): definition = """ # a synapse within the slice -> Cell.proj(presynaptic='cell') @@ -137,7 +128,6 @@ class LocalSynapse(dj.Manual): """ -@schema class GlobalSynapse(dj.Manual): # Mix old-style and new-style projected foreign keys definition = """ diff --git a/tests/schema_simple.py b/tests/schema_simple.py index 78f64d036..bb5c21ff5 100644 --- a/tests/schema_simple.py +++ b/tests/schema_simple.py @@ -7,14 +7,12 @@ import hashlib import uuid import faker -from . import PREFIX, CONN_INFO import numpy as np from datetime import date, timedelta -schema = dj.Schema(PREFIX + "_relational", locals(), connection=dj.conn(**CONN_INFO)) +LOCALS_SIMPLE = locals() -@schema class IJ(dj.Lookup): definition = """ # tests restrictions i : int @@ -23,7 +21,6 @@ class IJ(dj.Lookup): contents = list(dict(i=i, j=j + 2) for i in range(3) for j in range(3)) -@schema class JI(dj.Lookup): definition = """ # tests restrictions by relations when attributes are reordered j : int @@ -32,7 +29,6 @@ class JI(dj.Lookup): contents = list(dict(i=i + 1, j=j) for i in range(3) for j in range(3)) -@schema class A(dj.Lookup): definition = """ id_a :int @@ -42,7 +38,6 @@ class A(dj.Lookup): contents = [(i, i % 4 > i % 3) for i in range(10)] -@schema class B(dj.Computed): definition = """ -> A @@ -76,7 +71,6 @@ def make(self, key): ) -@schema class L(dj.Lookup): definition = """ id_l: int @@ -86,7 +80,6 @@ class L(dj.Lookup): contents = [(i, i % 3 >= i % 5) for i in range(30)] -@schema class D(dj.Computed): definition = """ -> A @@ -102,7 +95,6 @@ def _make_tuples(self, key): self.insert(dict(key, id_d=i, **random.choice(lookup)) for i in range(4)) -@schema class E(dj.Computed): definition = """ -> B @@ -132,7 +124,6 @@ def make(self, key): ) -@schema class F(dj.Manual): definition = """ id: int @@ -141,7 +132,6 @@ class F(dj.Manual): """ -@schema class DataA(dj.Lookup): definition = """ idx : int @@ -151,7 +141,6 @@ class DataA(dj.Lookup): contents = list(zip(range(5), range(5))) -@schema class DataB(dj.Lookup): definition = """ idx : int @@ -161,7 +150,6 @@ class DataB(dj.Lookup): contents = list(zip(range(5), range(5, 10))) -@schema class Website(dj.Lookup): definition = """ url_hash : uuid @@ -177,7 +165,6 @@ def insert1_url(self, url): return url_hash -@schema class Profile(dj.Manual): definition = """ ssn : char(11) @@ -210,7 +197,6 @@ def populate_random(self, n=10): ) -@schema class TTestUpdate(dj.Lookup): definition = """ primary_key : int @@ -226,7 +212,6 @@ class TTestUpdate(dj.Lookup): ] -@schema class ArgmaxTest(dj.Lookup): definition = """ primary_key : int @@ -247,7 +232,6 @@ def contents(self): ) -@schema class ReservedWord(dj.Manual): definition = """ # Test of SQL reserved words @@ -260,7 +244,6 @@ class ReservedWord(dj.Manual): """ -@schema class OutfitLaunch(dj.Lookup): definition = """ # Monthly released designer outfits diff --git a/tests/test_blob.py b/tests/test_blob.py index 761b02cf5..a3de2e9a9 100644 --- a/tests/test_blob.py +++ b/tests/test_blob.py @@ -1,13 +1,14 @@ +import pytest import datajoint as dj import timeit import numpy as np import uuid -from . import schema from decimal import Decimal from datetime import datetime from datajoint.blob import pack, unpack from numpy.testing import assert_array_equal from pytest import approx +from .schema import * def test_pack(): @@ -169,18 +170,16 @@ def test_complex(): assert_array_equal(x, unpack(pack(x)), "Arrays do not match!") -def test_insert_longblob(schema_fixture): +def test_insert_longblob(schema_any): insert_dj_blob = {"id": 1, "data": [1, 2, 3]} - schema.Longblob.insert1(insert_dj_blob) - assert (schema.Longblob & "id=1").fetch1() == insert_dj_blob - (schema.Longblob & "id=1").delete() + Longblob.insert1(insert_dj_blob) + assert (Longblob & "id=1").fetch1() == insert_dj_blob + (Longblob & "id=1").delete() query_mym_blob = {"id": 1, "data": np.array([1, 2, 3])} - schema.Longblob.insert1(query_mym_blob) - assert (schema.Longblob & "id=1").fetch1()["data"].all() == query_mym_blob[ - "data" - ].all() - (schema.Longblob & "id=1").delete() + Longblob.insert1(query_mym_blob) + assert (Longblob & "id=1").fetch1()["data"].all() == query_mym_blob["data"].all() + (Longblob & "id=1").delete() query_32_blob = ( "INSERT INTO djtest_test1.longblob (id, data) VALUES (1, " @@ -193,7 +192,7 @@ def test_insert_longblob(schema_fixture): ) dj.conn().query(query_32_blob).fetchall() dj.blob.use_32bit_dims = True - assert (schema.Longblob & "id=1").fetch1() == { + assert (Longblob & "id=1").fetch1() == { "id": 1, "data": np.rec.array( [ @@ -209,7 +208,7 @@ def test_insert_longblob(schema_fixture): dtype=[("hits", "O"), ("sides", "O"), ("tasks", "O"), ("stage", "O")], ), } - (schema.Longblob & "id=1").delete() + (Longblob & "id=1").delete() dj.blob.use_32bit_dims = False diff --git a/tests/test_dependencies.py b/tests/test_dependencies.py index 1e8b1da41..312e5f8ad 100644 --- a/tests/test_dependencies.py +++ b/tests/test_dependencies.py @@ -1,9 +1,8 @@ import datajoint as dj from datajoint import errors from pytest import raises - -from .schema import * from datajoint.dependencies import unite_master_parts +from .schema import * def test_unite_master_parts(): @@ -51,7 +50,7 @@ def test_unite_master_parts(): ] -def test_nullable_dependency(): +def test_nullable_dependency(schema_any): """test nullable unique foreign key""" # Thing C has a nullable dependency on B whose primary key is composite a = ThingA() @@ -80,7 +79,7 @@ def test_nullable_dependency(): assert len(c) == len(c.fetch()) == 5 -def test_unique_dependency(): +def test_unique_dependency(schema_any): """test nullable unique foreign key""" # Thing C has a nullable dependency on B whose primary key is composite diff --git a/tests/test_erd.py b/tests/test_erd.py index 991410995..f1274ec1b 100644 --- a/tests/test_erd.py +++ b/tests/test_erd.py @@ -1,76 +1,64 @@ import datajoint as dj -from .schema_simple import A, B, D, E, L, schema, OutfitLaunch -from . import schema_advanced +from .schema_simple import LOCALS_SIMPLE, A, B, D, E, L, OutfitLaunch +from .schema_advanced import * -namespace = locals() +def test_decorator(schema_simp): + assert issubclass(A, dj.Lookup) + assert not issubclass(A, dj.Part) + assert B.database == schema_simp.database + assert issubclass(B.C, dj.Part) + assert B.C.database == schema_simp.database + assert B.C.master is B and E.F.master is E -class TestERD: - @staticmethod - def setup_method(): - """ - class-level test setup. Executes before each test method. - """ - @staticmethod - def test_decorator(): - assert issubclass(A, dj.Lookup) - assert not issubclass(A, dj.Part) - assert B.database == schema.database - assert issubclass(B.C, dj.Part) - assert B.C.database == schema.database - assert B.C.master is B and E.F.master is E +def test_dependencies(schema_simp): + deps = schema_simp.connection.dependencies + deps.load() + assert all(cls.full_table_name in deps for cls in (A, B, B.C, D, E, E.F, L)) + assert set(A().children()) == set([B.full_table_name, D.full_table_name]) + assert set(D().parents(primary=True)) == set([A.full_table_name]) + assert set(D().parents(primary=False)) == set([L.full_table_name]) + assert set(deps.descendants(L.full_table_name)).issubset( + cls.full_table_name for cls in (L, D, E, E.F) + ) - @staticmethod - def test_dependencies(): - deps = schema.connection.dependencies - deps.load() - assert all(cls.full_table_name in deps for cls in (A, B, B.C, D, E, E.F, L)) - assert set(A().children()) == set([B.full_table_name, D.full_table_name]) - assert set(D().parents(primary=True)) == set([A.full_table_name]) - assert set(D().parents(primary=False)) == set([L.full_table_name]) - assert set(deps.descendants(L.full_table_name)).issubset( - cls.full_table_name for cls in (L, D, E, E.F) - ) - @staticmethod - def test_erd(): - assert dj.diagram.diagram_active, "Failed to import networkx and pydot" - erd = dj.ERD(schema, context=namespace) - graph = erd._make_graph() - assert set(cls.__name__ for cls in (A, B, D, E, L)).issubset(graph.nodes()) +def test_erd(schema_simp): + assert dj.diagram.diagram_active, "Failed to import networkx and pydot" + erd = dj.ERD(schema_simp, context=LOCALS_SIMPLE) + graph = erd._make_graph() + assert set(cls.__name__ for cls in (A, B, D, E, L)).issubset(graph.nodes()) - @staticmethod - def test_erd_algebra(): - erd0 = dj.ERD(B) - erd1 = erd0 + 3 - erd2 = dj.Di(E) - 3 - erd3 = erd1 * erd2 - erd4 = (erd0 + E).add_parts() - B - E - assert erd0.nodes_to_show == set(cls.full_table_name for cls in [B]) - assert erd1.nodes_to_show == set( - cls.full_table_name for cls in (B, B.C, E, E.F) - ) - assert erd2.nodes_to_show == set(cls.full_table_name for cls in (A, B, D, E, L)) - assert erd3.nodes_to_show == set(cls.full_table_name for cls in (B, E)) - assert erd4.nodes_to_show == set(cls.full_table_name for cls in (B.C, E.F)) - @staticmethod - def test_repr_svg(): - erd = dj.ERD(schema_advanced, context=namespace) - svg = erd._repr_svg_() - assert svg.startswith("") +def test_erd_algebra(schema_simp): + erd0 = dj.ERD(B) + erd1 = erd0 + 3 + erd2 = dj.Di(E) - 3 + erd3 = erd1 * erd2 + erd4 = (erd0 + E).add_parts() - B - E + assert erd0.nodes_to_show == set(cls.full_table_name for cls in [B]) + assert erd1.nodes_to_show == set(cls.full_table_name for cls in (B, B.C, E, E.F)) + assert erd2.nodes_to_show == set(cls.full_table_name for cls in (A, B, D, E, L)) + assert erd3.nodes_to_show == set(cls.full_table_name for cls in (B, E)) + assert erd4.nodes_to_show == set(cls.full_table_name for cls in (B.C, E.F)) - @staticmethod - def test_make_image(): - erd = dj.ERD(schema, context=namespace) - img = erd.make_image() - assert img.ndim == 3 and img.shape[2] in (3, 4) - @staticmethod - def test_part_table_parsing(): - # https://github.com/datajoint/datajoint-python/issues/882 - erd = dj.Di(schema) - graph = erd._make_graph() - assert "OutfitLaunch" in graph.nodes() - assert "OutfitLaunch.OutfitPiece" in graph.nodes() +def test_repr_svg(schema_adv): + erd = dj.ERD(schema_adv, context=locals()) + svg = erd._repr_svg_() + assert svg.startswith("") + + +def test_make_image(schema_simp): + erd = dj.ERD(schema_simp, context=locals()) + img = erd.make_image() + assert img.ndim == 3 and img.shape[2] in (3, 4) + + +def test_part_table_parsing(schema_simp): + # https://github.com/datajoint/datajoint-python/issues/882 + erd = dj.Di(schema_simp) + graph = erd._make_graph() + assert "OutfitLaunch" in graph.nodes() + assert "OutfitLaunch.OutfitPiece" in graph.nodes() diff --git a/tests/test_foreign_keys.py b/tests/test_foreign_keys.py index 05d87c041..18daa952a 100644 --- a/tests/test_foreign_keys.py +++ b/tests/test_foreign_keys.py @@ -1,11 +1,10 @@ from datajoint.declare import declare +from .schema_advanced import * -from . import schema_advanced - -def test_aliased_fk(): - person = schema_advanced.Person() - parent = schema_advanced.Parent() +def test_aliased_fk(schema_adv): + person = Person() + parent = Parent() person.delete() assert not person assert not parent @@ -21,21 +20,21 @@ def test_aliased_fk(): assert delete_count == 16 -def test_describe(): +def test_describe(schema_adv): """real_definition should match original definition""" - for rel in (schema_advanced.LocalSynapse, schema_advanced.GlobalSynapse): + for rel in (LocalSynapse, GlobalSynapse): describe = rel.describe() - s1 = declare( - rel.full_table_name, rel.definition, schema_advanced.schema.context - )[0].split("\n") + s1 = declare(rel.full_table_name, rel.definition, schema_adv.context)[0].split( + "\n" + ) s2 = declare(rel.full_table_name, describe, globals())[0].split("\n") for c1, c2 in zip(s1, s2): assert c1 == c2 -def test_delete(): - person = schema_advanced.Person() - parent = schema_advanced.Parent() +def test_delete(schema_adv): + person = Person() + parent = Parent() person.delete() assert not person assert not parent diff --git a/tests/test_groupby.py b/tests/test_groupby.py index 3d3be530e..109972760 100644 --- a/tests/test_groupby.py +++ b/tests/test_groupby.py @@ -1,7 +1,7 @@ from .schema_simple import A, D -def test_aggr_with_proj(): +def test_aggr_with_proj(schema_simp): # issue #944 - only breaks with MariaDB # MariaDB implements the SQL:1992 standard that prohibits fields in the select statement that are # not also in the GROUP BY statement. diff --git a/tests/test_log.py b/tests/test_log.py index a3aafa992..4b6e64613 100644 --- a/tests/test_log.py +++ b/tests/test_log.py @@ -1,8 +1,5 @@ -from . import schema - - -def test_log(): - ts, events = (schema.schema.log & 'event like "Declared%%"').fetch( +def test_log(schema_any): + ts, events = (schema_any.log & 'event like "Declared%%"').fetch( "timestamp", "event" ) assert len(ts) >= 2 diff --git a/tests/test_nan.py b/tests/test_nan.py index 38dd5036f..299c0d9f8 100644 --- a/tests/test_nan.py +++ b/tests/test_nan.py @@ -3,6 +3,7 @@ from . import PREFIX import pytest + class NanTest(dj.Manual): definition = """ id :int @@ -10,13 +11,15 @@ class NanTest(dj.Manual): value=null :double """ + @pytest.fixture(scope="module") def schema(connection_test): - schema = dj.Schema(PREFIX + "_nantest", locals(), connection=connection_test) + schema = dj.Schema(PREFIX + "_nantest", connection=connection_test) schema(NanTest) yield schema schema.drop() + @pytest.fixture(scope="class") def setup_class(request, schema): rel = NanTest() @@ -27,6 +30,7 @@ def setup_class(request, schema): request.cls.rel = rel request.cls.a = a + class TestNaNInsert: def test_insert_nan(self, setup_class): """Test fetching of null values""" diff --git a/tests/test_relation_u.py b/tests/test_relation_u.py index 44033708d..d225bccbb 100644 --- a/tests/test_relation_u.py +++ b/tests/test_relation_u.py @@ -1,6 +1,21 @@ +import pytest import datajoint as dj from pytest import raises -from . import schema, schema_simple +from .schema import * +from .schema_simple import * + + +@pytest.fixture(scope="class") +def setup_class(request, schema_any): + request.cls.user = User() + request.cls.language = Language() + request.cls.subject = Subject() + request.cls.experiment = Experiment() + request.cls.trial = Trial() + request.cls.ephys = Ephys() + request.cls.channel = Ephys.Channel() + request.cls.img = Image() + request.cls.trash = UberTrash() class TestU: @@ -8,19 +23,7 @@ class TestU: Test tables: insert, delete """ - @classmethod - def setup_class(cls): - cls.user = schema.User() - cls.language = schema.Language() - cls.subject = schema.Subject() - cls.experiment = schema.Experiment() - cls.trial = schema.Trial() - cls.ephys = schema.Ephys() - cls.channel = schema.Ephys.Channel() - cls.img = schema.Image() - cls.trash = schema.UberTrash() - - def test_restriction(self): + def test_restriction(self, setup_class): language_set = {s[1] for s in self.language.contents} rel = dj.U("language") & self.language assert list(rel.heading.names) == ["language"] @@ -32,15 +35,15 @@ def test_restriction(self): assert list(rel.primary_key) == list((rel & "trial_id>3").primary_key) assert list((dj.U("start_time") & self.trial).primary_key) == ["start_time"] - def test_invalid_restriction(self): + def test_invalid_restriction(self, setup_class): with raises(dj.DataJointError): result = dj.U("color") & dict(color="red") - def test_ineffective_restriction(self): + def test_ineffective_restriction(self, setup_class): rel = self.language & dj.U("language") assert rel.make_sql() == self.language.make_sql() - def test_join(self): + def test_join(self, setup_class): rel = self.experiment * dj.U("experiment_date") assert self.experiment.primary_key == ["subject_id", "experiment_id"] assert rel.primary_key == self.experiment.primary_key + ["experiment_date"] @@ -49,35 +52,35 @@ def test_join(self): assert self.experiment.primary_key == ["subject_id", "experiment_id"] assert rel.primary_key == self.experiment.primary_key + ["experiment_date"] - def test_invalid_join(self): + def test_invalid_join(self, setup_class): with raises(dj.DataJointError): rel = dj.U("language") * dict(language="English") - def test_repr_without_attrs(self): + def test_repr_without_attrs(self, setup_class): """test dj.U() display""" - query = dj.U().aggr(schema.Language, n="count(*)") + query = dj.U().aggr(Language, n="count(*)") repr(query) - def test_aggregations(self): - lang = schema.Language() + def test_aggregations(self, setup_class): + lang = Language() # test total aggregation on expression object n1 = dj.U().aggr(lang, n="count(*)").fetch1("n") assert n1 == len(lang.fetch()) # test total aggregation on expression class - n2 = dj.U().aggr(schema.Language, n="count(*)").fetch1("n") + n2 = dj.U().aggr(Language, n="count(*)").fetch1("n") assert n1 == n2 - rel = dj.U("language").aggr(schema.Language, number_of_speakers="count(*)") - assert len(rel) == len(set(l[1] for l in schema.Language.contents)) + rel = dj.U("language").aggr(Language, number_of_speakers="count(*)") + assert len(rel) == len(set(l[1] for l in Language.contents)) assert (rel & 'language="English"').fetch1("number_of_speakers") == 3 - def test_argmax(self): - rel = schema.TTest() + def test_argmax(self, setup_class): + rel = TTest() # get the tuples corresponding to the maximum value mx = (rel * dj.U().aggr(rel, mx="max(value)")) & "mx=value" assert mx.fetch("value")[0] == max(rel.fetch("value")) - def test_aggr(self): - rel = schema_simple.ArgmaxTest() + def test_aggr(self, setup_class, schema_simp): + rel = ArgmaxTest() amax1 = (dj.U("val") * rel) & dj.U("secondary_key").aggr(rel, val="min(val)") amax2 = (dj.U("val") * rel) * dj.U("secondary_key").aggr(rel, val="min(val)") assert ( diff --git a/tests/test_virtual_module.py b/tests/test_virtual_module.py index b7c3f23bb..bd8a0c754 100644 --- a/tests/test_virtual_module.py +++ b/tests/test_virtual_module.py @@ -2,8 +2,6 @@ from datajoint.user_tables import UserTable -def test_virtual_module(schema_obj, connection_test): - module = dj.VirtualModule( - "module", schema_obj.schema.database, connection=connection_test - ) +def test_virtual_module(schema_any, connection_test): + module = dj.VirtualModule("module", schema_any.database, connection=connection_test) assert issubclass(module.Experiment, UserTable) From 8f09fe9c3cf0b018bb0959266550982f89fd61b6 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Fri, 1 Dec 2023 14:05:56 -0600 Subject: [PATCH 5/6] remove temp conn info --- tests/__init__.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tests/__init__.py b/tests/__init__.py index 70381c090..de57f6eab 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -5,13 +5,6 @@ PREFIX = "djtest" -# Connection for testing -CONN_INFO = dict( - host=os.getenv("DJ_HOST"), - user=os.getenv("DJ_USER"), - password=os.getenv("DJ_PASS"), -) - CONN_INFO_ROOT = dict( host=os.getenv("DJ_HOST"), user=os.getenv("DJ_USER"), From e27147f69f52184b90dd815e9fa3f9b0da938346 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Fri, 1 Dec 2023 14:09:43 -0600 Subject: [PATCH 6/6] import cleanup --- tests/conftest.py | 2 -- tests/test_blob.py | 1 - tests/test_connection.py | 1 - 3 files changed, 4 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 8335b1c11..e13a13632 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,7 @@ -import sys import datajoint as dj from packaging import version import os import pytest -import inspect from . import PREFIX, schema, schema_simple, schema_advanced namespace = locals() diff --git a/tests/test_blob.py b/tests/test_blob.py index a3de2e9a9..23de7be76 100644 --- a/tests/test_blob.py +++ b/tests/test_blob.py @@ -1,4 +1,3 @@ -import pytest import datajoint as dj import timeit import numpy as np diff --git a/tests/test_connection.py b/tests/test_connection.py index 76b6d2389..795d3761e 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -6,7 +6,6 @@ from datajoint import DataJointError import numpy as np from . import CONN_INFO_ROOT - from . import PREFIX import pytest