Skip to content

Commit

Permalink
fix: responses from /chat/completions endpoint contain a leading spac…
Browse files Browse the repository at this point in the history
…e in the content (#488)

Co-authored-by: vansangpfiev <sang@jan.ai>
  • Loading branch information
vansangpfiev and sangjanai committed Apr 12, 2024
1 parent 4419a4d commit f64a90f
Show file tree
Hide file tree
Showing 10 changed files with 183 additions and 21 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ add_compile_definitions(NITRO_VERSION="${NITRO_VERSION}")
add_subdirectory(llama.cpp/examples/llava)
add_subdirectory(llama.cpp)
add_subdirectory(whisper.cpp)
add_subdirectory(test)

add_executable(${PROJECT_NAME} main.cc)

Expand Down
55 changes: 38 additions & 17 deletions controllers/llamaCPP.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
#include <fstream>
#include <iostream>
#include "log.h"
#include "utils/nitro_utils.h"
#include "utils/logging_utils.h"
#include "utils/nitro_utils.h"

// External
#include "common.h"
Expand All @@ -29,6 +29,8 @@ struct inferenceState {
int task_id;
InferenceStatus inference_status = PENDING;
llamaCPP* instance;
// Check if we receive the first token, set it to false after receiving
bool is_first_token = true;

inferenceState(llamaCPP* inst) : instance(inst) {}
};
Expand Down Expand Up @@ -208,7 +210,8 @@ void llamaCPP::InferenceImpl(

// Passing load value
data["repeat_last_n"] = this->repeat_last_n;
LOG_INFO_REQUEST(request_id) << "Stop words:" << completion.stop.toStyledString();
LOG_INFO_REQUEST(request_id)
<< "Stop words:" << completion.stop.toStyledString();

data["stream"] = completion.stream;
data["n_predict"] = completion.max_tokens;
Expand Down Expand Up @@ -267,7 +270,8 @@ void llamaCPP::InferenceImpl(
auto image_url = content_piece["image_url"]["url"].asString();
std::string base64_image_data;
if (image_url.find("http") != std::string::npos) {
LOG_INFO_REQUEST(request_id) << "Remote image detected but not supported yet";
LOG_INFO_REQUEST(request_id)
<< "Remote image detected but not supported yet";
} else if (image_url.find("data:image") != std::string::npos) {
LOG_INFO_REQUEST(request_id) << "Base64 image detected";
base64_image_data = nitro_utils::extractBase64(image_url);
Expand Down Expand Up @@ -328,16 +332,19 @@ void llamaCPP::InferenceImpl(
if (is_streamed) {
LOG_INFO_REQUEST(request_id) << "Streamed, waiting for respone";
auto state = create_inference_state(this);
auto chunked_content_provider =
[state, data, request_id](char* pBuffer, std::size_t nBuffSize) -> std::size_t {

auto chunked_content_provider = [state, data, request_id](
char* pBuffer,
std::size_t nBuffSize) -> std::size_t {
if (state->inference_status == PENDING) {
state->inference_status = RUNNING;
} else if (state->inference_status == FINISHED) {
return 0;
}

if (!pBuffer) {
LOG_WARN_REQUEST(request_id) "Connection closed or buffer is null. Reset context";
LOG_WARN_REQUEST(request_id)
"Connection closed or buffer is null. Reset context";
state->inference_status = FINISHED;
return 0;
}
Expand All @@ -350,7 +357,8 @@ void llamaCPP::InferenceImpl(
"stop") +
"\n\n" + "data: [DONE]" + "\n\n";

LOG_VERBOSE("data stream", {{"request_id": request_id}, {"to_send", str}});
LOG_VERBOSE("data stream",
{{"request_id": request_id}, {"to_send", str}});
std::size_t nRead = std::min(str.size(), nBuffSize);
memcpy(pBuffer, str.data(), nRead);
state->inference_status = FINISHED;
Expand All @@ -359,7 +367,13 @@ void llamaCPP::InferenceImpl(

task_result result = state->instance->llama.next_result(state->task_id);
if (!result.error) {
const std::string to_send = result.result_json["content"];
std::string to_send = result.result_json["content"];

// trim the leading space if it is the first token
if (std::exchange(state->is_first_token, false)) {
nitro_utils::ltrim(to_send);
}

const std::string str =
"data: " +
create_return_json(nitro_utils::generate_random_string(20), "_",
Expand Down Expand Up @@ -410,7 +424,8 @@ void llamaCPP::InferenceImpl(
retries += 1;
}
if (state->inference_status != RUNNING)
LOG_INFO_REQUEST(request_id) << "Wait for task to be released:" << state->task_id;
LOG_INFO_REQUEST(request_id)
<< "Wait for task to be released:" << state->task_id;
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
LOG_INFO_REQUEST(request_id) << "Task completed, release it";
Expand All @@ -428,9 +443,11 @@ void llamaCPP::InferenceImpl(
if (!result.error && result.stop) {
int prompt_tokens = result.result_json["tokens_evaluated"];
int predicted_tokens = result.result_json["tokens_predicted"];
respData = create_full_return_json(nitro_utils::generate_random_string(20),
"_", result.result_json["content"], "_",
prompt_tokens, predicted_tokens);
std::string to_send = result.result_json["content"];
nitro_utils::ltrim(to_send);
respData = create_full_return_json(
nitro_utils::generate_random_string(20), "_", to_send, "_",
prompt_tokens, predicted_tokens);
} else {
respData["message"] = "Internal error during inference";
LOG_ERROR_REQUEST(request_id) << "Error during inference";
Expand Down Expand Up @@ -463,7 +480,8 @@ void llamaCPP::EmbeddingImpl(
// Queue embedding task
auto state = create_inference_state(this);

state->instance->queue->runTaskInQueue([this, state, jsonBody, callback, request_id]() {
state->instance->queue->runTaskInQueue([this, state, jsonBody, callback,
request_id]() {
Json::Value responseData(Json::arrayValue);

if (jsonBody->isMember("input")) {
Expand Down Expand Up @@ -535,7 +553,7 @@ void llamaCPP::ModelStatus(
auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp);
callback(resp);
LOG_INFO << "Model status responded";
}
}
}

void llamaCPP::LoadModel(
Expand All @@ -545,10 +563,12 @@ void llamaCPP::LoadModel(
if (!nitro_utils::isAVX2Supported() && ggml_cpu_has_avx2()) {
LOG_ERROR << "AVX2 is not supported by your processor";
Json::Value jsonResp;
jsonResp["message"] = "AVX2 is not supported by your processor, please download and replace the correct Nitro asset version";
jsonResp["message"] =
"AVX2 is not supported by your processor, please download and replace "
"the correct Nitro asset version";
auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp);
resp->setStatusCode(drogon::k500InternalServerError);
callback(resp);
callback(resp);
return;
}

Expand Down Expand Up @@ -615,7 +635,8 @@ bool llamaCPP::LoadModelImpl(std::shared_ptr<Json::Value> jsonBody) {
if (model_path.isNull()) {
LOG_ERROR << "Missing model path in request";
} else {
if (std::filesystem::exists(std::filesystem::path(model_path.asString()))) {
if (std::filesystem::exists(
std::filesystem::path(model_path.asString()))) {
params.model = model_path.asString();
} else {
LOG_ERROR << "Could not find model in path " << model_path.asString();
Expand Down
4 changes: 2 additions & 2 deletions models/chat_completion_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ namespace inferences {
struct ChatCompletionRequest {
bool stream = false;
int max_tokens = 500;
float top_p = 0.95;
float temperature = 0.8;
float top_p = 0.95f;
float temperature = 0.8f;
float frequency_penalty = 0;
float presence_penalty = 0;
Json::Value stop = Json::Value(Json::arrayValue);
Expand Down
15 changes: 14 additions & 1 deletion nitro_deps/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,21 @@ ExternalProject_Add(
-DCMAKE_INSTALL_PREFIX=${THIRD_PARTY_INSTALL_PATH}
)

# Fix trantor cmakelists to link c-ares on Windows
# Download and install GoogleTest
ExternalProject_Add(
gtest
GIT_REPOSITORY https://github.com/google/googletest
GIT_TAG v1.14.0
CMAKE_ARGS
-Dgtest_force_shared_crt=ON
-DCMAKE_BUILD_TYPE=release
-DCMAKE_PREFIX_PATH=${THIRD_PARTY_INSTALL_PATH}
-DCMAKE_INSTALL_PREFIX=${THIRD_PARTY_INSTALL_PATH}
)


if(WIN32)
# Fix trantor cmakelists to link c-ares on Windows
set(TRANTOR_CMAKE_FILE ${CMAKE_CURRENT_SOURCE_DIR}/../build_deps/nitro_deps/drogon-prefix/src/drogon/trantor/CMakeLists.txt)
ExternalProject_Add_Step(drogon trantor_custom_target
COMMAND ${CMAKE_COMMAND} -E echo add_definitions(-DCARES_STATICLIB) >> ${TRANTOR_CMAKE_FILE}
Expand Down
2 changes: 2 additions & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

add_subdirectory(components)
16 changes: 16 additions & 0 deletions test/components/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
file(GLOB SRCS *.cc)
project(test-components)

enable_testing()

add_executable(${PROJECT_NAME} ${SRCS})

find_package(Drogon CONFIG REQUIRED)
find_package(GTest CONFIG REQUIRED)

target_link_libraries(${PROJECT_NAME} PRIVATE Drogon::Drogon GTest::gtest GTest::gmock
${CMAKE_THREAD_LIBS_INIT})
target_include_directories(${PROJECT_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../)

add_test(NAME ${PROJECT_NAME}
COMMAND ${PROJECT_NAME})
9 changes: 9 additions & 0 deletions test/components/main.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#include "gtest/gtest.h"
#include <drogon/HttpAppFramework.h>
#include <drogon/drogon.h>

int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
int ret = RUN_ALL_TESTS();
return ret;
}
53 changes: 53 additions & 0 deletions test/components/test_models.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#include "gtest/gtest.h"
#include "models/chat_completion_request.h"

using inferences::ChatCompletionRequest;

class ModelTest : public ::testing::Test {
};


TEST_F(ModelTest, should_parse_request) {
{
Json::Value data;
auto req = drogon::HttpRequest::newHttpJsonRequest(data);

auto res =
drogon::fromRequest<inferences::ChatCompletionRequest>(*req.get());

EXPECT_EQ(res.stream, false);
EXPECT_EQ(res.max_tokens, 500);
EXPECT_EQ(res.top_p, 0.95f);
EXPECT_EQ(res.temperature, 0.8f);
EXPECT_EQ(res.frequency_penalty, 0);
EXPECT_EQ(res.presence_penalty, 0);
EXPECT_EQ(res.stop, Json::Value{});
EXPECT_EQ(res.messages, Json::Value{});
}

{
Json::Value data;
data["stream"] = true;
data["max_tokens"] = 400;
data["top_p"] = 0.8;
data["temperature"] = 0.7;
data["frequency_penalty"] = 0.1;
data["presence_penalty"] = 0.2;
data["messages"] = "message";
data["stop"] = "stop";

auto req = drogon::HttpRequest::newHttpJsonRequest(data);

auto res =
drogon::fromRequest<inferences::ChatCompletionRequest>(*req.get());

EXPECT_EQ(res.stream, true);
EXPECT_EQ(res.max_tokens, 400);
EXPECT_EQ(res.top_p, 0.8f);
EXPECT_EQ(res.temperature, 0.7f);
EXPECT_EQ(res.frequency_penalty, 0.1f);
EXPECT_EQ(res.presence_penalty, 0.2f);
EXPECT_EQ(res.stop, Json::Value{"stop"});
EXPECT_EQ(res.messages, Json::Value{"message"});
}
}
41 changes: 41 additions & 0 deletions test/components/test_nitro_utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#include "gtest/gtest.h"
#include "utils/nitro_utils.h"

class NitroUtilTest : public ::testing::Test {
};

TEST_F(NitroUtilTest, left_trim) {
{
std::string empty;
nitro_utils::ltrim(empty);
EXPECT_EQ(empty, "");
}

{
std::string s = "abc";
std::string expected = "abc";
nitro_utils::ltrim(s);
EXPECT_EQ(s, expected);
}

{
std::string s = " abc";
std::string expected = "abc";
nitro_utils::ltrim(s);
EXPECT_EQ(s, expected);
}

{
std::string s = "1 abc 2 ";
std::string expected = "1 abc 2 ";
nitro_utils::ltrim(s);
EXPECT_EQ(s, expected);
}

{
std::string s = " |abc";
std::string expected = "|abc";
nitro_utils::ltrim(s);
EXPECT_EQ(s, expected);
}
}
8 changes: 7 additions & 1 deletion utils/nitro_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ inline std::string generate_random_string(std::size_t length) {
std::random_device rd;
std::mt19937 generator(rd());

std::uniform_int_distribution<> distribution(0, characters.size() - 1);
std::uniform_int_distribution<> distribution(0, static_cast<int>(characters.size()) - 1);

std::string random_string(length, '\0');
std::generate_n(random_string.begin(), length,
Expand Down Expand Up @@ -276,4 +276,10 @@ inline drogon::HttpResponsePtr nitroStreamResponse(
return resp;
}

inline void ltrim(std::string& s) {
s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) {
return !std::isspace(ch);
}));
};

} // namespace nitro_utils

0 comments on commit f64a90f

Please sign in to comment.