Skip to content

Commit

Permalink
Add retries with exponential backoff
Browse files Browse the repository at this point in the history
Addresses run-llama#210.
  • Loading branch information
kahkeng committed Jan 12, 2023
1 parent c64c906 commit 3742610
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 4 deletions.
8 changes: 6 additions & 2 deletions gpt_index/langchain_helpers/chain_wrapper.py
Expand Up @@ -3,13 +3,14 @@
from dataclasses import dataclass
from typing import Any, Optional, Tuple

import openai
from langchain import Cohere, LLMChain, OpenAI
from langchain.llms import AI21
from langchain.llms.base import BaseLLM

from gpt_index.constants import MAX_CHUNK_SIZE, NUM_OUTPUTS
from gpt_index.prompts.base import Prompt
from gpt_index.utils import globals_helper
from gpt_index.utils import globals_helper, retry_on_exceptions_with_backoff


@dataclass
Expand Down Expand Up @@ -79,7 +80,10 @@ def _predict(self, prompt: Prompt, **prompt_args: Any) -> str:
# Note: we don't pass formatted_prompt to llm_chain.predict because
# langchain does the same formatting under the hood
full_prompt_args = prompt.get_full_format_args(prompt_args)
llm_prediction = llm_chain.predict(**full_prompt_args)
llm_prediction = retry_on_exceptions_with_backoff(
lambda: llm_chain.predict(**full_prompt_args),
[openai.error.RateLimitError],
)
return llm_prediction

def predict(self, prompt: Prompt, **prompt_args: Any) -> Tuple[str, str]:
Expand Down
27 changes: 26 additions & 1 deletion gpt_index/utils.py
Expand Up @@ -2,9 +2,11 @@

import random
import sys
import time
import traceback
import uuid
from contextlib import contextmanager
from typing import Any, Callable, Generator, List, Optional, Set, cast
from typing import Any, Callable, Generator, List, Optional, Set, Type, cast

import nltk

Expand Down Expand Up @@ -99,3 +101,26 @@ def temp_set_attrs(obj: Any, **kwargs: Any) -> Generator:
finally:
for k, v in prev_values.items():
setattr(obj, k, v)


def retry_on_exceptions_with_backoff(
lambda_fn: Callable,
exception_classes: List[Type[Exception]],
max_tries: int = 10,
min_backoff_secs: float = 0.5,
max_backoff_secs: float = 60.0,
) -> Any:
"""Execute lambda function with retries and exponential backoff."""
exception_class_tuples = tuple(exception_classes)
backoff_secs = min_backoff_secs
tries = 0
while True:
try:
return lambda_fn()
except exception_class_tuples:
traceback.print_exc()
tries += 1
if tries >= max_tries:
raise
time.sleep(backoff_secs)
backoff_secs = min(backoff_secs * 2, max_backoff_secs)
49 changes: 48 additions & 1 deletion tests/test_utils.py
@@ -1,6 +1,10 @@
"""Test utils."""

from gpt_index.utils import globals_helper
from typing import Optional, Type

import pytest

from gpt_index.utils import globals_helper, retry_on_exceptions_with_backoff


def test_tokenizer() -> None:
Expand All @@ -12,3 +16,46 @@ def test_tokenizer() -> None:
text = "hello world foo bar"
tokenizer = globals_helper.tokenizer
assert len(tokenizer(text)) == 4


call_count = 0


def fn_with_exception(exception_cls: Optional[Type[Exception]]) -> bool:
"""Return true unless exception if specified."""
global call_count
call_count += 1
if exception_cls:
raise exception_cls
return True


def test_retry_on_exceptions_with_backoff() -> None:
"""Make sure retry function has accurate number of attempts."""
global call_count
assert fn_with_exception(None)

call_count = 0
with pytest.raises(ValueError):
fn_with_exception(ValueError)
assert call_count == 1

call_count = 0
with pytest.raises(ValueError):
retry_on_exceptions_with_backoff(
lambda: fn_with_exception(ValueError),
[ValueError],
max_tries=3,
min_backoff_secs=0.0,
)
assert call_count == 3

# different exception will not get retried
call_count = 0
with pytest.raises(TypeError):
retry_on_exceptions_with_backoff(
lambda: fn_with_exception(TypeError),
[ValueError],
max_tries=3,
)
assert call_count == 1

0 comments on commit 3742610

Please sign in to comment.