From 78966f4fd6ced3539a48217f9b4cc0d2f66a4cd5 Mon Sep 17 00:00:00 2001 From: hiro Date: Sat, 16 Dec 2023 12:52:12 +0700 Subject: [PATCH 01/31] chore: init gitsubmodule whisper.cpp --- .gitmodules | 3 +++ whisper.cpp | 1 + 2 files changed, 4 insertions(+) create mode 160000 whisper.cpp diff --git a/.gitmodules b/.gitmodules index a10a6776d..e2f71d456 100644 --- a/.gitmodules +++ b/.gitmodules @@ -2,3 +2,6 @@ path = llama.cpp url = https://github.com/ggerganov/llama.cpp branch = master +[submodule "whisper.cpp"] + path = whisper.cpp + url = https://github.com/ggerganov/whisper.cpp.git diff --git a/whisper.cpp b/whisper.cpp new file mode 160000 index 000000000..940de9dbe --- /dev/null +++ b/whisper.cpp @@ -0,0 +1 @@ +Subproject commit 940de9dbe9c90624dc99521cb34c8a97b86d543c From e4713f920b05694e42a5c117f9697257afd0d4ef Mon Sep 17 00:00:00 2001 From: hiro Date: Sat, 16 Dec 2023 12:52:22 +0700 Subject: [PATCH 02/31] chore: Add whisper.cpp to cmake file --- CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 39c988e09..220fa9720 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -42,6 +42,7 @@ if(DEBUG) endif() add_subdirectory(llama.cpp) +add_subdirectory(whisper.cpp) add_executable(${PROJECT_NAME} main.cc) # ############################################################################## From 72e6ff7f0cff1ca90dde0233b2221115cf0e5b90 Mon Sep 17 00:00:00 2001 From: hiro Date: Sat, 16 Dec 2023 12:52:46 +0700 Subject: [PATCH 03/31] chore: Add setUpload path --- main.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/main.cc b/main.cc index a1c6187fd..46d5e92a5 100644 --- a/main.cc +++ b/main.cc @@ -38,8 +38,9 @@ int main(int argc, char *argv[]) { nitro_utils::nitro_logo(); LOG_INFO << "Server started, listening at: " << host << ":" << port; LOG_INFO << "Please load your model"; - drogon::app().addListener(host, port); - drogon::app().setThreadNum(thread_num + 1); + drogon::app().addListener(host, port) + .setThreadNum(thread_num + 1) + .setUploadPath("./uploads"); LOG_INFO << "Number of thread is:" << drogon::app().getThreadNum(); drogon::app().run(); From 3dd606ca1097ff2a2ec6c46097325882826f6365 Mon Sep 17 00:00:00 2001 From: hiro Date: Sat, 16 Dec 2023 12:53:09 +0700 Subject: [PATCH 04/31] feat: Add transcription function --- controllers/llamaCPP.cc | 228 ++++++++++++++++++++++++++++------------ 1 file changed, 163 insertions(+), 65 deletions(-) diff --git a/controllers/llamaCPP.cc b/controllers/llamaCPP.cc index b76140a25..6bf63b036 100644 --- a/controllers/llamaCPP.cc +++ b/controllers/llamaCPP.cc @@ -14,7 +14,8 @@ using namespace inferences; using json = nlohmann::json; // To store state of each inference request -struct State { +struct State +{ bool isStopped = false; int task_id; llamaCPP *instance; @@ -22,14 +23,16 @@ struct State { State(int tid, llamaCPP *inst) : task_id(tid), instance(inst) {} }; -std::shared_ptr createState(int task_id, llamaCPP *instance) { +std::shared_ptr createState(int task_id, llamaCPP *instance) +{ return std::make_shared(task_id, instance); } // -------------------------------------------- std::string create_embedding_payload(const std::vector &embedding, - int prompt_tokens) { + int prompt_tokens) +{ Json::Value root; root["object"] = "list"; @@ -40,7 +43,8 @@ std::string create_embedding_payload(const std::vector &embedding, dataItem["object"] = "embedding"; Json::Value embeddingArray(Json::arrayValue); - for (const auto &value : embedding) { + for (const auto &value : embedding) + { embeddingArray.append(value); } dataItem["embedding"] = embeddingArray; @@ -67,7 +71,8 @@ std::string create_full_return_json(const std::string &id, const std::string &content, const std::string &system_fingerprint, int prompt_tokens, int completion_tokens, - Json::Value finish_reason = Json::Value()) { + Json::Value finish_reason = Json::Value()) +{ Json::Value root; @@ -103,7 +108,8 @@ std::string create_full_return_json(const std::string &id, std::string create_return_json(const std::string &id, const std::string &model, const std::string &content, - Json::Value finish_reason = Json::Value()) { + Json::Value finish_reason = Json::Value()) +{ Json::Value root; @@ -130,7 +136,8 @@ std::string create_return_json(const std::string &id, const std::string &model, return Json::writeString(writer, root); } -void llamaCPP::warmupModel() { +void llamaCPP::warmupModel() +{ json pseudo; pseudo["prompt"] = "Hello"; @@ -139,7 +146,8 @@ void llamaCPP::warmupModel() { const int task_id = llama.request_completion(pseudo, false, false); std::string completion_text; task_result result = llama.next_result(task_id); - if (!result.error && result.stop) { + if (!result.error && result.stop) + { LOG_INFO << result.result_json.dump(-1, ' ', false, json::error_handler_t::replace); } @@ -148,7 +156,8 @@ void llamaCPP::warmupModel() { void llamaCPP::chatCompletionPrelight( const HttpRequestPtr &req, - std::function &&callback) { + std::function &&callback) +{ auto resp = drogon::HttpResponse::newHttpResponse(); resp->setStatusCode(drogon::HttpStatusCode::k200OK); resp->addHeader("Access-Control-Allow-Origin", "*"); @@ -159,9 +168,11 @@ void llamaCPP::chatCompletionPrelight( void llamaCPP::chatCompletion( const HttpRequestPtr &req, - std::function &&callback) { + std::function &&callback) +{ - if (!model_loaded) { + if (!model_loaded) + { Json::Value jsonResp; jsonResp["message"] = "Model has not been loaded, please load model into nitro"; @@ -178,10 +189,12 @@ void llamaCPP::chatCompletion( int no_images = 0; // To set default value - if (jsonBody) { + if (jsonBody) + { // Increase number of chats received and clean the prompt no_of_chats++; - if (no_of_chats % clean_cache_threshold == 0) { + if (no_of_chats % clean_cache_threshold == 0) + { LOG_INFO << "Clean cache threshold reached!"; llama.kv_cache_clear(); LOG_INFO << "Cache cleaned"; @@ -203,47 +216,63 @@ void llamaCPP::chatCompletion( data["presence_penalty"] = (*jsonBody).get("presence_penalty", 0).asFloat(); const Json::Value &messages = (*jsonBody)["messages"]; - if (!llama.multimodal) { + if (!llama.multimodal) + { - for (const auto &message : messages) { + for (const auto &message : messages) + { std::string input_role = message["role"].asString(); std::string role; - if (input_role == "user") { + if (input_role == "user") + { role = user_prompt; std::string content = message["content"].asString(); formatted_output += role + content; - } else if (input_role == "assistant") { + } + else if (input_role == "assistant") + { role = ai_prompt; std::string content = message["content"].asString(); formatted_output += role + content; - } else if (input_role == "system") { + } + else if (input_role == "system") + { role = system_prompt; std::string content = message["content"].asString(); formatted_output = role + content + formatted_output; - - } else { + } + else + { role = input_role; std::string content = message["content"].asString(); formatted_output += role + content; } } formatted_output += ai_prompt; - } else { + } + else + { data["image_data"] = json::array(); - for (const auto &message : messages) { + for (const auto &message : messages) + { std::string input_role = message["role"].asString(); std::string role; - if (input_role == "user") { + if (input_role == "user") + { formatted_output += role; - for (auto content_piece : message["content"]) { + for (auto content_piece : message["content"]) + { role = user_prompt; auto content_piece_type = content_piece["type"].asString(); - if (content_piece_type == "text") { + if (content_piece_type == "text") + { auto text = content_piece["text"].asString(); formatted_output += text; - } else if (content_piece_type == "image_url") { + } + else if (content_piece_type == "image_url") + { auto image_url = content_piece["image_url"]["url"].asString(); auto base64_image_data = nitro_utils::extractBase64(image_url); LOG_INFO << base64_image_data; @@ -256,17 +285,21 @@ void llamaCPP::chatCompletion( no_images++; } } - - } else if (input_role == "assistant") { + } + else if (input_role == "assistant") + { role = ai_prompt; std::string content = message["content"].asString(); formatted_output += role + content; - } else if (input_role == "system") { + } + else if (input_role == "system") + { role = system_prompt; std::string content = message["content"].asString(); formatted_output = role + content + formatted_output; - - } else { + } + else + { role = input_role; std::string content = message["content"].asString(); formatted_output += role + content; @@ -277,7 +310,8 @@ void llamaCPP::chatCompletion( } data["prompt"] = formatted_output; - for (const auto &stop_word : (*jsonBody)["stop"]) { + for (const auto &stop_word : (*jsonBody)["stop"]) + { stopWords.push_back(stop_word.asString()); } // specify default stop words @@ -296,22 +330,27 @@ void llamaCPP::chatCompletion( const int task_id = llama.request_completion(data, false, false); LOG_INFO << "Resolved request for task_id:" << task_id; - if (is_streamed) { + if (is_streamed) + { auto state = createState(task_id, this); auto chunked_content_provider = - [state](char *pBuffer, std::size_t nBuffSize) -> std::size_t { - if (!pBuffer) { + [state](char *pBuffer, std::size_t nBuffSize) -> std::size_t + { + if (!pBuffer) + { LOG_INFO << "Connection closed or buffer is null. Reset context"; state->instance->llama.request_cancel(state->task_id); return 0; } - if (state->isStopped) { + if (state->isStopped) + { return 0; } task_result result = state->instance->llama.next_result(state->task_id); - if (!result.error) { + if (!result.error) + { const std::string to_send = result.result_json["content"]; const std::string str = "data: " + @@ -322,7 +361,8 @@ void llamaCPP::chatCompletion( std::size_t nRead = std::min(str.size(), nBuffSize); memcpy(pBuffer, str.data(), nRead); - if (result.stop) { + if (result.stop) + { const std::string str = "data: " + create_return_json(nitro_utils::generate_random_string(20), "_", @@ -338,7 +378,9 @@ void llamaCPP::chatCompletion( return nRead; } return nRead; - } else { + } + else + { return 0; } return 0; @@ -348,14 +390,18 @@ void llamaCPP::chatCompletion( callback(resp); return; - } else { + } + else + { Json::Value respData; auto resp = nitro_utils::nitroHttpResponse(); respData["testing"] = "thunghiem value moi"; - if (!json_value(data, "stream", false)) { + if (!json_value(data, "stream", false)) + { std::string completion_text; task_result result = llama.next_result(task_id); - if (!result.error && result.stop) { + if (!result.error && result.stop) + { int prompt_tokens = result.result_json["tokens_evaluated"]; int predicted_tokens = result.result_json["tokens_predicted"]; std::string full_return = @@ -363,7 +409,9 @@ void llamaCPP::chatCompletion( "_", result.result_json["content"], "_", prompt_tokens, predicted_tokens); resp->setBody(full_return); - } else { + } + else + { resp->setBody("internal error during inference"); return; } @@ -374,13 +422,17 @@ void llamaCPP::chatCompletion( } void llamaCPP::embedding( const HttpRequestPtr &req, - std::function &&callback) { + std::function &&callback) +{ const auto &jsonBody = req->getJsonObject(); json prompt; - if (jsonBody->isMember("input") != 0) { + if (jsonBody->isMember("input") != 0) + { prompt = (*jsonBody)["input"].asString(); - } else { + } + else + { prompt = ""; } const int task_id = llama.request_completion( @@ -395,12 +447,39 @@ void llamaCPP::embedding( return; } +void llamaCPP::transcription( + const HttpRequestPtr &req, + std::function &&callback) +{ + MultiPartParser fileUpload; + if (fileUpload.parse(req) != 0 || fileUpload.getFiles().size() != 1) + { + auto resp = HttpResponse::newHttpResponse(); + resp->setBody("Must only be one file"); + resp->setStatusCode(k403Forbidden); + callback(resp); + return; + } + auto &file = fileUpload.getFiles()[0]; + auto md5 = file.getMd5(); + auto resp = HttpResponse::newHttpResponse(); + resp->setBody( + "The server has calculated the file's MD5 hash to be " + md5); + file.save(); + LOG_INFO << "The uploaded file has been saved to the ./uploads " + "directory"; + callback(resp); + return; +} + void llamaCPP::unloadModel( const HttpRequestPtr &req, - std::function &&callback) { + std::function &&callback) +{ Json::Value jsonResp; jsonResp["message"] = "No model loaded"; - if (model_loaded) { + if (model_loaded) + { stopBackgroundTask(); llama_free(llama.ctx); @@ -415,13 +494,17 @@ void llamaCPP::unloadModel( } void llamaCPP::modelStatus( const HttpRequestPtr &req, - std::function &&callback) { + std::function &&callback) +{ Json::Value jsonResp; bool is_model_loaded = this->model_loaded; - if (is_model_loaded) { + if (is_model_loaded) + { jsonResp["model_loaded"] = is_model_loaded; jsonResp["model_data"] = llama.get_model_props().dump(); - } else { + } + else + { jsonResp["model_loaded"] = is_model_loaded; } @@ -430,15 +513,18 @@ void llamaCPP::modelStatus( return; } -bool llamaCPP::loadModelImpl(const Json::Value &jsonBody) { +bool llamaCPP::loadModelImpl(const Json::Value &jsonBody) +{ gpt_params params; // By default will setting based on number of handlers int drogon_thread = drogon::app().getThreadNum() - 1; LOG_INFO << "Drogon thread is:" << drogon_thread; - if (jsonBody) { - if (!jsonBody["mmproj"].isNull()) { + if (jsonBody) + { + if (!jsonBody["mmproj"].isNull()) + { LOG_INFO << "MMPROJ FILE detected, multi-model enabled!"; params.mmproj = jsonBody["mmproj"].asString(); } @@ -468,7 +554,8 @@ bool llamaCPP::loadModelImpl(const Json::Value &jsonBody) { LOG_INFO << "Setting up GGML CUBLAS PARAMS"; params.mul_mat_q = false; #endif // GGML_USE_CUBLAS - if (params.model_alias == "unknown") { + if (params.model_alias == "unknown") + { params.model_alias = params.model; } @@ -484,7 +571,8 @@ bool llamaCPP::loadModelImpl(const Json::Value &jsonBody) { }); // load the model - if (!llama.load_model(params)) { + if (!llama.load_model(params)) + { LOG_ERROR << "Error loading the model"; return false; // Indicate failure } @@ -498,9 +586,11 @@ bool llamaCPP::loadModelImpl(const Json::Value &jsonBody) { void llamaCPP::loadModel( const HttpRequestPtr &req, - std::function &&callback) { + std::function &&callback) +{ - if (model_loaded) { + if (model_loaded) + { LOG_INFO << "model loaded"; Json::Value jsonResp; jsonResp["message"] = "Model already loaded"; @@ -511,14 +601,17 @@ void llamaCPP::loadModel( } const auto &jsonBody = req->getJsonObject(); - if (!loadModelImpl(*jsonBody)) { + if (!loadModelImpl(*jsonBody)) + { // Error occurred during model loading Json::Value jsonResp; jsonResp["message"] = "Failed to load model"; auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); resp->setStatusCode(drogon::k500InternalServerError); callback(resp); - } else { + } + else + { // Model loaded successfully Json::Value jsonResp; jsonResp["message"] = "Model loaded successfully"; @@ -527,8 +620,10 @@ void llamaCPP::loadModel( } } -void llamaCPP::backgroundTask() { - while (model_loaded) { +void llamaCPP::backgroundTask() +{ + while (model_loaded) + { // model_loaded = llama.update_slots(); } @@ -538,11 +633,14 @@ void llamaCPP::backgroundTask() { return; } -void llamaCPP::stopBackgroundTask() { - if (model_loaded) { +void llamaCPP::stopBackgroundTask() +{ + if (model_loaded) + { model_loaded = false; LOG_INFO << "changed to false"; - if (backgroundThread.joinable()) { + if (backgroundThread.joinable()) + { backgroundThread.join(); } } From aea075702ca4fe92a30dc808f2607b2d29d7dabd Mon Sep 17 00:00:00 2001 From: hiro Date: Sat, 16 Dec 2023 12:53:29 +0700 Subject: [PATCH 05/31] feat: Add transcription function declaration --- controllers/llamaCPP.h | 932 ++++++++++++++++++++++++++--------------- 1 file changed, 605 insertions(+), 327 deletions(-) diff --git a/controllers/llamaCPP.h b/controllers/llamaCPP.h index a7f8762b4..7d5c95670 100644 --- a/controllers/llamaCPP.h +++ b/controllers/llamaCPP.h @@ -45,7 +45,8 @@ using json = nlohmann::json; -struct server_params { +struct server_params +{ std::string hostname = "127.0.0.1"; std::string public_path = "examples/server/public"; int32_t port = 8080; @@ -58,19 +59,21 @@ static bool server_verbose = false; #if SERVER_VERBOSE != 1 #define LOG_VERBOSE(MSG, ...) #else -#define LOG_VERBOSE(MSG, ...) \ - do { \ - if (server_verbose) { \ - server_log("VERBOSE", __func__, __LINE__, MSG, __VA_ARGS__); \ - } \ +#define LOG_VERBOSE(MSG, ...) \ + do \ + { \ + if (server_verbose) \ + { \ + server_log("VERBOSE", __func__, __LINE__, MSG, __VA_ARGS__); \ + } \ } while (0) #endif -#define LOG_ERROR_LLAMA(MSG, ...) \ +#define LOG_ERROR_LLAMA(MSG, ...) \ server_log("ERROR", __func__, __LINE__, MSG, __VA_ARGS__) -#define LOG_WARNING_LLAMA(MSG, ...) \ +#define LOG_WARNING_LLAMA(MSG, ...) \ server_log("WARNING", __func__, __LINE__, MSG, __VA_ARGS__) -#define LOG_INFO_LLAMA(MSG, ...) \ +#define LOG_INFO_LLAMA(MSG, ...) \ server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__) // @@ -81,11 +84,13 @@ static const std::string base64_chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" "abcdefghijklmnopqrstuvwxyz" "0123456789+/"; -static inline bool is_base64(uint8_t c) { +static inline bool is_base64(uint8_t c) +{ return (isalnum(c) || (c == '+') || (c == '/')); } -static std::vector base64_decode(std::string const &encoded_string) { +static std::vector base64_decode(std::string const &encoded_string) +{ int i = 0; int j = 0; int in_ = 0; @@ -98,11 +103,14 @@ static std::vector base64_decode(std::string const &encoded_string) { std::vector ret; while (in_len-- && (encoded_string[in_] != '=') && - is_base64(encoded_string[in_])) { + is_base64(encoded_string[in_])) + { char_array_4[i++] = encoded_string[in_]; in_++; - if (i == 4) { - for (i = 0; i < 4; i++) { + if (i == 4) + { + for (i = 0; i < 4; i++) + { char_array_4[i] = base64_chars.find(char_array_4[i]); } @@ -112,19 +120,23 @@ static std::vector base64_decode(std::string const &encoded_string) { ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; - for (i = 0; (i < 3); i++) { + for (i = 0; (i < 3); i++) + { ret.push_back(char_array_3[i]); } i = 0; } } - if (i) { - for (j = i; j < 4; j++) { + if (i) + { + for (j = i; j < 4; j++) + { char_array_4[j] = 0; } - for (j = 0; j < 4; j++) { + for (j = 0; j < 4; j++) + { char_array_4[j] = base64_chars.find(char_array_4[j]); } @@ -134,7 +146,8 @@ static std::vector base64_decode(std::string const &encoded_string) { ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; - for (j = 0; (j < i - 1); j++) { + for (j = 0; (j < i - 1); j++) + { ret.push_back(char_array_3[j]); } } @@ -146,9 +159,14 @@ static std::vector base64_decode(std::string const &encoded_string) { // parallel // -enum task_type { COMPLETION_TASK, CANCEL_TASK }; +enum task_type +{ + COMPLETION_TASK, + CANCEL_TASK +}; -struct task_server { +struct task_server +{ int id; int target_id; task_type type; @@ -157,7 +175,8 @@ struct task_server { bool embedding_mode = false; }; -struct task_result { +struct task_result +{ int id; bool stop; bool error; @@ -165,18 +184,21 @@ struct task_result { }; // TODO: can become bool if we can't find use of more states -enum slot_state { +enum slot_state +{ IDLE, PROCESSING, }; -enum slot_command { +enum slot_command +{ NONE, LOAD_PROMPT, RELEASE, }; -struct slot_params { +struct slot_params +{ bool stream = true; bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt @@ -191,7 +213,8 @@ struct slot_params { json input_suffix; }; -struct slot_image { +struct slot_image +{ int32_t id; bool request_encode_image = false; @@ -204,8 +227,10 @@ struct slot_image { }; // completion token output with probabilities -struct completion_token_output { - struct token_prob { +struct completion_token_output +{ + struct token_prob + { llama_token tok; float prob; }; @@ -216,31 +241,40 @@ struct completion_token_output { }; static size_t common_part(const std::vector &a, - const std::vector &b) { + const std::vector &b) +{ size_t i; - for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) { + for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) + { } return i; } -enum stop_type { +enum stop_type +{ STOP_FULL, STOP_PARTIAL, }; -static bool ends_with(const std::string &str, const std::string &suffix) { +static bool ends_with(const std::string &str, const std::string &suffix) +{ return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); } static size_t find_partial_stop_string(const std::string &stop, - const std::string &text) { - if (!text.empty() && !stop.empty()) { + const std::string &text) +{ + if (!text.empty() && !stop.empty()) + { const char text_last_char = text.back(); - for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) { - if (stop[char_index] == text_last_char) { + for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) + { + if (stop[char_index] == text_last_char) + { const std::string current_partial = stop.substr(0, char_index + 1); - if (ends_with(text, current_partial)) { + if (ends_with(text, current_partial)) + { return text.size() - char_index - 1; } } @@ -251,9 +285,11 @@ static size_t find_partial_stop_string(const std::string &stop, // TODO: reuse llama_detokenize template -static std::string tokens_to_str(llama_context *ctx, Iter begin, Iter end) { +static std::string tokens_to_str(llama_context *ctx, Iter begin, Iter end) +{ std::string ret; - for (; begin != end; ++begin) { + for (; begin != end; ++begin) + { ret += llama_token_to_piece(ctx, *begin); } return ret; @@ -261,14 +297,18 @@ static std::string tokens_to_str(llama_context *ctx, Iter begin, Iter end) { static void server_log(const char *level, const char *function, int line, const char *message, - const nlohmann::ordered_json &extra) { + const nlohmann::ordered_json &extra) +{ nlohmann::ordered_json log{ - {"timestamp", time(nullptr)}, {"level", level}, - {"function", function}, {"line", line}, + {"timestamp", time(nullptr)}, + {"level", level}, + {"function", function}, + {"line", line}, {"message", message}, }; - if (!extra.empty()) { + if (!extra.empty()) + { log.merge_patch(extra); } @@ -279,11 +319,13 @@ static void server_log(const char *level, const char *function, int line, // format incomplete utf-8 multibyte character for output static std::string tokens_to_output_formatted_string(const llama_context *ctx, - const llama_token token) { + const llama_token token) +{ std::string out = token == -1 ? "" : llama_token_to_piece(ctx, token); // if the size is 1 and first bit is 1, meaning it's a partial character // (size > 1 meaning it's already a known token) - if (out.size() == 1 && (out[0] & 0x80) == 0x80) { + if (out.size() == 1 && (out[0] & 0x80) == 0x80) + { std::stringstream ss; ss << std::hex << (out[0] & 0xff); std::string res(ss.str()); @@ -295,11 +337,14 @@ static std::string tokens_to_output_formatted_string(const llama_context *ctx, // convert a vector of completion_token_output to json static json probs_vector_to_json(const llama_context *ctx, - const std::vector &probs) { + const std::vector &probs) +{ json out = json::array(); - for (const auto &prob : probs) { + for (const auto &prob : probs) + { json probs_for_token = json::array(); - for (const auto &p : prob.probs) { + for (const auto &p : prob.probs) + { std::string tok_str = tokens_to_output_formatted_string(ctx, p.tok); probs_for_token.push_back(json{ {"tok_str", tok_str}, @@ -317,14 +362,16 @@ probs_vector_to_json(const llama_context *ctx, template static T json_value(const json &body, const std::string &key, - const T &default_value) { + const T &default_value) +{ // Fallback null to default value return body.contains(key) && !body.at(key).is_null() ? body.value(key, default_value) : default_value; } -struct llama_client_slot { +struct llama_client_slot +{ int id; int task_id = -1; @@ -380,7 +427,8 @@ struct llama_client_slot { double t_prompt_processing; // ms double t_token_generation; // ms - void reset() { + void reset() + { num_prompt_tokens = 0; generated_text = ""; truncated = false; @@ -396,7 +444,8 @@ struct llama_client_slot { generated_token_probs.clear(); - for (slot_image &img : images) { + for (slot_image &img : images) + { free(img.image_embedding); delete[] img.img_data.data; img.prefix_prompt = ""; @@ -406,11 +455,15 @@ struct llama_client_slot { // llama_set_rng_seed(ctx, params.seed); in batched the seed matter??????? } - bool has_budget(gpt_params &global_params) { + bool has_budget(gpt_params &global_params) + { n_remaining = -1; - if (params.n_predict != -1) { + if (params.n_predict != -1) + { n_remaining = params.n_predict - n_decoded; - } else if (global_params.n_predict != -1) { + } + else if (global_params.n_predict != -1) + { n_remaining = global_params.n_predict - n_decoded; } return n_remaining > 0 || n_remaining == -1; // no budget || limitless @@ -418,26 +471,32 @@ struct llama_client_slot { bool available() const { return state == IDLE && command == NONE; } - bool is_processing() const { + bool is_processing() const + { return (state == IDLE && command == LOAD_PROMPT) || state == PROCESSING; } - void add_token_string(const completion_token_output &token) { - if (command == RELEASE) { + void add_token_string(const completion_token_output &token) + { + if (command == RELEASE) + { return; } cache_tokens.push_back(token.tok); generated_token_probs.push_back(token); } - void release() { - if (state == IDLE || state == PROCESSING) { + void release() + { + if (state == IDLE || state == PROCESSING) + { t_token_generation = (ggml_time_us() - t_start_genereration) / 1e3; command = RELEASE; } } - json get_formated_timings() { + json get_formated_timings() + { return json{ {"prompt_n", num_prompt_tokens_processed}, {"prompt_ms", t_prompt_processing}, @@ -453,7 +512,8 @@ struct llama_client_slot { }; } - void print_timings() { + void print_timings() + { LOG_TEE("\n"); LOG_TEE("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per " "token, %8.2f tokens per second)\n", @@ -470,7 +530,8 @@ struct llama_client_slot { } }; -struct llama_server_context { +struct llama_server_context +{ llama_model *model = nullptr; llama_context *ctx = nullptr; @@ -504,45 +565,55 @@ struct llama_server_context { std::mutex mutex_tasks; std::mutex mutex_results; - ~llama_server_context() { - if (ctx) { + ~llama_server_context() + { + if (ctx) + { llama_free(ctx); ctx = nullptr; } - if (model) { + if (model) + { llama_free_model(model); model = nullptr; } } - bool load_model(const gpt_params ¶ms_) { + bool load_model(const gpt_params ¶ms_) + { params = params_; - if (!params.mmproj.empty()) { + if (!params.mmproj.empty()) + { multimodal = true; LOG_TEE("Multi Modal Mode Enabled"); clp_ctx = clip_model_load(params.mmproj.c_str(), /*verbosity=*/1); - if (clp_ctx == nullptr) { + if (clp_ctx == nullptr) + { LOG_ERROR_LLAMA("unable to load clip model", {{"model", params.mmproj}}); return false; } if (params.n_ctx < - 2048) { // request larger context for the image embedding + 2048) + { // request larger context for the image embedding params.n_ctx = 2048; } } std::tie(model, ctx) = llama_init_from_gpt_params(params); - if (model == nullptr) { + if (model == nullptr) + { LOG_ERROR_LLAMA("unable to load model", {{"model", params.model}}); return false; } - if (multimodal) { + if (multimodal) + { const int n_embd_clip = clip_n_mmproj_embd(clp_ctx); const int n_embd_llm = llama_n_embd(model); - if (n_embd_clip != n_embd_llm) { + if (n_embd_clip != n_embd_llm) + { LOG_TEE("%s: embedding dim of the multimodal projector (%d) is not " "equal to that of LLaMA (%d). Make sure that you use the " "correct mmproj file.\n", @@ -558,7 +629,8 @@ struct llama_server_context { return true; } - void initialize() { + void initialize() + { id_gen = 0; // create slots @@ -567,7 +639,8 @@ struct llama_server_context { const int32_t n_ctx_slot = n_ctx / params.n_parallel; LOG_TEE("Available slots:\n"); - for (int i = 0; i < params.n_parallel; i++) { + for (int i = 0; i < params.n_parallel; i++) + { llama_client_slot slot; slot.id = i; @@ -586,32 +659,44 @@ struct llama_server_context { } std::vector tokenize(const json &json_prompt, - bool add_bos) const { + bool add_bos) const + { // If `add_bos` is true, we only add BOS, when json_prompt is a string, // or the first element of the json_prompt array is a string. std::vector prompt_tokens; - if (json_prompt.is_array()) { + if (json_prompt.is_array()) + { bool first = true; - for (const auto &p : json_prompt) { - if (p.is_string()) { + for (const auto &p : json_prompt) + { + if (p.is_string()) + { auto s = p.template get(); std::vector p; - if (first) { + if (first) + { p = ::llama_tokenize(ctx, s, add_bos); first = false; - } else { + } + else + { p = ::llama_tokenize(ctx, s, false); } prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end()); - } else { - if (first) { + } + else + { + if (first) + { first = false; } prompt_tokens.push_back(p.template get()); } } - } else { + } + else + { auto s = json_prompt.template get(); prompt_tokens = ::llama_tokenize(ctx, s, add_bos); } @@ -619,16 +704,20 @@ struct llama_server_context { return prompt_tokens; } - llama_client_slot *get_slot(int id) { + llama_client_slot *get_slot(int id) + { int64_t t_last = ggml_time_us(); llama_client_slot *last_used = nullptr; - for (llama_client_slot &slot : slots) { - if (slot.id == id && slot.available()) { + for (llama_client_slot &slot : slots) + { + if (slot.id == id && slot.available()) + { return &slot; } - if (slot.available() && slot.t_last_used < t_last) { + if (slot.available() && slot.t_last_used < t_last) + { last_used = &slot; t_last = slot.t_last_used; } @@ -637,7 +726,8 @@ struct llama_server_context { return last_used; } - bool launch_slot_with_data(llama_client_slot *&slot, json data) { + bool launch_slot_with_data(llama_client_slot *&slot, json data) + { slot_params default_params; llama_sampling_params default_sparams; @@ -675,40 +765,57 @@ struct llama_server_context { json_value(data, "n_probs", default_sparams.n_probs); // infill - if (data.count("input_prefix") != 0) { + if (data.count("input_prefix") != 0) + { slot->params.input_prefix = data["input_prefix"]; - } else { + } + else + { slot->params.input_prefix = ""; } - if (data.count("input_suffix") != 0) { + if (data.count("input_suffix") != 0) + { slot->params.input_suffix = data["input_suffix"]; - } else { + } + else + { slot->params.input_suffix = ""; } - if (data.count("prompt") != 0) { + if (data.count("prompt") != 0) + { slot->prompt = data["prompt"]; - } else { + } + else + { slot->prompt = ""; } slot->sparams.logit_bias.clear(); - if (json_value(data, "ignore_eos", false)) { + if (json_value(data, "ignore_eos", false)) + { slot->sparams.logit_bias[llama_token_eos(model)] = -INFINITY; } const auto &logit_bias = data.find("logit_bias"); - if (logit_bias != data.end() && logit_bias->is_array()) { + if (logit_bias != data.end() && logit_bias->is_array()) + { const int n_vocab = llama_n_vocab(model); - for (const auto &el : *logit_bias) { - if (el.is_array() && el.size() == 2 && el[0].is_number_integer()) { + for (const auto &el : *logit_bias) + { + if (el.is_array() && el.size() == 2 && el[0].is_number_integer()) + { llama_token tok = el[0].get(); - if (tok >= 0 && tok < n_vocab) { - if (el[1].is_number()) { + if (tok >= 0 && tok < n_vocab) + { + if (el[1].is_number()) + { slot->sparams.logit_bias[tok] = el[1].get(); - } else if (el[1].is_boolean() && !el[1].get()) { + } + else if (el[1].is_boolean() && !el[1].get()) + { slot->sparams.logit_bias[tok] = -INFINITY; } } @@ -719,18 +826,24 @@ struct llama_server_context { slot->params.antiprompt.clear(); const auto &stop = data.find("stop"); - if (stop != data.end() && stop->is_array()) { - for (const auto &word : *stop) { - if (!word.empty()) { + if (stop != data.end() && stop->is_array()) + { + for (const auto &word : *stop) + { + if (!word.empty()) + { slot->params.antiprompt.push_back(word); } } } - if (multimodal) { + if (multimodal) + { const auto &images_data = data.find("image_data"); - if (images_data != data.end() && images_data->is_array()) { - for (const auto &img : *images_data) { + if (images_data != data.end() && images_data->is_array()) + { + for (const auto &img : *images_data) + { std::string data_b64 = img["data"].get(); slot_image img_sl; img_sl.id = @@ -741,7 +854,8 @@ struct llama_server_context { auto data = stbi_load_from_memory(image_buffer.data(), image_buffer.size(), &width, &height, &channels, 3); - if (!data) { + if (!data) + { LOG_TEE("slot %i - failed to load image [id: %i]\n", slot->id, img_sl.id); return false; @@ -761,21 +875,27 @@ struct llama_server_context { // example: system prompt [img-102] user [img-103] describe [img-134] -> // [{id: 102, prefix: 'system prompt '}, {id: 103, prefix: ' user '}, // {id: 134, prefix: ' describe '}]} - if (slot->images.size() > 0 && !slot->prompt.is_array()) { + if (slot->images.size() > 0 && !slot->prompt.is_array()) + { std::string prompt = slot->prompt.get(); size_t pos = 0, begin_prefix = 0; std::string pattern = "[img-"; - while ((pos = prompt.find(pattern, pos)) != std::string::npos) { + while ((pos = prompt.find(pattern, pos)) != std::string::npos) + { size_t end_prefix = pos; pos += pattern.length(); size_t end_pos = prompt.find("]", pos); - if (end_pos != std::string::npos) { + if (end_pos != std::string::npos) + { std::string image_id = prompt.substr(pos, end_pos - pos); - try { + try + { int img_id = std::stoi(image_id); bool found = false; - for (slot_image &img : slot->images) { - if (img.id == img_id) { + for (slot_image &img : slot->images) + { + if (img.id == img_id) + { found = true; img.prefix_prompt = prompt.substr(begin_prefix, end_prefix - begin_prefix); @@ -783,12 +903,15 @@ struct llama_server_context { break; } } - if (!found) { + if (!found) + { LOG_TEE("ERROR: Image with id: %i, not found.\n", img_id); slot->images.clear(); return false; } - } catch (const std::invalid_argument &e) { + } + catch (const std::invalid_argument &e) + { LOG_TEE("Invalid image number id in prompt\n"); slot->images.clear(); return false; @@ -803,7 +926,8 @@ struct llama_server_context { } } - if (slot->ctx_sampling != nullptr) { + if (slot->ctx_sampling != nullptr) + { llama_sampling_free(slot->ctx_sampling); } slot->ctx_sampling = llama_sampling_init(slot->sparams); @@ -816,30 +940,35 @@ struct llama_server_context { return true; } - void kv_cache_clear() { + void kv_cache_clear() + { // clear the entire KV cache llama_kv_cache_clear(ctx); clean_kv_cache = false; } - void update_system_prompt() { + void update_system_prompt() + { system_tokens = ::llama_tokenize(ctx, system_prompt, true); llama_batch_clear(batch); kv_cache_clear(); - for (int i = 0; i < (int)system_tokens.size(); ++i) { + for (int i = 0; i < (int)system_tokens.size(); ++i) + { llama_batch_add(batch, system_tokens[i], i, {0}, false); } - if (llama_decode(ctx, batch) != 0) { + if (llama_decode(ctx, batch) != 0) + { LOG_TEE("%s: llama_decode() failed\n", __func__); return; } // assign the system KV cache to all parallel sequences - for (int32_t i = 1; i < params.n_parallel; ++i) { + for (int32_t i = 1; i < params.n_parallel; ++i) + { llama_kv_cache_seq_cp(ctx, 0, i, 0, system_tokens.size()); } @@ -847,21 +976,25 @@ struct llama_server_context { system_need_update = false; } - void notify_system_prompt_changed() { + void notify_system_prompt_changed() + { // release all slots - for (llama_client_slot &slot : slots) { + for (llama_client_slot &slot : slots) + { slot.release(); } system_need_update = true; } - void process_system_prompt_data(const json &sys_props) { + void process_system_prompt_data(const json &sys_props) + { system_prompt = sys_props.value("prompt", ""); name_user = sys_props.value("anti_prompt", ""); name_assistant = sys_props.value("assistant_name", ""); - if (slots.size() > 0) { + if (slots.size() > 0) + { notify_system_prompt_changed(); } } @@ -869,21 +1002,28 @@ struct llama_server_context { static size_t find_stopping_strings(const std::string &text, const size_t last_token_size, const stop_type type, - llama_client_slot &slot) { + llama_client_slot &slot) + { size_t stop_pos = std::string::npos; - for (const std::string &word : slot.params.antiprompt) { + for (const std::string &word : slot.params.antiprompt) + { size_t pos; - if (type == STOP_FULL) { + if (type == STOP_FULL) + { const size_t tmp = word.size() + last_token_size; const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0; pos = text.find(word, from_pos); - } else { + } + else + { pos = find_partial_stop_string(word, text); } if (pos != std::string::npos && - (stop_pos == std::string::npos || pos < stop_pos)) { - if (type == STOP_FULL) { + (stop_pos == std::string::npos || pos < stop_pos)) + { + if (type == STOP_FULL) + { slot.stopped_word = true; slot.stopping_word = word; slot.has_next_token = false; @@ -895,7 +1035,8 @@ struct llama_server_context { return stop_pos; } - bool process_token(completion_token_output &result, llama_client_slot &slot) { + bool process_token(completion_token_output &result, llama_client_slot &slot) + { // remember which tokens were sampled - used for repetition penalties during // sampling const std::string token_str = llama_token_to_piece(ctx, result.tok); @@ -905,36 +1046,50 @@ struct llama_server_context { slot.generated_text += token_str; slot.has_next_token = true; - if (slot.multibyte_pending > 0) { + if (slot.multibyte_pending > 0) + { slot.multibyte_pending -= token_str.size(); - } else if (token_str.size() == 1) { + } + else if (token_str.size() == 1) + { const char c = token_str[0]; // 2-byte characters: 110xxxxx 10xxxxxx - if ((c & 0xE0) == 0xC0) { + if ((c & 0xE0) == 0xC0) + { slot.multibyte_pending = 1; // 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx - } else if ((c & 0xF0) == 0xE0) { + } + else if ((c & 0xF0) == 0xE0) + { slot.multibyte_pending = 2; // 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx - } else if ((c & 0xF8) == 0xF0) { + } + else if ((c & 0xF8) == 0xF0) + { slot.multibyte_pending = 3; - } else { + } + else + { slot.multibyte_pending = 0; } } - if (slot.multibyte_pending == 0) { + if (slot.multibyte_pending == 0) + { size_t pos = std::min(slot.sent_count, slot.generated_text.size()); const std::string str_test = slot.generated_text.substr(pos); bool is_stop_full = false; size_t stop_pos = find_stopping_strings(str_test, token_str.size(), STOP_FULL, slot); - if (stop_pos != std::string::npos) { + if (stop_pos != std::string::npos) + { is_stop_full = true; slot.generated_text.erase(slot.generated_text.begin() + pos + stop_pos, slot.generated_text.end()); pos = std::min(slot.sent_count, slot.generated_text.size()); - } else { + } + else + { is_stop_full = false; stop_pos = find_stopping_strings(str_test, token_str.size(), STOP_PARTIAL, slot); @@ -942,7 +1097,8 @@ struct llama_server_context { // check if there is any token to predict if (stop_pos == std::string::npos || - (!slot.has_next_token && !is_stop_full && stop_pos > 0)) { + (!slot.has_next_token && !is_stop_full && stop_pos > 0)) + { // no send the stop word in the response result.text_to_send = slot.generated_text.substr(pos, std::string::npos); @@ -950,22 +1106,26 @@ struct llama_server_context { // add the token to slot queue and cache } slot.add_token_string(result); - if (slot.params.stream) { + if (slot.params.stream) + { send_partial_response(slot, result); } } - if (slot.multibyte_pending > 0 && !slot.has_next_token) { + if (slot.multibyte_pending > 0 && !slot.has_next_token) + { slot.has_next_token = true; } // check the limits - if (slot.n_decoded > 2 && slot.has_next_token && !slot.has_budget(params)) { + if (slot.n_decoded > 2 && slot.has_next_token && !slot.has_budget(params)) + { slot.stopped_limit = true; slot.has_next_token = false; } - if (!slot.cache_tokens.empty() && result.tok == llama_token_eos(model)) { + if (!slot.cache_tokens.empty() && result.tok == llama_token_eos(model)) + { slot.stopped_eos = true; slot.has_next_token = false; LOG_VERBOSE("eos token found", {}); @@ -988,28 +1148,34 @@ struct llama_server_context { return slot.has_next_token; // continue } - bool process_images(llama_client_slot &slot) const { - for (slot_image &img : slot.images) { - if (!img.request_encode_image) { + bool process_images(llama_client_slot &slot) const + { + for (slot_image &img : slot.images) + { + if (!img.request_encode_image) + { continue; } clip_image_f32 img_res; if (!clip_image_preprocess(clp_ctx, &img.img_data, &img_res, - /*pad2square =*/true)) { + /*pad2square =*/true)) + { LOG_TEE("Error processing the given image"); clip_free(clp_ctx); return false; } img.image_tokens = clip_n_patches(clp_ctx); img.image_embedding = (float *)malloc(clip_embd_nbytes(clp_ctx)); - if (!img.image_embedding) { + if (!img.image_embedding) + { LOG_TEE("Unable to allocate memory for image embeddings\n"); clip_free(clp_ctx); return false; } LOG_TEE("slot %i - encoding image [id: %i]\n", slot.id, img.id); if (!clip_image_encode(clp_ctx, params.n_threads, &img_res, - img.image_embedding)) { + img.image_embedding)) + { LOG_TEE("Unable to encode image\n"); return false; } @@ -1019,7 +1185,8 @@ struct llama_server_context { return slot.images.size() > 0; } - void send_error(int id, std::string error) { + void send_error(int id, std::string error) + { std::lock_guard lock(mutex_results); task_result res; res.id = id; @@ -1030,7 +1197,8 @@ struct llama_server_context { json get_model_props() { return get_formated_generation(slots[0]); } - json get_formated_generation(llama_client_slot &slot) { + json get_formated_generation(llama_client_slot &slot) + { const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model)); const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && @@ -1064,7 +1232,8 @@ struct llama_server_context { } void send_partial_response(llama_client_slot &slot, - completion_token_output tkn) { + completion_token_output tkn) + { std::lock_guard lock(mutex_results); task_result res; res.id = slot.task_id; @@ -1076,7 +1245,8 @@ struct llama_server_context { {"slot_id", slot.id}, {"multimodal", multimodal}}; - if (slot.sparams.n_probs > 0) { + if (slot.sparams.n_probs > 0) + { std::vector probs_output = {}; const std::vector to_send_toks = llama_tokenize(ctx, tkn.text_to_send, false); @@ -1085,7 +1255,8 @@ struct llama_server_context { size_t probs_stop_pos = std::min(slot.sent_token_probs_index + to_send_toks.size(), slot.generated_token_probs.size()); - if (probs_pos < probs_stop_pos) { + if (probs_pos < probs_stop_pos) + { probs_output = std::vector( slot.generated_token_probs.begin() + probs_pos, slot.generated_token_probs.begin() + probs_stop_pos); @@ -1098,7 +1269,8 @@ struct llama_server_context { queue_results.push_back(res); } - void send_final_response(llama_client_slot &slot) { + void send_final_response(llama_client_slot &slot) + { std::lock_guard lock(mutex_results); task_result res; res.id = slot.task_id; @@ -1122,15 +1294,19 @@ struct llama_server_context { {"tokens_cached", slot.n_past}, {"timings", slot.get_formated_timings()}}; - if (slot.sparams.n_probs > 0) { + if (slot.sparams.n_probs > 0) + { std::vector probs = {}; - if (!slot.params.stream && slot.stopped_word) { + if (!slot.params.stream && slot.stopped_word) + { const std::vector stop_word_toks = llama_tokenize(ctx, slot.stopping_word, false); probs = std::vector( slot.generated_token_probs.begin(), slot.generated_token_probs.end() - stop_word_toks.size()); - } else { + } + else + { probs = std::vector( slot.generated_token_probs.begin(), slot.generated_token_probs.begin() + slot.sent_token_probs_index); @@ -1142,7 +1318,8 @@ struct llama_server_context { queue_results.push_back(res); } - void send_embedding(llama_client_slot &slot) { + void send_embedding(llama_client_slot &slot) + { std::lock_guard lock(mutex_results); task_result res; res.id = slot.task_id; @@ -1150,7 +1327,8 @@ struct llama_server_context { res.stop = true; const int n_embd = llama_n_embd(model); - if (!params.embedding) { + if (!params.embedding) + { LOG_WARNING_LLAMA("embedding disabled", { {"params.embedding", params.embedding}, @@ -1158,7 +1336,9 @@ struct llama_server_context { res.result_json = json{ {"embedding", std::vector(n_embd, 0.0f)}, }; - } else { + } + else + { const float *data = llama_get_embeddings(ctx); std::vector embedding(data, data + n_embd); res.result_json = json{ @@ -1168,7 +1348,8 @@ struct llama_server_context { queue_results.push_back(res); } - int request_completion(json data, bool infill, bool embedding) { + int request_completion(json data, bool infill, bool embedding) + { std::lock_guard lock(mutex_tasks); task_server task; task.id = id_gen++; @@ -1180,17 +1361,22 @@ struct llama_server_context { return task.id; } - task_result next_result(int task_id) { - while (true) { + task_result next_result(int task_id) + { + while (true) + { std::this_thread::sleep_for(std::chrono::microseconds(5)); std::lock_guard lock(mutex_results); - if (queue_results.empty()) { + if (queue_results.empty()) + { continue; } - for (int i = 0; i < (int)queue_results.size(); i++) { - if (queue_results[i].id == task_id) { + for (int i = 0; i < (int)queue_results.size(); i++) + { + if (queue_results[i].id == task_id) + { task_result res = queue_results[i]; queue_results.erase(queue_results.begin() + i); return res; @@ -1203,14 +1389,17 @@ struct llama_server_context { } // for multiple images processing - bool ingest_images(llama_client_slot &slot, int n_batch) { + bool ingest_images(llama_client_slot &slot, int n_batch) + { int image_idx = 0; - while (image_idx < (int)slot.images.size()) { + while (image_idx < (int)slot.images.size()) + { slot_image &img = slot.images[image_idx]; // process prefix prompt - for (int32_t i = 0; i < (int32_t)batch.n_tokens; i += n_batch) { + for (int32_t i = 0; i < (int32_t)batch.n_tokens; i += n_batch) + { const int32_t n_tokens = std::min(n_batch, (int32_t)(batch.n_tokens - i)); llama_batch batch_view = { @@ -1225,27 +1414,37 @@ struct llama_server_context { 0, 0, // unused }; - if (llama_decode(ctx, batch_view)) { + if (llama_decode(ctx, batch_view)) + { LOG_TEE("%s : failed to eval\n", __func__); return false; } } // process image with llm - for (int i = 0; i < img.image_tokens; i += n_batch) { + for (int i = 0; i < img.image_tokens; i += n_batch) + { int n_eval = img.image_tokens - i; - if (n_eval > n_batch) { + if (n_eval > n_batch) + { n_eval = n_batch; } const int n_embd = llama_n_embd(model); llama_batch batch_img = { - n_eval, nullptr, (img.image_embedding + i * n_embd), - nullptr, nullptr, nullptr, - nullptr, slot.n_past, 1, + n_eval, + nullptr, + (img.image_embedding + i * n_embd), + nullptr, + nullptr, + nullptr, + nullptr, + slot.n_past, + 1, 0, }; - if (llama_decode(ctx, batch_img)) { + if (llama_decode(ctx, batch_img)) + { LOG_TEE("%s : failed to eval image\n", __func__); return false; } @@ -1264,7 +1463,8 @@ struct llama_server_context { std::vector append_tokens = tokenize(json_prompt, false); // has next image - for (int i = 0; i < (int)append_tokens.size(); ++i) { + for (int i = 0; i < (int)append_tokens.size(); ++i) + { llama_batch_add(batch, append_tokens[i], slot.n_past, {slot.id}, true); slot.n_past += 1; } @@ -1273,7 +1473,8 @@ struct llama_server_context { return true; } - void request_cancel(int task_id) { + void request_cancel(int task_id) + { std::lock_guard lock(mutex_tasks); task_server task; task.id = id_gen++; @@ -1282,23 +1483,29 @@ struct llama_server_context { queue_tasks.push_back(task); } - void process_tasks() { + void process_tasks() + { std::lock_guard lock(mutex_tasks); - while (!queue_tasks.empty()) { + while (!queue_tasks.empty()) + { task_server task = queue_tasks.front(); queue_tasks.erase(queue_tasks.begin()); - switch (task.type) { - case COMPLETION_TASK: { + switch (task.type) + { + case COMPLETION_TASK: + { llama_client_slot *slot = get_slot(json_value(task.data, "slot_id", -1)); - if (slot == nullptr) { + if (slot == nullptr) + { LOG_TEE("slot unavailable\n"); // send error result send_error(task.id, "slot unavailable"); return; } - if (task.data.contains("system_prompt")) { + if (task.data.contains("system_prompt")) + { process_system_prompt_data(task.data["system_prompt"]); } @@ -1308,38 +1515,48 @@ struct llama_server_context { slot->embedding = task.embedding_mode; slot->task_id = task.id; - if (!launch_slot_with_data(slot, task.data)) { + if (!launch_slot_with_data(slot, task.data)) + { // send error result send_error(task.id, "internal_error"); break; } - } break; - case CANCEL_TASK: { // release slot linked with the task id - for (auto &slot : slots) { - if (slot.task_id == task.target_id) { + } + break; + case CANCEL_TASK: + { // release slot linked with the task id + for (auto &slot : slots) + { + if (slot.task_id == task.target_id) + { slot.release(); break; } } - } break; + } + break; } } } - bool update_slots() { + bool update_slots() + { // attend tasks process_tasks(); // update the system prompt wait until all slots are idle state - if (system_need_update && all_slots_are_idle) { + if (system_need_update && all_slots_are_idle) + { LOG_TEE("updating system prompt\n"); update_system_prompt(); } llama_batch_clear(batch); - if (all_slots_are_idle) { - if (system_prompt.empty() && clean_kv_cache) { + if (all_slots_are_idle) + { + if (system_prompt.empty() && clean_kv_cache) + { LOG_TEE("all slots are idle and system prompt is empty, clear the KV " "cache\n"); kv_cache_clear(); @@ -1348,9 +1565,11 @@ struct llama_server_context { std::this_thread::sleep_for(std::chrono::milliseconds(5)); } - for (llama_client_slot &slot : slots) { + for (llama_client_slot &slot : slots) + { if (slot.is_processing() && - slot.cache_tokens.size() >= (size_t)slot.n_ctx) { + slot.cache_tokens.size() >= (size_t)slot.n_ctx) + { // Shift context const int n_left = slot.n_past - slot.params.n_keep - 1; const int n_discard = n_left / 2; @@ -1365,7 +1584,8 @@ struct llama_server_context { slot.n_past, -n_discard); for (size_t i = slot.params.n_keep + 1 + n_discard; - i < slot.cache_tokens.size(); i++) { + i < slot.cache_tokens.size(); i++) + { slot.cache_tokens[i - n_discard] = slot.cache_tokens[i]; } @@ -1384,9 +1604,11 @@ struct llama_server_context { } // decode any currently ongoing sequences - for (auto &slot : slots) { + for (auto &slot : slots) + { // release the slot - if (slot.command == RELEASE) { + if (slot.command == RELEASE) + { slot.state = IDLE; slot.command = NONE; slot.t_last_used = ggml_time_us(); @@ -1397,7 +1619,8 @@ struct llama_server_context { continue; } - if (slot.state == IDLE) { + if (slot.state == IDLE) + { continue; } @@ -1414,15 +1637,18 @@ struct llama_server_context { int32_t n_batch = params.n_batch; // assign workload to the slots - if (params.cont_batching || batch.n_tokens == 0) { - for (auto &slot : slots) { + if (params.cont_batching || batch.n_tokens == 0) + { + for (auto &slot : slots) + { const bool has_prompt = slot.prompt.is_array() || (slot.prompt.is_string() && !slot.prompt.get().empty()) || !slot.images.empty(); // empty prompt passed -> release the slot and send empty response - if (slot.state == IDLE && slot.command == LOAD_PROMPT && !has_prompt) { + if (slot.state == IDLE && slot.command == LOAD_PROMPT && !has_prompt) + { slot.release(); slot.print_timings(); send_final_response(slot); @@ -1430,17 +1656,20 @@ struct llama_server_context { } // need process the prompt - if (slot.state == IDLE && slot.command == LOAD_PROMPT) { + if (slot.state == IDLE && slot.command == LOAD_PROMPT) + { slot.state = PROCESSING; slot.command = NONE; std::vector prompt_tokens; slot.t_start_process_prompt = ggml_time_us(); slot.t_start_genereration = 0; - if (slot.infill) { + if (slot.infill) + { bool suff_rm_leading_spc = true; if (params.input_suffix.find_first_of(' ') == 0 && - params.input_suffix.size() > 1) { + params.input_suffix.size() > 1) + { params.input_suffix.erase(0, 1); suff_rm_leading_spc = false; } @@ -1449,7 +1678,8 @@ struct llama_server_context { const int space_token = 29871; // TODO: this should not be hardcoded if (suff_rm_leading_spc && !suffix_tokens.empty() && - suffix_tokens[0] == space_token) { + suffix_tokens[0] == space_token) + { suffix_tokens.erase(suffix_tokens.begin()); } @@ -1463,7 +1693,9 @@ struct llama_server_context { suffix_tokens.end()); prefix_tokens.push_back(llama_token_middle(model)); prompt_tokens = prefix_tokens; - } else { + } + else + { prompt_tokens = tokenize( slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt @@ -1471,19 +1703,24 @@ struct llama_server_context { slot.num_prompt_tokens = prompt_tokens.size(); - if (!slot.params.cache_prompt) { + if (!slot.params.cache_prompt) + { llama_sampling_reset(slot.ctx_sampling); slot.n_past = 0; slot.num_prompt_tokens_processed = slot.num_prompt_tokens; - } else { - if (slot.params.n_keep < 0) { + } + else + { + if (slot.params.n_keep < 0) + { slot.params.n_keep = slot.num_prompt_tokens; } slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep); // if input prompt is too big, truncate it - if (slot.num_prompt_tokens >= slot.n_ctx) { + if (slot.num_prompt_tokens >= slot.n_ctx) + { const int n_left = slot.n_ctx - slot.params.n_keep; const int n_block_size = n_left / 2; const int erased_blocks = @@ -1515,7 +1752,8 @@ struct llama_server_context { } // push the prompt into the sampling context (do not apply grammar) - for (auto &token : prompt_tokens) { + for (auto &token : prompt_tokens) + { llama_sampling_accept(slot.ctx_sampling, ctx, token, false); } @@ -1535,7 +1773,8 @@ struct llama_server_context { slot.cache_tokens = prompt_tokens; - if (slot.n_past == slot.num_prompt_tokens) { + if (slot.n_past == slot.num_prompt_tokens) + { // we have to evaluate at least 1 token to generate logits. LOG_TEE("slot %d : we have to evaluate at least 1 token to " "generate logits\n", @@ -1561,19 +1800,22 @@ struct llama_server_context { std::vector prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, true) : prompt_tokens; - for (; slot.n_past < (int)prefix_tokens.size(); ++slot.n_past) { + for (; slot.n_past < (int)prefix_tokens.size(); ++slot.n_past) + { llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot.n_past, {slot.id}, false); } - if (has_images && !ingest_images(slot, n_batch)) { + if (has_images && !ingest_images(slot, n_batch)) + { LOG_TEE("failed processing images\n"); return false; } // extract the logits only for the last token - if (batch.n_tokens > 0) { + if (batch.n_tokens > 0) + { batch.logits[batch.n_tokens - 1] = true; } @@ -1583,12 +1825,14 @@ struct llama_server_context { } } - if (batch.n_tokens == 0) { + if (batch.n_tokens == 0) + { all_slots_are_idle = true; return true; } - for (int32_t i = 0; i < (int32_t)batch.n_tokens; i += n_batch) { + for (int32_t i = 0; i < (int32_t)batch.n_tokens; i += n_batch) + { const int32_t n_tokens = std::min(n_batch, (int32_t)(batch.n_tokens - i)); llama_batch batch_view = { n_tokens, @@ -1604,8 +1848,10 @@ struct llama_server_context { }; const int ret = llama_decode(ctx, batch_view); - if (ret != 0) { - if (n_batch == 1 || ret < 0) { + if (ret != 0) + { + if (n_batch == 1 || ret < 0) + { // if you get here, it means the KV cache is full - try increasing it // via the context size LOG_TEE("%s : failed to decode the batch, n_batch = %d, ret = %d\n", @@ -1624,13 +1870,16 @@ struct llama_server_context { continue; } - for (auto &slot : slots) { - if (slot.i_batch < (int)i || slot.i_batch >= (int)(i + n_tokens)) { + for (auto &slot : slots) + { + if (slot.i_batch < (int)i || slot.i_batch >= (int)(i + n_tokens)) + { continue; } // prompt evaluated for embedding - if (slot.embedding) { + if (slot.embedding) + { send_embedding(slot); slot.release(); slot.i_batch = -1; @@ -1643,7 +1892,8 @@ struct llama_server_context { llama_sampling_accept(slot.ctx_sampling, ctx, id, true); - if (slot.n_decoded == 1) { + if (slot.n_decoded == 1) + { slot.t_start_genereration = ggml_time_us(); slot.t_prompt_processing = (slot.t_start_genereration - slot.t_start_process_prompt) / 1e3; @@ -1654,16 +1904,19 @@ struct llama_server_context { result.tok = id; const int32_t n_probs = slot.sparams.n_probs; - if (slot.sparams.temp <= 0 && n_probs > 0) { + if (slot.sparams.temp <= 0 && n_probs > 0) + { // for llama_sample_token_greedy we need to sort candidates llama_sample_softmax(ctx, &cur_p); } - for (size_t i = 0; i < std::min(cur_p.size, (size_t)n_probs); ++i) { + for (size_t i = 0; i < std::min(cur_p.size, (size_t)n_probs); ++i) + { result.probs.push_back({cur_p.data[i].id, cur_p.data[i].p}); } - if (!process_token(result, slot)) { + if (!process_token(result, slot)) + { slot.release(); slot.print_timings(); send_final_response(slot); @@ -1677,7 +1930,8 @@ struct llama_server_context { }; static void server_print_usage(const char *argv0, const gpt_params ¶ms, - const server_params &sparams) { + const server_params &sparams) +{ printf("usage: %s [options]\n", argv0); printf("\n"); printf("options:\n"); @@ -1716,11 +1970,13 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, "key+value (default: disabled)\n"); printf(" not recommended: doubles context memory " "required and no measurable increase in quality\n"); - if (llama_mlock_supported()) { + if (llama_mlock_supported()) + { printf(" --mlock force system to keep model in RAM rather " "than swapping or compressing\n"); } - if (llama_mmap_supported()) { + if (llama_mmap_supported()) + { printf(" --no-mmap do not memory-map model (slower load but " "may reduce pageouts if not using mlock)\n"); } @@ -1778,139 +2034,161 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, static json format_partial_response(llama_server_context &llama, llama_client_slot *slot, const std::string &content, - const std::vector &probs) { + const std::vector &probs) +{ json res = json{{"content", content}, {"stop", false}, {"slot_id", slot->id}, {"multimodal", llama.multimodal}}; - if (slot->sparams.n_probs > 0) { + if (slot->sparams.n_probs > 0) + { res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs); } return res; } -static json format_tokenizer_response(const std::vector &tokens) { +static json format_tokenizer_response(const std::vector &tokens) +{ return json{{"tokens", tokens}}; } -static json format_detokenized_response(std::string content) { +static json format_detokenized_response(std::string content) +{ return json{{"content", content}}; } -struct token_translator { +struct token_translator +{ llama_context *ctx; - std::string operator()(llama_token tok) const { + std::string operator()(llama_token tok) const + { return llama_token_to_piece(ctx, tok); } - std::string operator()(const completion_token_output &cto) const { + std::string operator()(const completion_token_output &cto) const + { return (*this)(cto.tok); } }; static void append_to_generated_text_from_generated_token_probs(llama_server_context &llama, - llama_client_slot *slot) { + llama_client_slot *slot) +{ auto >ps = slot->generated_token_probs; auto translator = token_translator{llama.ctx}; - auto add_strlen = [=](size_t sum, const completion_token_output &cto) { + auto add_strlen = [=](size_t sum, const completion_token_output &cto) + { return sum + translator(cto).size(); }; const size_t len = std::accumulate(gtps.begin(), gtps.end(), size_t(0), add_strlen); - if (slot->generated_text.capacity() < slot->generated_text.size() + len) { + if (slot->generated_text.capacity() < slot->generated_text.size() + len) + { slot->generated_text.reserve(slot->generated_text.size() + len); } - for (const completion_token_output &cto : gtps) { + for (const completion_token_output &cto : gtps) + { slot->generated_text += translator(cto); } } using namespace drogon; -namespace inferences { -class llamaCPP : public drogon::HttpController { -public: - llamaCPP() { - // Some default values for now below - // log_disable(); // Disable the log to file feature, reduce bloat for - // target - // system () - std::vector llama_models = - nitro_utils::listFilesInDir(nitro_utils::models_folder); - std::string model_index; - if (llama_models.size() > 0) { - LOG_INFO << "Found models folder, here are the llama models you have:"; - int index_val = 0; - for (auto llama_model : llama_models) { - LOG_INFO << "index: " << index_val++ << "| model: " << llama_model; - std::cout - << "Please type the index of the model you want to load here >> "; - std::cin >> model_index; - Json::Value jsonBody; - jsonBody["llama_model_path"] = nitro_utils::models_folder + "/" + - llama_models[std::stoi(model_index)]; - loadModelImpl(jsonBody); +namespace inferences +{ + class llamaCPP : public drogon::HttpController + { + public: + llamaCPP() + { + // Some default values for now below + // log_disable(); // Disable the log to file feature, reduce bloat for + // target + // system () + std::vector llama_models = + nitro_utils::listFilesInDir(nitro_utils::models_folder); + std::string model_index; + if (llama_models.size() > 0) + { + LOG_INFO << "Found models folder, here are the llama models you have:"; + int index_val = 0; + for (auto llama_model : llama_models) + { + LOG_INFO << "index: " << index_val++ << "| model: " << llama_model; + std::cout + << "Please type the index of the model you want to load here >> "; + std::cin >> model_index; + Json::Value jsonBody; + jsonBody["llama_model_path"] = nitro_utils::models_folder + "/" + + llama_models[std::stoi(model_index)]; + loadModelImpl(jsonBody); + } + } + else + { + LOG_INFO << "Not found models folder, start server as usual"; } - } else { - LOG_INFO << "Not found models folder, start server as usual"; } - } - METHOD_LIST_BEGIN - // list path definitions here; - METHOD_ADD(llamaCPP::chatCompletion, "chat_completion", Post); - METHOD_ADD(llamaCPP::embedding, "embedding", Post); - METHOD_ADD(llamaCPP::loadModel, "loadmodel", Post); - METHOD_ADD(llamaCPP::unloadModel, "unloadmodel", Get); - METHOD_ADD(llamaCPP::modelStatus, "modelstatus", Get); - - // Openai compatible path - ADD_METHOD_TO(llamaCPP::chatCompletion, "/v1/chat/completions", Post); - ADD_METHOD_TO(llamaCPP::chatCompletionPrelight, "/v1/chat/completions", - Options); - - ADD_METHOD_TO(llamaCPP::embedding, "/v1/embeddings", Post); - - // PATH_ADD("/llama/chat_completion", Post); - METHOD_LIST_END - void chatCompletion(const HttpRequestPtr &req, - std::function &&callback); - void chatCompletionPrelight( - const HttpRequestPtr &req, - std::function &&callback); - void embedding(const HttpRequestPtr &req, - std::function &&callback); - void loadModel(const HttpRequestPtr &req, - std::function &&callback); - void unloadModel(const HttpRequestPtr &req, + METHOD_LIST_BEGIN + // list path definitions here; + METHOD_ADD(llamaCPP::chatCompletion, "chat_completion", Post); + METHOD_ADD(llamaCPP::embedding, "embedding", Post); + METHOD_ADD(llamaCPP::loadModel, "loadmodel", Post); + METHOD_ADD(llamaCPP::unloadModel, "unloadmodel", Get); + METHOD_ADD(llamaCPP::modelStatus, "modelstatus", Get); + + // Openai compatible path + ADD_METHOD_TO(llamaCPP::chatCompletion, "/v1/chat/completions", Post); + ADD_METHOD_TO(llamaCPP::chatCompletionPrelight, "/v1/chat/completions", + Options); + + ADD_METHOD_TO(llamaCPP::transcription, "/v1/audio/transcription", Post); + + ADD_METHOD_TO(llamaCPP::embedding, "/v1/embeddings", Post); + + // PATH_ADD("/llama/chat_completion", Post); + METHOD_LIST_END + void chatCompletion(const HttpRequestPtr &req, + std::function &&callback); + void chatCompletionPrelight( + const HttpRequestPtr &req, + std::function &&callback); + void embedding(const HttpRequestPtr &req, std::function &&callback); - - void modelStatus(const HttpRequestPtr &req, + void transcription(const HttpRequestPtr &req, + std::function &&callback); + void loadModel(const HttpRequestPtr &req, std::function &&callback); - - bool loadModelImpl(const Json::Value &jsonBody); - - void warmupModel(); - - void backgroundTask(); - - void stopBackgroundTask(); - -private: - llama_server_context llama; - std::atomic model_loaded = false; - size_t sent_count = 0; - size_t sent_token_probs_index = 0; - std::thread backgroundThread; - std::string user_prompt; - std::string ai_prompt; - std::string system_prompt; - std::string pre_prompt; - int repeat_last_n; - bool caching_enabled; - std::atomic no_of_chats = 0; - int clean_cache_threshold; -}; + void unloadModel(const HttpRequestPtr &req, + std::function &&callback); + + void modelStatus(const HttpRequestPtr &req, + std::function &&callback); + + bool loadModelImpl(const Json::Value &jsonBody); + + void warmupModel(); + + void backgroundTask(); + + void stopBackgroundTask(); + + private: + llama_server_context llama; + std::atomic model_loaded = false; + size_t sent_count = 0; + size_t sent_token_probs_index = 0; + std::thread backgroundThread; + std::string user_prompt; + std::string ai_prompt; + std::string system_prompt; + std::string pre_prompt; + int repeat_last_n; + bool caching_enabled; + std::atomic no_of_chats = 0; + int clean_cache_threshold; + }; }; // namespace inferences From 863ebceb565891fc154d37ad1b18c41fdfe14426 Mon Sep 17 00:00:00 2001 From: hiro Date: Sat, 16 Dec 2023 23:56:19 +0700 Subject: [PATCH 06/31] chore: Add uploads folder to gitignore --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index f75df993e..13a9b93d2 100644 --- a/.gitignore +++ b/.gitignore @@ -563,3 +563,5 @@ FodyWeavers.xsd build build_deps .DS_Store + +uploads/** \ No newline at end of file From 0990139c8dfabc622aa626e0d9a18df3f2c2f99e Mon Sep 17 00:00:00 2001 From: hiro Date: Sat, 16 Dec 2023 23:56:26 +0700 Subject: [PATCH 07/31] chore: delete config.json --- config.json | 8 -------- 1 file changed, 8 deletions(-) delete mode 100644 config.json diff --git a/config.json b/config.json deleted file mode 100644 index add7da3a6..000000000 --- a/config.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "listeners": [ - { - "address": "127.0.0.1", - "port": 3928 - } - ] -} From b53f4a71959884a7f32e0d2644d246c07b7af250 Mon Sep 17 00:00:00 2001 From: hiro Date: Sat, 16 Dec 2023 23:56:52 +0700 Subject: [PATCH 08/31] feat: Add OAI compatible APIs for /audio/transcriptions and /audio/translations --- controllers/llamaCPP.cc | 49 ++++++++++++++++++++++++++++++++--------- controllers/llamaCPP.h | 10 ++++++++- 2 files changed, 48 insertions(+), 11 deletions(-) diff --git a/controllers/llamaCPP.cc b/controllers/llamaCPP.cc index 6bf63b036..0a02cbd53 100644 --- a/controllers/llamaCPP.cc +++ b/controllers/llamaCPP.cc @@ -451,23 +451,52 @@ void llamaCPP::transcription( const HttpRequestPtr &req, std::function &&callback) { - MultiPartParser fileUpload; - if (fileUpload.parse(req) != 0 || fileUpload.getFiles().size() != 1) + MultiPartParser partParser; + Json::Value jsonResp; + + if (partParser.parse(req) != 0 || partParser.getFiles().size() != 1) + { + auto resp = HttpResponse::newHttpResponse(); + resp->setBody("Must have exactly one file"); + resp->setStatusCode(k403Forbidden); + callback(resp); + return; + } + auto &file = partParser.getFiles()[0]; + const auto &formFields = partParser.getParameters(); + std::string model = formFields.at("model"); + file.save(); + + jsonResp["text"] = "handling text"; + + auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); + callback(resp); + return; +} + +void llamaCPP::translation( + const HttpRequestPtr &req, + std::function &&callback) +{ + MultiPartParser partParser; + Json::Value jsonResp; + + if (partParser.parse(req) != 0 || partParser.getFiles().size() != 1) { auto resp = HttpResponse::newHttpResponse(); - resp->setBody("Must only be one file"); + resp->setBody("Must have exactly one file"); resp->setStatusCode(k403Forbidden); callback(resp); return; } - auto &file = fileUpload.getFiles()[0]; - auto md5 = file.getMd5(); - auto resp = HttpResponse::newHttpResponse(); - resp->setBody( - "The server has calculated the file's MD5 hash to be " + md5); + auto &file = partParser.getFiles()[0]; + const auto &formFields = partParser.getParameters(); + std::string model = formFields.at("model"); file.save(); - LOG_INFO << "The uploaded file has been saved to the ./uploads " - "directory"; + + jsonResp["text"] = "handling text"; + + auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); callback(resp); return; } diff --git a/controllers/llamaCPP.h b/controllers/llamaCPP.h index 7d5c95670..04b1a6386 100644 --- a/controllers/llamaCPP.h +++ b/controllers/llamaCPP.h @@ -2145,7 +2145,8 @@ namespace inferences ADD_METHOD_TO(llamaCPP::chatCompletionPrelight, "/v1/chat/completions", Options); - ADD_METHOD_TO(llamaCPP::transcription, "/v1/audio/transcription", Post); + ADD_METHOD_TO(llamaCPP::transcription, "/v1/audio/transcriptions", Post); + ADD_METHOD_TO(llamaCPP::translation, "/v1/audio/translations", Post); ADD_METHOD_TO(llamaCPP::embedding, "/v1/embeddings", Post); @@ -2156,12 +2157,19 @@ namespace inferences void chatCompletionPrelight( const HttpRequestPtr &req, std::function &&callback); + void embedding(const HttpRequestPtr &req, std::function &&callback); + void transcription(const HttpRequestPtr &req, std::function &&callback); + + void translation(const HttpRequestPtr &req, + std::function &&callback); + void loadModel(const HttpRequestPtr &req, std::function &&callback); + void unloadModel(const HttpRequestPtr &req, std::function &&callback); From 018450da8d9db8b9c19aa0915c7438a5be28d6cf Mon Sep 17 00:00:00 2001 From: hiro Date: Sun, 17 Dec 2023 15:01:27 +0700 Subject: [PATCH 09/31] chore: Fix warning message in compiler --- utils/nitro_utils.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/nitro_utils.h b/utils/nitro_utils.h index 987fbd7e1..271c4a908 100644 --- a/utils/nitro_utils.h +++ b/utils/nitro_utils.h @@ -96,7 +96,7 @@ inline void nitro_logo() { std::string resetColor = "\033[0m"; std::string asciiArt = " ___ ___ ___ \n" - " /__/\ ___ ___ / /\\ / /\\ \n" + " /__/ ___ ___ / /\\ / /\\ \n" " \\ \\:\\ / /\\ / /\\ / /::\\ / /::\\ " " \n" " \\ \\:\\ / /:/ / /:/ / /:/\\:\\ / /:/\\:\\ " From 1288b586fbce9cf30df9367b9f921d071e35425e Mon Sep 17 00:00:00 2001 From: hiro Date: Sun, 17 Dec 2023 15:02:31 +0700 Subject: [PATCH 10/31] chore: Add whisper.cpp to target --- CMakeLists.txt | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 220fa9720..38622f5a9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -23,8 +23,8 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) set(OPENSSL_USE_STATIC_LIBS TRUE) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -set(CMAKE_PREFIX_PATH ${CMAKE_CURRENT_SOURCE_DIR}/build_deps/_install -)# This is the critical line for installing another package +set(CMAKE_PREFIX_PATH ${CMAKE_CURRENT_SOURCE_DIR}/build_deps/_install) +# This is the critical line for installing another package if(LLAMA_CUBLAS) cmake_minimum_required(VERSION 3.17) @@ -41,8 +41,9 @@ if(DEBUG) add_compile_definitions(ALLOW_ALL_CORS) endif() -add_subdirectory(llama.cpp) add_subdirectory(whisper.cpp) +add_subdirectory(llama.cpp) + add_executable(${PROJECT_NAME} main.cc) # ############################################################################## @@ -52,7 +53,7 @@ add_executable(${PROJECT_NAME} main.cc) # # and comment out the following lines find_package(Drogon CONFIG REQUIRED) -target_link_libraries(${PROJECT_NAME} PRIVATE Drogon::Drogon common llama llava +target_link_libraries(${PROJECT_NAME} PRIVATE Drogon::Drogon common llama llava whisper ${CMAKE_THREAD_LIBS_INIT}) # ############################################################################## From bf1fdb8a6b6527b87e6991da11552e154ce48ee1 Mon Sep 17 00:00:00 2001 From: hiro Date: Sun, 17 Dec 2023 15:02:47 +0700 Subject: [PATCH 11/31] fix: Add whisper.h --- controllers/llamaCPP.cc | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/controllers/llamaCPP.cc b/controllers/llamaCPP.cc index 0a02cbd53..b8858fc78 100644 --- a/controllers/llamaCPP.cc +++ b/controllers/llamaCPP.cc @@ -1,4 +1,6 @@ #include "llamaCPP.h" +#include "ggml.h" +#include "whisper.h" #include "llama.h" #include "utils/nitro_utils.h" #include @@ -546,6 +548,7 @@ bool llamaCPP::loadModelImpl(const Json::Value &jsonBody) { gpt_params params; + whisper_full_params whisper_params; // By default will setting based on number of handlers int drogon_thread = drogon::app().getThreadNum() - 1; @@ -557,6 +560,11 @@ bool llamaCPP::loadModelImpl(const Json::Value &jsonBody) LOG_INFO << "MMPROJ FILE detected, multi-model enabled!"; params.mmproj = jsonBody["mmproj"].asString(); } + if (!jsonBody["whisper"].isNull()) + { + LOG_INFO << "WHISPER FILE detected, whisper enabled!"; + whisper_params.whisper = jsonBody["whisper"].asString(); + } params.model = jsonBody["llama_model_path"].asString(); params.n_gpu_layers = jsonBody.get("ngl", 100).asInt(); params.n_ctx = jsonBody.get("ctx_len", 2048).asInt(); From 77972e782cfe14cc595847670f3252096958ae6d Mon Sep 17 00:00:00 2001 From: hiro Date: Sun, 17 Dec 2023 18:13:35 +0700 Subject: [PATCH 12/31] chore: Migrate whisper.cpp definition to another controller file --- controllers/llamaCPP.cc | 61 --------------------------------------- controllers/llamaCPP.h | 9 ------ controllers/whisperCPP.cc | 60 ++++++++++++++++++++++++++++++++++++++ controllers/whisperCPP.h | 27 +++++++++++++++++ 4 files changed, 87 insertions(+), 70 deletions(-) create mode 100644 controllers/whisperCPP.cc create mode 100644 controllers/whisperCPP.h diff --git a/controllers/llamaCPP.cc b/controllers/llamaCPP.cc index b8858fc78..fb6c35b31 100644 --- a/controllers/llamaCPP.cc +++ b/controllers/llamaCPP.cc @@ -1,6 +1,5 @@ #include "llamaCPP.h" #include "ggml.h" -#include "whisper.h" #include "llama.h" #include "utils/nitro_utils.h" #include @@ -449,60 +448,6 @@ void llamaCPP::embedding( return; } -void llamaCPP::transcription( - const HttpRequestPtr &req, - std::function &&callback) -{ - MultiPartParser partParser; - Json::Value jsonResp; - - if (partParser.parse(req) != 0 || partParser.getFiles().size() != 1) - { - auto resp = HttpResponse::newHttpResponse(); - resp->setBody("Must have exactly one file"); - resp->setStatusCode(k403Forbidden); - callback(resp); - return; - } - auto &file = partParser.getFiles()[0]; - const auto &formFields = partParser.getParameters(); - std::string model = formFields.at("model"); - file.save(); - - jsonResp["text"] = "handling text"; - - auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); - callback(resp); - return; -} - -void llamaCPP::translation( - const HttpRequestPtr &req, - std::function &&callback) -{ - MultiPartParser partParser; - Json::Value jsonResp; - - if (partParser.parse(req) != 0 || partParser.getFiles().size() != 1) - { - auto resp = HttpResponse::newHttpResponse(); - resp->setBody("Must have exactly one file"); - resp->setStatusCode(k403Forbidden); - callback(resp); - return; - } - auto &file = partParser.getFiles()[0]; - const auto &formFields = partParser.getParameters(); - std::string model = formFields.at("model"); - file.save(); - - jsonResp["text"] = "handling text"; - - auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); - callback(resp); - return; -} - void llamaCPP::unloadModel( const HttpRequestPtr &req, std::function &&callback) @@ -548,7 +493,6 @@ bool llamaCPP::loadModelImpl(const Json::Value &jsonBody) { gpt_params params; - whisper_full_params whisper_params; // By default will setting based on number of handlers int drogon_thread = drogon::app().getThreadNum() - 1; @@ -560,11 +504,6 @@ bool llamaCPP::loadModelImpl(const Json::Value &jsonBody) LOG_INFO << "MMPROJ FILE detected, multi-model enabled!"; params.mmproj = jsonBody["mmproj"].asString(); } - if (!jsonBody["whisper"].isNull()) - { - LOG_INFO << "WHISPER FILE detected, whisper enabled!"; - whisper_params.whisper = jsonBody["whisper"].asString(); - } params.model = jsonBody["llama_model_path"].asString(); params.n_gpu_layers = jsonBody.get("ngl", 100).asInt(); params.n_ctx = jsonBody.get("ctx_len", 2048).asInt(); diff --git a/controllers/llamaCPP.h b/controllers/llamaCPP.h index 04b1a6386..49f90ed4d 100644 --- a/controllers/llamaCPP.h +++ b/controllers/llamaCPP.h @@ -2145,9 +2145,6 @@ namespace inferences ADD_METHOD_TO(llamaCPP::chatCompletionPrelight, "/v1/chat/completions", Options); - ADD_METHOD_TO(llamaCPP::transcription, "/v1/audio/transcriptions", Post); - ADD_METHOD_TO(llamaCPP::translation, "/v1/audio/translations", Post); - ADD_METHOD_TO(llamaCPP::embedding, "/v1/embeddings", Post); // PATH_ADD("/llama/chat_completion", Post); @@ -2161,12 +2158,6 @@ namespace inferences void embedding(const HttpRequestPtr &req, std::function &&callback); - void transcription(const HttpRequestPtr &req, - std::function &&callback); - - void translation(const HttpRequestPtr &req, - std::function &&callback); - void loadModel(const HttpRequestPtr &req, std::function &&callback); diff --git a/controllers/whisperCPP.cc b/controllers/whisperCPP.cc new file mode 100644 index 000000000..43f1302f5 --- /dev/null +++ b/controllers/whisperCPP.cc @@ -0,0 +1,60 @@ +#include "whisperCPP.h" +#include "whisper.h" +// #include "llama.h" + +#include "utils/nitro_utils.h" + +// Add definition of your processing function here +void whisperCPP::transcription( + const HttpRequestPtr &req, + std::function &&callback) +{ + MultiPartParser partParser; + Json::Value jsonResp; + + if (partParser.parse(req) != 0 || partParser.getFiles().size() != 1) + { + auto resp = HttpResponse::newHttpResponse(); + resp->setBody("Must have exactly one file"); + resp->setStatusCode(k403Forbidden); + callback(resp); + return; + } + auto &file = partParser.getFiles()[0]; + const auto &formFields = partParser.getParameters(); + std::string model = formFields.at("model"); + file.save(); + + jsonResp["text"] = "handling text"; + + auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); + callback(resp); + return; +} + +void whisperCPP::translation( + const HttpRequestPtr &req, + std::function &&callback) +{ + MultiPartParser partParser; + Json::Value jsonResp; + + if (partParser.parse(req) != 0 || partParser.getFiles().size() != 1) + { + auto resp = HttpResponse::newHttpResponse(); + resp->setBody("Must have exactly one file"); + resp->setStatusCode(k403Forbidden); + callback(resp); + return; + } + auto &file = partParser.getFiles()[0]; + const auto &formFields = partParser.getParameters(); + std::string model = formFields.at("model"); + file.save(); + + jsonResp["text"] = "handling text"; + + auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); + callback(resp); + return; +} \ No newline at end of file diff --git a/controllers/whisperCPP.h b/controllers/whisperCPP.h new file mode 100644 index 000000000..a9d4ec006 --- /dev/null +++ b/controllers/whisperCPP.h @@ -0,0 +1,27 @@ +#pragma once + +#include +#include "whisper.h" + +using namespace drogon; + +class whisperCPP : public drogon::HttpController +{ +public: + METHOD_LIST_BEGIN + // use METHOD_ADD to add your custom processing function here; + // METHOD_ADD(whisperCPP::get, "/{2}/{1}", Get); // path is /whisperCPP/{arg2}/{arg1} + // METHOD_ADD(whisperCPP::your_method_name, "/{1}/{2}/list", Get); // path is /whisperCPP/{arg1}/{arg2}/list + // ADD_METHOD_TO(whisperCPP::your_method_name, "/absolute/path/{1}/{2}/list", Get); // path is /absolute/path/{arg1}/{arg2}/list + ADD_METHOD_TO(whisperCPP::transcription, "/v1/audio/transcriptions", Post); + ADD_METHOD_TO(whisperCPP::translation, "/v1/audio/translations", Post); + METHOD_LIST_END + // your declaration of processing function maybe like this: + // void get(const HttpRequestPtr& req, std::function &&callback, int p1, std::string p2); + // void your_method_name(const HttpRequestPtr& req, std::function &&callback, double p1, int p2) const; + void transcription(const HttpRequestPtr &req, + std::function &&callback); + + void translation(const HttpRequestPtr &req, + std::function &&callback); +}; From 3ec6224227bab45186bc273a96add4a2af0dbe24 Mon Sep 17 00:00:00 2001 From: hiro Date: Sun, 17 Dec 2023 18:15:39 +0700 Subject: [PATCH 13/31] chore: remove ggml.h in llamaCPP.cc --- controllers/llamaCPP.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/controllers/llamaCPP.cc b/controllers/llamaCPP.cc index fb6c35b31..1eb0abc91 100644 --- a/controllers/llamaCPP.cc +++ b/controllers/llamaCPP.cc @@ -1,5 +1,4 @@ #include "llamaCPP.h" -#include "ggml.h" #include "llama.h" #include "utils/nitro_utils.h" #include From 79e0081dfe6796acaca24d45263166cc5596ed67 Mon Sep 17 00:00:00 2001 From: hiro-v Date: Wed, 24 Jan 2024 23:53:44 +0700 Subject: [PATCH 14/31] WIP --- controllers/whisperCPP.cc | 1052 +++++- controllers/whisperCPP.h | 196 +- utils/dr_wav.h | 6434 +++++++++++++++++++++++++++++++++++++ 3 files changed, 7631 insertions(+), 51 deletions(-) create mode 100644 utils/dr_wav.h diff --git a/controllers/whisperCPP.cc b/controllers/whisperCPP.cc index 43f1302f5..7a7d0de8a 100644 --- a/controllers/whisperCPP.cc +++ b/controllers/whisperCPP.cc @@ -1,60 +1,1038 @@ #include "whisperCPP.h" -#include "whisper.h" +// #include "whisper.h" // #include "llama.h" -#include "utils/nitro_utils.h" +bool read_wav(const std::string &fname, std::vector &pcmf32, std::vector> &pcmf32s, bool stereo) +{ + drwav wav; + std::vector wav_data; // used for pipe input from stdin -// Add definition of your processing function here -void whisperCPP::transcription( + if (fname == "-") + { + { + uint8_t buf[1024]; + while (true) + { + const size_t n = fread(buf, 1, sizeof(buf), stdin); + if (n == 0) + { + break; + } + wav_data.insert(wav_data.end(), buf, buf + n); + } + } + + if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) + { + fprintf(stderr, "error: failed to open WAV file from stdin\n"); + return false; + } + + fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size()); + } + else if (drwav_init_file(&wav, fname.c_str(), nullptr) == false) + { + fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname.c_str()); + return false; + } + + if (wav.channels != 1 && wav.channels != 2) + { + fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", __func__, fname.c_str()); + return false; + } + + if (stereo && wav.channels != 2) + { + fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization\n", __func__, fname.c_str()); + return false; + } + + if (wav.sampleRate != COMMON_SAMPLE_RATE) + { + fprintf(stderr, "%s: WAV file '%s' must be %i kHz\n", __func__, fname.c_str(), COMMON_SAMPLE_RATE / 1000); + return false; + } + + if (wav.bitsPerSample != 16) + { + fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", __func__, fname.c_str()); + return false; + } + + const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size() / (wav.channels * wav.bitsPerSample / 8); + + std::vector pcm16; + pcm16.resize(n * wav.channels); + drwav_read_pcm_frames_s16(&wav, n, pcm16.data()); + drwav_uninit(&wav); + + // convert to mono, float + pcmf32.resize(n); + if (wav.channels == 1) + { + for (uint64_t i = 0; i < n; i++) + { + pcmf32[i] = float(pcm16[i]) / 32768.0f; + } + } + else + { + for (uint64_t i = 0; i < n; i++) + { + pcmf32[i] = float(pcm16[2 * i] + pcm16[2 * i + 1]) / 65536.0f; + } + } + + if (stereo) + { + // convert to stereo, float + pcmf32s.resize(2); + + pcmf32s[0].resize(n); + pcmf32s[1].resize(n); + for (uint64_t i = 0; i < n; i++) + { + pcmf32s[0][i] = float(pcm16[2 * i]) / 32768.0f; + pcmf32s[1][i] = float(pcm16[2 * i + 1]) / 32768.0f; + } + } + + return true; +} + +std::string output_str(struct whisper_context *ctx, const whisper_params ¶ms, std::vector> pcmf32s) +{ + std::stringstream result; + const int n_segments = whisper_full_n_segments(ctx); + for (int i = 0; i < n_segments; ++i) + { + const char *text = whisper_full_get_segment_text(ctx, i); + std::string speaker = ""; + + if (params.diarize && pcmf32s.size() == 2) + { + const int64_t t0 = whisper_full_get_segment_t0(ctx, i); + const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + speaker = estimate_diarization_speaker(pcmf32s, t0, t1); + } + + result << speaker << text << "\n"; + } + return result.str(); +} + +std::string estimate_diarization_speaker(std::vector> pcmf32s, int64_t t0, int64_t t1, bool id_only) +{ + std::string speaker = ""; + const int64_t n_samples = pcmf32s[0].size(); + + const int64_t is0 = timestamp_to_sample(t0, n_samples); + const int64_t is1 = timestamp_to_sample(t1, n_samples); + + double energy0 = 0.0f; + double energy1 = 0.0f; + + for (int64_t j = is0; j < is1; j++) + { + energy0 += fabs(pcmf32s[0][j]); + energy1 += fabs(pcmf32s[1][j]); + } + + if (energy0 > 1.1 * energy1) + { + speaker = "0"; + } + else if (energy1 > 1.1 * energy0) + { + speaker = "1"; + } + else + { + speaker = "?"; + } + + // printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, speaker = %s\n", is0, is1, energy0, energy1, speaker.c_str()); + + if (!id_only) + { + speaker.insert(0, "(speaker "); + speaker.append(")"); + } + + return speaker; +} + +// 500 -> 00:05.000 +// 6000 -> 01:00.000 +std::string to_timestamp(int64_t t, bool comma) +{ + int64_t msec = t * 10; + int64_t hr = msec / (1000 * 60 * 60); + msec = msec - hr * (1000 * 60 * 60); + int64_t min = msec / (1000 * 60); + msec = msec - min * (1000 * 60); + int64_t sec = msec / 1000; + msec = msec - sec * 1000; + + char buf[32]; + snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int)hr, (int)min, (int)sec, comma ? "," : ".", (int)msec); + + return std::string(buf); +} + +int timestamp_to_sample(int64_t t, int n_samples) +{ + return std::max(0, std::min((int)n_samples - 1, (int)((t * WHISPER_SAMPLE_RATE) / 100))); +} + +bool is_file_exist(const char *fileName) +{ + std::ifstream infile(fileName); + return infile.good(); +} + +void whisper_print_usage(int /*argc*/, char **argv, const whisper_params ¶ms) +{ + fprintf(stderr, "\n"); + fprintf(stderr, "usage: %s [options] \n", argv[0]); + fprintf(stderr, "\n"); + fprintf(stderr, "options:\n"); + fprintf(stderr, " -h, --help [default] show this help message and exit\n"); + fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); + fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors); + fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms); + fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n); + fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms); + fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context); + fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len); + fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false"); + fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of); + fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size); + fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); + fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold); + fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold); + // fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); + fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false"); + fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); + fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); + fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false"); + fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false"); + fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); + fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false"); + fprintf(stderr, " -pr, --print-realtime [%-7s] print output in realtime\n", params.print_realtime ? "true" : "false"); + fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false"); + fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false"); + fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str()); + fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false"); + fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str()); + fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); + fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str()); + fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server", params.ffmpeg_converter ? "true" : "false"); + fprintf(stderr, "\n"); +} + +bool whisper_params_parse(int argc, char **argv, whisper_params ¶ms) +{ + for (int i = 1; i < argc; i++) + { + std::string arg = argv[i]; + + if (arg == "-h" || arg == "--help") + { + whisper_print_usage(argc, argv, params); + exit(0); + } + else if (arg == "-t" || arg == "--threads") + { + params.n_threads = std::stoi(argv[++i]); + } + else if (arg == "-p" || arg == "--processors") + { + params.n_processors = std::stoi(argv[++i]); + } + else if (arg == "-ot" || arg == "--offset-t") + { + params.offset_t_ms = std::stoi(argv[++i]); + } + else if (arg == "-on" || arg == "--offset-n") + { + params.offset_n = std::stoi(argv[++i]); + } + else if (arg == "-d" || arg == "--duration") + { + params.duration_ms = std::stoi(argv[++i]); + } + else if (arg == "-mc" || arg == "--max-context") + { + params.max_context = std::stoi(argv[++i]); + } + else if (arg == "-ml" || arg == "--max-len") + { + params.max_len = std::stoi(argv[++i]); + } + else if (arg == "-bo" || arg == "--best-of") + { + params.best_of = std::stoi(argv[++i]); + } + else if (arg == "-bs" || arg == "--beam-size") + { + params.beam_size = std::stoi(argv[++i]); + } + else if (arg == "-wt" || arg == "--word-thold") + { + params.word_thold = std::stof(argv[++i]); + } + else if (arg == "-et" || arg == "--entropy-thold") + { + params.entropy_thold = std::stof(argv[++i]); + } + else if (arg == "-lpt" || arg == "--logprob-thold") + { + params.logprob_thold = std::stof(argv[++i]); + } + // else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } + else if (arg == "-debug" || arg == "--debug-mode") + { + params.debug_mode = true; + } + else if (arg == "-tr" || arg == "--translate") + { + params.translate = true; + } + else if (arg == "-di" || arg == "--diarize") + { + params.diarize = true; + } + else if (arg == "-tdrz" || arg == "--tinydiarize") + { + params.tinydiarize = true; + } + else if (arg == "-sow" || arg == "--split-on-word") + { + params.split_on_word = true; + } + else if (arg == "-nf" || arg == "--no-fallback") + { + params.no_fallback = true; + } + else if (arg == "-fp" || arg == "--font-path") + { + params.font_path = argv[++i]; + } + else if (arg == "-ps" || arg == "--print-special") + { + params.print_special = true; + } + else if (arg == "-pc" || arg == "--print-colors") + { + params.print_colors = true; + } + else if (arg == "-pr" || arg == "--print-realtime") + { + params.print_realtime = true; + } + else if (arg == "-pp" || arg == "--print-progress") + { + params.print_progress = true; + } + else if (arg == "-nt" || arg == "--no-timestamps") + { + params.no_timestamps = true; + } + else if (arg == "-l" || arg == "--language") + { + params.language = argv[++i]; + } + else if (arg == "-dl" || arg == "--detect-language") + { + params.detect_language = true; + } + else if (arg == "--prompt") + { + params.prompt = argv[++i]; + } + else if (arg == "-m" || arg == "--model") + { + params.model = argv[++i]; + } + else if (arg == "-oved" || arg == "--ov-e-device") + { + params.openvino_encode_device = argv[++i]; + } + else if (arg == "-ng" || arg == "--no-gpu") + { + params.use_gpu = false; + } + else if (arg == "--convert") + { + params.ffmpeg_converter = true; + } + else + { + fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); + whisper_print_usage(argc, argv, params); + exit(0); + } + } + + return true; +} + +void check_ffmpeg_availibility() +{ + int result = system("ffmpeg -version"); + + if (result == 0) + { + std::cout << "ffmpeg is available." << std::endl; + } + else + { + // ffmpeg is not available + std::cout << "ffmpeg is not found. Please ensure that ffmpeg is installed "; + std::cout << "and that its executable is included in your system's PATH. "; + exit(0); + } +} + +bool convert_to_wav(const std::string &temp_filename, std::string &error_resp) +{ + std::ostringstream cmd_stream; + std::string converted_filename_temp = temp_filename + "_temp.wav"; + cmd_stream << "ffmpeg -i \"" << temp_filename << "\" -ar 16000 -ac 1 -c:a pcm_s16le \"" << converted_filename_temp << "\" 2>&1"; + std::string cmd = cmd_stream.str(); + + int status = std::system(cmd.c_str()); + if (status != 0) + { + error_resp = "{\"error\":\"FFmpeg conversion failed.\"}"; + return false; + } + + // Remove the original file + if (remove(temp_filename.c_str()) != 0) + { + error_resp = "{\"error\":\"Failed to remove the original file.\"}"; + return false; + } + + // Rename the temporary file to match the original filename + if (rename(converted_filename_temp.c_str(), temp_filename.c_str()) != 0) + { + error_resp = "{\"error\":\"Failed to rename the temporary file.\"}"; + return false; + } + return true; +} + +void whisper_print_progress_callback(struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, int progress, void *user_data) +{ + int progress_step = ((whisper_print_user_data *)user_data)->params->progress_step; + int *progress_prev = &(((whisper_print_user_data *)user_data)->progress_prev); + if (progress >= *progress_prev + progress_step) + { + *progress_prev += progress_step; + fprintf(stderr, "%s: progress = %3d%%\n", __func__, progress); + } +} + +void whisper_print_segment_callback(struct whisper_context *ctx, struct whisper_state * /*state*/, int n_new, void *user_data) +{ + const auto ¶ms = *((whisper_print_user_data *)user_data)->params; + const auto &pcmf32s = *((whisper_print_user_data *)user_data)->pcmf32s; + + const int n_segments = whisper_full_n_segments(ctx); + + std::string speaker = ""; + + int64_t t0 = 0; + int64_t t1 = 0; + + // print the last n_new segments + const int s0 = n_segments - n_new; + + if (s0 == 0) + { + printf("\n"); + } + + for (int i = s0; i < n_segments; i++) + { + if (!params.no_timestamps || params.diarize) + { + t0 = whisper_full_get_segment_t0(ctx, i); + t1 = whisper_full_get_segment_t1(ctx, i); + } + + if (!params.no_timestamps) + { + printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str()); + } + + if (params.diarize && pcmf32s.size() == 2) + { + speaker = estimate_diarization_speaker(pcmf32s, t0, t1); + } + + if (params.print_colors) + { + for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) + { + if (params.print_special == false) + { + const whisper_token id = whisper_full_get_token_id(ctx, i, j); + if (id >= whisper_token_eot(ctx)) + { + continue; + } + } + + const char *text = whisper_full_get_token_text(ctx, i, j); + const float p = whisper_full_get_token_p(ctx, i, j); + + const int col = std::max(0, std::min((int)k_colors.size() - 1, (int)(std::pow(p, 3) * float(k_colors.size())))); + + printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m"); + } + } + else + { + const char *text = whisper_full_get_segment_text(ctx, i); + + printf("%s%s", speaker.c_str(), text); + } + + if (params.tinydiarize) + { + if (whisper_full_get_segment_speaker_turn_next(ctx, i)) + { + printf("%s", params.tdrz_speaker_turn.c_str()); + } + } + + // with timestamps or speakers: each segment on new line + if (!params.no_timestamps || params.diarize) + { + printf("\n"); + } + fflush(stdout); + } +} + +bool parse_str_to_bool(const std::string &s) +{ + if (s == "true" || s == "1" || s == "yes" || s == "y") + { + return true; + } + return false; +} + +bool whisper_server_context::load_model(std::string &model_path) +{ + whisper_mutex.lock(); + + // clean up + whisper_free(ctx); + + // whisper init + ctx = whisper_init_from_file_with_params(model_path.c_str(), cparams); + + // TODO perhaps load prior model here instead of exit + if (ctx == nullptr) + { + whisper_mutex.unlock(); + return false; + } + + // initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured + whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr); + + // check if the model is in the file system + whisper_mutex.unlock(); + return true; +} + +std::string whisper_server_context::inference(std::string &input_file_path, std::string language, std::string prompt, + std::string response_format, float temperature, bool translate) +{ + // acquire whisper model mutex lock + whisper_mutex.lock(); + + // audio arrays + std::vector pcmf32; // mono-channel F32 PCM + std::vector> pcmf32s; // stereo-channel F32 PCM + + // if file is not wav, convert to wav + if (params.ffmpeg_converter) + { + std::string error_resp = "Failed to execute ffmpeg command converting " + input_file_path + " to wav"; + const bool is_converted = convert_to_wav(input_file_path, error_resp); + if (!is_converted) + { + whisper_mutex.unlock(); + LOG_ERROR << error_resp; + throw std::runtime_error(error_resp); + } + } + + // read wav content into pcmf32 + if (!read_wav(input_file_path, pcmf32, pcmf32s, params.diarize)) + { + std::string error_resp = "Failed to read WAV file " + input_file_path; + LOG_ERROR << error_resp; + whisper_mutex.unlock(); + throw std::runtime_error(error_resp); + } + + printf("Successfully loaded %s\n", input_file_path.c_str()); + + params.translate = translate; + params.language = language; + if (!whisper_is_multilingual(ctx)) + { + if (params.language != "en" || params.translate) + { + params.language = "en"; + params.translate = false; + LOG_WARN << "Model " << model_id << " is not multilingual, ignoring language and translation options"; + } + } + if (params.detect_language) + { + params.language = "auto"; + } + + // print some processing info + std::string processing_info = "Model " + model_id + "processing " + input_file_path + " (" + std::to_string(pcmf32.size()) + " samples, " + std::to_string(float(pcmf32.size()) / WHISPER_SAMPLE_RATE) + " sec), " + std::to_string(params.n_threads) + " threads, " + std::to_string(params.n_processors) + " processors, lang = " + params.language + ", task = " + (params.translate ? "translate" : "transcribe") + ", " + (params.tinydiarize ? "tdrz = 1, " : "") + (params.no_timestamps ? "timestamps = 0" : "timestamps = 1"); + LOG_INFO << processing_info; + + // run the inference + { + std::string msg = "Running whisper.cpp inference of model " + model_id + " on " + input_file_path; + LOG_INFO << msg; + whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); + + wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY; + + wparams.print_realtime = false; + wparams.print_progress = params.print_progress; + wparams.print_timestamps = !params.no_timestamps; + wparams.print_special = params.print_special; + wparams.translate = params.translate; + wparams.language = params.language.c_str(); + wparams.detect_language = params.detect_language; + wparams.n_threads = params.n_threads; + wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx; + wparams.offset_ms = params.offset_t_ms; + wparams.duration_ms = params.duration_ms; + + wparams.thold_pt = params.word_thold; + wparams.max_len = params.max_len == 0 ? 60 : params.max_len; + wparams.split_on_word = params.split_on_word; + + wparams.speed_up = params.speed_up; + wparams.debug_mode = params.debug_mode; + + wparams.tdrz_enable = params.tinydiarize; // [TDRZ] + + wparams.initial_prompt = prompt.c_str(); + + wparams.greedy.best_of = params.best_of; + wparams.beam_search.beam_size = params.beam_size; + + wparams.temperature = temperature; + wparams.temperature_inc = params.temperature_inc; + wparams.entropy_thold = params.entropy_thold; + wparams.logprob_thold = params.logprob_thold; + + wparams.no_timestamps = params.no_timestamps; + + whisper_print_user_data user_data = {¶ms, &pcmf32s, 0}; + + // this callback is called on each new segment + if (params.print_realtime) + { + wparams.new_segment_callback = whisper_print_segment_callback; + wparams.new_segment_callback_user_data = &user_data; + } + + if (wparams.print_progress) + { + wparams.progress_callback = whisper_print_progress_callback; + wparams.progress_callback_user_data = &user_data; + } + + // examples for abort mechanism + // in examples below, we do not abort the processing, but we could if the flag is set to true + + // the callback is called before every encoder run - if it returns false, the processing is aborted + { + static bool is_aborted = false; // NOTE: this should be atomic to avoid data race + + wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void *user_data) + { + bool is_aborted = *(bool *)user_data; + return !is_aborted; + }; + wparams.encoder_begin_callback_user_data = &is_aborted; + } + + // the callback is called before every computation - if it returns true, the computation is aborted + { + static bool is_aborted = false; // NOTE: this should be atomic to avoid data race + + wparams.abort_callback = [](void *user_data) + { + bool is_aborted = *(bool *)user_data; + return is_aborted; + }; + wparams.abort_callback_user_data = &is_aborted; + } + + if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) + { + std::string error_resp = "Failed to process audio"; + LOG_ERROR << error_resp; + whisper_mutex.unlock(); + throw std::runtime_error(error_resp); + } + } + + // return results to user + std::string result; + if (response_format == text_format) + { + result = output_str(ctx, params, pcmf32s); + } + else if (response_format == srt_format) + { + std::stringstream ss; + const int n_segments = whisper_full_n_segments(ctx); + for (int i = 0; i < n_segments; ++i) + { + const char *text = whisper_full_get_segment_text(ctx, i); + const int64_t t0 = whisper_full_get_segment_t0(ctx, i); + const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + std::string speaker = ""; + + if (params.diarize && pcmf32s.size() == 2) + { + speaker = estimate_diarization_speaker(pcmf32s, t0, t1); + } + + ss << i + 1 + params.offset_n << "\n"; + ss << to_timestamp(t0, true) << " --> " << to_timestamp(t1, true) << "\n"; + ss << speaker << text << "\n\n"; + } + result = ss.str(); + } + else if (params.response_format == vtt_format) + { + std::stringstream ss; + + ss << "WEBVTT\n\n"; + + const int n_segments = whisper_full_n_segments(ctx); + for (int i = 0; i < n_segments; ++i) + { + const char *text = whisper_full_get_segment_text(ctx, i); + const int64_t t0 = whisper_full_get_segment_t0(ctx, i); + const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + std::string speaker = ""; + + if (params.diarize && pcmf32s.size() == 2) + { + speaker = estimate_diarization_speaker(pcmf32s, t0, t1, true); + speaker.insert(0, ""); + } + + ss << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n"; + ss << speaker << text << "\n\n"; + } + result = ss.str(); + } + // TODO add more output formats + else + { + std::string results = output_str(ctx, params, pcmf32s); + json jres = json{ + {"text", results}}; + result = jres.dump(-1, ' ', false, json::error_handler_t::replace); + } + + // reset params to thier defaults + params = default_params; + + // return whisper model mutex lock + whisper_mutex.unlock(); + + return result; +} + +whisper_server_context::~whisper_server_context() +{ + if (ctx) + { + whisper_print_timings(ctx); + whisper_free(ctx); + ctx = nullptr; + } +} + +std::optional whisperCPP::parse_model_id( + const std::shared_ptr &jsonBody, + const std::function &callback) +{ + if (!jsonBody->isMember("model_id")) { + LOG_INFO << "No model_id found in request body"; + Json::Value jsonResp; + jsonResp["message"] = "No model_id found in request body"; + auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); + resp->setStatusCode(k400BadRequest); + callback(resp); + return std::nullopt; // Signal that an error occurred + } + + return (*jsonBody)["model_id"].asString(); +} + +void whisperCPP::load_model( const HttpRequestPtr &req, std::function &&callback) { - MultiPartParser partParser; - Json::Value jsonResp; + const auto jsonBody = req->getJsonObject(); + auto optional_model_id = parse_model_id(jsonBody, callback); + if (!optional_model_id) { + return; + } + std::string model_id = *optional_model_id; + + // Check if model is already loaded + if (whispers.find(model_id) != whispers.end()) + { + std::string error_msg = "Model " + model_id + "has not been loaded, please load that model into nitro"; + LOG_INFO << error_msg; + Json::Value jsonResp; + jsonResp["message"] = error_msg; + auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); + resp->setStatusCode(k409Conflict); + callback(resp); + return; + } - if (partParser.parse(req) != 0 || partParser.getFiles().size() != 1) - { - auto resp = HttpResponse::newHttpResponse(); - resp->setBody("Must have exactly one file"); - resp->setStatusCode(k403Forbidden); + // Model not loaded, load it + // Parse model path from request + std::string model_path = (*jsonBody)["model_path"].asString(); + if (!is_file_exist(model_path.c_str())) + { + std::string error_msg = "Model " + model_path + " not found"; + LOG_INFO << error_msg; + Json::Value jsonResp; + jsonResp["message"] = error_msg; + auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); + resp->setStatusCode(k404NotFound); + callback(resp); + return; + } + + whisper_server_context whisper; + bool model_loaded = whisper.load_model(model_path); + // If model failed to load, return a 500 error + if (!model_loaded) + { + whisper.~whisper_server_context(); + std::string error_msg = "Failed to load model " + model_path; + LOG_INFO << error_msg; + Json::Value jsonResp; + jsonResp["message"] = error_msg; + auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); + resp->setStatusCode(k500InternalServerError); + callback(resp); + return; + } + + // Model loaded successfully, add it to the map of loaded models + // and return a 200 response + // whispers.emplace(model_id, std::move(whisper)); + // whispers[model_id] = std::move(whisper); + whispers[model_id] = whisper; + Json::Value jsonResp; + std::string success_msg = "Model " + model_id + " loaded successfully"; + jsonResp["message"] = success_msg; + auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); + resp->setStatusCode(k200OK); callback(resp); return; - } - auto &file = partParser.getFiles()[0]; - const auto &formFields = partParser.getParameters(); - std::string model = formFields.at("model"); - file.save(); - jsonResp["text"] = "handling text"; +} + +void whisperCPP::unload_model( + const HttpRequestPtr &req, + std::function &&callback) +{ + const auto &jsonBody = req->getJsonObject(); + auto optional_model_id = parse_model_id(jsonBody, callback); + if (!optional_model_id) { + return; + } + std::string model_id = *optional_model_id; + + // If model is not loaded, return a 404 error + if (whispers.find(model_id) == whispers.end()) + { + std::string error_msg = "Model " + model_id + " has not been loaded, please load that model into nitro"; + LOG_INFO << error_msg; + Json::Value jsonResp; + jsonResp["message"] = error_msg; + auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); + resp->setStatusCode(k404NotFound); + callback(resp); + return; + } - auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); - callback(resp); - return; + // Model loaded, unload it + whispers[model_id].~whisper_server_context(); + whispers.erase(model_id); + + // Return a 200 response + Json::Value jsonResp; + std::string success_msg = "Model " + model_id + " unloaded successfully"; + LOG_INFO << success_msg; + jsonResp["message"] = success_msg; + auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); + resp->setStatusCode(k200OK); + callback(resp); + return; } -void whisperCPP::translation( +void whisperCPP::model_status( const HttpRequestPtr &req, std::function &&callback) { - MultiPartParser partParser; - Json::Value jsonResp; + // Return a list of all loaded models + Json::Value jsonResp; + Json::Value models; + for (auto const &model : whispers) + { + models.append(model.first); + } + jsonResp["models"] = models; + auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); + resp->setStatusCode(k200OK); + callback(resp); + return; +} - if (partParser.parse(req) != 0 || partParser.getFiles().size() != 1) - { - auto resp = HttpResponse::newHttpResponse(); - resp->setBody("Must have exactly one file"); - resp->setStatusCode(k403Forbidden); +void whisperCPP::transcription_impl( + const HttpRequestPtr &req, + std::function &&callback, + bool translate) +{ + MultiPartParser partParser; + Json::Value jsonResp; + if (partParser.parse(req) != 0 || partParser.getFiles().size() != 1) + { + auto resp = HttpResponse::newHttpResponse(); + resp->setBody("Must have exactly one file"); + resp->setStatusCode(k403Forbidden); + callback(resp); + return; + } + auto &file = partParser.getFiles()[0]; + const auto &formFields = partParser.getParameters(); + + // Check if model_id are present in the request. If not, return a 400 error + if (formFields.find("model_id") == formFields.end()) + { + LOG_INFO << "No model_id found in request body"; + Json::Value jsonResp; + jsonResp["message"] = "No model_id found in request body"; + auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + std::string model_id = formFields.at("model_id"); + + // Parse all other optional parameters from the request + std::string language = formFields.find("language") != formFields.end() ? formFields.at("language") : "en"; + std::string prompt = formFields.find("prompt") != formFields.end() ? formFields.at("prompt") : ""; + std::string response_format = formFields.find("response_format") != formFields.end() ? formFields.at("response_format") : json_format; + float temperature = formFields.find("temperature") != formFields.end() ? std::stof(formFields.at("temperature")) : 0; + + // Check if model is loaded. If not, return a 404 error + if (whispers.find(model_id) == whispers.end()) + { + std::string error_msg = "Model " + model_id + " has not been loaded, please load that model into nitro"; + LOG_INFO << error_msg; + Json::Value jsonResp; + jsonResp["message"] = error_msg; + auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); + resp->setStatusCode(k404NotFound); + callback(resp); + return; + } + + // Save input file to temp location + std::string temp_file_path = std::filesystem::temp_directory_path().string() + "/" + std::to_string(std::chrono::system_clock::now().time_since_epoch().count()) + ".wav"; + file.save(temp_file_path); + + + // Run inference + std::string result; + try { + result = whispers[model_id].inference(temp_file_path, language, prompt, response_format, temperature, translate); + } catch (const std::exception &e) { + std::remove(temp_file_path.c_str()); + Json::Value jsonResp; + jsonResp["message"] = e.what(); + auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); + resp->setStatusCode(k500InternalServerError); + callback(resp); + return; + } + + auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); + resp->setBody(result); + resp->setStatusCode(k200OK); + // Set content type based on response format + if (response_format == json_format) + { + resp->addHeader("Content-Type", "application/json"); + } + else if (response_format == text_format) + { + resp->addHeader("Content-Type", "text/html"); + } + else if (response_format == srt_format) + { + resp->addHeader("Content-Type", "application/x-subrip"); + } + else if (response_format == vtt_format) + { + resp->addHeader("Content-Type", "text/vtt"); + } + std::remove(temp_file_path.c_str()); callback(resp); return; - } - auto &file = partParser.getFiles()[0]; - const auto &formFields = partParser.getParameters(); - std::string model = formFields.at("model"); - file.save(); +} - jsonResp["text"] = "handling text"; - auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); - callback(resp); - return; +void whisperCPP::transcription( + const HttpRequestPtr &req, + std::function &&callback) +{ + return transcription_impl(req, std::move(callback), false); +} + + +void whisperCPP::translation( + const HttpRequestPtr &req, + std::function &&callback) +{ + return transcription_impl(req, std::move(callback), true); } \ No newline at end of file diff --git a/controllers/whisperCPP.h b/controllers/whisperCPP.h index a9d4ec006..3eaeba30a 100644 --- a/controllers/whisperCPP.h +++ b/controllers/whisperCPP.h @@ -1,27 +1,195 @@ #pragma once +#include +#include #include +#include #include "whisper.h" +#define DR_WAV_IMPLEMENTATION +#include "utils/dr_wav.h" + +#include "utils/nitro_utils.h" +#include "utils/json.hpp" + +using json = nlohmann::ordered_json; + +// Terminal color map. 10 colors grouped in ranges [0.0, 0.1, ..., 0.9] +// Lowest is red, middle is yellow, highest is green. +const std::vector k_colors = { + "\033[38;5;196m", + "\033[38;5;202m", + "\033[38;5;208m", + "\033[38;5;214m", + "\033[38;5;220m", + "\033[38;5;226m", + "\033[38;5;190m", + "\033[38;5;154m", + "\033[38;5;118m", + "\033[38;5;82m", +}; + +// output formats +const std::string json_format = "json"; +const std::string text_format = "text"; +const std::string srt_format = "srt"; +const std::string vjson_format = "verbose_json"; +const std::string vtt_format = "vtt"; + +struct whisper_params +{ + int32_t n_threads = std::min(4, (int32_t)std::thread::hardware_concurrency()); + int32_t n_processors = 1; + int32_t offset_t_ms = 0; + int32_t offset_n = 0; + int32_t duration_ms = 0; + int32_t progress_step = 5; + int32_t max_context = -1; + int32_t max_len = 0; + int32_t best_of = 2; + int32_t beam_size = -1; + + float word_thold = 0.01f; + float entropy_thold = 2.40f; + float logprob_thold = -1.00f; + float temperature = 0.00f; + float temperature_inc = 0.20f; + + bool speed_up = false; + bool debug_mode = false; + bool translate = false; + bool detect_language = false; + bool diarize = false; + bool tinydiarize = false; + bool split_on_word = false; + bool no_fallback = false; + bool print_special = false; + bool print_colors = false; + bool print_realtime = false; + bool print_progress = false; + bool no_timestamps = false; + bool use_gpu = true; + bool ffmpeg_converter = false; + + std::string language = "en"; + std::string prompt = ""; + std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf"; + std::string model = "models/ggml-base.en.bin"; + + std::string response_format = json_format; + + // [TDRZ] speaker turn string + std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line + + std::string openvino_encode_device = "CPU"; +}; + +struct whisper_print_user_data +{ + const whisper_params *params; + + const std::vector> *pcmf32s; + int progress_prev; +}; + +#define COMMON_SAMPLE_RATE 16000 + +// Read WAV audio file and store the PCM data into pcmf32 +// The sample rate of the audio must be equal to COMMON_SAMPLE_RATE +// If stereo flag is set and the audio has 2 channels, the pcmf32s will contain 2 channel PCM +bool read_wav( + const std::string &fname, + std::vector &pcmf32, + std::vector> &pcmf32s, + bool stereo); + +std::string output_str(struct whisper_context *ctx, const whisper_params ¶ms, std::vector> pcmf32s); + +std::string estimate_diarization_speaker(std::vector> pcmf32s, int64_t t0, int64_t t1, bool id_only = false); + +// 500 -> 00:05.000 +// 6000 -> 01:00.000 +std::string to_timestamp(int64_t t, bool comma = false); + +int timestamp_to_sample(int64_t t, int n_samples); + +bool is_file_exist(const char *fileName); + +void whisper_print_usage(int /*argc*/, char **argv, const whisper_params ¶ms); + +bool whisper_params_parse(int argc, char **argv, whisper_params ¶ms); + +void check_ffmpeg_availibility(); + +bool convert_to_wav(const std::string &temp_filename, std::string &error_resp); + +void whisper_print_progress_callback(struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, int progress, void *user_data); + +void whisper_print_segment_callback(struct whisper_context *ctx, struct whisper_state * /*state*/, int n_new, void *user_data); + +bool parse_str_to_bool(const std::string &s); + +struct whisper_server_context +{ + whisper_params params; + // store default params so we can reset after each inference request + whisper_params default_params = params; + std::mutex whisper_mutex; + std::string model_id; + + struct whisper_context_params cparams; + struct whisper_context *ctx = nullptr; + + bool load_model(std::string &model_path); + + std::string inference(std::string &input_file_path, std::string languague, std::string prompt, + std::string response_format, float temperature, bool translate); + + ~whisper_server_context(); +}; + using namespace drogon; class whisperCPP : public drogon::HttpController { public: - METHOD_LIST_BEGIN - // use METHOD_ADD to add your custom processing function here; - // METHOD_ADD(whisperCPP::get, "/{2}/{1}", Get); // path is /whisperCPP/{arg2}/{arg1} - // METHOD_ADD(whisperCPP::your_method_name, "/{1}/{2}/list", Get); // path is /whisperCPP/{arg1}/{arg2}/list - // ADD_METHOD_TO(whisperCPP::your_method_name, "/absolute/path/{1}/{2}/list", Get); // path is /absolute/path/{arg1}/{arg2}/list - ADD_METHOD_TO(whisperCPP::transcription, "/v1/audio/transcriptions", Post); - ADD_METHOD_TO(whisperCPP::translation, "/v1/audio/translations", Post); - METHOD_LIST_END - // your declaration of processing function maybe like this: - // void get(const HttpRequestPtr& req, std::function &&callback, int p1, std::string p2); - // void your_method_name(const HttpRequestPtr& req, std::function &&callback, double p1, int p2) const; - void transcription(const HttpRequestPtr &req, + METHOD_LIST_BEGIN + + METHOD_ADD(whisperCPP::load_model, "load_model", Post); + METHOD_ADD(whisperCPP::unload_model, "unload_model", Post); + METHOD_ADD(whisperCPP::model_status, "model_status", Get); + + ADD_METHOD_TO(whisperCPP::transcription, "/v1/audio/transcriptions", Post); + ADD_METHOD_TO(whisperCPP::translation, "/v1/audio/translations", Post); + + METHOD_LIST_END + + whisperCPP() { + whisper_print_system_info(); + } + + void load_model(const HttpRequestPtr &req, + std::function &&callback); + + void unload_model(const HttpRequestPtr &req, + std::function &&callback); + + void model_status(const HttpRequestPtr &req, + std::function &&callback); + + void transcription(const HttpRequestPtr &req, + std::function &&callback); + + void translation(const HttpRequestPtr &req, std::function &&callback); - void translation(const HttpRequestPtr &req, - std::function &&callback); +private: + std::unordered_map whispers; + + std::optionalparse_model_id(const std::shared_ptr &jsonBody, + const std::function &callback); + + void transcription_impl(const HttpRequestPtr &req, + std::function &&callback, + bool translate); }; diff --git a/utils/dr_wav.h b/utils/dr_wav.h new file mode 100644 index 000000000..fd3e95b34 --- /dev/null +++ b/utils/dr_wav.h @@ -0,0 +1,6434 @@ +/* +WAV audio loader and writer. Choice of public domain or MIT-0. See license statements at the end of this file. +dr_wav - v0.12.16 - 2020-12-02 + +David Reid - mackron@gmail.com + +GitHub: https://github.com/mackron/dr_libs +*/ + +/* +RELEASE NOTES - VERSION 0.12 +============================ +Version 0.12 includes breaking changes to custom chunk handling. + + +Changes to Chunk Callback +------------------------- +dr_wav supports the ability to fire a callback when a chunk is encounted (except for WAVE and FMT chunks). The callback has been updated to include both the +container (RIFF or Wave64) and the FMT chunk which contains information about the format of the data in the wave file. + +Previously, there was no direct way to determine the container, and therefore no way to discriminate against the different IDs in the chunk header (RIFF and +Wave64 containers encode chunk ID's differently). The `container` parameter can be used to know which ID to use. + +Sometimes it can be useful to know the data format at the time the chunk callback is fired. A pointer to a `drwav_fmt` object is now passed into the chunk +callback which will give you information about the data format. To determine the sample format, use `drwav_fmt_get_format()`. This will return one of the +`DR_WAVE_FORMAT_*` tokens. +*/ + +/* +Introduction +============ +This is a single file library. To use it, do something like the following in one .c file. + + ```c + #define DR_WAV_IMPLEMENTATION + #include "dr_wav.h" + ``` + +You can then #include this file in other parts of the program as you would with any other header file. Do something like the following to read audio data: + + ```c + drwav wav; + if (!drwav_init_file(&wav, "my_song.wav", NULL)) { + // Error opening WAV file. + } + + drwav_int32* pDecodedInterleavedPCMFrames = malloc(wav.totalPCMFrameCount * wav.channels * sizeof(drwav_int32)); + size_t numberOfSamplesActuallyDecoded = drwav_read_pcm_frames_s32(&wav, wav.totalPCMFrameCount, pDecodedInterleavedPCMFrames); + + ... + + drwav_uninit(&wav); + ``` + +If you just want to quickly open and read the audio data in a single operation you can do something like this: + + ```c + unsigned int channels; + unsigned int sampleRate; + drwav_uint64 totalPCMFrameCount; + float* pSampleData = drwav_open_file_and_read_pcm_frames_f32("my_song.wav", &channels, &sampleRate, &totalPCMFrameCount, NULL); + if (pSampleData == NULL) { + // Error opening and reading WAV file. + } + + ... + + drwav_free(pSampleData); + ``` + +The examples above use versions of the API that convert the audio data to a consistent format (32-bit signed PCM, in this case), but you can still output the +audio data in its internal format (see notes below for supported formats): + + ```c + size_t framesRead = drwav_read_pcm_frames(&wav, wav.totalPCMFrameCount, pDecodedInterleavedPCMFrames); + ``` + +You can also read the raw bytes of audio data, which could be useful if dr_wav does not have native support for a particular data format: + + ```c + size_t bytesRead = drwav_read_raw(&wav, bytesToRead, pRawDataBuffer); + ``` + +dr_wav can also be used to output WAV files. This does not currently support compressed formats. To use this, look at `drwav_init_write()`, +`drwav_init_file_write()`, etc. Use `drwav_write_pcm_frames()` to write samples, or `drwav_write_raw()` to write raw data in the "data" chunk. + + ```c + drwav_data_format format; + format.container = drwav_container_riff; // <-- drwav_container_riff = normal WAV files, drwav_container_w64 = Sony Wave64. + format.format = DR_WAVE_FORMAT_PCM; // <-- Any of the DR_WAVE_FORMAT_* codes. + format.channels = 2; + format.sampleRate = 44100; + format.bitsPerSample = 16; + drwav_init_file_write(&wav, "data/recording.wav", &format, NULL); + + ... + + drwav_uint64 framesWritten = drwav_write_pcm_frames(pWav, frameCount, pSamples); + ``` + +dr_wav has seamless support the Sony Wave64 format. The decoder will automatically detect it and it should Just Work without any manual intervention. + + +Build Options +============= +#define these options before including this file. + +#define DR_WAV_NO_CONVERSION_API + Disables conversion APIs such as `drwav_read_pcm_frames_f32()` and `drwav_s16_to_f32()`. + +#define DR_WAV_NO_STDIO + Disables APIs that initialize a decoder from a file such as `drwav_init_file()`, `drwav_init_file_write()`, etc. + + + +Notes +===== +- Samples are always interleaved. +- The default read function does not do any data conversion. Use `drwav_read_pcm_frames_f32()`, `drwav_read_pcm_frames_s32()` and `drwav_read_pcm_frames_s16()` + to read and convert audio data to 32-bit floating point, signed 32-bit integer and signed 16-bit integer samples respectively. Tested and supported internal + formats include the following: + - Unsigned 8-bit PCM + - Signed 12-bit PCM + - Signed 16-bit PCM + - Signed 24-bit PCM + - Signed 32-bit PCM + - IEEE 32-bit floating point + - IEEE 64-bit floating point + - A-law and u-law + - Microsoft ADPCM + - IMA ADPCM (DVI, format code 0x11) +- dr_wav will try to read the WAV file as best it can, even if it's not strictly conformant to the WAV format. +*/ + +#ifndef dr_wav_h +#define dr_wav_h + +#ifdef __cplusplus +extern "C" { +#endif + +#define DRWAV_STRINGIFY(x) #x +#define DRWAV_XSTRINGIFY(x) DRWAV_STRINGIFY(x) + +#define DRWAV_VERSION_MAJOR 0 +#define DRWAV_VERSION_MINOR 12 +#define DRWAV_VERSION_REVISION 16 +#define DRWAV_VERSION_STRING DRWAV_XSTRINGIFY(DRWAV_VERSION_MAJOR) "." DRWAV_XSTRINGIFY(DRWAV_VERSION_MINOR) "." DRWAV_XSTRINGIFY(DRWAV_VERSION_REVISION) + +#include /* For size_t. */ + +/* Sized types. */ +typedef signed char drwav_int8; +typedef unsigned char drwav_uint8; +typedef signed short drwav_int16; +typedef unsigned short drwav_uint16; +typedef signed int drwav_int32; +typedef unsigned int drwav_uint32; +#if defined(_MSC_VER) + typedef signed __int64 drwav_int64; + typedef unsigned __int64 drwav_uint64; +#else + #if defined(__clang__) || (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6))) + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wlong-long" + #if defined(__clang__) + #pragma GCC diagnostic ignored "-Wc++11-long-long" + #endif + #endif + typedef signed long long drwav_int64; + typedef unsigned long long drwav_uint64; + #if defined(__clang__) || (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6))) + #pragma GCC diagnostic pop + #endif +#endif +#if defined(__LP64__) || defined(_WIN64) || (defined(__x86_64__) && !defined(__ILP32__)) || defined(_M_X64) || defined(__ia64) || defined (_M_IA64) || defined(__aarch64__) || defined(__powerpc64__) + typedef drwav_uint64 drwav_uintptr; +#else + typedef drwav_uint32 drwav_uintptr; +#endif +typedef drwav_uint8 drwav_bool8; +typedef drwav_uint32 drwav_bool32; +#define DRWAV_TRUE 1 +#define DRWAV_FALSE 0 + +#if !defined(DRWAV_API) + #if defined(DRWAV_DLL) + #if defined(_WIN32) + #define DRWAV_DLL_IMPORT __declspec(dllimport) + #define DRWAV_DLL_EXPORT __declspec(dllexport) + #define DRWAV_DLL_PRIVATE static + #else + #if defined(__GNUC__) && __GNUC__ >= 4 + #define DRWAV_DLL_IMPORT __attribute__((visibility("default"))) + #define DRWAV_DLL_EXPORT __attribute__((visibility("default"))) + #define DRWAV_DLL_PRIVATE __attribute__((visibility("hidden"))) + #else + #define DRWAV_DLL_IMPORT + #define DRWAV_DLL_EXPORT + #define DRWAV_DLL_PRIVATE static + #endif + #endif + + #if defined(DR_WAV_IMPLEMENTATION) || defined(DRWAV_IMPLEMENTATION) + #define DRWAV_API DRWAV_DLL_EXPORT + #else + #define DRWAV_API DRWAV_DLL_IMPORT + #endif + #define DRWAV_PRIVATE DRWAV_DLL_PRIVATE + #else + #define DRWAV_API extern + #define DRWAV_PRIVATE static + #endif +#endif + +typedef drwav_int32 drwav_result; +#define DRWAV_SUCCESS 0 +#define DRWAV_ERROR -1 /* A generic error. */ +#define DRWAV_INVALID_ARGS -2 +#define DRWAV_INVALID_OPERATION -3 +#define DRWAV_OUT_OF_MEMORY -4 +#define DRWAV_OUT_OF_RANGE -5 +#define DRWAV_ACCESS_DENIED -6 +#define DRWAV_DOES_NOT_EXIST -7 +#define DRWAV_ALREADY_EXISTS -8 +#define DRWAV_TOO_MANY_OPEN_FILES -9 +#define DRWAV_INVALID_FILE -10 +#define DRWAV_TOO_BIG -11 +#define DRWAV_PATH_TOO_LONG -12 +#define DRWAV_NAME_TOO_LONG -13 +#define DRWAV_NOT_DIRECTORY -14 +#define DRWAV_IS_DIRECTORY -15 +#define DRWAV_DIRECTORY_NOT_EMPTY -16 +#define DRWAV_END_OF_FILE -17 +#define DRWAV_NO_SPACE -18 +#define DRWAV_BUSY -19 +#define DRWAV_IO_ERROR -20 +#define DRWAV_INTERRUPT -21 +#define DRWAV_UNAVAILABLE -22 +#define DRWAV_ALREADY_IN_USE -23 +#define DRWAV_BAD_ADDRESS -24 +#define DRWAV_BAD_SEEK -25 +#define DRWAV_BAD_PIPE -26 +#define DRWAV_DEADLOCK -27 +#define DRWAV_TOO_MANY_LINKS -28 +#define DRWAV_NOT_IMPLEMENTED -29 +#define DRWAV_NO_MESSAGE -30 +#define DRWAV_BAD_MESSAGE -31 +#define DRWAV_NO_DATA_AVAILABLE -32 +#define DRWAV_INVALID_DATA -33 +#define DRWAV_TIMEOUT -34 +#define DRWAV_NO_NETWORK -35 +#define DRWAV_NOT_UNIQUE -36 +#define DRWAV_NOT_SOCKET -37 +#define DRWAV_NO_ADDRESS -38 +#define DRWAV_BAD_PROTOCOL -39 +#define DRWAV_PROTOCOL_UNAVAILABLE -40 +#define DRWAV_PROTOCOL_NOT_SUPPORTED -41 +#define DRWAV_PROTOCOL_FAMILY_NOT_SUPPORTED -42 +#define DRWAV_ADDRESS_FAMILY_NOT_SUPPORTED -43 +#define DRWAV_SOCKET_NOT_SUPPORTED -44 +#define DRWAV_CONNECTION_RESET -45 +#define DRWAV_ALREADY_CONNECTED -46 +#define DRWAV_NOT_CONNECTED -47 +#define DRWAV_CONNECTION_REFUSED -48 +#define DRWAV_NO_HOST -49 +#define DRWAV_IN_PROGRESS -50 +#define DRWAV_CANCELLED -51 +#define DRWAV_MEMORY_ALREADY_MAPPED -52 +#define DRWAV_AT_END -53 + +/* Common data formats. */ +#define DR_WAVE_FORMAT_PCM 0x1 +#define DR_WAVE_FORMAT_ADPCM 0x2 +#define DR_WAVE_FORMAT_IEEE_FLOAT 0x3 +#define DR_WAVE_FORMAT_ALAW 0x6 +#define DR_WAVE_FORMAT_MULAW 0x7 +#define DR_WAVE_FORMAT_DVI_ADPCM 0x11 +#define DR_WAVE_FORMAT_EXTENSIBLE 0xFFFE + +/* Constants. */ +#ifndef DRWAV_MAX_SMPL_LOOPS +#define DRWAV_MAX_SMPL_LOOPS 1 +#endif + +/* Flags to pass into drwav_init_ex(), etc. */ +#define DRWAV_SEQUENTIAL 0x00000001 + +DRWAV_API void drwav_version(drwav_uint32* pMajor, drwav_uint32* pMinor, drwav_uint32* pRevision); +DRWAV_API const char* drwav_version_string(void); + +typedef enum +{ + drwav_seek_origin_start, + drwav_seek_origin_current +} drwav_seek_origin; + +typedef enum +{ + drwav_container_riff, + drwav_container_w64, + drwav_container_rf64 +} drwav_container; + +typedef struct +{ + union + { + drwav_uint8 fourcc[4]; + drwav_uint8 guid[16]; + } id; + + /* The size in bytes of the chunk. */ + drwav_uint64 sizeInBytes; + + /* + RIFF = 2 byte alignment. + W64 = 8 byte alignment. + */ + unsigned int paddingSize; +} drwav_chunk_header; + +typedef struct +{ + /* + The format tag exactly as specified in the wave file's "fmt" chunk. This can be used by applications + that require support for data formats not natively supported by dr_wav. + */ + drwav_uint16 formatTag; + + /* The number of channels making up the audio data. When this is set to 1 it is mono, 2 is stereo, etc. */ + drwav_uint16 channels; + + /* The sample rate. Usually set to something like 44100. */ + drwav_uint32 sampleRate; + + /* Average bytes per second. You probably don't need this, but it's left here for informational purposes. */ + drwav_uint32 avgBytesPerSec; + + /* Block align. This is equal to the number of channels * bytes per sample. */ + drwav_uint16 blockAlign; + + /* Bits per sample. */ + drwav_uint16 bitsPerSample; + + /* The size of the extended data. Only used internally for validation, but left here for informational purposes. */ + drwav_uint16 extendedSize; + + /* + The number of valid bits per sample. When is equal to WAVE_FORMAT_EXTENSIBLE, + is always rounded up to the nearest multiple of 8. This variable contains information about exactly how + many bits are valid per sample. Mainly used for informational purposes. + */ + drwav_uint16 validBitsPerSample; + + /* The channel mask. Not used at the moment. */ + drwav_uint32 channelMask; + + /* The sub-format, exactly as specified by the wave file. */ + drwav_uint8 subFormat[16]; +} drwav_fmt; + +DRWAV_API drwav_uint16 drwav_fmt_get_format(const drwav_fmt* pFMT); + + +/* +Callback for when data is read. Return value is the number of bytes actually read. + +pUserData [in] The user data that was passed to drwav_init() and family. +pBufferOut [out] The output buffer. +bytesToRead [in] The number of bytes to read. + +Returns the number of bytes actually read. + +A return value of less than bytesToRead indicates the end of the stream. Do _not_ return from this callback until +either the entire bytesToRead is filled or you have reached the end of the stream. +*/ +typedef size_t (* drwav_read_proc)(void* pUserData, void* pBufferOut, size_t bytesToRead); + +/* +Callback for when data is written. Returns value is the number of bytes actually written. + +pUserData [in] The user data that was passed to drwav_init_write() and family. +pData [out] A pointer to the data to write. +bytesToWrite [in] The number of bytes to write. + +Returns the number of bytes actually written. + +If the return value differs from bytesToWrite, it indicates an error. +*/ +typedef size_t (* drwav_write_proc)(void* pUserData, const void* pData, size_t bytesToWrite); + +/* +Callback for when data needs to be seeked. + +pUserData [in] The user data that was passed to drwav_init() and family. +offset [in] The number of bytes to move, relative to the origin. Will never be negative. +origin [in] The origin of the seek - the current position or the start of the stream. + +Returns whether or not the seek was successful. + +Whether or not it is relative to the beginning or current position is determined by the "origin" parameter which will be either drwav_seek_origin_start or +drwav_seek_origin_current. +*/ +typedef drwav_bool32 (* drwav_seek_proc)(void* pUserData, int offset, drwav_seek_origin origin); + +/* +Callback for when drwav_init_ex() finds a chunk. + +pChunkUserData [in] The user data that was passed to the pChunkUserData parameter of drwav_init_ex() and family. +onRead [in] A pointer to the function to call when reading. +onSeek [in] A pointer to the function to call when seeking. +pReadSeekUserData [in] The user data that was passed to the pReadSeekUserData parameter of drwav_init_ex() and family. +pChunkHeader [in] A pointer to an object containing basic header information about the chunk. Use this to identify the chunk. +container [in] Whether or not the WAV file is a RIFF or Wave64 container. If you're unsure of the difference, assume RIFF. +pFMT [in] A pointer to the object containing the contents of the "fmt" chunk. + +Returns the number of bytes read + seeked. + +To read data from the chunk, call onRead(), passing in pReadSeekUserData as the first parameter. Do the same for seeking with onSeek(). The return value must +be the total number of bytes you have read _plus_ seeked. + +Use the `container` argument to discriminate the fields in `pChunkHeader->id`. If the container is `drwav_container_riff` or `drwav_container_rf64` you should +use `id.fourcc`, otherwise you should use `id.guid`. + +The `pFMT` parameter can be used to determine the data format of the wave file. Use `drwav_fmt_get_format()` to get the sample format, which will be one of the +`DR_WAVE_FORMAT_*` identifiers. + +The read pointer will be sitting on the first byte after the chunk's header. You must not attempt to read beyond the boundary of the chunk. +*/ +typedef drwav_uint64 (* drwav_chunk_proc)(void* pChunkUserData, drwav_read_proc onRead, drwav_seek_proc onSeek, void* pReadSeekUserData, const drwav_chunk_header* pChunkHeader, drwav_container container, const drwav_fmt* pFMT); + +typedef struct +{ + void* pUserData; + void* (* onMalloc)(size_t sz, void* pUserData); + void* (* onRealloc)(void* p, size_t sz, void* pUserData); + void (* onFree)(void* p, void* pUserData); +} drwav_allocation_callbacks; + +/* Structure for internal use. Only used for loaders opened with drwav_init_memory(). */ +typedef struct +{ + const drwav_uint8* data; + size_t dataSize; + size_t currentReadPos; +} drwav__memory_stream; + +/* Structure for internal use. Only used for writers opened with drwav_init_memory_write(). */ +typedef struct +{ + void** ppData; + size_t* pDataSize; + size_t dataSize; + size_t dataCapacity; + size_t currentWritePos; +} drwav__memory_stream_write; + +typedef struct +{ + drwav_container container; /* RIFF, W64. */ + drwav_uint32 format; /* DR_WAVE_FORMAT_* */ + drwav_uint32 channels; + drwav_uint32 sampleRate; + drwav_uint32 bitsPerSample; +} drwav_data_format; + + +/* See the following for details on the 'smpl' chunk: https://sites.google.com/site/musicgapi/technical-documents/wav-file-format#smpl */ +typedef struct +{ + drwav_uint32 cuePointId; + drwav_uint32 type; + drwav_uint32 start; + drwav_uint32 end; + drwav_uint32 fraction; + drwav_uint32 playCount; +} drwav_smpl_loop; + + typedef struct +{ + drwav_uint32 manufacturer; + drwav_uint32 product; + drwav_uint32 samplePeriod; + drwav_uint32 midiUnityNotes; + drwav_uint32 midiPitchFraction; + drwav_uint32 smpteFormat; + drwav_uint32 smpteOffset; + drwav_uint32 numSampleLoops; + drwav_uint32 samplerData; + drwav_smpl_loop loops[DRWAV_MAX_SMPL_LOOPS]; +} drwav_smpl; + +typedef struct +{ + /* A pointer to the function to call when more data is needed. */ + drwav_read_proc onRead; + + /* A pointer to the function to call when data needs to be written. Only used when the drwav object is opened in write mode. */ + drwav_write_proc onWrite; + + /* A pointer to the function to call when the wav file needs to be seeked. */ + drwav_seek_proc onSeek; + + /* The user data to pass to callbacks. */ + void* pUserData; + + /* Allocation callbacks. */ + drwav_allocation_callbacks allocationCallbacks; + + + /* Whether or not the WAV file is formatted as a standard RIFF file or W64. */ + drwav_container container; + + + /* Structure containing format information exactly as specified by the wav file. */ + drwav_fmt fmt; + + /* The sample rate. Will be set to something like 44100. */ + drwav_uint32 sampleRate; + + /* The number of channels. This will be set to 1 for monaural streams, 2 for stereo, etc. */ + drwav_uint16 channels; + + /* The bits per sample. Will be set to something like 16, 24, etc. */ + drwav_uint16 bitsPerSample; + + /* Equal to fmt.formatTag, or the value specified by fmt.subFormat if fmt.formatTag is equal to 65534 (WAVE_FORMAT_EXTENSIBLE). */ + drwav_uint16 translatedFormatTag; + + /* The total number of PCM frames making up the audio data. */ + drwav_uint64 totalPCMFrameCount; + + + /* The size in bytes of the data chunk. */ + drwav_uint64 dataChunkDataSize; + + /* The position in the stream of the first byte of the data chunk. This is used for seeking. */ + drwav_uint64 dataChunkDataPos; + + /* The number of bytes remaining in the data chunk. */ + drwav_uint64 bytesRemaining; + + + /* + Only used in sequential write mode. Keeps track of the desired size of the "data" chunk at the point of initialization time. Always + set to 0 for non-sequential writes and when the drwav object is opened in read mode. Used for validation. + */ + drwav_uint64 dataChunkDataSizeTargetWrite; + + /* Keeps track of whether or not the wav writer was initialized in sequential mode. */ + drwav_bool32 isSequentialWrite; + + + /* smpl chunk. */ + drwav_smpl smpl; + + + /* A hack to avoid a DRWAV_MALLOC() when opening a decoder with drwav_init_memory(). */ + drwav__memory_stream memoryStream; + drwav__memory_stream_write memoryStreamWrite; + + /* Generic data for compressed formats. This data is shared across all block-compressed formats. */ + struct + { + drwav_uint64 iCurrentPCMFrame; /* The index of the next PCM frame that will be read by drwav_read_*(). This is used with "totalPCMFrameCount" to ensure we don't read excess samples at the end of the last block. */ + } compressed; + + /* Microsoft ADPCM specific data. */ + struct + { + drwav_uint32 bytesRemainingInBlock; + drwav_uint16 predictor[2]; + drwav_int32 delta[2]; + drwav_int32 cachedFrames[4]; /* Samples are stored in this cache during decoding. */ + drwav_uint32 cachedFrameCount; + drwav_int32 prevFrames[2][2]; /* The previous 2 samples for each channel (2 channels at most). */ + } msadpcm; + + /* IMA ADPCM specific data. */ + struct + { + drwav_uint32 bytesRemainingInBlock; + drwav_int32 predictor[2]; + drwav_int32 stepIndex[2]; + drwav_int32 cachedFrames[16]; /* Samples are stored in this cache during decoding. */ + drwav_uint32 cachedFrameCount; + } ima; +} drwav; + + +/* +Initializes a pre-allocated drwav object for reading. + +pWav [out] A pointer to the drwav object being initialized. +onRead [in] The function to call when data needs to be read from the client. +onSeek [in] The function to call when the read position of the client data needs to move. +onChunk [in, optional] The function to call when a chunk is enumerated at initialized time. +pUserData, pReadSeekUserData [in, optional] A pointer to application defined data that will be passed to onRead and onSeek. +pChunkUserData [in, optional] A pointer to application defined data that will be passed to onChunk. +flags [in, optional] A set of flags for controlling how things are loaded. + +Returns true if successful; false otherwise. + +Close the loader with drwav_uninit(). + +This is the lowest level function for initializing a WAV file. You can also use drwav_init_file() and drwav_init_memory() +to open the stream from a file or from a block of memory respectively. + +Possible values for flags: + DRWAV_SEQUENTIAL: Never perform a backwards seek while loading. This disables the chunk callback and will cause this function + to return as soon as the data chunk is found. Any chunks after the data chunk will be ignored. + +drwav_init() is equivalent to "drwav_init_ex(pWav, onRead, onSeek, NULL, pUserData, NULL, 0);". + +The onChunk callback is not called for the WAVE or FMT chunks. The contents of the FMT chunk can be read from pWav->fmt +after the function returns. + +See also: drwav_init_file(), drwav_init_memory(), drwav_uninit() +*/ +DRWAV_API drwav_bool32 drwav_init(drwav* pWav, drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_ex(drwav* pWav, drwav_read_proc onRead, drwav_seek_proc onSeek, drwav_chunk_proc onChunk, void* pReadSeekUserData, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks); + +/* +Initializes a pre-allocated drwav object for writing. + +onWrite [in] The function to call when data needs to be written. +onSeek [in] The function to call when the write position needs to move. +pUserData [in, optional] A pointer to application defined data that will be passed to onWrite and onSeek. + +Returns true if successful; false otherwise. + +Close the writer with drwav_uninit(). + +This is the lowest level function for initializing a WAV file. You can also use drwav_init_file_write() and drwav_init_memory_write() +to open the stream from a file or from a block of memory respectively. + +If the total sample count is known, you can use drwav_init_write_sequential(). This avoids the need for dr_wav to perform +a post-processing step for storing the total sample count and the size of the data chunk which requires a backwards seek. + +See also: drwav_init_file_write(), drwav_init_memory_write(), drwav_uninit() +*/ +DRWAV_API drwav_bool32 drwav_init_write(drwav* pWav, const drwav_data_format* pFormat, drwav_write_proc onWrite, drwav_seek_proc onSeek, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_write_sequential(drwav* pWav, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, drwav_write_proc onWrite, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_write_sequential_pcm_frames(drwav* pWav, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, drwav_write_proc onWrite, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks); + +/* +Utility function to determine the target size of the entire data to be written (including all headers and chunks). + +Returns the target size in bytes. + +Useful if the application needs to know the size to allocate. + +Only writing to the RIFF chunk and one data chunk is currently supported. + +See also: drwav_init_write(), drwav_init_file_write(), drwav_init_memory_write() +*/ +DRWAV_API drwav_uint64 drwav_target_write_size_bytes(const drwav_data_format* pFormat, drwav_uint64 totalSampleCount); + +/* +Uninitializes the given drwav object. + +Use this only for objects initialized with drwav_init*() functions (drwav_init(), drwav_init_ex(), drwav_init_write(), drwav_init_write_sequential()). +*/ +DRWAV_API drwav_result drwav_uninit(drwav* pWav); + + +/* +Reads raw audio data. + +This is the lowest level function for reading audio data. It simply reads the given number of +bytes of the raw internal sample data. + +Consider using drwav_read_pcm_frames_s16(), drwav_read_pcm_frames_s32() or drwav_read_pcm_frames_f32() for +reading sample data in a consistent format. + +pBufferOut can be NULL in which case a seek will be performed. + +Returns the number of bytes actually read. +*/ +DRWAV_API size_t drwav_read_raw(drwav* pWav, size_t bytesToRead, void* pBufferOut); + +/* +Reads up to the specified number of PCM frames from the WAV file. + +The output data will be in the file's internal format, converted to native-endian byte order. Use +drwav_read_pcm_frames_s16/f32/s32() to read data in a specific format. + +If the return value is less than it means the end of the file has been reached or +you have requested more PCM frames than can possibly fit in the output buffer. + +This function will only work when sample data is of a fixed size and uncompressed. If you are +using a compressed format consider using drwav_read_raw() or drwav_read_pcm_frames_s16/s32/f32(). + +pBufferOut can be NULL in which case a seek will be performed. +*/ +DRWAV_API drwav_uint64 drwav_read_pcm_frames(drwav* pWav, drwav_uint64 framesToRead, void* pBufferOut); +DRWAV_API drwav_uint64 drwav_read_pcm_frames_le(drwav* pWav, drwav_uint64 framesToRead, void* pBufferOut); +DRWAV_API drwav_uint64 drwav_read_pcm_frames_be(drwav* pWav, drwav_uint64 framesToRead, void* pBufferOut); + +/* +Seeks to the given PCM frame. + +Returns true if successful; false otherwise. +*/ +DRWAV_API drwav_bool32 drwav_seek_to_pcm_frame(drwav* pWav, drwav_uint64 targetFrameIndex); + + +/* +Writes raw audio data. + +Returns the number of bytes actually written. If this differs from bytesToWrite, it indicates an error. +*/ +DRWAV_API size_t drwav_write_raw(drwav* pWav, size_t bytesToWrite, const void* pData); + +/* +Writes PCM frames. + +Returns the number of PCM frames written. + +Input samples need to be in native-endian byte order. On big-endian architectures the input data will be converted to +little-endian. Use drwav_write_raw() to write raw audio data without performing any conversion. +*/ +DRWAV_API drwav_uint64 drwav_write_pcm_frames(drwav* pWav, drwav_uint64 framesToWrite, const void* pData); +DRWAV_API drwav_uint64 drwav_write_pcm_frames_le(drwav* pWav, drwav_uint64 framesToWrite, const void* pData); +DRWAV_API drwav_uint64 drwav_write_pcm_frames_be(drwav* pWav, drwav_uint64 framesToWrite, const void* pData); + + +/* Conversion Utilities */ +#ifndef DR_WAV_NO_CONVERSION_API + +/* +Reads a chunk of audio data and converts it to signed 16-bit PCM samples. + +pBufferOut can be NULL in which case a seek will be performed. + +Returns the number of PCM frames actually read. + +If the return value is less than it means the end of the file has been reached. +*/ +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s16(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut); +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s16le(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut); +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s16be(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut); + +/* Low-level function for converting unsigned 8-bit PCM samples to signed 16-bit PCM samples. */ +DRWAV_API void drwav_u8_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount); + +/* Low-level function for converting signed 24-bit PCM samples to signed 16-bit PCM samples. */ +DRWAV_API void drwav_s24_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount); + +/* Low-level function for converting signed 32-bit PCM samples to signed 16-bit PCM samples. */ +DRWAV_API void drwav_s32_to_s16(drwav_int16* pOut, const drwav_int32* pIn, size_t sampleCount); + +/* Low-level function for converting IEEE 32-bit floating point samples to signed 16-bit PCM samples. */ +DRWAV_API void drwav_f32_to_s16(drwav_int16* pOut, const float* pIn, size_t sampleCount); + +/* Low-level function for converting IEEE 64-bit floating point samples to signed 16-bit PCM samples. */ +DRWAV_API void drwav_f64_to_s16(drwav_int16* pOut, const double* pIn, size_t sampleCount); + +/* Low-level function for converting A-law samples to signed 16-bit PCM samples. */ +DRWAV_API void drwav_alaw_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount); + +/* Low-level function for converting u-law samples to signed 16-bit PCM samples. */ +DRWAV_API void drwav_mulaw_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount); + + +/* +Reads a chunk of audio data and converts it to IEEE 32-bit floating point samples. + +pBufferOut can be NULL in which case a seek will be performed. + +Returns the number of PCM frames actually read. + +If the return value is less than it means the end of the file has been reached. +*/ +DRWAV_API drwav_uint64 drwav_read_pcm_frames_f32(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut); +DRWAV_API drwav_uint64 drwav_read_pcm_frames_f32le(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut); +DRWAV_API drwav_uint64 drwav_read_pcm_frames_f32be(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut); + +/* Low-level function for converting unsigned 8-bit PCM samples to IEEE 32-bit floating point samples. */ +DRWAV_API void drwav_u8_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount); + +/* Low-level function for converting signed 16-bit PCM samples to IEEE 32-bit floating point samples. */ +DRWAV_API void drwav_s16_to_f32(float* pOut, const drwav_int16* pIn, size_t sampleCount); + +/* Low-level function for converting signed 24-bit PCM samples to IEEE 32-bit floating point samples. */ +DRWAV_API void drwav_s24_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount); + +/* Low-level function for converting signed 32-bit PCM samples to IEEE 32-bit floating point samples. */ +DRWAV_API void drwav_s32_to_f32(float* pOut, const drwav_int32* pIn, size_t sampleCount); + +/* Low-level function for converting IEEE 64-bit floating point samples to IEEE 32-bit floating point samples. */ +DRWAV_API void drwav_f64_to_f32(float* pOut, const double* pIn, size_t sampleCount); + +/* Low-level function for converting A-law samples to IEEE 32-bit floating point samples. */ +DRWAV_API void drwav_alaw_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount); + +/* Low-level function for converting u-law samples to IEEE 32-bit floating point samples. */ +DRWAV_API void drwav_mulaw_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount); + + +/* +Reads a chunk of audio data and converts it to signed 32-bit PCM samples. + +pBufferOut can be NULL in which case a seek will be performed. + +Returns the number of PCM frames actually read. + +If the return value is less than it means the end of the file has been reached. +*/ +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s32(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut); +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s32le(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut); +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s32be(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut); + +/* Low-level function for converting unsigned 8-bit PCM samples to signed 32-bit PCM samples. */ +DRWAV_API void drwav_u8_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount); + +/* Low-level function for converting signed 16-bit PCM samples to signed 32-bit PCM samples. */ +DRWAV_API void drwav_s16_to_s32(drwav_int32* pOut, const drwav_int16* pIn, size_t sampleCount); + +/* Low-level function for converting signed 24-bit PCM samples to signed 32-bit PCM samples. */ +DRWAV_API void drwav_s24_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount); + +/* Low-level function for converting IEEE 32-bit floating point samples to signed 32-bit PCM samples. */ +DRWAV_API void drwav_f32_to_s32(drwav_int32* pOut, const float* pIn, size_t sampleCount); + +/* Low-level function for converting IEEE 64-bit floating point samples to signed 32-bit PCM samples. */ +DRWAV_API void drwav_f64_to_s32(drwav_int32* pOut, const double* pIn, size_t sampleCount); + +/* Low-level function for converting A-law samples to signed 32-bit PCM samples. */ +DRWAV_API void drwav_alaw_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount); + +/* Low-level function for converting u-law samples to signed 32-bit PCM samples. */ +DRWAV_API void drwav_mulaw_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount); + +#endif /* DR_WAV_NO_CONVERSION_API */ + + +/* High-Level Convenience Helpers */ + +#ifndef DR_WAV_NO_STDIO +/* +Helper for initializing a wave file for reading using stdio. + +This holds the internal FILE object until drwav_uninit() is called. Keep this in mind if you're caching drwav +objects because the operating system may restrict the number of file handles an application can have open at +any given time. +*/ +DRWAV_API drwav_bool32 drwav_init_file(drwav* pWav, const char* filename, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_file_ex(drwav* pWav, const char* filename, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_file_w(drwav* pWav, const wchar_t* filename, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_file_ex_w(drwav* pWav, const wchar_t* filename, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks); + +/* +Helper for initializing a wave file for writing using stdio. + +This holds the internal FILE object until drwav_uninit() is called. Keep this in mind if you're caching drwav +objects because the operating system may restrict the number of file handles an application can have open at +any given time. +*/ +DRWAV_API drwav_bool32 drwav_init_file_write(drwav* pWav, const char* filename, const drwav_data_format* pFormat, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_file_write_sequential(drwav* pWav, const char* filename, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_file_write_sequential_pcm_frames(drwav* pWav, const char* filename, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_file_write_w(drwav* pWav, const wchar_t* filename, const drwav_data_format* pFormat, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_file_write_sequential_w(drwav* pWav, const wchar_t* filename, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_file_write_sequential_pcm_frames_w(drwav* pWav, const wchar_t* filename, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, const drwav_allocation_callbacks* pAllocationCallbacks); +#endif /* DR_WAV_NO_STDIO */ + +/* +Helper for initializing a loader from a pre-allocated memory buffer. + +This does not create a copy of the data. It is up to the application to ensure the buffer remains valid for +the lifetime of the drwav object. + +The buffer should contain the contents of the entire wave file, not just the sample data. +*/ +DRWAV_API drwav_bool32 drwav_init_memory(drwav* pWav, const void* data, size_t dataSize, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_memory_ex(drwav* pWav, const void* data, size_t dataSize, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks); + +/* +Helper for initializing a writer which outputs data to a memory buffer. + +dr_wav will manage the memory allocations, however it is up to the caller to free the data with drwav_free(). + +The buffer will remain allocated even after drwav_uninit() is called. The buffer should not be considered valid +until after drwav_uninit() has been called. +*/ +DRWAV_API drwav_bool32 drwav_init_memory_write(drwav* pWav, void** ppData, size_t* pDataSize, const drwav_data_format* pFormat, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_memory_write_sequential(drwav* pWav, void** ppData, size_t* pDataSize, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_memory_write_sequential_pcm_frames(drwav* pWav, void** ppData, size_t* pDataSize, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, const drwav_allocation_callbacks* pAllocationCallbacks); + + +#ifndef DR_WAV_NO_CONVERSION_API +/* +Opens and reads an entire wav file in a single operation. + +The return value is a heap-allocated buffer containing the audio data. Use drwav_free() to free the buffer. +*/ +DRWAV_API drwav_int16* drwav_open_and_read_pcm_frames_s16(drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API float* drwav_open_and_read_pcm_frames_f32(drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_int32* drwav_open_and_read_pcm_frames_s32(drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +#ifndef DR_WAV_NO_STDIO +/* +Opens and decodes an entire wav file in a single operation. + +The return value is a heap-allocated buffer containing the audio data. Use drwav_free() to free the buffer. +*/ +DRWAV_API drwav_int16* drwav_open_file_and_read_pcm_frames_s16(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API float* drwav_open_file_and_read_pcm_frames_f32(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_int32* drwav_open_file_and_read_pcm_frames_s32(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_int16* drwav_open_file_and_read_pcm_frames_s16_w(const wchar_t* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API float* drwav_open_file_and_read_pcm_frames_f32_w(const wchar_t* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_int32* drwav_open_file_and_read_pcm_frames_s32_w(const wchar_t* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +#endif +/* +Opens and decodes an entire wav file from a block of memory in a single operation. + +The return value is a heap-allocated buffer containing the audio data. Use drwav_free() to free the buffer. +*/ +DRWAV_API drwav_int16* drwav_open_memory_and_read_pcm_frames_s16(const void* data, size_t dataSize, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API float* drwav_open_memory_and_read_pcm_frames_f32(const void* data, size_t dataSize, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_int32* drwav_open_memory_and_read_pcm_frames_s32(const void* data, size_t dataSize, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +#endif + +/* Frees data that was allocated internally by dr_wav. */ +DRWAV_API void drwav_free(void* p, const drwav_allocation_callbacks* pAllocationCallbacks); + +/* Converts bytes from a wav stream to a sized type of native endian. */ +DRWAV_API drwav_uint16 drwav_bytes_to_u16(const drwav_uint8* data); +DRWAV_API drwav_int16 drwav_bytes_to_s16(const drwav_uint8* data); +DRWAV_API drwav_uint32 drwav_bytes_to_u32(const drwav_uint8* data); +DRWAV_API drwav_int32 drwav_bytes_to_s32(const drwav_uint8* data); +DRWAV_API drwav_uint64 drwav_bytes_to_u64(const drwav_uint8* data); +DRWAV_API drwav_int64 drwav_bytes_to_s64(const drwav_uint8* data); + +/* Compares a GUID for the purpose of checking the type of a Wave64 chunk. */ +DRWAV_API drwav_bool32 drwav_guid_equal(const drwav_uint8 a[16], const drwav_uint8 b[16]); + +/* Compares a four-character-code for the purpose of checking the type of a RIFF chunk. */ +DRWAV_API drwav_bool32 drwav_fourcc_equal(const drwav_uint8* a, const char* b); + +#ifdef __cplusplus +} +#endif +#endif /* dr_wav_h */ + + +/************************************************************************************************************************************************************ + ************************************************************************************************************************************************************ + + IMPLEMENTATION + + ************************************************************************************************************************************************************ + ************************************************************************************************************************************************************/ +#if defined(DR_WAV_IMPLEMENTATION) || defined(DRWAV_IMPLEMENTATION) +#ifndef dr_wav_c +#define dr_wav_c + +#include +#include /* For memcpy(), memset() */ +#include /* For INT_MAX */ + +#ifndef DR_WAV_NO_STDIO +#include +#include +#endif + +/* Standard library stuff. */ +#ifndef DRWAV_ASSERT +#include +#define DRWAV_ASSERT(expression) assert(expression) +#endif +#ifndef DRWAV_MALLOC +#define DRWAV_MALLOC(sz) malloc((sz)) +#endif +#ifndef DRWAV_REALLOC +#define DRWAV_REALLOC(p, sz) realloc((p), (sz)) +#endif +#ifndef DRWAV_FREE +#define DRWAV_FREE(p) free((p)) +#endif +#ifndef DRWAV_COPY_MEMORY +#define DRWAV_COPY_MEMORY(dst, src, sz) memcpy((dst), (src), (sz)) +#endif +#ifndef DRWAV_ZERO_MEMORY +#define DRWAV_ZERO_MEMORY(p, sz) memset((p), 0, (sz)) +#endif +#ifndef DRWAV_ZERO_OBJECT +#define DRWAV_ZERO_OBJECT(p) DRWAV_ZERO_MEMORY((p), sizeof(*p)) +#endif + +#define drwav_countof(x) (sizeof(x) / sizeof(x[0])) +#define drwav_align(x, a) ((((x) + (a) - 1) / (a)) * (a)) +#define drwav_min(a, b) (((a) < (b)) ? (a) : (b)) +#define drwav_max(a, b) (((a) > (b)) ? (a) : (b)) +#define drwav_clamp(x, lo, hi) (drwav_max((lo), drwav_min((hi), (x)))) + +#define DRWAV_MAX_SIMD_VECTOR_SIZE 64 /* 64 for AVX-512 in the future. */ + +/* CPU architecture. */ +#if defined(__x86_64__) || defined(_M_X64) + #define DRWAV_X64 +#elif defined(__i386) || defined(_M_IX86) + #define DRWAV_X86 +#elif defined(__arm__) || defined(_M_ARM) + #define DRWAV_ARM +#endif + +#ifdef _MSC_VER + #define DRWAV_INLINE __forceinline +#elif defined(__GNUC__) + /* + I've had a bug report where GCC is emitting warnings about functions possibly not being inlineable. This warning happens when + the __attribute__((always_inline)) attribute is defined without an "inline" statement. I think therefore there must be some + case where "__inline__" is not always defined, thus the compiler emitting these warnings. When using -std=c89 or -ansi on the + command line, we cannot use the "inline" keyword and instead need to use "__inline__". In an attempt to work around this issue + I am using "__inline__" only when we're compiling in strict ANSI mode. + */ + #if defined(__STRICT_ANSI__) + #define DRWAV_INLINE __inline__ __attribute__((always_inline)) + #else + #define DRWAV_INLINE inline __attribute__((always_inline)) + #endif +#elif defined(__WATCOMC__) + #define DRWAV_INLINE __inline +#else + #define DRWAV_INLINE +#endif + +#if defined(SIZE_MAX) + #define DRWAV_SIZE_MAX SIZE_MAX +#else + #if defined(_WIN64) || defined(_LP64) || defined(__LP64__) + #define DRWAV_SIZE_MAX ((drwav_uint64)0xFFFFFFFFFFFFFFFF) + #else + #define DRWAV_SIZE_MAX 0xFFFFFFFF + #endif +#endif + +#if defined(_MSC_VER) && _MSC_VER >= 1400 + #define DRWAV_HAS_BYTESWAP16_INTRINSIC + #define DRWAV_HAS_BYTESWAP32_INTRINSIC + #define DRWAV_HAS_BYTESWAP64_INTRINSIC +#elif defined(__clang__) + #if defined(__has_builtin) + #if __has_builtin(__builtin_bswap16) + #define DRWAV_HAS_BYTESWAP16_INTRINSIC + #endif + #if __has_builtin(__builtin_bswap32) + #define DRWAV_HAS_BYTESWAP32_INTRINSIC + #endif + #if __has_builtin(__builtin_bswap64) + #define DRWAV_HAS_BYTESWAP64_INTRINSIC + #endif + #endif +#elif defined(__GNUC__) + #if ((__GNUC__ > 4) || (__GNUC__ == 4 && __GNUC_MINOR__ >= 3)) + #define DRWAV_HAS_BYTESWAP32_INTRINSIC + #define DRWAV_HAS_BYTESWAP64_INTRINSIC + #endif + #if ((__GNUC__ > 4) || (__GNUC__ == 4 && __GNUC_MINOR__ >= 8)) + #define DRWAV_HAS_BYTESWAP16_INTRINSIC + #endif +#endif + +DRWAV_API void drwav_version(drwav_uint32* pMajor, drwav_uint32* pMinor, drwav_uint32* pRevision) +{ + if (pMajor) { + *pMajor = DRWAV_VERSION_MAJOR; + } + + if (pMinor) { + *pMinor = DRWAV_VERSION_MINOR; + } + + if (pRevision) { + *pRevision = DRWAV_VERSION_REVISION; + } +} + +DRWAV_API const char* drwav_version_string(void) +{ + return DRWAV_VERSION_STRING; +} + +/* +These limits are used for basic validation when initializing the decoder. If you exceed these limits, first of all: what on Earth are +you doing?! (Let me know, I'd be curious!) Second, you can adjust these by #define-ing them before the dr_wav implementation. +*/ +#ifndef DRWAV_MAX_SAMPLE_RATE +#define DRWAV_MAX_SAMPLE_RATE 384000 +#endif +#ifndef DRWAV_MAX_CHANNELS +#define DRWAV_MAX_CHANNELS 256 +#endif +#ifndef DRWAV_MAX_BITS_PER_SAMPLE +#define DRWAV_MAX_BITS_PER_SAMPLE 64 +#endif + +static const drwav_uint8 drwavGUID_W64_RIFF[16] = {0x72,0x69,0x66,0x66, 0x2E,0x91, 0xCF,0x11, 0xA5,0xD6, 0x28,0xDB,0x04,0xC1,0x00,0x00}; /* 66666972-912E-11CF-A5D6-28DB04C10000 */ +static const drwav_uint8 drwavGUID_W64_WAVE[16] = {0x77,0x61,0x76,0x65, 0xF3,0xAC, 0xD3,0x11, 0x8C,0xD1, 0x00,0xC0,0x4F,0x8E,0xDB,0x8A}; /* 65766177-ACF3-11D3-8CD1-00C04F8EDB8A */ +/*static const drwav_uint8 drwavGUID_W64_JUNK[16] = {0x6A,0x75,0x6E,0x6B, 0xF3,0xAC, 0xD3,0x11, 0x8C,0xD1, 0x00,0xC0,0x4F,0x8E,0xDB,0x8A};*/ /* 6B6E756A-ACF3-11D3-8CD1-00C04F8EDB8A */ +static const drwav_uint8 drwavGUID_W64_FMT [16] = {0x66,0x6D,0x74,0x20, 0xF3,0xAC, 0xD3,0x11, 0x8C,0xD1, 0x00,0xC0,0x4F,0x8E,0xDB,0x8A}; /* 20746D66-ACF3-11D3-8CD1-00C04F8EDB8A */ +static const drwav_uint8 drwavGUID_W64_FACT[16] = {0x66,0x61,0x63,0x74, 0xF3,0xAC, 0xD3,0x11, 0x8C,0xD1, 0x00,0xC0,0x4F,0x8E,0xDB,0x8A}; /* 74636166-ACF3-11D3-8CD1-00C04F8EDB8A */ +static const drwav_uint8 drwavGUID_W64_DATA[16] = {0x64,0x61,0x74,0x61, 0xF3,0xAC, 0xD3,0x11, 0x8C,0xD1, 0x00,0xC0,0x4F,0x8E,0xDB,0x8A}; /* 61746164-ACF3-11D3-8CD1-00C04F8EDB8A */ +static const drwav_uint8 drwavGUID_W64_SMPL[16] = {0x73,0x6D,0x70,0x6C, 0xF3,0xAC, 0xD3,0x11, 0x8C,0xD1, 0x00,0xC0,0x4F,0x8E,0xDB,0x8A}; /* 6C706D73-ACF3-11D3-8CD1-00C04F8EDB8A */ + +static DRWAV_INLINE drwav_bool32 drwav__guid_equal(const drwav_uint8 a[16], const drwav_uint8 b[16]) +{ + int i; + for (i = 0; i < 16; i += 1) { + if (a[i] != b[i]) { + return DRWAV_FALSE; + } + } + + return DRWAV_TRUE; +} + +static DRWAV_INLINE drwav_bool32 drwav__fourcc_equal(const drwav_uint8* a, const char* b) +{ + return + a[0] == b[0] && + a[1] == b[1] && + a[2] == b[2] && + a[3] == b[3]; +} + + + +static DRWAV_INLINE int drwav__is_little_endian(void) +{ +#if defined(DRWAV_X86) || defined(DRWAV_X64) + return DRWAV_TRUE; +#elif defined(__BYTE_ORDER) && defined(__LITTLE_ENDIAN) && __BYTE_ORDER == __LITTLE_ENDIAN + return DRWAV_TRUE; +#else + int n = 1; + return (*(char*)&n) == 1; +#endif +} + +static DRWAV_INLINE drwav_uint16 drwav__bytes_to_u16(const drwav_uint8* data) +{ + return (data[0] << 0) | (data[1] << 8); +} + +static DRWAV_INLINE drwav_int16 drwav__bytes_to_s16(const drwav_uint8* data) +{ + return (short)drwav__bytes_to_u16(data); +} + +static DRWAV_INLINE drwav_uint32 drwav__bytes_to_u32(const drwav_uint8* data) +{ + return (data[0] << 0) | (data[1] << 8) | (data[2] << 16) | (data[3] << 24); +} + +static DRWAV_INLINE drwav_int32 drwav__bytes_to_s32(const drwav_uint8* data) +{ + return (drwav_int32)drwav__bytes_to_u32(data); +} + +static DRWAV_INLINE drwav_uint64 drwav__bytes_to_u64(const drwav_uint8* data) +{ + return + ((drwav_uint64)data[0] << 0) | ((drwav_uint64)data[1] << 8) | ((drwav_uint64)data[2] << 16) | ((drwav_uint64)data[3] << 24) | + ((drwav_uint64)data[4] << 32) | ((drwav_uint64)data[5] << 40) | ((drwav_uint64)data[6] << 48) | ((drwav_uint64)data[7] << 56); +} + +static DRWAV_INLINE drwav_int64 drwav__bytes_to_s64(const drwav_uint8* data) +{ + return (drwav_int64)drwav__bytes_to_u64(data); +} + +static DRWAV_INLINE void drwav__bytes_to_guid(const drwav_uint8* data, drwav_uint8* guid) +{ + int i; + for (i = 0; i < 16; ++i) { + guid[i] = data[i]; + } +} + + +static DRWAV_INLINE drwav_uint16 drwav__bswap16(drwav_uint16 n) +{ +#ifdef DRWAV_HAS_BYTESWAP16_INTRINSIC + #if defined(_MSC_VER) + return _byteswap_ushort(n); + #elif defined(__GNUC__) || defined(__clang__) + return __builtin_bswap16(n); + #else + #error "This compiler does not support the byte swap intrinsic." + #endif +#else + return ((n & 0xFF00) >> 8) | + ((n & 0x00FF) << 8); +#endif +} + +static DRWAV_INLINE drwav_uint32 drwav__bswap32(drwav_uint32 n) +{ +#ifdef DRWAV_HAS_BYTESWAP32_INTRINSIC + #if defined(_MSC_VER) + return _byteswap_ulong(n); + #elif defined(__GNUC__) || defined(__clang__) + #if defined(DRWAV_ARM) && (defined(__ARM_ARCH) && __ARM_ARCH >= 6) && !defined(DRWAV_64BIT) /* <-- 64-bit inline assembly has not been tested, so disabling for now. */ + /* Inline assembly optimized implementation for ARM. In my testing, GCC does not generate optimized code with __builtin_bswap32(). */ + drwav_uint32 r; + __asm__ __volatile__ ( + #if defined(DRWAV_64BIT) + "rev %w[out], %w[in]" : [out]"=r"(r) : [in]"r"(n) /* <-- This is untested. If someone in the community could test this, that would be appreciated! */ + #else + "rev %[out], %[in]" : [out]"=r"(r) : [in]"r"(n) + #endif + ); + return r; + #else + return __builtin_bswap32(n); + #endif + #else + #error "This compiler does not support the byte swap intrinsic." + #endif +#else + return ((n & 0xFF000000) >> 24) | + ((n & 0x00FF0000) >> 8) | + ((n & 0x0000FF00) << 8) | + ((n & 0x000000FF) << 24); +#endif +} + +static DRWAV_INLINE drwav_uint64 drwav__bswap64(drwav_uint64 n) +{ +#ifdef DRWAV_HAS_BYTESWAP64_INTRINSIC + #if defined(_MSC_VER) + return _byteswap_uint64(n); + #elif defined(__GNUC__) || defined(__clang__) + return __builtin_bswap64(n); + #else + #error "This compiler does not support the byte swap intrinsic." + #endif +#else + /* Weird "<< 32" bitshift is required for C89 because it doesn't support 64-bit constants. Should be optimized out by a good compiler. */ + return ((n & ((drwav_uint64)0xFF000000 << 32)) >> 56) | + ((n & ((drwav_uint64)0x00FF0000 << 32)) >> 40) | + ((n & ((drwav_uint64)0x0000FF00 << 32)) >> 24) | + ((n & ((drwav_uint64)0x000000FF << 32)) >> 8) | + ((n & ((drwav_uint64)0xFF000000 )) << 8) | + ((n & ((drwav_uint64)0x00FF0000 )) << 24) | + ((n & ((drwav_uint64)0x0000FF00 )) << 40) | + ((n & ((drwav_uint64)0x000000FF )) << 56); +#endif +} + + +static DRWAV_INLINE drwav_int16 drwav__bswap_s16(drwav_int16 n) +{ + return (drwav_int16)drwav__bswap16((drwav_uint16)n); +} + +static DRWAV_INLINE void drwav__bswap_samples_s16(drwav_int16* pSamples, drwav_uint64 sampleCount) +{ + drwav_uint64 iSample; + for (iSample = 0; iSample < sampleCount; iSample += 1) { + pSamples[iSample] = drwav__bswap_s16(pSamples[iSample]); + } +} + + +static DRWAV_INLINE void drwav__bswap_s24(drwav_uint8* p) +{ + drwav_uint8 t; + t = p[0]; + p[0] = p[2]; + p[2] = t; +} + +static DRWAV_INLINE void drwav__bswap_samples_s24(drwav_uint8* pSamples, drwav_uint64 sampleCount) +{ + drwav_uint64 iSample; + for (iSample = 0; iSample < sampleCount; iSample += 1) { + drwav_uint8* pSample = pSamples + (iSample*3); + drwav__bswap_s24(pSample); + } +} + + +static DRWAV_INLINE drwav_int32 drwav__bswap_s32(drwav_int32 n) +{ + return (drwav_int32)drwav__bswap32((drwav_uint32)n); +} + +static DRWAV_INLINE void drwav__bswap_samples_s32(drwav_int32* pSamples, drwav_uint64 sampleCount) +{ + drwav_uint64 iSample; + for (iSample = 0; iSample < sampleCount; iSample += 1) { + pSamples[iSample] = drwav__bswap_s32(pSamples[iSample]); + } +} + + +static DRWAV_INLINE float drwav__bswap_f32(float n) +{ + union { + drwav_uint32 i; + float f; + } x; + x.f = n; + x.i = drwav__bswap32(x.i); + + return x.f; +} + +static DRWAV_INLINE void drwav__bswap_samples_f32(float* pSamples, drwav_uint64 sampleCount) +{ + drwav_uint64 iSample; + for (iSample = 0; iSample < sampleCount; iSample += 1) { + pSamples[iSample] = drwav__bswap_f32(pSamples[iSample]); + } +} + + +static DRWAV_INLINE double drwav__bswap_f64(double n) +{ + union { + drwav_uint64 i; + double f; + } x; + x.f = n; + x.i = drwav__bswap64(x.i); + + return x.f; +} + +static DRWAV_INLINE void drwav__bswap_samples_f64(double* pSamples, drwav_uint64 sampleCount) +{ + drwav_uint64 iSample; + for (iSample = 0; iSample < sampleCount; iSample += 1) { + pSamples[iSample] = drwav__bswap_f64(pSamples[iSample]); + } +} + + +static DRWAV_INLINE void drwav__bswap_samples_pcm(void* pSamples, drwav_uint64 sampleCount, drwav_uint32 bytesPerSample) +{ + /* Assumes integer PCM. Floating point PCM is done in drwav__bswap_samples_ieee(). */ + switch (bytesPerSample) + { + case 2: /* s16, s12 (loosely packed) */ + { + drwav__bswap_samples_s16((drwav_int16*)pSamples, sampleCount); + } break; + case 3: /* s24 */ + { + drwav__bswap_samples_s24((drwav_uint8*)pSamples, sampleCount); + } break; + case 4: /* s32 */ + { + drwav__bswap_samples_s32((drwav_int32*)pSamples, sampleCount); + } break; + default: + { + /* Unsupported format. */ + DRWAV_ASSERT(DRWAV_FALSE); + } break; + } +} + +static DRWAV_INLINE void drwav__bswap_samples_ieee(void* pSamples, drwav_uint64 sampleCount, drwav_uint32 bytesPerSample) +{ + switch (bytesPerSample) + { + #if 0 /* Contributions welcome for f16 support. */ + case 2: /* f16 */ + { + drwav__bswap_samples_f16((drwav_float16*)pSamples, sampleCount); + } break; + #endif + case 4: /* f32 */ + { + drwav__bswap_samples_f32((float*)pSamples, sampleCount); + } break; + case 8: /* f64 */ + { + drwav__bswap_samples_f64((double*)pSamples, sampleCount); + } break; + default: + { + /* Unsupported format. */ + DRWAV_ASSERT(DRWAV_FALSE); + } break; + } +} + +static DRWAV_INLINE void drwav__bswap_samples(void* pSamples, drwav_uint64 sampleCount, drwav_uint32 bytesPerSample, drwav_uint16 format) +{ + switch (format) + { + case DR_WAVE_FORMAT_PCM: + { + drwav__bswap_samples_pcm(pSamples, sampleCount, bytesPerSample); + } break; + + case DR_WAVE_FORMAT_IEEE_FLOAT: + { + drwav__bswap_samples_ieee(pSamples, sampleCount, bytesPerSample); + } break; + + case DR_WAVE_FORMAT_ALAW: + case DR_WAVE_FORMAT_MULAW: + { + drwav__bswap_samples_s16((drwav_int16*)pSamples, sampleCount); + } break; + + case DR_WAVE_FORMAT_ADPCM: + case DR_WAVE_FORMAT_DVI_ADPCM: + default: + { + /* Unsupported format. */ + DRWAV_ASSERT(DRWAV_FALSE); + } break; + } +} + + +static void* drwav__malloc_default(size_t sz, void* pUserData) +{ + (void)pUserData; + return DRWAV_MALLOC(sz); +} + +static void* drwav__realloc_default(void* p, size_t sz, void* pUserData) +{ + (void)pUserData; + return DRWAV_REALLOC(p, sz); +} + +static void drwav__free_default(void* p, void* pUserData) +{ + (void)pUserData; + DRWAV_FREE(p); +} + + +static void* drwav__malloc_from_callbacks(size_t sz, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (pAllocationCallbacks == NULL) { + return NULL; + } + + if (pAllocationCallbacks->onMalloc != NULL) { + return pAllocationCallbacks->onMalloc(sz, pAllocationCallbacks->pUserData); + } + + /* Try using realloc(). */ + if (pAllocationCallbacks->onRealloc != NULL) { + return pAllocationCallbacks->onRealloc(NULL, sz, pAllocationCallbacks->pUserData); + } + + return NULL; +} + +static void* drwav__realloc_from_callbacks(void* p, size_t szNew, size_t szOld, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (pAllocationCallbacks == NULL) { + return NULL; + } + + if (pAllocationCallbacks->onRealloc != NULL) { + return pAllocationCallbacks->onRealloc(p, szNew, pAllocationCallbacks->pUserData); + } + + /* Try emulating realloc() in terms of malloc()/free(). */ + if (pAllocationCallbacks->onMalloc != NULL && pAllocationCallbacks->onFree != NULL) { + void* p2; + + p2 = pAllocationCallbacks->onMalloc(szNew, pAllocationCallbacks->pUserData); + if (p2 == NULL) { + return NULL; + } + + if (p != NULL) { + DRWAV_COPY_MEMORY(p2, p, szOld); + pAllocationCallbacks->onFree(p, pAllocationCallbacks->pUserData); + } + + return p2; + } + + return NULL; +} + +static void drwav__free_from_callbacks(void* p, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (p == NULL || pAllocationCallbacks == NULL) { + return; + } + + if (pAllocationCallbacks->onFree != NULL) { + pAllocationCallbacks->onFree(p, pAllocationCallbacks->pUserData); + } +} + + +static drwav_allocation_callbacks drwav_copy_allocation_callbacks_or_defaults(const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (pAllocationCallbacks != NULL) { + /* Copy. */ + return *pAllocationCallbacks; + } else { + /* Defaults. */ + drwav_allocation_callbacks allocationCallbacks; + allocationCallbacks.pUserData = NULL; + allocationCallbacks.onMalloc = drwav__malloc_default; + allocationCallbacks.onRealloc = drwav__realloc_default; + allocationCallbacks.onFree = drwav__free_default; + return allocationCallbacks; + } +} + + +static DRWAV_INLINE drwav_bool32 drwav__is_compressed_format_tag(drwav_uint16 formatTag) +{ + return + formatTag == DR_WAVE_FORMAT_ADPCM || + formatTag == DR_WAVE_FORMAT_DVI_ADPCM; +} + +static unsigned int drwav__chunk_padding_size_riff(drwav_uint64 chunkSize) +{ + return (unsigned int)(chunkSize % 2); +} + +static unsigned int drwav__chunk_padding_size_w64(drwav_uint64 chunkSize) +{ + return (unsigned int)(chunkSize % 8); +} + +static drwav_uint64 drwav_read_pcm_frames_s16__msadpcm(drwav* pWav, drwav_uint64 samplesToRead, drwav_int16* pBufferOut); +static drwav_uint64 drwav_read_pcm_frames_s16__ima(drwav* pWav, drwav_uint64 samplesToRead, drwav_int16* pBufferOut); +static drwav_bool32 drwav_init_write__internal(drwav* pWav, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount); + +static drwav_result drwav__read_chunk_header(drwav_read_proc onRead, void* pUserData, drwav_container container, drwav_uint64* pRunningBytesReadOut, drwav_chunk_header* pHeaderOut) +{ + if (container == drwav_container_riff || container == drwav_container_rf64) { + drwav_uint8 sizeInBytes[4]; + + if (onRead(pUserData, pHeaderOut->id.fourcc, 4) != 4) { + return DRWAV_AT_END; + } + + if (onRead(pUserData, sizeInBytes, 4) != 4) { + return DRWAV_INVALID_FILE; + } + + pHeaderOut->sizeInBytes = drwav__bytes_to_u32(sizeInBytes); + pHeaderOut->paddingSize = drwav__chunk_padding_size_riff(pHeaderOut->sizeInBytes); + *pRunningBytesReadOut += 8; + } else { + drwav_uint8 sizeInBytes[8]; + + if (onRead(pUserData, pHeaderOut->id.guid, 16) != 16) { + return DRWAV_AT_END; + } + + if (onRead(pUserData, sizeInBytes, 8) != 8) { + return DRWAV_INVALID_FILE; + } + + pHeaderOut->sizeInBytes = drwav__bytes_to_u64(sizeInBytes) - 24; /* <-- Subtract 24 because w64 includes the size of the header. */ + pHeaderOut->paddingSize = drwav__chunk_padding_size_w64(pHeaderOut->sizeInBytes); + *pRunningBytesReadOut += 24; + } + + return DRWAV_SUCCESS; +} + +static drwav_bool32 drwav__seek_forward(drwav_seek_proc onSeek, drwav_uint64 offset, void* pUserData) +{ + drwav_uint64 bytesRemainingToSeek = offset; + while (bytesRemainingToSeek > 0) { + if (bytesRemainingToSeek > 0x7FFFFFFF) { + if (!onSeek(pUserData, 0x7FFFFFFF, drwav_seek_origin_current)) { + return DRWAV_FALSE; + } + bytesRemainingToSeek -= 0x7FFFFFFF; + } else { + if (!onSeek(pUserData, (int)bytesRemainingToSeek, drwav_seek_origin_current)) { + return DRWAV_FALSE; + } + bytesRemainingToSeek = 0; + } + } + + return DRWAV_TRUE; +} + +static drwav_bool32 drwav__seek_from_start(drwav_seek_proc onSeek, drwav_uint64 offset, void* pUserData) +{ + if (offset <= 0x7FFFFFFF) { + return onSeek(pUserData, (int)offset, drwav_seek_origin_start); + } + + /* Larger than 32-bit seek. */ + if (!onSeek(pUserData, 0x7FFFFFFF, drwav_seek_origin_start)) { + return DRWAV_FALSE; + } + offset -= 0x7FFFFFFF; + + for (;;) { + if (offset <= 0x7FFFFFFF) { + return onSeek(pUserData, (int)offset, drwav_seek_origin_current); + } + + if (!onSeek(pUserData, 0x7FFFFFFF, drwav_seek_origin_current)) { + return DRWAV_FALSE; + } + offset -= 0x7FFFFFFF; + } + + /* Should never get here. */ + /*return DRWAV_TRUE; */ +} + + +static drwav_bool32 drwav__read_fmt(drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, drwav_container container, drwav_uint64* pRunningBytesReadOut, drwav_fmt* fmtOut) +{ + drwav_chunk_header header; + drwav_uint8 fmt[16]; + + if (drwav__read_chunk_header(onRead, pUserData, container, pRunningBytesReadOut, &header) != DRWAV_SUCCESS) { + return DRWAV_FALSE; + } + + + /* Skip non-fmt chunks. */ + while (((container == drwav_container_riff || container == drwav_container_rf64) && !drwav__fourcc_equal(header.id.fourcc, "fmt ")) || (container == drwav_container_w64 && !drwav__guid_equal(header.id.guid, drwavGUID_W64_FMT))) { + if (!drwav__seek_forward(onSeek, header.sizeInBytes + header.paddingSize, pUserData)) { + return DRWAV_FALSE; + } + *pRunningBytesReadOut += header.sizeInBytes + header.paddingSize; + + /* Try the next header. */ + if (drwav__read_chunk_header(onRead, pUserData, container, pRunningBytesReadOut, &header) != DRWAV_SUCCESS) { + return DRWAV_FALSE; + } + } + + + /* Validation. */ + if (container == drwav_container_riff || container == drwav_container_rf64) { + if (!drwav__fourcc_equal(header.id.fourcc, "fmt ")) { + return DRWAV_FALSE; + } + } else { + if (!drwav__guid_equal(header.id.guid, drwavGUID_W64_FMT)) { + return DRWAV_FALSE; + } + } + + + if (onRead(pUserData, fmt, sizeof(fmt)) != sizeof(fmt)) { + return DRWAV_FALSE; + } + *pRunningBytesReadOut += sizeof(fmt); + + fmtOut->formatTag = drwav__bytes_to_u16(fmt + 0); + fmtOut->channels = drwav__bytes_to_u16(fmt + 2); + fmtOut->sampleRate = drwav__bytes_to_u32(fmt + 4); + fmtOut->avgBytesPerSec = drwav__bytes_to_u32(fmt + 8); + fmtOut->blockAlign = drwav__bytes_to_u16(fmt + 12); + fmtOut->bitsPerSample = drwav__bytes_to_u16(fmt + 14); + + fmtOut->extendedSize = 0; + fmtOut->validBitsPerSample = 0; + fmtOut->channelMask = 0; + memset(fmtOut->subFormat, 0, sizeof(fmtOut->subFormat)); + + if (header.sizeInBytes > 16) { + drwav_uint8 fmt_cbSize[2]; + int bytesReadSoFar = 0; + + if (onRead(pUserData, fmt_cbSize, sizeof(fmt_cbSize)) != sizeof(fmt_cbSize)) { + return DRWAV_FALSE; /* Expecting more data. */ + } + *pRunningBytesReadOut += sizeof(fmt_cbSize); + + bytesReadSoFar = 18; + + fmtOut->extendedSize = drwav__bytes_to_u16(fmt_cbSize); + if (fmtOut->extendedSize > 0) { + /* Simple validation. */ + if (fmtOut->formatTag == DR_WAVE_FORMAT_EXTENSIBLE) { + if (fmtOut->extendedSize != 22) { + return DRWAV_FALSE; + } + } + + if (fmtOut->formatTag == DR_WAVE_FORMAT_EXTENSIBLE) { + drwav_uint8 fmtext[22]; + if (onRead(pUserData, fmtext, fmtOut->extendedSize) != fmtOut->extendedSize) { + return DRWAV_FALSE; /* Expecting more data. */ + } + + fmtOut->validBitsPerSample = drwav__bytes_to_u16(fmtext + 0); + fmtOut->channelMask = drwav__bytes_to_u32(fmtext + 2); + drwav__bytes_to_guid(fmtext + 6, fmtOut->subFormat); + } else { + if (!onSeek(pUserData, fmtOut->extendedSize, drwav_seek_origin_current)) { + return DRWAV_FALSE; + } + } + *pRunningBytesReadOut += fmtOut->extendedSize; + + bytesReadSoFar += fmtOut->extendedSize; + } + + /* Seek past any leftover bytes. For w64 the leftover will be defined based on the chunk size. */ + if (!onSeek(pUserData, (int)(header.sizeInBytes - bytesReadSoFar), drwav_seek_origin_current)) { + return DRWAV_FALSE; + } + *pRunningBytesReadOut += (header.sizeInBytes - bytesReadSoFar); + } + + if (header.paddingSize > 0) { + if (!onSeek(pUserData, header.paddingSize, drwav_seek_origin_current)) { + return DRWAV_FALSE; + } + *pRunningBytesReadOut += header.paddingSize; + } + + return DRWAV_TRUE; +} + + +static size_t drwav__on_read(drwav_read_proc onRead, void* pUserData, void* pBufferOut, size_t bytesToRead, drwav_uint64* pCursor) +{ + size_t bytesRead; + + DRWAV_ASSERT(onRead != NULL); + DRWAV_ASSERT(pCursor != NULL); + + bytesRead = onRead(pUserData, pBufferOut, bytesToRead); + *pCursor += bytesRead; + return bytesRead; +} + +#if 0 +static drwav_bool32 drwav__on_seek(drwav_seek_proc onSeek, void* pUserData, int offset, drwav_seek_origin origin, drwav_uint64* pCursor) +{ + DRWAV_ASSERT(onSeek != NULL); + DRWAV_ASSERT(pCursor != NULL); + + if (!onSeek(pUserData, offset, origin)) { + return DRWAV_FALSE; + } + + if (origin == drwav_seek_origin_start) { + *pCursor = offset; + } else { + *pCursor += offset; + } + + return DRWAV_TRUE; +} +#endif + + + +static drwav_uint32 drwav_get_bytes_per_pcm_frame(drwav* pWav) +{ + /* + The bytes per frame is a bit ambiguous. It can be either be based on the bits per sample, or the block align. The way I'm doing it here + is that if the bits per sample is a multiple of 8, use floor(bitsPerSample*channels/8), otherwise fall back to the block align. + */ + if ((pWav->bitsPerSample & 0x7) == 0) { + /* Bits per sample is a multiple of 8. */ + return (pWav->bitsPerSample * pWav->fmt.channels) >> 3; + } else { + return pWav->fmt.blockAlign; + } +} + +DRWAV_API drwav_uint16 drwav_fmt_get_format(const drwav_fmt* pFMT) +{ + if (pFMT == NULL) { + return 0; + } + + if (pFMT->formatTag != DR_WAVE_FORMAT_EXTENSIBLE) { + return pFMT->formatTag; + } else { + return drwav__bytes_to_u16(pFMT->subFormat); /* Only the first two bytes are required. */ + } +} + +static drwav_bool32 drwav_preinit(drwav* pWav, drwav_read_proc onRead, drwav_seek_proc onSeek, void* pReadSeekUserData, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (pWav == NULL || onRead == NULL || onSeek == NULL) { + return DRWAV_FALSE; + } + + DRWAV_ZERO_MEMORY(pWav, sizeof(*pWav)); + pWav->onRead = onRead; + pWav->onSeek = onSeek; + pWav->pUserData = pReadSeekUserData; + pWav->allocationCallbacks = drwav_copy_allocation_callbacks_or_defaults(pAllocationCallbacks); + + if (pWav->allocationCallbacks.onFree == NULL || (pWav->allocationCallbacks.onMalloc == NULL && pWav->allocationCallbacks.onRealloc == NULL)) { + return DRWAV_FALSE; /* Invalid allocation callbacks. */ + } + + return DRWAV_TRUE; +} + +static drwav_bool32 drwav_init__internal(drwav* pWav, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags) +{ + /* This function assumes drwav_preinit() has been called beforehand. */ + + drwav_uint64 cursor; /* <-- Keeps track of the byte position so we can seek to specific locations. */ + drwav_bool32 sequential; + drwav_uint8 riff[4]; + drwav_fmt fmt; + unsigned short translatedFormatTag; + drwav_bool32 foundDataChunk; + drwav_uint64 dataChunkSize = 0; /* <-- Important! Don't explicitly set this to 0 anywhere else. Calculation of the size of the data chunk is performed in different paths depending on the container. */ + drwav_uint64 sampleCountFromFactChunk = 0; /* Same as dataChunkSize - make sure this is the only place this is initialized to 0. */ + drwav_uint64 chunkSize; + + cursor = 0; + sequential = (flags & DRWAV_SEQUENTIAL) != 0; + + /* The first 4 bytes should be the RIFF identifier. */ + if (drwav__on_read(pWav->onRead, pWav->pUserData, riff, sizeof(riff), &cursor) != sizeof(riff)) { + return DRWAV_FALSE; + } + + /* + The first 4 bytes can be used to identify the container. For RIFF files it will start with "RIFF" and for + w64 it will start with "riff". + */ + if (drwav__fourcc_equal(riff, "RIFF")) { + pWav->container = drwav_container_riff; + } else if (drwav__fourcc_equal(riff, "riff")) { + int i; + drwav_uint8 riff2[12]; + + pWav->container = drwav_container_w64; + + /* Check the rest of the GUID for validity. */ + if (drwav__on_read(pWav->onRead, pWav->pUserData, riff2, sizeof(riff2), &cursor) != sizeof(riff2)) { + return DRWAV_FALSE; + } + + for (i = 0; i < 12; ++i) { + if (riff2[i] != drwavGUID_W64_RIFF[i+4]) { + return DRWAV_FALSE; + } + } + } else if (drwav__fourcc_equal(riff, "RF64")) { + pWav->container = drwav_container_rf64; + } else { + return DRWAV_FALSE; /* Unknown or unsupported container. */ + } + + + if (pWav->container == drwav_container_riff || pWav->container == drwav_container_rf64) { + drwav_uint8 chunkSizeBytes[4]; + drwav_uint8 wave[4]; + + /* RIFF/WAVE */ + if (drwav__on_read(pWav->onRead, pWav->pUserData, chunkSizeBytes, sizeof(chunkSizeBytes), &cursor) != sizeof(chunkSizeBytes)) { + return DRWAV_FALSE; + } + + if (pWav->container == drwav_container_riff) { + if (drwav__bytes_to_u32(chunkSizeBytes) < 36) { + return DRWAV_FALSE; /* Chunk size should always be at least 36 bytes. */ + } + } else { + if (drwav__bytes_to_u32(chunkSizeBytes) != 0xFFFFFFFF) { + return DRWAV_FALSE; /* Chunk size should always be set to -1/0xFFFFFFFF for RF64. The actual size is retrieved later. */ + } + } + + if (drwav__on_read(pWav->onRead, pWav->pUserData, wave, sizeof(wave), &cursor) != sizeof(wave)) { + return DRWAV_FALSE; + } + + if (!drwav__fourcc_equal(wave, "WAVE")) { + return DRWAV_FALSE; /* Expecting "WAVE". */ + } + } else { + drwav_uint8 chunkSizeBytes[8]; + drwav_uint8 wave[16]; + + /* W64 */ + if (drwav__on_read(pWav->onRead, pWav->pUserData, chunkSizeBytes, sizeof(chunkSizeBytes), &cursor) != sizeof(chunkSizeBytes)) { + return DRWAV_FALSE; + } + + if (drwav__bytes_to_u64(chunkSizeBytes) < 80) { + return DRWAV_FALSE; + } + + if (drwav__on_read(pWav->onRead, pWav->pUserData, wave, sizeof(wave), &cursor) != sizeof(wave)) { + return DRWAV_FALSE; + } + + if (!drwav__guid_equal(wave, drwavGUID_W64_WAVE)) { + return DRWAV_FALSE; + } + } + + + /* For RF64, the "ds64" chunk must come next, before the "fmt " chunk. */ + if (pWav->container == drwav_container_rf64) { + drwav_uint8 sizeBytes[8]; + drwav_uint64 bytesRemainingInChunk; + drwav_chunk_header header; + drwav_result result = drwav__read_chunk_header(pWav->onRead, pWav->pUserData, pWav->container, &cursor, &header); + if (result != DRWAV_SUCCESS) { + return DRWAV_FALSE; + } + + if (!drwav__fourcc_equal(header.id.fourcc, "ds64")) { + return DRWAV_FALSE; /* Expecting "ds64". */ + } + + bytesRemainingInChunk = header.sizeInBytes + header.paddingSize; + + /* We don't care about the size of the RIFF chunk - skip it. */ + if (!drwav__seek_forward(pWav->onSeek, 8, pWav->pUserData)) { + return DRWAV_FALSE; + } + bytesRemainingInChunk -= 8; + cursor += 8; + + + /* Next 8 bytes is the size of the "data" chunk. */ + if (drwav__on_read(pWav->onRead, pWav->pUserData, sizeBytes, sizeof(sizeBytes), &cursor) != sizeof(sizeBytes)) { + return DRWAV_FALSE; + } + bytesRemainingInChunk -= 8; + dataChunkSize = drwav__bytes_to_u64(sizeBytes); + + + /* Next 8 bytes is the same count which we would usually derived from the FACT chunk if it was available. */ + if (drwav__on_read(pWav->onRead, pWav->pUserData, sizeBytes, sizeof(sizeBytes), &cursor) != sizeof(sizeBytes)) { + return DRWAV_FALSE; + } + bytesRemainingInChunk -= 8; + sampleCountFromFactChunk = drwav__bytes_to_u64(sizeBytes); + + + /* Skip over everything else. */ + if (!drwav__seek_forward(pWav->onSeek, bytesRemainingInChunk, pWav->pUserData)) { + return DRWAV_FALSE; + } + cursor += bytesRemainingInChunk; + } + + + /* The next bytes should be the "fmt " chunk. */ + if (!drwav__read_fmt(pWav->onRead, pWav->onSeek, pWav->pUserData, pWav->container, &cursor, &fmt)) { + return DRWAV_FALSE; /* Failed to read the "fmt " chunk. */ + } + + /* Basic validation. */ + if ((fmt.sampleRate == 0 || fmt.sampleRate > DRWAV_MAX_SAMPLE_RATE) || + (fmt.channels == 0 || fmt.channels > DRWAV_MAX_CHANNELS) || + (fmt.bitsPerSample == 0 || fmt.bitsPerSample > DRWAV_MAX_BITS_PER_SAMPLE) || + fmt.blockAlign == 0) { + return DRWAV_FALSE; /* Probably an invalid WAV file. */ + } + + + /* Translate the internal format. */ + translatedFormatTag = fmt.formatTag; + if (translatedFormatTag == DR_WAVE_FORMAT_EXTENSIBLE) { + translatedFormatTag = drwav__bytes_to_u16(fmt.subFormat + 0); + } + + + /* + We need to enumerate over each chunk for two reasons: + 1) The "data" chunk may not be the next one + 2) We may want to report each chunk back to the client + + In order to correctly report each chunk back to the client we will need to keep looping until the end of the file. + */ + foundDataChunk = DRWAV_FALSE; + + /* The next chunk we care about is the "data" chunk. This is not necessarily the next chunk so we'll need to loop. */ + for (;;) + { + drwav_chunk_header header; + drwav_result result = drwav__read_chunk_header(pWav->onRead, pWav->pUserData, pWav->container, &cursor, &header); + if (result != DRWAV_SUCCESS) { + if (!foundDataChunk) { + return DRWAV_FALSE; + } else { + break; /* Probably at the end of the file. Get out of the loop. */ + } + } + + /* Tell the client about this chunk. */ + if (!sequential && onChunk != NULL) { + drwav_uint64 callbackBytesRead = onChunk(pChunkUserData, pWav->onRead, pWav->onSeek, pWav->pUserData, &header, pWav->container, &fmt); + + /* + dr_wav may need to read the contents of the chunk, so we now need to seek back to the position before + we called the callback. + */ + if (callbackBytesRead > 0) { + if (!drwav__seek_from_start(pWav->onSeek, cursor, pWav->pUserData)) { + return DRWAV_FALSE; + } + } + } + + + if (!foundDataChunk) { + pWav->dataChunkDataPos = cursor; + } + + chunkSize = header.sizeInBytes; + if (pWav->container == drwav_container_riff || pWav->container == drwav_container_rf64) { + if (drwav__fourcc_equal(header.id.fourcc, "data")) { + foundDataChunk = DRWAV_TRUE; + if (pWav->container != drwav_container_rf64) { /* The data chunk size for RF64 will always be set to 0xFFFFFFFF here. It was set to it's true value earlier. */ + dataChunkSize = chunkSize; + } + } + } else { + if (drwav__guid_equal(header.id.guid, drwavGUID_W64_DATA)) { + foundDataChunk = DRWAV_TRUE; + dataChunkSize = chunkSize; + } + } + + /* + If at this point we have found the data chunk and we're running in sequential mode, we need to break out of this loop. The reason for + this is that we would otherwise require a backwards seek which sequential mode forbids. + */ + if (foundDataChunk && sequential) { + break; + } + + /* Optional. Get the total sample count from the FACT chunk. This is useful for compressed formats. */ + if (pWav->container == drwav_container_riff) { + if (drwav__fourcc_equal(header.id.fourcc, "fact")) { + drwav_uint32 sampleCount; + if (drwav__on_read(pWav->onRead, pWav->pUserData, &sampleCount, 4, &cursor) != 4) { + return DRWAV_FALSE; + } + chunkSize -= 4; + + if (!foundDataChunk) { + pWav->dataChunkDataPos = cursor; + } + + /* + The sample count in the "fact" chunk is either unreliable, or I'm not understanding it properly. For now I am only enabling this + for Microsoft ADPCM formats. + */ + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) { + sampleCountFromFactChunk = sampleCount; + } else { + sampleCountFromFactChunk = 0; + } + } + } else if (pWav->container == drwav_container_w64) { + if (drwav__guid_equal(header.id.guid, drwavGUID_W64_FACT)) { + if (drwav__on_read(pWav->onRead, pWav->pUserData, &sampleCountFromFactChunk, 8, &cursor) != 8) { + return DRWAV_FALSE; + } + chunkSize -= 8; + + if (!foundDataChunk) { + pWav->dataChunkDataPos = cursor; + } + } + } else if (pWav->container == drwav_container_rf64) { + /* We retrieved the sample count from the ds64 chunk earlier so no need to do that here. */ + } + + /* "smpl" chunk. */ + if (pWav->container == drwav_container_riff || pWav->container == drwav_container_rf64) { + if (drwav__fourcc_equal(header.id.fourcc, "smpl")) { + drwav_uint8 smplHeaderData[36]; /* 36 = size of the smpl header section, not including the loop data. */ + if (chunkSize >= sizeof(smplHeaderData)) { + drwav_uint64 bytesJustRead = drwav__on_read(pWav->onRead, pWav->pUserData, smplHeaderData, sizeof(smplHeaderData), &cursor); + chunkSize -= bytesJustRead; + + if (bytesJustRead == sizeof(smplHeaderData)) { + drwav_uint32 iLoop; + + pWav->smpl.manufacturer = drwav__bytes_to_u32(smplHeaderData+0); + pWav->smpl.product = drwav__bytes_to_u32(smplHeaderData+4); + pWav->smpl.samplePeriod = drwav__bytes_to_u32(smplHeaderData+8); + pWav->smpl.midiUnityNotes = drwav__bytes_to_u32(smplHeaderData+12); + pWav->smpl.midiPitchFraction = drwav__bytes_to_u32(smplHeaderData+16); + pWav->smpl.smpteFormat = drwav__bytes_to_u32(smplHeaderData+20); + pWav->smpl.smpteOffset = drwav__bytes_to_u32(smplHeaderData+24); + pWav->smpl.numSampleLoops = drwav__bytes_to_u32(smplHeaderData+28); + pWav->smpl.samplerData = drwav__bytes_to_u32(smplHeaderData+32); + + for (iLoop = 0; iLoop < pWav->smpl.numSampleLoops && iLoop < drwav_countof(pWav->smpl.loops); ++iLoop) { + drwav_uint8 smplLoopData[24]; /* 24 = size of a loop section in the smpl chunk. */ + bytesJustRead = drwav__on_read(pWav->onRead, pWav->pUserData, smplLoopData, sizeof(smplLoopData), &cursor); + chunkSize -= bytesJustRead; + + if (bytesJustRead == sizeof(smplLoopData)) { + pWav->smpl.loops[iLoop].cuePointId = drwav__bytes_to_u32(smplLoopData+0); + pWav->smpl.loops[iLoop].type = drwav__bytes_to_u32(smplLoopData+4); + pWav->smpl.loops[iLoop].start = drwav__bytes_to_u32(smplLoopData+8); + pWav->smpl.loops[iLoop].end = drwav__bytes_to_u32(smplLoopData+12); + pWav->smpl.loops[iLoop].fraction = drwav__bytes_to_u32(smplLoopData+16); + pWav->smpl.loops[iLoop].playCount = drwav__bytes_to_u32(smplLoopData+20); + } else { + break; /* Break from the smpl loop for loop. */ + } + } + } + } else { + /* Looks like invalid data. Ignore the chunk. */ + } + } + } else { + if (drwav__guid_equal(header.id.guid, drwavGUID_W64_SMPL)) { + /* + This path will be hit when a W64 WAV file contains a smpl chunk. I don't have a sample file to test this path, so a contribution + is welcome to add support for this. + */ + } + } + + /* Make sure we seek past the padding. */ + chunkSize += header.paddingSize; + if (!drwav__seek_forward(pWav->onSeek, chunkSize, pWav->pUserData)) { + break; + } + cursor += chunkSize; + + if (!foundDataChunk) { + pWav->dataChunkDataPos = cursor; + } + } + + /* If we haven't found a data chunk, return an error. */ + if (!foundDataChunk) { + return DRWAV_FALSE; + } + + /* We may have moved passed the data chunk. If so we need to move back. If running in sequential mode we can assume we are already sitting on the data chunk. */ + if (!sequential) { + if (!drwav__seek_from_start(pWav->onSeek, pWav->dataChunkDataPos, pWav->pUserData)) { + return DRWAV_FALSE; + } + cursor = pWav->dataChunkDataPos; + } + + + /* At this point we should be sitting on the first byte of the raw audio data. */ + + pWav->fmt = fmt; + pWav->sampleRate = fmt.sampleRate; + pWav->channels = fmt.channels; + pWav->bitsPerSample = fmt.bitsPerSample; + pWav->bytesRemaining = dataChunkSize; + pWav->translatedFormatTag = translatedFormatTag; + pWav->dataChunkDataSize = dataChunkSize; + + if (sampleCountFromFactChunk != 0) { + pWav->totalPCMFrameCount = sampleCountFromFactChunk; + } else { + pWav->totalPCMFrameCount = dataChunkSize / drwav_get_bytes_per_pcm_frame(pWav); + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) { + drwav_uint64 totalBlockHeaderSizeInBytes; + drwav_uint64 blockCount = dataChunkSize / fmt.blockAlign; + + /* Make sure any trailing partial block is accounted for. */ + if ((blockCount * fmt.blockAlign) < dataChunkSize) { + blockCount += 1; + } + + /* We decode two samples per byte. There will be blockCount headers in the data chunk. This is enough to know how to calculate the total PCM frame count. */ + totalBlockHeaderSizeInBytes = blockCount * (6*fmt.channels); + pWav->totalPCMFrameCount = ((dataChunkSize - totalBlockHeaderSizeInBytes) * 2) / fmt.channels; + } + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) { + drwav_uint64 totalBlockHeaderSizeInBytes; + drwav_uint64 blockCount = dataChunkSize / fmt.blockAlign; + + /* Make sure any trailing partial block is accounted for. */ + if ((blockCount * fmt.blockAlign) < dataChunkSize) { + blockCount += 1; + } + + /* We decode two samples per byte. There will be blockCount headers in the data chunk. This is enough to know how to calculate the total PCM frame count. */ + totalBlockHeaderSizeInBytes = blockCount * (4*fmt.channels); + pWav->totalPCMFrameCount = ((dataChunkSize - totalBlockHeaderSizeInBytes) * 2) / fmt.channels; + + /* The header includes a decoded sample for each channel which acts as the initial predictor sample. */ + pWav->totalPCMFrameCount += blockCount; + } + } + + /* Some formats only support a certain number of channels. */ + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM || pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) { + if (pWav->channels > 2) { + return DRWAV_FALSE; + } + } + +#ifdef DR_WAV_LIBSNDFILE_COMPAT + /* + I use libsndfile as a benchmark for testing, however in the version I'm using (from the Windows installer on the libsndfile website), + it appears the total sample count libsndfile uses for MS-ADPCM is incorrect. It would seem they are computing the total sample count + from the number of blocks, however this results in the inclusion of extra silent samples at the end of the last block. The correct + way to know the total sample count is to inspect the "fact" chunk, which should always be present for compressed formats, and should + always include the sample count. This little block of code below is only used to emulate the libsndfile logic so I can properly run my + correctness tests against libsndfile, and is disabled by default. + */ + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) { + drwav_uint64 blockCount = dataChunkSize / fmt.blockAlign; + pWav->totalPCMFrameCount = (((blockCount * (fmt.blockAlign - (6*pWav->channels))) * 2)) / fmt.channels; /* x2 because two samples per byte. */ + } + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) { + drwav_uint64 blockCount = dataChunkSize / fmt.blockAlign; + pWav->totalPCMFrameCount = (((blockCount * (fmt.blockAlign - (4*pWav->channels))) * 2) + (blockCount * pWav->channels)) / fmt.channels; + } +#endif + + return DRWAV_TRUE; +} + +DRWAV_API drwav_bool32 drwav_init(drwav* pWav, drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + return drwav_init_ex(pWav, onRead, onSeek, NULL, pUserData, NULL, 0, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_ex(drwav* pWav, drwav_read_proc onRead, drwav_seek_proc onSeek, drwav_chunk_proc onChunk, void* pReadSeekUserData, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (!drwav_preinit(pWav, onRead, onSeek, pReadSeekUserData, pAllocationCallbacks)) { + return DRWAV_FALSE; + } + + return drwav_init__internal(pWav, onChunk, pChunkUserData, flags); +} + + +static drwav_uint32 drwav__riff_chunk_size_riff(drwav_uint64 dataChunkSize) +{ + drwav_uint64 chunkSize = 4 + 24 + dataChunkSize + drwav__chunk_padding_size_riff(dataChunkSize); /* 4 = "WAVE". 24 = "fmt " chunk. */ + if (chunkSize > 0xFFFFFFFFUL) { + chunkSize = 0xFFFFFFFFUL; + } + + return (drwav_uint32)chunkSize; /* Safe cast due to the clamp above. */ +} + +static drwav_uint32 drwav__data_chunk_size_riff(drwav_uint64 dataChunkSize) +{ + if (dataChunkSize <= 0xFFFFFFFFUL) { + return (drwav_uint32)dataChunkSize; + } else { + return 0xFFFFFFFFUL; + } +} + +static drwav_uint64 drwav__riff_chunk_size_w64(drwav_uint64 dataChunkSize) +{ + drwav_uint64 dataSubchunkPaddingSize = drwav__chunk_padding_size_w64(dataChunkSize); + + return 80 + 24 + dataChunkSize + dataSubchunkPaddingSize; /* +24 because W64 includes the size of the GUID and size fields. */ +} + +static drwav_uint64 drwav__data_chunk_size_w64(drwav_uint64 dataChunkSize) +{ + return 24 + dataChunkSize; /* +24 because W64 includes the size of the GUID and size fields. */ +} + +static drwav_uint64 drwav__riff_chunk_size_rf64(drwav_uint64 dataChunkSize) +{ + drwav_uint64 chunkSize = 4 + 36 + 24 + dataChunkSize + drwav__chunk_padding_size_riff(dataChunkSize); /* 4 = "WAVE". 36 = "ds64" chunk. 24 = "fmt " chunk. */ + if (chunkSize > 0xFFFFFFFFUL) { + chunkSize = 0xFFFFFFFFUL; + } + + return chunkSize; +} + +static drwav_uint64 drwav__data_chunk_size_rf64(drwav_uint64 dataChunkSize) +{ + return dataChunkSize; +} + + +static size_t drwav__write(drwav* pWav, const void* pData, size_t dataSize) +{ + DRWAV_ASSERT(pWav != NULL); + DRWAV_ASSERT(pWav->onWrite != NULL); + + /* Generic write. Assumes no byte reordering required. */ + return pWav->onWrite(pWav->pUserData, pData, dataSize); +} + +static size_t drwav__write_u16ne_to_le(drwav* pWav, drwav_uint16 value) +{ + DRWAV_ASSERT(pWav != NULL); + DRWAV_ASSERT(pWav->onWrite != NULL); + + if (!drwav__is_little_endian()) { + value = drwav__bswap16(value); + } + + return drwav__write(pWav, &value, 2); +} + +static size_t drwav__write_u32ne_to_le(drwav* pWav, drwav_uint32 value) +{ + DRWAV_ASSERT(pWav != NULL); + DRWAV_ASSERT(pWav->onWrite != NULL); + + if (!drwav__is_little_endian()) { + value = drwav__bswap32(value); + } + + return drwav__write(pWav, &value, 4); +} + +static size_t drwav__write_u64ne_to_le(drwav* pWav, drwav_uint64 value) +{ + DRWAV_ASSERT(pWav != NULL); + DRWAV_ASSERT(pWav->onWrite != NULL); + + if (!drwav__is_little_endian()) { + value = drwav__bswap64(value); + } + + return drwav__write(pWav, &value, 8); +} + + +static drwav_bool32 drwav_preinit_write(drwav* pWav, const drwav_data_format* pFormat, drwav_bool32 isSequential, drwav_write_proc onWrite, drwav_seek_proc onSeek, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (pWav == NULL || onWrite == NULL) { + return DRWAV_FALSE; + } + + if (!isSequential && onSeek == NULL) { + return DRWAV_FALSE; /* <-- onSeek is required when in non-sequential mode. */ + } + + /* Not currently supporting compressed formats. Will need to add support for the "fact" chunk before we enable this. */ + if (pFormat->format == DR_WAVE_FORMAT_EXTENSIBLE) { + return DRWAV_FALSE; + } + if (pFormat->format == DR_WAVE_FORMAT_ADPCM || pFormat->format == DR_WAVE_FORMAT_DVI_ADPCM) { + return DRWAV_FALSE; + } + + DRWAV_ZERO_MEMORY(pWav, sizeof(*pWav)); + pWav->onWrite = onWrite; + pWav->onSeek = onSeek; + pWav->pUserData = pUserData; + pWav->allocationCallbacks = drwav_copy_allocation_callbacks_or_defaults(pAllocationCallbacks); + + if (pWav->allocationCallbacks.onFree == NULL || (pWav->allocationCallbacks.onMalloc == NULL && pWav->allocationCallbacks.onRealloc == NULL)) { + return DRWAV_FALSE; /* Invalid allocation callbacks. */ + } + + pWav->fmt.formatTag = (drwav_uint16)pFormat->format; + pWav->fmt.channels = (drwav_uint16)pFormat->channels; + pWav->fmt.sampleRate = pFormat->sampleRate; + pWav->fmt.avgBytesPerSec = (drwav_uint32)((pFormat->bitsPerSample * pFormat->sampleRate * pFormat->channels) / 8); + pWav->fmt.blockAlign = (drwav_uint16)((pFormat->channels * pFormat->bitsPerSample) / 8); + pWav->fmt.bitsPerSample = (drwav_uint16)pFormat->bitsPerSample; + pWav->fmt.extendedSize = 0; + pWav->isSequentialWrite = isSequential; + + return DRWAV_TRUE; +} + +static drwav_bool32 drwav_init_write__internal(drwav* pWav, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount) +{ + /* The function assumes drwav_preinit_write() was called beforehand. */ + + size_t runningPos = 0; + drwav_uint64 initialDataChunkSize = 0; + drwav_uint64 chunkSizeFMT; + + /* + The initial values for the "RIFF" and "data" chunks depends on whether or not we are initializing in sequential mode or not. In + sequential mode we set this to its final values straight away since they can be calculated from the total sample count. In non- + sequential mode we initialize it all to zero and fill it out in drwav_uninit() using a backwards seek. + */ + if (pWav->isSequentialWrite) { + initialDataChunkSize = (totalSampleCount * pWav->fmt.bitsPerSample) / 8; + + /* + The RIFF container has a limit on the number of samples. drwav is not allowing this. There's no practical limits for Wave64 + so for the sake of simplicity I'm not doing any validation for that. + */ + if (pFormat->container == drwav_container_riff) { + if (initialDataChunkSize > (0xFFFFFFFFUL - 36)) { + return DRWAV_FALSE; /* Not enough room to store every sample. */ + } + } + } + + pWav->dataChunkDataSizeTargetWrite = initialDataChunkSize; + + + /* "RIFF" chunk. */ + if (pFormat->container == drwav_container_riff) { + drwav_uint32 chunkSizeRIFF = 28 + (drwav_uint32)initialDataChunkSize; /* +28 = "WAVE" + [sizeof "fmt " chunk] */ + runningPos += drwav__write(pWav, "RIFF", 4); + runningPos += drwav__write_u32ne_to_le(pWav, chunkSizeRIFF); + runningPos += drwav__write(pWav, "WAVE", 4); + } else if (pFormat->container == drwav_container_w64) { + drwav_uint64 chunkSizeRIFF = 80 + 24 + initialDataChunkSize; /* +24 because W64 includes the size of the GUID and size fields. */ + runningPos += drwav__write(pWav, drwavGUID_W64_RIFF, 16); + runningPos += drwav__write_u64ne_to_le(pWav, chunkSizeRIFF); + runningPos += drwav__write(pWav, drwavGUID_W64_WAVE, 16); + } else if (pFormat->container == drwav_container_rf64) { + runningPos += drwav__write(pWav, "RF64", 4); + runningPos += drwav__write_u32ne_to_le(pWav, 0xFFFFFFFF); /* Always 0xFFFFFFFF for RF64. Set to a proper value in the "ds64" chunk. */ + runningPos += drwav__write(pWav, "WAVE", 4); + } + + + /* "ds64" chunk (RF64 only). */ + if (pFormat->container == drwav_container_rf64) { + drwav_uint32 initialds64ChunkSize = 28; /* 28 = [Size of RIFF (8 bytes)] + [Size of DATA (8 bytes)] + [Sample Count (8 bytes)] + [Table Length (4 bytes)]. Table length always set to 0. */ + drwav_uint64 initialRiffChunkSize = 8 + initialds64ChunkSize + initialDataChunkSize; /* +8 for the ds64 header. */ + + runningPos += drwav__write(pWav, "ds64", 4); + runningPos += drwav__write_u32ne_to_le(pWav, initialds64ChunkSize); /* Size of ds64. */ + runningPos += drwav__write_u64ne_to_le(pWav, initialRiffChunkSize); /* Size of RIFF. Set to true value at the end. */ + runningPos += drwav__write_u64ne_to_le(pWav, initialDataChunkSize); /* Size of DATA. Set to true value at the end. */ + runningPos += drwav__write_u64ne_to_le(pWav, totalSampleCount); /* Sample count. */ + runningPos += drwav__write_u32ne_to_le(pWav, 0); /* Table length. Always set to zero in our case since we're not doing any other chunks than "DATA". */ + } + + + /* "fmt " chunk. */ + if (pFormat->container == drwav_container_riff || pFormat->container == drwav_container_rf64) { + chunkSizeFMT = 16; + runningPos += drwav__write(pWav, "fmt ", 4); + runningPos += drwav__write_u32ne_to_le(pWav, (drwav_uint32)chunkSizeFMT); + } else if (pFormat->container == drwav_container_w64) { + chunkSizeFMT = 40; + runningPos += drwav__write(pWav, drwavGUID_W64_FMT, 16); + runningPos += drwav__write_u64ne_to_le(pWav, chunkSizeFMT); + } + + runningPos += drwav__write_u16ne_to_le(pWav, pWav->fmt.formatTag); + runningPos += drwav__write_u16ne_to_le(pWav, pWav->fmt.channels); + runningPos += drwav__write_u32ne_to_le(pWav, pWav->fmt.sampleRate); + runningPos += drwav__write_u32ne_to_le(pWav, pWav->fmt.avgBytesPerSec); + runningPos += drwav__write_u16ne_to_le(pWav, pWav->fmt.blockAlign); + runningPos += drwav__write_u16ne_to_le(pWav, pWav->fmt.bitsPerSample); + + pWav->dataChunkDataPos = runningPos; + + /* "data" chunk. */ + if (pFormat->container == drwav_container_riff) { + drwav_uint32 chunkSizeDATA = (drwav_uint32)initialDataChunkSize; + runningPos += drwav__write(pWav, "data", 4); + runningPos += drwav__write_u32ne_to_le(pWav, chunkSizeDATA); + } else if (pFormat->container == drwav_container_w64) { + drwav_uint64 chunkSizeDATA = 24 + initialDataChunkSize; /* +24 because W64 includes the size of the GUID and size fields. */ + runningPos += drwav__write(pWav, drwavGUID_W64_DATA, 16); + runningPos += drwav__write_u64ne_to_le(pWav, chunkSizeDATA); + } else if (pFormat->container == drwav_container_rf64) { + runningPos += drwav__write(pWav, "data", 4); + runningPos += drwav__write_u32ne_to_le(pWav, 0xFFFFFFFF); /* Always set to 0xFFFFFFFF for RF64. The true size of the data chunk is specified in the ds64 chunk. */ + } + + /* + The runningPos variable is incremented in the section above but is left unused which is causing some static analysis tools to detect it + as a dead store. I'm leaving this as-is for safety just in case I want to expand this function later to include other tags and want to + keep track of the running position for whatever reason. The line below should silence the static analysis tools. + */ + (void)runningPos; + + /* Set some properties for the client's convenience. */ + pWav->container = pFormat->container; + pWav->channels = (drwav_uint16)pFormat->channels; + pWav->sampleRate = pFormat->sampleRate; + pWav->bitsPerSample = (drwav_uint16)pFormat->bitsPerSample; + pWav->translatedFormatTag = (drwav_uint16)pFormat->format; + + return DRWAV_TRUE; +} + + +DRWAV_API drwav_bool32 drwav_init_write(drwav* pWav, const drwav_data_format* pFormat, drwav_write_proc onWrite, drwav_seek_proc onSeek, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (!drwav_preinit_write(pWav, pFormat, DRWAV_FALSE, onWrite, onSeek, pUserData, pAllocationCallbacks)) { + return DRWAV_FALSE; + } + + return drwav_init_write__internal(pWav, pFormat, 0); /* DRWAV_FALSE = Not Sequential */ +} + +DRWAV_API drwav_bool32 drwav_init_write_sequential(drwav* pWav, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, drwav_write_proc onWrite, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (!drwav_preinit_write(pWav, pFormat, DRWAV_TRUE, onWrite, NULL, pUserData, pAllocationCallbacks)) { + return DRWAV_FALSE; + } + + return drwav_init_write__internal(pWav, pFormat, totalSampleCount); /* DRWAV_TRUE = Sequential */ +} + +DRWAV_API drwav_bool32 drwav_init_write_sequential_pcm_frames(drwav* pWav, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, drwav_write_proc onWrite, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (pFormat == NULL) { + return DRWAV_FALSE; + } + + return drwav_init_write_sequential(pWav, pFormat, totalPCMFrameCount*pFormat->channels, onWrite, pUserData, pAllocationCallbacks); +} + +DRWAV_API drwav_uint64 drwav_target_write_size_bytes(const drwav_data_format* pFormat, drwav_uint64 totalSampleCount) +{ + /* Casting totalSampleCount to drwav_int64 for VC6 compatibility. No issues in practice because nobody is going to exhaust the whole 63 bits. */ + drwav_uint64 targetDataSizeBytes = (drwav_uint64)((drwav_int64)totalSampleCount * pFormat->channels * pFormat->bitsPerSample/8.0); + drwav_uint64 riffChunkSizeBytes; + drwav_uint64 fileSizeBytes = 0; + + if (pFormat->container == drwav_container_riff) { + riffChunkSizeBytes = drwav__riff_chunk_size_riff(targetDataSizeBytes); + fileSizeBytes = (8 + riffChunkSizeBytes); /* +8 because WAV doesn't include the size of the ChunkID and ChunkSize fields. */ + } else if (pFormat->container == drwav_container_w64) { + riffChunkSizeBytes = drwav__riff_chunk_size_w64(targetDataSizeBytes); + fileSizeBytes = riffChunkSizeBytes; + } else if (pFormat->container == drwav_container_rf64) { + riffChunkSizeBytes = drwav__riff_chunk_size_rf64(targetDataSizeBytes); + fileSizeBytes = (8 + riffChunkSizeBytes); /* +8 because WAV doesn't include the size of the ChunkID and ChunkSize fields. */ + } + + return fileSizeBytes; +} + + +#ifndef DR_WAV_NO_STDIO + +/* drwav_result_from_errno() is only used for fopen() and wfopen() so putting it inside DR_WAV_NO_STDIO for now. If something else needs this later we can move it out. */ +#include +static drwav_result drwav_result_from_errno(int e) +{ + switch (e) + { + case 0: return DRWAV_SUCCESS; + #ifdef EPERM + case EPERM: return DRWAV_INVALID_OPERATION; + #endif + #ifdef ENOENT + case ENOENT: return DRWAV_DOES_NOT_EXIST; + #endif + #ifdef ESRCH + case ESRCH: return DRWAV_DOES_NOT_EXIST; + #endif + #ifdef EINTR + case EINTR: return DRWAV_INTERRUPT; + #endif + #ifdef EIO + case EIO: return DRWAV_IO_ERROR; + #endif + #ifdef ENXIO + case ENXIO: return DRWAV_DOES_NOT_EXIST; + #endif + #ifdef E2BIG + case E2BIG: return DRWAV_INVALID_ARGS; + #endif + #ifdef ENOEXEC + case ENOEXEC: return DRWAV_INVALID_FILE; + #endif + #ifdef EBADF + case EBADF: return DRWAV_INVALID_FILE; + #endif + #ifdef ECHILD + case ECHILD: return DRWAV_ERROR; + #endif + #ifdef EAGAIN + case EAGAIN: return DRWAV_UNAVAILABLE; + #endif + #ifdef ENOMEM + case ENOMEM: return DRWAV_OUT_OF_MEMORY; + #endif + #ifdef EACCES + case EACCES: return DRWAV_ACCESS_DENIED; + #endif + #ifdef EFAULT + case EFAULT: return DRWAV_BAD_ADDRESS; + #endif + #ifdef ENOTBLK + case ENOTBLK: return DRWAV_ERROR; + #endif + #ifdef EBUSY + case EBUSY: return DRWAV_BUSY; + #endif + #ifdef EEXIST + case EEXIST: return DRWAV_ALREADY_EXISTS; + #endif + #ifdef EXDEV + case EXDEV: return DRWAV_ERROR; + #endif + #ifdef ENODEV + case ENODEV: return DRWAV_DOES_NOT_EXIST; + #endif + #ifdef ENOTDIR + case ENOTDIR: return DRWAV_NOT_DIRECTORY; + #endif + #ifdef EISDIR + case EISDIR: return DRWAV_IS_DIRECTORY; + #endif + #ifdef EINVAL + case EINVAL: return DRWAV_INVALID_ARGS; + #endif + #ifdef ENFILE + case ENFILE: return DRWAV_TOO_MANY_OPEN_FILES; + #endif + #ifdef EMFILE + case EMFILE: return DRWAV_TOO_MANY_OPEN_FILES; + #endif + #ifdef ENOTTY + case ENOTTY: return DRWAV_INVALID_OPERATION; + #endif + #ifdef ETXTBSY + case ETXTBSY: return DRWAV_BUSY; + #endif + #ifdef EFBIG + case EFBIG: return DRWAV_TOO_BIG; + #endif + #ifdef ENOSPC + case ENOSPC: return DRWAV_NO_SPACE; + #endif + #ifdef ESPIPE + case ESPIPE: return DRWAV_BAD_SEEK; + #endif + #ifdef EROFS + case EROFS: return DRWAV_ACCESS_DENIED; + #endif + #ifdef EMLINK + case EMLINK: return DRWAV_TOO_MANY_LINKS; + #endif + #ifdef EPIPE + case EPIPE: return DRWAV_BAD_PIPE; + #endif + #ifdef EDOM + case EDOM: return DRWAV_OUT_OF_RANGE; + #endif + #ifdef ERANGE + case ERANGE: return DRWAV_OUT_OF_RANGE; + #endif + #ifdef EDEADLK + case EDEADLK: return DRWAV_DEADLOCK; + #endif + #ifdef ENAMETOOLONG + case ENAMETOOLONG: return DRWAV_PATH_TOO_LONG; + #endif + #ifdef ENOLCK + case ENOLCK: return DRWAV_ERROR; + #endif + #ifdef ENOSYS + case ENOSYS: return DRWAV_NOT_IMPLEMENTED; + #endif + #ifdef ENOTEMPTY + case ENOTEMPTY: return DRWAV_DIRECTORY_NOT_EMPTY; + #endif + #ifdef ELOOP + case ELOOP: return DRWAV_TOO_MANY_LINKS; + #endif + #ifdef ENOMSG + case ENOMSG: return DRWAV_NO_MESSAGE; + #endif + #ifdef EIDRM + case EIDRM: return DRWAV_ERROR; + #endif + #ifdef ECHRNG + case ECHRNG: return DRWAV_ERROR; + #endif + #ifdef EL2NSYNC + case EL2NSYNC: return DRWAV_ERROR; + #endif + #ifdef EL3HLT + case EL3HLT: return DRWAV_ERROR; + #endif + #ifdef EL3RST + case EL3RST: return DRWAV_ERROR; + #endif + #ifdef ELNRNG + case ELNRNG: return DRWAV_OUT_OF_RANGE; + #endif + #ifdef EUNATCH + case EUNATCH: return DRWAV_ERROR; + #endif + #ifdef ENOCSI + case ENOCSI: return DRWAV_ERROR; + #endif + #ifdef EL2HLT + case EL2HLT: return DRWAV_ERROR; + #endif + #ifdef EBADE + case EBADE: return DRWAV_ERROR; + #endif + #ifdef EBADR + case EBADR: return DRWAV_ERROR; + #endif + #ifdef EXFULL + case EXFULL: return DRWAV_ERROR; + #endif + #ifdef ENOANO + case ENOANO: return DRWAV_ERROR; + #endif + #ifdef EBADRQC + case EBADRQC: return DRWAV_ERROR; + #endif + #ifdef EBADSLT + case EBADSLT: return DRWAV_ERROR; + #endif + #ifdef EBFONT + case EBFONT: return DRWAV_INVALID_FILE; + #endif + #ifdef ENOSTR + case ENOSTR: return DRWAV_ERROR; + #endif + #ifdef ENODATA + case ENODATA: return DRWAV_NO_DATA_AVAILABLE; + #endif + #ifdef ETIME + case ETIME: return DRWAV_TIMEOUT; + #endif + #ifdef ENOSR + case ENOSR: return DRWAV_NO_DATA_AVAILABLE; + #endif + #ifdef ENONET + case ENONET: return DRWAV_NO_NETWORK; + #endif + #ifdef ENOPKG + case ENOPKG: return DRWAV_ERROR; + #endif + #ifdef EREMOTE + case EREMOTE: return DRWAV_ERROR; + #endif + #ifdef ENOLINK + case ENOLINK: return DRWAV_ERROR; + #endif + #ifdef EADV + case EADV: return DRWAV_ERROR; + #endif + #ifdef ESRMNT + case ESRMNT: return DRWAV_ERROR; + #endif + #ifdef ECOMM + case ECOMM: return DRWAV_ERROR; + #endif + #ifdef EPROTO + case EPROTO: return DRWAV_ERROR; + #endif + #ifdef EMULTIHOP + case EMULTIHOP: return DRWAV_ERROR; + #endif + #ifdef EDOTDOT + case EDOTDOT: return DRWAV_ERROR; + #endif + #ifdef EBADMSG + case EBADMSG: return DRWAV_BAD_MESSAGE; + #endif + #ifdef EOVERFLOW + case EOVERFLOW: return DRWAV_TOO_BIG; + #endif + #ifdef ENOTUNIQ + case ENOTUNIQ: return DRWAV_NOT_UNIQUE; + #endif + #ifdef EBADFD + case EBADFD: return DRWAV_ERROR; + #endif + #ifdef EREMCHG + case EREMCHG: return DRWAV_ERROR; + #endif + #ifdef ELIBACC + case ELIBACC: return DRWAV_ACCESS_DENIED; + #endif + #ifdef ELIBBAD + case ELIBBAD: return DRWAV_INVALID_FILE; + #endif + #ifdef ELIBSCN + case ELIBSCN: return DRWAV_INVALID_FILE; + #endif + #ifdef ELIBMAX + case ELIBMAX: return DRWAV_ERROR; + #endif + #ifdef ELIBEXEC + case ELIBEXEC: return DRWAV_ERROR; + #endif + #ifdef EILSEQ + case EILSEQ: return DRWAV_INVALID_DATA; + #endif + #ifdef ERESTART + case ERESTART: return DRWAV_ERROR; + #endif + #ifdef ESTRPIPE + case ESTRPIPE: return DRWAV_ERROR; + #endif + #ifdef EUSERS + case EUSERS: return DRWAV_ERROR; + #endif + #ifdef ENOTSOCK + case ENOTSOCK: return DRWAV_NOT_SOCKET; + #endif + #ifdef EDESTADDRREQ + case EDESTADDRREQ: return DRWAV_NO_ADDRESS; + #endif + #ifdef EMSGSIZE + case EMSGSIZE: return DRWAV_TOO_BIG; + #endif + #ifdef EPROTOTYPE + case EPROTOTYPE: return DRWAV_BAD_PROTOCOL; + #endif + #ifdef ENOPROTOOPT + case ENOPROTOOPT: return DRWAV_PROTOCOL_UNAVAILABLE; + #endif + #ifdef EPROTONOSUPPORT + case EPROTONOSUPPORT: return DRWAV_PROTOCOL_NOT_SUPPORTED; + #endif + #ifdef ESOCKTNOSUPPORT + case ESOCKTNOSUPPORT: return DRWAV_SOCKET_NOT_SUPPORTED; + #endif + #ifdef EOPNOTSUPP + case EOPNOTSUPP: return DRWAV_INVALID_OPERATION; + #endif + #ifdef EPFNOSUPPORT + case EPFNOSUPPORT: return DRWAV_PROTOCOL_FAMILY_NOT_SUPPORTED; + #endif + #ifdef EAFNOSUPPORT + case EAFNOSUPPORT: return DRWAV_ADDRESS_FAMILY_NOT_SUPPORTED; + #endif + #ifdef EADDRINUSE + case EADDRINUSE: return DRWAV_ALREADY_IN_USE; + #endif + #ifdef EADDRNOTAVAIL + case EADDRNOTAVAIL: return DRWAV_ERROR; + #endif + #ifdef ENETDOWN + case ENETDOWN: return DRWAV_NO_NETWORK; + #endif + #ifdef ENETUNREACH + case ENETUNREACH: return DRWAV_NO_NETWORK; + #endif + #ifdef ENETRESET + case ENETRESET: return DRWAV_NO_NETWORK; + #endif + #ifdef ECONNABORTED + case ECONNABORTED: return DRWAV_NO_NETWORK; + #endif + #ifdef ECONNRESET + case ECONNRESET: return DRWAV_CONNECTION_RESET; + #endif + #ifdef ENOBUFS + case ENOBUFS: return DRWAV_NO_SPACE; + #endif + #ifdef EISCONN + case EISCONN: return DRWAV_ALREADY_CONNECTED; + #endif + #ifdef ENOTCONN + case ENOTCONN: return DRWAV_NOT_CONNECTED; + #endif + #ifdef ESHUTDOWN + case ESHUTDOWN: return DRWAV_ERROR; + #endif + #ifdef ETOOMANYREFS + case ETOOMANYREFS: return DRWAV_ERROR; + #endif + #ifdef ETIMEDOUT + case ETIMEDOUT: return DRWAV_TIMEOUT; + #endif + #ifdef ECONNREFUSED + case ECONNREFUSED: return DRWAV_CONNECTION_REFUSED; + #endif + #ifdef EHOSTDOWN + case EHOSTDOWN: return DRWAV_NO_HOST; + #endif + #ifdef EHOSTUNREACH + case EHOSTUNREACH: return DRWAV_NO_HOST; + #endif + #ifdef EALREADY + case EALREADY: return DRWAV_IN_PROGRESS; + #endif + #ifdef EINPROGRESS + case EINPROGRESS: return DRWAV_IN_PROGRESS; + #endif + #ifdef ESTALE + case ESTALE: return DRWAV_INVALID_FILE; + #endif + #ifdef EUCLEAN + case EUCLEAN: return DRWAV_ERROR; + #endif + #ifdef ENOTNAM + case ENOTNAM: return DRWAV_ERROR; + #endif + #ifdef ENAVAIL + case ENAVAIL: return DRWAV_ERROR; + #endif + #ifdef EISNAM + case EISNAM: return DRWAV_ERROR; + #endif + #ifdef EREMOTEIO + case EREMOTEIO: return DRWAV_IO_ERROR; + #endif + #ifdef EDQUOT + case EDQUOT: return DRWAV_NO_SPACE; + #endif + #ifdef ENOMEDIUM + case ENOMEDIUM: return DRWAV_DOES_NOT_EXIST; + #endif + #ifdef EMEDIUMTYPE + case EMEDIUMTYPE: return DRWAV_ERROR; + #endif + #ifdef ECANCELED + case ECANCELED: return DRWAV_CANCELLED; + #endif + #ifdef ENOKEY + case ENOKEY: return DRWAV_ERROR; + #endif + #ifdef EKEYEXPIRED + case EKEYEXPIRED: return DRWAV_ERROR; + #endif + #ifdef EKEYREVOKED + case EKEYREVOKED: return DRWAV_ERROR; + #endif + #ifdef EKEYREJECTED + case EKEYREJECTED: return DRWAV_ERROR; + #endif + #ifdef EOWNERDEAD + case EOWNERDEAD: return DRWAV_ERROR; + #endif + #ifdef ENOTRECOVERABLE + case ENOTRECOVERABLE: return DRWAV_ERROR; + #endif + #ifdef ERFKILL + case ERFKILL: return DRWAV_ERROR; + #endif + #ifdef EHWPOISON + case EHWPOISON: return DRWAV_ERROR; + #endif + default: return DRWAV_ERROR; + } +} + +static drwav_result drwav_fopen(FILE** ppFile, const char* pFilePath, const char* pOpenMode) +{ +#if _MSC_VER && _MSC_VER >= 1400 + errno_t err; +#endif + + if (ppFile != NULL) { + *ppFile = NULL; /* Safety. */ + } + + if (pFilePath == NULL || pOpenMode == NULL || ppFile == NULL) { + return DRWAV_INVALID_ARGS; + } + +#if _MSC_VER && _MSC_VER >= 1400 + err = fopen_s(ppFile, pFilePath, pOpenMode); + if (err != 0) { + return drwav_result_from_errno(err); + } +#else +#if defined(_WIN32) || defined(__APPLE__) + *ppFile = fopen(pFilePath, pOpenMode); +#else + #if defined(_FILE_OFFSET_BITS) && _FILE_OFFSET_BITS == 64 && defined(_LARGEFILE64_SOURCE) + *ppFile = fopen64(pFilePath, pOpenMode); + #else + *ppFile = fopen(pFilePath, pOpenMode); + #endif +#endif + if (*ppFile == NULL) { + drwav_result result = drwav_result_from_errno(errno); + if (result == DRWAV_SUCCESS) { + result = DRWAV_ERROR; /* Just a safety check to make sure we never ever return success when pFile == NULL. */ + } + + return result; + } +#endif + + return DRWAV_SUCCESS; +} + +/* +_wfopen() isn't always available in all compilation environments. + + * Windows only. + * MSVC seems to support it universally as far back as VC6 from what I can tell (haven't checked further back). + * MinGW-64 (both 32- and 64-bit) seems to support it. + * MinGW wraps it in !defined(__STRICT_ANSI__). + * OpenWatcom wraps it in !defined(_NO_EXT_KEYS). + +This can be reviewed as compatibility issues arise. The preference is to use _wfopen_s() and _wfopen() as opposed to the wcsrtombs() +fallback, so if you notice your compiler not detecting this properly I'm happy to look at adding support. +*/ +#if defined(_WIN32) + #if defined(_MSC_VER) || defined(__MINGW64__) || (!defined(__STRICT_ANSI__) && !defined(_NO_EXT_KEYS)) + #define DRWAV_HAS_WFOPEN + #endif +#endif + +static drwav_result drwav_wfopen(FILE** ppFile, const wchar_t* pFilePath, const wchar_t* pOpenMode, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (ppFile != NULL) { + *ppFile = NULL; /* Safety. */ + } + + if (pFilePath == NULL || pOpenMode == NULL || ppFile == NULL) { + return DRWAV_INVALID_ARGS; + } + +#if defined(DRWAV_HAS_WFOPEN) + { + /* Use _wfopen() on Windows. */ + #if defined(_MSC_VER) && _MSC_VER >= 1400 + errno_t err = _wfopen_s(ppFile, pFilePath, pOpenMode); + if (err != 0) { + return drwav_result_from_errno(err); + } + #else + *ppFile = _wfopen(pFilePath, pOpenMode); + if (*ppFile == NULL) { + return drwav_result_from_errno(errno); + } + #endif + (void)pAllocationCallbacks; + } +#else + /* + Use fopen() on anything other than Windows. Requires a conversion. This is annoying because fopen() is locale specific. The only real way I can + think of to do this is with wcsrtombs(). Note that wcstombs() is apparently not thread-safe because it uses a static global mbstate_t object for + maintaining state. I've checked this with -std=c89 and it works, but if somebody get's a compiler error I'll look into improving compatibility. + */ + { + mbstate_t mbs; + size_t lenMB; + const wchar_t* pFilePathTemp = pFilePath; + char* pFilePathMB = NULL; + char pOpenModeMB[32] = {0}; + + /* Get the length first. */ + DRWAV_ZERO_OBJECT(&mbs); + lenMB = wcsrtombs(NULL, &pFilePathTemp, 0, &mbs); + if (lenMB == (size_t)-1) { + return drwav_result_from_errno(errno); + } + + pFilePathMB = (char*)drwav__malloc_from_callbacks(lenMB + 1, pAllocationCallbacks); + if (pFilePathMB == NULL) { + return DRWAV_OUT_OF_MEMORY; + } + + pFilePathTemp = pFilePath; + DRWAV_ZERO_OBJECT(&mbs); + wcsrtombs(pFilePathMB, &pFilePathTemp, lenMB + 1, &mbs); + + /* The open mode should always consist of ASCII characters so we should be able to do a trivial conversion. */ + { + size_t i = 0; + for (;;) { + if (pOpenMode[i] == 0) { + pOpenModeMB[i] = '\0'; + break; + } + + pOpenModeMB[i] = (char)pOpenMode[i]; + i += 1; + } + } + + *ppFile = fopen(pFilePathMB, pOpenModeMB); + + drwav__free_from_callbacks(pFilePathMB, pAllocationCallbacks); + } + + if (*ppFile == NULL) { + return DRWAV_ERROR; + } +#endif + + return DRWAV_SUCCESS; +} + + +static size_t drwav__on_read_stdio(void* pUserData, void* pBufferOut, size_t bytesToRead) +{ + return fread(pBufferOut, 1, bytesToRead, (FILE*)pUserData); +} + +static size_t drwav__on_write_stdio(void* pUserData, const void* pData, size_t bytesToWrite) +{ + return fwrite(pData, 1, bytesToWrite, (FILE*)pUserData); +} + +static drwav_bool32 drwav__on_seek_stdio(void* pUserData, int offset, drwav_seek_origin origin) +{ + return fseek((FILE*)pUserData, offset, (origin == drwav_seek_origin_current) ? SEEK_CUR : SEEK_SET) == 0; +} + +DRWAV_API drwav_bool32 drwav_init_file(drwav* pWav, const char* filename, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + return drwav_init_file_ex(pWav, filename, NULL, NULL, 0, pAllocationCallbacks); +} + + +static drwav_bool32 drwav_init_file__internal_FILE(drwav* pWav, FILE* pFile, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav_bool32 result; + + result = drwav_preinit(pWav, drwav__on_read_stdio, drwav__on_seek_stdio, (void*)pFile, pAllocationCallbacks); + if (result != DRWAV_TRUE) { + fclose(pFile); + return result; + } + + result = drwav_init__internal(pWav, onChunk, pChunkUserData, flags); + if (result != DRWAV_TRUE) { + fclose(pFile); + return result; + } + + return DRWAV_TRUE; +} + +DRWAV_API drwav_bool32 drwav_init_file_ex(drwav* pWav, const char* filename, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + FILE* pFile; + if (drwav_fopen(&pFile, filename, "rb") != DRWAV_SUCCESS) { + return DRWAV_FALSE; + } + + /* This takes ownership of the FILE* object. */ + return drwav_init_file__internal_FILE(pWav, pFile, onChunk, pChunkUserData, flags, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_file_w(drwav* pWav, const wchar_t* filename, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + return drwav_init_file_ex_w(pWav, filename, NULL, NULL, 0, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_file_ex_w(drwav* pWav, const wchar_t* filename, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + FILE* pFile; + if (drwav_wfopen(&pFile, filename, L"rb", pAllocationCallbacks) != DRWAV_SUCCESS) { + return DRWAV_FALSE; + } + + /* This takes ownership of the FILE* object. */ + return drwav_init_file__internal_FILE(pWav, pFile, onChunk, pChunkUserData, flags, pAllocationCallbacks); +} + + +static drwav_bool32 drwav_init_file_write__internal_FILE(drwav* pWav, FILE* pFile, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, drwav_bool32 isSequential, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav_bool32 result; + + result = drwav_preinit_write(pWav, pFormat, isSequential, drwav__on_write_stdio, drwav__on_seek_stdio, (void*)pFile, pAllocationCallbacks); + if (result != DRWAV_TRUE) { + fclose(pFile); + return result; + } + + result = drwav_init_write__internal(pWav, pFormat, totalSampleCount); + if (result != DRWAV_TRUE) { + fclose(pFile); + return result; + } + + return DRWAV_TRUE; +} + +static drwav_bool32 drwav_init_file_write__internal(drwav* pWav, const char* filename, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, drwav_bool32 isSequential, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + FILE* pFile; + if (drwav_fopen(&pFile, filename, "wb") != DRWAV_SUCCESS) { + return DRWAV_FALSE; + } + + /* This takes ownership of the FILE* object. */ + return drwav_init_file_write__internal_FILE(pWav, pFile, pFormat, totalSampleCount, isSequential, pAllocationCallbacks); +} + +static drwav_bool32 drwav_init_file_write_w__internal(drwav* pWav, const wchar_t* filename, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, drwav_bool32 isSequential, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + FILE* pFile; + if (drwav_wfopen(&pFile, filename, L"wb", pAllocationCallbacks) != DRWAV_SUCCESS) { + return DRWAV_FALSE; + } + + /* This takes ownership of the FILE* object. */ + return drwav_init_file_write__internal_FILE(pWav, pFile, pFormat, totalSampleCount, isSequential, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_file_write(drwav* pWav, const char* filename, const drwav_data_format* pFormat, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + return drwav_init_file_write__internal(pWav, filename, pFormat, 0, DRWAV_FALSE, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_file_write_sequential(drwav* pWav, const char* filename, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + return drwav_init_file_write__internal(pWav, filename, pFormat, totalSampleCount, DRWAV_TRUE, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_file_write_sequential_pcm_frames(drwav* pWav, const char* filename, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (pFormat == NULL) { + return DRWAV_FALSE; + } + + return drwav_init_file_write_sequential(pWav, filename, pFormat, totalPCMFrameCount*pFormat->channels, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_file_write_w(drwav* pWav, const wchar_t* filename, const drwav_data_format* pFormat, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + return drwav_init_file_write_w__internal(pWav, filename, pFormat, 0, DRWAV_FALSE, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_file_write_sequential_w(drwav* pWav, const wchar_t* filename, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + return drwav_init_file_write_w__internal(pWav, filename, pFormat, totalSampleCount, DRWAV_TRUE, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_file_write_sequential_pcm_frames_w(drwav* pWav, const wchar_t* filename, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (pFormat == NULL) { + return DRWAV_FALSE; + } + + return drwav_init_file_write_sequential_w(pWav, filename, pFormat, totalPCMFrameCount*pFormat->channels, pAllocationCallbacks); +} +#endif /* DR_WAV_NO_STDIO */ + + +static size_t drwav__on_read_memory(void* pUserData, void* pBufferOut, size_t bytesToRead) +{ + drwav* pWav = (drwav*)pUserData; + size_t bytesRemaining; + + DRWAV_ASSERT(pWav != NULL); + DRWAV_ASSERT(pWav->memoryStream.dataSize >= pWav->memoryStream.currentReadPos); + + bytesRemaining = pWav->memoryStream.dataSize - pWav->memoryStream.currentReadPos; + if (bytesToRead > bytesRemaining) { + bytesToRead = bytesRemaining; + } + + if (bytesToRead > 0) { + DRWAV_COPY_MEMORY(pBufferOut, pWav->memoryStream.data + pWav->memoryStream.currentReadPos, bytesToRead); + pWav->memoryStream.currentReadPos += bytesToRead; + } + + return bytesToRead; +} + +static drwav_bool32 drwav__on_seek_memory(void* pUserData, int offset, drwav_seek_origin origin) +{ + drwav* pWav = (drwav*)pUserData; + DRWAV_ASSERT(pWav != NULL); + + if (origin == drwav_seek_origin_current) { + if (offset > 0) { + if (pWav->memoryStream.currentReadPos + offset > pWav->memoryStream.dataSize) { + return DRWAV_FALSE; /* Trying to seek too far forward. */ + } + } else { + if (pWav->memoryStream.currentReadPos < (size_t)-offset) { + return DRWAV_FALSE; /* Trying to seek too far backwards. */ + } + } + + /* This will never underflow thanks to the clamps above. */ + pWav->memoryStream.currentReadPos += offset; + } else { + if ((drwav_uint32)offset <= pWav->memoryStream.dataSize) { + pWav->memoryStream.currentReadPos = offset; + } else { + return DRWAV_FALSE; /* Trying to seek too far forward. */ + } + } + + return DRWAV_TRUE; +} + +static size_t drwav__on_write_memory(void* pUserData, const void* pDataIn, size_t bytesToWrite) +{ + drwav* pWav = (drwav*)pUserData; + size_t bytesRemaining; + + DRWAV_ASSERT(pWav != NULL); + DRWAV_ASSERT(pWav->memoryStreamWrite.dataCapacity >= pWav->memoryStreamWrite.currentWritePos); + + bytesRemaining = pWav->memoryStreamWrite.dataCapacity - pWav->memoryStreamWrite.currentWritePos; + if (bytesRemaining < bytesToWrite) { + /* Need to reallocate. */ + void* pNewData; + size_t newDataCapacity = (pWav->memoryStreamWrite.dataCapacity == 0) ? 256 : pWav->memoryStreamWrite.dataCapacity * 2; + + /* If doubling wasn't enough, just make it the minimum required size to write the data. */ + if ((newDataCapacity - pWav->memoryStreamWrite.currentWritePos) < bytesToWrite) { + newDataCapacity = pWav->memoryStreamWrite.currentWritePos + bytesToWrite; + } + + pNewData = drwav__realloc_from_callbacks(*pWav->memoryStreamWrite.ppData, newDataCapacity, pWav->memoryStreamWrite.dataCapacity, &pWav->allocationCallbacks); + if (pNewData == NULL) { + return 0; + } + + *pWav->memoryStreamWrite.ppData = pNewData; + pWav->memoryStreamWrite.dataCapacity = newDataCapacity; + } + + DRWAV_COPY_MEMORY(((drwav_uint8*)(*pWav->memoryStreamWrite.ppData)) + pWav->memoryStreamWrite.currentWritePos, pDataIn, bytesToWrite); + + pWav->memoryStreamWrite.currentWritePos += bytesToWrite; + if (pWav->memoryStreamWrite.dataSize < pWav->memoryStreamWrite.currentWritePos) { + pWav->memoryStreamWrite.dataSize = pWav->memoryStreamWrite.currentWritePos; + } + + *pWav->memoryStreamWrite.pDataSize = pWav->memoryStreamWrite.dataSize; + + return bytesToWrite; +} + +static drwav_bool32 drwav__on_seek_memory_write(void* pUserData, int offset, drwav_seek_origin origin) +{ + drwav* pWav = (drwav*)pUserData; + DRWAV_ASSERT(pWav != NULL); + + if (origin == drwav_seek_origin_current) { + if (offset > 0) { + if (pWav->memoryStreamWrite.currentWritePos + offset > pWav->memoryStreamWrite.dataSize) { + offset = (int)(pWav->memoryStreamWrite.dataSize - pWav->memoryStreamWrite.currentWritePos); /* Trying to seek too far forward. */ + } + } else { + if (pWav->memoryStreamWrite.currentWritePos < (size_t)-offset) { + offset = -(int)pWav->memoryStreamWrite.currentWritePos; /* Trying to seek too far backwards. */ + } + } + + /* This will never underflow thanks to the clamps above. */ + pWav->memoryStreamWrite.currentWritePos += offset; + } else { + if ((drwav_uint32)offset <= pWav->memoryStreamWrite.dataSize) { + pWav->memoryStreamWrite.currentWritePos = offset; + } else { + pWav->memoryStreamWrite.currentWritePos = pWav->memoryStreamWrite.dataSize; /* Trying to seek too far forward. */ + } + } + + return DRWAV_TRUE; +} + +DRWAV_API drwav_bool32 drwav_init_memory(drwav* pWav, const void* data, size_t dataSize, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + return drwav_init_memory_ex(pWav, data, dataSize, NULL, NULL, 0, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_memory_ex(drwav* pWav, const void* data, size_t dataSize, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (data == NULL || dataSize == 0) { + return DRWAV_FALSE; + } + + if (!drwav_preinit(pWav, drwav__on_read_memory, drwav__on_seek_memory, pWav, pAllocationCallbacks)) { + return DRWAV_FALSE; + } + + pWav->memoryStream.data = (const drwav_uint8*)data; + pWav->memoryStream.dataSize = dataSize; + pWav->memoryStream.currentReadPos = 0; + + return drwav_init__internal(pWav, onChunk, pChunkUserData, flags); +} + + +static drwav_bool32 drwav_init_memory_write__internal(drwav* pWav, void** ppData, size_t* pDataSize, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, drwav_bool32 isSequential, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (ppData == NULL || pDataSize == NULL) { + return DRWAV_FALSE; + } + + *ppData = NULL; /* Important because we're using realloc()! */ + *pDataSize = 0; + + if (!drwav_preinit_write(pWav, pFormat, isSequential, drwav__on_write_memory, drwav__on_seek_memory_write, pWav, pAllocationCallbacks)) { + return DRWAV_FALSE; + } + + pWav->memoryStreamWrite.ppData = ppData; + pWav->memoryStreamWrite.pDataSize = pDataSize; + pWav->memoryStreamWrite.dataSize = 0; + pWav->memoryStreamWrite.dataCapacity = 0; + pWav->memoryStreamWrite.currentWritePos = 0; + + return drwav_init_write__internal(pWav, pFormat, totalSampleCount); +} + +DRWAV_API drwav_bool32 drwav_init_memory_write(drwav* pWav, void** ppData, size_t* pDataSize, const drwav_data_format* pFormat, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + return drwav_init_memory_write__internal(pWav, ppData, pDataSize, pFormat, 0, DRWAV_FALSE, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_memory_write_sequential(drwav* pWav, void** ppData, size_t* pDataSize, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + return drwav_init_memory_write__internal(pWav, ppData, pDataSize, pFormat, totalSampleCount, DRWAV_TRUE, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_memory_write_sequential_pcm_frames(drwav* pWav, void** ppData, size_t* pDataSize, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (pFormat == NULL) { + return DRWAV_FALSE; + } + + return drwav_init_memory_write_sequential(pWav, ppData, pDataSize, pFormat, totalPCMFrameCount*pFormat->channels, pAllocationCallbacks); +} + + + +DRWAV_API drwav_result drwav_uninit(drwav* pWav) +{ + drwav_result result = DRWAV_SUCCESS; + + if (pWav == NULL) { + return DRWAV_INVALID_ARGS; + } + + /* + If the drwav object was opened in write mode we'll need to finalize a few things: + - Make sure the "data" chunk is aligned to 16-bits for RIFF containers, or 64 bits for W64 containers. + - Set the size of the "data" chunk. + */ + if (pWav->onWrite != NULL) { + drwav_uint32 paddingSize = 0; + + /* Padding. Do not adjust pWav->dataChunkDataSize - this should not include the padding. */ + if (pWav->container == drwav_container_riff || pWav->container == drwav_container_rf64) { + paddingSize = drwav__chunk_padding_size_riff(pWav->dataChunkDataSize); + } else { + paddingSize = drwav__chunk_padding_size_w64(pWav->dataChunkDataSize); + } + + if (paddingSize > 0) { + drwav_uint64 paddingData = 0; + drwav__write(pWav, &paddingData, paddingSize); /* Byte order does not matter for this. */ + } + + /* + Chunk sizes. When using sequential mode, these will have been filled in at initialization time. We only need + to do this when using non-sequential mode. + */ + if (pWav->onSeek && !pWav->isSequentialWrite) { + if (pWav->container == drwav_container_riff) { + /* The "RIFF" chunk size. */ + if (pWav->onSeek(pWav->pUserData, 4, drwav_seek_origin_start)) { + drwav_uint32 riffChunkSize = drwav__riff_chunk_size_riff(pWav->dataChunkDataSize); + drwav__write_u32ne_to_le(pWav, riffChunkSize); + } + + /* the "data" chunk size. */ + if (pWav->onSeek(pWav->pUserData, (int)pWav->dataChunkDataPos + 4, drwav_seek_origin_start)) { + drwav_uint32 dataChunkSize = drwav__data_chunk_size_riff(pWav->dataChunkDataSize); + drwav__write_u32ne_to_le(pWav, dataChunkSize); + } + } else if (pWav->container == drwav_container_w64) { + /* The "RIFF" chunk size. */ + if (pWav->onSeek(pWav->pUserData, 16, drwav_seek_origin_start)) { + drwav_uint64 riffChunkSize = drwav__riff_chunk_size_w64(pWav->dataChunkDataSize); + drwav__write_u64ne_to_le(pWav, riffChunkSize); + } + + /* The "data" chunk size. */ + if (pWav->onSeek(pWav->pUserData, (int)pWav->dataChunkDataPos + 16, drwav_seek_origin_start)) { + drwav_uint64 dataChunkSize = drwav__data_chunk_size_w64(pWav->dataChunkDataSize); + drwav__write_u64ne_to_le(pWav, dataChunkSize); + } + } else if (pWav->container == drwav_container_rf64) { + /* We only need to update the ds64 chunk. The "RIFF" and "data" chunks always have their sizes set to 0xFFFFFFFF for RF64. */ + int ds64BodyPos = 12 + 8; + + /* The "RIFF" chunk size. */ + if (pWav->onSeek(pWav->pUserData, ds64BodyPos + 0, drwav_seek_origin_start)) { + drwav_uint64 riffChunkSize = drwav__riff_chunk_size_rf64(pWav->dataChunkDataSize); + drwav__write_u64ne_to_le(pWav, riffChunkSize); + } + + /* The "data" chunk size. */ + if (pWav->onSeek(pWav->pUserData, ds64BodyPos + 8, drwav_seek_origin_start)) { + drwav_uint64 dataChunkSize = drwav__data_chunk_size_rf64(pWav->dataChunkDataSize); + drwav__write_u64ne_to_le(pWav, dataChunkSize); + } + } + } + + /* Validation for sequential mode. */ + if (pWav->isSequentialWrite) { + if (pWav->dataChunkDataSize != pWav->dataChunkDataSizeTargetWrite) { + result = DRWAV_INVALID_FILE; + } + } + } + +#ifndef DR_WAV_NO_STDIO + /* + If we opened the file with drwav_open_file() we will want to close the file handle. We can know whether or not drwav_open_file() + was used by looking at the onRead and onSeek callbacks. + */ + if (pWav->onRead == drwav__on_read_stdio || pWav->onWrite == drwav__on_write_stdio) { + fclose((FILE*)pWav->pUserData); + } +#endif + + return result; +} + + + +DRWAV_API size_t drwav_read_raw(drwav* pWav, size_t bytesToRead, void* pBufferOut) +{ + size_t bytesRead; + + if (pWav == NULL || bytesToRead == 0) { + return 0; + } + + if (bytesToRead > pWav->bytesRemaining) { + bytesToRead = (size_t)pWav->bytesRemaining; + } + + if (pBufferOut != NULL) { + bytesRead = pWav->onRead(pWav->pUserData, pBufferOut, bytesToRead); + } else { + /* We need to seek. If we fail, we need to read-and-discard to make sure we get a good byte count. */ + bytesRead = 0; + while (bytesRead < bytesToRead) { + size_t bytesToSeek = (bytesToRead - bytesRead); + if (bytesToSeek > 0x7FFFFFFF) { + bytesToSeek = 0x7FFFFFFF; + } + + if (pWav->onSeek(pWav->pUserData, (int)bytesToSeek, drwav_seek_origin_current) == DRWAV_FALSE) { + break; + } + + bytesRead += bytesToSeek; + } + + /* When we get here we may need to read-and-discard some data. */ + while (bytesRead < bytesToRead) { + drwav_uint8 buffer[4096]; + size_t bytesSeeked; + size_t bytesToSeek = (bytesToRead - bytesRead); + if (bytesToSeek > sizeof(buffer)) { + bytesToSeek = sizeof(buffer); + } + + bytesSeeked = pWav->onRead(pWav->pUserData, buffer, bytesToSeek); + bytesRead += bytesSeeked; + + if (bytesSeeked < bytesToSeek) { + break; /* Reached the end. */ + } + } + } + + pWav->bytesRemaining -= bytesRead; + return bytesRead; +} + + + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_le(drwav* pWav, drwav_uint64 framesToRead, void* pBufferOut) +{ + drwav_uint32 bytesPerFrame; + drwav_uint64 bytesToRead; /* Intentionally uint64 instead of size_t so we can do a check that we're not reading too much on 32-bit builds. */ + + if (pWav == NULL || framesToRead == 0) { + return 0; + } + + /* Cannot use this function for compressed formats. */ + if (drwav__is_compressed_format_tag(pWav->translatedFormatTag)) { + return 0; + } + + bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + /* Don't try to read more samples than can potentially fit in the output buffer. */ + bytesToRead = framesToRead * bytesPerFrame; + if (bytesToRead > DRWAV_SIZE_MAX) { + bytesToRead = (DRWAV_SIZE_MAX / bytesPerFrame) * bytesPerFrame; /* Round the number of bytes to read to a clean frame boundary. */ + } + + /* + Doing an explicit check here just to make it clear that we don't want to be attempt to read anything if there's no bytes to read. There + *could* be a time where it evaluates to 0 due to overflowing. + */ + if (bytesToRead == 0) { + return 0; + } + + return drwav_read_raw(pWav, (size_t)bytesToRead, pBufferOut) / bytesPerFrame; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_be(drwav* pWav, drwav_uint64 framesToRead, void* pBufferOut) +{ + drwav_uint64 framesRead = drwav_read_pcm_frames_le(pWav, framesToRead, pBufferOut); + + if (pBufferOut != NULL) { + drwav__bswap_samples(pBufferOut, framesRead*pWav->channels, drwav_get_bytes_per_pcm_frame(pWav)/pWav->channels, pWav->translatedFormatTag); + } + + return framesRead; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames(drwav* pWav, drwav_uint64 framesToRead, void* pBufferOut) +{ + if (drwav__is_little_endian()) { + return drwav_read_pcm_frames_le(pWav, framesToRead, pBufferOut); + } else { + return drwav_read_pcm_frames_be(pWav, framesToRead, pBufferOut); + } +} + + + +DRWAV_API drwav_bool32 drwav_seek_to_first_pcm_frame(drwav* pWav) +{ + if (pWav->onWrite != NULL) { + return DRWAV_FALSE; /* No seeking in write mode. */ + } + + if (!pWav->onSeek(pWav->pUserData, (int)pWav->dataChunkDataPos, drwav_seek_origin_start)) { + return DRWAV_FALSE; + } + + if (drwav__is_compressed_format_tag(pWav->translatedFormatTag)) { + pWav->compressed.iCurrentPCMFrame = 0; + + /* Cached data needs to be cleared for compressed formats. */ + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) { + DRWAV_ZERO_OBJECT(&pWav->msadpcm); + } else if (pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) { + DRWAV_ZERO_OBJECT(&pWav->ima); + } else { + DRWAV_ASSERT(DRWAV_FALSE); /* If this assertion is triggered it means I've implemented a new compressed format but forgot to add a branch for it here. */ + } + } + + pWav->bytesRemaining = pWav->dataChunkDataSize; + return DRWAV_TRUE; +} + +DRWAV_API drwav_bool32 drwav_seek_to_pcm_frame(drwav* pWav, drwav_uint64 targetFrameIndex) +{ + /* Seeking should be compatible with wave files > 2GB. */ + + if (pWav == NULL || pWav->onSeek == NULL) { + return DRWAV_FALSE; + } + + /* No seeking in write mode. */ + if (pWav->onWrite != NULL) { + return DRWAV_FALSE; + } + + /* If there are no samples, just return DRWAV_TRUE without doing anything. */ + if (pWav->totalPCMFrameCount == 0) { + return DRWAV_TRUE; + } + + /* Make sure the sample is clamped. */ + if (targetFrameIndex >= pWav->totalPCMFrameCount) { + targetFrameIndex = pWav->totalPCMFrameCount - 1; + } + + /* + For compressed formats we just use a slow generic seek. If we are seeking forward we just seek forward. If we are going backwards we need + to seek back to the start. + */ + if (drwav__is_compressed_format_tag(pWav->translatedFormatTag)) { + /* TODO: This can be optimized. */ + + /* + If we're seeking forward it's simple - just keep reading samples until we hit the sample we're requesting. If we're seeking backwards, + we first need to seek back to the start and then just do the same thing as a forward seek. + */ + if (targetFrameIndex < pWav->compressed.iCurrentPCMFrame) { + if (!drwav_seek_to_first_pcm_frame(pWav)) { + return DRWAV_FALSE; + } + } + + if (targetFrameIndex > pWav->compressed.iCurrentPCMFrame) { + drwav_uint64 offsetInFrames = targetFrameIndex - pWav->compressed.iCurrentPCMFrame; + + drwav_int16 devnull[2048]; + while (offsetInFrames > 0) { + drwav_uint64 framesRead = 0; + drwav_uint64 framesToRead = offsetInFrames; + if (framesToRead > drwav_countof(devnull)/pWav->channels) { + framesToRead = drwav_countof(devnull)/pWav->channels; + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) { + framesRead = drwav_read_pcm_frames_s16__msadpcm(pWav, framesToRead, devnull); + } else if (pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) { + framesRead = drwav_read_pcm_frames_s16__ima(pWav, framesToRead, devnull); + } else { + DRWAV_ASSERT(DRWAV_FALSE); /* If this assertion is triggered it means I've implemented a new compressed format but forgot to add a branch for it here. */ + } + + if (framesRead != framesToRead) { + return DRWAV_FALSE; + } + + offsetInFrames -= framesRead; + } + } + } else { + drwav_uint64 totalSizeInBytes; + drwav_uint64 currentBytePos; + drwav_uint64 targetBytePos; + drwav_uint64 offset; + + totalSizeInBytes = pWav->totalPCMFrameCount * drwav_get_bytes_per_pcm_frame(pWav); + DRWAV_ASSERT(totalSizeInBytes >= pWav->bytesRemaining); + + currentBytePos = totalSizeInBytes - pWav->bytesRemaining; + targetBytePos = targetFrameIndex * drwav_get_bytes_per_pcm_frame(pWav); + + if (currentBytePos < targetBytePos) { + /* Offset forwards. */ + offset = (targetBytePos - currentBytePos); + } else { + /* Offset backwards. */ + if (!drwav_seek_to_first_pcm_frame(pWav)) { + return DRWAV_FALSE; + } + offset = targetBytePos; + } + + while (offset > 0) { + int offset32 = ((offset > INT_MAX) ? INT_MAX : (int)offset); + if (!pWav->onSeek(pWav->pUserData, offset32, drwav_seek_origin_current)) { + return DRWAV_FALSE; + } + + pWav->bytesRemaining -= offset32; + offset -= offset32; + } + } + + return DRWAV_TRUE; +} + + +DRWAV_API size_t drwav_write_raw(drwav* pWav, size_t bytesToWrite, const void* pData) +{ + size_t bytesWritten; + + if (pWav == NULL || bytesToWrite == 0 || pData == NULL) { + return 0; + } + + bytesWritten = pWav->onWrite(pWav->pUserData, pData, bytesToWrite); + pWav->dataChunkDataSize += bytesWritten; + + return bytesWritten; +} + + +DRWAV_API drwav_uint64 drwav_write_pcm_frames_le(drwav* pWav, drwav_uint64 framesToWrite, const void* pData) +{ + drwav_uint64 bytesToWrite; + drwav_uint64 bytesWritten; + const drwav_uint8* pRunningData; + + if (pWav == NULL || framesToWrite == 0 || pData == NULL) { + return 0; + } + + bytesToWrite = ((framesToWrite * pWav->channels * pWav->bitsPerSample) / 8); + if (bytesToWrite > DRWAV_SIZE_MAX) { + return 0; + } + + bytesWritten = 0; + pRunningData = (const drwav_uint8*)pData; + + while (bytesToWrite > 0) { + size_t bytesJustWritten; + drwav_uint64 bytesToWriteThisIteration; + + bytesToWriteThisIteration = bytesToWrite; + DRWAV_ASSERT(bytesToWriteThisIteration <= DRWAV_SIZE_MAX); /* <-- This is checked above. */ + + bytesJustWritten = drwav_write_raw(pWav, (size_t)bytesToWriteThisIteration, pRunningData); + if (bytesJustWritten == 0) { + break; + } + + bytesToWrite -= bytesJustWritten; + bytesWritten += bytesJustWritten; + pRunningData += bytesJustWritten; + } + + return (bytesWritten * 8) / pWav->bitsPerSample / pWav->channels; +} + +DRWAV_API drwav_uint64 drwav_write_pcm_frames_be(drwav* pWav, drwav_uint64 framesToWrite, const void* pData) +{ + drwav_uint64 bytesToWrite; + drwav_uint64 bytesWritten; + drwav_uint32 bytesPerSample; + const drwav_uint8* pRunningData; + + if (pWav == NULL || framesToWrite == 0 || pData == NULL) { + return 0; + } + + bytesToWrite = ((framesToWrite * pWav->channels * pWav->bitsPerSample) / 8); + if (bytesToWrite > DRWAV_SIZE_MAX) { + return 0; + } + + bytesWritten = 0; + pRunningData = (const drwav_uint8*)pData; + + bytesPerSample = drwav_get_bytes_per_pcm_frame(pWav) / pWav->channels; + + while (bytesToWrite > 0) { + drwav_uint8 temp[4096]; + drwav_uint32 sampleCount; + size_t bytesJustWritten; + drwav_uint64 bytesToWriteThisIteration; + + bytesToWriteThisIteration = bytesToWrite; + DRWAV_ASSERT(bytesToWriteThisIteration <= DRWAV_SIZE_MAX); /* <-- This is checked above. */ + + /* + WAV files are always little-endian. We need to byte swap on big-endian architectures. Since our input buffer is read-only we need + to use an intermediary buffer for the conversion. + */ + sampleCount = sizeof(temp)/bytesPerSample; + + if (bytesToWriteThisIteration > ((drwav_uint64)sampleCount)*bytesPerSample) { + bytesToWriteThisIteration = ((drwav_uint64)sampleCount)*bytesPerSample; + } + + DRWAV_COPY_MEMORY(temp, pRunningData, (size_t)bytesToWriteThisIteration); + drwav__bswap_samples(temp, sampleCount, bytesPerSample, pWav->translatedFormatTag); + + bytesJustWritten = drwav_write_raw(pWav, (size_t)bytesToWriteThisIteration, temp); + if (bytesJustWritten == 0) { + break; + } + + bytesToWrite -= bytesJustWritten; + bytesWritten += bytesJustWritten; + pRunningData += bytesJustWritten; + } + + return (bytesWritten * 8) / pWav->bitsPerSample / pWav->channels; +} + +DRWAV_API drwav_uint64 drwav_write_pcm_frames(drwav* pWav, drwav_uint64 framesToWrite, const void* pData) +{ + if (drwav__is_little_endian()) { + return drwav_write_pcm_frames_le(pWav, framesToWrite, pData); + } else { + return drwav_write_pcm_frames_be(pWav, framesToWrite, pData); + } +} + + +static drwav_uint64 drwav_read_pcm_frames_s16__msadpcm(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut) +{ + drwav_uint64 totalFramesRead = 0; + + DRWAV_ASSERT(pWav != NULL); + DRWAV_ASSERT(framesToRead > 0); + + /* TODO: Lots of room for optimization here. */ + + while (framesToRead > 0 && pWav->compressed.iCurrentPCMFrame < pWav->totalPCMFrameCount) { + /* If there are no cached frames we need to load a new block. */ + if (pWav->msadpcm.cachedFrameCount == 0 && pWav->msadpcm.bytesRemainingInBlock == 0) { + if (pWav->channels == 1) { + /* Mono. */ + drwav_uint8 header[7]; + if (pWav->onRead(pWav->pUserData, header, sizeof(header)) != sizeof(header)) { + return totalFramesRead; + } + pWav->msadpcm.bytesRemainingInBlock = pWav->fmt.blockAlign - sizeof(header); + + pWav->msadpcm.predictor[0] = header[0]; + pWav->msadpcm.delta[0] = drwav__bytes_to_s16(header + 1); + pWav->msadpcm.prevFrames[0][1] = (drwav_int32)drwav__bytes_to_s16(header + 3); + pWav->msadpcm.prevFrames[0][0] = (drwav_int32)drwav__bytes_to_s16(header + 5); + pWav->msadpcm.cachedFrames[2] = pWav->msadpcm.prevFrames[0][0]; + pWav->msadpcm.cachedFrames[3] = pWav->msadpcm.prevFrames[0][1]; + pWav->msadpcm.cachedFrameCount = 2; + } else { + /* Stereo. */ + drwav_uint8 header[14]; + if (pWav->onRead(pWav->pUserData, header, sizeof(header)) != sizeof(header)) { + return totalFramesRead; + } + pWav->msadpcm.bytesRemainingInBlock = pWav->fmt.blockAlign - sizeof(header); + + pWav->msadpcm.predictor[0] = header[0]; + pWav->msadpcm.predictor[1] = header[1]; + pWav->msadpcm.delta[0] = drwav__bytes_to_s16(header + 2); + pWav->msadpcm.delta[1] = drwav__bytes_to_s16(header + 4); + pWav->msadpcm.prevFrames[0][1] = (drwav_int32)drwav__bytes_to_s16(header + 6); + pWav->msadpcm.prevFrames[1][1] = (drwav_int32)drwav__bytes_to_s16(header + 8); + pWav->msadpcm.prevFrames[0][0] = (drwav_int32)drwav__bytes_to_s16(header + 10); + pWav->msadpcm.prevFrames[1][0] = (drwav_int32)drwav__bytes_to_s16(header + 12); + + pWav->msadpcm.cachedFrames[0] = pWav->msadpcm.prevFrames[0][0]; + pWav->msadpcm.cachedFrames[1] = pWav->msadpcm.prevFrames[1][0]; + pWav->msadpcm.cachedFrames[2] = pWav->msadpcm.prevFrames[0][1]; + pWav->msadpcm.cachedFrames[3] = pWav->msadpcm.prevFrames[1][1]; + pWav->msadpcm.cachedFrameCount = 2; + } + } + + /* Output anything that's cached. */ + while (framesToRead > 0 && pWav->msadpcm.cachedFrameCount > 0 && pWav->compressed.iCurrentPCMFrame < pWav->totalPCMFrameCount) { + if (pBufferOut != NULL) { + drwav_uint32 iSample = 0; + for (iSample = 0; iSample < pWav->channels; iSample += 1) { + pBufferOut[iSample] = (drwav_int16)pWav->msadpcm.cachedFrames[(drwav_countof(pWav->msadpcm.cachedFrames) - (pWav->msadpcm.cachedFrameCount*pWav->channels)) + iSample]; + } + + pBufferOut += pWav->channels; + } + + framesToRead -= 1; + totalFramesRead += 1; + pWav->compressed.iCurrentPCMFrame += 1; + pWav->msadpcm.cachedFrameCount -= 1; + } + + if (framesToRead == 0) { + return totalFramesRead; + } + + + /* + If there's nothing left in the cache, just go ahead and load more. If there's nothing left to load in the current block we just continue to the next + loop iteration which will trigger the loading of a new block. + */ + if (pWav->msadpcm.cachedFrameCount == 0) { + if (pWav->msadpcm.bytesRemainingInBlock == 0) { + continue; + } else { + static drwav_int32 adaptationTable[] = { + 230, 230, 230, 230, 307, 409, 512, 614, + 768, 614, 512, 409, 307, 230, 230, 230 + }; + static drwav_int32 coeff1Table[] = { 256, 512, 0, 192, 240, 460, 392 }; + static drwav_int32 coeff2Table[] = { 0, -256, 0, 64, 0, -208, -232 }; + + drwav_uint8 nibbles; + drwav_int32 nibble0; + drwav_int32 nibble1; + + if (pWav->onRead(pWav->pUserData, &nibbles, 1) != 1) { + return totalFramesRead; + } + pWav->msadpcm.bytesRemainingInBlock -= 1; + + /* TODO: Optimize away these if statements. */ + nibble0 = ((nibbles & 0xF0) >> 4); if ((nibbles & 0x80)) { nibble0 |= 0xFFFFFFF0UL; } + nibble1 = ((nibbles & 0x0F) >> 0); if ((nibbles & 0x08)) { nibble1 |= 0xFFFFFFF0UL; } + + if (pWav->channels == 1) { + /* Mono. */ + drwav_int32 newSample0; + drwav_int32 newSample1; + + newSample0 = ((pWav->msadpcm.prevFrames[0][1] * coeff1Table[pWav->msadpcm.predictor[0]]) + (pWav->msadpcm.prevFrames[0][0] * coeff2Table[pWav->msadpcm.predictor[0]])) >> 8; + newSample0 += nibble0 * pWav->msadpcm.delta[0]; + newSample0 = drwav_clamp(newSample0, -32768, 32767); + + pWav->msadpcm.delta[0] = (adaptationTable[((nibbles & 0xF0) >> 4)] * pWav->msadpcm.delta[0]) >> 8; + if (pWav->msadpcm.delta[0] < 16) { + pWav->msadpcm.delta[0] = 16; + } + + pWav->msadpcm.prevFrames[0][0] = pWav->msadpcm.prevFrames[0][1]; + pWav->msadpcm.prevFrames[0][1] = newSample0; + + + newSample1 = ((pWav->msadpcm.prevFrames[0][1] * coeff1Table[pWav->msadpcm.predictor[0]]) + (pWav->msadpcm.prevFrames[0][0] * coeff2Table[pWav->msadpcm.predictor[0]])) >> 8; + newSample1 += nibble1 * pWav->msadpcm.delta[0]; + newSample1 = drwav_clamp(newSample1, -32768, 32767); + + pWav->msadpcm.delta[0] = (adaptationTable[((nibbles & 0x0F) >> 0)] * pWav->msadpcm.delta[0]) >> 8; + if (pWav->msadpcm.delta[0] < 16) { + pWav->msadpcm.delta[0] = 16; + } + + pWav->msadpcm.prevFrames[0][0] = pWav->msadpcm.prevFrames[0][1]; + pWav->msadpcm.prevFrames[0][1] = newSample1; + + + pWav->msadpcm.cachedFrames[2] = newSample0; + pWav->msadpcm.cachedFrames[3] = newSample1; + pWav->msadpcm.cachedFrameCount = 2; + } else { + /* Stereo. */ + drwav_int32 newSample0; + drwav_int32 newSample1; + + /* Left. */ + newSample0 = ((pWav->msadpcm.prevFrames[0][1] * coeff1Table[pWav->msadpcm.predictor[0]]) + (pWav->msadpcm.prevFrames[0][0] * coeff2Table[pWav->msadpcm.predictor[0]])) >> 8; + newSample0 += nibble0 * pWav->msadpcm.delta[0]; + newSample0 = drwav_clamp(newSample0, -32768, 32767); + + pWav->msadpcm.delta[0] = (adaptationTable[((nibbles & 0xF0) >> 4)] * pWav->msadpcm.delta[0]) >> 8; + if (pWav->msadpcm.delta[0] < 16) { + pWav->msadpcm.delta[0] = 16; + } + + pWav->msadpcm.prevFrames[0][0] = pWav->msadpcm.prevFrames[0][1]; + pWav->msadpcm.prevFrames[0][1] = newSample0; + + + /* Right. */ + newSample1 = ((pWav->msadpcm.prevFrames[1][1] * coeff1Table[pWav->msadpcm.predictor[1]]) + (pWav->msadpcm.prevFrames[1][0] * coeff2Table[pWav->msadpcm.predictor[1]])) >> 8; + newSample1 += nibble1 * pWav->msadpcm.delta[1]; + newSample1 = drwav_clamp(newSample1, -32768, 32767); + + pWav->msadpcm.delta[1] = (adaptationTable[((nibbles & 0x0F) >> 0)] * pWav->msadpcm.delta[1]) >> 8; + if (pWav->msadpcm.delta[1] < 16) { + pWav->msadpcm.delta[1] = 16; + } + + pWav->msadpcm.prevFrames[1][0] = pWav->msadpcm.prevFrames[1][1]; + pWav->msadpcm.prevFrames[1][1] = newSample1; + + pWav->msadpcm.cachedFrames[2] = newSample0; + pWav->msadpcm.cachedFrames[3] = newSample1; + pWav->msadpcm.cachedFrameCount = 1; + } + } + } + } + + return totalFramesRead; +} + + +static drwav_uint64 drwav_read_pcm_frames_s16__ima(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut) +{ + drwav_uint64 totalFramesRead = 0; + drwav_uint32 iChannel; + + static drwav_int32 indexTable[16] = { + -1, -1, -1, -1, 2, 4, 6, 8, + -1, -1, -1, -1, 2, 4, 6, 8 + }; + + static drwav_int32 stepTable[89] = { + 7, 8, 9, 10, 11, 12, 13, 14, 16, 17, + 19, 21, 23, 25, 28, 31, 34, 37, 41, 45, + 50, 55, 60, 66, 73, 80, 88, 97, 107, 118, + 130, 143, 157, 173, 190, 209, 230, 253, 279, 307, + 337, 371, 408, 449, 494, 544, 598, 658, 724, 796, + 876, 963, 1060, 1166, 1282, 1411, 1552, 1707, 1878, 2066, + 2272, 2499, 2749, 3024, 3327, 3660, 4026, 4428, 4871, 5358, + 5894, 6484, 7132, 7845, 8630, 9493, 10442, 11487, 12635, 13899, + 15289, 16818, 18500, 20350, 22385, 24623, 27086, 29794, 32767 + }; + + DRWAV_ASSERT(pWav != NULL); + DRWAV_ASSERT(framesToRead > 0); + + /* TODO: Lots of room for optimization here. */ + + while (framesToRead > 0 && pWav->compressed.iCurrentPCMFrame < pWav->totalPCMFrameCount) { + /* If there are no cached samples we need to load a new block. */ + if (pWav->ima.cachedFrameCount == 0 && pWav->ima.bytesRemainingInBlock == 0) { + if (pWav->channels == 1) { + /* Mono. */ + drwav_uint8 header[4]; + if (pWav->onRead(pWav->pUserData, header, sizeof(header)) != sizeof(header)) { + return totalFramesRead; + } + pWav->ima.bytesRemainingInBlock = pWav->fmt.blockAlign - sizeof(header); + + if (header[2] >= drwav_countof(stepTable)) { + pWav->onSeek(pWav->pUserData, pWav->ima.bytesRemainingInBlock, drwav_seek_origin_current); + pWav->ima.bytesRemainingInBlock = 0; + return totalFramesRead; /* Invalid data. */ + } + + pWav->ima.predictor[0] = drwav__bytes_to_s16(header + 0); + pWav->ima.stepIndex[0] = header[2]; + pWav->ima.cachedFrames[drwav_countof(pWav->ima.cachedFrames) - 1] = pWav->ima.predictor[0]; + pWav->ima.cachedFrameCount = 1; + } else { + /* Stereo. */ + drwav_uint8 header[8]; + if (pWav->onRead(pWav->pUserData, header, sizeof(header)) != sizeof(header)) { + return totalFramesRead; + } + pWav->ima.bytesRemainingInBlock = pWav->fmt.blockAlign - sizeof(header); + + if (header[2] >= drwav_countof(stepTable) || header[6] >= drwav_countof(stepTable)) { + pWav->onSeek(pWav->pUserData, pWav->ima.bytesRemainingInBlock, drwav_seek_origin_current); + pWav->ima.bytesRemainingInBlock = 0; + return totalFramesRead; /* Invalid data. */ + } + + pWav->ima.predictor[0] = drwav__bytes_to_s16(header + 0); + pWav->ima.stepIndex[0] = header[2]; + pWav->ima.predictor[1] = drwav__bytes_to_s16(header + 4); + pWav->ima.stepIndex[1] = header[6]; + + pWav->ima.cachedFrames[drwav_countof(pWav->ima.cachedFrames) - 2] = pWav->ima.predictor[0]; + pWav->ima.cachedFrames[drwav_countof(pWav->ima.cachedFrames) - 1] = pWav->ima.predictor[1]; + pWav->ima.cachedFrameCount = 1; + } + } + + /* Output anything that's cached. */ + while (framesToRead > 0 && pWav->ima.cachedFrameCount > 0 && pWav->compressed.iCurrentPCMFrame < pWav->totalPCMFrameCount) { + if (pBufferOut != NULL) { + drwav_uint32 iSample; + for (iSample = 0; iSample < pWav->channels; iSample += 1) { + pBufferOut[iSample] = (drwav_int16)pWav->ima.cachedFrames[(drwav_countof(pWav->ima.cachedFrames) - (pWav->ima.cachedFrameCount*pWav->channels)) + iSample]; + } + pBufferOut += pWav->channels; + } + + framesToRead -= 1; + totalFramesRead += 1; + pWav->compressed.iCurrentPCMFrame += 1; + pWav->ima.cachedFrameCount -= 1; + } + + if (framesToRead == 0) { + return totalFramesRead; + } + + /* + If there's nothing left in the cache, just go ahead and load more. If there's nothing left to load in the current block we just continue to the next + loop iteration which will trigger the loading of a new block. + */ + if (pWav->ima.cachedFrameCount == 0) { + if (pWav->ima.bytesRemainingInBlock == 0) { + continue; + } else { + /* + From what I can tell with stereo streams, it looks like every 4 bytes (8 samples) is for one channel. So it goes 4 bytes for the + left channel, 4 bytes for the right channel. + */ + pWav->ima.cachedFrameCount = 8; + for (iChannel = 0; iChannel < pWav->channels; ++iChannel) { + drwav_uint32 iByte; + drwav_uint8 nibbles[4]; + if (pWav->onRead(pWav->pUserData, &nibbles, 4) != 4) { + pWav->ima.cachedFrameCount = 0; + return totalFramesRead; + } + pWav->ima.bytesRemainingInBlock -= 4; + + for (iByte = 0; iByte < 4; ++iByte) { + drwav_uint8 nibble0 = ((nibbles[iByte] & 0x0F) >> 0); + drwav_uint8 nibble1 = ((nibbles[iByte] & 0xF0) >> 4); + + drwav_int32 step = stepTable[pWav->ima.stepIndex[iChannel]]; + drwav_int32 predictor = pWav->ima.predictor[iChannel]; + + drwav_int32 diff = step >> 3; + if (nibble0 & 1) diff += step >> 2; + if (nibble0 & 2) diff += step >> 1; + if (nibble0 & 4) diff += step; + if (nibble0 & 8) diff = -diff; + + predictor = drwav_clamp(predictor + diff, -32768, 32767); + pWav->ima.predictor[iChannel] = predictor; + pWav->ima.stepIndex[iChannel] = drwav_clamp(pWav->ima.stepIndex[iChannel] + indexTable[nibble0], 0, (drwav_int32)drwav_countof(stepTable)-1); + pWav->ima.cachedFrames[(drwav_countof(pWav->ima.cachedFrames) - (pWav->ima.cachedFrameCount*pWav->channels)) + (iByte*2+0)*pWav->channels + iChannel] = predictor; + + + step = stepTable[pWav->ima.stepIndex[iChannel]]; + predictor = pWav->ima.predictor[iChannel]; + + diff = step >> 3; + if (nibble1 & 1) diff += step >> 2; + if (nibble1 & 2) diff += step >> 1; + if (nibble1 & 4) diff += step; + if (nibble1 & 8) diff = -diff; + + predictor = drwav_clamp(predictor + diff, -32768, 32767); + pWav->ima.predictor[iChannel] = predictor; + pWav->ima.stepIndex[iChannel] = drwav_clamp(pWav->ima.stepIndex[iChannel] + indexTable[nibble1], 0, (drwav_int32)drwav_countof(stepTable)-1); + pWav->ima.cachedFrames[(drwav_countof(pWav->ima.cachedFrames) - (pWav->ima.cachedFrameCount*pWav->channels)) + (iByte*2+1)*pWav->channels + iChannel] = predictor; + } + } + } + } + } + + return totalFramesRead; +} + + +#ifndef DR_WAV_NO_CONVERSION_API +static unsigned short g_drwavAlawTable[256] = { + 0xEA80, 0xEB80, 0xE880, 0xE980, 0xEE80, 0xEF80, 0xEC80, 0xED80, 0xE280, 0xE380, 0xE080, 0xE180, 0xE680, 0xE780, 0xE480, 0xE580, + 0xF540, 0xF5C0, 0xF440, 0xF4C0, 0xF740, 0xF7C0, 0xF640, 0xF6C0, 0xF140, 0xF1C0, 0xF040, 0xF0C0, 0xF340, 0xF3C0, 0xF240, 0xF2C0, + 0xAA00, 0xAE00, 0xA200, 0xA600, 0xBA00, 0xBE00, 0xB200, 0xB600, 0x8A00, 0x8E00, 0x8200, 0x8600, 0x9A00, 0x9E00, 0x9200, 0x9600, + 0xD500, 0xD700, 0xD100, 0xD300, 0xDD00, 0xDF00, 0xD900, 0xDB00, 0xC500, 0xC700, 0xC100, 0xC300, 0xCD00, 0xCF00, 0xC900, 0xCB00, + 0xFEA8, 0xFEB8, 0xFE88, 0xFE98, 0xFEE8, 0xFEF8, 0xFEC8, 0xFED8, 0xFE28, 0xFE38, 0xFE08, 0xFE18, 0xFE68, 0xFE78, 0xFE48, 0xFE58, + 0xFFA8, 0xFFB8, 0xFF88, 0xFF98, 0xFFE8, 0xFFF8, 0xFFC8, 0xFFD8, 0xFF28, 0xFF38, 0xFF08, 0xFF18, 0xFF68, 0xFF78, 0xFF48, 0xFF58, + 0xFAA0, 0xFAE0, 0xFA20, 0xFA60, 0xFBA0, 0xFBE0, 0xFB20, 0xFB60, 0xF8A0, 0xF8E0, 0xF820, 0xF860, 0xF9A0, 0xF9E0, 0xF920, 0xF960, + 0xFD50, 0xFD70, 0xFD10, 0xFD30, 0xFDD0, 0xFDF0, 0xFD90, 0xFDB0, 0xFC50, 0xFC70, 0xFC10, 0xFC30, 0xFCD0, 0xFCF0, 0xFC90, 0xFCB0, + 0x1580, 0x1480, 0x1780, 0x1680, 0x1180, 0x1080, 0x1380, 0x1280, 0x1D80, 0x1C80, 0x1F80, 0x1E80, 0x1980, 0x1880, 0x1B80, 0x1A80, + 0x0AC0, 0x0A40, 0x0BC0, 0x0B40, 0x08C0, 0x0840, 0x09C0, 0x0940, 0x0EC0, 0x0E40, 0x0FC0, 0x0F40, 0x0CC0, 0x0C40, 0x0DC0, 0x0D40, + 0x5600, 0x5200, 0x5E00, 0x5A00, 0x4600, 0x4200, 0x4E00, 0x4A00, 0x7600, 0x7200, 0x7E00, 0x7A00, 0x6600, 0x6200, 0x6E00, 0x6A00, + 0x2B00, 0x2900, 0x2F00, 0x2D00, 0x2300, 0x2100, 0x2700, 0x2500, 0x3B00, 0x3900, 0x3F00, 0x3D00, 0x3300, 0x3100, 0x3700, 0x3500, + 0x0158, 0x0148, 0x0178, 0x0168, 0x0118, 0x0108, 0x0138, 0x0128, 0x01D8, 0x01C8, 0x01F8, 0x01E8, 0x0198, 0x0188, 0x01B8, 0x01A8, + 0x0058, 0x0048, 0x0078, 0x0068, 0x0018, 0x0008, 0x0038, 0x0028, 0x00D8, 0x00C8, 0x00F8, 0x00E8, 0x0098, 0x0088, 0x00B8, 0x00A8, + 0x0560, 0x0520, 0x05E0, 0x05A0, 0x0460, 0x0420, 0x04E0, 0x04A0, 0x0760, 0x0720, 0x07E0, 0x07A0, 0x0660, 0x0620, 0x06E0, 0x06A0, + 0x02B0, 0x0290, 0x02F0, 0x02D0, 0x0230, 0x0210, 0x0270, 0x0250, 0x03B0, 0x0390, 0x03F0, 0x03D0, 0x0330, 0x0310, 0x0370, 0x0350 +}; + +static unsigned short g_drwavMulawTable[256] = { + 0x8284, 0x8684, 0x8A84, 0x8E84, 0x9284, 0x9684, 0x9A84, 0x9E84, 0xA284, 0xA684, 0xAA84, 0xAE84, 0xB284, 0xB684, 0xBA84, 0xBE84, + 0xC184, 0xC384, 0xC584, 0xC784, 0xC984, 0xCB84, 0xCD84, 0xCF84, 0xD184, 0xD384, 0xD584, 0xD784, 0xD984, 0xDB84, 0xDD84, 0xDF84, + 0xE104, 0xE204, 0xE304, 0xE404, 0xE504, 0xE604, 0xE704, 0xE804, 0xE904, 0xEA04, 0xEB04, 0xEC04, 0xED04, 0xEE04, 0xEF04, 0xF004, + 0xF0C4, 0xF144, 0xF1C4, 0xF244, 0xF2C4, 0xF344, 0xF3C4, 0xF444, 0xF4C4, 0xF544, 0xF5C4, 0xF644, 0xF6C4, 0xF744, 0xF7C4, 0xF844, + 0xF8A4, 0xF8E4, 0xF924, 0xF964, 0xF9A4, 0xF9E4, 0xFA24, 0xFA64, 0xFAA4, 0xFAE4, 0xFB24, 0xFB64, 0xFBA4, 0xFBE4, 0xFC24, 0xFC64, + 0xFC94, 0xFCB4, 0xFCD4, 0xFCF4, 0xFD14, 0xFD34, 0xFD54, 0xFD74, 0xFD94, 0xFDB4, 0xFDD4, 0xFDF4, 0xFE14, 0xFE34, 0xFE54, 0xFE74, + 0xFE8C, 0xFE9C, 0xFEAC, 0xFEBC, 0xFECC, 0xFEDC, 0xFEEC, 0xFEFC, 0xFF0C, 0xFF1C, 0xFF2C, 0xFF3C, 0xFF4C, 0xFF5C, 0xFF6C, 0xFF7C, + 0xFF88, 0xFF90, 0xFF98, 0xFFA0, 0xFFA8, 0xFFB0, 0xFFB8, 0xFFC0, 0xFFC8, 0xFFD0, 0xFFD8, 0xFFE0, 0xFFE8, 0xFFF0, 0xFFF8, 0x0000, + 0x7D7C, 0x797C, 0x757C, 0x717C, 0x6D7C, 0x697C, 0x657C, 0x617C, 0x5D7C, 0x597C, 0x557C, 0x517C, 0x4D7C, 0x497C, 0x457C, 0x417C, + 0x3E7C, 0x3C7C, 0x3A7C, 0x387C, 0x367C, 0x347C, 0x327C, 0x307C, 0x2E7C, 0x2C7C, 0x2A7C, 0x287C, 0x267C, 0x247C, 0x227C, 0x207C, + 0x1EFC, 0x1DFC, 0x1CFC, 0x1BFC, 0x1AFC, 0x19FC, 0x18FC, 0x17FC, 0x16FC, 0x15FC, 0x14FC, 0x13FC, 0x12FC, 0x11FC, 0x10FC, 0x0FFC, + 0x0F3C, 0x0EBC, 0x0E3C, 0x0DBC, 0x0D3C, 0x0CBC, 0x0C3C, 0x0BBC, 0x0B3C, 0x0ABC, 0x0A3C, 0x09BC, 0x093C, 0x08BC, 0x083C, 0x07BC, + 0x075C, 0x071C, 0x06DC, 0x069C, 0x065C, 0x061C, 0x05DC, 0x059C, 0x055C, 0x051C, 0x04DC, 0x049C, 0x045C, 0x041C, 0x03DC, 0x039C, + 0x036C, 0x034C, 0x032C, 0x030C, 0x02EC, 0x02CC, 0x02AC, 0x028C, 0x026C, 0x024C, 0x022C, 0x020C, 0x01EC, 0x01CC, 0x01AC, 0x018C, + 0x0174, 0x0164, 0x0154, 0x0144, 0x0134, 0x0124, 0x0114, 0x0104, 0x00F4, 0x00E4, 0x00D4, 0x00C4, 0x00B4, 0x00A4, 0x0094, 0x0084, + 0x0078, 0x0070, 0x0068, 0x0060, 0x0058, 0x0050, 0x0048, 0x0040, 0x0038, 0x0030, 0x0028, 0x0020, 0x0018, 0x0010, 0x0008, 0x0000 +}; + +static DRWAV_INLINE drwav_int16 drwav__alaw_to_s16(drwav_uint8 sampleIn) +{ + return (short)g_drwavAlawTable[sampleIn]; +} + +static DRWAV_INLINE drwav_int16 drwav__mulaw_to_s16(drwav_uint8 sampleIn) +{ + return (short)g_drwavMulawTable[sampleIn]; +} + + + +static void drwav__pcm_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t totalSampleCount, unsigned int bytesPerSample) +{ + unsigned int i; + + /* Special case for 8-bit sample data because it's treated as unsigned. */ + if (bytesPerSample == 1) { + drwav_u8_to_s16(pOut, pIn, totalSampleCount); + return; + } + + + /* Slightly more optimal implementation for common formats. */ + if (bytesPerSample == 2) { + for (i = 0; i < totalSampleCount; ++i) { + *pOut++ = ((const drwav_int16*)pIn)[i]; + } + return; + } + if (bytesPerSample == 3) { + drwav_s24_to_s16(pOut, pIn, totalSampleCount); + return; + } + if (bytesPerSample == 4) { + drwav_s32_to_s16(pOut, (const drwav_int32*)pIn, totalSampleCount); + return; + } + + + /* Anything more than 64 bits per sample is not supported. */ + if (bytesPerSample > 8) { + DRWAV_ZERO_MEMORY(pOut, totalSampleCount * sizeof(*pOut)); + return; + } + + + /* Generic, slow converter. */ + for (i = 0; i < totalSampleCount; ++i) { + drwav_uint64 sample = 0; + unsigned int shift = (8 - bytesPerSample) * 8; + + unsigned int j; + for (j = 0; j < bytesPerSample; j += 1) { + DRWAV_ASSERT(j < 8); + sample |= (drwav_uint64)(pIn[j]) << shift; + shift += 8; + } + + pIn += j; + *pOut++ = (drwav_int16)((drwav_int64)sample >> 48); + } +} + +static void drwav__ieee_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t totalSampleCount, unsigned int bytesPerSample) +{ + if (bytesPerSample == 4) { + drwav_f32_to_s16(pOut, (const float*)pIn, totalSampleCount); + return; + } else if (bytesPerSample == 8) { + drwav_f64_to_s16(pOut, (const double*)pIn, totalSampleCount); + return; + } else { + /* Only supporting 32- and 64-bit float. Output silence in all other cases. Contributions welcome for 16-bit float. */ + DRWAV_ZERO_MEMORY(pOut, totalSampleCount * sizeof(*pOut)); + return; + } +} + +static drwav_uint64 drwav_read_pcm_frames_s16__pcm(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut) +{ + drwav_uint32 bytesPerFrame; + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + + /* Fast path. */ + if ((pWav->translatedFormatTag == DR_WAVE_FORMAT_PCM && pWav->bitsPerSample == 16) || pBufferOut == NULL) { + return drwav_read_pcm_frames(pWav, framesToRead, pBufferOut); + } + + bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav__pcm_to_s16(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels), bytesPerFrame/pWav->channels); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_s16__ieee(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + drwav_uint32 bytesPerFrame; + + if (pBufferOut == NULL) { + return drwav_read_pcm_frames(pWav, framesToRead, NULL); + } + + bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav__ieee_to_s16(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels), bytesPerFrame/pWav->channels); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_s16__alaw(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + drwav_uint32 bytesPerFrame; + + if (pBufferOut == NULL) { + return drwav_read_pcm_frames(pWav, framesToRead, NULL); + } + + bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav_alaw_to_s16(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels)); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_s16__mulaw(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + drwav_uint32 bytesPerFrame; + + if (pBufferOut == NULL) { + return drwav_read_pcm_frames(pWav, framesToRead, NULL); + } + + bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav_mulaw_to_s16(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels)); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s16(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut) +{ + if (pWav == NULL || framesToRead == 0) { + return 0; + } + + if (pBufferOut == NULL) { + return drwav_read_pcm_frames(pWav, framesToRead, NULL); + } + + /* Don't try to read more samples than can potentially fit in the output buffer. */ + if (framesToRead * pWav->channels * sizeof(drwav_int16) > DRWAV_SIZE_MAX) { + framesToRead = DRWAV_SIZE_MAX / sizeof(drwav_int16) / pWav->channels; + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_PCM) { + return drwav_read_pcm_frames_s16__pcm(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_IEEE_FLOAT) { + return drwav_read_pcm_frames_s16__ieee(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ALAW) { + return drwav_read_pcm_frames_s16__alaw(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_MULAW) { + return drwav_read_pcm_frames_s16__mulaw(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) { + return drwav_read_pcm_frames_s16__msadpcm(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) { + return drwav_read_pcm_frames_s16__ima(pWav, framesToRead, pBufferOut); + } + + return 0; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s16le(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut) +{ + drwav_uint64 framesRead = drwav_read_pcm_frames_s16(pWav, framesToRead, pBufferOut); + if (pBufferOut != NULL && drwav__is_little_endian() == DRWAV_FALSE) { + drwav__bswap_samples_s16(pBufferOut, framesRead*pWav->channels); + } + + return framesRead; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s16be(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut) +{ + drwav_uint64 framesRead = drwav_read_pcm_frames_s16(pWav, framesToRead, pBufferOut); + if (pBufferOut != NULL && drwav__is_little_endian() == DRWAV_TRUE) { + drwav__bswap_samples_s16(pBufferOut, framesRead*pWav->channels); + } + + return framesRead; +} + + +DRWAV_API void drwav_u8_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + int r; + size_t i; + for (i = 0; i < sampleCount; ++i) { + int x = pIn[i]; + r = x << 8; + r = r - 32768; + pOut[i] = (short)r; + } +} + +DRWAV_API void drwav_s24_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + int r; + size_t i; + for (i = 0; i < sampleCount; ++i) { + int x = ((int)(((unsigned int)(((const drwav_uint8*)pIn)[i*3+0]) << 8) | ((unsigned int)(((const drwav_uint8*)pIn)[i*3+1]) << 16) | ((unsigned int)(((const drwav_uint8*)pIn)[i*3+2])) << 24)) >> 8; + r = x >> 8; + pOut[i] = (short)r; + } +} + +DRWAV_API void drwav_s32_to_s16(drwav_int16* pOut, const drwav_int32* pIn, size_t sampleCount) +{ + int r; + size_t i; + for (i = 0; i < sampleCount; ++i) { + int x = pIn[i]; + r = x >> 16; + pOut[i] = (short)r; + } +} + +DRWAV_API void drwav_f32_to_s16(drwav_int16* pOut, const float* pIn, size_t sampleCount) +{ + int r; + size_t i; + for (i = 0; i < sampleCount; ++i) { + float x = pIn[i]; + float c; + c = ((x < -1) ? -1 : ((x > 1) ? 1 : x)); + c = c + 1; + r = (int)(c * 32767.5f); + r = r - 32768; + pOut[i] = (short)r; + } +} + +DRWAV_API void drwav_f64_to_s16(drwav_int16* pOut, const double* pIn, size_t sampleCount) +{ + int r; + size_t i; + for (i = 0; i < sampleCount; ++i) { + double x = pIn[i]; + double c; + c = ((x < -1) ? -1 : ((x > 1) ? 1 : x)); + c = c + 1; + r = (int)(c * 32767.5); + r = r - 32768; + pOut[i] = (short)r; + } +} + +DRWAV_API void drwav_alaw_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + size_t i; + for (i = 0; i < sampleCount; ++i) { + pOut[i] = drwav__alaw_to_s16(pIn[i]); + } +} + +DRWAV_API void drwav_mulaw_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + size_t i; + for (i = 0; i < sampleCount; ++i) { + pOut[i] = drwav__mulaw_to_s16(pIn[i]); + } +} + + + +static void drwav__pcm_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount, unsigned int bytesPerSample) +{ + unsigned int i; + + /* Special case for 8-bit sample data because it's treated as unsigned. */ + if (bytesPerSample == 1) { + drwav_u8_to_f32(pOut, pIn, sampleCount); + return; + } + + /* Slightly more optimal implementation for common formats. */ + if (bytesPerSample == 2) { + drwav_s16_to_f32(pOut, (const drwav_int16*)pIn, sampleCount); + return; + } + if (bytesPerSample == 3) { + drwav_s24_to_f32(pOut, pIn, sampleCount); + return; + } + if (bytesPerSample == 4) { + drwav_s32_to_f32(pOut, (const drwav_int32*)pIn, sampleCount); + return; + } + + + /* Anything more than 64 bits per sample is not supported. */ + if (bytesPerSample > 8) { + DRWAV_ZERO_MEMORY(pOut, sampleCount * sizeof(*pOut)); + return; + } + + + /* Generic, slow converter. */ + for (i = 0; i < sampleCount; ++i) { + drwav_uint64 sample = 0; + unsigned int shift = (8 - bytesPerSample) * 8; + + unsigned int j; + for (j = 0; j < bytesPerSample; j += 1) { + DRWAV_ASSERT(j < 8); + sample |= (drwav_uint64)(pIn[j]) << shift; + shift += 8; + } + + pIn += j; + *pOut++ = (float)((drwav_int64)sample / 9223372036854775807.0); + } +} + +static void drwav__ieee_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount, unsigned int bytesPerSample) +{ + if (bytesPerSample == 4) { + unsigned int i; + for (i = 0; i < sampleCount; ++i) { + *pOut++ = ((const float*)pIn)[i]; + } + return; + } else if (bytesPerSample == 8) { + drwav_f64_to_f32(pOut, (const double*)pIn, sampleCount); + return; + } else { + /* Only supporting 32- and 64-bit float. Output silence in all other cases. Contributions welcome for 16-bit float. */ + DRWAV_ZERO_MEMORY(pOut, sampleCount * sizeof(*pOut)); + return; + } +} + + +static drwav_uint64 drwav_read_pcm_frames_f32__pcm(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + + drwav_uint32 bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav__pcm_to_f32(pBufferOut, sampleData, (size_t)framesRead*pWav->channels, bytesPerFrame/pWav->channels); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_f32__msadpcm(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut) +{ + /* + We're just going to borrow the implementation from the drwav_read_s16() since ADPCM is a little bit more complicated than other formats and I don't + want to duplicate that code. + */ + drwav_uint64 totalFramesRead = 0; + drwav_int16 samples16[2048]; + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames_s16(pWav, drwav_min(framesToRead, drwav_countof(samples16)/pWav->channels), samples16); + if (framesRead == 0) { + break; + } + + drwav_s16_to_f32(pBufferOut, samples16, (size_t)(framesRead*pWav->channels)); /* <-- Safe cast because we're clamping to 2048. */ + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_f32__ima(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut) +{ + /* + We're just going to borrow the implementation from the drwav_read_s16() since IMA-ADPCM is a little bit more complicated than other formats and I don't + want to duplicate that code. + */ + drwav_uint64 totalFramesRead = 0; + drwav_int16 samples16[2048]; + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames_s16(pWav, drwav_min(framesToRead, drwav_countof(samples16)/pWav->channels), samples16); + if (framesRead == 0) { + break; + } + + drwav_s16_to_f32(pBufferOut, samples16, (size_t)(framesRead*pWav->channels)); /* <-- Safe cast because we're clamping to 2048. */ + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_f32__ieee(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + drwav_uint32 bytesPerFrame; + + /* Fast path. */ + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_IEEE_FLOAT && pWav->bitsPerSample == 32) { + return drwav_read_pcm_frames(pWav, framesToRead, pBufferOut); + } + + bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav__ieee_to_f32(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels), bytesPerFrame/pWav->channels); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_f32__alaw(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + drwav_uint32 bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav_alaw_to_f32(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels)); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_f32__mulaw(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + + drwav_uint32 bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav_mulaw_to_f32(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels)); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_f32(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut) +{ + if (pWav == NULL || framesToRead == 0) { + return 0; + } + + if (pBufferOut == NULL) { + return drwav_read_pcm_frames(pWav, framesToRead, NULL); + } + + /* Don't try to read more samples than can potentially fit in the output buffer. */ + if (framesToRead * pWav->channels * sizeof(float) > DRWAV_SIZE_MAX) { + framesToRead = DRWAV_SIZE_MAX / sizeof(float) / pWav->channels; + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_PCM) { + return drwav_read_pcm_frames_f32__pcm(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) { + return drwav_read_pcm_frames_f32__msadpcm(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_IEEE_FLOAT) { + return drwav_read_pcm_frames_f32__ieee(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ALAW) { + return drwav_read_pcm_frames_f32__alaw(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_MULAW) { + return drwav_read_pcm_frames_f32__mulaw(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) { + return drwav_read_pcm_frames_f32__ima(pWav, framesToRead, pBufferOut); + } + + return 0; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_f32le(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut) +{ + drwav_uint64 framesRead = drwav_read_pcm_frames_f32(pWav, framesToRead, pBufferOut); + if (pBufferOut != NULL && drwav__is_little_endian() == DRWAV_FALSE) { + drwav__bswap_samples_f32(pBufferOut, framesRead*pWav->channels); + } + + return framesRead; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_f32be(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut) +{ + drwav_uint64 framesRead = drwav_read_pcm_frames_f32(pWav, framesToRead, pBufferOut); + if (pBufferOut != NULL && drwav__is_little_endian() == DRWAV_TRUE) { + drwav__bswap_samples_f32(pBufferOut, framesRead*pWav->channels); + } + + return framesRead; +} + + +DRWAV_API void drwav_u8_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + +#ifdef DR_WAV_LIBSNDFILE_COMPAT + /* + It appears libsndfile uses slightly different logic for the u8 -> f32 conversion to dr_wav, which in my opinion is incorrect. It appears + libsndfile performs the conversion something like "f32 = (u8 / 256) * 2 - 1", however I think it should be "f32 = (u8 / 255) * 2 - 1" (note + the divisor of 256 vs 255). I use libsndfile as a benchmark for testing, so I'm therefore leaving this block here just for my automated + correctness testing. This is disabled by default. + */ + for (i = 0; i < sampleCount; ++i) { + *pOut++ = (pIn[i] / 256.0f) * 2 - 1; + } +#else + for (i = 0; i < sampleCount; ++i) { + float x = pIn[i]; + x = x * 0.00784313725490196078f; /* 0..255 to 0..2 */ + x = x - 1; /* 0..2 to -1..1 */ + + *pOut++ = x; + } +#endif +} + +DRWAV_API void drwav_s16_to_f32(float* pOut, const drwav_int16* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + *pOut++ = pIn[i] * 0.000030517578125f; + } +} + +DRWAV_API void drwav_s24_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + double x; + drwav_uint32 a = ((drwav_uint32)(pIn[i*3+0]) << 8); + drwav_uint32 b = ((drwav_uint32)(pIn[i*3+1]) << 16); + drwav_uint32 c = ((drwav_uint32)(pIn[i*3+2]) << 24); + + x = (double)((drwav_int32)(a | b | c) >> 8); + *pOut++ = (float)(x * 0.00000011920928955078125); + } +} + +DRWAV_API void drwav_s32_to_f32(float* pOut, const drwav_int32* pIn, size_t sampleCount) +{ + size_t i; + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + *pOut++ = (float)(pIn[i] / 2147483648.0); + } +} + +DRWAV_API void drwav_f64_to_f32(float* pOut, const double* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + *pOut++ = (float)pIn[i]; + } +} + +DRWAV_API void drwav_alaw_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + *pOut++ = drwav__alaw_to_s16(pIn[i]) / 32768.0f; + } +} + +DRWAV_API void drwav_mulaw_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + *pOut++ = drwav__mulaw_to_s16(pIn[i]) / 32768.0f; + } +} + + + +static void drwav__pcm_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t totalSampleCount, unsigned int bytesPerSample) +{ + unsigned int i; + + /* Special case for 8-bit sample data because it's treated as unsigned. */ + if (bytesPerSample == 1) { + drwav_u8_to_s32(pOut, pIn, totalSampleCount); + return; + } + + /* Slightly more optimal implementation for common formats. */ + if (bytesPerSample == 2) { + drwav_s16_to_s32(pOut, (const drwav_int16*)pIn, totalSampleCount); + return; + } + if (bytesPerSample == 3) { + drwav_s24_to_s32(pOut, pIn, totalSampleCount); + return; + } + if (bytesPerSample == 4) { + for (i = 0; i < totalSampleCount; ++i) { + *pOut++ = ((const drwav_int32*)pIn)[i]; + } + return; + } + + + /* Anything more than 64 bits per sample is not supported. */ + if (bytesPerSample > 8) { + DRWAV_ZERO_MEMORY(pOut, totalSampleCount * sizeof(*pOut)); + return; + } + + + /* Generic, slow converter. */ + for (i = 0; i < totalSampleCount; ++i) { + drwav_uint64 sample = 0; + unsigned int shift = (8 - bytesPerSample) * 8; + + unsigned int j; + for (j = 0; j < bytesPerSample; j += 1) { + DRWAV_ASSERT(j < 8); + sample |= (drwav_uint64)(pIn[j]) << shift; + shift += 8; + } + + pIn += j; + *pOut++ = (drwav_int32)((drwav_int64)sample >> 32); + } +} + +static void drwav__ieee_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t totalSampleCount, unsigned int bytesPerSample) +{ + if (bytesPerSample == 4) { + drwav_f32_to_s32(pOut, (const float*)pIn, totalSampleCount); + return; + } else if (bytesPerSample == 8) { + drwav_f64_to_s32(pOut, (const double*)pIn, totalSampleCount); + return; + } else { + /* Only supporting 32- and 64-bit float. Output silence in all other cases. Contributions welcome for 16-bit float. */ + DRWAV_ZERO_MEMORY(pOut, totalSampleCount * sizeof(*pOut)); + return; + } +} + + +static drwav_uint64 drwav_read_pcm_frames_s32__pcm(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + drwav_uint32 bytesPerFrame; + + /* Fast path. */ + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_PCM && pWav->bitsPerSample == 32) { + return drwav_read_pcm_frames(pWav, framesToRead, pBufferOut); + } + + bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav__pcm_to_s32(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels), bytesPerFrame/pWav->channels); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_s32__msadpcm(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut) +{ + /* + We're just going to borrow the implementation from the drwav_read_s16() since ADPCM is a little bit more complicated than other formats and I don't + want to duplicate that code. + */ + drwav_uint64 totalFramesRead = 0; + drwav_int16 samples16[2048]; + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames_s16(pWav, drwav_min(framesToRead, drwav_countof(samples16)/pWav->channels), samples16); + if (framesRead == 0) { + break; + } + + drwav_s16_to_s32(pBufferOut, samples16, (size_t)(framesRead*pWav->channels)); /* <-- Safe cast because we're clamping to 2048. */ + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_s32__ima(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut) +{ + /* + We're just going to borrow the implementation from the drwav_read_s16() since IMA-ADPCM is a little bit more complicated than other formats and I don't + want to duplicate that code. + */ + drwav_uint64 totalFramesRead = 0; + drwav_int16 samples16[2048]; + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames_s16(pWav, drwav_min(framesToRead, drwav_countof(samples16)/pWav->channels), samples16); + if (framesRead == 0) { + break; + } + + drwav_s16_to_s32(pBufferOut, samples16, (size_t)(framesRead*pWav->channels)); /* <-- Safe cast because we're clamping to 2048. */ + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_s32__ieee(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + + drwav_uint32 bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav__ieee_to_s32(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels), bytesPerFrame/pWav->channels); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_s32__alaw(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + + drwav_uint32 bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav_alaw_to_s32(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels)); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_s32__mulaw(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + + drwav_uint32 bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav_mulaw_to_s32(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels)); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s32(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut) +{ + if (pWav == NULL || framesToRead == 0) { + return 0; + } + + if (pBufferOut == NULL) { + return drwav_read_pcm_frames(pWav, framesToRead, NULL); + } + + /* Don't try to read more samples than can potentially fit in the output buffer. */ + if (framesToRead * pWav->channels * sizeof(drwav_int32) > DRWAV_SIZE_MAX) { + framesToRead = DRWAV_SIZE_MAX / sizeof(drwav_int32) / pWav->channels; + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_PCM) { + return drwav_read_pcm_frames_s32__pcm(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) { + return drwav_read_pcm_frames_s32__msadpcm(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_IEEE_FLOAT) { + return drwav_read_pcm_frames_s32__ieee(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ALAW) { + return drwav_read_pcm_frames_s32__alaw(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_MULAW) { + return drwav_read_pcm_frames_s32__mulaw(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) { + return drwav_read_pcm_frames_s32__ima(pWav, framesToRead, pBufferOut); + } + + return 0; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s32le(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut) +{ + drwav_uint64 framesRead = drwav_read_pcm_frames_s32(pWav, framesToRead, pBufferOut); + if (pBufferOut != NULL && drwav__is_little_endian() == DRWAV_FALSE) { + drwav__bswap_samples_s32(pBufferOut, framesRead*pWav->channels); + } + + return framesRead; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s32be(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut) +{ + drwav_uint64 framesRead = drwav_read_pcm_frames_s32(pWav, framesToRead, pBufferOut); + if (pBufferOut != NULL && drwav__is_little_endian() == DRWAV_TRUE) { + drwav__bswap_samples_s32(pBufferOut, framesRead*pWav->channels); + } + + return framesRead; +} + + +DRWAV_API void drwav_u8_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + *pOut++ = ((int)pIn[i] - 128) << 24; + } +} + +DRWAV_API void drwav_s16_to_s32(drwav_int32* pOut, const drwav_int16* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + *pOut++ = pIn[i] << 16; + } +} + +DRWAV_API void drwav_s24_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + unsigned int s0 = pIn[i*3 + 0]; + unsigned int s1 = pIn[i*3 + 1]; + unsigned int s2 = pIn[i*3 + 2]; + + drwav_int32 sample32 = (drwav_int32)((s0 << 8) | (s1 << 16) | (s2 << 24)); + *pOut++ = sample32; + } +} + +DRWAV_API void drwav_f32_to_s32(drwav_int32* pOut, const float* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + *pOut++ = (drwav_int32)(2147483648.0 * pIn[i]); + } +} + +DRWAV_API void drwav_f64_to_s32(drwav_int32* pOut, const double* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + *pOut++ = (drwav_int32)(2147483648.0 * pIn[i]); + } +} + +DRWAV_API void drwav_alaw_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + *pOut++ = ((drwav_int32)drwav__alaw_to_s16(pIn[i])) << 16; + } +} + +DRWAV_API void drwav_mulaw_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i= 0; i < sampleCount; ++i) { + *pOut++ = ((drwav_int32)drwav__mulaw_to_s16(pIn[i])) << 16; + } +} + + + +static drwav_int16* drwav__read_pcm_frames_and_close_s16(drwav* pWav, unsigned int* channels, unsigned int* sampleRate, drwav_uint64* totalFrameCount) +{ + drwav_uint64 sampleDataSize; + drwav_int16* pSampleData; + drwav_uint64 framesRead; + + DRWAV_ASSERT(pWav != NULL); + + sampleDataSize = pWav->totalPCMFrameCount * pWav->channels * sizeof(drwav_int16); + if (sampleDataSize > DRWAV_SIZE_MAX) { + drwav_uninit(pWav); + return NULL; /* File's too big. */ + } + + pSampleData = (drwav_int16*)drwav__malloc_from_callbacks((size_t)sampleDataSize, &pWav->allocationCallbacks); /* <-- Safe cast due to the check above. */ + if (pSampleData == NULL) { + drwav_uninit(pWav); + return NULL; /* Failed to allocate memory. */ + } + + framesRead = drwav_read_pcm_frames_s16(pWav, (size_t)pWav->totalPCMFrameCount, pSampleData); + if (framesRead != pWav->totalPCMFrameCount) { + drwav__free_from_callbacks(pSampleData, &pWav->allocationCallbacks); + drwav_uninit(pWav); + return NULL; /* There was an error reading the samples. */ + } + + drwav_uninit(pWav); + + if (sampleRate) { + *sampleRate = pWav->sampleRate; + } + if (channels) { + *channels = pWav->channels; + } + if (totalFrameCount) { + *totalFrameCount = pWav->totalPCMFrameCount; + } + + return pSampleData; +} + +static float* drwav__read_pcm_frames_and_close_f32(drwav* pWav, unsigned int* channels, unsigned int* sampleRate, drwav_uint64* totalFrameCount) +{ + drwav_uint64 sampleDataSize; + float* pSampleData; + drwav_uint64 framesRead; + + DRWAV_ASSERT(pWav != NULL); + + sampleDataSize = pWav->totalPCMFrameCount * pWav->channels * sizeof(float); + if (sampleDataSize > DRWAV_SIZE_MAX) { + drwav_uninit(pWav); + return NULL; /* File's too big. */ + } + + pSampleData = (float*)drwav__malloc_from_callbacks((size_t)sampleDataSize, &pWav->allocationCallbacks); /* <-- Safe cast due to the check above. */ + if (pSampleData == NULL) { + drwav_uninit(pWav); + return NULL; /* Failed to allocate memory. */ + } + + framesRead = drwav_read_pcm_frames_f32(pWav, (size_t)pWav->totalPCMFrameCount, pSampleData); + if (framesRead != pWav->totalPCMFrameCount) { + drwav__free_from_callbacks(pSampleData, &pWav->allocationCallbacks); + drwav_uninit(pWav); + return NULL; /* There was an error reading the samples. */ + } + + drwav_uninit(pWav); + + if (sampleRate) { + *sampleRate = pWav->sampleRate; + } + if (channels) { + *channels = pWav->channels; + } + if (totalFrameCount) { + *totalFrameCount = pWav->totalPCMFrameCount; + } + + return pSampleData; +} + +static drwav_int32* drwav__read_pcm_frames_and_close_s32(drwav* pWav, unsigned int* channels, unsigned int* sampleRate, drwav_uint64* totalFrameCount) +{ + drwav_uint64 sampleDataSize; + drwav_int32* pSampleData; + drwav_uint64 framesRead; + + DRWAV_ASSERT(pWav != NULL); + + sampleDataSize = pWav->totalPCMFrameCount * pWav->channels * sizeof(drwav_int32); + if (sampleDataSize > DRWAV_SIZE_MAX) { + drwav_uninit(pWav); + return NULL; /* File's too big. */ + } + + pSampleData = (drwav_int32*)drwav__malloc_from_callbacks((size_t)sampleDataSize, &pWav->allocationCallbacks); /* <-- Safe cast due to the check above. */ + if (pSampleData == NULL) { + drwav_uninit(pWav); + return NULL; /* Failed to allocate memory. */ + } + + framesRead = drwav_read_pcm_frames_s32(pWav, (size_t)pWav->totalPCMFrameCount, pSampleData); + if (framesRead != pWav->totalPCMFrameCount) { + drwav__free_from_callbacks(pSampleData, &pWav->allocationCallbacks); + drwav_uninit(pWav); + return NULL; /* There was an error reading the samples. */ + } + + drwav_uninit(pWav); + + if (sampleRate) { + *sampleRate = pWav->sampleRate; + } + if (channels) { + *channels = pWav->channels; + } + if (totalFrameCount) { + *totalFrameCount = pWav->totalPCMFrameCount; + } + + return pSampleData; +} + + + +DRWAV_API drwav_int16* drwav_open_and_read_pcm_frames_s16(drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (channelsOut) { + *channelsOut = 0; + } + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init(&wav, onRead, onSeek, pUserData, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_s16(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} + +DRWAV_API float* drwav_open_and_read_pcm_frames_f32(drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (channelsOut) { + *channelsOut = 0; + } + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init(&wav, onRead, onSeek, pUserData, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_f32(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} + +DRWAV_API drwav_int32* drwav_open_and_read_pcm_frames_s32(drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (channelsOut) { + *channelsOut = 0; + } + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init(&wav, onRead, onSeek, pUserData, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_s32(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} + +#ifndef DR_WAV_NO_STDIO +DRWAV_API drwav_int16* drwav_open_file_and_read_pcm_frames_s16(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (channelsOut) { + *channelsOut = 0; + } + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init_file(&wav, filename, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_s16(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} + +DRWAV_API float* drwav_open_file_and_read_pcm_frames_f32(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (channelsOut) { + *channelsOut = 0; + } + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init_file(&wav, filename, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_f32(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} + +DRWAV_API drwav_int32* drwav_open_file_and_read_pcm_frames_s32(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (channelsOut) { + *channelsOut = 0; + } + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init_file(&wav, filename, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_s32(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} + + +DRWAV_API drwav_int16* drwav_open_file_and_read_pcm_frames_s16_w(const wchar_t* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (channelsOut) { + *channelsOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init_file_w(&wav, filename, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_s16(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} + +DRWAV_API float* drwav_open_file_and_read_pcm_frames_f32_w(const wchar_t* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (channelsOut) { + *channelsOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init_file_w(&wav, filename, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_f32(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} + +DRWAV_API drwav_int32* drwav_open_file_and_read_pcm_frames_s32_w(const wchar_t* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (channelsOut) { + *channelsOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init_file_w(&wav, filename, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_s32(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} +#endif + +DRWAV_API drwav_int16* drwav_open_memory_and_read_pcm_frames_s16(const void* data, size_t dataSize, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (channelsOut) { + *channelsOut = 0; + } + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init_memory(&wav, data, dataSize, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_s16(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} + +DRWAV_API float* drwav_open_memory_and_read_pcm_frames_f32(const void* data, size_t dataSize, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (channelsOut) { + *channelsOut = 0; + } + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init_memory(&wav, data, dataSize, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_f32(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} + +DRWAV_API drwav_int32* drwav_open_memory_and_read_pcm_frames_s32(const void* data, size_t dataSize, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (channelsOut) { + *channelsOut = 0; + } + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init_memory(&wav, data, dataSize, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_s32(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} +#endif /* DR_WAV_NO_CONVERSION_API */ + + +DRWAV_API void drwav_free(void* p, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (pAllocationCallbacks != NULL) { + drwav__free_from_callbacks(p, pAllocationCallbacks); + } else { + drwav__free_default(p, NULL); + } +} + +DRWAV_API drwav_uint16 drwav_bytes_to_u16(const drwav_uint8* data) +{ + return drwav__bytes_to_u16(data); +} + +DRWAV_API drwav_int16 drwav_bytes_to_s16(const drwav_uint8* data) +{ + return drwav__bytes_to_s16(data); +} + +DRWAV_API drwav_uint32 drwav_bytes_to_u32(const drwav_uint8* data) +{ + return drwav__bytes_to_u32(data); +} + +DRWAV_API drwav_int32 drwav_bytes_to_s32(const drwav_uint8* data) +{ + return drwav__bytes_to_s32(data); +} + +DRWAV_API drwav_uint64 drwav_bytes_to_u64(const drwav_uint8* data) +{ + return drwav__bytes_to_u64(data); +} + +DRWAV_API drwav_int64 drwav_bytes_to_s64(const drwav_uint8* data) +{ + return drwav__bytes_to_s64(data); +} + + +DRWAV_API drwav_bool32 drwav_guid_equal(const drwav_uint8 a[16], const drwav_uint8 b[16]) +{ + return drwav__guid_equal(a, b); +} + +DRWAV_API drwav_bool32 drwav_fourcc_equal(const drwav_uint8* a, const char* b) +{ + return drwav__fourcc_equal(a, b); +} + +#endif /* dr_wav_c */ +#endif /* DR_WAV_IMPLEMENTATION */ + +/* +RELEASE NOTES - v0.11.0 +======================= +Version 0.11.0 has breaking API changes. + +Improved Client-Defined Memory Allocation +----------------------------------------- +The main change with this release is the addition of a more flexible way of implementing custom memory allocation routines. The +existing system of DRWAV_MALLOC, DRWAV_REALLOC and DRWAV_FREE are still in place and will be used by default when no custom +allocation callbacks are specified. + +To use the new system, you pass in a pointer to a drwav_allocation_callbacks object to drwav_init() and family, like this: + + void* my_malloc(size_t sz, void* pUserData) + { + return malloc(sz); + } + void* my_realloc(void* p, size_t sz, void* pUserData) + { + return realloc(p, sz); + } + void my_free(void* p, void* pUserData) + { + free(p); + } + + ... + + drwav_allocation_callbacks allocationCallbacks; + allocationCallbacks.pUserData = &myData; + allocationCallbacks.onMalloc = my_malloc; + allocationCallbacks.onRealloc = my_realloc; + allocationCallbacks.onFree = my_free; + drwav_init_file(&wav, "my_file.wav", &allocationCallbacks); + +The advantage of this new system is that it allows you to specify user data which will be passed in to the allocation routines. + +Passing in null for the allocation callbacks object will cause dr_wav to use defaults which is the same as DRWAV_MALLOC, +DRWAV_REALLOC and DRWAV_FREE and the equivalent of how it worked in previous versions. + +Every API that opens a drwav object now takes this extra parameter. These include the following: + + drwav_init() + drwav_init_ex() + drwav_init_file() + drwav_init_file_ex() + drwav_init_file_w() + drwav_init_file_w_ex() + drwav_init_memory() + drwav_init_memory_ex() + drwav_init_write() + drwav_init_write_sequential() + drwav_init_write_sequential_pcm_frames() + drwav_init_file_write() + drwav_init_file_write_sequential() + drwav_init_file_write_sequential_pcm_frames() + drwav_init_file_write_w() + drwav_init_file_write_sequential_w() + drwav_init_file_write_sequential_pcm_frames_w() + drwav_init_memory_write() + drwav_init_memory_write_sequential() + drwav_init_memory_write_sequential_pcm_frames() + drwav_open_and_read_pcm_frames_s16() + drwav_open_and_read_pcm_frames_f32() + drwav_open_and_read_pcm_frames_s32() + drwav_open_file_and_read_pcm_frames_s16() + drwav_open_file_and_read_pcm_frames_f32() + drwav_open_file_and_read_pcm_frames_s32() + drwav_open_file_and_read_pcm_frames_s16_w() + drwav_open_file_and_read_pcm_frames_f32_w() + drwav_open_file_and_read_pcm_frames_s32_w() + drwav_open_memory_and_read_pcm_frames_s16() + drwav_open_memory_and_read_pcm_frames_f32() + drwav_open_memory_and_read_pcm_frames_s32() + +Endian Improvements +------------------- +Previously, the following APIs returned little-endian audio data. These now return native-endian data. This improves compatibility +on big-endian architectures. + + drwav_read_pcm_frames() + drwav_read_pcm_frames_s16() + drwav_read_pcm_frames_s32() + drwav_read_pcm_frames_f32() + drwav_open_and_read_pcm_frames_s16() + drwav_open_and_read_pcm_frames_s32() + drwav_open_and_read_pcm_frames_f32() + drwav_open_file_and_read_pcm_frames_s16() + drwav_open_file_and_read_pcm_frames_s32() + drwav_open_file_and_read_pcm_frames_f32() + drwav_open_file_and_read_pcm_frames_s16_w() + drwav_open_file_and_read_pcm_frames_s32_w() + drwav_open_file_and_read_pcm_frames_f32_w() + drwav_open_memory_and_read_pcm_frames_s16() + drwav_open_memory_and_read_pcm_frames_s32() + drwav_open_memory_and_read_pcm_frames_f32() + +APIs have been added to give you explicit control over whether or not audio data is read or written in big- or little-endian byte +order: + + drwav_read_pcm_frames_le() + drwav_read_pcm_frames_be() + drwav_read_pcm_frames_s16le() + drwav_read_pcm_frames_s16be() + drwav_read_pcm_frames_f32le() + drwav_read_pcm_frames_f32be() + drwav_read_pcm_frames_s32le() + drwav_read_pcm_frames_s32be() + drwav_write_pcm_frames_le() + drwav_write_pcm_frames_be() + +Removed APIs +------------ +The following APIs were deprecated in version 0.10.0 and have now been removed: + + drwav_open() + drwav_open_ex() + drwav_open_write() + drwav_open_write_sequential() + drwav_open_file() + drwav_open_file_ex() + drwav_open_file_write() + drwav_open_file_write_sequential() + drwav_open_memory() + drwav_open_memory_ex() + drwav_open_memory_write() + drwav_open_memory_write_sequential() + drwav_close() + + + +RELEASE NOTES - v0.10.0 +======================= +Version 0.10.0 has breaking API changes. There are no significant bug fixes in this release, so if you are affected you do +not need to upgrade. + +Removed APIs +------------ +The following APIs were deprecated in version 0.9.0 and have been completely removed in version 0.10.0: + + drwav_read() + drwav_read_s16() + drwav_read_f32() + drwav_read_s32() + drwav_seek_to_sample() + drwav_write() + drwav_open_and_read_s16() + drwav_open_and_read_f32() + drwav_open_and_read_s32() + drwav_open_file_and_read_s16() + drwav_open_file_and_read_f32() + drwav_open_file_and_read_s32() + drwav_open_memory_and_read_s16() + drwav_open_memory_and_read_f32() + drwav_open_memory_and_read_s32() + drwav::totalSampleCount + +See release notes for version 0.9.0 at the bottom of this file for replacement APIs. + +Deprecated APIs +--------------- +The following APIs have been deprecated. There is a confusing and completely arbitrary difference between drwav_init*() and +drwav_open*(), where drwav_init*() initializes a pre-allocated drwav object, whereas drwav_open*() will first allocated a +drwav object on the heap and then initialize it. drwav_open*() has been deprecated which means you must now use a pre- +allocated drwav object with drwav_init*(). If you need the previous functionality, you can just do a malloc() followed by +a called to one of the drwav_init*() APIs. + + drwav_open() + drwav_open_ex() + drwav_open_write() + drwav_open_write_sequential() + drwav_open_file() + drwav_open_file_ex() + drwav_open_file_write() + drwav_open_file_write_sequential() + drwav_open_memory() + drwav_open_memory_ex() + drwav_open_memory_write() + drwav_open_memory_write_sequential() + drwav_close() + +These APIs will be removed completely in a future version. The rationale for this change is to remove confusion between the +two different ways to initialize a drwav object. +*/ + +/* +REVISION HISTORY +================ +v0.12.16 - 2020-12-02 + - Fix a bug when trying to read more bytes than can fit in a size_t. + +v0.12.15 - 2020-11-21 + - Fix compilation with OpenWatcom. + +v0.12.14 - 2020-11-13 + - Minor code clean up. + +v0.12.13 - 2020-11-01 + - Improve compiler support for older versions of GCC. + +v0.12.12 - 2020-09-28 + - Add support for RF64. + - Fix a bug in writing mode where the size of the RIFF chunk incorrectly includes the header section. + +v0.12.11 - 2020-09-08 + - Fix a compilation error on older compilers. + +v0.12.10 - 2020-08-24 + - Fix a bug when seeking with ADPCM formats. + +v0.12.9 - 2020-08-02 + - Simplify sized types. + +v0.12.8 - 2020-07-25 + - Fix a compilation warning. + +v0.12.7 - 2020-07-15 + - Fix some bugs on big-endian architectures. + - Fix an error in s24 to f32 conversion. + +v0.12.6 - 2020-06-23 + - Change drwav_read_*() to allow NULL to be passed in as the output buffer which is equivalent to a forward seek. + - Fix a buffer overflow when trying to decode invalid IMA-ADPCM files. + - Add include guard for the implementation section. + +v0.12.5 - 2020-05-27 + - Minor documentation fix. + +v0.12.4 - 2020-05-16 + - Replace assert() with DRWAV_ASSERT(). + - Add compile-time and run-time version querying. + - DRWAV_VERSION_MINOR + - DRWAV_VERSION_MAJOR + - DRWAV_VERSION_REVISION + - DRWAV_VERSION_STRING + - drwav_version() + - drwav_version_string() + +v0.12.3 - 2020-04-30 + - Fix compilation errors with VC6. + +v0.12.2 - 2020-04-21 + - Fix a bug where drwav_init_file() does not close the file handle after attempting to load an erroneous file. + +v0.12.1 - 2020-04-13 + - Fix some pedantic warnings. + +v0.12.0 - 2020-04-04 + - API CHANGE: Add container and format parameters to the chunk callback. + - Minor documentation updates. + +v0.11.5 - 2020-03-07 + - Fix compilation error with Visual Studio .NET 2003. + +v0.11.4 - 2020-01-29 + - Fix some static analysis warnings. + - Fix a bug when reading f32 samples from an A-law encoded stream. + +v0.11.3 - 2020-01-12 + - Minor changes to some f32 format conversion routines. + - Minor bug fix for ADPCM conversion when end of file is reached. + +v0.11.2 - 2019-12-02 + - Fix a possible crash when using custom memory allocators without a custom realloc() implementation. + - Fix an integer overflow bug. + - Fix a null pointer dereference bug. + - Add limits to sample rate, channels and bits per sample to tighten up some validation. + +v0.11.1 - 2019-10-07 + - Internal code clean up. + +v0.11.0 - 2019-10-06 + - API CHANGE: Add support for user defined memory allocation routines. This system allows the program to specify their own memory allocation + routines with a user data pointer for client-specific contextual data. This adds an extra parameter to the end of the following APIs: + - drwav_init() + - drwav_init_ex() + - drwav_init_file() + - drwav_init_file_ex() + - drwav_init_file_w() + - drwav_init_file_w_ex() + - drwav_init_memory() + - drwav_init_memory_ex() + - drwav_init_write() + - drwav_init_write_sequential() + - drwav_init_write_sequential_pcm_frames() + - drwav_init_file_write() + - drwav_init_file_write_sequential() + - drwav_init_file_write_sequential_pcm_frames() + - drwav_init_file_write_w() + - drwav_init_file_write_sequential_w() + - drwav_init_file_write_sequential_pcm_frames_w() + - drwav_init_memory_write() + - drwav_init_memory_write_sequential() + - drwav_init_memory_write_sequential_pcm_frames() + - drwav_open_and_read_pcm_frames_s16() + - drwav_open_and_read_pcm_frames_f32() + - drwav_open_and_read_pcm_frames_s32() + - drwav_open_file_and_read_pcm_frames_s16() + - drwav_open_file_and_read_pcm_frames_f32() + - drwav_open_file_and_read_pcm_frames_s32() + - drwav_open_file_and_read_pcm_frames_s16_w() + - drwav_open_file_and_read_pcm_frames_f32_w() + - drwav_open_file_and_read_pcm_frames_s32_w() + - drwav_open_memory_and_read_pcm_frames_s16() + - drwav_open_memory_and_read_pcm_frames_f32() + - drwav_open_memory_and_read_pcm_frames_s32() + Set this extra parameter to NULL to use defaults which is the same as the previous behaviour. Setting this NULL will use + DRWAV_MALLOC, DRWAV_REALLOC and DRWAV_FREE. + - Add support for reading and writing PCM frames in an explicit endianness. New APIs: + - drwav_read_pcm_frames_le() + - drwav_read_pcm_frames_be() + - drwav_read_pcm_frames_s16le() + - drwav_read_pcm_frames_s16be() + - drwav_read_pcm_frames_f32le() + - drwav_read_pcm_frames_f32be() + - drwav_read_pcm_frames_s32le() + - drwav_read_pcm_frames_s32be() + - drwav_write_pcm_frames_le() + - drwav_write_pcm_frames_be() + - Remove deprecated APIs. + - API CHANGE: The following APIs now return native-endian data. Previously they returned little-endian data. + - drwav_read_pcm_frames() + - drwav_read_pcm_frames_s16() + - drwav_read_pcm_frames_s32() + - drwav_read_pcm_frames_f32() + - drwav_open_and_read_pcm_frames_s16() + - drwav_open_and_read_pcm_frames_s32() + - drwav_open_and_read_pcm_frames_f32() + - drwav_open_file_and_read_pcm_frames_s16() + - drwav_open_file_and_read_pcm_frames_s32() + - drwav_open_file_and_read_pcm_frames_f32() + - drwav_open_file_and_read_pcm_frames_s16_w() + - drwav_open_file_and_read_pcm_frames_s32_w() + - drwav_open_file_and_read_pcm_frames_f32_w() + - drwav_open_memory_and_read_pcm_frames_s16() + - drwav_open_memory_and_read_pcm_frames_s32() + - drwav_open_memory_and_read_pcm_frames_f32() + +v0.10.1 - 2019-08-31 + - Correctly handle partial trailing ADPCM blocks. + +v0.10.0 - 2019-08-04 + - Remove deprecated APIs. + - Add wchar_t variants for file loading APIs: + drwav_init_file_w() + drwav_init_file_ex_w() + drwav_init_file_write_w() + drwav_init_file_write_sequential_w() + - Add drwav_target_write_size_bytes() which calculates the total size in bytes of a WAV file given a format and sample count. + - Add APIs for specifying the PCM frame count instead of the sample count when opening in sequential write mode: + drwav_init_write_sequential_pcm_frames() + drwav_init_file_write_sequential_pcm_frames() + drwav_init_file_write_sequential_pcm_frames_w() + drwav_init_memory_write_sequential_pcm_frames() + - Deprecate drwav_open*() and drwav_close(): + drwav_open() + drwav_open_ex() + drwav_open_write() + drwav_open_write_sequential() + drwav_open_file() + drwav_open_file_ex() + drwav_open_file_write() + drwav_open_file_write_sequential() + drwav_open_memory() + drwav_open_memory_ex() + drwav_open_memory_write() + drwav_open_memory_write_sequential() + drwav_close() + - Minor documentation updates. + +v0.9.2 - 2019-05-21 + - Fix warnings. + +v0.9.1 - 2019-05-05 + - Add support for C89. + - Change license to choice of public domain or MIT-0. + +v0.9.0 - 2018-12-16 + - API CHANGE: Add new reading APIs for reading by PCM frames instead of samples. Old APIs have been deprecated and + will be removed in v0.10.0. Deprecated APIs and their replacements: + drwav_read() -> drwav_read_pcm_frames() + drwav_read_s16() -> drwav_read_pcm_frames_s16() + drwav_read_f32() -> drwav_read_pcm_frames_f32() + drwav_read_s32() -> drwav_read_pcm_frames_s32() + drwav_seek_to_sample() -> drwav_seek_to_pcm_frame() + drwav_write() -> drwav_write_pcm_frames() + drwav_open_and_read_s16() -> drwav_open_and_read_pcm_frames_s16() + drwav_open_and_read_f32() -> drwav_open_and_read_pcm_frames_f32() + drwav_open_and_read_s32() -> drwav_open_and_read_pcm_frames_s32() + drwav_open_file_and_read_s16() -> drwav_open_file_and_read_pcm_frames_s16() + drwav_open_file_and_read_f32() -> drwav_open_file_and_read_pcm_frames_f32() + drwav_open_file_and_read_s32() -> drwav_open_file_and_read_pcm_frames_s32() + drwav_open_memory_and_read_s16() -> drwav_open_memory_and_read_pcm_frames_s16() + drwav_open_memory_and_read_f32() -> drwav_open_memory_and_read_pcm_frames_f32() + drwav_open_memory_and_read_s32() -> drwav_open_memory_and_read_pcm_frames_s32() + drwav::totalSampleCount -> drwav::totalPCMFrameCount + - API CHANGE: Rename drwav_open_and_read_file_*() to drwav_open_file_and_read_*(). + - API CHANGE: Rename drwav_open_and_read_memory_*() to drwav_open_memory_and_read_*(). + - Add built-in support for smpl chunks. + - Add support for firing a callback for each chunk in the file at initialization time. + - This is enabled through the drwav_init_ex(), etc. family of APIs. + - Handle invalid FMT chunks more robustly. + +v0.8.5 - 2018-09-11 + - Const correctness. + - Fix a potential stack overflow. + +v0.8.4 - 2018-08-07 + - Improve 64-bit detection. + +v0.8.3 - 2018-08-05 + - Fix C++ build on older versions of GCC. + +v0.8.2 - 2018-08-02 + - Fix some big-endian bugs. + +v0.8.1 - 2018-06-29 + - Add support for sequential writing APIs. + - Disable seeking in write mode. + - Fix bugs with Wave64. + - Fix typos. + +v0.8 - 2018-04-27 + - Bug fix. + - Start using major.minor.revision versioning. + +v0.7f - 2018-02-05 + - Restrict ADPCM formats to a maximum of 2 channels. + +v0.7e - 2018-02-02 + - Fix a crash. + +v0.7d - 2018-02-01 + - Fix a crash. + +v0.7c - 2018-02-01 + - Set drwav.bytesPerSample to 0 for all compressed formats. + - Fix a crash when reading 16-bit floating point WAV files. In this case dr_wav will output silence for + all format conversion reading APIs (*_s16, *_s32, *_f32 APIs). + - Fix some divide-by-zero errors. + +v0.7b - 2018-01-22 + - Fix errors with seeking of compressed formats. + - Fix compilation error when DR_WAV_NO_CONVERSION_API + +v0.7a - 2017-11-17 + - Fix some GCC warnings. + +v0.7 - 2017-11-04 + - Add writing APIs. + +v0.6 - 2017-08-16 + - API CHANGE: Rename dr_* types to drwav_*. + - Add support for custom implementations of malloc(), realloc(), etc. + - Add support for Microsoft ADPCM. + - Add support for IMA ADPCM (DVI, format code 0x11). + - Optimizations to drwav_read_s16(). + - Bug fixes. + +v0.5g - 2017-07-16 + - Change underlying type for booleans to unsigned. + +v0.5f - 2017-04-04 + - Fix a minor bug with drwav_open_and_read_s16() and family. + +v0.5e - 2016-12-29 + - Added support for reading samples as signed 16-bit integers. Use the _s16() family of APIs for this. + - Minor fixes to documentation. + +v0.5d - 2016-12-28 + - Use drwav_int* and drwav_uint* sized types to improve compiler support. + +v0.5c - 2016-11-11 + - Properly handle JUNK chunks that come before the FMT chunk. + +v0.5b - 2016-10-23 + - A minor change to drwav_bool8 and drwav_bool32 types. + +v0.5a - 2016-10-11 + - Fixed a bug with drwav_open_and_read() and family due to incorrect argument ordering. + - Improve A-law and mu-law efficiency. + +v0.5 - 2016-09-29 + - API CHANGE. Swap the order of "channels" and "sampleRate" parameters in drwav_open_and_read*(). Rationale for this is to + keep it consistent with dr_audio and dr_flac. + +v0.4b - 2016-09-18 + - Fixed a typo in documentation. + +v0.4a - 2016-09-18 + - Fixed a typo. + - Change date format to ISO 8601 (YYYY-MM-DD) + +v0.4 - 2016-07-13 + - API CHANGE. Make onSeek consistent with dr_flac. + - API CHANGE. Rename drwav_seek() to drwav_seek_to_sample() for clarity and consistency with dr_flac. + - Added support for Sony Wave64. + +v0.3a - 2016-05-28 + - API CHANGE. Return drwav_bool32 instead of int in onSeek callback. + - Fixed a memory leak. + +v0.3 - 2016-05-22 + - Lots of API changes for consistency. + +v0.2a - 2016-05-16 + - Fixed Linux/GCC build. + +v0.2 - 2016-05-11 + - Added support for reading data as signed 32-bit PCM for consistency with dr_flac. + +v0.1a - 2016-05-07 + - Fixed a bug in drwav_open_file() where the file handle would not be closed if the loader failed to initialize. + +v0.1 - 2016-05-04 + - Initial versioned release. +*/ + +/* +This software is available as a choice of the following licenses. Choose +whichever you prefer. + +=============================================================================== +ALTERNATIVE 1 - Public Domain (www.unlicense.org) +=============================================================================== +This is free and unencumbered software released into the public domain. + +Anyone is free to copy, modify, publish, use, compile, sell, or distribute this +software, either in source code form or as a compiled binary, for any purpose, +commercial or non-commercial, and by any means. + +In jurisdictions that recognize copyright laws, the author or authors of this +software dedicate any and all copyright interest in the software to the public +domain. We make this dedication for the benefit of the public at large and to +the detriment of our heirs and successors. We intend this dedication to be an +overt act of relinquishment in perpetuity of all present and future rights to +this software under copyright law. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN +ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +For more information, please refer to + +=============================================================================== +ALTERNATIVE 2 - MIT No Attribution +=============================================================================== +Copyright 2020 David Reid + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ From 8a0a2910af98081e38465ca4a152dcecc00945a1 Mon Sep 17 00:00:00 2001 From: hiro-v Date: Fri, 26 Jan 2024 00:56:13 +0700 Subject: [PATCH 15/31] Movable mutex --- controllers/whisperCPP.cc | 7 +++---- controllers/whisperCPP.h | 33 ++++++++++++++++++++++++++++----- 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/controllers/whisperCPP.cc b/controllers/whisperCPP.cc index 7a7d0de8a..93b1f792c 100644 --- a/controllers/whisperCPP.cc +++ b/controllers/whisperCPP.cc @@ -838,7 +838,8 @@ void whisperCPP::load_model( return; } - whisper_server_context whisper; + + whisper_server_context whisper = whisper_server_context(model_id); bool model_loaded = whisper.load_model(model_path); // If model failed to load, return a 500 error if (!model_loaded) @@ -856,9 +857,7 @@ void whisperCPP::load_model( // Model loaded successfully, add it to the map of loaded models // and return a 200 response - // whispers.emplace(model_id, std::move(whisper)); - // whispers[model_id] = std::move(whisper); - whispers[model_id] = whisper; + whispers.emplace(model_id, std::move(whisper)); Json::Value jsonResp; std::string success_msg = "Model " + model_id + " loaded successfully"; jsonResp["message"] = success_msg; diff --git a/controllers/whisperCPP.h b/controllers/whisperCPP.h index 3eaeba30a..929540aea 100644 --- a/controllers/whisperCPP.h +++ b/controllers/whisperCPP.h @@ -132,14 +132,37 @@ bool parse_str_to_bool(const std::string &s); struct whisper_server_context { whisper_params params; - // store default params so we can reset after each inference request - whisper_params default_params = params; + whisper_params default_params; std::mutex whisper_mutex; std::string model_id; struct whisper_context_params cparams; struct whisper_context *ctx = nullptr; + whisper_server_context() = default; // add this line + + // Constructor + whisper_server_context(const std::string &model_id) + { + this->model_id = model_id; + this->cparams = whisper_context_params(); + this->ctx = nullptr; + // store default params so we can reset after each inference request + this->default_params = whisper_params(); + this->params = whisper_params(); + } + + // Move constructor + whisper_server_context(whisper_server_context&& other) noexcept + : params(std::move(other.params)) + , default_params(std::move(other.default_params)) + , whisper_mutex() // std::mutex is not movable, so we initialize a new one + , model_id(std::move(other.model_id)) + , cparams(std::move(other.cparams)) + , ctx(std::exchange(other.ctx, nullptr)) // ctx is a raw pointer, so we use std::exchange + { + } + bool load_model(std::string &model_path); std::string inference(std::string &input_file_path, std::string languague, std::string prompt, @@ -155,9 +178,9 @@ class whisperCPP : public drogon::HttpController public: METHOD_LIST_BEGIN - METHOD_ADD(whisperCPP::load_model, "load_model", Post); - METHOD_ADD(whisperCPP::unload_model, "unload_model", Post); - METHOD_ADD(whisperCPP::model_status, "model_status", Get); + ADD_METHOD_TO(whisperCPP::load_model, "/v1/audio/load_model", Post); + ADD_METHOD_TO(whisperCPP::unload_model, "/v1/audio/unload_model", Post); + ADD_METHOD_TO(whisperCPP::model_status, "/v1/audio/model_status", Get); ADD_METHOD_TO(whisperCPP::transcription, "/v1/audio/transcriptions", Post); ADD_METHOD_TO(whisperCPP::translation, "/v1/audio/translations", Post); From 34d36a10623001fdae822f262ec9c2c7e6677564 Mon Sep 17 00:00:00 2001 From: hiro-v Date: Fri, 26 Jan 2024 01:14:26 +0700 Subject: [PATCH 16/31] json happy case working --- controllers/whisperCPP.cc | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/controllers/whisperCPP.cc b/controllers/whisperCPP.cc index 93b1f792c..3e8053133 100644 --- a/controllers/whisperCPP.cc +++ b/controllers/whisperCPP.cc @@ -768,6 +768,7 @@ std::string whisper_server_context::inference(std::string &input_file_path, std: // return whisper model mutex lock whisper_mutex.unlock(); + LOG_INFO << "Successfully processed " << input_file_path << ": " << result; return result; } @@ -813,7 +814,7 @@ void whisperCPP::load_model( // Check if model is already loaded if (whispers.find(model_id) != whispers.end()) { - std::string error_msg = "Model " + model_id + "has not been loaded, please load that model into nitro"; + std::string error_msg = "Model " + model_id + " already loaded"; LOG_INFO << error_msg; Json::Value jsonResp; jsonResp["message"] = error_msg; @@ -977,9 +978,12 @@ void whisperCPP::transcription_impl( } // Save input file to temp location - std::string temp_file_path = std::filesystem::temp_directory_path().string() + "/" + std::to_string(std::chrono::system_clock::now().time_since_epoch().count()) + ".wav"; - file.save(temp_file_path); - + std::string temp_dir = std::filesystem::temp_directory_path().string() + "/" + std::to_string(std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count()); + // Create the directory + std::filesystem::create_directory(temp_dir); + // Save the file to the directory, with its original name + std::string temp_file_path = temp_dir + "/" + file.getFileName(); + file.saveAs(temp_file_path); // Run inference std::string result; @@ -994,8 +998,10 @@ void whisperCPP::transcription_impl( callback(resp); return; } + // TODO: Need to remove the entire temp directory, not just the file + std::remove(temp_file_path.c_str()); - auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); + auto resp = nitro_utils::nitroHttpResponse(); resp->setBody(result); resp->setStatusCode(k200OK); // Set content type based on response format @@ -1015,7 +1021,6 @@ void whisperCPP::transcription_impl( { resp->addHeader("Content-Type", "text/vtt"); } - std::remove(temp_file_path.c_str()); callback(resp); return; } From 56a509b1ee0205b181a3cf3e066a6d490f5fdd8b Mon Sep 17 00:00:00 2001 From: hiro-v Date: Fri, 26 Jan 2024 22:08:50 +0700 Subject: [PATCH 17/31] All output format working --- controllers/whisperCPP.cc | 47 ++++++++++++++++++++++++++++++++++----- 1 file changed, 42 insertions(+), 5 deletions(-) diff --git a/controllers/whisperCPP.cc b/controllers/whisperCPP.cc index 3e8053133..a5b6967f7 100644 --- a/controllers/whisperCPP.cc +++ b/controllers/whisperCPP.cc @@ -590,6 +590,7 @@ std::string whisper_server_context::inference(std::string &input_file_path, std: params.translate = translate; params.language = language; + params.response_format = response_format; if (!whisper_is_multilingual(ctx)) { if (params.language != "en" || params.translate) @@ -702,11 +703,11 @@ std::string whisper_server_context::inference(std::string &input_file_path, std: // return results to user std::string result; - if (response_format == text_format) + if (params.response_format == text_format) { result = output_str(ctx, params, pcmf32s); } - else if (response_format == srt_format) + else if (params.response_format == srt_format) { std::stringstream ss; const int n_segments = whisper_full_n_segments(ctx); @@ -754,7 +755,43 @@ std::string whisper_server_context::inference(std::string &input_file_path, std: } result = ss.str(); } - // TODO add more output formats + else if (params.response_format == vjson_format) { + /* try to match openai/whisper's Python format */ + std::string results = output_str(ctx, params, pcmf32s); + json jres = json{{"text", results}}; + const int n_segments = whisper_full_n_segments(ctx); + for (int i = 0; i < n_segments; ++i) + { + json segment = json{ + {"id", i}, + {"text", whisper_full_get_segment_text(ctx, i)}, + }; + + if (!params.no_timestamps) { + segment["start"] = whisper_full_get_segment_t0(ctx, i) * 0.01; + segment["end"] = whisper_full_get_segment_t1(ctx, i) * 0.01; + } + + const int n_tokens = whisper_full_n_tokens(ctx, i); + for (int j = 0; j < n_tokens; ++j) { + whisper_token_data token = whisper_full_get_token_data(ctx, i, j); + if (token.id >= whisper_token_eot(ctx)) { + continue; + } + + segment["tokens"].push_back(token.id); + json word = json{{"word", whisper_full_get_token_text(ctx, i, j)}}; + if (!params.no_timestamps) { + word["start"] = token.t0 * 0.01; + word["end"] = token.t1 * 0.01; + } + word["probability"] = token.p; + segment["words"].push_back(word); + } + jres["segments"].push_back(segment); + } + result = jres.dump(-1, ' ', false, json::error_handler_t::replace); + } else { std::string results = output_str(ctx, params, pcmf32s); @@ -908,7 +945,7 @@ void whisperCPP::unload_model( return; } -void whisperCPP::model_status( +void whisperCPP::list_model( const HttpRequestPtr &req, std::function &&callback) { @@ -1005,7 +1042,7 @@ void whisperCPP::transcription_impl( resp->setBody(result); resp->setStatusCode(k200OK); // Set content type based on response format - if (response_format == json_format) + if (response_format == json_format || response_format == vjson_format) { resp->addHeader("Content-Type", "application/json"); } From cf4c0342d648cde2032807b74012dd3c9284de58 Mon Sep 17 00:00:00 2001 From: hiro-v Date: Fri, 26 Jan 2024 22:09:04 +0700 Subject: [PATCH 18/31] Rename model_status to list_model --- controllers/whisperCPP.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/controllers/whisperCPP.h b/controllers/whisperCPP.h index 929540aea..a786c2df2 100644 --- a/controllers/whisperCPP.h +++ b/controllers/whisperCPP.h @@ -180,7 +180,7 @@ class whisperCPP : public drogon::HttpController ADD_METHOD_TO(whisperCPP::load_model, "/v1/audio/load_model", Post); ADD_METHOD_TO(whisperCPP::unload_model, "/v1/audio/unload_model", Post); - ADD_METHOD_TO(whisperCPP::model_status, "/v1/audio/model_status", Get); + ADD_METHOD_TO(whisperCPP::list_model, "/v1/audio/list_model", Get); ADD_METHOD_TO(whisperCPP::transcription, "/v1/audio/transcriptions", Post); ADD_METHOD_TO(whisperCPP::translation, "/v1/audio/translations", Post); @@ -197,8 +197,8 @@ class whisperCPP : public drogon::HttpController void unload_model(const HttpRequestPtr &req, std::function &&callback); - void model_status(const HttpRequestPtr &req, - std::function &&callback); + void list_model(const HttpRequestPtr &req, + std::function &&callback); void transcription(const HttpRequestPtr &req, std::function &&callback); From 43ffb9ddd8f8a805577b780bb94f474e8c7c6e3a Mon Sep 17 00:00:00 2001 From: hiro Date: Sat, 27 Jan 2024 00:11:39 +0700 Subject: [PATCH 19/31] chore: space inline --- controllers/llamaCPP.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/controllers/llamaCPP.cc b/controllers/llamaCPP.cc index 806f26000..3d703e491 100644 --- a/controllers/llamaCPP.cc +++ b/controllers/llamaCPP.cc @@ -203,6 +203,7 @@ void llamaCPP::chatCompletion( role = system_prompt; std::string content = message["content"].asString(); formatted_output = role + content + formatted_output; + } else { role = input_role; std::string content = message["content"].asString(); @@ -253,6 +254,7 @@ void llamaCPP::chatCompletion( no_images++; } } + } else if (input_role == "assistant") { role = ai_prompt; std::string content = message["content"].asString(); @@ -261,6 +263,7 @@ void llamaCPP::chatCompletion( role = system_prompt; std::string content = message["content"].asString(); formatted_output = role + content + formatted_output; + } else { role = input_role; std::string content = message["content"].asString(); @@ -609,4 +612,4 @@ void llamaCPP::stopBackgroundTask() { backgroundThread.join(); } } -} +} \ No newline at end of file From 167cddcb43a54c3a1e85fb3db8af259704e63b28 Mon Sep 17 00:00:00 2001 From: hiro-v Date: Sat, 27 Jan 2024 00:32:59 +0700 Subject: [PATCH 20/31] Init temp audio doc --- audio.md | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 audio.md diff --git a/audio.md b/audio.md new file mode 100644 index 000000000..a44d51643 --- /dev/null +++ b/audio.md @@ -0,0 +1,43 @@ +## Whisper.cpp build instruction + +### For NVIDIA GPU on Linux +- CUDA Toolkit 10.2, with nvcc in PATH +```bash +mkdir build && cd build +cmake -DLLAMA_CUBLAS=ON -DWHISPER_CUBLAS=ON .. +make -j$(nproc) +``` + +### For x86 CPU on Linux +```bash +mkdir build && cd build +cmake .. +make -j$(nproc) +``` + +## Sample test command +- Download `ggml-base.en.bin` with [whisper.cpp/models/download-ggml-model.sh](whisper.cpp/models/download-ggml-model.sh) +- Load model +```bash +curl 127.0.0.1:3928/v1/audio/load_model \ +-X POST -H "Content-Type: application/json" \ +-d '{"model_id":"ggml-base.en","model_path":"/abs/path/to/whisper.cpp/models/ggml-base.en.bin"}' +``` +- List model: +```bash +curl 127.0.0.1:3928/v1/audio/list_model +``` +- Download sample audio file from: +```bash +wget https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav +``` +- Sample transcription: +```bash +curl -X POST 127.0.0.1:3928/v1/audio/transcriptions \ +-H "Content-Type: multipart/form-data" \ +-F file="@/abs/path/to/jfk.wav" \ +-F model_id="ggml-base.en" \ +-F temperature="0.0" \ +-F response_format="verbose_json" # \ +# -F prompt="The transcript is about OpenAI which makes technology like DALL·E, GPT-3, and ChatGPT with the hope of one day building an AGI system that benefits all of humanity. The president is trying to raly people to support the cause." +``` \ No newline at end of file From d39a939df385b83a1d30bf5cdcd0fb36480c342c Mon Sep 17 00:00:00 2001 From: hiro Date: Sat, 27 Jan 2024 10:21:28 +0700 Subject: [PATCH 21/31] fix: Update mac silicon build note --- audio.md | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/audio.md b/audio.md index a44d51643..42a4cc7ec 100644 --- a/audio.md +++ b/audio.md @@ -1,37 +1,56 @@ ## Whisper.cpp build instruction ### For NVIDIA GPU on Linux + - CUDA Toolkit 10.2, with nvcc in PATH + ```bash mkdir build && cd build -cmake -DLLAMA_CUBLAS=ON -DWHISPER_CUBLAS=ON .. +cmake -DLLAMA_CUBLAS=ON -DWHISPER_CUBLAS=ON .. make -j$(nproc) ``` ### For x86 CPU on Linux + ```bash mkdir build && cd build cmake .. make -j$(nproc) ``` +### For Mac Silicon with CoreML support + +``` +# Download model in `.bin` and `.mlmodelc` in order to use on Mac Silicon +cmake -B build -DWHISPER_COREML=1 +cmake --build build -j --config Release +``` + ## Sample test command + - Download `ggml-base.en.bin` with [whisper.cpp/models/download-ggml-model.sh](whisper.cpp/models/download-ggml-model.sh) - Load model + ```bash curl 127.0.0.1:3928/v1/audio/load_model \ -X POST -H "Content-Type: application/json" \ -d '{"model_id":"ggml-base.en","model_path":"/abs/path/to/whisper.cpp/models/ggml-base.en.bin"}' ``` + - List model: + ```bash curl 127.0.0.1:3928/v1/audio/list_model ``` + - Download sample audio file from: + ```bash wget https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav ``` + - Sample transcription: + ```bash curl -X POST 127.0.0.1:3928/v1/audio/transcriptions \ -H "Content-Type: multipart/form-data" \ @@ -40,4 +59,4 @@ curl -X POST 127.0.0.1:3928/v1/audio/transcriptions \ -F temperature="0.0" \ -F response_format="verbose_json" # \ # -F prompt="The transcript is about OpenAI which makes technology like DALL·E, GPT-3, and ChatGPT with the hope of one day building an AGI system that benefits all of humanity. The president is trying to raly people to support the cause." -``` \ No newline at end of file +``` From 2f94bb2010597cd1be5cfd8de2a85ebe14eff7fa Mon Sep 17 00:00:00 2001 From: hiro Date: Sat, 27 Jan 2024 10:24:18 +0700 Subject: [PATCH 22/31] fix: add note for using mac silicon coreml --- audio.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/audio.md b/audio.md index 42a4cc7ec..e4e90e629 100644 --- a/audio.md +++ b/audio.md @@ -34,7 +34,10 @@ cmake --build build -j --config Release ```bash curl 127.0.0.1:3928/v1/audio/load_model \ -X POST -H "Content-Type: application/json" \ --d '{"model_id":"ggml-base.en","model_path":"/abs/path/to/whisper.cpp/models/ggml-base.en.bin"}' +-d '{"model_id":"ggml-base.en.bin","model_path":"/abs/path/to/whisper.cpp/models/ggml-base.en.bin"}' + + +# If we enable CoreML on Mac silicon, we need to include `ggml-base.mlmodelc` file in the same folder as `ggml-base.en.bin` ``` - List model: From 8340e57d31935c45ed10cf9eef5a6c2a8dd077ee Mon Sep 17 00:00:00 2001 From: hiro Date: Sat, 27 Jan 2024 11:59:03 +0700 Subject: [PATCH 23/31] fix(whispercpp): Update std::min and std::max on windows build --- controllers/whisperCPP.cc | 1828 ++++++++++++++++++------------------- controllers/whisperCPP.h | 284 +++--- 2 files changed, 1040 insertions(+), 1072 deletions(-) diff --git a/controllers/whisperCPP.cc b/controllers/whisperCPP.cc index a5b6967f7..e6d495fc5 100644 --- a/controllers/whisperCPP.cc +++ b/controllers/whisperCPP.cc @@ -2,1078 +2,1046 @@ // #include "whisper.h" // #include "llama.h" -bool read_wav(const std::string &fname, std::vector &pcmf32, std::vector> &pcmf32s, bool stereo) -{ - drwav wav; - std::vector wav_data; // used for pipe input from stdin +bool read_wav(const std::string &fname, std::vector &pcmf32, + std::vector> &pcmf32s, bool stereo) { + drwav wav; + std::vector wav_data; // used for pipe input from stdin - if (fname == "-") + if (fname == "-") { { - { - uint8_t buf[1024]; - while (true) - { - const size_t n = fread(buf, 1, sizeof(buf), stdin); - if (n == 0) - { - break; - } - wav_data.insert(wav_data.end(), buf, buf + n); - } + uint8_t buf[1024]; + while (true) { + const size_t n = fread(buf, 1, sizeof(buf), stdin); + if (n == 0) { + break; } - - if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) - { - fprintf(stderr, "error: failed to open WAV file from stdin\n"); - return false; - } - - fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size()); - } - else if (drwav_init_file(&wav, fname.c_str(), nullptr) == false) - { - fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname.c_str()); - return false; - } - - if (wav.channels != 1 && wav.channels != 2) - { - fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", __func__, fname.c_str()); - return false; + wav_data.insert(wav_data.end(), buf, buf + n); + } } - if (stereo && wav.channels != 2) - { - fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization\n", __func__, fname.c_str()); - return false; + if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == + false) { + fprintf(stderr, "error: failed to open WAV file from stdin\n"); + return false; } - if (wav.sampleRate != COMMON_SAMPLE_RATE) - { - fprintf(stderr, "%s: WAV file '%s' must be %i kHz\n", __func__, fname.c_str(), COMMON_SAMPLE_RATE / 1000); - return false; - } + fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, + wav_data.size()); + } else if (drwav_init_file(&wav, fname.c_str(), nullptr) == false) { + fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname.c_str()); + return false; + } - if (wav.bitsPerSample != 16) - { - fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", __func__, fname.c_str()); - return false; - } + if (wav.channels != 1 && wav.channels != 2) { + fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", __func__, + fname.c_str()); + return false; + } - const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size() / (wav.channels * wav.bitsPerSample / 8); + if (stereo && wav.channels != 2) { + fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization\n", + __func__, fname.c_str()); + return false; + } - std::vector pcm16; - pcm16.resize(n * wav.channels); - drwav_read_pcm_frames_s16(&wav, n, pcm16.data()); - drwav_uninit(&wav); + if (wav.sampleRate != COMMON_SAMPLE_RATE) { + fprintf(stderr, "%s: WAV file '%s' must be %i kHz\n", __func__, + fname.c_str(), COMMON_SAMPLE_RATE / 1000); + return false; + } - // convert to mono, float - pcmf32.resize(n); - if (wav.channels == 1) - { - for (uint64_t i = 0; i < n; i++) - { - pcmf32[i] = float(pcm16[i]) / 32768.0f; - } + if (wav.bitsPerSample != 16) { + fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", __func__, + fname.c_str()); + return false; + } + + const uint64_t n = + wav_data.empty() + ? wav.totalPCMFrameCount + : wav_data.size() / (wav.channels * wav.bitsPerSample / 8); + + std::vector pcm16; + pcm16.resize(n * wav.channels); + drwav_read_pcm_frames_s16(&wav, n, pcm16.data()); + drwav_uninit(&wav); + + // convert to mono, float + pcmf32.resize(n); + if (wav.channels == 1) { + for (uint64_t i = 0; i < n; i++) { + pcmf32[i] = float(pcm16[i]) / 32768.0f; } - else - { - for (uint64_t i = 0; i < n; i++) - { - pcmf32[i] = float(pcm16[2 * i] + pcm16[2 * i + 1]) / 65536.0f; - } + } else { + for (uint64_t i = 0; i < n; i++) { + pcmf32[i] = float(pcm16[2 * i] + pcm16[2 * i + 1]) / 65536.0f; } + } - if (stereo) - { - // convert to stereo, float - pcmf32s.resize(2); - - pcmf32s[0].resize(n); - pcmf32s[1].resize(n); - for (uint64_t i = 0; i < n; i++) - { - pcmf32s[0][i] = float(pcm16[2 * i]) / 32768.0f; - pcmf32s[1][i] = float(pcm16[2 * i + 1]) / 32768.0f; - } + if (stereo) { + // convert to stereo, float + pcmf32s.resize(2); + + pcmf32s[0].resize(n); + pcmf32s[1].resize(n); + for (uint64_t i = 0; i < n; i++) { + pcmf32s[0][i] = float(pcm16[2 * i]) / 32768.0f; + pcmf32s[1][i] = float(pcm16[2 * i + 1]) / 32768.0f; } + } - return true; + return true; } -std::string output_str(struct whisper_context *ctx, const whisper_params ¶ms, std::vector> pcmf32s) -{ - std::stringstream result; - const int n_segments = whisper_full_n_segments(ctx); - for (int i = 0; i < n_segments; ++i) - { - const char *text = whisper_full_get_segment_text(ctx, i); - std::string speaker = ""; - - if (params.diarize && pcmf32s.size() == 2) - { - const int64_t t0 = whisper_full_get_segment_t0(ctx, i); - const int64_t t1 = whisper_full_get_segment_t1(ctx, i); - speaker = estimate_diarization_speaker(pcmf32s, t0, t1); - } +std::string output_str(struct whisper_context *ctx, + const whisper_params ¶ms, + std::vector> pcmf32s) { + std::stringstream result; + const int n_segments = whisper_full_n_segments(ctx); + for (int i = 0; i < n_segments; ++i) { + const char *text = whisper_full_get_segment_text(ctx, i); + std::string speaker = ""; - result << speaker << text << "\n"; + if (params.diarize && pcmf32s.size() == 2) { + const int64_t t0 = whisper_full_get_segment_t0(ctx, i); + const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + speaker = estimate_diarization_speaker(pcmf32s, t0, t1); } - return result.str(); + + result << speaker << text << "\n"; + } + return result.str(); } -std::string estimate_diarization_speaker(std::vector> pcmf32s, int64_t t0, int64_t t1, bool id_only) -{ - std::string speaker = ""; - const int64_t n_samples = pcmf32s[0].size(); +std::string +estimate_diarization_speaker(std::vector> pcmf32s, + int64_t t0, int64_t t1, bool id_only) { + std::string speaker = ""; + const int64_t n_samples = pcmf32s[0].size(); - const int64_t is0 = timestamp_to_sample(t0, n_samples); - const int64_t is1 = timestamp_to_sample(t1, n_samples); + const int64_t is0 = timestamp_to_sample(t0, n_samples); + const int64_t is1 = timestamp_to_sample(t1, n_samples); - double energy0 = 0.0f; - double energy1 = 0.0f; + double energy0 = 0.0f; + double energy1 = 0.0f; - for (int64_t j = is0; j < is1; j++) - { - energy0 += fabs(pcmf32s[0][j]); - energy1 += fabs(pcmf32s[1][j]); - } + for (int64_t j = is0; j < is1; j++) { + energy0 += fabs(pcmf32s[0][j]); + energy1 += fabs(pcmf32s[1][j]); + } - if (energy0 > 1.1 * energy1) - { - speaker = "0"; - } - else if (energy1 > 1.1 * energy0) - { - speaker = "1"; - } - else - { - speaker = "?"; - } + if (energy0 > 1.1 * energy1) { + speaker = "0"; + } else if (energy1 > 1.1 * energy0) { + speaker = "1"; + } else { + speaker = "?"; + } - // printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, speaker = %s\n", is0, is1, energy0, energy1, speaker.c_str()); + // printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, speaker = + // %s\n", is0, is1, energy0, energy1, speaker.c_str()); - if (!id_only) - { - speaker.insert(0, "(speaker "); - speaker.append(")"); - } + if (!id_only) { + speaker.insert(0, "(speaker "); + speaker.append(")"); + } - return speaker; + return speaker; } // 500 -> 00:05.000 // 6000 -> 01:00.000 -std::string to_timestamp(int64_t t, bool comma) -{ - int64_t msec = t * 10; - int64_t hr = msec / (1000 * 60 * 60); - msec = msec - hr * (1000 * 60 * 60); - int64_t min = msec / (1000 * 60); - msec = msec - min * (1000 * 60); - int64_t sec = msec / 1000; - msec = msec - sec * 1000; - - char buf[32]; - snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int)hr, (int)min, (int)sec, comma ? "," : ".", (int)msec); - - return std::string(buf); +std::string to_timestamp(int64_t t, bool comma) { + int64_t msec = t * 10; + int64_t hr = msec / (1000 * 60 * 60); + msec = msec - hr * (1000 * 60 * 60); + int64_t min = msec / (1000 * 60); + msec = msec - min * (1000 * 60); + int64_t sec = msec / 1000; + msec = msec - sec * 1000; + + char buf[32]; + snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int)hr, (int)min, + (int)sec, comma ? "," : ".", (int)msec); + + return std::string(buf); } -int timestamp_to_sample(int64_t t, int n_samples) -{ - return std::max(0, std::min((int)n_samples - 1, (int)((t * WHISPER_SAMPLE_RATE) / 100))); +int timestamp_to_sample(int64_t t, int n_samples) { + return (std::max)(0, (std::min)((int)n_samples - 1, + (int)((t * WHISPER_SAMPLE_RATE) / 100))); } -bool is_file_exist(const char *fileName) -{ - std::ifstream infile(fileName); - return infile.good(); +bool is_file_exist(const char *fileName) { + std::ifstream infile(fileName); + return infile.good(); } -void whisper_print_usage(int /*argc*/, char **argv, const whisper_params ¶ms) -{ - fprintf(stderr, "\n"); - fprintf(stderr, "usage: %s [options] \n", argv[0]); - fprintf(stderr, "\n"); - fprintf(stderr, "options:\n"); - fprintf(stderr, " -h, --help [default] show this help message and exit\n"); - fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); - fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors); - fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms); - fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n); - fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms); - fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context); - fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len); - fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false"); - fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of); - fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size); - fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); - fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold); - fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold); - // fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); - fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false"); - fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); - fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); - fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false"); - fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false"); - fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); - fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false"); - fprintf(stderr, " -pr, --print-realtime [%-7s] print output in realtime\n", params.print_realtime ? "true" : "false"); - fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false"); - fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false"); - fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str()); - fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false"); - fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str()); - fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); - fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str()); - fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server", params.ffmpeg_converter ? "true" : "false"); - fprintf(stderr, "\n"); +void whisper_print_usage(int /*argc*/, char **argv, + const whisper_params ¶ms) { + fprintf(stderr, "\n"); + fprintf(stderr, "usage: %s [options] \n", argv[0]); + fprintf(stderr, "\n"); + fprintf(stderr, "options:\n"); + fprintf(stderr, " -h, --help [default] show this help " + "message and exit\n"); + fprintf(stderr, + " -t N, --threads N [%-7d] number of threads to use " + "during computation\n", + params.n_threads); + fprintf(stderr, + " -p N, --processors N [%-7d] number of processors to use " + "during computation\n", + params.n_processors); + fprintf( + stderr, + " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", + params.offset_t_ms); + fprintf(stderr, + " -on N, --offset-n N [%-7d] segment index offset\n", + params.offset_n); + fprintf(stderr, + " -d N, --duration N [%-7d] duration of audio to " + "process in milliseconds\n", + params.duration_ms); + fprintf(stderr, + " -mc N, --max-context N [%-7d] maximum number of text " + "context tokens to store\n", + params.max_context); + fprintf(stderr, + " -ml N, --max-len N [%-7d] maximum segment length in " + "characters\n", + params.max_len); + fprintf(stderr, + " -sow, --split-on-word [%-7s] split on word rather than " + "on token\n", + params.split_on_word ? "true" : "false"); + fprintf(stderr, + " -bo N, --best-of N [%-7d] number of best candidates " + "to keep\n", + params.best_of); + fprintf(stderr, + " -bs N, --beam-size N [%-7d] beam size for beam search\n", + params.beam_size); + fprintf(stderr, + " -wt N, --word-thold N [%-7.2f] word timestamp " + "probability threshold\n", + params.word_thold); + fprintf(stderr, + " -et N, --entropy-thold N [%-7.2f] entropy threshold for " + "decoder fail\n", + params.entropy_thold); + fprintf(stderr, + " -lpt N, --logprob-thold N [%-7.2f] log probability threshold " + "for decoder fail\n", + params.logprob_thold); + // fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by + // x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); + fprintf(stderr, + " -debug, --debug-mode [%-7s] enable debug mode (eg. dump " + "log_mel)\n", + params.debug_mode ? "true" : "false"); + fprintf(stderr, + " -tr, --translate [%-7s] translate from source " + "language to english\n", + params.translate ? "true" : "false"); + fprintf(stderr, + " -di, --diarize [%-7s] stereo audio diarization\n", + params.diarize ? "true" : "false"); + fprintf(stderr, + " -tdrz, --tinydiarize [%-7s] enable tinydiarize " + "(requires a tdrz model)\n", + params.tinydiarize ? "true" : "false"); + fprintf(stderr, + " -nf, --no-fallback [%-7s] do not use temperature " + "fallback while decoding\n", + params.no_fallback ? "true" : "false"); + fprintf(stderr, + " -ps, --print-special [%-7s] print special tokens\n", + params.print_special ? "true" : "false"); + fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", + params.print_colors ? "true" : "false"); + fprintf(stderr, + " -pr, --print-realtime [%-7s] print output in realtime\n", + params.print_realtime ? "true" : "false"); + fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", + params.print_progress ? "true" : "false"); + fprintf(stderr, + " -nt, --no-timestamps [%-7s] do not print timestamps\n", + params.no_timestamps ? "true" : "false"); + fprintf(stderr, + " -l LANG, --language LANG [%-7s] spoken language ('auto' for " + "auto-detect)\n", + params.language.c_str()); + fprintf(stderr, + " -dl, --detect-language [%-7s] exit after automatically " + "detecting language\n", + params.detect_language ? "true" : "false"); + fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", + params.prompt.c_str()); + fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", + params.model.c_str()); + fprintf(stderr, + " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used " + "for encode inference\n", + params.openvino_encode_device.c_str()); + fprintf(stderr, + " --convert, [%-7s] Convert audio to WAV, " + "requires ffmpeg on the server", + params.ffmpeg_converter ? "true" : "false"); + fprintf(stderr, "\n"); } -bool whisper_params_parse(int argc, char **argv, whisper_params ¶ms) -{ - for (int i = 1; i < argc; i++) - { - std::string arg = argv[i]; - - if (arg == "-h" || arg == "--help") - { - whisper_print_usage(argc, argv, params); - exit(0); - } - else if (arg == "-t" || arg == "--threads") - { - params.n_threads = std::stoi(argv[++i]); - } - else if (arg == "-p" || arg == "--processors") - { - params.n_processors = std::stoi(argv[++i]); - } - else if (arg == "-ot" || arg == "--offset-t") - { - params.offset_t_ms = std::stoi(argv[++i]); - } - else if (arg == "-on" || arg == "--offset-n") - { - params.offset_n = std::stoi(argv[++i]); - } - else if (arg == "-d" || arg == "--duration") - { - params.duration_ms = std::stoi(argv[++i]); - } - else if (arg == "-mc" || arg == "--max-context") - { - params.max_context = std::stoi(argv[++i]); - } - else if (arg == "-ml" || arg == "--max-len") - { - params.max_len = std::stoi(argv[++i]); - } - else if (arg == "-bo" || arg == "--best-of") - { - params.best_of = std::stoi(argv[++i]); - } - else if (arg == "-bs" || arg == "--beam-size") - { - params.beam_size = std::stoi(argv[++i]); - } - else if (arg == "-wt" || arg == "--word-thold") - { - params.word_thold = std::stof(argv[++i]); - } - else if (arg == "-et" || arg == "--entropy-thold") - { - params.entropy_thold = std::stof(argv[++i]); - } - else if (arg == "-lpt" || arg == "--logprob-thold") - { - params.logprob_thold = std::stof(argv[++i]); - } - // else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } - else if (arg == "-debug" || arg == "--debug-mode") - { - params.debug_mode = true; - } - else if (arg == "-tr" || arg == "--translate") - { - params.translate = true; - } - else if (arg == "-di" || arg == "--diarize") - { - params.diarize = true; - } - else if (arg == "-tdrz" || arg == "--tinydiarize") - { - params.tinydiarize = true; - } - else if (arg == "-sow" || arg == "--split-on-word") - { - params.split_on_word = true; - } - else if (arg == "-nf" || arg == "--no-fallback") - { - params.no_fallback = true; - } - else if (arg == "-fp" || arg == "--font-path") - { - params.font_path = argv[++i]; - } - else if (arg == "-ps" || arg == "--print-special") - { - params.print_special = true; - } - else if (arg == "-pc" || arg == "--print-colors") - { - params.print_colors = true; - } - else if (arg == "-pr" || arg == "--print-realtime") - { - params.print_realtime = true; - } - else if (arg == "-pp" || arg == "--print-progress") - { - params.print_progress = true; - } - else if (arg == "-nt" || arg == "--no-timestamps") - { - params.no_timestamps = true; - } - else if (arg == "-l" || arg == "--language") - { - params.language = argv[++i]; - } - else if (arg == "-dl" || arg == "--detect-language") - { - params.detect_language = true; - } - else if (arg == "--prompt") - { - params.prompt = argv[++i]; - } - else if (arg == "-m" || arg == "--model") - { - params.model = argv[++i]; - } - else if (arg == "-oved" || arg == "--ov-e-device") - { - params.openvino_encode_device = argv[++i]; - } - else if (arg == "-ng" || arg == "--no-gpu") - { - params.use_gpu = false; - } - else if (arg == "--convert") - { - params.ffmpeg_converter = true; - } - else - { - fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); - whisper_print_usage(argc, argv, params); - exit(0); - } +bool whisper_params_parse(int argc, char **argv, whisper_params ¶ms) { + for (int i = 1; i < argc; i++) { + std::string arg = argv[i]; + + if (arg == "-h" || arg == "--help") { + whisper_print_usage(argc, argv, params); + exit(0); + } else if (arg == "-t" || arg == "--threads") { + params.n_threads = std::stoi(argv[++i]); + } else if (arg == "-p" || arg == "--processors") { + params.n_processors = std::stoi(argv[++i]); + } else if (arg == "-ot" || arg == "--offset-t") { + params.offset_t_ms = std::stoi(argv[++i]); + } else if (arg == "-on" || arg == "--offset-n") { + params.offset_n = std::stoi(argv[++i]); + } else if (arg == "-d" || arg == "--duration") { + params.duration_ms = std::stoi(argv[++i]); + } else if (arg == "-mc" || arg == "--max-context") { + params.max_context = std::stoi(argv[++i]); + } else if (arg == "-ml" || arg == "--max-len") { + params.max_len = std::stoi(argv[++i]); + } else if (arg == "-bo" || arg == "--best-of") { + params.best_of = std::stoi(argv[++i]); + } else if (arg == "-bs" || arg == "--beam-size") { + params.beam_size = std::stoi(argv[++i]); + } else if (arg == "-wt" || arg == "--word-thold") { + params.word_thold = std::stof(argv[++i]); + } else if (arg == "-et" || arg == "--entropy-thold") { + params.entropy_thold = std::stof(argv[++i]); + } else if (arg == "-lpt" || arg == "--logprob-thold") { + params.logprob_thold = std::stof(argv[++i]); + } + // else if (arg == "-su" || arg == "--speed-up") { params.speed_up + // = true; } + else if (arg == "-debug" || arg == "--debug-mode") { + params.debug_mode = true; + } else if (arg == "-tr" || arg == "--translate") { + params.translate = true; + } else if (arg == "-di" || arg == "--diarize") { + params.diarize = true; + } else if (arg == "-tdrz" || arg == "--tinydiarize") { + params.tinydiarize = true; + } else if (arg == "-sow" || arg == "--split-on-word") { + params.split_on_word = true; + } else if (arg == "-nf" || arg == "--no-fallback") { + params.no_fallback = true; + } else if (arg == "-fp" || arg == "--font-path") { + params.font_path = argv[++i]; + } else if (arg == "-ps" || arg == "--print-special") { + params.print_special = true; + } else if (arg == "-pc" || arg == "--print-colors") { + params.print_colors = true; + } else if (arg == "-pr" || arg == "--print-realtime") { + params.print_realtime = true; + } else if (arg == "-pp" || arg == "--print-progress") { + params.print_progress = true; + } else if (arg == "-nt" || arg == "--no-timestamps") { + params.no_timestamps = true; + } else if (arg == "-l" || arg == "--language") { + params.language = argv[++i]; + } else if (arg == "-dl" || arg == "--detect-language") { + params.detect_language = true; + } else if (arg == "--prompt") { + params.prompt = argv[++i]; + } else if (arg == "-m" || arg == "--model") { + params.model = argv[++i]; + } else if (arg == "-oved" || arg == "--ov-e-device") { + params.openvino_encode_device = argv[++i]; + } else if (arg == "-ng" || arg == "--no-gpu") { + params.use_gpu = false; + } else if (arg == "--convert") { + params.ffmpeg_converter = true; + } else { + fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); + whisper_print_usage(argc, argv, params); + exit(0); } + } - return true; + return true; } -void check_ffmpeg_availibility() -{ - int result = system("ffmpeg -version"); - - if (result == 0) - { - std::cout << "ffmpeg is available." << std::endl; - } - else - { - // ffmpeg is not available - std::cout << "ffmpeg is not found. Please ensure that ffmpeg is installed "; - std::cout << "and that its executable is included in your system's PATH. "; - exit(0); - } +void check_ffmpeg_availibility() { + int result = system("ffmpeg -version"); + + if (result == 0) { + std::cout << "ffmpeg is available." << std::endl; + } else { + // ffmpeg is not available + std::cout << "ffmpeg is not found. Please ensure that ffmpeg is installed "; + std::cout << "and that its executable is included in your system's PATH. "; + exit(0); + } } -bool convert_to_wav(const std::string &temp_filename, std::string &error_resp) -{ - std::ostringstream cmd_stream; - std::string converted_filename_temp = temp_filename + "_temp.wav"; - cmd_stream << "ffmpeg -i \"" << temp_filename << "\" -ar 16000 -ac 1 -c:a pcm_s16le \"" << converted_filename_temp << "\" 2>&1"; - std::string cmd = cmd_stream.str(); - - int status = std::system(cmd.c_str()); - if (status != 0) - { - error_resp = "{\"error\":\"FFmpeg conversion failed.\"}"; - return false; - } +bool convert_to_wav(const std::string &temp_filename, std::string &error_resp) { + std::ostringstream cmd_stream; + std::string converted_filename_temp = temp_filename + "_temp.wav"; + cmd_stream << "ffmpeg -i \"" << temp_filename + << "\" -ar 16000 -ac 1 -c:a pcm_s16le \"" + << converted_filename_temp << "\" 2>&1"; + std::string cmd = cmd_stream.str(); + + int status = std::system(cmd.c_str()); + if (status != 0) { + error_resp = "{\"error\":\"FFmpeg conversion failed.\"}"; + return false; + } - // Remove the original file - if (remove(temp_filename.c_str()) != 0) - { - error_resp = "{\"error\":\"Failed to remove the original file.\"}"; - return false; - } + // Remove the original file + if (remove(temp_filename.c_str()) != 0) { + error_resp = "{\"error\":\"Failed to remove the original file.\"}"; + return false; + } - // Rename the temporary file to match the original filename - if (rename(converted_filename_temp.c_str(), temp_filename.c_str()) != 0) - { - error_resp = "{\"error\":\"Failed to rename the temporary file.\"}"; - return false; - } - return true; + // Rename the temporary file to match the original filename + if (rename(converted_filename_temp.c_str(), temp_filename.c_str()) != 0) { + error_resp = "{\"error\":\"Failed to rename the temporary file.\"}"; + return false; + } + return true; } -void whisper_print_progress_callback(struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, int progress, void *user_data) -{ - int progress_step = ((whisper_print_user_data *)user_data)->params->progress_step; - int *progress_prev = &(((whisper_print_user_data *)user_data)->progress_prev); - if (progress >= *progress_prev + progress_step) - { - *progress_prev += progress_step; - fprintf(stderr, "%s: progress = %3d%%\n", __func__, progress); - } +void whisper_print_progress_callback(struct whisper_context * /*ctx*/, + struct whisper_state * /*state*/, + int progress, void *user_data) { + int progress_step = + ((whisper_print_user_data *)user_data)->params->progress_step; + int *progress_prev = &(((whisper_print_user_data *)user_data)->progress_prev); + if (progress >= *progress_prev + progress_step) { + *progress_prev += progress_step; + fprintf(stderr, "%s: progress = %3d%%\n", __func__, progress); + } } -void whisper_print_segment_callback(struct whisper_context *ctx, struct whisper_state * /*state*/, int n_new, void *user_data) -{ - const auto ¶ms = *((whisper_print_user_data *)user_data)->params; - const auto &pcmf32s = *((whisper_print_user_data *)user_data)->pcmf32s; +void whisper_print_segment_callback(struct whisper_context *ctx, + struct whisper_state * /*state*/, int n_new, + void *user_data) { + const auto ¶ms = *((whisper_print_user_data *)user_data)->params; + const auto &pcmf32s = *((whisper_print_user_data *)user_data)->pcmf32s; - const int n_segments = whisper_full_n_segments(ctx); + const int n_segments = whisper_full_n_segments(ctx); - std::string speaker = ""; + std::string speaker = ""; - int64_t t0 = 0; - int64_t t1 = 0; + int64_t t0 = 0; + int64_t t1 = 0; - // print the last n_new segments - const int s0 = n_segments - n_new; + // print the last n_new segments + const int s0 = n_segments - n_new; - if (s0 == 0) - { - printf("\n"); + if (s0 == 0) { + printf("\n"); + } + + for (int i = s0; i < n_segments; i++) { + if (!params.no_timestamps || params.diarize) { + t0 = whisper_full_get_segment_t0(ctx, i); + t1 = whisper_full_get_segment_t1(ctx, i); } - for (int i = s0; i < n_segments; i++) - { - if (!params.no_timestamps || params.diarize) - { - t0 = whisper_full_get_segment_t0(ctx, i); - t1 = whisper_full_get_segment_t1(ctx, i); - } + if (!params.no_timestamps) { + printf("[%s --> %s] ", to_timestamp(t0).c_str(), + to_timestamp(t1).c_str()); + } - if (!params.no_timestamps) - { - printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str()); - } + if (params.diarize && pcmf32s.size() == 2) { + speaker = estimate_diarization_speaker(pcmf32s, t0, t1); + } - if (params.diarize && pcmf32s.size() == 2) - { - speaker = estimate_diarization_speaker(pcmf32s, t0, t1); + if (params.print_colors) { + for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) { + if (params.print_special == false) { + const whisper_token id = whisper_full_get_token_id(ctx, i, j); + if (id >= whisper_token_eot(ctx)) { + continue; + } } - if (params.print_colors) - { - for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) - { - if (params.print_special == false) - { - const whisper_token id = whisper_full_get_token_id(ctx, i, j); - if (id >= whisper_token_eot(ctx)) - { - continue; - } - } - - const char *text = whisper_full_get_token_text(ctx, i, j); - const float p = whisper_full_get_token_p(ctx, i, j); - - const int col = std::max(0, std::min((int)k_colors.size() - 1, (int)(std::pow(p, 3) * float(k_colors.size())))); - - printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m"); - } - } - else - { - const char *text = whisper_full_get_segment_text(ctx, i); + const char *text = whisper_full_get_token_text(ctx, i, j); + const float p = whisper_full_get_token_p(ctx, i, j); - printf("%s%s", speaker.c_str(), text); - } + const int col = (std::max)( + 0, (std::min)((int)k_colors.size() - 1, + (int)((std::pow)(p, 3) * float(k_colors.size())))); - if (params.tinydiarize) - { - if (whisper_full_get_segment_speaker_turn_next(ctx, i)) - { - printf("%s", params.tdrz_speaker_turn.c_str()); - } - } + printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, + "\033[0m"); + } + } else { + const char *text = whisper_full_get_segment_text(ctx, i); - // with timestamps or speakers: each segment on new line - if (!params.no_timestamps || params.diarize) - { - printf("\n"); - } - fflush(stdout); + printf("%s%s", speaker.c_str(), text); } -} -bool parse_str_to_bool(const std::string &s) -{ - if (s == "true" || s == "1" || s == "yes" || s == "y") - { - return true; + if (params.tinydiarize) { + if (whisper_full_get_segment_speaker_turn_next(ctx, i)) { + printf("%s", params.tdrz_speaker_turn.c_str()); + } } - return false; -} -bool whisper_server_context::load_model(std::string &model_path) -{ - whisper_mutex.lock(); + // with timestamps or speakers: each segment on new line + if (!params.no_timestamps || params.diarize) { + printf("\n"); + } + fflush(stdout); + } +} - // clean up - whisper_free(ctx); +bool parse_str_to_bool(const std::string &s) { + if (s == "true" || s == "1" || s == "yes" || s == "y") { + return true; + } + return false; +} - // whisper init - ctx = whisper_init_from_file_with_params(model_path.c_str(), cparams); +bool whisper_server_context::load_model(std::string &model_path) { + whisper_mutex.lock(); - // TODO perhaps load prior model here instead of exit - if (ctx == nullptr) - { - whisper_mutex.unlock(); - return false; - } + // clean up + whisper_free(ctx); - // initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured - whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr); + // whisper init + ctx = whisper_init_from_file_with_params(model_path.c_str(), cparams); - // check if the model is in the file system + // TODO perhaps load prior model here instead of exit + if (ctx == nullptr) { whisper_mutex.unlock(); - return true; -} + return false; + } -std::string whisper_server_context::inference(std::string &input_file_path, std::string language, std::string prompt, - std::string response_format, float temperature, bool translate) -{ - // acquire whisper model mutex lock - whisper_mutex.lock(); + // initialize openvino encoder. this has no effect on whisper.cpp builds that + // don't have OpenVINO configured + whisper_ctx_init_openvino_encoder( + ctx, nullptr, params.openvino_encode_device.c_str(), nullptr); - // audio arrays - std::vector pcmf32; // mono-channel F32 PCM - std::vector> pcmf32s; // stereo-channel F32 PCM + // check if the model is in the file system + whisper_mutex.unlock(); + return true; +} - // if file is not wav, convert to wav - if (params.ffmpeg_converter) - { - std::string error_resp = "Failed to execute ffmpeg command converting " + input_file_path + " to wav"; - const bool is_converted = convert_to_wav(input_file_path, error_resp); - if (!is_converted) - { - whisper_mutex.unlock(); - LOG_ERROR << error_resp; - throw std::runtime_error(error_resp); - } +std::string whisper_server_context::inference( + std::string &input_file_path, std::string language, std::string prompt, + std::string response_format, float temperature, bool translate) { + // acquire whisper model mutex lock + whisper_mutex.lock(); + + // audio arrays + std::vector pcmf32; // mono-channel F32 PCM + std::vector> pcmf32s; // stereo-channel F32 PCM + + // if file is not wav, convert to wav + if (params.ffmpeg_converter) { + std::string error_resp = "Failed to execute ffmpeg command converting " + + input_file_path + " to wav"; + const bool is_converted = convert_to_wav(input_file_path, error_resp); + if (!is_converted) { + whisper_mutex.unlock(); + LOG_ERROR << error_resp; + throw std::runtime_error(error_resp); } + } - // read wav content into pcmf32 - if (!read_wav(input_file_path, pcmf32, pcmf32s, params.diarize)) - { - std::string error_resp = "Failed to read WAV file " + input_file_path; - LOG_ERROR << error_resp; - whisper_mutex.unlock(); - throw std::runtime_error(error_resp); + // read wav content into pcmf32 + if (!read_wav(input_file_path, pcmf32, pcmf32s, params.diarize)) { + std::string error_resp = "Failed to read WAV file " + input_file_path; + LOG_ERROR << error_resp; + whisper_mutex.unlock(); + throw std::runtime_error(error_resp); + } + + printf("Successfully loaded %s\n", input_file_path.c_str()); + + params.translate = translate; + params.language = language; + params.response_format = response_format; + if (!whisper_is_multilingual(ctx)) { + if (params.language != "en" || params.translate) { + params.language = "en"; + params.translate = false; + LOG_WARN + << "Model " << model_id + << " is not multilingual, ignoring language and translation options"; } - - printf("Successfully loaded %s\n", input_file_path.c_str()); - - params.translate = translate; - params.language = language; - params.response_format = response_format; - if (!whisper_is_multilingual(ctx)) - { - if (params.language != "en" || params.translate) - { - params.language = "en"; - params.translate = false; - LOG_WARN << "Model " << model_id << " is not multilingual, ignoring language and translation options"; - } + } + if (params.detect_language) { + params.language = "auto"; + } + + // print some processing info + std::string processing_info = + "Model " + model_id + "processing " + input_file_path + " (" + + std::to_string(pcmf32.size()) + " samples, " + + std::to_string(float(pcmf32.size()) / WHISPER_SAMPLE_RATE) + " sec), " + + std::to_string(params.n_threads) + " threads, " + + std::to_string(params.n_processors) + + " processors, lang = " + params.language + + ", task = " + (params.translate ? "translate" : "transcribe") + ", " + + (params.tinydiarize ? "tdrz = 1, " : "") + + (params.no_timestamps ? "timestamps = 0" : "timestamps = 1"); + LOG_INFO << processing_info; + + // run the inference + { + std::string msg = "Running whisper.cpp inference of model " + model_id + + " on " + input_file_path; + LOG_INFO << msg; + whisper_full_params wparams = + whisper_full_default_params(WHISPER_SAMPLING_GREEDY); + + wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH + : WHISPER_SAMPLING_GREEDY; + + wparams.print_realtime = false; + wparams.print_progress = params.print_progress; + wparams.print_timestamps = !params.no_timestamps; + wparams.print_special = params.print_special; + wparams.translate = params.translate; + wparams.language = params.language.c_str(); + wparams.detect_language = params.detect_language; + wparams.n_threads = params.n_threads; + wparams.n_max_text_ctx = + params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx; + wparams.offset_ms = params.offset_t_ms; + wparams.duration_ms = params.duration_ms; + + wparams.thold_pt = params.word_thold; + wparams.max_len = params.max_len == 0 ? 60 : params.max_len; + wparams.split_on_word = params.split_on_word; + + wparams.speed_up = params.speed_up; + wparams.debug_mode = params.debug_mode; + + wparams.tdrz_enable = params.tinydiarize; // [TDRZ] + + wparams.initial_prompt = prompt.c_str(); + + wparams.greedy.best_of = params.best_of; + wparams.beam_search.beam_size = params.beam_size; + + wparams.temperature = temperature; + wparams.temperature_inc = params.temperature_inc; + wparams.entropy_thold = params.entropy_thold; + wparams.logprob_thold = params.logprob_thold; + + wparams.no_timestamps = params.no_timestamps; + + whisper_print_user_data user_data = {¶ms, &pcmf32s, 0}; + + // this callback is called on each new segment + if (params.print_realtime) { + wparams.new_segment_callback = whisper_print_segment_callback; + wparams.new_segment_callback_user_data = &user_data; } - if (params.detect_language) - { - params.language = "auto"; + + if (wparams.print_progress) { + wparams.progress_callback = whisper_print_progress_callback; + wparams.progress_callback_user_data = &user_data; } - // print some processing info - std::string processing_info = "Model " + model_id + "processing " + input_file_path + " (" + std::to_string(pcmf32.size()) + " samples, " + std::to_string(float(pcmf32.size()) / WHISPER_SAMPLE_RATE) + " sec), " + std::to_string(params.n_threads) + " threads, " + std::to_string(params.n_processors) + " processors, lang = " + params.language + ", task = " + (params.translate ? "translate" : "transcribe") + ", " + (params.tinydiarize ? "tdrz = 1, " : "") + (params.no_timestamps ? "timestamps = 0" : "timestamps = 1"); - LOG_INFO << processing_info; + // examples for abort mechanism + // in examples below, we do not abort the processing, but we could if the + // flag is set to true - // run the inference + // the callback is called before every encoder run - if it returns false, + // the processing is aborted { - std::string msg = "Running whisper.cpp inference of model " + model_id + " on " + input_file_path; - LOG_INFO << msg; - whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); - - wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY; - - wparams.print_realtime = false; - wparams.print_progress = params.print_progress; - wparams.print_timestamps = !params.no_timestamps; - wparams.print_special = params.print_special; - wparams.translate = params.translate; - wparams.language = params.language.c_str(); - wparams.detect_language = params.detect_language; - wparams.n_threads = params.n_threads; - wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx; - wparams.offset_ms = params.offset_t_ms; - wparams.duration_ms = params.duration_ms; - - wparams.thold_pt = params.word_thold; - wparams.max_len = params.max_len == 0 ? 60 : params.max_len; - wparams.split_on_word = params.split_on_word; - - wparams.speed_up = params.speed_up; - wparams.debug_mode = params.debug_mode; - - wparams.tdrz_enable = params.tinydiarize; // [TDRZ] - - wparams.initial_prompt = prompt.c_str(); - - wparams.greedy.best_of = params.best_of; - wparams.beam_search.beam_size = params.beam_size; - - wparams.temperature = temperature; - wparams.temperature_inc = params.temperature_inc; - wparams.entropy_thold = params.entropy_thold; - wparams.logprob_thold = params.logprob_thold; - - wparams.no_timestamps = params.no_timestamps; - - whisper_print_user_data user_data = {¶ms, &pcmf32s, 0}; - - // this callback is called on each new segment - if (params.print_realtime) - { - wparams.new_segment_callback = whisper_print_segment_callback; - wparams.new_segment_callback_user_data = &user_data; - } - - if (wparams.print_progress) - { - wparams.progress_callback = whisper_print_progress_callback; - wparams.progress_callback_user_data = &user_data; - } - - // examples for abort mechanism - // in examples below, we do not abort the processing, but we could if the flag is set to true - - // the callback is called before every encoder run - if it returns false, the processing is aborted - { - static bool is_aborted = false; // NOTE: this should be atomic to avoid data race - - wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void *user_data) - { - bool is_aborted = *(bool *)user_data; - return !is_aborted; - }; - wparams.encoder_begin_callback_user_data = &is_aborted; - } - - // the callback is called before every computation - if it returns true, the computation is aborted - { - static bool is_aborted = false; // NOTE: this should be atomic to avoid data race - - wparams.abort_callback = [](void *user_data) - { - bool is_aborted = *(bool *)user_data; - return is_aborted; - }; - wparams.abort_callback_user_data = &is_aborted; - } - - if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) - { - std::string error_resp = "Failed to process audio"; - LOG_ERROR << error_resp; - whisper_mutex.unlock(); - throw std::runtime_error(error_resp); - } + static bool is_aborted = + false; // NOTE: this should be atomic to avoid data race + + wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, + struct whisper_state * /*state*/, + void *user_data) { + bool is_aborted = *(bool *)user_data; + return !is_aborted; + }; + wparams.encoder_begin_callback_user_data = &is_aborted; } - // return results to user - std::string result; - if (params.response_format == text_format) + // the callback is called before every computation - if it returns true, the + // computation is aborted { - result = output_str(ctx, params, pcmf32s); + static bool is_aborted = + false; // NOTE: this should be atomic to avoid data race + + wparams.abort_callback = [](void *user_data) { + bool is_aborted = *(bool *)user_data; + return is_aborted; + }; + wparams.abort_callback_user_data = &is_aborted; } - else if (params.response_format == srt_format) - { - std::stringstream ss; - const int n_segments = whisper_full_n_segments(ctx); - for (int i = 0; i < n_segments; ++i) - { - const char *text = whisper_full_get_segment_text(ctx, i); - const int64_t t0 = whisper_full_get_segment_t0(ctx, i); - const int64_t t1 = whisper_full_get_segment_t1(ctx, i); - std::string speaker = ""; - - if (params.diarize && pcmf32s.size() == 2) - { - speaker = estimate_diarization_speaker(pcmf32s, t0, t1); - } - - ss << i + 1 + params.offset_n << "\n"; - ss << to_timestamp(t0, true) << " --> " << to_timestamp(t1, true) << "\n"; - ss << speaker << text << "\n\n"; - } - result = ss.str(); + + if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), + params.n_processors) != 0) { + std::string error_resp = "Failed to process audio"; + LOG_ERROR << error_resp; + whisper_mutex.unlock(); + throw std::runtime_error(error_resp); } - else if (params.response_format == vtt_format) - { - std::stringstream ss; - - ss << "WEBVTT\n\n"; - - const int n_segments = whisper_full_n_segments(ctx); - for (int i = 0; i < n_segments; ++i) - { - const char *text = whisper_full_get_segment_text(ctx, i); - const int64_t t0 = whisper_full_get_segment_t0(ctx, i); - const int64_t t1 = whisper_full_get_segment_t1(ctx, i); - std::string speaker = ""; - - if (params.diarize && pcmf32s.size() == 2) - { - speaker = estimate_diarization_speaker(pcmf32s, t0, t1, true); - speaker.insert(0, ""); - } - - ss << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n"; - ss << speaker << text << "\n\n"; - } - result = ss.str(); + } + + // return results to user + std::string result; + if (params.response_format == text_format) { + result = output_str(ctx, params, pcmf32s); + } else if (params.response_format == srt_format) { + std::stringstream ss; + const int n_segments = whisper_full_n_segments(ctx); + for (int i = 0; i < n_segments; ++i) { + const char *text = whisper_full_get_segment_text(ctx, i); + const int64_t t0 = whisper_full_get_segment_t0(ctx, i); + const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + std::string speaker = ""; + + if (params.diarize && pcmf32s.size() == 2) { + speaker = estimate_diarization_speaker(pcmf32s, t0, t1); + } + + ss << i + 1 + params.offset_n << "\n"; + ss << to_timestamp(t0, true) << " --> " << to_timestamp(t1, true) << "\n"; + ss << speaker << text << "\n\n"; + } + result = ss.str(); + } else if (params.response_format == vtt_format) { + std::stringstream ss; + + ss << "WEBVTT\n\n"; + + const int n_segments = whisper_full_n_segments(ctx); + for (int i = 0; i < n_segments; ++i) { + const char *text = whisper_full_get_segment_text(ctx, i); + const int64_t t0 = whisper_full_get_segment_t0(ctx, i); + const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + std::string speaker = ""; + + if (params.diarize && pcmf32s.size() == 2) { + speaker = estimate_diarization_speaker(pcmf32s, t0, t1, true); + speaker.insert(0, ""); + } + + ss << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n"; + ss << speaker << text << "\n\n"; } - else if (params.response_format == vjson_format) { - /* try to match openai/whisper's Python format */ - std::string results = output_str(ctx, params, pcmf32s); - json jres = json{{"text", results}}; - const int n_segments = whisper_full_n_segments(ctx); - for (int i = 0; i < n_segments; ++i) - { - json segment = json{ - {"id", i}, - {"text", whisper_full_get_segment_text(ctx, i)}, - }; - - if (!params.no_timestamps) { - segment["start"] = whisper_full_get_segment_t0(ctx, i) * 0.01; - segment["end"] = whisper_full_get_segment_t1(ctx, i) * 0.01; - } - - const int n_tokens = whisper_full_n_tokens(ctx, i); - for (int j = 0; j < n_tokens; ++j) { - whisper_token_data token = whisper_full_get_token_data(ctx, i, j); - if (token.id >= whisper_token_eot(ctx)) { - continue; - } - - segment["tokens"].push_back(token.id); - json word = json{{"word", whisper_full_get_token_text(ctx, i, j)}}; - if (!params.no_timestamps) { - word["start"] = token.t0 * 0.01; - word["end"] = token.t1 * 0.01; - } - word["probability"] = token.p; - segment["words"].push_back(word); - } - jres["segments"].push_back(segment); + result = ss.str(); + } else if (params.response_format == vjson_format) { + /* try to match openai/whisper's Python format */ + std::string results = output_str(ctx, params, pcmf32s); + json jres = json{{"text", results}}; + const int n_segments = whisper_full_n_segments(ctx); + for (int i = 0; i < n_segments; ++i) { + json segment = json{ + {"id", i}, + {"text", whisper_full_get_segment_text(ctx, i)}, + }; + + if (!params.no_timestamps) { + segment["start"] = whisper_full_get_segment_t0(ctx, i) * 0.01; + segment["end"] = whisper_full_get_segment_t1(ctx, i) * 0.01; + } + + const int n_tokens = whisper_full_n_tokens(ctx, i); + for (int j = 0; j < n_tokens; ++j) { + whisper_token_data token = whisper_full_get_token_data(ctx, i, j); + if (token.id >= whisper_token_eot(ctx)) { + continue; } - result = jres.dump(-1, ' ', false, json::error_handler_t::replace); + + segment["tokens"].push_back(token.id); + json word = json{{"word", whisper_full_get_token_text(ctx, i, j)}}; + if (!params.no_timestamps) { + word["start"] = token.t0 * 0.01; + word["end"] = token.t1 * 0.01; } - else - { - std::string results = output_str(ctx, params, pcmf32s); - json jres = json{ - {"text", results}}; - result = jres.dump(-1, ' ', false, json::error_handler_t::replace); + word["probability"] = token.p; + segment["words"].push_back(word); + } + jres["segments"].push_back(segment); } + result = jres.dump(-1, ' ', false, json::error_handler_t::replace); + } else { + std::string results = output_str(ctx, params, pcmf32s); + json jres = json{{"text", results}}; + result = jres.dump(-1, ' ', false, json::error_handler_t::replace); + } - // reset params to thier defaults - params = default_params; + // reset params to thier defaults + params = default_params; - // return whisper model mutex lock - whisper_mutex.unlock(); - LOG_INFO << "Successfully processed " << input_file_path << ": " << result; + // return whisper model mutex lock + whisper_mutex.unlock(); + LOG_INFO << "Successfully processed " << input_file_path << ": " << result; - return result; + return result; } -whisper_server_context::~whisper_server_context() -{ - if (ctx) - { - whisper_print_timings(ctx); - whisper_free(ctx); - ctx = nullptr; - } +whisper_server_context::~whisper_server_context() { + if (ctx) { + whisper_print_timings(ctx); + whisper_free(ctx); + ctx = nullptr; + } } std::optional whisperCPP::parse_model_id( const std::shared_ptr &jsonBody, - const std::function &callback) -{ - if (!jsonBody->isMember("model_id")) { - LOG_INFO << "No model_id found in request body"; - Json::Value jsonResp; - jsonResp["message"] = "No model_id found in request body"; - auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); - resp->setStatusCode(k400BadRequest); - callback(resp); - return std::nullopt; // Signal that an error occurred - } + const std::function &callback) { + if (!jsonBody->isMember("model_id")) { + LOG_INFO << "No model_id found in request body"; + Json::Value jsonResp; + jsonResp["message"] = "No model_id found in request body"; + auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); + resp->setStatusCode(k400BadRequest); + callback(resp); + return std::nullopt; // Signal that an error occurred + } - return (*jsonBody)["model_id"].asString(); + return (*jsonBody)["model_id"].asString(); } void whisperCPP::load_model( const HttpRequestPtr &req, - std::function &&callback) -{ - const auto jsonBody = req->getJsonObject(); - auto optional_model_id = parse_model_id(jsonBody, callback); - if (!optional_model_id) { - return; - } - std::string model_id = *optional_model_id; - - // Check if model is already loaded - if (whispers.find(model_id) != whispers.end()) - { - std::string error_msg = "Model " + model_id + " already loaded"; - LOG_INFO << error_msg; - Json::Value jsonResp; - jsonResp["message"] = error_msg; - auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); - resp->setStatusCode(k409Conflict); - callback(resp); - return; - } - - // Model not loaded, load it - // Parse model path from request - std::string model_path = (*jsonBody)["model_path"].asString(); - if (!is_file_exist(model_path.c_str())) - { - std::string error_msg = "Model " + model_path + " not found"; - LOG_INFO << error_msg; - Json::Value jsonResp; - jsonResp["message"] = error_msg; - auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); - resp->setStatusCode(k404NotFound); - callback(resp); - return; - } - - - whisper_server_context whisper = whisper_server_context(model_id); - bool model_loaded = whisper.load_model(model_path); - // If model failed to load, return a 500 error - if (!model_loaded) - { - whisper.~whisper_server_context(); - std::string error_msg = "Failed to load model " + model_path; - LOG_INFO << error_msg; - Json::Value jsonResp; - jsonResp["message"] = error_msg; - auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); - resp->setStatusCode(k500InternalServerError); - callback(resp); - return; - } + std::function &&callback) { + const auto jsonBody = req->getJsonObject(); + auto optional_model_id = parse_model_id(jsonBody, callback); + if (!optional_model_id) { + return; + } + std::string model_id = *optional_model_id; - // Model loaded successfully, add it to the map of loaded models - // and return a 200 response - whispers.emplace(model_id, std::move(whisper)); + // Check if model is already loaded + if (whispers.find(model_id) != whispers.end()) { + std::string error_msg = "Model " + model_id + " already loaded"; + LOG_INFO << error_msg; Json::Value jsonResp; - std::string success_msg = "Model " + model_id + " loaded successfully"; - jsonResp["message"] = success_msg; + jsonResp["message"] = error_msg; auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); - resp->setStatusCode(k200OK); + resp->setStatusCode(k409Conflict); callback(resp); return; - + } + + // Model not loaded, load it + // Parse model path from request + std::string model_path = (*jsonBody)["model_path"].asString(); + if (!is_file_exist(model_path.c_str())) { + std::string error_msg = "Model " + model_path + " not found"; + LOG_INFO << error_msg; + Json::Value jsonResp; + jsonResp["message"] = error_msg; + auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); + resp->setStatusCode(k404NotFound); + callback(resp); + return; + } + + whisper_server_context whisper = whisper_server_context(model_id); + bool model_loaded = whisper.load_model(model_path); + // If model failed to load, return a 500 error + if (!model_loaded) { + whisper.~whisper_server_context(); + std::string error_msg = "Failed to load model " + model_path; + LOG_INFO << error_msg; + Json::Value jsonResp; + jsonResp["message"] = error_msg; + auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); + resp->setStatusCode(k500InternalServerError); + callback(resp); + return; + } + + // Model loaded successfully, add it to the map of loaded models + // and return a 200 response + whispers.emplace(model_id, std::move(whisper)); + Json::Value jsonResp; + std::string success_msg = "Model " + model_id + " loaded successfully"; + jsonResp["message"] = success_msg; + auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); + resp->setStatusCode(k200OK); + callback(resp); + return; } void whisperCPP::unload_model( const HttpRequestPtr &req, - std::function &&callback) -{ - const auto &jsonBody = req->getJsonObject(); - auto optional_model_id = parse_model_id(jsonBody, callback); - if (!optional_model_id) { - return; - } - std::string model_id = *optional_model_id; - - // If model is not loaded, return a 404 error - if (whispers.find(model_id) == whispers.end()) - { - std::string error_msg = "Model " + model_id + " has not been loaded, please load that model into nitro"; - LOG_INFO << error_msg; - Json::Value jsonResp; - jsonResp["message"] = error_msg; - auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); - resp->setStatusCode(k404NotFound); - callback(resp); - return; - } - - // Model loaded, unload it - whispers[model_id].~whisper_server_context(); - whispers.erase(model_id); - - // Return a 200 response + std::function &&callback) { + const auto &jsonBody = req->getJsonObject(); + auto optional_model_id = parse_model_id(jsonBody, callback); + if (!optional_model_id) { + return; + } + std::string model_id = *optional_model_id; + + // If model is not loaded, return a 404 error + if (whispers.find(model_id) == whispers.end()) { + std::string error_msg = + "Model " + model_id + + " has not been loaded, please load that model into nitro"; + LOG_INFO << error_msg; Json::Value jsonResp; - std::string success_msg = "Model " + model_id + " unloaded successfully"; - LOG_INFO << success_msg; - jsonResp["message"] = success_msg; + jsonResp["message"] = error_msg; auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); - resp->setStatusCode(k200OK); + resp->setStatusCode(k404NotFound); callback(resp); return; + } + + // Model loaded, unload it + whispers[model_id].~whisper_server_context(); + whispers.erase(model_id); + + // Return a 200 response + Json::Value jsonResp; + std::string success_msg = "Model " + model_id + " unloaded successfully"; + LOG_INFO << success_msg; + jsonResp["message"] = success_msg; + auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); + resp->setStatusCode(k200OK); + callback(resp); + return; } void whisperCPP::list_model( const HttpRequestPtr &req, - std::function &&callback) -{ - // Return a list of all loaded models - Json::Value jsonResp; - Json::Value models; - for (auto const &model : whispers) - { - models.append(model.first); - } - jsonResp["models"] = models; - auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); - resp->setStatusCode(k200OK); - callback(resp); - return; + std::function &&callback) { + // Return a list of all loaded models + Json::Value jsonResp; + Json::Value models; + for (auto const &model : whispers) { + models.append(model.first); + } + jsonResp["models"] = models; + auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); + resp->setStatusCode(k200OK); + callback(resp); + return; } void whisperCPP::transcription_impl( const HttpRequestPtr &req, - std::function &&callback, - bool translate) -{ - MultiPartParser partParser; - Json::Value jsonResp; - if (partParser.parse(req) != 0 || partParser.getFiles().size() != 1) - { - auto resp = HttpResponse::newHttpResponse(); - resp->setBody("Must have exactly one file"); - resp->setStatusCode(k403Forbidden); - callback(resp); - return; - } - auto &file = partParser.getFiles()[0]; - const auto &formFields = partParser.getParameters(); - - // Check if model_id are present in the request. If not, return a 400 error - if (formFields.find("model_id") == formFields.end()) - { - LOG_INFO << "No model_id found in request body"; - Json::Value jsonResp; - jsonResp["message"] = "No model_id found in request body"; - auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); - resp->setStatusCode(k400BadRequest); - callback(resp); - return; - } - - std::string model_id = formFields.at("model_id"); - - // Parse all other optional parameters from the request - std::string language = formFields.find("language") != formFields.end() ? formFields.at("language") : "en"; - std::string prompt = formFields.find("prompt") != formFields.end() ? formFields.at("prompt") : ""; - std::string response_format = formFields.find("response_format") != formFields.end() ? formFields.at("response_format") : json_format; - float temperature = formFields.find("temperature") != formFields.end() ? std::stof(formFields.at("temperature")) : 0; - - // Check if model is loaded. If not, return a 404 error - if (whispers.find(model_id) == whispers.end()) - { - std::string error_msg = "Model " + model_id + " has not been loaded, please load that model into nitro"; - LOG_INFO << error_msg; - Json::Value jsonResp; - jsonResp["message"] = error_msg; - auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); - resp->setStatusCode(k404NotFound); - callback(resp); - return; - } + std::function &&callback, bool translate) { + MultiPartParser partParser; + Json::Value jsonResp; + if (partParser.parse(req) != 0 || partParser.getFiles().size() != 1) { + auto resp = HttpResponse::newHttpResponse(); + resp->setBody("Must have exactly one file"); + resp->setStatusCode(k403Forbidden); + callback(resp); + return; + } + auto &file = partParser.getFiles()[0]; + const auto &formFields = partParser.getParameters(); - // Save input file to temp location - std::string temp_dir = std::filesystem::temp_directory_path().string() + "/" + std::to_string(std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count()); - // Create the directory - std::filesystem::create_directory(temp_dir); - // Save the file to the directory, with its original name - std::string temp_file_path = temp_dir + "/" + file.getFileName(); - file.saveAs(temp_file_path); - - // Run inference - std::string result; - try { - result = whispers[model_id].inference(temp_file_path, language, prompt, response_format, temperature, translate); - } catch (const std::exception &e) { - std::remove(temp_file_path.c_str()); - Json::Value jsonResp; - jsonResp["message"] = e.what(); - auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); - resp->setStatusCode(k500InternalServerError); - callback(resp); - return; - } - // TODO: Need to remove the entire temp directory, not just the file + // Check if model_id are present in the request. If not, return a 400 error + if (formFields.find("model_id") == formFields.end()) { + LOG_INFO << "No model_id found in request body"; + Json::Value jsonResp; + jsonResp["message"] = "No model_id found in request body"; + auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + std::string model_id = formFields.at("model_id"); + + // Parse all other optional parameters from the request + std::string language = formFields.find("language") != formFields.end() + ? formFields.at("language") + : "en"; + std::string prompt = formFields.find("prompt") != formFields.end() + ? formFields.at("prompt") + : ""; + std::string response_format = + formFields.find("response_format") != formFields.end() + ? formFields.at("response_format") + : json_format; + float temperature = formFields.find("temperature") != formFields.end() + ? std::stof(formFields.at("temperature")) + : 0; + + // Check if model is loaded. If not, return a 404 error + if (whispers.find(model_id) == whispers.end()) { + std::string error_msg = + "Model " + model_id + + " has not been loaded, please load that model into nitro"; + LOG_INFO << error_msg; + Json::Value jsonResp; + jsonResp["message"] = error_msg; + auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); + resp->setStatusCode(k404NotFound); + callback(resp); + return; + } + + // Save input file to temp location + std::string temp_dir = + std::filesystem::temp_directory_path().string() + "/" + + std::to_string(std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count()); + // Create the directory + std::filesystem::create_directory(temp_dir); + // Save the file to the directory, with its original name + std::string temp_file_path = temp_dir + "/" + file.getFileName(); + file.saveAs(temp_file_path); + + // Run inference + std::string result; + try { + result = + whispers[model_id].inference(temp_file_path, language, prompt, + response_format, temperature, translate); + } catch (const std::exception &e) { std::remove(temp_file_path.c_str()); - - auto resp = nitro_utils::nitroHttpResponse(); - resp->setBody(result); - resp->setStatusCode(k200OK); - // Set content type based on response format - if (response_format == json_format || response_format == vjson_format) - { - resp->addHeader("Content-Type", "application/json"); - } - else if (response_format == text_format) - { - resp->addHeader("Content-Type", "text/html"); - } - else if (response_format == srt_format) - { - resp->addHeader("Content-Type", "application/x-subrip"); - } - else if (response_format == vtt_format) - { - resp->addHeader("Content-Type", "text/vtt"); - } + Json::Value jsonResp; + jsonResp["message"] = e.what(); + auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); + resp->setStatusCode(k500InternalServerError); callback(resp); return; + } + // TODO: Need to remove the entire temp directory, not just the file + std::remove(temp_file_path.c_str()); + + auto resp = nitro_utils::nitroHttpResponse(); + resp->setBody(result); + resp->setStatusCode(k200OK); + // Set content type based on response format + if (response_format == json_format || response_format == vjson_format) { + resp->addHeader("Content-Type", "application/json"); + } else if (response_format == text_format) { + resp->addHeader("Content-Type", "text/html"); + } else if (response_format == srt_format) { + resp->addHeader("Content-Type", "application/x-subrip"); + } else if (response_format == vtt_format) { + resp->addHeader("Content-Type", "text/vtt"); + } + callback(resp); + return; } - void whisperCPP::transcription( const HttpRequestPtr &req, - std::function &&callback) -{ - return transcription_impl(req, std::move(callback), false); + std::function &&callback) { + return transcription_impl(req, std::move(callback), false); } - void whisperCPP::translation( const HttpRequestPtr &req, - std::function &&callback) -{ - return transcription_impl(req, std::move(callback), true); + std::function &&callback) { + return transcription_impl(req, std::move(callback), true); } \ No newline at end of file diff --git a/controllers/whisperCPP.h b/controllers/whisperCPP.h index a786c2df2..77b4b3898 100644 --- a/controllers/whisperCPP.h +++ b/controllers/whisperCPP.h @@ -1,32 +1,25 @@ #pragma once +#include "whisper.h" +#include #include #include -#include #include -#include "whisper.h" #define DR_WAV_IMPLEMENTATION #include "utils/dr_wav.h" -#include "utils/nitro_utils.h" #include "utils/json.hpp" +#include "utils/nitro_utils.h" using json = nlohmann::ordered_json; // Terminal color map. 10 colors grouped in ranges [0.0, 0.1, ..., 0.9] // Lowest is red, middle is yellow, highest is green. const std::vector k_colors = { - "\033[38;5;196m", - "\033[38;5;202m", - "\033[38;5;208m", - "\033[38;5;214m", - "\033[38;5;220m", - "\033[38;5;226m", - "\033[38;5;190m", - "\033[38;5;154m", - "\033[38;5;118m", - "\033[38;5;82m", + "\033[38;5;196m", "\033[38;5;202m", "\033[38;5;208m", "\033[38;5;214m", + "\033[38;5;220m", "\033[38;5;226m", "\033[38;5;190m", "\033[38;5;154m", + "\033[38;5;118m", "\033[38;5;82m", }; // output formats @@ -36,76 +29,79 @@ const std::string srt_format = "srt"; const std::string vjson_format = "verbose_json"; const std::string vtt_format = "vtt"; -struct whisper_params -{ - int32_t n_threads = std::min(4, (int32_t)std::thread::hardware_concurrency()); - int32_t n_processors = 1; - int32_t offset_t_ms = 0; - int32_t offset_n = 0; - int32_t duration_ms = 0; - int32_t progress_step = 5; - int32_t max_context = -1; - int32_t max_len = 0; - int32_t best_of = 2; - int32_t beam_size = -1; - - float word_thold = 0.01f; - float entropy_thold = 2.40f; - float logprob_thold = -1.00f; - float temperature = 0.00f; - float temperature_inc = 0.20f; - - bool speed_up = false; - bool debug_mode = false; - bool translate = false; - bool detect_language = false; - bool diarize = false; - bool tinydiarize = false; - bool split_on_word = false; - bool no_fallback = false; - bool print_special = false; - bool print_colors = false; - bool print_realtime = false; - bool print_progress = false; - bool no_timestamps = false; - bool use_gpu = true; - bool ffmpeg_converter = false; - - std::string language = "en"; - std::string prompt = ""; - std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf"; - std::string model = "models/ggml-base.en.bin"; - - std::string response_format = json_format; - - // [TDRZ] speaker turn string - std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line - - std::string openvino_encode_device = "CPU"; +struct whisper_params { + int32_t n_threads = + (std::min)(4, (int32_t)std::thread::hardware_concurrency()); + int32_t n_processors = 1; + int32_t offset_t_ms = 0; + int32_t offset_n = 0; + int32_t duration_ms = 0; + int32_t progress_step = 5; + int32_t max_context = -1; + int32_t max_len = 0; + int32_t best_of = 2; + int32_t beam_size = -1; + + float word_thold = 0.01f; + float entropy_thold = 2.40f; + float logprob_thold = -1.00f; + float temperature = 0.00f; + float temperature_inc = 0.20f; + + bool speed_up = false; + bool debug_mode = false; + bool translate = false; + bool detect_language = false; + bool diarize = false; + bool tinydiarize = false; + bool split_on_word = false; + bool no_fallback = false; + bool print_special = false; + bool print_colors = false; + bool print_realtime = false; + bool print_progress = false; + bool no_timestamps = false; + bool use_gpu = true; + bool ffmpeg_converter = false; + + std::string language = "en"; + std::string prompt = ""; + std::string font_path = + "/System/Library/Fonts/Supplemental/Courier New Bold.ttf"; + std::string model = "models/ggml-base.en.bin"; + + std::string response_format = json_format; + + // [TDRZ] speaker turn string + std::string tdrz_speaker_turn = + " [SPEAKER_TURN]"; // TODO: set from command line + + std::string openvino_encode_device = "CPU"; }; -struct whisper_print_user_data -{ - const whisper_params *params; +struct whisper_print_user_data { + const whisper_params *params; - const std::vector> *pcmf32s; - int progress_prev; + const std::vector> *pcmf32s; + int progress_prev; }; #define COMMON_SAMPLE_RATE 16000 // Read WAV audio file and store the PCM data into pcmf32 // The sample rate of the audio must be equal to COMMON_SAMPLE_RATE -// If stereo flag is set and the audio has 2 channels, the pcmf32s will contain 2 channel PCM -bool read_wav( - const std::string &fname, - std::vector &pcmf32, - std::vector> &pcmf32s, - bool stereo); +// If stereo flag is set and the audio has 2 channels, the pcmf32s will contain +// 2 channel PCM +bool read_wav(const std::string &fname, std::vector &pcmf32, + std::vector> &pcmf32s, bool stereo); -std::string output_str(struct whisper_context *ctx, const whisper_params ¶ms, std::vector> pcmf32s); +std::string output_str(struct whisper_context *ctx, + const whisper_params ¶ms, + std::vector> pcmf32s); -std::string estimate_diarization_speaker(std::vector> pcmf32s, int64_t t0, int64_t t1, bool id_only = false); +std::string +estimate_diarization_speaker(std::vector> pcmf32s, + int64_t t0, int64_t t1, bool id_only = false); // 500 -> 00:05.000 // 6000 -> 01:00.000 @@ -115,7 +111,8 @@ int timestamp_to_sample(int64_t t, int n_samples); bool is_file_exist(const char *fileName); -void whisper_print_usage(int /*argc*/, char **argv, const whisper_params ¶ms); +void whisper_print_usage(int /*argc*/, char **argv, + const whisper_params ¶ms); bool whisper_params_parse(int argc, char **argv, whisper_params ¶ms); @@ -123,96 +120,99 @@ void check_ffmpeg_availibility(); bool convert_to_wav(const std::string &temp_filename, std::string &error_resp); -void whisper_print_progress_callback(struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, int progress, void *user_data); +void whisper_print_progress_callback(struct whisper_context * /*ctx*/, + struct whisper_state * /*state*/, + int progress, void *user_data); -void whisper_print_segment_callback(struct whisper_context *ctx, struct whisper_state * /*state*/, int n_new, void *user_data); +void whisper_print_segment_callback(struct whisper_context *ctx, + struct whisper_state * /*state*/, int n_new, + void *user_data); bool parse_str_to_bool(const std::string &s); -struct whisper_server_context -{ - whisper_params params; - whisper_params default_params; - std::mutex whisper_mutex; - std::string model_id; - - struct whisper_context_params cparams; - struct whisper_context *ctx = nullptr; - - whisper_server_context() = default; // add this line - - // Constructor - whisper_server_context(const std::string &model_id) - { - this->model_id = model_id; - this->cparams = whisper_context_params(); - this->ctx = nullptr; - // store default params so we can reset after each inference request - this->default_params = whisper_params(); - this->params = whisper_params(); - } - - // Move constructor - whisper_server_context(whisper_server_context&& other) noexcept - : params(std::move(other.params)) - , default_params(std::move(other.default_params)) - , whisper_mutex() // std::mutex is not movable, so we initialize a new one - , model_id(std::move(other.model_id)) - , cparams(std::move(other.cparams)) - , ctx(std::exchange(other.ctx, nullptr)) // ctx is a raw pointer, so we use std::exchange - { - } - - bool load_model(std::string &model_path); - - std::string inference(std::string &input_file_path, std::string languague, std::string prompt, - std::string response_format, float temperature, bool translate); - - ~whisper_server_context(); +struct whisper_server_context { + whisper_params params; + whisper_params default_params; + std::mutex whisper_mutex; + std::string model_id; + + struct whisper_context_params cparams; + struct whisper_context *ctx = nullptr; + + whisper_server_context() = default; // add this line + + // Constructor + whisper_server_context(const std::string &model_id) { + this->model_id = model_id; + this->cparams = whisper_context_params(); + this->ctx = nullptr; + // store default params so we can reset after each inference request + this->default_params = whisper_params(); + this->params = whisper_params(); + } + + // Move constructor + whisper_server_context(whisper_server_context &&other) noexcept + : params(std::move(other.params)), + default_params(std::move(other.default_params)), + whisper_mutex() // std::mutex is not movable, so we initialize a new one + , + model_id(std::move(other.model_id)), cparams(std::move(other.cparams)), + ctx(std::exchange( + other.ctx, + nullptr)) // ctx is a raw pointer, so we use std::exchange + {} + + bool load_model(std::string &model_path); + + std::string inference(std::string &input_file_path, std::string languague, + std::string prompt, std::string response_format, + float temperature, bool translate); + + ~whisper_server_context(); }; using namespace drogon; -class whisperCPP : public drogon::HttpController -{ +class whisperCPP : public drogon::HttpController { public: - METHOD_LIST_BEGIN - - ADD_METHOD_TO(whisperCPP::load_model, "/v1/audio/load_model", Post); - ADD_METHOD_TO(whisperCPP::unload_model, "/v1/audio/unload_model", Post); - ADD_METHOD_TO(whisperCPP::list_model, "/v1/audio/list_model", Get); + METHOD_LIST_BEGIN - ADD_METHOD_TO(whisperCPP::transcription, "/v1/audio/transcriptions", Post); - ADD_METHOD_TO(whisperCPP::translation, "/v1/audio/translations", Post); + ADD_METHOD_TO(whisperCPP::load_model, "/v1/audio/load_model", Post); + ADD_METHOD_TO(whisperCPP::unload_model, "/v1/audio/unload_model", Post); + ADD_METHOD_TO(whisperCPP::list_model, "/v1/audio/list_model", Get); - METHOD_LIST_END + ADD_METHOD_TO(whisperCPP::transcription, "/v1/audio/transcriptions", Post); + ADD_METHOD_TO(whisperCPP::translation, "/v1/audio/translations", Post); - whisperCPP() { - whisper_print_system_info(); - } + METHOD_LIST_END - void load_model(const HttpRequestPtr &req, - std::function &&callback); + whisperCPP() { whisper_print_system_info(); } - void unload_model(const HttpRequestPtr &req, - std::function &&callback); + void load_model(const HttpRequestPtr &req, + std::function &&callback); - void list_model(const HttpRequestPtr &req, + void unload_model(const HttpRequestPtr &req, std::function &&callback); - void transcription(const HttpRequestPtr &req, - std::function &&callback); + void list_model(const HttpRequestPtr &req, + std::function &&callback); - void translation(const HttpRequestPtr &req, + void transcription(const HttpRequestPtr &req, std::function &&callback); + void translation(const HttpRequestPtr &req, + std::function &&callback); + private: - std::unordered_map whispers; + std::unordered_map whispers; - std::optionalparse_model_id(const std::shared_ptr &jsonBody, - const std::function &callback); + std::optional + parse_model_id(const std::shared_ptr &jsonBody, + const std::function &callback); - void transcription_impl(const HttpRequestPtr &req, - std::function &&callback, - bool translate); + void + transcription_impl(const HttpRequestPtr &req, + std::function &&callback, + bool translate); }; From 496bb0041e22f1f55362f44377eba50cadb711c0 Mon Sep 17 00:00:00 2001 From: hiro Date: Sat, 27 Jan 2024 12:18:23 +0700 Subject: [PATCH 24/31] fix: Update windows with cuda build note --- audio.md | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/audio.md b/audio.md index e4e90e629..bb56a12b2 100644 --- a/audio.md +++ b/audio.md @@ -26,6 +26,17 @@ cmake -B build -DWHISPER_COREML=1 cmake --build build -j --config Release ``` +### For Windows with CUDA + +``` +mkdir -p build +cd build +cmake .. -DLLAMA_CUBLAS=ON -DBUILD_SHARED_LIBS=ON -DWHISPER_CUBLAS=ON -DWHISPER_SDL2=ON +cmake --build . --config Release + +# Then copy llama.dll, whisper.dll and zlib.dll +``` + ## Sample test command - Download `ggml-base.en.bin` with [whisper.cpp/models/download-ggml-model.sh](whisper.cpp/models/download-ggml-model.sh) @@ -52,6 +63,12 @@ curl 127.0.0.1:3928/v1/audio/list_model wget https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav ``` +- The input needs to be converted to expected one + +``` +ffmpeg -i INPUT.MP3 -ar 16000 -ac 1 -c:a pcm_s16le OUTPUT.WAV +``` + - Sample transcription: ```bash From b71a2ef11a7da99772d1ff374fd56cf08ad43c47 Mon Sep 17 00:00:00 2001 From: hiro Date: Sat, 27 Jan 2024 12:18:35 +0700 Subject: [PATCH 25/31] feat: Add whisper.cpp build for nitro audio support --- .github/workflows/build.yml | 33 ++++++++++++++++++++++++++++----- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 5418e9b37..b356b1f22 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -196,7 +196,7 @@ jobs: run: | ./install_deps.sh mkdir build && cd build - cmake -DLLAMA_CUBLAS=ON -DLLAMA_NATIVE=OFF -DNITRO_VERSION=${{ needs.set-nitro-version.outputs.version }} .. + cmake -DLLAMA_NATIVE=OFF -DLLAMA_CUBLAS=ON -DLLAMA_CUBLAS=ON -DWHISPER_CUBLAS=ON -DNITRO_VERSION=${{ needs.set-nitro-version.outputs.version }} .. make -j $(nproc) ls -la @@ -249,14 +249,14 @@ jobs: continue-on-error: true run: | brew update - brew install cmake + brew install cmake sdl2 - name: Build id: cmake_build run: | ./install_deps.sh mkdir build && cd build - cmake -DNITRO_VERSION=${{ needs.set-nitro-version.outputs.version }} .. + cmake -DWHISPER_COREML=1 -DNITRO_VERSION=${{ needs.set-nitro-version.outputs.version }} .. CC=gcc-8 make -j $(sysctl -n hw.ncp) ls -la @@ -310,6 +310,7 @@ jobs: continue-on-error: true run: | brew update + brew install sdl2 - name: Build id: cmake_build @@ -374,6 +375,15 @@ jobs: env: ACTIONS_ALLOW_UNSECURE_COMMANDS: true + - name: Add msbuild to PATH + uses: microsoft/setup-msbuild@v1 + + - name: Fetch SDL2 and set SDL2_DIR version 2.28.5 + run: | + C:/msys64/usr/bin/wget.exe -qO sdl2.zip https://github.com/libsdl-org/SDL/releases/download/release-2.28.5/SDL2-devel-2.28.5-VC.zip + 7z x sdl2.zip + echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-2.28.5/cmake" >> $env:GITHUB_ENV + - name: actions-setup-cmake uses: jwlawson/actions-setup-cmake@v1.14.1 @@ -385,7 +395,7 @@ jobs: cmake --build ./build_deps/nitro_deps --config Release mkdir -p build cd build - cmake .. -DLLAMA_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DLLAMA_BLAS=ON -DBUILD_SHARED_LIBS=ON -DNITRO_VERSION=${{ needs.set-nitro-version.outputs.version }} + cmake .. -DLLAMA_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DLLAMA_BLAS=ON -DBUILD_SHARED_LIBS=ON -DCMAKE_BUILD_TYPE=RELEASE -DWHISPER_SDL2=ON -DNITRO_VERSION=${{ needs.set-nitro-version.outputs.version }} cmake --build . --config Release -j "%NUMBER_OF_PROCESSORS%" - name: Pack artifacts @@ -394,6 +404,8 @@ jobs: run: | robocopy build_deps\_install\bin\ .\build\Release\ zlib.dll robocopy build\bin\Release\ .\build\Release\ llama.dll + robocopy build\bin\Release\ .\build\Release\ whisper.dll + robocopy "$env:SDL2_DIR\..\lib\2.28.5\" .\build\Release\ SDL2.dll dotnet tool install --global AzureSignTool azuresigntool.exe sign -kvu "${{ secrets.AZURE_KEY_VAULT_URI }}" -kvi "${{ secrets.AZURE_CLIENT_ID }}" -kvt "${{ secrets.AZURE_TENANT_ID }}" -kvs "${{ secrets.AZURE_CLIENT_SECRET }}" -kvc ${{ secrets.AZURE_CERT_NAME }} -tr http://timestamp.globalsign.com/tsa/r6advanced1 -v ".\build\Release\nitro.exe" 7z a -ttar temp.tar .\build\Release\* @@ -442,6 +454,15 @@ jobs: env: ACTIONS_ALLOW_UNSECURE_COMMANDS: true + - name: Add msbuild to PATH + uses: microsoft/setup-msbuild@v1 + + - name: Fetch SDL2 and set SDL2_DIR version 2.28.5 + run: | + C:/msys64/usr/bin/wget.exe -qO sdl2.zip https://github.com/libsdl-org/SDL/releases/download/release-2.28.5/SDL2-devel-2.28.5-VC.zip + 7z x sdl2.zip + echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-2.28.5/cmake" >> $env:GITHUB_ENV + - name: actions-setup-cmake uses: jwlawson/actions-setup-cmake@v1.14.1 @@ -471,7 +492,7 @@ jobs: cmake --build ./build_deps/nitro_deps --config Release mkdir -p build cd build - cmake .. -DLLAMA_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DLLAMA_CUBLAS=ON -DBUILD_SHARED_LIBS=ON -DNITRO_VERSION=${{ needs.set-nitro-version.outputs.version }} + cmake .. -DLLAMA_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DLLAMA_CUBLAS=ON -DBUILD_SHARED_LIBS=ON -DCMAKE_BUILD_TYPE=RELEASE -DWHISPER_SDL2=ON -DWHISPER_CUBLAS=ON -DNITRO_VERSION=${{ needs.set-nitro-version.outputs.version }} cmake --build . --config Release -j "%NUMBER_OF_PROCESSORS%" - name: Pack artifacts @@ -481,6 +502,8 @@ jobs: set PATH=%PATH%;C:\Program Files\7-Zip\ robocopy build_deps\_install\bin\ .\build\Release\ zlib.dll robocopy build\bin\Release\ .\build\Release\ llama.dll + robocopy build\bin\Release\ .\build\Release\ whisper.dll + robocopy "$env:SDL2_DIR\..\lib\2.28.5\" .\build\Release\ SDL2.dll dotnet tool install --global AzureSignTool %USERPROFILE%\.dotnet\tools\azuresigntool.exe sign -kvu "${{ secrets.AZURE_KEY_VAULT_URI }}" -kvi "${{ secrets.AZURE_CLIENT_ID }}" -kvt "${{ secrets.AZURE_TENANT_ID }}" -kvs "${{ secrets.AZURE_CLIENT_SECRET }}" -kvc ${{ secrets.AZURE_CERT_NAME }} -tr http://timestamp.globalsign.com/tsa/r6advanced1 -v ".\build\Release\nitro.exe" 7z a -ttar temp.tar .\build\Release\* From 81978656448d9094f8b641b0659b3b222639535f Mon Sep 17 00:00:00 2001 From: hiro Date: Sat, 27 Jan 2024 12:49:10 +0700 Subject: [PATCH 26/31] fix: Migrate old e2e test script to new file name for llama --- ...mac.sh => e2e-test-llama-linux-and-mac.sh} | 28 +++++++++---------- ...windows.bat => e2e-test-llama-windows.bat} | 4 +-- 2 files changed, 16 insertions(+), 16 deletions(-) rename .github/scripts/{e2e-test-linux-and-mac.sh => e2e-test-llama-linux-and-mac.sh} (77%) rename .github/scripts/{e2e-test-windows.bat => e2e-test-llama-windows.bat} (96%) diff --git a/.github/scripts/e2e-test-linux-and-mac.sh b/.github/scripts/e2e-test-llama-linux-and-mac.sh similarity index 77% rename from .github/scripts/e2e-test-linux-and-mac.sh rename to .github/scripts/e2e-test-llama-linux-and-mac.sh index 1c5fdf7f4..d7f1b5ab8 100644 --- a/.github/scripts/e2e-test-linux-and-mac.sh +++ b/.github/scripts/e2e-test-llama-linux-and-mac.sh @@ -21,12 +21,12 @@ range=$((max - min + 1)) PORT=$((RANDOM % range + min)) # Start the binary file -"$BINARY_PATH" 1 127.0.0.1 $PORT > /tmp/nitro.log 2>&1 & +"$BINARY_PATH" 1 127.0.0.1 $PORT >/tmp/nitro.log 2>&1 & # Get the process id of the binary file pid=$! -if ! ps -p $pid > /dev/null; then +if ! ps -p $pid >/dev/null; then echo "nitro failed to start. Logs:" cat /tmp/nitro.log exit 1 @@ -35,26 +35,27 @@ fi # Wait for a few seconds to let the server start sleep 5 -# Check if /tmp/testmodel exists, if not, download it -if [[ ! -f "/tmp/testmodel" ]]; then - wget $DOWNLOAD_URL -O /tmp/testmodel +# Check if /tmp/testllm exists, if not, download it +if [[ ! -f "/tmp/testllm" ]]; then + wget $DOWNLOAD_URL -O /tmp/testllm fi # Run the curl commands response1=$(curl -o /tmp/response1.log -s -w "%{http_code}" --location "http://127.0.0.1:$PORT/inferences/llamacpp/loadModel" \ ---header 'Content-Type: application/json' \ ---data '{ - "llama_model_path": "/tmp/testmodel", + --header 'Content-Type: application/json' \ + --data '{ + "llama_model_path": "/tmp/testllm", "ctx_len": 50, "ngl": 32, "embedding": false }' 2>&1) -response2=$(curl -o /tmp/response2.log -s -w "%{http_code}" --location "http://127.0.0.1:$PORT/inferences/llamacpp/chat_completion" \ ---header 'Content-Type: application/json' \ ---header 'Accept: text/event-stream' \ ---header 'Access-Control-Allow-Origin: *' \ ---data '{ +response2=$( + curl -o /tmp/response2.log -s -w "%{http_code}" --location "http://127.0.0.1:$PORT/v1/chat/completions" \ + --header 'Content-Type: application/json' \ + --header 'Accept: text/event-stream' \ + --header 'Access-Control-Allow-Origin: *' \ + --data '{ "messages": [ {"content": "Hello there", "role": "assistant"}, {"content": "Write a long and sad story for me", "role": "user"} @@ -98,7 +99,6 @@ echo "----------------------" echo "Log run test:" cat /tmp/response2.log - echo "Nitro test run successfully!" # Kill the server process diff --git a/.github/scripts/e2e-test-windows.bat b/.github/scripts/e2e-test-llama-windows.bat similarity index 96% rename from .github/scripts/e2e-test-windows.bat rename to .github/scripts/e2e-test-llama-windows.bat index 96a5385de..9a758d5c1 100644 --- a/.github/scripts/e2e-test-windows.bat +++ b/.github/scripts/e2e-test-llama-windows.bat @@ -1,7 +1,7 @@ @echo off set "TEMP=C:\Users\%UserName%\AppData\Local\Temp" -set "MODEL_PATH=%TEMP%\testmodel" +set "MODEL_PATH=%TEMP%\testllm" rem Check for required arguments if "%~2"=="" ( @@ -62,7 +62,7 @@ echo curl_data2=%curl_data2% rem Run the curl commands and capture the status code curl.exe -o %TEMP%\response1.log -s -w "%%{http_code}" --location "http://127.0.0.1:%PORT%/inferences/llamacpp/loadModel" --header "Content-Type: application/json" --data "%curl_data1%" > %TEMP%\response1_code.log 2>&1 -curl.exe -o %TEMP%\response2.log -s -w "%%{http_code}" --location "http://127.0.0.1:%PORT%/inferences/llamacpp/chat_completion" ^ +curl.exe -o %TEMP%\response2.log -s -w "%%{http_code}" --location "http://127.0.0.1:%PORT%/v1/chat/completions" ^ --header "Content-Type: application/json" ^ --header "Accept: text/event-stream" ^ --header "Access-Control-Allow-Origin: *" ^ From 3323a08c64040864ad4a89cc9ba70a62b0a4d015 Mon Sep 17 00:00:00 2001 From: hiro Date: Sat, 27 Jan 2024 12:49:22 +0700 Subject: [PATCH 27/31] feat: Add e2e script to test whisper.cpp --- .../scripts/e2e-test-whisper-linux-and-mac.sh | 93 +++++++++++++++ .github/scripts/e2e-test-whisper-windows.bat | 109 ++++++++++++++++++ .github/workflows/build.yml | 2 +- 3 files changed, 203 insertions(+), 1 deletion(-) create mode 100644 .github/scripts/e2e-test-whisper-linux-and-mac.sh create mode 100644 .github/scripts/e2e-test-whisper-windows.bat diff --git a/.github/scripts/e2e-test-whisper-linux-and-mac.sh b/.github/scripts/e2e-test-whisper-linux-and-mac.sh new file mode 100644 index 000000000..d1ce007cd --- /dev/null +++ b/.github/scripts/e2e-test-whisper-linux-and-mac.sh @@ -0,0 +1,93 @@ +#!/bin/bash + +## Example run command +# ./linux-and-mac.sh './jan/plugins/@janhq/inference-plugin/dist/nitro/nitro_mac_arm64' https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/resolve/main/tinyllama-1.1b-chat-v0.3.Q2_K.gguf + +# Check for required arguments +if [[ $# -ne 2 ]]; then + echo "Usage: $0 " + exit 1 +fi + +rm /tmp/response1.log /tmp/response2.log /tmp/nitro.log + +BINARY_PATH=$1 +DOWNLOAD_URL=$2 + +# Random port to ensure it's not used +min=10000 +max=11000 +range=$((max - min + 1)) +PORT=$((RANDOM % range + min)) + +# Start the binary file +"$BINARY_PATH" 1 127.0.0.1 $PORT >/tmp/nitro.log 2>&1 & + +# Get the process id of the binary file +pid=$! + +if ! ps -p $pid >/dev/null; then + echo "nitro failed to start. Logs:" + cat /tmp/nitro.log + exit 1 +fi + +# Wait for a few seconds to let the server start +sleep 5 + +# Check if /tmp/testwhisper exists, if not, download it +if [[ ! -f "/tmp/testwhisper" ]]; then + wget $DOWNLOAD_URL -O /tmp/testwhisper +fi + +# Run the curl commands +response1=$(curl -o /tmp/response1.log -s -w "%{http_code}" --location "http://127.0.0.1:$PORT/v1/audio/load_model" \ + --header 'Content-Type: application/json' \ + --data '{ + "model_path": "/tmp/testwhisper", + "model_id": "whisper.cpp" +}' 2>&1) + +response2=$( + curl -o /tmp/response2.log -s -w "%{http_code}" --location "http://127.0.0.1:$PORT/v1/audio/transcriptions" \ + --header 'Access-Control-Allow-Origin: *' \ + --form 'file=@"whisper.cpp/samples/jfk.wav"' \ + --form 'model_id="whisper.cpp"' \ + --form 'temperature="0.0"' \ + --form 'prompt="The transcript is about OpenAI which makes technology like DALL·E, GPT-3, and ChatGPT with the hope of one day building an AGI system that benefits all of humanity. The president is trying to raly people to support the cause."' + 2>&1 +) + +error_occurred=0 +if [[ "$response1" -ne 200 ]]; then + echo "The first curl command failed with status code: $response1" + cat /tmp/response1.log + error_occurred=1 +fi + +if [[ "$response2" -ne 200 ]]; then + echo "The second curl command failed with status code: $response2" + cat /tmp/response2.log + error_occurred=1 +fi + +if [[ "$error_occurred" -eq 1 ]]; then + echo "Nitro test run failed!!!!!!!!!!!!!!!!!!!!!!" + echo "Nitro Error Logs:" + cat /tmp/nitro.log + kill $pid + exit 1 +fi + +echo "----------------------" +echo "Log load model:" +cat /tmp/response1.log + +echo "----------------------" +echo "Log run test:" +cat /tmp/response2.log + +echo "Nitro test run successfully!" + +# Kill the server process +kill $pid diff --git a/.github/scripts/e2e-test-whisper-windows.bat b/.github/scripts/e2e-test-whisper-windows.bat new file mode 100644 index 000000000..2b95f4f2f --- /dev/null +++ b/.github/scripts/e2e-test-whisper-windows.bat @@ -0,0 +1,109 @@ +@echo off + +set "TEMP=C:\Users\%UserName%\AppData\Local\Temp" +set "MODEL_PATH=%TEMP%\testwhisper" + +rem Check for required arguments +if "%~2"=="" ( + echo Usage: %~0 ^ ^ + exit /b 1 +) + +set "BINARY_PATH=%~1" +set "DOWNLOAD_URL=%~2" + +for %%i in ("%BINARY_PATH%") do set "BINARY_NAME=%%~nxi" + +echo BINARY_NAME=%BINARY_NAME% + +del %TEMP%\response1.log 2>nul +del %TEMP%\response2.log 2>nul +del %TEMP%\nitro.log 2>nul + +set /a min=9999 +set /a max=11000 +set /a range=max-min+1 +set /a PORT=%min% + %RANDOM% %% %range% + +rem Start the binary file +start /B "" "%BINARY_PATH%" 1 "127.0.0.1" %PORT% > %TEMP%\nitro.log 2>&1 + +ping -n 6 127.0.0.1 %PORT% > nul + +rem Capture the PID of the started process with "nitro" in its name +for /f "tokens=2" %%a in ('tasklist /fi "imagename eq %BINARY_NAME%" /fo list ^| findstr /B "PID:"') do ( + set "pid=%%a" +) + +echo pid=%pid% + +if not defined pid ( + echo nitro failed to start. Logs: + type %TEMP%\nitro.log + exit /b 1 +) + +rem Wait for a few seconds to let the server start + +rem Check if %TEMP%\testmodel exists, if not, download it +if not exist "%MODEL_PATH%" ( + bitsadmin.exe /transfer "DownloadTestModel" %DOWNLOAD_URL% "%MODEL_PATH%" +) + +rem Define JSON strings for curl data +call set "MODEL_PATH_STRING=%%MODEL_PATH:\=\\%%" +set "curl_data1={\"model_path\":\"%MODEL_PATH_STRING%\",\"model_id\":\"whisper.cpp\"}" + +rem Print the values of curl_data1 for debugging +echo curl_data1=%curl_data1% + +rem Run the curl commands and capture the status code +curl.exe -o %TEMP%\response1.log -s -w "%%{http_code}" --location "http://127.0.0.1:%PORT%/v1/audio/load_model" --header "Content-Type: application/json" --data "%curl_data1%" > %TEMP%\response1_code.log 2>&1 + +curl.exe -o %TEMP%\response2.log -s -w "%%{http_code}" --location "http://127.0.0.1:%PORT%/v1/audio/transcriptions" ^ +--header "Access-Control-Allow-Origin: *" ^ +--form 'model_id="whisper.cpp"' ^ +--form 'file=@"whisper.cpp\samples\jfk.wav"' ^ +--form 'temperature="0.0"' ^ +--form 'prompt="The transcript is about OpenAI which makes technology like DALL·E, GPT-3, and ChatGPT with the hope of one day building an AGI system that benefits all of humanity. The president is trying to raly people to support the cause."' ^ +> %TEMP%\response2_code.log 2>&1 + +set "error_occurred=0" + +rem Read the status codes from the log files +for /f %%a in (%TEMP%\response1_code.log) do set "response1=%%a" +for /f %%a in (%TEMP%\response2_code.log) do set "response2=%%a" + +if "%response1%" neq "200" ( + echo The first curl command failed with status code: %response1% + type %TEMP%\response1.log + set "error_occurred=1" +) + +if "%response2%" neq "200" ( + echo The second curl command failed with status code: %response2% + type %TEMP%\response2.log + set "error_occurred=1" +) + +if "%error_occurred%"=="1" ( + echo Nitro test run failed!!!!!!!!!!!!!!!!!!!!!! + echo Nitro Error Logs: + type %TEMP%\nitro.log + taskkill /f /pid %pid% + exit /b 1 +) + + +echo ---------------------- +echo Log load model: +type %TEMP%\response1.log + +echo ---------------------- +echo "Log run test:" +type %TEMP%\response2.log + +echo Nitro test run successfully! + +rem Kill the server process +taskkill /f /pid %pid% diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index b356b1f22..a98269af5 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -459,7 +459,7 @@ jobs: - name: Fetch SDL2 and set SDL2_DIR version 2.28.5 run: | - C:/msys64/usr/bin/wget.exe -qO sdl2.zip https://github.com/libsdl-org/SDL/releases/download/release-2.28.5/SDL2-devel-2.28.5-VC.zip + curl -L -o sdl2.zip https://github.com/libsdl-org/SDL/releases/download/release-2.28.5/SDL2-devel-2.28.5-VC.zip 7z x sdl2.zip echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-2.28.5/cmake" >> $env:GITHUB_ENV From f18a2a8836f4f2a39defb7fac43d173c9e8431d6 Mon Sep 17 00:00:00 2001 From: hiro Date: Sat, 27 Jan 2024 12:59:56 +0700 Subject: [PATCH 28/31] fix(ci): Add e2e to CI --- .github/workflows/build.yml | 87 +++++++++++++++++++++++++++---------- 1 file changed, 64 insertions(+), 23 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index a98269af5..860d26ac5 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -49,7 +49,8 @@ on: env: BRANCH_NAME: ${{ github.head_ref || github.ref_name }} - MODEL_URL: https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/resolve/main/tinyllama-1.1b-chat-v0.3.Q2_K.gguf + LLM_MODEL_URL: https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/resolve/main/tinyllama-1.1b-chat-v0.3.Q2_K.gguf + WHISPER_MODEL_URL: https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.bin jobs: create-draft-release: @@ -157,12 +158,19 @@ jobs: name: nitro-linux-amd64 path: ./nitro - - name: Run e2e testing + - name: Run e2e testing - LLama.CPP shell: bash run: | # run e2e testing cd nitro - chmod +x ../.github/scripts/e2e-test-linux-and-mac.sh && ../.github/scripts/e2e-test-linux-and-mac.sh ./nitro ${{ env.MODEL_URL }} + chmod +x ../.github/scripts/e2e-test-llama-linux-and-mac.sh && ../.github/scripts/e2e-test-llama-linux-and-mac.sh ./nitro ${{ env.LLM_MODEL_URL }} + + - name: Run e2e testing - Whisper.CPP + shell: bash + run: | + # run e2e testing + cd nitro + chmod +x ../.github/scripts/e2e-test-whisper-linux-and-mac.sh && ../.github/scripts/e2e-test-whisper-linux-and-mac.sh ./nitro ${{ env.WHISPER_MODEL_URL }} - uses: actions/upload-release-asset@v1.0.1 if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') @@ -214,12 +222,19 @@ jobs: name: nitro-linux-amd64-cuda-${{ matrix.cuda }} path: ./nitro - # - name: Run e2e testing - # shell: bash - # run: | - # # run e2e testing - # cd nitro - # chmod +x ../.github/scripts/e2e-test-linux-and-mac.sh && ../.github/scripts/e2e-test-linux-and-mac.sh ./nitro ${{ env.MODEL_URL }} + - name: Run e2e testing - LLama.CPP + shell: bash + run: | + # run e2e testing + cd nitro + chmod +x ../.github/scripts/e2e-test-llama-linux-and-mac.sh && ../.github/scripts/e2e-test-llama-linux-and-mac.sh ./nitro ${{ env.LLM_MODEL_URL }} + + - name: Run e2e testing - Whisper.CPP + shell: bash + run: | + # run e2e testing + cd nitro + chmod +x ../.github/scripts/e2e-test-whisper-linux-and-mac.sh && ../.github/scripts/e2e-test-whisper-linux-and-mac.sh ./nitro ${{ env.WHISPER_MODEL_URL }} - uses: actions/upload-release-asset@v1.0.1 if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') @@ -275,12 +290,19 @@ jobs: name: nitro-mac-arm64 path: ./nitro - - name: Run e2e testing + - name: Run e2e testing - LLama.CPP + shell: bash + run: | + # run e2e testing + cd nitro + chmod +x ../.github/scripts/e2e-test-llama-linux-and-mac.sh && ../.github/scripts/e2e-test-llama-linux-and-mac.sh ./nitro ${{ env.LLM_MODEL_URL }} + + - name: Run e2e testing - Whisper.CPP shell: bash run: | # run e2e testing cd nitro - chmod +x ../.github/scripts/e2e-test-linux-and-mac.sh && ../.github/scripts/e2e-test-linux-and-mac.sh ./nitro ${{ env.MODEL_URL }} + chmod +x ../.github/scripts/e2e-test-whisper-linux-and-mac.sh && ../.github/scripts/e2e-test-whisper-linux-and-mac.sh ./nitro ${{ env.WHISPER_MODEL_URL }} - uses: actions/upload-release-asset@v1.0.1 if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') @@ -335,12 +357,19 @@ jobs: name: nitro-mac-amd64 path: ./nitro - - name: Run e2e testing + - name: Run e2e testing - LLama.CPP shell: bash run: | # run e2e testing cd nitro - chmod +x ../.github/scripts/e2e-test-linux-and-mac.sh && ../.github/scripts/e2e-test-linux-and-mac.sh ./nitro ${{ env.MODEL_URL }} + chmod +x ../.github/scripts/e2e-test-llama-linux-and-mac.sh && ../.github/scripts/e2e-test-llama-linux-and-mac.sh ./nitro ${{ env.LLM_MODEL_URL }} + + - name: Run e2e testing - Whisper.CPP + shell: bash + run: | + # run e2e testing + cd nitro + chmod +x ../.github/scripts/e2e-test-whisper-linux-and-mac.sh && ../.github/scripts/e2e-test-whisper-linux-and-mac.sh ./nitro ${{ env.WHISPER_MODEL_URL }} - uses: actions/upload-release-asset@v1.0.1 if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') @@ -418,11 +447,17 @@ jobs: name: nitro-win-amd64 path: ./build/Release - # - name: Run e2e testing - # shell: cmd - # run: | - # cd .\build\Release - # ..\..\.github\scripts\e2e-test-windows.bat .\nitro.exe ${{ env.MODEL_URL }} + - name: Run e2e testing - Llama.cpp + shell: cmd + run: | + cd .\build\Release + ..\..\.github\scripts\e2e-test-llama-windows.bat .\nitro.exe ${{ env.LLM_MODEL_URL }} + + - name: Run e2e testing - Whisper.cpp + shell: cmd + run: | + cd .\build\Release + ..\..\.github\scripts\e2e-test-whisper-windows.bat .\nitro.exe ${{ env.WHISPER_MODEL_URL }} - uses: actions/upload-release-asset@v1.0.1 if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') @@ -516,11 +551,17 @@ jobs: name: nitro-win-amd64-cuda-${{ matrix.cuda }} path: ./build/Release - # - name: run e2e testing - # shell: cmd - # run: | - # cd .\build\Release - # ..\..\.github\scripts\e2e-test-windows.bat .\nitro.exe ${{ env.MODEL_URL }} + - name: Run e2e testing - Llama.cpp + shell: cmd + run: | + cd .\build\Release + ..\..\.github\scripts\e2e-test-llama-windows.bat .\nitro.exe ${{ env.LLM_MODEL_URL }} + + - name: Run e2e testing - Whisper.cpp + shell: cmd + run: | + cd .\build\Release + ..\..\.github\scripts\e2e-test-whisper-windows.bat .\nitro.exe ${{ env.WHISPER_MODEL_URL }} - uses: actions/upload-release-asset@v1.0.1 if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') From 74e9a11afba6f5e7e8d06b9ac519c28e0a8e850a Mon Sep 17 00:00:00 2001 From: hiro-v Date: Sat, 27 Jan 2024 14:23:37 +0700 Subject: [PATCH 29/31] Add model warm up --- audio.md | 15 ++++++++++----- controllers/whisperCPP.cc | 29 ++++++++++++++++++++++++++++- 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/audio.md b/audio.md index bb56a12b2..2d39be0d2 100644 --- a/audio.md +++ b/audio.md @@ -33,10 +33,10 @@ mkdir -p build cd build cmake .. -DLLAMA_CUBLAS=ON -DBUILD_SHARED_LIBS=ON -DWHISPER_CUBLAS=ON -DWHISPER_SDL2=ON cmake --build . --config Release - -# Then copy llama.dll, whisper.dll and zlib.dll ``` +Then copy llama.dll, whisper.dll and zlib.dll + ## Sample test command - Download `ggml-base.en.bin` with [whisper.cpp/models/download-ggml-model.sh](whisper.cpp/models/download-ggml-model.sh) @@ -45,11 +45,16 @@ cmake --build . --config Release ```bash curl 127.0.0.1:3928/v1/audio/load_model \ -X POST -H "Content-Type: application/json" \ --d '{"model_id":"ggml-base.en.bin","model_path":"/abs/path/to/whisper.cpp/models/ggml-base.en.bin"}' +-d '{ + "model_id":"ggml-base.en.bin", + "model_path":"/abs/path/to/whisper.cpp/models/ggml-base.en.bin" + "warm_up_audio_path":"/abs/path/to/samples.wav" +}' +``` +`warm_up_audio_path` is optional -# If we enable CoreML on Mac silicon, we need to include `ggml-base.mlmodelc` file in the same folder as `ggml-base.en.bin` -``` +If we enable CoreML on Mac silicon, we need to include `ggml-base.mlmodelc` file in the same folder as `ggml-base.en.bin` - List model: diff --git a/controllers/whisperCPP.cc b/controllers/whisperCPP.cc index e6d495fc5..b9b9b8bea 100644 --- a/controllers/whisperCPP.cc +++ b/controllers/whisperCPP.cc @@ -583,7 +583,7 @@ std::string whisper_server_context::inference( // print some processing info std::string processing_info = - "Model " + model_id + "processing " + input_file_path + " (" + + "Model " + model_id + " processing " + input_file_path + " (" + std::to_string(pcmf32.size()) + " samples, " + std::to_string(float(pcmf32.size()) / WHISPER_SAMPLE_RATE) + " sec), " + std::to_string(params.n_threads) + " threads, " + @@ -863,6 +863,33 @@ void whisperCPP::load_model( return; } + // Warm up the model + // Parse warm up audio path from request + if (jsonBody->isMember("warm_up_audio_path")) { + std::string warm_up_msg = "Warming up model " + model_id; + LOG_INFO << warm_up_msg; + std::string warm_up_audio_path = + (*jsonBody)["warm_up_audio_path"].asString(); + // Return 400 error if warm up audio path is not found + if (!is_file_exist(warm_up_audio_path.c_str())) { + std::string error_msg = "Warm up audio " + warm_up_audio_path + " not found, please provide a valid path or don't specify it at all"; + LOG_INFO << error_msg; + Json::Value jsonResp; + jsonResp["message"] = error_msg; + auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } else { + LOG_INFO << "Warming up model " << model_id << " with audio " << warm_up_audio_path << " ..."; + std::string warm_up_result = whisper.inference( + warm_up_audio_path, "en", "", text_format, 0, false); + LOG_INFO << "Warm up model " << model_id << " completed"; + } + } else { + LOG_INFO << "No warm up audio provided, skipping warm up"; + } + // Model loaded successfully, add it to the map of loaded models // and return a 200 response whispers.emplace(model_id, std::move(whisper)); From c5a45d5f063a2f54476de82b6ee772055cb4e320 Mon Sep 17 00:00:00 2001 From: hiro Date: Sat, 27 Jan 2024 16:16:06 +0700 Subject: [PATCH 30/31] fix(ci): Fix bug in path --- .github/scripts/e2e-test-llama-windows.bat | 2 +- .github/scripts/e2e-test-whisper-linux-and-mac.sh | 6 +++--- .github/scripts/e2e-test-whisper-windows.bat | 2 +- .github/workflows/build.yml | 4 ++++ 4 files changed, 9 insertions(+), 5 deletions(-) mode change 100644 => 100755 .github/scripts/e2e-test-whisper-linux-and-mac.sh diff --git a/.github/scripts/e2e-test-llama-windows.bat b/.github/scripts/e2e-test-llama-windows.bat index 9a758d5c1..84e6c33a0 100644 --- a/.github/scripts/e2e-test-llama-windows.bat +++ b/.github/scripts/e2e-test-llama-windows.bat @@ -62,7 +62,7 @@ echo curl_data2=%curl_data2% rem Run the curl commands and capture the status code curl.exe -o %TEMP%\response1.log -s -w "%%{http_code}" --location "http://127.0.0.1:%PORT%/inferences/llamacpp/loadModel" --header "Content-Type: application/json" --data "%curl_data1%" > %TEMP%\response1_code.log 2>&1 -curl.exe -o %TEMP%\response2.log -s -w "%%{http_code}" --location "http://127.0.0.1:%PORT%/v1/chat/completions" ^ +curl.exe -o %TEMP%\response2.log -s -w "%%{http_code}" --location "http://127.0.0.1:%PORT%/inferences/llamacpp/chat_completion" ^ --header "Content-Type: application/json" ^ --header "Accept: text/event-stream" ^ --header "Access-Control-Allow-Origin: *" ^ diff --git a/.github/scripts/e2e-test-whisper-linux-and-mac.sh b/.github/scripts/e2e-test-whisper-linux-and-mac.sh old mode 100644 new mode 100755 index d1ce007cd..90421dff3 --- a/.github/scripts/e2e-test-whisper-linux-and-mac.sh +++ b/.github/scripts/e2e-test-whisper-linux-and-mac.sh @@ -51,11 +51,11 @@ response1=$(curl -o /tmp/response1.log -s -w "%{http_code}" --location "http://1 response2=$( curl -o /tmp/response2.log -s -w "%{http_code}" --location "http://127.0.0.1:$PORT/v1/audio/transcriptions" \ --header 'Access-Control-Allow-Origin: *' \ - --form 'file=@"whisper.cpp/samples/jfk.wav"' \ + --form 'file=@"../whisper.cpp/samples/jfk.wav"' \ --form 'model_id="whisper.cpp"' \ --form 'temperature="0.0"' \ - --form 'prompt="The transcript is about OpenAI which makes technology like DALL·E, GPT-3, and ChatGPT with the hope of one day building an AGI system that benefits all of humanity. The president is trying to raly people to support the cause."' - 2>&1 + --form 'prompt="The transcript is about OpenAI which makes technology like DALL·E, GPT-3, and ChatGPT with the hope of one day building an AGI system that benefits all of humanity. The president is trying to raly people to support the cause."' \ + 2>&1 ) error_occurred=0 diff --git a/.github/scripts/e2e-test-whisper-windows.bat b/.github/scripts/e2e-test-whisper-windows.bat index 2b95f4f2f..99019e101 100644 --- a/.github/scripts/e2e-test-whisper-windows.bat +++ b/.github/scripts/e2e-test-whisper-windows.bat @@ -63,7 +63,7 @@ curl.exe -o %TEMP%\response1.log -s -w "%%{http_code}" --location "http://127.0. curl.exe -o %TEMP%\response2.log -s -w "%%{http_code}" --location "http://127.0.0.1:%PORT%/v1/audio/transcriptions" ^ --header "Access-Control-Allow-Origin: *" ^ --form 'model_id="whisper.cpp"' ^ ---form 'file=@"whisper.cpp\samples\jfk.wav"' ^ +--form 'file=@"..\whisper.cpp\samples\jfk.wav"' ^ --form 'temperature="0.0"' ^ --form 'prompt="The transcript is about OpenAI which makes technology like DALL·E, GPT-3, and ChatGPT with the hope of one day building an AGI system that benefits all of humanity. The president is trying to raly people to support the cause."' ^ > %TEMP%\response2_code.log 2>&1 diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 860d26ac5..8cbca243f 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -300,6 +300,10 @@ jobs: - name: Run e2e testing - Whisper.CPP shell: bash run: | + # To test with CoreML + wget https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.en-encoder.mlmodelc.zip + unzip ggml-tiny.en-encoder.mlmodelc.zip + mv ggml-tiny.en-encoder.mlmodelc /tmp/testwhisper-encoder.mlmodelc # run e2e testing cd nitro chmod +x ../.github/scripts/e2e-test-whisper-linux-and-mac.sh && ../.github/scripts/e2e-test-whisper-linux-and-mac.sh ./nitro ${{ env.WHISPER_MODEL_URL }} From d2a776a796db162e471459e3d2ed850208e5afa8 Mon Sep 17 00:00:00 2001 From: hiro Date: Sat, 27 Jan 2024 16:35:05 +0700 Subject: [PATCH 31/31] chore(ci): Temp disable CI for e2e testing on windows --- .github/workflows/build.yml | 44 ++++++++++++++++++------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 8cbca243f..43532a27a 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -451,17 +451,17 @@ jobs: name: nitro-win-amd64 path: ./build/Release - - name: Run e2e testing - Llama.cpp - shell: cmd - run: | - cd .\build\Release - ..\..\.github\scripts\e2e-test-llama-windows.bat .\nitro.exe ${{ env.LLM_MODEL_URL }} - - - name: Run e2e testing - Whisper.cpp - shell: cmd - run: | - cd .\build\Release - ..\..\.github\scripts\e2e-test-whisper-windows.bat .\nitro.exe ${{ env.WHISPER_MODEL_URL }} + # - name: Run e2e testing - Llama.cpp + # shell: cmd + # run: | + # cd .\build\Release + # ..\..\.github\scripts\e2e-test-llama-windows.bat .\nitro.exe ${{ env.LLM_MODEL_URL }} + + # - name: Run e2e testing - Whisper.cpp + # shell: cmd + # run: | + # cd .\build\Release + # ..\..\.github\scripts\e2e-test-whisper-windows.bat .\nitro.exe ${{ env.WHISPER_MODEL_URL }} - uses: actions/upload-release-asset@v1.0.1 if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') @@ -555,17 +555,17 @@ jobs: name: nitro-win-amd64-cuda-${{ matrix.cuda }} path: ./build/Release - - name: Run e2e testing - Llama.cpp - shell: cmd - run: | - cd .\build\Release - ..\..\.github\scripts\e2e-test-llama-windows.bat .\nitro.exe ${{ env.LLM_MODEL_URL }} - - - name: Run e2e testing - Whisper.cpp - shell: cmd - run: | - cd .\build\Release - ..\..\.github\scripts\e2e-test-whisper-windows.bat .\nitro.exe ${{ env.WHISPER_MODEL_URL }} + # - name: Run e2e testing - Llama.cpp + # shell: cmd + # run: | + # cd .\build\Release + # ..\..\.github\scripts\e2e-test-llama-windows.bat .\nitro.exe ${{ env.LLM_MODEL_URL }} + + # - name: Run e2e testing - Whisper.cpp + # shell: cmd + # run: | + # cd .\build\Release + # ..\..\.github\scripts\e2e-test-whisper-windows.bat .\nitro.exe ${{ env.WHISPER_MODEL_URL }} - uses: actions/upload-release-asset@v1.0.1 if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/')