Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions onnxscript/rewriter/_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,44 +87,46 @@ def _match_constant(self, pattern_constant: _pattern_ir.Constant, value: ir.Valu
)

try:
constant_value_numpy = constant_value.numpy()
numpy_value = constant_value.numpy()
except FileNotFoundError:
return self.fail(f"Constant value of {value.name} not available.")

pattern_constant_value = pattern_constant._value

if isinstance(pattern_constant_value, list):
expected_shape = (len(pattern_constant_value),)
if constant_value_numpy.shape != expected_shape:
return self.fail(f"Value has mismatching shape, expecting {expected_shape}.")
if numpy_value.shape != expected_shape:
return self.fail(
f"Value {value.name} has shape {numpy_value.shape}, expecting {expected_shape}."
)
if not all(
math.isclose(
constant_value_numpy.item(i),
numpy_value.item(i),
pattern_constant_value[i],
rel_tol=pattern_constant._rel_tol,
abs_tol=pattern_constant._abs_tol,
)
for i in range(len(pattern_constant_value))
):
return self.fail(
f"Value mismatch: expected {pattern_constant_value}, got {constant_value_numpy}."
f"Value mismatch: expected {pattern_constant_value}, got {numpy_value}."
)
return True

# TODO (rama): allow users to specify shape requirement, if desired.
if constant_value_numpy.size != 1:
if numpy_value.ndim != 0:
return self.fail(
f"Value {value.name} is not a scalar, expecting {pattern_constant_value}.",
)

if not math.isclose(
constant_value_numpy.item(),
numpy_value.item(),
pattern_constant_value,
rel_tol=pattern_constant._rel_tol,
abs_tol=pattern_constant._abs_tol,
):
return self.fail(
f"Constant value mismatch: expected {pattern_constant_value}, got {constant_value_numpy.item()}.",
f"Constant value mismatch: expected {pattern_constant_value}, got {numpy_value.item()}.",
)

return True
Expand Down
12 changes: 6 additions & 6 deletions onnxscript/rewriter/models/_rotary_embedding_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def _test_case_1_script(x: FLOAT[1, 4, 8, 8], position_ids: INT64[1, 8]) -> FLOA
emb = op.Concat(freqs, freqs, axis=-1)
cos = op.Cos(emb)
sin = op.Sin(emb)
cos_4d = op.Unsqueeze(cos, 1)
sin_4d = op.Unsqueeze(sin, 1)
cos_4d = op.Unsqueeze(cos, [1])
sin_4d = op.Unsqueeze(sin, [1])

x1 = op.Slice(x, [0], [4], [3], [1])
x2 = op.Slice(x, [4], [8], [3], [1])
Expand Down Expand Up @@ -73,8 +73,8 @@ def _test_case_2_script(x: FLOAT[1, 4, 8, 8], position_ids: INT64[8]) -> FLOAT[1
emb = op.Concat(freqs, freqs, axis=-1)
cos = op.Cos(emb)
sin = op.Sin(emb)
cos_4d = op.Unsqueeze(cos, 1)
sin_4d = op.Unsqueeze(sin, 1)
cos_4d = op.Unsqueeze(cos, [1])
sin_4d = op.Unsqueeze(sin, [1])

x1 = op.Slice(x, [0], [4], [3], [1])
x2 = op.Slice(x, [4], [8], [3], [1])
Expand Down Expand Up @@ -127,8 +127,8 @@ def _partial_rotary_script(position_ids, query):
# Split the query for partial embedding
to_embed = op.Slice(query, [0], [32], [3], [1])
unembedded = op.Slice(query, [32], [9223372036854775807], [3], [1])
cos_4d = op.Unsqueeze(cos_3d, 1) # [B, 1, S, rd]
sin_4d = op.Unsqueeze(sin_3d, 1) # [B, 1, S, rd]
cos_4d = op.Unsqueeze(cos_3d, [1]) # [B, 1, S, rd]
sin_4d = op.Unsqueeze(sin_3d, [1]) # [B, 1, S, rd]
# Compute rotation of X as X * cos + rotate_half(X) * sin, where rotate_half(X)
# essentially represents X rotated by 90 degrees
to_embed_times_cos = op.Mul(to_embed, cos_4d)
Expand Down
10 changes: 5 additions & 5 deletions onnxscript/rewriter/models/_smollm_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def main_graph(
minus_inf_10x10 = opset18.ConstantOfShape([10, 10], [-3.4028234663852886e38])
mask_10x10 = opset18.Trilu(minus_inf_10x10, 1)
slice_5 = opset18.Reshape(mask_10x10, [1, 1, 10, 10])
unsqueeze_2 = opset18.Unsqueeze(input1, 1)
unsqueeze_3 = opset18.Unsqueeze(unsqueeze_2, 2)
unsqueeze_2 = opset18.Unsqueeze(input1, [1])
unsqueeze_3 = opset18.Unsqueeze(unsqueeze_2, [2])
add = slice_5 + unsqueeze_3
eq = add == 0.0
slice_10 = slice_5
Expand All @@ -69,7 +69,7 @@ def main_graph(
slice_scatter = opset18.Transpose(val_179, perm=[2, 1, 0, 3])
val_191 = opset18.Transpose(slice_scatter, perm=[1, 0, 2, 3])
slice_scatter_1 = opset18.Transpose(val_191, perm=[1, 0, 2, 3])
unsqueeze_6 = opset18.Unsqueeze(input2, 1)
unsqueeze_6 = opset18.Unsqueeze(input2, [1])
to_copy_1 = opset18.Cast(unsqueeze_6, to=1)
view_1 = opset18.Constant(
value=ir.tensor(
Expand Down Expand Up @@ -138,8 +138,8 @@ def main_graph(
transpose_2 = opset18.Transpose(view_11, perm=[0, 2, 1, 3])
view_12 = opset18.Reshape(view_9, [1, 10, 32, 64], allowzero=0)
transpose_3 = opset18.Transpose(view_12, perm=[0, 2, 1, 3])
unsqueeze_7 = opset18.Unsqueeze(cos, 1)
unsqueeze_8 = opset18.Unsqueeze(sin, 1)
unsqueeze_7 = opset18.Unsqueeze(cos, [1])
unsqueeze_8 = opset18.Unsqueeze(sin, [1])
mul_5 = transpose_1 * unsqueeze_7
val_267 = opset18.Constant(value_ints=[1])
slice_19 = opset18.Slice(transpose_1, [0], [32], [3], val_267)
Expand Down
14 changes: 7 additions & 7 deletions onnxscript/rewriter/models/_smollm_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def main_graph(
gt = arange_1 > view
convert_element_type_default = opset18.Cast(gt, to=1)
mul = triu * convert_element_type_default
dim__2 = opset18.Constant(value_int=0)
dim__2 = opset18.Constant(value_ints=[0])
dim_0__2 = opset18.Cast(dim__2, to=7)
unsqueeze = opset18.Unsqueeze(model_rotary_emb_inv_freq, dim_0__2)
val_15 = opset18.Cast(0, to=7)
Expand All @@ -65,7 +65,7 @@ def main_graph(
val_25 = opset18.Reshape(val_23, val_24, allowzero=0)
val_26 = opset18.Constant(value_ints=[1])
slice_1 = opset18.Slice(unsqueeze, val_17, val_21, val_25, val_26)
dim__3 = opset18.Constant(value_int=2)
dim__3 = opset18.Constant(value_ints=[2])
dim_0__3 = opset18.Cast(dim__3, to=7)
unsqueeze_1 = opset18.Unsqueeze(slice_1, dim_0__3)
_to_copy = opset18.Cast(unsqueeze_1, to=1)
Expand All @@ -83,7 +83,7 @@ def main_graph(
val_36 = opset18.Reshape(val_34, val_35, allowzero=0)
val_37 = opset18.Constant(value_ints=[1])
slice_2 = opset18.Slice(position_ids, val_30, val_33, val_36, val_37)
dim__5 = opset18.Constant(value_int=1)
dim__5 = opset18.Constant(value_ints=[1])
dim_0__5 = opset18.Cast(dim__5, to=7)
unsqueeze_2 = opset18.Unsqueeze(slice_2, dim_0__5)
val_38 = opset18.Cast(0, to=7)
Expand Down Expand Up @@ -160,10 +160,10 @@ def main_graph(
val_71 = opset18.Cast([1, 30, 32, 64], to=7)
view_12 = opset18.Reshape(view_9, val_71, allowzero=0)
transpose_3 = opset18.Transpose(view_12, perm=[0, 2, 1, 3])
dim__8 = opset18.Constant(value_int=1)
dim__8 = opset18.Constant(value_ints=[1])
dim_0__8 = opset18.Cast(dim__8, to=7)
unsqueeze_3 = opset18.Unsqueeze(_to_copy_4, dim_0__8)
dim__9 = opset18.Constant(value_int=1)
dim__9 = opset18.Constant(value_ints=[1])
dim_0__9 = opset18.Cast(dim__9, to=7)
unsqueeze_4 = opset18.Unsqueeze(_to_copy_5, dim_0__9)
mul_5 = transpose_1 * unsqueeze_3
Expand Down Expand Up @@ -222,10 +222,10 @@ def main_graph(
add_2 = mul_7 + mul_8
cat_3 = opset18.Concat(past_key_values_0_0, add_2, axis=-2)
cat_4 = opset18.Concat(past_key_values_0_1, transpose_3, axis=-2)
dim__10 = opset18.Constant(value_int=0)
dim__10 = opset18.Constant(value_ints=[0])
dim_0__10 = opset18.Cast(dim__10, to=7)
unsqueeze_5 = opset18.Unsqueeze(mul, dim_0__10)
dim__11 = opset18.Constant(value_int=1)
dim__11 = opset18.Constant(value_ints=[1])
dim_0__11 = opset18.Cast(dim__11, to=7)
unsqueeze_6 = opset18.Unsqueeze(unsqueeze_5, dim_0__11)
val_114 = opset18.Cast(0, to=7)
Expand Down
4 changes: 2 additions & 2 deletions onnxscript/rewriter/ort_fusions/cos_sin_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ def pattern(
sin = op.Sin(emb)
if self._cast:
sin = op.Cast(sin, to=dtype)
cos_4d = op.Unsqueeze(cos, 1) # convert
sin_4d = op.Unsqueeze(sin, 1)
cos_4d = op.Unsqueeze(cos, [1]) # convert
sin_4d = op.Unsqueeze(sin, [1])
return op.RotaryEmbedding(
x,
cos_4d,
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/rewriter/ort_fusions/cos_sin_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_cos_sin_fusion(self, name, test_data_constructor):
original_outputs = ort_run("original", model, inputs)
count = fuse_rotary_embedding(model)
self.assertGreater(count, 0)
count = fuse_cos_sin_cache(model)
count = fuse_cos_sin_cache(model, debug=True)
self.assertGreater(count, 0)
new_outputs = ort_run("optimized", model, inputs)
assert_allclose(new_outputs, original_outputs)
Expand Down
4 changes: 2 additions & 2 deletions onnxscript/rewriter/ort_fusions/gqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def pattern(
key_seq_BHkvTDh = op.Concat(past_key, key_BHkvSDh_rope, axis=-2)
# Concat with past_key is optional:
key_seq_BHkvTDh = pattern.OrValue([key_seq_BHkvTDh, key_BHkvSDh_rope])
key_seq_BHkv1TDh = op.Unsqueeze(key_seq_BHkvTDh, 2)
key_seq_BHkv1TDh = op.Unsqueeze(key_seq_BHkvTDh, [2])
key_seq_BHkvGTDh = op.Expand(key_seq_BHkv1TDh, pattern.ANY_VALUE)
key_seq_BHTDh = op.Reshape(
key_seq_BHkvGTDh, pattern.ANY_VALUE, _outputs=["key_seq_BHTDh"]
Expand All @@ -234,7 +234,7 @@ def pattern(
value_seq_BHkvTDh = op.Concat(past_value, value_BHkvSDh, axis=-2)
# Concat with past_value is optional:
value_seq_BHkvTDh = pattern.OrValue([value_seq_BHkvTDh, value_BHkvSDh])
value_seq_BHkv1TDh = op.Unsqueeze(value_seq_BHkvTDh, 2)
value_seq_BHkv1TDh = op.Unsqueeze(value_seq_BHkvTDh, [2])
value_seq_BHkvGTDh = op.Expand(value_seq_BHkv1TDh, pattern.ANY_VALUE)
value_seq_BHTDh = op.Reshape(
value_seq_BHkvGTDh, pattern.ANY_VALUE, _outputs=["value_seq_BHTDh"]
Expand Down
8 changes: 4 additions & 4 deletions onnxscript/rewriter/ort_fusions/gqa_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,11 @@ def gqa(query, key, value, past_key, past_value, cos, sin):
value_seq_BHkvSkvDh = op.Concat(past_value, value_BHkvSDh, axis=-2)

# Now, expand from shared heads to all heads
key_BHkv1SDh = op.Unsqueeze(key_seq_BHkvSkvDh, 2)
key_BHkv1SDh = op.Unsqueeze(key_seq_BHkvSkvDh, [2])
key_BHkvGSDh = op.Expand(key_BHkv1SDh, shape_BHkvGSDh)
key_BHSDh = op.Reshape(key_BHkvGSDh, shape_BHSDh)

value_BHkv1SDh = op.Unsqueeze(value_seq_BHkvSkvDh, 2)
value_BHkv1SDh = op.Unsqueeze(value_seq_BHkvSkvDh, [2])
value_BHkvGSDh = op.Expand(value_BHkv1SDh, shape_BHkvGSDh)
value_BHSDh = op.Reshape(value_BHkvGSDh, shape_BHSDh)

Expand Down Expand Up @@ -527,11 +527,11 @@ def gqa(query, key, value, past_key, past_value, cos, sin, query_scale, key_scal
value_seq_BHkvSkvDh = value_BHkvSDh

# Now, expand from shared heads to all heads
key_BHkv1SDh = op.Unsqueeze(key_seq_BHkvSkvDh, 2)
key_BHkv1SDh = op.Unsqueeze(key_seq_BHkvSkvDh, [2])
key_BHkvGSDh = op.Expand(key_BHkv1SDh, shape_BHkvGSDh)
key_BHSDh = op.Reshape(key_BHkvGSDh, shape_BHSDh)

value_BHkv1SDh = op.Unsqueeze(value_seq_BHkvSkvDh, 2)
value_BHkv1SDh = op.Unsqueeze(value_seq_BHkvSkvDh, [2])
value_BHkvGSDh = op.Expand(value_BHkv1SDh, shape_BHkvGSDh)
value_BHSDh = op.Reshape(value_BHkvGSDh, shape_BHSDh)

Expand Down
16 changes: 16 additions & 0 deletions onnxscript/rewriter/rules/common/_no_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ def _check(self, model_text: str) -> None:
self.assertEqual(count, 1)
self.assertEqual(model.graph[-1].op_type, "Identity")

def _check_no_optimization(self, model_text: str) -> None:
model = ir.from_onnx_text(model_text)
count = _no_op.rules.apply_to_model(model)
self.assertEqual(count, 0)

@parameterized.parameterized.expand(
[
("float one input", "float[M]", "value_float=1.0", "one, input"),
Expand Down Expand Up @@ -195,6 +200,17 @@ def test_dropout_zero_or_inference_no_op_with_initializer(self, _, attribute: st
)
# TODO: Test the negative cases

def test_broadcast_is_not_eliminated(self):
model_text = """
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[M] input) => (float[1, 1, M] output)
<float[1,1,1] zero = {0.0}>
{
output = Add(zero, input)
}
"""
self._check_no_optimization(model_text)


if __name__ == "__main__":
unittest.main()
Loading