Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : ggml-backend integration #4766

Merged
merged 39 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
33f0761
llama : ggml-backend integration
slaren Dec 28, 2023
6483328
ggml-backend : add names to buffers
slaren Jan 5, 2024
a1ab35c
fix unmap after loading
slaren Jan 5, 2024
1fa7ee2
batched-bench : add tensor_split param
ggerganov Jan 5, 2024
863ef45
llama : check for null tensor_split
slaren Jan 5, 2024
d107459
ggml-backend : increase GGML_MAX_BACKENDS
slaren Jan 5, 2024
ece0b0d
improve graph splitting, partial fix for --no-kv-offload
slaren Jan 5, 2024
2f2c367
cuda : add ggml-backend split buffer support
slaren Jan 6, 2024
72b74f3
cuda : do not create buffer types for devices that don't exist (fixes…
slaren Jan 6, 2024
f77c72f
ggml : fix null backend dereference (#4807)
ggerganov Jan 7, 2024
7c16cf1
test-backend-ops : check buffer allocation failures
slaren Jan 7, 2024
87c8207
Merge remote-tracking branch 'origin/master' into sl/backend-sched
slaren Jan 7, 2024
5e879c9
llama : add cparam (split_mode) and command line argument (--split-mo…
slaren Jan 7, 2024
ac145fd
ggml : fix mul_mat_id work size
slaren Jan 8, 2024
444b975
llama : rewrite session kv load/set without graphs
slaren Jan 8, 2024
d41cef9
minor
slaren Jan 8, 2024
5a62db3
llama : only initialize used backends, free backends on context free
slaren Jan 8, 2024
4813e17
llama : abort ctx if cuda backend init fails
slaren Jan 8, 2024
11583c1
llama : rewrite lora with ggml-backend and compute on CPU
slaren Jan 8, 2024
4ed5f62
llama : only map to a backend buffer the region of the file mapping c…
slaren Jan 8, 2024
fa76201
opencl : add ggml-backend buffer type
slaren Jan 9, 2024
2e7814a
Merge remote-tracking branch 'origin/master' into sl/backend-sched
slaren Jan 9, 2024
5d2dffc
cuda : only use batched_cublas with batched mat muls (fixes fp16 tg p…
slaren Jan 10, 2024
3cb1c1f
Merge remote-tracking branch 'origin/master' into sl/backend-sched
slaren Jan 10, 2024
07a1b05
llama : on Metal, by default offload the full model
ggerganov Jan 10, 2024
3cd0cbb
metal : page align the data ptr (#4854)
ggerganov Jan 10, 2024
74066f8
Apply suggestions from code review
slaren Jan 10, 2024
c522c11
cuda : fix split buffer free
slaren Jan 10, 2024
9d4ba6e
address review comments
slaren Jan 11, 2024
d83c084
llama-bench : add split-mode parameter
slaren Jan 11, 2024
6dcc42b
fix whitespace
slaren Jan 11, 2024
42aa835
opencl : fix double initialization
slaren Jan 11, 2024
c3681af
Merge remote-tracking branch 'origin/master' into sl/backend-sched
slaren Jan 11, 2024
c486719
server : add --split-mode parameter
slaren Jan 11, 2024
23c14ef
use async copy and compute to improve multi-gpu performance
slaren Jan 11, 2024
e73009e
use async memcpys to copy the graph outputs to the CPU
slaren Jan 12, 2024
1e7694e
fix opencl
slaren Jan 12, 2024
458674c
Merge remote-tracking branch 'origin/master' into sl/backend-sched
slaren Jan 12, 2024
53ae0dd
use a host buffer for the cpu compute buffer for faster copies to the…
slaren Jan 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
65 changes: 39 additions & 26 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -543,9 +543,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
invalid_param = true;
break;
}
#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
params.n_gpu_layers = std::stoi(argv[i]);
#else
#ifndef LLAMA_SUPPORTS_GPU_OFFLOAD
fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
#endif
Expand All @@ -554,9 +553,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
invalid_param = true;
break;
}
#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
params.n_gpu_layers_draft = std::stoi(argv[i]);
#else
#ifndef LLAMA_SUPPORTS_GPU_OFFLOAD
fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers-draft option will be ignored\n");
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
#endif
Expand All @@ -565,40 +563,53 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
invalid_param = true;
break;
}
#ifdef GGML_USE_CUBLAS
params.main_gpu = std::stoi(argv[i]);
#else
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a main GPU.\n");
#endif
#ifndef GGML_USE_CUBLAS
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. Setting the main GPU has no effect.\n");
#endif // GGML_USE_CUBLAS
} else if (arg == "--split-mode" || arg == "-sm") {
if (++i >= argc) {
invalid_param = true;
break;
}
std::string arg_next = argv[i];
if (arg_next == "none") {
params.split_mode = LLAMA_SPLIT_NONE;
} else if (arg_next == "layer") {
params.split_mode = LLAMA_SPLIT_LAYER;
} else if (arg_next == "row") {
params.split_mode = LLAMA_SPLIT_ROW;
} else {
invalid_param = true;
break;
}
#ifndef GGML_USE_CUBLAS
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. Setting the split mode has no effect.\n");
#endif // GGML_USE_CUBLAS
} else if (arg == "--tensor-split" || arg == "-ts") {
if (++i >= argc) {
invalid_param = true;
break;
}
#ifdef GGML_USE_CUBLAS
std::string arg_next = argv[i];

// split string by , and /
const std::regex regex{R"([,/]+)"};
std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
std::vector<std::string> split_arg{it, {}};
GGML_ASSERT(split_arg.size() <= LLAMA_MAX_DEVICES);

if (split_arg.size() >= LLAMA_MAX_DEVICES) {
invalid_param = true;
break;
}
for (size_t i = 0; i < LLAMA_MAX_DEVICES; ++i) {
if (i < split_arg.size()) {
params.tensor_split[i] = std::stof(split_arg[i]);
} else {
params.tensor_split[i] = 0.0f;
}
}
#else
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n");
#endif // GGML_USE_CUBLAS
} else if (arg == "--no-mul-mat-q" || arg == "-nommq") {
#ifdef GGML_USE_CUBLAS
params.mul_mat_q = false;
#else
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. Disabling mul_mat_q kernels has no effect.\n");
#ifndef GGML_USE_CUBLAS
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. Setting a tensor split has no effect.\n");
#endif // GGML_USE_CUBLAS
} else if (arg == "--no-mmap") {
params.use_mmap = false;
Expand Down Expand Up @@ -909,14 +920,15 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" number of layers to store in VRAM\n");
printf(" -ngld N, --n-gpu-layers-draft N\n");
printf(" number of layers to store in VRAM for the draft model\n");
printf(" -sm SPLIT_MODE, --split-mode SPLIT_MODE\n");
printf(" how to split the model across multiple GPUs, one of:\n");
printf(" - none: use one GPU only\n");
printf(" - layer (default): split layers and KV across GPUs\n");
printf(" - row: split rows across GPUs\n");
printf(" -ts SPLIT --tensor-split SPLIT\n");
printf(" how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
printf(" -mg i, --main-gpu i the GPU to use for scratch and small tensors\n");
#ifdef GGML_USE_CUBLAS
printf(" -nommq, --no-mul-mat-q\n");
printf(" use " GGML_CUBLAS_NAME " instead of custom mul_mat_q " GGML_CUDA_NAME " kernels.\n");
printf(" Not recommended since this is both slower and uses more VRAM.\n");
#endif // GGML_USE_CUBLAS
printf(" fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1\n");
printf(" -mg i, --main-gpu i the GPU to use for the model (with split-mode = none),\n");
printf(" or for intermediate results and KV (with split-mode = row) (default: %d)\n", params.main_gpu);
#endif
printf(" -gan N, --grp-attn-n N\n");
printf(" group-attention factor (default: %d)\n", params.grp_attn_n);
Expand Down Expand Up @@ -1033,6 +1045,7 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
mparams.n_gpu_layers = params.n_gpu_layers;
}
mparams.main_gpu = params.main_gpu;
mparams.split_mode = params.split_mode;
mparams.tensor_split = params.tensor_split;
mparams.use_mmap = params.use_mmap;
mparams.use_mlock = params.use_mlock;
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ struct gpt_params {
float p_split = 0.1f; // speculative decoding split probability
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
llama_split_mode split_mode = LLAMA_SPLIT_LAYER; // how to split the model across GPUs
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
int32_t n_beams = 0; // if non-zero then use beam search of given width.
Expand Down
3 changes: 3 additions & 0 deletions examples/batched-bench/batched-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@ int main(int argc, char ** argv) {

llama_model_params model_params = llama_model_default_params();

const std::vector<float> t_split (LLAMA_MAX_DEVICES, 0.0f);

model_params.n_gpu_layers = n_gpu_layers;
model_params.tensor_split = t_split.data();

llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);

Expand Down
12 changes: 12 additions & 0 deletions ggml-alloc.c
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ void ggml_tallocr_reset(ggml_tallocr_t alloc) {
alloc->free_blocks[0].size = SIZE_MAX/2; // restrict maximum size of a measure allocator to half size_t max to avoid overflows
} else {
alloc->free_blocks[0].size = ggml_backend_buffer_get_size(alloc->buffer) - align_offset;
ggml_backend_buffer_reset(alloc->buffer);
}
}

Expand Down Expand Up @@ -779,10 +780,21 @@ ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_conte

if (nbytes == 0) {
// all the tensors in the context are already allocated
#ifndef NDEBUG
fprintf(stderr, "%s: all tensors in the context are already allocated\n", __func__);
#endif
return NULL;
}

ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, nbytes);
if (buffer == NULL) {
// failed to allocate buffer
#ifndef NDEBUG
fprintf(stderr, "%s: failed to allocate buffer\n", __func__);
#endif
return NULL;
}

ggml_tallocr_t tallocr = ggml_tallocr_new_from_buffer(buffer);

for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
Expand Down
29 changes: 16 additions & 13 deletions ggml-backend-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ extern "C" {
typedef void * ggml_backend_buffer_type_context_t;

struct ggml_backend_buffer_type_i {
const char * (*get_name) (ggml_backend_buffer_type_t buft);
ggml_backend_buffer_t (*alloc_buffer) (ggml_backend_buffer_type_t buft, size_t size);
size_t (*get_alignment) (ggml_backend_buffer_type_t buft); // tensor alignment
size_t (*get_alloc_size) (ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor); // data size needed to allocate the tensor, including padding
size_t (*get_alloc_size) (ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor); // data size needed to allocate the tensor, including padding
bool (*supports_backend)(ggml_backend_buffer_type_t buft, ggml_backend_t backend); // check if the buffer type is usable by the backend
// check if tensor data is in host memory
// should be equivalent to supports_backend(buft, ggml_backend_cpu_init())
Expand All @@ -34,23 +35,25 @@ extern "C" {
typedef void * ggml_backend_buffer_context_t;

struct ggml_backend_buffer_i {
void (*free_buffer) (ggml_backend_buffer_t buffer);
//void (*reset) (ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras
void * (*get_base) (ggml_backend_buffer_t buffer);
void (*init_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
void (*set_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
void (*get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
const char * (*get_name) (ggml_backend_buffer_t buffer);
void (*free_buffer) (ggml_backend_buffer_t buffer);
void * (*get_base) (ggml_backend_buffer_t buffer);
void (*init_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
void (*set_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
void (*get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
// (optional) copy tensor between different buffer-type, allow for single-copy tranfers
void (*cpy_tensor_from)(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst);
void (*cpy_tensor_to) (ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst);
void (*clear) (ggml_backend_buffer_t buffer, uint8_t value);
void (*cpy_tensor_from)(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst);
void (*cpy_tensor_to) (ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst);
void (*clear) (ggml_backend_buffer_t buffer, uint8_t value);
void (*reset) (ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras
};

struct ggml_backend_buffer {
struct ggml_backend_buffer_i iface;
ggml_backend_buffer_type_t buft;
ggml_backend_buffer_context_t context;
size_t size;
enum ggml_backend_buffer_usage usage;
};

ggml_backend_buffer_t ggml_backend_buffer_init(
Expand Down Expand Up @@ -79,13 +82,13 @@ extern "C" {
void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);

// (optional) asynchroneous tensor copy
void (*cpy_tensor_from_async)(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
void (*cpy_tensor_to_async) (ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
void (*cpy_tensor_from_async)(ggml_backend_t backend, const struct ggml_tensor * src, struct ggml_tensor * dst);
void (*cpy_tensor_to_async) (ggml_backend_t backend, const struct ggml_tensor * src, struct ggml_tensor * dst);

void (*synchronize)(ggml_backend_t backend);

// compute graph with a plan
ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, const struct ggml_cgraph * cgraph);
void (*graph_plan_free) (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
void (*graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);

Expand Down