diff --git a/hwtypes/adt.py b/hwtypes/adt.py index 1be6a4a..e65d70f 100644 --- a/hwtypes/adt.py +++ b/hwtypes/adt.py @@ -76,28 +76,6 @@ def value_dict(self): d[k] = getattr(self, k) return MappingProxyType(d) - @classmethod - def from_fields(cls, - class_name: str, - fields: tp.Mapping[str, type], - module: tp.Optional[str] = None, - qualname: tp.Optional[str] = None): - if cls.is_bound: - raise TypeError('Type must not be bound') - - ns = {} - - if module is None: - module = cls.__module__ - - if qualname is None: - qualname = class_name - - ns['__module__'] = module - ns['__qualname__'] = qualname - - return cls._from_fields(fields, class_name, (cls,), ns) - class Sum(metaclass=SumMeta): def __init__(self, value): if not isinstance(value, tuple(type(self).fields)): diff --git a/hwtypes/adt_meta.py b/hwtypes/adt_meta.py index 8559537..489409a 100644 --- a/hwtypes/adt_meta.py +++ b/hwtypes/adt_meta.py @@ -64,14 +64,22 @@ def __call__(cls, *args, **kwargs): def __new__(mcs, name, bases, namespace, fields=None, **kwargs): if '_fields_' in namespace: raise ReservedNameError('class attribute _fields_ is reserved by the type machinery') + if '_unbound_base_' in namespace: + raise ReservedNameError('class attribute _unbound_base_ is reserved by the type machinery') bound_types = fields + has_bound_base = False + unbound_bases = [] for base in bases: - if isinstance(base, BoundMeta) and base.is_bound: - if bound_types is None: - bound_types = base.fields - elif bound_types != base.fields: - raise TypeError("Can't inherit from multiple different bound_types") + if isinstance(base, BoundMeta): + if base.is_bound: + has_bound_base = True + if bound_types is None: + bound_types = base.fields + elif bound_types != base.fields: + raise TypeError("Can't inherit from multiple different bound_types") + else: + unbound_bases.append(base) if bound_types is not None: if '_fields_cb' in namespace: @@ -82,7 +90,19 @@ def __new__(mcs, name, bases, namespace, fields=None, **kwargs): bound_types = t._fields_cb(bound_types) namespace['_fields_'] = bound_types + namespace['_unbound_base_'] = None t = super().__new__(mcs, name, bases, namespace, **kwargs) + + if bound_types is None: + # t is a unbound type + t._unbound_base_ = t + elif len(unbound_bases) == 1: + # t is constructed from an unbound type + t._unbound_base_ = unbound_bases[0] + elif not has_bound_base: + # this shouldn't be reachable + raise AssertionError("Unreachable code") + return t def _fields_cb(cls, idx): @@ -128,9 +148,27 @@ def fields_dict(cls): def is_bound(cls) -> bool: return cls.fields is not None + @property + def unbound_t(cls) -> 'BoundMeta': + t = cls._unbound_base_ + if t is not None: + return t + else: + raise AttributeError(f'type {cls} has no unbound_t') + def __repr__(cls): return f"{cls.__name__}" + def rebind(cls, A : type, B : type): + new_fields = [] + for T in cls.fields: + if T == A: + new_fields.append(B) + elif isinstance(T, BoundMeta): + new_fields.append(T.rebind(A, B)) + else: + new_fields.append(T) + return cls.unbound_t[new_fields] class TupleMeta(BoundMeta): def __getitem__(cls, idx): @@ -255,8 +293,17 @@ def __init__(self, {type_sig}): exec(__init__, gs, ls) t.__init__ = ls['__init__'] + product_base = None + for base in bases: + if isinstance(base, mcs): + if product_base is None: + product_base = base + else: + raise TypeError('Can only inherit from one product type') + + if product_base is not None and not product_base.is_bound: + t._unbound_base_ = product_base - # Store the field indexs return t @@ -277,6 +324,38 @@ def __repr__(cls): def field_dict(cls): return MappingProxyType(cls._field_table_) + def from_fields(cls, + class_name: str, + fields: tp.Mapping[str, type], + module: tp.Optional[str] = None, + qualname: tp.Optional[str] = None): + if cls.is_bound: + raise TypeError('Type must not be bound') + + ns = {} + + if module is None: + module = cls.__module__ + + if qualname is None: + qualname = class_name + + ns['__module__'] = module + ns['__qualname__'] = qualname + + return cls._from_fields(fields, class_name, (cls,), ns) + + + def rebind(cls, A : type, B : type): + new_fields = OrderedDict() + for field, T in cls.field_dict.items(): + if T == A: + new_fields[field] = B + elif isinstance(T, BoundMeta): + new_fields[field] = T.rebind(A, B) + else: + new_fields[field] = T + return cls.unbound_t.from_fields(cls.__name__, new_fields, cls.__module__, cls.__qualname__) class SumMeta(BoundMeta): def _fields_cb(cls, idx): @@ -331,6 +410,17 @@ def __new__(mcs, cls_name, bases, namespace, **kwargs): t._fields_ = tuple(name_table.values()) + enum_base = None + for base in bases: + if isinstance(base, mcs): + if enum_base is None: + enum_base = base + else: + raise TypeError('Can only inherit from one enum type') + + if enum_base is not None: + t._unbound_base_ = enum_base + return t def __call__(cls, elem): @@ -344,3 +434,8 @@ def field_dict(cls): def enumerate(cls): yield from cls.fields + + def rebind(cls, A : type, B : type): + # Enums aren't bound to types + # could potentialy rebind values but that seems annoying + return cls diff --git a/tests/test_adt.py b/tests/test_adt.py index 6c71ca4..6549226 100644 --- a/tests/test_adt.py +++ b/tests/test_adt.py @@ -13,6 +13,11 @@ class En2(Enum): d = 1 +class En3(Enum): + e = 3 + f = 4 + + class Pr(Product): x = En1 y = En2 @@ -46,6 +51,7 @@ def test_enum(): assert issubclass(En1, Enum) assert isinstance(En1.a, Enum) assert isinstance(En1.a, En1) + assert En1.is_bound with pytest.raises(AttributeError): En1.a.b @@ -211,3 +217,47 @@ class _(T): ''' with pytest.raises(ReservedNameError): exec(cls_str, l_dict) + +@pytest.mark.parametrize("t, base", [ + (En1, Enum), + (Pr, Product), + (Su, Sum), + (Tu, Tuple), + ]) +def test_unbound_t(t, base): + assert t.unbound_t == base + class sub_t(t): pass + with pytest.raises(AttributeError): + sub_t.unbound_t + +@pytest.mark.parametrize("T", [Tu, Su, Pr]) +def test_rebind(T): + assert En1 in T.fields + assert En3 not in T.fields + T2 = T.rebind(En1, En3) + assert En1 not in T2.fields + assert En3 in T2.fields + + +class A: pass +class B: pass +class C: pass +class D: pass +class P1(Product): + A = A + B = B + +S1 = Sum[C, P1] + +class P2(Product): + S1 = S1 + C = C + +def test_rebind_recusrive(): + P3 = P2.rebind(A, D) + assert P3.S1.field_dict['P1'].A == D + assert P3.S1.field_dict['P1'].B == B + assert C in P3.S1.fields + P4 = P3.rebind(C, D) + assert P4.C == D + assert D in P4.S1.fields