From d131fb17ad2c398486725eaf8da1aa090a053bd7 Mon Sep 17 00:00:00 2001 From: shuailong616 <452509829@qq.com> Date: Wed, 28 May 2025 16:30:23 +0800 Subject: [PATCH 1/8] add expression restructuring pass --- .../triton/Dialect/Triton/Transforms/Passes.h | 1 + .../Dialect/Triton/Transforms/Passes.td | 11 ++ lib/Dialect/Triton/Transforms/CMakeLists.txt | 1 + .../Transforms/ExpressionRestructing.cpp | 175 ++++++++++++++++++ python/src/passes.cc | 1 + third_party/nvidia/backend/compiler.py | 1 + 6 files changed, 190 insertions(+) create mode 100644 lib/Dialect/Triton/Transforms/ExpressionRestructing.cpp diff --git a/include/triton/Dialect/Triton/Transforms/Passes.h b/include/triton/Dialect/Triton/Transforms/Passes.h index fde54fe17..3f00260ef 100644 --- a/include/triton/Dialect/Triton/Transforms/Passes.h +++ b/include/triton/Dialect/Triton/Transforms/Passes.h @@ -11,6 +11,7 @@ std::unique_ptr createCombineOpsPass(); std::unique_ptr createReorderBroadcastPass(); std::unique_ptr createRewriteTensorPointerPass(); +std::unique_ptr createExpressionRestructingPass(); } // namespace triton #define GEN_PASS_REGISTRATION diff --git a/include/triton/Dialect/Triton/Transforms/Passes.td b/include/triton/Dialect/Triton/Transforms/Passes.td index 4ebff63fa..23ff2a75e 100644 --- a/include/triton/Dialect/Triton/Transforms/Passes.td +++ b/include/triton/Dialect/Triton/Transforms/Passes.td @@ -41,4 +41,15 @@ def TritonRewriteTensorPointer : Pass { + let summary = "ExpressionRestructing"; + let description = [{ + transform a = b / c; d = a /e; to a= c * e; d = b / a; + }]; + + let constructor = "mlir::triton::createExpressionRestructingPass()"; + + let dependentDialects = ["mlir::triton::TritonDialect", "mlir::arith::ArithDialect"]; +} + #endif diff --git a/lib/Dialect/Triton/Transforms/CMakeLists.txt b/lib/Dialect/Triton/Transforms/CMakeLists.txt index 298398750..cbaa66c45 100644 --- a/lib/Dialect/Triton/Transforms/CMakeLists.txt +++ b/lib/Dialect/Triton/Transforms/CMakeLists.txt @@ -6,6 +6,7 @@ add_triton_library(TritonTransforms Combine.cpp ReorderBroadcast.cpp RewriteTensorPointer.cpp + ExpressionRestructing.cpp DEPENDS TritonTransformsIncGen diff --git a/lib/Dialect/Triton/Transforms/ExpressionRestructing.cpp b/lib/Dialect/Triton/Transforms/ExpressionRestructing.cpp new file mode 100644 index 000000000..7fa92556a --- /dev/null +++ b/lib/Dialect/Triton/Transforms/ExpressionRestructing.cpp @@ -0,0 +1,175 @@ +#include + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/Transforms/Passes.h" + + +#define GEN_PASS_CLASSES +#include "triton/Dialect/Triton/Transforms/Passes.h.inc" + +using namespace mlir; +using llvm::ArrayRef; +namespace mlir::triton{ + + +struct Div2Mul : public OpRewritePattern{ + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::DivFOp op, PatternRewriter &rewriter) const override{ + Value result = op.getResult(); + Value l = op.getLhs(); + Value r = op.getRhs(); + auto loc = op.getLoc(); + + if (!result.hasOneUse()) + return failure(); + for (auto &use : result.getUses()){ + if(!dyn_cast(use.getOwner())) + return failure(); + auto DivUser = dyn_cast(use.getOwner()); + if(DivUser.getLhs()!= op.getResult()) + return failure(); + auto originalInsertionPoint = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointAfter(DivUser); + auto loc_div = DivUser.getLoc(); + auto product = rewriter.create(loc_div, r, DivUser.getRhs()); + rewriter.setInsertionPointAfter(product); + auto ResultEnd = rewriter.create(loc_div, l, product.getResult()); + rewriter.restoreInsertionPoint(originalInsertionPoint); + rewriter.replaceOp(op, product.getResult()); + DivUser.replaceAllUsesWith(ResultEnd.getResult()); + rewriter.eraseOp(DivUser); + } + return success(); + } +}; + +struct Mul2Mul : public OpRewritePattern{ + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::MulFOp op, PatternRewriter &rewriter) const override{ + Value result = op.getResult(); + Value l = op.getLhs(); + Value r = op.getRhs(); + auto loc = op.getLoc(); + + if (!result.hasOneUse()) + return failure(); + for (auto &use : result.getUses()){ + if(!dyn_cast(use.getOwner())) + return failure(); + auto MulUser = dyn_cast(use.getOwner()); + if(!(MulUser.getLhs() == op.getResult() && ((MulUser.getRhs().getDefiningOp()&& r.getDefiningOp())||(r == MulUser.getRhs())))) + return failure(); + auto originalInsertionPoint = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointAfter(MulUser); + auto loc_mul = MulUser.getLoc(); + auto product = rewriter.create(loc_mul, r, MulUser.getRhs()); + rewriter.setInsertionPointAfter(product); + auto ResultEnd = rewriter.create(loc_mul, l, product.getResult()); + rewriter.restoreInsertionPoint(originalInsertionPoint); + rewriter.replaceOp(op, product.getResult()); + MulUser.replaceAllUsesWith(ResultEnd.getResult()); + rewriter.eraseOp(MulUser); + } + return success(); + } +}; + +struct Add2Add : public OpRewritePattern{ + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::AddFOp op, PatternRewriter &rewriter) const override{ + Value result = op.getResult(); + Value l = op.getLhs(); + Value r = op.getRhs(); + auto loc = op.getLoc(); + + if (!result.hasOneUse()) + return failure(); + for (auto &use : result.getUses()){ + if(!dyn_cast(use.getOwner())) + return failure(); + auto AddUser = dyn_cast(use.getOwner()); + if(!(AddUser.getLhs() == op.getResult() && ((AddUser.getRhs().getDefiningOp()&& r.getDefiningOp())||(r == AddUser.getRhs())))) + return failure(); + auto originalInsertionPoint = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointAfter(AddUser); + auto loc_add = AddUser.getLoc(); + auto sum = rewriter.create(loc_add, r, AddUser.getRhs()); + rewriter.setInsertionPointAfter(sum); + auto ResultEnd = rewriter.create(loc_add, l, sum.getResult()); + rewriter.restoreInsertionPoint(originalInsertionPoint); + rewriter.replaceOp(op, sum.getResult()); + AddUser.replaceAllUsesWith(ResultEnd.getResult()); + rewriter.eraseOp(AddUser); + } + return success(); + } +}; + +struct Sub2Add : public OpRewritePattern{ + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::SubFOp op, PatternRewriter &rewriter) const override{ + Value result = op.getResult(); + Value l = op.getLhs(); + Value r = op.getRhs(); + auto loc = op.getLoc(); + + if (!result.hasOneUse()) + return failure(); + for (auto &use : result.getUses()){ + if(!dyn_cast(use.getOwner())) + return failure(); + auto SubUser = dyn_cast(use.getOwner()); + if(!(SubUser.getLhs() == op.getResult() && ((SubUser.getRhs().getDefiningOp()&& r.getDefiningOp())||(r == SubUser.getRhs())))) + return failure(); + auto originalInsertionPoint = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointAfter(SubUser); + auto loc_sub = SubUser.getLoc(); + auto sum = rewriter.create(loc_sub, r, SubUser.getRhs()); + rewriter.setInsertionPointAfter(sum); + auto ResultEnd = rewriter.create(loc_sub, l, sum.getResult()); + rewriter.restoreInsertionPoint(originalInsertionPoint); + rewriter.replaceOp(op, sum.getResult()); + SubUser.replaceAllUsesWith(ResultEnd.getResult()); + rewriter.eraseOp(SubUser); + } + return success(); + } +}; + +class ExpressionRestructingPass : public TritonExpressionRestructingBase{ +public: + void runOnOperation() override{ + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + ModuleOp m = getOperation(); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + + if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) + signalPassFailure(); + } +}; + + +std::unique_ptr createExpressionRestructingPass(){ + return std::make_unique(); +} + +} + + + diff --git a/python/src/passes.cc b/python/src/passes.cc index 513e811d2..eba1c3e9d 100644 --- a/python/src/passes.cc +++ b/python/src/passes.cc @@ -37,6 +37,7 @@ void init_triton_passes_ttir(py::module &&m) { using namespace mlir::triton; ADD_PASS_WRAPPER_0("add_combine", createCombineOpsPass); ADD_PASS_WRAPPER_0("add_reorder_broadcast", createReorderBroadcastPass); + ADD_PASS_WRAPPER_0("add_expression_restructing", createExpressionRestructingPass); ADD_PASS_WRAPPER_0("add_rewrite_tensor_pointer", createRewriteTensorPointerPass); ADD_PASS_WRAPPER_4("add_convert_to_ttgpuir", diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 6d7994923..24a80484f 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -144,6 +144,7 @@ def make_ttir(mod, metadata, opt): passes.common.add_canonicalizer(pm) passes.ttir.add_reorder_broadcast(pm) passes.common.add_cse(pm) + passes.ttir.add_expression_restructing(pm) passes.common.add_licm(pm) passes.common.add_symbol_dce(pm) pm.run(mod) From 069162d78ec004b210c8a925ee7c6a72e350c7b1 Mon Sep 17 00:00:00 2001 From: AdvancedCompiler Date: Thu, 29 May 2025 09:59:17 +0800 Subject: [PATCH 2/8] code format --- .../Dialect/Triton/Transforms/Passes.td | 5 +- lib/Dialect/Triton/Transforms/CMakeLists.txt | 1 + .../Transforms/ExpressionRestructing.cpp | 294 +++++++++--------- python/src/passes.cc | 3 +- third_party/nvidia/backend/compiler.py | 151 ++++++--- 5 files changed, 261 insertions(+), 193 deletions(-) diff --git a/include/triton/Dialect/Triton/Transforms/Passes.td b/include/triton/Dialect/Triton/Transforms/Passes.td index 23ff2a75e..165b779b6 100644 --- a/include/triton/Dialect/Triton/Transforms/Passes.td +++ b/include/triton/Dialect/Triton/Transforms/Passes.td @@ -44,7 +44,10 @@ def TritonRewriteTensorPointer : Pass { let summary = "ExpressionRestructing"; let description = [{ - transform a = b / c; d = a /e; to a= c * e; d = b / a; + transform a = b / c; d = a / e; to a = c * e; d = b / a; + transform a = b + c; d = a + c; to a = c + c; d = b + a; + transform a = b - c; d = a - c; to a = c + c; d = b - a; + transform a = b * c; d = a * c; to a = c * c; d = b * a; }]; let constructor = "mlir::triton::createExpressionRestructingPass()"; diff --git a/lib/Dialect/Triton/Transforms/CMakeLists.txt b/lib/Dialect/Triton/Transforms/CMakeLists.txt index cbaa66c45..444644a3f 100644 --- a/lib/Dialect/Triton/Transforms/CMakeLists.txt +++ b/lib/Dialect/Triton/Transforms/CMakeLists.txt @@ -8,6 +8,7 @@ add_triton_library(TritonTransforms RewriteTensorPointer.cpp ExpressionRestructing.cpp + DEPENDS TritonTransformsIncGen TritonCombineIncGen diff --git a/lib/Dialect/Triton/Transforms/ExpressionRestructing.cpp b/lib/Dialect/Triton/Transforms/ExpressionRestructing.cpp index 7fa92556a..74681d389 100644 --- a/lib/Dialect/Triton/Transforms/ExpressionRestructing.cpp +++ b/lib/Dialect/Triton/Transforms/ExpressionRestructing.cpp @@ -11,165 +11,179 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/Transforms/Passes.h" - #define GEN_PASS_CLASSES #include "triton/Dialect/Triton/Transforms/Passes.h.inc" using namespace mlir; using llvm::ArrayRef; -namespace mlir::triton{ - - -struct Div2Mul : public OpRewritePattern{ - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(arith::DivFOp op, PatternRewriter &rewriter) const override{ - Value result = op.getResult(); - Value l = op.getLhs(); - Value r = op.getRhs(); - auto loc = op.getLoc(); - - if (!result.hasOneUse()) - return failure(); - for (auto &use : result.getUses()){ - if(!dyn_cast(use.getOwner())) - return failure(); - auto DivUser = dyn_cast(use.getOwner()); - if(DivUser.getLhs()!= op.getResult()) - return failure(); - auto originalInsertionPoint = rewriter.saveInsertionPoint(); - rewriter.setInsertionPointAfter(DivUser); - auto loc_div = DivUser.getLoc(); - auto product = rewriter.create(loc_div, r, DivUser.getRhs()); - rewriter.setInsertionPointAfter(product); - auto ResultEnd = rewriter.create(loc_div, l, product.getResult()); - rewriter.restoreInsertionPoint(originalInsertionPoint); - rewriter.replaceOp(op, product.getResult()); - DivUser.replaceAllUsesWith(ResultEnd.getResult()); - rewriter.eraseOp(DivUser); - } - return success(); +namespace mlir::triton { + +struct Div2Mul : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::DivFOp op, + PatternRewriter &rewriter) const override { + Value result = op.getResult(); + Value l = op.getLhs(); + Value r = op.getRhs(); + auto loc = op.getLoc(); + + if (!result.hasOneUse()) + return failure(); + for (auto &use : result.getUses()) { + if (!dyn_cast(use.getOwner())) + return failure(); + auto DivUser = dyn_cast(use.getOwner()); + if (DivUser.getLhs() != op.getResult()) + return failure(); + auto originalInsertionPoint = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointAfter(DivUser); + auto loc_div = DivUser.getLoc(); + auto product = + rewriter.create(loc_div, r, DivUser.getRhs()); + rewriter.setInsertionPointAfter(product); + auto ResultEnd = + rewriter.create(loc_div, l, product.getResult()); + rewriter.restoreInsertionPoint(originalInsertionPoint); + rewriter.replaceOp(op, product.getResult()); + DivUser.replaceAllUsesWith(ResultEnd.getResult()); + rewriter.eraseOp(DivUser); } + return success(); + } }; -struct Mul2Mul : public OpRewritePattern{ - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(arith::MulFOp op, PatternRewriter &rewriter) const override{ - Value result = op.getResult(); - Value l = op.getLhs(); - Value r = op.getRhs(); - auto loc = op.getLoc(); - - if (!result.hasOneUse()) - return failure(); - for (auto &use : result.getUses()){ - if(!dyn_cast(use.getOwner())) - return failure(); - auto MulUser = dyn_cast(use.getOwner()); - if(!(MulUser.getLhs() == op.getResult() && ((MulUser.getRhs().getDefiningOp()&& r.getDefiningOp())||(r == MulUser.getRhs())))) - return failure(); - auto originalInsertionPoint = rewriter.saveInsertionPoint(); - rewriter.setInsertionPointAfter(MulUser); - auto loc_mul = MulUser.getLoc(); - auto product = rewriter.create(loc_mul, r, MulUser.getRhs()); - rewriter.setInsertionPointAfter(product); - auto ResultEnd = rewriter.create(loc_mul, l, product.getResult()); - rewriter.restoreInsertionPoint(originalInsertionPoint); - rewriter.replaceOp(op, product.getResult()); - MulUser.replaceAllUsesWith(ResultEnd.getResult()); - rewriter.eraseOp(MulUser); - } - return success(); +struct Mul2Mul : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::MulFOp op, + PatternRewriter &rewriter) const override { + Value result = op.getResult(); + Value l = op.getLhs(); + Value r = op.getRhs(); + auto loc = op.getLoc(); + + if (!result.hasOneUse()) + return failure(); + for (auto &use : result.getUses()) { + if (!dyn_cast(use.getOwner())) + return failure(); + auto MulUser = dyn_cast(use.getOwner()); + if (!(MulUser.getLhs() == op.getResult() && + ((MulUser.getRhs().getDefiningOp() && + r.getDefiningOp()) || + (r == MulUser.getRhs())))) + return failure(); + auto originalInsertionPoint = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointAfter(MulUser); + auto loc_mul = MulUser.getLoc(); + auto product = + rewriter.create(loc_mul, r, MulUser.getRhs()); + rewriter.setInsertionPointAfter(product); + auto ResultEnd = + rewriter.create(loc_mul, l, product.getResult()); + rewriter.restoreInsertionPoint(originalInsertionPoint); + rewriter.replaceOp(op, product.getResult()); + MulUser.replaceAllUsesWith(ResultEnd.getResult()); + rewriter.eraseOp(MulUser); } + return success(); + } }; -struct Add2Add : public OpRewritePattern{ - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(arith::AddFOp op, PatternRewriter &rewriter) const override{ - Value result = op.getResult(); - Value l = op.getLhs(); - Value r = op.getRhs(); - auto loc = op.getLoc(); - - if (!result.hasOneUse()) - return failure(); - for (auto &use : result.getUses()){ - if(!dyn_cast(use.getOwner())) - return failure(); - auto AddUser = dyn_cast(use.getOwner()); - if(!(AddUser.getLhs() == op.getResult() && ((AddUser.getRhs().getDefiningOp()&& r.getDefiningOp())||(r == AddUser.getRhs())))) - return failure(); - auto originalInsertionPoint = rewriter.saveInsertionPoint(); - rewriter.setInsertionPointAfter(AddUser); - auto loc_add = AddUser.getLoc(); - auto sum = rewriter.create(loc_add, r, AddUser.getRhs()); - rewriter.setInsertionPointAfter(sum); - auto ResultEnd = rewriter.create(loc_add, l, sum.getResult()); - rewriter.restoreInsertionPoint(originalInsertionPoint); - rewriter.replaceOp(op, sum.getResult()); - AddUser.replaceAllUsesWith(ResultEnd.getResult()); - rewriter.eraseOp(AddUser); - } - return success(); +struct Add2Add : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::AddFOp op, + PatternRewriter &rewriter) const override { + Value result = op.getResult(); + Value l = op.getLhs(); + Value r = op.getRhs(); + auto loc = op.getLoc(); + + if (!result.hasOneUse()) + return failure(); + for (auto &use : result.getUses()) { + if (!dyn_cast(use.getOwner())) + return failure(); + auto AddUser = dyn_cast(use.getOwner()); + if (!(AddUser.getLhs() == op.getResult() && + ((AddUser.getRhs().getDefiningOp() && + r.getDefiningOp()) || + (r == AddUser.getRhs())))) + return failure(); + auto originalInsertionPoint = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointAfter(AddUser); + auto loc_add = AddUser.getLoc(); + auto sum = rewriter.create(loc_add, r, AddUser.getRhs()); + rewriter.setInsertionPointAfter(sum); + auto ResultEnd = + rewriter.create(loc_add, l, sum.getResult()); + rewriter.restoreInsertionPoint(originalInsertionPoint); + rewriter.replaceOp(op, sum.getResult()); + AddUser.replaceAllUsesWith(ResultEnd.getResult()); + rewriter.eraseOp(AddUser); } + return success(); + } }; -struct Sub2Add : public OpRewritePattern{ - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(arith::SubFOp op, PatternRewriter &rewriter) const override{ - Value result = op.getResult(); - Value l = op.getLhs(); - Value r = op.getRhs(); - auto loc = op.getLoc(); - - if (!result.hasOneUse()) - return failure(); - for (auto &use : result.getUses()){ - if(!dyn_cast(use.getOwner())) - return failure(); - auto SubUser = dyn_cast(use.getOwner()); - if(!(SubUser.getLhs() == op.getResult() && ((SubUser.getRhs().getDefiningOp()&& r.getDefiningOp())||(r == SubUser.getRhs())))) - return failure(); - auto originalInsertionPoint = rewriter.saveInsertionPoint(); - rewriter.setInsertionPointAfter(SubUser); - auto loc_sub = SubUser.getLoc(); - auto sum = rewriter.create(loc_sub, r, SubUser.getRhs()); - rewriter.setInsertionPointAfter(sum); - auto ResultEnd = rewriter.create(loc_sub, l, sum.getResult()); - rewriter.restoreInsertionPoint(originalInsertionPoint); - rewriter.replaceOp(op, sum.getResult()); - SubUser.replaceAllUsesWith(ResultEnd.getResult()); - rewriter.eraseOp(SubUser); - } - return success(); +struct Sub2Add : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::SubFOp op, + PatternRewriter &rewriter) const override { + Value result = op.getResult(); + Value l = op.getLhs(); + Value r = op.getRhs(); + auto loc = op.getLoc(); + + if (!result.hasOneUse()) + return failure(); + for (auto &use : result.getUses()) { + if (!dyn_cast(use.getOwner())) + return failure(); + auto SubUser = dyn_cast(use.getOwner()); + if (!(SubUser.getLhs() == op.getResult() && + ((SubUser.getRhs().getDefiningOp() && + r.getDefiningOp()) || + (r == SubUser.getRhs())))) + return failure(); + auto originalInsertionPoint = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointAfter(SubUser); + auto loc_sub = SubUser.getLoc(); + auto sum = rewriter.create(loc_sub, r, SubUser.getRhs()); + rewriter.setInsertionPointAfter(sum); + auto ResultEnd = + rewriter.create(loc_sub, l, sum.getResult()); + rewriter.restoreInsertionPoint(originalInsertionPoint); + rewriter.replaceOp(op, sum.getResult()); + SubUser.replaceAllUsesWith(ResultEnd.getResult()); + rewriter.eraseOp(SubUser); } + return success(); + } }; -class ExpressionRestructingPass : public TritonExpressionRestructingBase{ +class ExpressionRestructingPass + : public TritonExpressionRestructingBase { public: - void runOnOperation() override{ - MLIRContext *context = &getContext(); - RewritePatternSet patterns(context); - ModuleOp m = getOperation(); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - - if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) - signalPassFailure(); - } + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + ModuleOp m = getOperation(); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + + if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) + signalPassFailure(); + } }; - -std::unique_ptr createExpressionRestructingPass(){ - return std::make_unique(); +std::unique_ptr createExpressionRestructingPass() { + return std::make_unique(); } -} - - - +} // namespace mlir::triton diff --git a/python/src/passes.cc b/python/src/passes.cc index eba1c3e9d..263b20dae 100644 --- a/python/src/passes.cc +++ b/python/src/passes.cc @@ -37,7 +37,8 @@ void init_triton_passes_ttir(py::module &&m) { using namespace mlir::triton; ADD_PASS_WRAPPER_0("add_combine", createCombineOpsPass); ADD_PASS_WRAPPER_0("add_reorder_broadcast", createReorderBroadcastPass); - ADD_PASS_WRAPPER_0("add_expression_restructing", createExpressionRestructingPass); + ADD_PASS_WRAPPER_0("add_expression_restructing", + createExpressionRestructingPass); ADD_PASS_WRAPPER_0("add_rewrite_tensor_pointer", createRewriteTensorPointerPass); ADD_PASS_WRAPPER_4("add_convert_to_ttgpuir", diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 24a80484f..25a86900f 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -1,16 +1,16 @@ -from triton.backends.compiler import BaseBackend, GPUTarget -from triton._C.libtriton import ir, passes, llvm, nvidia - -from dataclasses import dataclass import functools -from typing import Any, Tuple, Optional import hashlib +import os import re -import tempfile import signal -import os import subprocess +import tempfile +from dataclasses import dataclass from pathlib import Path +from typing import Any, Optional, Tuple + +from triton._C.libtriton import ir, llvm, nvidia, passes +from triton.backends.compiler import BaseBackend, GPUTarget @functools.lru_cache() @@ -22,9 +22,15 @@ def _path_to_binary(binary: str): for bin in paths: if os.path.exists(bin) and os.path.isfile(bin): - result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT) + result = subprocess.check_output( + [bin, "--version"], stderr=subprocess.STDOUT + ) if result is not None: - version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE) + version = re.search( + r".*release (\d+\.\d+).*", + result.decode("utf-8"), + flags=re.MULTILINE, + ) if version is not None: return bin, version.group(1) raise RuntimeError(f"Cannot find {binary}") @@ -32,17 +38,19 @@ def _path_to_binary(binary: str): @functools.lru_cache() def get_ptxas_version(): - version = subprocess.check_output([_path_to_binary("ptxas")[0], "--version"]).decode("utf-8") + version = subprocess.check_output( + [_path_to_binary("ptxas")[0], "--version"] + ).decode("utf-8") return version @functools.lru_cache() def ptx_get_version(cuda_version) -> int: - ''' + """ Get the highest PTX version supported by the current CUDA driver. - ''' + """ assert isinstance(cuda_version, str) - major, minor = map(int, cuda_version.split('.')) + major, minor = map(int, cuda_version.split(".")) if major == 12: return 80 + minor if major == 11: @@ -76,29 +84,33 @@ class CUDAOptions: max_num_imprecise_acc_default: bool = None extern_libs: dict = None debug: bool = False - backend_name: str = 'cuda' + backend_name: str = "cuda" def __post_init__(self): - default_libdir = Path(__file__).parent / 'lib' + default_libdir = Path(__file__).parent / "lib" extern_libs = {} if self.extern_libs is None else dict(self.extern_libs) - if not extern_libs.get('libdevice', None): - extern_libs['libdevice'] = os.getenv("TRITON_LIBDEVICE_PATH", str(default_libdir / 'libdevice.10.bc')) - object.__setattr__(self, 'extern_libs', tuple(extern_libs.items())) - assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \ - "num_warps must be a power of 2" + if not extern_libs.get("libdevice", None): + extern_libs["libdevice"] = os.getenv( + "TRITON_LIBDEVICE_PATH", str(default_libdir / "libdevice.10.bc") + ) + object.__setattr__(self, "extern_libs", tuple(extern_libs.items())) + assert ( + self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0 + ), "num_warps must be a power of 2" def hash(self): hash_dict = dict(self.__dict__) - hash_dict["extern_libs"] = tuple((k, file_hash(v)) for k, v in sorted(hash_dict["extern_libs"])) + hash_dict["extern_libs"] = tuple( + (k, file_hash(v)) for k, v in sorted(hash_dict["extern_libs"]) + ) key = "_".join([f"{name}-{val}" for name, val in sorted(hash_dict.items())]) return hashlib.sha256(key.encode("utf-8")).hexdigest() class CUDABackend(BaseBackend): - @staticmethod def supports_target(target: GPUTarget): - return target.backend == 'cuda' + return target.backend == "cuda" def __init__(self, target: GPUTarget) -> None: super().__init__(target) @@ -107,7 +119,9 @@ def __init__(self, target: GPUTarget) -> None: self.binary_ext = "cubin" def parse_options(self, opts) -> Any: - args = {k: opts[k] for k in CUDAOptions.__dataclass_fields__.keys() if k in opts} + args = { + k: opts[k] for k in CUDAOptions.__dataclass_fields__.keys() if k in opts + } args["allow_fp8e4nv"] = self.capability >= 89 args["allow_fp8e4b15"] = self.capability < 90 args["max_num_imprecise_acc_default"] = 2**30 if self.capability == 90 else 0 @@ -125,9 +139,11 @@ def pack_metadata(self, metadata): def get_codegen_implementation(self): import triton.language.extra.cuda as cuda + codegen_fns = { - "convert_custom_types": - cuda.convert_custom_float8_sm80 if self.capability >= 80 else cuda.convert_custom_float8_sm70 + "convert_custom_types": cuda.convert_custom_float8_sm80 + if self.capability >= 80 + else cuda.convert_custom_float8_sm70 } return codegen_fns @@ -144,6 +160,7 @@ def make_ttir(mod, metadata, opt): passes.common.add_canonicalizer(pm) passes.ttir.add_reorder_broadcast(pm) passes.common.add_cse(pm) + passes.ttir.add_expression_restructing(pm) passes.common.add_licm(pm) passes.common.add_symbol_dce(pm) @@ -160,7 +177,9 @@ def make_ttgir(mod, metadata, opt, capability): # TTIR -> TTGIR pm = ir.pass_manager(mod.context) pm.enable_debug() - passes.ttir.add_convert_to_ttgpuir(pm, f"cuda:{capability}", opt.num_warps, 32, opt.num_ctas) + passes.ttir.add_convert_to_ttgpuir( + pm, f"cuda:{capability}", opt.num_warps, 32, opt.num_ctas + ) # optimize TTGIR passes.ttgpuir.add_coalesce(pm) if capability // 10 >= 8: @@ -188,7 +207,11 @@ def make_ttgir(mod, metadata, opt, capability): nvidia.passes.ttnvgpuir.add_tma_lowering(pm) passes.common.add_canonicalizer(pm) pm.run(mod) - metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ) + metadata["cluster_dims"] = ( + cluster_info.clusterDimX, + cluster_info.clusterDimY, + cluster_info.clusterDimZ, + ) return mod @staticmethod @@ -254,17 +277,27 @@ def make_ptx(src, metadata, opt, capability): # like "+ptx8.4 is not a recognized feature for this target". llvm_ptx_version = min(83, ptx_version) - triple = 'nvptx64-nvidia-cuda' - proc = 'sm_90a' if capability == 90 else f'sm_{capability}' - features = f'+ptx{llvm_ptx_version}' - ret = llvm.translate_to_asm(src, triple, proc, features, ['nvptx-short-ptr'], opt.enable_fp_fusion, False) + triple = "nvptx64-nvidia-cuda" + proc = "sm_90a" if capability == 90 else f"sm_{capability}" + features = f"+ptx{llvm_ptx_version}" + ret = llvm.translate_to_asm( + src, + triple, + proc, + features, + ["nvptx-short-ptr"], + opt.enable_fp_fusion, + False, + ) # Find kernel names (there should only be one) names = re.findall(r".visible .entry ([a-zA-Z_][a-zA-Z0-9_]*)", ret) assert len(names) == 1 metadata["name"] = names[0] # post-process - ptx_version = f'{ptx_version//10}.{ptx_version%10}' - ret = re.sub(r'\.version \d+\.\d+', f'.version {ptx_version}', ret, flags=re.MULTILINE) + ptx_version = f"{ptx_version//10}.{ptx_version%10}" + ret = re.sub( + r"\.version \d+\.\d+", f".version {ptx_version}", ret, flags=re.MULTILINE + ) # Remove the debug flag that prevents ptxas from optimizing the code ret = re.sub(r",\s*debug|debug,\s*", "", ret) if os.environ.get("NVPTX_ENABLE_DUMP", "0") == "1": @@ -275,19 +308,24 @@ def make_ptx(src, metadata, opt, capability): @staticmethod def make_cubin(src, metadata, opt, capability): ptxas, _ = _path_to_binary("ptxas") - with tempfile.NamedTemporaryFile(delete=False, mode='w', suffix='.ptx') as fsrc, \ - tempfile.NamedTemporaryFile(delete=False, mode='r', suffix='.log') as flog: + with tempfile.NamedTemporaryFile( + delete=False, mode="w", suffix=".ptx" + ) as fsrc, tempfile.NamedTemporaryFile( + delete=False, mode="r", suffix=".log" + ) as flog: fsrc.write(src) fsrc.flush() - fbin = fsrc.name + '.o' + fbin = fsrc.name + ".o" - line_info = '' if os.environ.get('TRITON_DISABLE_LINE_INFO') else ' -lineinfo' - fmad = '' if opt.enable_fp_fusion else ' --fmad=false' - suffix = 'a ' if capability == 90 else ' ' + line_info = ( + "" if os.environ.get("TRITON_DISABLE_LINE_INFO") else " -lineinfo" + ) + fmad = "" if opt.enable_fp_fusion else " --fmad=false" + suffix = "a " if capability == 90 else " " if os.environ.get("DISABLE_PTXAS_OPT", "0") == "1": - cmd = f'{ptxas}{line_info}{fmad} -v --opt-level 0 --gpu-name=sm_{capability}{suffix}{fsrc.name} -o {fbin} 2> {flog.name}' + cmd = f"{ptxas}{line_info}{fmad} -v --opt-level 0 --gpu-name=sm_{capability}{suffix}{fsrc.name} -o {fbin} 2> {flog.name}" else: - cmd = f'{ptxas}{line_info}{fmad} -v --gpu-name=sm_{capability}{suffix}{fsrc.name} -o {fbin} 2> {flog.name}' + cmd = f"{ptxas}{line_info}{fmad} -v --gpu-name=sm_{capability}{suffix}{fsrc.name} -o {fbin} 2> {flog.name}" try: subprocess.run(cmd, shell=True, check=True) @@ -295,19 +333,22 @@ def make_cubin(src, metadata, opt, capability): with open(flog.name) as log_file: log = log_file.read() if e.returncode == 255: - raise RuntimeError(f'Internal Triton PTX codegen error: \n{log}') + raise RuntimeError(f"Internal Triton PTX codegen error: \n{log}") elif e.returncode == 128 + signal.SIGSEGV: raise RuntimeError( - f'Please run `ptxas {fsrc.name}` to confirm that this is a bug in `ptxas`\n{log}') + f"Please run `ptxas {fsrc.name}` to confirm that this is a bug in `ptxas`\n{log}" + ) else: - raise RuntimeError(f'`ptxas` failed with error code {e.returncode}: \n{log}') + raise RuntimeError( + f"`ptxas` failed with error code {e.returncode}: \n{log}" + ) finally: if os.path.exists(fsrc.name): os.remove(fsrc.name) if os.path.exists(flog.name): os.remove(flog.name) - with open(fbin, 'rb') as f: + with open(fbin, "rb") as f: cubin = f.read() if os.path.exists(fbin): os.remove(fbin) @@ -315,12 +356,20 @@ def make_cubin(src, metadata, opt, capability): def add_stages(self, stages, options): stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options) - stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, self.capability) - stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, self.capability) - stages["ptx"] = lambda src, metadata: self.make_ptx(src, metadata, options, self.capability) - stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, options, self.capability) + stages["ttgir"] = lambda src, metadata: self.make_ttgir( + src, metadata, options, self.capability + ) + stages["llir"] = lambda src, metadata: self.make_llir( + src, metadata, options, self.capability + ) + stages["ptx"] = lambda src, metadata: self.make_ptx( + src, metadata, options, self.capability + ) + stages["cubin"] = lambda src, metadata: self.make_cubin( + src, metadata, options, self.capability + ) @functools.lru_cache() def hash(self): version = get_ptxas_version() - return f'{version}-{self.capability}' + return f"{version}-{self.capability}" From a753b8bedd31d3019dfbfb2449807c79360b2b54 Mon Sep 17 00:00:00 2001 From: AdvancedCompiler Date: Thu, 29 May 2025 10:55:38 +0800 Subject: [PATCH 3/8] code format new --- third_party/nvidia/backend/compiler.py | 152 +++++++++---------------- 1 file changed, 52 insertions(+), 100 deletions(-) diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 25a86900f..cb75f2a91 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -1,16 +1,16 @@ +from triton.backends.compiler import BaseBackend, GPUTarget +from triton._C.libtriton import ir, passes, llvm, nvidia + +from dataclasses import dataclass import functools +from typing import Any, Tuple, Optional import hashlib -import os import re +import tempfile import signal +import os import subprocess -import tempfile -from dataclasses import dataclass from pathlib import Path -from typing import Any, Optional, Tuple - -from triton._C.libtriton import ir, llvm, nvidia, passes -from triton.backends.compiler import BaseBackend, GPUTarget @functools.lru_cache() @@ -22,15 +22,9 @@ def _path_to_binary(binary: str): for bin in paths: if os.path.exists(bin) and os.path.isfile(bin): - result = subprocess.check_output( - [bin, "--version"], stderr=subprocess.STDOUT - ) + result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT) if result is not None: - version = re.search( - r".*release (\d+\.\d+).*", - result.decode("utf-8"), - flags=re.MULTILINE, - ) + version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE) if version is not None: return bin, version.group(1) raise RuntimeError(f"Cannot find {binary}") @@ -38,19 +32,17 @@ def _path_to_binary(binary: str): @functools.lru_cache() def get_ptxas_version(): - version = subprocess.check_output( - [_path_to_binary("ptxas")[0], "--version"] - ).decode("utf-8") + version = subprocess.check_output([_path_to_binary("ptxas")[0], "--version"]).decode("utf-8") return version @functools.lru_cache() def ptx_get_version(cuda_version) -> int: - """ + ''' Get the highest PTX version supported by the current CUDA driver. - """ + ''' assert isinstance(cuda_version, str) - major, minor = map(int, cuda_version.split(".")) + major, minor = map(int, cuda_version.split('.')) if major == 12: return 80 + minor if major == 11: @@ -84,33 +76,29 @@ class CUDAOptions: max_num_imprecise_acc_default: bool = None extern_libs: dict = None debug: bool = False - backend_name: str = "cuda" + backend_name: str = 'cuda' def __post_init__(self): - default_libdir = Path(__file__).parent / "lib" + default_libdir = Path(__file__).parent / 'lib' extern_libs = {} if self.extern_libs is None else dict(self.extern_libs) - if not extern_libs.get("libdevice", None): - extern_libs["libdevice"] = os.getenv( - "TRITON_LIBDEVICE_PATH", str(default_libdir / "libdevice.10.bc") - ) - object.__setattr__(self, "extern_libs", tuple(extern_libs.items())) - assert ( - self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0 - ), "num_warps must be a power of 2" + if not extern_libs.get('libdevice', None): + extern_libs['libdevice'] = os.getenv("TRITON_LIBDEVICE_PATH", str(default_libdir / 'libdevice.10.bc')) + object.__setattr__(self, 'extern_libs', tuple(extern_libs.items())) + assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \ + "num_warps must be a power of 2" def hash(self): hash_dict = dict(self.__dict__) - hash_dict["extern_libs"] = tuple( - (k, file_hash(v)) for k, v in sorted(hash_dict["extern_libs"]) - ) + hash_dict["extern_libs"] = tuple((k, file_hash(v)) for k, v in sorted(hash_dict["extern_libs"])) key = "_".join([f"{name}-{val}" for name, val in sorted(hash_dict.items())]) return hashlib.sha256(key.encode("utf-8")).hexdigest() class CUDABackend(BaseBackend): + @staticmethod def supports_target(target: GPUTarget): - return target.backend == "cuda" + return target.backend == 'cuda' def __init__(self, target: GPUTarget) -> None: super().__init__(target) @@ -119,9 +107,7 @@ def __init__(self, target: GPUTarget) -> None: self.binary_ext = "cubin" def parse_options(self, opts) -> Any: - args = { - k: opts[k] for k in CUDAOptions.__dataclass_fields__.keys() if k in opts - } + args = {k: opts[k] for k in CUDAOptions.__dataclass_fields__.keys() if k in opts} args["allow_fp8e4nv"] = self.capability >= 89 args["allow_fp8e4b15"] = self.capability < 90 args["max_num_imprecise_acc_default"] = 2**30 if self.capability == 90 else 0 @@ -139,11 +125,9 @@ def pack_metadata(self, metadata): def get_codegen_implementation(self): import triton.language.extra.cuda as cuda - codegen_fns = { - "convert_custom_types": cuda.convert_custom_float8_sm80 - if self.capability >= 80 - else cuda.convert_custom_float8_sm70 + "convert_custom_types": + cuda.convert_custom_float8_sm80 if self.capability >= 80 else cuda.convert_custom_float8_sm70 } return codegen_fns @@ -160,7 +144,6 @@ def make_ttir(mod, metadata, opt): passes.common.add_canonicalizer(pm) passes.ttir.add_reorder_broadcast(pm) passes.common.add_cse(pm) - passes.ttir.add_expression_restructing(pm) passes.common.add_licm(pm) passes.common.add_symbol_dce(pm) @@ -177,9 +160,7 @@ def make_ttgir(mod, metadata, opt, capability): # TTIR -> TTGIR pm = ir.pass_manager(mod.context) pm.enable_debug() - passes.ttir.add_convert_to_ttgpuir( - pm, f"cuda:{capability}", opt.num_warps, 32, opt.num_ctas - ) + passes.ttir.add_convert_to_ttgpuir(pm, f"cuda:{capability}", opt.num_warps, 32, opt.num_ctas) # optimize TTGIR passes.ttgpuir.add_coalesce(pm) if capability // 10 >= 8: @@ -207,11 +188,7 @@ def make_ttgir(mod, metadata, opt, capability): nvidia.passes.ttnvgpuir.add_tma_lowering(pm) passes.common.add_canonicalizer(pm) pm.run(mod) - metadata["cluster_dims"] = ( - cluster_info.clusterDimX, - cluster_info.clusterDimY, - cluster_info.clusterDimZ, - ) + metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ) return mod @staticmethod @@ -277,27 +254,17 @@ def make_ptx(src, metadata, opt, capability): # like "+ptx8.4 is not a recognized feature for this target". llvm_ptx_version = min(83, ptx_version) - triple = "nvptx64-nvidia-cuda" - proc = "sm_90a" if capability == 90 else f"sm_{capability}" - features = f"+ptx{llvm_ptx_version}" - ret = llvm.translate_to_asm( - src, - triple, - proc, - features, - ["nvptx-short-ptr"], - opt.enable_fp_fusion, - False, - ) + triple = 'nvptx64-nvidia-cuda' + proc = 'sm_90a' if capability == 90 else f'sm_{capability}' + features = f'+ptx{llvm_ptx_version}' + ret = llvm.translate_to_asm(src, triple, proc, features, ['nvptx-short-ptr'], opt.enable_fp_fusion, False) # Find kernel names (there should only be one) names = re.findall(r".visible .entry ([a-zA-Z_][a-zA-Z0-9_]*)", ret) assert len(names) == 1 metadata["name"] = names[0] # post-process - ptx_version = f"{ptx_version//10}.{ptx_version%10}" - ret = re.sub( - r"\.version \d+\.\d+", f".version {ptx_version}", ret, flags=re.MULTILINE - ) + ptx_version = f'{ptx_version//10}.{ptx_version%10}' + ret = re.sub(r'\.version \d+\.\d+', f'.version {ptx_version}', ret, flags=re.MULTILINE) # Remove the debug flag that prevents ptxas from optimizing the code ret = re.sub(r",\s*debug|debug,\s*", "", ret) if os.environ.get("NVPTX_ENABLE_DUMP", "0") == "1": @@ -308,24 +275,19 @@ def make_ptx(src, metadata, opt, capability): @staticmethod def make_cubin(src, metadata, opt, capability): ptxas, _ = _path_to_binary("ptxas") - with tempfile.NamedTemporaryFile( - delete=False, mode="w", suffix=".ptx" - ) as fsrc, tempfile.NamedTemporaryFile( - delete=False, mode="r", suffix=".log" - ) as flog: + with tempfile.NamedTemporaryFile(delete=False, mode='w', suffix='.ptx') as fsrc, \ + tempfile.NamedTemporaryFile(delete=False, mode='r', suffix='.log') as flog: fsrc.write(src) fsrc.flush() - fbin = fsrc.name + ".o" + fbin = fsrc.name + '.o' - line_info = ( - "" if os.environ.get("TRITON_DISABLE_LINE_INFO") else " -lineinfo" - ) - fmad = "" if opt.enable_fp_fusion else " --fmad=false" - suffix = "a " if capability == 90 else " " + line_info = '' if os.environ.get('TRITON_DISABLE_LINE_INFO') else ' -lineinfo' + fmad = '' if opt.enable_fp_fusion else ' --fmad=false' + suffix = 'a ' if capability == 90 else ' ' if os.environ.get("DISABLE_PTXAS_OPT", "0") == "1": - cmd = f"{ptxas}{line_info}{fmad} -v --opt-level 0 --gpu-name=sm_{capability}{suffix}{fsrc.name} -o {fbin} 2> {flog.name}" + cmd = f'{ptxas}{line_info}{fmad} -v --opt-level 0 --gpu-name=sm_{capability}{suffix}{fsrc.name} -o {fbin} 2> {flog.name}' else: - cmd = f"{ptxas}{line_info}{fmad} -v --gpu-name=sm_{capability}{suffix}{fsrc.name} -o {fbin} 2> {flog.name}" + cmd = f'{ptxas}{line_info}{fmad} -v --gpu-name=sm_{capability}{suffix}{fsrc.name} -o {fbin} 2> {flog.name}' try: subprocess.run(cmd, shell=True, check=True) @@ -333,22 +295,19 @@ def make_cubin(src, metadata, opt, capability): with open(flog.name) as log_file: log = log_file.read() if e.returncode == 255: - raise RuntimeError(f"Internal Triton PTX codegen error: \n{log}") + raise RuntimeError(f'Internal Triton PTX codegen error: \n{log}') elif e.returncode == 128 + signal.SIGSEGV: raise RuntimeError( - f"Please run `ptxas {fsrc.name}` to confirm that this is a bug in `ptxas`\n{log}" - ) + f'Please run `ptxas {fsrc.name}` to confirm that this is a bug in `ptxas`\n{log}') else: - raise RuntimeError( - f"`ptxas` failed with error code {e.returncode}: \n{log}" - ) + raise RuntimeError(f'`ptxas` failed with error code {e.returncode}: \n{log}') finally: if os.path.exists(fsrc.name): os.remove(fsrc.name) if os.path.exists(flog.name): os.remove(flog.name) - with open(fbin, "rb") as f: + with open(fbin, 'rb') as f: cubin = f.read() if os.path.exists(fbin): os.remove(fbin) @@ -356,20 +315,13 @@ def make_cubin(src, metadata, opt, capability): def add_stages(self, stages, options): stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options) - stages["ttgir"] = lambda src, metadata: self.make_ttgir( - src, metadata, options, self.capability - ) - stages["llir"] = lambda src, metadata: self.make_llir( - src, metadata, options, self.capability - ) - stages["ptx"] = lambda src, metadata: self.make_ptx( - src, metadata, options, self.capability - ) - stages["cubin"] = lambda src, metadata: self.make_cubin( - src, metadata, options, self.capability - ) + stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, self.capability) + stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, self.capability) + stages["ptx"] = lambda src, metadata: self.make_ptx(src, metadata, options, self.capability) + stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, options, self.capability) @functools.lru_cache() def hash(self): version = get_ptxas_version() - return f"{version}-{self.capability}" + return f'{version}-{self.capability}' + From 6c5906ead09c5021df4303b4193cbc72d767e13d Mon Sep 17 00:00:00 2001 From: AdvancedCompiler Date: Thu, 29 May 2025 11:01:10 +0800 Subject: [PATCH 4/8] Update compiler.py code format --- third_party/nvidia/backend/compiler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index cb75f2a91..24a80484f 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -324,4 +324,3 @@ def add_stages(self, stages, options): def hash(self): version = get_ptxas_version() return f'{version}-{self.capability}' - From fc0be8640932dc2bab3fcc27fcc57cb311b200e6 Mon Sep 17 00:00:00 2001 From: AdvancedCompiler Date: Fri, 30 May 2025 16:28:49 +0800 Subject: [PATCH 5/8] add test for expression restructuring pass --- .../test_expression_restructuring.py | 65 +++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 python/test/operators/test_expression_restructuring.py diff --git a/python/test/operators/test_expression_restructuring.py b/python/test/operators/test_expression_restructuring.py new file mode 100644 index 000000000..49bebda3c --- /dev/null +++ b/python/test/operators/test_expression_restructuring.py @@ -0,0 +1,65 @@ +import triton +import triton.language as tl +import torch + +import pytest + +VEC_SHAPES = [[64, 640], [32,128], [128,256]] + +def custom_rand_strided(shape, strides, device, dtype, seed=0): + torch.manual_seed(seed) + total_size = sum((s - 1) * st for s, st in zip(shape, strides)) + 1 + storage = torch.randn(total_size, device=device, dtype=dtype) + return torch.as_strided(storage, size=shape, stride=strides) + +def torch_equivalent(arg_0, arg_1, arg_2, arg_3): + reshaped_arg_0 = arg_0.view(arg_2.shape[0], arg_2.shape[0], arg_2.shape[2]) + reshaped_arg_3 = arg_3.squeeze(-1) + tmp0 = -reshaped_arg_0 + tmp4 = arg_1 * arg_2 + tmp7 = reshaped_arg_3 + 1e-06 + tmp8 = tmp4 / tmp7.unsqueeze(-1) + tmp9 = tmp8 / tmp7.unsqueeze(-1) + result = tmp0 * tmp9 + return result + +@triton.jit +def expression_restructuring_function_test(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr2, rnumel): + XBLOCK: tl.constexpr = 1 + xoffset = tl.program_id(0) * XBLOCK + RBLOCK: tl.constexpr = 1024 + xindex = tl.full([1], xoffset, tl.int32) + rindex = tl.arange(0, RBLOCK)[:] + rmask = rindex < rnumel + r1 = rindex + x0 = xindex + tmp0 = tl.load(in_ptr0 + (r1 + (rnumel * x0)), rmask, other=0) + tmp2 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last', other=0) + tmp3 = tl.load(in_ptr2 + (r1 + (rnumel * x0)), rmask, other=0) + tmp5 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last') + tmp1 = -tmp0 + tmp4 = tmp2 * tmp3 + tmp6 = 1e-06 + tmp7 = tmp5 + tmp6 + tmp8 = tmp4 / tmp7 + tmp9 = tmp8 / tmp7 + tmp10 = tmp1 * tmp9 + tl.store(out_ptr2 + (r1 + (rnumel * x0)), tmp10, rmask) + +@pytest.mark.parametrize("vec_shape", VEC_SHAPES) +def test_accruacy_kernel(vec_shape): + x = vec_shape[0] + y = vec_shape[1] + arg_0 = custom_rand_strided((x * x, y), (y, 1), dtype=torch.float32,device='cuda') + arg_1 = custom_rand_strided((y,), (1,), dtype=torch.float32,device='cuda') + arg_2 = custom_rand_strided((x, x, y), (x * y, y, 1), dtype=torch.float32,device='cuda') + arg_3 = custom_rand_strided((x, x, 1), (x, 1, 1), dtype=torch.float32,device='cuda') + triton_result = custom_rand_strided((x, x, y), (x * y, y, 1), dtype=torch.float32,device='cuda') + grid = lambda meta: (x*x, ) + expression_restructuring_function_test[grid](arg_0, arg_1, arg_2, arg_3, triton_result, y) + torch_result = torch_equivalent(arg_0, arg_1, arg_2, arg_3) + torch.testing.assert_close(triton_result, torch_result) + + + + From 2a292b4685f709d16bdde0d8e7c9657152c065d3 Mon Sep 17 00:00:00 2001 From: AdvancedCompiler Date: Fri, 30 May 2025 16:44:50 +0800 Subject: [PATCH 6/8] code format --- .../test_expression_restructuring.py | 37 +++++++++---------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/python/test/operators/test_expression_restructuring.py b/python/test/operators/test_expression_restructuring.py index 49bebda3c..1ba8570ae 100644 --- a/python/test/operators/test_expression_restructuring.py +++ b/python/test/operators/test_expression_restructuring.py @@ -4,7 +4,8 @@ import pytest -VEC_SHAPES = [[64, 640], [32,128], [128,256]] +VEC_SHAPES = [[64, 640], [32, 128], [128, 256]] + def custom_rand_strided(shape, strides, device, dtype, seed=0): torch.manual_seed(seed) @@ -12,17 +13,19 @@ def custom_rand_strided(shape, strides, device, dtype, seed=0): storage = torch.randn(total_size, device=device, dtype=dtype) return torch.as_strided(storage, size=shape, stride=strides) + def torch_equivalent(arg_0, arg_1, arg_2, arg_3): - reshaped_arg_0 = arg_0.view(arg_2.shape[0], arg_2.shape[0], arg_2.shape[2]) - reshaped_arg_3 = arg_3.squeeze(-1) - tmp0 = -reshaped_arg_0 - tmp4 = arg_1 * arg_2 - tmp7 = reshaped_arg_3 + 1e-06 - tmp8 = tmp4 / tmp7.unsqueeze(-1) - tmp9 = tmp8 / tmp7.unsqueeze(-1) + reshaped_arg_0 = arg_0.view(arg_2.shape[0], arg_2.shape[0], arg_2.shape[2]) + reshaped_arg_3 = arg_3.squeeze(-1) + tmp0 = -reshaped_arg_0 + tmp4 = arg_1 * arg_2 + tmp7 = reshaped_arg_3 + 1e-06 + tmp8 = tmp4 / tmp7.unsqueeze(-1) + tmp9 = tmp8 / tmp7.unsqueeze(-1) result = tmp0 * tmp9 return result + @triton.jit def expression_restructuring_function_test(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr2, rnumel): XBLOCK: tl.constexpr = 1 @@ -50,16 +53,12 @@ def expression_restructuring_function_test(in_ptr0, in_ptr1, in_ptr2, in_ptr3, o def test_accruacy_kernel(vec_shape): x = vec_shape[0] y = vec_shape[1] - arg_0 = custom_rand_strided((x * x, y), (y, 1), dtype=torch.float32,device='cuda') - arg_1 = custom_rand_strided((y,), (1,), dtype=torch.float32,device='cuda') - arg_2 = custom_rand_strided((x, x, y), (x * y, y, 1), dtype=torch.float32,device='cuda') - arg_3 = custom_rand_strided((x, x, 1), (x, 1, 1), dtype=torch.float32,device='cuda') - triton_result = custom_rand_strided((x, x, y), (x * y, y, 1), dtype=torch.float32,device='cuda') - grid = lambda meta: (x*x, ) - expression_restructuring_function_test[grid](arg_0, arg_1, arg_2, arg_3, triton_result, y) + arg_0 = custom_rand_strided((x * x, y), (y, 1), dtype=torch.float32, device='cuda') + arg_1 = custom_rand_strided((y, ), (1, ), dtype=torch.float32, device='cuda') + arg_2 = custom_rand_strided((x, x, y), (x * y, y, 1), dtype=torch.float32, device='cuda') + arg_3 = custom_rand_strided((x, x, 1), (x, 1, 1), dtype=torch.float32, device='cuda') + triton_result = custom_rand_strided((x, x, y), (x * y, y, 1), dtype=torch.float32, device='cuda') + grid = lambda meta: (x * x, ) + expression_restructuring_function_test[grid](arg_0, arg_1, arg_2, arg_3, triton_result, y) torch_result = torch_equivalent(arg_0, arg_1, arg_2, arg_3) torch.testing.assert_close(triton_result, torch_result) - - - - From cf2ed7c314270b4d3f1f47590112a6466d603f7f Mon Sep 17 00:00:00 2001 From: AdvancedCompiler Date: Fri, 30 May 2025 16:54:04 +0800 Subject: [PATCH 7/8] Update test_expression_restructuring.py --- python/test/operators/test_expression_restructuring.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/test/operators/test_expression_restructuring.py b/python/test/operators/test_expression_restructuring.py index 1ba8570ae..c85312346 100644 --- a/python/test/operators/test_expression_restructuring.py +++ b/python/test/operators/test_expression_restructuring.py @@ -49,6 +49,7 @@ def expression_restructuring_function_test(in_ptr0, in_ptr1, in_ptr2, in_ptr3, o tmp10 = tmp1 * tmp9 tl.store(out_ptr2 + (r1 + (rnumel * x0)), tmp10, rmask) + @pytest.mark.parametrize("vec_shape", VEC_SHAPES) def test_accruacy_kernel(vec_shape): x = vec_shape[0] From d5130fdb137056359f590417c905389542b145ac Mon Sep 17 00:00:00 2001 From: zhengyang Date: Tue, 3 Jun 2025 10:31:33 +0800 Subject: [PATCH 8/8] [CI/CD] Add operators test to CI --- .github/workflows/nv-build-and-test.yml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/nv-build-and-test.yml b/.github/workflows/nv-build-and-test.yml index 85e76773f..6678754dd 100644 --- a/.github/workflows/nv-build-and-test.yml +++ b/.github/workflows/nv-build-and-test.yml @@ -20,14 +20,15 @@ jobs: - name: Checkout code uses: actions/checkout@v4 - - name: FlagTree Build on NVIDIA-A100 + - name: FlagTree Build shell: bash run: | source ~/env.sh cd python - MAX_JOBS=20 pip3.11 install . --no-build-isolation + MAX_JOBS=32 pip3.11 install . --no-build-isolation - - name: FlagTree Test on NVIDIA-A100 + - name: FlagTree Test shell: bash run: | pytest -s python/test/unit + pytest -s python/test/operators