Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : custom attention mask + parallel decoding + no context swaps #3228

Merged
merged 57 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
c5df72e
tests : verify that RoPE is "additive"
ggerganov Sep 17, 2023
3b4bab6
llama : replace ggml_diag_mask_inf with ggml_add (custom -inf mask)
ggerganov Sep 17, 2023
1fb033f
ggml : ggml_rope now takes a vector with positions instead of n_past
ggerganov Sep 17, 2023
fad5693
metal : add rope_f16 kernel + optimize cpy kernels
ggerganov Sep 17, 2023
d29e769
llama : unified KV cache + batch inference API
ggerganov Sep 18, 2023
58bb511
Merge branch 'master' into custom-attention-mask
ggerganov Sep 18, 2023
9f42e75
llama : add new llama_decode() API that works with llama_batch
ggerganov Sep 18, 2023
6952a46
llama : add cell_max heuristic for more efficient kv_cache
ggerganov Sep 18, 2023
4d76d76
llama : extend llama_kv_cache API
ggerganov Sep 18, 2023
f015b26
llama : more robust cell_max heuristic + wip shift
ggerganov Sep 18, 2023
86c90e3
metal : disable concurrency optimization
ggerganov Sep 18, 2023
0cbf3bf
llama : add llama_kv_cache_shift_seq + no more context swaps
ggerganov Sep 18, 2023
7c1bdd0
llama : apply K-cache roping for Falcon and Baichuan
ggerganov Sep 18, 2023
1f17ea6
speculative : fix KV cache management
ggerganov Sep 18, 2023
0161372
parallel : example for serving multiple users in parallel
ggerganov Sep 18, 2023
466b513
parallel : disable hot-plug to avoid cache fragmentation
ggerganov Sep 18, 2023
897cacc
fixes : speculative KV cache + llama worst-case graph
ggerganov Sep 18, 2023
fa0e677
llama : extend batch API to select which logits to output
ggerganov Sep 18, 2023
daf4c6d
llama : fix worst case graph build
ggerganov Sep 19, 2023
7e2b997
ggml-cuda : update rope implementation for parallel decoding (#3254)
slaren Sep 19, 2023
25bd254
make : add parallel to build + fix static functions in llama.cpp
ggerganov Sep 19, 2023
467e307
simple : fix token counting
ggerganov Sep 19, 2023
36714e1
parallel : various improvements
ggerganov Sep 19, 2023
ddad227
llama : fix cell_max logic + rename functions
ggerganov Sep 19, 2023
806d397
parallel : try smaller batches when the KV cache is fragmented
ggerganov Sep 19, 2023
16090a5
parallel : fix sequence termination criteria
ggerganov Sep 19, 2023
d37081a
llama : silence errors KV cache errors
ggerganov Sep 19, 2023
82e20e9
parallel : remove new line from prompt
ggerganov Sep 19, 2023
4b5f3cd
parallel : process system prompt once + configurable paramters + llam…
ggerganov Sep 19, 2023
8a9aca3
parallel : remove question with short answers
ggerganov Sep 19, 2023
eed3fd4
parallel : count cache misses
ggerganov Sep 19, 2023
6028879
parallel : print misses on each request
ggerganov Sep 19, 2023
7b7472e
parallel : minor
ggerganov Sep 19, 2023
e1067ef
llama : fix n_kv to never become 0
ggerganov Sep 20, 2023
a1327c7
parallel : rename hot-plug to continuous-batching
ggerganov Sep 20, 2023
addae65
llama : improve llama_batch API + simplify parallel example
ggerganov Sep 20, 2023
b377bf2
simple : add parallel decoding support
ggerganov Sep 20, 2023
db0fc2d
simple : improve comments + free batch
ggerganov Sep 20, 2023
e04dc51
ggml-cuda : add rope f16, restore performance with parallel decoding …
slaren Sep 20, 2023
5420696
llama : disable MPI for now
ggerganov Sep 20, 2023
2f3a46f
train : make KQ_pos memory buffer permanent via dummy scale op
ggerganov Sep 20, 2023
1be2b8c
ggml : revert change to ggml_cpy, add ggml_cont_Nd instead (#3275)
slaren Sep 20, 2023
ee1d670
parallel : fix bug (extra BOS) + smaller token_prev array
ggerganov Sep 20, 2023
ded9b43
parallel : fix cases where the input prompts can overflow the batch
ggerganov Sep 20, 2023
b2debf6
parallel : add disabled experimental batch chunking in powers of two
ggerganov Sep 20, 2023
5a3369d
llama : llama.h formatting + comments
ggerganov Sep 21, 2023
8845160
simple : add README.md
ggerganov Sep 21, 2023
c1596f6
llama : fix kv cache heuristic when context is less than 32
ggerganov Sep 27, 2023
2585690
Merge branch 'master' into custom-attention-mask
ggerganov Sep 28, 2023
4ad0676
parallel : fix crash when `-n -1`
ggerganov Sep 28, 2023
e946379
llama : simplify returns if/else branches
ggerganov Sep 28, 2023
4c72ab1
metal : use mm kernels for batch size > 2
ggerganov Sep 28, 2023
d008733
examples : utilize new llama_get_logits_ith()
ggerganov Sep 28, 2023
a207561
examples : add example for batched decoding
ggerganov Sep 28, 2023
2b8830a
examples : do not eval prompt 2 times (close #3348)
ggerganov Sep 28, 2023
ce2d995
server : clear the KV cache beyond n_past before llama_decode
ggerganov Sep 28, 2023
c5650ed
server : avoid context swaps by shifting the KV cache
ggerganov Sep 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ models-mnt
/save-load-state
/server
/simple
/batched
/speculative
/parallel
/train-text-from-scratch
/vdot
build-info.h
Expand Down
8 changes: 7 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Define the default target now so that it is always the first target
BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch convert-llama2c-to-ggml simple save-load-state server embd-input-test gguf llama-bench baby-llama beam-search speculative tests/test-c.o
BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch convert-llama2c-to-ggml simple batched save-load-state server embd-input-test gguf llama-bench baby-llama beam-search speculative parallel tests/test-c.o

# Binaries only useful for tests
TEST_TARGETS = tests/test-llama-grammar tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0-llama tests/test-tokenizer-0-falcon tests/test-tokenizer-1-llama
Expand Down Expand Up @@ -519,6 +519,9 @@ main: examples/main/main.cpp build-info.h ggml.
simple: examples/simple/simple.cpp build-info.h ggml.o llama.o common.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)

batched: examples/batched/batched.cpp build-info.h ggml.o llama.o common.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)

quantize: examples/quantize/quantize.cpp build-info.h ggml.o llama.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)

Expand Down Expand Up @@ -565,6 +568,9 @@ beam-search: examples/beam-search/beam-search.cpp build-info.h ggml.o llama.o co
speculative: examples/speculative/speculative.cpp build-info.h ggml.o llama.o common.o grammar-parser.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)

parallel: examples/parallel/parallel.cpp build-info.h ggml.o llama.o common.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)

ifdef LLAMA_METAL
metal: examples/metal/metal.cpp ggml.o $(OBJS)
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
Expand Down
43 changes: 29 additions & 14 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
params.n_chunks = std::stoi(argv[i]);
} else if (arg == "-np" || arg == "--parallel") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.n_parallel = std::stoi(argv[i]);
} else if (arg == "-ns" || arg == "--sequences") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.n_sequences = std::stoi(argv[i]);
} else if (arg == "-m" || arg == "--model") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -360,6 +372,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.multiline_input = true;
} else if (arg == "--simple-io") {
params.simple_io = true;
} else if (arg == "-cb" || arg == "--cont-batching") {
params.cont_batching = true;
} else if (arg == "--color") {
params.use_color = true;
} else if (arg == "--mlock") {
Expand Down Expand Up @@ -436,8 +450,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.use_mmap = false;
} else if (arg == "--numa") {
params.numa = true;
} else if (arg == "--export") {
params.export_cgraph = true;
} else if (arg == "--verbose-prompt") {
params.verbose_prompt = true;
} else if (arg == "-r" || arg == "--reverse-prompt") {
Expand All @@ -456,8 +468,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
if (params.logdir.back() != DIRECTORY_SEPARATOR) {
params.logdir += DIRECTORY_SEPARATOR;
}
} else if (arg == "--perplexity") {
params.perplexity = true;
} else if (arg == "--perplexity" || arg == "--all-logits") {
params.logits_all = true;
} else if (arg == "--ppl-stride") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -655,12 +667,15 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
printf(" not recommended: doubles context memory required and no measurable increase in quality\n");
printf(" --temp N temperature (default: %.1f)\n", (double)params.temp);
printf(" --perplexity compute perplexity over each ctx window of the prompt\n");
printf(" --logits-all return logits for all tokens in the batch (default: disabled)\n");
printf(" --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n");
printf(" --hellaswag-tasks N number of tasks to use when computing the HellaSwag score (default: %zu)\n", params.hellaswag_tasks);
printf(" --keep N number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
printf(" --draft N number of tokens to draft for speculative decoding (default: %d)\n", params.n_draft);
printf(" --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks);
printf(" -np N, --parallel N number of parallel sequences to decode (default: %d)\n", params.n_parallel);
printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences);
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
if (llama_mlock_supported()) {
printf(" --mlock force system to keep model in RAM rather than swapping or compressing\n");
}
Expand All @@ -685,7 +700,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" Not recommended since this is both slower and uses more VRAM.\n");
#endif // GGML_USE_CUBLAS
#endif
printf(" --export export the computation graph to 'llama.ggml'\n");
printf(" --verbose-prompt print prompt before generation\n");
fprintf(stderr, " --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n");
printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
Expand Down Expand Up @@ -738,7 +752,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
lparams.f16_kv = params.memory_f16;
lparams.use_mmap = params.use_mmap;
lparams.use_mlock = params.use_mlock;
lparams.logits_all = params.perplexity;
lparams.logits_all = params.logits_all;
lparams.embedding = params.embedding;
lparams.rope_freq_base = params.rope_freq_base;
lparams.rope_freq_scale = params.rope_freq_scale;
Expand Down Expand Up @@ -782,8 +796,9 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
{
LOG("warming up the model with an empty run\n");

const std::vector<llama_token> tmp = { llama_token_bos(lctx), llama_token_eos(lctx), };
llama_eval(lctx, tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, params.n_threads);
std::vector<llama_token> tmp = { llama_token_bos(lctx), llama_token_eos(lctx), };
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0), params.n_threads);
llama_kv_cache_tokens_rm(lctx, -1, -1);
llama_reset_timings(lctx);
}

Expand Down Expand Up @@ -890,7 +905,7 @@ llama_token llama_sample_token(

llama_token id = 0;

float * logits = llama_get_logits(ctx) + idx * n_vocab;
float * logits = llama_get_logits_ith(ctx, idx);

// Apply params.logit_bias map
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
Expand Down Expand Up @@ -941,19 +956,19 @@ llama_token llama_sample_token(
if (mirostat == 1) {
static float mirostat_mu = 2.0f * mirostat_tau;
const int mirostat_m = 100;
llama_sample_temperature(ctx, &cur_p, temp);
llama_sample_temp(ctx, &cur_p, temp);
id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
} else if (mirostat == 2) {
static float mirostat_mu = 2.0f * mirostat_tau;
llama_sample_temperature(ctx, &cur_p, temp);
llama_sample_temp(ctx, &cur_p, temp);
id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu);
} else {
// Temperature sampling
llama_sample_top_k (ctx, &cur_p, top_k, 1);
llama_sample_tail_free (ctx, &cur_p, tfs_z, 1);
llama_sample_typical (ctx, &cur_p, typical_p, 1);
llama_sample_top_p (ctx, &cur_p, top_p, 1);
llama_sample_temperature(ctx, &cur_p, temp);
llama_sample_temp(ctx, &cur_p, temp);

{
const int n_top = 10;
Expand Down Expand Up @@ -1182,7 +1197,6 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false");
fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false");
fprintf(stream, "export: %s # default: false\n", params.export_cgraph ? "true" : "false");
fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n");
fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", params.frequency_penalty);
dump_string_yaml_multiline(stream, "grammar", params.grammar.c_str());
Expand Down Expand Up @@ -1256,6 +1270,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "rope_freq_scale: %f # default: 1.0\n", params.rope_freq_scale);
fprintf(stream, "seed: %d # default: -1 (random seed)\n", params.seed);
fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false");
fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false");
fprintf(stream, "temp: %f # default: 0.8\n", params.temp);

const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + LLAMA_MAX_DEVICES);
Expand Down
8 changes: 5 additions & 3 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ struct gpt_params {
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_draft = 16; // number of tokens to draft during speculative decoding
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
int32_t n_parallel = 1; // number of parallel sequences to decode
int32_t n_sequences = 1; // number of sequences to decode
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
Expand Down Expand Up @@ -107,16 +109,16 @@ struct gpt_params {
bool interactive_first = false; // wait for user input immediately
bool multiline_input = false; // reverse the usage of `\`
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
bool cont_batching = false; // insert new sequences for decoding on-the-fly

bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
bool ignore_eos = false; // ignore generated EOS tokens
bool instruct = false; // instruction mode (used for Alpaca models)
bool penalize_nl = true; // consider newlines as a repeatable token
bool perplexity = false; // compute perplexity over the prompt
bool logits_all = false; // return logits for all tokens in the batch
bool use_mmap = true; // use mmap for faster loads
bool use_mlock = false; // use mlock to keep model in memory
bool numa = false; // attempt optimizations that help on some NUMA systems
bool export_cgraph = false; // export the computation graph
bool verbose_prompt = false; // print prompt tokens before generation
};

Expand Down Expand Up @@ -181,7 +183,7 @@ std::string llama_detokenize_bpe(
// - ctx_guidance: context to use for classifier-free guidance, ignore if NULL
// - grammar: grammar to use for sampling, ignore if NULL
// - last_tokens: needed for repetition penalty, ignore if empty
// - idx: sample from llama_get_logits(ctx) + idx * n_vocab
// - idx: sample from llama_get_logits_ith(ctx, idx)
//
// returns:
// - token: sampled token
Expand Down
2 changes: 2 additions & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ else()
add_subdirectory(train-text-from-scratch)
add_subdirectory(convert-llama2c-to-ggml)
add_subdirectory(simple)
add_subdirectory(batched)
add_subdirectory(speculative)
add_subdirectory(parallel)
add_subdirectory(embd-input)
add_subdirectory(llama-bench)
add_subdirectory(beam-search)
Expand Down
37 changes: 31 additions & 6 deletions examples/baby-llama/baby-llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,14 @@ static struct ggml_tensor * forward(
struct ggml_tensor * kc = kv_self.k;
struct ggml_tensor * vc = kv_self.v;

struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
{
int * data = (int *) KQ_pos->data;
for (int i = 0; i < N; ++i) {
data[i] = n_past + i;
}
}

// inpL shape [n_embd,N,1,1]
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens);
for (int il = 0; il < n_layer; ++il) {
Expand Down Expand Up @@ -581,8 +589,8 @@ static struct ggml_tensor * forward(
// wk shape [n_embd, n_embd, 1, 1]
// Qcur shape [n_embd/n_head, n_head, N, 1]
// Kcur shape [n_embd/n_head, n_head, N, 1]
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, 0);
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, 0);
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N), KQ_pos, n_rot, 0, 0);
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N), KQ_pos, n_rot, 0, 0);

// store key and value to memory
{
Expand Down Expand Up @@ -808,9 +816,18 @@ static struct ggml_tensor * forward_batch(
struct ggml_tensor * kc = kv_self.k;
struct ggml_tensor * vc = kv_self.v;

struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
{
int * data = (int *) KQ_pos->data;
for (int i = 0; i < N; ++i) {
data[i] = n_past + i;
}
}

// inpL shape [n_embd,N*n_batch,1]
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens);
assert_shape_2d(inpL, n_embd, N*n_batch);

for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL;

Expand Down Expand Up @@ -838,8 +855,8 @@ static struct ggml_tensor * forward_batch(
// wk shape [n_embd, n_embd, 1, 1]
// Qcur shape [n_embd/n_head, n_head, N, n_batch]
// Kcur shape [n_embd/n_head, n_head, N, n_batch]
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0, 0);
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0, 0);
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), KQ_pos, n_rot, 0, 0);
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), KQ_pos, n_rot, 0, 0);
assert_shape_4d(Qcur, n_embd/n_head, n_head, N, n_batch);
assert_shape_4d(Kcur, n_embd/n_head, n_head, N, n_batch);

Expand Down Expand Up @@ -1097,6 +1114,14 @@ static struct ggml_tensor * forward_lora(
struct ggml_tensor * kc = kv_self.k;
struct ggml_tensor * vc = kv_self.v;

struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
{
int * data = (int *) KQ_pos->data;
for (int i = 0; i < N; ++i) {
data[i] = n_past + i;
}
}

// inpL shape [n_embd,N,1,1]
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens);
for (int il = 0; il < n_layer; ++il) {
Expand Down Expand Up @@ -1130,7 +1155,7 @@ static struct ggml_tensor * forward_lora(
model->layers[il].wqb,
cur)),
n_embd/n_head, n_head, N),
n_past, n_rot, 0, 0);
KQ_pos, n_rot, 0, 0);
struct ggml_tensor * Kcur = ggml_rope(ctx0,
ggml_reshape_3d(ctx0,
ggml_mul_mat(ctx0,
Expand All @@ -1139,7 +1164,7 @@ static struct ggml_tensor * forward_lora(
model->layers[il].wkb,
cur)),
n_embd/n_head, n_head, N),
n_past, n_rot, 0, 0);
KQ_pos, n_rot, 0, 0);

// store key and value to memory
{
Expand Down
5 changes: 5 additions & 0 deletions examples/batched/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
set(TARGET batched)
add_executable(${TARGET} batched.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)
44 changes: 44 additions & 0 deletions examples/batched/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# llama.cpp/example/batched

The example demonstrates batched generation from a given prompt

```bash
./batched ./models/llama-7b-v2/ggml-model-f16.gguf "Hello my name is" 4

...

main: n_len = 32, n_ctx = 2048, n_parallel = 4, n_kv_req = 113

Hello my name is

main: generating 4 sequences ...

main: stream 0 finished
main: stream 1 finished
main: stream 2 finished
main: stream 3 finished

sequence 0:

Hello my name is Shirley. I am a 25-year-old female who has been working for over 5 years as a b

sequence 1:

Hello my name is Renee and I'm a 32 year old female from the United States. I'm looking for a man between

sequence 2:

Hello my name is Diana. I am looking for a housekeeping job. I have experience with children and have my own transportation. I am

sequence 3:

Hello my name is Cody. I am a 3 year old neutered male. I am a very friendly cat. I am very playful and

main: decoded 108 tokens in 3.57 s, speed: 30.26 t/s

llama_print_timings: load time = 587.00 ms
llama_print_timings: sample time = 2.56 ms / 112 runs ( 0.02 ms per token, 43664.72 tokens per second)
llama_print_timings: prompt eval time = 4089.11 ms / 118 tokens ( 34.65 ms per token, 28.86 tokens per second)
llama_print_timings: eval time = 0.00 ms / 1 runs ( 0.00 ms per token, inf tokens per second)
llama_print_timings: total time = 4156.04 ms
```
Loading
Loading