Skip to content

Commit

Permalink
Add reflected operators to bit and bv types
Browse files Browse the repository at this point in the history
  • Loading branch information
cdonovick committed Aug 2, 2023
1 parent aa051b3 commit 4948fe4
Show file tree
Hide file tree
Showing 5 changed files with 199 additions and 244 deletions.
196 changes: 74 additions & 122 deletions hwtypes/bit_vector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import typing as tp
from .bit_vector_abc import AbstractBitVector, AbstractBit, TypeFamily, InconsistentSizeError
from .bit_vector_util import build_ite
from .util import Method
from .compatibility import IntegerTypes, StringTypes

import functools
Expand Down Expand Up @@ -71,14 +72,27 @@ def __ne__(self, other):
def __and__(self, other):
return type(self)(self._value & other._value)

@bit_cast
def __rand__(self, other):
return type(self)(other._value & self._value)

@bit_cast
def __or__(self, other):
return type(self)(self._value | other._value)

@bit_cast
def __ror__(self, other):
return type(self)(other._value | self._value)

@bit_cast
def __xor__(self, other):
return type(self)(self._value ^ other._value)

@bit_cast
def __rxor__(self, other):
return type(self)(other._value ^ self._value)


def ite(self, t_branch, f_branch):
'''
typing works as follows:
Expand Down Expand Up @@ -132,6 +146,34 @@ def wrapped(self : 'BitVector', other : tp.Any) -> tp.Any:
return fn(self, other)
return wrapped



def dispatch_oper(method: tp.MethodDescriptorType):
def oper(self, other):
try:
return method(self, other)
except InconsistentSizeError as e:
raise e from None
except TypeError:
return NotImplemented

return Method(oper)


# A little inefficient because of double _coerce but whate;er
def dispatch_roper(method: Method):
def roper(self, other):
try:
other = _coerce(type(self), other)
except inconsistentsizeerror as e:
raise e from None
except TypeError:
return NotImplemented
return method(other, self)

return Method(roper)


class BitVector(AbstractBitVector):
@staticmethod
def get_family() -> TypeFamily:
Expand Down Expand Up @@ -328,8 +370,6 @@ def bvurem(self, other):
return self
return type(self)(self.as_uint() % other)

# bvumod

@bv_cast
def bvsdiv(self, other):
other = other.as_sint()
Expand All @@ -344,140 +384,46 @@ def bvsrem(self, other):
return self
return type(self)(self.as_sint() % other)

# bvsmod
def __invert__(self): return self.bvnot()

def __and__(self, other):
try:
return self.bvand(other)
except InconsistentSizeError as e:
raise e from None
except TypeError:
return NotImplemented
__and__ = dispatch_oper(bvand)
__rand__ = dispatch_roper(__and__)

def __or__(self, other):
try:
return self.bvor(other)
except InconsistentSizeError as e:
raise e from None
except TypeError:
return NotImplemented
__or__ = dispatch_oper(bvor)
__ror__ = dispatch_roper(__or__)

def __xor__(self, other):
try:
return self.bvxor(other)
except InconsistentSizeError as e:
raise e from None
except TypeError:
return NotImplemented
__xor__ = dispatch_oper(bvxor)
__rxor__ = dispatch_roper(__xor__)

__lshift__ = dispatch_oper(bvshl)
__rlshift__ = dispatch_roper(__lshift__)

def __lshift__(self, other):
try:
return self.bvshl(other)
except InconsistentSizeError as e:
raise e from None
except TypeError:
return NotImplemented

def __rshift__(self, other):
try:
return self.bvlshr(other)
except InconsistentSizeError as e:
raise e from None
except TypeError:
return NotImplemented
__rshift__ = dispatch_oper(bvlshr)
__rrshift__ = dispatch_oper(__rshift__)

def __neg__(self): return self.bvneg()

def __add__(self, other):
try:
return self.bvadd(other)
except InconsistentSizeError as e:
raise e from None
except TypeError:
return NotImplemented

def __sub__(self, other):
try:
return self.bvsub(other)
except InconsistentSizeError as e:
raise e from None
except TypeError:
return NotImplemented

def __mul__(self, other):
try:
return self.bvmul(other)
except InconsistentSizeError as e:
raise e from None
except TypeError:
return NotImplemented
__add__ = dispatch_oper(bvadd)
__radd__ = dispatch_roper(__add__)

def __floordiv__(self, other):
try:
return self.bvudiv(other)
except InconsistentSizeError as e:
raise e from None
except TypeError:
return NotImplemented
__sub__ = dispatch_oper(bvsub)
__rsub__ = dispatch_roper(__sub__)

def __mod__(self, other):
try:
return self.bvurem(other)
except InconsistentSizeError as e:
raise e from None
except TypeError:
return NotImplemented
__mul__ = dispatch_oper(bvmul)
__rmul__ = dispatch_roper(__mul__)

__floordiv__ = dispatch_oper(bvudiv)
__rfloordiv__ = dispatch_roper(__floordiv__)

def __eq__(self, other):
try:
return self.bveq(other)
except InconsistentSizeError as e:
raise e from None
except TypeError:
return NotImplemented
__mod__ = dispatch_oper(bvurem)
__rmod__ = dispatch_roper(__mod__)

def __ne__(self, other):
try:
return self.bvne(other)
except InconsistentSizeError as e:
raise e from None
except TypeError:
return NotImplemented

def __ge__(self, other):
try:
return self.bvuge(other)
except InconsistentSizeError as e:
raise e from None
except TypeError:
return NotImplemented

def __gt__(self, other):
try:
return self.bvugt(other)
except InconsistentSizeError as e:
raise e from None
except TypeError:
return NotImplemented

def __le__(self, other):
try:
return self.bvule(other)
except InconsistentSizeError as e:
raise e from None
except TypeError:
return NotImplemented

def __lt__(self, other):
try:
return self.bvult(other)
except InconsistentSizeError as e:
raise e from None
except TypeError as e:
return NotImplemented
__eq__ = dispatch_oper(bveq)
__ne__ = dispatch_oper(AbstractBitVector.bvne)
__ge__ = dispatch_oper(AbstractBitVector.bvuge)
__gt__ = dispatch_oper(AbstractBitVector.bvugt)
__le__ = dispatch_oper(AbstractBitVector.bvule)
__lt__ = dispatch_oper(bvult)

def as_uint(self):
return self._value
Expand Down Expand Up @@ -565,6 +511,8 @@ def __rshift__(self, other):
except TypeError:
return NotImplemented

__rrshift__ = dispatch_roper(__rshift__)

def __floordiv__(self, other):
try:
return self.bvsdiv(other)
Expand All @@ -573,6 +521,8 @@ def __floordiv__(self, other):
except TypeError:
return NotImplemented

__rfloordiv__ = dispatch_roper(__floordiv__)

def __mod__(self, other):
try:
return self.bvsrem(other)
Expand All @@ -581,6 +531,8 @@ def __mod__(self, other):
except TypeError:
return NotImplemented

__rmod__ = dispatch_roper(__mod__)

def __ge__(self, other):
try:
return self.bvsge(other)
Expand Down
3 changes: 2 additions & 1 deletion hwtypes/bit_vector_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import inspect
import types

from .util import Method
from .bit_vector_abc import InconsistentSizeError
from .bit_vector_abc import BitVectorMeta, AbstractBitVector, AbstractBit

Expand Down Expand Up @@ -172,7 +173,7 @@ def VCall(*args, **kwargs):
if v0 is NotImplemented or v0 is NotImplemented:
return NotImplemented
return select.ite(v0, v1)
return VCall
return Method(VCall)


def get_branch_type(branch):
Expand Down

0 comments on commit 4948fe4

Please sign in to comment.