Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion datajoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@
'Manual', 'Lookup', 'Imported', 'Computed',
'conn']

# define an object that identifies the primary key in RelationalOperand.__getitem__
class PrimaryKey: pass


key = PrimaryKey


class DataJointError(Exception):
"""
Base class for errors specific to DataJoint internal operation.
Expand Down Expand Up @@ -51,4 +58,3 @@ class DataJointError(Exception):
from .relational_operand import Not
from .heading import Heading
from .schema import schema

4 changes: 2 additions & 2 deletions datajoint/declare.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def declare(full_table_name, definition, context):
in_key = False # start parsing dependent attributes
elif line.startswith('->'):
# foreign key
ref = eval(line[2:], context)() # TODO: surround this with try...except... to give a better error message
ref = eval(line[2:], context)() # TODO: surround this with try...except... to give a better error message
foreign_key_sql.append(
'FOREIGN KEY ({primary_key})'
' REFERENCES {ref} ({primary_key})'
Expand All @@ -65,7 +65,7 @@ def declare(full_table_name, definition, context):
# compile SQL
if not primary_key:
raise DataJointError('Table must have a primary key')
sql = 'CREATE TABLE %s (\n ' % full_table_name
sql = 'CREATE TABLE IF NOT EXISTS %s (\n ' % full_table_name
sql += ',\n '.join(attribute_sql)
sql += ',\n PRIMARY KEY (`' + '`,`'.join(primary_key) + '`)'
if foreign_key_sql:
Expand Down
3 changes: 1 addition & 2 deletions datajoint/erd.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class ERM:

Represents known relation between tables
"""
#_checked_dependencies = set()
# _checked_dependencies = set()

def __init__(self, conn):
self._conn = conn
Expand All @@ -31,7 +31,6 @@ def __init__(self, conn):
self._children = defaultdict(list)
self._references = defaultdict(list)


def load_dependencies(self, full_table_name):
# check if already loaded. Use clear_dependencies before reloading
if full_table_name in self._parents:
Expand Down
37 changes: 37 additions & 0 deletions datajoint/relational_operand.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from collections import OrderedDict
from copy import copy
from . import config
from . import key as PRIMARY_KEY
from . import DataJointError
import logging

Expand Down Expand Up @@ -310,6 +311,41 @@ def make_condition(arg):

return ' WHERE ' + ' AND '.join(condition_string)

def __getitem__(self, item): # TODO: implement dj.key and primary key return

attr_keys = list(self.heading.attributes.keys())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

attr_keys = self.heading.names

key_index = None

# prepare arguments for project
if isinstance(item, str):
args = (item,)
elif item is PRIMARY_KEY: # this one we return directly, since it is easy
return self.project().fetch()
elif isinstance(item, list) or isinstance(item, tuple):
args = tuple(i for i in item if not i is PRIMARY_KEY)
if PRIMARY_KEY in item:
key_index = item.index(PRIMARY_KEY)
elif isinstance(item, slice):
start = attr_keys.index(item.start) if isinstance(item.start, str) else item.start
stop = attr_keys.index(item.stop) if isinstance(item.stop, str) else item.stop
item = slice(start, stop, item.step)
args = attr_keys[item]
elif isinstance(item, int):
args = attr_keys[item]
else:
raise DataJointError("Index must be a slice, a tuple, a list, or a string.")

tmp = self.project(*args).fetch()
if key_index is None:
return tuple(tmp[e] for e in args)
else:
retval = [tmp[e] for e in args]

dtype2 = np.dtype({name: tmp.dtype.fields[name] for name in self.primary_key})
tmp2 = np.unique(np.ndarray(tmp.shape, dtype2, tmp, 0, tmp.strides))
retval.insert(key_index, tmp2)
return retval


class Not:
"""
Expand Down Expand Up @@ -370,6 +406,7 @@ def __init__(self, arg, group=None, *attributes, **renamed_attributes):
if group:
if arg.connection != group.connection:
raise DataJointError('Cannot join relations with different database connections')
# TODO: don't Subquery if not necessary (if does not have some types of restrictions)
self._group = Subquery(group)
self._arg = Subquery(arg)
else:
Expand Down
6 changes: 2 additions & 4 deletions datajoint/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@ def to_camel_case(s):
>>>to_camel_case("table_name")
"TableName"
"""

def to_upper(match):
return match.group(0)[-1].upper()

return re.sub('(^|[_\W])+[a-zA-Z]', to_upper, s)


Expand All @@ -63,7 +65,3 @@ def convert(match):
raise DataJointError(
'ClassName must be alphanumeric in CamelCase, begin with a capital letter')
return re.sub(r'(\B[A-Z])|(\b[A-Z])', convert, s)




6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

here = path.abspath(path.dirname(__file__))

#with open(path.join(here, 'VERSION')) as version_file:
# with open(path.join(here, 'VERSION')) as version_file:
# version = version_file.read().strip()
long_description="An object-relational mapping and relational algebra to facilitate data definition and data manipulation in MySQL databases."
long_description = "An object-relational mapping and relational algebra to facilitate data definition and data manipulation in MySQL databases."


setup(
Expand All @@ -16,7 +16,7 @@
long_description=long_description,
author='Dimitri Yatsenko',
author_email='Dimitri.Yatsenko@gmail.com',
license = "MIT",
license = "GNU LGPL",
url='https://github.com/datajoint/datajoint-python',
keywords='database organization',
packages=find_packages(exclude=['contrib', 'docs', 'tests*']),
Expand Down
2 changes: 1 addition & 1 deletion tests/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _make_tuples(self, key):
sampling_frequency=16000,
duration=random.expovariate(1/30))
self.insert1(row)
EphysChannel().fill(key, number_samples=round(row.duration*row.sampling_frequency))
EphysChannel().fill(key, number_samples=round(row.duration * row.sampling_frequency))


@schema
Expand Down
24 changes: 12 additions & 12 deletions tests/test_declare.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,34 +13,34 @@ def __init__(self):

def test_attributes(self):
assert_list_equal(self.subject.heading.names,
['subject_id', 'real_id', 'species', 'date_of_birth', 'subject_notes'])
['subject_id', 'real_id', 'species', 'date_of_birth', 'subject_notes'])
assert_list_equal(self.subject.primary_key,
['subject_id'])
['subject_id'])
assert_true(self.subject.heading.attributes['subject_id'].numeric)
assert_false(self.subject.heading.attributes['real_id'].numeric)

experiment = schema.Experiment()
assert_list_equal(experiment.heading.names,
['subject_id', 'experiment_id', 'experiment_date',
'username', 'data_path',
'notes', 'entry_time'])
['subject_id', 'experiment_id', 'experiment_date',
'username', 'data_path',
'notes', 'entry_time'])
assert_list_equal(experiment.primary_key,
['subject_id', 'experiment_id'])
['subject_id', 'experiment_id'])

assert_list_equal(self.trial.heading.names,
['subject_id', 'experiment_id', 'trial_id', 'start_time'])
['subject_id', 'experiment_id', 'trial_id', 'start_time'])
assert_list_equal(self.trial.primary_key,
['subject_id', 'experiment_id', 'trial_id'])
['subject_id', 'experiment_id', 'trial_id'])

assert_list_equal(self.ephys.heading.names,
['subject_id', 'experiment_id', 'trial_id', 'sampling_frequency', 'duration'])
['subject_id', 'experiment_id', 'trial_id', 'sampling_frequency', 'duration'])
assert_list_equal(self.ephys.primary_key,
['subject_id', 'experiment_id', 'trial_id'])
['subject_id', 'experiment_id', 'trial_id'])

assert_list_equal(self.channel.heading.names,
['subject_id', 'experiment_id', 'trial_id', 'channel', 'voltage'])
['subject_id', 'experiment_id', 'trial_id', 'channel', 'voltage'])
assert_list_equal(self.channel.primary_key,
['subject_id', 'experiment_id', 'trial_id', 'channel'])
['subject_id', 'experiment_id', 'trial_id', 'channel'])
assert_true(self.channel.heading.attributes['voltage'].is_blob)

def test_dependencies(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
assert_tuple_equal, assert_dict_equal, raises

from . import schema

import datajoint as dj

class TestRelation:
"""
Expand Down
35 changes: 34 additions & 1 deletion tests/test_relational_operand.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
from operator import itemgetter
from numpy.testing import assert_array_equal
import numpy as np

from . import schema
import datajoint as dj

# """
# Collection of test cases to test relational methods
# """
Expand Down Expand Up @@ -44,4 +51,30 @@
# pass
#
# def test_not(self):
# pass
# pass

class TestRelationalOperand:
def __init__(self):
self.subject = schema.Subject()

def test_getitem(self):
"""Testing RelationalOperand.__getitem__"""

np.testing.assert_array_equal(sorted(self.subject.project().fetch(), key=itemgetter(0)),
sorted(self.subject[dj.key], key=itemgetter(0)),
'Primary key is not returned correctly')

tmp = self.subject.fetch(order_by=['subject_id'])

for column, field in zip(self.subject[:], [e[0] for e in tmp.dtype.descr]):
np.testing.assert_array_equal(sorted(tmp[field]), sorted(column), 'slice : does not work correctly')

subject_notes, key, real_id = self.subject['subject_notes', dj.key, 'real_id']

np.testing.assert_array_equal(sorted(subject_notes), sorted(tmp['subject_notes']))
np.testing.assert_array_equal(sorted(real_id), sorted(tmp['real_id']))
np.testing.assert_array_equal(sorted(key, key=itemgetter(0)),
sorted(self.subject.project().fetch(), key=itemgetter(0)))

for column, field in zip(self.subject['subject_id'::2], [e[0] for e in tmp.dtype.descr][::2]):
np.testing.assert_array_equal(sorted(tmp[field]), sorted(column), 'slice : does not work correctly')