Skip to content

Commit

Permalink
Update openai dependency to >= 1.0.0 (#1709)
Browse files Browse the repository at this point in the history
  • Loading branch information
jasper-xian authored Nov 10, 2023
1 parent 12cbb11 commit 086e16b
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 4 deletions.
5 changes: 3 additions & 2 deletions pyserini/encode/_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@

openai.api_key = os.getenv("OPENAI_API_KEY")
openai.organization = os.getenv("OPENAI_ORG_KEY")
client = openai.OpenAI()
OPENAI_API_RETRY_DELAY = 5

def retry_with_delay(func, delay: int = OPENAI_API_RETRY_DELAY, max_retries: int = 10, errors: tuple = (openai.error.RateLimitError)):
def retry_with_delay(func, delay: int = OPENAI_API_RETRY_DELAY, max_retries: int = 10, errors: tuple = (openai.RateLimitError)):
def wrapper(*args, **kwargs):
num_retries = 0
while True:
Expand Down Expand Up @@ -49,7 +50,7 @@ def __init__(self, model_name: str = 'text-embedding-ada-002', tokenizer_name: s

@retry_with_delay
def get_embedding(self, text: str):
return np.array(openai.Embedding.create(input=text, model=self.model)['data'][0]['embedding'])
return np.array(client.embeddings.create(input=text, model=self.model)['data'][0]['embedding'])

def encode(self, text: str, max_length: int = 512, **kwargs):
inputs = self.tokenizer.encode(text=text)[:max_length]
Expand Down
3 changes: 2 additions & 1 deletion pyserini/search/faiss/_searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ def __init__(self, encoder_dir: str = None, encoded_query_dir: str = None,
if encoder_dir:
openai.api_key = os.getenv("OPENAI_API_KEY")
openai.organization = os.getenv("OPENAI_ORG_KEY")
self.client = openai.OpenAI()
self.model = encoder_dir
self.tokenizer = tiktoken.get_encoding(tokenizer_name)
self.max_length = max_length
Expand All @@ -330,7 +331,7 @@ def __init__(self, encoder_dir: str = None, encoded_query_dir: str = None,

@retry_with_delay
def get_embedding(self, text: str):
return np.array(openai.Embedding.create(input=text, model=self.model)['data'][0]['embedding'])
return np.array(self.client.embeddings.create(input=text, model=self.model)['data'][0]['embedding'])

def encode(self, query: str, **kwargs):
if self.has_model:
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ onnxruntime>=1.8.1
lightgbm>=3.3.2
spacy>=3.2.1
pyyaml
openai>=0.28.0
openai>=1.0.0
tiktoken>=0.4.0

0 comments on commit 086e16b

Please sign in to comment.