Skip to content

Commit

Permalink
[mlir][spirv] Support conversion of extract op from vector<1xT> type
Browse files Browse the repository at this point in the history
Differential Revision: https://reviews.llvm.org/D100814
  • Loading branch information
ThomasRaoux committed Apr 20, 2021
1 parent da76462 commit b2e72cd
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
5 changes: 5 additions & 0 deletions mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
Expand Up @@ -89,6 +89,11 @@ struct VectorExtractOpConvert final
return failure();

vector::ExtractOp::Adaptor adaptor(operands);
if (adaptor.vector().getType().isa<spirv::ScalarType>()) {
rewriter.replaceOp(extractOp, adaptor.vector());
return success();
}

int32_t id = getFirstIntValue(extractOp.position());
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
extractOp, adaptor.vector(), id);
Expand Down
18 changes: 18 additions & 0 deletions mlir/test/Conversion/VectorToSPIRV/simple.mlir
Expand Up @@ -40,6 +40,24 @@ func @extract(%arg0 : vector<2xf32>) {

// -----

module attributes { spv.target_env = #spv.target_env<#spv.vce<v1.0, [Float16], []>, {}> } {

// CHECK-LABEL: func @extract_scalar
// CHECK-SAME: %[[ARG0:.+]]: vector<2xf16>
// CHECK-SAME: %[[ARG1:.+]]: vector<4xf32>
// CHECK: %[[S:.+]] = spv.Bitcast %[[ARG0]] : vector<2xf16> to f32
// CHECK: spv.CompositeInsert %[[S]], %[[ARG1]][0 : i32] : f32 into vector<4xf32>
func @extract_scalar(%arg0 : vector<2xf16>, %arg1 : vector<4xf32>) {
%0 = vector.bitcast %arg0 : vector<2xf16> to vector<1xf32>
%1 = vector.extract %0[0] : vector<1xf32>
%2 = vector.insert %1, %arg1[0] : f32 into vector<4xf32>
spv.Return
}

} // end module

// -----

// CHECK-LABEL: extract_insert
// CHECK-SAME: %[[V:.*]]: vector<4xf32>
// CHECK: %[[S:.*]] = spv.CompositeExtract %[[V]][1 : i32] : vector<4xf32>
Expand Down

0 comments on commit b2e72cd

Please sign in to comment.