1111
1212import openai
1313
14- from typing import Dict , Union , Optional
14+ from typing import Dict , List , Union , Optional
1515from collections import OrderedDict
1616from flask import Flask , request , jsonify , abort
1717from sentence_transformers import SentenceTransformer
@@ -85,16 +85,26 @@ def __init__(self, sbert_model: Optional[str] = None, openai_key: Optional[str]
8585
8686 if openai_key is not None :
8787 openai .api_key = self .openai_key
88- logger .info ('enabled model: text-embedding-ada-002' )
88+ try :
89+ openai .Model .list ()
90+ logger .info ('enabled model: text-embedding-ada-002' )
91+ except Exception as err :
92+ logger .error (f'Failed to connect to OpenAI API; disabling OpenAI model: { err } ' )
8993
90- def generate (self , text : str , model_type : str ) -> Dict [str , Union [str , float , list ]]:
94+ def generate (self , text_batch : List [ str ] , model_type : str ) -> Dict [str , Union [str , float , list ]]:
9195 start_time = time .time ()
92- result = {'status' : 'success' }
96+ result = {
97+ 'status' : 'success' ,
98+ 'message' : '' ,
99+ 'model' : '' ,
100+ 'elapsed' : 0 ,
101+ 'embeddings' : []
102+ }
93103
94104 if model_type == 'openai' :
95105 try :
96- response = openai .Embedding .create (input = text , model = 'text-embedding-ada-002' )
97- result ['embedding ' ] = response [ ' data' ][ 0 ][ 'embedding' ]
106+ response = openai .Embedding .create (input = text_batch , model = 'text-embedding-ada-002' )
107+ result ['embeddings ' ] = [ data [ 'embedding' ] for data in response [ 'data' ] ]
98108 result ['model' ] = 'text-embedding-ada-002'
99109 except Exception as err :
100110 logger .error (f'Failed to get OpenAI embeddings: { err } ' )
@@ -103,8 +113,8 @@ def generate(self, text: str, model_type: str) -> Dict[str, Union[str, float, li
103113
104114 else :
105115 try :
106- embedding = self .model .encode (text ).tolist ()
107- result ['embedding ' ] = embedding
116+ embedding = self .model .encode (text_batch , batch_size = len ( text_batch ), device = 'cuda' ).tolist ()
117+ result ['embeddings ' ] = embedding
108118 result ['model' ] = self .sbert_model
109119 except Exception as err :
110120 logger .error (f'Failed to get sentence-transformers embeddings: { err } ' )
@@ -145,33 +155,18 @@ def submit_text():
145155 if text_data is None :
146156 abort (400 , 'Missing text data to embed' )
147157
148- if model_type not in ['local' , 'openai' ]:
149- abort (400 , 'model field must be one of: local, openai' )
150-
151- if isinstance (text_data , str ):
152- text_data = [text_data ]
153-
154158 if not all (isinstance (text , str ) for text in text_data ):
155159 abort (400 , 'all data must be text strings' )
156160
157161 results = []
158- for text in text_data :
159- result = None
160-
161- if embedding_cache :
162- result = embedding_cache .get (text , model_type )
163- if result :
164- logger .info ('found embedding in cache!' )
165- result = {'embedding' : result , 'cache' : True , "status" : 'success' }
166-
167- if result is None :
168- result = embedding_generator .generate (text , model_type )
162+ result = embedding_generator .generate (text_data , model_type )
169163
170- if embedding_cache and result ['status' ] == 'success' :
171- embedding_cache .set (text , model_type , result ['embedding' ])
172- logger .info ('added to cache' )
164+ if embedding_cache and result ['status' ] == 'success' :
165+ for text , embedding in zip (text_data , result ['embeddings' ]):
166+ embedding_cache .set (text , model_type , embedding )
167+ logger .info ('added to cache' )
173168
174- results .append (result )
169+ results .append (result )
175170
176171 return jsonify (results )
177172
0 commit comments