Skip to content
Cannot retrieve contributors at this time
64 lines (46 sloc) 1.47 KB
Typing for enums.
import operator
from numba.core import types
from numba.core.typing.templates import (AbstractTemplate, AttributeTemplate,
signature, Registry)
registry = Registry()
infer = registry.register
infer_global = registry.register_global
infer_getattr = registry.register_attr
class EnumAttribute(AttributeTemplate):
key = types.EnumMember
def resolve_value(self, ty):
return ty.dtype
class EnumClassAttribute(AttributeTemplate):
key = types.EnumClass
def generic_resolve(self, ty, attr):
Resolve attributes of an enum class as enum members.
if attr in ty.instance_class.__members__:
return ty.member_type
class EnumClassStaticGetItem(AbstractTemplate):
key = "static_getitem"
def generic(self, args, kws):
enum, idx = args
if (isinstance(enum, types.EnumClass)
and idx in enum.instance_class.__members__):
return signature(enum.member_type, *args)
class EnumCompare(AbstractTemplate):
def generic(self, args, kws):
[lhs, rhs] = args
if (isinstance(lhs, types.EnumMember)
and isinstance(rhs, types.EnumMember)
and lhs == rhs):
return signature(types.boolean, lhs, rhs)
class EnumEq(EnumCompare):
class EnumNe(EnumCompare):
You can’t perform that action at this time.