Skip to content
Permalink
 
 
Cannot retrieve contributors at this time
268 lines (212 sloc) 7.95 KB
"""
Typing declarations for np.timedelta64.
"""
from itertools import product
import operator
from numba.core import types
from numba.core.typing.templates import (AttributeTemplate, ConcreteTemplate,
AbstractTemplate, infer_global, infer,
infer_getattr, signature)
from numba.np import npdatetime_helpers
# timedelta64-only operations
class TimedeltaUnaryOp(AbstractTemplate):
def generic(self, args, kws):
if len(args) == 2:
# Guard against binary + and -
return
op, = args
if not isinstance(op, types.NPTimedelta):
return
return signature(op, op)
class TimedeltaBinOp(AbstractTemplate):
def generic(self, args, kws):
if len(args) == 1:
# Guard against unary + and -
return
left, right = args
if not all(isinstance(tp, types.NPTimedelta) for tp in args):
return
if npdatetime_helpers.can_cast_timedelta_units(left.unit, right.unit):
return signature(right, left, right)
elif npdatetime_helpers.can_cast_timedelta_units(right.unit, left.unit):
return signature(left, left, right)
class TimedeltaCmpOp(AbstractTemplate):
def generic(self, args, kws):
# For equality comparisons, all units are inter-comparable
left, right = args
if not all(isinstance(tp, types.NPTimedelta) for tp in args):
return
return signature(types.boolean, left, right)
class TimedeltaOrderedCmpOp(AbstractTemplate):
def generic(self, args, kws):
# For ordered comparisons, units must be compatible
left, right = args
if not all(isinstance(tp, types.NPTimedelta) for tp in args):
return
if (npdatetime_helpers.can_cast_timedelta_units(left.unit, right.unit) or
npdatetime_helpers.can_cast_timedelta_units(right.unit, left.unit)):
return signature(types.boolean, left, right)
class TimedeltaMixOp(AbstractTemplate):
def generic(self, args, kws):
"""
(timedelta64, {int, float}) -> timedelta64
({int, float}, timedelta64) -> timedelta64
"""
left, right = args
if isinstance(right, types.NPTimedelta):
td, other = right, left
sig_factory = lambda other: signature(td, other, td)
elif isinstance(left, types.NPTimedelta):
td, other = left, right
sig_factory = lambda other: signature(td, td, other)
else:
return
if not isinstance(other, (types.Float, types.Integer)):
return
# Force integer types to convert to signed because it matches
# timedelta64 semantics better.
if isinstance(other, types.Integer):
other = types.int64
return sig_factory(other)
class TimedeltaDivOp(AbstractTemplate):
def generic(self, args, kws):
"""
(timedelta64, {int, float}) -> timedelta64
(timedelta64, timedelta64) -> float
"""
left, right = args
if not isinstance(left, types.NPTimedelta):
return
if isinstance(right, types.NPTimedelta):
if (npdatetime_helpers.can_cast_timedelta_units(left.unit, right.unit)
or npdatetime_helpers.can_cast_timedelta_units(right.unit, left.unit)):
return signature(types.float64, left, right)
elif isinstance(right, (types.Float)):
return signature(left, left, right)
elif isinstance(right, (types.Integer)):
# Force integer types to convert to signed because it matches
# timedelta64 semantics better.
return signature(left, left, types.int64)
@infer_global(operator.pos)
class TimedeltaUnaryPos(TimedeltaUnaryOp):
key = operator.pos
@infer_global(operator.neg)
class TimedeltaUnaryNeg(TimedeltaUnaryOp):
key = operator.neg
@infer_global(operator.add)
@infer_global(operator.iadd)
class TimedeltaBinAdd(TimedeltaBinOp):
key = operator.add
@infer_global(operator.sub)
@infer_global(operator.isub)
class TimedeltaBinSub(TimedeltaBinOp):
key = operator.sub
@infer_global(operator.mul)
@infer_global(operator.imul)
class TimedeltaBinMult(TimedeltaMixOp):
key = operator.mul
@infer_global(operator.truediv)
@infer_global(operator.itruediv)
class TimedeltaTrueDiv(TimedeltaDivOp):
key = operator.truediv
@infer_global(operator.floordiv)
@infer_global(operator.ifloordiv)
class TimedeltaFloorDiv(TimedeltaDivOp):
key = operator.floordiv
@infer_global(operator.eq)
class TimedeltaCmpEq(TimedeltaCmpOp):
key = operator.eq
@infer_global(operator.ne)
class TimedeltaCmpNe(TimedeltaCmpOp):
key = operator.ne
@infer_global(operator.lt)
class TimedeltaCmpLt(TimedeltaOrderedCmpOp):
key = operator.lt
@infer_global(operator.le)
class TimedeltaCmpLE(TimedeltaOrderedCmpOp):
key = operator.le
@infer_global(operator.gt)
class TimedeltaCmpGt(TimedeltaOrderedCmpOp):
key = operator.gt
@infer_global(operator.ge)
class TimedeltaCmpGE(TimedeltaOrderedCmpOp):
key = operator.ge
@infer_global(abs)
class TimedeltaAbs(TimedeltaUnaryOp):
pass
# datetime64 operations
@infer_global(operator.add)
@infer_global(operator.iadd)
class DatetimePlusTimedelta(AbstractTemplate):
key = operator.add
def generic(self, args, kws):
if len(args) == 1:
# Guard against unary +
return
left, right = args
if isinstance(right, types.NPTimedelta):
dt = left
td = right
elif isinstance(left, types.NPTimedelta):
dt = right
td = left
else:
return
if isinstance(dt, types.NPDatetime):
unit = npdatetime_helpers.combine_datetime_timedelta_units(dt.unit,
td.unit)
if unit is not None:
return signature(types.NPDatetime(unit), left, right)
@infer_global(operator.sub)
@infer_global(operator.isub)
class DatetimeMinusTimedelta(AbstractTemplate):
key = operator.sub
def generic(self, args, kws):
if len(args) == 1:
# Guard against unary -
return
dt, td = args
if isinstance(dt, types.NPDatetime) and isinstance(td,
types.NPTimedelta):
unit = npdatetime_helpers.combine_datetime_timedelta_units(dt.unit,
td.unit)
if unit is not None:
return signature(types.NPDatetime(unit), dt, td)
@infer_global(operator.sub)
class DatetimeMinusDatetime(AbstractTemplate):
key = operator.sub
def generic(self, args, kws):
if len(args) == 1:
# Guard against unary -
return
left, right = args
if isinstance(left, types.NPDatetime) and isinstance(right,
types.NPDatetime):
# All units compatible...
unit = npdatetime_helpers.get_best_unit(left.unit, right.unit)
return signature(types.NPTimedelta(unit), left, right)
class DatetimeCmpOp(AbstractTemplate):
def generic(self, args, kws):
# For datetime64 comparisons, all units are inter-comparable
left, right = args
if not all(isinstance(tp, types.NPDatetime) for tp in args):
return
return signature(types.boolean, left, right)
@infer_global(operator.eq)
class DatetimeCmpEq(DatetimeCmpOp):
key = operator.eq
@infer_global(operator.ne)
class DatetimeCmpNe(DatetimeCmpOp):
key = operator.ne
@infer_global(operator.lt)
class DatetimeCmpLt(DatetimeCmpOp):
key = operator.lt
@infer_global(operator.le)
class DatetimeCmpLE(DatetimeCmpOp):
key = operator.le
@infer_global(operator.gt)
class DatetimeCmpGt(DatetimeCmpOp):
key = operator.gt
@infer_global(operator.ge)
class DatetimeCmpGE(DatetimeCmpOp):
key = operator.ge
You can’t perform that action at this time.