diff --git a/llama.cpp b/llama.cpp index 0faaac0134395..a45de326deda5 100644 --- a/llama.cpp +++ b/llama.cpp @@ -15227,7 +15227,9 @@ size_t llama_set_seq_data(struct llama_context * ctx, const uint8_t * src, llama uint32_t size_t_size; memcpy(&size_t_size, inp, sizeof(size_t_size)); inp += sizeof(size_t_size); - GGML_ASSERT(size_t_size == sizeof(size_t)); + if (size_t_size != sizeof(size_t)) { + return -1; + } // Read the cell count uint32_t cell_count; @@ -15244,6 +15246,18 @@ size_t llama_set_seq_data(struct llama_context * ctx, const uint8_t * src, llama memcpy(&n_embd_v_gqa_ref, inp, sizeof(n_embd_v_gqa_ref)); inp += sizeof(n_embd_v_gqa_ref); + // Sanity check model compatibility + const auto& hparams = ctx->model.hparams; + const uint32_t n_layer = hparams.n_layer; + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); + if (n_layer != n_layer_ref) { + return -2; + } + if (n_embd_v_gqa != n_embd_v_gqa_ref) { + return -2; + } + // Allocate the new cells for the slot { llama_batch batch = llama_batch_init(cell_count, 0, 1); @@ -15259,7 +15273,7 @@ size_t llama_set_seq_data(struct llama_context * ctx, const uint8_t * src, llama } if (!llama_kv_cache_find_slot(kv_self, batch)) { llama_batch_free(batch); - return 0; + return -3; } // DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values) @@ -15274,10 +15288,6 @@ size_t llama_set_seq_data(struct llama_context * ctx, const uint8_t * src, llama llama_batch_free(batch); } - const auto& hparams = ctx->model.hparams; - const uint32_t n_layer = hparams.n_layer; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); const uint32_t kv_size = kv_self.size; const uint32_t kv_head = kv_self.head; GGML_ASSERT(n_layer == n_layer_ref); diff --git a/llama.h b/llama.h index 33164a33af217..a29fed38986e1 100644 --- a/llama.h +++ b/llama.h @@ -632,6 +632,13 @@ extern "C" { uint8_t * dst, llama_seq_id seq_id); + // Copy the sequence data (originally copied with `llama_copy_seq_data`) into a sequence. + // Returns: + // - Positive: Ok + // - Negative: An error of some kind + // - -1: `size_t` is incorrect size + // - -2: Model mismatch + // - -3: Cannot find space in KV cache LLAMA_API size_t llama_set_seq_data( struct llama_context * ctx, const uint8_t * src,