diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 67d0106684..b49e5fe40b 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -769,6 +769,12 @@ def aten_bitwise_not(self: TInt) -> TInt: return op.BitwiseNot(self) +@torch_op("aten::bitwise_not", overload=True) +def aten_bitwise_not_bool(self: BOOL) -> BOOL: + # bitwise_not(Tensor self) -> Tensor + return op.Not(self) + + @torch_op("aten::bitwise_or") def aten_bitwise_or(self: TInt, other: TInt) -> TInt: # bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor @@ -3146,9 +3152,10 @@ def aten_margin_ranking_loss( @torch_op("aten::masked_fill") def aten_masked_fill(self: TTensor, mask: BOOL, value: TTensor) -> TTensor: # masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor - mask_cast = op.Cast(mask, to=BOOL.dtype) + # NOTE: Do not attempt to cast `mask` to BOOL because mask should not take any other types. + # `mask` coming in as other types is often an error and should fail the model. value_cast = op.CastLike(value, self) - return op.Where(mask_cast, value_cast, self) + return op.Where(mask, value_cast, self) def aten_masked_scatter(self: TensorType, mask: TensorType, source: TensorType) -> TensorType: