Skip to content

Commit

Permalink
BUG: Fix model path changed since speechbrain v1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
ghisvail committed May 21, 2024
1 parent e8a27f9 commit 0c3fcfe
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 4 deletions.
4 changes: 3 additions & 1 deletion medkit/audio/transcription/sb_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def __init__(
self.batch_size = batch_size
self._torch_device = "cpu" if self.device < 0 else f"cuda:{self.device}"

asr_class = speechbrain.pretrained.EncoderDecoderASR if needs_decoder else speechbrain.pretrained.EncoderASR
asr_class = (
speechbrain.inference.ASR.EncoderDecoderASR if needs_decoder else speechbrain.inference.ASR.EncoderASR
)

self._asr = asr_class.from_hparams(source=model, savedir=cache_dir, run_opts={"device": self._torch_device})

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ metrics-text-classification = [
"scikit-learn>=1.3.2",
]
metrics-transcription = [
"speechbrain>=0.5",
"speechbrain>=1.0",
]
nlstruct = [
"huggingface-hub",
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/audio/transcription/test_sb_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def __init__(self):

@pytest.fixture(scope="module", autouse=True)
def _mocked_asr(module_mocker):
module_mocker.patch("speechbrain.pretrained.EncoderASR", _MockSpeechbrainASR)
module_mocker.patch("speechbrain.pretrained.EncoderDecoderASR", _MockSpeechbrainASR)
module_mocker.patch("speechbrain.inference.ASR.EncoderASR", _MockSpeechbrainASR)
module_mocker.patch("speechbrain.inference.ASR.EncoderDecoderASR", _MockSpeechbrainASR)


def _gen_segment(nb_samples) -> Segment:
Expand Down

0 comments on commit 0c3fcfe

Please sign in to comment.