diff --git a/datajoint/__init__.py b/datajoint/__init__.py index cffc707e1..b745ab686 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -18,7 +18,7 @@ 'Connection', 'Heading', 'Relation', 'FreeRelation', 'Not', 'Relation', 'schema', 'Manual', 'Lookup', 'Imported', 'Computed', - 'conn'] + 'conn', 'kill'] # define an object that identifies the primary key in RelationalOperand.__getitem__ class PrimaryKey: pass @@ -58,3 +58,4 @@ class DataJointError(Exception): from .relational_operand import Not from .heading import Heading from .schema import schema +from .kill import kill diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index 729c22441..7d2b43f95 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -4,8 +4,7 @@ import datetime from .relational_operand import RelationalOperand from . import DataJointError -from .relation import Relation, FreeRelation -from . import jobs +from .relation import FreeRelation # noinspection PyExceptionInherit,PyCallingNonCallable @@ -47,6 +46,10 @@ def _make_tuples(self, key): @property def target(self): + """ + relation to be populated. + Typically, AutoPopulate are mixed into a Relation object and the target is self. + """ return self def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False): @@ -68,7 +71,7 @@ def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False): jobs = self.connection.jobs[self.target.database] table_name = self.target.table_name - unpopulated = (self.populated_from - self.target) & restriction + unpopulated = (self.populated_from & restriction) - self.target.project() for key in unpopulated.fetch.keys(): if not reserve_jobs or jobs.reserve(table_name, key): self.connection.start_transaction() @@ -95,14 +98,17 @@ def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False): jobs.complete(table_name, key) return error_list - def progress(self): + def progress(self, restriction=None, display=True): """ report progress of populating this table + :return: remaining, total -- tuples to be populated """ - total = len(self.populated_from) - remaining = len(self.populated_from - self.target) - print('Completed %d of %d (%2.1f%%) %s' % - (total - remaining, total, 100 - 100 * remaining / total, - datetime.datetime.strftime(datetime.datetime.now(), '%Y-%m-%d %H:%M:%S') - ) if remaining - else 'Complete', flush=True) + total = len(self.populated_from & restriction) + remaining = len((self.populated_from & restriction) - self.target.project()) + if display: + print('%-20s' % self.__class__.__name__, flush=True, end=': ') + print('Completed %d of %d (%2.1f%%) %s' % + (total - remaining, total, 100 - 100 * remaining / (total+1e-12), + datetime.datetime.strftime(datetime.datetime.now(), '%Y-%m-%d %H:%M:%S') + ), flush=True) + return remaining, total diff --git a/datajoint/connection.py b/datajoint/connection.py index f059debd6..3c7559f94 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -1,10 +1,10 @@ """ -This module hosts the Connection class that manages the connection to the mysql database via -`pymysql`, and the `conn` function that provides access to a persistent connection in datajoint. - +This module hosts the Connection class that manages the connection to the mysql database, + and the `conn` function that provides access to a persistent connection in datajoint. """ + from contextlib import contextmanager -import pymysql +import pymysql as connector import logging from . import config from . import DataJointError @@ -59,7 +59,7 @@ def __init__(self, host, user, passwd, init_fun=None): else: port = config['database.port'] self.conn_info = dict(host=host, port=port, user=user, passwd=passwd) - self._conn = pymysql.connect(init_command=init_fun, **self.conn_info) + self._conn = connector.connect(**self.conn_info) if self.is_connected: logger.info("Connected {user}@{host}:{port}".format(**self.conn_info)) else: @@ -96,11 +96,11 @@ def query(self, query, args=(), as_dict=False): Execute the specified query and return the tuple generator (cursor). :param query: mysql query - :param args: additional arguments for the pymysql.cursor + :param args: additional arguments for the connector.cursor :param as_dict: If as_dict is set to True, the returned cursor objects returns query results as dictionary. """ - cursor = pymysql.cursors.DictCursor if as_dict else pymysql.cursors.Cursor + cursor = connector.cursors.DictCursor if as_dict else connector.cursors.Cursor cur = self._conn.cursor(cursor=cursor) # Log the query diff --git a/datajoint/declare.py b/datajoint/declare.py index f725bdbf5..8b357b8c0 100644 --- a/datajoint/declare.py +++ b/datajoint/declare.py @@ -35,7 +35,7 @@ def declare(full_table_name, definition, context): for line in definition: if line.startswith('#'): # additional comments are ignored pass - elif line.startswith('---'): + elif line.startswith('---') or line.startswith('___'): in_key = False # start parsing dependent attributes elif line.startswith('->'): # foreign key diff --git a/datajoint/erd.py b/datajoint/erd.py index a66c53be5..0a3e2a9ca 100644 --- a/datajoint/erd.py +++ b/datajoint/erd.py @@ -212,14 +212,12 @@ def up_down_neighbors(self, node, ups=2, downs=2, _prev=None): s = {node} if ups > 0: for x in self.predecessors_iter(node): - if x == _prev: - continue - s.update(self.up_down_neighbors(x, ups-1, downs, node)) + if x != _prev: + s.update(self.up_down_neighbors(x, ups-1, downs, node)) if downs > 0: for x in self.successors_iter(node): - if x == _prev: - continue - s.update(self.up_down_neighbors(x, ups, downs-1, node)) + if x != _prev: + s.update(self.up_down_neighbors(x, ups, downs-1, node)) return s def n_neighbors(self, node, n, directed=False, prev=None): @@ -234,23 +232,20 @@ def n_neighbors(self, node, n, directed=False, prev=None): Set directed=True to follow only outgoing edges. """ s = {node} - if n < 1: - return s if n == 1: s.update(self.predecessors(node)) s.update(self.successors(node)) - return s - if not directed: - for x in self.predecesors_iter(): - if x == prev: # skip prev point - continue - s.update(self.n_neighbors(x, n-1, prev)) - for x in self.succesors_iter(): - if x == prev: - continue - s.update(self.n_neighbors(x, n-1, prev)) + elif n > 1: + if not directed: + for x in self.predecesors_iter(): + if x != prev: # skip prev point + s.update(self.n_neighbors(x, n-1, prev)) + for x in self.succesors_iter(): + if x != prev: + s.update(self.n_neighbors(x, n-1, prev)) return s + class ERM(RelGraph): """ Entity Relation Map @@ -278,9 +273,7 @@ def update_graph(self, reload=False): # create primary key foreign connections for table, parents in self._parents.items(): mod, cls = (x.strip('`') for x in table.split('.')) - - self.add_node(table, label=table, - mod=mod, cls=cls) + self.add_node(table, label=table, mod=mod, cls=cls) for parent in parents: self.add_edge(parent, table, rel='parent') diff --git a/datajoint/fetch.py b/datajoint/fetch.py index 6b0e7a76a..5d7a079b5 100644 --- a/datajoint/fetch.py +++ b/datajoint/fetch.py @@ -26,6 +26,9 @@ def prepare_attributes(relation, item): def copy_first(f): + """ + decorates methods that return an altered copy of self + """ @wraps(f) def ret(*args, **kwargs): args = list(args) @@ -70,7 +73,6 @@ def offset(self, offset): self.behavior['offset'] = offset return self - @copy_first def set_behavior(self, **kwargs): self.behavior.update(kwargs) @@ -90,7 +92,8 @@ def __call__(self, **kwargs): """ behavior = dict(self.behavior, **kwargs) if behavior['limit'] is None and behavior['offset'] is not None: - warnings.warn('Offset set, but no limit. Setting limit to a large number. Consider setting a limit yourself.') + warnings.warn('Offset set, but no limit. Setting limit to a large number. ' + 'Consider setting a limit explicitly.') behavior['limit'] = 2*len(self._relation) cur = self._relation.cursor(**behavior) @@ -180,7 +183,9 @@ def __repr__(self): def __len__(self): return len(self._relation) + class Fetch1: + def __init__(self, relation): self._relation = relation diff --git a/datajoint/kill.py b/datajoint/kill.py new file mode 100644 index 000000000..82c9ac38d --- /dev/null +++ b/datajoint/kill.py @@ -0,0 +1,47 @@ +import pymysql +from . import conn + + +def kill(restriction=None, connection=None): + """ + view and kill database connections. + :param restriction: restriciton to be applied to processlist + :param connection: a datajoint.Connection object. Default calls datajoint.conn() + + Restrictions are specified as strings and can involve any of the attributes of + information_schema.processlist: ID, USER, HOST, DB, COMMAND, TIME, STATE, INFO. + + Examples: + dj.kill('HOST LIKE "%compute%"') lists only connections from hosts containing "compute". + dj.kill('TIME > 600') lists only connections older than 10 minutes. + """ + + if connection is None: + connection = conn() + + query = 'SELECT * FROM information_schema.processlist WHERE id <> CONNECTION_ID()'; + if restriction is not None: + query += ' AND (%s)' % restriction + + while True: + print(' ID USER STATE TIME INFO') + print('+--+ +----------+ +-----------+ +--+') + for process in connection.query(query, as_dict=True).fetchall(): + try: + print('{ID:>4d} {USER:<12s} {STATE:<12s} {TIME:>5d} {INFO}'.format(**process)) + except TypeError as err: + print(process) + + response = input('process to kill or "q" to quit)') + if response == 'q': + break + if response: + try: + id = int(response) + except ValueError: + pass # ignore non-numeric input + else: + try: + connection.query('kill %d' % id) + except pymysql.err.InternalError: + print('Process not found') \ No newline at end of file diff --git a/datajoint/relation.py b/datajoint/relation.py index d8de9c962..c93826c3a 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -1,4 +1,4 @@ -from collections import Mapping +from collections import Mapping, OrderedDict import numpy as np import logging import abc @@ -108,12 +108,14 @@ def descendants(self): """ :return: list of relation objects for all children and references, recursively, in order of dependence. + Does not include self. This is helpful for cascading delete or drop operations. """ relations = (FreeRelation(self.connection, table) for table in self.connection.erm.get_descendants(self.full_table_name)) return [relation for relation in relations if relation.is_declared] + def _repr_helper(self): return "%s.%s()" % (self.__module__, self.__class__.__name__) @@ -206,32 +208,33 @@ def delete(self): """ relations = self.descendants restrict_by_me = set() - rel_by_name = {r.full_table_name:r for r in relations} for r in relations: - for ref in r.references: - restrict_by_me.add(ref) + restrict_by_me.update(r.references) + relations = OrderedDict((r.full_table_name, r) for r in relations) if self.restrictions: restrict_by_me.add(self.full_table_name) - rel_by_name[self.full_table_name] &= self.restrictions + relations[self.full_table_name] &= self.restrictions - for r in relations: + for name in relations: + r = relations[name] for dep in (r.children + r.references): - rel_by_name[dep] &= r.project() if r.full_table_name in restrict_by_me else r.restrictions - - if config['safemode']: - do_delete = False # indicate if there is anything to delete - print('The contents of the following tables are about to be deleted:') - for relation in relations: - count = len(relation) - if count: - do_delete = True + relations[dep] &= r.project() if name in restrict_by_me else r.restrictions + + do_delete = False # indicate if there is anything to delete + print('The contents of the following tables are about to be deleted:') + for relation in relations.values(): + count = len(relation) + if count: + do_delete = True + if config['safemode']: print(relation.full_table_name, '(%d tuples)' % count) - if not do_delete or user_choice("Proceed?", default='no') != 'yes': - return - with self.connection.transaction: - while relations: - relations.pop().delete_quick() + else: + relations.pop(relation.full_table_name) + if do_delete and (not config['safemode'] or user_choice("Proceed?", default='no') == 'yes'): + with self.connection.transaction: + for r in reversed(list(relations.values())): + r.delete_quick() def drop_quick(self): """ diff --git a/datajoint/relational_operand.py b/datajoint/relational_operand.py index 1ba820222..3ee608eba 100644 --- a/datajoint/relational_operand.py +++ b/datajoint/relational_operand.py @@ -20,11 +20,12 @@ class RelationalOperand(metaclass=abc.ABCMeta): RelationalOperand operators are: restrict, pro, and join. """ - _restrictions = None + _restrictions = [] @property def restrictions(self): - return [] if self._restrictions is None else self._restrictions + assert self._restrictions is not None + return self._restrictions @property def primary_key(self): @@ -113,14 +114,17 @@ def __and__(self, restriction): """ relational restriction or semijoin """ - if not restriction: - return self + # make a copy ret = copy(self) - ret._restrictions = list(ret.restrictions) # copy restriction list - if isinstance(restriction, list) or isinstance(restriction, tuple): - ret._restrictions.extend(restriction) - else: - ret._restrictions.append(restriction) + ret._restrictions = list(ret.restrictions) + # apply restrictions, if any + if isinstance(restriction, RelationalOperand) or restriction: + restrictions = restriction \ + if isinstance(restriction, list) or isinstance(restriction, tuple) \ + else [restriction] + for restriction in restrictions: + if restriction not in ret._restrictions: + ret._restrictions.append(restriction) return ret def __sub__(self, restriction): @@ -129,8 +133,9 @@ def __sub__(self, restriction): """ return self & Not(restriction) + @abc.abstractmethod def _repr_helper(self): - return "None" + pass def __repr__(self): ret = self._repr_helper() @@ -224,8 +229,8 @@ def make_condition(arg): negate = False if not isinstance(r, str): raise DataJointError('Invalid restriction object') - conditions.append('%s(%s)' % ('not ' if negate else '', r)) - + if r: + conditions.append('%s(%s)' % ('not ' if negate else '', r)) return ' WHERE ' + ' AND '.join(conditions) @@ -249,7 +254,7 @@ def __init__(self, arg1, arg2, left=False): if arg1.connection != arg2.connection: raise DataJointError('Cannot join relations with different database connections') self._arg1 = Subquery(arg1) if arg1.heading.computed else arg1 - self._arg2 = Subquery(arg1) if arg2.heading.computed else arg2 + self._arg2 = Subquery(arg2) if arg2.heading.computed else arg2 self._restrictions = self._arg1.restrictions + self._arg2.restrictions self._left = left self._heading = self._arg1.heading.join(self._arg2.heading, left=left) @@ -283,7 +288,7 @@ def select_fields(self): class Projection(RelationalOperand): - def __init__(self, arg, *attributes, _aggregate=False, **renamed_attributes): + def __init__(self, arg, *attributes, **renamed_attributes): """ See RelationalOperand.project() """ @@ -299,7 +304,6 @@ def __init__(self, arg, *attributes, _aggregate=False, **renamed_attributes): self._renamed_attributes.update({d['alias']: d['sql_expression']}) else: self._attributes.append(attribute) - self._aggregate = _aggregate if arg.heading.computed: self._arg = Subquery(arg) @@ -307,6 +311,9 @@ def __init__(self, arg, *attributes, _aggregate=False, **renamed_attributes): self._arg = arg self._restrictions = arg.restrictions + def _repr_helper(self): + return "(%r).project(%r)" % (self._arg, self._attributes) + @property def connection(self): return self._arg.connection @@ -323,8 +330,9 @@ def __and__(self, restriction): """ When projection has renamed attributes, it must be enclosed in a subquery before restriction """ - if restriction: - return Subquery(self) & restriction if self.heading.computed else super().__and__(restriction) + has_restriction = isinstance(restriction, RelationalOperand) or restriction + do_subquery = has_restriction and self.heading.computed + return Subquery(self) & restriction if do_subquery else super().__and__(restriction) class Aggregation(Projection): @@ -332,10 +340,6 @@ class Aggregation(Projection): def _grouped(self): return True - def _repr_helper(self): - # TODO: create better repr - return "project(%r, %r)" % (self._arg, self._attributes) - class Subquery(RelationalOperand): """ diff --git a/tests/schema.py b/tests/schema.py index 57ee76d24..14061e6ff 100644 --- a/tests/schema.py +++ b/tests/schema.py @@ -39,6 +39,7 @@ class Subject(dj.Manual): def _prepare(self): self.insert(self.contents, ignore_errors=True) + @schema class Language(dj.Lookup): @@ -57,7 +58,8 @@ class Language(dj.Lookup): ('Dimitri', 'Ukrainian'), ('Fabian', 'German'), ('Edgar', 'Japanese'), - ] + ] + @schema class Experiment(dj.Imported): @@ -129,7 +131,7 @@ def _make_tuples(self, key): sampling_frequency=6000, duration=np.minimum(2, random.expovariate(1))) self.insert1(row) - number_samples = round(row['duration'] * row['sampling_frequency']); + number_samples = round(row['duration'] * row['sampling_frequency']) EphysChannel().fill(key, number_samples=number_samples) diff --git a/tests/test_fetch.py b/tests/test_fetch.py index 8706c47a0..4accfc679 100644 --- a/tests/test_fetch.py +++ b/tests/test_fetch.py @@ -117,13 +117,15 @@ def test_fetch1(self): dat = (self.lang & key).fetch1() for k, (ke, c) in zip(true, dat.items()): - assert_true(k == c == (self.lang & key).fetch1[ke], 'Values are not the same') + assert_true(k == c == (self.lang & key).fetch1[ke], + 'Values are not the same') def test_copy(self): """Test whether modifications copy the object""" f = self.lang.fetch f2 = f.order_by('name') - assert_true(f.behavior['order_by'] is None and len(f2.behavior['order_by']) == 1, 'Object was not copied') + assert_true(f.behavior['order_by'] is None and len(f2.behavior['order_by']) == 1, + 'Object was not copied') def test_repr(self): """Test string representation of fetch, returning table preview""" diff --git a/tests/test_relation.py b/tests/test_relation.py index 5077ac3e2..d025775d0 100644 --- a/tests/test_relation.py +++ b/tests/test_relation.py @@ -1,5 +1,3 @@ -import random -import string from numpy.testing import assert_array_equal import numpy as np from nose.tools import assert_raises, assert_equal, \ @@ -7,7 +5,7 @@ assert_tuple_equal, assert_dict_equal, raises from . import schema -import datajoint as dj + class TestRelation: """ diff --git a/tests/test_relational_operand.py b/tests/test_relational_operand.py index 0a8fabc22..5b68bd529 100644 --- a/tests/test_relational_operand.py +++ b/tests/test_relational_operand.py @@ -1,55 +1,219 @@ -from operator import itemgetter -from numpy.testing import assert_array_equal +import datajoint as dj +from . import PREFIX, CONN_INFO +import random import numpy as np +from nose.tools import assert_raises, assert_equal, \ + assert_false, assert_true, assert_list_equal, \ + assert_tuple_equal, assert_dict_equal, raises -from . import schema -import datajoint as dj -# """ -# Collection of test cases to test relational methods -# """ -# -# __author__ = 'eywalker' -# -# -# def setup(): -# """ -# Setup -# :return: -# """ -# -# class TestRelationalAlgebra(object): -# -# def setup(self): -# pass -# -# def test_mul(self): -# pass -# -# def test_project(self): -# pass -# -# def test_iand(self): -# pass -# -# def test_isub(self): -# pass -# -# def test_sub(self): -# pass -# -# def test_len(self): -# pass -# -# def test_fetch(self): -# pass -# -# def test_repr(self): -# pass -# -# def test_iter(self): -# pass -# -# def test_not(self): -# pass +schema = dj.schema(PREFIX + '_relational', locals(), connection=dj.conn(**CONN_INFO)) + + +@schema +class A(dj.Lookup): + definition = """ + id_a :int + --- + cond_in_a :tinyint + """ + contents = [(i, i % 4 > i % 3) for i in range(10)] + + +@schema +class B(dj.Computed): + definition = """ + -> A + id_b :int + --- + mu :float # mean value + sigma :float # standard deviation + n :smallint # number samples + """ + + def _make_tuples(self, key): + random.seed(str(key)) + sub = C() + for i in range(4): + key['id_b'] = i + mu = random.normalvariate(0, 10) + sigma = random.lognormvariate(0, 4) + n = random.randint(0, 10) + self.insert1(dict(key, mu=mu, sigma=sigma, n=n)) + for j in range(n): + sub.insert1(dict(key, id_c=j, value=random.normalvariate(mu, sigma))) + + +@schema +class C(dj.Subordinate, dj.Computed): + definition = """ + -> B + id_c :int + --- + value :float # normally distributed variables according to parameters in B + """ + + +@schema +class L(dj.Lookup): + definition = """ + id_l: int + --- + cond_in_l :tinyint + """ + contents = ((i, i % 3 >= i % 5) for i in range(30)) + + +@schema +class D(dj.Computed): + definition = """ + -> A + id_d :int + --- + -> L + """ + + def _make_tuples(self, key): + # connect to random L + random.seed(str(key)) + lookup = list(L().fetch.keys()) + for i in range(4): + self.insert1(dict(key, id_d=i, **random.choice(lookup))) + + +@schema +class E(dj.Computed): + definition = """ + -> B + -> D + --- + -> L + """ + + def _make_tuples(self, key): + random.seed(str(key)) + self.insert1(dict(key, **random.choice(list(L().fetch.keys())))) + sub = F() + references = list((C() & key).fetch.keys()) + random.shuffle(references) + for i, ref in enumerate(references): + if random.getrandbits(1): + sub.insert1(dict(key, id_f=i, **ref)) + + +@schema +class F(dj.Subordinate, dj.Computed): + definition = """ + -> E + id_f :int + --- + -> C + """ + + +def setup(): + """ + module-level test setup + """ + B().populate() + D().populate() + E().populate() + pass + + +class TestRelational: + + @staticmethod + def test_populate(): + assert_false(B().progress(display=False)[0], 'B incompletely populated') + assert_false(D().progress(display=False)[0], 'D incompletely populated') + assert_false(E().progress(display=False)[0], 'E incompletely populated') + + assert_true(len(B()) == 40, 'B populated incorrectly') + assert_true(len(C()) > 0, 'C populated incorrectly') + assert_true(len(D()) == 40, 'D populated incorrectly') + assert_true(len(E()) == len(B())*len(D())/len(A()), 'E populated incorrectly') + assert_true(len(F()) > 0, 'F populated incorrectly') + + @staticmethod + def test_join(): + # Test cartesian product + x = A() + y = L() + rel = x*y + assert_equal(len(rel), len(x)*len(y), + 'incorrect join') + assert_equal(set(x.heading.names).union(y.heading.names), set(rel.heading.names), + 'incorrect join heading') + assert_equal(set(x.primary_key).union(y.primary_key), set(rel.primary_key), + 'incorrect join primary_key') + + # Test cartesian product of restricted relations + x = A() & 'cond_in_a=1' + y = L() & 'cond_in_l=1' + rel = x*y + assert_equal(len(rel), len(x)*len(y), + 'incorrect join') + assert_equal(set(x.heading.names).union(y.heading.names), set(rel.heading.names), + 'incorrect join heading') + assert_equal(set(x.primary_key).union(y.primary_key), set(rel.primary_key), + 'incorrect join primary_key') + + # Test join with common attributes + cond = A() & 'cond_in_a=1' + x = B() & cond + y = D() + rel = x*y + assert_true(len(rel) >= len(x) and len(rel) >= len(y), 'incorrect join') + assert_false(rel - cond, 'incorrect join, restriction, or antijoin') + assert_equal(set(x.heading.names).union(y.heading.names), set(rel.heading.names), + 'incorrect join heading') + assert_equal(set(x.primary_key).union(y.primary_key), set(rel.primary_key), + 'incorrect join primary_key') + + # test renamed join + x = B().project(i='id_a') # rename the common attribute to achieve full cartesian product + y = D() + rel = x*y + assert_equal(len(rel), len(x)*len(y), + 'incorrect join') + assert_equal(set(x.heading.names).union(y.heading.names), set(rel.heading.names), + 'incorrect join heading') + assert_equal(set(x.primary_key).union(y.primary_key), set(rel.primary_key), + 'incorrect join primary_key') + + # test the % notation + x = B() % ['id_a->a'] + y = D() + rel = x*y + assert_equal(len(rel), len(x)*len(y), + 'incorrect join') + assert_equal(set(x.heading.names).union(y.heading.names), set(rel.heading.names), + 'incorrect join heading') + assert_equal(set(x.primary_key).union(y.primary_key), set(rel.primary_key), + 'incorrect join primary_key') + + # test pairing + # Approach 1 + x = A().project(a1='id_a', c1='cond_in_a') & 'c1=0' + y = A().project(a2='id_a', c2='cond_in_a') & 'c2=1' + rel = x*y & 'c1=0' & 'c2=1' + assert_equal(len(x)+len(y), len(A())) + assert_equal(len(rel), len(x)*len(y), 'incorrect pairing') + # Approach 2 + x = (A() & 'cond_in_a=0').project(a1='id_a') + y = (A() & 'cond_in_a=1').project(a2='id_a') + assert_equal(len(rel), len(x*y)) + + @staticmethod + def test_project(): + x = A().project(a='id_a') # rename + assert_equal(x.heading.names, ['a'], 'renaming does not work') + x = A().project(a='(id_a)') # extend + assert_equal(set(x.heading.names), set(('id_a', 'a')), 'extend does not work') + + @staticmethod + def test_aggregate(): + x = B().aggregate(C(), 'n', computed='count(id_c)') + assert_equal(len(x), len(B()))