Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break;
}
params.yarn_beta_slow = std::stof(argv[i]);
} else if (arg == "--defrag-thold" || arg == "-dt") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.defrag_thold = std::stof(argv[i]);
} else if (arg == "--samplers") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -1004,6 +1010,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
printf(" -dt N, --defrag-thold N\n");
printf(" KV cache defragmentation threshold (default: %.1f, < 0 - disabled)\n", params.defrag_thold);
printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
printf(" --no-penalize-nl do not penalize newline token\n");
printf(" --temp N temperature (default: %.1f)\n", (double)sparams.temp);
Expand Down Expand Up @@ -1285,6 +1293,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.yarn_beta_fast = params.yarn_beta_fast;
cparams.yarn_beta_slow = params.yarn_beta_slow;
cparams.yarn_orig_ctx = params.yarn_orig_ctx;
cparams.defrag_thold = params.defrag_thold;
cparams.offload_kqv = !params.no_kv_offload;

cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ struct gpt_params {
float yarn_beta_fast = 32.0f; // YaRN low correction dim
float yarn_beta_slow = 1.0f; // YaRN high correction dim
int32_t yarn_orig_ctx = 0; // YaRN original context length
float defrag_thold = -1.0f; // KV cache defragmentation threshold
int32_t rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;

Expand Down
4 changes: 2 additions & 2 deletions examples/passkey/passkey.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ int main(int argc, char ** argv) {

llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
llama_kv_cache_defrag (ctx);
//llama_kv_cache_defrag (ctx);
llama_kv_cache_update (ctx);

n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
Expand Down Expand Up @@ -213,7 +213,7 @@ int main(int argc, char ** argv) {

llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
llama_kv_cache_defrag (ctx);
//llama_kv_cache_defrag (ctx);
llama_kv_cache_update (ctx);

n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
Expand Down
97 changes: 69 additions & 28 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1641,6 +1641,7 @@ struct llama_cparams {
float yarn_attn_factor;
float yarn_beta_fast;
float yarn_beta_slow;
float defrag_thold;

bool mul_mat_q;
bool offload_kqv;
Expand Down Expand Up @@ -5114,16 +5115,16 @@ struct llm_build_context {
struct ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids) {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);

for (int i = 0; i < n_kv; ++i) {
const int id = ids[i];
for (uint32_t i = 0; i < ids.size(); ++i) {
const uint32_t id = ids[i];

if (i == id || id == n_kv) {
if (i == id || id == ids.size()) {
continue;
}

int nm = 1;
uint32_t nm = 1;

while (i + nm < n_kv && (int) ids[i + nm] == id + nm) {
while (i + nm < ids.size() && ids[i + nm] == id + nm) {
nm++;
}

Expand Down Expand Up @@ -5155,6 +5156,8 @@ struct llm_build_context {
i += nm - 1;
}

//LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);

return gf;
}

Expand Down Expand Up @@ -7935,6 +7938,8 @@ static int llama_decode_internal(
batch.seq_id = seq_id_arr.data();
}

llama_kv_cache_update(&lctx);

// if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it
if (kv_self.head > kv_self.used + 2*n_tokens) {
Expand All @@ -7953,8 +7958,6 @@ static int llama_decode_internal(

//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);

llama_kv_cache_update(&lctx);

ggml_backend_sched_reset(lctx.sched);
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);

Expand Down Expand Up @@ -8004,6 +8007,18 @@ static int llama_decode_internal(
}
}

// decide if we need to defrag the kv cache
if (cparams.defrag_thold >= 0.0f) {
const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used + n_tokens)/float(kv_self.n) : 0.0f;

// queue defragmentation for next llama_kv_cache_update
if (fragmentation > cparams.defrag_thold) {
//LLAMA_LOG_INFO("fragmentation: %.2f\n", fragmentation);

llama_kv_cache_defrag(kv_self);
}
}

#ifdef GGML_PERF
// print timing information per ggml operation (for debugging purposes)
// requires GGML_PERF to be defined
Expand Down Expand Up @@ -8095,12 +8110,16 @@ static int llama_decode_internal(
static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
auto & kv_self = lctx.kv_self;

const auto & hparams = lctx.model.hparams;

const uint32_t n_layer = hparams.n_layer;

const uint32_t n_kv = llama_kv_cache_cell_max(kv_self);
const uint32_t n_used = kv_self.used;

assert(n_used <= n_kv);

const int64_t t_start = ggml_time_us();
//const int64_t t_start = ggml_time_us();

// number of cells moved
uint32_t n_moves = 0;
Expand All @@ -8124,15 +8143,26 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {

// found a hole - fill it with data from the end of the cache

// determine the size of the hole
uint32_t nh = 1;

// determine the size of the hole
while (i0 + nh < n_used && kv_self.cells[i0 + nh].is_empty()) {
nh++;
}

// starting from the end, find nh non-empty cells
// each move requires 6*n_layer tensors (see build_defrag)
// - source view, destination view, copy operation
// - x2 for keys and values
//
if (6*(n_moves + nh)*n_layer >= LLAMA_MAX_NODES) {
// the graph is too big, we cannot move more cells
break;
}

uint32_t nf = 0;
uint32_t is = n_kv - 1;

// starting from the end, find nh non-empty cells
for (; is > i0; --is) {
const auto & cell1 = kv_self.cells[is];

Expand All @@ -8153,11 +8183,17 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {

nf = 0;

uint32_t i1 = is;

// are we moving a continuous block of memory?
bool cont = false;

// go back and move the nf cells to the hole
for (uint32_t i1 = is; i1 < n_kv; ++i1) {
const auto & cell1 = kv_self.cells[i1];
for (; i1 < n_kv; ++i1) {
auto & cell1 = kv_self.cells[i1];

if (cell1.is_empty() || ids[i1] != n_kv) {
cont = false;
continue;
}

Expand All @@ -8167,11 +8203,23 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
// move the cell meta data
kv_self.cells[i0 + nf] = cell1;

n_moves++;
// clear the old cell and move the head there
cell1 = llama_kv_cell();
kv_self.head = n_used;

if (!cont) {
n_moves++;
cont = true;
}

nf++;

if (nf == nh) {
break;
}
}

LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, n_kv, i0, i0 + nh);
//LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);

i0 += nh - 1;
}
Expand All @@ -8180,15 +8228,9 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
return;
}

LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", n_moves);

kv_self.head = n_used;
kv_self.used = n_used;
//LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", n_moves);

// zero the rest of the cells
for (uint32_t i = n_used; i < n_kv; ++i) {
kv_self.cells[i] = llama_kv_cell();
}
//LLAMA_LOG_INFO("expected gf nodes: %u\n", 6*n_moves*n_layer);

#if 0
// CPU defrag
Expand All @@ -8200,9 +8242,6 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
// likely not worth the effort, as we have ggml_graph based defrag
//

const auto & hparams = lctx.model.hparams;

const uint32_t n_layer = hparams.n_layer;
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();

Expand Down Expand Up @@ -8271,9 +8310,9 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
llama_graph_compute(lctx, gf, lctx.cparams.n_threads);
#endif

const int64_t t_end = ggml_time_us();
//const int64_t t_end = ggml_time_us();

LLAMA_LOG_INFO("(tmp log) KV defrag time: %.3f ms\n", (t_end - t_start)/1000.0);
//LLAMA_LOG_INFO("(tmp log) KV defrag time: %.3f ms\n", (t_end - t_start)/1000.0);
}

static void llama_kv_cache_update_internal(struct llama_context & lctx) {
Expand Down Expand Up @@ -11635,6 +11674,7 @@ struct llama_context_params llama_context_default_params() {
/*.yarn_beta_fast =*/ 32.0f,
/*.yarn_beta_slow =*/ 1.0f,
/*.yarn_orig_ctx =*/ 0,
/*.defrag_thold =*/ -1.0f,
/*.cb_eval =*/ nullptr,
/*.cb_eval_user_data =*/ nullptr,
/*.type_k =*/ GGML_TYPE_F16,
Expand Down Expand Up @@ -11799,6 +11839,7 @@ struct llama_context * llama_new_context_with_model(
cparams.yarn_attn_factor = params.yarn_attn_factor;
cparams.yarn_beta_fast = params.yarn_beta_fast;
cparams.yarn_beta_slow = params.yarn_beta_slow;
cparams.defrag_thold = params.defrag_thold;
cparams.mul_mat_q = params.mul_mat_q;
cparams.offload_kqv = params.offload_kqv;
cparams.do_pooling = params.do_pooling;
Expand Down Expand Up @@ -12000,7 +12041,7 @@ struct llama_context * llama_new_context_with_model(
}

// buffer used to store the computation graph and the tensor meta data
ctx->buf_compute_meta.resize(ggml_tensor_overhead()*LLAMA_MAX_NODES + ggml_graph_overhead());
ctx->buf_compute_meta.resize(ggml_tensor_overhead()*LLAMA_MAX_NODES + ggml_graph_overhead_custom(LLAMA_MAX_NODES, false));

ctx->sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(), ctx->backends.size(), LLAMA_MAX_NODES);

Expand Down
1 change: 1 addition & 0 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ extern "C" {
float yarn_beta_fast; // YaRN low correction dim
float yarn_beta_slow; // YaRN high correction dim
uint32_t yarn_orig_ctx; // YaRN original context size
float defrag_thold; // defragment the KV cache if holes/size > thold, < 0 disabled (default)

ggml_backend_sched_eval_callback cb_eval;
void * cb_eval_user_data;
Expand Down