Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ZanSara committed May 17, 2023
1 parent 9a58a2f commit c69d138
Show file tree
Hide file tree
Showing 9 changed files with 169 additions and 110 deletions.
39 changes: 39 additions & 0 deletions e2e/preview/components/test_transcriber.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import os
from pathlib import Path

import pytest

from haystack.preview.components.audio.whisper_remote import RemoteWhisperTranscriber


SAMPLES_PATH = Path(__file__).parent.parent / "test_files"


@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
def test_whisperremotetranscriber():
comp = RemoteWhisperTranscriber(api_key=os.environ.get("OPENAI_API_KEY"))

output = comp.run(
audio_files=[
SAMPLES_PATH / "audio" / "this is the content of the document.wav",
str((SAMPLES_PATH / "audio" / "the context for this answer is here.wav").absolute()),
open(SAMPLES_PATH / "audio" / "answer.wav", "rb"),
]
)
docs = output.documents
assert len(docs) == 3

assert "this is the content of the document." == docs[0].content.strip().lower()
assert SAMPLES_PATH / "audio" / "this is the content of the document.wav" == docs[0].metadata["audio_file"]

assert "the context for this answer is here." == docs[1].content.strip().lower()
assert (
str((SAMPLES_PATH / "audio" / "the context for this answer is here.wav").absolute())
== docs[1].metadata["audio_file"]
)

assert "answer." == docs[2].content.strip().lower()
assert "<<binary stream>>" == docs[2].metadata["audio_file"]
Binary file added e2e/preview/test_files/audio/answer.wav
Binary file not shown.
Binary file not shown.
13 changes: 0 additions & 13 deletions e2e/preview/test_whisper_remote.py

This file was deleted.

10 changes: 7 additions & 3 deletions haystack/preview/components/audio/whisper_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(self, api_key: str, model_name_or_path: WhisperRemoteModel = "whisp
self.api_key = api_key
self.model_name = model_name_or_path

def run(self, audio_files: List[Path], whisper_params: Dict[str, Any]) -> Output:
def run(self, audio_files: List[Path], whisper_params: Optional[Dict[str, Any]] = None) -> Output:
"""
Transcribe the audio files into a list of Documents, one for each input file.
Expand All @@ -66,6 +66,8 @@ def run(self, audio_files: List[Path], whisper_params: Dict[str, Any]) -> Output
alignment data. Another key called `audio_file` contains the path to the audio file used for the
transcription.
"""
if not whisper_params:
whisper_params = {}
documents = self.transcribe(audio_files, **whisper_params)
return RemoteWhisperTranscriber.Output(documents)

Expand All @@ -84,6 +86,8 @@ def transcribe(self, audio_files: Sequence[Union[str, Path, BinaryIO]], **kwargs
documents = []
for audio, transcript in zip(audio_files, transcriptions):
content = transcript.pop("text")
if not isinstance(audio, (str, Path)):
audio = "<<binary stream>>"
doc = Document(content=content, metadata={"audio_file": audio, **transcript})
documents.append(doc)
return documents
Expand All @@ -108,8 +112,8 @@ def _raw_transcribe(self, audio_files: Sequence[Union[str, Path, BinaryIO]], **k
transcriptions = []
for audio_file in audio_files:
if isinstance(audio_file, (str, Path)):
with open(audio_file, "rb") as audio_file:
transcription = self._invoke_api(audio_file, url, data, headers)
audio_file = open(audio_file, "rb")
transcription = self._invoke_api(audio_file, url, data, headers)
transcriptions.append(transcription)
return transcriptions

Expand Down
217 changes: 123 additions & 94 deletions test/preview/components/audio/test_whisper_remote.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import os
import sys
from pathlib import Path
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

import pytest
import torch
import whisper
from generalimport import FakeModule
import requests

from haystack.preview.dataclasses import Document
from haystack.preview.components import RemoteWhisperTranscriber
from haystack.preview.components.audio.whisper_remote import (
RemoteWhisperTranscriber,
OPENAI_TIMEOUT,
OpenAIError,
OpenAIRateLimitError,
)

from test.preview.components.test_component_base import BaseTestComponent
from test.preview.components.base import BaseTestComponent


SAMPLES_PATH = Path(__file__).parent.parent / "test_files"
SAMPLES_PATH = Path(__file__).parent.parent.parent / "test_files"


class TestRemoteWhisperTranscriber(BaseTestComponent):
Expand All @@ -26,109 +27,137 @@ class TestRemoteWhisperTranscriber(BaseTestComponent):
def components(self):
return [RemoteWhisperTranscriber(api_key="just a test")]

@pytest.fixture
def mock_models(self, monkeypatch):
def mock_transcribe(_, audio_file, **kwargs):
return {
"text": "test transcription",
"other_metadata": ["other", "meta", "data"],
"kwargs received": kwargs,
}

monkeypatch.setattr(RemoteWhisperTranscriber, "_transcribe_with_api", mock_transcribe)

@pytest.mark.unit
def test_init_remote_unknown_model(self):
def test_init_unknown_model(self):
with pytest.raises(ValueError, match="not recognized"):
RemoteWhisperTranscriber(model_name_or_path="anything")
RemoteWhisperTranscriber(model_name_or_path="anything", api_key="something")

@pytest.mark.unit
def test_init_default_remote_missing_key(self):
with pytest.raises(ValueError, match="API key"):
RemoteWhisperTranscriber()

@pytest.mark.unit
def test_init_explicit_remote_missing_key(self):
with pytest.raises(ValueError, match="API key"):
RemoteWhisperTranscriber(model_name_or_path="whisper-1")

@pytest.mark.unit
def test_init_remote(self):
def test_init_default(self):
transcriber = RemoteWhisperTranscriber(api_key="just a test")
assert transcriber.model_name == "whisper-1"
assert not transcriber.use_local_whisper
assert not hasattr(transcriber, "device")
assert hasattr(transcriber, "_model") and transcriber._model is None
assert transcriber.api_key == "just a test"

@pytest.mark.unit
def test_init_local(self):
transcriber = RemoteWhisperTranscriber(model_name_or_path="large-v2")
assert transcriber.model_name == "large-v2" # Doesn't matter if it's huge, the model is not loaded in init.
assert transcriber.use_local_whisper
assert hasattr(transcriber, "device") and transcriber.device == torch.device("cpu")
assert hasattr(transcriber, "_model") and transcriber._model is None

@pytest.mark.unit
def test_init_local_with_api_key(self):
transcriber = RemoteWhisperTranscriber(model_name_or_path="large-v2")
assert transcriber.model_name == "large-v2" # Doesn't matter if it's huge, the model is not loaded in init.
assert transcriber.use_local_whisper
assert hasattr(transcriber, "device") and transcriber.device == torch.device("cpu")
assert hasattr(transcriber, "_model") and transcriber._model is None
def test_run_with_path(self):
with patch("haystack.preview.components.audio.whisper_remote.requests") as mocked_requests:
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
mocked_requests.post.return_value = mock_response
comp = RemoteWhisperTranscriber(api_key="whatever")

result = comp.run(audio_files=[SAMPLES_PATH / "audio" / "this is the content of the document.wav"])
expected = Document(
content="test transcription",
metadata={
"audio_file": SAMPLES_PATH / "audio" / "this is the content of the document.wav",
"other_metadata": ["other", "meta", "data"],
},
)
assert result.documents == [expected]

@pytest.mark.unit
def test_init_missing_whisper_lib_local_model(self, monkeypatch):
monkeypatch.setitem(sys.modules, "whisper", FakeModule(spec=MagicMock(), message="test"))
with pytest.raises(ValueError, match="audio extra"):
RemoteWhisperTranscriber(model_name_or_path="large-v2")
def test_run_with_str(self):
with patch("haystack.preview.components.audio.whisper_remote.requests") as mocked_requests:
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
mocked_requests.post.return_value = mock_response
comp = RemoteWhisperTranscriber(api_key="whatever")

result = comp.run(
audio_files=[str((SAMPLES_PATH / "audio" / "this is the content of the document.wav").absolute())]
)
expected = Document(
content="test transcription",
metadata={
"audio_file": str((SAMPLES_PATH / "audio" / "this is the content of the document.wav").absolute()),
"other_metadata": ["other", "meta", "data"],
},
)
assert result.documents == [expected]

@pytest.mark.unit
def test_init_missing_whisper_lib_remote_model(self, monkeypatch):
monkeypatch.setitem(sys.modules, "whisper", FakeModule(spec=MagicMock(), message="test"))
# Should not fail if the lib is missing and we're using API
RemoteWhisperTranscriber(model_name_or_path="whisper-1", api_key="doesn't matter")
def test_transcribe_with_stream(self):
with patch("haystack.preview.components.audio.whisper_remote.requests") as mocked_requests:
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
mocked_requests.post.return_value = mock_response
comp = RemoteWhisperTranscriber(api_key="whatever")

with open(SAMPLES_PATH / "audio" / "this is the content of the document.wav", "rb") as audio_stream:
result = comp.transcribe(audio_files=[audio_stream])
expected = Document(
content="test transcription",
metadata={"audio_file": "<<binary stream>>", "other_metadata": ["other", "meta", "data"]},
)
assert result == [expected]

@pytest.mark.unit
def test_warmup_remote_model(self, monkeypatch):
load_model = MagicMock()
monkeypatch.setattr(whisper, "load_model", load_model)
component = RemoteWhisperTranscriber(model_name_or_path="whisper-1", api_key="doesn't matter")
component.warm_up()
assert not load_model.called
def test_api_transcription(self):
with patch("haystack.preview.components.audio.whisper_remote.requests") as mocked_requests:
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
mocked_requests.post.return_value = mock_response
comp = RemoteWhisperTranscriber(api_key="whatever")

comp.run(audio_files=[SAMPLES_PATH / "audio" / "this is the content of the document.wav"])

requests_params = mocked_requests.post.call_args.kwargs
requests_params.pop("files")
assert requests_params == {
"url": "https://api.openai.com/v1/audio/transcriptions",
"data": {"model": "whisper-1"},
"headers": {"Authorization": f"Bearer whatever"},
"timeout": OPENAI_TIMEOUT,
}

@pytest.mark.unit
def test_warmup_local_model(self, monkeypatch):
load_model = MagicMock()
load_model.side_effect = ["FAKE MODEL"]
monkeypatch.setattr(whisper, "load_model", load_model)

component = RemoteWhisperTranscriber(model_name_or_path="large-v2")
component.warm_up()
def test_api_translation(self):
with patch("haystack.preview.components.audio.whisper_remote.requests") as mocked_requests:
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = '{"text": "test transcription", "other_metadata": ["other", "meta", "data"]}'
mocked_requests.post.return_value = mock_response
comp = RemoteWhisperTranscriber(api_key="whatever")

comp.run(
audio_files=[SAMPLES_PATH / "audio" / "this is the content of the document.wav"],
whisper_params={"translate": True},
)

assert hasattr(component, "_model")
assert component._model == "FAKE MODEL"
load_model.assert_called_with("large-v2", device=torch.device(type="cpu"))
requests_params = mocked_requests.post.call_args.kwargs
requests_params.pop("files")
assert requests_params == {
"url": "https://api.openai.com/v1/audio/translations",
"data": {"model": "whisper-1"},
"headers": {"Authorization": f"Bearer whatever"},
"timeout": OPENAI_TIMEOUT,
}

@pytest.mark.unit
def test_warmup_local_model_doesnt_reload(self, monkeypatch):
load_model = MagicMock()
monkeypatch.setattr(whisper, "load_model", load_model)
component = RemoteWhisperTranscriber(model_name_or_path="large-v2")
component.warm_up()
component.warm_up()
load_model.assert_called_once()
def test_api_fails(self):
with patch("haystack.preview.components.audio.whisper_remote.requests") as mocked_requests:
mock_response = MagicMock()
mock_response.status_code = 500
mock_response.content = '{"error": "something went wrong on our end!"}'
mocked_requests.post.return_value = mock_response
comp = RemoteWhisperTranscriber(api_key="whatever")

with pytest.raises(OpenAIError):
comp.run(audio_files=[SAMPLES_PATH / "audio" / "this is the content of the document.wav"])

@pytest.mark.unit
def test_transcribe_to_documents(self, mock_models):
comp = RemoteWhisperTranscriber(model_name_or_path="large-v2")
output = comp.transcribe(audio_files=[SAMPLES_PATH / "audio" / "this is the content of the document.wav"])
assert output == [
Document(
content="test transcription",
metadata={
"audio_file": SAMPLES_PATH / "audio" / "this is the content of the document.wav",
"other_metadata": ["other", "meta", "data"],
"kwargs received": {},
},
)
]
def test_api_rate_limiting(self):
with patch("haystack.preview.components.audio.whisper_remote.requests") as mocked_requests:
mock_response = MagicMock()
mock_response.status_code = 429
mock_response.content = '{"error": "do not sent this many requests please :)"}'
mocked_requests.post.return_value = mock_response
comp = RemoteWhisperTranscriber(api_key="whatever")

with pytest.raises(OpenAIRateLimitError):
comp.run(audio_files=[SAMPLES_PATH / "audio" / "this is the content of the document.wav"])
Binary file not shown.
Binary file not shown.

0 comments on commit c69d138

Please sign in to comment.