Skip to content
Permalink
 
 
Cannot retrieve contributors at this time
"""
Typing for enums.
"""
import operator
from numba.core import types
from numba.core.typing.templates import (AbstractTemplate, AttributeTemplate,
signature, Registry)
registry = Registry()
infer = registry.register
infer_global = registry.register_global
infer_getattr = registry.register_attr
@infer_getattr
class EnumAttribute(AttributeTemplate):
key = types.EnumMember
def resolve_value(self, ty):
return ty.dtype
@infer_getattr
class EnumClassAttribute(AttributeTemplate):
key = types.EnumClass
def generic_resolve(self, ty, attr):
"""
Resolve attributes of an enum class as enum members.
"""
if attr in ty.instance_class.__members__:
return ty.member_type
@infer
class EnumClassStaticGetItem(AbstractTemplate):
key = "static_getitem"
def generic(self, args, kws):
enum, idx = args
if (isinstance(enum, types.EnumClass)
and idx in enum.instance_class.__members__):
return signature(enum.member_type, *args)
class EnumCompare(AbstractTemplate):
def generic(self, args, kws):
[lhs, rhs] = args
if (isinstance(lhs, types.EnumMember)
and isinstance(rhs, types.EnumMember)
and lhs == rhs):
return signature(types.boolean, lhs, rhs)
@infer_global(operator.eq)
class EnumEq(EnumCompare):
pass
@infer_global(operator.ne)
class EnumNe(EnumCompare):
pass