Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add utility functions for modifiers; Better type hierarchy #83

Merged
merged 1 commit into from
Jul 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 66 additions & 9 deletions hwtypes/modifiers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import types
import weakref

__ALL__ = ['new', 'make_modifier', 'is_modified', 'is_modifier', 'get_modifier', 'get_unmodified']


_DEBUG = False
#special sentinal value
class _MISSING: pass

Expand All @@ -17,24 +21,77 @@ class T(klass): pass
return T

class _ModifierMeta(type):
_modifier_lookup = weakref.WeakKeyDictionary()
def __instancecheck__(cls, obj):
return type(obj) in cls._sub_classes.values()
if cls is AbstractModifier:
return super().__instancecheck__(obj)
else:
return type(obj) in cls._sub_classes

def __subclasscheck__(cls, typ):
return typ in cls._sub_classes.values()
def __subclasscheck__(cls, T):
if cls is AbstractModifier:
return super().__subclasscheck__(T)
else:
return T in cls._sub_classes

def __call__(cls, *args):
if cls is AbstractModifier:
raise TypeError('Cannot instance or apply AbstractModifier')

if len(args) != 1:
return super().__call__(*args)
sub = args[0]

unmod_cls = args[0]
try:
return cls._sub_classes[sub]
return cls._sub_class_cache[unmod_cls]
except KeyError:
pass

mod_sub_name = cls.__name__ + sub.__name__
mod_sub = type(mod_sub_name, (sub,), {})
return cls._sub_classes.setdefault(sub, mod_sub)
mod_name = cls.__name__ + unmod_cls.__name__
bases = [unmod_cls]
for base in unmod_cls.__bases__:
bases.append(cls(base))
mod_cls = type(mod_name, tuple(bases), {})
cls._register_modified(unmod_cls, mod_cls)
return mod_cls

class AbstractModifier(metaclass=_ModifierMeta):
def __init_subclass__(cls, **kwargs):
cls._sub_class_cache = weakref.WeakValueDictionary()
cls._sub_classes = weakref.WeakSet()

@classmethod
def _register_modified(cls, unmod_cls, mod_cls):
type(cls)._modifier_lookup[mod_cls] = cls
cls._sub_classes.add(mod_cls)
cls._sub_class_cache[unmod_cls] = mod_cls
if _DEBUG:
# O(n) assert, but its a pretty key invariant
assert set(cls._sub_classes) == set(cls._sub_class_cache.values())

def is_modified(T):
return T in _ModifierMeta._modifier_lookup

def is_modifier(T):
return issubclass(T, AbstractModifier)

def get_modifier(T):
if is_modified(T):
return _ModifierMeta._modifier_lookup[T]
else:
raise TypeError(f'{T} has no modifiers')

def get_unmodified(T):
if is_modified(T):
unmod = T.__bases__[0]
if _DEBUG:
# Not an expensive assert but as there is a
# already a debug guard might as well use it.
mod = get_modifier(T)
assert mod._sub_class_cache[unmod] is T
return unmod
else:
raise TypeError(f'{T} has no modifiers')

_mod_cache = weakref.WeakValueDictionary()
# This is a factory for type modifiers.
Expand All @@ -45,7 +102,7 @@ def make_modifier(name, cache=False):
except KeyError:
pass

ModType = _ModifierMeta(name, (), {'_sub_classes' : weakref.WeakValueDictionary()})
ModType = _ModifierMeta(name, (AbstractModifier,), {})

if cache:
return _mod_cache.setdefault(name, ModType)
Expand Down
28 changes: 26 additions & 2 deletions tests/test_modifiers.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,46 @@
from hwtypes.modifiers import make_modifier
import pytest

from hwtypes import Bit, AbstractBit
import hwtypes.modifiers as modifiers
from hwtypes.modifiers import make_modifier, is_modified, is_modifier
from hwtypes.modifiers import get_modifier, get_unmodified

modifiers._DEBUG = True

def test_basic():
Global = make_modifier("Global")
GlobalBit = Global(Bit)

assert GlobalBit is Global(Bit)

assert issubclass(GlobalBit, Bit)
assert issubclass(GlobalBit, AbstractBit)
assert issubclass(GlobalBit, Global)
assert issubclass(GlobalBit, Global(AbstractBit))

global_bit = GlobalBit(0)

assert isinstance(global_bit, GlobalBit)
assert isinstance(global_bit, Bit)
assert isinstance(global_bit, AbstractBit)
assert isinstance(global_bit, Global)
assert isinstance(global_bit, Global(AbstractBit))

assert is_modifier(Global)
assert is_modified(GlobalBit)
assert not is_modifier(Bit)
assert not is_modified(Bit)
assert not is_modified(Global)

assert get_modifier(GlobalBit) is Global
assert get_unmodified(GlobalBit) is Bit

with pytest.raises(TypeError):
get_modifier(Bit)

with pytest.raises(TypeError):
get_unmodified(Bit)


def test_cache():
G1 = make_modifier("Global", cache=True)
Expand All @@ -24,4 +49,3 @@ def test_cache():

assert G1 is G2
assert G1 is not G3