From 541dcf5af1d792ba5130bbebdc7581a720b2fc20 Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Wed, 15 Feb 2023 20:38:26 +0000 Subject: [PATCH 1/2] bloom needed minor changes on Ops --- onnxscript/function_libs/torch_aten/ops/core.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 67d0106684..3fe4c8538f 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -769,6 +769,12 @@ def aten_bitwise_not(self: TInt) -> TInt: return op.BitwiseNot(self) +@torch_op("aten::bitwise_not", overload=True) +def aten_bitwise_not_bool(self: BOOL) -> BOOL: + # bitwise_not(Tensor self) -> Tensor + return op.Not(self) + + @torch_op("aten::bitwise_or") def aten_bitwise_or(self: TInt, other: TInt) -> TInt: # bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor @@ -3146,9 +3152,8 @@ def aten_margin_ranking_loss( @torch_op("aten::masked_fill") def aten_masked_fill(self: TTensor, mask: BOOL, value: TTensor) -> TTensor: # masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor - mask_cast = op.Cast(mask, to=BOOL.dtype) value_cast = op.CastLike(value, self) - return op.Where(mask_cast, value_cast, self) + return op.Where(mask, value_cast, self) def aten_masked_scatter(self: TensorType, mask: TensorType, source: TensorType) -> TensorType: From c1f9206c7dd50ece17ef5a326b27ad4d79b03895 Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Wed, 15 Feb 2023 20:50:30 +0000 Subject: [PATCH 2/2] add comments --- onnxscript/function_libs/torch_aten/ops/core.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 3fe4c8538f..b49e5fe40b 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -3152,6 +3152,8 @@ def aten_margin_ranking_loss( @torch_op("aten::masked_fill") def aten_masked_fill(self: TTensor, mask: BOOL, value: TTensor) -> TTensor: # masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor + # NOTE: Do not attempt to cast `mask` to BOOL because mask should not take any other types. + # `mask` coming in as other types is often an error and should fail the model. value_cast = op.CastLike(value, self) return op.Where(mask, value_cast, self)