diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 67d0106684..5d310e6f0f 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -57,6 +57,8 @@ def aten_acosh(self: TFloat) -> TFloat: @torch_op("aten::add") def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: # add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor + # TODO(titaiwang): Delete this when we have type promotion + other = op.CastLike(other, self) alpha = op.CastLike(alpha, other) other = op.Mul(other, alpha) return op.Add(self, other) @@ -765,10 +767,18 @@ def aten_bitwise_left_shift(self: TInt, other: TInt) -> TInt: @torch_op("aten::bitwise_not") def aten_bitwise_not(self: TInt) -> TInt: # bitwise_not(Tensor self) -> Tensor - + # TODO(titaiwang): Support BOOL input return op.BitwiseNot(self) +@torch_op("aten::bitwise_not", overload=True) +def aten_bitwise_not_bool(self: BOOL) -> BOOL: + # bitwise_not(Tensor self) -> Tensor + # FIXME(titaiwang): This is a hack to get around the fact that we don't have op.BitwiseNot supporting bool now. + # We should remove this once we have a proper implementation. + return op.Not(self) + + @torch_op("aten::bitwise_or") def aten_bitwise_or(self: TInt, other: TInt) -> TInt: # bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor @@ -3146,9 +3156,8 @@ def aten_margin_ranking_loss( @torch_op("aten::masked_fill") def aten_masked_fill(self: TTensor, mask: BOOL, value: TTensor) -> TTensor: # masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor - mask_cast = op.Cast(mask, to=BOOL.dtype) value_cast = op.CastLike(value, self) - return op.Where(mask_cast, value_cast, self) + return op.Where(mask, value_cast, self) def aten_masked_scatter(self: TensorType, mask: TensorType, source: TensorType) -> TensorType: @@ -3648,7 +3657,8 @@ def aten_msort(self: TensorType) -> TensorType: @torch_op("aten::mul") def aten_mul(self: TReal, other: TReal) -> TReal: # mul.Tensor(Tensor self, Tensor other) -> Tensor - + # TODO(titaiwang): Delete this when we have type promotion + other = op.CastLike(other, self) return op.Mul(self, other) @@ -4697,6 +4707,7 @@ def aten_rsqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16: def aten_rsub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: # rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor alpha = op.CastLike(alpha, self) + return op.Sub(other, op.Mul(self, alpha)) @@ -4848,7 +4859,8 @@ def aten_slice( else: step = op.Constant(value_ints=[1]) - return op.Slice(self, start, end, dim, step) + # TODO(titaiwang): Delete this Cast when we have type promotion + return op.Cast(op.Slice(self, start, end, dim, step), to=FLOAT.dtype) def aten_slice_backward( @@ -5017,9 +5029,10 @@ def aten_stft( @torch_op("aten::sub") def aten_sub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: # sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor + # TODO(titaiwang): Delete this when we have type promotion + other = op.CastLike(other, self) alpha = op.CastLike(alpha, other) other = op.Mul(other, alpha) - return op.Sub(self, other)