Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion ggml/src/ggml-cann/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -341,11 +341,18 @@ class cann_task_queue {

#ifdef USE_ACL_GRAPH
struct ggml_graph_node_properties {
// dst tensor
void * node_address;
ggml_op node_op;
int64_t ne[GGML_MAX_DIMS];
size_t nb[GGML_MAX_DIMS];

// src tensor
void * src_address[GGML_MAX_SRC];
int64_t src_ne[GGML_MAX_SRC][GGML_MAX_DIMS];
size_t src_nb[GGML_MAX_SRC][GGML_MAX_DIMS];

// op
ggml_op node_op;
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
};

Expand Down
48 changes: 37 additions & 11 deletions ggml/src/ggml-cann/ggml-cann.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2186,7 +2186,15 @@ static void add_lru_matched_graph_node_properties(
std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb);

for (int src = 0; src < GGML_MAX_SRC; ++src) {
prop.src_address[src] = node->src[src] ? node->src[src]->data : nullptr;
if (node->src[src]) {
prop.src_address[src] = node->src[src]->data;
std::copy_n(node->src[src]->ne, GGML_MAX_DIMS, prop.src_ne[src]);
std::copy_n(node->src[src]->nb, GGML_MAX_DIMS, prop.src_nb[src]);
} else {
prop.src_address[src] = nullptr;
std::fill_n(prop.src_ne[src], GGML_MAX_DIMS, 0);
std::fill_n(prop.src_nb[src], GGML_MAX_DIMS, 0);
}
}

memcpy(prop.op_params, node->op_params, GGML_MAX_OP_PARAMS);
Expand All @@ -2206,14 +2214,18 @@ static void add_lru_matched_graph_node_properties(
* @param graph_node_properties The stored properties of a CANN graph node.
* @return true if all fields match (excluding GGML_OP_VIEW); false otherwise.
*/
static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
static bool ggml_graph_node_has_matching_properties(
ggml_tensor * node,
ggml_graph_node_properties * graph_node_properties) {
if (node->data != graph_node_properties->node_address &&
node->op != GGML_OP_VIEW) {
node->op != GGML_OP_VIEW) {
return false;
}

if (node->op != graph_node_properties->node_op) {
return false;
}

for (int i = 0; i < GGML_MAX_DIMS; i++) {
if (node->ne[i] != graph_node_properties->ne[i]) {
return false;
Expand All @@ -2222,17 +2234,31 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
return false;
}
}

for (int i = 0; i < GGML_MAX_SRC; i++) {
if (node->src[i] &&
node->src[i]->data != graph_node_properties->src_address[i] &&
node->op != GGML_OP_VIEW
) {
return false;
if (node->src[i]) {
if (node->src[i]->data != graph_node_properties->src_address[i] &&
node->op != GGML_OP_VIEW) {
return false;
}

for (int d = 0; d < GGML_MAX_DIMS; d++) {
if (node->src[i]->ne[d] != graph_node_properties->src_ne[i][d]) {
return false;
}
if (node->src[i]->nb[d] != graph_node_properties->src_nb[i][d]) {
return false;
}
}
} else {
if (graph_node_properties->src_address[i] != nullptr) {
return false;
}
}
}
if (node->op == GGML_OP_SCALE &&
memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
return false;

if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU) {
return memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0;
}
return true;
}
Expand Down
Loading