From 2f8ac036654fb4292ab59cdba9fb16f70a157c76 Mon Sep 17 00:00:00 2001 From: Ettore Tiotto Date: Fri, 7 Jun 2024 16:09:00 -0700 Subject: [PATCH] Fix #1159 Signed-off-by: Ettore Tiotto --- third_party/intel/backend/compiler.py | 12 +++++++----- .../intel/include/TritonIntelGPUToLLVM/Passes.td | 6 ++++++ .../include/TritonIntelGPUToLLVM/TypeConverter.h | 2 +- .../intel/lib/TritonIntelGPUToLLVM/PipelineManager.h | 4 ++-- .../lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp | 6 +++--- .../intel/lib/TritonIntelGPUToLLVM/TypeConverter.cpp | 8 +++----- third_party/intel/triton_xpu.cc | 4 ++-- 7 files changed, 24 insertions(+), 18 deletions(-) diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index 34ac8fe78..fe55d9c41 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -48,7 +48,6 @@ class XPUOptions: extern_libs: dict = None debug: bool = False backend_name: str = 'intel' - is_block_ptr_enabled: bool = os.getenv("TRITON_INTEL_ENABLE_BLOCK_PTR", "0") == "1" def __post_init__(self): default_libdir = Path(__file__).parent / 'lib' @@ -71,7 +70,7 @@ class XPUBackend(BaseBackend): class Experimental: @staticmethod - def make_ttgir(mod, metadata, opt, device_arch): + def make_ttgir(mod, metadata, opt): pm = ir.pass_manager(mod.context) pm.enable_debug() @@ -152,8 +151,9 @@ def make_ttir(mod, metadata, opt): @staticmethod def make_ttgir(mod, metadata, opt, device_arch): - if XPUOptions.is_block_ptr_enabled: - return XPUBackend.Experimental.make_ttgir(mod, metadata, opt, device_arch) + is_lts = Version(metadata["target"].arch['driver_version']) == Version("1.3.27642") + if (not is_lts and os.getenv("TRITON_INTEL_ENABLE_BLOCK_PTR", "0") == "1"): + return XPUBackend.Experimental.make_ttgir(mod, metadata, opt) # TTIR -> TTGIR pm = ir.pass_manager(mod.context) @@ -187,6 +187,8 @@ def make_ttgir(mod, metadata, opt, device_arch): @staticmethod def make_llir(src, metadata, options): + is_lts = Version(metadata["target"].arch['driver_version']) == Version("1.3.27642") + # warp-specialization mutates num_warps num_warp_groups = src.get_int_attr("triton_gpu.num-warp-groups-per-cta") if num_warp_groups is not None: @@ -201,7 +203,7 @@ def make_llir(src, metadata, options): passes.convert.add_scf_to_cf(pm) passes.convert.add_index_to_llvmir(pm) intel.passes.ttgpuir.add_allocate_shared_memory(pm) - intel.passes.ttgpuir.add_to_llvmir(pm) + intel.passes.ttgpuir.add_to_llvmir(pm, is_lts) passes.convert.add_arith_to_llvmir(pm) passes.common.add_canonicalizer(pm) passes.common.add_cse(pm) diff --git a/third_party/intel/include/TritonIntelGPUToLLVM/Passes.td b/third_party/intel/include/TritonIntelGPUToLLVM/Passes.td index 5a85f0bf8..9b1ef01c8 100644 --- a/third_party/intel/include/TritonIntelGPUToLLVM/Passes.td +++ b/third_party/intel/include/TritonIntelGPUToLLVM/Passes.td @@ -26,6 +26,12 @@ def ConvertTritonIntelGPUToLLVM "mlir::triton::TritonDialect", "mlir::triton::gpu::TritonGPUDialect", "mlir::triton::TritonGEN::TritonGENDialect"]; + + let options = [ + Option<"isLTSDriver", "is-lts-driver", + "bool", /*default*/"false", + "Target is LTS driver or not">, + ]; } #endif // TRITONINTELGPU_CONVERSION_PASSES diff --git a/third_party/intel/include/TritonIntelGPUToLLVM/TypeConverter.h b/third_party/intel/include/TritonIntelGPUToLLVM/TypeConverter.h index 2ca21ab34..92df7a3f3 100644 --- a/third_party/intel/include/TritonIntelGPUToLLVM/TypeConverter.h +++ b/third_party/intel/include/TritonIntelGPUToLLVM/TypeConverter.h @@ -18,7 +18,7 @@ class TritonIntelGPUToLLVMTypeConverter : public TritonGPUToLLVMTypeConverter { using TypeConverter::convertType; TritonIntelGPUToLLVMTypeConverter( - MLIRContext *ctx, LowerToLLVMOptions &option, + MLIRContext *ctx, LowerToLLVMOptions &option, bool isLTSDriver, const DataLayoutAnalysis *analysis = nullptr); }; diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h b/third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h index 1052ff55e..946491486 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h @@ -178,9 +178,9 @@ struct AddSPIRVEnvPattern : public mlir::OpRewritePattern { /// block pointers or not. class TritonGPUToLLVMPipelineManager { public: - TritonGPUToLLVMPipelineManager(ModuleOp &mod, MLIRContext *ctx) + TritonGPUToLLVMPipelineManager(ModuleOp &mod, MLIRContext *ctx, bool isLTSDriver) : mod(mod), ctx(ctx), - blockPtrPathIsEnabled( + blockPtrPathIsEnabled(!isLTSDriver && mlir::triton::tools::getBoolEnv("TRITON_INTEL_ENABLE_BLOCK_PTR")) {} /// FIXME: remove once the block ptr conversion path is capable of handling diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp index 86523cf57..9d18ca703 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp @@ -76,10 +76,10 @@ struct ConvertTritonGPUToLLVM MLIRContext *context = &getContext(); ModuleOp mod = getOperation(); - intel::TritonGPUToLLVMPipelineManager pipelineManager(mod, context); + intel::TritonGPUToLLVMPipelineManager pipelineManager(mod, context, isLTSDriver); mlir::LowerToLLVMOptions option(context); option.overrideIndexBitwidth(32); - TritonIntelGPUToLLVMTypeConverter typeConverter(context, option); + TritonIntelGPUToLLVMTypeConverter typeConverter(context, option, isLTSDriver); TritonLLVMConversionTarget convTarget(*context); int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod); @@ -95,7 +95,7 @@ struct ConvertTritonGPUToLLVM // Lower functions { mlir::LowerToLLVMOptions option(context); - TritonIntelGPUToLLVMTypeConverter typeConverter(context, option); + TritonIntelGPUToLLVMTypeConverter typeConverter(context, option, isLTSDriver); TritonLLVMFunctionConversionTarget funcTarget(*context); RewritePatternSet funcPatterns(context); pipelineManager.populateFunctionConversionPatterns( diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TypeConverter.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TypeConverter.cpp index 0d596d816..3de72a164 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TypeConverter.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TypeConverter.cpp @@ -9,16 +9,14 @@ #include "intel/include/TritonIntelGPUToLLVM/TypeConverter.h" #include "triton/Tools/Sys/GetEnv.hpp" -using namespace mlir; -using namespace mlir::triton; - TritonIntelGPUToLLVMTypeConverter::TritonIntelGPUToLLVMTypeConverter( - MLIRContext *ctx, LowerToLLVMOptions &option, + MLIRContext *ctx, LowerToLLVMOptions &option, bool isLTSDriver, const DataLayoutAnalysis *analysis) : TritonGPUToLLVMTypeConverter(ctx, option, analysis) { // Augment/overwrite type conversions required for the Intel conversion // passes. - if (mlir::triton::tools::getBoolEnv("TRITON_INTEL_ENABLE_BLOCK_PTR")) { + if (!isLTSDriver && + mlir::triton::tools::getBoolEnv("TRITON_INTEL_ENABLE_BLOCK_PTR")) { // tt::pointer to v2i32. addConversion([&](PointerType type) -> std::optional { if (isa(type.getPointeeType())) { diff --git a/third_party/intel/triton_xpu.cc b/third_party/intel/triton_xpu.cc index 4dfd89246..38f3318a9 100644 --- a/third_party/intel/triton_xpu.cc +++ b/third_party/intel/triton_xpu.cc @@ -55,8 +55,8 @@ void init_triton_intel_passes_ttgpuir(py::module &&m) { .value("PVC", gpu::intel::DeviceArch::PVC) .export_values(); - ADD_PASS_WRAPPER_0("add_to_llvmir", - gpu::intel::createConvertTritonIntelGPUToLLVM); + ADD_PASS_WRAPPER_OPT_1("add_to_llvmir", + gpu::intel::createConvertTritonIntelGPUToLLVM, bool); ADD_PASS_WRAPPER_0("add_accelerate_matmul", gpu::intel::createTritonIntelGPUAccelerateMatmul); ADD_PASS_WRAPPER_0("add_decompose_unsupported_conversions",