Skip to content

Commit

Permalink
Fix #1159
Browse files Browse the repository at this point in the history
Signed-off-by: Ettore Tiotto <ettore.tiotto@intel.com>
  • Loading branch information
etiotto committed Jun 7, 2024
1 parent 951c0ba commit 2f8ac03
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 18 deletions.
12 changes: 7 additions & 5 deletions third_party/intel/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ class XPUOptions:
extern_libs: dict = None
debug: bool = False
backend_name: str = 'intel'
is_block_ptr_enabled: bool = os.getenv("TRITON_INTEL_ENABLE_BLOCK_PTR", "0") == "1"

def __post_init__(self):
default_libdir = Path(__file__).parent / 'lib'
Expand All @@ -71,7 +70,7 @@ class XPUBackend(BaseBackend):
class Experimental:

@staticmethod
def make_ttgir(mod, metadata, opt, device_arch):
def make_ttgir(mod, metadata, opt):
pm = ir.pass_manager(mod.context)
pm.enable_debug()

Expand Down Expand Up @@ -152,8 +151,9 @@ def make_ttir(mod, metadata, opt):

@staticmethod
def make_ttgir(mod, metadata, opt, device_arch):
if XPUOptions.is_block_ptr_enabled:
return XPUBackend.Experimental.make_ttgir(mod, metadata, opt, device_arch)
is_lts = Version(metadata["target"].arch['driver_version']) == Version("1.3.27642")
if (not is_lts and os.getenv("TRITON_INTEL_ENABLE_BLOCK_PTR", "0") == "1"):
return XPUBackend.Experimental.make_ttgir(mod, metadata, opt)

# TTIR -> TTGIR
pm = ir.pass_manager(mod.context)
Expand Down Expand Up @@ -187,6 +187,8 @@ def make_ttgir(mod, metadata, opt, device_arch):

@staticmethod
def make_llir(src, metadata, options):
is_lts = Version(metadata["target"].arch['driver_version']) == Version("1.3.27642")

# warp-specialization mutates num_warps
num_warp_groups = src.get_int_attr("triton_gpu.num-warp-groups-per-cta")
if num_warp_groups is not None:
Expand All @@ -201,7 +203,7 @@ def make_llir(src, metadata, options):
passes.convert.add_scf_to_cf(pm)
passes.convert.add_index_to_llvmir(pm)
intel.passes.ttgpuir.add_allocate_shared_memory(pm)
intel.passes.ttgpuir.add_to_llvmir(pm)
intel.passes.ttgpuir.add_to_llvmir(pm, is_lts)
passes.convert.add_arith_to_llvmir(pm)
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
Expand Down
6 changes: 6 additions & 0 deletions third_party/intel/include/TritonIntelGPUToLLVM/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ def ConvertTritonIntelGPUToLLVM
"mlir::triton::TritonDialect",
"mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonGEN::TritonGENDialect"];

let options = [
Option<"isLTSDriver", "is-lts-driver",
"bool", /*default*/"false",
"Target is LTS driver or not">,
];
}

#endif // TRITONINTELGPU_CONVERSION_PASSES
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class TritonIntelGPUToLLVMTypeConverter : public TritonGPUToLLVMTypeConverter {
using TypeConverter::convertType;

TritonIntelGPUToLLVMTypeConverter(
MLIRContext *ctx, LowerToLLVMOptions &option,
MLIRContext *ctx, LowerToLLVMOptions &option, bool isLTSDriver,
const DataLayoutAnalysis *analysis = nullptr);
};

Expand Down
4 changes: 2 additions & 2 deletions third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,9 @@ struct AddSPIRVEnvPattern : public mlir::OpRewritePattern<ModuleOp> {
/// block pointers or not.
class TritonGPUToLLVMPipelineManager {
public:
TritonGPUToLLVMPipelineManager(ModuleOp &mod, MLIRContext *ctx)
TritonGPUToLLVMPipelineManager(ModuleOp &mod, MLIRContext *ctx, bool isLTSDriver)
: mod(mod), ctx(ctx),
blockPtrPathIsEnabled(
blockPtrPathIsEnabled(!isLTSDriver &&
mlir::triton::tools::getBoolEnv("TRITON_INTEL_ENABLE_BLOCK_PTR")) {}

/// FIXME: remove once the block ptr conversion path is capable of handling
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ struct ConvertTritonGPUToLLVM
MLIRContext *context = &getContext();
ModuleOp mod = getOperation();

intel::TritonGPUToLLVMPipelineManager pipelineManager(mod, context);
intel::TritonGPUToLLVMPipelineManager pipelineManager(mod, context, isLTSDriver);
mlir::LowerToLLVMOptions option(context);
option.overrideIndexBitwidth(32);
TritonIntelGPUToLLVMTypeConverter typeConverter(context, option);
TritonIntelGPUToLLVMTypeConverter typeConverter(context, option, isLTSDriver);
TritonLLVMConversionTarget convTarget(*context);
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod);
Expand All @@ -95,7 +95,7 @@ struct ConvertTritonGPUToLLVM
// Lower functions
{
mlir::LowerToLLVMOptions option(context);
TritonIntelGPUToLLVMTypeConverter typeConverter(context, option);
TritonIntelGPUToLLVMTypeConverter typeConverter(context, option, isLTSDriver);
TritonLLVMFunctionConversionTarget funcTarget(*context);
RewritePatternSet funcPatterns(context);
pipelineManager.populateFunctionConversionPatterns(
Expand Down
8 changes: 3 additions & 5 deletions third_party/intel/lib/TritonIntelGPUToLLVM/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,14 @@
#include "intel/include/TritonIntelGPUToLLVM/TypeConverter.h"
#include "triton/Tools/Sys/GetEnv.hpp"

using namespace mlir;
using namespace mlir::triton;

TritonIntelGPUToLLVMTypeConverter::TritonIntelGPUToLLVMTypeConverter(
MLIRContext *ctx, LowerToLLVMOptions &option,
MLIRContext *ctx, LowerToLLVMOptions &option, bool isLTSDriver,
const DataLayoutAnalysis *analysis)
: TritonGPUToLLVMTypeConverter(ctx, option, analysis) {
// Augment/overwrite type conversions required for the Intel conversion
// passes.
if (mlir::triton::tools::getBoolEnv("TRITON_INTEL_ENABLE_BLOCK_PTR")) {
if (!isLTSDriver &&
mlir::triton::tools::getBoolEnv("TRITON_INTEL_ENABLE_BLOCK_PTR")) {
// tt::pointer to v2i32.
addConversion([&](PointerType type) -> std::optional<Type> {
if (isa<RankedTensorType>(type.getPointeeType())) {
Expand Down
4 changes: 2 additions & 2 deletions third_party/intel/triton_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ void init_triton_intel_passes_ttgpuir(py::module &&m) {
.value("PVC", gpu::intel::DeviceArch::PVC)
.export_values();

ADD_PASS_WRAPPER_0("add_to_llvmir",
gpu::intel::createConvertTritonIntelGPUToLLVM);
ADD_PASS_WRAPPER_OPT_1("add_to_llvmir",
gpu::intel::createConvertTritonIntelGPUToLLVM, bool);
ADD_PASS_WRAPPER_0("add_accelerate_matmul",
gpu::intel::createTritonIntelGPUAccelerateMatmul);
ADD_PASS_WRAPPER_0("add_decompose_unsupported_conversions",
Expand Down

0 comments on commit 2f8ac03

Please sign in to comment.