Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
Expand Down Expand Up @@ -389,7 +390,8 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
// Load result or Store valye Type can be vector or scalar.
Type valOrResTy;
if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>)
valOrResTy = op.getResult().getType();
valOrResTy =
this->getTypeConverter()->convertType(op.getResult().getType());
else
valOrResTy = adaptor.getValue().getType();
VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy);
Expand Down Expand Up @@ -879,10 +881,30 @@ struct ConvertXeGPUToXeVMPass
}
return {};
};
typeConverter.addSourceMaterialization(memrefMaterializationCast);
typeConverter.addSourceMaterialization(ui64MaterializationCast);
typeConverter.addSourceMaterialization(ui32MaterializationCast);
typeConverter.addSourceMaterialization(vectorMaterializationCast);

// If result type of original op is single element vector and lowered type
// is scalar. This materialization cast creates a single element vector by
// broadcasting the scalar value.
auto singleElementVectorMaterializationCast =
[](OpBuilder &builder, Type type, ValueRange inputs,
Location loc) -> Value {
if (inputs.size() != 1)
return {};
auto input = inputs.front();
if (input.getType().isIntOrIndexOrFloat()) {
// If the input is a scalar, and the target type is a vector of single
// element, create a single element vector by broadcasting.
if (auto vecTy = dyn_cast<VectorType>(type)) {
if (vecTy.getNumElements() == 1) {
return vector::BroadcastOp::create(builder, loc, vecTy, input)
.getResult();
}
}
}
return {};
};
typeConverter.addSourceMaterialization(
singleElementVectorMaterializationCast);
typeConverter.addTargetMaterialization(memrefMaterializationCast);
typeConverter.addTargetMaterialization(ui32MaterializationCast);
typeConverter.addTargetMaterialization(ui64MaterializationCast);
Expand Down
29 changes: 23 additions & 6 deletions mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,36 @@ gpu.func @load_gather_i64_src_value_offset(%src: i64, %offset: vector<1xindex>)
// CHECK: %[[VAR4:.*]] = arith.addi %[[ARG0]], %[[VAR3]] : i64
// CHECK: %[[VAR5:.*]] = llvm.inttoptr %[[VAR4]] : i64 to !llvm.ptr<1>
// CHECK: %[[VAR6:.*]] = scf.if %[[VAR2]] -> (f16) {
// CHECK: %[[VAR7:.*]] = llvm.load %[[VAR5]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>} : !llvm.ptr<1> -> vector<1xf16>
// CHECK: %[[VAR8:.*]] = vector.extract %[[VAR7]][0] : f16 from vector<1xf16>
// CHECK: scf.yield %[[VAR8]] : f16
// CHECK: } else {
// CHECK: %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<1xf16>
// CHECK: %[[VAR7:.*]] = vector.extract %[[CST_0]][0] : f16 from vector<1xf16>
// CHECK: %[[VAR7:.*]] = llvm.load %[[VAR5]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>} : !llvm.ptr<1> -> f16
// CHECK: scf.yield %[[VAR7]] : f16
// CHECK: } else {
// CHECK: %[[CST_0:.*]] = arith.constant 0.000000e+00 : f16
// CHECK: scf.yield %[[CST_0]] : f16
// CHECK: }
%3 = xegpu.load %src[%offset], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
: i64, vector<1xindex>, vector<1xi1> -> vector<1xf16>
gpu.return
}
}

// -----
gpu.module @test {
// CHECK-LABEL: @source_materialize_single_elem_vec
// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: vector<1xindex>, %[[ARG2:.*]]: memref<1xf16>
gpu.func @source_materialize_single_elem_vec(%src: i64, %offset: vector<1xindex>, %dst: memref<1xf16>) {
%1 = arith.constant dense<1>: vector<1xi1>
%3 = xegpu.load %src[%offset], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
: i64, vector<1xindex>, vector<1xi1> -> vector<1xf16>
// CHECK: %[[VAR_IF:.*]] = scf.if
// CHECK: %[[VAR_RET:.*]] = vector.broadcast %[[VAR_IF]] : f16 to vector<1xf16>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: vector.store %[[VAR_RET]], %[[ARG2]][%[[C0]]] : memref<1xf16>, vector<1xf16>
%c0 = arith.constant 0 : index
vector.store %3, %dst[%c0] : memref<1xf16>, vector<1xf16>
gpu.return
}
}

// -----

gpu.module @test {
Expand Down