Skip to content

Commit

Permalink
talk-llama : optional wake-up command and audio confirmation (#1765)
Browse files Browse the repository at this point in the history
* talk-llama: add optional wake-word detection from command

* talk-llama: add optional audio confirmation before generating answer

* talk-llama: fix small formatting issue in output

* talk-llama.cpp: fix Windows build
  • Loading branch information
Rakksor committed Jan 16, 2024
1 parent f5f159c commit f661415
Showing 1 changed file with 62 additions and 2 deletions.
64 changes: 62 additions & 2 deletions examples/talk-llama/talk-llama.cpp
Expand Up @@ -14,6 +14,7 @@
#include <thread>
#include <vector>
#include <regex>
#include <sstream>

std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos) {
auto * model = llama_get_model(ctx);
Expand Down Expand Up @@ -68,6 +69,8 @@ struct whisper_params {

std::string person = "Georgi";
std::string bot_name = "LLaMA";
std::string wake_cmd = "";
std::string heard_ok = "";
std::string language = "en";
std::string model_wsp = "models/ggml-base.en.bin";
std::string model_llama = "models/ggml-llama-7B.bin";
Expand Down Expand Up @@ -104,6 +107,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; }
else if (arg == "-bn" || arg == "--bot-name") { params.bot_name = argv[++i]; }
else if (arg == "--session") { params.path_session = argv[++i]; }
else if (arg == "-w" || arg == "--wake-command") { params.wake_cmd = argv[++i]; }
else if (arg == "-ho" || arg == "--heard-ok") { params.heard_ok = argv[++i]; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-mw" || arg == "--model-whisper") { params.model_wsp = argv[++i]; }
else if (arg == "-ml" || arg == "--model-llama") { params.model_llama = argv[++i]; }
Expand Down Expand Up @@ -149,6 +154,8 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
fprintf(stderr, " -p NAME, --person NAME [%-7s] person name (for prompt selection)\n", params.person.c_str());
fprintf(stderr, " -bn NAME, --bot-name NAME [%-7s] bot name (to display)\n", params.bot_name.c_str());
fprintf(stderr, " -w TEXT, --wake-command T [%-7s] wake-up command to listen for\n", params.wake_cmd.c_str());
fprintf(stderr, " -ho TEXT, --heard-ok TEXT [%-7s] said by TTS before generating reply\n", params.heard_ok.c_str());
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
fprintf(stderr, " -mw FILE, --model-whisper [%-7s] whisper model file\n", params.model_wsp.c_str());
fprintf(stderr, " -ml FILE, --model-llama [%-7s] llama model file\n", params.model_llama.c_str());
Expand Down Expand Up @@ -227,6 +234,18 @@ std::string transcribe(
return result;
}

std::vector<std::string> get_words(const std::string &txt) {
std::vector<std::string> words;

std::istringstream iss(txt);
std::string word;
while (iss >> word) {
words.push_back(word);
}

return words;
}

const std::string k_prompt_whisper = R"(A conversation with a person called {1}.)";

const std::string k_prompt_llama = R"(Text transcript of a never ending dialog, where {0} interacts with an AI assistant named {1}.
Expand Down Expand Up @@ -441,6 +460,16 @@ int main(int argc, char ** argv) {
bool need_to_save_session = !path_session.empty() && n_matching_session_tokens < (embd_inp.size() * 3 / 4);

printf("%s : done! start speaking in the microphone\n", __func__);

// show wake command if enabled
const std::string wake_cmd = params.wake_cmd;
const int wake_cmd_length = get_words(wake_cmd).size();
const bool use_wake_cmd = wake_cmd_length > 0;

if (use_wake_cmd) {
printf("%s : the wake-up command is: '%s%s%s'\n", __func__, "\033[1m", wake_cmd.c_str(), "\033[0m");
}

printf("\n");
printf("%s%s", params.person.c_str(), chat_symb.c_str());
fflush(stdout);
Expand Down Expand Up @@ -486,10 +515,41 @@ int main(int argc, char ** argv) {

audio.get(params.voice_ms, pcmf32_cur);

std::string text_heard;
std::string all_heard;

if (!force_speak) {
text_heard = ::trim(::transcribe(ctx_wsp, params, pcmf32_cur, prompt_whisper, prob0, t_ms));
all_heard = ::trim(::transcribe(ctx_wsp, params, pcmf32_cur, prompt_whisper, prob0, t_ms));
}

const auto words = get_words(all_heard);

std::string wake_cmd_heard;
std::string text_heard;

for (int i = 0; i < (int) words.size(); ++i) {
if (i < wake_cmd_length) {
wake_cmd_heard += words[i] + " ";
} else {
text_heard += words[i] + " ";
}
}

// check if audio starts with the wake-up command if enabled
if (use_wake_cmd) {
const float sim = similarity(wake_cmd_heard, wake_cmd);

if ((sim < 0.7f) || (text_heard.empty())) {
audio.clear();
continue;
}
}

// optionally give audio feedback that the current text is being processed
if (!params.heard_ok.empty()) {
int ret = system((params.speak + " " + std::to_string(voice_id) + " '" + params.heard_ok + "'").c_str());
if (ret != 0) {
fprintf(stderr, "%s: failed to speak\n", __func__);
}
}

// remove text between brackets using regex
Expand Down

0 comments on commit f661415

Please sign in to comment.