Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/unit-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ jobs:
python-version: 3.9.12
- name: Install Python dependencies
run: pip install -e .[test,dev,torch,st]
- uses: FedericoCarboni/setup-ffmpeg@v2
id: setup-ffmpeg
- name: Run Unit test_const
run: python -m pytest -s -v ./tests/unit/test_const.py
- name: Run Unit test_handler
Expand Down
3 changes: 3 additions & 0 deletions dockerfiles/starlette/pytorch/Dockerfile.cpu
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ RUN pip install --no-cache-dir -r /tmp/requirements.txt && rm /tmp/requirements
# Think about a better solution -> base contaienr has pt 1.13. thats why need below 0.14
RUN pip install --no-cache-dir sentence_transformers torchvision~="0.14.0" diffusers=="0.9.0" accelerate=="0.14.0"

# Add upgrade due to issue in base container upgrade https://github.com/mamba-org/mamba/issues/2170
RUN pip install transformers==4.25.1 --no-cache-dir --upgrade

# copy application
COPY src/huggingface_inference_toolkit huggingface_inference_toolkit
COPY src/huggingface_inference_toolkit/webservice_starlette.py webservice_starlette.py
Expand Down
3 changes: 3 additions & 0 deletions dockerfiles/starlette/pytorch/Dockerfile.gpu
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ RUN pip install --no-cache-dir -r /tmp/requirements.txt && rm /tmp/requirements
# Think about a better solution -> base contaienr has pt 1.13. thats why need below 0.14
RUN pip install --no-cache-dir sentence_transformers torchvision~="0.14.0" diffusers=="0.9.0" accelerate=="0.14.0"

# Add upgrade due to issue in base container upgrade https://github.com/mamba-org/mamba/issues/2170
RUN pip install transformers==4.25.1 --no-cache-dir --upgrade

# copy application
COPY src/huggingface_inference_toolkit huggingface_inference_toolkit
COPY src/huggingface_inference_toolkit/webservice_starlette.py webservice_starlette.py
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

install_requires = [
# transformers
"transformers[sklearn,sentencepiece]>=4.20.1",
"transformers[sklearn,sentencepiece]>=4.25.1",
# api stuff
"orjson",
# "robyn",
Expand Down
15 changes: 13 additions & 2 deletions src/huggingface_inference_toolkit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from huggingface_hub import HfApi, login
from huggingface_hub.file_download import cached_download, hf_hub_url
from huggingface_hub.utils import filter_repo_objects
from transformers import pipeline
from transformers import WhisperForConditionalGeneration, pipeline
from transformers.file_utils import is_tf_available, is_torch_available
from transformers.pipelines import Conversation, Pipeline

Expand Down Expand Up @@ -282,5 +282,16 @@ def get_pipeline(task: str, model_dir: Path, **kwargs) -> Pipeline:
# wrapp specific pipeline to support better ux
if task == "conversational":
hf_pipeline = wrap_conversation_pipeline(hf_pipeline)

elif task == "automatic-speech-recognition" and isinstance(hf_pipeline.model, WhisperForConditionalGeneration):
# set chunk length to 30s for whisper to enable long audio files
hf_pipeline._preprocess_params["chunk_length_s"] = 30
hf_pipeline._preprocess_params["ignore_warning"] = True
# set decoder to english by default
# TODO: replace when transformers 4.26.0 is release with
# hf_pipeline.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language=lang, task="transcribe")
hf_pipeline.tokenizer.language = "english"
hf_pipeline.tokenizer.task = "transcribe"
hf_pipeline.model.config.forced_decoder_ids = [
(rank + 1, token) for rank, token in enumerate(hf_pipeline.tokenizer.prefix_tokens[1:])
]
return hf_pipeline
Binary file added tests/resources/audio/long_sample.mp3
Binary file not shown.
11 changes: 11 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,17 @@ def test_get_pipeline():
assert "score" in res[0]


@require_torch
def test_whisper_long_audio():
with tempfile.TemporaryDirectory() as tmpdirname:

storage_dir = _load_repository_from_hf("openai/whisper-tiny", tmpdirname, framework="pytorch")
pipe = get_pipeline("automatic-speech-recognition", storage_dir.as_posix())
res = pipe(os.path.join(os.getcwd(), "tests/resources/audio", "long_sample.mp3"))

assert len(res["text"]) > 700


@require_torch
def test_wrap_conversation_pipeline():
init_pipeline = pipeline(
Expand Down