Skip to content

Commit

Permalink
Merge f3b1f75 into 692621c
Browse files Browse the repository at this point in the history
  • Loading branch information
rdaly525 committed Aug 5, 2019
2 parents 692621c + f3b1f75 commit 2124e3b
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 5 deletions.
26 changes: 24 additions & 2 deletions hwtypes/adt_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,17 @@
from .bit_vector_abc import AbstractBitVectorMeta, AbstractBitVector

from .util import _issubclass
from hwtypes.modifiers import unwrap_modifier, wrap_modifier, is_modified

def rebind_bitvector(
adt,
bv_type_0: AbstractBitVectorMeta,
bv_type_1: AbstractBitVectorMeta):
bv_type_1: AbstractBitVectorMeta,
keep_modifiers=False):
if keep_modifiers and is_modified(adt):
unmod, mods = unwrap_modifier(adt)
return wrap_modifier(rebind_bitvector(unmod,bv_type_0,bv_type_1,True),mods)

if _issubclass(adt, bv_type_0):
if adt.is_sized:
return bv_type_1[adt.size]
Expand All @@ -15,7 +21,23 @@ def rebind_bitvector(
elif isinstance(adt, BoundMeta):
new_adt = adt
for field in adt.fields:
new_field = rebind_bitvector(field, bv_type_0, bv_type_1)
new_field = rebind_bitvector(field, bv_type_0, bv_type_1,keep_modifiers)
new_adt = new_adt.rebind(field, new_field)
return new_adt
else:
return adt

def rebind_keep_modifiers(adt, A, B):
if is_modified(adt):
unmod, mods = unwrap_modifier(adt)
return wrap_modifier(rebind_keep_modifiers(unmod,A,B),mods)

if _issubclass(adt,A):
return B
elif isinstance(adt, BoundMeta):
new_adt = adt
for field in adt.fields:
new_field = rebind_keep_modifiers(field, A, B)
new_adt = new_adt.rebind(field, new_field)
return new_adt
else:
Expand Down
20 changes: 20 additions & 0 deletions hwtypes/modifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ def get_unmodified(T):
else:
raise TypeError(f'{T} has no modifiers')

def get_all_modifiers(T):
if is_modified(T):
yield from get_all_modifiers(get_unmodified(T))
yield get_modifier(T)

_mod_cache = weakref.WeakValueDictionary()
# This is a factory for type modifiers.
def make_modifier(name, cache=False):
Expand All @@ -108,3 +113,18 @@ def make_modifier(name, cache=False):
return _mod_cache.setdefault(name, ModType)

return ModType

def unwrap_modifier(T):
if not is_modified(T):
return T, []
mod = get_modifier(T)
unmod = get_unmodified(T)
unmod, mods = unwrap_modifier(unmod)
mods.append(mod)
return unmod, mods

def wrap_modifier(T, mods):
wrapped = T
for mod in mods:
wrapped = mod(wrapped)
return wrapped
12 changes: 11 additions & 1 deletion tests/test_modifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from hwtypes import Bit, AbstractBit
import hwtypes.modifiers as modifiers
from hwtypes.modifiers import make_modifier, is_modified, is_modifier
from hwtypes.modifiers import make_modifier, is_modified, is_modifier, unwrap_modifier, wrap_modifier
from hwtypes.modifiers import get_modifier, get_unmodified
from hwtypes.adt import Tuple, Product, Sum

Expand Down Expand Up @@ -72,3 +72,13 @@ def test_cache():

assert G1 is G2
assert G1 is not G3

def test_nested():
A = make_modifier("A")
B = make_modifier("B")
C = make_modifier("C")
ABCBit = C(B(A(Bit)))
base, mods = unwrap_modifier(ABCBit)
assert base is Bit
assert mods == [A,B,C]
assert wrap_modifier(Bit,mods) == ABCBit
17 changes: 15 additions & 2 deletions tests/test_rebind.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import pytest

from hwtypes.adt import Product, Sum, Enum, Tuple
from hwtypes.adt_util import rebind_bitvector
from hwtypes.adt_util import rebind_bitvector, rebind_keep_modifiers
from hwtypes.bit_vector import AbstractBitVector, BitVector, AbstractBit, Bit
from hwtypes.smt_bit_vector import SMTBit
from hwtypes.smt_bit_vector import SMTBit, SMTBitVector
from hwtypes.util import _issubclass
from hwtypes.modifiers import make_modifier

class A: pass
class B: pass
Expand Down Expand Up @@ -147,3 +148,15 @@ class A(Product):

A_smt = A.rebind(AbstractBit, SMTBit, True)
assert A_smt.a is SMTBit

def test_rebind_mod():
M = make_modifier("M")
class A(Product):
a=M(Bit)
b=M(BitVector[4])

A_smt = rebind_bitvector(A,AbstractBitVector, SMTBitVector, True)
A_smt = rebind_keep_modifiers(A_smt, AbstractBit, SMTBit)
assert A_smt.b == M(SMTBitVector[4])
assert A_smt.a == M(SMTBit)
test_rebind_mod()

0 comments on commit 2124e3b

Please sign in to comment.