diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index f030ea0136a95..5df6dc96a3b2e 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -2234,7 +2234,7 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx, ACL_MEM_MALLOC_HUGE_FIRST)); acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float), - theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); + theta_scale_ne, theta_scale_nb, 1); float start = 0; float step = 1; @@ -2251,7 +2251,7 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx, yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float)); void * yarn_ramp_buffer = yarn_ramp_allocator.get(); acl_yarn_ramp_tensor = ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne, - theta_scale_nb, GGML_MAX_DIMS); + theta_scale_nb, 1); float zero_value = 0, one_value = 1; float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]); aclScalar * low = aclCreateScalar(&corr_dims[0], aclDataType::ACL_FLOAT); diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 8bd5449f1f75f..51345742ee59e 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -67,19 +67,30 @@ GGML_ABORT("CANN error"); } +// Thread-local variable to record the current device of this thread. +thread_local int g_current_cann_device = -1; + /** - * @brief Sets the device to be used by CANN. + * @brief Set the CANN device to be used. * - * @param device The device ID to set. + * @param device The target device ID to set. */ void ggml_cann_set_device(const int32_t device) { - int current_device = -1; - aclrtGetDevice(¤t_device); + // int current_device = -1; + // Note: In some CANN versions, if no device has been set yet, + // aclrtGetDevice(¤t_device) may return 0 by default. + // aclrtGetDevice(¤t_device); - if (device == current_device) { + // If the current device is already the target one, no need to switch. + if (device == g_current_cann_device) { return; } + + // Switch to the new device. ACL_CHECK(aclrtSetDevice(device)); + + // Update the global device record. + g_current_cann_device = device; } /**