diff --git a/pyserini/encode/_openai.py b/pyserini/encode/_openai.py index ce0c41531..7b20a8420 100644 --- a/pyserini/encode/_openai.py +++ b/pyserini/encode/_openai.py @@ -8,9 +8,10 @@ openai.api_key = os.getenv("OPENAI_API_KEY") openai.organization = os.getenv("OPENAI_ORG_KEY") +client = openai.OpenAI() OPENAI_API_RETRY_DELAY = 5 -def retry_with_delay(func, delay: int = OPENAI_API_RETRY_DELAY, max_retries: int = 10, errors: tuple = (openai.error.RateLimitError)): +def retry_with_delay(func, delay: int = OPENAI_API_RETRY_DELAY, max_retries: int = 10, errors: tuple = (openai.RateLimitError)): def wrapper(*args, **kwargs): num_retries = 0 while True: @@ -49,7 +50,7 @@ def __init__(self, model_name: str = 'text-embedding-ada-002', tokenizer_name: s @retry_with_delay def get_embedding(self, text: str): - return np.array(openai.Embedding.create(input=text, model=self.model)['data'][0]['embedding']) + return np.array(client.embeddings.create(input=text, model=self.model)['data'][0]['embedding']) def encode(self, text: str, max_length: int = 512, **kwargs): inputs = self.tokenizer.encode(text=text)[:max_length] diff --git a/pyserini/search/faiss/_searcher.py b/pyserini/search/faiss/_searcher.py index 7c29b8f4f..7986f94da 100644 --- a/pyserini/search/faiss/_searcher.py +++ b/pyserini/search/faiss/_searcher.py @@ -321,6 +321,7 @@ def __init__(self, encoder_dir: str = None, encoded_query_dir: str = None, if encoder_dir: openai.api_key = os.getenv("OPENAI_API_KEY") openai.organization = os.getenv("OPENAI_ORG_KEY") + self.client = openai.OpenAI() self.model = encoder_dir self.tokenizer = tiktoken.get_encoding(tokenizer_name) self.max_length = max_length @@ -330,7 +331,7 @@ def __init__(self, encoder_dir: str = None, encoded_query_dir: str = None, @retry_with_delay def get_embedding(self, text: str): - return np.array(openai.Embedding.create(input=text, model=self.model)['data'][0]['embedding']) + return np.array(self.client.embeddings.create(input=text, model=self.model)['data'][0]['embedding']) def encode(self, query: str, **kwargs): if self.has_model: diff --git a/requirements.txt b/requirements.txt index 7af7a66b7..1da88dd34 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,5 +12,5 @@ onnxruntime>=1.8.1 lightgbm>=3.3.2 spacy>=3.2.1 pyyaml -openai>=0.28.0 +openai>=1.0.0 tiktoken>=0.4.0 \ No newline at end of file