Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions .github/workflows/nv-build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@ jobs:
- name: Checkout code
uses: actions/checkout@v4

- name: FlagTree Build on NVIDIA-A100
- name: FlagTree Build
shell: bash
run: |
source ~/env.sh
cd python
MAX_JOBS=20 pip3.11 install . --no-build-isolation
MAX_JOBS=32 pip3.11 install . --no-build-isolation

- name: FlagTree Test on NVIDIA-A100
- name: FlagTree Test
shell: bash
run: |
pytest -s python/test/unit
pytest -s python/test/operators
1 change: 1 addition & 0 deletions include/triton/Dialect/Triton/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ std::unique_ptr<Pass> createCombineOpsPass();
std::unique_ptr<Pass> createReorderBroadcastPass();
std::unique_ptr<Pass> createRewriteTensorPointerPass();

std::unique_ptr<Pass> createExpressionRestructingPass();
} // namespace triton

#define GEN_PASS_REGISTRATION
Expand Down
14 changes: 14 additions & 0 deletions include/triton/Dialect/Triton/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,18 @@ def TritonRewriteTensorPointer : Pass</*cli-arg*/"triton-rewrite-tensor-pointer"
let dependentDialects = ["mlir::triton::TritonDialect"];
}

def TritonExpressionRestructing : Pass</*cli-arg*/"triton-expression-resturcting", /*Op*/"mlir::ModuleOp"> {
let summary = "ExpressionRestructing";
let description = [{
transform a = b / c; d = a / e; to a = c * e; d = b / a;
transform a = b + c; d = a + c; to a = c + c; d = b + a;
transform a = b - c; d = a - c; to a = c + c; d = b - a;
transform a = b * c; d = a * c; to a = c * c; d = b * a;
}];

let constructor = "mlir::triton::createExpressionRestructingPass()";

let dependentDialects = ["mlir::triton::TritonDialect", "mlir::arith::ArithDialect"];
}

#endif
2 changes: 2 additions & 0 deletions lib/Dialect/Triton/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ add_triton_library(TritonTransforms
Combine.cpp
ReorderBroadcast.cpp
RewriteTensorPointer.cpp
ExpressionRestructing.cpp


DEPENDS
TritonTransformsIncGen
Expand Down
189 changes: 189 additions & 0 deletions lib/Dialect/Triton/Transforms/ExpressionRestructing.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
#include <memory>

#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/Transforms/Passes.h"

#define GEN_PASS_CLASSES
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"

using namespace mlir;
using llvm::ArrayRef;
namespace mlir::triton {

struct Div2Mul : public OpRewritePattern<arith::DivFOp> {
using OpRewritePattern<arith::DivFOp>::OpRewritePattern;

LogicalResult matchAndRewrite(arith::DivFOp op,
PatternRewriter &rewriter) const override {
Value result = op.getResult();
Value l = op.getLhs();
Value r = op.getRhs();
auto loc = op.getLoc();

if (!result.hasOneUse())
return failure();
for (auto &use : result.getUses()) {
if (!dyn_cast<arith::DivFOp>(use.getOwner()))
return failure();
auto DivUser = dyn_cast<arith::DivFOp>(use.getOwner());
if (DivUser.getLhs() != op.getResult())
return failure();
auto originalInsertionPoint = rewriter.saveInsertionPoint();
rewriter.setInsertionPointAfter(DivUser);
auto loc_div = DivUser.getLoc();
auto product =
rewriter.create<arith::MulFOp>(loc_div, r, DivUser.getRhs());
rewriter.setInsertionPointAfter(product);
auto ResultEnd =
rewriter.create<arith::DivFOp>(loc_div, l, product.getResult());
rewriter.restoreInsertionPoint(originalInsertionPoint);
rewriter.replaceOp(op, product.getResult());
DivUser.replaceAllUsesWith(ResultEnd.getResult());
rewriter.eraseOp(DivUser);
}
return success();
}
};

struct Mul2Mul : public OpRewritePattern<arith::MulFOp> {
using OpRewritePattern<arith::MulFOp>::OpRewritePattern;

LogicalResult matchAndRewrite(arith::MulFOp op,
PatternRewriter &rewriter) const override {
Value result = op.getResult();
Value l = op.getLhs();
Value r = op.getRhs();
auto loc = op.getLoc();

if (!result.hasOneUse())
return failure();
for (auto &use : result.getUses()) {
if (!dyn_cast<arith::MulFOp>(use.getOwner()))
return failure();
auto MulUser = dyn_cast<arith::MulFOp>(use.getOwner());
if (!(MulUser.getLhs() == op.getResult() &&
((MulUser.getRhs().getDefiningOp<arith::ConstantOp>() &&
r.getDefiningOp<arith::ConstantOp>()) ||
(r == MulUser.getRhs()))))
return failure();
auto originalInsertionPoint = rewriter.saveInsertionPoint();
rewriter.setInsertionPointAfter(MulUser);
auto loc_mul = MulUser.getLoc();
auto product =
rewriter.create<arith::MulFOp>(loc_mul, r, MulUser.getRhs());
rewriter.setInsertionPointAfter(product);
auto ResultEnd =
rewriter.create<arith::MulFOp>(loc_mul, l, product.getResult());
rewriter.restoreInsertionPoint(originalInsertionPoint);
rewriter.replaceOp(op, product.getResult());
MulUser.replaceAllUsesWith(ResultEnd.getResult());
rewriter.eraseOp(MulUser);
}
return success();
}
};

struct Add2Add : public OpRewritePattern<arith::AddFOp> {
using OpRewritePattern<arith::AddFOp>::OpRewritePattern;

LogicalResult matchAndRewrite(arith::AddFOp op,
PatternRewriter &rewriter) const override {
Value result = op.getResult();
Value l = op.getLhs();
Value r = op.getRhs();
auto loc = op.getLoc();

if (!result.hasOneUse())
return failure();
for (auto &use : result.getUses()) {
if (!dyn_cast<arith::AddFOp>(use.getOwner()))
return failure();
auto AddUser = dyn_cast<arith::AddFOp>(use.getOwner());
if (!(AddUser.getLhs() == op.getResult() &&
((AddUser.getRhs().getDefiningOp<arith::ConstantOp>() &&
r.getDefiningOp<arith::ConstantOp>()) ||
(r == AddUser.getRhs()))))
return failure();
auto originalInsertionPoint = rewriter.saveInsertionPoint();
rewriter.setInsertionPointAfter(AddUser);
auto loc_add = AddUser.getLoc();
auto sum = rewriter.create<arith::AddFOp>(loc_add, r, AddUser.getRhs());
rewriter.setInsertionPointAfter(sum);
auto ResultEnd =
rewriter.create<arith::AddFOp>(loc_add, l, sum.getResult());
rewriter.restoreInsertionPoint(originalInsertionPoint);
rewriter.replaceOp(op, sum.getResult());
AddUser.replaceAllUsesWith(ResultEnd.getResult());
rewriter.eraseOp(AddUser);
}
return success();
}
};

struct Sub2Add : public OpRewritePattern<arith::SubFOp> {
using OpRewritePattern<arith::SubFOp>::OpRewritePattern;

LogicalResult matchAndRewrite(arith::SubFOp op,
PatternRewriter &rewriter) const override {
Value result = op.getResult();
Value l = op.getLhs();
Value r = op.getRhs();
auto loc = op.getLoc();

if (!result.hasOneUse())
return failure();
for (auto &use : result.getUses()) {
if (!dyn_cast<arith::SubFOp>(use.getOwner()))
return failure();
auto SubUser = dyn_cast<arith::SubFOp>(use.getOwner());
if (!(SubUser.getLhs() == op.getResult() &&
((SubUser.getRhs().getDefiningOp<arith::ConstantOp>() &&
r.getDefiningOp<arith::ConstantOp>()) ||
(r == SubUser.getRhs()))))
return failure();
auto originalInsertionPoint = rewriter.saveInsertionPoint();
rewriter.setInsertionPointAfter(SubUser);
auto loc_sub = SubUser.getLoc();
auto sum = rewriter.create<arith::AddFOp>(loc_sub, r, SubUser.getRhs());
rewriter.setInsertionPointAfter(sum);
auto ResultEnd =
rewriter.create<arith::SubFOp>(loc_sub, l, sum.getResult());
rewriter.restoreInsertionPoint(originalInsertionPoint);
rewriter.replaceOp(op, sum.getResult());
SubUser.replaceAllUsesWith(ResultEnd.getResult());
rewriter.eraseOp(SubUser);
}
return success();
}
};

class ExpressionRestructingPass
: public TritonExpressionRestructingBase<ExpressionRestructingPass> {
public:
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
ModuleOp m = getOperation();
patterns.add<Div2Mul>(context);
patterns.add<Mul2Mul>(context);
patterns.add<Add2Add>(context);
patterns.add<Sub2Add>(context);

if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
signalPassFailure();
}
};

std::unique_ptr<mlir::Pass> createExpressionRestructingPass() {
return std::make_unique<ExpressionRestructingPass>();
}

} // namespace mlir::triton
2 changes: 2 additions & 0 deletions python/src/passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ void init_triton_passes_ttir(py::module &&m) {
using namespace mlir::triton;
ADD_PASS_WRAPPER_0("add_combine", createCombineOpsPass);
ADD_PASS_WRAPPER_0("add_reorder_broadcast", createReorderBroadcastPass);
ADD_PASS_WRAPPER_0("add_expression_restructing",
createExpressionRestructingPass);
ADD_PASS_WRAPPER_0("add_rewrite_tensor_pointer",
createRewriteTensorPointerPass);
ADD_PASS_WRAPPER_4("add_convert_to_ttgpuir",
Expand Down
65 changes: 65 additions & 0 deletions python/test/operators/test_expression_restructuring.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import triton
import triton.language as tl
import torch

import pytest

VEC_SHAPES = [[64, 640], [32, 128], [128, 256]]


def custom_rand_strided(shape, strides, device, dtype, seed=0):
torch.manual_seed(seed)
total_size = sum((s - 1) * st for s, st in zip(shape, strides)) + 1
storage = torch.randn(total_size, device=device, dtype=dtype)
return torch.as_strided(storage, size=shape, stride=strides)


def torch_equivalent(arg_0, arg_1, arg_2, arg_3):
reshaped_arg_0 = arg_0.view(arg_2.shape[0], arg_2.shape[0], arg_2.shape[2])
reshaped_arg_3 = arg_3.squeeze(-1)
tmp0 = -reshaped_arg_0
tmp4 = arg_1 * arg_2
tmp7 = reshaped_arg_3 + 1e-06
tmp8 = tmp4 / tmp7.unsqueeze(-1)
tmp9 = tmp8 / tmp7.unsqueeze(-1)
result = tmp0 * tmp9
return result


@triton.jit
def expression_restructuring_function_test(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr2, rnumel):
XBLOCK: tl.constexpr = 1
xoffset = tl.program_id(0) * XBLOCK
RBLOCK: tl.constexpr = 1024
xindex = tl.full([1], xoffset, tl.int32)
rindex = tl.arange(0, RBLOCK)[:]
rmask = rindex < rnumel
r1 = rindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r1 + (rnumel * x0)), rmask, other=0)
tmp2 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last', other=0)
tmp3 = tl.load(in_ptr2 + (r1 + (rnumel * x0)), rmask, other=0)
tmp5 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last')
tmp1 = -tmp0
tmp4 = tmp2 * tmp3
tmp6 = 1e-06
tmp7 = tmp5 + tmp6
tmp8 = tmp4 / tmp7
tmp9 = tmp8 / tmp7
tmp10 = tmp1 * tmp9
tl.store(out_ptr2 + (r1 + (rnumel * x0)), tmp10, rmask)


@pytest.mark.parametrize("vec_shape", VEC_SHAPES)
def test_accruacy_kernel(vec_shape):
x = vec_shape[0]
y = vec_shape[1]
arg_0 = custom_rand_strided((x * x, y), (y, 1), dtype=torch.float32, device='cuda')
arg_1 = custom_rand_strided((y, ), (1, ), dtype=torch.float32, device='cuda')
arg_2 = custom_rand_strided((x, x, y), (x * y, y, 1), dtype=torch.float32, device='cuda')
arg_3 = custom_rand_strided((x, x, 1), (x, 1, 1), dtype=torch.float32, device='cuda')
triton_result = custom_rand_strided((x, x, y), (x * y, y, 1), dtype=torch.float32, device='cuda')
grid = lambda meta: (x * x, )
expression_restructuring_function_test[grid](arg_0, arg_1, arg_2, arg_3, triton_result, y)
torch_result = torch_equivalent(arg_0, arg_1, arg_2, arg_3)
torch.testing.assert_close(triton_result, torch_result)
1 change: 1 addition & 0 deletions third_party/nvidia/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def make_ttir(mod, metadata, opt):
passes.common.add_canonicalizer(pm)
passes.ttir.add_reorder_broadcast(pm)
passes.common.add_cse(pm)
passes.ttir.add_expression_restructing(pm)
passes.common.add_licm(pm)
passes.common.add_symbol_dce(pm)
pm.run(mod)
Expand Down
Loading