diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 276e1697d466c..812bf2530491a 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -151,7 +151,8 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1 = std::numeric_limits::max(); } - // models like Mamba or RWKV can't have a state partially erased + // models like Mamba or RWKV can't have a state partially erased at the end + // of the sequence because their state isn't preserved for previous tokens if (seq_id >= (int64_t) size) { // could be fatal return false; @@ -160,8 +161,8 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos int32_t & tail_id = cells[seq_id].tail; if (tail_id >= 0) { const auto & cell = cells[tail_id]; - // partial intersection is invalid - if ((0 < p0 && p0 < cell.pos) || (0 < p1 && p1 <= cell.pos)) { + // partial intersection is invalid if it includes the final pos + if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) { //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false\n"); return false; } diff --git a/tools/main/main.cpp b/tools/main/main.cpp index 498e00e3a5e58..33e8862335793 100644 --- a/tools/main/main.cpp +++ b/tools/main/main.cpp @@ -354,7 +354,11 @@ int main(int argc, char ** argv) { } // remove any "future" tokens that we might have inherited from the previous session - llama_memory_seq_rm(mem, -1, n_matching_session_tokens, -1); + if (!llama_memory_seq_rm(mem, -1, n_matching_session_tokens, -1)) { + LOG_INF("%s: unable to resuse common prefix\n", __func__); + n_matching_session_tokens = 0; + llama_memory_seq_rm(mem, -1, -1, -1); + } } LOG_DBG("recalculate the cached logits (check): embd_inp.size() %zu, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu\n",