Skip to content

Commit

Permalink
Extends input types of RemoteWhisperTranscriber (#6218)
Browse files Browse the repository at this point in the history
* fix tests

* reno

* tests

* retain file name

* paths are strings for openai sdk

* streams->sources

* feedback

* always add name to file

* mypy

* test placeholder with extension

* fallback

* paths

* path test

* path must be a string

* fix test
  • Loading branch information
ZanSara committed Nov 22, 2023
1 parent e6c8374 commit b751978
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 25 deletions.
19 changes: 13 additions & 6 deletions haystack/preview/components/audio/whisper_remote.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import io
import logging
import os
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union
from pathlib import Path

import openai

Expand Down Expand Up @@ -111,7 +112,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "RemoteWhisperTranscriber":
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
def run(self, streams: List[ByteStream]):
def run(self, sources: List[Union[str, Path, ByteStream]]):
"""
Transcribe the audio files into a list of Documents, one for each input file.
Expand All @@ -124,11 +125,17 @@ def run(self, streams: List[ByteStream]):
"""
documents = []

for stream in streams:
file = io.BytesIO(stream.data)
file.name = stream.metadata.get("file_path", "audio_input.wav") # default name if `file_path` not found
for source in sources:
if not isinstance(source, ByteStream):
path = source
source = ByteStream.from_file_path(Path(source))
source.metadata["file_path"] = path

file = io.BytesIO(source.data)
file.name = str(source.metadata["file_path"]) if "file_path" in source.metadata else "__fallback__.wav"

content = openai.Audio.transcribe(file=file, model=self.model_name, **self.whisper_params)
doc = Document(content=content["text"], meta=stream.metadata)
doc = Document(content=content["text"], meta=source.metadata)
documents.append(doc)

return {"documents": documents}
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
preview:
- Extends input types of RemoteWhisperTranscriber from List[ByteStream] to List[Union[str, Path, ByteStream]] to make possible to connect it to FileTypeRouter.
53 changes: 34 additions & 19 deletions test/preview/components/audio/test_whisper_remote.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from unittest.mock import patch
from pathlib import Path

import openai
import pytest
Expand Down Expand Up @@ -182,7 +183,33 @@ def test_from_dict_with_defualt_parameters_no_env_var(self, monkeypatch):
RemoteWhisperTranscriber.from_dict(data)

@pytest.mark.unit
def test_run(self, preview_samples_path):
def test_run_str(self, preview_samples_path):
with patch("haystack.preview.components.audio.whisper_remote.openai.Audio") as openai_audio_patch:
model = "whisper-1"
file_path = str(preview_samples_path / "audio" / "this is the content of the document.wav")
openai_audio_patch.transcribe.side_effect = mock_openai_response

transcriber = RemoteWhisperTranscriber(api_key="test_api_key", model_name=model, response_format="json")
result = transcriber.run(sources=[file_path])

assert result["documents"][0].content == "test transcription"
assert result["documents"][0].meta["file_path"] == file_path

@pytest.mark.unit
def test_run_path(self, preview_samples_path):
with patch("haystack.preview.components.audio.whisper_remote.openai.Audio") as openai_audio_patch:
model = "whisper-1"
file_path = preview_samples_path / "audio" / "this is the content of the document.wav"
openai_audio_patch.transcribe.side_effect = mock_openai_response

transcriber = RemoteWhisperTranscriber(api_key="test_api_key", model_name=model, response_format="json")
result = transcriber.run(sources=[file_path])

assert result["documents"][0].content == "test transcription"
assert result["documents"][0].meta["file_path"] == file_path

@pytest.mark.unit
def test_run_bytestream(self, preview_samples_path):
with patch("haystack.preview.components.audio.whisper_remote.openai.Audio") as openai_audio_patch:
model = "whisper-1"
file_path = preview_samples_path / "audio" / "this is the content of the document.wav"
Expand All @@ -193,7 +220,7 @@ def test_run(self, preview_samples_path):
byte_stream = audio_stream.read()
audio_file = ByteStream(byte_stream, metadata={"file_path": str(file_path.absolute())})

result = transcriber.run(streams=[audio_file])
result = transcriber.run(sources=[audio_file])

assert result["documents"][0].content == "test transcription"
assert result["documents"][0].meta["file_path"] == str(file_path.absolute())
Expand All @@ -208,32 +235,20 @@ def test_whisper_remote_transcriber(self, preview_samples_path):

paths = [
preview_samples_path / "audio" / "this is the content of the document.wav",
preview_samples_path / "audio" / "the context for this answer is here.wav",
preview_samples_path / "audio" / "answer.wav",
str(preview_samples_path / "audio" / "the context for this answer is here.wav"),
ByteStream.from_file_path(preview_samples_path / "audio" / "answer.wav"),
]

audio_files = []
for file_path in paths:
with open(file_path, "rb") as audio_stream:
byte_stream = audio_stream.read()
audio_file = ByteStream(byte_stream, metadata={"file_path": str(file_path.absolute())})
audio_files.append(audio_file)

output = transcriber.run(streams=audio_files)
output = transcriber.run(sources=paths)

docs = output["documents"]
assert len(docs) == 3
assert docs[0].content.strip().lower() == "this is the content of the document."
assert (
str((preview_samples_path / "audio" / "this is the content of the document.wav").absolute())
== docs[0].meta["file_path"]
)
assert preview_samples_path / "audio" / "this is the content of the document.wav" == docs[0].meta["file_path"]

assert docs[1].content.strip().lower() == "the context for this answer is here."
assert (
str((preview_samples_path / "audio" / "the context for this answer is here.wav").absolute())
== docs[1].meta["file_path"]
str(preview_samples_path / "audio" / "the context for this answer is here.wav") == docs[1].meta["file_path"]
)

assert docs[2].content.strip().lower() == "answer."
assert str((preview_samples_path / "audio" / "answer.wav").absolute()) == docs[2].meta["file_path"]

0 comments on commit b751978

Please sign in to comment.