Skip to content
Permalink
 
 
Cannot retrieve contributors at this time
192 lines (150 sloc) 6.05 KB
import operator
from numba.core import types
from .templates import (ConcreteTemplate, AbstractTemplate, AttributeTemplate,
CallableTemplate, Registry, signature, bound_function,
make_callable_template)
# Ensure set is typed as a collection as well
from numba.core.typing import collections
registry = Registry()
infer = registry.register
infer_global = registry.register_global
infer_getattr = registry.register_attr
@infer_global(set)
class SetBuiltin(AbstractTemplate):
def generic(self, args, kws):
assert not kws
if args:
# set(iterable)
iterable, = args
if isinstance(iterable, types.IterableType):
dtype = iterable.iterator_type.yield_type
if isinstance(dtype, types.Hashable):
return signature(types.Set(dtype), iterable)
else:
# set()
return signature(types.Set(types.undefined))
@infer_getattr
class SetAttribute(AttributeTemplate):
key = types.Set
@bound_function("set.add")
def resolve_add(self, set, args, kws):
item, = args
assert not kws
unified = self.context.unify_pairs(set.dtype, item)
if unified is not None:
sig = signature(types.none, unified)
sig = sig.replace(recvr=set.copy(dtype=unified))
return sig
@bound_function("set.clear")
def resolve_clear(self, set, args, kws):
assert not kws
if not args:
return signature(types.none)
@bound_function("set.copy")
def resolve_copy(self, set, args, kws):
assert not kws
if not args:
return signature(set)
@bound_function("set.discard")
def resolve_discard(self, set, args, kws):
item, = args
assert not kws
return signature(types.none, set.dtype)
@bound_function("set.pop")
def resolve_pop(self, set, args, kws):
assert not kws
if not args:
return signature(set.dtype)
@bound_function("set.remove")
def resolve_remove(self, set, args, kws):
item, = args
assert not kws
return signature(types.none, set.dtype)
@bound_function("set.update")
def resolve_update(self, set, args, kws):
iterable, = args
assert not kws
if not isinstance(iterable, types.IterableType):
return
dtype = iterable.iterator_type.yield_type
unified = self.context.unify_pairs(set.dtype, dtype)
if unified is not None:
sig = signature(types.none, iterable)
sig = sig.replace(recvr=set.copy(dtype=unified))
return sig
def _resolve_xxx_update(self, set, args, kws):
assert not kws
iterable, = args
# Set arguments only supported for now
# (note we can mix non-reflected and reflected arguments)
if isinstance(iterable, types.Set) and iterable.dtype == set.dtype:
return signature(types.none, iterable)
@bound_function("set.difference_update")
def resolve_difference_update(self, set, args, kws):
return self._resolve_xxx_update(set, args, kws)
@bound_function("set.intersection_update")
def resolve_intersection_update(self, set, args, kws):
return self._resolve_xxx_update(set, args, kws)
@bound_function("set.symmetric_difference_update")
def resolve_symmetric_difference_update(self, set, args, kws):
return self._resolve_xxx_update(set, args, kws)
def _resolve_operator(self, set, args, kws):
assert not kws
iterable, = args
# Set arguments only supported for now
# (note we can mix non-reflected and reflected arguments)
if isinstance(iterable, types.Set) and iterable.dtype == set.dtype:
return signature(set, iterable)
@bound_function("set.difference")
def resolve_difference(self, set, args, kws):
return self._resolve_operator(set, args, kws)
@bound_function("set.intersection")
def resolve_intersection(self, set, args, kws):
return self._resolve_operator(set, args, kws)
@bound_function("set.symmetric_difference")
def resolve_symmetric_difference(self, set, args, kws):
return self._resolve_operator(set, args, kws)
@bound_function("set.union")
def resolve_union(self, set, args, kws):
return self._resolve_operator(set, args, kws)
def _resolve_comparator(self, set, args, kws):
assert not kws
arg, = args
if arg == set:
return signature(types.boolean, arg)
@bound_function("set.isdisjoint")
def resolve_isdisjoint(self, set, args, kws):
return self._resolve_comparator(set, args, kws)
@bound_function("set.issubset")
def resolve_issubset(self, set, args, kws):
return self._resolve_comparator(set, args, kws)
@bound_function("set.issuperset")
def resolve_issuperset(self, set, args, kws):
return self._resolve_comparator(set, args, kws)
class SetOperator(AbstractTemplate):
def generic(self, args, kws):
if len(args) != 2:
return
a, b = args
if (isinstance(a, types.Set) and isinstance(b, types.Set)
and a.dtype == b.dtype):
return signature(a, *args)
class SetComparison(AbstractTemplate):
def generic(self, args, kws):
if len(args) != 2:
return
a, b = args
if isinstance(a, types.Set) and isinstance(b, types.Set) and a == b:
return signature(types.boolean, *args)
for op_key in (operator.add, operator.sub, operator.and_, operator.or_, operator.xor, operator.invert):
@infer_global(op_key)
class ConcreteSetOperator(SetOperator):
key = op_key
for op_key in (operator.iadd, operator.isub, operator.iand, operator.ior, operator.ixor):
@infer_global(op_key)
class ConcreteInplaceSetOperator(SetOperator):
key = op_key
for op_key in (operator.eq, operator.ne, operator.lt, operator.le, operator.ge, operator.gt):
@infer_global(op_key)
class ConcreteSetComparison(SetComparison):
key = op_key
You can’t perform that action at this time.