Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions onnxscript/function_libs/torch_aten/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def aten_acosh(self: TFloat) -> TFloat:
@torch_op("aten::add")
def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
# add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
# TODO(titaiwang): Delete this when we have type promotion
other = op.CastLike(other, self)
alpha = op.CastLike(alpha, other)
other = op.Mul(other, alpha)
return op.Add(self, other)
Expand Down Expand Up @@ -765,10 +767,18 @@ def aten_bitwise_left_shift(self: TInt, other: TInt) -> TInt:
@torch_op("aten::bitwise_not")
def aten_bitwise_not(self: TInt) -> TInt:
# bitwise_not(Tensor self) -> Tensor

# TODO(titaiwang): Support BOOL input
return op.BitwiseNot(self)


@torch_op("aten::bitwise_not", overload=True)
def aten_bitwise_not_bool(self: BOOL) -> BOOL:
# bitwise_not(Tensor self) -> Tensor
# FIXME(titaiwang): This is a hack to get around the fact that we don't have op.BitwiseNot supporting bool now.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will have to go with this I think. Not a hack as we need a dispatcher to handle different dtypes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need if for aten_bitwise_not?

Copy link
Collaborator

@justinchuby justinchuby Feb 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh we cannot do if because the dtypes are different, and we don’t know the dtype within a function

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So even trace only won’t work for us here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So how should we know the dtype beforehand?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe there are only certain kinds of ops can be input of BitWise operations that we can do the processing in those ops?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guess we can forget it...

@_onnx_symbolic("aten::bitwise_not")
@_beartype.beartype
def bitwise_not(g: jit_utils.GraphContext, input):
    if not symbolic_helper._is_bool(input):
        raise errors.SymbolicValueError(
            "ONNX export does NOT support exporting bitwise Not "
            "for non-boolean input values",
            input,
        )
    return g.op("Not", input)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So how should we know the dtype beforehand?

the exporter knows about dtypes so it can choose which function to use

# We should remove this once we have a proper implementation.
return op.Not(self)


@torch_op("aten::bitwise_or")
def aten_bitwise_or(self: TInt, other: TInt) -> TInt:
# bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor
Expand Down Expand Up @@ -3146,9 +3156,8 @@ def aten_margin_ranking_loss(
@torch_op("aten::masked_fill")
def aten_masked_fill(self: TTensor, mask: BOOL, value: TTensor) -> TTensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you create another PR for the non-hacks we identified, including changes to mask fill and bitwise not (both versions) so that we can merge into main? The rest can stay here for reference but we are probably not going to merge?

# masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor
mask_cast = op.Cast(mask, to=BOOL.dtype)
value_cast = op.CastLike(value, self)
return op.Where(mask_cast, value_cast, self)
return op.Where(mask, value_cast, self)


def aten_masked_scatter(self: TensorType, mask: TensorType, source: TensorType) -> TensorType:
Expand Down Expand Up @@ -3648,7 +3657,8 @@ def aten_msort(self: TensorType) -> TensorType:
@torch_op("aten::mul")
def aten_mul(self: TReal, other: TReal) -> TReal:
# mul.Tensor(Tensor self, Tensor other) -> Tensor

# TODO(titaiwang): Delete this when we have type promotion
other = op.CastLike(other, self)
return op.Mul(self, other)


Expand Down Expand Up @@ -4697,6 +4707,7 @@ def aten_rsqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_rsub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
# rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
alpha = op.CastLike(alpha, self)

return op.Sub(other, op.Mul(self, alpha))


Expand Down Expand Up @@ -4848,7 +4859,8 @@ def aten_slice(
else:
step = op.Constant(value_ints=[1])

return op.Slice(self, start, end, dim, step)
# TODO(titaiwang): Delete this Cast when we have type promotion
return op.Cast(op.Slice(self, start, end, dim, step), to=FLOAT.dtype)


def aten_slice_backward(
Expand Down Expand Up @@ -5017,9 +5029,10 @@ def aten_stft(
@torch_op("aten::sub")
def aten_sub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
# sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
# TODO(titaiwang): Delete this when we have type promotion
other = op.CastLike(other, self)
alpha = op.CastLike(alpha, other)
other = op.Mul(other, alpha)

return op.Sub(self, other)


Expand Down