Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions .github/workflows/aipu-build-and-test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
name: AIPU-Build-And-Test

on:
push:
branches: [ "triton_v3.3.x" ]
pull_request:
branches: [ "triton_v3.3.x" ]

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true

jobs:
aipu-build-and-test:
runs-on: aipu
steps:
- name: Checkout code (attempt 1)
id: checkout1
uses: actions/checkout@v4
continue-on-error: true

- name: Sleep before checkout2
if: steps.checkout1.outcome == 'failure'
run: |
echo "First checkout attempt failed. Sleeping for 120 seconds before retry..."
sleep 120

- name: Checkout code (attempt 2)
id: checkout2
if: steps.checkout1.outcome == 'failure'
uses: actions/checkout@v4
continue-on-error: true

- name: Sleep before final checkout
if: steps.checkout1.outcome == 'failure' && steps.checkout2.outcome == 'failure'
run: |
echo "Second checkout attempt failed. Sleeping for 180 seconds before final retry..."
sleep 180

- name: Checkout code (final attempt)
if: steps.checkout1.outcome == 'failure' && steps.checkout2.outcome == 'failure'
uses: actions/checkout@v4

- name: Verify checkout success
if: success()
run: echo "Checkout completed successfully"

- name: FlagTree Build on AIPU
shell: bash
run: |
source ~/env.sh
source ~/env_setup.sh
export FLAGTREE_BACKEND=aipu
cd python
python3.10 -m pip install . --no-build-isolation -v

- name: FlagTree Test on AIPU
shell: bash
run: |
source ~/env_setup.sh
python3.10 third_party/aipu/python/test/test_01_vector_add.py
python3.10 third_party/aipu/python/test/test_02_fused_softmax.py
21 changes: 0 additions & 21 deletions .github/workflows/code-format-check-master.yml

This file was deleted.

4 changes: 2 additions & 2 deletions .github/workflows/code-format-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ on:
schedule:
- cron: '0 21 * * *'
push:
branches: [ "main" ]
branches: [ "main", "triton_v3.3.x" ]
pull_request:
branches: [ "main" ]
branches: [ "main", "triton_v3.3.x" ]

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/iluvatar-build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ jobs:
export FLAGTREE_BACKEND=iluvatar
source ~/env.sh
cd python
MAX_JOBS=20 pip3 install . --no-build-isolation
MAX_JOBS=32 pip3 install . --no-build-isolation

- name: FlagTree Test on Iluvatar
shell: bash
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/metax-build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
source ~/env.sh
export FLAGTREE_BACKEND=metax
cd python
MAX_JOBS=20 pip3 install . --no-build-isolation
MAX_JOBS=32 pip3 install . --no-build-isolation

- name: FlagTree Test on Metax
shell: bash
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/mthreads-build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
source ~/env.sh
export FLAGTREE_BACKEND=mthreads
cd python
MAX_JOBS=20 pip3 install . --no-build-isolation
MAX_JOBS=32 pip3 install . --no-build-isolation

- name: FlagTree Test on Mthreads
shell: bash
Expand Down
30 changes: 25 additions & 5 deletions .github/workflows/nv-build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ on:
schedule:
- cron: '0 21 * * *'
push:
branches: [ "main" ]
branches: [ "main", "triton_v3.3.x" ]
pull_request:
branches: [ "main" ]
branches: [ "main", "triton_v3.3.x" ]

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
Expand All @@ -19,14 +19,34 @@ jobs:
- name: Checkout code
uses: actions/checkout@v4

- name: FlagTree Build on NVIDIA-A100
- name: Detect Target Branch
shell: bash
run: |
if [ "${{ github.event_name }}" = "pull_request" ]; then
TARGET_BRANCH="${{ github.base_ref }}"
else
TARGET_BRANCH="${{ github.ref_name }}"
fi
echo "TARGET_BRANCH=$TARGET_BRANCH" >> $GITHUB_ENV
echo "TARGET_BRANCH=$TARGET_BRANCH"

- name: FlagTree Build (Main branch)
if: ${{ env.TARGET_BRANCH == 'main' }}
shell: bash
run: |
source ~/env.sh
cd python
MAX_JOBS=20 pip3.11 install . --no-build-isolation
MAX_JOBS=32 pip3.11 install . --no-build-isolation

- name: FlagTree Build (triton_v3.3.x branch)
if: ${{ env.TARGET_BRANCH == 'triton_v3.3.x' }}
shell: bash
run: |
source ~/env-3.3.sh
cd python
MAX_JOBS=32 pip3.11 install . --no-build-isolation

- name: FlagTree Test on NVIDIA-A100
- name: FlagTree Test
shell: bash
run: |
pytest -s python/test/unit
13 changes: 8 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ elseif(FLAGTREE_BACKEND STREQUAL "mthreads")
set(CMAKE_C_COMPILER clang)
set(CMAKE_CXX_COMPILER clang++)
set(ENV{FLAGTREE_PLUGIN} $ENV{FLAGTREE_BACKEND})
elseif(FLAGTREE_BACKEND STREQUAL "aipu")
add_definitions(-D__NVIDIA__)
add_definitions(-D__AMD__)
endif()
set(FLAGTREE_PLUGIN "$ENV{FLAGTREE_PLUGIN}")
if(FLAGTREE_PLUGIN)
Expand Down Expand Up @@ -201,7 +204,7 @@ include_directories(${PROJECT_SOURCE_DIR}/third_party)
include_directories(${PROJECT_BINARY_DIR}/third_party) # Tablegen'd files

# link_directories(${LLVM_LIBRARY_DIR})
if (FLAGTREE_BACKEND STREQUAL "cambricon")
if (FLAGTREE_BACKEND MATCHES "^(cambricon|aipu)$")
include_directories(${PROJECT_SOURCE_DIR}/include)
include_directories(${PROJECT_BINARY_DIR}/include) # Tablegen'd files
add_subdirectory(include)
Expand Down Expand Up @@ -263,10 +266,10 @@ if(TRITON_BUILD_PYTHON_MODULE)
if (TRITON_BUILD_PROTON)
add_definitions(-D__PROTON__)
add_subdirectory(third_party/proton)
# We always build proton dialect
list(APPEND TRITON_PLUGIN_NAMES "proton")
add_subdirectory(third_party/proton/dialect)
endif()
# We always build proton dialect
list(APPEND TRITON_PLUGIN_NAMES "proton")
add_subdirectory(third_party/proton/dialect)

get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS)
get_property(triton_plugins GLOBAL PROPERTY TRITON_PLUGINS)
Expand Down Expand Up @@ -443,7 +446,7 @@ find_package(Threads REQUIRED)

add_subdirectory(third_party/f2reduce)

if(NOT FLAGTREE_BACKEND)
if(NOT FLAGTREE_BACKEND OR FLAGTREE_BACKEND STREQUAL "aipu")
add_subdirectory(bin)
add_subdirectory(test)
endif()
Expand Down
1 change: 1 addition & 0 deletions include/triton/Dialect/Triton/IR/TritonAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define TRITON_ATTR_DEFS

include "mlir/IR/EnumAttr.td"
include "mlir/IR/AttrTypeBase.td"

// Attributes for LoadOp and StoreOp
def TT_CacheModifierAttr : I32EnumAttr<
Expand Down
23 changes: 22 additions & 1 deletion include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface
include "triton/Dialect/Triton/IR/TritonOpInterfaces.td"
include "mlir/IR/BuiltinAttributes.td"


//
Expand Down Expand Up @@ -248,13 +249,33 @@ def TT_LoadOp : TT_Op<"load", [
OptionalAttr<TT_PaddingOptionAttr>:$padding,
DefaultValuedAttr<TT_CacheModifierAttr, "::mlir::triton::CacheModifier::NONE">:$cache,
DefaultValuedAttr<TT_EvictionPolicyAttr, "::mlir::triton::EvictionPolicy::NORMAL">:$evict,
DefaultValuedAttr<BoolAttr, "false">:$isVolatile
DefaultValuedAttr<BoolAttr, "false">:$isVolatile,
// TODO: now flagtree_hints is string, default value of an empty string (""), needed redesign
DefaultValuedAttr<StrAttr, "\"\"">:$flagtree_hints
);

let results = (outs TT_Type:$result);

let builders = [
// A tensor of pointers or a pointer to a scalar
OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "mlir::StringAttr":$flagtree_hints)>,
// A tensor pointer with boundary check and padding
OpBuilder<(ins "Value":$ptr, "ArrayRef<int32_t>":$boundaryCheck,
"std::optional<triton::PaddingOption>":$padding, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "mlir::StringAttr":$flagtree_hints)>,
// A tensor of pointers or a pointer to a scalar with mask
OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "mlir::StringAttr":$flagtree_hints)>,
// A tensor of pointers or a pointer to a scalar with mask and other
OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "mlir::StringAttr":$flagtree_hints)>,
// A utility function to build the operation with all attributes
OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other,
"ArrayRef<int32_t>":$boundaryCheck,
"std::optional<triton::PaddingOption>":$padding, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "mlir::StringAttr":$flagtree_hints)>,
// A tensor of pointers or a pointer to a scalar
OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
// A tensor pointer with boundary check and padding
Expand Down
55 changes: 55 additions & 0 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr,
cache, evict, isVolatile);
}

// implementatio with flagtree_hints
void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr,
CacheModifier cache, EvictionPolicy evict, bool isVolatile,
mlir::StringAttr flagtree_hints) {
LoadOp::build(builder, state, ptr, /*mask=*/{}, /*other=*/{},
/*boundaryCheck=*/ArrayRef<int32_t>{}, /*padding=*/std::nullopt,
cache, evict, isVolatile, flagtree_hints);
}

void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr,
ArrayRef<int32_t> boundaryCheck,
std::optional<PaddingOption> padding, CacheModifier cache,
Expand All @@ -53,6 +62,16 @@ void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr,
padding, cache, evict, isVolatile);
}

// implementatio with flagtree_hints
void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr,
ArrayRef<int32_t> boundaryCheck,
std::optional<PaddingOption> padding, CacheModifier cache,
EvictionPolicy evict, bool isVolatile,
mlir::StringAttr flagtree_hints) {
LoadOp::build(builder, state, ptr, /*mask=*/{}, /*other=*/{}, boundaryCheck,
padding, cache, evict, isVolatile, flagtree_hints);
}

void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr,
Value mask, CacheModifier cache, EvictionPolicy evict,
bool isVolatile) {
Expand All @@ -61,6 +80,16 @@ void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr,
/*padding=*/std::nullopt, cache, evict, isVolatile);
}

// implementatio with flagtree_hints
void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr,
Value mask, CacheModifier cache, EvictionPolicy evict,
bool isVolatile, mlir::StringAttr flagtree_hints) {
LoadOp::build(builder, state, ptr, mask, /*other=*/{},
/*boundaryCheck=*/ArrayRef<int32_t>{},
/*padding=*/std::nullopt, cache, evict, isVolatile,
flagtree_hints);
}

void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr,
Value mask, Value other, CacheModifier cache,
EvictionPolicy evict, bool isVolatile) {
Expand All @@ -69,6 +98,17 @@ void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr,
/*padding=*/std::nullopt, cache, evict, isVolatile);
}

// implementatio with flagtree_hints
void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr,
Value mask, Value other, CacheModifier cache,
EvictionPolicy evict, bool isVolatile,
mlir::StringAttr flagtree_hints) {
LoadOp::build(builder, state, ptr, mask, other,
/*boundaryCheck=*/ArrayRef<int32_t>{},
/*padding=*/std::nullopt, cache, evict, isVolatile,
flagtree_hints);
}

void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr,
Value mask, Value other, ArrayRef<int32_t> boundaryCheck,
std::optional<PaddingOption> padding, CacheModifier cache,
Expand All @@ -82,6 +122,21 @@ void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr,
evict, isVolatile);
}

// implementatio with flagtree_hints
void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr,
Value mask, Value other, ArrayRef<int32_t> boundaryCheck,
std::optional<PaddingOption> padding, CacheModifier cache,
EvictionPolicy evict, bool isVolatile,
mlir::StringAttr flagtree_hints) {
auto paddingAttr =
padding.has_value()
? PaddingOptionAttr::get(builder.getContext(), padding.value())
: PaddingOptionAttr();
LoadOp::build(builder, state, ptr, mask, other,
builder.getDenseI32ArrayAttr(boundaryCheck), paddingAttr, cache,
evict, isVolatile, flagtree_hints);
}

// load(ptr, splat(1), ...) -> load(ptr, ...)
// load(ptr, splat(0), other, ...) -> other
struct CanonicalizeMaskedLoadPattern : public OpRewritePattern<LoadOp> {
Expand Down
3 changes: 2 additions & 1 deletion lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,8 @@ class RewriteTensorPointerPass
if (auto loadOp = dyn_cast<triton::LoadOp>(op)) {
auto newResult = builder.create<triton::LoadOp>(
loadOp.getLoc(), newPtr, newMask, newOther, loadOp.getCache(),
loadOp.getEvict(), loadOp.getIsVolatile());
loadOp.getEvict(), loadOp.getIsVolatile(),
loadOp.getFlagtreeHintsAttr());
op->getResult(0).replaceAllUsesWith(newResult);
if (op->getAttr("async_task_id"))
newResult->setAttr("async_task_id", op->getAttr("async_task_id"));
Expand Down
8 changes: 7 additions & 1 deletion python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,13 @@ def build_extension(self, ext):
)

if helper.flagtree_backend:
backends = [*BackendInstaller.copy(helper.extend_backends), *BackendInstaller.copy_externals()]
if helper.flagtree_backend == "aipu":
backends = [
*BackendInstaller.copy(helper.default_backends + helper.extend_backends),
*BackendInstaller.copy_externals(),
]
else:
backends = [*BackendInstaller.copy(helper.extend_backends), *BackendInstaller.copy_externals()]
else:
backends = [*BackendInstaller.copy(helper.default_backends), *BackendInstaller.copy_externals()]

Expand Down
Loading