Skip to content

Commit

Permalink
feat: Add mixedbread ai support (#122)
Browse files Browse the repository at this point in the history
* feat: add mixedbread ai support

* Update embedding_service.hpp

* Update embedding_service.hpp

---------

Co-authored-by: richard-epsilla <131846445+richard-epsilla@users.noreply.github.com>
  • Loading branch information
juliuslipp and richard-epsilla committed Jan 28, 2024
1 parent 8e54857 commit 9e2dbc5
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 4 deletions.
12 changes: 12 additions & 0 deletions engine/server/web_server/web_controller.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ class WebController : public oatpp::web::server::api::ApiController {
if (headerValue != nullptr) {
headers[VOYAGEAI_KEY_HEADER] = headerValue->c_str();
}
headerValue = request->getHeader(MIXEDBREADAI_KEY_HEADER);
if (headerValue != nullptr) {
headers[MIXEDBREADAI_KEY_HEADER] = headerValue->c_str();
}

std::string db_path = parsedBody.GetString("path");
std::string db_name = parsedBody.GetString("name");
Expand Down Expand Up @@ -408,6 +412,10 @@ class WebController : public oatpp::web::server::api::ApiController {
if (headerValue != nullptr) {
headers[VOYAGEAI_KEY_HEADER] = headerValue->c_str();
}
headerValue = request->getHeader(MIXEDBREADAI_KEY_HEADER);
if (headerValue != nullptr) {
headers[MIXEDBREADAI_KEY_HEADER] = headerValue->c_str();
}

auto data = parsedBody.GetArray("data");
vectordb::Status insert_status = db_server->Insert(db_name, table_name, data, headers, upsert);
Expand Down Expand Up @@ -638,6 +646,10 @@ class WebController : public oatpp::web::server::api::ApiController {
if (headerValue != nullptr) {
headers[VOYAGEAI_KEY_HEADER] = headerValue->c_str();
}
headerValue = request->getHeader(MIXEDBREADAI_KEY_HEADER);
if (headerValue != nullptr) {
headers[MIXEDBREADAI_KEY_HEADER] = headerValue->c_str();
}

vectordb::Json result;
vectordb::Status search_status;
Expand Down
16 changes: 14 additions & 2 deletions engine/services/embedding_service.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ Status EmbeddingService::denseEmbedDocuments(
std::string openai_key = "";
std::string jinaai_key = "";
std::string voyageai_key = "";
std::string mixedbreadai_key = "";
// Inject 3rd party service key based on their model name.
if (server::CommonUtil::StartsWith(model_name, "openai/")) {
if (headers.find(OPENAI_KEY_HEADER) == headers.end()) {
Expand All @@ -77,6 +78,11 @@ Status EmbeddingService::denseEmbedDocuments(
return Status(INVALID_PAYLOAD, "Missing VoyageAI API key.");
}
voyageai_key = headers[VOYAGEAI_KEY_HEADER];
} else if (server::CommonUtil::StartsWith(model_name, "mixedbreadai/")) {
if (headers.find(MIXEDBREADAI_KEY_HEADER) == headers.end()) {
return Status(INVALID_PAYLOAD, "Missing mixedbread ai API key.");
}
mixedbreadai_key = headers[MIXEDBREADAI_KEY_HEADER];
}

// Constructing documents list from attr_column_container
Expand All @@ -85,7 +91,7 @@ Status EmbeddingService::denseEmbedDocuments(
// Assuming attr_column_container[idx] returns a string or can be converted to string
requestBody->documents->push_back(oatpp::String(std::get<std::string>(attr_column_container[idx]).c_str()));
}
auto response = m_client->denseEmbedDocuments("/v1/embeddings", openai_key, jinaai_key, voyageai_key, requestBody);
auto response = m_client->denseEmbedDocuments("/v1/embeddings", openai_key, jinaai_key, voyageai_key, mixedbreadai_key, requestBody);
auto responseBody = response->readBodyToString();
// std::cout << "Embedding response: " << responseBody->c_str() << std::endl;
vectordb::Json json;
Expand Down Expand Up @@ -136,6 +142,7 @@ Status EmbeddingService::denseEmbedQuery(
std::string openai_key = "";
std::string jinaai_key = "";
std::string voyageai_key = "";
std::string mixedbreadai_key = "";
// Inject 3rd party service key based on their model name.
if (server::CommonUtil::StartsWith(model_name, "openai/")) {
if (headers.find(OPENAI_KEY_HEADER) == headers.end()) {
Expand All @@ -152,9 +159,14 @@ Status EmbeddingService::denseEmbedQuery(
return Status(INVALID_PAYLOAD, "Missing VoyageAI API key.");
}
voyageai_key = headers[VOYAGEAI_KEY_HEADER];
} else if (server::CommonUtil::StartsWith(model_name, "mixedbreadai/")) {
if (headers.find(MIXEDBREADAI_KEY_HEADER) == headers.end()) {
return Status(INVALID_PAYLOAD, "Missing mixedbread ai API key.");
}
mixedbreadai_key = headers[MIXEDBREADAI_KEY_HEADER];
}

auto response = m_client->denseEmbedDocuments("/v1/embeddings", openai_key, jinaai_key, voyageai_key, requestBody);
auto response = m_client->denseEmbedDocuments("/v1/embeddings", openai_key, jinaai_key, voyageai_key, mixedbreadai_key, requestBody);
auto responseBody = response->readBodyToString();
// std::cout << "Embedding response: " << responseBody->c_str() << std::endl;
vectordb::Json json;
Expand Down
4 changes: 2 additions & 2 deletions engine/services/embedding_service.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class MyApiClient : public oatpp::web::client::ApiClient {
API_CLIENT_INIT(MyApiClient)

API_CALL("GET", "{path}", getEmbeddings, PATH(String, path))
API_CALL("POST", "{path}", denseEmbedDocuments, PATH(String, path), HEADER(String, openaiHeader, OPENAI_KEY_HEADER), HEADER(String, jinaaiHeader, JINAAI_KEY_HEADER), HEADER(String, voyageaiHeader, VOYAGEAI_KEY_HEADER), BODY_DTO(Object<EmbeddingRequestBody>, body))
API_CALL("POST", "{path}", denseEmbedDocuments, PATH(String, path), HEADER(String, openaiHeader, OPENAI_KEY_HEADER), HEADER(String, jinaaiHeader, JINAAI_KEY_HEADER), HEADER(String, voyageaiHeader, VOYAGEAI_KEY_HEADER), HEADER(String, mixedbreadaiHeader, MIXEDBREADAI_KEY_HEADER), BODY_DTO(Object<EmbeddingRequestBody>, body))


#include OATPP_CODEGEN_END(ApiClient) //<- End codegen
Expand Down Expand Up @@ -85,4 +85,4 @@ class EmbeddingService {
};

} // namespace engine
} // namespace vectordb
} // namespace vectordb
1 change: 1 addition & 0 deletions engine/utils/constants.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ namespace vectordb {
constexpr const char* OPENAI_KEY_HEADER = "X-OpenAI-API-Key";
constexpr const char* JINAAI_KEY_HEADER = "X-JinaAI-API-Key";
constexpr const char* VOYAGEAI_KEY_HEADER = "X-VoyageAI-API-Key";
constexpr const char* MIXEDBREADAI_KEY_HEADER = "X-MixedbreadAI-API-Key";
} // namespace vectordb

0 comments on commit 9e2dbc5

Please sign in to comment.