Skip to content

Commit

Permalink
talk-llama : add --session support (ggerganov#845)
Browse files Browse the repository at this point in the history
* feat: adding session support

* readme: adding --session info in examples/talk-llama

* llama: adding session fixes

* readme: updating session doc

* talk-llama: update the value of need_to_save_session to true in order to save the session in the subsequent interaction

* talk-llama: adding missing function which updates session_tokens
  • Loading branch information
herrera-luis committed May 1, 2023
1 parent 71fbfa3 commit 75665af
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 42 deletions.
14 changes: 14 additions & 0 deletions examples/talk-llama/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,20 @@ make talk-llama
- The `-mw` argument specifies the Whisper model that you would like to use. Recommended `base` or `small` for real-time experience
- The `-ml` argument specifies the LLaMA model that you would like to use. Read the instructions in https://github.com/ggerganov/llama.cpp for information about how to obtain a `ggml` compatible LLaMA model

## Session

The `talk-llama` tool supports session management to enable more coherent and continuous conversations. By maintaining context from previous interactions, it can better understand and respond to user requests in a more natural way.

To enable session support, use the `--session FILE` command line option when running the program. The `talk-llama` model state will be saved to the specified file after each interaction. If the file does not exist, it will be created. If the file exists, the model state will be loaded from it, allowing you to resume a previous session.

This feature is especially helpful for maintaining context in long conversations or when interacting with the AI assistant across multiple sessions. It ensures that the assistant remembers the previous interactions and can provide more relevant and contextual responses.

Example usage:

```bash
./talk-llama --session ./my-session-file -mw ./models/ggml-small.en.bin -ml ../llama.cpp/models/13B/ggml-model-q4_0.bin -p "Georgi" -t 8
```

## TTS

For best experience, this example needs a TTS tool to convert the generated text responses to voice.
Expand Down
95 changes: 60 additions & 35 deletions examples/talk-llama/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2695,56 +2695,81 @@ std::vector<std::pair<std::string, struct ggml_tensor *>>& llama_internal_get_te
return ctx->model.tensors_by_name;
}

size_t llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
// TODO leverage mmap
bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
llama_file file(path_session, "rb");
const uint32_t magic = file.read_u32();
const uint32_t version = file.read_u32();

if (!(magic == 'ggsn' && version == 0)) {
fprintf(stderr, "%s : unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version);
return 0;
// sanity checks
{
const uint32_t magic = file.read_u32();
const uint32_t version = file.read_u32();

if (!(magic == LLAMA_SESSION_MAGIC && version == LLAMA_SESSION_VERSION)) {
fprintf(stderr, "%s : unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version);
return false;
}

llama_hparams session_hparams;
file.read_raw(&session_hparams, sizeof(llama_hparams));

if (session_hparams != ctx->model.hparams) {
fprintf(stderr, "%s : model hparams didn't match from session file!\n", __func__);
return false;
}
}

llama_hparams session_hparams;
file.read_raw(&session_hparams, sizeof(llama_hparams));
// load the prompt
{
const uint32_t n_token_count = file.read_u32();

// REVIEW
if (session_hparams != ctx->model.hparams) {
fprintf(stderr, "%s : model hparams didn't match from session file!\n", __func__);
return 0;
if (n_token_count > n_token_capacity) {
fprintf(stderr, "%s : token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
return false;
}

file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
*n_token_count_out = n_token_count;
}

const uint32_t n_token_count = file.read_u32();
LLAMA_ASSERT(n_token_capacity >= n_token_count);
file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
*n_token_count_out = n_token_count;
// restore the context state
{
const size_t n_state_size_cur = file.size - file.tell();
const size_t n_state_size_exp = llama_get_state_size(ctx);

if (n_state_size_cur != n_state_size_exp) {
fprintf(stderr, "%s : the state size in session file didn't match! expected %zu, got %zu\n", __func__, n_state_size_exp, n_state_size_cur);
return false;
}

std::vector<uint8_t> state_data(n_state_size_cur);
file.read_raw(state_data.data(), n_state_size_cur);

const size_t n_state_size = file.size - file.tell();
const size_t n_orig_state_size = llama_get_state_size(ctx);
if (n_state_size != n_orig_state_size) {
fprintf(stderr, "%s : failed to validate state size\n", __func__);
llama_set_state_data(ctx, state_data.data());
}
std::unique_ptr<uint8_t[]> state_data(new uint8_t[n_state_size]);
file.read_raw(state_data.get(), n_state_size);
return llama_set_state_data(ctx, state_data.get());

return true;
}

size_t llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
// TODO save temp & swap
bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
llama_file file(path_session, "wb");

const size_t n_state_size = llama_get_state_size(ctx);
std::unique_ptr<uint8_t[]> state_data(new uint8_t[n_state_size]);
llama_copy_state_data(ctx, state_data.get());
file.write_u32(LLAMA_SESSION_MAGIC);
file.write_u32(LLAMA_SESSION_VERSION);

file.write_u32('ggsn'); // magic
file.write_u32(0); // version
file.write_raw(&ctx->model.hparams, sizeof(llama_hparams));

file.write_u32((uint32_t) n_token_count); // REVIEW
// save the prompt
file.write_u32((uint32_t) n_token_count);
file.write_raw(tokens, sizeof(llama_token) * n_token_count);

file.write_raw(state_data.get(), n_state_size);
return n_state_size; // REVIEW
}
// save the context state
{
const size_t n_state_size = llama_get_state_size(ctx);

std::vector<uint8_t> state_data(n_state_size);
llama_copy_state_data(ctx, state_data.data());

file.write_raw(state_data.data(), n_state_size);
}

return true;
}
13 changes: 7 additions & 6 deletions examples/talk-llama/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
# define LLAMA_API
#endif

#define LLAMA_FILE_VERSION 1
#define LLAMA_FILE_MAGIC 0x67676a74 // 'ggjt' in hex
#define LLAMA_FILE_MAGIC_UNVERSIONED 0x67676d6c // pre-versioned files
#define LLAMA_FILE_VERSION 1
#define LLAMA_FILE_MAGIC 'ggjt'
#define LLAMA_FILE_MAGIC_UNVERSIONED 'ggml'
#define LLAMA_SESSION_MAGIC 'ggsn'
#define LLAMA_SESSION_VERSION 0

#ifdef __cplusplus
extern "C" {
Expand Down Expand Up @@ -138,9 +140,8 @@ extern "C" {
LLAMA_API size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src);

// Save/load session file
LLAMA_API size_t llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out);
LLAMA_API size_t llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count);

LLAMA_API bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out);
LLAMA_API bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count);
// Run the llama inference to obtain the logits and probabilities for the next token.
// tokens + n_tokens is the provided batch of new tokens to process
// n_past is the number of tokens to use from previous eval calls
Expand Down
91 changes: 90 additions & 1 deletion examples/talk-llama/talk-llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ struct whisper_params {
std::string speak = "./examples/talk-llama/speak.sh";
std::string prompt = "";
std::string fname_out;
std::string path_session = ""; // path to file for saving/loading model eval state
};

void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
Expand All @@ -78,6 +79,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
else if (arg == "--verbose-prompt") { params.verbose_prompt = true; }
else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; }
else if (arg == "--session") { params.path_session = argv[++i];}
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-mw" || arg == "--model-whisper") { params.model_wsp = argv[++i]; }
else if (arg == "-ml" || arg == "--model-llama") { params.model_llama = argv[++i]; }
Expand Down Expand Up @@ -124,6 +126,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " --n-parts-llama N [%-7d] num parts in llama model file\n", params.n_parts_llama);
fprintf(stderr, " -s FILE, --speak TEXT [%-7s] command for TTS\n", params.speak.c_str());
fprintf(stderr, " --prompt-file FNAME [%-7s] file with custom prompt to start dialog\n", "");
fprintf(stderr, " --session FNAME file to cache model state in (may be large!) (default: none)\n");
fprintf(stderr, " --verbose-prompt [%-7s] print prompt at start\n", params.verbose_prompt ? "true" : "false");
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
fprintf(stderr, "\n");
Expand Down Expand Up @@ -348,6 +351,57 @@ int main(int argc, char ** argv) {
fflush(stdout);
}

// init session
std::string path_session = params.path_session;
std::vector<llama_token> session_tokens;

if (!path_session.empty()) {
fprintf(stderr, "%s: attempting to load saved session from %s\n", __func__, path_session.c_str());

// fopen to check for existing session
FILE * fp = std::fopen(path_session.c_str(), "rb");
if (fp != NULL) {
std::fclose(fp);

session_tokens.resize(lparams.n_ctx);
size_t n_token_count_out = 0;
if (!llama_load_session_file(ctx_llama, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) {
fprintf(stderr, "%s: error: failed to load session file '%s'\n", __func__, path_session.c_str());
return 1;
}
session_tokens.resize(n_token_count_out);

fprintf(stderr, "%s: loaded a session with prompt size of %d tokens\n", __func__, (int) session_tokens.size());
} else {
fprintf(stderr, "%s: session file does not exist, will create\n", __func__);
}
}

// debug message about similarity of saved session, if applicable
size_t n_matching_session_tokens = 0;
if (session_tokens.size()) {
for (llama_token id : session_tokens) {
if (n_matching_session_tokens >= embd_inp.size() || id != embd_inp[n_matching_session_tokens]) {
break;
}
n_matching_session_tokens++;
}
if (n_matching_session_tokens >= embd_inp.size()) {
fprintf(stderr, "%s: session file has exact match for prompt!\n", __func__);
} else if (n_matching_session_tokens < (embd_inp.size() / 2)) {
fprintf(stderr, "%s: warning: session file has low similarity to prompt (%zu / %zu tokens); will mostly be reevaluated\n",
__func__, n_matching_session_tokens, embd_inp.size());
} else {
fprintf(stderr, "%s: session file matches %zu / %zu tokens of prompt\n",
__func__, n_matching_session_tokens, embd_inp.size());
}
}

// HACK - because session saving incurs a non-negligible delay, for now skip re-saving session
// if we loaded a session with at least 75% similarity. It's currently just used to speed up the
// initial prompt so it doesn't need to be an exact match.
bool need_to_save_session = !path_session.empty() && n_matching_session_tokens < (embd_inp.size() * 3 / 4);

printf("%s : done! start speaking in the microphone\n", __func__);
printf("\n");
printf("%s%s", params.person.c_str(), chat_symb.c_str());
Expand All @@ -363,6 +417,7 @@ int main(int argc, char ** argv) {

int n_past = n_keep;
int n_prev = 64; // TODO arg
int n_session_consumed = 0;

std::vector<llama_token> embd;

Expand Down Expand Up @@ -450,7 +505,8 @@ int main(int argc, char ** argv) {

// insert n_left/2 tokens at the start of embd from last_n_tokens
embd.insert(embd.begin(), embd_inp.begin() + embd_inp.size() - n_prev, embd_inp.end());

// stop saving session if we run out of context
path_session = "";
//printf("\n---\n");
//printf("resetting: '");
//for (int i = 0; i < (int) embd.size(); i++) {
Expand All @@ -460,6 +516,29 @@ int main(int argc, char ** argv) {
//printf("\n---\n");
}

// try to reuse a matching prefix from the loaded session instead of re-eval (via n_past)
// REVIEW
if (n_session_consumed < (int) session_tokens.size()) {
size_t i = 0;
for ( ; i < embd.size(); i++) {
if (embd[i] != session_tokens[n_session_consumed]) {
session_tokens.resize(n_session_consumed);
break;
}

n_past++;
n_session_consumed++;

if (n_session_consumed >= (int) session_tokens.size()) {
i++;
break;
}
}
if (i > 0) {
embd.erase(embd.begin(), embd.begin() + i);
}
}

if (llama_eval(ctx_llama, embd.data(), embd.size(), n_past, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return 1;
Expand All @@ -470,6 +549,10 @@ int main(int argc, char ** argv) {

embd_inp.insert(embd_inp.end(), embd.begin(), embd.end());
n_past += embd.size();
if (embd.size() > 0 && !path_session.empty()) {
session_tokens.insert(session_tokens.end(), embd.begin(), embd.end());
n_session_consumed = session_tokens.size();
}
embd.clear();

if (done) break;
Expand All @@ -483,6 +566,11 @@ int main(int argc, char ** argv) {

const int repeat_last_n = 256;

if (!path_session.empty() && need_to_save_session) {
need_to_save_session = false;
llama_save_session_file(ctx_llama, path_session.c_str(), session_tokens.data(), session_tokens.size());
}

llama_token id = 0;

{
Expand Down Expand Up @@ -542,6 +630,7 @@ int main(int argc, char ** argv) {
done = true;
text_to_speak = ::replace(text_to_speak, antiprompt, "");
fflush(stdout);
need_to_save_session = true;
break;
}
}
Expand Down

0 comments on commit 75665af

Please sign in to comment.