/
embedding.py
73 lines (56 loc) · 2.4 KB
/
embedding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import numpy as np
from openai import OpenAI
from loguru import logger
from typing import List, Dict
from sentence_transformers import SentenceTransformer
def cosine_similarity(embedding1: List, embedding2: List) -> float:
""" Get cosine similarity between two embeddings """
product = np.dot(embedding1, embedding2)
norm = np.linalg.norm(embedding1) * np.linalg.norm(embedding2)
return product / norm
class Embedder:
def __init__(self, model: str, openai_key: str = None):
self.name = 'embedder'
self.model_name = model
if model == 'openai':
logger.info('Using OpenAI')
if openai_key is None:
logger.error('No OpenAI API key passed to embedder.')
raise ValueError("No OpenAI API key provided.")
self.client = OpenAI(api_key=openai_key)
try:
self.client.models.list()
except Exception as err:
logger.error(f'Failed to connect to OpenAI API: {err}')
raise Exception(f"Connection to OpenAI API failed: {err}")
self.embed_func = self._openai
else:
logger.info(f'Using SentenceTransformer: {model}')
try:
self.model = SentenceTransformer(model)
logger.success(f'Loaded model: {model}')
except Exception as err:
logger.error(f'Failed to load model: {model} error="{err}"')
raise ValueError(f"Failed to load SentenceTransformer model: {err}")
self.embed_func = self._transformer
logger.success('Loaded embedder')
def generate(self, input_data: str) -> List:
logger.info(f'Generating with: {self.model_name}')
return self.embed_func(input_data)
def _openai(self, input_data: str) -> List:
try:
response = self.client.embeddings.create(
input=input_data, model='text-embedding-ada-002'
)
data = response.data[0]
return data.embedding
except Exception as err:
logger.error(f'Failed to generate embedding: {err}')
return []
def _transformer(self, input_data: str) -> List:
try:
results = self.model.encode(input_data).tolist()
return results
except Exception as err:
logger.error(f'Failed to generate embedding: {err}')
return []