From 965568dcd722462466afc1a729be55fb884ab64c Mon Sep 17 00:00:00 2001 From: Jakob Frick Date: Sun, 2 Jul 2023 14:48:02 -0400 Subject: [PATCH] dolly : add interactive prompt and port mode (#319) * update basic function to execute prompt * try to factor our prediciton loop * update code * update prompt things * only render at the end * add basic server port * refactor * fix client file descriptor * undo common.h style changes * undo sytle changes to main.cpp * fix check for interactive port --- examples/common.cpp | 5 + examples/common.h | 3 + examples/dolly-v2/main.cpp | 253 +++++++++++++++++++++++++++---------- 3 files changed, 197 insertions(+), 64 deletions(-) diff --git a/examples/common.cpp b/examples/common.cpp index 960656def..7d215ae1f 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -47,6 +47,11 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.n_batch = std::stoi(argv[++i]); } else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; + } else if (arg == "-i" || arg == "--interactive") { + params.interactive = true; + } else if (arg == "-ip" || arg == "--interactive-port") { + params.interactive = true; + params.interactive_port = std::stoi(argv[++i]); } else if (arg == "-h" || arg == "--help") { gpt_print_usage(argc, argv, params); exit(0); diff --git a/examples/common.h b/examples/common.h index 12b2b339d..74655cbfc 100644 --- a/examples/common.h +++ b/examples/common.h @@ -31,6 +31,9 @@ struct gpt_params { std::string model = "models/gpt-2-117M/ggml-model.bin"; // model path std::string prompt = ""; std::string token_test = ""; + + bool interactive = false; + int interactive_port = -1; }; bool gpt_params_parse(int argc, char ** argv, gpt_params & params); diff --git a/examples/dolly-v2/main.cpp b/examples/dolly-v2/main.cpp index 0d511b4e6..9bc5e1a79 100644 --- a/examples/dolly-v2/main.cpp +++ b/examples/dolly-v2/main.cpp @@ -9,10 +9,16 @@ #include #include #include +#include #include #include #include +#include +#include +#include +#include + #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data #endif @@ -671,61 +677,25 @@ bool dollyv2_eval( return true; } -int main(int argc, char ** argv) { - ggml_time_init(); - - const int64_t t_main_start_us = ggml_time_us(); - - gpt_params params; - params.model = "models/dolly-v2-3b/ggml-model-f16.bin"; - - if (gpt_params_parse(argc, argv, params) == false) { - return 1; - } - - if (params.seed < 0) { - params.seed = time(NULL); - } - - printf("%s: seed = %d\n", __func__, params.seed); - - std::mt19937 rng(params.seed); - if (params.prompt.empty()) { - params.prompt = gpt_random_prompt(rng); - } - - const std::string prompt = prompt_for_generation(params.prompt); - - int64_t t_load_us = 0; - - gpt_vocab vocab; - dollyv2_model model; - - // load the model - { - const int64_t t_start_us = ggml_time_us(); - - if (!dollyv2_model_load(params.model, model, vocab)) { - fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); - return 1; - } - - t_load_us = ggml_time_us() - t_start_us; - - test_gpt_tokenizer(vocab, params.token_test); - } - - int n_past = 0; - - int64_t t_sample_us = 0; - int64_t t_predict_us = 0; - +std::string execute_prompt( + const dollyv2_model &model, + gpt_vocab &vocab, + const std::string &prompt, + gpt_params ¶ms, + std::mt19937 &rng, + int64_t t_load_us, + int64_t t_sample_us, + int64_t t_predict_us, + size_t mem_per_token, + int n_past, + bool stream_response_to_cout = false) { + std::string output = ""; std::vector logits; // tokenize the prompt std::vector embd_inp = ::gpt_tokenize(vocab, prompt); - params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size()); + params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int)embd_inp.size()); printf("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); for (int i = 0; i < embd_inp.size(); i++) { @@ -735,9 +705,7 @@ int main(int argc, char ** argv) { std::vector embd; - // determine the required inference memory per token: - size_t mem_per_token = 0; - dollyv2_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token); + dollyv2_eval(model, params.n_threads, 0, {0, 1, 2, 3}, logits, mem_per_token); const int32_t end_token = vocab.token_to_id["### End"]; @@ -748,7 +716,7 @@ int main(int argc, char ** argv) { if (!dollyv2_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) { printf("Failed to predict\n"); - return 1; + return output; } t_predict_us += ggml_time_us() - t_start_us; @@ -759,9 +727,9 @@ int main(int argc, char ** argv) { if (i >= embd_inp.size()) { // sample next token - const int top_k = params.top_k; + const int top_k = params.top_k; const float top_p = params.top_p; - const float temp = params.temp; + const float temp = params.temp; const int n_vocab = model.hparams.n_vocab; @@ -777,7 +745,6 @@ int main(int argc, char ** argv) { // add it to the context embd.push_back(id); - } else { // if here, it means we are still processing the input prompt for (int k = i; k < embd_inp.size(); k++) { @@ -791,15 +758,169 @@ int main(int argc, char ** argv) { // display text for (auto id : embd) { - printf("%s", vocab.id_to_token[id].c_str()); + output += vocab.id_to_token[id]; + if (stream_response_to_cout) { + printf("%s", vocab.id_to_token[id].c_str()); + } + } + if (stream_response_to_cout) { + fflush(stdout); } - fflush(stdout); // end of text token if (embd.back() == 0 || (end_token > 0 && embd.back() == end_token)) { - break; + return output; } } + return output; +} + +int setup_port(const int port) { + int sockfd = socket(AF_INET, SOCK_STREAM, 0); + if (sockfd < 0) { + std::cerr << "Failed to create socket\n"; + return -1; + } + + sockaddr_in servaddr; + std::memset(&servaddr, 0, sizeof(servaddr)); + + servaddr.sin_family = AF_INET; + servaddr.sin_addr.s_addr = htonl(INADDR_ANY); + servaddr.sin_port = htons(port); + + if (bind(sockfd, (struct sockaddr *)&servaddr, sizeof(servaddr)) < 0) { + std::cerr << "Failed to bind to port\n"; + return -1; + } + + if (listen(sockfd, 10) < 0) { + std::cerr << "Failed to listen on socket\n"; + return -1; + } + return sockfd; +} + +std::string read_from_port(int sockfd, int clientfd) { + if (clientfd < 0) { + std::cerr << "Failed to accept new connection\n"; + return ""; + } + + char buffer[4096]; + std::memset(buffer, 0, sizeof(buffer)); + + if (read(clientfd, buffer, sizeof(buffer)) < 0) { + std::cerr << "Failed to read from client\n"; + } else { + std::cout << "Received: " << buffer; + return std::string(buffer); + } + return std::string(""); +} + +int main(int argc, char ** argv) { + ggml_time_init(); + + const int64_t t_main_start_us = ggml_time_us(); + + gpt_params params; + params.model = "models/dolly-v2-3b/ggml-model-f16.bin"; + + if (gpt_params_parse(argc, argv, params) == false) { + return 1; + } + + if (params.seed < 0) { + params.seed = time(NULL); + } + + printf("%s: seed = %d\n", __func__, params.seed); + + std::mt19937 rng(params.seed); + + int64_t t_load_us = 0; + int64_t t_sample_us = 0; + int64_t t_predict_us = 0; + + // determine the required inference memory per token: + size_t mem_per_token = 0; + + int n_past = 0; + + gpt_vocab vocab; + dollyv2_model model; + + // load the model + { + const int64_t t_start_us = ggml_time_us(); + + if (!dollyv2_model_load(params.model, model, vocab)) { + fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); + return 1; + } + + t_load_us = ggml_time_us() - t_start_us; + + test_gpt_tokenizer(vocab, params.token_test); + } + + int sockfd; + if (params.interactive_port != -1) { + sockfd = setup_port(params.interactive_port); + if (sockfd == -1) { + return 1; + } + fprintf(stdout, "Model is ready on port %i\n", params.interactive_port); + fflush(stdout); + } + + if (params.interactive or params.interactive_port != -1) { + while (true) { + std::string prompt_input; + int clientfd; + if (params.interactive_port != -1) { + sockaddr_in clientaddr; + socklen_t clientaddrlen = sizeof(clientaddr); + clientfd = accept(sockfd, (struct sockaddr *)&clientaddr, &clientaddrlen); + prompt_input = read_from_port(sockfd, clientfd); + } else { + printf("Please enter your quesiton:\n>"); + fflush(stdout); + + std::getline(std::cin, prompt_input); + } + + if (strcmp(prompt_input.c_str(), "exit") == 0) { + break; + } + + const std::string prompt = prompt_for_generation(prompt_input); + // call the model + const std::string response = execute_prompt(model, vocab, prompt, params, rng, t_load_us, t_sample_us, t_predict_us, mem_per_token, n_past, true); + + if (params.interactive_port != -1) { + if (write(clientfd, response.c_str(), response.size()) < 0) { + std::cerr << "Failed to write to client\n"; + } + + if (close(clientfd) < 0) { + std::cerr << "Failed to close client socket\n"; + } + } + else { + printf("%s\n\n", response.c_str()); + } + fflush(stdout); + } + } else { + if (params.prompt.empty()) { + params.prompt = gpt_random_prompt(rng); + } + + const std::string prompt = prompt_for_generation(params.prompt); + execute_prompt(model, vocab, prompt, params, rng, t_load_us, t_sample_us, t_predict_us, mem_per_token, n_past, true); + } // report timing { @@ -807,13 +928,17 @@ int main(int argc, char ** argv) { printf("\n\n"); printf("%s: mem per token = %8zu bytes\n", __func__, mem_per_token); - printf("%s: load time = %8.2f ms\n", __func__, t_load_us/1000.0f); - printf("%s: sample time = %8.2f ms\n", __func__, t_sample_us/1000.0f); - printf("%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us/1000.0f, t_predict_us/1000.0f/n_past); - printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f); + printf("%s: load time = %8.2f ms\n", __func__, t_load_us / 1000.0f); + printf("%s: sample time = %8.2f ms\n", __func__, t_sample_us / 1000.0f); + printf("%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us / 1000.0f, t_predict_us / 1000.0f / n_past); + printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us) / 1000.0f); } ggml_free(model.ctx); + if (params.interactive_port != -1 && close(sockfd) < 0) { + std::cerr << "Failed to close server socket\n"; + } + return 0; }