Skip to content

Commit

Permalink
review feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
ZanSara committed May 17, 2023
1 parent 9a3addb commit 0104afb
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 60 deletions.
27 changes: 12 additions & 15 deletions haystack/preview/utils/requests_with_retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
logger = logging.getLogger(__file__)


def request_with_retry(attempts: int = 3, status_codes: Optional[List[int]] = None, **kwargs) -> requests.Response:
def request_with_retry(
attempts: int = 3, retry_on_status_codes: Optional[List[int]] = None, **kwargs
) -> requests.Response:
"""
request_with_retry is a simple wrapper function that executes an HTTP request
with a configurable exponential backoff retry on failures.
Expand All @@ -25,7 +27,7 @@ def request_with_retry(attempts: int = 3, status_codes: Optional[List[int]] = No
res = request_with_retry(method="GET", url="https://example.com", attempts=10)
# Sending an HTTP request with custom HTTP codes to retry
res = request_with_retry(method="GET", url="https://example.com", status_codes=[408, 503])
res = request_with_retry(method="GET", url="https://example.com", retry_on_status_codes=[408, 503])
# Sending an HTTP request with custom timeout in seconds
res = request_with_retry(method="GET", url="https://example.com", timeout=5)
Expand All @@ -44,25 +46,22 @@ def __call__(self, r):
url="https://example.com",
auth=CustomAuth(),
attempts=10,
status_codes[408, 503],
retry_on_status_codes=[408, 503],
timeout=5
)
# Sending a POST request
res = request_with_retry(method="POST", url="https://example.com", data={"key": "value"}, attempts=10)
# Retry all 5xx status codes
res = request_with_retry(method="GET", url="https://example.com", status_codes=list(range(500, 600)))
res = request_with_retry(method="GET", url="https://example.com", retry_on_status_codes=list(range(500, 600)))
:param attempts: Maximum number of attempts to retry the request, defaults to 3
:param status_codes: List of HTTP status codes that will trigger a retry, defaults to [408, 418, 429, 503]
:param **kwargs: Optional arguments that ``request`` takes.
:return: :class:`Response <Response>` object
:param retry_on_status_codes: List of HTTP status codes that will trigger a retry, defaults to >= 400
:param **kwargs: Optional arguments that ``request.request()`` takes.
:return: :class:`requests.Response` object
"""

if status_codes is None:
status_codes = [408, 418, 429, 503]

@retry(
reraise=True,
wait=wait_exponential(),
Expand All @@ -72,12 +71,10 @@ def __call__(self, r):
after=after_log(logger, logging.DEBUG),
)
def run():
# We ignore the missing-timeout Pylint rule as we set a default
kwargs.setdefault("timeout", 10)
res = requests.request(**kwargs) # pylint: disable=missing-timeout
timeout = kwargs.pop("timeout", 10)
res = requests.request(**kwargs, timeout=timeout)

if res.status_code in status_codes:
# We raise only for the status codes that must trigger a retry
if not retry_on_status_codes or res.status_code in retry_on_status_codes:
res.raise_for_status()

return res
Expand Down
84 changes: 39 additions & 45 deletions test/preview/components/audio/test_whisper_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,13 @@ def test_init_no_key(self):

@pytest.mark.unit
def test_run_with_path(self):
with patch("haystack.preview.components.audio.whisper_remote.request_with_retry") as mocked_requests:
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
mocked_requests.post.return_value = mock_response
comp = RemoteWhisperTranscriber(api_key="whatever")
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
comp = RemoteWhisperTranscriber(api_key="whatever")

with patch("haystack.preview.utils.requests_with_retry.requests") as mocked_requests:
mocked_requests.request.return_value = mock_response

result = comp.run(audio_files=[SAMPLES_PATH / "audio" / "this is the content of the document.wav"])
expected = Document(
Expand All @@ -59,12 +60,13 @@ def test_run_with_path(self):

@pytest.mark.unit
def test_run_with_str(self):
with patch("haystack.preview.components.audio.whisper_remote.request_with_retry") as mocked_requests:
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
mocked_requests.post.return_value = mock_response
comp = RemoteWhisperTranscriber(api_key="whatever")
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
comp = RemoteWhisperTranscriber(api_key="whatever")

with patch("haystack.preview.utils.requests_with_retry.requests") as mocked_requests:
mocked_requests.request.return_value = mock_response

result = comp.run(
audio_files=[str((SAMPLES_PATH / "audio" / "this is the content of the document.wav").absolute())]
Expand All @@ -80,12 +82,13 @@ def test_run_with_str(self):

@pytest.mark.unit
def test_transcribe_with_stream(self):
with patch("haystack.preview.components.audio.whisper_remote.request_with_retry") as mocked_requests:
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
mocked_requests.post.return_value = mock_response
comp = RemoteWhisperTranscriber(api_key="whatever")
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
comp = RemoteWhisperTranscriber(api_key="whatever")

with patch("haystack.preview.utils.requests_with_retry.requests") as mocked_requests:
mocked_requests.request.return_value = mock_response

with open(SAMPLES_PATH / "audio" / "this is the content of the document.wav", "rb") as audio_stream:
result = comp.transcribe(audio_files=[audio_stream])
Expand All @@ -97,18 +100,20 @@ def test_transcribe_with_stream(self):

@pytest.mark.unit
def test_api_transcription(self):
with patch("haystack.preview.components.audio.whisper_remote.request_with_retry") as mocked_requests:
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
mocked_requests.post.return_value = mock_response
comp = RemoteWhisperTranscriber(api_key="whatever")
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
comp = RemoteWhisperTranscriber(api_key="whatever")

with patch("haystack.preview.utils.requests_with_retry.requests") as mocked_requests:
mocked_requests.request.return_value = mock_response

comp.run(audio_files=[SAMPLES_PATH / "audio" / "this is the content of the document.wav"])

requests_params = mocked_requests.post.call_args.kwargs
requests_params = mocked_requests.request.call_args.kwargs
requests_params.pop("files")
assert requests_params == {
"method": "post",
"url": "https://api.openai.com/v1/audio/transcriptions",
"data": {"model": "whisper-1"},
"headers": {"Authorization": f"Bearer whatever"},
Expand All @@ -117,35 +122,24 @@ def test_api_transcription(self):

@pytest.mark.unit
def test_api_translation(self):
with patch("haystack.preview.components.audio.whisper_remote.request_with_retry") as mocked_requests:
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
mocked_requests.post.return_value = mock_response
comp = RemoteWhisperTranscriber(api_key="whatever")
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
comp = RemoteWhisperTranscriber(api_key="whatever")

with patch("haystack.preview.utils.requests_with_retry.requests") as mocked_requests:
mocked_requests.request.return_value = mock_response

comp.run(
audio_files=[SAMPLES_PATH / "audio" / "this is the content of the document.wav"],
whisper_params={"translate": True},
)

requests_params = mocked_requests.post.call_args.kwargs
requests_params = mocked_requests.request.call_args.kwargs
requests_params.pop("files")
assert requests_params == {
"method": "post",
"url": "https://api.openai.com/v1/audio/translations",
"data": {"model": "whisper-1"},
"headers": {"Authorization": f"Bearer whatever"},
"timeout": OPENAI_TIMEOUT,
}

@pytest.mark.unit
def test_api_fails(self):
with patch("haystack.preview.components.audio.whisper_remote.request_with_retry") as mocked_requests:
mock_response = MagicMock()
mock_response.status_code = 500
mock_response.content = '{"error": "something went wrong on our end!"}'
mocked_requests.post.return_value = mock_response
comp = RemoteWhisperTranscriber(api_key="whatever")

with pytest.raises(requests.HTTPError):
comp.run(audio_files=[SAMPLES_PATH / "audio" / "this is the content of the document.wav"])

0 comments on commit 0104afb

Please sign in to comment.