diff --git a/onnxscript/rewriter/_matcher.py b/onnxscript/rewriter/_matcher.py index e347b98375..f54b77033f 100644 --- a/onnxscript/rewriter/_matcher.py +++ b/onnxscript/rewriter/_matcher.py @@ -87,7 +87,7 @@ def _match_constant(self, pattern_constant: _pattern_ir.Constant, value: ir.Valu ) try: - constant_value_numpy = constant_value.numpy() + numpy_value = constant_value.numpy() except FileNotFoundError: return self.fail(f"Constant value of {value.name} not available.") @@ -95,11 +95,13 @@ def _match_constant(self, pattern_constant: _pattern_ir.Constant, value: ir.Valu if isinstance(pattern_constant_value, list): expected_shape = (len(pattern_constant_value),) - if constant_value_numpy.shape != expected_shape: - return self.fail(f"Value has mismatching shape, expecting {expected_shape}.") + if numpy_value.shape != expected_shape: + return self.fail( + f"Value {value.name} has shape {numpy_value.shape}, expecting {expected_shape}." + ) if not all( math.isclose( - constant_value_numpy.item(i), + numpy_value.item(i), pattern_constant_value[i], rel_tol=pattern_constant._rel_tol, abs_tol=pattern_constant._abs_tol, @@ -107,24 +109,24 @@ def _match_constant(self, pattern_constant: _pattern_ir.Constant, value: ir.Valu for i in range(len(pattern_constant_value)) ): return self.fail( - f"Value mismatch: expected {pattern_constant_value}, got {constant_value_numpy}." + f"Value mismatch: expected {pattern_constant_value}, got {numpy_value}." ) return True # TODO (rama): allow users to specify shape requirement, if desired. - if constant_value_numpy.size != 1: + if numpy_value.ndim != 0: return self.fail( f"Value {value.name} is not a scalar, expecting {pattern_constant_value}.", ) if not math.isclose( - constant_value_numpy.item(), + numpy_value.item(), pattern_constant_value, rel_tol=pattern_constant._rel_tol, abs_tol=pattern_constant._abs_tol, ): return self.fail( - f"Constant value mismatch: expected {pattern_constant_value}, got {constant_value_numpy.item()}.", + f"Constant value mismatch: expected {pattern_constant_value}, got {numpy_value.item()}.", ) return True diff --git a/onnxscript/rewriter/models/_rotary_embedding_models.py b/onnxscript/rewriter/models/_rotary_embedding_models.py index ecdb7d138b..3709cd04f7 100644 --- a/onnxscript/rewriter/models/_rotary_embedding_models.py +++ b/onnxscript/rewriter/models/_rotary_embedding_models.py @@ -26,8 +26,8 @@ def _test_case_1_script(x: FLOAT[1, 4, 8, 8], position_ids: INT64[1, 8]) -> FLOA emb = op.Concat(freqs, freqs, axis=-1) cos = op.Cos(emb) sin = op.Sin(emb) - cos_4d = op.Unsqueeze(cos, 1) - sin_4d = op.Unsqueeze(sin, 1) + cos_4d = op.Unsqueeze(cos, [1]) + sin_4d = op.Unsqueeze(sin, [1]) x1 = op.Slice(x, [0], [4], [3], [1]) x2 = op.Slice(x, [4], [8], [3], [1]) @@ -73,8 +73,8 @@ def _test_case_2_script(x: FLOAT[1, 4, 8, 8], position_ids: INT64[8]) -> FLOAT[1 emb = op.Concat(freqs, freqs, axis=-1) cos = op.Cos(emb) sin = op.Sin(emb) - cos_4d = op.Unsqueeze(cos, 1) - sin_4d = op.Unsqueeze(sin, 1) + cos_4d = op.Unsqueeze(cos, [1]) + sin_4d = op.Unsqueeze(sin, [1]) x1 = op.Slice(x, [0], [4], [3], [1]) x2 = op.Slice(x, [4], [8], [3], [1]) @@ -127,8 +127,8 @@ def _partial_rotary_script(position_ids, query): # Split the query for partial embedding to_embed = op.Slice(query, [0], [32], [3], [1]) unembedded = op.Slice(query, [32], [9223372036854775807], [3], [1]) - cos_4d = op.Unsqueeze(cos_3d, 1) # [B, 1, S, rd] - sin_4d = op.Unsqueeze(sin_3d, 1) # [B, 1, S, rd] + cos_4d = op.Unsqueeze(cos_3d, [1]) # [B, 1, S, rd] + sin_4d = op.Unsqueeze(sin_3d, [1]) # [B, 1, S, rd] # Compute rotation of X as X * cos + rotate_half(X) * sin, where rotate_half(X) # essentially represents X rotated by 90 degrees to_embed_times_cos = op.Mul(to_embed, cos_4d) diff --git a/onnxscript/rewriter/models/_smollm_1.py b/onnxscript/rewriter/models/_smollm_1.py index d592eb2572..e3efecfe17 100644 --- a/onnxscript/rewriter/models/_smollm_1.py +++ b/onnxscript/rewriter/models/_smollm_1.py @@ -59,8 +59,8 @@ def main_graph( minus_inf_10x10 = opset18.ConstantOfShape([10, 10], [-3.4028234663852886e38]) mask_10x10 = opset18.Trilu(minus_inf_10x10, 1) slice_5 = opset18.Reshape(mask_10x10, [1, 1, 10, 10]) - unsqueeze_2 = opset18.Unsqueeze(input1, 1) - unsqueeze_3 = opset18.Unsqueeze(unsqueeze_2, 2) + unsqueeze_2 = opset18.Unsqueeze(input1, [1]) + unsqueeze_3 = opset18.Unsqueeze(unsqueeze_2, [2]) add = slice_5 + unsqueeze_3 eq = add == 0.0 slice_10 = slice_5 @@ -69,7 +69,7 @@ def main_graph( slice_scatter = opset18.Transpose(val_179, perm=[2, 1, 0, 3]) val_191 = opset18.Transpose(slice_scatter, perm=[1, 0, 2, 3]) slice_scatter_1 = opset18.Transpose(val_191, perm=[1, 0, 2, 3]) - unsqueeze_6 = opset18.Unsqueeze(input2, 1) + unsqueeze_6 = opset18.Unsqueeze(input2, [1]) to_copy_1 = opset18.Cast(unsqueeze_6, to=1) view_1 = opset18.Constant( value=ir.tensor( @@ -138,8 +138,8 @@ def main_graph( transpose_2 = opset18.Transpose(view_11, perm=[0, 2, 1, 3]) view_12 = opset18.Reshape(view_9, [1, 10, 32, 64], allowzero=0) transpose_3 = opset18.Transpose(view_12, perm=[0, 2, 1, 3]) - unsqueeze_7 = opset18.Unsqueeze(cos, 1) - unsqueeze_8 = opset18.Unsqueeze(sin, 1) + unsqueeze_7 = opset18.Unsqueeze(cos, [1]) + unsqueeze_8 = opset18.Unsqueeze(sin, [1]) mul_5 = transpose_1 * unsqueeze_7 val_267 = opset18.Constant(value_ints=[1]) slice_19 = opset18.Slice(transpose_1, [0], [32], [3], val_267) diff --git a/onnxscript/rewriter/models/_smollm_2.py b/onnxscript/rewriter/models/_smollm_2.py index 62d857a2d6..47ad451895 100644 --- a/onnxscript/rewriter/models/_smollm_2.py +++ b/onnxscript/rewriter/models/_smollm_2.py @@ -51,7 +51,7 @@ def main_graph( gt = arange_1 > view convert_element_type_default = opset18.Cast(gt, to=1) mul = triu * convert_element_type_default - dim__2 = opset18.Constant(value_int=0) + dim__2 = opset18.Constant(value_ints=[0]) dim_0__2 = opset18.Cast(dim__2, to=7) unsqueeze = opset18.Unsqueeze(model_rotary_emb_inv_freq, dim_0__2) val_15 = opset18.Cast(0, to=7) @@ -65,7 +65,7 @@ def main_graph( val_25 = opset18.Reshape(val_23, val_24, allowzero=0) val_26 = opset18.Constant(value_ints=[1]) slice_1 = opset18.Slice(unsqueeze, val_17, val_21, val_25, val_26) - dim__3 = opset18.Constant(value_int=2) + dim__3 = opset18.Constant(value_ints=[2]) dim_0__3 = opset18.Cast(dim__3, to=7) unsqueeze_1 = opset18.Unsqueeze(slice_1, dim_0__3) _to_copy = opset18.Cast(unsqueeze_1, to=1) @@ -83,7 +83,7 @@ def main_graph( val_36 = opset18.Reshape(val_34, val_35, allowzero=0) val_37 = opset18.Constant(value_ints=[1]) slice_2 = opset18.Slice(position_ids, val_30, val_33, val_36, val_37) - dim__5 = opset18.Constant(value_int=1) + dim__5 = opset18.Constant(value_ints=[1]) dim_0__5 = opset18.Cast(dim__5, to=7) unsqueeze_2 = opset18.Unsqueeze(slice_2, dim_0__5) val_38 = opset18.Cast(0, to=7) @@ -160,10 +160,10 @@ def main_graph( val_71 = opset18.Cast([1, 30, 32, 64], to=7) view_12 = opset18.Reshape(view_9, val_71, allowzero=0) transpose_3 = opset18.Transpose(view_12, perm=[0, 2, 1, 3]) - dim__8 = opset18.Constant(value_int=1) + dim__8 = opset18.Constant(value_ints=[1]) dim_0__8 = opset18.Cast(dim__8, to=7) unsqueeze_3 = opset18.Unsqueeze(_to_copy_4, dim_0__8) - dim__9 = opset18.Constant(value_int=1) + dim__9 = opset18.Constant(value_ints=[1]) dim_0__9 = opset18.Cast(dim__9, to=7) unsqueeze_4 = opset18.Unsqueeze(_to_copy_5, dim_0__9) mul_5 = transpose_1 * unsqueeze_3 @@ -222,10 +222,10 @@ def main_graph( add_2 = mul_7 + mul_8 cat_3 = opset18.Concat(past_key_values_0_0, add_2, axis=-2) cat_4 = opset18.Concat(past_key_values_0_1, transpose_3, axis=-2) - dim__10 = opset18.Constant(value_int=0) + dim__10 = opset18.Constant(value_ints=[0]) dim_0__10 = opset18.Cast(dim__10, to=7) unsqueeze_5 = opset18.Unsqueeze(mul, dim_0__10) - dim__11 = opset18.Constant(value_int=1) + dim__11 = opset18.Constant(value_ints=[1]) dim_0__11 = opset18.Cast(dim__11, to=7) unsqueeze_6 = opset18.Unsqueeze(unsqueeze_5, dim_0__11) val_114 = opset18.Cast(0, to=7) diff --git a/onnxscript/rewriter/ort_fusions/cos_sin_cache.py b/onnxscript/rewriter/ort_fusions/cos_sin_cache.py index cba06d2fb7..8e6ec1d9da 100644 --- a/onnxscript/rewriter/ort_fusions/cos_sin_cache.py +++ b/onnxscript/rewriter/ort_fusions/cos_sin_cache.py @@ -148,8 +148,8 @@ def pattern( sin = op.Sin(emb) if self._cast: sin = op.Cast(sin, to=dtype) - cos_4d = op.Unsqueeze(cos, 1) # convert - sin_4d = op.Unsqueeze(sin, 1) + cos_4d = op.Unsqueeze(cos, [1]) # convert + sin_4d = op.Unsqueeze(sin, [1]) return op.RotaryEmbedding( x, cos_4d, diff --git a/onnxscript/rewriter/ort_fusions/cos_sin_cache_test.py b/onnxscript/rewriter/ort_fusions/cos_sin_cache_test.py index 48842aa429..4245916c64 100644 --- a/onnxscript/rewriter/ort_fusions/cos_sin_cache_test.py +++ b/onnxscript/rewriter/ort_fusions/cos_sin_cache_test.py @@ -45,7 +45,7 @@ def test_cos_sin_fusion(self, name, test_data_constructor): original_outputs = ort_run("original", model, inputs) count = fuse_rotary_embedding(model) self.assertGreater(count, 0) - count = fuse_cos_sin_cache(model) + count = fuse_cos_sin_cache(model, debug=True) self.assertGreater(count, 0) new_outputs = ort_run("optimized", model, inputs) assert_allclose(new_outputs, original_outputs) diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index 907ffe27bc..bf883c58bc 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -223,7 +223,7 @@ def pattern( key_seq_BHkvTDh = op.Concat(past_key, key_BHkvSDh_rope, axis=-2) # Concat with past_key is optional: key_seq_BHkvTDh = pattern.OrValue([key_seq_BHkvTDh, key_BHkvSDh_rope]) - key_seq_BHkv1TDh = op.Unsqueeze(key_seq_BHkvTDh, 2) + key_seq_BHkv1TDh = op.Unsqueeze(key_seq_BHkvTDh, [2]) key_seq_BHkvGTDh = op.Expand(key_seq_BHkv1TDh, pattern.ANY_VALUE) key_seq_BHTDh = op.Reshape( key_seq_BHkvGTDh, pattern.ANY_VALUE, _outputs=["key_seq_BHTDh"] @@ -234,7 +234,7 @@ def pattern( value_seq_BHkvTDh = op.Concat(past_value, value_BHkvSDh, axis=-2) # Concat with past_value is optional: value_seq_BHkvTDh = pattern.OrValue([value_seq_BHkvTDh, value_BHkvSDh]) - value_seq_BHkv1TDh = op.Unsqueeze(value_seq_BHkvTDh, 2) + value_seq_BHkv1TDh = op.Unsqueeze(value_seq_BHkvTDh, [2]) value_seq_BHkvGTDh = op.Expand(value_seq_BHkv1TDh, pattern.ANY_VALUE) value_seq_BHTDh = op.Reshape( value_seq_BHkvGTDh, pattern.ANY_VALUE, _outputs=["value_seq_BHTDh"] diff --git a/onnxscript/rewriter/ort_fusions/gqa_test.py b/onnxscript/rewriter/ort_fusions/gqa_test.py index 038749c017..1a79b9c29f 100644 --- a/onnxscript/rewriter/ort_fusions/gqa_test.py +++ b/onnxscript/rewriter/ort_fusions/gqa_test.py @@ -195,11 +195,11 @@ def gqa(query, key, value, past_key, past_value, cos, sin): value_seq_BHkvSkvDh = op.Concat(past_value, value_BHkvSDh, axis=-2) # Now, expand from shared heads to all heads - key_BHkv1SDh = op.Unsqueeze(key_seq_BHkvSkvDh, 2) + key_BHkv1SDh = op.Unsqueeze(key_seq_BHkvSkvDh, [2]) key_BHkvGSDh = op.Expand(key_BHkv1SDh, shape_BHkvGSDh) key_BHSDh = op.Reshape(key_BHkvGSDh, shape_BHSDh) - value_BHkv1SDh = op.Unsqueeze(value_seq_BHkvSkvDh, 2) + value_BHkv1SDh = op.Unsqueeze(value_seq_BHkvSkvDh, [2]) value_BHkvGSDh = op.Expand(value_BHkv1SDh, shape_BHkvGSDh) value_BHSDh = op.Reshape(value_BHkvGSDh, shape_BHSDh) @@ -527,11 +527,11 @@ def gqa(query, key, value, past_key, past_value, cos, sin, query_scale, key_scal value_seq_BHkvSkvDh = value_BHkvSDh # Now, expand from shared heads to all heads - key_BHkv1SDh = op.Unsqueeze(key_seq_BHkvSkvDh, 2) + key_BHkv1SDh = op.Unsqueeze(key_seq_BHkvSkvDh, [2]) key_BHkvGSDh = op.Expand(key_BHkv1SDh, shape_BHkvGSDh) key_BHSDh = op.Reshape(key_BHkvGSDh, shape_BHSDh) - value_BHkv1SDh = op.Unsqueeze(value_seq_BHkvSkvDh, 2) + value_BHkv1SDh = op.Unsqueeze(value_seq_BHkvSkvDh, [2]) value_BHkvGSDh = op.Expand(value_BHkv1SDh, shape_BHkvGSDh) value_BHSDh = op.Reshape(value_BHkvGSDh, shape_BHSDh) diff --git a/onnxscript/rewriter/rules/common/_no_op_test.py b/onnxscript/rewriter/rules/common/_no_op_test.py index 7815473e34..2c2f9e6e2b 100644 --- a/onnxscript/rewriter/rules/common/_no_op_test.py +++ b/onnxscript/rewriter/rules/common/_no_op_test.py @@ -15,6 +15,11 @@ def _check(self, model_text: str) -> None: self.assertEqual(count, 1) self.assertEqual(model.graph[-1].op_type, "Identity") + def _check_no_optimization(self, model_text: str) -> None: + model = ir.from_onnx_text(model_text) + count = _no_op.rules.apply_to_model(model) + self.assertEqual(count, 0) + @parameterized.parameterized.expand( [ ("float one input", "float[M]", "value_float=1.0", "one, input"), @@ -195,6 +200,17 @@ def test_dropout_zero_or_inference_no_op_with_initializer(self, _, attribute: st ) # TODO: Test the negative cases + def test_broadcast_is_not_eliminated(self): + model_text = """ + + agraph (float[M] input) => (float[1, 1, M] output) + + { + output = Add(zero, input) + } + """ + self._check_no_optimization(model_text) + if __name__ == "__main__": unittest.main()