Skip to content

Commit

Permalink
Correctly model undefined behavior in {tensor|memref}.dim
Browse files Browse the repository at this point in the history
These operations have undefined behavior if the index is not less than the rank of the source tensor / memref, so they cannot be freely speculated like they were before this patch.  After this patch we speculate them only if we can prove that they don't have UB.

Depends on D135505.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D135748
  • Loading branch information
Sanjoy Das committed Oct 13, 2022
1 parent 5a52c5c commit adabce4
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 2 deletions.
5 changes: 4 additions & 1 deletion mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ def MemRef_DeallocOp : MemRef_Op<"dealloc", [MemRefsNormalizable]> {
def MemRef_DimOp : MemRef_Op<"dim", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
MemRefsNormalizable,
Pure,
ConditionallySpeculatable, NoMemoryEffect,
ShapedDimOpInterface]> {
let summary = "dimension index operation";
let description = [{
Expand Down Expand Up @@ -593,6 +593,9 @@ def MemRef_DimOp : MemRef_Op<"dim", [

/// Interface method of ShapedDimOpInterface: Return the dimension.
OpFoldResult getDimension() { return getIndex(); }

/// Interface method for ConditionallySpeculatable.
Speculation::Speculatability getSpeculatability();
}];

let hasCanonicalizer = 1;
Expand Down
5 changes: 4 additions & 1 deletion mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def Tensor_CastOp : Tensor_Op<"cast", [

def Tensor_DimOp : Tensor_Op<"dim", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
Pure,
ConditionallySpeculatable, NoMemoryEffect,
ShapedDimOpInterface]> {
let summary = "dimension index operation";
let description = [{
Expand Down Expand Up @@ -135,6 +135,9 @@ def Tensor_DimOp : Tensor_Op<"dim", [

/// Interface method of ShapedDimOpInterface: Return the dimension.
OpFoldResult getDimension() { return getIndex(); }

/// Interface method for ConditionallySpeculatable.
Speculation::Speculatability getSpeculatability();
}];

let hasCanonicalizer = 1;
Expand Down
14 changes: 14 additions & 0 deletions mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,20 @@ Optional<int64_t> DimOp::getConstantIndex() {
return {};
}

Speculation::Speculatability DimOp::getSpeculatability() {
auto constantIndex = getConstantIndex();
if (!constantIndex)
return Speculation::NotSpeculatable;

auto rankedSourceType = dyn_cast<MemRefType>(getSource().getType());
if (!rankedSourceType)
return Speculation::NotSpeculatable;

// The verifier rejects operations that violate this assertion.
assert(constantIndex < rankedSourceType.getRank());
return Speculation::Speculatable;
}

LogicalResult DimOp::verify() {
// Assume unknown index to be in range.
Optional<int64_t> index = getConstantIndex();
Expand Down
14 changes: 14 additions & 0 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,20 @@ Optional<int64_t> DimOp::getConstantIndex() {
return {};
}

Speculation::Speculatability DimOp::getSpeculatability() {
auto constantIndex = getConstantIndex();
if (!constantIndex)
return Speculation::NotSpeculatable;

auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().getType());
if (!rankedSourceType)
return Speculation::NotSpeculatable;

// The verifier rejects operations that violate this assertion.
assert(constantIndex < rankedSourceType.getRank());
return Speculation::Speculatable;
}

LogicalResult DimOp::verify() {
// Assume unknown index to be in range.
Optional<int64_t> index = getConstantIndex();
Expand Down
104 changes: 104 additions & 0 deletions mlir/test/Transforms/loop-invariant-code-motion.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -503,3 +503,107 @@ func.func @test_recursively_speculatable_op_failure(%lb: index, %ub: index, %ste

return
}

// -----

func.func @speculate_tensor_dim_unknown_rank_unknown_dim(
// CHECK-LABEL: @speculate_tensor_dim_unknown_rank_unknown_dim
%t: tensor<*xf32>, %dim_idx: index, %lb: index, %ub: index, %step: index) {
// CHECK: scf.for
// CHECK-NEXT: tensor.dim
scf.for %i = %lb to %ub step %step {
%val = tensor.dim %t, %dim_idx : tensor<*xf32>
}

return
}

func.func @speculate_tensor_dim_known_rank_unknown_dim(
// CHECK-LABEL: @speculate_tensor_dim_known_rank_unknown_dim
%t: tensor<?x?x?x?xf32>, %dim_idx: index, %lb: index, %ub: index, %step: index) {
// CHECK: scf.for
// CHECK-NEXT: tensor.dim
scf.for %i = %lb to %ub step %step {
%val = tensor.dim %t, %dim_idx : tensor<?x?x?x?xf32>
}

return
}

func.func @speculate_tensor_dim_unknown_rank_known_dim(
// CHECK-LABEL: @speculate_tensor_dim_unknown_rank_known_dim
%t: tensor<*xf32>, %dim_idx: index, %lb: index, %ub: index, %step: index) {
%c0 = arith.constant 0 : index
// CHECK: scf.for
// CHECK-NEXT: tensor.dim
scf.for %i = %lb to %ub step %step {
%val = tensor.dim %t, %c0 : tensor<*xf32>
}

return
}

func.func @speculate_tensor_dim_known_rank_known_dim_inbounds(
// CHECK-LABEL: @speculate_tensor_dim_known_rank_known_dim_inbounds
%t: tensor<?x?x?x?xf32>, %dim_idx: index, %lb: index, %ub: index, %step: index) {
%c1 = arith.constant 1 : index
// CHECK: tensor.dim
// CHECK-NEXT: scf.for
scf.for %i = %lb to %ub step %step {
%val = tensor.dim %t, %c1 : tensor<?x?x?x?xf32>
}

return
}

// -----

func.func @speculate_memref_dim_unknown_rank_unknown_dim(
// CHECK-LABEL: @speculate_memref_dim_unknown_rank_unknown_dim
%t: memref<*xf32>, %dim_idx: index, %lb: index, %ub: index, %step: index) {
// CHECK: scf.for
// CHECK-NEXT: memref.dim
scf.for %i = %lb to %ub step %step {
%val = memref.dim %t, %dim_idx : memref<*xf32>
}

return
}

func.func @speculate_memref_dim_known_rank_unknown_dim(
// CHECK-LABEL: @speculate_memref_dim_known_rank_unknown_dim
%t: memref<?x?x?x?xf32>, %dim_idx: index, %lb: index, %ub: index, %step: index) {
// CHECK: scf.for
// CHECK-NEXT: memref.dim
scf.for %i = %lb to %ub step %step {
%val = memref.dim %t, %dim_idx : memref<?x?x?x?xf32>
}

return
}

func.func @speculate_memref_dim_unknown_rank_known_dim(
// CHECK-LABEL: @speculate_memref_dim_unknown_rank_known_dim
%t: memref<*xf32>, %dim_idx: index, %lb: index, %ub: index, %step: index) {
%c0 = arith.constant 0 : index
// CHECK: scf.for
// CHECK-NEXT: memref.dim
scf.for %i = %lb to %ub step %step {
%val = memref.dim %t, %c0 : memref<*xf32>
}

return
}

func.func @speculate_memref_dim_known_rank_known_dim_inbounds(
// CHECK-LABEL: @speculate_memref_dim_known_rank_known_dim_inbounds
%t: memref<?x?x?x?xf32>, %dim_idx: index, %lb: index, %ub: index, %step: index) {
%c1 = arith.constant 1 : index
// CHECK: memref.dim
// CHECK-NEXT: scf.for
scf.for %i = %lb to %ub step %step {
%val = memref.dim %t, %c1 : memref<?x?x?x?xf32>
}

return
}

0 comments on commit adabce4

Please sign in to comment.