Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ggml/src/ggml-cann/aclnn_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2234,7 +2234,7 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
ACL_MEM_MALLOC_HUGE_FIRST));

acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
theta_scale_ne, theta_scale_nb, 1);

float start = 0;
float step = 1;
Expand All @@ -2251,7 +2251,7 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float));
void * yarn_ramp_buffer = yarn_ramp_allocator.get();
acl_yarn_ramp_tensor = ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne,
theta_scale_nb, GGML_MAX_DIMS);
theta_scale_nb, 1);
float zero_value = 0, one_value = 1;
float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
aclScalar * low = aclCreateScalar(&corr_dims[0], aclDataType::ACL_FLOAT);
Expand Down
21 changes: 16 additions & 5 deletions ggml/src/ggml-cann/ggml-cann.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,19 +67,30 @@
GGML_ABORT("CANN error");
}

// Thread-local variable to record the current device of this thread.
thread_local int g_current_cann_device = -1;

/**
* @brief Sets the device to be used by CANN.
* @brief Set the CANN device to be used.
*
* @param device The device ID to set.
* @param device The target device ID to set.
*/
void ggml_cann_set_device(const int32_t device) {
int current_device = -1;
aclrtGetDevice(&current_device);
// int current_device = -1;
// Note: In some CANN versions, if no device has been set yet,
// aclrtGetDevice(&current_device) may return 0 by default.
// aclrtGetDevice(&current_device);

if (device == current_device) {
// If the current device is already the target one, no need to switch.
if (device == g_current_cann_device) {
return;
}

// Switch to the new device.
ACL_CHECK(aclrtSetDevice(device));

// Update the global device record.
g_current_cann_device = device;
}

/**
Expand Down
Loading