Skip to content

Commit

Permalink
[mlir][memref] Missing type conversion in memref.reshape llvm lowering
Browse files Browse the repository at this point in the history
Shape can be memref of index type, so memref::LoadOp result need to be converted into llvm type.

Differential Revision: https://reviews.llvm.org/D129965
  • Loading branch information
Hardcode84 committed Jul 21, 2022
1 parent 70056d0 commit d4217e6
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
5 changes: 5 additions & 0 deletions mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
Expand Up @@ -1161,6 +1161,11 @@ struct MemRefReshapeOpLowering
Value shapeOp = reshapeOp.getShape();
Value index = createIndexConstant(rewriter, loc, i);
dimSize = rewriter.create<memref::LoadOp>(loc, shapeOp, index);
Type indexType = getIndexType();
if (dimSize.getType() != indexType)
dimSize = typeConverter->materializeTargetConversion(
rewriter, loc, indexType, dimSize);
assert(dimSize && "Invalid memref element type");
}

desc.setSize(rewriter, loc, i, dimSize);
Expand Down
32 changes: 32 additions & 0 deletions mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
Expand Up @@ -306,3 +306,35 @@ func.func @memref.reshape.dynamic.dim(%arg: memref<?x?x?xf32>, %shape: memref<4x
return %0 : memref<?x?x12x32xf32>
// CHECK: return %[[result_cast]] : memref<?x?x12x32xf32>
}

// -----

// CHECK-LABEL: func @memref.reshape_index
// CHECK-SAME: %[[arg:.*]]: memref<?x?xi32>, %[[shape:.*]]: memref<1xindex>
func.func @memref.reshape_index(%arg0: memref<?x?xi32>, %shape: memref<1xindex>) -> memref<?xi32> {
// CHECK: %[[arg_cast:.*]] = builtin.unrealized_conversion_cast %[[arg]] : memref<?x?xi32> to !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[shape_cast:.*]] = builtin.unrealized_conversion_cast %[[shape]] : memref<1xindex> to !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[undef:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[alloc_ptr:.*]] = llvm.extractvalue %[[arg_cast]][0] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[align_ptr:.*]] = llvm.extractvalue %[[arg_cast]][1] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[insert0:.*]] = llvm.insertvalue %[[alloc_ptr]], %[[undef:.*]][0] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[insert1:.*]] = llvm.insertvalue %[[align_ptr]], %[[insert0:.*]][1] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<1 x i64>, array<1 x i64>)>

// CHECK: %[[zero0:.*]] = llvm.mlir.constant(0 : index) : i64
// CHECK: %[[insert2:.*]] = llvm.insertvalue %[[zero0]], %[[insert1:.*]][2] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<1 x i64>, array<1 x i64>)>

// CHECK: %[[one0:.*]] = llvm.mlir.constant(1 : index) : i64
// CHECK: %[[zero1:.*]] = llvm.mlir.constant(0 : index) : i64

// CHECK: %[[shape_ptr0:.*]] = llvm.extractvalue %[[shape_cast:.*]][1] : !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[shape_gep0:.*]] = llvm.getelementptr %[[shape_ptr0:.*]][%[[zero1:.*]]] : (!llvm.ptr<i64>, i64) -> !llvm.ptr<i64>
// CHECK: %[[shape_load0:.*]] = llvm.load %[[shape_gep0:.*]] : !llvm.ptr<i64>
// CHECK: %[[insert3:.*]] = llvm.insertvalue %[[shape_load0:.*]], %[[insert2:.*]][3, 0] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[insert4:.*]] = llvm.insertvalue %[[one0:.*]], %[[insert3:.*]][4, 0] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<1 x i64>, array<1 x i64>)>

// CHECK: %[[result_cast:.*]] = builtin.unrealized_conversion_cast %[[insert4:.*]] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<1 x i64>, array<1 x i64>)> to memref<?xi32>
// CHECK: return %[[result_cast:.*]] : memref<?xi32>

%1 = memref.reshape %arg0(%shape) : (memref<?x?xi32>, memref<1xindex>) -> memref<?xi32>
return %1 : memref<?xi32>
}

0 comments on commit d4217e6

Please sign in to comment.