Skip to content

Commit

Permalink
Merge bcc7379 into 0650775
Browse files Browse the repository at this point in the history
  • Loading branch information
cdonovick committed Jul 10, 2019
2 parents 0650775 + bcc7379 commit c3ad699
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 17 deletions.
34 changes: 23 additions & 11 deletions hwtypes/adt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .adt_meta import TupleMeta, ProductMeta, SumMeta, EnumMeta, is_adt_type
from collections import OrderedDict
from types import MappingProxyType
import typing as tp

__all__ = ['Tuple', 'Product', 'Sum', 'Enum']
__all__ += ['new_instruction', 'is_adt_type']
Expand Down Expand Up @@ -65,17 +66,6 @@ def value_dict(self):


class Product(Tuple, metaclass=ProductMeta):
def __new__(cls, *args, **kwargs):
if cls.is_bound:
return super().__new__(cls, *args, **kwargs)
elif len(args) == 1:
#fields, name, bases, namespace
t = type(cls).from_fields(kwargs, args[0], (cls,), {})
return t

else:
raise TypeError('Cannot instance unbound product type')

def __repr__(self):
return f'{type(self).__name__}({", ".join(f"{k}={v}" for k,v in self.value_dict.items())})'

Expand All @@ -86,6 +76,28 @@ def value_dict(self):
d[k] = getattr(self, k)
return MappingProxyType(d)

@classmethod
def from_fields(cls,
class_name: str,
fields: tp.Mapping[str, type],
module: tp.Optional[str] = None,
qualname: tp.Optional[str] = None):
if cls.is_bound:
raise TypeError('Type must not be bound')

ns = {}

if module is None:
module = cls.__module__

if qualname is None:
qualname = class_name

ns['__module__'] = module
ns['__qualname__'] = qualname

return cls._from_fields(fields, class_name, (cls,), ns)

class Sum(metaclass=SumMeta):
def __init__(self, value):
if not isinstance(value, tuple(type(self).fields)):
Expand Down
10 changes: 6 additions & 4 deletions hwtypes/adt_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,6 @@ def __new__(mcs, name, bases, namespace, **kwargs):
for k, v in namespace.items():
if _is_sunder(k) or _is_dunder(k) or _is_descriptor(v):
ns[k] = v
elif k in _RESERVED_NAMES:
raise ReservedNameError(f'Field name {k} is reserved by the type machinery')
elif isinstance(v, type):
if k in fields:
raise TypeError(f'Conflicting definitions of field {k}')
Expand All @@ -180,19 +178,23 @@ def __new__(mcs, name, bases, namespace, **kwargs):
ns[k] = v

if fields:
return mcs.from_fields(fields, name, bases, ns, **kwargs)
return mcs._from_fields(fields, name, bases, ns, **kwargs)
else:
return super().__new__(mcs, name, bases, ns, **kwargs)

@classmethod
def from_fields(mcs, fields, name, bases, ns, **kwargs):
def _from_fields(mcs, fields, name, bases, ns, **kwargs):
# not strictly necessary could iterative over class dict finding
# TypedProperty to reconstruct _field_table_ but that seems bad
if '_field_table_' in ns:
raise ReservedNameError('class attribute _field_table_ is reserved by the type machinery')
else:
ns['_field_table_'] = OrderedDict()

for field in fields:
if field in _RESERVED_NAMES:
raise ReservedNameError(f'Field name {field} is reserved by the type machinery')

def _get_tuple_base(bases):
for base in bases:
if not isinstance(base, ProductMeta) and isinstance(base, TupleMeta):
Expand Down
32 changes: 30 additions & 2 deletions tests/test_adt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,33 @@ class En1(Enum):
a = 0
b = 1


class En2(Enum):
c = 0
d = 1


class Pr(Product):
x = En1
y = En2


class Pr2(Product):
x = En1
y = En2


class Pr3(Product):
y = En2
x = En1


Su = Sum[En1, Pr]


Tu = Tuple[En1, En2]


def test_enum():
assert set(En1.enumerate()) == {
En1.a,
Expand All @@ -43,6 +50,7 @@ def test_enum():
with pytest.raises(AttributeError):
En1.a.b


def test_tuple():
assert set(Tu.enumerate()) == {
Tu(En1.a, En2.c),
Expand Down Expand Up @@ -71,6 +79,7 @@ def test_tuple():
with pytest.raises(TypeError):
t[1] = 1


def test_product():
assert set(Pr.enumerate()) == {
Pr(En1.a, En2.c),
Expand Down Expand Up @@ -119,6 +128,26 @@ def test_product():
assert Pr.field_dict != Pr3.field_dict


def test_product_from_fields():
P = Product.from_fields('P', {'A' : int, 'B' : str})
assert issubclass(P, Product)
assert issubclass(P, Tuple[int, str])
assert P.A == int
assert P.B == str
assert P.__name__ == 'P'
assert P.__module__ == Product.__module__
assert P.__qualname__ == 'P'

P = Product.from_fields('P', {'A' : int, 'B' : str}, module='foo')
assert P.__module__ == 'foo'

P = Product.from_fields('P', {'A' : int, 'B' : str}, qualname='Foo.P')
assert P.__qualname__ == 'Foo.P'

with pytest.raises(TypeError):
Pr.from_fields('P', {'A' : int, 'B' : str})


def test_sum():
assert set(Su.enumerate()) == {
Su(En1.a),
Expand Down Expand Up @@ -159,6 +188,7 @@ def test_new():
t = new(Sum, (En1, Pr), module=__name__)
assert t.__module__ == __name__


@pytest.mark.parametrize("T", [En1, Tu, Su, Pr])
def test_repr(T):
s = repr(T)
Expand All @@ -181,5 +211,3 @@ class _(T):
'''
with pytest.raises(ReservedNameError):
exec(cls_str, l_dict)


0 comments on commit c3ad699

Please sign in to comment.