diff --git a/datajoint/__init__.py b/datajoint/__init__.py index 2709b823f..cffc707e1 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -20,6 +20,13 @@ 'Manual', 'Lookup', 'Imported', 'Computed', 'conn'] +# define an object that identifies the primary key in RelationalOperand.__getitem__ +class PrimaryKey: pass + + +key = PrimaryKey + + class DataJointError(Exception): """ Base class for errors specific to DataJoint internal operation. @@ -51,4 +58,3 @@ class DataJointError(Exception): from .relational_operand import Not from .heading import Heading from .schema import schema - diff --git a/datajoint/declare.py b/datajoint/declare.py index 1e210aaa8..f725bdbf5 100644 --- a/datajoint/declare.py +++ b/datajoint/declare.py @@ -39,7 +39,7 @@ def declare(full_table_name, definition, context): in_key = False # start parsing dependent attributes elif line.startswith('->'): # foreign key - ref = eval(line[2:], context)() # TODO: surround this with try...except... to give a better error message + ref = eval(line[2:], context)() # TODO: surround this with try...except... to give a better error message foreign_key_sql.append( 'FOREIGN KEY ({primary_key})' ' REFERENCES {ref} ({primary_key})' @@ -65,7 +65,7 @@ def declare(full_table_name, definition, context): # compile SQL if not primary_key: raise DataJointError('Table must have a primary key') - sql = 'CREATE TABLE %s (\n ' % full_table_name + sql = 'CREATE TABLE IF NOT EXISTS %s (\n ' % full_table_name sql += ',\n '.join(attribute_sql) sql += ',\n PRIMARY KEY (`' + '`,`'.join(primary_key) + '`)' if foreign_key_sql: diff --git a/datajoint/erd.py b/datajoint/erd.py index e23643188..0be844496 100644 --- a/datajoint/erd.py +++ b/datajoint/erd.py @@ -22,7 +22,7 @@ class ERM: Represents known relation between tables """ - #_checked_dependencies = set() + # _checked_dependencies = set() def __init__(self, conn): self._conn = conn @@ -31,7 +31,6 @@ def __init__(self, conn): self._children = defaultdict(list) self._references = defaultdict(list) - def load_dependencies(self, full_table_name): # check if already loaded. Use clear_dependencies before reloading if full_table_name in self._parents: diff --git a/datajoint/relational_operand.py b/datajoint/relational_operand.py index 1f4fa3d21..8ba08ecc0 100644 --- a/datajoint/relational_operand.py +++ b/datajoint/relational_operand.py @@ -8,6 +8,7 @@ from collections import OrderedDict from copy import copy from . import config +from . import key as PRIMARY_KEY from . import DataJointError import logging @@ -310,6 +311,41 @@ def make_condition(arg): return ' WHERE ' + ' AND '.join(condition_string) + def __getitem__(self, item): # TODO: implement dj.key and primary key return + + attr_keys = list(self.heading.attributes.keys()) + key_index = None + + # prepare arguments for project + if isinstance(item, str): + args = (item,) + elif item is PRIMARY_KEY: # this one we return directly, since it is easy + return self.project().fetch() + elif isinstance(item, list) or isinstance(item, tuple): + args = tuple(i for i in item if not i is PRIMARY_KEY) + if PRIMARY_KEY in item: + key_index = item.index(PRIMARY_KEY) + elif isinstance(item, slice): + start = attr_keys.index(item.start) if isinstance(item.start, str) else item.start + stop = attr_keys.index(item.stop) if isinstance(item.stop, str) else item.stop + item = slice(start, stop, item.step) + args = attr_keys[item] + elif isinstance(item, int): + args = attr_keys[item] + else: + raise DataJointError("Index must be a slice, a tuple, a list, or a string.") + + tmp = self.project(*args).fetch() + if key_index is None: + return tuple(tmp[e] for e in args) + else: + retval = [tmp[e] for e in args] + + dtype2 = np.dtype({name: tmp.dtype.fields[name] for name in self.primary_key}) + tmp2 = np.unique(np.ndarray(tmp.shape, dtype2, tmp, 0, tmp.strides)) + retval.insert(key_index, tmp2) + return retval + class Not: """ @@ -370,6 +406,7 @@ def __init__(self, arg, group=None, *attributes, **renamed_attributes): if group: if arg.connection != group.connection: raise DataJointError('Cannot join relations with different database connections') + # TODO: don't Subquery if not necessary (if does not have some types of restrictions) self._group = Subquery(group) self._arg = Subquery(arg) else: diff --git a/datajoint/utils.py b/datajoint/utils.py index 7ef0cab24..662df9bca 100644 --- a/datajoint/utils.py +++ b/datajoint/utils.py @@ -42,8 +42,10 @@ def to_camel_case(s): >>>to_camel_case("table_name") "TableName" """ + def to_upper(match): return match.group(0)[-1].upper() + return re.sub('(^|[_\W])+[a-zA-Z]', to_upper, s) @@ -63,7 +65,3 @@ def convert(match): raise DataJointError( 'ClassName must be alphanumeric in CamelCase, begin with a capital letter') return re.sub(r'(\B[A-Z])|(\b[A-Z])', convert, s) - - - - diff --git a/setup.py b/setup.py index 7e1d5d07e..5ebb3959d 100644 --- a/setup.py +++ b/setup.py @@ -4,9 +4,9 @@ here = path.abspath(path.dirname(__file__)) -#with open(path.join(here, 'VERSION')) as version_file: +# with open(path.join(here, 'VERSION')) as version_file: # version = version_file.read().strip() -long_description="An object-relational mapping and relational algebra to facilitate data definition and data manipulation in MySQL databases." +long_description = "An object-relational mapping and relational algebra to facilitate data definition and data manipulation in MySQL databases." setup( @@ -16,7 +16,7 @@ long_description=long_description, author='Dimitri Yatsenko', author_email='Dimitri.Yatsenko@gmail.com', - license = "MIT", + license = "GNU LGPL", url='https://github.com/datajoint/datajoint-python', keywords='database organization', packages=find_packages(exclude=['contrib', 'docs', 'tests*']), diff --git a/tests/schema.py b/tests/schema.py index 981e205eb..55a47a38e 100644 --- a/tests/schema.py +++ b/tests/schema.py @@ -105,7 +105,7 @@ def _make_tuples(self, key): sampling_frequency=16000, duration=random.expovariate(1/30)) self.insert1(row) - EphysChannel().fill(key, number_samples=round(row.duration*row.sampling_frequency)) + EphysChannel().fill(key, number_samples=round(row.duration * row.sampling_frequency)) @schema diff --git a/tests/test_declare.py b/tests/test_declare.py index e15b49180..ef760d2ce 100644 --- a/tests/test_declare.py +++ b/tests/test_declare.py @@ -13,34 +13,34 @@ def __init__(self): def test_attributes(self): assert_list_equal(self.subject.heading.names, - ['subject_id', 'real_id', 'species', 'date_of_birth', 'subject_notes']) + ['subject_id', 'real_id', 'species', 'date_of_birth', 'subject_notes']) assert_list_equal(self.subject.primary_key, - ['subject_id']) + ['subject_id']) assert_true(self.subject.heading.attributes['subject_id'].numeric) assert_false(self.subject.heading.attributes['real_id'].numeric) experiment = schema.Experiment() assert_list_equal(experiment.heading.names, - ['subject_id', 'experiment_id', 'experiment_date', - 'username', 'data_path', - 'notes', 'entry_time']) + ['subject_id', 'experiment_id', 'experiment_date', + 'username', 'data_path', + 'notes', 'entry_time']) assert_list_equal(experiment.primary_key, - ['subject_id', 'experiment_id']) + ['subject_id', 'experiment_id']) assert_list_equal(self.trial.heading.names, - ['subject_id', 'experiment_id', 'trial_id', 'start_time']) + ['subject_id', 'experiment_id', 'trial_id', 'start_time']) assert_list_equal(self.trial.primary_key, - ['subject_id', 'experiment_id', 'trial_id']) + ['subject_id', 'experiment_id', 'trial_id']) assert_list_equal(self.ephys.heading.names, - ['subject_id', 'experiment_id', 'trial_id', 'sampling_frequency', 'duration']) + ['subject_id', 'experiment_id', 'trial_id', 'sampling_frequency', 'duration']) assert_list_equal(self.ephys.primary_key, - ['subject_id', 'experiment_id', 'trial_id']) + ['subject_id', 'experiment_id', 'trial_id']) assert_list_equal(self.channel.heading.names, - ['subject_id', 'experiment_id', 'trial_id', 'channel', 'voltage']) + ['subject_id', 'experiment_id', 'trial_id', 'channel', 'voltage']) assert_list_equal(self.channel.primary_key, - ['subject_id', 'experiment_id', 'trial_id', 'channel']) + ['subject_id', 'experiment_id', 'trial_id', 'channel']) assert_true(self.channel.heading.attributes['voltage'].is_blob) def test_dependencies(self): diff --git a/tests/test_relation.py b/tests/test_relation.py index d4fa37ddb..5077ac3e2 100644 --- a/tests/test_relation.py +++ b/tests/test_relation.py @@ -7,7 +7,7 @@ assert_tuple_equal, assert_dict_equal, raises from . import schema - +import datajoint as dj class TestRelation: """ diff --git a/tests/test_relational_operand.py b/tests/test_relational_operand.py index 3ceeb4964..7b84a1e32 100644 --- a/tests/test_relational_operand.py +++ b/tests/test_relational_operand.py @@ -1,3 +1,10 @@ +from operator import itemgetter +from numpy.testing import assert_array_equal +import numpy as np + +from . import schema +import datajoint as dj + # """ # Collection of test cases to test relational methods # """ @@ -44,4 +51,30 @@ # pass # # def test_not(self): -# pass \ No newline at end of file +# pass + +class TestRelationalOperand: + def __init__(self): + self.subject = schema.Subject() + + def test_getitem(self): + """Testing RelationalOperand.__getitem__""" + + np.testing.assert_array_equal(sorted(self.subject.project().fetch(), key=itemgetter(0)), + sorted(self.subject[dj.key], key=itemgetter(0)), + 'Primary key is not returned correctly') + + tmp = self.subject.fetch(order_by=['subject_id']) + + for column, field in zip(self.subject[:], [e[0] for e in tmp.dtype.descr]): + np.testing.assert_array_equal(sorted(tmp[field]), sorted(column), 'slice : does not work correctly') + + subject_notes, key, real_id = self.subject['subject_notes', dj.key, 'real_id'] + + np.testing.assert_array_equal(sorted(subject_notes), sorted(tmp['subject_notes'])) + np.testing.assert_array_equal(sorted(real_id), sorted(tmp['real_id'])) + np.testing.assert_array_equal(sorted(key, key=itemgetter(0)), + sorted(self.subject.project().fetch(), key=itemgetter(0))) + + for column, field in zip(self.subject['subject_id'::2], [e[0] for e in tmp.dtype.descr][::2]): + np.testing.assert_array_equal(sorted(tmp[field]), sorted(column), 'slice : does not work correctly')