diff --git a/python/setup_tools/utils/__init__.py b/python/setup_tools/utils/__init__.py index dd4f98db3..103b96178 100644 --- a/python/setup_tools/utils/__init__.py +++ b/python/setup_tools/utils/__init__.py @@ -20,7 +20,7 @@ class FlagTreeBackend: FlagTreeBackend(name="cambricon", url="https://github.com/Cambricon/triton-linalg.git", tag="00f51c2e48a943922f86f03d58e29f514def646d"), FlagTreeBackend(name="flir", url="git@github.com:FlagTree/flir.git", - tag="318ed13e396d4d0ed84773975c8507c6e3f0275d"), + tag="243690fb1c8b9f032c1f938271414831a6cfe406"), FlagTreeBackend( name="ascend", url="https://gitee.com/ascend/triton-ascend.git", diff --git a/third_party/aipu/CMakeLists.txt b/third_party/aipu/CMakeLists.txt index 2ab8d44cb..691d3e393 100644 --- a/third_party/aipu/CMakeLists.txt +++ b/third_party/aipu/CMakeLists.txt @@ -1,5 +1,6 @@ add_subdirectory(include) add_subdirectory(lib) +add_subdirectory(tools) add_triton_plugin(TritonAIPU ${CMAKE_CURRENT_SOURCE_DIR}/triton_aipu.cc) target_include_directories(TritonAIPU PRIVATE ${CMAKE_SOURCE_DIR}/third_party/flir/include) diff --git a/third_party/aipu/backend/aipu_torch_dev.cpp b/third_party/aipu/backend/aipu_torch_dev.cpp index dc5dfc6b8..31036a944 100644 --- a/third_party/aipu/backend/aipu_torch_dev.cpp +++ b/third_party/aipu/backend/aipu_torch_dev.cpp @@ -11,6 +11,8 @@ #include #include +#include +#include #include #include #include @@ -228,9 +230,29 @@ Tensor aipu_copy_from(const Tensor &self, const Tensor &dst, auto status = aipu_memcpy(aipu_ctx_, dst.data_ptr(), self.data_ptr(), self.nbytes(), kind); AIPU_DRIVER_HANDLE_ERROR(status); - return self; + return dst; } +Tensor aipu_copy_from_and_resize(const Tensor &self, const Tensor &dst) { + if (self.sizes() != dst.sizes()) { + auto new_dst = + custom_empty_symint(self.sizes(), self.scalar_type(), c10::nullopt, + c10::nullopt, c10::nullopt, c10::nullopt); + auto kind = AIPU_MEMCPY_HOST_TO_DEVICE; + if (StrStartsWith(self.device().str(), "aipu")) { + kind = AIPU_MEMCPY_DEVICE_TO_HOST; + if (StrStartsWith(dst.device().str(), "aipu")) { + kind = AIPU_MEMCPY_DEVICE_TO_DEVICE; + } + } + auto aipu_ctx_ = AIPUAllocator::aipu_ctx_; + auto status = aipu_memcpy(aipu_ctx_, new_dst.data_ptr(), self.data_ptr(), + self.nbytes(), kind); + AIPU_DRIVER_HANDLE_ERROR(status); + return new_dst; + } + return aipu_copy_from(self, dst, false); +} template