From 035f0f425e4524a66f34fa886dcc9c45235539c4 Mon Sep 17 00:00:00 2001 From: Louis Date: Thu, 22 Feb 2024 21:58:10 +0700 Subject: [PATCH] refactor: whisper extends base class --- common/base.h | 67 ++- controllers/{whisperCPP.cc => audio.cc} | 169 +++--- controllers/{whisperCPP.h => audio.h} | 136 ++--- controllers/llamaCPP.cc | 128 ++--- controllers/llamaCPP.h | 654 +++++++++++++----------- 5 files changed, 627 insertions(+), 527 deletions(-) rename controllers/{whisperCPP.cc => audio.cc} (88%) rename controllers/{whisperCPP.h => audio.h} (51%) diff --git a/common/base.h b/common/base.h index 4813592fd..0d82d2ef7 100644 --- a/common/base.h +++ b/common/base.h @@ -3,35 +3,56 @@ using namespace drogon; -class BaseProvider { -public: - virtual ~BaseProvider() {} - - // General inference method - virtual void - inference(const HttpRequestPtr &req, - std::function &&callback) = 0; +class BaseModel { + public: + virtual ~BaseModel() {} // Model management - virtual void - loadModel(const HttpRequestPtr &req, - std::function &&callback) = 0; - virtual void - unloadModel(const HttpRequestPtr &req, - std::function &&callback) = 0; - virtual void - modelStatus(const HttpRequestPtr &req, - std::function &&callback) = 0; + virtual void LoadModel( + const HttpRequestPtr &req, + std::function &&callback) = 0; + virtual void UnloadModel( + const HttpRequestPtr &req, + std::function &&callback) = 0; + virtual void ModelStatus( + const HttpRequestPtr &req, + std::function &&callback) = 0; +}; + +class BaseChatCompletion { + public: + virtual ~BaseChatCompletion() {} + + // General chat method + virtual void ChatCompletion( + const HttpRequestPtr &req, + std::function &&callback) = 0; }; -class ChatProvider : public BaseProvider { -public: - virtual ~ChatProvider() {} +class BaseEmbedding { + public: + virtual ~BaseEmbedding() {} // Implement embedding functionality specific to chat - virtual void - embedding(const HttpRequestPtr &req, - std::function &&callback) = 0; + virtual void Embedding( + const HttpRequestPtr &req, + std::function &&callback) = 0; // The derived class can also override other methods if needed }; + +class BaseAudio { + public: + virtual ~BaseAudio() {} + // Transcribes audio into the input language. + virtual void CreateTranscription( + const HttpRequestPtr &req, + std::function &&callback) = 0; + + // Translates audio into the input language. + virtual void CreateTranslation( + const HttpRequestPtr &req, + std::function &&callback) = 0; + + // The derived class can also override other methods if needed +}; \ No newline at end of file diff --git a/controllers/whisperCPP.cc b/controllers/audio.cc similarity index 88% rename from controllers/whisperCPP.cc rename to controllers/audio.cc index a2039f396..d83867565 100644 --- a/controllers/whisperCPP.cc +++ b/controllers/audio.cc @@ -1,11 +1,12 @@ -#include "whisperCPP.h" +#include "audio.h" // #include "whisper.h" // #include "llama.h" +using namespace v1; -bool read_wav(const std::string &fname, std::vector &pcmf32, - std::vector> &pcmf32s, bool stereo) { +bool read_wav(const std::string& fname, std::vector& pcmf32, + std::vector>& pcmf32s, bool stereo) { drwav wav; - std::vector wav_data; // used for pipe input from stdin + std::vector wav_data; // used for pipe input from stdin if (fname == "-") { { @@ -93,13 +94,13 @@ bool read_wav(const std::string &fname, std::vector &pcmf32, return true; } -std::string output_str(struct whisper_context *ctx, - const whisper_params ¶ms, +std::string output_str(struct whisper_context* ctx, + const whisper_params& params, std::vector> pcmf32s) { std::stringstream result; const int n_segments = whisper_full_n_segments(ctx); for (int i = 0; i < n_segments; ++i) { - const char *text = whisper_full_get_segment_text(ctx, i); + const char* text = whisper_full_get_segment_text(ctx, i); std::string speaker = ""; if (params.diarize && pcmf32s.size() == 2) { @@ -113,9 +114,9 @@ std::string output_str(struct whisper_context *ctx, return result.str(); } -std::string -estimate_diarization_speaker(std::vector> pcmf32s, - int64_t t0, int64_t t1, bool id_only) { +std::string estimate_diarization_speaker( + std::vector> pcmf32s, int64_t t0, int64_t t1, + bool id_only) { std::string speaker = ""; const int64_t n_samples = pcmf32s[0].size(); @@ -172,19 +173,20 @@ int timestamp_to_sample(int64_t t, int n_samples) { (int)((t * WHISPER_SAMPLE_RATE) / 100))); } -bool is_file_exist(const char *fileName) { +bool is_file_exist(const char* fileName) { std::ifstream infile(fileName); return infile.good(); } -void whisper_print_usage(int /*argc*/, char **argv, - const whisper_params ¶ms) { +void whisper_print_usage(int /*argc*/, char** argv, + const whisper_params& params) { fprintf(stderr, "\n"); fprintf(stderr, "usage: %s [options] \n", argv[0]); fprintf(stderr, "\n"); fprintf(stderr, "options:\n"); - fprintf(stderr, " -h, --help [default] show this help " - "message and exit\n"); + fprintf(stderr, + " -h, --help [default] show this help " + "message and exit\n"); fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use " "during computation\n", @@ -292,7 +294,7 @@ void whisper_print_usage(int /*argc*/, char **argv, fprintf(stderr, "\n"); } -bool whisper_params_parse(int argc, char **argv, whisper_params ¶ms) { +bool whisper_params_parse(int argc, char** argv, whisper_params& params) { for (int i = 1; i < argc; i++) { std::string arg = argv[i]; @@ -387,7 +389,7 @@ void check_ffmpeg_availibility() { } } -bool convert_to_wav(const std::string &temp_filename, std::string &error_resp) { +bool convert_to_wav(const std::string& temp_filename, std::string& error_resp) { std::ostringstream cmd_stream; std::string converted_filename_temp = temp_filename + "_temp.wav"; cmd_stream << "ffmpeg -i \"" << temp_filename @@ -415,23 +417,23 @@ bool convert_to_wav(const std::string &temp_filename, std::string &error_resp) { return true; } -void whisper_print_progress_callback(struct whisper_context * /*ctx*/, - struct whisper_state * /*state*/, - int progress, void *user_data) { +void whisper_print_progress_callback(struct whisper_context* /*ctx*/, + struct whisper_state* /*state*/, + int progress, void* user_data) { int progress_step = - ((whisper_print_user_data *)user_data)->params->progress_step; - int *progress_prev = &(((whisper_print_user_data *)user_data)->progress_prev); + ((whisper_print_user_data*)user_data)->params->progress_step; + int* progress_prev = &(((whisper_print_user_data*)user_data)->progress_prev); if (progress >= *progress_prev + progress_step) { *progress_prev += progress_step; fprintf(stderr, "%s: progress = %3d%%\n", __func__, progress); } } -void whisper_print_segment_callback(struct whisper_context *ctx, - struct whisper_state * /*state*/, int n_new, - void *user_data) { - const auto ¶ms = *((whisper_print_user_data *)user_data)->params; - const auto &pcmf32s = *((whisper_print_user_data *)user_data)->pcmf32s; +void whisper_print_segment_callback(struct whisper_context* ctx, + struct whisper_state* /*state*/, int n_new, + void* user_data) { + const auto& params = *((whisper_print_user_data*)user_data)->params; + const auto& pcmf32s = *((whisper_print_user_data*)user_data)->pcmf32s; const int n_segments = whisper_full_n_segments(ctx); @@ -471,7 +473,7 @@ void whisper_print_segment_callback(struct whisper_context *ctx, } } - const char *text = whisper_full_get_token_text(ctx, i, j); + const char* text = whisper_full_get_token_text(ctx, i, j); const float p = whisper_full_get_token_p(ctx, i, j); const int col = (std::max)( @@ -482,7 +484,7 @@ void whisper_print_segment_callback(struct whisper_context *ctx, "\033[0m"); } } else { - const char *text = whisper_full_get_segment_text(ctx, i); + const char* text = whisper_full_get_segment_text(ctx, i); printf("%s%s", speaker.c_str(), text); } @@ -501,14 +503,14 @@ void whisper_print_segment_callback(struct whisper_context *ctx, } } -bool parse_str_to_bool(const std::string &s) { +bool parse_str_to_bool(const std::string& s) { if (s == "true" || s == "1" || s == "yes" || s == "y") { return true; } return false; } -bool whisper_server_context::load_model(std::string &model_path) { +bool whisper_server_context::load_model(std::string& model_path) { whisper_mutex.lock(); // clean up @@ -534,14 +536,14 @@ bool whisper_server_context::load_model(std::string &model_path) { } std::string whisper_server_context::inference( - std::string &input_file_path, std::string language, std::string prompt, + std::string& input_file_path, std::string language, std::string prompt, std::string response_format, float temperature, bool translate) { // acquire whisper model mutex lock whisper_mutex.lock(); // audio arrays - std::vector pcmf32; // mono-channel F32 PCM - std::vector> pcmf32s; // stereo-channel F32 PCM + std::vector pcmf32; // mono-channel F32 PCM + std::vector> pcmf32s; // stereo-channel F32 PCM // if file is not wav, convert to wav if (params.ffmpeg_converter) { @@ -625,7 +627,7 @@ std::string whisper_server_context::inference( wparams.speed_up = params.speed_up; wparams.debug_mode = params.debug_mode; - wparams.tdrz_enable = params.tinydiarize; // [TDRZ] + wparams.tdrz_enable = params.tinydiarize; // [TDRZ] wparams.initial_prompt = prompt.c_str(); @@ -660,12 +662,12 @@ std::string whisper_server_context::inference( // the processing is aborted { static bool is_aborted = - false; // NOTE: this should be atomic to avoid data race + false; // NOTE: this should be atomic to avoid data race - wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, - struct whisper_state * /*state*/, - void *user_data) { - bool is_aborted = *(bool *)user_data; + wparams.encoder_begin_callback = [](struct whisper_context* /*ctx*/, + struct whisper_state* /*state*/, + void* user_data) { + bool is_aborted = *(bool*)user_data; return !is_aborted; }; wparams.encoder_begin_callback_user_data = &is_aborted; @@ -675,10 +677,10 @@ std::string whisper_server_context::inference( // computation is aborted { static bool is_aborted = - false; // NOTE: this should be atomic to avoid data race + false; // NOTE: this should be atomic to avoid data race - wparams.abort_callback = [](void *user_data) { - bool is_aborted = *(bool *)user_data; + wparams.abort_callback = [](void* user_data) { + bool is_aborted = *(bool*)user_data; return is_aborted; }; wparams.abort_callback_user_data = &is_aborted; @@ -701,7 +703,7 @@ std::string whisper_server_context::inference( std::stringstream ss; const int n_segments = whisper_full_n_segments(ctx); for (int i = 0; i < n_segments; ++i) { - const char *text = whisper_full_get_segment_text(ctx, i); + const char* text = whisper_full_get_segment_text(ctx, i); const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(ctx, i); std::string speaker = ""; @@ -722,7 +724,7 @@ std::string whisper_server_context::inference( const int n_segments = whisper_full_n_segments(ctx); for (int i = 0; i < n_segments; ++i) { - const char *text = whisper_full_get_segment_text(ctx, i); + const char* text = whisper_full_get_segment_text(ctx, i); const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(ctx, i); std::string speaker = ""; @@ -796,9 +798,15 @@ whisper_server_context::~whisper_server_context() { } } -std::optional whisperCPP::parse_model_id( - const std::shared_ptr &jsonBody, - const std::function &callback) { +audio::audio() { + whisper_print_system_info(); +}; + +audio::~audio() {} + +std::optional audio::ParseModelId( + const std::shared_ptr& jsonBody, + const std::function& callback) { if (!jsonBody->isMember("model_id")) { LOG_INFO << "No model_id found in request body"; Json::Value jsonResp; @@ -806,17 +814,16 @@ std::optional whisperCPP::parse_model_id( auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); resp->setStatusCode(k400BadRequest); callback(resp); - return std::nullopt; // Signal that an error occurred + return std::nullopt; // Signal that an error occurred } return (*jsonBody)["model_id"].asString(); } -void whisperCPP::load_model( - const HttpRequestPtr &req, - std::function &&callback) { +void audio::LoadModel(const HttpRequestPtr& req, + std::function&& callback) { const auto jsonBody = req->getJsonObject(); - auto optional_model_id = parse_model_id(jsonBody, callback); + auto optional_model_id = ParseModelId(jsonBody, callback); if (!optional_model_id) { return; } @@ -905,11 +912,11 @@ void whisperCPP::load_model( return; } -void whisperCPP::unload_model( - const HttpRequestPtr &req, - std::function &&callback) { - const auto &jsonBody = req->getJsonObject(); - auto optional_model_id = parse_model_id(jsonBody, callback); +void audio::UnloadModel( + const HttpRequestPtr& req, + std::function&& callback) { + const auto& jsonBody = req->getJsonObject(); + auto optional_model_id = ParseModelId(jsonBody, callback); if (!optional_model_id) { return; } @@ -944,13 +951,12 @@ void whisperCPP::unload_model( return; } -void whisperCPP::list_model( - const HttpRequestPtr &req, - std::function &&callback) { +void audio::ListModels(const HttpRequestPtr& req, + std::function&& callback) { // Return a list of all loaded models Json::Value jsonResp; Json::Value models; - for (auto const &model : whispers) { + for (auto const& model : whispers) { models.append(model.first); } jsonResp["models"] = models; @@ -960,9 +966,9 @@ void whisperCPP::list_model( return; } -void whisperCPP::transcription_impl( - const HttpRequestPtr &req, - std::function &&callback, bool translate) { +void audio::TranscriptionImpl( + const HttpRequestPtr& req, + std::function&& callback, bool translate) { MultiPartParser partParser; Json::Value jsonResp; if (partParser.parse(req) != 0 || partParser.getFiles().size() != 1) { @@ -972,8 +978,8 @@ void whisperCPP::transcription_impl( callback(resp); return; } - auto &file = partParser.getFiles()[0]; - const auto &formFields = partParser.getParameters(); + auto& file = partParser.getFiles()[0]; + const auto& formFields = partParser.getParameters(); // Check if model_id are present in the request. If not, return a 400 error if (formFields.find("model_id") == formFields.end()) { @@ -1035,7 +1041,7 @@ void whisperCPP::transcription_impl( result = whispers[model_id].inference(temp_file_path, language, prompt, response_format, temperature, translate); - } catch (const std::exception &e) { + } catch (const std::exception& e) { std::remove(temp_file_path.c_str()); Json::Value jsonResp; jsonResp["message"] = e.what(); @@ -1064,14 +1070,25 @@ void whisperCPP::transcription_impl( return; } -void whisperCPP::transcription( - const HttpRequestPtr &req, - std::function &&callback) { - return transcription_impl(req, std::move(callback), false); +// TODO: Unimplemented +void audio::ModelStatus( + const HttpRequestPtr& req, + std::function&& callback) { + Json::Value jsonResp; + jsonResp["message"] = "Unimplemented"; + auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); + resp->setStatusCode(k404NotFound); + callback(resp); +} + +void audio::CreateTranscription( + const HttpRequestPtr& req, + std::function&& callback) { + return TranscriptionImpl(req, std::move(callback), false); } -void whisperCPP::translation( - const HttpRequestPtr &req, - std::function &&callback) { - return transcription_impl(req, std::move(callback), true); +void audio::CreateTranslation( + const HttpRequestPtr& req, + std::function&& callback) { + return TranscriptionImpl(req, std::move(callback), true); } \ No newline at end of file diff --git a/controllers/whisperCPP.h b/controllers/audio.h similarity index 51% rename from controllers/whisperCPP.h rename to controllers/audio.h index 77b4b3898..fb4f5730e 100644 --- a/controllers/whisperCPP.h +++ b/controllers/audio.h @@ -1,10 +1,11 @@ #pragma once -#include "whisper.h" #include +#include #include #include -#include +#include "common/base.h" +#include "whisper.h" #define DR_WAV_IMPLEMENTATION #include "utils/dr_wav.h" @@ -74,15 +75,15 @@ struct whisper_params { // [TDRZ] speaker turn string std::string tdrz_speaker_turn = - " [SPEAKER_TURN]"; // TODO: set from command line + " [SPEAKER_TURN]"; // TODO: set from command line std::string openvino_encode_device = "CPU"; }; struct whisper_print_user_data { - const whisper_params *params; + const whisper_params* params; - const std::vector> *pcmf32s; + const std::vector>* pcmf32s; int progress_prev; }; @@ -92,16 +93,16 @@ struct whisper_print_user_data { // The sample rate of the audio must be equal to COMMON_SAMPLE_RATE // If stereo flag is set and the audio has 2 channels, the pcmf32s will contain // 2 channel PCM -bool read_wav(const std::string &fname, std::vector &pcmf32, - std::vector> &pcmf32s, bool stereo); +bool read_wav(const std::string& fname, std::vector& pcmf32, + std::vector>& pcmf32s, bool stereo); -std::string output_str(struct whisper_context *ctx, - const whisper_params ¶ms, +std::string output_str(struct whisper_context* ctx, + const whisper_params& params, std::vector> pcmf32s); -std::string -estimate_diarization_speaker(std::vector> pcmf32s, - int64_t t0, int64_t t1, bool id_only = false); +std::string estimate_diarization_speaker( + std::vector> pcmf32s, int64_t t0, int64_t t1, + bool id_only = false); // 500 -> 00:05.000 // 6000 -> 01:00.000 @@ -109,26 +110,26 @@ std::string to_timestamp(int64_t t, bool comma = false); int timestamp_to_sample(int64_t t, int n_samples); -bool is_file_exist(const char *fileName); +bool is_file_exist(const char* fileName); -void whisper_print_usage(int /*argc*/, char **argv, - const whisper_params ¶ms); +void whisper_print_usage(int /*argc*/, char** argv, + const whisper_params& params); -bool whisper_params_parse(int argc, char **argv, whisper_params ¶ms); +bool whisper_params_parse(int argc, char** argv, whisper_params& params); void check_ffmpeg_availibility(); -bool convert_to_wav(const std::string &temp_filename, std::string &error_resp); +bool convert_to_wav(const std::string& temp_filename, std::string& error_resp); -void whisper_print_progress_callback(struct whisper_context * /*ctx*/, - struct whisper_state * /*state*/, - int progress, void *user_data); +void whisper_print_progress_callback(struct whisper_context* /*ctx*/, + struct whisper_state* /*state*/, + int progress, void* user_data); -void whisper_print_segment_callback(struct whisper_context *ctx, - struct whisper_state * /*state*/, int n_new, - void *user_data); +void whisper_print_segment_callback(struct whisper_context* ctx, + struct whisper_state* /*state*/, int n_new, + void* user_data); -bool parse_str_to_bool(const std::string &s); +bool parse_str_to_bool(const std::string& s); struct whisper_server_context { whisper_params params; @@ -137,12 +138,12 @@ struct whisper_server_context { std::string model_id; struct whisper_context_params cparams; - struct whisper_context *ctx = nullptr; + struct whisper_context* ctx = nullptr; - whisper_server_context() = default; // add this line + whisper_server_context() = default; // add this line // Constructor - whisper_server_context(const std::string &model_id) { + whisper_server_context(const std::string& model_id) { this->model_id = model_id; this->cparams = whisper_context_params(); this->ctx = nullptr; @@ -152,20 +153,21 @@ struct whisper_server_context { } // Move constructor - whisper_server_context(whisper_server_context &&other) noexcept + whisper_server_context(whisper_server_context&& other) noexcept : params(std::move(other.params)), default_params(std::move(other.default_params)), - whisper_mutex() // std::mutex is not movable, so we initialize a new one + whisper_mutex() // std::mutex is not movable, so we initialize a new one , - model_id(std::move(other.model_id)), cparams(std::move(other.cparams)), + model_id(std::move(other.model_id)), + cparams(std::move(other.cparams)), ctx(std::exchange( other.ctx, - nullptr)) // ctx is a raw pointer, so we use std::exchange + nullptr)) // ctx is a raw pointer, so we use std::exchange {} - bool load_model(std::string &model_path); + bool load_model(std::string& model_path); - std::string inference(std::string &input_file_path, std::string languague, + std::string inference(std::string& input_file_path, std::string languague, std::string prompt, std::string response_format, float temperature, bool translate); @@ -174,45 +176,59 @@ struct whisper_server_context { using namespace drogon; -class whisperCPP : public drogon::HttpController { -public: +namespace v1 { +class audio : public drogon::HttpController