Skip to content

Commit

Permalink
Add EMBEDDING_MODEL_FORMAT in API config (#152)
Browse files Browse the repository at this point in the history
  • Loading branch information
tanaysoni committed Jun 16, 2020
1 parent 42f5667 commit af5fc79
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 5 deletions.
1 change: 1 addition & 0 deletions haystack/api/config.py
Expand Up @@ -39,6 +39,7 @@
if EXCLUDE_META_DATA_FIELDS:
EXCLUDE_META_DATA_FIELDS = ast.literal_eval(EXCLUDE_META_DATA_FIELDS)
EMBEDDING_MODEL_PATH = os.getenv("EMBEDDING_MODEL_PATH", None)
EMBEDDING_MODEL_FORMAT = os.getenv("EMBEDDING_MODEL_FORMAT", "farm")

# Monitoring
APM_SERVER = os.getenv("APM_SERVER", None)
Expand Down
10 changes: 8 additions & 2 deletions haystack/api/controller/search.py
Expand Up @@ -11,7 +11,8 @@
from haystack.api.config import DB_HOST, DB_PORT, DB_USER, DB_PW, DB_INDEX, ES_CONN_SCHEME, TEXT_FIELD_NAME, SEARCH_FIELD_NAME, \
EMBEDDING_DIM, EMBEDDING_FIELD_NAME, EXCLUDE_META_DATA_FIELDS, EMBEDDING_MODEL_PATH, USE_GPU, READER_MODEL_PATH, \
BATCHSIZE, CONTEXT_WINDOW_SIZE, TOP_K_PER_CANDIDATE, NO_ANS_BOOST, MAX_PROCESSES, MAX_SEQ_LEN, DOC_STRIDE, \
DEFAULT_TOP_K_READER, DEFAULT_TOP_K_RETRIEVER, CONCURRENT_REQUEST_PER_WORKER, FAQ_QUESTION_FIELD_NAME
DEFAULT_TOP_K_READER, DEFAULT_TOP_K_RETRIEVER, CONCURRENT_REQUEST_PER_WORKER, FAQ_QUESTION_FIELD_NAME, \
EMBEDDING_MODEL_FORMAT
from haystack.api.controller.utils import RequestLimiter
from haystack.database.elasticsearch import ElasticsearchDocumentStore
from haystack.reader.farm import FARMReader
Expand Down Expand Up @@ -41,7 +42,12 @@


if EMBEDDING_MODEL_PATH:
retriever = EmbeddingRetriever(document_store=document_store, embedding_model=EMBEDDING_MODEL_PATH, gpu=USE_GPU) # type: BaseRetriever
retriever = EmbeddingRetriever(
document_store=document_store,
embedding_model=EMBEDDING_MODEL_PATH,
model_format=EMBEDDING_MODEL_FORMAT, # type: ignore
gpu=USE_GPU
) # type: BaseRetriever
else:
retriever = ElasticsearchRetriever(document_store=document_store)

Expand Down
7 changes: 4 additions & 3 deletions haystack/retriever/elasticsearch.py
Expand Up @@ -2,6 +2,7 @@
from typing import List, Union

from farm.infer import Inferencer
from typing_extensions import Literal

from haystack.database.base import Document
from haystack.database.elasticsearch import ElasticsearchDocumentStore
Expand Down Expand Up @@ -107,8 +108,8 @@ def __init__(
document_store: ElasticsearchDocumentStore,
embedding_model: str,
gpu: bool = True,
model_format: str = "farm",
pooling_strategy: str = "reduce_mean",
model_format: Literal["farm", "transformers", "sentence_transformers"] = "farm",
pooling_strategy: Literal["cls_token", "reduce_mean", "reduce_max", "per_token", "s3e"] = "reduce_mean",
emb_extraction_layer: int = -1,
):
"""
Expand Down Expand Up @@ -162,7 +163,7 @@ def create_embedding(self, texts: Union[List[str], str]) -> List[List[float]]:
texts = [texts] # type: ignore
assert type(texts) == list, "Expecting a list of texts, i.e. create_embeddings(texts=['text1',...])"

if self.model_format == "farm":
if self.model_format == "farm" or self.model_format == "transformers":
res = self.embedding_model.inference_from_dicts(dicts=[{"text": t} for t in texts]) # type: ignore
emb = [list(r["vec"]) for r in res] #cast from numpy
elif self.model_format == "sentence_transformers":
Expand Down

0 comments on commit af5fc79

Please sign in to comment.