Skip to content

Commit

Permalink
Merge 1e7d6e6 into db1076c
Browse files Browse the repository at this point in the history
  • Loading branch information
cdonovick committed Jul 15, 2019
2 parents db1076c + 1e7d6e6 commit 51f1b7b
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 28 deletions.
22 changes: 0 additions & 22 deletions hwtypes/adt.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,28 +76,6 @@ def value_dict(self):
d[k] = getattr(self, k)
return MappingProxyType(d)

@classmethod
def from_fields(cls,
class_name: str,
fields: tp.Mapping[str, type],
module: tp.Optional[str] = None,
qualname: tp.Optional[str] = None):
if cls.is_bound:
raise TypeError('Type must not be bound')

ns = {}

if module is None:
module = cls.__module__

if qualname is None:
qualname = class_name

ns['__module__'] = module
ns['__qualname__'] = qualname

return cls._from_fields(fields, class_name, (cls,), ns)

class Sum(metaclass=SumMeta):
def __init__(self, value):
if not isinstance(value, tuple(type(self).fields)):
Expand Down
107 changes: 101 additions & 6 deletions hwtypes/adt_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,22 @@ def __call__(cls, *args, **kwargs):
def __new__(mcs, name, bases, namespace, fields=None, **kwargs):
if '_fields_' in namespace:
raise ReservedNameError('class attribute _fields_ is reserved by the type machinery')
if '_unbound_base_' in namespace:
raise ReservedNameError('class attribute _unbound_base_ is reserved by the type machinery')

bound_types = fields
has_bound_base = False
unbound_bases = []
for base in bases:
if isinstance(base, BoundMeta) and base.is_bound:
if bound_types is None:
bound_types = base.fields
elif bound_types != base.fields:
raise TypeError("Can't inherit from multiple different bound_types")
if isinstance(base, BoundMeta):
if base.is_bound:
has_bound_base = True
if bound_types is None:
bound_types = base.fields
elif bound_types != base.fields:
raise TypeError("Can't inherit from multiple different bound_types")
else:
unbound_bases.append(base)

if bound_types is not None:
if '_fields_cb' in namespace:
Expand All @@ -82,7 +90,19 @@ def __new__(mcs, name, bases, namespace, fields=None, **kwargs):
bound_types = t._fields_cb(bound_types)

namespace['_fields_'] = bound_types
namespace['_unbound_base_'] = None
t = super().__new__(mcs, name, bases, namespace, **kwargs)

if bound_types is None:
# t is a unbound type
t._unbound_base_ = t
elif len(unbound_bases) == 1:
# t is constructed from an unbound type
t._unbound_base_ = unbound_bases[0]
elif not has_bound_base:
# this shouldn't be reachable
raise AssertionError("Unreachable code")

return t

def _fields_cb(cls, idx):
Expand Down Expand Up @@ -128,9 +148,27 @@ def fields_dict(cls):
def is_bound(cls) -> bool:
return cls.fields is not None

@property
def unbound_t(cls) -> 'BoundMeta':
t = cls._unbound_base_
if t is not None:
return t
else:
raise AttributeError(f'type {cls} has no unbound_t')

def __repr__(cls):
return f"{cls.__name__}"

def rebind(cls, A : type, B : type):
new_fields = []
for T in cls.fields:
if T == A:
new_fields.append(B)
elif isinstance(T, BoundMeta):
new_fields.append(T.rebind(A, B))
else:
new_fields.append(T)
return cls.unbound_t[new_fields]

class TupleMeta(BoundMeta):
def __getitem__(cls, idx):
Expand Down Expand Up @@ -255,8 +293,17 @@ def __init__(self, {type_sig}):
exec(__init__, gs, ls)
t.__init__ = ls['__init__']

product_base = None
for base in bases:
if isinstance(base, mcs):
if product_base is None:
product_base = base
else:
raise TypeError('Can only inherit from one product type')

if product_base is not None and not product_base.is_bound:
t._unbound_base_ = product_base

# Store the field indexs
return t


Expand All @@ -277,6 +324,38 @@ def __repr__(cls):
def field_dict(cls):
return MappingProxyType(cls._field_table_)

def from_fields(cls,
class_name: str,
fields: tp.Mapping[str, type],
module: tp.Optional[str] = None,
qualname: tp.Optional[str] = None):
if cls.is_bound:
raise TypeError('Type must not be bound')

ns = {}

if module is None:
module = cls.__module__

if qualname is None:
qualname = class_name

ns['__module__'] = module
ns['__qualname__'] = qualname

return cls._from_fields(fields, class_name, (cls,), ns)


def rebind(cls, A : type, B : type):
new_fields = OrderedDict()
for field, T in cls.field_dict.items():
if T == A:
new_fields[field] = B
elif isinstance(T, BoundMeta):
new_fields[field] = T.rebind(A, B)
else:
new_fields[field] = T
return cls.unbound_t.from_fields(cls.__name__, new_fields, cls.__module__, cls.__qualname__)

class SumMeta(BoundMeta):
def _fields_cb(cls, idx):
Expand Down Expand Up @@ -331,6 +410,17 @@ def __new__(mcs, cls_name, bases, namespace, **kwargs):

t._fields_ = tuple(name_table.values())

enum_base = None
for base in bases:
if isinstance(base, mcs):
if enum_base is None:
enum_base = base
else:
raise TypeError('Can only inherit from one enum type')

if enum_base is not None:
t._unbound_base_ = enum_base

return t

def __call__(cls, elem):
Expand All @@ -344,3 +434,8 @@ def field_dict(cls):

def enumerate(cls):
yield from cls.fields

def rebind(cls, A : type, B : type):
# Enums aren't bound to types
# could potentialy rebind values but that seems annoying
return cls
50 changes: 50 additions & 0 deletions tests/test_adt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ class En2(Enum):
d = 1


class En3(Enum):
e = 3
f = 4


class Pr(Product):
x = En1
y = En2
Expand Down Expand Up @@ -46,6 +51,7 @@ def test_enum():
assert issubclass(En1, Enum)
assert isinstance(En1.a, Enum)
assert isinstance(En1.a, En1)
assert En1.is_bound

with pytest.raises(AttributeError):
En1.a.b
Expand Down Expand Up @@ -211,3 +217,47 @@ class _(T):
'''
with pytest.raises(ReservedNameError):
exec(cls_str, l_dict)

@pytest.mark.parametrize("t, base", [
(En1, Enum),
(Pr, Product),
(Su, Sum),
(Tu, Tuple),
])
def test_unbound_t(t, base):
assert t.unbound_t == base
class sub_t(t): pass
with pytest.raises(AttributeError):
sub_t.unbound_t

@pytest.mark.parametrize("T", [Tu, Su, Pr])
def test_rebind(T):
assert En1 in T.fields
assert En3 not in T.fields
T2 = T.rebind(En1, En3)
assert En1 not in T2.fields
assert En3 in T2.fields


class A: pass
class B: pass
class C: pass
class D: pass
class P1(Product):
A = A
B = B

S1 = Sum[C, P1]

class P2(Product):
S1 = S1
C = C

def test_rebind_recusrive():
P3 = P2.rebind(A, D)
assert P3.S1.field_dict['P1'].A == D
assert P3.S1.field_dict['P1'].B == B
assert C in P3.S1.fields
P4 = P3.rebind(C, D)
assert P4.C == D
assert D in P4.S1.fields

0 comments on commit 51f1b7b

Please sign in to comment.