Skip to content

Commit

Permalink
Merge d7b3720 into ad78d8b
Browse files Browse the repository at this point in the history
  • Loading branch information
rdaly525 committed Sep 18, 2019
2 parents ad78d8b + d7b3720 commit 5f26a82
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 1 deletion.
19 changes: 19 additions & 0 deletions hwtypes/smt_bit_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def wrapped(self, other):
return fn(self, other)
return wrapped

SMTBit=None
class SMTBit(AbstractBit):
@staticmethod
def get_family() -> TypeFamily:
Expand Down Expand Up @@ -163,6 +164,14 @@ def ite(self, t_branch, f_branch):

return T(smt.Ite(self.value, t_branch.value, f_branch.value))

def substitute(self, *subs : tp.List[tp.Tuple['SMTBit', 'SMTBit']]):
return SMTBit(
self.value.substitute(
{from_.value:to.value for from_, to in subs}
)
)


def _coerce(T : tp.Type['SMTBitVector'], val : tp.Any) -> 'SMTBitVector':
if not isinstance(val, SMTBitVector):
return T(val)
Expand Down Expand Up @@ -630,6 +639,14 @@ def zext(self, ext):
raise ValueError()
return type(self).unsized_t[self.size + ext](smt.BVZExt(self.value, ext))

def substitute(self, *subs : tp.List[tp.Tuple["SBV", "SBV"]]):
return SMTBitVector[self.size](
self.value.substitute(
{from_.value:to.value for from_, to in subs}
)
)


# def bits(self):
# return [(self >> i) & 1 for i in range(self.size)]
#
Expand Down Expand Up @@ -711,4 +728,6 @@ def __le__(self, other):
return NotImplemented




_Family_ = TypeFamily(SMTBit, SMTBitVector, SMTUIntVector, SMTSIntVector)
11 changes: 11 additions & 0 deletions tests/test_smt_bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,14 @@ def test_bin_op(op, Bit):
@pytest.mark.parametrize("Bit", [SMTBit, z3Bit])
def test_unary_op(op, Bit):
assert isinstance(op(Bit()), Bit)

def test_substitute():
a0 = SMTBit()
a1 = SMTBit()
b0 = SMTBit()
b1 = SMTBit()
expr0 = a0|b0
expr1 = expr0.substitute((a0, a1), (b0, b1))
assert expr1.value is (a1|b1).value


12 changes: 11 additions & 1 deletion tests/test_smt_bv.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
import operator
from hwtypes import SMTBitVector, z3BitVector
from hwtypes import SMTBitVector, z3BitVector, SMTBit


WIDTHS = [1,2,4,8]
Expand Down Expand Up @@ -42,3 +42,13 @@ def test_unary_op(width, op, BV):
@pytest.mark.parametrize("BV", [SMTBitVector, z3BitVector])
def test_bit_op(width, op, BV):
assert isinstance(op(BV[width](), BV[width]()), BV.get_family().Bit)

def test_substitute():
a0 = SMTBitVector[3]()
a1 = SMTBitVector[3]()
b0 = SMTBitVector[3]()
b1 = SMTBitVector[3]()
expr0 = a0 + b0*a0
expr1 = expr0.substitute((a0, a1), (b0, b1))
assert expr1.value is (a1 + b1*a1).value

0 comments on commit 5f26a82

Please sign in to comment.