diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 644d27938177b..26ab46d99dec1 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1653,6 +1653,10 @@ def ConvertXeGPUToXeVMPass : Pass<"convert-xegpu-to-xevm"> { "memref::MemRefDialect", "arith::ArithDialect", "LLVM::LLVMDialect", "index::IndexDialect", "gpu::GPUDialect", "scf::SCFDialect"]; + let options = [Option<"use64bitIndex", "use-64bit-index", "bool", + /*default=*/"true", + "Use 64-bit integers to convert index types">, + ]; } #endif // MLIR_CONVERSION_PASSES diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 6df209438447b..d09b565e34a3c 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -1049,14 +1049,25 @@ struct ConvertXeGPUToXeVMPass using Base::Base; void runOnOperation() override { - LLVMTypeConverter typeConverter(&getContext()); + MLIRContext *context = &getContext(); + + // XeVM type converter is based on LLVM type converter with the + // following customizations. + // First, type conversion rules are added for xegpu custom types, + // TensorDescType and MemDescType. + // Second, MemRefType is lowered to single integer type + // Third, VectorType of single element or 0D is converted to vector + // element type. Otherwise, vector type is flatten to 1D. + LowerToLLVMOptions options(context); + options.overrideIndexBitwidth(this->use64bitIndex ? 64 : 32); + LLVMTypeConverter typeConverter(context, options); + + Type xevmIndexType = typeConverter.convertType(IndexType::get(context)); + Type i32Type = IntegerType::get(context, 32); typeConverter.addConversion([&](VectorType type) -> Type { - unsigned rank = type.getRank(); - auto elemType = type.getElementType(); - // If the element type is index, convert it to i64. - if (llvm::isa(elemType)) - elemType = IntegerType::get(&getContext(), 64); + auto elemType = typeConverter.convertType(type.getElementType()); // If the vector rank is 0 or has a single element, return the element + unsigned rank = type.getRank(); if (rank == 0 || type.getNumElements() == 1) return elemType; // Otherwise, convert the vector to a flat vector type. @@ -1068,17 +1079,21 @@ struct ConvertXeGPUToXeVMPass if (type.isScattered()) return {}; if (type.getRank() == 1) - return IntegerType::get(&getContext(), 64); - auto i32Type = IntegerType::get(&getContext(), 32); + return xevmIndexType; return VectorType::get(8, i32Type); }); + // SLM access related type conversions. + // TODO: LLVM DLTI provides clean way of representing different pointer size + // based on address space. Currently pointer size of SLM access is hard + // coded to 32bit. Update to use DLTI when switching overall XeGPU lowering + // to use DLTI instead of use64bitIndex option used above. + // Convert MemDescType into i32 for SLM - typeConverter.addConversion([&](xegpu::MemDescType type) -> Type { - return IntegerType::get(&getContext(), 32); - }); + typeConverter.addConversion( + [&](xegpu::MemDescType type) -> Type { return i32Type; }); typeConverter.addConversion([&](MemRefType type) -> Type { - return IntegerType::get(&getContext(), (isSharedMemRef(type) ? 32 : 64)); + return isSharedMemRef(type) ? i32Type : xevmIndexType; }); // LLVM type converter puts unrealized casts for the following cases: @@ -1188,9 +1203,9 @@ struct ConvertXeGPUToXeVMPass return {}; }; - // Materialization to convert - // - bitcast vector of same rank - // - shape vector of different rank but same element type + // Materialization to convert between vector types + // - Add shape cast for different shapes + // - Add bitcast for different element types // Applies to both source and target materialization. auto vectorToVectorMaterializationCast = [](OpBuilder &builder, Type type, ValueRange inputs, @@ -1200,17 +1215,22 @@ struct ConvertXeGPUToXeVMPass auto input = inputs.front(); if (auto vecTy = dyn_cast(input.getType())) { if (auto targetVecTy = dyn_cast(type)) { - // If the target type is a vector of same rank, - // bitcast to the target type. - if (targetVecTy.getRank() == vecTy.getRank()) - return vector::BitCastOp::create(builder, loc, targetVecTy, input) - .getResult(); - else if (targetVecTy.getElementType() == vecTy.getElementType()) { - // If the target type is a vector of different rank but same element - // type, reshape to the target type. - return vector::ShapeCastOp::create(builder, loc, targetVecTy, input) - .getResult(); + Value cast = input; + // If the target type has a different shape, add a shape cast + // If the target type has a different element type, add a bitcast + if (targetVecTy.getShape() != vecTy.getShape()) { + cast = vector::ShapeCastOp::create( + builder, loc, + VectorType::get(targetVecTy.getShape(), + vecTy.getElementType()), + cast) + .getResult(); } + if (targetVecTy.getElementType() != vecTy.getElementType()) { + cast = vector::BitCastOp::create(builder, loc, targetVecTy, cast) + .getResult(); + } + return cast; } } return {}; @@ -1226,26 +1246,31 @@ struct ConvertXeGPUToXeVMPass return {}; auto input = inputs.front(); if (auto vecTy = dyn_cast(input.getType())) { - if (type == vecTy.getElementType() || - ((vecTy.getElementType() == builder.getIndexType()) && - type.isInteger())) { - // If the vector rank is 0 or has a single element, - // extract scalar of target type. - auto rank = vecTy.getRank(); - Value cast; - if (rank == 0) { - cast = - vector::ExtractOp::create(builder, loc, input, {}).getResult(); - } else { - cast = vector::ExtractOp::create(builder, loc, input, - SmallVector(rank, 0)) - .getResult(); - } - if (type != vecTy.getElementType()) - cast = arith::IndexCastUIOp::create(builder, loc, type, cast) - .getResult(); - return cast; + // Source needs to be single element vector + auto rank = vecTy.getRank(); + if (rank != 0 && vecTy.getNumElements() != 1) + return {}; + auto inElemTy = vecTy.getElementType(); + // extract scalar + Value cast = input; + if (rank == 0) { + cast = vector::ExtractOp::create(builder, loc, cast, {}).getResult(); + } else { + cast = vector::ExtractOp::create(builder, loc, cast, + SmallVector(rank, 0)) + .getResult(); + } + // Extracted element type may need conversion + // Two cases + // 1. Index type to integer type + // 2. Other element type mismatch + if (inElemTy.isIndex()) { + cast = arith::IndexCastUIOp::create(builder, loc, type, cast) + .getResult(); + } else if (inElemTy != type) { + cast = arith::BitcastOp::create(builder, loc, type, cast).getResult(); } + return cast; } return {}; }; @@ -1254,7 +1279,8 @@ struct ConvertXeGPUToXeVMPass // - single element of vector element type to single element vector // If result type of original op is single element vector and lowered type // is scalar. This materialization cast creates a single element vector by - // broadcasting the scalar value. + // First convert element type if needed and then broadcast to single + // element vector. // Applies only to source materialization. auto singleElementToVectorMaterializationCast = [](OpBuilder &builder, Type type, ValueRange inputs, @@ -1262,21 +1288,26 @@ struct ConvertXeGPUToXeVMPass if (inputs.size() != 1) return {}; auto input = inputs.front(); + auto inTy = input.getType(); + if (!inTy.isIntOrFloat()) + return {}; // If the target type is a vector of rank 0 or single element vector // of element type matching input type, broadcast input to target type. if (auto vecTy = dyn_cast(type)) { - if (vecTy.getRank() == 0 || vecTy.getNumElements() == 1) { - if (input.getType() == vecTy.getElementType()) { - return vector::BroadcastOp::create(builder, loc, vecTy, input) - .getResult(); - } else if (vecTy.getElementType() == builder.getIndexType()) { - Value cast = arith::IndexCastUIOp::create( - builder, loc, builder.getIndexType(), input) - .getResult(); - return vector::BroadcastOp::create(builder, loc, vecTy, cast) - .getResult(); - } + if (vecTy.getRank() != 0 && vecTy.getNumElements() != 1) + return {}; + auto outElemTy = vecTy.getElementType(); + Value cast = input; + if (outElemTy.isIndex()) { + cast = arith::IndexCastUIOp::create(builder, loc, + builder.getIndexType(), cast) + .getResult(); + } else if (inTy != outElemTy) { + cast = arith::BitcastOp::create(builder, loc, outElemTy, cast) + .getResult(); } + return vector::BroadcastOp::create(builder, loc, vecTy, cast) + .getResult(); } return {}; }; @@ -1289,14 +1320,14 @@ struct ConvertXeGPUToXeVMPass typeConverter.addTargetMaterialization( vectorToSingleElementMaterializationCast); typeConverter.addTargetMaterialization(vectorToVectorMaterializationCast); - ConversionTarget target(getContext()); + ConversionTarget target(*context); target.addLegalDialect(); target.addIllegalDialect(); - RewritePatternSet patterns(&getContext()); + RewritePatternSet patterns(context); populateXeGPUToXeVMConversionPatterns(typeConverter, patterns); scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter, patterns, target); diff --git a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp index fbb7bb8aeb4bc..fb260f45e5ddd 100644 --- a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp +++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp @@ -97,7 +97,10 @@ void buildGPUPassPipeline(OpPassManager &pm, pm.addNestedPass(xegpu::createXeGPUVectorLinearize()); } pm.addNestedPass(createConvertMathToXeVM()); - pm.addNestedPass(createConvertXeGPUToXeVMPass()); + ConvertXeGPUToXeVMPassOptions xegpuToXeVMOptions; + xegpuToXeVMOptions.use64bitIndex = options.use64bitIndex; + pm.addNestedPass( + createConvertXeGPUToXeVMPass(xegpuToXeVMOptions)); { ConvertGpuOpsToLLVMSPVOpsOptions gpuToLLVMSPVOptions; gpuToLLVMSPVOptions.use64bitIndex = options.use64bitIndex;