From 42c1536f1cc68c99a0c5be67f9a095a341aa02ab Mon Sep 17 00:00:00 2001 From: huaweiZJX <125643694+huaweiZJX@users.noreply.github.com> Date: Wed, 29 Nov 2023 17:24:28 +0800 Subject: [PATCH 1/5] add op. --- docs/en/understand_mmcv/ops.md | 8 +- docs/zh_cn/understand_mmcv/ops.md | 8 +- mmcv/ops/chamfer_distance.py | 8 +- mmcv/ops/csrc/common/pytorch_npu_helper.hpp | 2 +- mmcv/ops/csrc/common/pytorch_npu_util.hpp | 585 ++++++++++++++++++ .../csrc/pytorch/npu/chamfer_distance_npu.cpp | 39 ++ mmcv/ops/csrc/pytorch/npu/common_util.h | 14 + mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp | 18 +- .../pytorch/npu/fused_bias_leakyrelu_npu.cpp | 4 +- .../csrc/pytorch/npu/gather_points_npu.cpp | 6 +- mmcv/ops/csrc/pytorch/npu/roi_align_npu.cpp | 7 +- .../pytorch/npu/roi_align_rotated_npu.cpp | 69 +++ .../pytorch/npu/rotated_feature_align_npu.cpp | 52 ++ .../csrc/pytorch/npu/stack_ball_query_npu.cpp | 23 + .../pytorch/npu/three_interpolate_npu.cpp | 32 +- setup.py | 13 +- tests/test_ops/test_ball_query.py | 49 +- tests/test_ops/test_chamfer_distance.py | 107 ++-- tests/test_ops/test_rotated_feature_align.py | 6 +- 19 files changed, 959 insertions(+), 91 deletions(-) create mode 100644 mmcv/ops/csrc/common/pytorch_npu_util.hpp create mode 100644 mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp create mode 100644 mmcv/ops/csrc/pytorch/npu/common_util.h create mode 100644 mmcv/ops/csrc/pytorch/npu/roi_align_rotated_npu.cpp create mode 100644 mmcv/ops/csrc/pytorch/npu/rotated_feature_align_npu.cpp create mode 100644 mmcv/ops/csrc/pytorch/npu/stack_ball_query_npu.cpp diff --git a/docs/en/understand_mmcv/ops.md b/docs/en/understand_mmcv/ops.md index 265327bc5c..259e1ec6ad 100644 --- a/docs/en/understand_mmcv/ops.md +++ b/docs/en/understand_mmcv/ops.md @@ -6,13 +6,13 @@ We implement common ops used in detection, segmentation, etc. | ---------------------------- | --- | ---- | --- | --- | ------ | | ActiveRotatedFilter | √ | √ | | | √ | | AssignScoreWithK | | √ | | | | -| BallQuery | | √ | √ | | | +| BallQuery | | √ | √ | | √ | | BBoxOverlaps | | √ | √ | √ | √ | | BorderAlign | | √ | | | | | BoxIouRotated | √ | √ | √ | | √ | | BoxIouQuadri | √ | √ | | | | | CARAFE | | √ | √ | | | -| ChamferDistance | | √ | | | | +| ChamferDistance | | √ | | | √ | | CrissCrossAttention | | √ | | | | | ContourExpand | √ | | | | | | ConvexIoU | | √ | | | | @@ -41,10 +41,10 @@ We implement common ops used in detection, segmentation, etc. | PointsInBoxes | √ | √ | | | | | PointsInPolygons | | √ | | | √ | | PSAMask | √ | √ | √ | | √ | -| RotatedFeatureAlign | √ | √ | √ | | | +| RotatedFeatureAlign | √ | √ | √ | | √ | | RoIPointPool3d | | √ | √ | | | | RoIPool | | √ | √ | | √ | -| RoIAlignRotated | √ | √ | √ | | | +| RoIAlignRotated | √ | √ | √ | | √ | | RiRoIAlignRotated | | √ | | | | | RoIAlign | √ | √ | √ | | √ | | RoIAwarePool3d | | √ | √ | | | diff --git a/docs/zh_cn/understand_mmcv/ops.md b/docs/zh_cn/understand_mmcv/ops.md index ba744daf11..f2d0382260 100644 --- a/docs/zh_cn/understand_mmcv/ops.md +++ b/docs/zh_cn/understand_mmcv/ops.md @@ -6,13 +6,13 @@ MMCV 提供了检测、分割等任务中常用的算子 | ---------------------------- | --- | ---- | --- | --- | ------ | | ActiveRotatedFilter | √ | √ | | | √ | | AssignScoreWithK | | √ | | | | -| BallQuery | | √ | √ | | | +| BallQuery | | √ | √ | | √ | | BBoxOverlaps | | √ | √ | √ | √ | | BorderAlign | | √ | | | | | BoxIouRotated | √ | √ | √ | | √ | | BoxIouQuadri | √ | √ | | | | | CARAFE | | √ | √ | | | -| ChamferDistance | | √ | | | | +| ChamferDistance | | √ | | | √ | | CrissCrossAttention | | √ | | | | | ContourExpand | √ | | | | | | ConvexIoU | | √ | | | | @@ -41,10 +41,10 @@ MMCV 提供了检测、分割等任务中常用的算子 | PointsInBoxes | √ | √ | | | | | PointsInPolygons | | √ | | | | | PSAMask | √ | √ | √ | | √ | -| RotatedFeatureAlign | √ | √ | √ | | | +| RotatedFeatureAlign | √ | √ | √ | | √ | | RoIPointPool3d | | √ | √ | | | | RoIPool | | √ | √ | | √ | -| RoIAlignRotated | √ | √ | √ | | | +| RoIAlignRotated | √ | √ | √ | | √ | | RiRoIAlignRotated | | √ | | | | | RoIAlign | √ | √ | √ | | √ | | RoIAwarePool3d | | √ | √ | | | diff --git a/mmcv/ops/chamfer_distance.py b/mmcv/ops/chamfer_distance.py index 1f908a5bbc..d95bd47747 100644 --- a/mmcv/ops/chamfer_distance.py +++ b/mmcv/ops/chamfer_distance.py @@ -44,8 +44,8 @@ def forward(ctx, xyz1: Tensor, xyz2: Tensor) -> Sequence[Tensor]: xyz1 = xyz1.contiguous() xyz2 = xyz2.contiguous() - dist1 = torch.zeros(batch_size, n).to(device) - dist2 = torch.zeros(batch_size, m).to(device) + dist1 = torch.zeros(batch_size, n).type(xyz1.dtype).to(device) + dist2 = torch.zeros(batch_size, m).type(xyz2.dtype).to(device) idx1 = torch.zeros(batch_size, n).type(torch.IntTensor).to(device) idx2 = torch.zeros(batch_size, m).type(torch.IntTensor).to(device) @@ -81,8 +81,8 @@ def backward(ctx, device = grad_dist1.device grad_dist1 = grad_dist1.contiguous() grad_dist2 = grad_dist2.contiguous() - grad_xyz1 = torch.zeros(xyz1.size()).to(device) - grad_xyz2 = torch.zeros(xyz2.size()).to(device) + grad_xyz1 = torch.zeros(xyz1.size()).type(xyz1.dtype).to(device) + grad_xyz2 = torch.zeros(xyz2.size()).type(xyz2.dtype).to(device) ext_module.chamfer_distance_backward(xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2, grad_xyz1, diff --git a/mmcv/ops/csrc/common/pytorch_npu_helper.hpp b/mmcv/ops/csrc/common/pytorch_npu_helper.hpp index 073d6b38c3..82e80276ed 100644 --- a/mmcv/ops/csrc/common/pytorch_npu_helper.hpp +++ b/mmcv/ops/csrc/common/pytorch_npu_helper.hpp @@ -18,12 +18,12 @@ #ifndef PYTORCH_NPU_HELPER_HPP_ #define PYTORCH_NPU_HELPER_HPP_ -#include #include #include #include "pytorch_cpp_helper.hpp" #include "pytorch_device_registry.hpp" +#include "pytorch_npu_util.hpp" #define NPU_NAME_SPACE at_npu::native diff --git a/mmcv/ops/csrc/common/pytorch_npu_util.hpp b/mmcv/ops/csrc/common/pytorch_npu_util.hpp new file mode 100644 index 0000000000..8c26a934f0 --- /dev/null +++ b/mmcv/ops/csrc/common/pytorch_npu_util.hpp @@ -0,0 +1,585 @@ +/****************************************************************************** + * Copyright (c) 2022 Huawei Technologies Co., Ltd + * All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +#ifndef MMCV_OPS_CSRC_COMMON_PYTORCH_NPU_UTIL_HPP_ +#define MMCV_OPS_CSRC_COMMON_PYTORCH_NPU_UTIL_HPP_ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "pytorch_cpp_helper.hpp" +#include "pytorch_device_registry.hpp" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" +#include "torch_npu/csrc/core/npu/NPUStream.h" +#include "torch_npu/csrc/framework/OpCommand.h" +#include "torch_npu/csrc/framework/interface/EnvVariables.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/framework/utils/OpPreparation.h" + +#define NPU_NAME_SPACE at_npu::native + +typedef struct aclOpExecutor aclOpExecutor; +typedef struct aclTensor aclTensor; +typedef struct aclScalar aclScalar; +typedef struct aclIntArray aclIntArray; +typedef struct aclFloatArray aclFloatArray; +typedef struct aclBoolArray aclBoolArray; +typedef struct aclTensorList aclTensorList; + +typedef aclTensor *(*_aclCreateTensor)( + const int64_t *view_dims, uint64_t view_dims_num, aclDataType data_type, + const int64_t *stride, int64_t offset, aclFormat format, + const int64_t *storage_dims, uint64_t storage_dims_num, void *tensor_data); +typedef aclScalar *(*_aclCreateScalar)(void *value, aclDataType data_type); +typedef aclIntArray *(*_aclCreateIntArray)(const int64_t *value, uint64_t size); +typedef aclFloatArray *(*_aclCreateFloatArray)(const float *value, + uint64_t size); +typedef aclBoolArray *(*_aclCreateBoolArray)(const bool *value, uint64_t size); +typedef aclTensorList *(*_aclCreateTensorList)(const aclTensor *const *value, + uint64_t size); + +typedef int (*_aclDestroyTensor)(const aclTensor *tensor); +typedef int (*_aclDestroyScalar)(const aclScalar *scalar); +typedef int (*_aclDestroyIntArray)(const aclIntArray *array); +typedef int (*_aclDestroyFloatArray)(const aclFloatArray *array); +typedef int (*_aclDestroyBoolArray)(const aclBoolArray *array); +typedef int (*_aclDestroyTensorList)(const aclTensorList *array); + +constexpr int kHashBufSize = 8192; +constexpr int kHashBufMaxSize = kHashBufSize + 1024; +extern thread_local char g_hashBuf[kHashBufSize]; +extern thread_local int g_hashOffset; + +#ifdef MMCV_WITH_XLA +#define DEVICE_TYPE at_npu::key::NativeDeviceType +#else +#define DEVICE_TYPE c10::DeviceType::PrivateUse1 +#endif + +#define AT_ALL_SCALAR_TYPE_AND_ACL_DATATYPE_PAIR(_) \ + _(at::ScalarType::Byte, ACL_UINT8) \ + _(at::ScalarType::Char, ACL_INT8) \ + _(at::ScalarType::Short, ACL_INT16) \ + _(at::ScalarType::Int, ACL_INT32) \ + _(at::ScalarType::Long, ACL_INT64) \ + _(at::ScalarType::Half, ACL_FLOAT16) \ + _(at::ScalarType::Float, ACL_FLOAT) \ + _(at::ScalarType::Double, ACL_DOUBLE) \ + _(at::ScalarType::ComplexHalf, ACL_DT_UNDEFINED) \ + _(at::ScalarType::ComplexFloat, ACL_COMPLEX64) \ + _(at::ScalarType::ComplexDouble, ACL_COMPLEX128) \ + _(at::ScalarType::Bool, ACL_BOOL) \ + _(at::ScalarType::QInt8, ACL_DT_UNDEFINED) \ + _(at::ScalarType::QUInt8, ACL_DT_UNDEFINED) \ + _(at::ScalarType::QInt32, ACL_DT_UNDEFINED) \ + _(at::ScalarType::BFloat16, ACL_BF16) \ + _(at::ScalarType::QUInt4x2, ACL_DT_UNDEFINED) \ + _(at::ScalarType::QUInt2x4, ACL_DT_UNDEFINED) \ + _(at::ScalarType::Undefined, ACL_DT_UNDEFINED) \ + _(at::ScalarType::NumOptions, ACL_DT_UNDEFINED) + +constexpr aclDataType kATenScalarTypeToAclDataTypeTable + [static_cast(at::ScalarType::NumOptions) + 1] = { +#define DEFINE_ENUM(_1, n) n, + AT_ALL_SCALAR_TYPE_AND_ACL_DATATYPE_PAIR(DEFINE_ENUM) +#undef DEFINE_ENUM +}; + +#define GET_OP_API_FUNC(apiName) \ + reinterpret_cast<_##apiName>(GetOpApiFuncAddr(#apiName)) + +#define MEMCPY_TO_BUF(data_expression, size_expression) \ + if (g_hashOffset + (size_expression) > kHashBufSize) { \ + g_hashOffset = kHashBufMaxSize; \ + return; \ + } \ + memcpy(g_hashBuf + g_hashOffset, data_expression, size_expression); \ + g_hashOffset += size_expression; + +inline const char *GetOpApiLibName(void) { return "libopapi.so"; } + +inline const char *GetCustOpApiLibName(void) { return "libcust_opapi.so"; } + +inline void *GetOpApiFuncAddrInLib(void *handler, const char *libName, + const char *apiName) { + auto funcAddr = dlsym(handler, apiName); + if (funcAddr == nullptr) { + ASCEND_LOGW("dlsym %s from %s failed, error:%s.", apiName, libName, + dlerror()); + } + return funcAddr; +} + +inline void *GetOpApiLibHandler(const char *libName) { + auto handler = dlopen(libName, RTLD_LAZY); + if (handler == nullptr) { + ASCEND_LOGW("dlopen %s failed, error:%s.", libName, dlerror()); + } + return handler; +} + +inline void *GetOpApiFuncAddr(const char *apiName) { + static auto custOpApiHandler = GetOpApiLibHandler(GetCustOpApiLibName()); + if (custOpApiHandler != nullptr) { + auto funcAddr = + GetOpApiFuncAddrInLib(custOpApiHandler, GetCustOpApiLibName(), apiName); + if (funcAddr != nullptr) { + return funcAddr; + } + } + + static auto opApiHandler = GetOpApiLibHandler(GetOpApiLibName()); + if (opApiHandler == nullptr) { + return nullptr; + } + return GetOpApiFuncAddrInLib(opApiHandler, GetOpApiLibName(), apiName); +} + +inline c10::Scalar ConvertTensorToScalar(const at::Tensor &tensor) { + c10::Scalar expScalar; + const at::Tensor *aclInput = &tensor; + if (aclInput->scalar_type() == at::ScalarType::Double) { + double value = *(double *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } else if (aclInput->scalar_type() == at::ScalarType::Long) { + int64_t value = *(int64_t *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } else if (aclInput->scalar_type() == at::ScalarType::Float) { + float value = *(float *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } else if (aclInput->scalar_type() == at::ScalarType::Int) { + int value = *(int *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } else if (aclInput->scalar_type() == at::ScalarType::Half) { + c10::Half value = *(c10::Half *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } else if (aclInput->scalar_type() == at::ScalarType::Bool) { + int8_t value = *(int8_t *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } else if (aclInput->scalar_type() == at::ScalarType::ComplexDouble) { + c10::complex value = *(c10::complex *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } else if (aclInput->scalar_type() == at::ScalarType::ComplexFloat) { + c10::complex value = *(c10::complex *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } else if (aclInput->scalar_type() == at::ScalarType::BFloat16) { + c10::BFloat16 value = *(c10::BFloat16 *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } + return expScalar; +} + +inline at::Tensor CopyTensorHostToDevice(const at::Tensor &cpu_tensor) { + at::Tensor cpuPinMemTensor = cpu_tensor.pin_memory(); + int deviceIndex = 0; + return cpuPinMemTensor.to( + c10::Device(DEVICE_TYPE, deviceIndex), + cpuPinMemTensor.scalar_type(), true, true); +} + +inline at::Tensor CopyScalarToDevice(const c10::Scalar &cpu_scalar, + at::ScalarType scalar_data_type) { + return CopyTensorHostToDevice( + scalar_to_tensor(cpu_scalar).to(scalar_data_type)); +} + +inline aclTensor *ConvertType(const at::Tensor &at_tensor) { + static const auto aclCreateTensor = GET_OP_API_FUNC(aclCreateTensor); + if (aclCreateTensor == nullptr) { + return nullptr; + } + + if (!at_tensor.defined()) { + return nullptr; + } + at::ScalarType scalar_data_type = at_tensor.scalar_type(); + aclDataType acl_data_type = + kATenScalarTypeToAclDataTypeTable[static_cast(scalar_data_type)]; + TORCH_CHECK( + acl_data_type != ACL_DT_UNDEFINED, + std::string(c10::toString(scalar_data_type)) + " has not been supported") + c10::SmallVector storageDims; + // if acl_data_type is ACL_STRING, storageDims is empty. + auto itemsize = at_tensor.itemsize(); + if (itemsize == 0) { + AT_ERROR("When ConvertType, tensor item size of cannot be zero."); + return nullptr; + } + if (acl_data_type != ACL_STRING) { + storageDims.push_back(at_tensor.storage().nbytes() / itemsize); + } + + const auto dimNum = at_tensor.sizes().size(); + aclFormat format = ACL_FORMAT_ND; + switch (dimNum) { + case 3: + format = ACL_FORMAT_NCL; + break; + case 4: + format = ACL_FORMAT_NCHW; + break; + case 5: + format = ACL_FORMAT_NCDHW; + break; + default: + format = ACL_FORMAT_ND; + } + + if (at_tensor.unsafeGetTensorImpl()->is_wrapped_number()) { + c10::Scalar expScalar = ConvertTensorToScalar(at_tensor); + at::Tensor aclInput = CopyScalarToDevice(expScalar, scalar_data_type); + return aclCreateTensor( + aclInput.sizes().data(), aclInput.sizes().size(), acl_data_type, + aclInput.strides().data(), aclInput.storage_offset(), format, + storageDims.data(), storageDims.size(), const_cast(aclInput.storage().data())); + } + + auto acl_tensor = aclCreateTensor( + at_tensor.sizes().data(), at_tensor.sizes().size(), acl_data_type, + at_tensor.strides().data(), at_tensor.storage_offset(), format, + storageDims.data(), storageDims.size(), const_cast(at_tensor.storage().data())); + return acl_tensor; +} + +inline aclScalar *ConvertType(const at::Scalar &at_scalar) { + static const auto aclCreateScalar = GET_OP_API_FUNC(aclCreateScalar); + if (aclCreateScalar == nullptr) { + return nullptr; + } + + at::ScalarType scalar_data_type = at_scalar.type(); + aclDataType acl_data_type = + kATenScalarTypeToAclDataTypeTable[static_cast(scalar_data_type)]; + TORCH_CHECK( + acl_data_type != ACL_DT_UNDEFINED, + std::string(c10::toString(scalar_data_type)) + " has not been supported") + aclScalar *acl_scalar = nullptr; + switch (scalar_data_type) { + case at::ScalarType::Double: { + double value = at_scalar.toDouble(); + acl_scalar = aclCreateScalar(&value, acl_data_type); + break; + } + case at::ScalarType::Long: { + int64_t value = at_scalar.toLong(); + acl_scalar = aclCreateScalar(&value, acl_data_type); + break; + } + case at::ScalarType::Bool: { + bool value = at_scalar.toBool(); + acl_scalar = aclCreateScalar(&value, acl_data_type); + break; + } + case at::ScalarType::ComplexDouble: { + auto value = at_scalar.toComplexDouble(); + acl_scalar = aclCreateScalar(&value, acl_data_type); + break; + } + default: + acl_scalar = nullptr; + break; + } + return acl_scalar; +} + +inline aclIntArray *ConvertType(const at::IntArrayRef &at_array) { + static const auto aclCreateIntArray = GET_OP_API_FUNC(aclCreateIntArray); + if (aclCreateIntArray == nullptr) { + return nullptr; + } + auto array = aclCreateIntArray(at_array.data(), at_array.size()); + return array; +} + +template +inline aclBoolArray *ConvertType(const std::array &value) { + static const auto aclCreateBoolArray = GET_OP_API_FUNC(aclCreateBoolArray); + if (aclCreateBoolArray == nullptr) { + return nullptr; + } + + auto array = aclCreateBoolArray(value.data(), value.size()); + return array; +} + +inline aclBoolArray *ConvertType(const at::ArrayRef &value) { + static const auto aclCreateBoolArray = GET_OP_API_FUNC(aclCreateBoolArray); + if (aclCreateBoolArray == nullptr) { + return nullptr; + } + + auto array = aclCreateBoolArray(value.data(), value.size()); + return array; +} + +inline aclTensorList *ConvertType(const at::TensorList &at_tensor_list) { + static const auto aclCreateTensorList = GET_OP_API_FUNC(aclCreateTensorList); + if (aclCreateTensorList == nullptr) { + return nullptr; + } + + std::vector tensor_list(at_tensor_list.size()); + for (size_t i = 0; i < at_tensor_list.size(); i++) { + tensor_list[i] = ConvertType(at_tensor_list[i]); + } + auto acl_tensor_list = + aclCreateTensorList(tensor_list.data(), tensor_list.size()); + return acl_tensor_list; +} + +inline aclTensor *ConvertType(const c10::optional &opt_tensor) { + if (opt_tensor.has_value() && opt_tensor.value().defined()) { + return ConvertType(opt_tensor.value()); + } + return nullptr; +} + +inline aclIntArray *ConvertType( + const c10::optional &opt_array) { + if (opt_array.has_value()) { + return ConvertType(opt_array.value()); + } + return nullptr; +} + +inline aclScalar *ConvertType(const c10::optional &opt_scalar) { + if (opt_scalar.has_value()) { + return ConvertType(opt_scalar.value()); + } + return nullptr; +} + +inline aclDataType ConvertType(const at::ScalarType scalarType) { + return kATenScalarTypeToAclDataTypeTable[static_cast(scalarType)]; +} + +template +T ConvertType(T value) { + return value; +} + +template +auto ConvertToOpApiFunc(const Tuple ¶ms, void *opApiAddr, + std::index_sequence) { + typedef int (*OpApiFunc)( + typename std::decay(params))>::type...); + auto func = reinterpret_cast(opApiAddr); + return func; +} + +template +auto ConvertToOpApiFunc(const Tuple ¶ms, void *opApiAddr) { + static constexpr auto size = std::tuple_size::value; + return ConvertToOpApiFunc(params, opApiAddr, + std::make_index_sequence{}); +} + +inline void Release(aclTensor *p) { + static const auto aclDestroyTensor = GET_OP_API_FUNC(aclDestroyTensor); + if (aclDestroyTensor == nullptr) { + return; + } + aclDestroyTensor(p); +} + +inline void Release(aclScalar *p) { + static const auto aclDestroyScalar = GET_OP_API_FUNC(aclDestroyScalar); + if (aclDestroyScalar == nullptr) { + return; + } + aclDestroyScalar(p); +} + +inline void Release(aclIntArray *p) { + static const auto aclDestroyIntArray = GET_OP_API_FUNC(aclDestroyIntArray); + if (aclDestroyIntArray == nullptr) { + return; + } + + aclDestroyIntArray(p); +} + +inline void Release(aclBoolArray *p) { + static const auto aclDestroyBoolArray = GET_OP_API_FUNC(aclDestroyBoolArray); + if (aclDestroyBoolArray == nullptr) { + return; + } + + aclDestroyBoolArray(p); +} + +inline void Release(aclTensorList *p) { + static const auto aclDestroyTensorList = + GET_OP_API_FUNC(aclDestroyTensorList); + if (aclDestroyTensorList == nullptr) { + return; + } + + aclDestroyTensorList(p); +} + +template +void Release(T value) { + (void)value; +} + +template +void CallRelease(Tuple t, std::index_sequence) { + (void)std::initializer_list{(Release(std::get(t)), 0)...}; +} + +template +void ReleaseConvertTypes(Tuple &t) { + static constexpr auto size = std::tuple_size::value; + CallRelease(t, std::make_index_sequence{}); +} + +template +constexpr auto ConvertTypes(Ts &... args) { + return std::make_tuple(ConvertType(args)...); +} + +template +auto call(Function f, Tuple t, std::index_sequence) { + return f(std::get(t)...); +} + +template +auto call(Function f, Tuple t) { + static constexpr auto size = std::tuple_size::value; + return call(f, t, std::make_index_sequence{}); +} + +template +void AddParamToBuf(const std::array &value) { + MEMCPY_TO_BUF(value.data(), value.size() * sizeof(bool)); +} + +template +void AddParamToBuf(const T &value) { + MEMCPY_TO_BUF(&value, sizeof(T)); +} + +void AddParamToBuf(const at::Tensor &); +void AddParamToBuf(const at::Scalar &); +void AddParamToBuf(const at::IntArrayRef &); +void AddParamToBuf(const at::ArrayRef &); +void AddParamToBuf(const at::TensorList &); +void AddParamToBuf(const c10::optional &); +void AddParamToBuf(const c10::optional &); +void AddParamToBuf(const c10::optional &); +void AddParamToBuf(const at::ScalarType); +void AddParamToBuf(const string &); +void AddParamToBuf(); + +template +void AddParamToBuf(const T &arg, Args &... args) { + AddParamToBuf(arg); + AddParamToBuf(args...); +} + +uint64_t CalcHashId(); +typedef int (*InitHugeMemThreadLocal)(void *, bool); +typedef void (*UnInitHugeMemThreadLocal)(void *, bool); +typedef void (*ReleaseHugeMem)(void *, bool); + +#define EXEC_NPU_CMD(aclnn_api, ...) \ + do { \ + static const auto getWorkspaceSizeFuncAddr = \ + GetOpApiFuncAddr(#aclnn_api "GetWorkspaceSize"); \ + static const auto opApiFuncAddr = GetOpApiFuncAddr(#aclnn_api); \ + static const auto initMemAddr = \ + GetOpApiFuncAddr("InitHugeMemThreadLocal"); \ + static const auto unInitMemAddr = \ + GetOpApiFuncAddr("UnInitHugeMemThreadLocal"); \ + static const auto releaseMemAddr = GetOpApiFuncAddr("ReleaseHugeMem"); \ + TORCH_CHECK( \ + getWorkspaceSizeFuncAddr != nullptr && opApiFuncAddr != nullptr, \ + #aclnn_api, " or ", #aclnn_api "GetWorkspaceSize", " not in ", \ + GetOpApiLibName(), ", or ", GetOpApiLibName(), "not found."); \ + auto acl_stream = c10_npu::getCurrentNPUStream().stream(false); \ + uint64_t workspace_size = 0; \ + uint64_t *workspace_size_addr = &workspace_size; \ + aclOpExecutor *executor = nullptr; \ + aclOpExecutor **executor_addr = &executor; \ + InitHugeMemThreadLocal initMemFunc = \ + reinterpret_cast(initMemAddr); \ + UnInitHugeMemThreadLocal unInitMemFunc = \ + reinterpret_cast(unInitMemAddr); \ + if (initMemFunc) { \ + initMemFunc(nullptr, false); \ + } \ + auto converted_params = \ + ConvertTypes(__VA_ARGS__, workspace_size_addr, executor_addr); \ + static auto getWorkspaceSizeFunc = \ + ConvertToOpApiFunc(converted_params, getWorkspaceSizeFuncAddr); \ + auto workspace_status = call(getWorkspaceSizeFunc, converted_params); \ + TORCH_CHECK(workspace_status == 0, \ + "call " #aclnn_api " failed, detail:", aclGetRecentErrMsg()); \ + void *workspace_addr = nullptr; \ + if (workspace_size != 0) { \ + at::TensorOptions options = \ + at::TensorOptions(torch_npu::utils::get_npu_device_type()); \ + auto workspace_tensor = \ + at::empty({workspace_size}, options.dtype(kByte)); \ + workspace_addr = const_cast(workspace_tensor.storage().data()); \ + } \ + auto acl_call = [converted_params, workspace_addr, workspace_size, \ + acl_stream, executor]() -> int { \ + typedef int (*OpApiFunc)(void *, uint64_t, aclOpExecutor *, \ + const aclrtStream); \ + OpApiFunc opApiFunc = reinterpret_cast(opApiFuncAddr); \ + auto api_ret = \ + opApiFunc(workspace_addr, workspace_size, executor, acl_stream); \ + TORCH_CHECK(api_ret == 0, "call " #aclnn_api " failed, detail:", \ + aclGetRecentErrMsg()); \ + ReleaseConvertTypes(converted_params); \ + ReleaseHugeMem releaseMemFunc = \ + reinterpret_cast(releaseMemAddr); \ + if (releaseMemFunc) { \ + releaseMemFunc(nullptr, false); \ + } \ + return api_ret; \ + }; \ + at_npu::native::OpCommand cmd; \ + cmd.Name(#aclnn_api); \ + cmd.SetCustomHandler(acl_call); \ + cmd.Run(); \ + if (unInitMemFunc) { \ + unInitMemFunc(nullptr, false); \ + } \ + } while (false) + +#endif // MMCV_OPS_CSRC_COMMON_PYTORCH_NPU_UTIL_HPP_ diff --git a/mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp b/mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp new file mode 100644 index 0000000000..800e636945 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp @@ -0,0 +1,39 @@ +#include "pytorch_npu_helper.hpp" + +using namespace NPU_NAME_SPACE; +using namespace std; + +void chamfer_distance_forward_npu(Tensor XYZ1, Tensor XYZ2, Tensor dist1, + Tensor dist2, Tensor idx1, Tensor idx2) { + at::Tensor xyz1 = at::ones_like(XYZ1); + at::Tensor xyz2 = at::ones_like(XYZ2); + xyz1 = XYZ1.transpose(1, 2).transpose(0, 1); + xyz2 = XYZ2.transpose(1, 2).transpose(0, 1); + OpCommand cmd; + cmd.Name("ChamferDistance") + .Input(xyz1) + .Input(xyz2) + .Output(dist1) + .Output(dist2) + .Output(idx1) + .Output(idx2) + .Run(); +} + +void chamfer_distance_backward_npu(Tensor xyz1, Tensor xyz2, + Tensor idx1, Tensor idx2, Tensor grad_dist1, Tensor grad_dist2, + Tensor grad_xyz1, Tensor grad_xyz2) { + EXEC_NPU_CMD(aclnnChamferDistanceBackward, xyz1, xyz2, idx1, idx2, + grad_dist1, grad_dist2, grad_xyz1, grad_xyz2); +} + +void chamfer_distance_forward_impl(Tensor XYZ1, Tensor XYZ2, Tensor dist1, + Tensor dist2, Tensor idx1, Tensor idx2); +REGISTER_NPU_IMPL(chamfer_distance_forward_impl, + chamfer_distance_forward_npu); + +void chamfer_distance_backward_impl(Tensor xyz1, Tensor xyz2, Tensor idx1, Tensor idx2, + Tensor grad_dist1, Tensor grad_dist2, + Tensor grad_xyz1, Tensor grad_xyz2); +REGISTER_NPU_IMPL(chamfer_distance_backward_impl, + chamfer_distance_backward_npu); diff --git a/mmcv/ops/csrc/pytorch/npu/common_util.h b/mmcv/ops/csrc/pytorch/npu/common_util.h new file mode 100644 index 0000000000..5a303e8764 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/common_util.h @@ -0,0 +1,14 @@ +#ifndef MMCV_OPS_CSRC_COMMON__UTIL_HPP_ +#define MMCV_OPS_CSRC_COMMON__UTIL_HPP_ +const int SIZE = 8; + +c10::SmallVector array_to_vector(c10::IntArrayRef shape) { + c10::SmallVector shape_small_vec; + for (uint64_t i = 0; i < shape.size(); i++) { + shape_small_vec.emplace_back(shape[i]); + } + + return shape_small_vec; +} + +#endif // MMCV_OPS_CSRC_COMMON__UTIL_HPP_ diff --git a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp index b7c995a223..b39e24f2df 100644 --- a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp @@ -1,5 +1,4 @@ #include "pytorch_npu_helper.hpp" - using namespace NPU_NAME_SPACE; using namespace std; @@ -100,7 +99,22 @@ void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, c10::SmallVector sizes = {n_batch, 1}; at::IntArrayRef offset = at::IntArrayRef(offsets); at::IntArrayRef size = at::IntArrayRef(sizes); - at_npu::native::custom_ops::npu_slice_out(op_output, offset, size, output); + at::IntArrayRef size_array = at::IntArrayRef(sizes); + c10::SmallVector offsetVec; + for (uint64_t i = 0; i < offset.size(); i++) { + offsetVec.emplace_back(offset[i]); + } + c10::SmallVector sizeVec; + for (uint64_t i = 0; i < size_array.size(); i++) { + sizeVec.emplace_back(size_array[i]); + } + OpCommand cmd2; + cmd2.Name("Slice") + .Input(op_output) + .Input(offsetVec) + .Input(sizeVec) + .Output(output) + .Run(); } void softmax_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, diff --git a/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp b/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp index a3d44cacb2..071f1c4021 100644 --- a/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp @@ -16,7 +16,9 @@ Tensor fused_bias_leakyrelu_npu(const Tensor &input, const Tensor &bias, auto input_size = input.sizes(); int input_length = input_size.size(); c10::SmallVector input_size_tmp; - input_size_tmp = array_to_small_vector(input_size); + for (uint64_t i = 0; i < input_size.size(); i++) { + input_size_tmp.emplace_back(input_size[i]); + } if (input_length > 1) { for (int i = 0; i < input_length; i++) { if (i != 1) { diff --git a/mmcv/ops/csrc/pytorch/npu/gather_points_npu.cpp b/mmcv/ops/csrc/pytorch/npu/gather_points_npu.cpp index b84fcfcac2..1035a36ae0 100644 --- a/mmcv/ops/csrc/pytorch/npu/gather_points_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/gather_points_npu.cpp @@ -32,7 +32,11 @@ void gather_points_backward_npu(int b, int c, int n, int npoints, indices.unsqueeze_(0); } int64_t dim = 0; - at::SmallVector pad_size = array_to_small_vector(idx.sizes()); + auto shape = idx.sizes(); + c10::SmallVector pad_size; + for (uint64_t i = 0; i < shape.size(); i++) { + pad_size.emplace_back(shape[i]); + } at::Tensor trans_grad_points = grad_points.transpose(1, 2).contiguous(); at::Tensor grad_points_view = trans_grad_points.view( {trans_grad_points.sizes()[0] * trans_grad_points.sizes()[1], diff --git a/mmcv/ops/csrc/pytorch/npu/roi_align_npu.cpp b/mmcv/ops/csrc/pytorch/npu/roi_align_npu.cpp index f505b23e18..377563f755 100644 --- a/mmcv/ops/csrc/pytorch/npu/roi_align_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/roi_align_npu.cpp @@ -41,8 +41,11 @@ void roi_align_backward_npu(Tensor grad_output, Tensor rois, Tensor argmax_y, LOG(WARNING) << "The [aligned] attr in roi_align_grad op is false"; roi_end_mode = 0; } - c10::SmallVector xdiff_shape = - array_to_small_vector(grad_input.sizes()); + auto shape = grad_input.sizes(); + c10::SmallVector xdiff_shape; + for (uint64_t i = 0; i < shape.size(); i++) { + xdiff_shape.emplace_back(shape[i]); + } OpCommand cmd; cmd.Name("ROIAlignGrad") .Input(grad_output) diff --git a/mmcv/ops/csrc/pytorch/npu/roi_align_rotated_npu.cpp b/mmcv/ops/csrc/pytorch/npu/roi_align_rotated_npu.cpp new file mode 100644 index 0000000000..813678f8ca --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/roi_align_rotated_npu.cpp @@ -0,0 +1,69 @@ +#include "pytorch_npu_helper.hpp" + +using namespace NPU_NAME_SPACE; +using namespace std; + +void roi_align_rotated_forward_npu(Tensor input, Tensor rois, Tensor output, + int aligned_height, int aligned_width, + float spatial_scale, int sampling_ratio, + bool aligned, bool clockwise) { + int64_t aligned_height_64 = aligned_height; + int64_t aligned_width_64 = aligned_width; + int64_t sampling_ratio_64 = sampling_ratio; + OpCommand cmd; + cmd.Name("RoiAlignRotated") + .Input(input) + .Input(rois) + .Output(output) + .Attr("pooled_h", aligned_height_64) + .Attr("pooled_w", aligned_width_64) + .Attr("spatial_scale", spatial_scale) + .Attr("sampling_ratio", sampling_ratio_64) + .Attr("aligned", aligned) + .Attr("clockwise", clockwise) + .Run(); +} + +void roi_align_rotated_backward_npu(Tensor top_grad, Tensor rois, + Tensor bottom_grad, int aligned_height, + int aligned_width, float spatial_scale, + int sampling_ratio, bool aligned, + bool clockwise) { + int64_t aligned_height_64 = aligned_height; + int64_t aligned_width_64 = aligned_width; + int64_t sampling_ratio_64 = sampling_ratio; + c10::SmallVector y_grad_shape; + auto shape = bottom_grad.sizes(); + for (uint64_t i = 0; i < shape.size(); i++) { + y_grad_shape.emplace_back(shape[i]); + } + OpCommand cmd; + cmd.Name("RoiAlignRotatedGrad") + .Input(top_grad) + .Input(rois) + .Output(bottom_grad) + .Attr("y_grad_shape", y_grad_shape) + .Attr("pooled_h", aligned_width_64) + .Attr("pooled_w", aligned_height_64) + .Attr("spatial_scale", spatial_scale) + .Attr("sampling_ratio", sampling_ratio_64) + .Attr("aligned", aligned) + .Attr("clockwise", clockwise) + .Run(); +} + +void roi_align_rotated_forward_impl(Tensor input, Tensor rois, Tensor output, + int aligned_height, int aligned_width, + float spatial_scale, int sampling_ratio, + bool aligned, bool clockwise); + +void roi_align_rotated_backward_impl(Tensor top_grad, Tensor rois, + Tensor bottom_grad, int aligned_height, + int aligned_width, float spatial_scale, + int sampling_ratio, bool aligned, + bool clockwise); + +REGISTER_NPU_IMPL(roi_align_rotated_forward_impl, + roi_align_rotated_forward_npu); +REGISTER_NPU_IMPL(roi_align_rotated_backward_impl, + roi_align_rotated_backward_npu); diff --git a/mmcv/ops/csrc/pytorch/npu/rotated_feature_align_npu.cpp b/mmcv/ops/csrc/pytorch/npu/rotated_feature_align_npu.cpp new file mode 100644 index 0000000000..21217d350b --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/rotated_feature_align_npu.cpp @@ -0,0 +1,52 @@ +#include "pytorch_npu_helper.hpp" + +using namespace NPU_NAME_SPACE; +using namespace std; + +void rotated_feature_align_forward_impl(const Tensor features, + const Tensor best_bboxes, + const float spatial_scale, + const int points, Tensor output); + +void rotated_feature_align_backward_impl(const Tensor top_grad, + const Tensor best_bboxes, + const float spatial_scale, + const int points, Tensor bottom_grad); + +void rotated_feature_align_forward_npu(const Tensor features, + const Tensor best_bboxes, + const float spatial_scale, + const int points, Tensor output) { + int64_t points_ = (int64_t)points; + at::Tensor best_bboxes_ = best_bboxes.transpose(2, 3).transpose(1, 2); + OpCommand cmd; + cmd.Name("RotatedFeatureAlign") + .Input(features) + .Input(best_bboxes_) + .Output(output) + .Attr("spatial_scale", spatial_scale) + .Attr("points", points_) + .Run(); +} + +void rotated_feature_align_backward_npu(const Tensor top_grad, + const Tensor best_bboxes, + const float spatial_scale, + const int points, Tensor bottom_grad) { + int64_t points_ = (int64_t)points; + at::Tensor best_bboxes_ = best_bboxes.transpose(2, 3).transpose(1, 2); + OpCommand cmd; + cmd.Name("RotatedFeatureAlignGrad") + .Input(top_grad) + .Input(best_bboxes_) + .Output(bottom_grad) + .Attr("spatial_scale", spatial_scale) + .Attr("points", points_) + .Run(); +} + +REGISTER_NPU_IMPL(rotated_feature_align_forward_impl, + rotated_feature_align_forward_npu); + +REGISTER_NPU_IMPL(rotated_feature_align_backward_impl, + rotated_feature_align_backward_npu); diff --git a/mmcv/ops/csrc/pytorch/npu/stack_ball_query_npu.cpp b/mmcv/ops/csrc/pytorch/npu/stack_ball_query_npu.cpp new file mode 100644 index 0000000000..cd8c3ad8c9 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/stack_ball_query_npu.cpp @@ -0,0 +1,23 @@ +#include "pytorch_npu_helper.hpp" + +using namespace NPU_NAME_SPACE; +using namespace std; + +void stack_ball_query_forward_npu(float max_radius, int nsample, + const Tensor new_xyz, + const Tensor new_xyz_batch_cnt, + const Tensor xyz, const Tensor xyz_batch_cnt, + Tensor idx) { + at::Tensor xyz_transpose = xyz.transpose(0, 1).contiguous(); + double max_radius_double = double(max_radius); + EXEC_NPU_CMD(aclnnStackBallQuery, xyz_transpose, new_xyz, xyz_batch_cnt, + new_xyz_batch_cnt, max_radius_double, nsample, idx); +} + +void stack_ball_query_forward_impl(float max_radius, int nsample, + const Tensor new_xyz, + const Tensor new_xyz_batch_cnt, + const Tensor xyz, const Tensor xyz_batch_cnt, + Tensor idx); + +REGISTER_NPU_IMPL(stack_ball_query_forward_impl, stack_ball_query_forward_npu); diff --git a/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp b/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp index 0f1b14e7dc..07a5fed04b 100644 --- a/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp @@ -1,4 +1,6 @@ #include "pytorch_npu_helper.hpp" +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" using namespace NPU_NAME_SPACE; using namespace std; @@ -6,6 +8,10 @@ using namespace std; void three_interpolate_forward_npu(int b, int c, int m, int n, const Tensor points, const Tensor idx, const Tensor weight, Tensor out) { + auto originDtype = points.scalar_type(); + TORCH_CHECK((originDtype == at::kFloat || originDtype == at::kHalf), + "three_interpolate_forward ascend only support fp32 and fp16."); + auto point_c_trans = points.transpose(1, 2); OpCommand cmd; @@ -17,13 +23,37 @@ void three_interpolate_forward_npu(int b, int c, int m, int n, .Run(); auto output = out.view({b, n, c}).transpose(1, 2); - auto res = NpuUtils::format_contiguous(output); + auto res = output.contiguous(); out.copy_(res); } +void three_interpolate_backward_npu(int b, int c, int n, int m, + const Tensor grad_out, const Tensor idx, + const Tensor weight, Tensor grad_points) { + auto originDtype = grad_out.scalar_type(); + TORCH_CHECK((originDtype == at::kFloat || originDtype == at::kHalf), + "three_interpolate_backward ascend only support fp32 and fp16."); + + auto grad_x = at::unsqueeze(grad_out, 3); + auto grad_y = at::unsqueeze(grad_points, 3); + + EXEC_NPU_CMD(aclnnThreeInterpolateBackward, grad_x, idx, weight, m, grad_y); + + auto output = at::squeeze(grad_y, 3); + auto res = output.contiguous(); + grad_points.copy_(res); +} + void three_interpolate_forward_impl(int b, int c, int m, int n, const Tensor points, const Tensor idx, const Tensor weight, Tensor out); +void three_interpolate_backward_impl(int b, int c, int n, int m, + const Tensor grad_out, const Tensor idx, + const Tensor weight, Tensor grad_points); + REGISTER_NPU_IMPL(three_interpolate_forward_impl, three_interpolate_forward_npu); + +REGISTER_NPU_IMPL(three_interpolate_backward_impl, + three_interpolate_backward_npu); diff --git a/setup.py b/setup.py index c0ccde298e..546dc8c9e9 100644 --- a/setup.py +++ b/setup.py @@ -397,12 +397,21 @@ def get_mluops_version(file_path): elif (os.getenv('FORCE_NPU', '0') == '1'): print(f'Compiling {ext_name} only with CPU and NPU') try: + import imp + from torch_npu.utils.cpp_extension import NpuExtension + extra_compile_args['cxx'] += [ + '-D__FILENAME__=\"$$(notdir $$(abspath $$<))\"' + ] + extra_compile_args['cxx'] += [ + '-I' + imp.find_module('torch_npu')[1] + + '/include/third_party/acl/inc' + ] define_macros += [('MMCV_WITH_NPU', None)] extension = NpuExtension - if parse_version(torch.__version__) <= parse_version('2.0.0'): + if parse_version(torch.__version__) < parse_version('2.1.0'): define_macros += [('MMCV_WITH_XLA', None)] - if parse_version(torch.__version__) > parse_version('2.0.0'): + if parse_version(torch.__version__) >= parse_version('2.1.0'): define_macros += [('MMCV_WITH_KPRIVATE', None)] except Exception: raise ImportError('can not find any torch_npu') diff --git a/tests/test_ops/test_ball_query.py b/tests/test_ops/test_ball_query.py index 25899f2e1f..b64d21cbbb 100644 --- a/tests/test_ops/test_ball_query.py +++ b/tests/test_ops/test_ball_query.py @@ -63,20 +63,25 @@ def test_ball_query(device): assert torch.all(idx == expected_idx) -@pytest.mark.skipif( - not torch.cuda.is_available(), reason='requires CUDA support') -def test_stack_ball_query(): - new_xyz = torch.tensor([[-0.0740, 1.3147, -1.3625], - [-2.2769, 2.7817, -0.2334], - [-0.4003, 2.4666, -0.5116], - [-0.0740, 1.3147, -1.3625], - [-0.0740, 1.3147, -1.3625], - [-2.0289, 2.4952, -0.1708], - [-2.0668, 6.0278, -0.4875], - [0.4066, 1.4211, -0.2947], - [-2.0289, 2.4952, -0.1708], - [-2.0289, 2.4952, -0.1708]]).cuda() - new_xyz_batch_cnt = torch.tensor([5, 5], dtype=torch.int32).cuda() +@pytest.mark.parametrize('device', [ + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'npu', + marks=pytest.mark.skipif( + not IS_NPU_AVAILABLE, reason='requires NPU support')) +]) +def test_stack_ball_query(device): + new_xyz = torch.tensor( + [[-0.0740, 1.3147, -1.3625], [-2.2769, 2.7817, -0.2334], + [-0.4003, 2.4666, -0.5116], [-0.0740, 1.3147, -1.3625], + [-0.0740, 1.3147, -1.3625], [-2.0289, 2.4952, -0.1708], + [-2.0668, 6.0278, -0.4875], [0.4066, 1.4211, -0.2947], + [-2.0289, 2.4952, -0.1708], [-2.0289, 2.4952, -0.1708]], + device=device) + new_xyz_batch_cnt = torch.tensor([5, 5], dtype=torch.int32, device=device) xyz = torch.tensor([[-0.0740, 1.3147, -1.3625], [0.5555, 1.0399, -1.3634], [-0.4003, 2.4666, -0.5116], [-0.5251, 2.4379, -0.8466], [-0.9691, 1.1418, -1.3733], [-0.2232, 0.9561, -1.3626], @@ -86,15 +91,15 @@ def test_stack_ball_query(): [-2.0668, 6.0278, -0.4875], [-1.9304, 3.3092, 0.6610], [0.0949, 1.4332, 0.3140], [-1.2879, 2.0008, -0.7791], [-0.7252, 0.9611, -0.6371], [0.4066, 1.4211, -0.2947], - [0.3220, 1.4447, 0.3548], [-0.9744, 2.3856, - -1.2000]]).cuda() - xyz_batch_cnt = torch.tensor([10, 10], dtype=torch.int32).cuda() + [0.3220, 1.4447, 0.3548], [-0.9744, 2.3856, -1.2000]], + device=device) + xyz_batch_cnt = torch.tensor([10, 10], dtype=torch.int32, device=device) idx = ball_query(0, 0.2, 5, xyz, new_xyz, xyz_batch_cnt, new_xyz_batch_cnt) - expected_idx = torch.tensor([[0, 0, 0, 0, 0], [6, 6, 6, 6, 6], - [2, 2, 2, 2, 2], [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], - [2, 2, 2, 2, 2], [7, 7, 7, 7, 7], - [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]).cuda() + expected_idx = torch.tensor( + [[0, 0, 0, 0, 0], [6, 6, 6, 6, 6], [2, 2, 2, 2, 2], [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [2, 2, 2, 2, 2], [7, 7, 7, 7, 7], + [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + device=device) assert torch.all(idx == expected_idx) xyz = xyz.double() diff --git a/tests/test_ops/test_chamfer_distance.py b/tests/test_ops/test_chamfer_distance.py index 522dcdddc7..de45e92ea7 100644 --- a/tests/test_ops/test_chamfer_distance.py +++ b/tests/test_ops/test_chamfer_distance.py @@ -1,57 +1,72 @@ # Copyright (c) OpenMMLab. All rights reserved. +import numpy as np import pytest import torch from mmcv.ops import chamfer_distance +from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE -@pytest.mark.skipif( - not torch.cuda.is_available(), reason='requires CUDA support') -def test_chamfer_distance(): - pointset1 = torch.tensor( - [[[1.3, 9.39], [2.3, 9.39], [2.3, 10.39], [1.3, 10.39]], - [[1.0, 9.39], [3.0, 9.39], [3.0, 10.39], [1.0, 10.39]], - [[1.6, 9.99], [2.3, 9.99], [2.3, 10.39], [1.6, 10.39]]], - device='cuda', - requires_grad=True) +def chamfer_distance_forward_gloden(xyz1, xyz2, dtype): + bs, ns, ss = xyz1.shape + dist1 = np.zeros((bs, ns)).astype(torch_type_trans(dtype)) + dist2 = np.zeros((bs, ns)).astype(torch_type_trans(dtype)) + idx1 = np.zeros((bs, ns)).astype('int32') + idx2 = np.zeros((bs, ns)).astype('int32') + for b1 in range(bs): + for n1 in range(ns): + x1, y1 = xyz1[b1][n1] + dist1[b1][n1] = 10000000 + for n2 in range(ns): + x2, y2 = xyz2[b1][n2] + dst = (x1 - x2) * (x1 - x2) + (y1 - y2) * (y1 - y2) + if dist1[b1][n1] > dst: + dist1[b1][n1] = dst + idx1[b1][n1] = n2 + for b1 in range(bs): + for n1 in range(ns): + x1, y1 = xyz2[b1][n1] + dist2[b1][n1] = 10000000 + for n2 in range(ns): + x2, y2 = xyz1[b1][n2] + dst = (x1 - x2) * (x1 - x2) + (y1 - y2) * (y1 - y2) + if dist2[b1][n1] > dst: + dist2[b1][n1] = dst + idx2[b1][n1] = n2 + return [dist1, dist2, idx1, idx2] - pointset2 = torch.tensor( - [[[1.0, 9.39], [3.0, 9.39], [3.0, 10.39], [1.0, 10.39]], - [[1.3, 9.39], [2.3, 9.39], [2.3, 10.39], [1.3, 10.39]], - [[1.0, 9.39], [3.0, 9.39], [3.0, 10.39], [1.0, 10.39]]], - device='cuda', - requires_grad=True) - expected_dist1 = torch.tensor( - [[0.0900, 0.4900, 0.4900, 0.0900], [0.0900, 0.4900, 0.4900, 0.0900], - [0.5200, 0.6500, 0.4900, 0.3600]], - device='cuda') - expected_dist2 = torch.tensor( - [[0.0900, 0.4900, 0.4900, 0.0900], [0.0900, 0.4900, 0.4900, 0.0900], - [0.7200, 0.8500, 0.4900, 0.3600]], - device='cuda') +def torch_type_trans(dtype): + if dtype == torch.half: + return np.float16 + elif dtype == torch.float32: + return np.float32 - expected_pointset1_grad = torch.tensor( - [[[0.6000, 0.0000], [-1.4000, 0.0000], [-1.4000, 0.0000], - [0.6000, 0.0000]], - [[-0.6000, 0.0000], [1.4000, 0.0000], [1.4000, 0.0000], - [-0.6000, 0.0000]], - [[1.2000, -0.8000], [-1.4000, -0.8000], [-1.4000, 0.0000], - [1.2000, 0.0000]]], - device='cuda') - expected_pointset2_grad = torch.tensor( - [[[-0.6000, 0.0000], [1.4000, 0.0000], [1.4000, 0.0000], - [-0.6000, 0.0000]], - [[0.6000, 0.0000], [-1.4000, 0.0000], [-1.4000, 0.0000], - [0.6000, 0.0000]], - [[0.0000, 0.0000], [0.0000, 0.0000], [2.8000, 0.8000], - [-2.4000, 0.8000]]], - device='cuda') - - dist1, dist2, idx1, idx2 = chamfer_distance(pointset1, pointset2) - dist1.backward(torch.ones_like(dist1)) - assert torch.allclose(dist1, expected_dist1, 1e-2) - assert torch.allclose(dist2, expected_dist2, 1e-2) - assert torch.allclose(pointset1.grad.data, expected_pointset1_grad, 1e-2) - assert torch.allclose(pointset2.grad.data, expected_pointset2_grad, 1e-2) +@pytest.mark.parametrize('device', [ + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'npu', + marks=pytest.mark.skipif( + not IS_NPU_AVAILABLE, reason='requires NPU support')) +]) +@pytest.mark.parametrize('dtype', [torch.half, torch.float32]) +@pytest.mark.parametrize('shape', [(2, 600, 2), (2, 600, 2)]) +def test_chamfer_distance_npu_dynamic_shape(dtype, device, shape): + bs = shape[0] + ns = shape[1] + xyz1 = np.random.uniform(-10.0, 10.0, + (bs, ns, 2)).astype(torch_type_trans(dtype)) + xyz2 = np.random.uniform(-10.0, 10.0, + (bs, ns, 2)).astype(torch_type_trans(dtype)) + xyz1_npu = torch.tensor(xyz1, dtype=dtype).to(device) + xyz2_npu = torch.tensor(xyz2, dtype=dtype).to(device) + expected_output = chamfer_distance_forward_gloden(xyz1, xyz2, dtype) + output = chamfer_distance(xyz1_npu, xyz2_npu) + assert np.allclose(output[0].cpu().numpy(), expected_output[0], 1e-3, 1e-4) + assert np.allclose(output[1].cpu().numpy(), expected_output[1], 1e-3, 1e-4) + assert np.allclose(output[2].cpu().numpy(), expected_output[2], 1e-3, 1e-4) + assert np.allclose(output[3].cpu().numpy(), expected_output[3], 1e-3, 1e-4) diff --git a/tests/test_ops/test_rotated_feature_align.py b/tests/test_ops/test_rotated_feature_align.py index 005cbcf01c..23de07e8ef 100644 --- a/tests/test_ops/test_rotated_feature_align.py +++ b/tests/test_ops/test_rotated_feature_align.py @@ -3,7 +3,7 @@ import torch from mmcv.ops import rotated_feature_align -from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE @pytest.mark.skipif( @@ -17,6 +17,10 @@ 'mlu', marks=pytest.mark.skipif( not IS_MLU_AVAILABLE, reason='requires MLU support')), + pytest.param( + 'npu', + marks=pytest.mark.skipif( + not IS_NPU_AVAILABLE, reason='requires NPU support')), pytest.param( 'cpu', marks=pytest.mark.skipif( From bc6ce53712e60be384344ee87873f0710be0017c Mon Sep 17 00:00:00 2001 From: momo609 <963372609@qq.com> Date: Thu, 28 Dec 2023 18:42:48 +0800 Subject: [PATCH 2/5] FIX LINT --- mmcv/ops/csrc/common/pytorch_npu_util.hpp | 19 +++++++++-------- .../csrc/pytorch/npu/chamfer_distance_npu.cpp | 21 ++++++++++--------- mmcv/ops/csrc/pytorch/npu/common_util.h | 10 ++++----- mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp | 4 ++-- .../pytorch/npu/fused_bias_leakyrelu_npu.cpp | 2 +- .../csrc/pytorch/npu/gather_points_npu.cpp | 2 +- mmcv/ops/csrc/pytorch/npu/roi_align_npu.cpp | 2 +- .../pytorch/npu/roi_align_rotated_npu.cpp | 2 +- .../pytorch/npu/three_interpolate_npu.cpp | 2 +- 9 files changed, 32 insertions(+), 32 deletions(-) diff --git a/mmcv/ops/csrc/common/pytorch_npu_util.hpp b/mmcv/ops/csrc/common/pytorch_npu_util.hpp index 8c26a934f0..3c3712a933 100644 --- a/mmcv/ops/csrc/common/pytorch_npu_util.hpp +++ b/mmcv/ops/csrc/common/pytorch_npu_util.hpp @@ -204,9 +204,8 @@ inline c10::Scalar ConvertTensorToScalar(const at::Tensor &tensor) { inline at::Tensor CopyTensorHostToDevice(const at::Tensor &cpu_tensor) { at::Tensor cpuPinMemTensor = cpu_tensor.pin_memory(); int deviceIndex = 0; - return cpuPinMemTensor.to( - c10::Device(DEVICE_TYPE, deviceIndex), - cpuPinMemTensor.scalar_type(), true, true); + return cpuPinMemTensor.to(c10::Device(DEVICE_TYPE, deviceIndex), + cpuPinMemTensor.scalar_type(), true, true); } inline at::Tensor CopyScalarToDevice(const c10::Scalar &cpu_scalar, @@ -260,16 +259,18 @@ inline aclTensor *ConvertType(const at::Tensor &at_tensor) { if (at_tensor.unsafeGetTensorImpl()->is_wrapped_number()) { c10::Scalar expScalar = ConvertTensorToScalar(at_tensor); at::Tensor aclInput = CopyScalarToDevice(expScalar, scalar_data_type); - return aclCreateTensor( - aclInput.sizes().data(), aclInput.sizes().size(), acl_data_type, - aclInput.strides().data(), aclInput.storage_offset(), format, - storageDims.data(), storageDims.size(), const_cast(aclInput.storage().data())); + return aclCreateTensor(aclInput.sizes().data(), aclInput.sizes().size(), + acl_data_type, aclInput.strides().data(), + aclInput.storage_offset(), format, + storageDims.data(), storageDims.size(), + const_cast(aclInput.storage().data())); } auto acl_tensor = aclCreateTensor( at_tensor.sizes().data(), at_tensor.sizes().size(), acl_data_type, at_tensor.strides().data(), at_tensor.storage_offset(), format, - storageDims.data(), storageDims.size(), const_cast(at_tensor.storage().data())); + storageDims.data(), storageDims.size(), + const_cast(at_tensor.storage().data())); return acl_tensor; } @@ -554,7 +555,7 @@ typedef void (*ReleaseHugeMem)(void *, bool); at::TensorOptions(torch_npu::utils::get_npu_device_type()); \ auto workspace_tensor = \ at::empty({workspace_size}, options.dtype(kByte)); \ - workspace_addr = const_cast(workspace_tensor.storage().data()); \ + workspace_addr = const_cast(workspace_tensor.storage().data()); \ } \ auto acl_call = [converted_params, workspace_addr, workspace_size, \ acl_stream, executor]() -> int { \ diff --git a/mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp b/mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp index 800e636945..56408282f7 100644 --- a/mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp @@ -20,20 +20,21 @@ void chamfer_distance_forward_npu(Tensor XYZ1, Tensor XYZ2, Tensor dist1, .Run(); } -void chamfer_distance_backward_npu(Tensor xyz1, Tensor xyz2, - Tensor idx1, Tensor idx2, Tensor grad_dist1, Tensor grad_dist2, - Tensor grad_xyz1, Tensor grad_xyz2) { - EXEC_NPU_CMD(aclnnChamferDistanceBackward, xyz1, xyz2, idx1, idx2, - grad_dist1, grad_dist2, grad_xyz1, grad_xyz2); +void chamfer_distance_backward_npu(Tensor xyz1, Tensor xyz2, Tensor idx1, + Tensor idx2, Tensor grad_dist1, + Tensor grad_dist2, Tensor grad_xyz1, + Tensor grad_xyz2) { + EXEC_NPU_CMD(aclnnChamferDistanceBackward, xyz1, xyz2, idx1, idx2, grad_dist1, + grad_dist2, grad_xyz1, grad_xyz2); } void chamfer_distance_forward_impl(Tensor XYZ1, Tensor XYZ2, Tensor dist1, Tensor dist2, Tensor idx1, Tensor idx2); -REGISTER_NPU_IMPL(chamfer_distance_forward_impl, - chamfer_distance_forward_npu); +REGISTER_NPU_IMPL(chamfer_distance_forward_impl, chamfer_distance_forward_npu); -void chamfer_distance_backward_impl(Tensor xyz1, Tensor xyz2, Tensor idx1, Tensor idx2, - Tensor grad_dist1, Tensor grad_dist2, - Tensor grad_xyz1, Tensor grad_xyz2); +void chamfer_distance_backward_impl(Tensor xyz1, Tensor xyz2, Tensor idx1, + Tensor idx2, Tensor grad_dist1, + Tensor grad_dist2, Tensor grad_xyz1, + Tensor grad_xyz2); REGISTER_NPU_IMPL(chamfer_distance_backward_impl, chamfer_distance_backward_npu); diff --git a/mmcv/ops/csrc/pytorch/npu/common_util.h b/mmcv/ops/csrc/pytorch/npu/common_util.h index 5a303e8764..db05e6a05c 100644 --- a/mmcv/ops/csrc/pytorch/npu/common_util.h +++ b/mmcv/ops/csrc/pytorch/npu/common_util.h @@ -3,12 +3,10 @@ const int SIZE = 8; c10::SmallVector array_to_vector(c10::IntArrayRef shape) { - c10::SmallVector shape_small_vec; - for (uint64_t i = 0; i < shape.size(); i++) { - shape_small_vec.emplace_back(shape[i]); - } - - return shape_small_vec; + c10::SmallVector shape_small_vec; + for (uint64_t i = 0; i < shape.size(); i++) { + shape_small_vec.emplace_back(shape[i]); + } } #endif // MMCV_OPS_CSRC_COMMON__UTIL_HPP_ diff --git a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp index b39e24f2df..5030fed0e7 100644 --- a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp @@ -102,11 +102,11 @@ void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, at::IntArrayRef size_array = at::IntArrayRef(sizes); c10::SmallVector offsetVec; for (uint64_t i = 0; i < offset.size(); i++) { - offsetVec.emplace_back(offset[i]); + offsetVec.emplace_back(offset[i]); } c10::SmallVector sizeVec; for (uint64_t i = 0; i < size_array.size(); i++) { - sizeVec.emplace_back(size_array[i]); + sizeVec.emplace_back(size_array[i]); } OpCommand cmd2; cmd2.Name("Slice") diff --git a/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp b/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp index 071f1c4021..2e1270e450 100644 --- a/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp @@ -17,7 +17,7 @@ Tensor fused_bias_leakyrelu_npu(const Tensor &input, const Tensor &bias, int input_length = input_size.size(); c10::SmallVector input_size_tmp; for (uint64_t i = 0; i < input_size.size(); i++) { - input_size_tmp.emplace_back(input_size[i]); + input_size_tmp.emplace_back(input_size[i]); } if (input_length > 1) { for (int i = 0; i < input_length; i++) { diff --git a/mmcv/ops/csrc/pytorch/npu/gather_points_npu.cpp b/mmcv/ops/csrc/pytorch/npu/gather_points_npu.cpp index 1035a36ae0..747380fb09 100644 --- a/mmcv/ops/csrc/pytorch/npu/gather_points_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/gather_points_npu.cpp @@ -35,7 +35,7 @@ void gather_points_backward_npu(int b, int c, int n, int npoints, auto shape = idx.sizes(); c10::SmallVector pad_size; for (uint64_t i = 0; i < shape.size(); i++) { - pad_size.emplace_back(shape[i]); + pad_size.emplace_back(shape[i]); } at::Tensor trans_grad_points = grad_points.transpose(1, 2).contiguous(); at::Tensor grad_points_view = trans_grad_points.view( diff --git a/mmcv/ops/csrc/pytorch/npu/roi_align_npu.cpp b/mmcv/ops/csrc/pytorch/npu/roi_align_npu.cpp index 377563f755..0e673614fa 100644 --- a/mmcv/ops/csrc/pytorch/npu/roi_align_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/roi_align_npu.cpp @@ -44,7 +44,7 @@ void roi_align_backward_npu(Tensor grad_output, Tensor rois, Tensor argmax_y, auto shape = grad_input.sizes(); c10::SmallVector xdiff_shape; for (uint64_t i = 0; i < shape.size(); i++) { - xdiff_shape.emplace_back(shape[i]); + xdiff_shape.emplace_back(shape[i]); } OpCommand cmd; cmd.Name("ROIAlignGrad") diff --git a/mmcv/ops/csrc/pytorch/npu/roi_align_rotated_npu.cpp b/mmcv/ops/csrc/pytorch/npu/roi_align_rotated_npu.cpp index 813678f8ca..2a3ff09e98 100644 --- a/mmcv/ops/csrc/pytorch/npu/roi_align_rotated_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/roi_align_rotated_npu.cpp @@ -35,7 +35,7 @@ void roi_align_rotated_backward_npu(Tensor top_grad, Tensor rois, c10::SmallVector y_grad_shape; auto shape = bottom_grad.sizes(); for (uint64_t i = 0; i < shape.size(); i++) { - y_grad_shape.emplace_back(shape[i]); + y_grad_shape.emplace_back(shape[i]); } OpCommand cmd; cmd.Name("RoiAlignRotatedGrad") diff --git a/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp b/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp index 07a5fed04b..f908755478 100644 --- a/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp @@ -1,6 +1,6 @@ #include "pytorch_npu_helper.hpp" -#include "torch_npu/csrc/framework/utils/OpAdapter.h" #include "torch_npu/csrc/aten/NPUNativeFunctions.h" +#include "torch_npu/csrc/framework/utils/OpAdapter.h" using namespace NPU_NAME_SPACE; using namespace std; From 8b137d5fc1cf7a284a6d81982bc2a6f1d45fa76e Mon Sep 17 00:00:00 2001 From: momo609 <963372609@qq.com> Date: Fri, 29 Dec 2023 14:26:57 +0800 Subject: [PATCH 3/5] FIX LINT --- mmcv/ops/csrc/pytorch/npu/ball_query_npu.cpp | 4 ++-- mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp | 2 +- mmcv/ops/csrc/pytorch/npu/group_points_npu.cpp | 4 ++-- mmcv/ops/csrc/pytorch/npu/points_in_polygons_npu.cpp | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/npu/ball_query_npu.cpp b/mmcv/ops/csrc/pytorch/npu/ball_query_npu.cpp index 9167875376..ca743500cd 100644 --- a/mmcv/ops/csrc/pytorch/npu/ball_query_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/ball_query_npu.cpp @@ -15,7 +15,7 @@ void ball_query_forward_npu(int b, int n, int m, float min_radius, at::Tensor xyz_transpose = xyz.transpose(1, 2); // transpose idx from [B, M, nsample] to [M, B, nsample] - at::Tensor idx_transpose = NpuUtils::format_contiguous(idx.transpose(0, 1)); + at::Tensor idx_transpose = idx.transpose(0, 1).contiguous(); OpCommand cmd; cmd.Name("BallQuery") @@ -27,7 +27,7 @@ void ball_query_forward_npu(int b, int n, int m, float min_radius, .Attr("sample_num", nsample_i64) .Run(); - idx_transpose = NpuUtils::format_contiguous(idx_transpose.transpose(0, 1)); + idx_transpose = idx_transpose.transpose(0, 1).contiguous(); idx.copy_(idx_transpose); } diff --git a/mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp b/mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp index 56408282f7..8b30fa15df 100644 --- a/mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp @@ -29,7 +29,7 @@ void chamfer_distance_backward_npu(Tensor xyz1, Tensor xyz2, Tensor idx1, } void chamfer_distance_forward_impl(Tensor XYZ1, Tensor XYZ2, Tensor dist1, - Tensor dist2, Tensor idx1, Tensor idx2); + Tensor dist2, Tensor idx1, Tensor idx2); REGISTER_NPU_IMPL(chamfer_distance_forward_impl, chamfer_distance_forward_npu); void chamfer_distance_backward_impl(Tensor xyz1, Tensor xyz2, Tensor idx1, diff --git a/mmcv/ops/csrc/pytorch/npu/group_points_npu.cpp b/mmcv/ops/csrc/pytorch/npu/group_points_npu.cpp index eabf9118c7..f52789bbcc 100644 --- a/mmcv/ops/csrc/pytorch/npu/group_points_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/group_points_npu.cpp @@ -20,7 +20,7 @@ void group_points_forward_npu(int b, int c, int n, int npoints, int nsample, indices = indices.view({-1}); at::Tensor trans_features = points.transpose(1, 2); - at::Tensor features = NpuUtils::format_contiguous(trans_features); + at::Tensor features = trans_features.contiguous(); features = features.view({b * n, c}); OpCommand cmd; @@ -34,7 +34,7 @@ void group_points_forward_npu(int b, int c, int n, int npoints, int nsample, at::Tensor output = out.view({b, npoints, nsample, c}).transpose(1, 3).transpose(2, 3); - at::Tensor res = NpuUtils::format_contiguous(output); + at::Tensor res = output.contiguous(); out.copy_(res); } diff --git a/mmcv/ops/csrc/pytorch/npu/points_in_polygons_npu.cpp b/mmcv/ops/csrc/pytorch/npu/points_in_polygons_npu.cpp index f282afeed3..6b8f08635a 100644 --- a/mmcv/ops/csrc/pytorch/npu/points_in_polygons_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/points_in_polygons_npu.cpp @@ -12,7 +12,7 @@ void points_in_polygons_npu(const Tensor points, Tensor polygons, Tensor output, "The batch of polygons tensor must be less than MAX_POLYGONS_BATCH"); at::Tensor trans_polygons = polygons.transpose(0, 1); OpCommand cmd; - at::Tensor new_trans_polygons = NpuUtils::format_contiguous(trans_polygons); + at::Tensor new_trans_polygons = trans_polygons.contiguous(); cmd.Name("PointsInPolygons") .Input(points, (string) "points") .Input(new_trans_polygons, (string) "polygons") From e890c7861f98064b9fa0d06414c93c0e75004b6c Mon Sep 17 00:00:00 2001 From: momo609 <963372609@qq.com> Date: Thu, 4 Jan 2024 17:08:34 +0800 Subject: [PATCH 4/5] FIX --- setup.py | 5 +++-- tests/test_ops/test_chamfer_distance.py | 14 +++++++------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/setup.py b/setup.py index 546dc8c9e9..02afded81e 100644 --- a/setup.py +++ b/setup.py @@ -397,14 +397,15 @@ def get_mluops_version(file_path): elif (os.getenv('FORCE_NPU', '0') == '1'): print(f'Compiling {ext_name} only with CPU and NPU') try: - import imp + import importlib from torch_npu.utils.cpp_extension import NpuExtension extra_compile_args['cxx'] += [ '-D__FILENAME__=\"$$(notdir $$(abspath $$<))\"' ] extra_compile_args['cxx'] += [ - '-I' + imp.find_module('torch_npu')[1] + + '-I' + importlib.util.find_spec( + 'torch_npu').submodule_search_location[0] + '/include/third_party/acl/inc' ] define_macros += [('MMCV_WITH_NPU', None)] diff --git a/tests/test_ops/test_chamfer_distance.py b/tests/test_ops/test_chamfer_distance.py index de45e92ea7..93563895bd 100644 --- a/tests/test_ops/test_chamfer_distance.py +++ b/tests/test_ops/test_chamfer_distance.py @@ -7,10 +7,10 @@ from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE -def chamfer_distance_forward_gloden(xyz1, xyz2, dtype): +def chamfer_distance_forward_groundtruth(xyz1, xyz2, dtype): bs, ns, ss = xyz1.shape - dist1 = np.zeros((bs, ns)).astype(torch_type_trans(dtype)) - dist2 = np.zeros((bs, ns)).astype(torch_type_trans(dtype)) + dist1 = np.zeros((bs, ns)).astype(torch_to_np_type(dtype)) + dist2 = np.zeros((bs, ns)).astype(torch_to_np_type(dtype)) idx1 = np.zeros((bs, ns)).astype('int32') idx2 = np.zeros((bs, ns)).astype('int32') for b1 in range(bs): @@ -36,7 +36,7 @@ def chamfer_distance_forward_gloden(xyz1, xyz2, dtype): return [dist1, dist2, idx1, idx2] -def torch_type_trans(dtype): +def torch_to_np_type(dtype): if dtype == torch.half: return np.float16 elif dtype == torch.float32: @@ -59,12 +59,12 @@ def test_chamfer_distance_npu_dynamic_shape(dtype, device, shape): bs = shape[0] ns = shape[1] xyz1 = np.random.uniform(-10.0, 10.0, - (bs, ns, 2)).astype(torch_type_trans(dtype)) + (bs, ns, 2)).astype(torch_to_np_type(dtype)) xyz2 = np.random.uniform(-10.0, 10.0, - (bs, ns, 2)).astype(torch_type_trans(dtype)) + (bs, ns, 2)).astype(torch_to_np_type(dtype)) xyz1_npu = torch.tensor(xyz1, dtype=dtype).to(device) xyz2_npu = torch.tensor(xyz2, dtype=dtype).to(device) - expected_output = chamfer_distance_forward_gloden(xyz1, xyz2, dtype) + expected_output = chamfer_distance_forward_groundtruth(xyz1, xyz2, dtype) output = chamfer_distance(xyz1_npu, xyz2_npu) assert np.allclose(output[0].cpu().numpy(), expected_output[0], 1e-3, 1e-4) assert np.allclose(output[1].cpu().numpy(), expected_output[1], 1e-3, 1e-4) From d3fa25526441a8b663324586d17157c47914c783 Mon Sep 17 00:00:00 2001 From: momo609 <963372609@qq.com> Date: Thu, 4 Jan 2024 17:10:12 +0800 Subject: [PATCH 5/5] FIX --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 02afded81e..3df847dd5b 100644 --- a/setup.py +++ b/setup.py @@ -405,7 +405,7 @@ def get_mluops_version(file_path): ] extra_compile_args['cxx'] += [ '-I' + importlib.util.find_spec( - 'torch_npu').submodule_search_location[0] + + 'torch_npu').submodule_search_locations[0] + '/include/third_party/acl/inc' ] define_macros += [('MMCV_WITH_NPU', None)]