Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rebind #66

Merged
merged 4 commits into from
Jul 16, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 0 additions & 22 deletions hwtypes/adt.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,28 +76,6 @@ def value_dict(self):
d[k] = getattr(self, k)
return MappingProxyType(d)

@classmethod
def from_fields(cls,
class_name: str,
fields: tp.Mapping[str, type],
module: tp.Optional[str] = None,
qualname: tp.Optional[str] = None):
if cls.is_bound:
raise TypeError('Type must not be bound')

ns = {}

if module is None:
module = cls.__module__

if qualname is None:
qualname = class_name

ns['__module__'] = module
ns['__qualname__'] = qualname

return cls._from_fields(fields, class_name, (cls,), ns)

class Sum(metaclass=SumMeta):
def __init__(self, value):
if not isinstance(value, tuple(type(self).fields)):
Expand Down
107 changes: 101 additions & 6 deletions hwtypes/adt_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,22 @@ def __call__(cls, *args, **kwargs):
def __new__(mcs, name, bases, namespace, fields=None, **kwargs):
if '_fields_' in namespace:
raise ReservedNameError('class attribute _fields_ is reserved by the type machinery')
if '_unbound_base_' in namespace:
raise ReservedNameError('class attribute _unbound_base_ is reserved by the type machinery')

bound_types = fields
has_bound_base = False
unbound_bases = []
for base in bases:
if isinstance(base, BoundMeta) and base.is_bound:
if bound_types is None:
bound_types = base.fields
elif bound_types != base.fields:
raise TypeError("Can't inherit from multiple different bound_types")
if isinstance(base, BoundMeta):
if base.is_bound:
has_bound_base = True
if bound_types is None:
bound_types = base.fields
elif bound_types != base.fields:
raise TypeError("Can't inherit from multiple different bound_types")
else:
unbound_bases.append(base)

if bound_types is not None:
if '_fields_cb' in namespace:
Expand All @@ -82,7 +90,19 @@ def __new__(mcs, name, bases, namespace, fields=None, **kwargs):
bound_types = t._fields_cb(bound_types)

namespace['_fields_'] = bound_types
namespace['_unbound_base_'] = None
t = super().__new__(mcs, name, bases, namespace, **kwargs)

if bound_types is None:
# t is a unbound type
t._unbound_base_ = t
elif len(unbound_bases) == 1:
# t is constructed from an unbound type
t._unbound_base_ = unbound_bases[0]
elif not has_bound_base:
# this shouldn't be reachable
raise AssertionError("Unreachable code")

return t

def _fields_cb(cls, idx):
Expand Down Expand Up @@ -128,9 +148,27 @@ def fields_dict(cls):
def is_bound(cls) -> bool:
return cls.fields is not None

@property
def unbound_t(cls) -> 'BoundMeta':
t = cls._unbound_base_
if t is not None:
return t
else:
raise AttributeError(f'type {cls} has no unbound_t')

def __repr__(cls):
return f"{cls.__name__}"

def rebind(cls, A : type, B : type):
new_fields = []
for T in cls.fields:
if T == A:
new_fields.append(B)
elif isinstance(T, BoundMeta):
new_fields.append(T.rebind(A, B))
else:
new_fields.append(T)
return cls.unbound_t[new_fields]

class TupleMeta(BoundMeta):
def __getitem__(cls, idx):
Expand Down Expand Up @@ -255,8 +293,17 @@ def __init__(self, {type_sig}):
exec(__init__, gs, ls)
t.__init__ = ls['__init__']

product_base = None
for base in bases:
if isinstance(base, mcs):
if product_base is None:
product_base = base
else:
raise TypeError('Can only inherit from one product type')

if product_base is not None and not product_base.is_bound:
t._unbound_base_ = product_base

# Store the field indexs
return t


Expand All @@ -277,6 +324,38 @@ def __repr__(cls):
def field_dict(cls):
return MappingProxyType(cls._field_table_)

def from_fields(cls,
class_name: str,
fields: tp.Mapping[str, type],
module: tp.Optional[str] = None,
qualname: tp.Optional[str] = None):
if cls.is_bound:
raise TypeError('Type must not be bound')

ns = {}

if module is None:
module = cls.__module__

if qualname is None:
qualname = class_name

ns['__module__'] = module
ns['__qualname__'] = qualname

return cls._from_fields(fields, class_name, (cls,), ns)


def rebind(cls, A : type, B : type):
new_fields = OrderedDict()
for field, T in cls.field_dict.items():
if T == A:
new_fields[field] = B
elif isinstance(T, BoundMeta):
new_fields[field] = T.rebind(A, B)
else:
new_fields[field] = T
return cls.unbound_t.from_fields(cls.__name__, new_fields, cls.__module__, cls.__qualname__)

class SumMeta(BoundMeta):
def _fields_cb(cls, idx):
Expand Down Expand Up @@ -331,6 +410,17 @@ def __new__(mcs, cls_name, bases, namespace, **kwargs):

t._fields_ = tuple(name_table.values())

enum_base = None
for base in bases:
if isinstance(base, mcs):
if enum_base is None:
enum_base = base
else:
raise TypeError('Can only inherit from one enum type')

if enum_base is not None:
t._unbound_base_ = enum_base

return t

def __call__(cls, elem):
Expand All @@ -344,3 +434,8 @@ def field_dict(cls):

def enumerate(cls):
yield from cls.fields

def rebind(cls, A : type, B : type):
# Enums aren't bound to types
# could potentialy rebind values but that seems annoying
return cls
50 changes: 50 additions & 0 deletions tests/test_adt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ class En2(Enum):
d = 1


class En3(Enum):
e = 3
f = 4


class Pr(Product):
x = En1
y = En2
Expand Down Expand Up @@ -46,6 +51,7 @@ def test_enum():
assert issubclass(En1, Enum)
assert isinstance(En1.a, Enum)
assert isinstance(En1.a, En1)
assert En1.is_bound

with pytest.raises(AttributeError):
En1.a.b
Expand Down Expand Up @@ -211,3 +217,47 @@ class _(T):
'''
with pytest.raises(ReservedNameError):
exec(cls_str, l_dict)

@pytest.mark.parametrize("t, base", [
(En1, Enum),
(Pr, Product),
(Su, Sum),
(Tu, Tuple),
])
def test_unbound_t(t, base):
assert t.unbound_t == base
class sub_t(t): pass
with pytest.raises(AttributeError):
sub_t.unbound_t

@pytest.mark.parametrize("T", [Tu, Su, Pr])
def test_rebind(T):
assert En1 in T.fields
assert En3 not in T.fields
T2 = T.rebind(En1, En3)
assert En1 not in T2.fields
assert En3 in T2.fields


class A: pass
class B: pass
class C: pass
class D: pass
class P1(Product):
A = A
B = B

S1 = Sum[C, P1]

class P2(Product):
S1 = S1
C = C

def test_rebind_recusrive():
P3 = P2.rebind(A, D)
assert P3.S1.field_dict['P1'].A == D
assert P3.S1.field_dict['P1'].B == B
assert C in P3.S1.fields
P4 = P3.rebind(C, D)
assert P4.C == D
assert D in P4.S1.fields