Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

callback to abort ggml_graph_compute() #328

Merged
merged 13 commits into from
Jul 11, 2023
11 changes: 10 additions & 1 deletion include/ggml/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,13 @@
#define GGML_MAX_NAME 48
#define GGML_DEFAULT_N_THREADS 4


#define GGML_EXIT_SUCCESS 0
#define GGML_EXIT_ABORTED 1

ggerganov marked this conversation as resolved.
Show resolved Hide resolved
#define GGML_UNUSED(x) (void)(x)


#define GGML_ASSERT(x) \
do { \
if (!(x)) { \
Expand Down Expand Up @@ -442,6 +447,10 @@ extern "C" {

// the `n_tasks` of nodes, 1:1 mapping to cgraph nodes
int n_tasks[GGML_MAX_NODES];

// abort ggml_graph_compute when true
bool (*abort_callback)(void * data);
void * abort_callback_data;
};

// computation graph
Expand Down Expand Up @@ -1303,7 +1312,7 @@ extern "C" {
// ggml_graph_plan() has to be called before ggml_graph_compute()
// when plan.work_size > 0, caller must allocate memory for plan.work_data
GGML_API struct ggml_cplan ggml_graph_plan (struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/);
GGML_API void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
GGML_API int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph);

// same as ggml_graph_compute() but the work data is allocated as a part of the context
Expand Down
22 changes: 18 additions & 4 deletions src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <float.h>
#include <limits.h>
#include <stdarg.h>
#include <signal.h>

#ifdef GGML_USE_METAL
#include <unistd.h>
Expand Down Expand Up @@ -15946,6 +15947,9 @@ struct ggml_compute_state_shared {
// synchronization primitives
atomic_int n_active; // num active threads
atomic_int node_n; // active graph node

bool (*abort_callback)(void * data); // abort ggml_graph_compute when true
void * abort_callback_data;
};

struct ggml_compute_state {
Expand Down Expand Up @@ -15977,6 +15981,9 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
int node_n = -1;

while (true) {
if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
return GGML_EXIT_ABORTED;
}
if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) {
// all other threads are finished and spinning
// do finalize and init here so we don't have synchronize again
Expand Down Expand Up @@ -16030,6 +16037,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
} else {
break;
}

if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
break;
}
}

atomic_store(&state->shared->n_active, n_threads);
Expand Down Expand Up @@ -16063,9 +16074,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
}
}

return 0;
return GGML_EXIT_SUCCESS;
}

static bool always_false(void * data) { return false; }
struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
if (n_threads <= 0) {
n_threads = GGML_DEFAULT_N_THREADS;
Expand Down Expand Up @@ -16403,7 +16415,7 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
return cplan;
}

void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
{
GGML_ASSERT(cplan);
GGML_ASSERT(cplan->n_threads > 0);
Expand Down Expand Up @@ -16452,12 +16464,12 @@ void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan)
const int64_t perf_start_time_us = ggml_perf_time_us();

// this is a work thread too
ggml_graph_compute_thread(&workers[0]);
int compute_status = ggml_graph_compute_thread(&workers[0]);

// don't leave affinity set on the main thread
clear_numa_thread_affinity();

// join thread pool
// join or kill thread pool
if (n_threads > 1) {
for (int j = 1; j < n_threads; j++) {
const int rc = ggml_thread_join(workers[j].thrd, NULL);
Expand All @@ -16481,6 +16493,8 @@ void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan)
(double) perf_time_us_cur / 1000.0,
(double) cgraph->perf_time_us / 1000.0 / cgraph->perf_runs);
}

return compute_status;
}

void ggml_graph_reset(struct ggml_cgraph * cgraph) {
Expand Down