Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : add pipeline parallelism support #6017

Merged
merged 23 commits into from Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
822121f
llama : add pipeline parallelism support for batch processing with mu…
slaren Feb 13, 2024
1ac668e
server : add -ub, --ubatch-size parameter
slaren Mar 12, 2024
4ddccc2
fix server embedding test
slaren Mar 12, 2024
937966d
llama : fix Mamba inference for pipeline parallelism
compilade Mar 12, 2024
00a415d
llama : limit max batch size to n_batch
slaren Mar 12, 2024
89bfa1f
add LLAMA_SCHED_MAX_COPIES to configure the number of input copies fo…
slaren Mar 12, 2024
aa1e2f8
fix hip build
slaren Mar 12, 2024
deb3e24
Merge remote-tracking branch 'origin/master' into sl/pipeline-paralle…
slaren Mar 12, 2024
ead5c8b
fix sycl build (disable cpy_tensor_async)
slaren Mar 12, 2024
255c1ec
fix hip build
slaren Mar 12, 2024
4400153
llama : limit n_batch and n_ubatch to n_ctx during context creation
slaren Mar 13, 2024
9e7cecc
llama : fix norm backend
slaren Mar 13, 2024
b25a0f1
batched-bench : sync after decode
ggerganov Mar 13, 2024
529e749
swiftui : sync after decode
ggerganov Mar 13, 2024
54cdd47
ggml : allow ggml_get_rows to use multiple threads if they are available
slaren Mar 13, 2024
cda49d3
check n_ubatch >= n_tokens with non-casual attention
slaren Mar 13, 2024
015e1bf
llama : do not limit n_batch to n_ctx with non-casual attn
slaren Mar 13, 2024
0d934ee
server : construct batch with size of llama_n_batch
ggerganov Mar 13, 2024
3c38789
ggml_backend_cpu_graph_compute : fix return value when alloc fails
slaren Mar 13, 2024
9092883
llama : better n_batch and n_ubatch comment
slaren Mar 13, 2024
cb580a6
fix merge
slaren Mar 13, 2024
1f56481
small fix
slaren Mar 13, 2024
976176d
reduce default n_batch to 2048
slaren Mar 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 12 additions & 2 deletions common/common.cpp
Expand Up @@ -483,6 +483,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break;
}
params.n_batch = std::stoi(argv[i]);
} else if (arg == "-ub" || arg == "--ubatch-size") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.n_ubatch = std::stoi(argv[i]);
} else if (arg == "--keep") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -977,7 +983,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" binary file containing multiple choice tasks.\n");
printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict);
printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx);
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
printf(" -b N, --batch-size N logical maximum batch size (default: %d)\n", params.n_batch);
printf(" -ub N, --ubatch-size N\n");
printf(" physical maximum batch size (default: %d)\n", params.n_ubatch);
printf(" --samplers samplers that will be used for generation in the order, separated by \';\'\n");
printf(" (default: %s)\n", sampler_type_names.c_str());
printf(" --sampling-seq simplified sequence for samplers that will be used (default: %s)\n", sampler_type_chars.c_str());
Expand Down Expand Up @@ -1287,8 +1295,9 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
auto cparams = llama_context_default_params();

cparams.n_ctx = params.n_ctx;
cparams.n_batch = params.n_batch;
cparams.n_seq_max = params.n_parallel;
cparams.n_batch = params.n_batch;
cparams.n_ubatch = params.n_ubatch;
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
cparams.seed = params.seed;
Expand Down Expand Up @@ -1379,6 +1388,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
std::vector<llama_token> tmp = { llama_token_bos(model), llama_token_eos(model), };
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
llama_kv_cache_clear(lctx);
llama_synchronize(lctx);
llama_reset_timings(lctx);
}

Expand Down
3 changes: 2 additions & 1 deletion common/common.h
Expand Up @@ -51,7 +51,8 @@ struct gpt_params {
int32_t n_threads_batch_draft = -1;
int32_t n_predict = -1; // new tokens to predict
int32_t n_ctx = 512; // context size
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_batch = 4096; // logical batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_draft = 5; // number of tokens to draft during speculative decoding
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
Expand Down
2 changes: 1 addition & 1 deletion examples/embedding/embedding.cpp
Expand Up @@ -107,7 +107,7 @@ int main(int argc, char ** argv) {

// max batch size
const uint64_t n_batch = params.n_batch;
GGML_ASSERT(params.n_batch == params.n_ctx);
GGML_ASSERT(params.n_batch >= params.n_ctx);

// tokenize the prompts and trim
std::vector<std::vector<int32_t>> inputs;
Expand Down
53 changes: 43 additions & 10 deletions examples/llama-bench/llama-bench.cpp
Expand Up @@ -164,6 +164,7 @@ struct cmd_params {
std::vector<int> n_prompt;
std::vector<int> n_gen;
std::vector<int> n_batch;
std::vector<int> n_ubatch;
std::vector<ggml_type> type_k;
std::vector<ggml_type> type_v;
std::vector<int> n_threads;
Expand All @@ -183,7 +184,8 @@ static const cmd_params cmd_params_defaults = {
/* model */ {"models/7B/ggml-model-q4_0.gguf"},
/* n_prompt */ {512},
/* n_gen */ {128},
/* n_batch */ {512},
/* n_batch */ {4096},
/* n_ubatch */ {512},
/* type_k */ {GGML_TYPE_F16},
/* type_v */ {GGML_TYPE_F16},
/* n_threads */ {get_num_physical_cores()},
Expand All @@ -208,6 +210,7 @@ static void print_usage(int /* argc */, char ** argv) {
printf(" -p, --n-prompt <n> (default: %s)\n", join(cmd_params_defaults.n_prompt, ",").c_str());
printf(" -n, --n-gen <n> (default: %s)\n", join(cmd_params_defaults.n_gen, ",").c_str());
printf(" -b, --batch-size <n> (default: %s)\n", join(cmd_params_defaults.n_batch, ",").c_str());
printf(" -ub N, --ubatch-size <n> (default: %s)\n", join(cmd_params_defaults.n_ubatch, ",").c_str());
printf(" -ctk <t>, --cache-type-k <t> (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_k, ggml_type_name), ",").c_str());
printf(" -ctv <t>, --cache-type-v <t> (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_v, ggml_type_name), ",").c_str());
printf(" -t, --threads <n> (default: %s)\n", join(cmd_params_defaults.n_threads, ",").c_str());
Expand All @@ -217,7 +220,7 @@ static void print_usage(int /* argc */, char ** argv) {
printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str());
printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str());
printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str());
printf(" -ts, --tensor_split <ts0/ts1/..> (default: 0)\n");
printf(" -ts, --tensor-split <ts0/ts1/..> (default: 0)\n");
printf(" -r, --repetitions <n> (default: %d)\n", cmd_params_defaults.reps);
printf(" -o, --output <csv|json|md|sql> (default: %s)\n", output_format_str(cmd_params_defaults.output_format));
printf(" -v, --verbose (default: %s)\n", cmd_params_defaults.verbose ? "1" : "0");
Expand Down Expand Up @@ -297,6 +300,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
}
auto p = split<int>(argv[i], split_delim);
params.n_batch.insert(params.n_batch.end(), p.begin(), p.end());
} else if (arg == "-ub" || arg == "--ubatch-size") {
if (++i >= argc) {
invalid_param = true;
break;
}
auto p = split<int>(argv[i], split_delim);
params.n_ubatch.insert(params.n_ubatch.end(), p.begin(), p.end());
} else if (arg == "-ctk" || arg == "--cache-type-k") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -455,6 +465,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
if (params.n_prompt.empty()) { params.n_prompt = cmd_params_defaults.n_prompt; }
if (params.n_gen.empty()) { params.n_gen = cmd_params_defaults.n_gen; }
if (params.n_batch.empty()) { params.n_batch = cmd_params_defaults.n_batch; }
if (params.n_ubatch.empty()) { params.n_ubatch = cmd_params_defaults.n_ubatch; }
if (params.type_k.empty()) { params.type_k = cmd_params_defaults.type_k; }
if (params.type_v.empty()) { params.type_v = cmd_params_defaults.type_v; }
if (params.n_gpu_layers.empty()) { params.n_gpu_layers = cmd_params_defaults.n_gpu_layers; }
Expand All @@ -474,6 +485,7 @@ struct cmd_params_instance {
int n_prompt;
int n_gen;
int n_batch;
int n_ubatch;
ggml_type type_k;
ggml_type type_v;
int n_threads;
Expand Down Expand Up @@ -511,6 +523,7 @@ struct cmd_params_instance {

cparams.n_ctx = n_prompt + n_gen;
cparams.n_batch = n_batch;
cparams.n_ubatch = n_ubatch;
cparams.type_k = type_k;
cparams.type_v = type_v;
cparams.offload_kqv = !no_kv_offload;
Expand All @@ -532,6 +545,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
for (const auto & mmp : params.use_mmap)
for (const auto & embd : params.embeddings)
for (const auto & nb : params.n_batch)
for (const auto & nub : params.n_ubatch)
for (const auto & tk : params.type_k)
for (const auto & tv : params.type_v)
for (const auto & nkvo : params.no_kv_offload)
Expand All @@ -545,6 +559,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .n_prompt = */ n_prompt,
/* .n_gen = */ 0,
/* .n_batch = */ nb,
/* .n_ubatch = */ nub,
/* .type_k = */ tk,
/* .type_v = */ tv,
/* .n_threads = */ nt,
Expand All @@ -568,6 +583,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .n_prompt = */ 0,
/* .n_gen = */ n_gen,
/* .n_batch = */ nb,
/* .n_ubatch = */ nub,
/* .type_k = */ tk,
/* .type_v = */ tv,
/* .n_threads = */ nt,
Expand Down Expand Up @@ -604,6 +620,7 @@ struct test {
uint64_t model_size;
uint64_t model_n_params;
int n_batch;
int n_ubatch;
int n_threads;
ggml_type type_k;
ggml_type type_v;
Expand All @@ -627,6 +644,7 @@ struct test {
model_size = llama_model_size(lmodel);
model_n_params = llama_model_n_params(lmodel);
n_batch = inst.n_batch;
n_ubatch = inst.n_ubatch;
n_threads = inst.n_threads;
type_k = inst.type_k;
type_v = inst.type_v;
Expand Down Expand Up @@ -705,7 +723,8 @@ struct test {
"cuda", "opencl", "vulkan", "kompute", "metal", "sycl", "gpu_blas", "blas",
"cpu_info", "gpu_info",
"model_filename", "model_type", "model_size", "model_n_params",
"n_batch", "n_threads", "type_k", "type_v",
"n_batch", "n_ubatch",
"n_threads", "type_k", "type_v",
"n_gpu_layers", "split_mode",
"main_gpu", "no_kv_offload",
"tensor_split", "use_mmap", "embeddings",
Expand All @@ -719,7 +738,8 @@ struct test {
enum field_type {STRING, BOOL, INT, FLOAT};

static field_type get_field_type(const std::string & field) {
if (field == "build_number" || field == "n_batch" || field == "n_threads" ||
if (field == "build_number" || field == "n_batch" || field == "n_ubatch" ||
field == "n_threads" ||
field == "model_size" || field == "model_n_params" ||
field == "n_gpu_layers" || field == "main_gpu" ||
field == "n_prompt" || field == "n_gen" ||
Expand Down Expand Up @@ -759,7 +779,8 @@ struct test {
std::to_string(metal), std::to_string(sycl), std::to_string(gpu_blas), std::to_string(blas),
cpu_info, gpu_info,
model_filename, model_type, std::to_string(model_size), std::to_string(model_n_params),
std::to_string(n_batch), std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v),
std::to_string(n_batch), std::to_string(n_ubatch),
std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v),
std::to_string(n_gpu_layers), split_mode_str(split_mode),
std::to_string(main_gpu), std::to_string(no_kv_offload),
tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings),
Expand Down Expand Up @@ -957,6 +978,9 @@ struct markdown_printer : public printer {
if (params.n_batch.size() > 1 || params.n_batch != cmd_params_defaults.n_batch) {
fields.emplace_back("n_batch");
}
if (params.n_ubatch.size() > 1 || params.n_ubatch != cmd_params_defaults.n_ubatch) {
fields.emplace_back("n_ubatch");
}
if (params.type_k.size() > 1 || params.type_k != cmd_params_defaults.type_k) {
fields.emplace_back("type_k");
}
Expand Down Expand Up @@ -1096,25 +1120,32 @@ struct sql_printer : public printer {
};

static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) {
llama_set_n_threads(ctx, n_threads, n_threads);

//std::vector<llama_token> tokens(n_prompt, llama_token_bos(llama_get_model(ctx)));
//llama_decode(ctx, llama_batch_get_one(tokens.data(), n_prompt, n_past, 0));
//GGML_UNUSED(n_batch);

std::vector<llama_token> tokens(n_batch, llama_token_bos(llama_get_model(ctx)));
int n_processed = 0;

llama_set_n_threads(ctx, n_threads, n_threads);

while (n_processed < n_prompt) {
int n_tokens = std::min(n_prompt - n_processed, n_batch);
llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens, n_past + n_processed, 0));
n_processed += n_tokens;
}

llama_synchronize(ctx);
}

static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) {
llama_token token = llama_token_bos(llama_get_model(ctx));

llama_set_n_threads(ctx, n_threads, n_threads);

llama_token token = llama_token_bos(llama_get_model(ctx));

for (int i = 0; i < n_gen; i++) {
llama_decode(ctx, llama_batch_get_one(&token, 1, n_past + i, 0));
llama_synchronize(ctx);
}
}

Expand Down Expand Up @@ -1203,7 +1234,8 @@ int main(int argc, char ** argv) {

// warmup run
if (t.n_prompt > 0) {
test_prompt(ctx, std::min(2, t.n_batch), 0, t.n_batch, t.n_threads);
//test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads);
test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
}
if (t.n_gen > 0) {
test_gen(ctx, 1, 0, t.n_threads);
Expand All @@ -1219,6 +1251,7 @@ int main(int argc, char ** argv) {
if (t.n_gen > 0) {
test_gen(ctx, t.n_gen, t.n_prompt, t.n_threads);
}

uint64_t t_ns = get_time_ns() - t_start;
t.samples_ns.push_back(t_ns);
}
Expand Down
3 changes: 2 additions & 1 deletion examples/perplexity/perplexity.cpp
Expand Up @@ -589,9 +589,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
}
}

const auto t_end = std::chrono::high_resolution_clock::now();

if (i == 0) {
llama_synchronize(ctx);
const auto t_end = std::chrono::high_resolution_clock::now();
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
int total_seconds = (int)(t_total*n_chunk/n_seq);
Expand Down
9 changes: 8 additions & 1 deletion examples/server/server.cpp
Expand Up @@ -2157,7 +2157,8 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
printf(" --pooling {none,mean,cls} pooling type for embeddings, use model default if unspecified\n");
printf(" -dt N, --defrag-thold N\n");
printf(" KV cache defragmentation threshold (default: %.1f, < 0 - disabled)\n", params.defrag_thold);
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
printf(" -b N, --batch-size N logical maximum batch size (default: %d)\n", params.n_batch);
printf(" -ub N, --ubatch-size N physical maximum batch size (default: %d)\n", params.n_ubatch);
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
printf(" not recommended: doubles context memory required and no measurable increase in quality\n");
if (llama_supports_mlock()) {
Expand Down Expand Up @@ -2424,6 +2425,12 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
break;
}
params.n_batch = std::stoi(argv[i]);
} else if (arg == "-ub" || arg == "--ubatch-size") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.n_ubatch = std::stoi(argv[i]);
} else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") {
if (++i >= argc) {
invalid_param = true;
Expand Down