From 09c6f835f0d8202e0cb47e45a587e7248f103cb8 Mon Sep 17 00:00:00 2001 From: zjgarvey Date: Mon, 14 Oct 2024 13:06:33 -0500 Subject: [PATCH 1/2] Add AtenSliceTOp Canonicalization to SimplifyShapeCalculations pass --- .../Transforms/SimplifyShapeCalculations.cpp | 1 + .../Torch/simplify-shape-calculations.mlir | 39 +++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp b/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp index 6d2008a28407..281704e6548f 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp @@ -141,6 +141,7 @@ class SimplifyShapeCalculationsPass AtenSizeOp::getCanonicalizationPatterns(patterns, context); AtenLenTOp::getCanonicalizationPatterns(patterns, context); AtenAddTOp::getCanonicalizationPatterns(patterns, context); + AtenSliceTOp::getCanonicalizationPatterns(patterns, context); // TODO: Debug visitation order to make this more efficient. // A single linear scan should suffice. diff --git a/test/Dialect/Torch/simplify-shape-calculations.mlir b/test/Dialect/Torch/simplify-shape-calculations.mlir index b7e7cf17ba0e..aa440ac9de40 100644 --- a/test/Dialect/Torch/simplify-shape-calculations.mlir +++ b/test/Dialect/Torch/simplify-shape-calculations.mlir @@ -489,3 +489,42 @@ func.func @shape_calc_with_two_uses(%arg0: !torch.vtensor<[2],f32>) -> !torch.vt return %arg0 : !torch.vtensor<[2],f32> } + +// CHECK-LABEL: func.func @unflat_shape_partial_dyn +func.func @unflat_shape_partial_dyn(%arg0: !torch.vtensor<[?,?,3072],f32>) -> !torch.vtensor<[?,?,4,?],f32> { + %int768 = torch.constant.int 768 + // CHECK-DAG: %[[INT768:.*]] = torch.constant.int 768 + %int3072 = torch.constant.int 3072 + %int0 = torch.constant.int 0 + // CHECK-DAG: %[[INT0:.*]] = torch.constant.int 0 + %int3 = torch.constant.int 3 + %int1 = torch.constant.int 1 + // CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1 + %none = torch.constant.none + %int-1 = torch.constant.int -1 + %int2 = torch.constant.int 2 + %int4 = torch.constant.int 4 + // CHECK-DAG: %[[INT4:.*]] = torch.constant.int 4 + %0 = torch.prim.ListConstruct %int4, %int-1 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.shape.calculate { + %2 = torch.aten.unflatten.int %arg0, %int2, %0 : !torch.vtensor<[?,?,3072],f32>, !torch.int, !torch.list -> !torch.vtensor<[?,?,4,?],f32> + torch.shape.calculate.yield %2 : !torch.vtensor<[?,?,4,?],f32> + } shapes { + %2 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,3072],f32>, !torch.int -> !torch.int + %3 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?,3072],f32>, !torch.int -> !torch.int + %4 = torch.prim.ListConstruct %2, %3, %int3072 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5 = torch.prim.ListConstruct %int4, %int768 : (!torch.int, !torch.int) -> !torch.list + %6 = torch.aten.slice.t %4, %none, %int2, %int1 : !torch.list, !torch.none, !torch.int, !torch.int -> !torch.list + %7 = torch.aten.add.t %6, %5 : !torch.list, !torch.list -> !torch.list + %8 = torch.aten.slice.t %4, %int3, %none, %int1 : !torch.list, !torch.int, !torch.none, !torch.int -> !torch.list + %9 = torch.aten.add.t %7, %8 : !torch.list, !torch.list -> !torch.list + torch.shape.calculate.yield.shapes %9 : !torch.list + } : !torch.vtensor<[?,?,4,?],f32> + // CHECK : } shapes { + // CHECK : %[[SZE0:.*]] = torch.aten.size.int %arg0, %[[INT0]] : !torch.vtensor<[?,?,3072],f32>, !torch.int -> !torch.int + // CHECK : %[[SZE1:.*]] = torch.aten.size.int %arg0, %[[INT1]] : !torch.vtensor<[?,?,3072],f32>, !torch.int -> !torch.int + // CHECK : %[[LIST:.*]] = torch.prim.ListConstruct %[[SZE0]], %[[SZE1]], %[[INT4]], %[[INT768]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK : torch.shape.calculate.yield.shapes %[[LIST]] : !torch.list + // CHECK : } : !torch.vtensor<[?,?,4,768],f32> + return %1 : !torch.vtensor<[?,?,4,?],f32> +} From e2be336c0898a94c27751abedfc6f99b67c055c7 Mon Sep 17 00:00:00 2001 From: zjgarvey Date: Mon, 14 Oct 2024 14:19:40 -0500 Subject: [PATCH 2/2] Move checks to be consistent with rest of file --- .../Torch/simplify-shape-calculations.mlir | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/test/Dialect/Torch/simplify-shape-calculations.mlir b/test/Dialect/Torch/simplify-shape-calculations.mlir index aa440ac9de40..59884616f13f 100644 --- a/test/Dialect/Torch/simplify-shape-calculations.mlir +++ b/test/Dialect/Torch/simplify-shape-calculations.mlir @@ -491,20 +491,26 @@ func.func @shape_calc_with_two_uses(%arg0: !torch.vtensor<[2],f32>) -> !torch.vt } // CHECK-LABEL: func.func @unflat_shape_partial_dyn +// CHECK-DAG: %[[INT768:.*]] = torch.constant.int 768 +// CHECK-DAG: %[[INT0:.*]] = torch.constant.int 0 +// CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1 +// CHECK-DAG: %[[INT4:.*]] = torch.constant.int 4 +// CHECK : } shapes { +// CHECK : %[[SZE0:.*]] = torch.aten.size.int %arg0, %[[INT0]] : !torch.vtensor<[?,?,3072],f32>, !torch.int -> !torch.int +// CHECK : %[[SZE1:.*]] = torch.aten.size.int %arg0, %[[INT1]] : !torch.vtensor<[?,?,3072],f32>, !torch.int -> !torch.int +// CHECK : %[[LIST:.*]] = torch.prim.ListConstruct %[[SZE0]], %[[SZE1]], %[[INT4]], %[[INT768]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK : torch.shape.calculate.yield.shapes %[[LIST]] : !torch.list +// CHECK : } : !torch.vtensor<[?,?,4,768],f32> func.func @unflat_shape_partial_dyn(%arg0: !torch.vtensor<[?,?,3072],f32>) -> !torch.vtensor<[?,?,4,?],f32> { %int768 = torch.constant.int 768 - // CHECK-DAG: %[[INT768:.*]] = torch.constant.int 768 %int3072 = torch.constant.int 3072 %int0 = torch.constant.int 0 - // CHECK-DAG: %[[INT0:.*]] = torch.constant.int 0 %int3 = torch.constant.int 3 %int1 = torch.constant.int 1 - // CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1 %none = torch.constant.none %int-1 = torch.constant.int -1 %int2 = torch.constant.int 2 %int4 = torch.constant.int 4 - // CHECK-DAG: %[[INT4:.*]] = torch.constant.int 4 %0 = torch.prim.ListConstruct %int4, %int-1 : (!torch.int, !torch.int) -> !torch.list %1 = torch.shape.calculate { %2 = torch.aten.unflatten.int %arg0, %int2, %0 : !torch.vtensor<[?,?,3072],f32>, !torch.int, !torch.list -> !torch.vtensor<[?,?,4,?],f32> @@ -520,11 +526,5 @@ func.func @unflat_shape_partial_dyn(%arg0: !torch.vtensor<[?,?,3072],f32>) -> !t %9 = torch.aten.add.t %7, %8 : !torch.list, !torch.list -> !torch.list torch.shape.calculate.yield.shapes %9 : !torch.list } : !torch.vtensor<[?,?,4,?],f32> - // CHECK : } shapes { - // CHECK : %[[SZE0:.*]] = torch.aten.size.int %arg0, %[[INT0]] : !torch.vtensor<[?,?,3072],f32>, !torch.int -> !torch.int - // CHECK : %[[SZE1:.*]] = torch.aten.size.int %arg0, %[[INT1]] : !torch.vtensor<[?,?,3072],f32>, !torch.int -> !torch.int - // CHECK : %[[LIST:.*]] = torch.prim.ListConstruct %[[SZE0]], %[[SZE1]], %[[INT4]], %[[INT768]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - // CHECK : torch.shape.calculate.yield.shapes %[[LIST]] : !torch.list - // CHECK : } : !torch.vtensor<[?,?,4,768],f32> return %1 : !torch.vtensor<[?,?,4,?],f32> }