-
Notifications
You must be signed in to change notification settings - Fork 2.8k
/
hf_text_embedding.py
62 lines (50 loc) · 2.14 KB
/
hf_text_embedding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# Copyright (c) Microsoft. All rights reserved.
import logging
from typing import Any
import sentence_transformers
import torch
from numpy import array, ndarray
from semantic_kernel.connectors.ai.embeddings.embedding_generator_base import EmbeddingGeneratorBase
from semantic_kernel.exceptions import ServiceResponseException
from semantic_kernel.utils.experimental_decorator import experimental_class
logger: logging.Logger = logging.getLogger(__name__)
@experimental_class
class HuggingFaceTextEmbedding(EmbeddingGeneratorBase):
device: str
generator: Any
def __init__(
self,
ai_model_id: str,
device: int | None = -1,
service_id: str | None = None,
) -> None:
"""
Initializes a new instance of the HuggingFaceTextEmbedding class.
Arguments:
ai_model_id {str} -- Hugging Face model card string, see
https://huggingface.co/sentence-transformers
device {Optional[int]} -- Device to run the model on, -1 for CPU, 0+ for GPU.
log -- The logger instance to use. (Optional) (Deprecated)
Note that this model will be downloaded from the Hugging Face model hub.
"""
resolved_device = f"cuda:{device}" if device >= 0 and torch.cuda.is_available() else "cpu"
super().__init__(
ai_model_id=ai_model_id,
service_id=service_id,
device=resolved_device,
generator=sentence_transformers.SentenceTransformer(model_name_or_path=ai_model_id, device=resolved_device),
)
async def generate_embeddings(self, texts: list[str], **kwargs: Any) -> ndarray:
"""
Generates embeddings for a list of texts.
Arguments:
texts {List[str]} -- Texts to generate embeddings for.
Returns:
ndarray -- Embeddings for the texts.
"""
try:
logger.info(f"Generating embeddings for {len(texts)} texts")
embeddings = self.generator.encode(texts, **kwargs)
return array(embeddings)
except Exception as e:
raise ServiceResponseException("Hugging Face embeddings failed", e) from e