Skip to content

Commit

Permalink
adapt to upstream changes
Browse files Browse the repository at this point in the history
  • Loading branch information
mudler committed Feb 1, 2024
1 parent ff8e910 commit 160468e
Showing 1 changed file with 8 additions and 12 deletions.
20 changes: 8 additions & 12 deletions backend/cpp/llama/grpc-server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -527,14 +527,6 @@ struct llama_server_context
slot_params default_params;
llama_sampling_params default_sparams;

if (data.count("__oaicompat") != 0) {
slot->oaicompat = true;
slot->oaicompat_model = json_value(data, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
} else {
slot->oaicompat = false;
slot->oaicompat_model = "";
}

slot->params.stream = json_value(data, "stream", false);
slot->params.cache_prompt = json_value(data, "cache_prompt", false);
slot->params.n_predict = json_value(data, "n_predict", default_params.n_predict);
Expand Down Expand Up @@ -2032,9 +2024,9 @@ static void params_parse(const backend::ModelOptions* request,
std::sregex_token_iterator it{ arg_next.begin(), arg_next.end(), regex, -1 };
std::vector<std::string> split_arg{ it, {} };

GGML_ASSERT(split_arg.size() <= LLAMA_MAX_DEVICES);
GGML_ASSERT(split_arg.size() <= llama_max_devices());

for (size_t i_device = 0; i_device < LLAMA_MAX_DEVICES; ++i_device) {
for (size_t i_device = 0; i_device < llama_max_devices(); ++i_device) {
if (i_device < split_arg.size()) {
params.tensor_split[i_device] = std::stof(split_arg[i_device]);
}
Expand Down Expand Up @@ -2116,7 +2108,9 @@ class BackendServiceImpl final : public backend::Backend::Service {
}
grpc::Status PredictStream(grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter<backend::Reply>* writer) override {
json data = parse_options(true, request, llama);
const int task_id = llama.request_completion(data, false, false, -1);
const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id);
llama.request_completion(task_id, data, false, false, -1);
while (true)
{
task_result result = llama.next_result(task_id);
Expand Down Expand Up @@ -2152,7 +2146,9 @@ class BackendServiceImpl final : public backend::Backend::Service {

grpc::Status Predict(ServerContext* context, const backend::PredictOptions* request, backend::Reply* reply) {
json data = parse_options(false, request, llama);
const int task_id = llama.request_completion(data, false, false, -1);
const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id);
llama.request_completion(task_id, data, false, false, -1);
std::string completion_text;
task_result result = llama.next_result(task_id);
if (!result.error && result.stop) {
Expand Down

0 comments on commit 160468e

Please sign in to comment.