Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datajoint/free_relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def batch_insert(self, data, **kwargs):
"""
self.iter_insert(data.__iter__(), **kwargs)

def insert(self, tup, ignore_errors=False, replace=False): # TODO: in progress (issue #8)
def insert(self, tup, ignore_errors=False, replace=False):
"""
Insert one data record or one Mapping (like a dictionary).

Expand Down
5 changes: 3 additions & 2 deletions datajoint/relational_operand.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def make_condition(arg):
return ' AND '.join(conditions)

condition_string = []
for r in self._restrictions:
for r in self.restrictions:
negate = isinstance(r, Not)
if negate:
r = r.restriction
Expand All @@ -227,7 +227,8 @@ def make_condition(arg):
r = '('+') OR ('.join([make_condition(q) for q in r])+')'
elif isinstance(r, RelationalOperand):
common_attributes = ','.join([q for q in self.heading.names if r.heading.names])
r = '(%s) in (SELECT %s FROM %s)' % (common_attributes, common_attributes, r.from_clause)
r = '(%s) in (SELECT %s FROM %s%s)' % (
common_attributes, common_attributes, r.from_clause, r.where_clause)

assert isinstance(r, str), 'condition must be converted into a string'
r = '('+r+')'
Expand Down
1 change: 0 additions & 1 deletion demos/demo1.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
"""

import datajoint as dj
import os

print("Welcome to the database 'demo1'")

Expand Down
20 changes: 20 additions & 0 deletions tests/schemata/schema1/test1.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import datajoint as dj
from .. import schema2


class Subjects(dj.Relation):
definition = """
test1.Subjects (manual) # Basic subject info
Expand All @@ -16,6 +17,21 @@ class Subjects(dj.Relation):
species = "mouse" : enum('mouse', 'monkey', 'human') # species
"""


class Trials(dj.Relation):
definition = """
test1.Trials (manual) # info about trials

-> test1.Subjects
trial_id : int
---
outcome : int # result of experiment

notes="" : varchar(4096) # other comments
trial_ts=CURRENT_TIMESTAMP : timestamp # automatic
"""


# test reference to another table in same schema
class Experiments(dj.Relation):
definition = """
Expand All @@ -26,6 +42,7 @@ class Experiments(dj.Relation):
exp_data_file : varchar(255) # data file
"""


# refers to a table in dj_test2 (bound to test2) but without a class
class Sessions(dj.Relation):
definition = """
Expand All @@ -37,6 +54,7 @@ class Sessions(dj.Relation):
session_comment : varchar(255) # comment about the session
"""


class Match(dj.Relation):
definition = """
test1.Match (manual) # Match between subject and color
Expand All @@ -45,6 +63,7 @@ class Match(dj.Relation):
dob : date # date of birth
"""


# this tries to reference a table in database directly without ORM
class TrainingSession(dj.Relation):
definition = """
Expand All @@ -53,5 +72,6 @@ class TrainingSession(dj.Relation):
session_id : int # training session id
"""


class Empty(dj.Relation):
pass
3 changes: 3 additions & 0 deletions tests/schemata/schema1/test2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class Experiments(dj.Relation):
species = "mouse" : enum('mouse', 'monkey', 'human') # species
"""


# references to another schema
class Conditions(dj.Relation):
definition = """
Expand All @@ -27,13 +28,15 @@ class Conditions(dj.Relation):
condition_name : varchar(255) # description of the condition
"""


class FoodPreference(dj.Relation):
definition = """
test2.FoodPreference (manual) # Food preference of each subject
-> animals.Subjects
preferred_food : enum('banana', 'apple', 'oranges')
"""


class Session(dj.Relation):
definition = """
test2.Session (manual) # Experiment sessions
Expand Down
1 change: 0 additions & 1 deletion tests/schemata/schema2/test1.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import datajoint as dj



class Subjects(dj.Relation):
definition = """
schema2.Subjects (manual) # Basic subject info
Expand Down
10 changes: 9 additions & 1 deletion tests/test_blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@


def test_pack():
x = np.random.randn(10, 10)
x = np.random.randn(8, 10)
assert_array_equal(x, unpack(pack(x)), "Arrays do not match!")

x = np.random.randn(10)
assert_array_equal(x, unpack(pack(x)), "Arrays do not match!")

x = np.float32(np.random.randn(3, 4, 5))
assert_array_equal(x, unpack(pack(x)), "Arrays do not match!")

x = np.int16(np.random.randn(1, 2, 3))
assert_array_equal(x, unpack(pack(x)), "Arrays do not match!")
70 changes: 42 additions & 28 deletions tests/test_relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,16 @@
import numpy as np
from numpy.testing import assert_array_equal
from datajoint.free_relation import FreeRelation
import numpy as np


def trial_faker(n=10):
def iter():
for s in [1, 2]:
for i in range(n):
yield dict(trial_id=i, subject_id=s, outcome=int(np.random.randint(10)), notes= 'no comment')
return iter()


def setup():
"""
Expand All @@ -23,7 +33,7 @@ def setup():

class TestTableObject(object):
def __init__(self):
self.relvar = None
self.subjects = None
self.setup()

"""
Expand All @@ -38,60 +48,64 @@ def setup(self):
"""
cleanup() # drop all databases with PREFIX
test1.__dict__.pop('conn', None)
test4.__dict__.pop('conn', None) # make sure conn is not defined at schema level
test4.__dict__.pop('conn', None) # make sure conn is not defined at schema level

self.conn = Connection(**CONN_INFO)
test1.conn = self.conn
test4.conn = self.conn
self.conn.bind(test1.__name__, PREFIX + '_test1')
self.conn.bind(test4.__name__, PREFIX + '_test4')
self.relvar = test1.Subjects()
self.subjects = test1.Subjects()
self.relvar_blob = test4.Matrix()
self.trials = test1.Trials()

def teardown(self):
cleanup()

def test_compound_restriction(self):
s = self.subjects
t = self.trials

s.insert(dict(subject_id=1, real_id='M' ))
s.insert(dict(subject_id=2, real_id='F' ))
t.iter_insert(trial_faker(20))

tM = t & (s & "real_id = 'M'")
t1 = t & "subject_id = 1"

# def test_tuple_insert(self):
# "Test whether tuple insert works"
# testt = (1, 'Peter', 'mouse')
# self.relvar.insert(testt)
# testt2 = tuple((self.relvar & 'subject_id = 1').fetch()[0])
# assert_equal(testt2, testt, "Inserted and fetched tuple do not match!")
assert_equal(tM.count, t1.count, "Results of compound request does not have same length")

# def test_list_insert(self):
# "Test whether tuple insert works"
# testt = [1, 'Peter', 'mouse']
# self.relvar.insert(testt)
# testt2 = list((self.relvar & 'subject_id = 1').fetch()[0])
# assert_equal(testt2, testt, "Inserted and fetched tuple do not match!")
for t1_item, tM_item in zip(sorted(t1, key=lambda item: item['trial_id']),
sorted(tM, key=lambda item: item['trial_id'])):
assert_dict_equal(t1_item, tM_item,
'Dictionary elements do not agree in compound statement')

def test_record_insert(self):
"Test whether record insert works"
tmp = np.array([(2, 'Klara', 'monkey')],
dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')])

self.relvar.insert(tmp[0])
testt2 = (self.relvar & 'subject_id = 2').fetch()[0]
self.subjects.insert(tmp[0])
testt2 = (self.subjects & 'subject_id = 2').fetch()[0]
assert_equal(tuple(tmp[0]), tuple(testt2), "Inserted and fetched record do not match!")

def test_record_insert_different_order(self):
"Test whether record insert works"
tmp = np.array([('Klara', 2, 'monkey')],
dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')])

self.relvar.insert(tmp[0])
testt2 = (self.relvar & 'subject_id = 2').fetch()[0]
assert_equal((2, 'Klara', 'monkey'), tuple(testt2), "Inserted and fetched record do not match!")
self.subjects.insert(tmp[0])
testt2 = (self.subjects & 'subject_id = 2').fetch()[0]
assert_equal((2, 'Klara', 'monkey'), tuple(testt2),
"Inserted and fetched record do not match!")

@raises(KeyError)
def test_wrong_key_insert_records(self):
"Test whether record insert works"
tmp = np.array([('Klara', 2, 'monkey')],
dtype=[('real_deal', 'O'), ('subject_id', '>i4'), ('species', 'O')])

self.relvar.insert(tmp[0])
self.subjects.insert(tmp[0])


def test_dict_insert(self):
Expand All @@ -100,8 +114,8 @@ def test_dict_insert(self):
'subject_id': 3,
'species': 'human'}

self.relvar.insert(tmp)
testt2 = (self.relvar & 'subject_id = 3').fetch()[0]
self.subjects.insert(tmp)
testt2 = (self.subjects & 'subject_id = 3').fetch()[0]
assert_equal((3, 'Brunhilda', 'human'), tuple(testt2), "Inserted and fetched record do not match!")

@raises(KeyError)
Expand All @@ -111,19 +125,19 @@ def test_wrong_key_insert(self):
'subject_database': 3,
'species': 'human'}

self.relvar.insert(tmp)
self.subjects.insert(tmp)

def test_batch_insert(self):
"Test whether record insert works"
tmp = np.array([('Klara', 2, 'monkey'), ('Brunhilda', 3, 'mouse'), ('Mickey', 1, 'human')],
dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')])

self.relvar.batch_insert(tmp)
self.subjects.batch_insert(tmp)

expected = np.array([(1, 'Mickey', 'human'), (2, 'Klara', 'monkey'),
(3, 'Brunhilda', 'mouse')],
dtype=[('subject_id', '<i4'), ('real_id', 'O'), ('species', 'O')])
delivered = self.relvar.fetch()
delivered = self.subjects.fetch()

for e,d in zip(expected, delivered):
assert_equal(tuple(e), tuple(d),'Inserted and fetched records do not match')
Expand All @@ -133,12 +147,12 @@ def test_iter_insert(self):
tmp = np.array([('Klara', 2, 'monkey'), ('Brunhilda', 3, 'mouse'), ('Mickey', 1, 'human')],
dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')])

self.relvar.iter_insert(tmp.__iter__())
self.subjects.iter_insert(tmp.__iter__())

expected = np.array([(1, 'Mickey', 'human'), (2, 'Klara', 'monkey'),
(3, 'Brunhilda', 'mouse')],
dtype=[('subject_id', '<i4'), ('real_id', 'O'), ('species', 'O')])
delivered = self.relvar.fetch()
delivered = self.subjects.fetch()

for e,d in zip(expected, delivered):
assert_equal(tuple(e), tuple(d),'Inserted and fetched records do not match')
Expand Down