diff --git a/.github/workflows/nv-build-and-test.yml b/.github/workflows/nv-build-and-test.yml index 85e76773f..6678754dd 100644 --- a/.github/workflows/nv-build-and-test.yml +++ b/.github/workflows/nv-build-and-test.yml @@ -20,14 +20,15 @@ jobs: - name: Checkout code uses: actions/checkout@v4 - - name: FlagTree Build on NVIDIA-A100 + - name: FlagTree Build shell: bash run: | source ~/env.sh cd python - MAX_JOBS=20 pip3.11 install . --no-build-isolation + MAX_JOBS=32 pip3.11 install . --no-build-isolation - - name: FlagTree Test on NVIDIA-A100 + - name: FlagTree Test shell: bash run: | pytest -s python/test/unit + pytest -s python/test/operators diff --git a/include/triton/Dialect/Triton/Transforms/Passes.h b/include/triton/Dialect/Triton/Transforms/Passes.h index fde54fe17..3f00260ef 100644 --- a/include/triton/Dialect/Triton/Transforms/Passes.h +++ b/include/triton/Dialect/Triton/Transforms/Passes.h @@ -11,6 +11,7 @@ std::unique_ptr createCombineOpsPass(); std::unique_ptr createReorderBroadcastPass(); std::unique_ptr createRewriteTensorPointerPass(); +std::unique_ptr createExpressionRestructingPass(); } // namespace triton #define GEN_PASS_REGISTRATION diff --git a/include/triton/Dialect/Triton/Transforms/Passes.td b/include/triton/Dialect/Triton/Transforms/Passes.td index 4ebff63fa..165b779b6 100644 --- a/include/triton/Dialect/Triton/Transforms/Passes.td +++ b/include/triton/Dialect/Triton/Transforms/Passes.td @@ -41,4 +41,18 @@ def TritonRewriteTensorPointer : Pass { + let summary = "ExpressionRestructing"; + let description = [{ + transform a = b / c; d = a / e; to a = c * e; d = b / a; + transform a = b + c; d = a + c; to a = c + c; d = b + a; + transform a = b - c; d = a - c; to a = c + c; d = b - a; + transform a = b * c; d = a * c; to a = c * c; d = b * a; + }]; + + let constructor = "mlir::triton::createExpressionRestructingPass()"; + + let dependentDialects = ["mlir::triton::TritonDialect", "mlir::arith::ArithDialect"]; +} + #endif diff --git a/lib/Dialect/Triton/Transforms/CMakeLists.txt b/lib/Dialect/Triton/Transforms/CMakeLists.txt index 298398750..444644a3f 100644 --- a/lib/Dialect/Triton/Transforms/CMakeLists.txt +++ b/lib/Dialect/Triton/Transforms/CMakeLists.txt @@ -6,6 +6,8 @@ add_triton_library(TritonTransforms Combine.cpp ReorderBroadcast.cpp RewriteTensorPointer.cpp + ExpressionRestructing.cpp + DEPENDS TritonTransformsIncGen diff --git a/lib/Dialect/Triton/Transforms/ExpressionRestructing.cpp b/lib/Dialect/Triton/Transforms/ExpressionRestructing.cpp new file mode 100644 index 000000000..74681d389 --- /dev/null +++ b/lib/Dialect/Triton/Transforms/ExpressionRestructing.cpp @@ -0,0 +1,189 @@ +#include + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/Transforms/Passes.h" + +#define GEN_PASS_CLASSES +#include "triton/Dialect/Triton/Transforms/Passes.h.inc" + +using namespace mlir; +using llvm::ArrayRef; +namespace mlir::triton { + +struct Div2Mul : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::DivFOp op, + PatternRewriter &rewriter) const override { + Value result = op.getResult(); + Value l = op.getLhs(); + Value r = op.getRhs(); + auto loc = op.getLoc(); + + if (!result.hasOneUse()) + return failure(); + for (auto &use : result.getUses()) { + if (!dyn_cast(use.getOwner())) + return failure(); + auto DivUser = dyn_cast(use.getOwner()); + if (DivUser.getLhs() != op.getResult()) + return failure(); + auto originalInsertionPoint = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointAfter(DivUser); + auto loc_div = DivUser.getLoc(); + auto product = + rewriter.create(loc_div, r, DivUser.getRhs()); + rewriter.setInsertionPointAfter(product); + auto ResultEnd = + rewriter.create(loc_div, l, product.getResult()); + rewriter.restoreInsertionPoint(originalInsertionPoint); + rewriter.replaceOp(op, product.getResult()); + DivUser.replaceAllUsesWith(ResultEnd.getResult()); + rewriter.eraseOp(DivUser); + } + return success(); + } +}; + +struct Mul2Mul : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::MulFOp op, + PatternRewriter &rewriter) const override { + Value result = op.getResult(); + Value l = op.getLhs(); + Value r = op.getRhs(); + auto loc = op.getLoc(); + + if (!result.hasOneUse()) + return failure(); + for (auto &use : result.getUses()) { + if (!dyn_cast(use.getOwner())) + return failure(); + auto MulUser = dyn_cast(use.getOwner()); + if (!(MulUser.getLhs() == op.getResult() && + ((MulUser.getRhs().getDefiningOp() && + r.getDefiningOp()) || + (r == MulUser.getRhs())))) + return failure(); + auto originalInsertionPoint = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointAfter(MulUser); + auto loc_mul = MulUser.getLoc(); + auto product = + rewriter.create(loc_mul, r, MulUser.getRhs()); + rewriter.setInsertionPointAfter(product); + auto ResultEnd = + rewriter.create(loc_mul, l, product.getResult()); + rewriter.restoreInsertionPoint(originalInsertionPoint); + rewriter.replaceOp(op, product.getResult()); + MulUser.replaceAllUsesWith(ResultEnd.getResult()); + rewriter.eraseOp(MulUser); + } + return success(); + } +}; + +struct Add2Add : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::AddFOp op, + PatternRewriter &rewriter) const override { + Value result = op.getResult(); + Value l = op.getLhs(); + Value r = op.getRhs(); + auto loc = op.getLoc(); + + if (!result.hasOneUse()) + return failure(); + for (auto &use : result.getUses()) { + if (!dyn_cast(use.getOwner())) + return failure(); + auto AddUser = dyn_cast(use.getOwner()); + if (!(AddUser.getLhs() == op.getResult() && + ((AddUser.getRhs().getDefiningOp() && + r.getDefiningOp()) || + (r == AddUser.getRhs())))) + return failure(); + auto originalInsertionPoint = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointAfter(AddUser); + auto loc_add = AddUser.getLoc(); + auto sum = rewriter.create(loc_add, r, AddUser.getRhs()); + rewriter.setInsertionPointAfter(sum); + auto ResultEnd = + rewriter.create(loc_add, l, sum.getResult()); + rewriter.restoreInsertionPoint(originalInsertionPoint); + rewriter.replaceOp(op, sum.getResult()); + AddUser.replaceAllUsesWith(ResultEnd.getResult()); + rewriter.eraseOp(AddUser); + } + return success(); + } +}; + +struct Sub2Add : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::SubFOp op, + PatternRewriter &rewriter) const override { + Value result = op.getResult(); + Value l = op.getLhs(); + Value r = op.getRhs(); + auto loc = op.getLoc(); + + if (!result.hasOneUse()) + return failure(); + for (auto &use : result.getUses()) { + if (!dyn_cast(use.getOwner())) + return failure(); + auto SubUser = dyn_cast(use.getOwner()); + if (!(SubUser.getLhs() == op.getResult() && + ((SubUser.getRhs().getDefiningOp() && + r.getDefiningOp()) || + (r == SubUser.getRhs())))) + return failure(); + auto originalInsertionPoint = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointAfter(SubUser); + auto loc_sub = SubUser.getLoc(); + auto sum = rewriter.create(loc_sub, r, SubUser.getRhs()); + rewriter.setInsertionPointAfter(sum); + auto ResultEnd = + rewriter.create(loc_sub, l, sum.getResult()); + rewriter.restoreInsertionPoint(originalInsertionPoint); + rewriter.replaceOp(op, sum.getResult()); + SubUser.replaceAllUsesWith(ResultEnd.getResult()); + rewriter.eraseOp(SubUser); + } + return success(); + } +}; + +class ExpressionRestructingPass + : public TritonExpressionRestructingBase { +public: + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + ModuleOp m = getOperation(); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + + if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) + signalPassFailure(); + } +}; + +std::unique_ptr createExpressionRestructingPass() { + return std::make_unique(); +} + +} // namespace mlir::triton diff --git a/python/src/passes.cc b/python/src/passes.cc index 513e811d2..263b20dae 100644 --- a/python/src/passes.cc +++ b/python/src/passes.cc @@ -37,6 +37,8 @@ void init_triton_passes_ttir(py::module &&m) { using namespace mlir::triton; ADD_PASS_WRAPPER_0("add_combine", createCombineOpsPass); ADD_PASS_WRAPPER_0("add_reorder_broadcast", createReorderBroadcastPass); + ADD_PASS_WRAPPER_0("add_expression_restructing", + createExpressionRestructingPass); ADD_PASS_WRAPPER_0("add_rewrite_tensor_pointer", createRewriteTensorPointerPass); ADD_PASS_WRAPPER_4("add_convert_to_ttgpuir", diff --git a/python/test/operators/test_expression_restructuring.py b/python/test/operators/test_expression_restructuring.py new file mode 100644 index 000000000..c85312346 --- /dev/null +++ b/python/test/operators/test_expression_restructuring.py @@ -0,0 +1,65 @@ +import triton +import triton.language as tl +import torch + +import pytest + +VEC_SHAPES = [[64, 640], [32, 128], [128, 256]] + + +def custom_rand_strided(shape, strides, device, dtype, seed=0): + torch.manual_seed(seed) + total_size = sum((s - 1) * st for s, st in zip(shape, strides)) + 1 + storage = torch.randn(total_size, device=device, dtype=dtype) + return torch.as_strided(storage, size=shape, stride=strides) + + +def torch_equivalent(arg_0, arg_1, arg_2, arg_3): + reshaped_arg_0 = arg_0.view(arg_2.shape[0], arg_2.shape[0], arg_2.shape[2]) + reshaped_arg_3 = arg_3.squeeze(-1) + tmp0 = -reshaped_arg_0 + tmp4 = arg_1 * arg_2 + tmp7 = reshaped_arg_3 + 1e-06 + tmp8 = tmp4 / tmp7.unsqueeze(-1) + tmp9 = tmp8 / tmp7.unsqueeze(-1) + result = tmp0 * tmp9 + return result + + +@triton.jit +def expression_restructuring_function_test(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr2, rnumel): + XBLOCK: tl.constexpr = 1 + xoffset = tl.program_id(0) * XBLOCK + RBLOCK: tl.constexpr = 1024 + xindex = tl.full([1], xoffset, tl.int32) + rindex = tl.arange(0, RBLOCK)[:] + rmask = rindex < rnumel + r1 = rindex + x0 = xindex + tmp0 = tl.load(in_ptr0 + (r1 + (rnumel * x0)), rmask, other=0) + tmp2 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last', other=0) + tmp3 = tl.load(in_ptr2 + (r1 + (rnumel * x0)), rmask, other=0) + tmp5 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last') + tmp1 = -tmp0 + tmp4 = tmp2 * tmp3 + tmp6 = 1e-06 + tmp7 = tmp5 + tmp6 + tmp8 = tmp4 / tmp7 + tmp9 = tmp8 / tmp7 + tmp10 = tmp1 * tmp9 + tl.store(out_ptr2 + (r1 + (rnumel * x0)), tmp10, rmask) + + +@pytest.mark.parametrize("vec_shape", VEC_SHAPES) +def test_accruacy_kernel(vec_shape): + x = vec_shape[0] + y = vec_shape[1] + arg_0 = custom_rand_strided((x * x, y), (y, 1), dtype=torch.float32, device='cuda') + arg_1 = custom_rand_strided((y, ), (1, ), dtype=torch.float32, device='cuda') + arg_2 = custom_rand_strided((x, x, y), (x * y, y, 1), dtype=torch.float32, device='cuda') + arg_3 = custom_rand_strided((x, x, 1), (x, 1, 1), dtype=torch.float32, device='cuda') + triton_result = custom_rand_strided((x, x, y), (x * y, y, 1), dtype=torch.float32, device='cuda') + grid = lambda meta: (x * x, ) + expression_restructuring_function_test[grid](arg_0, arg_1, arg_2, arg_3, triton_result, y) + torch_result = torch_equivalent(arg_0, arg_1, arg_2, arg_3) + torch.testing.assert_close(triton_result, torch_result) diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 6d7994923..24a80484f 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -144,6 +144,7 @@ def make_ttir(mod, metadata, opt): passes.common.add_canonicalizer(pm) passes.ttir.add_reorder_broadcast(pm) passes.common.add_cse(pm) + passes.ttir.add_expression_restructing(pm) passes.common.add_licm(pm) passes.common.add_symbol_dce(pm) pm.run(mod)