From f60393a60c3f4ebedd52fe5161989ef4ee609f71 Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Wed, 15 Feb 2023 06:11:16 +0000 Subject: [PATCH 1/4] minor adjustments on Ops to unblock bloom --- .../function_libs/torch_aten/ops/core.py | 42 ++++++++++--------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 67d0106684..0bbee510ac 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -59,6 +59,8 @@ def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: # add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor alpha = op.CastLike(alpha, other) other = op.Mul(other, alpha) + # TODO(titaiwang): Delete this when we have type promotion + other = op.CastLike(other, self) return op.Add(self, other) @@ -205,7 +207,7 @@ def aten_amax(self: TReal, dim: Optional[int] = None, keepdim: bool = False) -> # TODO(justinchuby): Make dim INT64 after we upgrade to onnxruntime 1.14 if dim is None: - return opset17.ReduceMax(self, keepdims=keepdim) + return op.ReduceMax(self, dim, keepdims=keepdim) if not isinstance(dim, Sequence): dims = [dim] else: @@ -216,11 +218,11 @@ def aten_amax(self: TReal, dim: Optional[int] = None, keepdim: bool = False) -> @torch_op("aten::amax", overload=True) def _aten_amax_onnx(self: TReal, axes: Sequence[int], keepdims: bool) -> TReal: # TODO(justinchuby): Use opset18 after we upgrade to onnxruntime 1.14 - if opset17.Size(opset17.Shape(self)) == 0: + if op.Size(op.Shape(self)) == 0: # Scalar result = self else: - result = opset17.ReduceMax(self, axes=axes, keepdims=keepdims) + result = op.ReduceMax(self, axes, keepdims=keepdims) return result @@ -230,7 +232,7 @@ def aten_amin(self: TReal, dim: Optional[int] = None, keepdim: bool = False) -> # TODO(justinchuby): Make dim INT64 after we upgrade to onnxruntime 1.14 if dim is None: - return opset17.ReduceMin(self, keepdims=keepdim) + return op.ReduceMin(self, dim, keepdims=keepdim) if not isinstance(dim, Sequence): dims = [dim] else: @@ -241,11 +243,11 @@ def aten_amin(self: TReal, dim: Optional[int] = None, keepdim: bool = False) -> @torch_op("aten::amin", overload=True) def _aten_amin_onnx(self: TReal, axes: Sequence[int], keepdims: bool) -> TReal: # TODO(justinchuby): Use opset18 after we upgrade to onnxruntime 1.14 - if opset17.Size(opset17.Shape(self)) == 0: + if op.Size(op.Shape(self)) == 0: # Scalar result = self else: - result = opset17.ReduceMin(self, axes=axes, keepdims=keepdims) + result = op.ReduceMin(self, axes, keepdims=keepdims) return result @@ -765,7 +767,7 @@ def aten_bitwise_left_shift(self: TInt, other: TInt) -> TInt: @torch_op("aten::bitwise_not") def aten_bitwise_not(self: TInt) -> TInt: # bitwise_not(Tensor self) -> Tensor - + self = op.Cast(self, to=INT64.dtype) return op.BitwiseNot(self) @@ -3874,18 +3876,18 @@ def _aten_native_layer_norm_onnx( ) -> Tuple[TReal, TReal, TReal]: # FIXME(justinchuby): Use opset18 when it is supported by onnxruntime - mean = opset17.ReduceMean(input, axes=axes) - numerator = opset17.Sub(input, mean) - power_num = opset17.Pow(numerator, 2.0) - variance = opset17.ReduceMean(power_num, axes=axes) - variance_eps = opset17.Add(variance, eps) - denominator = opset17.Sqrt(variance_eps) - result = opset17.Div(numerator, denominator) - weight = opset17.CastLike(weight, result) - result = opset17.Mul(result, weight) - bias = opset17.CastLike(bias, result) - result = opset17.Add(result, bias) - rdenominator = opset17.Reciprocal(denominator) + mean = op.ReduceMean(input, axes) + numerator = op.Sub(input, mean) + power_num = op.Pow(numerator, 2.0) + variance = op.ReduceMean(power_num, axes) + variance_eps = op.Add(variance, eps) + denominator = op.Sqrt(variance_eps) + result = op.Div(numerator, denominator) + weight = op.CastLike(weight, result) + result = op.Mul(result, weight) + bias = op.CastLike(bias, result) + result = op.Add(result, bias) + rdenominator = op.Reciprocal(denominator) return result, mean, rdenominator @@ -4848,7 +4850,7 @@ def aten_slice( else: step = op.Constant(value_ints=[1]) - return op.Slice(self, start, end, dim, step) + return op.Cast(op.Slice(self, start, end, dim, step), to=FLOAT.dtype) def aten_slice_backward( From f176db92e80422598f2014fa1ed233a437298b7a Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Wed, 15 Feb 2023 17:36:21 +0000 Subject: [PATCH 2/4] bloom modified --- .../function_libs/torch_aten/ops/core.py | 25 +++++++++++++------ 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 0bbee510ac..a7e2257a2c 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -26,7 +26,6 @@ TRealUnlessInt16OrInt8, TTensor, ) -from onnxscript.onnx_opset import opset17 from onnxscript.onnx_opset import opset18 as op from onnxscript.onnx_types import TensorType @@ -57,10 +56,10 @@ def aten_acosh(self: TFloat) -> TFloat: @torch_op("aten::add") def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: # add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor - alpha = op.CastLike(alpha, other) - other = op.Mul(other, alpha) # TODO(titaiwang): Delete this when we have type promotion other = op.CastLike(other, self) + alpha = op.CastLike(alpha, other) + other = op.Mul(other, alpha) return op.Add(self, other) @@ -767,10 +766,18 @@ def aten_bitwise_left_shift(self: TInt, other: TInt) -> TInt: @torch_op("aten::bitwise_not") def aten_bitwise_not(self: TInt) -> TInt: # bitwise_not(Tensor self) -> Tensor - self = op.Cast(self, to=INT64.dtype) + # TODO(titaiwang): Support BOOL input return op.BitwiseNot(self) +@torch_op("aten::bitwise_not", overload=True) +def aten_bitwise_not_bool(self: BOOL) -> BOOL: + # bitwise_not(Tensor self) -> Tensor + # FIXME(titaiwang): This is a hack to get around the fact that we don't have op.BitwiseNot supporting bool now. + # We should remove this once we have a proper implementation. + return op.Not(self) + + @torch_op("aten::bitwise_or") def aten_bitwise_or(self: TInt, other: TInt) -> TInt: # bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor @@ -3148,9 +3155,8 @@ def aten_margin_ranking_loss( @torch_op("aten::masked_fill") def aten_masked_fill(self: TTensor, mask: BOOL, value: TTensor) -> TTensor: # masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor - mask_cast = op.Cast(mask, to=BOOL.dtype) value_cast = op.CastLike(value, self) - return op.Where(mask_cast, value_cast, self) + return op.Where(mask, value_cast, self) def aten_masked_scatter(self: TensorType, mask: TensorType, source: TensorType) -> TensorType: @@ -3650,7 +3656,8 @@ def aten_msort(self: TensorType) -> TensorType: @torch_op("aten::mul") def aten_mul(self: TReal, other: TReal) -> TReal: # mul.Tensor(Tensor self, Tensor other) -> Tensor - + # TODO(titaiwang): Delete this when we have type promotion + other = op.CastLike(other, self) return op.Mul(self, other) @@ -4699,6 +4706,7 @@ def aten_rsqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16: def aten_rsub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: # rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor alpha = op.CastLike(alpha, self) + return op.Sub(other, op.Mul(self, alpha)) @@ -5019,9 +5027,10 @@ def aten_stft( @torch_op("aten::sub") def aten_sub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: # sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor + # TODO(titaiwang): Delete this when we have type promotion + other = op.CastLike(other, self) alpha = op.CastLike(alpha, other) other = op.Mul(other, alpha) - return op.Sub(self, other) From 9397ce331db068188924a98d8dd895577020e0a8 Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Wed, 15 Feb 2023 18:04:35 +0000 Subject: [PATCH 3/4] fall back to 17 --- .../function_libs/torch_aten/ops/core.py | 37 ++++++++++--------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index a7e2257a2c..cb8d402e63 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -26,6 +26,7 @@ TRealUnlessInt16OrInt8, TTensor, ) +from onnxscript.onnx_opset import opset17 from onnxscript.onnx_opset import opset18 as op from onnxscript.onnx_types import TensorType @@ -206,7 +207,7 @@ def aten_amax(self: TReal, dim: Optional[int] = None, keepdim: bool = False) -> # TODO(justinchuby): Make dim INT64 after we upgrade to onnxruntime 1.14 if dim is None: - return op.ReduceMax(self, dim, keepdims=keepdim) + return opset17.ReduceMax(self, keepdims=keepdim) if not isinstance(dim, Sequence): dims = [dim] else: @@ -217,11 +218,11 @@ def aten_amax(self: TReal, dim: Optional[int] = None, keepdim: bool = False) -> @torch_op("aten::amax", overload=True) def _aten_amax_onnx(self: TReal, axes: Sequence[int], keepdims: bool) -> TReal: # TODO(justinchuby): Use opset18 after we upgrade to onnxruntime 1.14 - if op.Size(op.Shape(self)) == 0: + if opset17.Size(opset17.Shape(self)) == 0: # Scalar result = self else: - result = op.ReduceMax(self, axes, keepdims=keepdims) + result = opset17.ReduceMax(self, axes=axes, keepdims=keepdims) return result @@ -231,7 +232,7 @@ def aten_amin(self: TReal, dim: Optional[int] = None, keepdim: bool = False) -> # TODO(justinchuby): Make dim INT64 after we upgrade to onnxruntime 1.14 if dim is None: - return op.ReduceMin(self, dim, keepdims=keepdim) + return opset17.ReduceMin(self, keepdims=keepdim) if not isinstance(dim, Sequence): dims = [dim] else: @@ -242,11 +243,11 @@ def aten_amin(self: TReal, dim: Optional[int] = None, keepdim: bool = False) -> @torch_op("aten::amin", overload=True) def _aten_amin_onnx(self: TReal, axes: Sequence[int], keepdims: bool) -> TReal: # TODO(justinchuby): Use opset18 after we upgrade to onnxruntime 1.14 - if op.Size(op.Shape(self)) == 0: + if opset17.Size(opset17.Shape(self)) == 0: # Scalar result = self else: - result = op.ReduceMin(self, axes, keepdims=keepdims) + result = opset17.ReduceMin(self, axes=axes, keepdims=keepdims) return result @@ -3883,18 +3884,18 @@ def _aten_native_layer_norm_onnx( ) -> Tuple[TReal, TReal, TReal]: # FIXME(justinchuby): Use opset18 when it is supported by onnxruntime - mean = op.ReduceMean(input, axes) - numerator = op.Sub(input, mean) - power_num = op.Pow(numerator, 2.0) - variance = op.ReduceMean(power_num, axes) - variance_eps = op.Add(variance, eps) - denominator = op.Sqrt(variance_eps) - result = op.Div(numerator, denominator) - weight = op.CastLike(weight, result) - result = op.Mul(result, weight) - bias = op.CastLike(bias, result) - result = op.Add(result, bias) - rdenominator = op.Reciprocal(denominator) + mean = opset17.ReduceMean(input, axes=axes) + numerator = opset17.Sub(input, mean) + power_num = opset17.Pow(numerator, 2.0) + variance = opset17.ReduceMean(power_num, axes=axes) + variance_eps = opset17.Add(variance, eps) + denominator = opset17.Sqrt(variance_eps) + result = opset17.Div(numerator, denominator) + weight = opset17.CastLike(weight, result) + result = opset17.Mul(result, weight) + bias = opset17.CastLike(bias, result) + result = opset17.Add(result, bias) + rdenominator = opset17.Reciprocal(denominator) return result, mean, rdenominator From 6e8f0c86be09fab16fcd0dd6cc3d901e1a844e33 Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Wed, 15 Feb 2023 18:11:40 +0000 Subject: [PATCH 4/4] comments --- onnxscript/function_libs/torch_aten/ops/core.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index cb8d402e63..5d310e6f0f 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -4859,6 +4859,7 @@ def aten_slice( else: step = op.Constant(value_ints=[1]) + # TODO(titaiwang): Delete this Cast when we have type promotion return op.Cast(op.Slice(self, start, end, dim, step), to=FLOAT.dtype)