Skip to content

Commit

Permalink
Adding descendant_exclude constraints to Node generation
Browse files Browse the repository at this point in the history
  • Loading branch information
mbodenhamer committed Feb 21, 2017
1 parent f2015d7 commit 824dba2
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 16 deletions.
10 changes: 9 additions & 1 deletion syn/tree/b/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ def __nonzero__(self):
def __bool__(self):
return True

@classmethod
def _generate(cls, **kwargs):
if cls._opts.descendant_exclude:
excludes = list(kwargs.get('exclude_types', []))
excludes.extend(list(cls._opts.descendant_exclude))
kwargs['exclude_types'] = excludes
return super(Node, cls)._generate(**kwargs)

@init_hook
def _initial_node_count(self):
self._node_count = 1
Expand Down Expand Up @@ -268,7 +276,7 @@ def validate(self):
if self._parent is not None:
raise TreeError("node must be root, but has parent")

dex = self._opts.descendant_exclude
dex = tuple(self._opts.descendant_exclude)
if dex:
for d in self.descendants():
if isinstance(d, dex):
Expand Down
54 changes: 39 additions & 15 deletions syn/tree/b/tests/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,17 @@
from nose.tools import assert_raises
from syn.base.b import Attr
from syn.tree.b import Node, TreeError
from syn.type.a import Schema
from syn.schema.b.sequence import Sequence
from syn.base.b.tests.test_base import check_idempotence
from syn.base_utils import assert_equivalent, assert_inequivalent, consume
from syn.base_utils import assert_equivalent, assert_inequivalent, consume, \
ngzwarn
from syn.types.a import generate
from syn.five import xrange

from syn.globals import TEST_SAMPLES as SAMPLES
SAMPLES //= 10
SAMPLES = max(SAMPLES, 1)
ngzwarn(SAMPLES, 'SAMPLES')

#-------------------------------------------------------------------------------
# Tree Node Test 1
Expand Down Expand Up @@ -291,34 +297,51 @@ class CT2(Node):
pass

class CTTest(Node):
_opts = dict(init_validate = True)
_opts = dict(init_validate = True,
min_len = 1)
types = [CT1]

def test_child_types():
CTTest()
CTTest(CT1())
assert_raises(TypeError, CTTest, CT1(), CT2())

for k in xrange(SAMPLES):
val = generate(CTTest)
val.validate()

#-------------------------------------------------------------------------------
# Descendant Exclude

class DE1(Node):
pass

class DE2(Node):
class DE3(Node):
pass

class DE4(Node):
pass

class DE2(Node):
_opts = dict(min_len = 1)
types = [DE3, DE4]

class DETest(Node):
_opts = dict(descendant_exclude = (DE2,),
types = [DE1, DE2]
_opts = dict(descendant_exclude = [DE4],
min_len = 1,
init_validate = True)

def test_descendant_exclude():
DETest()
DETest(DE1())

n = DE1(DE2())
DETest(DE1(DE2(DE3())))

n = DE2(DE4())
assert_raises(TypeError, DETest, n)

for k in xrange(SAMPLES):
val = generate(DETest)
val.validate()

#-------------------------------------------------------------------------------
# Schema Attrs

Expand All @@ -340,12 +363,13 @@ def test_schema_attrs():
SchemaTest(SA1(), SA2(), a=1)
assert_raises(TypeError, SchemaTest, SA1(), SA3(), a=2)

val = generate(SchemaTest)
assert type(val) is SchemaTest
assert isinstance(val.a, int)
assert type(val[0]) is SA1
assert type(val[1]) is SA2
assert len(val) == 2
for k in xrange(SAMPLES):
val = generate(SchemaTest)
assert type(val) is SchemaTest
assert isinstance(val.a, int)
assert type(val[0]) is SA1
assert type(val[1]) is SA2
assert len(val) == 2

def bad():
class SchemaTest2(Node):
Expand Down

0 comments on commit 824dba2

Please sign in to comment.