Skip to content

Commit 52b6d47

Browse files
authored
Handle bounds in the CholeskyOp shape function (#887)
Cholesky op's [constraint (C3)](https://github.com/openxla/stablehlo/blob/main/docs/spec.md#cholesky) requires the last two dimensions to be the same. The bound rules are as follows (3 being an arbitrary static shape): ``` dynamic type inference rules for the last two dims of A (case0 is sanity check for dynamic batch dims): dim R-2 | dim R-1 | inferred R-2 | inferred R-1 case0: 3 | 3 | 3 | 3 case1: ? | 3 | 3 | 3 case2: ? | ? | ? | ? dynamic bound infererence rules for the last two dims of A (case0 is sanity check for dynamic batch dims): dim R-2 | dim R-1 | inferred R-2 | inferred R-1 case0: 3, ? | 3, ? | 3, ? | 3, ? case1: ?, ? | 3, ? | 3, ? | 3, ? case2: ?, A<3 | 3, ? | error | error case3: ?, A>=3 | 3, ? | 3, ? | 3, ? case4: ?, ? | ?, ? | ?, ? | ?, ? case5: ?, A | ?, ? | ?, A | ?, A case6: ?, A | ?, B | ?, min(A,B) | ?, min(A,B) ``` The rules proposed above are reviewed and agreed upon, but there will be no immediate implementation in the current PR to follow recently updated guidelines. closes #804
1 parent d2cc1fa commit 52b6d47

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

stablehlo/dialect/TypeInference.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1443,8 +1443,10 @@ LogicalResult inferCholeskyOp(
14431443
return emitOptionalError(
14441444
location, "minor dimensions of 'a' must have equal size, got shape ",
14451445
aShape, ".");
1446+
14461447
inferredReturnShapes.emplace_back(aRankedType.getShape(),
1447-
aRankedType.getElementType());
1448+
aRankedType.getElementType(),
1449+
aRankedType.getEncoding());
14481450
return success();
14491451
}
14501452

stablehlo/tests/infer_stablehlo.mlir

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,8 @@ func.func @pad(%arg0: tensor<1x2x3xf16>, %arg1: tensor<f16>) -> tensor<2x4x7xind
8888
// CHECK-LABEL: @cholesky
8989
func.func @cholesky(%arg0: tensor<1x2x2xf32>) -> tensor<1x2x2xindex> {
9090
%0 = "stablehlo.cholesky"(%arg0) { lower = true } : (tensor<1x2x2xf32>) -> tensor<1x2x2xf32>
91-
%1 = "hlo_test_infer.get_return_type_components"(%0)
92-
: (tensor<1x2x2xf32>) -> tensor<1x2x2xindex>
93-
// CHECK: %1 = "hlo_test_infer.return_type_components"(%0) {dims0 = "[1, 2, 2]", element_type0 = f32} : (tensor<1x2x2xf32>) -> tensor<1x2x2xindex>
91+
%1 = "hlo_test_infer.get_return_type_components"(%0) : (tensor<1x2x2xf32>) -> tensor<1x2x2xindex>
92+
// CHECK: %1 = "hlo_test_infer.return_type_components"(%0) {dims0 = "[1, 2, 2]", element_type0 = f32} : (tensor<1x2x2xf32>) -> tensor<1x2x2xindex>
9493
func.return %1: tensor<1x2x2xindex>
9594
}
9695

@@ -378,7 +377,7 @@ func.func @dynamic_update_slice(%arg0: tensor<4x4xi32>, %arg1: tensor<2x2xi32>,
378377

379378
// -----
380379

381-
func.func @dynamic_update_slice(%input: tensor<3x?x?xi64, #stablehlo.type_extensions<bounds = [?, ?, 5]>>, %update: tensor<1x4x3xi64>, %start1: tensor<i64>, %start2: tensor<i64>, %start3 : tensor<i64>) -> tensor<*xindex> {
380+
func.func @dynamic_update_slice(%input: tensor<3x?x?xi64, #stablehlo.type_extensions<bounds = [?, ?, 5]>>, %update: tensor<1x4x3xi64>, %start1: tensor<i64>, %start2: tensor<i64>, %start3 : tensor<i64>) -> tensor<*xindex> {
382381
%0 = "stablehlo.dynamic_update_slice"(%input, %update, %start1, %start2, %start3) : (tensor<3x?x?xi64, #stablehlo.type_extensions<bounds = [?, ?, 5]>>, tensor<1x4x3xi64>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<3x?x?xi64>
383382
%1 = "hlo_test_infer.get_return_types"(%0) : (tensor<3x?x?xi64>) -> tensor<*xindex>
384383
// CHECK: types0 = tensor<3x?x?xi64, #stablehlo.type_extensions<bounds = [?, ?, 5]>>
@@ -1037,6 +1036,14 @@ func.func @pad(%arg0: tensor<?x48x48x32xf32>) -> tensor<4xindex> {
10371036

10381037
// -----
10391038

1039+
// CHECK-LABEL: func @cholesky_bounds
1040+
func.func @cholesky_bounds(%input: tensor<2x?x?xf32, #stablehlo.type_extensions<bounds = [?, 5, ?]>>) -> tensor<*xindex> {
1041+
%0 = "stablehlo.cholesky"(%input) { lower = true } : (tensor<2x?x?xf32, #stablehlo.type_extensions<bounds = [?, 5, ?]>>) -> tensor<*xf32>
1042+
// CHECK: types0 = tensor<2x?x?xf32, #stablehlo.type_extensions<bounds = [?, 5, ?]>>
1043+
%1 = "hlo_test_infer.get_return_types"(%0) : (tensor<*xf32>) -> tensor<*xindex>
1044+
func.return %1 : tensor<*xindex>
1045+
}
1046+
10401047
// CHECK-LABEL: func @concatenate
10411048
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x?xi32>, %[[ARG1:.*]]: tensor<?x?xi32>, %[[ARG2:.*]]: tensor<?x?xi32>
10421049
func.func @concatenate(%arg0: tensor<?x?xi32>, %arg1: tensor<?x?xi32>, %arg2: tensor<?x?xi32>) -> tensor<2xindex> {
@@ -1165,7 +1172,7 @@ func.func @broadcast(%arg0: tensor<?xi32>) -> tensor<3xindex> {
11651172

11661173
// CHECK-LABEL: func @transpose
11671174
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x?x?x?xi32>
1168-
func.func @transpose(%arg0: tensor<?x?x?x?xi32>) -> tensor<4xindex> {
1175+
func.func @transpose(%arg0: tensor<?x?x?x?xi32>) -> tensor<4xindex> {
11691176
// CHECK: %[[C0:.*]] = arith.constant 0 : index
11701177
// CHECK: %[[C1:.*]] = arith.constant 1 : index
11711178
// CHECK: %[[C2:.*]] = arith.constant 2 : index

0 commit comments

Comments
 (0)