Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/setup_tools/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class FlagTreeBackend:
FlagTreeBackend(name="cambricon", url="https://github.com/Cambricon/triton-linalg.git",
tag="00f51c2e48a943922f86f03d58e29f514def646d"),
FlagTreeBackend(name="flir", url="git@github.com:FlagTree/flir.git",
tag="318ed13e396d4d0ed84773975c8507c6e3f0275d"),
tag="243690fb1c8b9f032c1f938271414831a6cfe406"),
FlagTreeBackend(
name="ascend",
url="https://gitee.com/ascend/triton-ascend.git",
Expand Down
1 change: 1 addition & 0 deletions third_party/aipu/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
add_subdirectory(include)
add_subdirectory(lib)
add_subdirectory(tools)

add_triton_plugin(TritonAIPU ${CMAKE_CURRENT_SOURCE_DIR}/triton_aipu.cc)
target_include_directories(TritonAIPU PRIVATE ${CMAKE_SOURCE_DIR}/third_party/flir/include)
Expand Down
44 changes: 43 additions & 1 deletion third_party/aipu/backend/aipu_torch_dev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

#include <ATen/EmptyTensor.h>
#include <ATen/InferSize.h>
#include <ATen/core/GeneratorForPrivateuseone.h>
#include <ATen/native/CPUFallback.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/DistributionTemplates.h>
#include <ATen/native/cpu/DistributionTemplates.h>
Expand Down Expand Up @@ -228,9 +230,29 @@ Tensor aipu_copy_from(const Tensor &self, const Tensor &dst,
auto status = aipu_memcpy(aipu_ctx_, dst.data_ptr(), self.data_ptr(),
self.nbytes(), kind);
AIPU_DRIVER_HANDLE_ERROR(status);
return self;
return dst;
}
Tensor aipu_copy_from_and_resize(const Tensor &self, const Tensor &dst) {
if (self.sizes() != dst.sizes()) {
auto new_dst =
custom_empty_symint(self.sizes(), self.scalar_type(), c10::nullopt,
c10::nullopt, c10::nullopt, c10::nullopt);
auto kind = AIPU_MEMCPY_HOST_TO_DEVICE;
if (StrStartsWith(self.device().str(), "aipu")) {
kind = AIPU_MEMCPY_DEVICE_TO_HOST;
if (StrStartsWith(dst.device().str(), "aipu")) {
kind = AIPU_MEMCPY_DEVICE_TO_DEVICE;
}
}
auto aipu_ctx_ = AIPUAllocator::aipu_ctx_;
auto status = aipu_memcpy(aipu_ctx_, new_dst.data_ptr(), self.data_ptr(),
self.nbytes(), kind);
AIPU_DRIVER_HANDLE_ERROR(status);

return new_dst;
}
return aipu_copy_from(self, dst, false);
}
template <template <typename> class RND>
Tensor &random_kernel(Tensor &self, double cond1, double cond2,
c10::optional<Generator> gen) {
Expand Down Expand Up @@ -323,12 +345,30 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
m.impl("aten::uniform_", &random_kernel<uniform_real_distribution>);
m.impl("aten::normal_", &random_kernel<normal_distribution>);
m.impl("aten::_copy_from", &aipu_copy_from);
m.impl("aten::_copy_from_and_resize", &aipu_copy_from_and_resize);
m.impl("aten::random_.from",
&random_from_to_kernel<uniform_int_from_to_distribution>);
m.impl("aten::_local_scalar_dense", &_local_scalar_dense_aipu);
m.impl("aten::fill_.Scalar", &fill_scalar_aipu);
}

void custom_cpu_fallback(const c10::OperatorHandle &op,
torch::jit::Stack *stack) {
at::native::cpu_fallback(op, stack);
}

TORCH_LIBRARY_IMPL(_, PrivateUse1, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
}

at::Generator make_custom_generator(c10::DeviceIndex device_index) {
return at::detail::getDefaultCPUGenerator();
}

REGISTER_GENERATOR_PRIVATEUSE1(make_custom_generator)

C10_REGISTER_GUARD_IMPL(PrivateUse1, c10::impl::AIPUGuardImpl)

// Register the autograd dispatch key for operators that have no dispatches
TORCH_LIBRARY_IMPL(aten, AutogradPrivateUse1, m) {
m.impl("isfinite", torch::autograd::autogradNotImplementedFallback());
Expand All @@ -352,6 +392,7 @@ struct _DeviceGuard {

struct _Device {
_Device(c10::Device device) { idx = device.index(); }
_Device(int index) { idx = index; }

int idx = 0;
int prev_idx = -1;
Expand All @@ -375,6 +416,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::class_<_Device>(m, "device", py::module_local())
.def(py::init(
[](c10::Device device) { return std::make_unique<_Device>(device); }))
.def(py::init([](int index) { return std::make_unique<_Device>(index); }))
.def("__enter__", [](_Device &self) { ; })
.def("__exit__",
[](_Device &self, pybind11::object type, pybind11::object value,
Expand Down
6 changes: 5 additions & 1 deletion third_party/aipu/backend/analysis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from .determine_vfactor import determine_vectorization_factor
from .get_linalg_generic_size import get_linalg_generic_size

__all__ = ["determine_vectorization_factor"]
__all__ = [
"determine_vectorization_factor",
"get_linalg_generic_size",
]
16 changes: 12 additions & 4 deletions third_party/aipu/backend/analysis/determine_vfactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,23 @@ def determine_vectorization_factor(module, target_bitwidth=256, debug=False):
"""
min_width = target_bitwidth

def _get_width(value):
elem_type = (value.type.element_type if hasattr(value.type, 'element_type') else value.type)
elem_width = 32 if isinstance(elem_type, ir.IndexType) else elem_type.width
# Except bool dtype
if elem_width == 1:
return None
return elem_width

def walk_callback(op):
nonlocal min_width
if op.name == "affine.for":
all_ops = (_op for region in op.regions for block in region.blocks for _op in block.operations)
for _op in all_ops:
for result in _op.results:
elem_type = (result.type.element_type if hasattr(result.type, 'element_type') else result.type)
elem_width = 32 if isinstance(elem_type, ir.IndexType) else elem_type.width
min_width = min(min_width, elem_width)
items = _op.results or _op.operands
for item in items:
if width := _get_width(item):
min_width = min(min_width, width)

return ir.WalkResult.ADVANCE

Expand Down
25 changes: 25 additions & 0 deletions third_party/aipu/backend/analysis/get_linalg_generic_size.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from mlir import ir


def get_linalg_generic_size(module):
"""
Get the size of linalg generic.

Args:
module: The Triton module to analyze.

Returns:
int: The size of linalg generic.
"""
size = 0

def walk_callback(op):
nonlocal size
if op.name == "linalg.generic":
size += 1

return ir.WalkResult.ADVANCE

module.operation.walk(walk_callback, ir.WalkOrder.PRE_ORDER)

return size
Loading