Skip to content

Commit

Permalink
embeddings embeddingsfor OpenAI
Browse files Browse the repository at this point in the history
  • Loading branch information
wujjpp committed Jan 29, 2024
1 parent fbe7dfa commit 9461329
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 0 deletions.
15 changes: 15 additions & 0 deletions examples/server/oai.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,18 @@ inline static std::vector<json> format_partial_response_oaicompat(const task_res

return std::vector<json>({ret});
}

inline static json format_embeddings_response_oaicompat(const json &request, const json &embeddings)
{
json res =
json{
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
{"object", "list"},
{"usage",
json{{"prompt_tokens", 0},
{"total_tokens", 0}}},
{"data", embeddings}
};
return res;
}

60 changes: 60 additions & 0 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2929,6 +2929,66 @@ int main(int argc, char **argv)
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
});

svr.Post("/v1/embeddings", [&llama](const httplib::Request &req, httplib::Response &res)
{
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const json body = json::parse(req.body);

json prompt;
if (body.count("input") != 0)
{
prompt = body["input"];
// batch
if(prompt.is_array()) {
json data = json::array();
int i = 0;
for (const json &elem : prompt) {
const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id);
llama.request_completion(task_id, { {"prompt", elem}, { "n_predict", 0} }, false, true, -1);

// get the result
task_result result = llama.queue_results.recv(task_id);
llama.queue_results.remove_waiting_task_id(task_id);

json embedding = json{
{"embedding", json_value(result.result_json, "embedding", json::array())},
{"index", i++},
{"object", "embedding"}
};
data.push_back(embedding);
}
json result = format_embeddings_response_oaicompat(body, data);
return res.set_content(result.dump(), "application/json; charset=utf-8");
}
}
else
{
prompt = "";
}

// create and queue the task
const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id);
llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}}, false, true, -1);

// get the result
task_result result = llama.queue_results.recv(task_id);
llama.queue_results.remove_waiting_task_id(task_id);

json data = json::array({json{
{"embedding", json_value(result.result_json, "embedding", json::array())},
{"index", 0},
{"object", "embedding"}
}}
);

json root = format_embeddings_response_oaicompat(body, data);

// send the result
return res.set_content(root.dump(), "application/json; charset=utf-8");
});

// GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!?
// "Bus error: 10" - this is on macOS, it does not crash on Linux
//std::thread t2([&]()
Expand Down

0 comments on commit 9461329

Please sign in to comment.