33#include < fstream>
44#include < iostream>
55#include " log.h"
6- #include " utils/nitro_utils.h"
76#include " utils/logging_utils.h"
7+ #include " utils/nitro_utils.h"
88
99// External
1010#include " common.h"
@@ -29,6 +29,8 @@ struct inferenceState {
2929 int task_id;
3030 InferenceStatus inference_status = PENDING;
3131 llamaCPP* instance;
32+ // Check if we receive the first token, set it to false after receiving
33+ bool is_first_token = true ;
3234
3335 inferenceState (llamaCPP* inst) : instance(inst) {}
3436};
@@ -208,7 +210,8 @@ void llamaCPP::InferenceImpl(
208210
209211 // Passing load value
210212 data[" repeat_last_n" ] = this ->repeat_last_n ;
211- LOG_INFO_REQUEST (request_id) << " Stop words:" << completion.stop .toStyledString ();
213+ LOG_INFO_REQUEST (request_id)
214+ << " Stop words:" << completion.stop .toStyledString ();
212215
213216 data[" stream" ] = completion.stream ;
214217 data[" n_predict" ] = completion.max_tokens ;
@@ -267,7 +270,8 @@ void llamaCPP::InferenceImpl(
267270 auto image_url = content_piece[" image_url" ][" url" ].asString ();
268271 std::string base64_image_data;
269272 if (image_url.find (" http" ) != std::string::npos) {
270- LOG_INFO_REQUEST (request_id) << " Remote image detected but not supported yet" ;
273+ LOG_INFO_REQUEST (request_id)
274+ << " Remote image detected but not supported yet" ;
271275 } else if (image_url.find (" data:image" ) != std::string::npos) {
272276 LOG_INFO_REQUEST (request_id) << " Base64 image detected" ;
273277 base64_image_data = nitro_utils::extractBase64 (image_url);
@@ -328,16 +332,19 @@ void llamaCPP::InferenceImpl(
328332 if (is_streamed) {
329333 LOG_INFO_REQUEST (request_id) << " Streamed, waiting for respone" ;
330334 auto state = create_inference_state (this );
331- auto chunked_content_provider =
332- [state, data, request_id](char * pBuffer, std::size_t nBuffSize) -> std::size_t {
335+
336+ auto chunked_content_provider = [state, data, request_id](
337+ char * pBuffer,
338+ std::size_t nBuffSize) -> std::size_t {
333339 if (state->inference_status == PENDING) {
334340 state->inference_status = RUNNING;
335341 } else if (state->inference_status == FINISHED) {
336342 return 0 ;
337343 }
338344
339345 if (!pBuffer) {
340- LOG_WARN_REQUEST (request_id) " Connection closed or buffer is null. Reset context" ;
346+ LOG_WARN_REQUEST (request_id)
347+ " Connection closed or buffer is null. Reset context" ;
341348 state->inference_status = FINISHED;
342349 return 0 ;
343350 }
@@ -350,7 +357,8 @@ void llamaCPP::InferenceImpl(
350357 " stop" ) +
351358 " \n\n " + " data: [DONE]" + " \n\n " ;
352359
353- LOG_VERBOSE (" data stream" , {{" request_id" : request_id}, {" to_send" , str}});
360+ LOG_VERBOSE (" data stream" ,
361+ {{" request_id" : request_id}, {" to_send" , str}});
354362 std::size_t nRead = std::min (str.size (), nBuffSize);
355363 memcpy (pBuffer, str.data (), nRead);
356364 state->inference_status = FINISHED;
@@ -359,7 +367,13 @@ void llamaCPP::InferenceImpl(
359367
360368 task_result result = state->instance ->llama .next_result (state->task_id );
361369 if (!result.error ) {
362- const std::string to_send = result.result_json [" content" ];
370+ std::string to_send = result.result_json [" content" ];
371+
372+ // trim the leading space if it is the first token
373+ if (std::exchange (state->is_first_token , false )) {
374+ nitro_utils::ltrim (to_send);
375+ }
376+
363377 const std::string str =
364378 " data: " +
365379 create_return_json (nitro_utils::generate_random_string (20 ), " _" ,
@@ -410,7 +424,8 @@ void llamaCPP::InferenceImpl(
410424 retries += 1 ;
411425 }
412426 if (state->inference_status != RUNNING)
413- LOG_INFO_REQUEST (request_id) << " Wait for task to be released:" << state->task_id ;
427+ LOG_INFO_REQUEST (request_id)
428+ << " Wait for task to be released:" << state->task_id ;
414429 std::this_thread::sleep_for (std::chrono::milliseconds (100 ));
415430 }
416431 LOG_INFO_REQUEST (request_id) << " Task completed, release it" ;
@@ -428,9 +443,11 @@ void llamaCPP::InferenceImpl(
428443 if (!result.error && result.stop ) {
429444 int prompt_tokens = result.result_json [" tokens_evaluated" ];
430445 int predicted_tokens = result.result_json [" tokens_predicted" ];
431- respData = create_full_return_json (nitro_utils::generate_random_string (20 ),
432- " _" , result.result_json [" content" ], " _" ,
433- prompt_tokens, predicted_tokens);
446+ std::string to_send = result.result_json [" content" ];
447+ nitro_utils::ltrim (to_send);
448+ respData = create_full_return_json (
449+ nitro_utils::generate_random_string (20 ), " _" , to_send, " _" ,
450+ prompt_tokens, predicted_tokens);
434451 } else {
435452 respData[" message" ] = " Internal error during inference" ;
436453 LOG_ERROR_REQUEST (request_id) << " Error during inference" ;
@@ -463,7 +480,8 @@ void llamaCPP::EmbeddingImpl(
463480 // Queue embedding task
464481 auto state = create_inference_state (this );
465482
466- state->instance ->queue ->runTaskInQueue ([this , state, jsonBody, callback, request_id]() {
483+ state->instance ->queue ->runTaskInQueue ([this , state, jsonBody, callback,
484+ request_id]() {
467485 Json::Value responseData (Json::arrayValue);
468486
469487 if (jsonBody->isMember (" input" )) {
@@ -535,7 +553,7 @@ void llamaCPP::ModelStatus(
535553 auto resp = nitro_utils::nitroHttpJsonResponse (jsonResp);
536554 callback (resp);
537555 LOG_INFO << " Model status responded" ;
538- }
556+ }
539557}
540558
541559void llamaCPP::LoadModel (
@@ -545,10 +563,12 @@ void llamaCPP::LoadModel(
545563 if (!nitro_utils::isAVX2Supported () && ggml_cpu_has_avx2 ()) {
546564 LOG_ERROR << " AVX2 is not supported by your processor" ;
547565 Json::Value jsonResp;
548- jsonResp[" message" ] = " AVX2 is not supported by your processor, please download and replace the correct Nitro asset version" ;
566+ jsonResp[" message" ] =
567+ " AVX2 is not supported by your processor, please download and replace "
568+ " the correct Nitro asset version" ;
549569 auto resp = nitro_utils::nitroHttpJsonResponse (jsonResp);
550570 resp->setStatusCode (drogon::k500InternalServerError);
551- callback (resp);
571+ callback (resp);
552572 return ;
553573 }
554574
@@ -615,7 +635,8 @@ bool llamaCPP::LoadModelImpl(std::shared_ptr<Json::Value> jsonBody) {
615635 if (model_path.isNull ()) {
616636 LOG_ERROR << " Missing model path in request" ;
617637 } else {
618- if (std::filesystem::exists (std::filesystem::path (model_path.asString ()))) {
638+ if (std::filesystem::exists (
639+ std::filesystem::path (model_path.asString ()))) {
619640 params.model = model_path.asString ();
620641 } else {
621642 LOG_ERROR << " Could not find model in path " << model_path.asString ();
0 commit comments