Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 39 additions & 30 deletions controllers/llamaCPP.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@
using namespace inferences;
using json = nlohmann::json;

struct State {
bool isStopped = false;
struct inferenceState {
bool is_stopped = false;
bool is_streaming = false;
int task_id;
llamaCPP *instance;

State(int tid, llamaCPP *inst) : task_id(tid), instance(inst) {}
inferenceState(llamaCPP *inst) : instance(inst) {}
};

std::shared_ptr<State> createState(int task_id, llamaCPP *instance) {
return std::make_shared<State>(task_id, instance);
std::shared_ptr<inferenceState> create_inference_state(llamaCPP *instance) {
return std::make_shared<inferenceState>(instance);
}

// --------------------------------------------
Expand Down Expand Up @@ -295,41 +296,35 @@ void llamaCPP::chatCompletion(
#endif
int task_id;

if (llama.params.n_parallel == 1) {
while (true) {
if (!single_queue_is_busy) {
task_id = llama.request_completion(data, false, false, -1);
single_queue_is_busy = true;
break;
} else {
std::this_thread::sleep_for(
std::chrono::milliseconds(500)); // Sleep for 500 milliseconds
}
}
} else {
task_id = llama.request_completion(data, false, false, -1);
}

LOG_INFO << "Resolved request for task_id:" << task_id;

if (is_streamed) {
auto state = createState(task_id, this);

auto state = create_inference_state(this);
state->task_id = task_id;
auto chunked_content_provider =
[this, state](char *pBuffer, std::size_t nBuffSize) -> std::size_t {
[state, data](char *pBuffer, std::size_t nBuffSize) -> std::size_t {
if (!state->is_streaming) {
state->task_id =
state->instance->llama.request_completion(data, false, false, -1);
state->instance->single_queue_is_busy = true;
}
if (!pBuffer) {
LOG_INFO << "Connection closed or buffer is null. Reset context";
state->instance->llama.request_cancel(state->task_id);
single_queue_is_busy = false;
state->is_streaming = false;
state->instance->single_queue_is_busy = false;
return 0;
}
if (state->isStopped) {
single_queue_is_busy = false;
if (state->is_stopped) {
state->is_streaming = false;
state->instance->single_queue_is_busy = false;
return 0;
}

task_result result = state->instance->llama.next_result(state->task_id);
if (!result.error) {
// Update streaming state to being streamed
state->is_streaming = true;
const std::string to_send = result.result_json["content"];
const std::string str =
"data: " +
Expand All @@ -351,16 +346,30 @@ void llamaCPP::chatCompletion(
std::size_t nRead = std::min(str.size(), nBuffSize);
memcpy(pBuffer, str.data(), nRead);
LOG_INFO << "reached result stop";
state->isStopped = true;
state->is_stopped = true;
state->instance->llama.request_cancel(state->task_id);
state->is_streaming = false;
state->instance->single_queue_is_busy = false;

return nRead;
}
return nRead;
} else {
single_queue_is_busy = false;
return 0;
if (state->instance->llama.params.n_parallel == 1) {
while (state->instance->single_queue_is_busy) {
LOG_INFO << "Waiting for task to be released status:"
<< state->instance->single_queue_is_busy;
std::this_thread::sleep_for(std::chrono::milliseconds(500)); // Waiting in 500 miliseconds step
}
}
std::string str = "\n\n";
std::size_t nRead = str.size();
memcpy(pBuffer, str.data(), nRead);
LOG_INFO << "Failing retrying now";
return nRead;
}
single_queue_is_busy = false;
state->is_streaming = false;
state->instance->single_queue_is_busy = false;
return 0;
};
auto resp = nitro_utils::nitroStreamResponse(chunked_content_provider,
Expand Down