Skip to content
Permalink
 
 
Cannot retrieve contributors at this time
64 lines (46 sloc) 1.47 KB
"""
Typing for enums.
"""
import operator
from numba.core import types
from numba.core.typing.templates import (AbstractTemplate, AttributeTemplate,
signature, Registry)
registry = Registry()
infer = registry.register
infer_global = registry.register_global
infer_getattr = registry.register_attr
@infer_getattr
class EnumAttribute(AttributeTemplate):
key = types.EnumMember
def resolve_value(self, ty):
return ty.dtype
@infer_getattr
class EnumClassAttribute(AttributeTemplate):
key = types.EnumClass
def generic_resolve(self, ty, attr):
"""
Resolve attributes of an enum class as enum members.
"""
if attr in ty.instance_class.__members__:
return ty.member_type
@infer
class EnumClassStaticGetItem(AbstractTemplate):
key = "static_getitem"
def generic(self, args, kws):
enum, idx = args
if (isinstance(enum, types.EnumClass)
and idx in enum.instance_class.__members__):
return signature(enum.member_type, *args)
class EnumCompare(AbstractTemplate):
def generic(self, args, kws):
[lhs, rhs] = args
if (isinstance(lhs, types.EnumMember)
and isinstance(rhs, types.EnumMember)
and lhs == rhs):
return signature(types.boolean, lhs, rhs)
@infer_global(operator.eq)
class EnumEq(EnumCompare):
pass
@infer_global(operator.ne)
class EnumNe(EnumCompare):
pass
You can’t perform that action at this time.