Typing for enums.
import operator
from numba.core import types
from numba.core.typing.templates import (AbstractTemplate, AttributeTemplate,
signature, Registry)
registry = Registry()
infer = registry.register
infer_global = registry.register_global
infer_getattr = registry.register_attr
class EnumAttribute(AttributeTemplate):
key = types.EnumMember
def resolve_value(self, ty):
return ty.dtype
class EnumClassAttribute(AttributeTemplate):
key = types.EnumClass
def generic_resolve(self, ty, attr):
Resolve attributes of an enum class as enum members.
if attr in ty.instance_class.__members__:
return ty.member_type
class EnumClassStaticGetItem(AbstractTemplate):
key = "static_getitem"
def generic(self, args, kws):
enum, idx = args
if (isinstance(enum, types.EnumClass)
and idx in enum.instance_class.__members__):
return signature(enum.member_type, *args)
class EnumCompare(AbstractTemplate):
def generic(self, args, kws):
[lhs, rhs] = args
if (isinstance(lhs, types.EnumMember)
and isinstance(rhs, types.EnumMember)
and lhs == rhs):
return signature(types.boolean, lhs, rhs)
class EnumEq(EnumCompare):
class EnumNe(EnumCompare):