Skip to content

Commit

Permalink
Consolidate op transformation
Browse files Browse the repository at this point in the history
  • Loading branch information
cjdrake committed Feb 10, 2015
1 parent 3679686 commit 538d33b
Showing 1 changed file with 24 additions and 39 deletions.
63 changes: 24 additions & 39 deletions pyeda/boolalg/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1336,30 +1336,12 @@ def support(self):
return frozenset.union(*[x.support for x in self.xs])

def _urestrict1(self, upoint):
modified = False
ys = list()
for x in self.xs:
y = x._urestrict1(upoint)
if y is not x:
modified = True
ys.append(y)
if modified:
return self.__class__(*ys)
else:
return self
f = lambda x: x._urestrict1(upoint)
return _op_transform(self, f)

def compose(self, mapping):
modified = False
ys = list()
for x in self.xs:
y = x.compose(mapping)
if y is not x:
modified = True
ys.append(y)
if modified:
return self.__class__(*ys).simplify()
else:
return self.simplify()
f = lambda x: x.compose(mapping)
return _op_transform(self, f).simplify()

# From Expression
def _traverse(self, visited):
Expand Down Expand Up @@ -1519,25 +1501,14 @@ def term_index(self):

# FactoredExpression
def _flatten(self, op):
modified = False
ys = list()
for x in self.xs:
y = x._flatten(op)
if y is not x:
modified = True
ys.append(y)
if modified:
f = self.__class__(*ys).simplify()._absorb()
f = lambda x: x._flatten(op)
ex = _op_transform(self, f).simplify()._absorb()
if ex.depth < 2 or isinstance(ex, op.get_dual()):
return ex
else:
f = self

if f.depth < 2 or isinstance(f, op.get_dual()):
return f
else:
args = [x._lits for x in f.xs]
args = {x._lits for x in ex.xs}
prod = {frozenset(t) for t in itertools.product(*args)}
g = op.get_dual()(*[op(*t) for t in prod]).simplify()
return g
return op.get_dual()(*[op(*t) for t in prod]).simplify()

# FlattenedExpression
@cached_property
Expand Down Expand Up @@ -2317,6 +2288,20 @@ def __str__(self):
return "p cnf {0.nvars} {0.nclauses}\n{1}".format(self, formula)


def _op_transform(op, f):
modified = False
ys = list()
for x in op.xs:
y = f(x)
if y is not x:
modified = True
ys.append(y)
if modified:
return op.__class__(*ys)
else:
return op


def _iter_zeros(expr):
"""Iterate through all upoints that map to element zero."""
if expr is EXPRZERO:
Expand Down

0 comments on commit 538d33b

Please sign in to comment.