In [1]:
from gamma import *

testcase = namedtuple('testcase', ('graph', 'rule', 'result'))

def run(testcase):
    draw(testcase.graph)
    draw(testcase.rule)
    res = apply_rule(testcase.graph, testcase.rule)
    draw(res)
    assert(res) == testcase.result

#### Test 1: remove nodes

In [2]:
make_node = lambda t, params={}: {'type': t, 'params': params}
identity = make_node("Identity", {'important': False})
important_identity = make_node('Identity', {'important': True})


@bind_vars
def remove_identities(name, _in):
    LHS = {name: (identity, [_in])}
    RHS = {}
    redirects = (name, _in)
    return LHS, RHS, redirects

g = pipeline([
    ('A', identity, ['input']),
    ('B', identity),
    ('C', make_node('Type_A')),
    ('D', identity),
    ('E', identity, ['C']),
    ('F', identity),
    ('G', identity),
    ('H', make_node('Type_B')),
    ('I', identity),
    ('J', important_identity),
    ('K', identity, ['H']),
    ('L', make_node('Type_C'), ['I'])
], prefix='scope')

res = {
    'scope/C': ({'type': 'Type_A', 'params': {}}, ['input']),
    'scope/H': ({'type': 'Type_B', 'params': {}}, ['scope/C']),
    'scope/J': ({'type': 'Identity', 'params': {'important': True}}, ['scope/H']),
    'scope/L': ({'type': 'Type_C', 'params': {}}, ['scope/H'])
}
run(testcase(g, remove_identities(), res))

#### Test 2: match with dangling edges

In [3]:
add = {'type': 'Add', 'params': {}}
scale= {'type': 'Scale', 'params': {}}
add_scale = {'type': 'AddScale', 'params': {}}

@bind_vars
def fuse_add_scale(add_name, scale_name, input_name):
    LHS = pipeline([
        (add_name, add, [input_name]),
        (scale_name, scale)
    ])
    RHS = {path(scale_name, 'fused'): (add_scale, [input_name])}
    return LHS, RHS, (scale_name, path(scale_name, 'fused'))

g = pipeline([
    ('A', add, ['input']),
    ('B', scale),
    ('C', add),
    ('D', add, ['A']),
    ('E', scale)
])

res = {
    'A': ({'type': 'Add', 'params': {}}, ['input']),
    'B': ({'type': 'Scale', 'params': {}}, ['A']),
    'C': ({'type': 'Add', 'params': {}}, ['B']),
    'E/fused': ({'type': 'AddScale', 'params': {}}, ['A'])
}

run(testcase(g, fuse_add_scale(), res))